improved

Auth Dependencies for FastAPI Demo App

The FastAPI Demo App now shows how to use Dependencies instead of Middleware for securing APIs. The core logic remains the same, but this approach better aligns with FastAPI best practices and provides more granular control over which routes require authentication.

Before:

Middleware would perform path matching to determine which routes required authentication.

# src/middleware/session_auth_middleware.py
from datetime import datetime, timedelta
from fastapi import Request, Response, status
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Awaitable, Callable, Optional
from wristband.fastapi_auth import TokenData

from auth.wristband import wristband_auth
from models.session_data import SessionData
from utils.csrf import update_csrf_cookie


class SessionCookieAuthMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
        path: str = request.url.path

        if path.startswith("/api/auth/") or path.startswith("/static/") or path == "/api/hello":
            return await call_next(request)

        session_data: SessionData = request.state.session.get()
        if not session_data.is_authenticated:
            return Response(status_code=status.HTTP_401_UNAUTHORIZED)

        header_csrf_token = request.headers.get("X-CSRF-TOKEN")
        if not session_data.csrf_token or not header_csrf_token or session_data.csrf_token != header_csrf_token:
            return Response(status_code=status.HTTP_403_FORBIDDEN)

        try:
            new_token_data: Optional[TokenData] = await wristband_auth.refresh_token_if_expired(
                session_data.refresh_token, session_data.expires_at
            )
            if new_token_data:
                session_data.access_token = new_token_data.access_token
                session_data.refresh_token = new_token_data.refresh_token
                session_data.expires_at = new_token_data.expires_at

            response: Response = await call_next(request)
            request.state.session.update(response, session_data)
            update_csrf_cookie(response, session_data.csrf_token)
            return response
        except Exception as e:
            return Response(status_code=status.HTTP_401_UNAUTHORIZED)

After:

Middleware logic now lives in a Dependency that gets injected at either the router and endpoint level.

# src/auth/session_auth_dependencies.py
from fastapi import HTTPException, Request, Response, status
from typing import Optional
from wristband.fastapi_auth import TokenData

from auth.wristband import wristband_auth
from models.schemas import SessionData
from utils.csrf import update_csrf_cookie


async def require_session_auth(request: Request, response: Response) -> None:
    session_data: SessionData = request.state.session.get()
    if not session_data.is_authenticated:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)

    header_csrf_token = request.headers.get("X-CSRF-TOKEN")
    if not session_data.csrf_token or not header_csrf_token or session_data.csrf_token != header_csrf_token:
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)

    try:
        new_token_data: Optional[TokenData] = await wristband_auth.refresh_token_if_expired(
            session_data.refresh_token, session_data.expires_at
        )
        if new_token_data:
            session_data.access_token = new_token_data.access_token
            session_data.refresh_token = new_token_data.refresh_token
            session_data.expires_at = new_token_data.expires_at

        request.state.session.update(response, session_data)
        update_csrf_cookie(response, session_data.csrf_token)

    except Exception as e:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
# src/routes/session_routes.py
from fastapi import APIRouter, Depends, HTTPException, Request, status

from auth.session_auth_dependencies import require_session_auth
from models.schemas import SessionData, SessionResponse

router = APIRouter()


@router.get("", dependencies=[Depends(require_session_auth)], response_model=SessionResponse)
async def get_session(request: Request) -> SessionResponse:
    try:
        session_data: SessionData = request.state.session.get()
        return SessionResponse(
            tenant_id=session_data.tenant_id,
            user_id=session_data.user_id,
            metadata=session_data,
        )
    except Exception as e:
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)