oidc-client mk1

This commit is contained in:
2026-04-27 02:44:32 +09:30
parent 02e95a4199
commit 63a3109a82
6 changed files with 889 additions and 2 deletions
+87 -2
View File
@@ -1,6 +1,12 @@
"""CORS middleware configuration."""
import base64
import json
from urllib.parse import parse_qs
from flask import request, make_response
from gatehouse_app.models import OIDCClient
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
ALLOWED_HEADERS = (
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
@@ -40,6 +46,85 @@ def _cors_origin_header(cors_origins, request_origin):
return None
def _get_oidc_client_id_from_request():
"""Extract client_id from OIDC endpoint requests."""
path = request.path
# POST to /oidc/token, /oidc/revoke, /oidc/introspect
if request.method == "POST" and any(
path.endswith(ep) for ep in ("/oidc/token", "/oidc/revoke", "/oidc/introspect")
):
# Try Basic Auth header first
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
try:
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
client_id, _, _ = decoded.partition(":")
if client_id:
return client_id
except Exception:
pass
# Try form body
if request.form:
client_id = request.form.get("client_id")
if client_id:
return client_id
# Try JSON body
if request.is_json:
try:
client_id = request.json.get("client_id")
if client_id:
return client_id
except Exception:
pass
return None
# GET/POST to /oidc/userinfo
if path.endswith("/oidc/userinfo"):
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
try:
payload_b64 = token.split(".")[1]
padding = 4 - len(payload_b64) % 4
if padding != 4:
payload_b64 += "=" * padding
payload = json.loads(base64.urlsafe_b64decode(payload_b64))
return payload.get("client_id")
except Exception:
return None
return None
def _get_effective_cors_origins(app, request):
"""Get effective CORS origins, checking per-client config for OIDC endpoints."""
global_origins = app.config.get("CORS_ORIGINS", [])
if "/oidc/" not in request.path:
return global_origins
try:
client_id = _get_oidc_client_id_from_request()
if not client_id:
return global_origins
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return global_origins
effective = client.get_effective_origins()
if effective is not None:
return effective
except Exception:
pass
return global_origins
def setup_cors(app):
"""
Configure CORS for the application.
@@ -54,7 +139,7 @@ def setup_cors(app):
"""Handle CORS preflight OPTIONS requests."""
if request.method == "OPTIONS":
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
cors_origins = _get_effective_cors_origins(app, request)
if not _is_origin_allowed(origin, cors_origins):
return None
@@ -73,7 +158,7 @@ def setup_cors(app):
def after_request_cors(response):
"""Add CORS headers to non-preflight responses."""
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
cors_origins = _get_effective_cors_origins(app, request)
allow_origin = _cors_origin_header(cors_origins, origin)
if allow_origin: