3 Commits

Author SHA1 Message Date
coryHawkvelt 803bf4f4f2 refactor: consolidate user and superadmin sessions into unified model 2026-04-28 20:54:15 +09:30
coryHawkvelt 5abbadff9a Improve auditing 2026-04-28 17:17:54 +09:30
coryHawkvelt 63a3109a82 oidc-client mk1 2026-04-27 02:44:32 +09:30
29 changed files with 1715 additions and 134 deletions
+240
View File
@@ -0,0 +1,240 @@
# Per-Client CORS Origins for OIDC Endpoints
## Overview
Gatehouse OIDC now supports **per-client CORS origins**. This allows each OIDC client to declare which browser origins are permitted to make cross-origin requests to OIDC endpoints (`/oidc/token`, `/oidc/revoke`, `/oidc/userinfo`, `/oidc/introspect`).
Previously, CORS was controlled by a single server-wide `CORS_ORIGINS` environment variable. If your SPA's origin wasn't in that list, the browser would block requests to OIDC endpoints — even if your OIDC client was properly configured.
## How It Works
### The Problem
When a browser-based SPA (e.g., running at `http://localhost:8080`) exchanges an authorization code for tokens, it makes a POST request to `/oidc/token`. The browser sends a preflight OPTIONS request first, and the server must respond with CORS headers allowing the SPA's origin.
Previously, if `http://localhost:8080` wasn't in the server's `CORS_ORIGINS` env var, the preflight would fail and the SPA couldn't get tokens.
### The Solution
Each OIDC client can now declare its own `allowed_cors_origins`. When a request hits an OIDC endpoint, the server checks the client's CORS configuration first, then falls back to the global config.
## Configuration
### Setting CORS Origins on an OIDC Client
When creating or updating an OIDC client, set the `allowed_cors_origins` field:
```json
{
"name": "My SPA",
"client_id": "oidc_myapp",
"redirect_uris": ["http://localhost:8080/callback", "https://app.example.com/callback"],
"allowed_cors_origins": ["http://localhost:8080", "https://app.example.com"],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"scopes": ["openid", "profile", "email"]
}
```
### Auto-Derive from Redirect URIs
Set `allowed_cors_origins` to `["+"]` to automatically derive CORS origins from the client's `redirect_uris`. The server extracts the scheme, hostname, and port from each redirect URI.
```json
{
"redirect_uris": ["http://localhost:8080/callback", "https://app.example.com/callback"],
"allowed_cors_origins": ["+"]
}
```
This is equivalent to:
```json
{
"allowed_cors_origins": ["http://localhost:8080", "https://app.example.com"]
}
```
### Use Global Config (Default)
Set `allowed_cors_origins` to `null` (or omit it) to use the server's global `CORS_ORIGINS` config. This is the default behavior for existing clients.
```json
{
"allowed_cors_origins": null
}
```
### Allow All Origins (Not Recommended)
Set `allowed_cors_origins` to `["*"]` to allow any origin. **This is not recommended for production.**
```json
{
"allowed_cors_origins": ["*"]
}
```
## Affected Endpoints
The following OIDC endpoints support per-client CORS:
| Endpoint | Method | How Client is Identified |
|---|---|---|
| `/oidc/token` | POST | `client_id` in request body or Basic Auth header |
| `/oidc/revoke` | POST | `client_id` in request body or Basic Auth header |
| `/oidc/introspect` | POST | `client_id` in request body or Basic Auth header |
| `/oidc/userinfo` | GET/POST | `client_id` extracted from Bearer token |
## SPA Integration Guide
### Step 1: Register Your OIDC Client
Register your SPA as an OIDC client with the correct redirect URIs and CORS origins:
```json
{
"name": "My React App",
"redirect_uris": ["http://localhost:3000/callback"],
"allowed_cors_origins": ["http://localhost:3000"],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"scopes": ["openid", "profile", "email"],
"is_confidential": false,
"require_pkce": true
}
```
### Step 2: Use PKCE (Required for Public Clients)
Gatehouse requires PKCE for public clients. Generate a code verifier and challenge before redirecting to the authorize endpoint:
```javascript
// Generate PKCE
const codeVerifier = generateRandomString(128);
const codeChallenge = await sha256(codeVerifier);
const state = generateRandomString(32);
// Store verifier for later
sessionStorage.setItem('pkce_verifier', codeVerifier);
// Redirect to authorize
const authUrl = new URL('https://api.example.com/api/v1/oidc/authorize');
authUrl.searchParams.set('response_type', 'code');
authUrl.searchParams.set('client_id', 'oidc_myapp');
authUrl.searchParams.set('redirect_uri', 'http://localhost:3000/callback');
authUrl.searchParams.set('scope', 'openid profile email');
authUrl.searchParams.set('state', state);
authUrl.searchParams.set('code_challenge', codeChallenge);
authUrl.searchParams.set('code_challenge_method', 'S256');
window.location.href = authUrl.toString();
```
### Step 3: Exchange Code for Tokens
After the user authenticates and is redirected back to your callback page, exchange the authorization code for tokens:
```javascript
// Extract code from URL
const params = new URLSearchParams(window.location.search);
const code = params.get('code');
const state = params.get('state');
// Verify state matches
if (state !== sessionStorage.getItem('pkce_state')) {
throw new Error('State mismatch');
}
// Exchange code for tokens
const response = await fetch('https://api.example.com/api/v1/oidc/token', {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: new URLSearchParams({
grant_type: 'authorization_code',
code: code,
redirect_uri: 'http://localhost:3000/callback',
client_id: 'oidc_myapp',
code_verifier: sessionStorage.getItem('pkce_verifier'),
}),
});
const tokens = await response.json();
// tokens.access_token, tokens.id_token, tokens.refresh_token
```
The server will return CORS headers because `http://localhost:3000` is in the client's `allowed_cors_origins`.
### Step 4: Refresh Tokens
```javascript
const response = await fetch('https://api.example.com/api/v1/oidc/token', {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: new URLSearchParams({
grant_type: 'refresh_token',
refresh_token: storedRefreshToken,
client_id: 'oidc_myapp',
}),
});
```
### Step 5: Call UserInfo
```javascript
const response = await fetch('https://api.example.com/api/v1/oidc/userinfo', {
headers: {
'Authorization': `Bearer ${accessToken}`,
},
});
const userInfo = await response.json();
```
## Troubleshooting
### "CORS error" when exchanging code for tokens
**Cause**: Your SPA's origin is not in the client's `allowed_cors_origins` or the server's global `CORS_ORIGINS`.
**Fix**: Add your SPA's origin to the client's `allowed_cors_origins`:
```json
{
"allowed_cors_origins": ["http://localhost:3000"]
}
```
### "CORS error" on preflight OPTIONS request
**Cause**: The preflight request doesn't carry client credentials, so the server can't identify which client to check CORS origins for. It falls back to the global `CORS_ORIGINS`.
**Fix**: Either add your origin to the global `CORS_ORIGINS` env var, or ensure the actual POST request (after preflight) includes the `client_id` in the request body.
### CORS works for `/oidc/token` but not `/oidc/userinfo`
**Cause**: The userinfo endpoint identifies the client from the Bearer token. If the token doesn't contain a `client_id` claim, the server falls back to global config.
**Fix**: Ensure your access tokens include the `client_id` claim (this is the default behavior).
## API Reference
### OIDCClient Fields
| Field | Type | Description |
|---|---|---|
| `allowed_cors_origins` | `string[]` or `null` | List of allowed browser origins. `null` = use global config. `["+"]` = auto-derive from redirect URIs. `["*"]` = allow all (not recommended). |
### CORS Headers Returned
When a request's origin matches the client's allowed origins:
```
Access-Control-Allow-Origin: <request-origin>
Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS
Access-Control-Allow-Headers: Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token
Access-Control-Allow-Credentials: true
Access-Control-Max-Age: 3600
```
+56
View File
@@ -10,6 +10,8 @@ from gatehouse_app.models import Department, DepartmentMembership
from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.user_service import UserService from gatehouse_app.services.user_service import UserService
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
class DepartmentCreateSchema(Schema): class DepartmentCreateSchema(Schema):
@@ -127,6 +129,15 @@ def create_department(org_id):
db.session.add(dept) db.session.add(dept)
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_CREATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="department",
resource_id=str(dept.id),
description=f"Department '{dept.name}' created",
)
return api_response( return api_response(
data={"department": dept.to_dict()}, data={"department": dept.to_dict()},
message="Department created successfully", message="Department created successfully",
@@ -255,6 +266,15 @@ def update_department(org_id, dept_id):
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="department",
resource_id=str(dept.id),
description=f"Department '{dept.name}' updated",
)
return api_response( return api_response(
data={"department": dept.to_dict()}, data={"department": dept.to_dict()},
message="Department updated successfully", message="Department updated successfully",
@@ -308,6 +328,15 @@ def delete_department(org_id, dept_id):
dept.deleted_at = db.func.now() dept.deleted_at = db.func.now()
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_DELETED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="department",
resource_id=str(dept.id),
description=f"Department '{dept.name}' deleted",
)
return api_response( return api_response(
message="Department deleted successfully", message="Department deleted successfully",
) )
@@ -461,6 +490,15 @@ def add_department_member(org_id, dept_id):
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_MEMBER_ADDED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user.id),
description=f"Added user {user.email} to department '{dept.name}'",
)
member_dict = membership.to_dict() member_dict = membership.to_dict()
member_dict["user"] = user.to_dict() member_dict["user"] = user.to_dict()
@@ -533,6 +571,15 @@ def remove_department_member(org_id, dept_id, user_id):
membership.deleted_at = db.func.now() membership.deleted_at = db.func.now()
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_MEMBER_REMOVED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user_id),
description=f"Removed user from department '{dept.name}'",
)
return api_response( return api_response(
message="Member removed successfully", message="Member removed successfully",
) )
@@ -699,5 +746,14 @@ def set_dept_cert_policy(org_id, dept_id):
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.DEPARTMENT_CERT_POLICY_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="department",
resource_id=str(dept_id),
description=f"Certificate policy updated for department '{dept.name}'",
)
return api_response(data={"cert_policy": policy.to_dict()}, message="Certificate policy saved") return api_response(data={"cert_policy": policy.to_dict()}, message="Certificate policy saved")
@@ -3,6 +3,8 @@ from flask import g, request
from gatehouse_app.api.v1 import api_v1_bp from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required from gatehouse_app.utils.decorators import login_required
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/admin/oauth/providers", methods=["GET"]) @api_v1_bp.route("/admin/oauth/providers", methods=["GET"])
@@ -78,6 +80,14 @@ def admin_configure_app_provider(provider: str):
db.session.add(cfg) db.session.add(cfg)
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_CONFIG_UPDATE if cfg else AuditAction.EXTERNAL_AUTH_CONFIG_CREATE,
user_id=g.current_user.id,
resource_type="oauth_provider",
resource_id=provider,
description=f"OAuth provider '{provider}' configured (enabled={cfg.is_enabled})",
)
return api_response( return api_response(
data={"provider": {"id": provider, "client_id": cfg.client_id, "is_enabled": cfg.is_enabled}}, data={"provider": {"id": provider, "client_id": cfg.client_id, "is_enabled": cfg.is_enabled}},
message=f"{provider.capitalize()} OAuth provider configured successfully", message=f"{provider.capitalize()} OAuth provider configured successfully",
@@ -104,4 +114,13 @@ def admin_delete_app_provider(provider: str):
return api_response(success=False, message=f"Provider '{provider}' is not configured", status=404, error_type="NOT_FOUND") return api_response(success=False, message=f"Provider '{provider}' is not configured", status=404, error_type="NOT_FOUND")
cfg.delete() cfg.delete()
AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_CONFIG_DELETE,
user_id=g.current_user.id,
resource_type="oauth_provider",
resource_id=provider,
description=f"OAuth provider '{provider}' configuration removed",
)
return api_response(message=f"{provider.capitalize()} OAuth provider configuration removed") return api_response(message=f"{provider.capitalize()} OAuth provider configuration removed")
+15
View File
@@ -26,6 +26,9 @@ from gatehouse_app.exceptions.auth_exceptions import (
AccountSuspendedError, AccountSuspendedError,
AccountInactiveError, AccountInactiveError,
) )
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -849,6 +852,18 @@ def oidc_register():
) )
client.save() client.save()
OIDCAuditService.log_event(
event_type="client_registration",
client_id=client_id,
user_id=g.current_user.id if hasattr(g, "current_user") else None,
success=True,
metadata={
"client_name": client_name,
"redirect_uris": redirect_uris,
"organization_id": str(organization.id),
},
)
response = jsonify({ response = jsonify({
"client_id": client_id, "client_id": client_id,
"client_secret": client_secret, "client_secret": client_secret,
+33 -4
View File
@@ -8,6 +8,8 @@ from gatehouse_app.utils.decorators import login_required, require_admin, full_a
from gatehouse_app.models.organization import OrganizationApiKey from gatehouse_app.models.organization import OrganizationApiKey
from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
class ApiKeyCreateSchema(Schema): class ApiKeyCreateSchema(Schema):
@@ -130,7 +132,16 @@ def create_api_key(org_id):
name=data["name"], name=data["name"],
description=data.get("description"), description=data.get("description"),
) )
AuditService.log_action(
action=AuditAction.ORG_API_KEY_CREATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="api_key",
resource_id=str(api_key.id),
description=f"API key '{api_key.name}' created",
)
# Return the key data with the plain text key (only on creation) # Return the key data with the plain text key (only on creation)
key_dict = api_key.to_dict() key_dict = api_key.to_dict()
key_dict["key"] = plain_key # Include plain text only on creation key_dict["key"] = plain_key # Include plain text only on creation
@@ -219,9 +230,18 @@ def update_api_key(org_id, key_id):
api_key.name = data["name"] api_key.name = data["name"]
if "description" in data: if "description" in data:
api_key.description = data["description"] api_key.description = data["description"]
api_key.save() api_key.save()
AuditService.log_action(
action=AuditAction.ORG_API_KEY_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="api_key",
resource_id=str(api_key.id),
description=f"API key '{api_key.name}' updated",
)
return api_response( return api_response(
data={"api_key": api_key.to_dict()}, data={"api_key": api_key.to_dict()},
message="API key updated successfully", message="API key updated successfully",
@@ -293,7 +313,16 @@ def delete_api_key(org_id, key_id):
# Soft delete the API key # Soft delete the API key
api_key.delete(soft=True) api_key.delete(soft=True)
AuditService.log_action(
action=AuditAction.ORG_API_KEY_DELETED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="api_key",
resource_id=str(api_key.id),
description=f"API key '{api_key.name}' deleted",
)
return api_response( return api_response(
message="API key deleted successfully", message="API key deleted successfully",
) )
+21
View File
@@ -6,6 +6,8 @@ from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin from gatehouse_app.utils.decorators import login_required, require_admin
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.api.v1.organizations._helpers import _get_system_ca_dict from gatehouse_app.api.v1.organizations._helpers import _get_system_ca_dict
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/organizations/<org_id>/cas", methods=["GET"]) @api_v1_bp.route("/organizations/<org_id>/cas", methods=["GET"])
@@ -66,6 +68,16 @@ def update_org_ca(org_id, ca_id):
ca.max_cert_validity_hours = data["max_cert_validity_hours"] ca.max_cert_validity_hours = data["max_cert_validity_hours"]
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.CA_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="CA",
resource_id=ca_id,
description=f"CA '{ca.name}' updated",
)
return api_response(data={"ca": ca.to_dict()}, message="CA updated successfully") return api_response(data={"ca": ca.to_dict()}, message="CA updated successfully")
except ValidationError as e: except ValidationError as e:
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
@@ -150,6 +162,15 @@ def create_org_ca(org_id):
return api_response(success=False, message="A CA with that name already exists in this organization (it may have been recently deleted — choose a different name).", status=400, error_type="DUPLICATE_NAME") return api_response(success=False, message="A CA with that name already exists in this organization (it may have been recently deleted — choose a different name).", status=400, error_type="DUPLICATE_NAME")
raise raise
AuditService.log_action(
action=AuditAction.CA_CREATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="CA",
resource_id=str(ca.id),
description=f"CA '{ca.name}' created",
)
return api_response(data={"ca": ca.to_dict()}, message="CA created successfully", status=201) return api_response(data={"ca": ca.to_dict()}, message="CA created successfully", status=201)
except MaValidationError as e: except MaValidationError as e:
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
@@ -5,6 +5,8 @@ from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
from gatehouse_app.extensions import db, bcrypt from gatehouse_app.extensions import db, bcrypt
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/organizations/<org_id>/clients", methods=["GET"]) @api_v1_bp.route("/organizations/<org_id>/clients", methods=["GET"])
@@ -28,6 +30,7 @@ def list_org_clients(org_id):
"redirect_uris": c.redirect_uris, "redirect_uris": c.redirect_uris,
"scopes": c.scopes, "scopes": c.scopes,
"grant_types": c.grant_types, "grant_types": c.grant_types,
"allowed_cors_origins": c.allowed_cors_origins,
"is_active": c.is_active, "is_active": c.is_active,
"created_at": c.created_at.isoformat() + "Z", "created_at": c.created_at.isoformat() + "Z",
} }
@@ -78,6 +81,15 @@ def create_org_client(org_id):
db.session.add(client) db.session.add(client)
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.ORG_CLIENT_CREATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="oidc_client",
resource_id=str(client.id),
description=f"OIDC client '{client.name}' created",
)
return api_response( return api_response(
data={ data={
"client": { "client": {
@@ -125,6 +137,15 @@ def update_org_client(org_id, client_id):
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.ORG_CLIENT_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="oidc_client",
resource_id=str(client.id),
description=f"OIDC client '{client.name}' updated",
)
return api_response( return api_response(
data={ data={
"client": { "client": {
@@ -154,4 +175,14 @@ def delete_org_client(org_id, client_id):
client.is_active = False client.is_active = False
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.ORG_CLIENT_DEACTIVATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="oidc_client",
resource_id=str(client.id),
description=f"OIDC client '{client.name}' deactivated",
)
return api_response(data={}, message="Client deactivated successfully") return api_response(data={}, message="Client deactivated successfully")
@@ -7,6 +7,8 @@ from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
from gatehouse_app.schemas.organization_schema import OrganizationCreateSchema, OrganizationUpdateSchema from gatehouse_app.schemas.organization_schema import OrganizationCreateSchema, OrganizationUpdateSchema
from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/organizations", methods=["POST"]) @api_v1_bp.route("/organizations", methods=["POST"])
@@ -32,6 +34,14 @@ def create_organization():
description=data.get("description"), description=data.get("description"),
logo_url=data.get("logo_url"), logo_url=data.get("logo_url"),
) )
AuditService.log_action(
action=AuditAction.ORG_CREATE,
user_id=g.current_user.id,
organization_id=org.id,
resource_type="organization",
resource_id=str(org.id),
description=f"Organization '{org.name}' created",
)
return api_response(data={"organization": org.to_dict()}, message="Organization created successfully", status=201) return api_response(data={"organization": org.to_dict()}, message="Organization created successfully", status=201)
except ValidationError as e: except ValidationError as e:
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
@@ -60,6 +70,14 @@ def update_organization(org_id):
data = schema.load(request.json) data = schema.load(request.json)
org = OrganizationService.get_organization_by_id(org_id) org = OrganizationService.get_organization_by_id(org_id)
org = OrganizationService.update_organization(org=org, user_id=g.current_user.id, **data) org = OrganizationService.update_organization(org=org, user_id=g.current_user.id, **data)
AuditService.log_action(
action=AuditAction.ORG_UPDATE,
user_id=g.current_user.id,
organization_id=org.id,
resource_type="organization",
resource_id=str(org.id),
description=f"Organization '{org.name}' updated",
)
return api_response(data={"organization": org.to_dict()}, message="Organization updated successfully") return api_response(data={"organization": org.to_dict()}, message="Organization updated successfully")
except ValidationError as e: except ValidationError as e:
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
@@ -92,4 +110,12 @@ def delete_organization(org_id):
) )
OrganizationService.force_delete_organization(org=org, user_id=caller.id) OrganizationService.force_delete_organization(org=org, user_id=caller.id)
AuditService.log_action(
action=AuditAction.ORG_DELETE,
user_id=caller.id,
organization_id=org.id,
resource_type="organization",
resource_id=str(org.id),
description=f"Organization '{org.name}' deleted",
)
return api_response(message="Organization deleted successfully") return api_response(message="Organization deleted successfully")
@@ -136,6 +136,15 @@ def cancel_org_invite(org_id, invite_id):
return api_response(success=False, message="Invite not found", status=404) return api_response(success=False, message="Invite not found", status=404)
invite.delete(soft=True) invite.delete(soft=True)
AuditService.log_action(
action=AuditAction.ORG_INVITE_CANCELLED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="org_invite",
resource_id=invite.id,
metadata={"invited_email": invite.email, "role": invite.role},
description=f"Invitation for {invite.email} cancelled",
)
return api_response(data={}, message="Invite cancelled") return api_response(data={}, message="Invite cancelled")
+35 -1
View File
@@ -7,7 +7,8 @@ from gatehouse_app.utils.decorators import login_required, require_admin, full_a
from gatehouse_app.schemas.organization_schema import InviteMemberSchema, UpdateMemberRoleSchema from gatehouse_app.schemas.organization_schema import InviteMemberSchema, UpdateMemberRoleSchema
from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.user_service import UserService from gatehouse_app.services.user_service import UserService
from gatehouse_app.utils.constants import OrganizationRole from gatehouse_app.utils.constants import AuditAction, OrganizationRole
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/organizations/<org_id>/members", methods=["GET"]) @api_v1_bp.route("/organizations/<org_id>/members", methods=["GET"])
@@ -43,6 +44,14 @@ def add_organization_member(org_id):
role = OrganizationRole(data["role"]) role = OrganizationRole(data["role"])
member = OrganizationService.add_member(org=org, user_id=user.id, role=role, inviter_id=g.current_user.id) member = OrganizationService.add_member(org=org, user_id=user.id, role=role, inviter_id=g.current_user.id)
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ADD,
user_id=g.current_user.id,
organization_id=org.id,
resource_type="user",
resource_id=str(user.id),
description=f"Added user {user.email} to organization with role {role.value}",
)
member_dict = member.to_dict() member_dict = member.to_dict()
member_dict["user"] = user.to_dict() member_dict["user"] = user.to_dict()
return api_response(data={"member": member_dict}, message="Member added successfully", status=201) return api_response(data={"member": member_dict}, message="Member added successfully", status=201)
@@ -60,6 +69,14 @@ def remove_organization_member(org_id, user_id):
OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id) OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id)
except ValueError as e: except ValueError as e:
return api_response(success=False, message=str(e), status=403, error_type="OWNER_PROTECTION") return api_response(success=False, message=str(e), status=403, error_type="OWNER_PROTECTION")
AuditService.log_action(
action=AuditAction.ORG_MEMBER_REMOVE,
user_id=g.current_user.id,
organization_id=org.id,
resource_type="user",
resource_id=str(user_id),
description=f"Removed user {user_id} from organization",
)
return api_response(message="Member removed successfully") return api_response(message="Member removed successfully")
@@ -74,6 +91,14 @@ def update_member_role(org_id, user_id):
org = OrganizationService.get_organization_by_id(org_id) org = OrganizationService.get_organization_by_id(org_id)
new_role = OrganizationRole(data["role"]) new_role = OrganizationRole(data["role"])
member = OrganizationService.update_member_role(org=org, user_id=user_id, new_role=new_role, updater_id=g.current_user.id) member = OrganizationService.update_member_role(org=org, user_id=user_id, new_role=new_role, updater_id=g.current_user.id)
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ROLE_CHANGE,
user_id=g.current_user.id,
organization_id=org.id,
resource_type="user",
resource_id=str(user_id),
description=f"Changed role for user {user_id} to {new_role.value}",
)
member_dict = member.to_dict() member_dict = member.to_dict()
member_dict["user"] = member.user.to_dict() member_dict["user"] = member.user.to_dict()
return api_response(data={"member": member_dict}, message="Member role updated successfully") return api_response(data={"member": member_dict}, message="Member role updated successfully")
@@ -180,4 +205,13 @@ def send_mfa_reminder(org_id, user_id):
html_body=html_body, html_body=html_body,
) )
AuditService.log_action(
action=AuditAction.ORG_MFA_REMINDER_SENT,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user_id),
description=f"MFA reminder sent to {user.email}",
)
return api_response(data={}, message="Reminder sent successfully") return api_response(data={}, message="Reminder sent successfully")
+22 -1
View File
@@ -3,8 +3,9 @@ from flask import g, request
from gatehouse_app.api.v1 import api_v1_bp from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
from gatehouse_app.utils.constants import OrganizationRole from gatehouse_app.utils.constants import AuditAction, OrganizationRole
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.services.audit_service import AuditService
@api_v1_bp.route("/organizations/<org_id>/roles", methods=["GET"]) @api_v1_bp.route("/organizations/<org_id>/roles", methods=["GET"])
@@ -59,6 +60,16 @@ def assign_role_to_member(org_id, role_name):
membership.role = new_role membership.role = new_role
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ROLE_CHANGE,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(target_user_id),
description=f"Role changed to {new_role.value} for user {target_user_id}",
)
return api_response(data={"user_id": target_user_id, "role": new_role.value}, message=f"Role updated to {new_role.value}") return api_response(data={"user_id": target_user_id, "role": new_role.value}, message=f"Role updated to {new_role.value}")
@@ -82,4 +93,14 @@ def remove_role_from_member(org_id, role_name, user_id):
org = OrganizationService.get_organization_by_id(org_id) org = OrganizationService.get_organization_by_id(org_id)
OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id) OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id)
AuditService.log_action(
action=AuditAction.ORG_MEMBER_REMOVE,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user_id),
description=f"Member {user_id} removed from organization via role removal",
)
return api_response(data={"user_id": user_id}, message="Member removed from organization") return api_response(data={"user_id": user_id}, message="Member removed from organization")
+65
View File
@@ -10,6 +10,8 @@ from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.user_service import UserService from gatehouse_app.services.user_service import UserService
from gatehouse_app.exceptions import OrganizationNotFoundError from gatehouse_app.exceptions import OrganizationNotFoundError
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
class PrincipalCreateSchema(Schema): class PrincipalCreateSchema(Schema):
@@ -127,6 +129,15 @@ def create_principal(org_id):
db.session.add(principal) db.session.add(principal)
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_CREATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="principal",
resource_id=str(principal.id),
description=f"Principal '{principal.name}' created",
)
return api_response( return api_response(
data={"principal": principal.to_dict()}, data={"principal": principal.to_dict()},
message="Principal created successfully", message="Principal created successfully",
@@ -255,6 +266,15 @@ def update_principal(org_id, principal_id):
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_UPDATED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="principal",
resource_id=str(principal.id),
description=f"Principal '{principal.name}' updated",
)
return api_response( return api_response(
data={"principal": principal.to_dict()}, data={"principal": principal.to_dict()},
message="Principal updated successfully", message="Principal updated successfully",
@@ -308,6 +328,15 @@ def delete_principal(org_id, principal_id):
principal.deleted_at = db.func.now() principal.deleted_at = db.func.now()
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_DELETED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="principal",
resource_id=str(principal.id),
description=f"Principal '{principal.name}' deleted",
)
return api_response( return api_response(
message="Principal deleted successfully", message="Principal deleted successfully",
) )
@@ -476,6 +505,15 @@ def add_principal_member(org_id, principal_id):
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_MEMBER_ADDED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user.id),
description=f"Added user {user.email} to principal '{principal.name}'",
)
member_dict = membership.to_dict() member_dict = membership.to_dict()
member_dict["user"] = user.to_dict() member_dict["user"] = user.to_dict()
@@ -548,6 +586,15 @@ def remove_principal_member(org_id, principal_id, user_id):
membership.deleted_at = db.func.now() membership.deleted_at = db.func.now()
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_MEMBER_REMOVED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="user",
resource_id=str(user_id),
description=f"Removed user from principal '{principal.name}'",
)
return api_response( return api_response(
message="Member removed successfully", message="Member removed successfully",
) )
@@ -697,6 +744,15 @@ def link_principal_to_department(org_id, principal_id, dept_id):
error_type="SERVER_ERROR", error_type="SERVER_ERROR",
) )
AuditService.log_action(
action=AuditAction.PRINCIPAL_DEPARTMENT_LINKED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="principal",
resource_id=str(principal_id),
description=f"Principal '{principal.name}' linked to department '{dept.name}'",
)
return api_response( return api_response(
data={ data={
"principal": principal.to_dict(), "principal": principal.to_dict(),
@@ -774,6 +830,15 @@ def unlink_principal_from_department(org_id, principal_id, dept_id):
link.deleted_at = db.func.now() link.deleted_at = db.func.now()
db.session.commit() db.session.commit()
AuditService.log_action(
action=AuditAction.PRINCIPAL_DEPARTMENT_UNLINKED,
user_id=g.current_user.id,
organization_id=org_id,
resource_type="principal",
resource_id=str(principal_id),
description=f"Principal '{principal.name}' unlinked from department '{dept.name}'",
)
return api_response( return api_response(
message="Principal unlinked from department successfully", message="Principal unlinked from department successfully",
) )
+2
View File
@@ -9,6 +9,7 @@ from gatehouse_app.utils.response import api_response
from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
from gatehouse_app.utils.constants import AuditAction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -105,6 +106,7 @@ def login():
@superadmin_bp.route("/auth/logout", methods=["POST"]) @superadmin_bp.route("/auth/logout", methods=["POST"])
@superadmin_required @superadmin_required
@superadmin_audit_log(action=AuditAction.USER_LOGOUT, resource_type="session")
def logout(): def logout():
"""Superadmin logout endpoint. """Superadmin logout endpoint.
+10 -6
View File
@@ -15,7 +15,7 @@ def superadmin_required(f):
"""Decorator to require superadmin Bearer token authentication. """Decorator to require superadmin Bearer token authentication.
Extracts token from Authorization: Bearer {token} header, Extracts token from Authorization: Bearer {token} header,
validates the session against SuperadminSession table, validates the session against the unified sessions table,
and sets g.current_superadmin and g.superadmin_session. and sets g.current_superadmin and g.superadmin_session.
Returns 401 if no valid session, 403 if not a superadmin. Returns 401 if no valid session, 403 if not a superadmin.
@@ -46,10 +46,14 @@ def superadmin_required(f):
token = parts[1] token = parts[1]
# Import here to avoid circular imports # Import here to avoid circular imports
from gatehouse_app.models.superadmin import SuperadminSession, Superadmin from gatehouse_app.models.user.session import Session
from gatehouse_app.models.superadmin import Superadmin
from gatehouse_app.utils.constants import SessionType
# Get active session by token # Get active session by token, scoped to superadmin
session = SuperadminSession.query.filter_by(token=token).first() session = Session.query.filter_by(
token=token, owner_type=SessionType.SUPERADMIN
).first()
if not session: if not session:
return api_response( return api_response(
@@ -68,8 +72,8 @@ def superadmin_required(f):
error_type="SESSION_INACTIVE" error_type="SESSION_INACTIVE"
) )
# Get the superadmin # Get the superadmin by owner_id
superadmin = session.superadmin superadmin = Superadmin.query.get(session.owner_id)
if not superadmin: if not superadmin:
return api_response( return api_response(
success=False, success=False,
+87 -2
View File
@@ -1,6 +1,12 @@
"""CORS middleware configuration.""" """CORS middleware configuration."""
import base64
import json
from urllib.parse import parse_qs
from flask import request, make_response from flask import request, make_response
from gatehouse_app.models import OIDCClient
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS" ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
ALLOWED_HEADERS = ( ALLOWED_HEADERS = (
"Content-Type, Authorization, X-Requested-With, X-Request-ID, " "Content-Type, Authorization, X-Requested-With, X-Request-ID, "
@@ -40,6 +46,85 @@ def _cors_origin_header(cors_origins, request_origin):
return None return None
def _get_oidc_client_id_from_request():
"""Extract client_id from OIDC endpoint requests."""
path = request.path
# POST to /oidc/token, /oidc/revoke, /oidc/introspect
if request.method == "POST" and any(
path.endswith(ep) for ep in ("/oidc/token", "/oidc/revoke", "/oidc/introspect")
):
# Try Basic Auth header first
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
try:
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
client_id, _, _ = decoded.partition(":")
if client_id:
return client_id
except Exception:
pass
# Try form body
if request.form:
client_id = request.form.get("client_id")
if client_id:
return client_id
# Try JSON body
if request.is_json:
try:
client_id = request.json.get("client_id")
if client_id:
return client_id
except Exception:
pass
return None
# GET/POST to /oidc/userinfo
if path.endswith("/oidc/userinfo"):
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
try:
payload_b64 = token.split(".")[1]
padding = 4 - len(payload_b64) % 4
if padding != 4:
payload_b64 += "=" * padding
payload = json.loads(base64.urlsafe_b64decode(payload_b64))
return payload.get("client_id")
except Exception:
return None
return None
def _get_effective_cors_origins(app, request):
"""Get effective CORS origins, checking per-client config for OIDC endpoints."""
global_origins = app.config.get("CORS_ORIGINS", [])
if "/oidc/" not in request.path:
return global_origins
try:
client_id = _get_oidc_client_id_from_request()
if not client_id:
return global_origins
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return global_origins
effective = client.get_effective_origins()
if effective is not None:
return effective
except Exception:
pass
return global_origins
def setup_cors(app): def setup_cors(app):
""" """
Configure CORS for the application. Configure CORS for the application.
@@ -54,7 +139,7 @@ def setup_cors(app):
"""Handle CORS preflight OPTIONS requests.""" """Handle CORS preflight OPTIONS requests."""
if request.method == "OPTIONS": if request.method == "OPTIONS":
origin = request.headers.get("Origin") origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", []) cors_origins = _get_effective_cors_origins(app, request)
if not _is_origin_allowed(origin, cors_origins): if not _is_origin_allowed(origin, cors_origins):
return None return None
@@ -73,7 +158,7 @@ def setup_cors(app):
def after_request_cors(response): def after_request_cors(response):
"""Add CORS headers to non-preflight responses.""" """Add CORS headers to non-preflight responses."""
origin = request.headers.get("Origin") origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", []) cors_origins = _get_effective_cors_origins(app, request)
allow_origin = _cors_origin_header(cors_origins, origin) allow_origin = _cors_origin_header(cors_origins, origin)
if allow_origin: if allow_origin:
-2
View File
@@ -118,7 +118,6 @@ from gatehouse_app.models.zerotier import ( # noqa: F401
from gatehouse_app.models.superadmin import ( # noqa: F401 from gatehouse_app.models.superadmin import ( # noqa: F401
Superadmin, Superadmin,
SuperadminSession, SuperadminSession,
SuperadminSessionStatus,
) )
from gatehouse_app.models.superadmin_audit_log import SuperadminAuditLog # noqa: F401 from gatehouse_app.models.superadmin_audit_log import SuperadminAuditLog # noqa: F401
from gatehouse_app.models.security.user_security_policy import ( # noqa: F401 from gatehouse_app.models.security.user_security_policy import ( # noqa: F401
@@ -186,6 +185,5 @@ __all__ = [
# Superadmin # Superadmin
"Superadmin", "Superadmin",
"SuperadminSession", "SuperadminSession",
"SuperadminSessionStatus",
"SuperadminAuditLog", "SuperadminAuditLog",
] ]
+34
View File
@@ -1,4 +1,6 @@
"""OIDC Client model.""" """OIDC Client model."""
from urllib.parse import urlparse
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType
@@ -21,6 +23,7 @@ class OIDCClient(BaseModel):
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
response_types = db.Column(db.JSON, nullable=False) # Allowed response types response_types = db.Column(db.JSON, nullable=False) # Allowed response types
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
allowed_cors_origins = db.Column(db.JSON, nullable=True, default=None) # Per-client CORS origins
# Client metadata # Client metadata
logo_uri = db.Column(db.String(512), nullable=True) logo_uri = db.Column(db.String(512), nullable=True)
@@ -81,6 +84,37 @@ class OIDCClient(BaseModel):
"""Check if a redirect URI is allowed for this client.""" """Check if a redirect URI is allowed for this client."""
return redirect_uri in self.redirect_uris return redirect_uri in self.redirect_uris
def get_effective_origins(self) -> list | None:
"""Get effective CORS origins for this client.
Returns None to signal "use global config", a derived list from
redirect_uris when "+" is present, or the configured list as-is.
"""
if self.allowed_cors_origins is None:
return None
if "+" in self.allowed_cors_origins:
origins = set()
for uri in self.redirect_uris:
parsed = urlparse(uri)
if parsed.scheme and parsed.hostname:
port = f":{parsed.port}" if parsed.port else ""
origins.add(f"{parsed.scheme}://{parsed.hostname}{port}")
return sorted(origins)
return list(self.allowed_cors_origins)
def is_origin_allowed(self, origin: str) -> bool | None:
"""Check if a browser origin is allowed for CORS.
Returns True/False when a per-client list is configured,
or None to defer to the global CORS policy.
"""
effective = self.get_effective_origins()
if effective is None:
return None
if "*" in effective:
return True
return origin in effective
def has_scope(self, scope: str) -> bool: def has_scope(self, scope: str) -> bool:
"""Check if client is allowed to request a specific scope.""" """Check if client is allowed to request a specific scope."""
return scope in self.scopes return scope in self.scopes
+2 -2
View File
@@ -1,5 +1,5 @@
"""Superadmin models.""" """Superadmin models."""
from gatehouse_app.models.superadmin.superadmin import Superadmin from gatehouse_app.models.superadmin.superadmin import Superadmin
from gatehouse_app.models.superadmin.superadmin_session import SuperadminSession, SuperadminSessionStatus from gatehouse_app.models.user.session import Session as SuperadminSession
__all__ = ["Superadmin", "SuperadminSession", "SuperadminSessionStatus"] __all__ = ["Superadmin", "SuperadminSession"]
@@ -23,11 +23,15 @@ class Superadmin(BaseModel):
is_active = db.Column(db.Boolean, default=True, nullable=False) is_active = db.Column(db.Boolean, default=True, nullable=False)
last_login_at = db.Column(db.DateTime, nullable=True) last_login_at = db.Column(db.DateTime, nullable=True)
# Relationship to sessions # Relationship to sessions (unified model, scoped to superadmin owner_type)
sessions = db.relationship( sessions = db.relationship(
"SuperadminSession", "Session",
back_populates="superadmin", primaryjoin=(
cascade="all, delete-orphan" "and_(Superadmin.id == foreign(Session.owner_id), "
"Session.owner_type == 'superadmin')"
),
cascade="all, delete-orphan",
lazy="dynamic",
) )
# Relationship to audit logs # Relationship to audit logs
@@ -1,80 +0,0 @@
"""Superadmin session model."""
import logging
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
logger = logging.getLogger(__name__)
class SuperadminSessionStatus:
"""Session status constants."""
ACTIVE = "active"
REVOKED = "revoked"
EXPIRED = "expired"
class SuperadminSession(BaseModel):
"""Session model for superadmin authentication."""
__tablename__ = "superadmin_sessions"
superadmin_id = db.Column(
db.String(36),
db.ForeignKey("superadmins.id"),
nullable=False,
index=True
)
token = db.Column(db.String(255), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
last_activity_at = db.Column(
db.DateTime,
nullable=False,
default=lambda: datetime.now(timezone.utc)
)
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
# Relationship
superadmin = db.relationship("Superadmin", back_populates="sessions")
def __repr__(self):
return f"<SuperadminSession superadmin_id={self.superadmin_id}>"
def is_active(self):
"""Check if session is currently active."""
now = datetime.now(timezone.utc)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return (
self.deleted_at is None
and self.revoked_at is None
and expires_at > now
)
def is_expired(self):
"""Check if session has expired."""
now = datetime.now(timezone.utc)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
def revoke(self, reason: str = None):
"""Revoke the session."""
self.revoked_at = datetime.now(timezone.utc)
if reason:
self.revoked_reason = reason
from gatehouse_app import db
db.session.commit()
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
exclude.append("token")
return super().to_dict(exclude=exclude)
+54 -12
View File
@@ -3,15 +3,24 @@ from datetime import datetime, timedelta, timezone
from flask import current_app from flask import current_app
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import SessionStatus from gatehouse_app.utils.constants import SessionStatus, SessionType
class Session(BaseModel): class Session(BaseModel):
"""Session model for tracking user sessions.""" """Session model for tracking user and superadmin sessions."""
__tablename__ = "sessions" __tablename__ = "sessions"
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True) # Owner discriminator — determines which table the owner_id references
owner_type = db.Column(
db.String(20), nullable=False, default=SessionType.USER, index=True
)
owner_id = db.Column(db.String(36), nullable=False, index=True)
# Legacy column kept for backward compatibility during migration;
# new code should use owner_id / owner_type.
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True)
token = db.Column(db.String(255), unique=True, nullable=False, index=True) token = db.Column(db.String(255), unique=True, nullable=False, index=True)
status = db.Column(db.Enum(SessionStatus), default=SessionStatus.ACTIVE, nullable=False) status = db.Column(db.Enum(SessionStatus), default=SessionStatus.ACTIVE, nullable=False)
@@ -34,21 +43,37 @@ class Session(BaseModel):
# Relationships # Relationships
user = db.relationship("User", back_populates="sessions") user = db.relationship("User", back_populates="sessions")
# Composite index for owner-scoped queries
__table_args__ = (
db.Index("ix_sessions_owner_type_owner_id", "owner_type", "owner_id"),
)
# ---- Convenience properties ------------------------------------------------
@property
def is_user(self):
return self.owner_type == SessionType.USER
@property
def is_superadmin(self):
return self.owner_type == SessionType.SUPERADMIN
# ---- Core methods ----------------------------------------------------------
def __repr__(self): def __repr__(self):
"""String representation of Session.""" return f"<Session owner_type={self.owner_type} owner_id={self.owner_id} status={self.status}>"
return f"<Session user_id={self.user_id} status={self.status}>"
def is_active(self): def is_active(self):
"""Check if session is currently active. """Check if session is currently active.
Sessions are evaluated against two independent timeouts: User sessions are evaluated against two independent timeouts:
- Idle timeout: expires if no request has been made within - Idle timeout: expires if no request has been made within
SESSION_IDLE_TIMEOUT seconds (default 15 min). SESSION_IDLE_TIMEOUT seconds (default 15 min).
- Absolute timeout: expires if SESSION_ABSOLUTE_TIMEOUT seconds - Absolute timeout: expires if SESSION_ABSOLUTE_TIMEOUT seconds
have elapsed since the session was created (default 8 h), have elapsed since the session was created (default 8 h).
regardless of activity.
A session must satisfy *both* constraints to remain active. Superadmin sessions use absolute timeout only (no idle timeout).
A session must satisfy *all* applicable constraints to remain active.
""" """
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
created_at = self.created_at created_at = self.created_at
@@ -59,12 +84,21 @@ class Session(BaseModel):
if last_activity_at.tzinfo is None: if last_activity_at.tzinfo is None:
last_activity_at = last_activity_at.replace(tzinfo=timezone.utc) last_activity_at = last_activity_at.replace(tzinfo=timezone.utc)
idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800) absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout)
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout) absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
if self.is_superadmin:
# Superadmin: absolute timeout only
return (
self.status == SessionStatus.ACTIVE
and now < absolute_expires_at
and self.deleted_at is None
)
# User: idle + absolute timeout
idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout)
return ( return (
self.status == SessionStatus.ACTIVE self.status == SessionStatus.ACTIVE
and now < idle_expires_at and now < idle_expires_at
@@ -83,6 +117,8 @@ class Session(BaseModel):
capped so that the session never exceeds the absolute lifetime capped so that the session never exceeds the absolute lifetime
(``created_at + absolute timeout``). (``created_at + absolute timeout``).
Superadmin sessions only update last_activity_at (no sliding window).
Args: Args:
duration_seconds: Override for the idle timeout. When *None* duration_seconds: Override for the idle timeout. When *None*
(the common case), the value is read from (the common case), the value is read from
@@ -90,6 +126,12 @@ class Session(BaseModel):
""" """
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
if self.is_superadmin:
# Superadmin: just bump last_activity_at, no sliding window
self.last_activity_at = now
db.session.commit()
return
if duration_seconds is None: if duration_seconds is None:
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900) duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
+3 -1
View File
@@ -8,7 +8,7 @@ from gatehouse_app.extensions import db, bcrypt
from gatehouse_app.models.user.user import User from gatehouse_app.models.user.user import User
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.models.user.session import Session from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, UserStatus, AuditAction from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, SessionType, UserStatus, AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError
from gatehouse_app.services.audit_service import AuditService from gatehouse_app.services.audit_service import AuditService
@@ -165,6 +165,8 @@ class AuthService:
# Create session # Create session
session = Session( session = Session(
owner_type=SessionType.USER,
owner_id=user.id,
user_id=user.id, user_id=user.id,
token=token, token=token,
status=SessionStatus.ACTIVE, status=SessionStatus.ACTIVE,
+64 -11
View File
@@ -1,7 +1,7 @@
"""Session service.""" """Session service."""
from datetime import datetime, timezone from datetime import datetime, timezone
from gatehouse_app.models.user.session import Session from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import SessionStatus from gatehouse_app.utils.constants import SessionStatus, SessionType
class SessionService: class SessionService:
@@ -28,18 +28,22 @@ class SessionService:
).first() ).first()
@staticmethod @staticmethod
def get_user_sessions(user_id, active_only=True): def get_owner_sessions(owner_type, owner_id, active_only=True):
""" """Get all sessions for an owner (user or superadmin).
Get all sessions for a user.
Args: Args:
user_id: User ID owner_type: SessionType.USER or SessionType.SUPERADMIN
owner_id: Owner ID
active_only: If True, only return active sessions active_only: If True, only return active sessions
Returns: Returns:
List of Session instances List of Session instances
""" """
query = Session.query.filter_by(user_id=user_id, deleted_at=None) query = Session.query.filter_by(
owner_type=owner_type,
owner_id=owner_id,
deleted_at=None,
)
if active_only: if active_only:
query = query.filter_by(status=SessionStatus.ACTIVE).filter( query = query.filter_by(status=SessionStatus.ACTIVE).filter(
@@ -49,18 +53,67 @@ class SessionService:
return query.all() return query.all()
@staticmethod @staticmethod
def revoke_user_sessions(user_id, reason="User logged out from all devices"): def get_user_sessions(user_id, active_only=True):
"""Get all sessions for a user.
Args:
user_id: User ID
active_only: If True, only return active sessions
Returns:
List of Session instances
""" """
Revoke all active sessions for a user. return SessionService.get_owner_sessions(
SessionType.USER, user_id, active_only=active_only
)
@staticmethod
def get_superadmin_sessions(superadmin_id, active_only=True):
"""Get all sessions for a superadmin.
Args:
superadmin_id: Superadmin ID
active_only: If True, only return active sessions
Returns:
List of Session instances
"""
return SessionService.get_owner_sessions(
SessionType.SUPERADMIN, superadmin_id, active_only=active_only
)
@staticmethod
def revoke_owner_sessions(owner_type, owner_id, reason="Logged out from all devices"):
"""Revoke all active sessions for an owner.
Args:
owner_type: SessionType.USER or SessionType.SUPERADMIN
owner_id: Owner ID
reason: Reason for revocation
"""
sessions = SessionService.get_owner_sessions(owner_type, owner_id, active_only=True)
for session in sessions:
session.revoke(reason=reason)
@staticmethod
def revoke_user_sessions(user_id, reason="User logged out from all devices"):
"""Revoke all active sessions for a user.
Args: Args:
user_id: User ID user_id: User ID
reason: Reason for revocation reason: Reason for revocation
""" """
sessions = SessionService.get_user_sessions(user_id, active_only=True) SessionService.revoke_owner_sessions(SessionType.USER, user_id, reason=reason)
for session in sessions: @staticmethod
session.revoke(reason=reason) def revoke_superadmin_sessions(superadmin_id, reason="Superadmin logged out"):
"""Revoke all active sessions for a superadmin.
Args:
superadmin_id: Superadmin ID
reason: Reason for revocation
"""
SessionService.revoke_owner_sessions(SessionType.SUPERADMIN, superadmin_id, reason=reason)
@staticmethod @staticmethod
def cleanup_expired_sessions(): def cleanup_expired_sessions():
@@ -6,7 +6,9 @@ from typing import Optional
from flask import request, current_app from flask import request, current_app
from gatehouse_app.extensions import db, bcrypt from gatehouse_app.extensions import db, bcrypt
from gatehouse_app.models.superadmin import Superadmin, SuperadminSession from gatehouse_app.models.superadmin import Superadmin
from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import SessionType
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
@@ -70,15 +72,17 @@ class SuperadminAuthService:
duration_seconds: Session duration in seconds (default 8 hours) duration_seconds: Session duration in seconds (default 8 hours)
Returns: Returns:
SuperadminSession instance Session instance
""" """
# Generate secure token # Generate secure token
token = secrets.token_urlsafe(32) token = secrets.token_urlsafe(32)
# Create session # Create session using unified model
session = SuperadminSession( session = Session(
superadmin_id=superadmin_id, owner_type=SessionType.SUPERADMIN,
owner_id=superadmin_id,
token=token, token=token,
status="active",
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds), expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
last_activity_at=datetime.now(timezone.utc), last_activity_at=datetime.now(timezone.utc),
ip_address=request.remote_addr, ip_address=request.remote_addr,
@@ -97,7 +101,9 @@ class SuperadminAuthService:
session_id: Session ID to revoke session_id: Session ID to revoke
reason: Optional revocation reason reason: Optional revocation reason
""" """
session = SuperadminSession.query.get(session_id) session = Session.query.filter_by(
id=session_id, owner_type=SessionType.SUPERADMIN
).first()
if session: if session:
session.revoke(reason=reason) session.revoke(reason=reason)
logger.info(f"[SuperadminAuth] Session {session_id} revoked: {reason or 'No reason'}") logger.info(f"[SuperadminAuth] Session {session_id} revoked: {reason or 'No reason'}")
@@ -111,9 +117,11 @@ class SuperadminAuthService:
except_token: Optional token to keep (current session) except_token: Optional token to keep (current session)
reason: Optional revocation reason reason: Optional revocation reason
""" """
query = SuperadminSession.query.filter_by(superadmin_id=superadmin_id) query = Session.query.filter_by(
owner_type=SessionType.SUPERADMIN, owner_id=superadmin_id
)
if except_token: if except_token:
query = query.filter(SuperadminSession.token != except_token) query = query.filter(Session.token != except_token)
sessions = query.all() sessions = query.all()
for session in sessions: for session in sessions:
+28
View File
@@ -52,6 +52,13 @@ class SessionStatus(str, Enum):
REVOKED = "revoked" REVOKED = "revoked"
class SessionType(str, Enum):
"""Session owner type discriminator."""
USER = "user"
SUPERADMIN = "superadmin"
class AuditAction(str, Enum): class AuditAction(str, Enum):
"""Audit log action types.""" """Audit log action types."""
@@ -154,6 +161,27 @@ class AuditAction(str, Enum):
DEPARTMENT_DELETED = "department.deleted" DEPARTMENT_DELETED = "department.deleted"
DEPARTMENT_MEMBER_ADDED = "department.member.added" DEPARTMENT_MEMBER_ADDED = "department.member.added"
DEPARTMENT_MEMBER_REMOVED = "department.member.removed" DEPARTMENT_MEMBER_REMOVED = "department.member.removed"
DEPARTMENT_CERT_POLICY_UPDATED = "department.cert_policy.updated"
# Organization invite actions
ORG_INVITE_CANCELLED = "org.invite.cancelled"
# MFA reminder
ORG_MFA_REMINDER_SENT = "org.mfa_reminder.sent"
# API key actions
ORG_API_KEY_CREATED = "org.api_key.created"
ORG_API_KEY_UPDATED = "org.api_key.updated"
ORG_API_KEY_DELETED = "org.api_key.deleted"
# OIDC client actions
ORG_CLIENT_CREATED = "org.client.created"
ORG_CLIENT_UPDATED = "org.client.updated"
ORG_CLIENT_DEACTIVATED = "org.client.deactivated"
# Principal department link actions
PRINCIPAL_DEPARTMENT_LINKED = "principal.department.linked"
PRINCIPAL_DEPARTMENT_UNLINKED = "principal.department.unlinked"
class OIDCGrantType(str, Enum): class OIDCGrantType(str, Enum):
@@ -0,0 +1,24 @@
"""Add allowed_cors_origins to oidc_clients.
Revision ID: b7e3f1a92c4d
Revises: a1b2c3d4e5f6
Create Date: 2026-04-27 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'b7e3f1a92c4d'
down_revision = 'a1b2c3d4e5f6'
branch_labels = None
depends_on = None
def upgrade():
op.add_column('oidc_clients', sa.Column('allowed_cors_origins', sa.JSON(), nullable=True))
def downgrade():
op.drop_column('oidc_clients', 'allowed_cors_origins')
@@ -0,0 +1,122 @@
"""Consolidate user and superadmin sessions into unified sessions table.
Revision ID: c8d2e4f6a1b3
Revises: b7e3f1a92c4d
Create Date: 2026-04-28 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'c8d2e4f6a1b3'
down_revision = 'b7e3f1a92c4d'
branch_labels = None
depends_on = None
def upgrade():
# 1. Add new columns (nullable initially for data migration)
op.add_column('sessions', sa.Column('owner_type', sa.String(20), nullable=True))
op.add_column('sessions', sa.Column('owner_id', sa.String(36), nullable=True))
# 2. Backfill existing user sessions: owner_type = 'user', owner_id = user_id
op.execute("""
UPDATE sessions
SET owner_type = 'user',
owner_id = user_id
WHERE owner_type IS NULL
""")
# 3. Migrate superadmin sessions into the sessions table
op.execute("""
INSERT INTO sessions (
id, owner_type, owner_id, token, status,
ip_address, user_agent, device_info,
expires_at, last_activity_at, revoked_at, revoked_reason,
is_compliance_only, created_at, updated_at, deleted_at
)
SELECT
id, 'superadmin', superadmin_id, token, 'active',
ip_address, user_agent, NULL,
expires_at, last_activity_at, revoked_at, revoked_reason,
FALSE, created_at, updated_at, deleted_at
FROM superadmin_sessions
""")
# 4. Make owner_type and owner_id NOT NULL
op.alter_column('sessions', 'owner_type', nullable=False)
op.alter_column('sessions', 'owner_id', nullable=False)
# 5. Make user_id nullable (no longer the sole owner reference)
op.alter_column('sessions', 'user_id', nullable=True)
# 6. Create indexes for efficient owner-scoped queries
op.create_index(
'ix_sessions_owner_type_owner_id',
'sessions',
['owner_type', 'owner_id']
)
op.create_index(
'ix_sessions_owner_type',
'sessions',
['owner_type']
)
op.create_index(
'ix_sessions_owner_id',
'sessions',
['owner_id']
)
# 7. Drop the now-redundant superadmin_sessions table
op.drop_table('superadmin_sessions')
def downgrade():
# 1. Recreate superadmin_sessions table
op.create_table(
'superadmin_sessions',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('superadmin_id', sa.String(36), sa.ForeignKey('superadmins.id'), nullable=False, index=True),
sa.Column('token', sa.String(255), unique=True, nullable=False, index=True),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('last_activity_at', sa.DateTime, nullable=False),
sa.Column('ip_address', sa.String(45), nullable=True),
sa.Column('user_agent', sa.Text, nullable=True),
sa.Column('revoked_at', sa.DateTime, nullable=True),
sa.Column('revoked_reason', sa.String(255), nullable=True),
sa.Column('created_at', sa.DateTime, nullable=False),
sa.Column('updated_at', sa.DateTime, nullable=False),
sa.Column('deleted_at', sa.DateTime, nullable=True),
)
# 2. Move superadmin sessions back to superadmin_sessions
op.execute("""
INSERT INTO superadmin_sessions (
id, superadmin_id, token, expires_at, last_activity_at,
ip_address, user_agent, revoked_at, revoked_reason,
created_at, updated_at, deleted_at
)
SELECT
id, owner_id, token, expires_at, last_activity_at,
ip_address, user_agent, revoked_at, revoked_reason,
created_at, updated_at, deleted_at
FROM sessions
WHERE owner_type = 'superadmin'
""")
# 3. Remove superadmin sessions from sessions table
op.execute("DELETE FROM sessions WHERE owner_type = 'superadmin'")
# 4. Drop indexes
op.drop_index('ix_sessions_owner_id', table_name='sessions')
op.drop_index('ix_sessions_owner_type', table_name='sessions')
op.drop_index('ix_sessions_owner_type_owner_id', table_name='sessions')
# 5. Remove new columns
op.drop_column('sessions', 'owner_id')
op.drop_column('sessions', 'owner_type')
# 6. Make user_id NOT NULL again
op.alter_column('sessions', 'user_id', nullable=False)
@@ -0,0 +1,186 @@
"""Superadmin session timeout integration tests.
Validates the absolute-only timeout policy for superadmin sessions.
Superadmin sessions do NOT have idle timeout — only absolute timeout.
"""
import pytest
import uuid
from datetime import datetime, timedelta, timezone
from tests.integration.client.base import ApiError
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def assert_success(response: dict, message_contains: str = "") -> dict:
"""Assert that an api_response-wrapped payload succeeded."""
data = response.get("data", {})
assert response.get("success") is not False, (
f"Expected success but got error: {response.get('message')}"
)
if message_contains:
assert message_contains.lower() in response.get("message", "").lower(), (
f"Expected message to contain '{message_contains}' but got: {response.get('message')}"
)
return data
def _get_session_row(integration_app, token: str):
"""Look up the Session model row for a given bearer token."""
from gatehouse_app.models.user.session import Session
with integration_app.app_context():
return Session.query.filter_by(token=token).first()
def _touch_session(integration_app, session_id: str, **updates):
"""Directly update columns on a Session row.
Only use this to simulate the passage of time — never to assert
internal state.
"""
from gatehouse_app.models.user.session import Session
with integration_app.app_context():
sess = Session.query.get(session_id)
for attr, value in updates.items():
setattr(sess, attr, value)
from gatehouse_app import db
db.session.commit()
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def superadmin_credentials(integration_app):
"""Create a superadmin and return login credentials."""
from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService
email = f"admin_{uuid.uuid4().hex[:8]}@gatehouse.local"
password = "SuperAdmin123!"
with integration_app.app_context():
sa = SuperadminAuthService.create_superadmin(
email=email,
credential=password,
full_name="Test Superadmin",
)
return {"id": str(sa.id), "email": email, "password": password}
@pytest.fixture
def logged_in_superadmin(integration_client, superadmin_credentials, integration_app):
"""Log in as superadmin and return session metadata.
Returns dict with ``superadmin``, ``token``, ``session_id``, ``session_row``.
"""
creds = superadmin_credentials
resp = integration_client.post(
"/api/v1/superadmin/auth/login",
data={"email": creds["email"], "password": creds["password"]},
)
data = assert_success(resp)
token = data["token"]
session_row = _get_session_row(integration_app, token)
assert session_row is not None, "Session row should exist after superadmin login"
return {
"superadmin": creds,
"token": token,
"session_id": session_row.id,
"session_row": session_row,
}
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestSuperadminSessionTimeouts:
"""Absolute-only timeout behavior for superadmin sessions."""
def test_superadmin_session_valid_before_timeout(
self, integration_client, logged_in_superadmin,
):
"""SA-SESS-01 — Fresh superadmin session is accepted."""
integration_client.set_token(logged_in_superadmin["token"])
result = integration_client.get("/api/v1/superadmin/auth/me")
data = assert_success(result)
assert "superadmin" in data
def test_absolute_timeout_rejects_superadmin(
self, integration_client, logged_in_superadmin, integration_app,
):
"""SA-SESS-02 — Superadmin session rejected after absolute timeout.
Push ``created_at`` far into the past. The session must be
rejected even though ``last_activity_at`` is fresh.
"""
_touch_session(
integration_app,
logged_in_superadmin["session_id"],
created_at=datetime.now(timezone.utc) - timedelta(days=1),
last_activity_at=datetime.now(timezone.utc),
)
integration_client.set_token(logged_in_superadmin["token"])
with pytest.raises(ApiError) as exc_info:
integration_client.get("/api/v1/superadmin/auth/me")
assert exc_info.value.status_code == 401
def test_idle_timeout_does_NOT_reject_superadmin(
self, integration_client, logged_in_superadmin, integration_app,
):
"""SA-SESS-03 — Superadmin sessions have NO idle timeout.
Push ``last_activity_at`` far into the past but keep
``created_at`` recent. The session should still be valid
because superadmin sessions only use absolute timeout.
"""
_touch_session(
integration_app,
logged_in_superadmin["session_id"],
last_activity_at=datetime.now(timezone.utc) - timedelta(hours=1),
)
integration_client.set_token(logged_in_superadmin["token"])
result = integration_client.get("/api/v1/superadmin/auth/me")
data = assert_success(result)
assert "superadmin" in data
def test_revoked_superadmin_session_rejected(
self, integration_client, logged_in_superadmin,
):
"""SA-SESS-04 — Revoked superadmin session is rejected."""
integration_client.set_token(logged_in_superadmin["token"])
# Logout revokes the session
integration_client.post("/api/v1/superadmin/auth/logout")
integration_client.clear_token()
# Try using the old token
integration_client.set_token(logged_in_superadmin["token"])
with pytest.raises(ApiError) as exc_info:
integration_client.get("/api/v1/superadmin/auth/me")
assert exc_info.value.status_code == 401
def test_superadmin_session_has_owner_type(
self, integration_app, logged_in_superadmin,
):
"""SA-SESS-05 — Superadmin session row has owner_type='superadmin'."""
from gatehouse_app.models.user.session import Session
from gatehouse_app.utils.constants import SessionType
with integration_app.app_context():
sess = Session.query.get(logged_in_superadmin["session_id"])
assert sess is not None
assert sess.owner_type == SessionType.SUPERADMIN
assert sess.owner_id == logged_in_superadmin["superadmin"]["id"]
+503
View File
@@ -0,0 +1,503 @@
"""Unit tests for per-client CORS feature.
WHAT: Tests for per-client CORS origin resolution, including OIDCClient
model methods, request client_id extraction, effective origin
resolution, and integration with the CORS middleware.
WHY: Per-client CORS prevents one OIDC client from making cross-origin
requests meant for another; misconfiguration breaks browser flows.
EXPECTED: Correct origin derivation, proper client_id extraction, and
correct CORS headers on OIDC endpoints.
"""
import base64
import json
from unittest.mock import patch
import pytest
from flask import Flask, request as flask_request
import gatehouse_app.middleware.cors as cors_module
from gatehouse_app.middleware.cors import (
_get_oidc_client_id_from_request,
_get_effective_cors_origins,
setup_cors,
)
# ---------------------------------------------------------------------------
# Helper: build a lightweight stub that quacks like OIDCClient
# ---------------------------------------------------------------------------
class StubClient:
"""Minimal stand-in for OIDCClient -- no SQLAlchemy, no DB needed."""
def __init__(self, *, allowed_cors_origins=None, redirect_uris=None):
self.allowed_cors_origins = allowed_cors_origins
self.redirect_uris = redirect_uris or []
def get_effective_origins(self):
from urllib.parse import urlparse
if self.allowed_cors_origins is None:
return None
if "+" in self.allowed_cors_origins:
origins = set()
for uri in self.redirect_uris:
parsed = urlparse(uri)
if parsed.scheme and parsed.hostname:
port = f":{parsed.port}" if parsed.port else ""
origins.add(f"{parsed.scheme}://{parsed.hostname}{port}")
return sorted(origins)
return list(self.allowed_cors_origins)
def is_origin_allowed(self, origin):
effective = self.get_effective_origins()
if effective is None:
return None
if "*" in effective:
return True
return origin in effective
def _basic_auth_header(client_id, secret="secret"):
"""Return a 'Basic <b64>' Authorization header value."""
return "Basic " + base64.b64encode(f"{client_id}:{secret}".encode()).decode()
# ---------------------------------------------------------------------------
# OIDCClient.get_effective_origins
# ---------------------------------------------------------------------------
class TestGetEffectiveOrigins:
def test_returns_none_when_allowed_cors_origins_is_none(self):
"""TEST: PCORS-GE-01 -- None config signals 'use global'."""
client = StubClient(allowed_cors_origins=None)
assert client.get_effective_origins() is None
def test_derives_from_redirect_uris_when_plus_sign(self):
"""TEST: PCORS-GE-02 -- '+' in list derives origins from redirect_uris."""
client = StubClient(
allowed_cors_origins=["+"],
redirect_uris=[
"https://app.example.com/callback",
"http://localhost:3000/callback",
],
)
assert client.get_effective_origins() == [
"http://localhost:3000",
"https://app.example.com",
]
def test_derives_with_port(self):
"""TEST: PCORS-GE-03 -- Non-standard ports are preserved."""
client = StubClient(
allowed_cors_origins=["+"],
redirect_uris=["https://app.example.com:8443/cb"],
)
assert client.get_effective_origins() == ["https://app.example.com:8443"]
def test_deduplicates_derived_origins(self):
"""TEST: PCORS-GE-04 -- Duplicate redirect URIs produce unique origins."""
client = StubClient(
allowed_cors_origins=["+"],
redirect_uris=[
"https://app.example.com/cb1",
"https://app.example.com/cb2",
],
)
assert client.get_effective_origins() == ["https://app.example.com"]
def test_returns_list_as_is_when_normal_list(self):
"""TEST: PCORS-GE-05 -- Normal list is returned unchanged."""
origins = ["https://a.com", "https://b.com"]
client = StubClient(allowed_cors_origins=origins)
assert client.get_effective_origins() == origins
def test_returns_wildcard_list_as_is(self):
"""TEST: PCORS-GE-06 -- ['*'] is returned (handled downstream)."""
client = StubClient(allowed_cors_origins=["*"])
assert client.get_effective_origins() == ["*"]
def test_empty_redirect_uris_with_plus_returns_empty(self):
"""TEST: PCORS-GE-07 -- '+' with empty redirect_uris yields empty list."""
client = StubClient(
allowed_cors_origins=["+"],
redirect_uris=[],
)
assert client.get_effective_origins() == []
def test_skips_malformed_redirect_uris(self):
"""TEST: PCORS-GE-08 -- URIs without scheme/host are skipped."""
client = StubClient(
allowed_cors_origins=["+"],
redirect_uris=["not-a-uri", "https://good.com/cb"],
)
assert client.get_effective_origins() == ["https://good.com"]
# ---------------------------------------------------------------------------
# OIDCClient.is_origin_allowed
# ---------------------------------------------------------------------------
class TestIsOriginAllowed:
def test_returns_none_when_no_per_client_config(self):
"""TEST: PCORS-IO-01 -- None config defers to global CORS."""
client = StubClient(allowed_cors_origins=None)
assert client.is_origin_allowed("https://anything.com") is None
def test_returns_true_when_wildcard(self):
"""TEST: PCORS-IO-02 -- '*' in effective origins allows any origin."""
client = StubClient(allowed_cors_origins=["*"])
assert client.is_origin_allowed("https://evil.com") is True
def test_returns_true_for_matching_origin(self):
"""TEST: PCORS-IO-03 -- Matching origin is allowed."""
client = StubClient(
allowed_cors_origins=["https://app.example.com", "https://other.com"],
)
assert client.is_origin_allowed("https://app.example.com") is True
def test_returns_false_for_non_matching_origin(self):
"""TEST: PCORS-IO-04 -- Non-matching origin is rejected."""
client = StubClient(allowed_cors_origins=["https://app.example.com"])
assert client.is_origin_allowed("https://evil.com") is False
def test_returns_false_for_empty_list(self):
"""TEST: PCORS-IO-05 -- Empty list rejects everything."""
client = StubClient(allowed_cors_origins=[])
assert client.is_origin_allowed("https://anything.com") is False
# ---------------------------------------------------------------------------
# _get_oidc_client_id_from_request
# ---------------------------------------------------------------------------
class TestGetOidcClientIdFromRequest:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
def test_extracts_from_basic_auth(self, app):
"""TEST: PCORS-CI-01 -- Basic Auth header yields client_id."""
with app.test_request_context(
"/oidc/token",
method="POST",
headers={"Authorization": _basic_auth_header("my-client")},
):
assert _get_oidc_client_id_from_request() == "my-client"
def test_extracts_from_form_body(self, app):
"""TEST: PCORS-CI-02 -- Form-encoded body yields client_id."""
with app.test_request_context(
"/oidc/token",
method="POST",
data={"client_id": "form-client", "grant_type": "client_credentials"},
):
assert _get_oidc_client_id_from_request() == "form-client"
def test_extracts_from_json_body(self, app):
"""TEST: PCORS-CI-03 -- JSON body yields client_id."""
with app.test_request_context(
"/oidc/token",
method="POST",
data=json.dumps({"client_id": "json-client", "grant_type": "client_credentials"}),
content_type="application/json",
):
assert _get_oidc_client_id_from_request() == "json-client"
def test_extracts_from_bearer_jwt(self, app):
"""TEST: PCORS-CI-04 -- Bearer JWT payload yields client_id."""
payload = base64.urlsafe_b64encode(
json.dumps({"client_id": "jwt-client"}).encode()
).rstrip(b"=").decode()
token = f"eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.{payload}.sig"
with app.test_request_context(
"/oidc/userinfo",
method="GET",
headers={"Authorization": f"Bearer {token}"},
):
assert _get_oidc_client_id_from_request() == "jwt-client"
def test_returns_none_for_non_oidc_endpoint(self, app):
"""TEST: PCORS-CI-05 -- Non-OIDC path returns None."""
with app.test_request_context(
"/api/v1/users",
method="GET",
headers={"Authorization": _basic_auth_header("x")},
):
assert _get_oidc_client_id_from_request() is None
def test_returns_none_when_no_client_id_found(self, app):
"""TEST: PCORS-CI-06 -- OIDC token endpoint with no credentials returns None."""
with app.test_request_context(
"/oidc/token",
method="POST",
data={"grant_type": "client_credentials"},
):
assert _get_oidc_client_id_from_request() is None
def test_extracts_from_revoke_endpoint(self, app):
"""TEST: PCORS-CI-07 -- /oidc/revoke also accepts Basic Auth."""
with app.test_request_context(
"/oidc/revoke",
method="POST",
headers={"Authorization": _basic_auth_header("rev-client")},
):
assert _get_oidc_client_id_from_request() == "rev-client"
def test_extracts_from_introspect_endpoint(self, app):
"""TEST: PCORS-CI-08 -- /oidc/introspect also accepts Basic Auth."""
with app.test_request_context(
"/oidc/introspect",
method="POST",
headers={"Authorization": _basic_auth_header("int-client")},
):
assert _get_oidc_client_id_from_request() == "int-client"
def test_returns_none_for_options_preflight(self, app):
"""TEST: PCORS-CI-09 -- OPTIONS preflight cannot carry client credentials."""
with app.test_request_context(
"/oidc/token",
method="OPTIONS",
headers={
"Origin": "https://app.com",
"Access-Control-Request-Method": "POST",
},
):
assert _get_oidc_client_id_from_request() is None
# ---------------------------------------------------------------------------
# _get_effective_cors_origins
# ---------------------------------------------------------------------------
class TestGetEffectiveCorsOrigins:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
app.config["CORS_ORIGINS"] = ["https://global.com"]
return app
def test_global_config_for_non_oidc_endpoint(self, app):
"""TEST: PCORS-EO-01 -- Non-OIDC path always uses global config."""
with app.test_request_context("/api/v1/users", method="GET"):
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://global.com"]
def test_per_client_origins_for_oidc_endpoint(self, app):
"""TEST: PCORS-EO-02 -- OIDC endpoint with configured client uses per-client origins."""
fake_client = StubClient(allowed_cors_origins=["https://client.com"])
with app.test_request_context(
"/oidc/token",
method="POST",
headers={"Authorization": _basic_auth_header("test-client")},
):
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://client.com"]
def test_fallback_when_client_not_found(self, app):
"""TEST: PCORS-EO-03 -- Unknown client_id falls back to global config."""
with app.test_request_context(
"/oidc/token",
method="POST",
headers={"Authorization": _basic_auth_header("unknown")},
):
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = None
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://global.com"]
def test_fallback_when_allowed_cors_origins_is_none(self, app):
"""TEST: PCORS-EO-04 -- Client with None origins falls back to global."""
fake_client = StubClient(allowed_cors_origins=None)
with app.test_request_context(
"/oidc/token",
method="POST",
headers={"Authorization": _basic_auth_header("test-client")},
):
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://global.com"]
def test_fallback_on_db_error(self, app):
"""TEST: PCORS-EO-05 -- Database exception falls back to global config."""
with app.test_request_context(
"/oidc/token",
method="POST",
headers={"Authorization": _basic_auth_header("test-client")},
):
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.side_effect = Exception("DB down")
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://global.com"]
def test_fallback_when_no_client_id_extracted(self, app):
"""TEST: PCORS-EO-06 -- OIDC path with no credentials falls back to global."""
with app.test_request_context(
"/oidc/token",
method="POST",
data={"grant_type": "client_credentials"},
):
result = _get_effective_cors_origins(app, flask_request)
assert result == ["https://global.com"]
# ---------------------------------------------------------------------------
# Integration: OIDC endpoint CORS headers
# ---------------------------------------------------------------------------
class TestOidcEndpointCorsIntegration:
@pytest.fixture
def app_with_global_and_client(self):
"""Flask app with global CORS and route stubs for integration tests."""
app = Flask(__name__)
app.config["TESTING"] = True
app.config["CORS_ORIGINS"] = ["https://global.com"]
app.config["CORS_SUPPORTS_CREDENTIALS"] = True
@app.route("/oidc/token", methods=["POST", "OPTIONS"])
def oidc_token():
return {"status": "ok"}, 200
@app.route("/api/v1/users", methods=["GET", "OPTIONS"])
def api_users():
return {"users": []}, 200
setup_cors(app)
return app
def test_post_oidc_with_per_client_origin_includes_cors_headers(
self, app_with_global_and_client
):
"""TEST: PCORS-INT-01 -- POST to /oidc/token with per-client origin and
Basic Auth includes CORS headers. Per-client CORS applies to the actual
request (which carries credentials), not the preflight."""
fake_client = StubClient(allowed_cors_origins=["https://client-app.com"])
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
with app_with_global_and_client.test_client() as client:
resp = client.post(
"/oidc/token",
headers={
"Origin": "https://client-app.com",
"Authorization": _basic_auth_header("test-client"),
},
)
assert resp.status_code == 200
assert (
resp.headers.get("Access-Control-Allow-Origin")
== "https://client-app.com"
)
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
def test_post_oidc_with_non_matching_per_client_origin_no_cors_headers(
self, app_with_global_and_client
):
"""TEST: PCORS-INT-02 -- POST with origin not in per-client list has no
CORS headers."""
fake_client = StubClient(allowed_cors_origins=["https://allowed.com"])
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
with app_with_global_and_client.test_client() as client:
resp = client.post(
"/oidc/token",
headers={
"Origin": "https://evil.com",
"Authorization": _basic_auth_header("test-client"),
},
)
assert resp.status_code == 200
assert resp.headers.get("Access-Control-Allow-Origin") is None
def test_post_oidc_wildcard_client_echoes_origin(self, app_with_global_and_client):
"""TEST: PCORS-INT-03 -- Client with '*' echoes the request origin."""
fake_client = StubClient(allowed_cors_origins=["*"])
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
with app_with_global_and_client.test_client() as client:
resp = client.post(
"/oidc/token",
headers={
"Origin": "https://any-origin.com",
"Authorization": _basic_auth_header("test-client"),
},
)
assert resp.status_code == 200
assert (
resp.headers.get("Access-Control-Allow-Origin")
== "https://any-origin.com"
)
def test_preflight_oidc_falls_back_to_global(self, app_with_global_and_client):
"""TEST: PCORS-INT-04 -- OPTIONS preflight cannot carry client credentials,
so it uses global CORS config. A preflight from a per-client-only origin
that is not in the global list will not receive CORS headers."""
fake_client = StubClient(allowed_cors_origins=["https://client-app.com"])
with patch.object(cors_module, "OIDCClient") as MockModel:
MockModel.query.filter_by.return_value.first.return_value = fake_client
with app_with_global_and_client.test_client() as client:
resp = client.options(
"/oidc/token",
headers={
"Origin": "https://client-app.com",
"Access-Control-Request-Method": "POST",
},
)
# Origin not in global list; no CORS headers on preflight
assert resp.headers.get("Access-Control-Allow-Origin") is None
def test_preflight_oidc_with_global_origin_succeeds(self, app_with_global_and_client):
"""TEST: PCORS-INT-05 -- OPTIONS preflight from a globally-allowed origin
returns 204 with CORS headers even for OIDC endpoints."""
with app_with_global_and_client.test_client() as client:
resp = client.options(
"/oidc/token",
headers={
"Origin": "https://global.com",
"Access-Control-Request-Method": "POST",
},
)
assert resp.status_code == 204
assert resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
def test_non_oidc_endpoint_uses_global_cors(self, app_with_global_and_client):
"""TEST: PCORS-INT-06 -- Non-OIDC endpoint uses global CORS config."""
with app_with_global_and_client.test_client() as client:
resp = client.options(
"/api/v1/users",
headers={
"Origin": "https://global.com",
"Access-Control-Request-Method": "GET",
},
)
assert resp.status_code == 204
assert resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
def test_post_oidc_no_auth_uses_global_cors(self, app_with_global_and_client):
"""TEST: PCORS-INT-07 -- POST to OIDC endpoint without credentials uses
global CORS (cannot identify client)."""
with app_with_global_and_client.test_client() as client:
resp = client.post(
"/oidc/token",
headers={"Origin": "https://global.com"},
)
assert resp.status_code == 200
assert (
resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
)