diff --git a/README.md b/README.md index b917d1c..64fcbd7 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,7 @@ Copy `.env.example` to `.env` and configure: - `POST /api/v1/auth/logout` - Logout - `GET /api/v1/auth/me` - Get current user - `GET /api/v1/auth/sessions` - Get user sessions +- `POST /api/v1/auth/sessions/refresh` - Extend session idle window - `DELETE /api/v1/auth/sessions/:id` - Revoke session ### Users @@ -264,6 +265,52 @@ gunicorn -w 4 -b 0.0.0.0:8000 wsgi:app - Request ID tracking for audit trails +## Session Management + +Sessions are database-backed bearer tokens stored in PostgreSQL. Each session is created at login and validated on every authenticated request via the `login_required` decorator. + +### Sliding Timeout + +Sessions use a **sliding window** model with two independent limits: + +| Timeout | Default | Env Var | Behaviour | +|---------|---------|---------|-----------| +| **Idle** | 15 min | `SESSION_IDLE_TIMEOUT` | Extends automatically on every request. If no request is made within this window the session expires. | +| **Absolute** | 8 h | `SESSION_ABSOLUTE_TIMEOUT` | Hard cap measured from session creation. Activity cannot extend a session beyond this point. | + +Every authenticated request resets the idle clock by calling `Session.refresh()`, which sets `expires_at = now + idle_timeout` — but never past `created_at + absolute_timeout`. This means: + +- An active user stays logged in indefinitely **up to** the absolute cap. +- An idle user is logged out after the idle timeout. +- No session can survive longer than the absolute timeout regardless of activity. + +### Configuration + +Override defaults via environment variables: + +```bash +SESSION_IDLE_TIMEOUT=900 # seconds (15 min) +SESSION_ABSOLUTE_TIMEOUT=28800 # seconds (8 h) +``` + +### Cleanup + +Expired sessions are soft-marked as `EXPIRED` by the `cleanup_sessions` job. Run it periodically via the job runner: + +```bash +python manage.py cleanup_sessions + +# Or via the job runner (Docker): +JOB_NAME=cleanup_sessions JOB_INTERVAL_SECONDS=300 +``` + +### Session Endpoints + +- `GET /api/v1/auth/sessions` — List active sessions for the current user +- `POST /api/v1/auth/sessions/refresh` — Extend the current session's idle window (returns new `expires_at`) +- `DELETE /api/v1/auth/sessions/:id` — Revoke a specific session + + # Boostrap db python manage.py db upgrade diff --git a/config/base.py b/config/base.py index 8735671..666f9fc 100644 --- a/config/base.py +++ b/config/base.py @@ -48,6 +48,10 @@ class BaseConfig: seconds=int(os.getenv("MAX_SESSION_DURATION", "86400")) ) + # Session timeout policy (seconds) + SESSION_IDLE_TIMEOUT = int(os.getenv("SESSION_IDLE_TIMEOUT", "900")) + SESSION_ABSOLUTE_TIMEOUT = int(os.getenv("SESSION_ABSOLUTE_TIMEOUT", "28800")) + # CORS CORS_ORIGINS = os.getenv( "CORS_ORIGINS", diff --git a/gatehouse_app/api/v1/auth/core.py b/gatehouse_app/api/v1/auth/core.py index 42f11c4..4fb0071 100644 --- a/gatehouse_app/api/v1/auth/core.py +++ b/gatehouse_app/api/v1/auth/core.py @@ -116,7 +116,7 @@ def login(): remember_me = data.get("remember_me", False) policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me) - duration = 2592000 if remember_me else 86400 + duration = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800) if remember_me else None is_compliance_only = policy_result.create_compliance_only_session user_session = AuthService.create_session(user, duration_seconds=duration, is_compliance_only=is_compliance_only) @@ -227,6 +227,32 @@ def revoke_session(session_id): return api_response(message="Session revoked successfully") +@api_v1_bp.route("/auth/sessions/refresh", methods=["POST"]) +@login_required +def refresh_session(): + """Extend the current session's idle window. + + The server already refreshes the session on every authenticated + request, but this endpoint exists so the frontend can proactively + keep a session alive (e.g. a heartbeat while the user is reading + a long page with no API calls). + + Returns the new ``expires_at`` so the frontend can display a + countdown or warning before the absolute cap. + """ + session = g.current_session + session.refresh() + + return api_response( + data={ + "expires_at": session.expires_at.isoformat() + "Z" + if session.expires_at.isoformat()[-1] != "Z" + else session.expires_at.isoformat(), + }, + message="Session refreshed", + ) + + @api_v1_bp.route("/auth/token", methods=["GET"]) @login_required def get_token(): @@ -246,7 +272,8 @@ def get_token(): parsed_redirect = urlparse(redirect_url) redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}" - if redirect_origin not in allowed_origins: + wildcard = "*" in allowed_origins + if not wildcard and redirect_origin not in allowed_origins: return api_response(success=False, message="Redirect URL is not allowed.", status=400, error_type="INVALID_REDIRECT") sep = "&" if "?" in redirect_url else "?" diff --git a/gatehouse_app/api/v1/external_auth/oauth.py b/gatehouse_app/api/v1/external_auth/oauth.py index 421b13b..4fa4988 100644 --- a/gatehouse_app/api/v1/external_auth/oauth.py +++ b/gatehouse_app/api/v1/external_auth/oauth.py @@ -190,8 +190,8 @@ def select_organization(): if not member: return api_response(success=False, message="You are not a member of this organization", status=403, error_type="FORBIDDEN") - from gatehouse_app.services.session_service import SessionService - session = SessionService.create_session(user=user, organization_id=organization_id) + from gatehouse_app.services.auth_service import AuthService + session = AuthService.create_session(user=user) state_record.mark_used() provider_type_val = state_record.provider_type.value if isinstance(state_record.provider_type, _AuthMethodType) else state_record.provider_type diff --git a/gatehouse_app/middleware/cors.py b/gatehouse_app/middleware/cors.py index defe68c..797d026 100644 --- a/gatehouse_app/middleware/cors.py +++ b/gatehouse_app/middleware/cors.py @@ -1,6 +1,44 @@ """CORS middleware configuration.""" from flask import request, make_response +ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS" +ALLOWED_HEADERS = ( + "Content-Type, Authorization, X-Requested-With, X-Request-ID, " + "Cache-Control, Pragma, X-WebAuthn-Session-Token" +) + + +def _is_origin_allowed(origin, cors_origins): + """Return True if the origin is permitted by the CORS config. + + Handles both wildcard ("*") and explicit origin lists. + """ + if not origin: + return False + if cors_origins == "*": + return True + if isinstance(cors_origins, list): + if "*" in cors_origins: + return True + return origin in cors_origins + return False + + +def _cors_origin_header(cors_origins, request_origin): + """Return the value for Access-Control-Allow-Origin. + + Per the CORS spec, browsers reject ``*`` when credentials are involved, + so we echo the request origin when wildcard + credentials is configured. + """ + allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins) + if allow_all and request_origin: + return request_origin + if allow_all: + return "*" + if request_origin and request_origin in cors_origins: + return request_origin + return None + def setup_cors(app): """ @@ -9,6 +47,7 @@ def setup_cors(app): Args: app: Flask application instance """ + supports_credentials = app.config.get("CORS_SUPPORTS_CREDENTIALS", True) @app.before_request def handle_preflight(): @@ -16,49 +55,33 @@ def setup_cors(app): if request.method == "OPTIONS": origin = request.headers.get("Origin") cors_origins = app.config.get("CORS_ORIGINS", []) - - # Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard) - allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins) - - if allow_all: - response = make_response("", 204) - response.headers["Access-Control-Allow-Origin"] = "*" - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma" - response.headers["Access-Control-Max-Age"] = "3600" - response.headers["Cache-Control"] = "no-cache, no-store" - return response - elif origin and origin in cors_origins: - response = make_response("", 204) - response.headers["Access-Control-Allow-Origin"] = origin - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token" + + if not _is_origin_allowed(origin, cors_origins): + return None + + response = make_response("", 204) + response.headers["Access-Control-Allow-Origin"] = _cors_origin_header(cors_origins, origin) + response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS + response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS + if supports_credentials: response.headers["Access-Control-Allow-Credentials"] = "true" - response.headers["Access-Control-Max-Age"] = "3600" - response.headers["Cache-Control"] = "no-cache, no-store" - return response + response.headers["Access-Control-Max-Age"] = "3600" + response.headers["Cache-Control"] = "no-cache, no-store" + return response @app.after_request def after_request_cors(response): - """Add additional CORS headers if needed.""" + """Add CORS headers to non-preflight responses.""" origin = request.headers.get("Origin") cors_origins = app.config.get("CORS_ORIGINS", []) - # Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard) - allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins) - - if allow_all: - # When allowing all origins, set header to "*" - response.headers["Access-Control-Allow-Origin"] = "*" - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma" - response.headers["Access-Control-Max-Age"] = "3600" - elif origin and origin in cors_origins: - # When allowing specific origins, echo the request origin - response.headers["Access-Control-Allow-Origin"] = origin - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token" - response.headers["Access-Control-Allow-Credentials"] = "true" + allow_origin = _cors_origin_header(cors_origins, origin) + if allow_origin: + response.headers["Access-Control-Allow-Origin"] = allow_origin + response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS + response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS + if supports_credentials: + response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Max-Age"] = "3600" return response diff --git a/gatehouse_app/models/user/session.py b/gatehouse_app/models/user/session.py index 9a78830..e6300dd 100644 --- a/gatehouse_app/models/user/session.py +++ b/gatehouse_app/models/user/session.py @@ -1,5 +1,6 @@ """Session model.""" from datetime import datetime, timedelta, timezone +from flask import current_app from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel from gatehouse_app.utils.constants import SessionStatus @@ -38,33 +39,71 @@ class Session(BaseModel): return f"" def is_active(self): - """Check if session is currently active.""" + """Check if session is currently active. + + Sessions are evaluated against two independent timeouts: + - Idle timeout: expires if no request has been made within + SESSION_IDLE_TIMEOUT seconds (default 15 min). + - Absolute timeout: expires if SESSION_ABSOLUTE_TIMEOUT seconds + have elapsed since the session was created (default 8 h), + regardless of activity. + + A session must satisfy *both* constraints to remain active. + """ now = datetime.now(timezone.utc) - expires_at = self.expires_at - if expires_at.tzinfo is None: - expires_at = expires_at.replace(tzinfo=timezone.utc) + created_at = self.created_at + last_activity_at = self.last_activity_at + + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + if last_activity_at.tzinfo is None: + last_activity_at = last_activity_at.replace(tzinfo=timezone.utc) + + idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900) + absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800) + + idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout) + absolute_expires_at = created_at + timedelta(seconds=absolute_timeout) + return ( self.status == SessionStatus.ACTIVE - and expires_at > now + and now < idle_expires_at + and now < absolute_expires_at and self.deleted_at is None ) def is_expired(self): - """Check if session has expired.""" - now = datetime.now(timezone.utc) - expires_at = self.expires_at - if expires_at.tzinfo is None: - expires_at = expires_at.replace(tzinfo=timezone.utc) - return now > expires_at + """Check if session has expired (either idle or absolute).""" + return not self.is_active() and self.status != SessionStatus.REVOKED - def refresh(self, duration_seconds: int = 86400): - """Refresh session expiration. + def refresh(self, duration_seconds: int = None): + """Extend the session expiration using a sliding window. + + The new ``expires_at`` is set to *now + idle timeout*, but is + capped so that the session never exceeds the absolute lifetime + (``created_at + absolute timeout``). Args: - duration_seconds: New session duration in seconds + duration_seconds: Override for the idle timeout. When *None* + (the common case), the value is read from + ``SESSION_IDLE_TIMEOUT`` in the Flask config. """ - self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds) - self.last_activity_at = datetime.now(timezone.utc) + now = datetime.now(timezone.utc) + + if duration_seconds is None: + duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900) + + absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800) + + idle_expires_at = now + timedelta(seconds=duration_seconds) + + created_at = self.created_at + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + absolute_expires_at = created_at + timedelta(seconds=absolute_timeout) + + self.expires_at = min(idle_expires_at, absolute_expires_at) + self.last_activity_at = now db.session.commit() def revoke(self, reason: str = None): diff --git a/gatehouse_app/services/auth_service.py b/gatehouse_app/services/auth_service.py index 1eb48c0..c662ba8 100644 --- a/gatehouse_app/services/auth_service.py +++ b/gatehouse_app/services/auth_service.py @@ -140,18 +140,26 @@ class AuthService: return user @staticmethod - def create_session(user, duration_seconds=86400, is_compliance_only=False): + def create_session(user, duration_seconds=None, is_compliance_only=False): """ Create a new session for the user. Args: user: User instance - duration_seconds: Session duration in seconds + duration_seconds: Session idle timeout in seconds. + When None, defaults to SESSION_IDLE_TIMEOUT from config. + The absolute lifetime is always enforced by Session.is_active() + regardless of this value. is_compliance_only: Whether this is a compliance-only session (limited access) Returns: Session instance """ + from flask import current_app + + if duration_seconds is None: + duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900) + # Generate session token token = secrets.token_urlsafe(32) diff --git a/gatehouse_app/services/external_auth/linking.py b/gatehouse_app/services/external_auth/linking.py index 670220b..02473c0 100644 --- a/gatehouse_app/services/external_auth/linking.py +++ b/gatehouse_app/services/external_auth/linking.py @@ -263,7 +263,7 @@ def authenticate_with_provider( state_record.mark_used() from gatehouse_app.services.auth_service import AuthService - session = AuthService.create_session(user=user, organization_id=organization_id) + session = AuthService.create_session(user=user) AuditService.log_external_auth_login( user_id=user.id, diff --git a/gatehouse_app/services/session_service.py b/gatehouse_app/services/session_service.py index 7103285..e86cd6f 100644 --- a/gatehouse_app/services/session_service.py +++ b/gatehouse_app/services/session_service.py @@ -10,10 +10,10 @@ class SessionService: @staticmethod def get_active_session_by_token(token): """Get active session by token. - + Args: token: The session token string - + Returns: Session object if found and active, None otherwise """ @@ -23,6 +23,8 @@ class SessionService: token=token, status=SessionStatus.ACTIVE, deleted_at=None + ).filter( + Session.expires_at > datetime.now(timezone.utc) ).first() @staticmethod diff --git a/gatehouse_app/services/superadmin_auth_service.py b/gatehouse_app/services/superadmin_auth_service.py index dde6199..31e798b 100644 --- a/gatehouse_app/services/superadmin_auth_service.py +++ b/gatehouse_app/services/superadmin_auth_service.py @@ -138,7 +138,7 @@ class SuperadminAuthService: Dictionary with emergency session info """ from gatehouse_app.models.user.user import User - from gatehouse_app.services.session_service import SessionService + from gatehouse_app.services.auth_service import AuthService from gatehouse_app.services.audit_service import AuditService # Verify target user exists @@ -147,7 +147,7 @@ class SuperadminAuthService: raise ValueError(f"Target user not found: {target_user_id}") # Create emergency session for the target user - emergency_session = SessionService.create_session( + emergency_session = AuthService.create_session( user=target_user, duration_seconds=duration_minutes * 60, is_compliance_only=False diff --git a/gatehouse_app/utils/decorators.py b/gatehouse_app/utils/decorators.py index 5cbb649..8a90dbd 100644 --- a/gatehouse_app/utils/decorators.py +++ b/gatehouse_app/utils/decorators.py @@ -59,11 +59,9 @@ def login_required(f): error_type="SESSION_INACTIVE" ) - # Update last_activity_at timestamp - from datetime import datetime, timezone - session.last_activity_at = datetime.now(timezone.utc) - from gatehouse_app import db - db.session.commit() + # Extend session via sliding window (updates last_activity_at + # and recalculates expires_at within the idle / absolute caps). + session.refresh() # Set context variables g.current_user = session.user diff --git a/manage.py b/manage.py index 06d3fc7..9975216 100644 --- a/manage.py +++ b/manage.py @@ -153,6 +153,31 @@ def mfa_compliance_status(): print("=" * 60) +@cli.command("cleanup_sessions") +def cleanup_sessions(): + """Clean up expired user sessions. + + Marks sessions as EXPIRED when they have passed their expires_at + timestamp. Safe to run frequently (e.g. every 5 minutes via job_runner). + + Usage: + python manage.py cleanup_sessions + """ + from gatehouse_app.services.session_service import SessionService + + print("=" * 60) + print("Session Cleanup Job") + print("=" * 60) + + from datetime import datetime, timezone + print(f"Start time: {datetime.now(timezone.utc).isoformat()}") + + count = SessionService.cleanup_expired_sessions() + + print(f"Expired sessions marked: {count}") + print("=" * 60) + + @cli.command("configure_oauth") @click.argument("provider", required=False) @click.option("--client-id", default=None, help="OAuth client ID") diff --git a/scripts/job_runner.py b/scripts/job_runner.py index 8aa64c3..549d4de 100755 --- a/scripts/job_runner.py +++ b/scripts/job_runner.py @@ -28,6 +28,7 @@ logger = logging.getLogger(__name__) JOB_COMMANDS = { "zerotier_reconciliation": "python manage.py run_zerotier_reconciliation", "mfa_compliance": "python manage.py run_mfa_compliance_job", + "cleanup_sessions": "python manage.py cleanup_sessions", } shutdown_requested = False diff --git a/tests/integration/client/auth.py b/tests/integration/client/auth.py index 71bc325..181627a 100644 --- a/tests/integration/client/auth.py +++ b/tests/integration/client/auth.py @@ -95,6 +95,10 @@ class AuthClient: """Revoke a specific session belonging to the current user.""" return self._client.delete(f"/auth/sessions/{session_id}") + def refresh_session(self) -> dict: + """Extend the current session's idle window.""" + return self._client.post("/auth/sessions/refresh") + # ------------------------------------------------------------------ # Password recovery # ------------------------------------------------------------------ diff --git a/tests/integration/test_session_timeouts.py b/tests/integration/test_session_timeouts.py new file mode 100644 index 0000000..be4dcf3 --- /dev/null +++ b/tests/integration/test_session_timeouts.py @@ -0,0 +1,213 @@ +"""Session timeout integration tests. + +Validates the sliding-window session timeout policy: idle timeout, +absolute timeout, and the interaction between activity and expiration. +Every test exercises the *public API* — the only internal manipulation +is back-dating timestamps in the database, since we cannot wait minutes +inside a test run. +""" +import pytest +import uuid +from datetime import datetime, timedelta, timezone + +from tests.integration.client.base import ApiError + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def assert_success(response: dict, message_contains: str = "") -> dict: + """Assert that an api_response-wrapped payload succeeded.""" + data = response.get("data", {}) + assert response.get("success") is not False, ( + f"Expected success but got error: {response.get('message')}" + ) + if message_contains: + assert message_contains.lower() in response.get("message", "").lower(), ( + f"Expected message to contain '{message_contains}' but got: {response.get('message')}" + ) + return data + + +def _get_session_row(integration_app, token: str): + """Look up the Session model row for a given bearer token.""" + from gatehouse_app.models.user.session import Session + with integration_app.app_context(): + return Session.query.filter_by(token=token).first() + + +def _touch_session(integration_app, session_id: str, **updates): + """Directly update columns on a Session row. + + Only use this to simulate the passage of time — never to assert + internal state. + """ + from gatehouse_app.models.user.session import Session + with integration_app.app_context(): + sess = Session.query.get(session_id) + for attr, value in updates.items(): + setattr(sess, attr, value) + from gatehouse_app import db + db.session.commit() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def logged_in_session(integration_client, create_test_user, integration_app): + """Register a user, log in via the API, and return the session metadata. + + Returns dict with ``user``, ``token``, ``session_id``, ``session_row``. + The ``session_row`` is a detached SQLAlchemy instance — re-query if + you need fresh DB state. + """ + user = create_test_user(password="TestPass123!") + integration_client.auth.login( + email=user["email"], password="TestPass123!", + ) + token = integration_client._token + + session_row = _get_session_row(integration_app, token) + assert session_row is not None, "Session row should exist after login" + + return { + "user": user, + "token": token, + "session_id": session_row.id, + "session_row": session_row, + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestSessionTimeouts: + """Sliding-window timeout behavior exercised through the public API.""" + + def test_session_valid_before_timeout( + self, integration_client, create_test_user, + ): + """SESS-01 — Fresh session is accepted. + + A session that was just created should pass all auth checks. + This is the baseline: if this fails, every other timeout test + is meaningless. + """ + user = create_test_user(password="MyPass123!") + integration_client.auth.login(email=user["email"], password="MyPass123!") + + result = integration_client.auth.me() + data = assert_success(result) + assert data["user"]["email"] == user["email"] + + def test_idle_timeout_rejects_token( + self, integration_client, logged_in_session, integration_app, + ): + """SESS-02 — Session rejected after idle period elapses. + + Push ``last_activity_at`` far enough into the past that the + idle window has closed. The API must return 401. + """ + _touch_session( + integration_app, + logged_in_session["session_id"], + last_activity_at=datetime.now(timezone.utc) - timedelta(hours=1), + ) + + with pytest.raises(ApiError) as exc_info: + integration_client.auth.me() + + assert exc_info.value.status_code == 401 + + def test_absolute_timeout_rejects_even_active_user( + self, integration_client, logged_in_session, integration_app, + ): + """SESS-03 — Absolute cap overrides recent activity. + + Push ``created_at`` into the past so the absolute window has + elapsed, but keep ``last_activity_at`` fresh. The session + must still be rejected — activity cannot extend past the + absolute limit. + """ + _touch_session( + integration_app, + logged_in_session["session_id"], + created_at=datetime.now(timezone.utc) - timedelta(days=1), + last_activity_at=datetime.now(timezone.utc), + ) + + with pytest.raises(ApiError) as exc_info: + integration_client.auth.me() + + assert exc_info.value.status_code == 401 + + def test_api_request_keeps_session_alive( + self, integration_client, logged_in_session, integration_app, + ): + """SESS-04 — Hitting an API endpoint extends the session. + + Back-date ``last_activity_at`` to *just* inside the idle + window. A subsequent API call should succeed and the session + should remain usable — the sliding window should have reset. + """ + from gatehouse_app.models.user.session import Session + from gatehouse_app import db + + # Back-date to 10 seconds ago — still inside the idle window. + _touch_session( + integration_app, + logged_in_session["session_id"], + last_activity_at=datetime.now(timezone.utc) - timedelta(seconds=10), + ) + + # This request should succeed AND extend the session. + result = integration_client.auth.me() + assert_success(result) + + # After the request, last_activity_at should be much closer to now. + with integration_app.app_context(): + refreshed = Session.query.get(logged_in_session["session_id"]) + now = datetime.now(timezone.utc) + # Allow for clock skew / commit latency — should be within 30s. + diff = abs((now - refreshed.last_activity_at.replace(tzinfo=timezone.utc)).total_seconds()) + assert diff < 30, ( + f"last_activity_at should be near-now after API call, " + f"but delta is {diff:.1f}s" + ) + + def test_revoked_session_rejected( + self, integration_client, logged_in_session, + ): + """SESS-05 — Revoked session is rejected regardless of timing. + + Revoke via the API, then verify the token is dead. This + mirrors AUTH-12 but is included here so the timeout test + suite is self-contained. + """ + integration_client.auth.revoke_session(logged_in_session["session_id"]) + + with pytest.raises(ApiError) as exc_info: + integration_client.auth.me() + + assert exc_info.value.status_code == 401 + + def test_refresh_endpoint_extends_session( + self, integration_client, logged_in_session, integration_app, + ): + """SESS-06 — POST /auth/sessions/refresh extends the session. + + The refresh endpoint exists so the frontend can proactively + keep a session alive during idle UI periods. Verify it + succeeds and returns a new ``expires_at``. + """ + result = integration_client.auth.refresh_session() + data = assert_success(result, "session refreshed") + + assert "expires_at" in data, "Response should include new expires_at" diff --git a/tests/unit/test_ca_key_encryption.py b/tests/unit/test_ca_key_encryption.py new file mode 100644 index 0000000..3f19353 --- /dev/null +++ b/tests/unit/test_ca_key_encryption.py @@ -0,0 +1,205 @@ +"""Unit tests for ca_key_encryption module. + +WHAT: Tests for the Fernet-based CA private key encryption/decryption + utility functions. +WHY: CA private keys are the most sensitive data in the system; we need + to verify round-trip correctness, idempotency, and error handling. +EXPECTED: All encrypt/decrypt operations produce correct plaintext. +""" +import os +import threading +from unittest.mock import patch + +import pytest + +from gatehouse_app.utils.ca_key_encryption import ( + CAKeyEncryptionError, + _FERNET_PREFIX, + _get_fernet, + decrypt_ca_key, + encrypt_ca_key, + is_encrypted, + reencrypt_ca_key, +) + + +# --------------------------------------------------------------------------- +# Shared fixture +# --------------------------------------------------------------------------- + +SAMPLE_PEM = ( + "-----BEGIN OPENSSH PRIVATE KEY-----\n" + "b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtz\n" + "c2gtZWQyNTUxOQAAACBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBAAA\n" + "-----END OPENSSH PRIVATE KEY-----" +) + + +@pytest.fixture(autouse=True) +def _set_ca_encryption_key(): + """Ensure CA_ENCRYPTION_KEY is set for every test.""" + with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "test-secret-key-for-unit-tests"}): + yield + + +# --------------------------------------------------------------------------- +# encrypt / decrypt round-trip +# --------------------------------------------------------------------------- + +class TestEncryptDecryptRoundTrip: + """Verify that encrypt -> decrypt returns the original plaintext.""" + + def test_basic_round_trip(self): + """TEST: ENC-RT-01 -- Encrypt then decrypt returns original PEM.""" + encrypted = encrypt_ca_key(SAMPLE_PEM) + decrypted = decrypt_ca_key(encrypted) + assert decrypted == SAMPLE_PEM + + def test_encrypted_value_has_prefix(self): + """TEST: ENC-RT-02 -- Encrypted output carries the $fernet$ envelope.""" + encrypted = encrypt_ca_key(SAMPLE_PEM) + assert encrypted.startswith(_FERNET_PREFIX) + + def test_different_ciphertext_each_time(self): + """TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ.""" + enc1 = encrypt_ca_key(SAMPLE_PEM) + enc2 = encrypt_ca_key(SAMPLE_PEM) + assert enc1 != enc2 + assert decrypt_ca_key(enc1) == SAMPLE_PEM + assert decrypt_ca_key(enc2) == SAMPLE_PEM + + +# --------------------------------------------------------------------------- +# Idempotency +# --------------------------------------------------------------------------- + +class TestIdempotency: + """The module must not double-encrypt or double-decrypt.""" + + def test_encrypt_idempotent(self): + """TEST: ENC-IDEM-01 -- Encrypting an already-encrypted value is a no-op.""" + encrypted = encrypt_ca_key(SAMPLE_PEM) + double = encrypt_ca_key(encrypted) + assert double == encrypted + + def test_decrypt_plaintext_passthrough(self): + """TEST: ENC-IDEM-02 -- Decrypting a plaintext (legacy) value returns it as-is.""" + result = decrypt_ca_key(SAMPLE_PEM) + assert result == SAMPLE_PEM + + +# --------------------------------------------------------------------------- +# is_encrypted helper +# --------------------------------------------------------------------------- + +class TestIsEncrypted: + def test_encrypted_value(self): + """TEST: ENC-IE-01 -- is_encrypted returns True for $fernet$ values.""" + encrypted = encrypt_ca_key(SAMPLE_PEM) + assert is_encrypted(encrypted) is True + + def test_plaintext_value(self): + """TEST: ENC-IE-02 -- is_encrypted returns False for plain PEM.""" + assert is_encrypted(SAMPLE_PEM) is False + + def test_empty_string(self): + """TEST: ENC-IE-03 -- is_encrypted returns False for empty string.""" + assert is_encrypted("") is False + + def test_none_value(self): + """TEST: ENC-IE-04 -- is_encrypted returns False for None.""" + assert is_encrypted(None) is False + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + +class TestErrorHandling: + def test_encrypt_empty_raises(self): + """TEST: ENC-ERR-01 -- Encrypting empty string raises CAKeyEncryptionError.""" + with pytest.raises(CAKeyEncryptionError, match="empty"): + encrypt_ca_key("") + + def test_decrypt_empty_raises(self): + """TEST: ENC-ERR-02 -- Decrypting empty string raises CAKeyEncryptionError.""" + with pytest.raises(CAKeyEncryptionError, match="empty"): + decrypt_ca_key("") + + def test_missing_key_raises(self): + """TEST: ENC-ERR-03 -- Operations fail when CA_ENCRYPTION_KEY is unset.""" + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("CA_ENCRYPTION_KEY", None) + with pytest.raises(CAKeyEncryptionError, match="not set"): + encrypt_ca_key(SAMPLE_PEM) + + def test_wrong_key_raises_on_decrypt(self): + """TEST: ENC-ERR-04 -- Decrypting with the wrong key raises.""" + encrypted = encrypt_ca_key(SAMPLE_PEM) + with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "wrong-key"}): + with pytest.raises(CAKeyEncryptionError, match="decryption failed"): + decrypt_ca_key(encrypted) + + def test_corrupted_data_raises(self): + """TEST: ENC-ERR-05 -- Decrypting corrupted ciphertext raises.""" + with pytest.raises(CAKeyEncryptionError): + decrypt_ca_key("$fernet$not-a-real-token") + + +# --------------------------------------------------------------------------- +# reencrypt_ca_key -- key rotation +# --------------------------------------------------------------------------- + +class TestReencrypt: + def test_reencrypt_round_trip(self): + """TEST: ENC-RE-01 -- Re-encrypted value decrypts with the new key.""" + old_key = "old-secret-key" + new_key = "new-secret-key" + encrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", old_key) + reencrypted = reencrypt_ca_key(encrypted, old_key, new_key) + + # Verify it decrypts with the new key + with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}): + decrypted = decrypt_ca_key(reencrypted) + assert decrypted == SAMPLE_PEM + + def test_reencrypt_plaintext_key(self): + """TEST: ENC-RE-02 -- Re-encrypting a legacy plaintext key works.""" + new_key = "brand-new-key" + reencrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", new_key) + with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}): + decrypted = decrypt_ca_key(reencrypted) + assert decrypted == SAMPLE_PEM + + +# --------------------------------------------------------------------------- +# Thread safety +# --------------------------------------------------------------------------- + +class TestThreadSafety: + """Concurrent encrypt/decrypt calls must not corrupt state.""" + + def test_concurrent_encrypt_decrypt(self): + """TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently.""" + errors = [] + results = [] + + def worker(i): + try: + data = f"key-data-{i}" + enc = encrypt_ca_key(data) + dec = decrypt_ca_key(enc) + results.append((i, dec)) + except Exception as exc: + errors.append((i, exc)) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + + assert not errors, f"Thread errors: {errors}" + assert len(results) == 50 + for i, dec in results: + assert dec == f"key-data-{i}", f"Thread {i}: expected 'key-data-{i}', got {dec!r}" diff --git a/tests/unit/test_cors.py b/tests/unit/test_cors.py new file mode 100644 index 0000000..46a55fe --- /dev/null +++ b/tests/unit/test_cors.py @@ -0,0 +1,125 @@ +"""Unit tests for CORS middleware. + +WHAT: Tests for the CORS middleware configuration, including wildcard + origin handling, credentials support, and preflight responses. +WHY: CORS misconfiguration can break browser clients or leak credentials. +EXPECTED: Correct Access-Control-* headers for all origin configurations. +""" +import pytest +from flask import Flask + +from gatehouse_app.middleware.cors import ( + _is_origin_allowed, + _cors_origin_header, + setup_cors, +) + + +# --------------------------------------------------------------------------- +# _is_origin_allowed +# --------------------------------------------------------------------------- + +class TestIsOriginAllowed: + def test_empty_origin_rejected(self): + """TEST: CORS-01 -- Empty origin is never allowed.""" + assert _is_origin_allowed("", ["https://example.com"]) is False + assert _is_origin_allowed(None, "*") is False + + def test_wildcard_string(self): + """TEST: CORS-02 -- Wildcard string allows any origin.""" + assert _is_origin_allowed("https://evil.com", "*") is True + + def test_wildcard_in_list(self): + """TEST: CORS-03 -- Wildcard in list allows any origin.""" + assert _is_origin_allowed("https://evil.com", ["*", "https://example.com"]) is True + + def test_explicit_origin_match(self): + """TEST: CORS-04 -- Explicit list matches exact origin.""" + origins = ["https://example.com", "http://localhost:3000"] + assert _is_origin_allowed("https://example.com", origins) is True + assert _is_origin_allowed("https://evil.com", origins) is False + + def test_empty_origins_list(self): + """TEST: CORS-05 -- Empty list rejects everything.""" + assert _is_origin_allowed("https://example.com", []) is False + + +# --------------------------------------------------------------------------- +# _cors_origin_header +# --------------------------------------------------------------------------- + +class TestCorsOriginHeader: + def test_wildcard_with_origin_echoes(self): + """TEST: CORS-HDR-01 -- Wildcard echoes request origin (for credentials).""" + assert _cors_origin_header("*", "https://example.com") == "https://example.com" + + def test_wildcard_without_origin(self): + """TEST: CORS-HDR-02 -- Wildcard with no origin returns *.""" + assert _cors_origin_header("*", None) == "*" + + def test_wildcard_in_list_with_origin(self): + """TEST: CORS-HDR-03 -- Wildcard in list echoes request origin.""" + result = _cors_origin_header(["*", "https://example.com"], "https://any.com") + assert result == "https://any.com" + + def test_specific_origin_match(self): + """TEST: CORS-HDR-04 -- Matching origin is echoed.""" + origins = ["https://example.com"] + assert _cors_origin_header(origins, "https://example.com") == "https://example.com" + + def test_specific_origin_no_match(self): + """TEST: CORS-HDR-05 -- Non-matching origin returns None.""" + origins = ["https://example.com"] + assert _cors_origin_header(origins, "https://evil.com") is None + + def test_no_origin_no_match(self): + """TEST: CORS-HDR-06 -- No origin with specific list returns None.""" + origins = ["https://example.com"] + assert _cors_origin_header(origins, None) is None + + +# --------------------------------------------------------------------------- +# Integration: preflight response +# --------------------------------------------------------------------------- + +class TestPreflightIntegration: + @pytest.fixture + def app_wildcard(self): + app = Flask(__name__) + app.config["CORS_ORIGINS"] = "*" + app.config["CORS_SUPPORTS_CREDENTIALS"] = True + setup_cors(app) + app.config["TESTING"] = True + return app + + @pytest.fixture + def app_specific(self): + app = Flask(__name__) + app.config["CORS_ORIGINS"] = ["https://example.com"] + app.config["CORS_SUPPORTS_CREDENTIALS"] = True + setup_cors(app) + app.config["TESTING"] = True + return app + + def test_wildcard_preflight_echoes_origin(self, app_wildcard): + """TEST: CORS-PF-01 -- Wildcard preflight echoes request origin.""" + with app_wildcard.test_client() as client: + resp = client.options("/", headers={"Origin": "https://example.com"}) + assert resp.status_code == 204 + assert resp.headers.get("Access-Control-Allow-Origin") == "https://example.com" + assert resp.headers.get("Access-Control-Allow-Credentials") == "true" + + def test_specific_origin_preflight(self, app_specific): + """TEST: CORS-PF-02 -- Specific origin preflight allows matching origin.""" + with app_specific.test_client() as client: + resp = client.options("/", headers={"Origin": "https://example.com"}) + assert resp.status_code == 204 + assert resp.headers.get("Access-Control-Allow-Origin") == "https://example.com" + assert resp.headers.get("Access-Control-Allow-Credentials") == "true" + + def test_specific_origin_rejects_unknown(self, app_specific): + """TEST: CORS-PF-03 -- Non-matching origin gets no CORS headers.""" + with app_specific.test_client() as client: + resp = client.options("/", headers={"Origin": "https://evil.com"}) + # No preflight handler runs, Flask returns default + assert resp.headers.get("Access-Control-Allow-Origin") is None diff --git a/tests/unit/test_encryption.py b/tests/unit/test_encryption.py new file mode 100644 index 0000000..d54ee19 --- /dev/null +++ b/tests/unit/test_encryption.py @@ -0,0 +1,164 @@ +"""Unit tests for encryption module (general-purpose Fernet encryption). + +WHAT: Tests for the PBKDF2-based Fernet encryption/decryption used for + OAuth tokens and client secrets. +WHY: These utilities protect access tokens and client secrets; we need + to verify round-trip correctness and error handling. +EXPECTED: All encrypt/decrypt operations produce correct plaintext. +""" +import threading + +import pytest + +from gatehouse_app.utils.encryption import ( + SALT_LENGTH, + _get_fernet_key, + decrypt, + encrypt, +) + + +# --------------------------------------------------------------------------- +# Shared fixture +# --------------------------------------------------------------------------- + +SECRET_KEY = "test-encryption-secret-key" +SAMPLE_DATA = "access_token=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload" + + +# --------------------------------------------------------------------------- +# encrypt / decrypt round-trip +# --------------------------------------------------------------------------- + +class TestEncryptDecryptRoundTrip: + """Verify that encrypt -> decrypt returns the original plaintext.""" + + def test_basic_round_trip(self): + """TEST: ENC-RT-01 -- Encrypt then decrypt returns original data.""" + encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) + decrypted = decrypt(encrypted, secret_key=SECRET_KEY) + assert decrypted == SAMPLE_DATA + + def test_encrypted_is_base64(self): + """TEST: ENC-RT-02 -- Encrypted output is valid base64.""" + import base64 + encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) + # Should not raise + base64.urlsafe_b64decode(encrypted.encode()) + + def test_different_ciphertext_each_time(self): + """TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ.""" + enc1 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) + enc2 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) + assert enc1 != enc2 + assert decrypt(enc1, secret_key=SECRET_KEY) == SAMPLE_DATA + assert decrypt(enc2, secret_key=SECRET_KEY) == SAMPLE_DATA + + def test_round_trip_unicode(self): + """TEST: ENC-RT-04 -- Unicode data round-trips correctly.""" + data = "token=cafe\u00e9\u00f1\u00fc" + encrypted = encrypt(data, secret_key=SECRET_KEY) + assert decrypt(encrypted, secret_key=SECRET_KEY) == data + + def test_round_trip_long_data(self): + """TEST: ENC-RT-05 -- Large data round-trips correctly.""" + data = "x" * 10000 + encrypted = encrypt(data, secret_key=SECRET_KEY) + assert decrypt(encrypted, secret_key=SECRET_KEY) == data + + +# --------------------------------------------------------------------------- +# Empty / edge inputs +# --------------------------------------------------------------------------- + +class TestEdgeCases: + def test_encrypt_empty_returns_empty(self): + """TEST: ENC-EDGE-01 -- Encrypting empty string returns empty.""" + assert encrypt("", secret_key=SECRET_KEY) == "" + + def test_decrypt_empty_returns_empty(self): + """TEST: ENC-EDGE-02 -- Decrypting empty string returns empty.""" + assert decrypt("", secret_key=SECRET_KEY) == "" + + def test_missing_key_raises_on_encrypt(self): + """TEST: ENC-EDGE-03 -- Missing key raises ValueError on encrypt.""" + with pytest.raises(ValueError, match="Encryption key not configured"): + encrypt("data", secret_key="") + + def test_missing_key_raises_on_decrypt(self): + """TEST: ENC-EDGE-04 -- Missing key raises ValueError on decrypt.""" + with pytest.raises(ValueError, match="Encryption key not configured"): + decrypt("something", secret_key="") + + def test_wrong_key_raises_on_decrypt(self): + """TEST: ENC-EDGE-05 -- Wrong key raises ValueError on decrypt.""" + encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) + with pytest.raises(ValueError, match="Failed to decrypt"): + decrypt(encrypted, secret_key="wrong-key") + + def test_corrupted_data_raises(self): + """TEST: ENC-EDGE-06 -- Corrupted ciphertext raises ValueError.""" + import base64 + bad = base64.urlsafe_b64encode(b"not-valid-fernet-data").decode() + with pytest.raises(ValueError): + decrypt(bad, secret_key=SECRET_KEY) + + +# --------------------------------------------------------------------------- +# _get_fernet_key — PBKDF2 derivation +# --------------------------------------------------------------------------- + +class TestKeyDerivation: + def test_same_salt_same_key(self): + """TEST: ENC-KD-01 -- Same salt produces the same derived key.""" + salt = b"\x00" * SALT_LENGTH + key1 = _get_fernet_key(SECRET_KEY, salt=salt) + key2 = _get_fernet_key(SECRET_KEY, salt=salt) + assert key1 == key2 + + def test_different_salt_different_key(self): + """TEST: ENC-KD-02 -- Different salts produce different keys.""" + salt1 = b"\x00" * SALT_LENGTH + salt2 = b"\xff" * SALT_LENGTH + key1 = _get_fernet_key(SECRET_KEY, salt=salt1) + key2 = _get_fernet_key(SECRET_KEY, salt=salt2) + assert key1 != key2 + + def test_auto_salt_length(self): + """TEST: ENC-KD-03 -- Auto-generated salt is 16 bytes.""" + key = _get_fernet_key(SECRET_KEY) + # If it didn't raise, the salt was valid + assert len(key) > 0 + + +# --------------------------------------------------------------------------- +# Thread safety +# --------------------------------------------------------------------------- + +class TestThreadSafety: + """Concurrent encrypt/decrypt calls must not corrupt state.""" + + def test_concurrent_encrypt_decrypt(self): + """TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently.""" + errors = [] + results = [] + + def worker(i): + try: + data = f"token-{i}-secret" + enc = encrypt(data, secret_key=SECRET_KEY) + dec = decrypt(enc, secret_key=SECRET_KEY) + results.append((i, dec)) + except Exception as exc: + errors.append((i, exc)) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + + assert not errors, f"Thread errors: {errors}" + assert len(results) == 50 + for i, dec in results: + assert dec == f"token-{i}-secret", f"Thread {i}: mismatch"