oidc-client mk1

This commit is contained in:
2026-04-27 02:44:32 +09:30
parent 02e95a4199
commit 63a3109a82
6 changed files with 889 additions and 2 deletions
@@ -28,6 +28,7 @@ def list_org_clients(org_id):
"redirect_uris": c.redirect_uris,
"scopes": c.scopes,
"grant_types": c.grant_types,
"allowed_cors_origins": c.allowed_cors_origins,
"is_active": c.is_active,
"created_at": c.created_at.isoformat() + "Z",
}
+87 -2
View File
@@ -1,6 +1,12 @@
"""CORS middleware configuration."""
import base64
import json
from urllib.parse import parse_qs
from flask import request, make_response
from gatehouse_app.models import OIDCClient
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
ALLOWED_HEADERS = (
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
@@ -40,6 +46,85 @@ def _cors_origin_header(cors_origins, request_origin):
return None
def _get_oidc_client_id_from_request():
"""Extract client_id from OIDC endpoint requests."""
path = request.path
# POST to /oidc/token, /oidc/revoke, /oidc/introspect
if request.method == "POST" and any(
path.endswith(ep) for ep in ("/oidc/token", "/oidc/revoke", "/oidc/introspect")
):
# Try Basic Auth header first
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
try:
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
client_id, _, _ = decoded.partition(":")
if client_id:
return client_id
except Exception:
pass
# Try form body
if request.form:
client_id = request.form.get("client_id")
if client_id:
return client_id
# Try JSON body
if request.is_json:
try:
client_id = request.json.get("client_id")
if client_id:
return client_id
except Exception:
pass
return None
# GET/POST to /oidc/userinfo
if path.endswith("/oidc/userinfo"):
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
try:
payload_b64 = token.split(".")[1]
padding = 4 - len(payload_b64) % 4
if padding != 4:
payload_b64 += "=" * padding
payload = json.loads(base64.urlsafe_b64decode(payload_b64))
return payload.get("client_id")
except Exception:
return None
return None
def _get_effective_cors_origins(app, request):
"""Get effective CORS origins, checking per-client config for OIDC endpoints."""
global_origins = app.config.get("CORS_ORIGINS", [])
if "/oidc/" not in request.path:
return global_origins
try:
client_id = _get_oidc_client_id_from_request()
if not client_id:
return global_origins
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return global_origins
effective = client.get_effective_origins()
if effective is not None:
return effective
except Exception:
pass
return global_origins
def setup_cors(app):
"""
Configure CORS for the application.
@@ -54,7 +139,7 @@ def setup_cors(app):
"""Handle CORS preflight OPTIONS requests."""
if request.method == "OPTIONS":
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
cors_origins = _get_effective_cors_origins(app, request)
if not _is_origin_allowed(origin, cors_origins):
return None
@@ -73,7 +158,7 @@ def setup_cors(app):
def after_request_cors(response):
"""Add CORS headers to non-preflight responses."""
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
cors_origins = _get_effective_cors_origins(app, request)
allow_origin = _cors_origin_header(cors_origins, origin)
if allow_origin:
+34
View File
@@ -1,4 +1,6 @@
"""OIDC Client model."""
from urllib.parse import urlparse
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType
@@ -21,6 +23,7 @@ class OIDCClient(BaseModel):
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
response_types = db.Column(db.JSON, nullable=False) # Allowed response types
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
allowed_cors_origins = db.Column(db.JSON, nullable=True, default=None) # Per-client CORS origins
# Client metadata
logo_uri = db.Column(db.String(512), nullable=True)
@@ -81,6 +84,37 @@ class OIDCClient(BaseModel):
"""Check if a redirect URI is allowed for this client."""
return redirect_uri in self.redirect_uris
def get_effective_origins(self) -> list | None:
"""Get effective CORS origins for this client.
Returns None to signal "use global config", a derived list from
redirect_uris when "+" is present, or the configured list as-is.
"""
if self.allowed_cors_origins is None:
return None
if "+" in self.allowed_cors_origins:
origins = set()
for uri in self.redirect_uris:
parsed = urlparse(uri)
if parsed.scheme and parsed.hostname:
port = f":{parsed.port}" if parsed.port else ""
origins.add(f"{parsed.scheme}://{parsed.hostname}{port}")
return sorted(origins)
return list(self.allowed_cors_origins)
def is_origin_allowed(self, origin: str) -> bool | None:
"""Check if a browser origin is allowed for CORS.
Returns True/False when a per-client list is configured,
or None to defer to the global CORS policy.
"""
effective = self.get_effective_origins()
if effective is None:
return None
if "*" in effective:
return True
return origin in effective
def has_scope(self, scope: str) -> bool:
"""Check if client is allowed to request a specific scope."""
return scope in self.scopes