oidc-client mk1
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
"""CORS middleware configuration."""
|
||||
import base64
|
||||
import json
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from flask import request, make_response
|
||||
|
||||
from gatehouse_app.models import OIDCClient
|
||||
|
||||
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
ALLOWED_HEADERS = (
|
||||
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
|
||||
@@ -40,6 +46,85 @@ def _cors_origin_header(cors_origins, request_origin):
|
||||
return None
|
||||
|
||||
|
||||
def _get_oidc_client_id_from_request():
|
||||
"""Extract client_id from OIDC endpoint requests."""
|
||||
path = request.path
|
||||
|
||||
# POST to /oidc/token, /oidc/revoke, /oidc/introspect
|
||||
if request.method == "POST" and any(
|
||||
path.endswith(ep) for ep in ("/oidc/token", "/oidc/revoke", "/oidc/introspect")
|
||||
):
|
||||
# Try Basic Auth header first
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
|
||||
client_id, _, _ = decoded.partition(":")
|
||||
if client_id:
|
||||
return client_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try form body
|
||||
if request.form:
|
||||
client_id = request.form.get("client_id")
|
||||
if client_id:
|
||||
return client_id
|
||||
|
||||
# Try JSON body
|
||||
if request.is_json:
|
||||
try:
|
||||
client_id = request.json.get("client_id")
|
||||
if client_id:
|
||||
return client_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
# GET/POST to /oidc/userinfo
|
||||
if path.endswith("/oidc/userinfo"):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
payload_b64 = token.split(".")[1]
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
payload = json.loads(base64.urlsafe_b64decode(payload_b64))
|
||||
return payload.get("client_id")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_effective_cors_origins(app, request):
|
||||
"""Get effective CORS origins, checking per-client config for OIDC endpoints."""
|
||||
global_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
if "/oidc/" not in request.path:
|
||||
return global_origins
|
||||
|
||||
try:
|
||||
client_id = _get_oidc_client_id_from_request()
|
||||
if not client_id:
|
||||
return global_origins
|
||||
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
return global_origins
|
||||
|
||||
effective = client.get_effective_origins()
|
||||
if effective is not None:
|
||||
return effective
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return global_origins
|
||||
|
||||
|
||||
def setup_cors(app):
|
||||
"""
|
||||
Configure CORS for the application.
|
||||
@@ -54,7 +139,7 @@ def setup_cors(app):
|
||||
"""Handle CORS preflight OPTIONS requests."""
|
||||
if request.method == "OPTIONS":
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
cors_origins = _get_effective_cors_origins(app, request)
|
||||
|
||||
if not _is_origin_allowed(origin, cors_origins):
|
||||
return None
|
||||
@@ -73,7 +158,7 @@ def setup_cors(app):
|
||||
def after_request_cors(response):
|
||||
"""Add CORS headers to non-preflight responses."""
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
cors_origins = _get_effective_cors_origins(app, request)
|
||||
|
||||
allow_origin = _cors_origin_header(cors_origins, origin)
|
||||
if allow_origin:
|
||||
|
||||
Reference in New Issue
Block a user