2026-01-08 01:00:26 +10:30
|
|
|
"""CORS middleware configuration."""
|
2026-04-27 02:44:32 +09:30
|
|
|
import base64
|
|
|
|
|
import json
|
|
|
|
|
from urllib.parse import parse_qs
|
|
|
|
|
|
2026-01-08 15:59:53 +10:30
|
|
|
from flask import request, make_response
|
2026-01-08 01:00:26 +10:30
|
|
|
|
2026-04-27 02:44:32 +09:30
|
|
|
from gatehouse_app.models import OIDCClient
|
|
|
|
|
|
2026-04-26 01:12:39 +09:30
|
|
|
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
|
|
|
|
ALLOWED_HEADERS = (
|
|
|
|
|
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
|
|
|
|
|
"Cache-Control, Pragma, X-WebAuthn-Session-Token"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_origin_allowed(origin, cors_origins):
|
|
|
|
|
"""Return True if the origin is permitted by the CORS config.
|
|
|
|
|
|
|
|
|
|
Handles both wildcard ("*") and explicit origin lists.
|
|
|
|
|
"""
|
|
|
|
|
if not origin:
|
|
|
|
|
return False
|
|
|
|
|
if cors_origins == "*":
|
|
|
|
|
return True
|
|
|
|
|
if isinstance(cors_origins, list):
|
|
|
|
|
if "*" in cors_origins:
|
|
|
|
|
return True
|
|
|
|
|
return origin in cors_origins
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cors_origin_header(cors_origins, request_origin):
|
|
|
|
|
"""Return the value for Access-Control-Allow-Origin.
|
|
|
|
|
|
|
|
|
|
Per the CORS spec, browsers reject ``*`` when credentials are involved,
|
|
|
|
|
so we echo the request origin when wildcard + credentials is configured.
|
|
|
|
|
"""
|
|
|
|
|
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
|
|
|
|
if allow_all and request_origin:
|
|
|
|
|
return request_origin
|
|
|
|
|
if allow_all:
|
|
|
|
|
return "*"
|
|
|
|
|
if request_origin and request_origin in cors_origins:
|
|
|
|
|
return request_origin
|
|
|
|
|
return None
|
|
|
|
|
|
2026-01-08 01:00:26 +10:30
|
|
|
|
2026-04-27 02:44:32 +09:30
|
|
|
def _get_oidc_client_id_from_request():
|
|
|
|
|
"""Extract client_id from OIDC endpoint requests."""
|
|
|
|
|
path = request.path
|
|
|
|
|
|
|
|
|
|
# POST to /oidc/token, /oidc/revoke, /oidc/introspect
|
|
|
|
|
if request.method == "POST" and any(
|
|
|
|
|
path.endswith(ep) for ep in ("/oidc/token", "/oidc/revoke", "/oidc/introspect")
|
|
|
|
|
):
|
|
|
|
|
# Try Basic Auth header first
|
|
|
|
|
auth_header = request.headers.get("Authorization", "")
|
|
|
|
|
if auth_header.startswith("Basic "):
|
|
|
|
|
try:
|
|
|
|
|
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
|
|
|
|
|
client_id, _, _ = decoded.partition(":")
|
|
|
|
|
if client_id:
|
|
|
|
|
return client_id
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# Try form body
|
|
|
|
|
if request.form:
|
|
|
|
|
client_id = request.form.get("client_id")
|
|
|
|
|
if client_id:
|
|
|
|
|
return client_id
|
|
|
|
|
|
|
|
|
|
# Try JSON body
|
|
|
|
|
if request.is_json:
|
|
|
|
|
try:
|
|
|
|
|
client_id = request.json.get("client_id")
|
|
|
|
|
if client_id:
|
|
|
|
|
return client_id
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# GET/POST to /oidc/userinfo
|
|
|
|
|
if path.endswith("/oidc/userinfo"):
|
|
|
|
|
auth_header = request.headers.get("Authorization", "")
|
|
|
|
|
if auth_header.startswith("Bearer "):
|
|
|
|
|
token = auth_header[7:]
|
|
|
|
|
try:
|
|
|
|
|
payload_b64 = token.split(".")[1]
|
|
|
|
|
padding = 4 - len(payload_b64) % 4
|
|
|
|
|
if padding != 4:
|
|
|
|
|
payload_b64 += "=" * padding
|
|
|
|
|
payload = json.loads(base64.urlsafe_b64decode(payload_b64))
|
|
|
|
|
return payload.get("client_id")
|
|
|
|
|
except Exception:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_effective_cors_origins(app, request):
|
|
|
|
|
"""Get effective CORS origins, checking per-client config for OIDC endpoints."""
|
|
|
|
|
global_origins = app.config.get("CORS_ORIGINS", [])
|
|
|
|
|
|
|
|
|
|
if "/oidc/" not in request.path:
|
|
|
|
|
return global_origins
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
client_id = _get_oidc_client_id_from_request()
|
|
|
|
|
if not client_id:
|
|
|
|
|
return global_origins
|
|
|
|
|
|
|
|
|
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
|
|
|
|
if not client:
|
|
|
|
|
return global_origins
|
|
|
|
|
|
|
|
|
|
effective = client.get_effective_origins()
|
|
|
|
|
if effective is not None:
|
|
|
|
|
return effective
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
return global_origins
|
|
|
|
|
|
|
|
|
|
|
2026-01-08 15:59:53 +10:30
|
|
|
def setup_cors(app):
|
2026-01-08 01:00:26 +10:30
|
|
|
"""
|
|
|
|
|
Configure CORS for the application.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
app: Flask application instance
|
|
|
|
|
"""
|
2026-04-26 01:12:39 +09:30
|
|
|
supports_credentials = app.config.get("CORS_SUPPORTS_CREDENTIALS", True)
|
2026-01-08 15:59:53 +10:30
|
|
|
|
|
|
|
|
@app.before_request
|
|
|
|
|
def handle_preflight():
|
|
|
|
|
"""Handle CORS preflight OPTIONS requests."""
|
|
|
|
|
if request.method == "OPTIONS":
|
|
|
|
|
origin = request.headers.get("Origin")
|
2026-04-27 02:44:32 +09:30
|
|
|
cors_origins = _get_effective_cors_origins(app, request)
|
2026-04-26 01:12:39 +09:30
|
|
|
|
|
|
|
|
if not _is_origin_allowed(origin, cors_origins):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
response = make_response("", 204)
|
|
|
|
|
response.headers["Access-Control-Allow-Origin"] = _cors_origin_header(cors_origins, origin)
|
|
|
|
|
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
|
|
|
|
|
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
|
|
|
|
|
if supports_credentials:
|
2026-01-08 15:59:53 +10:30
|
|
|
response.headers["Access-Control-Allow-Credentials"] = "true"
|
2026-04-26 01:12:39 +09:30
|
|
|
response.headers["Access-Control-Max-Age"] = "3600"
|
|
|
|
|
response.headers["Cache-Control"] = "no-cache, no-store"
|
|
|
|
|
return response
|
2026-01-08 01:00:26 +10:30
|
|
|
|
|
|
|
|
@app.after_request
|
|
|
|
|
def after_request_cors(response):
|
2026-04-26 01:12:39 +09:30
|
|
|
"""Add CORS headers to non-preflight responses."""
|
2026-01-08 01:00:26 +10:30
|
|
|
origin = request.headers.get("Origin")
|
2026-04-27 02:44:32 +09:30
|
|
|
cors_origins = _get_effective_cors_origins(app, request)
|
2026-01-08 01:00:26 +10:30
|
|
|
|
2026-04-26 01:12:39 +09:30
|
|
|
allow_origin = _cors_origin_header(cors_origins, origin)
|
|
|
|
|
if allow_origin:
|
|
|
|
|
response.headers["Access-Control-Allow-Origin"] = allow_origin
|
|
|
|
|
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
|
|
|
|
|
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
|
|
|
|
|
if supports_credentials:
|
|
|
|
|
response.headers["Access-Control-Allow-Credentials"] = "true"
|
2026-01-08 01:00:26 +10:30
|
|
|
response.headers["Access-Control-Max-Age"] = "3600"
|
|
|
|
|
|
|
|
|
|
return response
|