oidc-client mk1
This commit is contained in:
@@ -0,0 +1,240 @@
|
||||
# Per-Client CORS Origins for OIDC Endpoints
|
||||
|
||||
## Overview
|
||||
|
||||
Gatehouse OIDC now supports **per-client CORS origins**. This allows each OIDC client to declare which browser origins are permitted to make cross-origin requests to OIDC endpoints (`/oidc/token`, `/oidc/revoke`, `/oidc/userinfo`, `/oidc/introspect`).
|
||||
|
||||
Previously, CORS was controlled by a single server-wide `CORS_ORIGINS` environment variable. If your SPA's origin wasn't in that list, the browser would block requests to OIDC endpoints — even if your OIDC client was properly configured.
|
||||
|
||||
## How It Works
|
||||
|
||||
### The Problem
|
||||
|
||||
When a browser-based SPA (e.g., running at `http://localhost:8080`) exchanges an authorization code for tokens, it makes a POST request to `/oidc/token`. The browser sends a preflight OPTIONS request first, and the server must respond with CORS headers allowing the SPA's origin.
|
||||
|
||||
Previously, if `http://localhost:8080` wasn't in the server's `CORS_ORIGINS` env var, the preflight would fail and the SPA couldn't get tokens.
|
||||
|
||||
### The Solution
|
||||
|
||||
Each OIDC client can now declare its own `allowed_cors_origins`. When a request hits an OIDC endpoint, the server checks the client's CORS configuration first, then falls back to the global config.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Setting CORS Origins on an OIDC Client
|
||||
|
||||
When creating or updating an OIDC client, set the `allowed_cors_origins` field:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "My SPA",
|
||||
"client_id": "oidc_myapp",
|
||||
"redirect_uris": ["http://localhost:8080/callback", "https://app.example.com/callback"],
|
||||
"allowed_cors_origins": ["http://localhost:8080", "https://app.example.com"],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"scopes": ["openid", "profile", "email"]
|
||||
}
|
||||
```
|
||||
|
||||
### Auto-Derive from Redirect URIs
|
||||
|
||||
Set `allowed_cors_origins` to `["+"]` to automatically derive CORS origins from the client's `redirect_uris`. The server extracts the scheme, hostname, and port from each redirect URI.
|
||||
|
||||
```json
|
||||
{
|
||||
"redirect_uris": ["http://localhost:8080/callback", "https://app.example.com/callback"],
|
||||
"allowed_cors_origins": ["+"]
|
||||
}
|
||||
```
|
||||
|
||||
This is equivalent to:
|
||||
|
||||
```json
|
||||
{
|
||||
"allowed_cors_origins": ["http://localhost:8080", "https://app.example.com"]
|
||||
}
|
||||
```
|
||||
|
||||
### Use Global Config (Default)
|
||||
|
||||
Set `allowed_cors_origins` to `null` (or omit it) to use the server's global `CORS_ORIGINS` config. This is the default behavior for existing clients.
|
||||
|
||||
```json
|
||||
{
|
||||
"allowed_cors_origins": null
|
||||
}
|
||||
```
|
||||
|
||||
### Allow All Origins (Not Recommended)
|
||||
|
||||
Set `allowed_cors_origins` to `["*"]` to allow any origin. **This is not recommended for production.**
|
||||
|
||||
```json
|
||||
{
|
||||
"allowed_cors_origins": ["*"]
|
||||
}
|
||||
```
|
||||
|
||||
## Affected Endpoints
|
||||
|
||||
The following OIDC endpoints support per-client CORS:
|
||||
|
||||
| Endpoint | Method | How Client is Identified |
|
||||
|---|---|---|
|
||||
| `/oidc/token` | POST | `client_id` in request body or Basic Auth header |
|
||||
| `/oidc/revoke` | POST | `client_id` in request body or Basic Auth header |
|
||||
| `/oidc/introspect` | POST | `client_id` in request body or Basic Auth header |
|
||||
| `/oidc/userinfo` | GET/POST | `client_id` extracted from Bearer token |
|
||||
|
||||
## SPA Integration Guide
|
||||
|
||||
### Step 1: Register Your OIDC Client
|
||||
|
||||
Register your SPA as an OIDC client with the correct redirect URIs and CORS origins:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "My React App",
|
||||
"redirect_uris": ["http://localhost:3000/callback"],
|
||||
"allowed_cors_origins": ["http://localhost:3000"],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"scopes": ["openid", "profile", "email"],
|
||||
"is_confidential": false,
|
||||
"require_pkce": true
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Use PKCE (Required for Public Clients)
|
||||
|
||||
Gatehouse requires PKCE for public clients. Generate a code verifier and challenge before redirecting to the authorize endpoint:
|
||||
|
||||
```javascript
|
||||
// Generate PKCE
|
||||
const codeVerifier = generateRandomString(128);
|
||||
const codeChallenge = await sha256(codeVerifier);
|
||||
const state = generateRandomString(32);
|
||||
|
||||
// Store verifier for later
|
||||
sessionStorage.setItem('pkce_verifier', codeVerifier);
|
||||
|
||||
// Redirect to authorize
|
||||
const authUrl = new URL('https://api.example.com/api/v1/oidc/authorize');
|
||||
authUrl.searchParams.set('response_type', 'code');
|
||||
authUrl.searchParams.set('client_id', 'oidc_myapp');
|
||||
authUrl.searchParams.set('redirect_uri', 'http://localhost:3000/callback');
|
||||
authUrl.searchParams.set('scope', 'openid profile email');
|
||||
authUrl.searchParams.set('state', state);
|
||||
authUrl.searchParams.set('code_challenge', codeChallenge);
|
||||
authUrl.searchParams.set('code_challenge_method', 'S256');
|
||||
|
||||
window.location.href = authUrl.toString();
|
||||
```
|
||||
|
||||
### Step 3: Exchange Code for Tokens
|
||||
|
||||
After the user authenticates and is redirected back to your callback page, exchange the authorization code for tokens:
|
||||
|
||||
```javascript
|
||||
// Extract code from URL
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
const code = params.get('code');
|
||||
const state = params.get('state');
|
||||
|
||||
// Verify state matches
|
||||
if (state !== sessionStorage.getItem('pkce_state')) {
|
||||
throw new Error('State mismatch');
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
const response = await fetch('https://api.example.com/api/v1/oidc/token', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
grant_type: 'authorization_code',
|
||||
code: code,
|
||||
redirect_uri: 'http://localhost:3000/callback',
|
||||
client_id: 'oidc_myapp',
|
||||
code_verifier: sessionStorage.getItem('pkce_verifier'),
|
||||
}),
|
||||
});
|
||||
|
||||
const tokens = await response.json();
|
||||
// tokens.access_token, tokens.id_token, tokens.refresh_token
|
||||
```
|
||||
|
||||
The server will return CORS headers because `http://localhost:3000` is in the client's `allowed_cors_origins`.
|
||||
|
||||
### Step 4: Refresh Tokens
|
||||
|
||||
```javascript
|
||||
const response = await fetch('https://api.example.com/api/v1/oidc/token', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: storedRefreshToken,
|
||||
client_id: 'oidc_myapp',
|
||||
}),
|
||||
});
|
||||
```
|
||||
|
||||
### Step 5: Call UserInfo
|
||||
|
||||
```javascript
|
||||
const response = await fetch('https://api.example.com/api/v1/oidc/userinfo', {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
const userInfo = await response.json();
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "CORS error" when exchanging code for tokens
|
||||
|
||||
**Cause**: Your SPA's origin is not in the client's `allowed_cors_origins` or the server's global `CORS_ORIGINS`.
|
||||
|
||||
**Fix**: Add your SPA's origin to the client's `allowed_cors_origins`:
|
||||
```json
|
||||
{
|
||||
"allowed_cors_origins": ["http://localhost:3000"]
|
||||
}
|
||||
```
|
||||
|
||||
### "CORS error" on preflight OPTIONS request
|
||||
|
||||
**Cause**: The preflight request doesn't carry client credentials, so the server can't identify which client to check CORS origins for. It falls back to the global `CORS_ORIGINS`.
|
||||
|
||||
**Fix**: Either add your origin to the global `CORS_ORIGINS` env var, or ensure the actual POST request (after preflight) includes the `client_id` in the request body.
|
||||
|
||||
### CORS works for `/oidc/token` but not `/oidc/userinfo`
|
||||
|
||||
**Cause**: The userinfo endpoint identifies the client from the Bearer token. If the token doesn't contain a `client_id` claim, the server falls back to global config.
|
||||
|
||||
**Fix**: Ensure your access tokens include the `client_id` claim (this is the default behavior).
|
||||
|
||||
## API Reference
|
||||
|
||||
### OIDCClient Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `allowed_cors_origins` | `string[]` or `null` | List of allowed browser origins. `null` = use global config. `["+"]` = auto-derive from redirect URIs. `["*"]` = allow all (not recommended). |
|
||||
|
||||
### CORS Headers Returned
|
||||
|
||||
When a request's origin matches the client's allowed origins:
|
||||
|
||||
```
|
||||
Access-Control-Allow-Origin: <request-origin>
|
||||
Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS
|
||||
Access-Control-Allow-Headers: Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token
|
||||
Access-Control-Allow-Credentials: true
|
||||
Access-Control-Max-Age: 3600
|
||||
```
|
||||
@@ -28,6 +28,7 @@ def list_org_clients(org_id):
|
||||
"redirect_uris": c.redirect_uris,
|
||||
"scopes": c.scopes,
|
||||
"grant_types": c.grant_types,
|
||||
"allowed_cors_origins": c.allowed_cors_origins,
|
||||
"is_active": c.is_active,
|
||||
"created_at": c.created_at.isoformat() + "Z",
|
||||
}
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
"""CORS middleware configuration."""
|
||||
import base64
|
||||
import json
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from flask import request, make_response
|
||||
|
||||
from gatehouse_app.models import OIDCClient
|
||||
|
||||
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
ALLOWED_HEADERS = (
|
||||
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
|
||||
@@ -40,6 +46,85 @@ def _cors_origin_header(cors_origins, request_origin):
|
||||
return None
|
||||
|
||||
|
||||
def _get_oidc_client_id_from_request():
|
||||
"""Extract client_id from OIDC endpoint requests."""
|
||||
path = request.path
|
||||
|
||||
# POST to /oidc/token, /oidc/revoke, /oidc/introspect
|
||||
if request.method == "POST" and any(
|
||||
path.endswith(ep) for ep in ("/oidc/token", "/oidc/revoke", "/oidc/introspect")
|
||||
):
|
||||
# Try Basic Auth header first
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
|
||||
client_id, _, _ = decoded.partition(":")
|
||||
if client_id:
|
||||
return client_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try form body
|
||||
if request.form:
|
||||
client_id = request.form.get("client_id")
|
||||
if client_id:
|
||||
return client_id
|
||||
|
||||
# Try JSON body
|
||||
if request.is_json:
|
||||
try:
|
||||
client_id = request.json.get("client_id")
|
||||
if client_id:
|
||||
return client_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
# GET/POST to /oidc/userinfo
|
||||
if path.endswith("/oidc/userinfo"):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
payload_b64 = token.split(".")[1]
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
payload = json.loads(base64.urlsafe_b64decode(payload_b64))
|
||||
return payload.get("client_id")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_effective_cors_origins(app, request):
|
||||
"""Get effective CORS origins, checking per-client config for OIDC endpoints."""
|
||||
global_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
if "/oidc/" not in request.path:
|
||||
return global_origins
|
||||
|
||||
try:
|
||||
client_id = _get_oidc_client_id_from_request()
|
||||
if not client_id:
|
||||
return global_origins
|
||||
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
return global_origins
|
||||
|
||||
effective = client.get_effective_origins()
|
||||
if effective is not None:
|
||||
return effective
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return global_origins
|
||||
|
||||
|
||||
def setup_cors(app):
|
||||
"""
|
||||
Configure CORS for the application.
|
||||
@@ -54,7 +139,7 @@ def setup_cors(app):
|
||||
"""Handle CORS preflight OPTIONS requests."""
|
||||
if request.method == "OPTIONS":
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
cors_origins = _get_effective_cors_origins(app, request)
|
||||
|
||||
if not _is_origin_allowed(origin, cors_origins):
|
||||
return None
|
||||
@@ -73,7 +158,7 @@ def setup_cors(app):
|
||||
def after_request_cors(response):
|
||||
"""Add CORS headers to non-preflight responses."""
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
cors_origins = _get_effective_cors_origins(app, request)
|
||||
|
||||
allow_origin = _cors_origin_header(cors_origins, origin)
|
||||
if allow_origin:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""OIDC Client model."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType
|
||||
@@ -21,6 +23,7 @@ class OIDCClient(BaseModel):
|
||||
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
|
||||
response_types = db.Column(db.JSON, nullable=False) # Allowed response types
|
||||
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
|
||||
allowed_cors_origins = db.Column(db.JSON, nullable=True, default=None) # Per-client CORS origins
|
||||
|
||||
# Client metadata
|
||||
logo_uri = db.Column(db.String(512), nullable=True)
|
||||
@@ -81,6 +84,37 @@ class OIDCClient(BaseModel):
|
||||
"""Check if a redirect URI is allowed for this client."""
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def get_effective_origins(self) -> list | None:
|
||||
"""Get effective CORS origins for this client.
|
||||
|
||||
Returns None to signal "use global config", a derived list from
|
||||
redirect_uris when "+" is present, or the configured list as-is.
|
||||
"""
|
||||
if self.allowed_cors_origins is None:
|
||||
return None
|
||||
if "+" in self.allowed_cors_origins:
|
||||
origins = set()
|
||||
for uri in self.redirect_uris:
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme and parsed.hostname:
|
||||
port = f":{parsed.port}" if parsed.port else ""
|
||||
origins.add(f"{parsed.scheme}://{parsed.hostname}{port}")
|
||||
return sorted(origins)
|
||||
return list(self.allowed_cors_origins)
|
||||
|
||||
def is_origin_allowed(self, origin: str) -> bool | None:
|
||||
"""Check if a browser origin is allowed for CORS.
|
||||
|
||||
Returns True/False when a per-client list is configured,
|
||||
or None to defer to the global CORS policy.
|
||||
"""
|
||||
effective = self.get_effective_origins()
|
||||
if effective is None:
|
||||
return None
|
||||
if "*" in effective:
|
||||
return True
|
||||
return origin in effective
|
||||
|
||||
def has_scope(self, scope: str) -> bool:
|
||||
"""Check if client is allowed to request a specific scope."""
|
||||
return scope in self.scopes
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Add allowed_cors_origins to oidc_clients.
|
||||
|
||||
Revision ID: b7e3f1a92c4d
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2026-04-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b7e3f1a92c4d'
|
||||
down_revision = 'a1b2c3d4e5f6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column('oidc_clients', sa.Column('allowed_cors_origins', sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column('oidc_clients', 'allowed_cors_origins')
|
||||
@@ -0,0 +1,503 @@
|
||||
"""Unit tests for per-client CORS feature.
|
||||
|
||||
WHAT: Tests for per-client CORS origin resolution, including OIDCClient
|
||||
model methods, request client_id extraction, effective origin
|
||||
resolution, and integration with the CORS middleware.
|
||||
WHY: Per-client CORS prevents one OIDC client from making cross-origin
|
||||
requests meant for another; misconfiguration breaks browser flows.
|
||||
EXPECTED: Correct origin derivation, proper client_id extraction, and
|
||||
correct CORS headers on OIDC endpoints.
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request as flask_request
|
||||
|
||||
import gatehouse_app.middleware.cors as cors_module
|
||||
from gatehouse_app.middleware.cors import (
|
||||
_get_oidc_client_id_from_request,
|
||||
_get_effective_cors_origins,
|
||||
setup_cors,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: build a lightweight stub that quacks like OIDCClient
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class StubClient:
|
||||
"""Minimal stand-in for OIDCClient -- no SQLAlchemy, no DB needed."""
|
||||
|
||||
def __init__(self, *, allowed_cors_origins=None, redirect_uris=None):
|
||||
self.allowed_cors_origins = allowed_cors_origins
|
||||
self.redirect_uris = redirect_uris or []
|
||||
|
||||
def get_effective_origins(self):
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if self.allowed_cors_origins is None:
|
||||
return None
|
||||
if "+" in self.allowed_cors_origins:
|
||||
origins = set()
|
||||
for uri in self.redirect_uris:
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme and parsed.hostname:
|
||||
port = f":{parsed.port}" if parsed.port else ""
|
||||
origins.add(f"{parsed.scheme}://{parsed.hostname}{port}")
|
||||
return sorted(origins)
|
||||
return list(self.allowed_cors_origins)
|
||||
|
||||
def is_origin_allowed(self, origin):
|
||||
effective = self.get_effective_origins()
|
||||
if effective is None:
|
||||
return None
|
||||
if "*" in effective:
|
||||
return True
|
||||
return origin in effective
|
||||
|
||||
|
||||
def _basic_auth_header(client_id, secret="secret"):
|
||||
"""Return a 'Basic <b64>' Authorization header value."""
|
||||
return "Basic " + base64.b64encode(f"{client_id}:{secret}".encode()).decode()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OIDCClient.get_effective_origins
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetEffectiveOrigins:
|
||||
def test_returns_none_when_allowed_cors_origins_is_none(self):
|
||||
"""TEST: PCORS-GE-01 -- None config signals 'use global'."""
|
||||
client = StubClient(allowed_cors_origins=None)
|
||||
assert client.get_effective_origins() is None
|
||||
|
||||
def test_derives_from_redirect_uris_when_plus_sign(self):
|
||||
"""TEST: PCORS-GE-02 -- '+' in list derives origins from redirect_uris."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["+"],
|
||||
redirect_uris=[
|
||||
"https://app.example.com/callback",
|
||||
"http://localhost:3000/callback",
|
||||
],
|
||||
)
|
||||
assert client.get_effective_origins() == [
|
||||
"http://localhost:3000",
|
||||
"https://app.example.com",
|
||||
]
|
||||
|
||||
def test_derives_with_port(self):
|
||||
"""TEST: PCORS-GE-03 -- Non-standard ports are preserved."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["+"],
|
||||
redirect_uris=["https://app.example.com:8443/cb"],
|
||||
)
|
||||
assert client.get_effective_origins() == ["https://app.example.com:8443"]
|
||||
|
||||
def test_deduplicates_derived_origins(self):
|
||||
"""TEST: PCORS-GE-04 -- Duplicate redirect URIs produce unique origins."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["+"],
|
||||
redirect_uris=[
|
||||
"https://app.example.com/cb1",
|
||||
"https://app.example.com/cb2",
|
||||
],
|
||||
)
|
||||
assert client.get_effective_origins() == ["https://app.example.com"]
|
||||
|
||||
def test_returns_list_as_is_when_normal_list(self):
|
||||
"""TEST: PCORS-GE-05 -- Normal list is returned unchanged."""
|
||||
origins = ["https://a.com", "https://b.com"]
|
||||
client = StubClient(allowed_cors_origins=origins)
|
||||
assert client.get_effective_origins() == origins
|
||||
|
||||
def test_returns_wildcard_list_as_is(self):
|
||||
"""TEST: PCORS-GE-06 -- ['*'] is returned (handled downstream)."""
|
||||
client = StubClient(allowed_cors_origins=["*"])
|
||||
assert client.get_effective_origins() == ["*"]
|
||||
|
||||
def test_empty_redirect_uris_with_plus_returns_empty(self):
|
||||
"""TEST: PCORS-GE-07 -- '+' with empty redirect_uris yields empty list."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["+"],
|
||||
redirect_uris=[],
|
||||
)
|
||||
assert client.get_effective_origins() == []
|
||||
|
||||
def test_skips_malformed_redirect_uris(self):
|
||||
"""TEST: PCORS-GE-08 -- URIs without scheme/host are skipped."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["+"],
|
||||
redirect_uris=["not-a-uri", "https://good.com/cb"],
|
||||
)
|
||||
assert client.get_effective_origins() == ["https://good.com"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OIDCClient.is_origin_allowed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsOriginAllowed:
|
||||
def test_returns_none_when_no_per_client_config(self):
|
||||
"""TEST: PCORS-IO-01 -- None config defers to global CORS."""
|
||||
client = StubClient(allowed_cors_origins=None)
|
||||
assert client.is_origin_allowed("https://anything.com") is None
|
||||
|
||||
def test_returns_true_when_wildcard(self):
|
||||
"""TEST: PCORS-IO-02 -- '*' in effective origins allows any origin."""
|
||||
client = StubClient(allowed_cors_origins=["*"])
|
||||
assert client.is_origin_allowed("https://evil.com") is True
|
||||
|
||||
def test_returns_true_for_matching_origin(self):
|
||||
"""TEST: PCORS-IO-03 -- Matching origin is allowed."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["https://app.example.com", "https://other.com"],
|
||||
)
|
||||
assert client.is_origin_allowed("https://app.example.com") is True
|
||||
|
||||
def test_returns_false_for_non_matching_origin(self):
|
||||
"""TEST: PCORS-IO-04 -- Non-matching origin is rejected."""
|
||||
client = StubClient(allowed_cors_origins=["https://app.example.com"])
|
||||
assert client.is_origin_allowed("https://evil.com") is False
|
||||
|
||||
def test_returns_false_for_empty_list(self):
|
||||
"""TEST: PCORS-IO-05 -- Empty list rejects everything."""
|
||||
client = StubClient(allowed_cors_origins=[])
|
||||
assert client.is_origin_allowed("https://anything.com") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_oidc_client_id_from_request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetOidcClientIdFromRequest:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
def test_extracts_from_basic_auth(self, app):
|
||||
"""TEST: PCORS-CI-01 -- Basic Auth header yields client_id."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("my-client")},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "my-client"
|
||||
|
||||
def test_extracts_from_form_body(self, app):
|
||||
"""TEST: PCORS-CI-02 -- Form-encoded body yields client_id."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
data={"client_id": "form-client", "grant_type": "client_credentials"},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "form-client"
|
||||
|
||||
def test_extracts_from_json_body(self, app):
|
||||
"""TEST: PCORS-CI-03 -- JSON body yields client_id."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
data=json.dumps({"client_id": "json-client", "grant_type": "client_credentials"}),
|
||||
content_type="application/json",
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "json-client"
|
||||
|
||||
def test_extracts_from_bearer_jwt(self, app):
|
||||
"""TEST: PCORS-CI-04 -- Bearer JWT payload yields client_id."""
|
||||
payload = base64.urlsafe_b64encode(
|
||||
json.dumps({"client_id": "jwt-client"}).encode()
|
||||
).rstrip(b"=").decode()
|
||||
token = f"eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.{payload}.sig"
|
||||
with app.test_request_context(
|
||||
"/oidc/userinfo",
|
||||
method="GET",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "jwt-client"
|
||||
|
||||
def test_returns_none_for_non_oidc_endpoint(self, app):
|
||||
"""TEST: PCORS-CI-05 -- Non-OIDC path returns None."""
|
||||
with app.test_request_context(
|
||||
"/api/v1/users",
|
||||
method="GET",
|
||||
headers={"Authorization": _basic_auth_header("x")},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() is None
|
||||
|
||||
def test_returns_none_when_no_client_id_found(self, app):
|
||||
"""TEST: PCORS-CI-06 -- OIDC token endpoint with no credentials returns None."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
data={"grant_type": "client_credentials"},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() is None
|
||||
|
||||
def test_extracts_from_revoke_endpoint(self, app):
|
||||
"""TEST: PCORS-CI-07 -- /oidc/revoke also accepts Basic Auth."""
|
||||
with app.test_request_context(
|
||||
"/oidc/revoke",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("rev-client")},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "rev-client"
|
||||
|
||||
def test_extracts_from_introspect_endpoint(self, app):
|
||||
"""TEST: PCORS-CI-08 -- /oidc/introspect also accepts Basic Auth."""
|
||||
with app.test_request_context(
|
||||
"/oidc/introspect",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("int-client")},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "int-client"
|
||||
|
||||
def test_returns_none_for_options_preflight(self, app):
|
||||
"""TEST: PCORS-CI-09 -- OPTIONS preflight cannot carry client credentials."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="OPTIONS",
|
||||
headers={
|
||||
"Origin": "https://app.com",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_effective_cors_origins
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetEffectiveCorsOrigins:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["CORS_ORIGINS"] = ["https://global.com"]
|
||||
return app
|
||||
|
||||
def test_global_config_for_non_oidc_endpoint(self, app):
|
||||
"""TEST: PCORS-EO-01 -- Non-OIDC path always uses global config."""
|
||||
with app.test_request_context("/api/v1/users", method="GET"):
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://global.com"]
|
||||
|
||||
def test_per_client_origins_for_oidc_endpoint(self, app):
|
||||
"""TEST: PCORS-EO-02 -- OIDC endpoint with configured client uses per-client origins."""
|
||||
fake_client = StubClient(allowed_cors_origins=["https://client.com"])
|
||||
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("test-client")},
|
||||
):
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://client.com"]
|
||||
|
||||
def test_fallback_when_client_not_found(self, app):
|
||||
"""TEST: PCORS-EO-03 -- Unknown client_id falls back to global config."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("unknown")},
|
||||
):
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = None
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://global.com"]
|
||||
|
||||
def test_fallback_when_allowed_cors_origins_is_none(self, app):
|
||||
"""TEST: PCORS-EO-04 -- Client with None origins falls back to global."""
|
||||
fake_client = StubClient(allowed_cors_origins=None)
|
||||
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("test-client")},
|
||||
):
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://global.com"]
|
||||
|
||||
def test_fallback_on_db_error(self, app):
|
||||
"""TEST: PCORS-EO-05 -- Database exception falls back to global config."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("test-client")},
|
||||
):
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.side_effect = Exception("DB down")
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://global.com"]
|
||||
|
||||
def test_fallback_when_no_client_id_extracted(self, app):
|
||||
"""TEST: PCORS-EO-06 -- OIDC path with no credentials falls back to global."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
data={"grant_type": "client_credentials"},
|
||||
):
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://global.com"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: OIDC endpoint CORS headers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOidcEndpointCorsIntegration:
|
||||
@pytest.fixture
|
||||
def app_with_global_and_client(self):
|
||||
"""Flask app with global CORS and route stubs for integration tests."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["CORS_ORIGINS"] = ["https://global.com"]
|
||||
app.config["CORS_SUPPORTS_CREDENTIALS"] = True
|
||||
|
||||
@app.route("/oidc/token", methods=["POST", "OPTIONS"])
|
||||
def oidc_token():
|
||||
return {"status": "ok"}, 200
|
||||
|
||||
@app.route("/api/v1/users", methods=["GET", "OPTIONS"])
|
||||
def api_users():
|
||||
return {"users": []}, 200
|
||||
|
||||
setup_cors(app)
|
||||
return app
|
||||
|
||||
def test_post_oidc_with_per_client_origin_includes_cors_headers(
|
||||
self, app_with_global_and_client
|
||||
):
|
||||
"""TEST: PCORS-INT-01 -- POST to /oidc/token with per-client origin and
|
||||
Basic Auth includes CORS headers. Per-client CORS applies to the actual
|
||||
request (which carries credentials), not the preflight."""
|
||||
fake_client = StubClient(allowed_cors_origins=["https://client-app.com"])
|
||||
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.post(
|
||||
"/oidc/token",
|
||||
headers={
|
||||
"Origin": "https://client-app.com",
|
||||
"Authorization": _basic_auth_header("test-client"),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert (
|
||||
resp.headers.get("Access-Control-Allow-Origin")
|
||||
== "https://client-app.com"
|
||||
)
|
||||
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
|
||||
|
||||
def test_post_oidc_with_non_matching_per_client_origin_no_cors_headers(
|
||||
self, app_with_global_and_client
|
||||
):
|
||||
"""TEST: PCORS-INT-02 -- POST with origin not in per-client list has no
|
||||
CORS headers."""
|
||||
fake_client = StubClient(allowed_cors_origins=["https://allowed.com"])
|
||||
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.post(
|
||||
"/oidc/token",
|
||||
headers={
|
||||
"Origin": "https://evil.com",
|
||||
"Authorization": _basic_auth_header("test-client"),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
||||
|
||||
def test_post_oidc_wildcard_client_echoes_origin(self, app_with_global_and_client):
|
||||
"""TEST: PCORS-INT-03 -- Client with '*' echoes the request origin."""
|
||||
fake_client = StubClient(allowed_cors_origins=["*"])
|
||||
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.post(
|
||||
"/oidc/token",
|
||||
headers={
|
||||
"Origin": "https://any-origin.com",
|
||||
"Authorization": _basic_auth_header("test-client"),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert (
|
||||
resp.headers.get("Access-Control-Allow-Origin")
|
||||
== "https://any-origin.com"
|
||||
)
|
||||
|
||||
def test_preflight_oidc_falls_back_to_global(self, app_with_global_and_client):
|
||||
"""TEST: PCORS-INT-04 -- OPTIONS preflight cannot carry client credentials,
|
||||
so it uses global CORS config. A preflight from a per-client-only origin
|
||||
that is not in the global list will not receive CORS headers."""
|
||||
fake_client = StubClient(allowed_cors_origins=["https://client-app.com"])
|
||||
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.options(
|
||||
"/oidc/token",
|
||||
headers={
|
||||
"Origin": "https://client-app.com",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
},
|
||||
)
|
||||
# Origin not in global list; no CORS headers on preflight
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
||||
|
||||
def test_preflight_oidc_with_global_origin_succeeds(self, app_with_global_and_client):
|
||||
"""TEST: PCORS-INT-05 -- OPTIONS preflight from a globally-allowed origin
|
||||
returns 204 with CORS headers even for OIDC endpoints."""
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.options(
|
||||
"/oidc/token",
|
||||
headers={
|
||||
"Origin": "https://global.com",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
|
||||
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
|
||||
|
||||
def test_non_oidc_endpoint_uses_global_cors(self, app_with_global_and_client):
|
||||
"""TEST: PCORS-INT-06 -- Non-OIDC endpoint uses global CORS config."""
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.options(
|
||||
"/api/v1/users",
|
||||
headers={
|
||||
"Origin": "https://global.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
|
||||
|
||||
def test_post_oidc_no_auth_uses_global_cors(self, app_with_global_and_client):
|
||||
"""TEST: PCORS-INT-07 -- POST to OIDC endpoint without credentials uses
|
||||
global CORS (cannot identify client)."""
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.post(
|
||||
"/oidc/token",
|
||||
headers={"Origin": "https://global.com"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert (
|
||||
resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
|
||||
)
|
||||
Reference in New Issue
Block a user