oidc-client mk1
This commit is contained in:
@@ -28,6 +28,7 @@ def list_org_clients(org_id):
|
||||
"redirect_uris": c.redirect_uris,
|
||||
"scopes": c.scopes,
|
||||
"grant_types": c.grant_types,
|
||||
"allowed_cors_origins": c.allowed_cors_origins,
|
||||
"is_active": c.is_active,
|
||||
"created_at": c.created_at.isoformat() + "Z",
|
||||
}
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
"""CORS middleware configuration."""
|
||||
import base64
|
||||
import json
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from flask import request, make_response
|
||||
|
||||
from gatehouse_app.models import OIDCClient
|
||||
|
||||
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
ALLOWED_HEADERS = (
|
||||
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
|
||||
@@ -40,6 +46,85 @@ def _cors_origin_header(cors_origins, request_origin):
|
||||
return None
|
||||
|
||||
|
||||
def _get_oidc_client_id_from_request():
|
||||
"""Extract client_id from OIDC endpoint requests."""
|
||||
path = request.path
|
||||
|
||||
# POST to /oidc/token, /oidc/revoke, /oidc/introspect
|
||||
if request.method == "POST" and any(
|
||||
path.endswith(ep) for ep in ("/oidc/token", "/oidc/revoke", "/oidc/introspect")
|
||||
):
|
||||
# Try Basic Auth header first
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
|
||||
client_id, _, _ = decoded.partition(":")
|
||||
if client_id:
|
||||
return client_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try form body
|
||||
if request.form:
|
||||
client_id = request.form.get("client_id")
|
||||
if client_id:
|
||||
return client_id
|
||||
|
||||
# Try JSON body
|
||||
if request.is_json:
|
||||
try:
|
||||
client_id = request.json.get("client_id")
|
||||
if client_id:
|
||||
return client_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
# GET/POST to /oidc/userinfo
|
||||
if path.endswith("/oidc/userinfo"):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
payload_b64 = token.split(".")[1]
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
payload = json.loads(base64.urlsafe_b64decode(payload_b64))
|
||||
return payload.get("client_id")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_effective_cors_origins(app, request):
|
||||
"""Get effective CORS origins, checking per-client config for OIDC endpoints."""
|
||||
global_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
if "/oidc/" not in request.path:
|
||||
return global_origins
|
||||
|
||||
try:
|
||||
client_id = _get_oidc_client_id_from_request()
|
||||
if not client_id:
|
||||
return global_origins
|
||||
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
return global_origins
|
||||
|
||||
effective = client.get_effective_origins()
|
||||
if effective is not None:
|
||||
return effective
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return global_origins
|
||||
|
||||
|
||||
def setup_cors(app):
|
||||
"""
|
||||
Configure CORS for the application.
|
||||
@@ -54,7 +139,7 @@ def setup_cors(app):
|
||||
"""Handle CORS preflight OPTIONS requests."""
|
||||
if request.method == "OPTIONS":
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
cors_origins = _get_effective_cors_origins(app, request)
|
||||
|
||||
if not _is_origin_allowed(origin, cors_origins):
|
||||
return None
|
||||
@@ -73,7 +158,7 @@ def setup_cors(app):
|
||||
def after_request_cors(response):
|
||||
"""Add CORS headers to non-preflight responses."""
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
cors_origins = _get_effective_cors_origins(app, request)
|
||||
|
||||
allow_origin = _cors_origin_header(cors_origins, origin)
|
||||
if allow_origin:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""OIDC Client model."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType
|
||||
@@ -21,6 +23,7 @@ class OIDCClient(BaseModel):
|
||||
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
|
||||
response_types = db.Column(db.JSON, nullable=False) # Allowed response types
|
||||
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
|
||||
allowed_cors_origins = db.Column(db.JSON, nullable=True, default=None) # Per-client CORS origins
|
||||
|
||||
# Client metadata
|
||||
logo_uri = db.Column(db.String(512), nullable=True)
|
||||
@@ -81,6 +84,37 @@ class OIDCClient(BaseModel):
|
||||
"""Check if a redirect URI is allowed for this client."""
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def get_effective_origins(self) -> list | None:
|
||||
"""Get effective CORS origins for this client.
|
||||
|
||||
Returns None to signal "use global config", a derived list from
|
||||
redirect_uris when "+" is present, or the configured list as-is.
|
||||
"""
|
||||
if self.allowed_cors_origins is None:
|
||||
return None
|
||||
if "+" in self.allowed_cors_origins:
|
||||
origins = set()
|
||||
for uri in self.redirect_uris:
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme and parsed.hostname:
|
||||
port = f":{parsed.port}" if parsed.port else ""
|
||||
origins.add(f"{parsed.scheme}://{parsed.hostname}{port}")
|
||||
return sorted(origins)
|
||||
return list(self.allowed_cors_origins)
|
||||
|
||||
def is_origin_allowed(self, origin: str) -> bool | None:
|
||||
"""Check if a browser origin is allowed for CORS.
|
||||
|
||||
Returns True/False when a per-client list is configured,
|
||||
or None to defer to the global CORS policy.
|
||||
"""
|
||||
effective = self.get_effective_origins()
|
||||
if effective is None:
|
||||
return None
|
||||
if "*" in effective:
|
||||
return True
|
||||
return origin in effective
|
||||
|
||||
def has_scope(self, scope: str) -> bool:
|
||||
"""Check if client is allowed to request a specific scope."""
|
||||
return scope in self.scopes
|
||||
|
||||
Reference in New Issue
Block a user