@@ -0,0 +1,240 @@
|
||||
# Per-Client CORS Origins for OIDC Endpoints
|
||||
|
||||
## Overview
|
||||
|
||||
Gatehouse OIDC now supports **per-client CORS origins**. This allows each OIDC client to declare which browser origins are permitted to make cross-origin requests to OIDC endpoints (`/oidc/token`, `/oidc/revoke`, `/oidc/userinfo`, `/oidc/introspect`).
|
||||
|
||||
Previously, CORS was controlled by a single server-wide `CORS_ORIGINS` environment variable. If your SPA's origin wasn't in that list, the browser would block requests to OIDC endpoints — even if your OIDC client was properly configured.
|
||||
|
||||
## How It Works
|
||||
|
||||
### The Problem
|
||||
|
||||
When a browser-based SPA (e.g., running at `http://localhost:8080`) exchanges an authorization code for tokens, it makes a POST request to `/oidc/token`. The browser sends a preflight OPTIONS request first, and the server must respond with CORS headers allowing the SPA's origin.
|
||||
|
||||
Previously, if `http://localhost:8080` wasn't in the server's `CORS_ORIGINS` env var, the preflight would fail and the SPA couldn't get tokens.
|
||||
|
||||
### The Solution
|
||||
|
||||
Each OIDC client can now declare its own `allowed_cors_origins`. When a request hits an OIDC endpoint, the server checks the client's CORS configuration first, then falls back to the global config.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Setting CORS Origins on an OIDC Client
|
||||
|
||||
When creating or updating an OIDC client, set the `allowed_cors_origins` field:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "My SPA",
|
||||
"client_id": "oidc_myapp",
|
||||
"redirect_uris": ["http://localhost:8080/callback", "https://app.example.com/callback"],
|
||||
"allowed_cors_origins": ["http://localhost:8080", "https://app.example.com"],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"scopes": ["openid", "profile", "email"]
|
||||
}
|
||||
```
|
||||
|
||||
### Auto-Derive from Redirect URIs
|
||||
|
||||
Set `allowed_cors_origins` to `["+"]` to automatically derive CORS origins from the client's `redirect_uris`. The server extracts the scheme, hostname, and port from each redirect URI.
|
||||
|
||||
```json
|
||||
{
|
||||
"redirect_uris": ["http://localhost:8080/callback", "https://app.example.com/callback"],
|
||||
"allowed_cors_origins": ["+"]
|
||||
}
|
||||
```
|
||||
|
||||
This is equivalent to:
|
||||
|
||||
```json
|
||||
{
|
||||
"allowed_cors_origins": ["http://localhost:8080", "https://app.example.com"]
|
||||
}
|
||||
```
|
||||
|
||||
### Use Global Config (Default)
|
||||
|
||||
Set `allowed_cors_origins` to `null` (or omit it) to use the server's global `CORS_ORIGINS` config. This is the default behavior for existing clients.
|
||||
|
||||
```json
|
||||
{
|
||||
"allowed_cors_origins": null
|
||||
}
|
||||
```
|
||||
|
||||
### Allow All Origins (Not Recommended)
|
||||
|
||||
Set `allowed_cors_origins` to `["*"]` to allow any origin. **This is not recommended for production.**
|
||||
|
||||
```json
|
||||
{
|
||||
"allowed_cors_origins": ["*"]
|
||||
}
|
||||
```
|
||||
|
||||
## Affected Endpoints
|
||||
|
||||
The following OIDC endpoints support per-client CORS:
|
||||
|
||||
| Endpoint | Method | How Client is Identified |
|
||||
|---|---|---|
|
||||
| `/oidc/token` | POST | `client_id` in request body or Basic Auth header |
|
||||
| `/oidc/revoke` | POST | `client_id` in request body or Basic Auth header |
|
||||
| `/oidc/introspect` | POST | `client_id` in request body or Basic Auth header |
|
||||
| `/oidc/userinfo` | GET/POST | `client_id` extracted from Bearer token |
|
||||
|
||||
## SPA Integration Guide
|
||||
|
||||
### Step 1: Register Your OIDC Client
|
||||
|
||||
Register your SPA as an OIDC client with the correct redirect URIs and CORS origins:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "My React App",
|
||||
"redirect_uris": ["http://localhost:3000/callback"],
|
||||
"allowed_cors_origins": ["http://localhost:3000"],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"scopes": ["openid", "profile", "email"],
|
||||
"is_confidential": false,
|
||||
"require_pkce": true
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Use PKCE (Required for Public Clients)
|
||||
|
||||
Gatehouse requires PKCE for public clients. Generate a code verifier and challenge before redirecting to the authorize endpoint:
|
||||
|
||||
```javascript
|
||||
// Generate PKCE
|
||||
const codeVerifier = generateRandomString(128);
|
||||
const codeChallenge = await sha256(codeVerifier);
|
||||
const state = generateRandomString(32);
|
||||
|
||||
// Store verifier for later
|
||||
sessionStorage.setItem('pkce_verifier', codeVerifier);
|
||||
|
||||
// Redirect to authorize
|
||||
const authUrl = new URL('https://api.example.com/api/v1/oidc/authorize');
|
||||
authUrl.searchParams.set('response_type', 'code');
|
||||
authUrl.searchParams.set('client_id', 'oidc_myapp');
|
||||
authUrl.searchParams.set('redirect_uri', 'http://localhost:3000/callback');
|
||||
authUrl.searchParams.set('scope', 'openid profile email');
|
||||
authUrl.searchParams.set('state', state);
|
||||
authUrl.searchParams.set('code_challenge', codeChallenge);
|
||||
authUrl.searchParams.set('code_challenge_method', 'S256');
|
||||
|
||||
window.location.href = authUrl.toString();
|
||||
```
|
||||
|
||||
### Step 3: Exchange Code for Tokens
|
||||
|
||||
After the user authenticates and is redirected back to your callback page, exchange the authorization code for tokens:
|
||||
|
||||
```javascript
|
||||
// Extract code from URL
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
const code = params.get('code');
|
||||
const state = params.get('state');
|
||||
|
||||
// Verify state matches
|
||||
if (state !== sessionStorage.getItem('pkce_state')) {
|
||||
throw new Error('State mismatch');
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
const response = await fetch('https://api.example.com/api/v1/oidc/token', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
grant_type: 'authorization_code',
|
||||
code: code,
|
||||
redirect_uri: 'http://localhost:3000/callback',
|
||||
client_id: 'oidc_myapp',
|
||||
code_verifier: sessionStorage.getItem('pkce_verifier'),
|
||||
}),
|
||||
});
|
||||
|
||||
const tokens = await response.json();
|
||||
// tokens.access_token, tokens.id_token, tokens.refresh_token
|
||||
```
|
||||
|
||||
The server will return CORS headers because `http://localhost:3000` is in the client's `allowed_cors_origins`.
|
||||
|
||||
### Step 4: Refresh Tokens
|
||||
|
||||
```javascript
|
||||
const response = await fetch('https://api.example.com/api/v1/oidc/token', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: storedRefreshToken,
|
||||
client_id: 'oidc_myapp',
|
||||
}),
|
||||
});
|
||||
```
|
||||
|
||||
### Step 5: Call UserInfo
|
||||
|
||||
```javascript
|
||||
const response = await fetch('https://api.example.com/api/v1/oidc/userinfo', {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
const userInfo = await response.json();
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "CORS error" when exchanging code for tokens
|
||||
|
||||
**Cause**: Your SPA's origin is not in the client's `allowed_cors_origins` or the server's global `CORS_ORIGINS`.
|
||||
|
||||
**Fix**: Add your SPA's origin to the client's `allowed_cors_origins`:
|
||||
```json
|
||||
{
|
||||
"allowed_cors_origins": ["http://localhost:3000"]
|
||||
}
|
||||
```
|
||||
|
||||
### "CORS error" on preflight OPTIONS request
|
||||
|
||||
**Cause**: The preflight request doesn't carry client credentials, so the server can't identify which client to check CORS origins for. It falls back to the global `CORS_ORIGINS`.
|
||||
|
||||
**Fix**: Either add your origin to the global `CORS_ORIGINS` env var, or ensure the actual POST request (after preflight) includes the `client_id` in the request body.
|
||||
|
||||
### CORS works for `/oidc/token` but not `/oidc/userinfo`
|
||||
|
||||
**Cause**: The userinfo endpoint identifies the client from the Bearer token. If the token doesn't contain a `client_id` claim, the server falls back to global config.
|
||||
|
||||
**Fix**: Ensure your access tokens include the `client_id` claim (this is the default behavior).
|
||||
|
||||
## API Reference
|
||||
|
||||
### OIDCClient Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `allowed_cors_origins` | `string[]` or `null` | List of allowed browser origins. `null` = use global config. `["+"]` = auto-derive from redirect URIs. `["*"]` = allow all (not recommended). |
|
||||
|
||||
### CORS Headers Returned
|
||||
|
||||
When a request's origin matches the client's allowed origins:
|
||||
|
||||
```
|
||||
Access-Control-Allow-Origin: <request-origin>
|
||||
Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS
|
||||
Access-Control-Allow-Headers: Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token
|
||||
Access-Control-Allow-Credentials: true
|
||||
Access-Control-Max-Age: 3600
|
||||
```
|
||||
@@ -10,6 +10,8 @@ from gatehouse_app.models import Department, DepartmentMembership
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
|
||||
class DepartmentCreateSchema(Schema):
|
||||
@@ -127,6 +129,15 @@ def create_department(org_id):
|
||||
db.session.add(dept)
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.DEPARTMENT_CREATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="department",
|
||||
resource_id=str(dept.id),
|
||||
description=f"Department '{dept.name}' created",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"department": dept.to_dict()},
|
||||
message="Department created successfully",
|
||||
@@ -255,6 +266,15 @@ def update_department(org_id, dept_id):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.DEPARTMENT_UPDATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="department",
|
||||
resource_id=str(dept.id),
|
||||
description=f"Department '{dept.name}' updated",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"department": dept.to_dict()},
|
||||
message="Department updated successfully",
|
||||
@@ -308,6 +328,15 @@ def delete_department(org_id, dept_id):
|
||||
dept.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.DEPARTMENT_DELETED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="department",
|
||||
resource_id=str(dept.id),
|
||||
description=f"Department '{dept.name}' deleted",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
message="Department deleted successfully",
|
||||
)
|
||||
@@ -461,6 +490,15 @@ def add_department_member(org_id, dept_id):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.DEPARTMENT_MEMBER_ADDED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="user",
|
||||
resource_id=str(user.id),
|
||||
description=f"Added user {user.email} to department '{dept.name}'",
|
||||
)
|
||||
|
||||
member_dict = membership.to_dict()
|
||||
member_dict["user"] = user.to_dict()
|
||||
|
||||
@@ -533,6 +571,15 @@ def remove_department_member(org_id, dept_id, user_id):
|
||||
membership.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.DEPARTMENT_MEMBER_REMOVED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"Removed user from department '{dept.name}'",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
message="Member removed successfully",
|
||||
)
|
||||
@@ -699,5 +746,14 @@ def set_dept_cert_policy(org_id, dept_id):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.DEPARTMENT_CERT_POLICY_UPDATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="department",
|
||||
resource_id=str(dept_id),
|
||||
description=f"Certificate policy updated for department '{dept.name}'",
|
||||
)
|
||||
|
||||
return api_response(data={"cert_policy": policy.to_dict()}, message="Certificate policy saved")
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ from flask import g, request
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.utils.decorators import login_required
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
|
||||
@api_v1_bp.route("/admin/oauth/providers", methods=["GET"])
|
||||
@@ -78,6 +80,14 @@ def admin_configure_app_provider(provider: str):
|
||||
db.session.add(cfg)
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_CONFIG_UPDATE if cfg else AuditAction.EXTERNAL_AUTH_CONFIG_CREATE,
|
||||
user_id=g.current_user.id,
|
||||
resource_type="oauth_provider",
|
||||
resource_id=provider,
|
||||
description=f"OAuth provider '{provider}' configured (enabled={cfg.is_enabled})",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"provider": {"id": provider, "client_id": cfg.client_id, "is_enabled": cfg.is_enabled}},
|
||||
message=f"{provider.capitalize()} OAuth provider configured successfully",
|
||||
@@ -104,4 +114,13 @@ def admin_delete_app_provider(provider: str):
|
||||
return api_response(success=False, message=f"Provider '{provider}' is not configured", status=404, error_type="NOT_FOUND")
|
||||
|
||||
cfg.delete()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_CONFIG_DELETE,
|
||||
user_id=g.current_user.id,
|
||||
resource_type="oauth_provider",
|
||||
resource_id=provider,
|
||||
description=f"OAuth provider '{provider}' configuration removed",
|
||||
)
|
||||
|
||||
return api_response(message=f"{provider.capitalize()} OAuth provider configuration removed")
|
||||
|
||||
@@ -26,6 +26,9 @@ from gatehouse_app.exceptions.auth_exceptions import (
|
||||
AccountSuspendedError,
|
||||
AccountInactiveError,
|
||||
)
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -849,6 +852,18 @@ def oidc_register():
|
||||
)
|
||||
client.save()
|
||||
|
||||
OIDCAuditService.log_event(
|
||||
event_type="client_registration",
|
||||
client_id=client_id,
|
||||
user_id=g.current_user.id if hasattr(g, "current_user") else None,
|
||||
success=True,
|
||||
metadata={
|
||||
"client_name": client_name,
|
||||
"redirect_uris": redirect_uris,
|
||||
"organization_id": str(organization.id),
|
||||
},
|
||||
)
|
||||
|
||||
response = jsonify({
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
|
||||
@@ -8,6 +8,8 @@ from gatehouse_app.utils.decorators import login_required, require_admin, full_a
|
||||
from gatehouse_app.models.organization import OrganizationApiKey
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
|
||||
class ApiKeyCreateSchema(Schema):
|
||||
@@ -130,7 +132,16 @@ def create_api_key(org_id):
|
||||
name=data["name"],
|
||||
description=data.get("description"),
|
||||
)
|
||||
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_API_KEY_CREATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="api_key",
|
||||
resource_id=str(api_key.id),
|
||||
description=f"API key '{api_key.name}' created",
|
||||
)
|
||||
|
||||
# Return the key data with the plain text key (only on creation)
|
||||
key_dict = api_key.to_dict()
|
||||
key_dict["key"] = plain_key # Include plain text only on creation
|
||||
@@ -219,9 +230,18 @@ def update_api_key(org_id, key_id):
|
||||
api_key.name = data["name"]
|
||||
if "description" in data:
|
||||
api_key.description = data["description"]
|
||||
|
||||
|
||||
api_key.save()
|
||||
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_API_KEY_UPDATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="api_key",
|
||||
resource_id=str(api_key.id),
|
||||
description=f"API key '{api_key.name}' updated",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"api_key": api_key.to_dict()},
|
||||
message="API key updated successfully",
|
||||
@@ -293,7 +313,16 @@ def delete_api_key(org_id, key_id):
|
||||
|
||||
# Soft delete the API key
|
||||
api_key.delete(soft=True)
|
||||
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_API_KEY_DELETED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="api_key",
|
||||
resource_id=str(api_key.id),
|
||||
description=f"API key '{api_key.name}' deleted",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
message="API key deleted successfully",
|
||||
)
|
||||
|
||||
@@ -6,6 +6,8 @@ from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.utils.decorators import login_required, require_admin
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.api.v1.organizations._helpers import _get_system_ca_dict
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/cas", methods=["GET"])
|
||||
@@ -66,6 +68,16 @@ def update_org_ca(org_id, ca_id):
|
||||
ca.max_cert_validity_hours = data["max_cert_validity_hours"]
|
||||
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.CA_UPDATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="CA",
|
||||
resource_id=ca_id,
|
||||
description=f"CA '{ca.name}' updated",
|
||||
)
|
||||
|
||||
return api_response(data={"ca": ca.to_dict()}, message="CA updated successfully")
|
||||
except ValidationError as e:
|
||||
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
|
||||
@@ -150,6 +162,15 @@ def create_org_ca(org_id):
|
||||
return api_response(success=False, message="A CA with that name already exists in this organization (it may have been recently deleted — choose a different name).", status=400, error_type="DUPLICATE_NAME")
|
||||
raise
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.CA_CREATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="CA",
|
||||
resource_id=str(ca.id),
|
||||
description=f"CA '{ca.name}' created",
|
||||
)
|
||||
|
||||
return api_response(data={"ca": ca.to_dict()}, message="CA created successfully", status=201)
|
||||
except MaValidationError as e:
|
||||
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
|
||||
|
||||
@@ -5,6 +5,8 @@ from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
|
||||
from gatehouse_app.extensions import db, bcrypt
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/clients", methods=["GET"])
|
||||
@@ -28,6 +30,7 @@ def list_org_clients(org_id):
|
||||
"redirect_uris": c.redirect_uris,
|
||||
"scopes": c.scopes,
|
||||
"grant_types": c.grant_types,
|
||||
"allowed_cors_origins": c.allowed_cors_origins,
|
||||
"is_active": c.is_active,
|
||||
"created_at": c.created_at.isoformat() + "Z",
|
||||
}
|
||||
@@ -78,6 +81,15 @@ def create_org_client(org_id):
|
||||
db.session.add(client)
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_CLIENT_CREATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="oidc_client",
|
||||
resource_id=str(client.id),
|
||||
description=f"OIDC client '{client.name}' created",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"client": {
|
||||
@@ -125,6 +137,15 @@ def update_org_client(org_id, client_id):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_CLIENT_UPDATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="oidc_client",
|
||||
resource_id=str(client.id),
|
||||
description=f"OIDC client '{client.name}' updated",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"client": {
|
||||
@@ -154,4 +175,14 @@ def delete_org_client(org_id, client_id):
|
||||
|
||||
client.is_active = False
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_CLIENT_DEACTIVATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="oidc_client",
|
||||
resource_id=str(client.id),
|
||||
description=f"OIDC client '{client.name}' deactivated",
|
||||
)
|
||||
|
||||
return api_response(data={}, message="Client deactivated successfully")
|
||||
|
||||
@@ -7,6 +7,8 @@ from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
|
||||
from gatehouse_app.schemas.organization_schema import OrganizationCreateSchema, OrganizationUpdateSchema
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations", methods=["POST"])
|
||||
@@ -32,6 +34,14 @@ def create_organization():
|
||||
description=data.get("description"),
|
||||
logo_url=data.get("logo_url"),
|
||||
)
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_CREATE,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization",
|
||||
resource_id=str(org.id),
|
||||
description=f"Organization '{org.name}' created",
|
||||
)
|
||||
return api_response(data={"organization": org.to_dict()}, message="Organization created successfully", status=201)
|
||||
except ValidationError as e:
|
||||
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
|
||||
@@ -60,6 +70,14 @@ def update_organization(org_id):
|
||||
data = schema.load(request.json)
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
org = OrganizationService.update_organization(org=org, user_id=g.current_user.id, **data)
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_UPDATE,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization",
|
||||
resource_id=str(org.id),
|
||||
description=f"Organization '{org.name}' updated",
|
||||
)
|
||||
return api_response(data={"organization": org.to_dict()}, message="Organization updated successfully")
|
||||
except ValidationError as e:
|
||||
return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages)
|
||||
@@ -92,4 +110,12 @@ def delete_organization(org_id):
|
||||
)
|
||||
|
||||
OrganizationService.force_delete_organization(org=org, user_id=caller.id)
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_DELETE,
|
||||
user_id=caller.id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization",
|
||||
resource_id=str(org.id),
|
||||
description=f"Organization '{org.name}' deleted",
|
||||
)
|
||||
return api_response(message="Organization deleted successfully")
|
||||
|
||||
@@ -136,6 +136,15 @@ def cancel_org_invite(org_id, invite_id):
|
||||
return api_response(success=False, message="Invite not found", status=404)
|
||||
|
||||
invite.delete(soft=True)
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_INVITE_CANCELLED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="org_invite",
|
||||
resource_id=invite.id,
|
||||
metadata={"invited_email": invite.email, "role": invite.role},
|
||||
description=f"Invitation for {invite.email} cancelled",
|
||||
)
|
||||
return api_response(data={}, message="Invite cancelled")
|
||||
|
||||
|
||||
@@ -243,6 +252,30 @@ def accept_invite(token):
|
||||
|
||||
invite.accept()
|
||||
|
||||
if invite.invited_by and invite.invited_by.email:
|
||||
from gatehouse_app.services.email_templates import build_invite_accepted_html
|
||||
from gatehouse_app.services.notification_service import NotificationService
|
||||
|
||||
member_display = user.full_name or user.email
|
||||
inviter_display = invite.invited_by.full_name or invite.invited_by.email
|
||||
org_link = f"{current_app.config.get('APP_URL', '')}/organizations/{invite.organization_id}"
|
||||
|
||||
html_body = build_invite_accepted_html(
|
||||
inviter_name=inviter_display,
|
||||
member_name=member_display,
|
||||
member_email=user.email,
|
||||
org_name=invite.organization.name,
|
||||
role=invite.role,
|
||||
org_link=org_link,
|
||||
)
|
||||
|
||||
NotificationService._send_email_async(
|
||||
to_address=invite.invited_by.email,
|
||||
subject=f"{member_display} accepted your invitation to {invite.organization.name}",
|
||||
body=f"{member_display} has accepted your invitation to join {invite.organization.name} on Secuird.",
|
||||
html_body=html_body,
|
||||
)
|
||||
|
||||
has_webauthn = user.has_webauthn_enabled()
|
||||
has_totp = user.has_totp_enabled()
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ from gatehouse_app.utils.decorators import login_required, require_admin, full_a
|
||||
from gatehouse_app.schemas.organization_schema import InviteMemberSchema, UpdateMemberRoleSchema
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
from gatehouse_app.utils.constants import AuditAction, OrganizationRole
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/members", methods=["GET"])
|
||||
@@ -43,6 +44,14 @@ def add_organization_member(org_id):
|
||||
|
||||
role = OrganizationRole(data["role"])
|
||||
member = OrganizationService.add_member(org=org, user_id=user.id, role=role, inviter_id=g.current_user.id)
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_MEMBER_ADD,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org.id,
|
||||
resource_type="user",
|
||||
resource_id=str(user.id),
|
||||
description=f"Added user {user.email} to organization with role {role.value}",
|
||||
)
|
||||
member_dict = member.to_dict()
|
||||
member_dict["user"] = user.to_dict()
|
||||
return api_response(data={"member": member_dict}, message="Member added successfully", status=201)
|
||||
@@ -60,6 +69,14 @@ def remove_organization_member(org_id, user_id):
|
||||
OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id)
|
||||
except ValueError as e:
|
||||
return api_response(success=False, message=str(e), status=403, error_type="OWNER_PROTECTION")
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_MEMBER_REMOVE,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org.id,
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"Removed user {user_id} from organization",
|
||||
)
|
||||
return api_response(message="Member removed successfully")
|
||||
|
||||
|
||||
@@ -74,6 +91,14 @@ def update_member_role(org_id, user_id):
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
new_role = OrganizationRole(data["role"])
|
||||
member = OrganizationService.update_member_role(org=org, user_id=user_id, new_role=new_role, updater_id=g.current_user.id)
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_MEMBER_ROLE_CHANGE,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org.id,
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"Changed role for user {user_id} to {new_role.value}",
|
||||
)
|
||||
member_dict = member.to_dict()
|
||||
member_dict["user"] = member.user.to_dict()
|
||||
return api_response(data={"member": member_dict}, message="Member role updated successfully")
|
||||
@@ -180,4 +205,13 @@ def send_mfa_reminder(org_id, user_id):
|
||||
html_body=html_body,
|
||||
)
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_MFA_REMINDER_SENT,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"MFA reminder sent to {user.email}",
|
||||
)
|
||||
|
||||
return api_response(data={}, message="Reminder sent successfully")
|
||||
|
||||
@@ -3,8 +3,9 @@ from flask import g, request
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
from gatehouse_app.utils.constants import AuditAction, OrganizationRole
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/roles", methods=["GET"])
|
||||
@@ -59,6 +60,16 @@ def assign_role_to_member(org_id, role_name):
|
||||
|
||||
membership.role = new_role
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_MEMBER_ROLE_CHANGE,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="user",
|
||||
resource_id=str(target_user_id),
|
||||
description=f"Role changed to {new_role.value} for user {target_user_id}",
|
||||
)
|
||||
|
||||
return api_response(data={"user_id": target_user_id, "role": new_role.value}, message=f"Role updated to {new_role.value}")
|
||||
|
||||
|
||||
@@ -82,4 +93,14 @@ def remove_role_from_member(org_id, role_name, user_id):
|
||||
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id)
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_MEMBER_REMOVE,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"Member {user_id} removed from organization via role removal",
|
||||
)
|
||||
|
||||
return api_response(data={"user_id": user_id}, message="Member removed from organization")
|
||||
|
||||
@@ -10,6 +10,8 @@ from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.exceptions import OrganizationNotFoundError
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
|
||||
class PrincipalCreateSchema(Schema):
|
||||
@@ -127,6 +129,15 @@ def create_principal(org_id):
|
||||
db.session.add(principal)
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.PRINCIPAL_CREATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="principal",
|
||||
resource_id=str(principal.id),
|
||||
description=f"Principal '{principal.name}' created",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"principal": principal.to_dict()},
|
||||
message="Principal created successfully",
|
||||
@@ -255,6 +266,15 @@ def update_principal(org_id, principal_id):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.PRINCIPAL_UPDATED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="principal",
|
||||
resource_id=str(principal.id),
|
||||
description=f"Principal '{principal.name}' updated",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"principal": principal.to_dict()},
|
||||
message="Principal updated successfully",
|
||||
@@ -308,6 +328,15 @@ def delete_principal(org_id, principal_id):
|
||||
principal.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.PRINCIPAL_DELETED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="principal",
|
||||
resource_id=str(principal.id),
|
||||
description=f"Principal '{principal.name}' deleted",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
message="Principal deleted successfully",
|
||||
)
|
||||
@@ -476,6 +505,15 @@ def add_principal_member(org_id, principal_id):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.PRINCIPAL_MEMBER_ADDED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="user",
|
||||
resource_id=str(user.id),
|
||||
description=f"Added user {user.email} to principal '{principal.name}'",
|
||||
)
|
||||
|
||||
member_dict = membership.to_dict()
|
||||
member_dict["user"] = user.to_dict()
|
||||
|
||||
@@ -548,6 +586,15 @@ def remove_principal_member(org_id, principal_id, user_id):
|
||||
membership.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.PRINCIPAL_MEMBER_REMOVED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="user",
|
||||
resource_id=str(user_id),
|
||||
description=f"Removed user from principal '{principal.name}'",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
message="Member removed successfully",
|
||||
)
|
||||
@@ -697,6 +744,15 @@ def link_principal_to_department(org_id, principal_id, dept_id):
|
||||
error_type="SERVER_ERROR",
|
||||
)
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.PRINCIPAL_DEPARTMENT_LINKED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="principal",
|
||||
resource_id=str(principal_id),
|
||||
description=f"Principal '{principal.name}' linked to department '{dept.name}'",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"principal": principal.to_dict(),
|
||||
@@ -774,6 +830,15 @@ def unlink_principal_from_department(org_id, principal_id, dept_id):
|
||||
link.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.PRINCIPAL_DEPARTMENT_UNLINKED,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="principal",
|
||||
resource_id=str(principal_id),
|
||||
description=f"Principal '{principal.name}' unlinked from department '{dept.name}'",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
message="Principal unlinked from department successfully",
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService
|
||||
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -105,6 +106,7 @@ def login():
|
||||
|
||||
@superadmin_bp.route("/auth/logout", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action=AuditAction.USER_LOGOUT, resource_type="session")
|
||||
def logout():
|
||||
"""Superadmin logout endpoint.
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ def superadmin_required(f):
|
||||
"""Decorator to require superadmin Bearer token authentication.
|
||||
|
||||
Extracts token from Authorization: Bearer {token} header,
|
||||
validates the session against SuperadminSession table,
|
||||
validates the session against the unified sessions table,
|
||||
and sets g.current_superadmin and g.superadmin_session.
|
||||
|
||||
Returns 401 if no valid session, 403 if not a superadmin.
|
||||
@@ -46,10 +46,14 @@ def superadmin_required(f):
|
||||
token = parts[1]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from gatehouse_app.models.superadmin import SuperadminSession, Superadmin
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.models.superadmin import Superadmin
|
||||
from gatehouse_app.utils.constants import SessionType
|
||||
|
||||
# Get active session by token
|
||||
session = SuperadminSession.query.filter_by(token=token).first()
|
||||
# Get active session by token, scoped to superadmin
|
||||
session = Session.query.filter_by(
|
||||
token=token, owner_type=SessionType.SUPERADMIN
|
||||
).first()
|
||||
|
||||
if not session:
|
||||
return api_response(
|
||||
@@ -68,8 +72,8 @@ def superadmin_required(f):
|
||||
error_type="SESSION_INACTIVE"
|
||||
)
|
||||
|
||||
# Get the superadmin
|
||||
superadmin = session.superadmin
|
||||
# Get the superadmin by owner_id
|
||||
superadmin = Superadmin.query.get(session.owner_id)
|
||||
if not superadmin:
|
||||
return api_response(
|
||||
success=False,
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
"""CORS middleware configuration."""
|
||||
import base64
|
||||
import json
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from flask import request, make_response
|
||||
|
||||
from gatehouse_app.models import OIDCClient
|
||||
|
||||
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
ALLOWED_HEADERS = (
|
||||
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
|
||||
@@ -40,6 +46,85 @@ def _cors_origin_header(cors_origins, request_origin):
|
||||
return None
|
||||
|
||||
|
||||
def _get_oidc_client_id_from_request():
|
||||
"""Extract client_id from OIDC endpoint requests."""
|
||||
path = request.path
|
||||
|
||||
# POST to /oidc/token, /oidc/revoke, /oidc/introspect
|
||||
if request.method == "POST" and any(
|
||||
path.endswith(ep) for ep in ("/oidc/token", "/oidc/revoke", "/oidc/introspect")
|
||||
):
|
||||
# Try Basic Auth header first
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
|
||||
client_id, _, _ = decoded.partition(":")
|
||||
if client_id:
|
||||
return client_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try form body
|
||||
if request.form:
|
||||
client_id = request.form.get("client_id")
|
||||
if client_id:
|
||||
return client_id
|
||||
|
||||
# Try JSON body
|
||||
if request.is_json:
|
||||
try:
|
||||
client_id = request.json.get("client_id")
|
||||
if client_id:
|
||||
return client_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
# GET/POST to /oidc/userinfo
|
||||
if path.endswith("/oidc/userinfo"):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
payload_b64 = token.split(".")[1]
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
payload = json.loads(base64.urlsafe_b64decode(payload_b64))
|
||||
return payload.get("client_id")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_effective_cors_origins(app, request):
|
||||
"""Get effective CORS origins, checking per-client config for OIDC endpoints."""
|
||||
global_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
if "/oidc/" not in request.path:
|
||||
return global_origins
|
||||
|
||||
try:
|
||||
client_id = _get_oidc_client_id_from_request()
|
||||
if not client_id:
|
||||
return global_origins
|
||||
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
return global_origins
|
||||
|
||||
effective = client.get_effective_origins()
|
||||
if effective is not None:
|
||||
return effective
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return global_origins
|
||||
|
||||
|
||||
def setup_cors(app):
|
||||
"""
|
||||
Configure CORS for the application.
|
||||
@@ -54,7 +139,7 @@ def setup_cors(app):
|
||||
"""Handle CORS preflight OPTIONS requests."""
|
||||
if request.method == "OPTIONS":
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
cors_origins = _get_effective_cors_origins(app, request)
|
||||
|
||||
if not _is_origin_allowed(origin, cors_origins):
|
||||
return None
|
||||
@@ -73,7 +158,7 @@ def setup_cors(app):
|
||||
def after_request_cors(response):
|
||||
"""Add CORS headers to non-preflight responses."""
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
cors_origins = _get_effective_cors_origins(app, request)
|
||||
|
||||
allow_origin = _cors_origin_header(cors_origins, origin)
|
||||
if allow_origin:
|
||||
|
||||
@@ -118,7 +118,6 @@ from gatehouse_app.models.zerotier import ( # noqa: F401
|
||||
from gatehouse_app.models.superadmin import ( # noqa: F401
|
||||
Superadmin,
|
||||
SuperadminSession,
|
||||
SuperadminSessionStatus,
|
||||
)
|
||||
from gatehouse_app.models.superadmin_audit_log import SuperadminAuditLog # noqa: F401
|
||||
from gatehouse_app.models.security.user_security_policy import ( # noqa: F401
|
||||
@@ -186,6 +185,5 @@ __all__ = [
|
||||
# Superadmin
|
||||
"Superadmin",
|
||||
"SuperadminSession",
|
||||
"SuperadminSessionStatus",
|
||||
"SuperadminAuditLog",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""OIDC Client model."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType
|
||||
@@ -21,6 +23,7 @@ class OIDCClient(BaseModel):
|
||||
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
|
||||
response_types = db.Column(db.JSON, nullable=False) # Allowed response types
|
||||
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
|
||||
allowed_cors_origins = db.Column(db.JSON, nullable=True, default=None) # Per-client CORS origins
|
||||
|
||||
# Client metadata
|
||||
logo_uri = db.Column(db.String(512), nullable=True)
|
||||
@@ -81,6 +84,37 @@ class OIDCClient(BaseModel):
|
||||
"""Check if a redirect URI is allowed for this client."""
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def get_effective_origins(self) -> list | None:
|
||||
"""Get effective CORS origins for this client.
|
||||
|
||||
Returns None to signal "use global config", a derived list from
|
||||
redirect_uris when "+" is present, or the configured list as-is.
|
||||
"""
|
||||
if self.allowed_cors_origins is None:
|
||||
return None
|
||||
if "+" in self.allowed_cors_origins:
|
||||
origins = set()
|
||||
for uri in self.redirect_uris:
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme and parsed.hostname:
|
||||
port = f":{parsed.port}" if parsed.port else ""
|
||||
origins.add(f"{parsed.scheme}://{parsed.hostname}{port}")
|
||||
return sorted(origins)
|
||||
return list(self.allowed_cors_origins)
|
||||
|
||||
def is_origin_allowed(self, origin: str) -> bool | None:
|
||||
"""Check if a browser origin is allowed for CORS.
|
||||
|
||||
Returns True/False when a per-client list is configured,
|
||||
or None to defer to the global CORS policy.
|
||||
"""
|
||||
effective = self.get_effective_origins()
|
||||
if effective is None:
|
||||
return None
|
||||
if "*" in effective:
|
||||
return True
|
||||
return origin in effective
|
||||
|
||||
def has_scope(self, scope: str) -> bool:
|
||||
"""Check if client is allowed to request a specific scope."""
|
||||
return scope in self.scopes
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Superadmin models."""
|
||||
from gatehouse_app.models.superadmin.superadmin import Superadmin
|
||||
from gatehouse_app.models.superadmin.superadmin_session import SuperadminSession, SuperadminSessionStatus
|
||||
from gatehouse_app.models.user.session import Session as SuperadminSession
|
||||
|
||||
__all__ = ["Superadmin", "SuperadminSession", "SuperadminSessionStatus"]
|
||||
__all__ = ["Superadmin", "SuperadminSession"]
|
||||
|
||||
@@ -23,11 +23,15 @@ class Superadmin(BaseModel):
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
last_login_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationship to sessions
|
||||
# Relationship to sessions (unified model, scoped to superadmin owner_type)
|
||||
sessions = db.relationship(
|
||||
"SuperadminSession",
|
||||
back_populates="superadmin",
|
||||
cascade="all, delete-orphan"
|
||||
"Session",
|
||||
primaryjoin=(
|
||||
"and_(Superadmin.id == foreign(Session.owner_id), "
|
||||
"Session.owner_type == 'superadmin')"
|
||||
),
|
||||
cascade="all, delete-orphan",
|
||||
lazy="dynamic",
|
||||
)
|
||||
|
||||
# Relationship to audit logs
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Superadmin session model."""
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuperadminSessionStatus:
|
||||
"""Session status constants."""
|
||||
ACTIVE = "active"
|
||||
REVOKED = "revoked"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class SuperadminSession(BaseModel):
|
||||
"""Session model for superadmin authentication."""
|
||||
|
||||
__tablename__ = "superadmin_sessions"
|
||||
|
||||
superadmin_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("superadmins.id"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
token = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
last_activity_at = db.Column(
|
||||
db.DateTime,
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationship
|
||||
superadmin = db.relationship("Superadmin", back_populates="sessions")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SuperadminSession superadmin_id={self.superadmin_id}>"
|
||||
|
||||
def is_active(self):
|
||||
"""Check if session is currently active."""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return (
|
||||
self.deleted_at is None
|
||||
and self.revoked_at is None
|
||||
and expires_at > now
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if session has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return now > expires_at
|
||||
|
||||
def revoke(self, reason: str = None):
|
||||
"""Revoke the session."""
|
||||
self.revoked_at = datetime.now(timezone.utc)
|
||||
if reason:
|
||||
self.revoked_reason = reason
|
||||
from gatehouse_app import db
|
||||
db.session.commit()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
exclude.append("token")
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -3,15 +3,24 @@ from datetime import datetime, timedelta, timezone
|
||||
from flask import current_app
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
from gatehouse_app.utils.constants import SessionStatus, SessionType
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
"""Session model for tracking user sessions."""
|
||||
"""Session model for tracking user and superadmin sessions."""
|
||||
|
||||
__tablename__ = "sessions"
|
||||
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True)
|
||||
# Owner discriminator — determines which table the owner_id references
|
||||
owner_type = db.Column(
|
||||
db.String(20), nullable=False, default=SessionType.USER, index=True
|
||||
)
|
||||
owner_id = db.Column(db.String(36), nullable=False, index=True)
|
||||
|
||||
# Legacy column kept for backward compatibility during migration;
|
||||
# new code should use owner_id / owner_type.
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True)
|
||||
|
||||
token = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
status = db.Column(db.Enum(SessionStatus), default=SessionStatus.ACTIVE, nullable=False)
|
||||
|
||||
@@ -34,21 +43,37 @@ class Session(BaseModel):
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="sessions")
|
||||
|
||||
# Composite index for owner-scoped queries
|
||||
__table_args__ = (
|
||||
db.Index("ix_sessions_owner_type_owner_id", "owner_type", "owner_id"),
|
||||
)
|
||||
|
||||
# ---- Convenience properties ------------------------------------------------
|
||||
|
||||
@property
|
||||
def is_user(self):
|
||||
return self.owner_type == SessionType.USER
|
||||
|
||||
@property
|
||||
def is_superadmin(self):
|
||||
return self.owner_type == SessionType.SUPERADMIN
|
||||
|
||||
# ---- Core methods ----------------------------------------------------------
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of Session."""
|
||||
return f"<Session user_id={self.user_id} status={self.status}>"
|
||||
return f"<Session owner_type={self.owner_type} owner_id={self.owner_id} status={self.status}>"
|
||||
|
||||
def is_active(self):
|
||||
"""Check if session is currently active.
|
||||
|
||||
Sessions are evaluated against two independent timeouts:
|
||||
User sessions are evaluated against two independent timeouts:
|
||||
- Idle timeout: expires if no request has been made within
|
||||
SESSION_IDLE_TIMEOUT seconds (default 15 min).
|
||||
- Absolute timeout: expires if SESSION_ABSOLUTE_TIMEOUT seconds
|
||||
have elapsed since the session was created (default 8 h),
|
||||
regardless of activity.
|
||||
have elapsed since the session was created (default 8 h).
|
||||
|
||||
A session must satisfy *both* constraints to remain active.
|
||||
Superadmin sessions use absolute timeout only (no idle timeout).
|
||||
A session must satisfy *all* applicable constraints to remain active.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
created_at = self.created_at
|
||||
@@ -59,12 +84,21 @@ class Session(BaseModel):
|
||||
if last_activity_at.tzinfo is None:
|
||||
last_activity_at = last_activity_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
|
||||
|
||||
idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout)
|
||||
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
|
||||
|
||||
if self.is_superadmin:
|
||||
# Superadmin: absolute timeout only
|
||||
return (
|
||||
self.status == SessionStatus.ACTIVE
|
||||
and now < absolute_expires_at
|
||||
and self.deleted_at is None
|
||||
)
|
||||
|
||||
# User: idle + absolute timeout
|
||||
idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout)
|
||||
|
||||
return (
|
||||
self.status == SessionStatus.ACTIVE
|
||||
and now < idle_expires_at
|
||||
@@ -83,6 +117,8 @@ class Session(BaseModel):
|
||||
capped so that the session never exceeds the absolute lifetime
|
||||
(``created_at + absolute timeout``).
|
||||
|
||||
Superadmin sessions only update last_activity_at (no sliding window).
|
||||
|
||||
Args:
|
||||
duration_seconds: Override for the idle timeout. When *None*
|
||||
(the common case), the value is read from
|
||||
@@ -90,6 +126,12 @@ class Session(BaseModel):
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if self.is_superadmin:
|
||||
# Superadmin: just bump last_activity_at, no sliding window
|
||||
self.last_activity_at = now
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
if duration_seconds is None:
|
||||
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from gatehouse_app.extensions import db, bcrypt
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, UserStatus, AuditAction
|
||||
from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, SessionType, UserStatus, AuditAction
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
|
||||
from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
@@ -165,6 +165,8 @@ class AuthService:
|
||||
|
||||
# Create session
|
||||
session = Session(
|
||||
owner_type=SessionType.USER,
|
||||
owner_id=user.id,
|
||||
user_id=user.id,
|
||||
token=token,
|
||||
status=SessionStatus.ACTIVE,
|
||||
|
||||
@@ -562,3 +562,51 @@ def build_contact_enquiry_html(
|
||||
<p style="margin: 0; color: {TEXT_COLOR}; font-size: 14px; line-height: 1.6; white-space: pre-wrap;">{message_display}</p>
|
||||
'''
|
||||
return get_base_html(content, f"Secuird Website: {type_label}", f"New {type_label} from {submitter_email}")
|
||||
|
||||
|
||||
def build_invite_accepted_html(
|
||||
inviter_name: str,
|
||||
member_name: str,
|
||||
member_email: str,
|
||||
org_name: str,
|
||||
role: str,
|
||||
org_link: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build invite accepted notification email.
|
||||
|
||||
Args:
|
||||
inviter_name: Name of the person who sent the invite
|
||||
member_name: Name of the person who accepted
|
||||
member_email: Email of the person who accepted
|
||||
org_name: Organization name
|
||||
role: Role assigned to the member
|
||||
org_link: Optional link to view the organization
|
||||
|
||||
Returns:
|
||||
HTML email string
|
||||
"""
|
||||
content = f'''
|
||||
<h2 style="margin: 0 0 20px 0; color: {TEXT_COLOR}; font-size: 20px; font-weight: 600;">Invitation Accepted</h2>
|
||||
<p style="margin: 0 0 20px 0; color: {TEXT_COLOR}; font-size: 15px; line-height: 1.6;">
|
||||
<strong>{member_name}</strong> has accepted your invitation to join <strong>{org_name}</strong> on Secuird.
|
||||
</p>
|
||||
{get_alert_box(f"<strong>{member_name}</strong> ({member_email}) has joined <strong>{org_name}</strong>", "success", "✅")}
|
||||
<table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="margin: 20px 0; background-color: {BACKGROUND_COLOR}; border-radius: 8px;">
|
||||
<tr>
|
||||
<td style="padding: 20px;">
|
||||
<h3 style="margin: 0 0 16px 0; color: {TEXT_COLOR}; font-size: 14px; font-weight: 600;">Membership Details</h3>
|
||||
<table role="presentation" width="100%" cellspacing="0" cellpadding="0">
|
||||
{get_detail_row("Member", member_name)}
|
||||
{get_detail_row("Email", member_email)}
|
||||
{get_detail_row("Organization", org_name)}
|
||||
{get_detail_row("Role", role)}
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
'''
|
||||
if org_link:
|
||||
content += get_action_button(org_link, "View Organization", PRIMARY_COLOR)
|
||||
|
||||
return get_base_html(content, f"Invitation accepted: {org_name}", f"{member_name} has joined {org_name}")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Session service."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
from gatehouse_app.utils.constants import SessionStatus, SessionType
|
||||
|
||||
|
||||
class SessionService:
|
||||
@@ -28,18 +28,22 @@ class SessionService:
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_sessions(user_id, active_only=True):
|
||||
"""
|
||||
Get all sessions for a user.
|
||||
def get_owner_sessions(owner_type, owner_id, active_only=True):
|
||||
"""Get all sessions for an owner (user or superadmin).
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
owner_type: SessionType.USER or SessionType.SUPERADMIN
|
||||
owner_id: Owner ID
|
||||
active_only: If True, only return active sessions
|
||||
|
||||
Returns:
|
||||
List of Session instances
|
||||
"""
|
||||
query = Session.query.filter_by(user_id=user_id, deleted_at=None)
|
||||
query = Session.query.filter_by(
|
||||
owner_type=owner_type,
|
||||
owner_id=owner_id,
|
||||
deleted_at=None,
|
||||
)
|
||||
|
||||
if active_only:
|
||||
query = query.filter_by(status=SessionStatus.ACTIVE).filter(
|
||||
@@ -49,18 +53,67 @@ class SessionService:
|
||||
return query.all()
|
||||
|
||||
@staticmethod
|
||||
def revoke_user_sessions(user_id, reason="User logged out from all devices"):
|
||||
def get_user_sessions(user_id, active_only=True):
|
||||
"""Get all sessions for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
active_only: If True, only return active sessions
|
||||
|
||||
Returns:
|
||||
List of Session instances
|
||||
"""
|
||||
Revoke all active sessions for a user.
|
||||
return SessionService.get_owner_sessions(
|
||||
SessionType.USER, user_id, active_only=active_only
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_superadmin_sessions(superadmin_id, active_only=True):
|
||||
"""Get all sessions for a superadmin.
|
||||
|
||||
Args:
|
||||
superadmin_id: Superadmin ID
|
||||
active_only: If True, only return active sessions
|
||||
|
||||
Returns:
|
||||
List of Session instances
|
||||
"""
|
||||
return SessionService.get_owner_sessions(
|
||||
SessionType.SUPERADMIN, superadmin_id, active_only=active_only
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def revoke_owner_sessions(owner_type, owner_id, reason="Logged out from all devices"):
|
||||
"""Revoke all active sessions for an owner.
|
||||
|
||||
Args:
|
||||
owner_type: SessionType.USER or SessionType.SUPERADMIN
|
||||
owner_id: Owner ID
|
||||
reason: Reason for revocation
|
||||
"""
|
||||
sessions = SessionService.get_owner_sessions(owner_type, owner_id, active_only=True)
|
||||
for session in sessions:
|
||||
session.revoke(reason=reason)
|
||||
|
||||
@staticmethod
|
||||
def revoke_user_sessions(user_id, reason="User logged out from all devices"):
|
||||
"""Revoke all active sessions for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
reason: Reason for revocation
|
||||
"""
|
||||
sessions = SessionService.get_user_sessions(user_id, active_only=True)
|
||||
SessionService.revoke_owner_sessions(SessionType.USER, user_id, reason=reason)
|
||||
|
||||
for session in sessions:
|
||||
session.revoke(reason=reason)
|
||||
@staticmethod
|
||||
def revoke_superadmin_sessions(superadmin_id, reason="Superadmin logged out"):
|
||||
"""Revoke all active sessions for a superadmin.
|
||||
|
||||
Args:
|
||||
superadmin_id: Superadmin ID
|
||||
reason: Reason for revocation
|
||||
"""
|
||||
SessionService.revoke_owner_sessions(SessionType.SUPERADMIN, superadmin_id, reason=reason)
|
||||
|
||||
@staticmethod
|
||||
def cleanup_expired_sessions():
|
||||
|
||||
@@ -6,7 +6,9 @@ from typing import Optional
|
||||
|
||||
from flask import request, current_app
|
||||
from gatehouse_app.extensions import db, bcrypt
|
||||
from gatehouse_app.models.superadmin import Superadmin, SuperadminSession
|
||||
from gatehouse_app.models.superadmin import Superadmin
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.utils.constants import SessionType
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||
|
||||
|
||||
@@ -70,15 +72,17 @@ class SuperadminAuthService:
|
||||
duration_seconds: Session duration in seconds (default 8 hours)
|
||||
|
||||
Returns:
|
||||
SuperadminSession instance
|
||||
Session instance
|
||||
"""
|
||||
# Generate secure token
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
# Create session
|
||||
session = SuperadminSession(
|
||||
superadmin_id=superadmin_id,
|
||||
# Create session using unified model
|
||||
session = Session(
|
||||
owner_type=SessionType.SUPERADMIN,
|
||||
owner_id=superadmin_id,
|
||||
token=token,
|
||||
status="active",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
ip_address=request.remote_addr,
|
||||
@@ -97,7 +101,9 @@ class SuperadminAuthService:
|
||||
session_id: Session ID to revoke
|
||||
reason: Optional revocation reason
|
||||
"""
|
||||
session = SuperadminSession.query.get(session_id)
|
||||
session = Session.query.filter_by(
|
||||
id=session_id, owner_type=SessionType.SUPERADMIN
|
||||
).first()
|
||||
if session:
|
||||
session.revoke(reason=reason)
|
||||
logger.info(f"[SuperadminAuth] Session {session_id} revoked: {reason or 'No reason'}")
|
||||
@@ -111,9 +117,11 @@ class SuperadminAuthService:
|
||||
except_token: Optional token to keep (current session)
|
||||
reason: Optional revocation reason
|
||||
"""
|
||||
query = SuperadminSession.query.filter_by(superadmin_id=superadmin_id)
|
||||
query = Session.query.filter_by(
|
||||
owner_type=SessionType.SUPERADMIN, owner_id=superadmin_id
|
||||
)
|
||||
if except_token:
|
||||
query = query.filter(SuperadminSession.token != except_token)
|
||||
query = query.filter(Session.token != except_token)
|
||||
|
||||
sessions = query.all()
|
||||
for session in sessions:
|
||||
|
||||
@@ -52,6 +52,13 @@ class SessionStatus(str, Enum):
|
||||
REVOKED = "revoked"
|
||||
|
||||
|
||||
class SessionType(str, Enum):
|
||||
"""Session owner type discriminator."""
|
||||
|
||||
USER = "user"
|
||||
SUPERADMIN = "superadmin"
|
||||
|
||||
|
||||
class AuditAction(str, Enum):
|
||||
"""Audit log action types."""
|
||||
|
||||
@@ -154,6 +161,27 @@ class AuditAction(str, Enum):
|
||||
DEPARTMENT_DELETED = "department.deleted"
|
||||
DEPARTMENT_MEMBER_ADDED = "department.member.added"
|
||||
DEPARTMENT_MEMBER_REMOVED = "department.member.removed"
|
||||
DEPARTMENT_CERT_POLICY_UPDATED = "department.cert_policy.updated"
|
||||
|
||||
# Organization invite actions
|
||||
ORG_INVITE_CANCELLED = "org.invite.cancelled"
|
||||
|
||||
# MFA reminder
|
||||
ORG_MFA_REMINDER_SENT = "org.mfa_reminder.sent"
|
||||
|
||||
# API key actions
|
||||
ORG_API_KEY_CREATED = "org.api_key.created"
|
||||
ORG_API_KEY_UPDATED = "org.api_key.updated"
|
||||
ORG_API_KEY_DELETED = "org.api_key.deleted"
|
||||
|
||||
# OIDC client actions
|
||||
ORG_CLIENT_CREATED = "org.client.created"
|
||||
ORG_CLIENT_UPDATED = "org.client.updated"
|
||||
ORG_CLIENT_DEACTIVATED = "org.client.deactivated"
|
||||
|
||||
# Principal department link actions
|
||||
PRINCIPAL_DEPARTMENT_LINKED = "principal.department.linked"
|
||||
PRINCIPAL_DEPARTMENT_UNLINKED = "principal.department.unlinked"
|
||||
|
||||
|
||||
class OIDCGrantType(str, Enum):
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Add allowed_cors_origins to oidc_clients.
|
||||
|
||||
Revision ID: b7e3f1a92c4d
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2026-04-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b7e3f1a92c4d'
|
||||
down_revision = 'a1b2c3d4e5f6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column('oidc_clients', sa.Column('allowed_cors_origins', sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column('oidc_clients', 'allowed_cors_origins')
|
||||
@@ -0,0 +1,122 @@
|
||||
"""Consolidate user and superadmin sessions into unified sessions table.
|
||||
|
||||
Revision ID: c8d2e4f6a1b3
|
||||
Revises: b7e3f1a92c4d
|
||||
Create Date: 2026-04-28 00:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'c8d2e4f6a1b3'
|
||||
down_revision = 'b7e3f1a92c4d'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# 1. Add new columns (nullable initially for data migration)
|
||||
op.add_column('sessions', sa.Column('owner_type', sa.String(20), nullable=True))
|
||||
op.add_column('sessions', sa.Column('owner_id', sa.String(36), nullable=True))
|
||||
|
||||
# 2. Backfill existing user sessions: owner_type = 'user', owner_id = user_id
|
||||
op.execute("""
|
||||
UPDATE sessions
|
||||
SET owner_type = 'user',
|
||||
owner_id = user_id
|
||||
WHERE owner_type IS NULL
|
||||
""")
|
||||
|
||||
# 3. Migrate superadmin sessions into the sessions table
|
||||
op.execute("""
|
||||
INSERT INTO sessions (
|
||||
id, owner_type, owner_id, token, status,
|
||||
ip_address, user_agent, device_info,
|
||||
expires_at, last_activity_at, revoked_at, revoked_reason,
|
||||
is_compliance_only, created_at, updated_at, deleted_at
|
||||
)
|
||||
SELECT
|
||||
id, 'superadmin', superadmin_id, token, 'active',
|
||||
ip_address, user_agent, NULL,
|
||||
expires_at, last_activity_at, revoked_at, revoked_reason,
|
||||
FALSE, created_at, updated_at, deleted_at
|
||||
FROM superadmin_sessions
|
||||
""")
|
||||
|
||||
# 4. Make owner_type and owner_id NOT NULL
|
||||
op.alter_column('sessions', 'owner_type', nullable=False)
|
||||
op.alter_column('sessions', 'owner_id', nullable=False)
|
||||
|
||||
# 5. Make user_id nullable (no longer the sole owner reference)
|
||||
op.alter_column('sessions', 'user_id', nullable=True)
|
||||
|
||||
# 6. Create indexes for efficient owner-scoped queries
|
||||
op.create_index(
|
||||
'ix_sessions_owner_type_owner_id',
|
||||
'sessions',
|
||||
['owner_type', 'owner_id']
|
||||
)
|
||||
op.create_index(
|
||||
'ix_sessions_owner_type',
|
||||
'sessions',
|
||||
['owner_type']
|
||||
)
|
||||
op.create_index(
|
||||
'ix_sessions_owner_id',
|
||||
'sessions',
|
||||
['owner_id']
|
||||
)
|
||||
|
||||
# 7. Drop the now-redundant superadmin_sessions table
|
||||
op.drop_table('superadmin_sessions')
|
||||
|
||||
|
||||
def downgrade():
|
||||
# 1. Recreate superadmin_sessions table
|
||||
op.create_table(
|
||||
'superadmin_sessions',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('superadmin_id', sa.String(36), sa.ForeignKey('superadmins.id'), nullable=False, index=True),
|
||||
sa.Column('token', sa.String(255), unique=True, nullable=False, index=True),
|
||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
||||
sa.Column('last_activity_at', sa.DateTime, nullable=False),
|
||||
sa.Column('ip_address', sa.String(45), nullable=True),
|
||||
sa.Column('user_agent', sa.Text, nullable=True),
|
||||
sa.Column('revoked_at', sa.DateTime, nullable=True),
|
||||
sa.Column('revoked_reason', sa.String(255), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime, nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime, nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime, nullable=True),
|
||||
)
|
||||
|
||||
# 2. Move superadmin sessions back to superadmin_sessions
|
||||
op.execute("""
|
||||
INSERT INTO superadmin_sessions (
|
||||
id, superadmin_id, token, expires_at, last_activity_at,
|
||||
ip_address, user_agent, revoked_at, revoked_reason,
|
||||
created_at, updated_at, deleted_at
|
||||
)
|
||||
SELECT
|
||||
id, owner_id, token, expires_at, last_activity_at,
|
||||
ip_address, user_agent, revoked_at, revoked_reason,
|
||||
created_at, updated_at, deleted_at
|
||||
FROM sessions
|
||||
WHERE owner_type = 'superadmin'
|
||||
""")
|
||||
|
||||
# 3. Remove superadmin sessions from sessions table
|
||||
op.execute("DELETE FROM sessions WHERE owner_type = 'superadmin'")
|
||||
|
||||
# 4. Drop indexes
|
||||
op.drop_index('ix_sessions_owner_id', table_name='sessions')
|
||||
op.drop_index('ix_sessions_owner_type', table_name='sessions')
|
||||
op.drop_index('ix_sessions_owner_type_owner_id', table_name='sessions')
|
||||
|
||||
# 5. Remove new columns
|
||||
op.drop_column('sessions', 'owner_id')
|
||||
op.drop_column('sessions', 'owner_type')
|
||||
|
||||
# 6. Make user_id NOT NULL again
|
||||
op.alter_column('sessions', 'user_id', nullable=False)
|
||||
+21
-1
@@ -148,8 +148,28 @@ def test_html_email():
|
||||
success = provider.send(message)
|
||||
print(f"Result: {'✅ SUCCESS' if success else '❌ FAILED'}")
|
||||
|
||||
# Test 8: Invite Accepted
|
||||
print("\n--- Test 8: Invite Accepted ---")
|
||||
html_body = email_templates.build_invite_accepted_html(
|
||||
inviter_name="Admin User",
|
||||
member_name="New Member",
|
||||
member_email="newmember@example.com",
|
||||
org_name="Acme Corporation",
|
||||
role="Member",
|
||||
org_link="https://secuird.tech/organizations/org-123",
|
||||
)
|
||||
message = EmailMessage(
|
||||
to="cory@hawkvelt.id.au",
|
||||
subject="New Member accepted your invitation to Acme Corporation",
|
||||
body="Plain text version: New Member has accepted your invitation.",
|
||||
html_body=html_body,
|
||||
from_address="Secuird <noreply@secuird.tech>",
|
||||
)
|
||||
success = provider.send(message)
|
||||
print(f"Result: {'✅ SUCCESS' if success else '❌ FAILED'}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("All 7 email templates sent!")
|
||||
print("All 8 email templates sent!")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
"""Superadmin session timeout integration tests.
|
||||
|
||||
Validates the absolute-only timeout policy for superadmin sessions.
|
||||
Superadmin sessions do NOT have idle timeout — only absolute timeout.
|
||||
"""
|
||||
import pytest
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from tests.integration.client.base import ApiError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def assert_success(response: dict, message_contains: str = "") -> dict:
|
||||
"""Assert that an api_response-wrapped payload succeeded."""
|
||||
data = response.get("data", {})
|
||||
assert response.get("success") is not False, (
|
||||
f"Expected success but got error: {response.get('message')}"
|
||||
)
|
||||
if message_contains:
|
||||
assert message_contains.lower() in response.get("message", "").lower(), (
|
||||
f"Expected message to contain '{message_contains}' but got: {response.get('message')}"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def _get_session_row(integration_app, token: str):
|
||||
"""Look up the Session model row for a given bearer token."""
|
||||
from gatehouse_app.models.user.session import Session
|
||||
with integration_app.app_context():
|
||||
return Session.query.filter_by(token=token).first()
|
||||
|
||||
|
||||
def _touch_session(integration_app, session_id: str, **updates):
|
||||
"""Directly update columns on a Session row.
|
||||
|
||||
Only use this to simulate the passage of time — never to assert
|
||||
internal state.
|
||||
"""
|
||||
from gatehouse_app.models.user.session import Session
|
||||
with integration_app.app_context():
|
||||
sess = Session.query.get(session_id)
|
||||
for attr, value in updates.items():
|
||||
setattr(sess, attr, value)
|
||||
from gatehouse_app import db
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def superadmin_credentials(integration_app):
|
||||
"""Create a superadmin and return login credentials."""
|
||||
from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService
|
||||
|
||||
email = f"admin_{uuid.uuid4().hex[:8]}@gatehouse.local"
|
||||
password = "SuperAdmin123!"
|
||||
|
||||
with integration_app.app_context():
|
||||
sa = SuperadminAuthService.create_superadmin(
|
||||
email=email,
|
||||
credential=password,
|
||||
full_name="Test Superadmin",
|
||||
)
|
||||
return {"id": str(sa.id), "email": email, "password": password}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logged_in_superadmin(integration_client, superadmin_credentials, integration_app):
|
||||
"""Log in as superadmin and return session metadata.
|
||||
|
||||
Returns dict with ``superadmin``, ``token``, ``session_id``, ``session_row``.
|
||||
"""
|
||||
creds = superadmin_credentials
|
||||
resp = integration_client.post(
|
||||
"/api/v1/superadmin/auth/login",
|
||||
data={"email": creds["email"], "password": creds["password"]},
|
||||
)
|
||||
data = assert_success(resp)
|
||||
token = data["token"]
|
||||
|
||||
session_row = _get_session_row(integration_app, token)
|
||||
assert session_row is not None, "Session row should exist after superadmin login"
|
||||
|
||||
return {
|
||||
"superadmin": creds,
|
||||
"token": token,
|
||||
"session_id": session_row.id,
|
||||
"session_row": session_row,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSuperadminSessionTimeouts:
|
||||
"""Absolute-only timeout behavior for superadmin sessions."""
|
||||
|
||||
def test_superadmin_session_valid_before_timeout(
|
||||
self, integration_client, logged_in_superadmin,
|
||||
):
|
||||
"""SA-SESS-01 — Fresh superadmin session is accepted."""
|
||||
integration_client.set_token(logged_in_superadmin["token"])
|
||||
result = integration_client.get("/api/v1/superadmin/auth/me")
|
||||
data = assert_success(result)
|
||||
assert "superadmin" in data
|
||||
|
||||
def test_absolute_timeout_rejects_superadmin(
|
||||
self, integration_client, logged_in_superadmin, integration_app,
|
||||
):
|
||||
"""SA-SESS-02 — Superadmin session rejected after absolute timeout.
|
||||
|
||||
Push ``created_at`` far into the past. The session must be
|
||||
rejected even though ``last_activity_at`` is fresh.
|
||||
"""
|
||||
_touch_session(
|
||||
integration_app,
|
||||
logged_in_superadmin["session_id"],
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
integration_client.set_token(logged_in_superadmin["token"])
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.get("/api/v1/superadmin/auth/me")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_idle_timeout_does_NOT_reject_superadmin(
|
||||
self, integration_client, logged_in_superadmin, integration_app,
|
||||
):
|
||||
"""SA-SESS-03 — Superadmin sessions have NO idle timeout.
|
||||
|
||||
Push ``last_activity_at`` far into the past but keep
|
||||
``created_at`` recent. The session should still be valid
|
||||
because superadmin sessions only use absolute timeout.
|
||||
"""
|
||||
_touch_session(
|
||||
integration_app,
|
||||
logged_in_superadmin["session_id"],
|
||||
last_activity_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
)
|
||||
|
||||
integration_client.set_token(logged_in_superadmin["token"])
|
||||
result = integration_client.get("/api/v1/superadmin/auth/me")
|
||||
data = assert_success(result)
|
||||
assert "superadmin" in data
|
||||
|
||||
def test_revoked_superadmin_session_rejected(
|
||||
self, integration_client, logged_in_superadmin,
|
||||
):
|
||||
"""SA-SESS-04 — Revoked superadmin session is rejected."""
|
||||
integration_client.set_token(logged_in_superadmin["token"])
|
||||
|
||||
# Logout revokes the session
|
||||
integration_client.post("/api/v1/superadmin/auth/logout")
|
||||
integration_client.clear_token()
|
||||
|
||||
# Try using the old token
|
||||
integration_client.set_token(logged_in_superadmin["token"])
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.get("/api/v1/superadmin/auth/me")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_superadmin_session_has_owner_type(
|
||||
self, integration_app, logged_in_superadmin,
|
||||
):
|
||||
"""SA-SESS-05 — Superadmin session row has owner_type='superadmin'."""
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.utils.constants import SessionType
|
||||
|
||||
with integration_app.app_context():
|
||||
sess = Session.query.get(logged_in_superadmin["session_id"])
|
||||
assert sess is not None
|
||||
assert sess.owner_type == SessionType.SUPERADMIN
|
||||
assert sess.owner_id == logged_in_superadmin["superadmin"]["id"]
|
||||
@@ -0,0 +1,503 @@
|
||||
"""Unit tests for per-client CORS feature.
|
||||
|
||||
WHAT: Tests for per-client CORS origin resolution, including OIDCClient
|
||||
model methods, request client_id extraction, effective origin
|
||||
resolution, and integration with the CORS middleware.
|
||||
WHY: Per-client CORS prevents one OIDC client from making cross-origin
|
||||
requests meant for another; misconfiguration breaks browser flows.
|
||||
EXPECTED: Correct origin derivation, proper client_id extraction, and
|
||||
correct CORS headers on OIDC endpoints.
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, request as flask_request
|
||||
|
||||
import gatehouse_app.middleware.cors as cors_module
|
||||
from gatehouse_app.middleware.cors import (
|
||||
_get_oidc_client_id_from_request,
|
||||
_get_effective_cors_origins,
|
||||
setup_cors,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: build a lightweight stub that quacks like OIDCClient
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class StubClient:
|
||||
"""Minimal stand-in for OIDCClient -- no SQLAlchemy, no DB needed."""
|
||||
|
||||
def __init__(self, *, allowed_cors_origins=None, redirect_uris=None):
|
||||
self.allowed_cors_origins = allowed_cors_origins
|
||||
self.redirect_uris = redirect_uris or []
|
||||
|
||||
def get_effective_origins(self):
|
||||
from urllib.parse import urlparse
|
||||
|
||||
if self.allowed_cors_origins is None:
|
||||
return None
|
||||
if "+" in self.allowed_cors_origins:
|
||||
origins = set()
|
||||
for uri in self.redirect_uris:
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme and parsed.hostname:
|
||||
port = f":{parsed.port}" if parsed.port else ""
|
||||
origins.add(f"{parsed.scheme}://{parsed.hostname}{port}")
|
||||
return sorted(origins)
|
||||
return list(self.allowed_cors_origins)
|
||||
|
||||
def is_origin_allowed(self, origin):
|
||||
effective = self.get_effective_origins()
|
||||
if effective is None:
|
||||
return None
|
||||
if "*" in effective:
|
||||
return True
|
||||
return origin in effective
|
||||
|
||||
|
||||
def _basic_auth_header(client_id, secret="secret"):
|
||||
"""Return a 'Basic <b64>' Authorization header value."""
|
||||
return "Basic " + base64.b64encode(f"{client_id}:{secret}".encode()).decode()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OIDCClient.get_effective_origins
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetEffectiveOrigins:
|
||||
def test_returns_none_when_allowed_cors_origins_is_none(self):
|
||||
"""TEST: PCORS-GE-01 -- None config signals 'use global'."""
|
||||
client = StubClient(allowed_cors_origins=None)
|
||||
assert client.get_effective_origins() is None
|
||||
|
||||
def test_derives_from_redirect_uris_when_plus_sign(self):
|
||||
"""TEST: PCORS-GE-02 -- '+' in list derives origins from redirect_uris."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["+"],
|
||||
redirect_uris=[
|
||||
"https://app.example.com/callback",
|
||||
"http://localhost:3000/callback",
|
||||
],
|
||||
)
|
||||
assert client.get_effective_origins() == [
|
||||
"http://localhost:3000",
|
||||
"https://app.example.com",
|
||||
]
|
||||
|
||||
def test_derives_with_port(self):
|
||||
"""TEST: PCORS-GE-03 -- Non-standard ports are preserved."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["+"],
|
||||
redirect_uris=["https://app.example.com:8443/cb"],
|
||||
)
|
||||
assert client.get_effective_origins() == ["https://app.example.com:8443"]
|
||||
|
||||
def test_deduplicates_derived_origins(self):
|
||||
"""TEST: PCORS-GE-04 -- Duplicate redirect URIs produce unique origins."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["+"],
|
||||
redirect_uris=[
|
||||
"https://app.example.com/cb1",
|
||||
"https://app.example.com/cb2",
|
||||
],
|
||||
)
|
||||
assert client.get_effective_origins() == ["https://app.example.com"]
|
||||
|
||||
def test_returns_list_as_is_when_normal_list(self):
|
||||
"""TEST: PCORS-GE-05 -- Normal list is returned unchanged."""
|
||||
origins = ["https://a.com", "https://b.com"]
|
||||
client = StubClient(allowed_cors_origins=origins)
|
||||
assert client.get_effective_origins() == origins
|
||||
|
||||
def test_returns_wildcard_list_as_is(self):
|
||||
"""TEST: PCORS-GE-06 -- ['*'] is returned (handled downstream)."""
|
||||
client = StubClient(allowed_cors_origins=["*"])
|
||||
assert client.get_effective_origins() == ["*"]
|
||||
|
||||
def test_empty_redirect_uris_with_plus_returns_empty(self):
|
||||
"""TEST: PCORS-GE-07 -- '+' with empty redirect_uris yields empty list."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["+"],
|
||||
redirect_uris=[],
|
||||
)
|
||||
assert client.get_effective_origins() == []
|
||||
|
||||
def test_skips_malformed_redirect_uris(self):
|
||||
"""TEST: PCORS-GE-08 -- URIs without scheme/host are skipped."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["+"],
|
||||
redirect_uris=["not-a-uri", "https://good.com/cb"],
|
||||
)
|
||||
assert client.get_effective_origins() == ["https://good.com"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OIDCClient.is_origin_allowed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsOriginAllowed:
|
||||
def test_returns_none_when_no_per_client_config(self):
|
||||
"""TEST: PCORS-IO-01 -- None config defers to global CORS."""
|
||||
client = StubClient(allowed_cors_origins=None)
|
||||
assert client.is_origin_allowed("https://anything.com") is None
|
||||
|
||||
def test_returns_true_when_wildcard(self):
|
||||
"""TEST: PCORS-IO-02 -- '*' in effective origins allows any origin."""
|
||||
client = StubClient(allowed_cors_origins=["*"])
|
||||
assert client.is_origin_allowed("https://evil.com") is True
|
||||
|
||||
def test_returns_true_for_matching_origin(self):
|
||||
"""TEST: PCORS-IO-03 -- Matching origin is allowed."""
|
||||
client = StubClient(
|
||||
allowed_cors_origins=["https://app.example.com", "https://other.com"],
|
||||
)
|
||||
assert client.is_origin_allowed("https://app.example.com") is True
|
||||
|
||||
def test_returns_false_for_non_matching_origin(self):
|
||||
"""TEST: PCORS-IO-04 -- Non-matching origin is rejected."""
|
||||
client = StubClient(allowed_cors_origins=["https://app.example.com"])
|
||||
assert client.is_origin_allowed("https://evil.com") is False
|
||||
|
||||
def test_returns_false_for_empty_list(self):
|
||||
"""TEST: PCORS-IO-05 -- Empty list rejects everything."""
|
||||
client = StubClient(allowed_cors_origins=[])
|
||||
assert client.is_origin_allowed("https://anything.com") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_oidc_client_id_from_request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetOidcClientIdFromRequest:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
def test_extracts_from_basic_auth(self, app):
|
||||
"""TEST: PCORS-CI-01 -- Basic Auth header yields client_id."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("my-client")},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "my-client"
|
||||
|
||||
def test_extracts_from_form_body(self, app):
|
||||
"""TEST: PCORS-CI-02 -- Form-encoded body yields client_id."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
data={"client_id": "form-client", "grant_type": "client_credentials"},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "form-client"
|
||||
|
||||
def test_extracts_from_json_body(self, app):
|
||||
"""TEST: PCORS-CI-03 -- JSON body yields client_id."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
data=json.dumps({"client_id": "json-client", "grant_type": "client_credentials"}),
|
||||
content_type="application/json",
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "json-client"
|
||||
|
||||
def test_extracts_from_bearer_jwt(self, app):
|
||||
"""TEST: PCORS-CI-04 -- Bearer JWT payload yields client_id."""
|
||||
payload = base64.urlsafe_b64encode(
|
||||
json.dumps({"client_id": "jwt-client"}).encode()
|
||||
).rstrip(b"=").decode()
|
||||
token = f"eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.{payload}.sig"
|
||||
with app.test_request_context(
|
||||
"/oidc/userinfo",
|
||||
method="GET",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "jwt-client"
|
||||
|
||||
def test_returns_none_for_non_oidc_endpoint(self, app):
|
||||
"""TEST: PCORS-CI-05 -- Non-OIDC path returns None."""
|
||||
with app.test_request_context(
|
||||
"/api/v1/users",
|
||||
method="GET",
|
||||
headers={"Authorization": _basic_auth_header("x")},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() is None
|
||||
|
||||
def test_returns_none_when_no_client_id_found(self, app):
|
||||
"""TEST: PCORS-CI-06 -- OIDC token endpoint with no credentials returns None."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
data={"grant_type": "client_credentials"},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() is None
|
||||
|
||||
def test_extracts_from_revoke_endpoint(self, app):
|
||||
"""TEST: PCORS-CI-07 -- /oidc/revoke also accepts Basic Auth."""
|
||||
with app.test_request_context(
|
||||
"/oidc/revoke",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("rev-client")},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "rev-client"
|
||||
|
||||
def test_extracts_from_introspect_endpoint(self, app):
|
||||
"""TEST: PCORS-CI-08 -- /oidc/introspect also accepts Basic Auth."""
|
||||
with app.test_request_context(
|
||||
"/oidc/introspect",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("int-client")},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() == "int-client"
|
||||
|
||||
def test_returns_none_for_options_preflight(self, app):
|
||||
"""TEST: PCORS-CI-09 -- OPTIONS preflight cannot carry client credentials."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="OPTIONS",
|
||||
headers={
|
||||
"Origin": "https://app.com",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
},
|
||||
):
|
||||
assert _get_oidc_client_id_from_request() is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_effective_cors_origins
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetEffectiveCorsOrigins:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["CORS_ORIGINS"] = ["https://global.com"]
|
||||
return app
|
||||
|
||||
def test_global_config_for_non_oidc_endpoint(self, app):
|
||||
"""TEST: PCORS-EO-01 -- Non-OIDC path always uses global config."""
|
||||
with app.test_request_context("/api/v1/users", method="GET"):
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://global.com"]
|
||||
|
||||
def test_per_client_origins_for_oidc_endpoint(self, app):
|
||||
"""TEST: PCORS-EO-02 -- OIDC endpoint with configured client uses per-client origins."""
|
||||
fake_client = StubClient(allowed_cors_origins=["https://client.com"])
|
||||
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("test-client")},
|
||||
):
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://client.com"]
|
||||
|
||||
def test_fallback_when_client_not_found(self, app):
|
||||
"""TEST: PCORS-EO-03 -- Unknown client_id falls back to global config."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("unknown")},
|
||||
):
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = None
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://global.com"]
|
||||
|
||||
def test_fallback_when_allowed_cors_origins_is_none(self, app):
|
||||
"""TEST: PCORS-EO-04 -- Client with None origins falls back to global."""
|
||||
fake_client = StubClient(allowed_cors_origins=None)
|
||||
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("test-client")},
|
||||
):
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://global.com"]
|
||||
|
||||
def test_fallback_on_db_error(self, app):
|
||||
"""TEST: PCORS-EO-05 -- Database exception falls back to global config."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
headers={"Authorization": _basic_auth_header("test-client")},
|
||||
):
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.side_effect = Exception("DB down")
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://global.com"]
|
||||
|
||||
def test_fallback_when_no_client_id_extracted(self, app):
|
||||
"""TEST: PCORS-EO-06 -- OIDC path with no credentials falls back to global."""
|
||||
with app.test_request_context(
|
||||
"/oidc/token",
|
||||
method="POST",
|
||||
data={"grant_type": "client_credentials"},
|
||||
):
|
||||
result = _get_effective_cors_origins(app, flask_request)
|
||||
assert result == ["https://global.com"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: OIDC endpoint CORS headers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOidcEndpointCorsIntegration:
|
||||
@pytest.fixture
|
||||
def app_with_global_and_client(self):
|
||||
"""Flask app with global CORS and route stubs for integration tests."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["CORS_ORIGINS"] = ["https://global.com"]
|
||||
app.config["CORS_SUPPORTS_CREDENTIALS"] = True
|
||||
|
||||
@app.route("/oidc/token", methods=["POST", "OPTIONS"])
|
||||
def oidc_token():
|
||||
return {"status": "ok"}, 200
|
||||
|
||||
@app.route("/api/v1/users", methods=["GET", "OPTIONS"])
|
||||
def api_users():
|
||||
return {"users": []}, 200
|
||||
|
||||
setup_cors(app)
|
||||
return app
|
||||
|
||||
def test_post_oidc_with_per_client_origin_includes_cors_headers(
|
||||
self, app_with_global_and_client
|
||||
):
|
||||
"""TEST: PCORS-INT-01 -- POST to /oidc/token with per-client origin and
|
||||
Basic Auth includes CORS headers. Per-client CORS applies to the actual
|
||||
request (which carries credentials), not the preflight."""
|
||||
fake_client = StubClient(allowed_cors_origins=["https://client-app.com"])
|
||||
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.post(
|
||||
"/oidc/token",
|
||||
headers={
|
||||
"Origin": "https://client-app.com",
|
||||
"Authorization": _basic_auth_header("test-client"),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert (
|
||||
resp.headers.get("Access-Control-Allow-Origin")
|
||||
== "https://client-app.com"
|
||||
)
|
||||
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
|
||||
|
||||
def test_post_oidc_with_non_matching_per_client_origin_no_cors_headers(
|
||||
self, app_with_global_and_client
|
||||
):
|
||||
"""TEST: PCORS-INT-02 -- POST with origin not in per-client list has no
|
||||
CORS headers."""
|
||||
fake_client = StubClient(allowed_cors_origins=["https://allowed.com"])
|
||||
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.post(
|
||||
"/oidc/token",
|
||||
headers={
|
||||
"Origin": "https://evil.com",
|
||||
"Authorization": _basic_auth_header("test-client"),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
||||
|
||||
def test_post_oidc_wildcard_client_echoes_origin(self, app_with_global_and_client):
|
||||
"""TEST: PCORS-INT-03 -- Client with '*' echoes the request origin."""
|
||||
fake_client = StubClient(allowed_cors_origins=["*"])
|
||||
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.post(
|
||||
"/oidc/token",
|
||||
headers={
|
||||
"Origin": "https://any-origin.com",
|
||||
"Authorization": _basic_auth_header("test-client"),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert (
|
||||
resp.headers.get("Access-Control-Allow-Origin")
|
||||
== "https://any-origin.com"
|
||||
)
|
||||
|
||||
def test_preflight_oidc_falls_back_to_global(self, app_with_global_and_client):
|
||||
"""TEST: PCORS-INT-04 -- OPTIONS preflight cannot carry client credentials,
|
||||
so it uses global CORS config. A preflight from a per-client-only origin
|
||||
that is not in the global list will not receive CORS headers."""
|
||||
fake_client = StubClient(allowed_cors_origins=["https://client-app.com"])
|
||||
|
||||
with patch.object(cors_module, "OIDCClient") as MockModel:
|
||||
MockModel.query.filter_by.return_value.first.return_value = fake_client
|
||||
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.options(
|
||||
"/oidc/token",
|
||||
headers={
|
||||
"Origin": "https://client-app.com",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
},
|
||||
)
|
||||
# Origin not in global list; no CORS headers on preflight
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
||||
|
||||
def test_preflight_oidc_with_global_origin_succeeds(self, app_with_global_and_client):
|
||||
"""TEST: PCORS-INT-05 -- OPTIONS preflight from a globally-allowed origin
|
||||
returns 204 with CORS headers even for OIDC endpoints."""
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.options(
|
||||
"/oidc/token",
|
||||
headers={
|
||||
"Origin": "https://global.com",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
|
||||
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
|
||||
|
||||
def test_non_oidc_endpoint_uses_global_cors(self, app_with_global_and_client):
|
||||
"""TEST: PCORS-INT-06 -- Non-OIDC endpoint uses global CORS config."""
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.options(
|
||||
"/api/v1/users",
|
||||
headers={
|
||||
"Origin": "https://global.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
|
||||
|
||||
def test_post_oidc_no_auth_uses_global_cors(self, app_with_global_and_client):
|
||||
"""TEST: PCORS-INT-07 -- POST to OIDC endpoint without credentials uses
|
||||
global CORS (cannot identify client)."""
|
||||
with app_with_global_and_client.test_client() as client:
|
||||
resp = client.post(
|
||||
"/oidc/token",
|
||||
headers={"Origin": "https://global.com"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert (
|
||||
resp.headers.get("Access-Control-Allow-Origin") == "https://global.com"
|
||||
)
|
||||
Reference in New Issue
Block a user