Merge pull request #37 from CoryHawkless/oidc-uplift

OIDC uplift
This commit is contained in:
2026-05-19 14:48:58 +09:30
committed by GitHub
31 changed files with 1808 additions and 135 deletions
+240
View File
@@ -0,0 +1,240 @@
# Per-Client CORS Origins for OIDC Endpoints
## Overview
Gatehouse OIDC now supports **per-client CORS origins**. This allows each OIDC client to declare which browser origins are permitted to make cross-origin requests to OIDC endpoints (`/oidc/token`, `/oidc/revoke`, `/oidc/userinfo`, `/oidc/introspect`).
Previously, CORS was controlled by a single server-wide `CORS_ORIGINS` environment variable. If your SPA's origin wasn't in that list, the browser would block requests to OIDC endpoints — even if your OIDC client was properly configured.
## How It Works
### The Problem
When a browser-based SPA (e.g., running at `http://localhost:8080`) exchanges an authorization code for tokens, it makes a POST request to `/oidc/token`. The browser sends a preflight OPTIONS request first, and the server must respond with CORS headers allowing the SPA's origin.
Previously, if `http://localhost:8080` wasn't in the server's `CORS_ORIGINS` env var, the preflight would fail and the SPA couldn't get tokens.
### The Solution
Each OIDC client can now declare its own `allowed_cors_origins`. When a request hits an OIDC endpoint, the server checks the client's CORS configuration first, then falls back to the global config.
## Configuration
### Setting CORS Origins on an OIDC Client
When creating or updating an OIDC client, set the `allowed_cors_origins` field:
```json
{
"name": "My SPA",
"client_id": "oidc_myapp",
"redirect_uris": ["http://localhost:8080/callback", "https://app.example.com/callback"],
"allowed_cors_origins": ["http://localhost:8080", "https://app.example.com"],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"scopes": ["openid", "profile", "email"]
}
```
### Auto-Derive from Redirect URIs
Set `allowed_cors_origins` to `["+"]` to automatically derive CORS origins from the client's `redirect_uris`. The server extracts the scheme, hostname, and port from each redirect URI.
```json
{
"redirect_uris": ["http://localhost:8080/callback", "https://app.example.com/callback"],
"allowed_cors_origins": ["+"]
}
```
This is equivalent to:
```json
{
"allowed_cors_origins": ["http://localhost:8080", "https://app.example.com"]
}
```
### Use Global Config (Default)
Set `allowed_cors_origins` to `null` (or omit it) to use the server's global `CORS_ORIGINS` config. This is the default behavior for existing clients.
```json
{
"allowed_cors_origins": null
}
```
### Allow All Origins (Not Recommended)
Set `allowed_cors_origins` to `["*"]` to allow any origin. **This is not recommended for production.**
```json
{
"allowed_cors_origins": ["*"]
}
```
## Affected Endpoints
The following OIDC endpoints support per-client CORS:
| Endpoint | Method | How Client is Identified |
|---|---|---|
| `/oidc/token` | POST | `client_id` in request body or Basic Auth header |
| `/oidc/revoke` | POST | `client_id` in request body or Basic Auth header |
| `/oidc/introspect` | POST | `client_id` in request body or Basic Auth header |
| `/oidc/userinfo` | GET/POST | `client_id` extracted from Bearer token |
## SPA Integration Guide
### Step 1: Register Your OIDC Client
Register your SPA as an OIDC client with the correct redirect URIs and CORS origins:
```json
{
"name": "My React App",
"redirect_uris": ["http://localhost:3000/callback"],
"allowed_cors_origins": ["http://localhost:3000"],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"scopes": ["openid", "profile", "email"],
"is_confidential": false,
"require_pkce": true
}
```
### Step 2: Use PKCE (Required for Public Clients)
Gatehouse requires PKCE for public clients. Generate a code verifier and challenge before redirecting to the authorize endpoint:
```javascript
// Generate PKCE
const codeVerifier = generateRandomString(128);
const codeChallenge = await sha256(codeVerifier);
const state = generateRandomString(32);
// Store verifier for later
sessionStorage.setItem('pkce_verifier', codeVerifier);
// Redirect to authorize
const authUrl = new URL('https://api.example.com/api/v1/oidc/authorize');
authUrl.searchParams.set('response_type', 'code');
authUrl.searchParams.set('client_id', 'oidc_myapp');
authUrl.searchParams.set('redirect_uri', 'http://localhost:3000/callback');
authUrl.searchParams.set('scope', 'openid profile email');
authUrl.searchParams.set('state', state);
authUrl.searchParams.set('code_challenge', codeChallenge);
authUrl.searchParams.set('code_challenge_method', 'S256');
window.location.href = authUrl.toString();
```
### Step 3: Exchange Code for Tokens
After the user authenticates and is redirected back to your callback page, exchange the authorization code for tokens:
```javascript
// Extract code from URL
const params = new URLSearchParams(window.location.search);
const code = params.get('code');
const state = params.get('state');
// Verify state matches
if (state !== sessionStorage.getItem('pkce_state')) {
throw new Error('State mismatch');
}
// Exchange code for tokens
const response = await fetch('https://api.example.com/api/v1/oidc/token', {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: new URLSearchParams({
grant_type: 'authorization_code',
code: code,
redirect_uri: 'http://localhost:3000/callback',
client_id: 'oidc_myapp',
code_verifier: sessionStorage.getItem('pkce_verifier'),
}),
});
const tokens = await response.json();
// tokens.access_token, tokens.id_token, tokens.refresh_token
```
The server will return CORS headers because `http://localhost:3000` is in the client's `allowed_cors_origins`.
### Step 4: Refresh Tokens
```javascript
const response = await fetch('https://api.example.com/api/v1/oidc/token', {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: new URLSearchParams({
grant_type: 'refresh_token',
refresh_token: storedRefreshToken,
client_id: 'oidc_myapp',
}),
});
```
### Step 5: Call UserInfo
```javascript
const response = await fetch('https://api.example.com/api/v1/oidc/userinfo', {
headers: {
'Authorization': `Bearer ${accessToken}`,
},
});
const userInfo = await response.json();
```
## Troubleshooting
### "CORS error" when exchanging code for tokens
**Cause**: Your SPA's origin is not in the client's `allowed_cors_origins` or the server's global `CORS_ORIGINS`.
**Fix**: Add your SPA's origin to the client's `allowed_cors_origins`:
```json
{
"allowed_cors_origins": ["http://localhost:3000"]
}
```
### "CORS error" on preflight OPTIONS request
**Cause**: The preflight request doesn't carry client credentials, so the server can't identify which client to check CORS origins for. It falls back to the global `CORS_ORIGINS`.
**Fix**: Either add your origin to the global `CORS_ORIGINS` env var, or ensure the actual POST request (after preflight) includes the `client_id` in the request body.
### CORS works for `/oidc/token` but not `/oidc/userinfo`
**Cause**: The userinfo endpoint identifies the client from the Bearer token. If the token doesn't contain a `client_id` claim, the server falls back to global config.
**Fix**: Ensure your access tokens include the `client_id` claim (this is the default behavior).
## API Reference
### OIDCClient Fields
| Field | Type | Description |
|---|---|---|
| `allowed_cors_origins` | `string[]` or `null` | List of allowed browser origins. `null` = use global config. `["+"]` = auto-derive from redirect URIs. `["*"]` = allow all (not recommended). |
### CORS Headers Returned
When a request's origin matches the client's allowed origins:
```
Access-Control-Allow-Origin: <request-origin>
Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS
Access-Control-Allow-Headers: Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token
Access-Control-Allow-Credentials: true
Access-Control-Max-Age: 3600
```
+56
View File
@@ -10,6 +10,8 @@ from gatehouse_app.models import Department, DepartmentMembership
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.extensions import db
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
class DepartmentCreateSchema(Schema):
@@ -127,6 +129,15 @@ def create_department(org_id):
db.session.add(dept)
db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_CREATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="department",
resource_id=str(dept.id),
description=f"Department '{dept.name}' created",
)
return api_response(
data={"department": dept.to_dict()},
message="Department created successfully",
@@ -255,6 +266,15 @@ def update_department(org_id, dept_id):
db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="department",
resource_id=str(dept.id),
description=f"Department '{dept.name}' updated",
)
return api_response(
data={"department": dept.to_dict()},
message="Department updated successfully",
@@ -308,6 +328,15 @@ def delete_department(org_id, dept_id):
dept.deleted_at = db.func.now()
db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_DELETED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="department",
resource_id=str(dept.id),
description=f"Department '{dept.name}' deleted",
)
return api_response(
message="Department deleted successfully",
)
@@ -461,6 +490,15 @@ def add_department_member(org_id, dept_id):
db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_MEMBER_ADDED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user.id),
description=f"Added user {user.email} to department '{dept.name}'",
)
member_dict = membership.to_dict()
member_dict["user"] = user.to_dict()
@@ -533,6 +571,15 @@ def remove_department_member(org_id, dept_id, user_id):
membership.deleted_at = db.func.now()
db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_MEMBER_REMOVED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user_id),
description=f"Removed user from department '{dept.name}'",
)
return api_response(
message="Member removed successfully",
)
@@ -699,5 +746,14 @@ def set_dept_cert_policy(org_id, dept_id):
db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_CERT_POLICY_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="department",
resource_id=str(dept_id),
description=f"Certificate policy updated for department '{dept.name}'",
)
return api_response(data={"cert_policy": policy.to_dict()}, message="Certificate policy saved")
@@ -3,6 +3,8 @@ from flask import g, request
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/admin/oauth/providers", methods=["GET"])
@@ -78,6 +80,14 @@ def admin_configure_app_provider(provider: str):
db.session.add(cfg)
db.session.commit()
AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_CONFIG_UPDATE if cfg else AuditAction.EXTERNAL_AUTH_CONFIG_CREATE,
user_id=g.current_user.id,
resource_type="oauth_provider",
resource_id=provider,
description=f"OAuth provider '{provider}' configured (enabled={cfg.is_enabled})",
)
return api_response(
data={"provider": {"id": provider, "client_id": cfg.client_id, "is_enabled": cfg.is_enabled}},
message=f"{provider.capitalize()} OAuth provider configured successfully",
@@ -104,4 +114,13 @@ def admin_delete_app_provider(provider: str):
return api_response(success=False, message=f"Provider '{provider}' is not configured", status=404, error_type="NOT_FOUND")
cfg.delete()
AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_CONFIG_DELETE,
user_id=g.current_user.id,
resource_type="oauth_provider",
resource_id=provider,
description=f"OAuth provider '{provider}' configuration removed",
)
return api_response(message=f"{provider.capitalize()} OAuth provider configuration removed")
+15
View File
@@ -26,6 +26,9 @@ from gatehouse_app.exceptions.auth_exceptions import (
AccountSuspendedError,
AccountInactiveError,
)
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
logger = logging.getLogger(__name__)
@@ -849,6 +852,18 @@ def oidc_register():
)
client.save()
OIDCAuditService.log_event(
event_type="client_registration",
client_id=client_id,
user_id=g.current_user.id if hasattr(g, "current_user") else None,
success=True,
metadata={
"client_name": client_name,
"redirect_uris": redirect_uris,
"organization_id": str(organization.id),
},
)
response = jsonify({
"client_id": client_id,
"client_secret": client_secret,
+33 -4
View File
@@ -8,6 +8,8 @@ from gatehouse_app.utils.decorators import login_required, require_admin, full_a
from gatehouse_app.models.organization import OrganizationApiKey
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.extensions import db
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
class ApiKeyCreateSchema(Schema):
@@ -130,7 +132,16 @@ def create_api_key(org_id):
name=data["name"],
description=data.get("description"),
)
AuditService.log_action(
action=AuditAction.ORG_API_KEY_CREATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="api_key",
resource_id=str(api_key.id),
description=f"API key '{api_key.name}' created",
)
# Return the key data with the plain text key (only on creation)
key_dict = api_key.to_dict()
key_dict["key"] = plain_key # Include plain text only on creation
@@ -219,9 +230,18 @@ def update_api_key(org_id, key_id):
api_key.name = data["name"]
if "description" in data:
api_key.description = data["description"]
api_key.save()
AuditService.log_action(
action=AuditAction.ORG_API_KEY_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="api_key",
resource_id=str(api_key.id),
description=f"API key '{api_key.name}' updated",
)
return api_response(
data={"api_key": api_key.to_dict()},
message="API key updated successfully",
@@ -293,7 +313,16 @@ def delete_api_key(org_id, key_id):
# Soft delete the API key
api_key.delete(soft=True)
AuditService.log_action(
action=AuditAction.ORG_API_KEY_DELETED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="api_key",
resource_id=str(api_key.id),
description=f"API key '{api_key.name}' deleted",
)
return api_response(
message="API key deleted successfully",
)
+21
View File
@@ -6,6 +6,8 @@ from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin
from gatehouse_app.extensions import db
from gatehouse_app.api.v1.organizations._helpers import _get_system_ca_dict
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/organizations/<org_id>/cas", methods=["GET"])
@@ -66,6 +68,16 @@ def update_org_ca(org_id, ca_id):
ca.max_cert_validity_hours = data["max_cert_validity_hours"]
db.session.commit()
AuditService.log_action(
action=AuditAction.CA_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="CA",
resource_id=ca_id,
description=f"CA '{ca.name}' updated",
)
return api_response(data={"ca": ca.to_dict()}, message="CA updated successfully")
except ValidationError as e:
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
@@ -150,6 +162,15 @@ def create_org_ca(org_id):
return api_response(success=False, message="A CA with that name already exists in this organization (it may have been recently deleted — choose a different name).", status=400, error_type="DUPLICATE_NAME")
raise
AuditService.log_action(
action=AuditAction.CA_CREATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="CA",
resource_id=str(ca.id),
description=f"CA '{ca.name}' created",
)
return api_response(data={"ca": ca.to_dict()}, message="CA created successfully", status=201)
except MaValidationError as e:
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
@@ -5,6 +5,8 @@ from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
from gatehouse_app.extensions import db, bcrypt
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/organizations/<org_id>/clients", methods=["GET"])
@@ -28,6 +30,7 @@ def list_org_clients(org_id):
"redirect_uris": c.redirect_uris,
"scopes": c.scopes,
"grant_types": c.grant_types,
"allowed_cors_origins": c.allowed_cors_origins,
"is_active": c.is_active,
"created_at": c.created_at.isoformat() + "Z",
}
@@ -78,6 +81,15 @@ def create_org_client(org_id):
db.session.add(client)
db.session.commit()
AuditService.log_action(
action=AuditAction.ORG_CLIENT_CREATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="oidc_client",
resource_id=str(client.id),
description=f"OIDC client '{client.name}' created",
)
return api_response(
data={
"client": {
@@ -125,6 +137,15 @@ def update_org_client(org_id, client_id):
db.session.commit()
AuditService.log_action(
action=AuditAction.ORG_CLIENT_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="oidc_client",
resource_id=str(client.id),
description=f"OIDC client '{client.name}' updated",
)
return api_response(
data={
"client": {
@@ -154,4 +175,14 @@ def delete_org_client(org_id, client_id):
client.is_active = False
db.session.commit()
AuditService.log_action(
action=AuditAction.ORG_CLIENT_DEACTIVATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="oidc_client",
resource_id=str(client.id),
description=f"OIDC client '{client.name}' deactivated",
)
return api_response(data={}, message="Client deactivated successfully")
@@ -7,6 +7,8 @@ from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
from gatehouse_app.schemas.organization_schema import OrganizationCreateSchema, OrganizationUpdateSchema
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/organizations", methods=["POST"])
@@ -32,6 +34,14 @@ def create_organization():
description=data.get("description"),
logo_url=data.get("logo_url"),
)
AuditService.log_action(
action=AuditAction.ORG_CREATE,
user_id=g.current_user.id,
organization_id=org.id,
resource_type="organization",
resource_id=str(org.id),
description=f"Organization '{org.name}' created",
)
return api_response(data={"organization": org.to_dict()}, message="Organization created successfully", status=201)
except ValidationError as e:
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
@@ -60,6 +70,14 @@ def update_organization(org_id):
data = schema.load(request.json)
org = OrganizationService.get_organization_by_id(org_id)
org = OrganizationService.update_organization(org=org, user_id=g.current_user.id, **data)
AuditService.log_action(
action=AuditAction.ORG_UPDATE,
user_id=g.current_user.id,
organization_id=org.id,
resource_type="organization",
resource_id=str(org.id),
description=f"Organization '{org.name}' updated",
)
return api_response(data={"organization": org.to_dict()}, message="Organization updated successfully")
except ValidationError as e:
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
@@ -92,4 +110,12 @@ def delete_organization(org_id):
)
OrganizationService.force_delete_organization(org=org, user_id=caller.id)
AuditService.log_action(
action=AuditAction.ORG_DELETE,
user_id=caller.id,
organization_id=org.id,
resource_type="organization",
resource_id=str(org.id),
description=f"Organization '{org.name}' deleted",
)
return api_response(message="Organization deleted successfully")
@@ -136,6 +136,15 @@ def cancel_org_invite(org_id, invite_id):
return api_response(success=False, message="Invite not found", status=404)
invite.delete(soft=True)
AuditService.log_action(
action=AuditAction.ORG_INVITE_CANCELLED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="org_invite",
resource_id=invite.id,
metadata={"invited_email": invite.email, "role": invite.role},
description=f"Invitation for {invite.email} cancelled",
)
return api_response(data={}, message="Invite cancelled")
@@ -243,6 +252,30 @@ def accept_invite(token):
invite.accept()
if invite.invited_by and invite.invited_by.email:
from gatehouse_app.services.email_templates import build_invite_accepted_html
from gatehouse_app.services.notification_service import NotificationService
member_display = user.full_name or user.email
inviter_display = invite.invited_by.full_name or invite.invited_by.email
org_link = f"{current_app.config.get('APP_URL', '')}/organizations/{invite.organization_id}"
html_body = build_invite_accepted_html(
inviter_name=inviter_display,
member_name=member_display,
member_email=user.email,
org_name=invite.organization.name,
role=invite.role,
org_link=org_link,
)
NotificationService._send_email_async(
to_address=invite.invited_by.email,
subject=f"{member_display} accepted your invitation to {invite.organization.name}",
body=f"{member_display} has accepted your invitation to join {invite.organization.name} on Secuird.",
html_body=html_body,
)
has_webauthn = user.has_webauthn_enabled()
has_totp = user.has_totp_enabled()
+35 -1
View File
@@ -7,7 +7,8 @@ from gatehouse_app.utils.decorators import login_required, require_admin, full_a
from gatehouse_app.schemas.organization_schema import InviteMemberSchema, UpdateMemberRoleSchema
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.utils.constants import OrganizationRole
from gatehouse_app.utils.constants import AuditAction, OrganizationRole
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/organizations/<org_id>/members", methods=["GET"])
@@ -43,6 +44,14 @@ def add_organization_member(org_id):
role = OrganizationRole(data["role"])
member = OrganizationService.add_member(org=org, user_id=user.id, role=role, inviter_id=g.current_user.id)
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ADD,
user_id=g.current_user.id,
organization_id=org.id,
resource_type="user",
resource_id=str(user.id),
description=f"Added user {user.email} to organization with role {role.value}",
)
member_dict = member.to_dict()
member_dict["user"] = user.to_dict()
return api_response(data={"member": member_dict}, message="Member added successfully", status=201)
@@ -60,6 +69,14 @@ def remove_organization_member(org_id, user_id):
OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id)
except ValueError as e:
return api_response(success=False, message=str(e), status=403, error_type="OWNER_PROTECTION")
AuditService.log_action(
action=AuditAction.ORG_MEMBER_REMOVE,
user_id=g.current_user.id,
organization_id=org.id,
resource_type="user",
resource_id=str(user_id),
description=f"Removed user {user_id} from organization",
)
return api_response(message="Member removed successfully")
@@ -74,6 +91,14 @@ def update_member_role(org_id, user_id):
org = OrganizationService.get_organization_by_id(org_id)
new_role = OrganizationRole(data["role"])
member = OrganizationService.update_member_role(org=org, user_id=user_id, new_role=new_role, updater_id=g.current_user.id)
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ROLE_CHANGE,
user_id=g.current_user.id,
organization_id=org.id,
resource_type="user",
resource_id=str(user_id),
description=f"Changed role for user {user_id} to {new_role.value}",
)
member_dict = member.to_dict()
member_dict["user"] = member.user.to_dict()
return api_response(data={"member": member_dict}, message="Member role updated successfully")
@@ -180,4 +205,13 @@ def send_mfa_reminder(org_id, user_id):
html_body=html_body,
)
AuditService.log_action(
action=AuditAction.ORG_MFA_REMINDER_SENT,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user_id),
description=f"MFA reminder sent to {user.email}",
)
return api_response(data={}, message="Reminder sent successfully")
+22 -1
View File
@@ -3,8 +3,9 @@ from flask import g, request
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
from gatehouse_app.utils.constants import OrganizationRole
from gatehouse_app.utils.constants import AuditAction, OrganizationRole
from gatehouse_app.extensions import db
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/organizations/<org_id>/roles", methods=["GET"])
@@ -59,6 +60,16 @@ def assign_role_to_member(org_id, role_name):
membership.role = new_role
db.session.commit()
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ROLE_CHANGE,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(target_user_id),
description=f"Role changed to {new_role.value} for user {target_user_id}",
)
return api_response(data={"user_id": target_user_id, "role": new_role.value}, message=f"Role updated to {new_role.value}")
@@ -82,4 +93,14 @@ def remove_role_from_member(org_id, role_name, user_id):
org = OrganizationService.get_organization_by_id(org_id)
OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id)
AuditService.log_action(
action=AuditAction.ORG_MEMBER_REMOVE,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user_id),
description=f"Member {user_id} removed from organization via role removal",
)
return api_response(data={"user_id": user_id}, message="Member removed from organization")
+65
View File
@@ -10,6 +10,8 @@ from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.exceptions import OrganizationNotFoundError
from gatehouse_app.extensions import db
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
class PrincipalCreateSchema(Schema):
@@ -127,6 +129,15 @@ def create_principal(org_id):
db.session.add(principal)
db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_CREATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="principal",
resource_id=str(principal.id),
description=f"Principal '{principal.name}' created",
)
return api_response(
data={"principal": principal.to_dict()},
message="Principal created successfully",
@@ -255,6 +266,15 @@ def update_principal(org_id, principal_id):
db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="principal",
resource_id=str(principal.id),
description=f"Principal '{principal.name}' updated",
)
return api_response(
data={"principal": principal.to_dict()},
message="Principal updated successfully",
@@ -308,6 +328,15 @@ def delete_principal(org_id, principal_id):
principal.deleted_at = db.func.now()
db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_DELETED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="principal",
resource_id=str(principal.id),
description=f"Principal '{principal.name}' deleted",
)
return api_response(
message="Principal deleted successfully",
)
@@ -476,6 +505,15 @@ def add_principal_member(org_id, principal_id):
db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_MEMBER_ADDED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user.id),
description=f"Added user {user.email} to principal '{principal.name}'",
)
member_dict = membership.to_dict()
member_dict["user"] = user.to_dict()
@@ -548,6 +586,15 @@ def remove_principal_member(org_id, principal_id, user_id):
membership.deleted_at = db.func.now()
db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_MEMBER_REMOVED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user_id),
description=f"Removed user from principal '{principal.name}'",
)
return api_response(
message="Member removed successfully",
)
@@ -697,6 +744,15 @@ def link_principal_to_department(org_id, principal_id, dept_id):
error_type="SERVER_ERROR",
)
AuditService.log_action(
action=AuditAction.PRINCIPAL_DEPARTMENT_LINKED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="principal",
resource_id=str(principal_id),
description=f"Principal '{principal.name}' linked to department '{dept.name}'",
)
return api_response(
data={
"principal": principal.to_dict(),
@@ -774,6 +830,15 @@ def unlink_principal_from_department(org_id, principal_id, dept_id):
link.deleted_at = db.func.now()
db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_DEPARTMENT_UNLINKED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="principal",
resource_id=str(principal_id),
description=f"Principal '{principal.name}' unlinked from department '{dept.name}'",
)
return api_response(
message="Principal unlinked from department successfully",
)
+2
View File
@@ -9,6 +9,7 @@ from gatehouse_app.utils.response import api_response
from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
from gatehouse_app.utils.constants import AuditAction
logger = logging.getLogger(__name__)
@@ -105,6 +106,7 @@ def login():
@superadmin_bp.route("/auth/logout", methods=["POST"])
@superadmin_required
@superadmin_audit_log(action=AuditAction.USER_LOGOUT, resource_type="session")
def logout():
"""Superadmin logout endpoint.
+10 -6
View File
@@ -15,7 +15,7 @@ def superadmin_required(f):
"""Decorator to require superadmin Bearer token authentication.
Extracts token from Authorization: Bearer {token} header,
validates the session against SuperadminSession table,
validates the session against the unified sessions table,
and sets g.current_superadmin and g.superadmin_session.
Returns 401 if no valid session, 403 if not a superadmin.
@@ -46,10 +46,14 @@ def superadmin_required(f):
token = parts[1]
# Import here to avoid circular imports
from gatehouse_app.models.superadmin import SuperadminSession, Superadmin
from gatehouse_app.models.user.session import Session
from gatehouse_app.models.superadmin import Superadmin
from gatehouse_app.utils.constants import SessionType
# Get active session by token
session = SuperadminSession.query.filter_by(token=token).first()
# Get active session by token, scoped to superadmin
session = Session.query.filter_by(
token=token, owner_type=SessionType.SUPERADMIN
).first()
if not session:
return api_response(
@@ -68,8 +72,8 @@ def superadmin_required(f):
error_type="SESSION_INACTIVE"
)
# Get the superadmin
superadmin = session.superadmin
# Get the superadmin by owner_id
superadmin = Superadmin.query.get(session.owner_id)
if not superadmin:
return api_response(
success=False,
+87 -2
View File
@@ -1,6 +1,12 @@
"""CORS middleware configuration."""
import base64
import json
from urllib.parse import parse_qs
from flask import request, make_response
from gatehouse_app.models import OIDCClient
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
ALLOWED_HEADERS = (
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
@@ -40,6 +46,85 @@ def _cors_origin_header(cors_origins, request_origin):
return None
def _get_oidc_client_id_from_request():
"""Extract client_id from OIDC endpoint requests."""
path = request.path
# POST to /oidc/token, /oidc/revoke, /oidc/introspect
if request.method == "POST" and any(
path.endswith(ep) for ep in ("/oidc/token", "/oidc/revoke", "/oidc/introspect")
):
# Try Basic Auth header first
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
try:
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
client_id, _, _ = decoded.partition(":")
if client_id:
return client_id
except Exception:
pass
# Try form body
if request.form:
client_id = request.form.get("client_id")
if client_id:
return client_id
# Try JSON body
if request.is_json:
try:
client_id = request.json.get("client_id")
if client_id:
return client_id
except Exception:
pass
return None
# GET/POST to /oidc/userinfo
if path.endswith("/oidc/userinfo"):
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
try:
payload_b64 = token.split(".")[1]
padding = 4 - len(payload_b64) % 4
if padding != 4:
payload_b64 += "=" * padding
payload = json.loads(base64.urlsafe_b64decode(payload_b64))
return payload.get("client_id")
except Exception:
return None
return None
def _get_effective_cors_origins(app, request):
"""Get effective CORS origins, checking per-client config for OIDC endpoints."""
global_origins = app.config.get("CORS_ORIGINS", [])
if "/oidc/" not in request.path:
return global_origins
try:
client_id = _get_oidc_client_id_from_request()
if not client_id:
return global_origins
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return global_origins
effective = client.get_effective_origins()
if effective is not None:
return effective
except Exception:
pass
return global_origins
def setup_cors(app):
"""
Configure CORS for the application.
@@ -54,7 +139,7 @@ def setup_cors(app):
"""Handle CORS preflight OPTIONS requests."""
if request.method == "OPTIONS":
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
cors_origins = _get_effective_cors_origins(app, request)
if not _is_origin_allowed(origin, cors_origins):
return None
@@ -73,7 +158,7 @@ def setup_cors(app):
def after_request_cors(response):
"""Add CORS headers to non-preflight responses."""
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
cors_origins = _get_effective_cors_origins(app, request)
allow_origin = _cors_origin_header(cors_origins, origin)
if allow_origin:
-2
View File
@@ -118,7 +118,6 @@ from gatehouse_app.models.zerotier import ( # noqa: F401
from gatehouse_app.models.superadmin import ( # noqa: F401
Superadmin,
SuperadminSession,
SuperadminSessionStatus,
)
from gatehouse_app.models.superadmin_audit_log import SuperadminAuditLog # noqa: F401
from gatehouse_app.models.security.user_security_policy import ( # noqa: F401
@@ -186,6 +185,5 @@ __all__ = [
# Superadmin
"Superadmin",
"SuperadminSession",
"SuperadminSessionStatus",
"SuperadminAuditLog",
]
+34
View File
@@ -1,4 +1,6 @@
"""OIDC Client model."""
from urllib.parse import urlparse
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType
@@ -21,6 +23,7 @@ class OIDCClient(BaseModel):
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
response_types = db.Column(db.JSON, nullable=False) # Allowed response types
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
allowed_cors_origins = db.Column(db.JSON, nullable=True, default=None) # Per-client CORS origins
# Client metadata
logo_uri = db.Column(db.String(512), nullable=True)
@@ -81,6 +84,37 @@ class OIDCClient(BaseModel):
"""Check if a redirect URI is allowed for this client."""
return redirect_uri in self.redirect_uris
def get_effective_origins(self) -> list | None:
"""Get effective CORS origins for this client.
Returns None to signal "use global config", a derived list from
redirect_uris when "+" is present, or the configured list as-is.
"""
if self.allowed_cors_origins is None:
return None
if "+" in self.allowed_cors_origins:
origins = set()
for uri in self.redirect_uris:
parsed = urlparse(uri)
if parsed.scheme and parsed.hostname:
port = f":{parsed.port}" if parsed.port else ""
origins.add(f"{parsed.scheme}://{parsed.hostname}{port}")
return sorted(origins)
return list(self.allowed_cors_origins)
def is_origin_allowed(self, origin: str) -> bool | None:
"""Check if a browser origin is allowed for CORS.
Returns True/False when a per-client list is configured,
or None to defer to the global CORS policy.
"""
effective = self.get_effective_origins()
if effective is None:
return None
if "*" in effective:
return True
return origin in effective
def has_scope(self, scope: str) -> bool:
"""Check if client is allowed to request a specific scope."""
return scope in self.scopes
+2 -2
View File
@@ -1,5 +1,5 @@
"""Superadmin models."""
from gatehouse_app.models.superadmin.superadmin import Superadmin
from gatehouse_app.models.superadmin.superadmin_session import SuperadminSession, SuperadminSessionStatus
from gatehouse_app.models.user.session import Session as SuperadminSession
__all__ = ["Superadmin", "SuperadminSession", "SuperadminSessionStatus"]
__all__ = ["Superadmin", "SuperadminSession"]
@@ -23,11 +23,15 @@ class Superadmin(BaseModel):
is_active = db.Column(db.Boolean, default=True, nullable=False)
last_login_at = db.Column(db.DateTime, nullable=True)
# Relationship to sessions
# Relationship to sessions (unified model, scoped to superadmin owner_type)
sessions = db.relationship(
"SuperadminSession",
back_populates="superadmin",
cascade="all, delete-orphan"
"Session",
primaryjoin=(
"and_(Superadmin.id == foreign(Session.owner_id), "
"Session.owner_type == 'superadmin')"
),
cascade="all, delete-orphan",
lazy="dynamic",
)
# Relationship to audit logs
@@ -1,80 +0,0 @@
"""Superadmin session model."""
import logging
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
logger = logging.getLogger(__name__)
class SuperadminSessionStatus:
"""Session status constants."""
ACTIVE = "active"
REVOKED = "revoked"
EXPIRED = "expired"
class SuperadminSession(BaseModel):
"""Session model for superadmin authentication."""
__tablename__ = "superadmin_sessions"
superadmin_id = db.Column(
db.String(36),
db.ForeignKey("superadmins.id"),
nullable=False,
index=True
)
token = db.Column(db.String(255), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
last_activity_at = db.Column(
db.DateTime,
nullable=False,
default=lambda: datetime.now(timezone.utc)
)
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
# Relationship
superadmin = db.relationship("Superadmin", back_populates="sessions")
def __repr__(self):
return f"<SuperadminSession superadmin_id={self.superadmin_id}>"
def is_active(self):
"""Check if session is currently active."""
now = datetime.now(timezone.utc)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return (
self.deleted_at is None
and self.revoked_at is None
and expires_at > now
)
def is_expired(self):
"""Check if session has expired."""
now = datetime.now(timezone.utc)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
def revoke(self, reason: str = None):
"""Revoke the session."""
self.revoked_at = datetime.now(timezone.utc)
if reason:
self.revoked_reason = reason
from gatehouse_app import db
db.session.commit()
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
exclude.append("token")
return super().to_dict(exclude=exclude)
+54 -12
View File
@@ -3,15 +3,24 @@ from datetime import datetime, timedelta, timezone
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import SessionStatus
from gatehouse_app.utils.constants import SessionStatus, SessionType
class Session(BaseModel):
"""Session model for tracking user sessions."""
"""Session model for tracking user and superadmin sessions."""
__tablename__ = "sessions"
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True)
# Owner discriminator — determines which table the owner_id references
owner_type = db.Column(
db.String(20), nullable=False, default=SessionType.USER, index=True
)
owner_id = db.Column(db.String(36), nullable=False, index=True)
# Legacy column kept for backward compatibility during migration;
# new code should use owner_id / owner_type.
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True)
token = db.Column(db.String(255), unique=True, nullable=False, index=True)
status = db.Column(db.Enum(SessionStatus), default=SessionStatus.ACTIVE, nullable=False)
@@ -34,21 +43,37 @@ class Session(BaseModel):
# Relationships
user = db.relationship("User", back_populates="sessions")
# Composite index for owner-scoped queries
__table_args__ = (
db.Index("ix_sessions_owner_type_owner_id", "owner_type", "owner_id"),
)
# ---- Convenience properties ------------------------------------------------
@property
def is_user(self):
return self.owner_type == SessionType.USER
@property
def is_superadmin(self):
return self.owner_type == SessionType.SUPERADMIN
# ---- Core methods ----------------------------------------------------------
def __repr__(self):
"""String representation of Session."""
return f"<Session user_id={self.user_id} status={self.status}>"
return f"<Session owner_type={self.owner_type} owner_id={self.owner_id} status={self.status}>"
def is_active(self):
"""Check if session is currently active.
Sessions are evaluated against two independent timeouts:
User sessions are evaluated against two independent timeouts:
- Idle timeout: expires if no request has been made within
SESSION_IDLE_TIMEOUT seconds (default 15 min).
- Absolute timeout: expires if SESSION_ABSOLUTE_TIMEOUT seconds
have elapsed since the session was created (default 8 h),
regardless of activity.
have elapsed since the session was created (default 8 h).
A session must satisfy *both* constraints to remain active.
Superadmin sessions use absolute timeout only (no idle timeout).
A session must satisfy *all* applicable constraints to remain active.
"""
now = datetime.now(timezone.utc)
created_at = self.created_at
@@ -59,12 +84,21 @@ class Session(BaseModel):
if last_activity_at.tzinfo is None:
last_activity_at = last_activity_at.replace(tzinfo=timezone.utc)
idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout)
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
if self.is_superadmin:
# Superadmin: absolute timeout only
return (
self.status == SessionStatus.ACTIVE
and now < absolute_expires_at
and self.deleted_at is None
)
# User: idle + absolute timeout
idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout)
return (
self.status == SessionStatus.ACTIVE
and now < idle_expires_at
@@ -83,6 +117,8 @@ class Session(BaseModel):
capped so that the session never exceeds the absolute lifetime
(``created_at + absolute timeout``).
Superadmin sessions only update last_activity_at (no sliding window).
Args:
duration_seconds: Override for the idle timeout. When *None*
(the common case), the value is read from
@@ -90,6 +126,12 @@ class Session(BaseModel):
"""
now = datetime.now(timezone.utc)
if self.is_superadmin:
# Superadmin: just bump last_activity_at, no sliding window
self.last_activity_at = now
db.session.commit()
return
if duration_seconds is None:
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
+3 -1
View File
@@ -8,7 +8,7 @@ from gatehouse_app.extensions import db, bcrypt
from gatehouse_app.models.user.user import User
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, UserStatus, AuditAction
from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, SessionType, UserStatus, AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError
from gatehouse_app.services.audit_service import AuditService
@@ -165,6 +165,8 @@ class AuthService:
# Create session
session = Session(
owner_type=SessionType.USER,
owner_id=user.id,
user_id=user.id,
token=token,
status=SessionStatus.ACTIVE,
+48
View File
@@ -562,3 +562,51 @@ def build_contact_enquiry_html(
<p style="margin: 0; color: {TEXT_COLOR}; font-size: 14px; line-height: 1.6; white-space: pre-wrap;">{message_display}</p>
'''
return get_base_html(content, f"Secuird Website: {type_label}", f"New {type_label} from {submitter_email}")
def build_invite_accepted_html(
inviter_name: str,
member_name: str,
member_email: str,
org_name: str,
role: str,
org_link: Optional[str] = None,
) -> str:
"""Build invite accepted notification email.
Args:
inviter_name: Name of the person who sent the invite
member_name: Name of the person who accepted
member_email: Email of the person who accepted
org_name: Organization name
role: Role assigned to the member
org_link: Optional link to view the organization
Returns:
HTML email string
"""
content = f'''
<h2 style="margin: 0 0 20px 0; color: {TEXT_COLOR}; font-size: 20px; font-weight: 600;">Invitation Accepted</h2>
<p style="margin: 0 0 20px 0; color: {TEXT_COLOR}; font-size: 15px; line-height: 1.6;">
<strong>{member_name}</strong> has accepted your invitation to join <strong>{org_name}</strong> on Secuird.
</p>
{get_alert_box(f"<strong>{member_name}</strong> ({member_email}) has joined <strong>{org_name}</strong>", "success", "")}
<table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="margin: 20px 0; background-color: {BACKGROUND_COLOR}; border-radius: 8px;">
<tr>
<td style="padding: 20px;">
<h3 style="margin: 0 0 16px 0; color: {TEXT_COLOR}; font-size: 14px; font-weight: 600;">Membership Details</h3>
<table role="presentation" width="100%" cellspacing="0" cellpadding="0">
{get_detail_row("Member", member_name)}
{get_detail_row("Email", member_email)}
{get_detail_row("Organization", org_name)}
{get_detail_row("Role", role)}
</table>
</td>
</tr>
</table>
'''
if org_link:
content += get_action_button(org_link, "View Organization", PRIMARY_COLOR)
return get_base_html(content, f"Invitation accepted: {org_name}", f"{member_name} has joined {org_name}")
+64 -11
View File
@@ -1,7 +1,7 @@
"""Session service."""
from datetime import datetime, timezone
from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import SessionStatus
from gatehouse_app.utils.constants import SessionStatus, SessionType
class SessionService:
@@ -28,18 +28,22 @@ class SessionService:
).first()
@staticmethod
def get_user_sessions(user_id, active_only=True):
"""
Get all sessions for a user.
def get_owner_sessions(owner_type, owner_id, active_only=True):
"""Get all sessions for an owner (user or superadmin).
Args:
user_id: User ID
owner_type: SessionType.USER or SessionType.SUPERADMIN
owner_id: Owner ID
active_only: If True, only return active sessions
Returns:
List of Session instances
"""
query = Session.query.filter_by(user_id=user_id, deleted_at=None)
query = Session.query.filter_by(
owner_type=owner_type,
owner_id=owner_id,
deleted_at=None,
)
if active_only:
query = query.filter_by(status=SessionStatus.ACTIVE).filter(
@@ -49,18 +53,67 @@ class SessionService:
return query.all()
@staticmethod
def revoke_user_sessions(user_id, reason="User logged out from all devices"):
def get_user_sessions(user_id, active_only=True):
"""Get all sessions for a user.
Args:
user_id: User ID
active_only: If True, only return active sessions
Returns:
List of Session instances
"""
Revoke all active sessions for a user.
return SessionService.get_owner_sessions(
SessionType.USER, user_id, active_only=active_only
)
@staticmethod
def get_superadmin_sessions(superadmin_id, active_only=True):
"""Get all sessions for a superadmin.
Args:
superadmin_id: Superadmin ID
active_only: If True, only return active sessions
Returns:
List of Session instances
"""
return SessionService.get_owner_sessions(
SessionType.SUPERADMIN, superadmin_id, active_only=active_only
)
@staticmethod
def revoke_owner_sessions(owner_type, owner_id, reason="Logged out from all devices"):
"""Revoke all active sessions for an owner.
Args:
owner_type: SessionType.USER or SessionType.SUPERADMIN
owner_id: Owner ID
reason: Reason for revocation
"""
sessions = SessionService.get_owner_sessions(owner_type, owner_id, active_only=True)
for session in sessions:
session.revoke(reason=reason)
@staticmethod
def revoke_user_sessions(user_id, reason="User logged out from all devices"):
"""Revoke all active sessions for a user.
Args:
user_id: User ID
reason: Reason for revocation
"""
sessions = SessionService.get_user_sessions(user_id, active_only=True)
SessionService.revoke_owner_sessions(SessionType.USER, user_id, reason=reason)
for session in sessions:
session.revoke(reason=reason)
@staticmethod
def revoke_superadmin_sessions(superadmin_id, reason="Superadmin logged out"):
"""Revoke all active sessions for a superadmin.
Args:
superadmin_id: Superadmin ID
reason: Reason for revocation
"""
SessionService.revoke_owner_sessions(SessionType.SUPERADMIN, superadmin_id, reason=reason)
@staticmethod
def cleanup_expired_sessions():
@@ -6,7 +6,9 @@ from typing import Optional
from flask import request, current_app
from gatehouse_app.extensions import db, bcrypt
from gatehouse_app.models.superadmin import Superadmin, SuperadminSession
from gatehouse_app.models.superadmin import Superadmin
from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import SessionType
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
@@ -70,15 +72,17 @@ class SuperadminAuthService:
duration_seconds: Session duration in seconds (default 8 hours)
Returns:
SuperadminSession instance
Session instance
"""
# Generate secure token
token = secrets.token_urlsafe(32)
# Create session
session = SuperadminSession(
superadmin_id=superadmin_id,
# Create session using unified model
session = Session(
owner_type=SessionType.SUPERADMIN,
owner_id=superadmin_id,
token=token,
status="active",
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
last_activity_at=datetime.now(timezone.utc),
ip_address=request.remote_addr,
@@ -97,7 +101,9 @@ class SuperadminAuthService:
session_id: Session ID to revoke
reason: Optional revocation reason
"""
session = SuperadminSession.query.get(session_id)
session = Session.query.filter_by(
id=session_id, owner_type=SessionType.SUPERADMIN
).first()
if session:
session.revoke(reason=reason)
logger.info(f"[SuperadminAuth] Session {session_id} revoked: {reason or 'No reason'}")
@@ -111,9 +117,11 @@ class SuperadminAuthService:
except_token: Optional token to keep (current session)
reason: Optional revocation reason
"""
query = SuperadminSession.query.filter_by(superadmin_id=superadmin_id)
query = Session.query.filter_by(
owner_type=SessionType.SUPERADMIN, owner_id=superadmin_id
)
if except_token:
query = query.filter(SuperadminSession.token != except_token)
query = query.filter(Session.token != except_token)
sessions = query.all()
for session in sessions:
+28
View File
@@ -52,6 +52,13 @@ class SessionStatus(str, Enum):
REVOKED = "revoked"
class SessionType(str, Enum):
"""Session owner type discriminator."""
USER = "user"
SUPERADMIN = "superadmin"
class AuditAction(str, Enum):
"""Audit log action types."""
@@ -154,6 +161,27 @@ class AuditAction(str, Enum):
DEPARTMENT_DELETED = "department.deleted"
DEPARTMENT_MEMBER_ADDED = "department.member.added"
DEPARTMENT_MEMBER_REMOVED = "department.member.removed"
DEPARTMENT_CERT_POLICY_UPDATED = "department.cert_policy.updated"
# Organization invite actions
ORG_INVITE_CANCELLED = "org.invite.cancelled"
# MFA reminder
ORG_MFA_REMINDER_SENT = "org.mfa_reminder.sent"
# API key actions
ORG_API_KEY_CREATED = "org.api_key.created"
ORG_API_KEY_UPDATED = "org.api_key.updated"
ORG_API_KEY_DELETED = "org.api_key.deleted"
# OIDC client actions
ORG_CLIENT_CREATED = "org.client.created"
ORG_CLIENT_UPDATED = "org.client.updated"
ORG_CLIENT_DEACTIVATED = "org.client.deactivated"
# Principal department link actions
PRINCIPAL_DEPARTMENT_LINKED = "principal.department.linked"
PRINCIPAL_DEPARTMENT_UNLINKED = "principal.department.unlinked"
class OIDCGrantType(str, Enum):
@@ -0,0 +1,24 @@
"""Add allowed_cors_origins to oidc_clients.
Revision ID: b7e3f1a92c4d
Revises: a1b2c3d4e5f6
Create Date: 2026-04-27 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'b7e3f1a92c4d'
down_revision = 'a1b2c3d4e5f6'
branch_labels = None
depends_on = None
def upgrade():
op.add_column('oidc_clients', sa.Column('allowed_cors_origins', sa.JSON(), nullable=True))
def downgrade():
op.drop_column('oidc_clients', 'allowed_cors_origins')
@@ -0,0 +1,122 @@
"""Consolidate user and superadmin sessions into unified sessions table.
Revision ID: c8d2e4f6a1b3
Revises: b7e3f1a92c4d
Create Date: 2026-04-28 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'c8d2e4f6a1b3'
down_revision = 'b7e3f1a92c4d'
branch_labels = None
depends_on = None
def upgrade():
# 1. Add new columns (nullable initially for data migration)
op.add_column('sessions', sa.Column('owner_type', sa.String(20), nullable=True))
op.add_column('sessions', sa.Column('owner_id', sa.String(36), nullable=True))
# 2. Backfill existing user sessions: owner_type = 'user', owner_id = user_id
op.execute("""
UPDATE sessions
SET owner_type = 'user',
owner_id = user_id
WHERE owner_type IS NULL
""")
# 3. Migrate superadmin sessions into the sessions table
op.execute("""
INSERT INTO sessions (
id, owner_type, owner_id, token, status,
ip_address, user_agent, device_info,
expires_at, last_activity_at, revoked_at, revoked_reason,
is_compliance_only, created_at, updated_at, deleted_at
)
SELECT
id, 'superadmin', superadmin_id, token, 'active',
ip_address, user_agent, NULL,
expires_at, last_activity_at, revoked_at, revoked_reason,
FALSE, created_at, updated_at, deleted_at
FROM superadmin_sessions
""")
# 4. Make owner_type and owner_id NOT NULL
op.alter_column('sessions', 'owner_type', nullable=False)
op.alter_column('sessions', 'owner_id', nullable=False)
# 5. Make user_id nullable (no longer the sole owner reference)
op.alter_column('sessions', 'user_id', nullable=True)
# 6. Create indexes for efficient owner-scoped queries
op.create_index(
'ix_sessions_owner_type_owner_id',
'sessions',
['owner_type', 'owner_id']
)
op.create_index(
'ix_sessions_owner_type',
'sessions',
['owner_type']
)
op.create_index(
'ix_sessions_owner_id',
'sessions',
['owner_id']
)
# 7. Drop the now-redundant superadmin_sessions table
op.drop_table('superadmin_sessions')
def downgrade():
# 1. Recreate superadmin_sessions table
op.create_table(
'superadmin_sessions',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('superadmin_id', sa.String(36), sa.ForeignKey('superadmins.id'), nullable=False, index=True),
sa.Column('token', sa.String(255), unique=True, nullable=False, index=True),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('last_activity_at', sa.DateTime, nullable=False),
sa.Column('ip_address', sa.String(45), nullable=True),
sa.Column('user_agent', sa.Text, nullable=True),
sa.Column('revoked_at', sa.DateTime, nullable=True),
sa.Column('revoked_reason', sa.String(255), nullable=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
)
# 2. Move superadmin sessions back to superadmin_sessions
op.execute("""
INSERT INTO superadmin_sessions (
id, superadmin_id, token, expires_at, last_activity_at,
ip_address, user_agent, revoked_at, revoked_reason,
created_at, updated_at, deleted_at
)
SELECT
id, owner_id, token, expires_at, last_activity_at,
ip_address, user_agent, revoked_at, revoked_reason,
created_at, updated_at, deleted_at
FROM sessions
WHERE owner_type = 'superadmin'
""")
# 3. Remove superadmin sessions from sessions table
op.execute("DELETE FROM sessions WHERE owner_type = 'superadmin'")
# 4. Drop indexes
op.drop_index('ix_sessions_owner_id', table_name='sessions')
op.drop_index('ix_sessions_owner_type', table_name='sessions')
op.drop_index('ix_sessions_owner_type_owner_id', table_name='sessions')
# 5. Remove new columns
op.drop_column('sessions', 'owner_id')
op.drop_column('sessions', 'owner_type')
# 6. Make user_id NOT NULL again
op.alter_column('sessions', 'user_id', nullable=False)
+21 -1
View File
@@ -148,8 +148,28 @@ def test_html_email():
success = provider.send(message)
print(f"Result: {'✅ SUCCESS' if success else '❌ FAILED'}")
# Test 8: Invite Accepted
print("\n--- Test 8: Invite Accepted ---")
html_body = email_templates.build_invite_accepted_html(
inviter_name="Admin User",
member_name="New Member",
member_email="newmember@example.com",
org_name="Acme Corporation",
role="Member",
org_link="https://secuird.tech/organizations/org-123",
)
message = EmailMessage(
to="cory@hawkvelt.id.au",
subject="New Member accepted your invitation to Acme Corporation",
body="Plain text version: New Member has accepted your invitation.",
html_body=html_body,
from_address="Secuird <noreply@secuird.tech>",
)
success = provider.send(message)
print(f"Result: {'✅ SUCCESS' if success else '❌ FAILED'}")
print("\n" + "=" * 50)
print("All 7 email templates sent!")
print("All 8 email templates sent!")
print("=" * 50)
@@ -0,0 +1,186 @@
"""Superadmin session timeout integration tests.
Validates the absolute-only timeout policy for superadmin sessions.
Superadmin sessions do NOT have idle timeout only absolute timeout.
"""
import pytest
import uuid
from datetime import datetime, timedelta, timezone
from tests.integration.client.base import ApiError
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def assert_success(response: dict, message_contains: str = "") -> dict:
"""Assert that an api_response-wrapped payload succeeded."""
data = response.get("data", {})
assert response.get("success") is not False, (
f"Expected success but got error: {response.get('message')}"
)
if message_contains:
assert message_contains.lower() in response.get("message", "").lower(), (
f"Expected message to contain '{message_contains}' but got: {response.get('message')}"
)
return data
def _get_session_row(integration_app, token: str):
"""Look up the Session model row for a given bearer token."""
from gatehouse_app.models.user.session import Session
with integration_app.app_context():
return Session.query.filter_by(token=token).first()
def _touch_session(integration_app, session_id: str, **updates):
"""Directly update columns on a Session row.
Only use this to simulate the passage of time never to assert
internal state.
"""
from gatehouse_app.models.user.session import Session
with integration_app.app_context():
sess = Session.query.get(session_id)
for attr, value in updates.items():
setattr(sess, attr, value)
from gatehouse_app import db
db.session.commit()
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def superadmin_credentials(integration_app):
"""Create a superadmin and return login credentials."""
from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService
email = f"admin_{uuid.uuid4().hex[:8]}@gatehouse.local"
password = "SuperAdmin123!"
with integration_app.app_context():
sa = SuperadminAuthService.create_superadmin(
email=email,
credential=password,
full_name="Test Superadmin",
)
return {"id": str(sa.id), "email": email, "password": password}
@pytest.fixture
def logged_in_superadmin(integration_client, superadmin_credentials, integration_app):
"""Log in as superadmin and return session metadata.
Returns dict with ``superadmin``, ``token``, ``session_id``, ``session_row``.
"""
creds = superadmin_credentials
resp = integration_client.post(
"/api/v1/superadmin/auth/login",
data={"email": creds["email"], "password": creds["password"]},
)
data = assert_success(resp)
token = data["token"]
session_row = _get_session_row(integration_app, token)
assert session_row is not None, "Session row should exist after superadmin login"
return {
"superadmin": creds,
"token": token,
"session_id": session_row.id,
"session_row": session_row,
}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestSuperadminSessionTimeouts:
"""Absolute-only timeout behavior for superadmin sessions."""
def test_superadmin_session_valid_before_timeout(
self, integration_client, logged_in_superadmin,
):
"""SA-SESS-01 — Fresh superadmin session is accepted."""
integration_client.set_token(logged_in_superadmin["token"])
result = integration_client.get("/api/v1/superadmin/auth/me")
data = assert_success(result)
assert "superadmin" in data
def test_absolute_timeout_rejects_superadmin(
self, integration_client, logged_in_superadmin, integration_app,
):
"""SA-SESS-02 — Superadmin session rejected after absolute timeout.
Push ``created_at`` far into the past. The session must be
rejected even though ``last_activity_at`` is fresh.
"""
_touch_session(
integration_app,
logged_in_superadmin["session_id"],
created_at=datetime.now(timezone.utc) - timedelta(days=1),
last_activity_at=datetime.now(timezone.utc),
)
integration_client.set_token(logged_in_superadmin["token"])
with pytest.raises(ApiError) as exc_info:
integration_client.get("/api/v1/superadmin/auth/me")
assert exc_info.value.status_code == 401
def test_idle_timeout_does_NOT_reject_superadmin(
self, integration_client, logged_in_superadmin, integration_app,
):
"""SA-SESS-03 — Superadmin sessions have NO idle timeout.
Push ``last_activity_at`` far into the past but keep
``created_at`` recent. The session should still be valid
because superadmin sessions only use absolute timeout.
"""
_touch_session(
integration_app,
logged_in_superadmin["session_id"],
last_activity_at=datetime.now(timezone.utc) - timedelta(hours=1),
)
integration_client.set_token(logged_in_superadmin["token"])
result = integration_client.get("/api/v1/superadmin/auth/me")
data = assert_success(result)
assert "superadmin" in data
def test_revoked_superadmin_session_rejected(
self, integration_client, logged_in_superadmin,
):
"""SA-SESS-04 — Revoked superadmin session is rejected."""
integration_client.set_token(logged_in_superadmin["token"])
# Logout revokes the session
integration_client.post("/api/v1/superadmin/auth/logout")
integration_client.clear_token()
# Try using the old token
integration_client.set_token(logged_in_superadmin["token"])
with pytest.raises(ApiError) as exc_info:
integration_client.get("/api/v1/superadmin/auth/me")
assert exc_info.value.status_code == 401
def test_superadmin_session_has_owner_type(
self, integration_app, logged_in_superadmin,
):
"""SA-SESS-05 — Superadmin session row has owner_type='superadmin'."""
from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import SessionType
with integration_app.app_context():
sess = Session.query.get(logged_in_superadmin["session_id"])
assert sess is not None
assert sess.owner_type == SessionType.SUPERADMIN
assert sess.owner_id == logged_in_superadmin["superadmin"]["id"]
+503
View File
@@ -0,0 +1,503 @@
"""Unit tests for per-client CORS feature.
WHAT: Tests for per-client CORS origin resolution, including OIDCClient
model methods, request client_id extraction, effective origin
resolution, and integration with the CORS middleware.
WHY: Per-client CORS prevents one OIDC client from making cross-origin
requests meant for another; misconfiguration breaks browser flows.
EXPECTED: Correct origin derivation, proper client_id extraction, and
correct CORS headers on OIDC endpoints.
"""
import base64
import json
from unittest.mock import patch
import pytest
from flask import Flask, request as flask_request
import gatehouse_app.middleware.cors as cors_module
from gatehouse_app.middleware.cors import (
_get_oidc_client_id_from_request,
_get_effective_cors_origins,
setup_cors,
)
# ---------------------------------------------------------------------------
# Helper: build a lightweight stub that quacks like OIDCClient
# ---------------------------------------------------------------------------
class StubClient:
"""Minimal stand-in for OIDCClient -- no SQLAlchemy, no DB needed."""
def __init__(self, *, allowed_cors_origins=None, redirect_uris=None):
self.allowed_cors_origins = allowed_cors_origins
self.redirect_uris = redirect_uris or []
def get_effective_origins(self):
from urllib.parse import urlparse
if self.allowed_cors_origins is None:
return None
if "+" in self.allowed_cors_origins:
origins = set()
for uri in self.redirect_uris:
parsed = urlparse(uri)
if parsed.scheme and parsed.hostname:
port = f":{parsed.port}" if parsed.port else ""
origins.add(f"{parsed.scheme}://{parsed.hostname}{port}")
return sorted(origins)
return list(self.allowed_cors_origins)
def is_origin_allowed(self, origin):
effective = self.get_effective_origins()
if effective is None:
return None
if "*" in effective:
return True
return origin in effective
def _basic_auth_header(client_id, secret="secret"):
"""Return a 'Basic <b64>' Authorization header value."""
return "Basic " + base64.b64encode(f"{client_id}:{secret}".encode()).decode()
# ---------------------------------------------------------------------------
# OIDCClient.get_effective_origins
# ---------------------------------------------------------------------------
class TestGetEffectiveOrigins:
def test_returns_none_when_allowed_cors_origins_is_none(self):
"""TEST: PCORS-GE-01 -- None config signals 'use global'."""
client = StubClient(allowed_cors_origins=None)
assert client.get_effective_origins() is None
def test_derives_from_redirect_uris_when_plus_sign(self):
"""TEST: PCORS-GE-02 -- '+' in list derives origins from redirect_uris."""
client = StubClient(
allowed_cors_origins=["+"],
redirect_uris=[
"https://app.example.com/callback",
"http://localhost:3000/callback",
],
)
assert client.get_effective_origins() == [
"http://localhost:3000",
"https://app.example.com",
]
def test_derives_with_port(self):
"""TEST: PCORS-GE-03 -- Non-standard ports are preserved."""
client = StubClient(
allowed_cors_origins=["+"],
redirect_uris=["https://app.example.com:8443/cb"],
)
assert client.get_effective_origins() == ["https://app.example.com:8443"]
def test_deduplicates_derived_origins(self):
"""TEST: PCORS-GE-04 -- Duplicate redirect URIs produce unique origins."""
client = StubClient(
allowed_cors_origins=["+"],
redirect_uris=[
"https://app.example.com/cb1",
"https://app.example.com/cb2",
],
)
assert client.get_effective_origins() == ["https://app.example.com"]
def test_returns_list_as_is_when_normal_list(self):
"""TEST: PCORS-GE-05 -- Normal list is returned unchanged."""
origins = ["https://a.com", "https://b.com"]
client = StubClient(allowed_cors_origins=origins)
assert client.get_effective_origins() == origins
def test_returns_wildcard_list_as_is(self):
"""TEST: PCORS-GE-06 -- ['*'] is returned (handled downstream)."""
client = StubClient(allowed_cors_origins=["*"])
assert client.get_effective_origins() == ["*"]
def test_empty_redirect_uris_with_plus_returns_empty(self):
"""TEST: PCORS-GE-07 -- '+' with empty redirect_uris yields empty list."""
client = StubClient(
allowed_cors_origins=["+"],
redirect_uris=[],
)
assert client.get_effective_origins() == []
def test_skips_malformed_redirect_uris(self):
"""TEST: PCORS-GE-08 -- URIs without scheme/host are skipped."""
client = StubClient(
allowed_cors_origins=["+"],
redirect_uris=["not-a-uri", "https://good.com/cb"],
)
assert client.get_effective_origins() == ["https://good.com"]
# ---------------------------------------------------------------------------
# OIDCClient.is_origin_allowed
# ---------------------------------------------------------------------------
class TestIsOriginAllowed:
def test_returns_none_when_no_per_client_config(self):
"""TEST: PCORS-IO-01 -- None config defers to global CORS."""
client = StubClient(allowed_cors_origins=None)
assert client.is_origin_allowed("https://anything.com") is None
def test_returns_true_when_wildcard(self):
"""TEST: PCORS-IO-02 -- '*' in effective origins allows any origin."""
client = StubClient(allowed_cors_origins=["*"])
assert client.is_origin_allowed("https://evil.com") is True
def test_returns_true_for_matching_origin(self):
"""TEST: PCORS-IO-03 -- Matching origin is allowed."""
client = StubClient(
allowed_cors_origins=["https://app.example.com", "https://other.com"],
)
assert client.is_origin_allowed("https://app.example.com") is True
def test_returns_false_for_non_matching_origin(self):
"""TEST: PCORS-IO-04 -- Non-matching origin is rejected."""
client = StubClient(allowed_cors_origins=["https://app.example.com"])
assert client.is_origin_allowed("https://evil.com") is False
def test_returns_false_for_empty_list(self):
"""TEST: PCORS-IO-05 -- Empty list rejects everything."""
client = StubClient(allowed_cors_origins=[])
assert client.is_origin_allowed("https://anything.com") is False
# ---------------------------------------------------------------------------
# _get_oidc_client_id_from_request
# ---------------------------------------------------------------------------
class TestGetOidcClientIdFromRequest:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
def test_extracts_from_basic_auth(self, app):
"""TEST: PCORS-CI-01 -- Basic Auth header yields client_id."""
with app.test_request_context(
"/oidc/token",
method="POST",
headers={"Authorization": _basic_auth_header("my-client")},
):
assert _get_oidc_client_id_from_request() == "my-client"
def test_extracts_from_form_body(self, app):
"""TEST: PCORS-CI-02 -- Form-encoded body yields client_id."""
with app.test_request_context(
"/oidc/token",
method="POST",
data={"client_id": "form-client", "grant_type": "client_credentials"},
):
assert _get_oidc_client_id_from_request() == "form-client"
def test_extracts_from_json_body(self, app):
"""TEST: PCORS-CI-03 -- JSON body yields client_id."""
with app.test_request_context(
"/oidc/token",
method="POST",
data=json.dumps({"client_id": "json-client", "grant_type": "client_credentials"}),
content_type="application/json",
):
assert _get_oidc_client_id_from_request() == "json-client"
def test_extracts_from_bearer_jwt(self, app):
"""TEST: PCORS-CI-04 -- Bearer JWT payload yields client_id."""
payload = base64.urlsafe_b64encode(
json.dumps({"client_id": "jwt-client"}).encode()
).rstrip(b"=").decode()
token = f"eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.{payload}.sig"
with app.test_request_context(
"/oidc/userinfo",
method="GET",
headers={"Authorization": f"Bearer {token}"},
):
assert _get_oidc_client_id_from_request() == "jwt-client"
def test_returns_none_for_non_oidc_endpoint(self, app):
"""TEST: PCORS-CI-05 -- Non-OIDC path returns None."""
with app.test_request_context(
"/api/v1/users",
method="GET",
headers={"Authorization": _basic_auth_header("x")},
):
assert _get_oidc_client_id_from_request() is None
def test_returns_none_when_no_client_id_found(self, app):
"""TEST: PCORS-CI-06 -- OIDC token endpoint with no credentials returns None."""
with app.test_request_context(
"/oidc/token",
method="POST",
data={"grant_type": "client_credentials"},
):
assert _get_oidc_client_id_from_request() is None
def test_extracts_from_revoke_endpoint(self, app):
"""TEST: PCORS-CI-07 -- /oidc/revoke also accepts Basic Auth."""
with app.test_request_context(
"/oidc/revoke",
method="POST",
headers={"Authorization": _basic_auth_header("rev-client")},
):
assert _get_oidc_client_id_from_request() == "rev-client"
def test_extracts_from_introspect_endpoint(self, app):
"""TEST: PCORS-CI-08 -- /oidc/introspect also accepts Basic Auth."""
with app.test_request_context(
"/oidc/introspect",
method="POST",
headers={"Authorization": _basic_auth_header("int-client")},
):
assert _get_oidc_client_id_from_request() == "int-client"
def test_returns_none_for_options_preflight(self, app):
"""TEST: PCORS-CI-09 -- OPTIONS preflight cannot carry client credentials."""
with app.test_request_context(
"/oidc/token",
method="OPTIONS",
headers={
"Origin": "https://app.com",
"Access-Control-Request-Method": "POST",
},
):
assert _get_oidc_client_id_from_request() is None
# ---------------------------------------------------------------------------
# _get_effective_cors_origins
# ---------------------------------------------------------------------------
class TestGetEffectiveCorsOrigins:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
app.config["CORS_ORIGINS"] = ["https://global.com"]
return app
def test_global_config_for_non_oidc_endpoint(self, app):
"""TEST: PCORS-EO-01 -- Non-OIDC path always uses global config."""
with app.test_request_context("/api/v1/users", method="GET"):
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://global.com"]
def test_per_client_origins_for_oidc_endpoint(self, app):
"""TEST: PCORS-EO-02 -- OIDC endpoint with configured client uses per-client origins."""
fake_client = StubClient(allowed_cors_origins=["https://client.com"])
with app.test_request_context(
"/oidc/token",
method="POST",
headers={"Authorization": _basic_auth_header("test-client")},
):
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://client.com"]
def test_fallback_when_client_not_found(self, app):
"""TEST: PCORS-EO-03 -- Unknown client_id falls back to global config."""
with app.test_request_context(
"/oidc/token",
method="POST",
headers={"Authorization": _basic_auth_header("unknown")},
):
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = None
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://global.com"]
def test_fallback_when_allowed_cors_origins_is_none(self, app):
"""TEST: PCORS-EO-04 -- Client with None origins falls back to global."""
fake_client = StubClient(allowed_cors_origins=None)
with app.test_request_context(
"/oidc/token",
method="POST",
headers={"Authorization": _basic_auth_header("test-client")},
):
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://global.com"]
def test_fallback_on_db_error(self, app):
"""TEST: PCORS-EO-05 -- Database exception falls back to global config."""
with app.test_request_context(
"/oidc/token",
method="POST",
headers={"Authorization": _basic_auth_header("test-client")},
):
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.side_effect = Exception("DB down")
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://global.com"]
def test_fallback_when_no_client_id_extracted(self, app):
"""TEST: PCORS-EO-06 -- OIDC path with no credentials falls back to global."""
with app.test_request_context(
"/oidc/token",
method="POST",
data={"grant_type": "client_credentials"},
):
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://global.com"]
# ---------------------------------------------------------------------------
# Integration: OIDC endpoint CORS headers
# ---------------------------------------------------------------------------
class TestOidcEndpointCorsIntegration:
@pytest.fixture
def app_with_global_and_client(self):
"""Flask app with global CORS and route stubs for integration tests."""
app = Flask(__name__)
app.config["TESTING"] = True
app.config["CORS_ORIGINS"] = ["https://global.com"]
app.config["CORS_SUPPORTS_CREDENTIALS"] = True
@app.route("/oidc/token", methods=["POST", "OPTIONS"])
def oidc_token():
return {"status": "ok"}, 200
@app.route("/api/v1/users", methods=["GET", "OPTIONS"])
def api_users():
return {"users": []}, 200
setup_cors(app)
return app
def test_post_oidc_with_per_client_origin_includes_cors_headers(
self, app_with_global_and_client
):
"""TEST: PCORS-INT-01 -- POST to /oidc/token with per-client origin and
Basic Auth includes CORS headers. Per-client CORS applies to the actual
request (which carries credentials), not the preflight."""
fake_client = StubClient(allowed_cors_origins=["https://client-app.com"])
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
with app_with_global_and_client.test_client() as client:
resp = client.post(
"/oidc/token",
headers={
"Origin": "https://client-app.com",
"Authorization": _basic_auth_header("test-client"),
},
)
assert resp.status_code == 200
assert (
resp.headers.get("Access-Control-Allow-Origin")
== "https://client-app.com"
)
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
def test_post_oidc_with_non_matching_per_client_origin_no_cors_headers(
self, app_with_global_and_client
):
"""TEST: PCORS-INT-02 -- POST with origin not in per-client list has no
CORS headers."""
fake_client = StubClient(allowed_cors_origins=["https://allowed.com"])
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
with app_with_global_and_client.test_client() as client:
resp = client.post(
"/oidc/token",
headers={
"Origin": "https://evil.com",
"Authorization": _basic_auth_header("test-client"),
},
)
assert resp.status_code == 200
assert resp.headers.get("Access-Control-Allow-Origin") is None
def test_post_oidc_wildcard_client_echoes_origin(self, app_with_global_and_client):
"""TEST: PCORS-INT-03 -- Client with '*' echoes the request origin."""
fake_client = StubClient(allowed_cors_origins=["*"])
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
with app_with_global_and_client.test_client() as client:
resp = client.post(
"/oidc/token",
headers={
"Origin": "https://any-origin.com",
"Authorization": _basic_auth_header("test-client"),
},
)
assert resp.status_code == 200
assert (
resp.headers.get("Access-Control-Allow-Origin")
== "https://any-origin.com"
)
def test_preflight_oidc_falls_back_to_global(self, app_with_global_and_client):
"""TEST: PCORS-INT-04 -- OPTIONS preflight cannot carry client credentials,
so it uses global CORS config. A preflight from a per-client-only origin
that is not in the global list will not receive CORS headers."""
fake_client = StubClient(allowed_cors_origins=["https://client-app.com"])
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
with app_with_global_and_client.test_client() as client:
resp = client.options(
"/oidc/token",
headers={
"Origin": "https://client-app.com",
"Access-Control-Request-Method": "POST",
},
)
# Origin not in global list; no CORS headers on preflight
assert resp.headers.get("Access-Control-Allow-Origin") is None
def test_preflight_oidc_with_global_origin_succeeds(self, app_with_global_and_client):
"""TEST: PCORS-INT-05 -- OPTIONS preflight from a globally-allowed origin
returns 204 with CORS headers even for OIDC endpoints."""
with app_with_global_and_client.test_client() as client:
resp = client.options(
"/oidc/token",
headers={
"Origin": "https://global.com",
"Access-Control-Request-Method": "POST",
},
)
assert resp.status_code == 204
assert resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
def test_non_oidc_endpoint_uses_global_cors(self, app_with_global_and_client):
"""TEST: PCORS-INT-06 -- Non-OIDC endpoint uses global CORS config."""
with app_with_global_and_client.test_client() as client:
resp = client.options(
"/api/v1/users",
headers={
"Origin": "https://global.com",
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 204
assert resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
def test_post_oidc_no_auth_uses_global_cors(self, app_with_global_and_client):
"""TEST: PCORS-INT-07 -- POST to OIDC endpoint without credentials uses
global CORS (cannot identify client)."""
with app_with_global_and_client.test_client() as client:
resp = client.post(
"/oidc/token",
headers={"Origin": "https://global.com"},
)
assert resp.status_code == 200
assert (
resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
)