improved
Auth Dependencies for FastAPI Demo App
20 days ago by Jim Verducci
The FastAPI Demo App now shows how to use Dependencies instead of Middleware for securing APIs. The core logic remains the same, but this approach better aligns with FastAPI best practices and provides more granular control over which routes require authentication.
Before:
Middleware would perform path matching to determine which routes required authentication.
# src/middleware/session_auth_middleware.py
from datetime import datetime, timedelta
from fastapi import Request, Response, status
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Awaitable, Callable, Optional
from wristband.fastapi_auth import TokenData
from auth.wristband import wristband_auth
from models.session_data import SessionData
from utils.csrf import update_csrf_cookie
class SessionCookieAuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
path: str = request.url.path
if path.startswith("/api/auth/") or path.startswith("/static/") or path == "/api/hello":
return await call_next(request)
session_data: SessionData = request.state.session.get()
if not session_data.is_authenticated:
return Response(status_code=status.HTTP_401_UNAUTHORIZED)
header_csrf_token = request.headers.get("X-CSRF-TOKEN")
if not session_data.csrf_token or not header_csrf_token or session_data.csrf_token != header_csrf_token:
return Response(status_code=status.HTTP_403_FORBIDDEN)
try:
new_token_data: Optional[TokenData] = await wristband_auth.refresh_token_if_expired(
session_data.refresh_token, session_data.expires_at
)
if new_token_data:
session_data.access_token = new_token_data.access_token
session_data.refresh_token = new_token_data.refresh_token
session_data.expires_at = new_token_data.expires_at
response: Response = await call_next(request)
request.state.session.update(response, session_data)
update_csrf_cookie(response, session_data.csrf_token)
return response
except Exception as e:
return Response(status_code=status.HTTP_401_UNAUTHORIZED)
After:
Middleware logic now lives in a Dependency that gets injected at either the router and endpoint level.
# src/auth/session_auth_dependencies.py
from fastapi import HTTPException, Request, Response, status
from typing import Optional
from wristband.fastapi_auth import TokenData
from auth.wristband import wristband_auth
from models.schemas import SessionData
from utils.csrf import update_csrf_cookie
async def require_session_auth(request: Request, response: Response) -> None:
session_data: SessionData = request.state.session.get()
if not session_data.is_authenticated:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
header_csrf_token = request.headers.get("X-CSRF-TOKEN")
if not session_data.csrf_token or not header_csrf_token or session_data.csrf_token != header_csrf_token:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
try:
new_token_data: Optional[TokenData] = await wristband_auth.refresh_token_if_expired(
session_data.refresh_token, session_data.expires_at
)
if new_token_data:
session_data.access_token = new_token_data.access_token
session_data.refresh_token = new_token_data.refresh_token
session_data.expires_at = new_token_data.expires_at
request.state.session.update(response, session_data)
update_csrf_cookie(response, session_data.csrf_token)
except Exception as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
# src/routes/session_routes.py
from fastapi import APIRouter, Depends, HTTPException, Request, status
from auth.session_auth_dependencies import require_session_auth
from models.schemas import SessionData, SessionResponse
router = APIRouter()
@router.get("", dependencies=[Depends(require_session_auth)], response_model=SessionResponse)
async def get_session(request: Request) -> SessionResponse:
try:
session_data: SessionData = request.state.session.get()
return SessionResponse(
tenant_id=session_data.tenant_id,
user_id=session_data.user_id,
metadata=session_data,
)
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)