From 02e95a419949ccdcc8ca70cb2129c85727eb1616 Mon Sep 17 00:00:00 2001 From: Cory Hawklvelt Date: Sun, 26 Apr 2026 18:36:58 +0930 Subject: [PATCH 1/4] feat(organizations): email inviter when membership invite is accepted When a user accepts an org invite, send a notification email to the person who sent the invite with membership details (member name, email, org name, role) and an optional View Organization button. Added build_invite_accepted_html() template to email_templates.py, wired it into the accept_invite() handler, and added a test case. --- gatehouse_app/api/v1/organizations/invites.py | 24 ++++++++++ gatehouse_app/services/email_templates.py | 48 +++++++++++++++++++ test_email.py | 22 ++++++++- 3 files changed, 93 insertions(+), 1 deletion(-) diff --git a/gatehouse_app/api/v1/organizations/invites.py b/gatehouse_app/api/v1/organizations/invites.py index 762d5a2..1ac718a 100644 --- a/gatehouse_app/api/v1/organizations/invites.py +++ b/gatehouse_app/api/v1/organizations/invites.py @@ -243,6 +243,30 @@ def accept_invite(token): invite.accept() + if invite.invited_by and invite.invited_by.email: + from gatehouse_app.services.email_templates import build_invite_accepted_html + from gatehouse_app.services.notification_service import NotificationService + + member_display = user.full_name or user.email + inviter_display = invite.invited_by.full_name or invite.invited_by.email + org_link = f"{current_app.config.get('APP_URL', '')}/organizations/{invite.organization_id}" + + html_body = build_invite_accepted_html( + inviter_name=inviter_display, + member_name=member_display, + member_email=user.email, + org_name=invite.organization.name, + role=invite.role, + org_link=org_link, + ) + + NotificationService._send_email_async( + to_address=invite.invited_by.email, + subject=f"{member_display} accepted your invitation to {invite.organization.name}", + body=f"{member_display} has accepted your invitation to join {invite.organization.name} on Secuird.", + html_body=html_body, + ) + has_webauthn = user.has_webauthn_enabled() has_totp = user.has_totp_enabled() diff --git a/gatehouse_app/services/email_templates.py b/gatehouse_app/services/email_templates.py index ec4d81d..4e6fc9c 100644 --- a/gatehouse_app/services/email_templates.py +++ b/gatehouse_app/services/email_templates.py @@ -562,3 +562,51 @@ def build_contact_enquiry_html(

{message_display}

''' return get_base_html(content, f"Secuird Website: {type_label}", f"New {type_label} from {submitter_email}") + + +def build_invite_accepted_html( + inviter_name: str, + member_name: str, + member_email: str, + org_name: str, + role: str, + org_link: Optional[str] = None, +) -> str: + """Build invite accepted notification email. + + Args: + inviter_name: Name of the person who sent the invite + member_name: Name of the person who accepted + member_email: Email of the person who accepted + org_name: Organization name + role: Role assigned to the member + org_link: Optional link to view the organization + + Returns: + HTML email string + """ + content = f''' +

Invitation Accepted

+

+ {member_name} has accepted your invitation to join {org_name} on Secuird. +

+ {get_alert_box(f"{member_name} ({member_email}) has joined {org_name}", "success", "✅")} + + + + +
+

Membership Details

+ + {get_detail_row("Member", member_name)} + {get_detail_row("Email", member_email)} + {get_detail_row("Organization", org_name)} + {get_detail_row("Role", role)} +
+
+ ''' + if org_link: + content += get_action_button(org_link, "View Organization", PRIMARY_COLOR) + + return get_base_html(content, f"Invitation accepted: {org_name}", f"{member_name} has joined {org_name}") + diff --git a/test_email.py b/test_email.py index 6061a5c..25bb98f 100644 --- a/test_email.py +++ b/test_email.py @@ -148,8 +148,28 @@ def test_html_email(): success = provider.send(message) print(f"Result: {'✅ SUCCESS' if success else '❌ FAILED'}") + # Test 8: Invite Accepted + print("\n--- Test 8: Invite Accepted ---") + html_body = email_templates.build_invite_accepted_html( + inviter_name="Admin User", + member_name="New Member", + member_email="newmember@example.com", + org_name="Acme Corporation", + role="Member", + org_link="https://secuird.tech/organizations/org-123", + ) + message = EmailMessage( + to="cory@hawkvelt.id.au", + subject="New Member accepted your invitation to Acme Corporation", + body="Plain text version: New Member has accepted your invitation.", + html_body=html_body, + from_address="Secuird ", + ) + success = provider.send(message) + print(f"Result: {'✅ SUCCESS' if success else '❌ FAILED'}") + print("\n" + "=" * 50) - print("All 7 email templates sent!") + print("All 8 email templates sent!") print("=" * 50) From 63a3109a825b400d4f64f73309eceea5877b558d Mon Sep 17 00:00:00 2001 From: Cory Hawklvelt Date: Mon, 27 Apr 2026 02:44:32 +0930 Subject: [PATCH 2/4] oidc-client mk1 --- docs/per-client-cors.md | 240 +++++++++ gatehouse_app/api/v1/organizations/clients.py | 1 + gatehouse_app/middleware/cors.py | 89 +++- gatehouse_app/models/oidc/oidc_client.py | 34 ++ ...e3f1a92c4d_add_oidc_client_cors_origins.py | 24 + tests/unit/test_per_client_cors.py | 503 ++++++++++++++++++ 6 files changed, 889 insertions(+), 2 deletions(-) create mode 100644 docs/per-client-cors.md create mode 100644 migrations/versions/b7e3f1a92c4d_add_oidc_client_cors_origins.py create mode 100644 tests/unit/test_per_client_cors.py diff --git a/docs/per-client-cors.md b/docs/per-client-cors.md new file mode 100644 index 0000000..0aca5d5 --- /dev/null +++ b/docs/per-client-cors.md @@ -0,0 +1,240 @@ +# Per-Client CORS Origins for OIDC Endpoints + +## Overview + +Gatehouse OIDC now supports **per-client CORS origins**. This allows each OIDC client to declare which browser origins are permitted to make cross-origin requests to OIDC endpoints (`/oidc/token`, `/oidc/revoke`, `/oidc/userinfo`, `/oidc/introspect`). + +Previously, CORS was controlled by a single server-wide `CORS_ORIGINS` environment variable. If your SPA's origin wasn't in that list, the browser would block requests to OIDC endpoints — even if your OIDC client was properly configured. + +## How It Works + +### The Problem + +When a browser-based SPA (e.g., running at `http://localhost:8080`) exchanges an authorization code for tokens, it makes a POST request to `/oidc/token`. The browser sends a preflight OPTIONS request first, and the server must respond with CORS headers allowing the SPA's origin. + +Previously, if `http://localhost:8080` wasn't in the server's `CORS_ORIGINS` env var, the preflight would fail and the SPA couldn't get tokens. + +### The Solution + +Each OIDC client can now declare its own `allowed_cors_origins`. When a request hits an OIDC endpoint, the server checks the client's CORS configuration first, then falls back to the global config. + +## Configuration + +### Setting CORS Origins on an OIDC Client + +When creating or updating an OIDC client, set the `allowed_cors_origins` field: + +```json +{ + "name": "My SPA", + "client_id": "oidc_myapp", + "redirect_uris": ["http://localhost:8080/callback", "https://app.example.com/callback"], + "allowed_cors_origins": ["http://localhost:8080", "https://app.example.com"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "scopes": ["openid", "profile", "email"] +} +``` + +### Auto-Derive from Redirect URIs + +Set `allowed_cors_origins` to `["+"]` to automatically derive CORS origins from the client's `redirect_uris`. The server extracts the scheme, hostname, and port from each redirect URI. + +```json +{ + "redirect_uris": ["http://localhost:8080/callback", "https://app.example.com/callback"], + "allowed_cors_origins": ["+"] +} +``` + +This is equivalent to: + +```json +{ + "allowed_cors_origins": ["http://localhost:8080", "https://app.example.com"] +} +``` + +### Use Global Config (Default) + +Set `allowed_cors_origins` to `null` (or omit it) to use the server's global `CORS_ORIGINS` config. This is the default behavior for existing clients. + +```json +{ + "allowed_cors_origins": null +} +``` + +### Allow All Origins (Not Recommended) + +Set `allowed_cors_origins` to `["*"]` to allow any origin. **This is not recommended for production.** + +```json +{ + "allowed_cors_origins": ["*"] +} +``` + +## Affected Endpoints + +The following OIDC endpoints support per-client CORS: + +| Endpoint | Method | How Client is Identified | +|---|---|---| +| `/oidc/token` | POST | `client_id` in request body or Basic Auth header | +| `/oidc/revoke` | POST | `client_id` in request body or Basic Auth header | +| `/oidc/introspect` | POST | `client_id` in request body or Basic Auth header | +| `/oidc/userinfo` | GET/POST | `client_id` extracted from Bearer token | + +## SPA Integration Guide + +### Step 1: Register Your OIDC Client + +Register your SPA as an OIDC client with the correct redirect URIs and CORS origins: + +```json +{ + "name": "My React App", + "redirect_uris": ["http://localhost:3000/callback"], + "allowed_cors_origins": ["http://localhost:3000"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "scopes": ["openid", "profile", "email"], + "is_confidential": false, + "require_pkce": true +} +``` + +### Step 2: Use PKCE (Required for Public Clients) + +Gatehouse requires PKCE for public clients. Generate a code verifier and challenge before redirecting to the authorize endpoint: + +```javascript +// Generate PKCE +const codeVerifier = generateRandomString(128); +const codeChallenge = await sha256(codeVerifier); +const state = generateRandomString(32); + +// Store verifier for later +sessionStorage.setItem('pkce_verifier', codeVerifier); + +// Redirect to authorize +const authUrl = new URL('https://api.example.com/api/v1/oidc/authorize'); +authUrl.searchParams.set('response_type', 'code'); +authUrl.searchParams.set('client_id', 'oidc_myapp'); +authUrl.searchParams.set('redirect_uri', 'http://localhost:3000/callback'); +authUrl.searchParams.set('scope', 'openid profile email'); +authUrl.searchParams.set('state', state); +authUrl.searchParams.set('code_challenge', codeChallenge); +authUrl.searchParams.set('code_challenge_method', 'S256'); + +window.location.href = authUrl.toString(); +``` + +### Step 3: Exchange Code for Tokens + +After the user authenticates and is redirected back to your callback page, exchange the authorization code for tokens: + +```javascript +// Extract code from URL +const params = new URLSearchParams(window.location.search); +const code = params.get('code'); +const state = params.get('state'); + +// Verify state matches +if (state !== sessionStorage.getItem('pkce_state')) { + throw new Error('State mismatch'); +} + +// Exchange code for tokens +const response = await fetch('https://api.example.com/api/v1/oidc/token', { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: new URLSearchParams({ + grant_type: 'authorization_code', + code: code, + redirect_uri: 'http://localhost:3000/callback', + client_id: 'oidc_myapp', + code_verifier: sessionStorage.getItem('pkce_verifier'), + }), +}); + +const tokens = await response.json(); +// tokens.access_token, tokens.id_token, tokens.refresh_token +``` + +The server will return CORS headers because `http://localhost:3000` is in the client's `allowed_cors_origins`. + +### Step 4: Refresh Tokens + +```javascript +const response = await fetch('https://api.example.com/api/v1/oidc/token', { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: new URLSearchParams({ + grant_type: 'refresh_token', + refresh_token: storedRefreshToken, + client_id: 'oidc_myapp', + }), +}); +``` + +### Step 5: Call UserInfo + +```javascript +const response = await fetch('https://api.example.com/api/v1/oidc/userinfo', { + headers: { + 'Authorization': `Bearer ${accessToken}`, + }, +}); +const userInfo = await response.json(); +``` + +## Troubleshooting + +### "CORS error" when exchanging code for tokens + +**Cause**: Your SPA's origin is not in the client's `allowed_cors_origins` or the server's global `CORS_ORIGINS`. + +**Fix**: Add your SPA's origin to the client's `allowed_cors_origins`: +```json +{ + "allowed_cors_origins": ["http://localhost:3000"] +} +``` + +### "CORS error" on preflight OPTIONS request + +**Cause**: The preflight request doesn't carry client credentials, so the server can't identify which client to check CORS origins for. It falls back to the global `CORS_ORIGINS`. + +**Fix**: Either add your origin to the global `CORS_ORIGINS` env var, or ensure the actual POST request (after preflight) includes the `client_id` in the request body. + +### CORS works for `/oidc/token` but not `/oidc/userinfo` + +**Cause**: The userinfo endpoint identifies the client from the Bearer token. If the token doesn't contain a `client_id` claim, the server falls back to global config. + +**Fix**: Ensure your access tokens include the `client_id` claim (this is the default behavior). + +## API Reference + +### OIDCClient Fields + +| Field | Type | Description | +|---|---|---| +| `allowed_cors_origins` | `string[]` or `null` | List of allowed browser origins. `null` = use global config. `["+"]` = auto-derive from redirect URIs. `["*"]` = allow all (not recommended). | + +### CORS Headers Returned + +When a request's origin matches the client's allowed origins: + +``` +Access-Control-Allow-Origin: +Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS +Access-Control-Allow-Headers: Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token +Access-Control-Allow-Credentials: true +Access-Control-Max-Age: 3600 +``` diff --git a/gatehouse_app/api/v1/organizations/clients.py b/gatehouse_app/api/v1/organizations/clients.py index 3817023..3cbbaba 100644 --- a/gatehouse_app/api/v1/organizations/clients.py +++ b/gatehouse_app/api/v1/organizations/clients.py @@ -28,6 +28,7 @@ def list_org_clients(org_id): "redirect_uris": c.redirect_uris, "scopes": c.scopes, "grant_types": c.grant_types, + "allowed_cors_origins": c.allowed_cors_origins, "is_active": c.is_active, "created_at": c.created_at.isoformat() + "Z", } diff --git a/gatehouse_app/middleware/cors.py b/gatehouse_app/middleware/cors.py index 797d026..a51e841 100644 --- a/gatehouse_app/middleware/cors.py +++ b/gatehouse_app/middleware/cors.py @@ -1,6 +1,12 @@ """CORS middleware configuration.""" +import base64 +import json +from urllib.parse import parse_qs + from flask import request, make_response +from gatehouse_app.models import OIDCClient + ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS" ALLOWED_HEADERS = ( "Content-Type, Authorization, X-Requested-With, X-Request-ID, " @@ -40,6 +46,85 @@ def _cors_origin_header(cors_origins, request_origin): return None +def _get_oidc_client_id_from_request(): + """Extract client_id from OIDC endpoint requests.""" + path = request.path + + # POST to /oidc/token, /oidc/revoke, /oidc/introspect + if request.method == "POST" and any( + path.endswith(ep) for ep in ("/oidc/token", "/oidc/revoke", "/oidc/introspect") + ): + # Try Basic Auth header first + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Basic "): + try: + decoded = base64.b64decode(auth_header[6:]).decode("utf-8") + client_id, _, _ = decoded.partition(":") + if client_id: + return client_id + except Exception: + pass + + # Try form body + if request.form: + client_id = request.form.get("client_id") + if client_id: + return client_id + + # Try JSON body + if request.is_json: + try: + client_id = request.json.get("client_id") + if client_id: + return client_id + except Exception: + pass + + return None + + # GET/POST to /oidc/userinfo + if path.endswith("/oidc/userinfo"): + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + try: + payload_b64 = token.split(".")[1] + padding = 4 - len(payload_b64) % 4 + if padding != 4: + payload_b64 += "=" * padding + payload = json.loads(base64.urlsafe_b64decode(payload_b64)) + return payload.get("client_id") + except Exception: + return None + + return None + + +def _get_effective_cors_origins(app, request): + """Get effective CORS origins, checking per-client config for OIDC endpoints.""" + global_origins = app.config.get("CORS_ORIGINS", []) + + if "/oidc/" not in request.path: + return global_origins + + try: + client_id = _get_oidc_client_id_from_request() + if not client_id: + return global_origins + + client = OIDCClient.query.filter_by(client_id=client_id).first() + if not client: + return global_origins + + effective = client.get_effective_origins() + if effective is not None: + return effective + except Exception: + pass + + return global_origins + + def setup_cors(app): """ Configure CORS for the application. @@ -54,7 +139,7 @@ def setup_cors(app): """Handle CORS preflight OPTIONS requests.""" if request.method == "OPTIONS": origin = request.headers.get("Origin") - cors_origins = app.config.get("CORS_ORIGINS", []) + cors_origins = _get_effective_cors_origins(app, request) if not _is_origin_allowed(origin, cors_origins): return None @@ -73,7 +158,7 @@ def setup_cors(app): def after_request_cors(response): """Add CORS headers to non-preflight responses.""" origin = request.headers.get("Origin") - cors_origins = app.config.get("CORS_ORIGINS", []) + cors_origins = _get_effective_cors_origins(app, request) allow_origin = _cors_origin_header(cors_origins, origin) if allow_origin: diff --git a/gatehouse_app/models/oidc/oidc_client.py b/gatehouse_app/models/oidc/oidc_client.py index 03c0b18..e09772e 100644 --- a/gatehouse_app/models/oidc/oidc_client.py +++ b/gatehouse_app/models/oidc/oidc_client.py @@ -1,4 +1,6 @@ """OIDC Client model.""" +from urllib.parse import urlparse + from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType @@ -21,6 +23,7 @@ class OIDCClient(BaseModel): grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types response_types = db.Column(db.JSON, nullable=False) # Allowed response types scopes = db.Column(db.JSON, nullable=False) # Allowed scopes + allowed_cors_origins = db.Column(db.JSON, nullable=True, default=None) # Per-client CORS origins # Client metadata logo_uri = db.Column(db.String(512), nullable=True) @@ -81,6 +84,37 @@ class OIDCClient(BaseModel): """Check if a redirect URI is allowed for this client.""" return redirect_uri in self.redirect_uris + def get_effective_origins(self) -> list | None: + """Get effective CORS origins for this client. + + Returns None to signal "use global config", a derived list from + redirect_uris when "+" is present, or the configured list as-is. + """ + if self.allowed_cors_origins is None: + return None + if "+" in self.allowed_cors_origins: + origins = set() + for uri in self.redirect_uris: + parsed = urlparse(uri) + if parsed.scheme and parsed.hostname: + port = f":{parsed.port}" if parsed.port else "" + origins.add(f"{parsed.scheme}://{parsed.hostname}{port}") + return sorted(origins) + return list(self.allowed_cors_origins) + + def is_origin_allowed(self, origin: str) -> bool | None: + """Check if a browser origin is allowed for CORS. + + Returns True/False when a per-client list is configured, + or None to defer to the global CORS policy. + """ + effective = self.get_effective_origins() + if effective is None: + return None + if "*" in effective: + return True + return origin in effective + def has_scope(self, scope: str) -> bool: """Check if client is allowed to request a specific scope.""" return scope in self.scopes diff --git a/migrations/versions/b7e3f1a92c4d_add_oidc_client_cors_origins.py b/migrations/versions/b7e3f1a92c4d_add_oidc_client_cors_origins.py new file mode 100644 index 0000000..304d5fd --- /dev/null +++ b/migrations/versions/b7e3f1a92c4d_add_oidc_client_cors_origins.py @@ -0,0 +1,24 @@ +"""Add allowed_cors_origins to oidc_clients. + +Revision ID: b7e3f1a92c4d +Revises: a1b2c3d4e5f6 +Create Date: 2026-04-27 00:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'b7e3f1a92c4d' +down_revision = 'a1b2c3d4e5f6' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('oidc_clients', sa.Column('allowed_cors_origins', sa.JSON(), nullable=True)) + + +def downgrade(): + op.drop_column('oidc_clients', 'allowed_cors_origins') diff --git a/tests/unit/test_per_client_cors.py b/tests/unit/test_per_client_cors.py new file mode 100644 index 0000000..1127531 --- /dev/null +++ b/tests/unit/test_per_client_cors.py @@ -0,0 +1,503 @@ +"""Unit tests for per-client CORS feature. + +WHAT: Tests for per-client CORS origin resolution, including OIDCClient + model methods, request client_id extraction, effective origin + resolution, and integration with the CORS middleware. +WHY: Per-client CORS prevents one OIDC client from making cross-origin + requests meant for another; misconfiguration breaks browser flows. +EXPECTED: Correct origin derivation, proper client_id extraction, and + correct CORS headers on OIDC endpoints. +""" +import base64 +import json +from unittest.mock import patch + +import pytest +from flask import Flask, request as flask_request + +import gatehouse_app.middleware.cors as cors_module +from gatehouse_app.middleware.cors import ( + _get_oidc_client_id_from_request, + _get_effective_cors_origins, + setup_cors, +) + + +# --------------------------------------------------------------------------- +# Helper: build a lightweight stub that quacks like OIDCClient +# --------------------------------------------------------------------------- + +class StubClient: + """Minimal stand-in for OIDCClient -- no SQLAlchemy, no DB needed.""" + + def __init__(self, *, allowed_cors_origins=None, redirect_uris=None): + self.allowed_cors_origins = allowed_cors_origins + self.redirect_uris = redirect_uris or [] + + def get_effective_origins(self): + from urllib.parse import urlparse + + if self.allowed_cors_origins is None: + return None + if "+" in self.allowed_cors_origins: + origins = set() + for uri in self.redirect_uris: + parsed = urlparse(uri) + if parsed.scheme and parsed.hostname: + port = f":{parsed.port}" if parsed.port else "" + origins.add(f"{parsed.scheme}://{parsed.hostname}{port}") + return sorted(origins) + return list(self.allowed_cors_origins) + + def is_origin_allowed(self, origin): + effective = self.get_effective_origins() + if effective is None: + return None + if "*" in effective: + return True + return origin in effective + + +def _basic_auth_header(client_id, secret="secret"): + """Return a 'Basic ' Authorization header value.""" + return "Basic " + base64.b64encode(f"{client_id}:{secret}".encode()).decode() + + +# --------------------------------------------------------------------------- +# OIDCClient.get_effective_origins +# --------------------------------------------------------------------------- + +class TestGetEffectiveOrigins: + def test_returns_none_when_allowed_cors_origins_is_none(self): + """TEST: PCORS-GE-01 -- None config signals 'use global'.""" + client = StubClient(allowed_cors_origins=None) + assert client.get_effective_origins() is None + + def test_derives_from_redirect_uris_when_plus_sign(self): + """TEST: PCORS-GE-02 -- '+' in list derives origins from redirect_uris.""" + client = StubClient( + allowed_cors_origins=["+"], + redirect_uris=[ + "https://app.example.com/callback", + "http://localhost:3000/callback", + ], + ) + assert client.get_effective_origins() == [ + "http://localhost:3000", + "https://app.example.com", + ] + + def test_derives_with_port(self): + """TEST: PCORS-GE-03 -- Non-standard ports are preserved.""" + client = StubClient( + allowed_cors_origins=["+"], + redirect_uris=["https://app.example.com:8443/cb"], + ) + assert client.get_effective_origins() == ["https://app.example.com:8443"] + + def test_deduplicates_derived_origins(self): + """TEST: PCORS-GE-04 -- Duplicate redirect URIs produce unique origins.""" + client = StubClient( + allowed_cors_origins=["+"], + redirect_uris=[ + "https://app.example.com/cb1", + "https://app.example.com/cb2", + ], + ) + assert client.get_effective_origins() == ["https://app.example.com"] + + def test_returns_list_as_is_when_normal_list(self): + """TEST: PCORS-GE-05 -- Normal list is returned unchanged.""" + origins = ["https://a.com", "https://b.com"] + client = StubClient(allowed_cors_origins=origins) + assert client.get_effective_origins() == origins + + def test_returns_wildcard_list_as_is(self): + """TEST: PCORS-GE-06 -- ['*'] is returned (handled downstream).""" + client = StubClient(allowed_cors_origins=["*"]) + assert client.get_effective_origins() == ["*"] + + def test_empty_redirect_uris_with_plus_returns_empty(self): + """TEST: PCORS-GE-07 -- '+' with empty redirect_uris yields empty list.""" + client = StubClient( + allowed_cors_origins=["+"], + redirect_uris=[], + ) + assert client.get_effective_origins() == [] + + def test_skips_malformed_redirect_uris(self): + """TEST: PCORS-GE-08 -- URIs without scheme/host are skipped.""" + client = StubClient( + allowed_cors_origins=["+"], + redirect_uris=["not-a-uri", "https://good.com/cb"], + ) + assert client.get_effective_origins() == ["https://good.com"] + + +# --------------------------------------------------------------------------- +# OIDCClient.is_origin_allowed +# --------------------------------------------------------------------------- + +class TestIsOriginAllowed: + def test_returns_none_when_no_per_client_config(self): + """TEST: PCORS-IO-01 -- None config defers to global CORS.""" + client = StubClient(allowed_cors_origins=None) + assert client.is_origin_allowed("https://anything.com") is None + + def test_returns_true_when_wildcard(self): + """TEST: PCORS-IO-02 -- '*' in effective origins allows any origin.""" + client = StubClient(allowed_cors_origins=["*"]) + assert client.is_origin_allowed("https://evil.com") is True + + def test_returns_true_for_matching_origin(self): + """TEST: PCORS-IO-03 -- Matching origin is allowed.""" + client = StubClient( + allowed_cors_origins=["https://app.example.com", "https://other.com"], + ) + assert client.is_origin_allowed("https://app.example.com") is True + + def test_returns_false_for_non_matching_origin(self): + """TEST: PCORS-IO-04 -- Non-matching origin is rejected.""" + client = StubClient(allowed_cors_origins=["https://app.example.com"]) + assert client.is_origin_allowed("https://evil.com") is False + + def test_returns_false_for_empty_list(self): + """TEST: PCORS-IO-05 -- Empty list rejects everything.""" + client = StubClient(allowed_cors_origins=[]) + assert client.is_origin_allowed("https://anything.com") is False + + +# --------------------------------------------------------------------------- +# _get_oidc_client_id_from_request +# --------------------------------------------------------------------------- + +class TestGetOidcClientIdFromRequest: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + def test_extracts_from_basic_auth(self, app): + """TEST: PCORS-CI-01 -- Basic Auth header yields client_id.""" + with app.test_request_context( + "/oidc/token", + method="POST", + headers={"Authorization": _basic_auth_header("my-client")}, + ): + assert _get_oidc_client_id_from_request() == "my-client" + + def test_extracts_from_form_body(self, app): + """TEST: PCORS-CI-02 -- Form-encoded body yields client_id.""" + with app.test_request_context( + "/oidc/token", + method="POST", + data={"client_id": "form-client", "grant_type": "client_credentials"}, + ): + assert _get_oidc_client_id_from_request() == "form-client" + + def test_extracts_from_json_body(self, app): + """TEST: PCORS-CI-03 -- JSON body yields client_id.""" + with app.test_request_context( + "/oidc/token", + method="POST", + data=json.dumps({"client_id": "json-client", "grant_type": "client_credentials"}), + content_type="application/json", + ): + assert _get_oidc_client_id_from_request() == "json-client" + + def test_extracts_from_bearer_jwt(self, app): + """TEST: PCORS-CI-04 -- Bearer JWT payload yields client_id.""" + payload = base64.urlsafe_b64encode( + json.dumps({"client_id": "jwt-client"}).encode() + ).rstrip(b"=").decode() + token = f"eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.{payload}.sig" + with app.test_request_context( + "/oidc/userinfo", + method="GET", + headers={"Authorization": f"Bearer {token}"}, + ): + assert _get_oidc_client_id_from_request() == "jwt-client" + + def test_returns_none_for_non_oidc_endpoint(self, app): + """TEST: PCORS-CI-05 -- Non-OIDC path returns None.""" + with app.test_request_context( + "/api/v1/users", + method="GET", + headers={"Authorization": _basic_auth_header("x")}, + ): + assert _get_oidc_client_id_from_request() is None + + def test_returns_none_when_no_client_id_found(self, app): + """TEST: PCORS-CI-06 -- OIDC token endpoint with no credentials returns None.""" + with app.test_request_context( + "/oidc/token", + method="POST", + data={"grant_type": "client_credentials"}, + ): + assert _get_oidc_client_id_from_request() is None + + def test_extracts_from_revoke_endpoint(self, app): + """TEST: PCORS-CI-07 -- /oidc/revoke also accepts Basic Auth.""" + with app.test_request_context( + "/oidc/revoke", + method="POST", + headers={"Authorization": _basic_auth_header("rev-client")}, + ): + assert _get_oidc_client_id_from_request() == "rev-client" + + def test_extracts_from_introspect_endpoint(self, app): + """TEST: PCORS-CI-08 -- /oidc/introspect also accepts Basic Auth.""" + with app.test_request_context( + "/oidc/introspect", + method="POST", + headers={"Authorization": _basic_auth_header("int-client")}, + ): + assert _get_oidc_client_id_from_request() == "int-client" + + def test_returns_none_for_options_preflight(self, app): + """TEST: PCORS-CI-09 -- OPTIONS preflight cannot carry client credentials.""" + with app.test_request_context( + "/oidc/token", + method="OPTIONS", + headers={ + "Origin": "https://app.com", + "Access-Control-Request-Method": "POST", + }, + ): + assert _get_oidc_client_id_from_request() is None + + +# --------------------------------------------------------------------------- +# _get_effective_cors_origins +# --------------------------------------------------------------------------- + +class TestGetEffectiveCorsOrigins: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["CORS_ORIGINS"] = ["https://global.com"] + return app + + def test_global_config_for_non_oidc_endpoint(self, app): + """TEST: PCORS-EO-01 -- Non-OIDC path always uses global config.""" + with app.test_request_context("/api/v1/users", method="GET"): + result = _get_effective_cors_origins(app, flask_request) + assert result == ["https://global.com"] + + def test_per_client_origins_for_oidc_endpoint(self, app): + """TEST: PCORS-EO-02 -- OIDC endpoint with configured client uses per-client origins.""" + fake_client = StubClient(allowed_cors_origins=["https://client.com"]) + + with app.test_request_context( + "/oidc/token", + method="POST", + headers={"Authorization": _basic_auth_header("test-client")}, + ): + with patch.object(cors_module, "OIDCClient") as MockModel: + MockModel.query.filter_by.return_value.first.return_value = fake_client + result = _get_effective_cors_origins(app, flask_request) + assert result == ["https://client.com"] + + def test_fallback_when_client_not_found(self, app): + """TEST: PCORS-EO-03 -- Unknown client_id falls back to global config.""" + with app.test_request_context( + "/oidc/token", + method="POST", + headers={"Authorization": _basic_auth_header("unknown")}, + ): + with patch.object(cors_module, "OIDCClient") as MockModel: + MockModel.query.filter_by.return_value.first.return_value = None + result = _get_effective_cors_origins(app, flask_request) + assert result == ["https://global.com"] + + def test_fallback_when_allowed_cors_origins_is_none(self, app): + """TEST: PCORS-EO-04 -- Client with None origins falls back to global.""" + fake_client = StubClient(allowed_cors_origins=None) + + with app.test_request_context( + "/oidc/token", + method="POST", + headers={"Authorization": _basic_auth_header("test-client")}, + ): + with patch.object(cors_module, "OIDCClient") as MockModel: + MockModel.query.filter_by.return_value.first.return_value = fake_client + result = _get_effective_cors_origins(app, flask_request) + assert result == ["https://global.com"] + + def test_fallback_on_db_error(self, app): + """TEST: PCORS-EO-05 -- Database exception falls back to global config.""" + with app.test_request_context( + "/oidc/token", + method="POST", + headers={"Authorization": _basic_auth_header("test-client")}, + ): + with patch.object(cors_module, "OIDCClient") as MockModel: + MockModel.query.filter_by.side_effect = Exception("DB down") + result = _get_effective_cors_origins(app, flask_request) + assert result == ["https://global.com"] + + def test_fallback_when_no_client_id_extracted(self, app): + """TEST: PCORS-EO-06 -- OIDC path with no credentials falls back to global.""" + with app.test_request_context( + "/oidc/token", + method="POST", + data={"grant_type": "client_credentials"}, + ): + result = _get_effective_cors_origins(app, flask_request) + assert result == ["https://global.com"] + + +# --------------------------------------------------------------------------- +# Integration: OIDC endpoint CORS headers +# --------------------------------------------------------------------------- + +class TestOidcEndpointCorsIntegration: + @pytest.fixture + def app_with_global_and_client(self): + """Flask app with global CORS and route stubs for integration tests.""" + app = Flask(__name__) + app.config["TESTING"] = True + app.config["CORS_ORIGINS"] = ["https://global.com"] + app.config["CORS_SUPPORTS_CREDENTIALS"] = True + + @app.route("/oidc/token", methods=["POST", "OPTIONS"]) + def oidc_token(): + return {"status": "ok"}, 200 + + @app.route("/api/v1/users", methods=["GET", "OPTIONS"]) + def api_users(): + return {"users": []}, 200 + + setup_cors(app) + return app + + def test_post_oidc_with_per_client_origin_includes_cors_headers( + self, app_with_global_and_client + ): + """TEST: PCORS-INT-01 -- POST to /oidc/token with per-client origin and + Basic Auth includes CORS headers. Per-client CORS applies to the actual + request (which carries credentials), not the preflight.""" + fake_client = StubClient(allowed_cors_origins=["https://client-app.com"]) + + with patch.object(cors_module, "OIDCClient") as MockModel: + MockModel.query.filter_by.return_value.first.return_value = fake_client + + with app_with_global_and_client.test_client() as client: + resp = client.post( + "/oidc/token", + headers={ + "Origin": "https://client-app.com", + "Authorization": _basic_auth_header("test-client"), + }, + ) + assert resp.status_code == 200 + assert ( + resp.headers.get("Access-Control-Allow-Origin") + == "https://client-app.com" + ) + assert resp.headers.get("Access-Control-Allow-Credentials") == "true" + + def test_post_oidc_with_non_matching_per_client_origin_no_cors_headers( + self, app_with_global_and_client + ): + """TEST: PCORS-INT-02 -- POST with origin not in per-client list has no + CORS headers.""" + fake_client = StubClient(allowed_cors_origins=["https://allowed.com"]) + + with patch.object(cors_module, "OIDCClient") as MockModel: + MockModel.query.filter_by.return_value.first.return_value = fake_client + + with app_with_global_and_client.test_client() as client: + resp = client.post( + "/oidc/token", + headers={ + "Origin": "https://evil.com", + "Authorization": _basic_auth_header("test-client"), + }, + ) + assert resp.status_code == 200 + assert resp.headers.get("Access-Control-Allow-Origin") is None + + def test_post_oidc_wildcard_client_echoes_origin(self, app_with_global_and_client): + """TEST: PCORS-INT-03 -- Client with '*' echoes the request origin.""" + fake_client = StubClient(allowed_cors_origins=["*"]) + + with patch.object(cors_module, "OIDCClient") as MockModel: + MockModel.query.filter_by.return_value.first.return_value = fake_client + + with app_with_global_and_client.test_client() as client: + resp = client.post( + "/oidc/token", + headers={ + "Origin": "https://any-origin.com", + "Authorization": _basic_auth_header("test-client"), + }, + ) + assert resp.status_code == 200 + assert ( + resp.headers.get("Access-Control-Allow-Origin") + == "https://any-origin.com" + ) + + def test_preflight_oidc_falls_back_to_global(self, app_with_global_and_client): + """TEST: PCORS-INT-04 -- OPTIONS preflight cannot carry client credentials, + so it uses global CORS config. A preflight from a per-client-only origin + that is not in the global list will not receive CORS headers.""" + fake_client = StubClient(allowed_cors_origins=["https://client-app.com"]) + + with patch.object(cors_module, "OIDCClient") as MockModel: + MockModel.query.filter_by.return_value.first.return_value = fake_client + + with app_with_global_and_client.test_client() as client: + resp = client.options( + "/oidc/token", + headers={ + "Origin": "https://client-app.com", + "Access-Control-Request-Method": "POST", + }, + ) + # Origin not in global list; no CORS headers on preflight + assert resp.headers.get("Access-Control-Allow-Origin") is None + + def test_preflight_oidc_with_global_origin_succeeds(self, app_with_global_and_client): + """TEST: PCORS-INT-05 -- OPTIONS preflight from a globally-allowed origin + returns 204 with CORS headers even for OIDC endpoints.""" + with app_with_global_and_client.test_client() as client: + resp = client.options( + "/oidc/token", + headers={ + "Origin": "https://global.com", + "Access-Control-Request-Method": "POST", + }, + ) + assert resp.status_code == 204 + assert resp.headers.get("Access-Control-Allow-Origin") == "https://global.com" + assert resp.headers.get("Access-Control-Allow-Credentials") == "true" + + def test_non_oidc_endpoint_uses_global_cors(self, app_with_global_and_client): + """TEST: PCORS-INT-06 -- Non-OIDC endpoint uses global CORS config.""" + with app_with_global_and_client.test_client() as client: + resp = client.options( + "/api/v1/users", + headers={ + "Origin": "https://global.com", + "Access-Control-Request-Method": "GET", + }, + ) + assert resp.status_code == 204 + assert resp.headers.get("Access-Control-Allow-Origin") == "https://global.com" + + def test_post_oidc_no_auth_uses_global_cors(self, app_with_global_and_client): + """TEST: PCORS-INT-07 -- POST to OIDC endpoint without credentials uses + global CORS (cannot identify client).""" + with app_with_global_and_client.test_client() as client: + resp = client.post( + "/oidc/token", + headers={"Origin": "https://global.com"}, + ) + assert resp.status_code == 200 + assert ( + resp.headers.get("Access-Control-Allow-Origin") == "https://global.com" + ) From 5abbadff9a871a22d97bdb9e8fb4459fbc1fca36 Mon Sep 17 00:00:00 2001 From: Cory Hawklvelt Date: Tue, 28 Apr 2026 17:17:54 +0930 Subject: [PATCH 3/4] Improve auditing --- gatehouse_app/api/v1/departments.py | 56 ++++++++++++++++ gatehouse_app/api/v1/external_auth/admin.py | 19 ++++++ gatehouse_app/api/v1/oidc.py | 15 +++++ .../api/v1/organizations/api_keys.py | 37 +++++++++-- gatehouse_app/api/v1/organizations/cas.py | 21 ++++++ gatehouse_app/api/v1/organizations/clients.py | 30 +++++++++ gatehouse_app/api/v1/organizations/core.py | 26 ++++++++ gatehouse_app/api/v1/organizations/invites.py | 9 +++ gatehouse_app/api/v1/organizations/members.py | 36 +++++++++- gatehouse_app/api/v1/organizations/roles.py | 23 ++++++- gatehouse_app/api/v1/principals.py | 65 +++++++++++++++++++ gatehouse_app/api/v1/superadmin/auth.py | 2 + gatehouse_app/utils/constants.py | 21 ++++++ 13 files changed, 354 insertions(+), 6 deletions(-) diff --git a/gatehouse_app/api/v1/departments.py b/gatehouse_app/api/v1/departments.py index 4ced781..117db75 100644 --- a/gatehouse_app/api/v1/departments.py +++ b/gatehouse_app/api/v1/departments.py @@ -10,6 +10,8 @@ from gatehouse_app.models import Department, DepartmentMembership from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.services.user_service import UserService from gatehouse_app.extensions import db +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.services.audit_service import AuditService class DepartmentCreateSchema(Schema): @@ -127,6 +129,15 @@ def create_department(org_id): db.session.add(dept) db.session.commit() + AuditService.log_action( + action=AuditAction.DEPARTMENT_CREATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="department", + resource_id=str(dept.id), + description=f"Department '{dept.name}' created", + ) + return api_response( data={"department": dept.to_dict()}, message="Department created successfully", @@ -255,6 +266,15 @@ def update_department(org_id, dept_id): db.session.commit() + AuditService.log_action( + action=AuditAction.DEPARTMENT_UPDATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="department", + resource_id=str(dept.id), + description=f"Department '{dept.name}' updated", + ) + return api_response( data={"department": dept.to_dict()}, message="Department updated successfully", @@ -308,6 +328,15 @@ def delete_department(org_id, dept_id): dept.deleted_at = db.func.now() db.session.commit() + AuditService.log_action( + action=AuditAction.DEPARTMENT_DELETED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="department", + resource_id=str(dept.id), + description=f"Department '{dept.name}' deleted", + ) + return api_response( message="Department deleted successfully", ) @@ -461,6 +490,15 @@ def add_department_member(org_id, dept_id): db.session.commit() + AuditService.log_action( + action=AuditAction.DEPARTMENT_MEMBER_ADDED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="user", + resource_id=str(user.id), + description=f"Added user {user.email} to department '{dept.name}'", + ) + member_dict = membership.to_dict() member_dict["user"] = user.to_dict() @@ -533,6 +571,15 @@ def remove_department_member(org_id, dept_id, user_id): membership.deleted_at = db.func.now() db.session.commit() + AuditService.log_action( + action=AuditAction.DEPARTMENT_MEMBER_REMOVED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="user", + resource_id=str(user_id), + description=f"Removed user from department '{dept.name}'", + ) + return api_response( message="Member removed successfully", ) @@ -699,5 +746,14 @@ def set_dept_cert_policy(org_id, dept_id): db.session.commit() + AuditService.log_action( + action=AuditAction.DEPARTMENT_CERT_POLICY_UPDATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="department", + resource_id=str(dept_id), + description=f"Certificate policy updated for department '{dept.name}'", + ) + return api_response(data={"cert_policy": policy.to_dict()}, message="Certificate policy saved") diff --git a/gatehouse_app/api/v1/external_auth/admin.py b/gatehouse_app/api/v1/external_auth/admin.py index d1149b3..1a9de57 100644 --- a/gatehouse_app/api/v1/external_auth/admin.py +++ b/gatehouse_app/api/v1/external_auth/admin.py @@ -3,6 +3,8 @@ from flask import g, request from gatehouse_app.api.v1 import api_v1_bp from gatehouse_app.utils.response import api_response from gatehouse_app.utils.decorators import login_required +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.services.audit_service import AuditService @api_v1_bp.route("/admin/oauth/providers", methods=["GET"]) @@ -78,6 +80,14 @@ def admin_configure_app_provider(provider: str): db.session.add(cfg) db.session.commit() + AuditService.log_action( + action=AuditAction.EXTERNAL_AUTH_CONFIG_UPDATE if cfg else AuditAction.EXTERNAL_AUTH_CONFIG_CREATE, + user_id=g.current_user.id, + resource_type="oauth_provider", + resource_id=provider, + description=f"OAuth provider '{provider}' configured (enabled={cfg.is_enabled})", + ) + return api_response( data={"provider": {"id": provider, "client_id": cfg.client_id, "is_enabled": cfg.is_enabled}}, message=f"{provider.capitalize()} OAuth provider configured successfully", @@ -104,4 +114,13 @@ def admin_delete_app_provider(provider: str): return api_response(success=False, message=f"Provider '{provider}' is not configured", status=404, error_type="NOT_FOUND") cfg.delete() + + AuditService.log_action( + action=AuditAction.EXTERNAL_AUTH_CONFIG_DELETE, + user_id=g.current_user.id, + resource_type="oauth_provider", + resource_id=provider, + description=f"OAuth provider '{provider}' configuration removed", + ) + return api_response(message=f"{provider.capitalize()} OAuth provider configuration removed") diff --git a/gatehouse_app/api/v1/oidc.py b/gatehouse_app/api/v1/oidc.py index 07a9366..e4f01f4 100644 --- a/gatehouse_app/api/v1/oidc.py +++ b/gatehouse_app/api/v1/oidc.py @@ -26,6 +26,9 @@ from gatehouse_app.exceptions.auth_exceptions import ( AccountSuspendedError, AccountInactiveError, ) +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.services.audit_service import AuditService +from gatehouse_app.services.oidc_audit_service import OIDCAuditService logger = logging.getLogger(__name__) @@ -849,6 +852,18 @@ def oidc_register(): ) client.save() + OIDCAuditService.log_event( + event_type="client_registration", + client_id=client_id, + user_id=g.current_user.id if hasattr(g, "current_user") else None, + success=True, + metadata={ + "client_name": client_name, + "redirect_uris": redirect_uris, + "organization_id": str(organization.id), + }, + ) + response = jsonify({ "client_id": client_id, "client_secret": client_secret, diff --git a/gatehouse_app/api/v1/organizations/api_keys.py b/gatehouse_app/api/v1/organizations/api_keys.py index 90d83ee..878c180 100644 --- a/gatehouse_app/api/v1/organizations/api_keys.py +++ b/gatehouse_app/api/v1/organizations/api_keys.py @@ -8,6 +8,8 @@ from gatehouse_app.utils.decorators import login_required, require_admin, full_a from gatehouse_app.models.organization import OrganizationApiKey from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.extensions import db +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.services.audit_service import AuditService class ApiKeyCreateSchema(Schema): @@ -130,7 +132,16 @@ def create_api_key(org_id): name=data["name"], description=data.get("description"), ) - + + AuditService.log_action( + action=AuditAction.ORG_API_KEY_CREATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="api_key", + resource_id=str(api_key.id), + description=f"API key '{api_key.name}' created", + ) + # Return the key data with the plain text key (only on creation) key_dict = api_key.to_dict() key_dict["key"] = plain_key # Include plain text only on creation @@ -219,9 +230,18 @@ def update_api_key(org_id, key_id): api_key.name = data["name"] if "description" in data: api_key.description = data["description"] - + api_key.save() - + + AuditService.log_action( + action=AuditAction.ORG_API_KEY_UPDATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="api_key", + resource_id=str(api_key.id), + description=f"API key '{api_key.name}' updated", + ) + return api_response( data={"api_key": api_key.to_dict()}, message="API key updated successfully", @@ -293,7 +313,16 @@ def delete_api_key(org_id, key_id): # Soft delete the API key api_key.delete(soft=True) - + + AuditService.log_action( + action=AuditAction.ORG_API_KEY_DELETED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="api_key", + resource_id=str(api_key.id), + description=f"API key '{api_key.name}' deleted", + ) + return api_response( message="API key deleted successfully", ) diff --git a/gatehouse_app/api/v1/organizations/cas.py b/gatehouse_app/api/v1/organizations/cas.py index ad1ac71..d1567b4 100644 --- a/gatehouse_app/api/v1/organizations/cas.py +++ b/gatehouse_app/api/v1/organizations/cas.py @@ -6,6 +6,8 @@ from gatehouse_app.utils.response import api_response from gatehouse_app.utils.decorators import login_required, require_admin from gatehouse_app.extensions import db from gatehouse_app.api.v1.organizations._helpers import _get_system_ca_dict +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.services.audit_service import AuditService @api_v1_bp.route("/organizations//cas", methods=["GET"]) @@ -66,6 +68,16 @@ def update_org_ca(org_id, ca_id): ca.max_cert_validity_hours = data["max_cert_validity_hours"] db.session.commit() + + AuditService.log_action( + action=AuditAction.CA_UPDATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="CA", + resource_id=ca_id, + description=f"CA '{ca.name}' updated", + ) + return api_response(data={"ca": ca.to_dict()}, message="CA updated successfully") except ValidationError as e: return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) @@ -150,6 +162,15 @@ def create_org_ca(org_id): return api_response(success=False, message="A CA with that name already exists in this organization (it may have been recently deleted — choose a different name).", status=400, error_type="DUPLICATE_NAME") raise + AuditService.log_action( + action=AuditAction.CA_CREATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="CA", + resource_id=str(ca.id), + description=f"CA '{ca.name}' created", + ) + return api_response(data={"ca": ca.to_dict()}, message="CA created successfully", status=201) except MaValidationError as e: return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) diff --git a/gatehouse_app/api/v1/organizations/clients.py b/gatehouse_app/api/v1/organizations/clients.py index 3cbbaba..8fb7334 100644 --- a/gatehouse_app/api/v1/organizations/clients.py +++ b/gatehouse_app/api/v1/organizations/clients.py @@ -5,6 +5,8 @@ from gatehouse_app.api.v1 import api_v1_bp from gatehouse_app.utils.response import api_response from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required from gatehouse_app.extensions import db, bcrypt +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.services.audit_service import AuditService @api_v1_bp.route("/organizations//clients", methods=["GET"]) @@ -79,6 +81,15 @@ def create_org_client(org_id): db.session.add(client) db.session.commit() + AuditService.log_action( + action=AuditAction.ORG_CLIENT_CREATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="oidc_client", + resource_id=str(client.id), + description=f"OIDC client '{client.name}' created", + ) + return api_response( data={ "client": { @@ -126,6 +137,15 @@ def update_org_client(org_id, client_id): db.session.commit() + AuditService.log_action( + action=AuditAction.ORG_CLIENT_UPDATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="oidc_client", + resource_id=str(client.id), + description=f"OIDC client '{client.name}' updated", + ) + return api_response( data={ "client": { @@ -155,4 +175,14 @@ def delete_org_client(org_id, client_id): client.is_active = False db.session.commit() + + AuditService.log_action( + action=AuditAction.ORG_CLIENT_DEACTIVATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="oidc_client", + resource_id=str(client.id), + description=f"OIDC client '{client.name}' deactivated", + ) + return api_response(data={}, message="Client deactivated successfully") diff --git a/gatehouse_app/api/v1/organizations/core.py b/gatehouse_app/api/v1/organizations/core.py index 940c089..6c785b9 100644 --- a/gatehouse_app/api/v1/organizations/core.py +++ b/gatehouse_app/api/v1/organizations/core.py @@ -7,6 +7,8 @@ from gatehouse_app.utils.response import api_response from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required from gatehouse_app.schemas.organization_schema import OrganizationCreateSchema, OrganizationUpdateSchema from gatehouse_app.services.organization_service import OrganizationService +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.services.audit_service import AuditService @api_v1_bp.route("/organizations", methods=["POST"]) @@ -32,6 +34,14 @@ def create_organization(): description=data.get("description"), logo_url=data.get("logo_url"), ) + AuditService.log_action( + action=AuditAction.ORG_CREATE, + user_id=g.current_user.id, + organization_id=org.id, + resource_type="organization", + resource_id=str(org.id), + description=f"Organization '{org.name}' created", + ) return api_response(data={"organization": org.to_dict()}, message="Organization created successfully", status=201) except ValidationError as e: return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) @@ -60,6 +70,14 @@ def update_organization(org_id): data = schema.load(request.json) org = OrganizationService.get_organization_by_id(org_id) org = OrganizationService.update_organization(org=org, user_id=g.current_user.id, **data) + AuditService.log_action( + action=AuditAction.ORG_UPDATE, + user_id=g.current_user.id, + organization_id=org.id, + resource_type="organization", + resource_id=str(org.id), + description=f"Organization '{org.name}' updated", + ) return api_response(data={"organization": org.to_dict()}, message="Organization updated successfully") except ValidationError as e: return api_response(success=False, message="Validation failed", status=400, error_type="VALIDATION_ERROR", error_details=e.messages) @@ -92,4 +110,12 @@ def delete_organization(org_id): ) OrganizationService.force_delete_organization(org=org, user_id=caller.id) + AuditService.log_action( + action=AuditAction.ORG_DELETE, + user_id=caller.id, + organization_id=org.id, + resource_type="organization", + resource_id=str(org.id), + description=f"Organization '{org.name}' deleted", + ) return api_response(message="Organization deleted successfully") diff --git a/gatehouse_app/api/v1/organizations/invites.py b/gatehouse_app/api/v1/organizations/invites.py index 1ac718a..b5b7ea8 100644 --- a/gatehouse_app/api/v1/organizations/invites.py +++ b/gatehouse_app/api/v1/organizations/invites.py @@ -136,6 +136,15 @@ def cancel_org_invite(org_id, invite_id): return api_response(success=False, message="Invite not found", status=404) invite.delete(soft=True) + AuditService.log_action( + action=AuditAction.ORG_INVITE_CANCELLED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="org_invite", + resource_id=invite.id, + metadata={"invited_email": invite.email, "role": invite.role}, + description=f"Invitation for {invite.email} cancelled", + ) return api_response(data={}, message="Invite cancelled") diff --git a/gatehouse_app/api/v1/organizations/members.py b/gatehouse_app/api/v1/organizations/members.py index b42a99c..4ebe750 100644 --- a/gatehouse_app/api/v1/organizations/members.py +++ b/gatehouse_app/api/v1/organizations/members.py @@ -7,7 +7,8 @@ from gatehouse_app.utils.decorators import login_required, require_admin, full_a from gatehouse_app.schemas.organization_schema import InviteMemberSchema, UpdateMemberRoleSchema from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.services.user_service import UserService -from gatehouse_app.utils.constants import OrganizationRole +from gatehouse_app.utils.constants import AuditAction, OrganizationRole +from gatehouse_app.services.audit_service import AuditService @api_v1_bp.route("/organizations//members", methods=["GET"]) @@ -43,6 +44,14 @@ def add_organization_member(org_id): role = OrganizationRole(data["role"]) member = OrganizationService.add_member(org=org, user_id=user.id, role=role, inviter_id=g.current_user.id) + AuditService.log_action( + action=AuditAction.ORG_MEMBER_ADD, + user_id=g.current_user.id, + organization_id=org.id, + resource_type="user", + resource_id=str(user.id), + description=f"Added user {user.email} to organization with role {role.value}", + ) member_dict = member.to_dict() member_dict["user"] = user.to_dict() return api_response(data={"member": member_dict}, message="Member added successfully", status=201) @@ -60,6 +69,14 @@ def remove_organization_member(org_id, user_id): OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id) except ValueError as e: return api_response(success=False, message=str(e), status=403, error_type="OWNER_PROTECTION") + AuditService.log_action( + action=AuditAction.ORG_MEMBER_REMOVE, + user_id=g.current_user.id, + organization_id=org.id, + resource_type="user", + resource_id=str(user_id), + description=f"Removed user {user_id} from organization", + ) return api_response(message="Member removed successfully") @@ -74,6 +91,14 @@ def update_member_role(org_id, user_id): org = OrganizationService.get_organization_by_id(org_id) new_role = OrganizationRole(data["role"]) member = OrganizationService.update_member_role(org=org, user_id=user_id, new_role=new_role, updater_id=g.current_user.id) + AuditService.log_action( + action=AuditAction.ORG_MEMBER_ROLE_CHANGE, + user_id=g.current_user.id, + organization_id=org.id, + resource_type="user", + resource_id=str(user_id), + description=f"Changed role for user {user_id} to {new_role.value}", + ) member_dict = member.to_dict() member_dict["user"] = member.user.to_dict() return api_response(data={"member": member_dict}, message="Member role updated successfully") @@ -180,4 +205,13 @@ def send_mfa_reminder(org_id, user_id): html_body=html_body, ) + AuditService.log_action( + action=AuditAction.ORG_MFA_REMINDER_SENT, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="user", + resource_id=str(user_id), + description=f"MFA reminder sent to {user.email}", + ) + return api_response(data={}, message="Reminder sent successfully") diff --git a/gatehouse_app/api/v1/organizations/roles.py b/gatehouse_app/api/v1/organizations/roles.py index 3c982e9..2f8451e 100644 --- a/gatehouse_app/api/v1/organizations/roles.py +++ b/gatehouse_app/api/v1/organizations/roles.py @@ -3,8 +3,9 @@ from flask import g, request from gatehouse_app.api.v1 import api_v1_bp from gatehouse_app.utils.response import api_response from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required -from gatehouse_app.utils.constants import OrganizationRole +from gatehouse_app.utils.constants import AuditAction, OrganizationRole from gatehouse_app.extensions import db +from gatehouse_app.services.audit_service import AuditService @api_v1_bp.route("/organizations//roles", methods=["GET"]) @@ -59,6 +60,16 @@ def assign_role_to_member(org_id, role_name): membership.role = new_role db.session.commit() + + AuditService.log_action( + action=AuditAction.ORG_MEMBER_ROLE_CHANGE, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="user", + resource_id=str(target_user_id), + description=f"Role changed to {new_role.value} for user {target_user_id}", + ) + return api_response(data={"user_id": target_user_id, "role": new_role.value}, message=f"Role updated to {new_role.value}") @@ -82,4 +93,14 @@ def remove_role_from_member(org_id, role_name, user_id): org = OrganizationService.get_organization_by_id(org_id) OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id) + + AuditService.log_action( + action=AuditAction.ORG_MEMBER_REMOVE, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="user", + resource_id=str(user_id), + description=f"Member {user_id} removed from organization via role removal", + ) + return api_response(data={"user_id": user_id}, message="Member removed from organization") diff --git a/gatehouse_app/api/v1/principals.py b/gatehouse_app/api/v1/principals.py index 0da8315..3a71667 100644 --- a/gatehouse_app/api/v1/principals.py +++ b/gatehouse_app/api/v1/principals.py @@ -10,6 +10,8 @@ from gatehouse_app.services.organization_service import OrganizationService from gatehouse_app.services.user_service import UserService from gatehouse_app.exceptions import OrganizationNotFoundError from gatehouse_app.extensions import db +from gatehouse_app.utils.constants import AuditAction +from gatehouse_app.services.audit_service import AuditService class PrincipalCreateSchema(Schema): @@ -127,6 +129,15 @@ def create_principal(org_id): db.session.add(principal) db.session.commit() + AuditService.log_action( + action=AuditAction.PRINCIPAL_CREATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="principal", + resource_id=str(principal.id), + description=f"Principal '{principal.name}' created", + ) + return api_response( data={"principal": principal.to_dict()}, message="Principal created successfully", @@ -255,6 +266,15 @@ def update_principal(org_id, principal_id): db.session.commit() + AuditService.log_action( + action=AuditAction.PRINCIPAL_UPDATED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="principal", + resource_id=str(principal.id), + description=f"Principal '{principal.name}' updated", + ) + return api_response( data={"principal": principal.to_dict()}, message="Principal updated successfully", @@ -308,6 +328,15 @@ def delete_principal(org_id, principal_id): principal.deleted_at = db.func.now() db.session.commit() + AuditService.log_action( + action=AuditAction.PRINCIPAL_DELETED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="principal", + resource_id=str(principal.id), + description=f"Principal '{principal.name}' deleted", + ) + return api_response( message="Principal deleted successfully", ) @@ -476,6 +505,15 @@ def add_principal_member(org_id, principal_id): db.session.commit() + AuditService.log_action( + action=AuditAction.PRINCIPAL_MEMBER_ADDED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="user", + resource_id=str(user.id), + description=f"Added user {user.email} to principal '{principal.name}'", + ) + member_dict = membership.to_dict() member_dict["user"] = user.to_dict() @@ -548,6 +586,15 @@ def remove_principal_member(org_id, principal_id, user_id): membership.deleted_at = db.func.now() db.session.commit() + AuditService.log_action( + action=AuditAction.PRINCIPAL_MEMBER_REMOVED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="user", + resource_id=str(user_id), + description=f"Removed user from principal '{principal.name}'", + ) + return api_response( message="Member removed successfully", ) @@ -697,6 +744,15 @@ def link_principal_to_department(org_id, principal_id, dept_id): error_type="SERVER_ERROR", ) + AuditService.log_action( + action=AuditAction.PRINCIPAL_DEPARTMENT_LINKED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="principal", + resource_id=str(principal_id), + description=f"Principal '{principal.name}' linked to department '{dept.name}'", + ) + return api_response( data={ "principal": principal.to_dict(), @@ -774,6 +830,15 @@ def unlink_principal_from_department(org_id, principal_id, dept_id): link.deleted_at = db.func.now() db.session.commit() + AuditService.log_action( + action=AuditAction.PRINCIPAL_DEPARTMENT_UNLINKED, + user_id=g.current_user.id, + organization_id=org_id, + resource_type="principal", + resource_id=str(principal_id), + description=f"Principal '{principal.name}' unlinked from department '{dept.name}'", + ) + return api_response( message="Principal unlinked from department successfully", ) diff --git a/gatehouse_app/api/v1/superadmin/auth.py b/gatehouse_app/api/v1/superadmin/auth.py index e45e63c..48fac09 100644 --- a/gatehouse_app/api/v1/superadmin/auth.py +++ b/gatehouse_app/api/v1/superadmin/auth.py @@ -9,6 +9,7 @@ from gatehouse_app.utils.response import api_response from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError +from gatehouse_app.utils.constants import AuditAction logger = logging.getLogger(__name__) @@ -105,6 +106,7 @@ def login(): @superadmin_bp.route("/auth/logout", methods=["POST"]) @superadmin_required +@superadmin_audit_log(action=AuditAction.USER_LOGOUT, resource_type="session") def logout(): """Superadmin logout endpoint. diff --git a/gatehouse_app/utils/constants.py b/gatehouse_app/utils/constants.py index 1d600b7..e31bcab 100644 --- a/gatehouse_app/utils/constants.py +++ b/gatehouse_app/utils/constants.py @@ -154,6 +154,27 @@ class AuditAction(str, Enum): DEPARTMENT_DELETED = "department.deleted" DEPARTMENT_MEMBER_ADDED = "department.member.added" DEPARTMENT_MEMBER_REMOVED = "department.member.removed" + DEPARTMENT_CERT_POLICY_UPDATED = "department.cert_policy.updated" + + # Organization invite actions + ORG_INVITE_CANCELLED = "org.invite.cancelled" + + # MFA reminder + ORG_MFA_REMINDER_SENT = "org.mfa_reminder.sent" + + # API key actions + ORG_API_KEY_CREATED = "org.api_key.created" + ORG_API_KEY_UPDATED = "org.api_key.updated" + ORG_API_KEY_DELETED = "org.api_key.deleted" + + # OIDC client actions + ORG_CLIENT_CREATED = "org.client.created" + ORG_CLIENT_UPDATED = "org.client.updated" + ORG_CLIENT_DEACTIVATED = "org.client.deactivated" + + # Principal department link actions + PRINCIPAL_DEPARTMENT_LINKED = "principal.department.linked" + PRINCIPAL_DEPARTMENT_UNLINKED = "principal.department.unlinked" class OIDCGrantType(str, Enum): From 803bf4f4f24d5aee7b6e25db2cb6f42177d7d43f Mon Sep 17 00:00:00 2001 From: Cory Hawklvelt Date: Tue, 28 Apr 2026 20:54:15 +0930 Subject: [PATCH 4/4] refactor: consolidate user and superadmin sessions into unified model --- gatehouse_app/decorators/superadmin.py | 16 +- gatehouse_app/models/__init__.py | 2 - gatehouse_app/models/superadmin/__init__.py | 4 +- gatehouse_app/models/superadmin/superadmin.py | 12 +- .../models/superadmin/superadmin_session.py | 80 -------- gatehouse_app/models/user/session.py | 66 +++++-- gatehouse_app/services/auth_service.py | 4 +- gatehouse_app/services/session_service.py | 75 +++++-- .../services/superadmin_auth_service.py | 24 ++- gatehouse_app/utils/constants.py | 7 + .../c8d2e4f6a1b3_consolidate_sessions.py | 122 ++++++++++++ .../test_superadmin_session_timeouts.py | 186 ++++++++++++++++++ 12 files changed, 472 insertions(+), 126 deletions(-) delete mode 100644 gatehouse_app/models/superadmin/superadmin_session.py create mode 100644 migrations/versions/c8d2e4f6a1b3_consolidate_sessions.py create mode 100644 tests/integration/test_superadmin_session_timeouts.py diff --git a/gatehouse_app/decorators/superadmin.py b/gatehouse_app/decorators/superadmin.py index c318a68..d7401ac 100644 --- a/gatehouse_app/decorators/superadmin.py +++ b/gatehouse_app/decorators/superadmin.py @@ -15,7 +15,7 @@ def superadmin_required(f): """Decorator to require superadmin Bearer token authentication. Extracts token from Authorization: Bearer {token} header, - validates the session against SuperadminSession table, + validates the session against the unified sessions table, and sets g.current_superadmin and g.superadmin_session. Returns 401 if no valid session, 403 if not a superadmin. @@ -46,10 +46,14 @@ def superadmin_required(f): token = parts[1] # Import here to avoid circular imports - from gatehouse_app.models.superadmin import SuperadminSession, Superadmin + from gatehouse_app.models.user.session import Session + from gatehouse_app.models.superadmin import Superadmin + from gatehouse_app.utils.constants import SessionType - # Get active session by token - session = SuperadminSession.query.filter_by(token=token).first() + # Get active session by token, scoped to superadmin + session = Session.query.filter_by( + token=token, owner_type=SessionType.SUPERADMIN + ).first() if not session: return api_response( @@ -68,8 +72,8 @@ def superadmin_required(f): error_type="SESSION_INACTIVE" ) - # Get the superadmin - superadmin = session.superadmin + # Get the superadmin by owner_id + superadmin = Superadmin.query.get(session.owner_id) if not superadmin: return api_response( success=False, diff --git a/gatehouse_app/models/__init__.py b/gatehouse_app/models/__init__.py index fe62307..3378ecb 100644 --- a/gatehouse_app/models/__init__.py +++ b/gatehouse_app/models/__init__.py @@ -118,7 +118,6 @@ from gatehouse_app.models.zerotier import ( # noqa: F401 from gatehouse_app.models.superadmin import ( # noqa: F401 Superadmin, SuperadminSession, - SuperadminSessionStatus, ) from gatehouse_app.models.superadmin_audit_log import SuperadminAuditLog # noqa: F401 from gatehouse_app.models.security.user_security_policy import ( # noqa: F401 @@ -186,6 +185,5 @@ __all__ = [ # Superadmin "Superadmin", "SuperadminSession", - "SuperadminSessionStatus", "SuperadminAuditLog", ] diff --git a/gatehouse_app/models/superadmin/__init__.py b/gatehouse_app/models/superadmin/__init__.py index ff9fddc..529038f 100644 --- a/gatehouse_app/models/superadmin/__init__.py +++ b/gatehouse_app/models/superadmin/__init__.py @@ -1,5 +1,5 @@ """Superadmin models.""" from gatehouse_app.models.superadmin.superadmin import Superadmin -from gatehouse_app.models.superadmin.superadmin_session import SuperadminSession, SuperadminSessionStatus +from gatehouse_app.models.user.session import Session as SuperadminSession -__all__ = ["Superadmin", "SuperadminSession", "SuperadminSessionStatus"] +__all__ = ["Superadmin", "SuperadminSession"] diff --git a/gatehouse_app/models/superadmin/superadmin.py b/gatehouse_app/models/superadmin/superadmin.py index 525b410..bdbd495 100644 --- a/gatehouse_app/models/superadmin/superadmin.py +++ b/gatehouse_app/models/superadmin/superadmin.py @@ -23,11 +23,15 @@ class Superadmin(BaseModel): is_active = db.Column(db.Boolean, default=True, nullable=False) last_login_at = db.Column(db.DateTime, nullable=True) - # Relationship to sessions + # Relationship to sessions (unified model, scoped to superadmin owner_type) sessions = db.relationship( - "SuperadminSession", - back_populates="superadmin", - cascade="all, delete-orphan" + "Session", + primaryjoin=( + "and_(Superadmin.id == foreign(Session.owner_id), " + "Session.owner_type == 'superadmin')" + ), + cascade="all, delete-orphan", + lazy="dynamic", ) # Relationship to audit logs diff --git a/gatehouse_app/models/superadmin/superadmin_session.py b/gatehouse_app/models/superadmin/superadmin_session.py deleted file mode 100644 index 8eef79b..0000000 --- a/gatehouse_app/models/superadmin/superadmin_session.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Superadmin session model.""" -import logging -from datetime import datetime, timezone, timedelta - -from gatehouse_app.extensions import db -from gatehouse_app.models.base import BaseModel - - -logger = logging.getLogger(__name__) - - -class SuperadminSessionStatus: - """Session status constants.""" - ACTIVE = "active" - REVOKED = "revoked" - EXPIRED = "expired" - - -class SuperadminSession(BaseModel): - """Session model for superadmin authentication.""" - - __tablename__ = "superadmin_sessions" - - superadmin_id = db.Column( - db.String(36), - db.ForeignKey("superadmins.id"), - nullable=False, - index=True - ) - token = db.Column(db.String(255), unique=True, nullable=False, index=True) - expires_at = db.Column(db.DateTime, nullable=False) - last_activity_at = db.Column( - db.DateTime, - nullable=False, - default=lambda: datetime.now(timezone.utc) - ) - ip_address = db.Column(db.String(45), nullable=True) - user_agent = db.Column(db.Text, nullable=True) - revoked_at = db.Column(db.DateTime, nullable=True) - revoked_reason = db.Column(db.String(255), nullable=True) - - # Relationship - superadmin = db.relationship("Superadmin", back_populates="sessions") - - def __repr__(self): - return f"" - - def is_active(self): - """Check if session is currently active.""" - now = datetime.now(timezone.utc) - expires_at = self.expires_at - if expires_at.tzinfo is None: - expires_at = expires_at.replace(tzinfo=timezone.utc) - return ( - self.deleted_at is None - and self.revoked_at is None - and expires_at > now - ) - - def is_expired(self): - """Check if session has expired.""" - now = datetime.now(timezone.utc) - expires_at = self.expires_at - if expires_at.tzinfo is None: - expires_at = expires_at.replace(tzinfo=timezone.utc) - return now > expires_at - - def revoke(self, reason: str = None): - """Revoke the session.""" - self.revoked_at = datetime.now(timezone.utc) - if reason: - self.revoked_reason = reason - from gatehouse_app import db - db.session.commit() - - def to_dict(self, exclude=None): - """Convert to dictionary, excluding sensitive fields.""" - exclude = exclude or [] - exclude.append("token") - return super().to_dict(exclude=exclude) diff --git a/gatehouse_app/models/user/session.py b/gatehouse_app/models/user/session.py index e6300dd..cf8ee32 100644 --- a/gatehouse_app/models/user/session.py +++ b/gatehouse_app/models/user/session.py @@ -3,15 +3,24 @@ from datetime import datetime, timedelta, timezone from flask import current_app from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel -from gatehouse_app.utils.constants import SessionStatus +from gatehouse_app.utils.constants import SessionStatus, SessionType class Session(BaseModel): - """Session model for tracking user sessions.""" + """Session model for tracking user and superadmin sessions.""" __tablename__ = "sessions" - user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True) + # Owner discriminator — determines which table the owner_id references + owner_type = db.Column( + db.String(20), nullable=False, default=SessionType.USER, index=True + ) + owner_id = db.Column(db.String(36), nullable=False, index=True) + + # Legacy column kept for backward compatibility during migration; + # new code should use owner_id / owner_type. + user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True) + token = db.Column(db.String(255), unique=True, nullable=False, index=True) status = db.Column(db.Enum(SessionStatus), default=SessionStatus.ACTIVE, nullable=False) @@ -34,21 +43,37 @@ class Session(BaseModel): # Relationships user = db.relationship("User", back_populates="sessions") + # Composite index for owner-scoped queries + __table_args__ = ( + db.Index("ix_sessions_owner_type_owner_id", "owner_type", "owner_id"), + ) + + # ---- Convenience properties ------------------------------------------------ + + @property + def is_user(self): + return self.owner_type == SessionType.USER + + @property + def is_superadmin(self): + return self.owner_type == SessionType.SUPERADMIN + + # ---- Core methods ---------------------------------------------------------- + def __repr__(self): - """String representation of Session.""" - return f"" + return f"" def is_active(self): """Check if session is currently active. - Sessions are evaluated against two independent timeouts: + User sessions are evaluated against two independent timeouts: - Idle timeout: expires if no request has been made within SESSION_IDLE_TIMEOUT seconds (default 15 min). - Absolute timeout: expires if SESSION_ABSOLUTE_TIMEOUT seconds - have elapsed since the session was created (default 8 h), - regardless of activity. + have elapsed since the session was created (default 8 h). - A session must satisfy *both* constraints to remain active. + Superadmin sessions use absolute timeout only (no idle timeout). + A session must satisfy *all* applicable constraints to remain active. """ now = datetime.now(timezone.utc) created_at = self.created_at @@ -59,12 +84,21 @@ class Session(BaseModel): if last_activity_at.tzinfo is None: last_activity_at = last_activity_at.replace(tzinfo=timezone.utc) - idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900) absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800) - - idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout) absolute_expires_at = created_at + timedelta(seconds=absolute_timeout) + if self.is_superadmin: + # Superadmin: absolute timeout only + return ( + self.status == SessionStatus.ACTIVE + and now < absolute_expires_at + and self.deleted_at is None + ) + + # User: idle + absolute timeout + idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900) + idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout) + return ( self.status == SessionStatus.ACTIVE and now < idle_expires_at @@ -83,6 +117,8 @@ class Session(BaseModel): capped so that the session never exceeds the absolute lifetime (``created_at + absolute timeout``). + Superadmin sessions only update last_activity_at (no sliding window). + Args: duration_seconds: Override for the idle timeout. When *None* (the common case), the value is read from @@ -90,6 +126,12 @@ class Session(BaseModel): """ now = datetime.now(timezone.utc) + if self.is_superadmin: + # Superadmin: just bump last_activity_at, no sliding window + self.last_activity_at = now + db.session.commit() + return + if duration_seconds is None: duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900) diff --git a/gatehouse_app/services/auth_service.py b/gatehouse_app/services/auth_service.py index c662ba8..af438e0 100644 --- a/gatehouse_app/services/auth_service.py +++ b/gatehouse_app/services/auth_service.py @@ -8,7 +8,7 @@ from gatehouse_app.extensions import db, bcrypt from gatehouse_app.models.user.user import User from gatehouse_app.models.auth.authentication_method import AuthenticationMethod from gatehouse_app.models.user.session import Session -from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, UserStatus, AuditAction +from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, SessionType, UserStatus, AuditAction from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError from gatehouse_app.services.audit_service import AuditService @@ -165,6 +165,8 @@ class AuthService: # Create session session = Session( + owner_type=SessionType.USER, + owner_id=user.id, user_id=user.id, token=token, status=SessionStatus.ACTIVE, diff --git a/gatehouse_app/services/session_service.py b/gatehouse_app/services/session_service.py index e86cd6f..6f3e595 100644 --- a/gatehouse_app/services/session_service.py +++ b/gatehouse_app/services/session_service.py @@ -1,7 +1,7 @@ """Session service.""" from datetime import datetime, timezone from gatehouse_app.models.user.session import Session -from gatehouse_app.utils.constants import SessionStatus +from gatehouse_app.utils.constants import SessionStatus, SessionType class SessionService: @@ -28,18 +28,22 @@ class SessionService: ).first() @staticmethod - def get_user_sessions(user_id, active_only=True): - """ - Get all sessions for a user. + def get_owner_sessions(owner_type, owner_id, active_only=True): + """Get all sessions for an owner (user or superadmin). Args: - user_id: User ID + owner_type: SessionType.USER or SessionType.SUPERADMIN + owner_id: Owner ID active_only: If True, only return active sessions Returns: List of Session instances """ - query = Session.query.filter_by(user_id=user_id, deleted_at=None) + query = Session.query.filter_by( + owner_type=owner_type, + owner_id=owner_id, + deleted_at=None, + ) if active_only: query = query.filter_by(status=SessionStatus.ACTIVE).filter( @@ -49,18 +53,67 @@ class SessionService: return query.all() @staticmethod - def revoke_user_sessions(user_id, reason="User logged out from all devices"): + def get_user_sessions(user_id, active_only=True): + """Get all sessions for a user. + + Args: + user_id: User ID + active_only: If True, only return active sessions + + Returns: + List of Session instances """ - Revoke all active sessions for a user. + return SessionService.get_owner_sessions( + SessionType.USER, user_id, active_only=active_only + ) + + @staticmethod + def get_superadmin_sessions(superadmin_id, active_only=True): + """Get all sessions for a superadmin. + + Args: + superadmin_id: Superadmin ID + active_only: If True, only return active sessions + + Returns: + List of Session instances + """ + return SessionService.get_owner_sessions( + SessionType.SUPERADMIN, superadmin_id, active_only=active_only + ) + + @staticmethod + def revoke_owner_sessions(owner_type, owner_id, reason="Logged out from all devices"): + """Revoke all active sessions for an owner. + + Args: + owner_type: SessionType.USER or SessionType.SUPERADMIN + owner_id: Owner ID + reason: Reason for revocation + """ + sessions = SessionService.get_owner_sessions(owner_type, owner_id, active_only=True) + for session in sessions: + session.revoke(reason=reason) + + @staticmethod + def revoke_user_sessions(user_id, reason="User logged out from all devices"): + """Revoke all active sessions for a user. Args: user_id: User ID reason: Reason for revocation """ - sessions = SessionService.get_user_sessions(user_id, active_only=True) + SessionService.revoke_owner_sessions(SessionType.USER, user_id, reason=reason) - for session in sessions: - session.revoke(reason=reason) + @staticmethod + def revoke_superadmin_sessions(superadmin_id, reason="Superadmin logged out"): + """Revoke all active sessions for a superadmin. + + Args: + superadmin_id: Superadmin ID + reason: Reason for revocation + """ + SessionService.revoke_owner_sessions(SessionType.SUPERADMIN, superadmin_id, reason=reason) @staticmethod def cleanup_expired_sessions(): diff --git a/gatehouse_app/services/superadmin_auth_service.py b/gatehouse_app/services/superadmin_auth_service.py index 31e798b..5d91ebb 100644 --- a/gatehouse_app/services/superadmin_auth_service.py +++ b/gatehouse_app/services/superadmin_auth_service.py @@ -6,7 +6,9 @@ from typing import Optional from flask import request, current_app from gatehouse_app.extensions import db, bcrypt -from gatehouse_app.models.superadmin import Superadmin, SuperadminSession +from gatehouse_app.models.superadmin import Superadmin +from gatehouse_app.models.user.session import Session +from gatehouse_app.utils.constants import SessionType from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError @@ -70,15 +72,17 @@ class SuperadminAuthService: duration_seconds: Session duration in seconds (default 8 hours) Returns: - SuperadminSession instance + Session instance """ # Generate secure token token = secrets.token_urlsafe(32) - # Create session - session = SuperadminSession( - superadmin_id=superadmin_id, + # Create session using unified model + session = Session( + owner_type=SessionType.SUPERADMIN, + owner_id=superadmin_id, token=token, + status="active", expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds), last_activity_at=datetime.now(timezone.utc), ip_address=request.remote_addr, @@ -97,7 +101,9 @@ class SuperadminAuthService: session_id: Session ID to revoke reason: Optional revocation reason """ - session = SuperadminSession.query.get(session_id) + session = Session.query.filter_by( + id=session_id, owner_type=SessionType.SUPERADMIN + ).first() if session: session.revoke(reason=reason) logger.info(f"[SuperadminAuth] Session {session_id} revoked: {reason or 'No reason'}") @@ -111,9 +117,11 @@ class SuperadminAuthService: except_token: Optional token to keep (current session) reason: Optional revocation reason """ - query = SuperadminSession.query.filter_by(superadmin_id=superadmin_id) + query = Session.query.filter_by( + owner_type=SessionType.SUPERADMIN, owner_id=superadmin_id + ) if except_token: - query = query.filter(SuperadminSession.token != except_token) + query = query.filter(Session.token != except_token) sessions = query.all() for session in sessions: diff --git a/gatehouse_app/utils/constants.py b/gatehouse_app/utils/constants.py index e31bcab..c489b1d 100644 --- a/gatehouse_app/utils/constants.py +++ b/gatehouse_app/utils/constants.py @@ -52,6 +52,13 @@ class SessionStatus(str, Enum): REVOKED = "revoked" +class SessionType(str, Enum): + """Session owner type discriminator.""" + + USER = "user" + SUPERADMIN = "superadmin" + + class AuditAction(str, Enum): """Audit log action types.""" diff --git a/migrations/versions/c8d2e4f6a1b3_consolidate_sessions.py b/migrations/versions/c8d2e4f6a1b3_consolidate_sessions.py new file mode 100644 index 0000000..4a0af90 --- /dev/null +++ b/migrations/versions/c8d2e4f6a1b3_consolidate_sessions.py @@ -0,0 +1,122 @@ +"""Consolidate user and superadmin sessions into unified sessions table. + +Revision ID: c8d2e4f6a1b3 +Revises: b7e3f1a92c4d +Create Date: 2026-04-28 00:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'c8d2e4f6a1b3' +down_revision = 'b7e3f1a92c4d' +branch_labels = None +depends_on = None + + +def upgrade(): + # 1. Add new columns (nullable initially for data migration) + op.add_column('sessions', sa.Column('owner_type', sa.String(20), nullable=True)) + op.add_column('sessions', sa.Column('owner_id', sa.String(36), nullable=True)) + + # 2. Backfill existing user sessions: owner_type = 'user', owner_id = user_id + op.execute(""" + UPDATE sessions + SET owner_type = 'user', + owner_id = user_id + WHERE owner_type IS NULL + """) + + # 3. Migrate superadmin sessions into the sessions table + op.execute(""" + INSERT INTO sessions ( + id, owner_type, owner_id, token, status, + ip_address, user_agent, device_info, + expires_at, last_activity_at, revoked_at, revoked_reason, + is_compliance_only, created_at, updated_at, deleted_at + ) + SELECT + id, 'superadmin', superadmin_id, token, 'active', + ip_address, user_agent, NULL, + expires_at, last_activity_at, revoked_at, revoked_reason, + FALSE, created_at, updated_at, deleted_at + FROM superadmin_sessions + """) + + # 4. Make owner_type and owner_id NOT NULL + op.alter_column('sessions', 'owner_type', nullable=False) + op.alter_column('sessions', 'owner_id', nullable=False) + + # 5. Make user_id nullable (no longer the sole owner reference) + op.alter_column('sessions', 'user_id', nullable=True) + + # 6. Create indexes for efficient owner-scoped queries + op.create_index( + 'ix_sessions_owner_type_owner_id', + 'sessions', + ['owner_type', 'owner_id'] + ) + op.create_index( + 'ix_sessions_owner_type', + 'sessions', + ['owner_type'] + ) + op.create_index( + 'ix_sessions_owner_id', + 'sessions', + ['owner_id'] + ) + + # 7. Drop the now-redundant superadmin_sessions table + op.drop_table('superadmin_sessions') + + +def downgrade(): + # 1. Recreate superadmin_sessions table + op.create_table( + 'superadmin_sessions', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('superadmin_id', sa.String(36), sa.ForeignKey('superadmins.id'), nullable=False, index=True), + sa.Column('token', sa.String(255), unique=True, nullable=False, index=True), + sa.Column('expires_at', sa.DateTime, nullable=False), + sa.Column('last_activity_at', sa.DateTime, nullable=False), + sa.Column('ip_address', sa.String(45), nullable=True), + sa.Column('user_agent', sa.Text, nullable=True), + sa.Column('revoked_at', sa.DateTime, nullable=True), + sa.Column('revoked_reason', sa.String(255), nullable=True), + sa.Column('created_at', sa.DateTime, nullable=False), + sa.Column('updated_at', sa.DateTime, nullable=False), + sa.Column('deleted_at', sa.DateTime, nullable=True), + ) + + # 2. Move superadmin sessions back to superadmin_sessions + op.execute(""" + INSERT INTO superadmin_sessions ( + id, superadmin_id, token, expires_at, last_activity_at, + ip_address, user_agent, revoked_at, revoked_reason, + created_at, updated_at, deleted_at + ) + SELECT + id, owner_id, token, expires_at, last_activity_at, + ip_address, user_agent, revoked_at, revoked_reason, + created_at, updated_at, deleted_at + FROM sessions + WHERE owner_type = 'superadmin' + """) + + # 3. Remove superadmin sessions from sessions table + op.execute("DELETE FROM sessions WHERE owner_type = 'superadmin'") + + # 4. Drop indexes + op.drop_index('ix_sessions_owner_id', table_name='sessions') + op.drop_index('ix_sessions_owner_type', table_name='sessions') + op.drop_index('ix_sessions_owner_type_owner_id', table_name='sessions') + + # 5. Remove new columns + op.drop_column('sessions', 'owner_id') + op.drop_column('sessions', 'owner_type') + + # 6. Make user_id NOT NULL again + op.alter_column('sessions', 'user_id', nullable=False) diff --git a/tests/integration/test_superadmin_session_timeouts.py b/tests/integration/test_superadmin_session_timeouts.py new file mode 100644 index 0000000..21410ae --- /dev/null +++ b/tests/integration/test_superadmin_session_timeouts.py @@ -0,0 +1,186 @@ +"""Superadmin session timeout integration tests. + +Validates the absolute-only timeout policy for superadmin sessions. +Superadmin sessions do NOT have idle timeout — only absolute timeout. +""" +import pytest +import uuid +from datetime import datetime, timedelta, timezone + +from tests.integration.client.base import ApiError + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def assert_success(response: dict, message_contains: str = "") -> dict: + """Assert that an api_response-wrapped payload succeeded.""" + data = response.get("data", {}) + assert response.get("success") is not False, ( + f"Expected success but got error: {response.get('message')}" + ) + if message_contains: + assert message_contains.lower() in response.get("message", "").lower(), ( + f"Expected message to contain '{message_contains}' but got: {response.get('message')}" + ) + return data + + +def _get_session_row(integration_app, token: str): + """Look up the Session model row for a given bearer token.""" + from gatehouse_app.models.user.session import Session + with integration_app.app_context(): + return Session.query.filter_by(token=token).first() + + +def _touch_session(integration_app, session_id: str, **updates): + """Directly update columns on a Session row. + + Only use this to simulate the passage of time — never to assert + internal state. + """ + from gatehouse_app.models.user.session import Session + with integration_app.app_context(): + sess = Session.query.get(session_id) + for attr, value in updates.items(): + setattr(sess, attr, value) + from gatehouse_app import db + db.session.commit() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def superadmin_credentials(integration_app): + """Create a superadmin and return login credentials.""" + from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService + + email = f"admin_{uuid.uuid4().hex[:8]}@gatehouse.local" + password = "SuperAdmin123!" + + with integration_app.app_context(): + sa = SuperadminAuthService.create_superadmin( + email=email, + credential=password, + full_name="Test Superadmin", + ) + return {"id": str(sa.id), "email": email, "password": password} + + +@pytest.fixture +def logged_in_superadmin(integration_client, superadmin_credentials, integration_app): + """Log in as superadmin and return session metadata. + + Returns dict with ``superadmin``, ``token``, ``session_id``, ``session_row``. + """ + creds = superadmin_credentials + resp = integration_client.post( + "/api/v1/superadmin/auth/login", + data={"email": creds["email"], "password": creds["password"]}, + ) + data = assert_success(resp) + token = data["token"] + + session_row = _get_session_row(integration_app, token) + assert session_row is not None, "Session row should exist after superadmin login" + + return { + "superadmin": creds, + "token": token, + "session_id": session_row.id, + "session_row": session_row, + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestSuperadminSessionTimeouts: + """Absolute-only timeout behavior for superadmin sessions.""" + + def test_superadmin_session_valid_before_timeout( + self, integration_client, logged_in_superadmin, + ): + """SA-SESS-01 — Fresh superadmin session is accepted.""" + integration_client.set_token(logged_in_superadmin["token"]) + result = integration_client.get("/api/v1/superadmin/auth/me") + data = assert_success(result) + assert "superadmin" in data + + def test_absolute_timeout_rejects_superadmin( + self, integration_client, logged_in_superadmin, integration_app, + ): + """SA-SESS-02 — Superadmin session rejected after absolute timeout. + + Push ``created_at`` far into the past. The session must be + rejected even though ``last_activity_at`` is fresh. + """ + _touch_session( + integration_app, + logged_in_superadmin["session_id"], + created_at=datetime.now(timezone.utc) - timedelta(days=1), + last_activity_at=datetime.now(timezone.utc), + ) + + integration_client.set_token(logged_in_superadmin["token"]) + with pytest.raises(ApiError) as exc_info: + integration_client.get("/api/v1/superadmin/auth/me") + + assert exc_info.value.status_code == 401 + + def test_idle_timeout_does_NOT_reject_superadmin( + self, integration_client, logged_in_superadmin, integration_app, + ): + """SA-SESS-03 — Superadmin sessions have NO idle timeout. + + Push ``last_activity_at`` far into the past but keep + ``created_at`` recent. The session should still be valid + because superadmin sessions only use absolute timeout. + """ + _touch_session( + integration_app, + logged_in_superadmin["session_id"], + last_activity_at=datetime.now(timezone.utc) - timedelta(hours=1), + ) + + integration_client.set_token(logged_in_superadmin["token"]) + result = integration_client.get("/api/v1/superadmin/auth/me") + data = assert_success(result) + assert "superadmin" in data + + def test_revoked_superadmin_session_rejected( + self, integration_client, logged_in_superadmin, + ): + """SA-SESS-04 — Revoked superadmin session is rejected.""" + integration_client.set_token(logged_in_superadmin["token"]) + + # Logout revokes the session + integration_client.post("/api/v1/superadmin/auth/logout") + integration_client.clear_token() + + # Try using the old token + integration_client.set_token(logged_in_superadmin["token"]) + with pytest.raises(ApiError) as exc_info: + integration_client.get("/api/v1/superadmin/auth/me") + + assert exc_info.value.status_code == 401 + + def test_superadmin_session_has_owner_type( + self, integration_app, logged_in_superadmin, + ): + """SA-SESS-05 — Superadmin session row has owner_type='superadmin'.""" + from gatehouse_app.models.user.session import Session + from gatehouse_app.utils.constants import SessionType + + with integration_app.app_context(): + sess = Session.query.get(logged_in_superadmin["session_id"]) + assert sess is not None + assert sess.owner_type == SessionType.SUPERADMIN + assert sess.owner_id == logged_in_superadmin["superadmin"]["id"]