"""CORS middleware configuration.""" import base64 import json from urllib.parse import parse_qs from flask import request, make_response from gatehouse_app.models import OIDCClient ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS" ALLOWED_HEADERS = ( "Content-Type, Authorization, X-Requested-With, X-Request-ID, " "Cache-Control, Pragma, X-WebAuthn-Session-Token" ) def _is_origin_allowed(origin, cors_origins): """Return True if the origin is permitted by the CORS config. Handles both wildcard ("*") and explicit origin lists. """ if not origin: return False if cors_origins == "*": return True if isinstance(cors_origins, list): if "*" in cors_origins: return True return origin in cors_origins return False def _cors_origin_header(cors_origins, request_origin): """Return the value for Access-Control-Allow-Origin. Per the CORS spec, browsers reject ``*`` when credentials are involved, so we echo the request origin when wildcard + credentials is configured. """ allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins) if allow_all and request_origin: return request_origin if allow_all: return "*" if request_origin and request_origin in cors_origins: return request_origin return None def _get_oidc_client_id_from_request(): """Extract client_id from OIDC endpoint requests.""" path = request.path # POST to /oidc/token, /oidc/revoke, /oidc/introspect if request.method == "POST" and any( path.endswith(ep) for ep in ("/oidc/token", "/oidc/revoke", "/oidc/introspect") ): # Try Basic Auth header first auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Basic "): try: decoded = base64.b64decode(auth_header[6:]).decode("utf-8") client_id, _, _ = decoded.partition(":") if client_id: return client_id except Exception: pass # Try form body if request.form: client_id = request.form.get("client_id") if client_id: return client_id # Try JSON body if request.is_json: try: client_id = request.json.get("client_id") if client_id: return client_id except Exception: pass return None # GET/POST to /oidc/userinfo if path.endswith("/oidc/userinfo"): auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): token = auth_header[7:] try: payload_b64 = token.split(".")[1] padding = 4 - len(payload_b64) % 4 if padding != 4: payload_b64 += "=" * padding payload = json.loads(base64.urlsafe_b64decode(payload_b64)) return payload.get("client_id") except Exception: return None return None def _get_effective_cors_origins(app, request): """Get effective CORS origins, checking per-client config for OIDC endpoints.""" global_origins = app.config.get("CORS_ORIGINS", []) if "/oidc/" not in request.path: return global_origins try: client_id = _get_oidc_client_id_from_request() if not client_id: return global_origins client = OIDCClient.query.filter_by(client_id=client_id).first() if not client: return global_origins effective = client.get_effective_origins() if effective is not None: return effective except Exception: pass return global_origins def setup_cors(app): """ Configure CORS for the application. Args: app: Flask application instance """ supports_credentials = app.config.get("CORS_SUPPORTS_CREDENTIALS", True) @app.before_request def handle_preflight(): """Handle CORS preflight OPTIONS requests.""" if request.method == "OPTIONS": origin = request.headers.get("Origin") cors_origins = _get_effective_cors_origins(app, request) if not _is_origin_allowed(origin, cors_origins): return None response = make_response("", 204) response.headers["Access-Control-Allow-Origin"] = _cors_origin_header(cors_origins, origin) response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS if supports_credentials: response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Max-Age"] = "3600" response.headers["Cache-Control"] = "no-cache, no-store" return response @app.after_request def after_request_cors(response): """Add CORS headers to non-preflight responses.""" origin = request.headers.get("Origin") cors_origins = _get_effective_cors_origins(app, request) allow_origin = _cors_origin_header(cors_origins, origin) if allow_origin: response.headers["Access-Control-Allow-Origin"] = allow_origin response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS if supports_credentials: response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Max-Age"] = "3600" return response