Merge pull request #34 from CoryHawkless/cory-wip-session
fix(cors): handle wildcard origin with credentials and add unit tests
This commit is contained in:
@@ -154,6 +154,7 @@ Copy `.env.example` to `.env` and configure:
|
||||
- `POST /api/v1/auth/logout` - Logout
|
||||
- `GET /api/v1/auth/me` - Get current user
|
||||
- `GET /api/v1/auth/sessions` - Get user sessions
|
||||
- `POST /api/v1/auth/sessions/refresh` - Extend session idle window
|
||||
- `DELETE /api/v1/auth/sessions/:id` - Revoke session
|
||||
|
||||
### Users
|
||||
@@ -264,6 +265,52 @@ gunicorn -w 4 -b 0.0.0.0:8000 wsgi:app
|
||||
- Request ID tracking for audit trails
|
||||
|
||||
|
||||
## Session Management
|
||||
|
||||
Sessions are database-backed bearer tokens stored in PostgreSQL. Each session is created at login and validated on every authenticated request via the `login_required` decorator.
|
||||
|
||||
### Sliding Timeout
|
||||
|
||||
Sessions use a **sliding window** model with two independent limits:
|
||||
|
||||
| Timeout | Default | Env Var | Behaviour |
|
||||
|---------|---------|---------|-----------|
|
||||
| **Idle** | 15 min | `SESSION_IDLE_TIMEOUT` | Extends automatically on every request. If no request is made within this window the session expires. |
|
||||
| **Absolute** | 8 h | `SESSION_ABSOLUTE_TIMEOUT` | Hard cap measured from session creation. Activity cannot extend a session beyond this point. |
|
||||
|
||||
Every authenticated request resets the idle clock by calling `Session.refresh()`, which sets `expires_at = now + idle_timeout` — but never past `created_at + absolute_timeout`. This means:
|
||||
|
||||
- An active user stays logged in indefinitely **up to** the absolute cap.
|
||||
- An idle user is logged out after the idle timeout.
|
||||
- No session can survive longer than the absolute timeout regardless of activity.
|
||||
|
||||
### Configuration
|
||||
|
||||
Override defaults via environment variables:
|
||||
|
||||
```bash
|
||||
SESSION_IDLE_TIMEOUT=900 # seconds (15 min)
|
||||
SESSION_ABSOLUTE_TIMEOUT=28800 # seconds (8 h)
|
||||
```
|
||||
|
||||
### Cleanup
|
||||
|
||||
Expired sessions are soft-marked as `EXPIRED` by the `cleanup_sessions` job. Run it periodically via the job runner:
|
||||
|
||||
```bash
|
||||
python manage.py cleanup_sessions
|
||||
|
||||
# Or via the job runner (Docker):
|
||||
JOB_NAME=cleanup_sessions JOB_INTERVAL_SECONDS=300
|
||||
```
|
||||
|
||||
### Session Endpoints
|
||||
|
||||
- `GET /api/v1/auth/sessions` — List active sessions for the current user
|
||||
- `POST /api/v1/auth/sessions/refresh` — Extend the current session's idle window (returns new `expires_at`)
|
||||
- `DELETE /api/v1/auth/sessions/:id` — Revoke a specific session
|
||||
|
||||
|
||||
# Boostrap db
|
||||
python manage.py db upgrade
|
||||
|
||||
|
||||
@@ -48,6 +48,10 @@ class BaseConfig:
|
||||
seconds=int(os.getenv("MAX_SESSION_DURATION", "86400"))
|
||||
)
|
||||
|
||||
# Session timeout policy (seconds)
|
||||
SESSION_IDLE_TIMEOUT = int(os.getenv("SESSION_IDLE_TIMEOUT", "900"))
|
||||
SESSION_ABSOLUTE_TIMEOUT = int(os.getenv("SESSION_ABSOLUTE_TIMEOUT", "28800"))
|
||||
|
||||
# CORS
|
||||
CORS_ORIGINS = os.getenv(
|
||||
"CORS_ORIGINS",
|
||||
|
||||
@@ -116,7 +116,7 @@ def login():
|
||||
|
||||
remember_me = data.get("remember_me", False)
|
||||
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me)
|
||||
duration = 2592000 if remember_me else 86400
|
||||
duration = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800) if remember_me else None
|
||||
is_compliance_only = policy_result.create_compliance_only_session
|
||||
|
||||
user_session = AuthService.create_session(user, duration_seconds=duration, is_compliance_only=is_compliance_only)
|
||||
@@ -227,6 +227,32 @@ def revoke_session(session_id):
|
||||
return api_response(message="Session revoked successfully")
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/sessions/refresh", methods=["POST"])
|
||||
@login_required
|
||||
def refresh_session():
|
||||
"""Extend the current session's idle window.
|
||||
|
||||
The server already refreshes the session on every authenticated
|
||||
request, but this endpoint exists so the frontend can proactively
|
||||
keep a session alive (e.g. a heartbeat while the user is reading
|
||||
a long page with no API calls).
|
||||
|
||||
Returns the new ``expires_at`` so the frontend can display a
|
||||
countdown or warning before the absolute cap.
|
||||
"""
|
||||
session = g.current_session
|
||||
session.refresh()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"expires_at": session.expires_at.isoformat() + "Z"
|
||||
if session.expires_at.isoformat()[-1] != "Z"
|
||||
else session.expires_at.isoformat(),
|
||||
},
|
||||
message="Session refreshed",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/token", methods=["GET"])
|
||||
@login_required
|
||||
def get_token():
|
||||
@@ -246,7 +272,8 @@ def get_token():
|
||||
parsed_redirect = urlparse(redirect_url)
|
||||
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
|
||||
|
||||
if redirect_origin not in allowed_origins:
|
||||
wildcard = "*" in allowed_origins
|
||||
if not wildcard and redirect_origin not in allowed_origins:
|
||||
return api_response(success=False, message="Redirect URL is not allowed.", status=400, error_type="INVALID_REDIRECT")
|
||||
|
||||
sep = "&" if "?" in redirect_url else "?"
|
||||
|
||||
@@ -190,8 +190,8 @@ def select_organization():
|
||||
if not member:
|
||||
return api_response(success=False, message="You are not a member of this organization", status=403, error_type="FORBIDDEN")
|
||||
|
||||
from gatehouse_app.services.session_service import SessionService
|
||||
session = SessionService.create_session(user=user, organization_id=organization_id)
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user)
|
||||
state_record.mark_used()
|
||||
|
||||
provider_type_val = state_record.provider_type.value if isinstance(state_record.provider_type, _AuthMethodType) else state_record.provider_type
|
||||
|
||||
@@ -1,6 +1,44 @@
|
||||
"""CORS middleware configuration."""
|
||||
from flask import request, make_response
|
||||
|
||||
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
ALLOWED_HEADERS = (
|
||||
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
|
||||
"Cache-Control, Pragma, X-WebAuthn-Session-Token"
|
||||
)
|
||||
|
||||
|
||||
def _is_origin_allowed(origin, cors_origins):
|
||||
"""Return True if the origin is permitted by the CORS config.
|
||||
|
||||
Handles both wildcard ("*") and explicit origin lists.
|
||||
"""
|
||||
if not origin:
|
||||
return False
|
||||
if cors_origins == "*":
|
||||
return True
|
||||
if isinstance(cors_origins, list):
|
||||
if "*" in cors_origins:
|
||||
return True
|
||||
return origin in cors_origins
|
||||
return False
|
||||
|
||||
|
||||
def _cors_origin_header(cors_origins, request_origin):
|
||||
"""Return the value for Access-Control-Allow-Origin.
|
||||
|
||||
Per the CORS spec, browsers reject ``*`` when credentials are involved,
|
||||
so we echo the request origin when wildcard + credentials is configured.
|
||||
"""
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
if allow_all and request_origin:
|
||||
return request_origin
|
||||
if allow_all:
|
||||
return "*"
|
||||
if request_origin and request_origin in cors_origins:
|
||||
return request_origin
|
||||
return None
|
||||
|
||||
|
||||
def setup_cors(app):
|
||||
"""
|
||||
@@ -9,6 +47,7 @@ def setup_cors(app):
|
||||
Args:
|
||||
app: Flask application instance
|
||||
"""
|
||||
supports_credentials = app.config.get("CORS_SUPPORTS_CREDENTIALS", True)
|
||||
|
||||
@app.before_request
|
||||
def handle_preflight():
|
||||
@@ -16,49 +55,33 @@ def setup_cors(app):
|
||||
if request.method == "OPTIONS":
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
|
||||
if allow_all:
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
return response
|
||||
elif origin and origin in cors_origins:
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
|
||||
|
||||
if not _is_origin_allowed(origin, cors_origins):
|
||||
return None
|
||||
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = _cors_origin_header(cors_origins, origin)
|
||||
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
|
||||
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
|
||||
if supports_credentials:
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
return response
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
return response
|
||||
|
||||
@app.after_request
|
||||
def after_request_cors(response):
|
||||
"""Add additional CORS headers if needed."""
|
||||
"""Add CORS headers to non-preflight responses."""
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
|
||||
if allow_all:
|
||||
# When allowing all origins, set header to "*"
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
elif origin and origin in cors_origins:
|
||||
# When allowing specific origins, echo the request origin
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
allow_origin = _cors_origin_header(cors_origins, origin)
|
||||
if allow_origin:
|
||||
response.headers["Access-Control-Allow-Origin"] = allow_origin
|
||||
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
|
||||
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
|
||||
if supports_credentials:
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
|
||||
return response
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Session model."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from flask import current_app
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
@@ -38,33 +39,71 @@ class Session(BaseModel):
|
||||
return f"<Session user_id={self.user_id} status={self.status}>"
|
||||
|
||||
def is_active(self):
|
||||
"""Check if session is currently active."""
|
||||
"""Check if session is currently active.
|
||||
|
||||
Sessions are evaluated against two independent timeouts:
|
||||
- Idle timeout: expires if no request has been made within
|
||||
SESSION_IDLE_TIMEOUT seconds (default 15 min).
|
||||
- Absolute timeout: expires if SESSION_ABSOLUTE_TIMEOUT seconds
|
||||
have elapsed since the session was created (default 8 h),
|
||||
regardless of activity.
|
||||
|
||||
A session must satisfy *both* constraints to remain active.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
created_at = self.created_at
|
||||
last_activity_at = self.last_activity_at
|
||||
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
if last_activity_at.tzinfo is None:
|
||||
last_activity_at = last_activity_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
|
||||
|
||||
idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout)
|
||||
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
|
||||
|
||||
return (
|
||||
self.status == SessionStatus.ACTIVE
|
||||
and expires_at > now
|
||||
and now < idle_expires_at
|
||||
and now < absolute_expires_at
|
||||
and self.deleted_at is None
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if session has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return now > expires_at
|
||||
"""Check if session has expired (either idle or absolute)."""
|
||||
return not self.is_active() and self.status != SessionStatus.REVOKED
|
||||
|
||||
def refresh(self, duration_seconds: int = 86400):
|
||||
"""Refresh session expiration.
|
||||
def refresh(self, duration_seconds: int = None):
|
||||
"""Extend the session expiration using a sliding window.
|
||||
|
||||
The new ``expires_at`` is set to *now + idle timeout*, but is
|
||||
capped so that the session never exceeds the absolute lifetime
|
||||
(``created_at + absolute timeout``).
|
||||
|
||||
Args:
|
||||
duration_seconds: New session duration in seconds
|
||||
duration_seconds: Override for the idle timeout. When *None*
|
||||
(the common case), the value is read from
|
||||
``SESSION_IDLE_TIMEOUT`` in the Flask config.
|
||||
"""
|
||||
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
|
||||
self.last_activity_at = datetime.now(timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if duration_seconds is None:
|
||||
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
|
||||
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
|
||||
|
||||
idle_expires_at = now + timedelta(seconds=duration_seconds)
|
||||
|
||||
created_at = self.created_at
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
|
||||
|
||||
self.expires_at = min(idle_expires_at, absolute_expires_at)
|
||||
self.last_activity_at = now
|
||||
db.session.commit()
|
||||
|
||||
def revoke(self, reason: str = None):
|
||||
|
||||
@@ -140,18 +140,26 @@ class AuthService:
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_session(user, duration_seconds=86400, is_compliance_only=False):
|
||||
def create_session(user, duration_seconds=None, is_compliance_only=False):
|
||||
"""
|
||||
Create a new session for the user.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
duration_seconds: Session duration in seconds
|
||||
duration_seconds: Session idle timeout in seconds.
|
||||
When None, defaults to SESSION_IDLE_TIMEOUT from config.
|
||||
The absolute lifetime is always enforced by Session.is_active()
|
||||
regardless of this value.
|
||||
is_compliance_only: Whether this is a compliance-only session (limited access)
|
||||
|
||||
Returns:
|
||||
Session instance
|
||||
"""
|
||||
from flask import current_app
|
||||
|
||||
if duration_seconds is None:
|
||||
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
|
||||
# Generate session token
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
@@ -263,7 +263,7 @@ def authenticate_with_provider(
|
||||
state_record.mark_used()
|
||||
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user, organization_id=organization_id)
|
||||
session = AuthService.create_session(user=user)
|
||||
|
||||
AuditService.log_external_auth_login(
|
||||
user_id=user.id,
|
||||
|
||||
@@ -10,10 +10,10 @@ class SessionService:
|
||||
@staticmethod
|
||||
def get_active_session_by_token(token):
|
||||
"""Get active session by token.
|
||||
|
||||
|
||||
Args:
|
||||
token: The session token string
|
||||
|
||||
|
||||
Returns:
|
||||
Session object if found and active, None otherwise
|
||||
"""
|
||||
@@ -23,6 +23,8 @@ class SessionService:
|
||||
token=token,
|
||||
status=SessionStatus.ACTIVE,
|
||||
deleted_at=None
|
||||
).filter(
|
||||
Session.expires_at > datetime.now(timezone.utc)
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -138,7 +138,7 @@ class SuperadminAuthService:
|
||||
Dictionary with emergency session info
|
||||
"""
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.services.session_service import SessionService
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
# Verify target user exists
|
||||
@@ -147,7 +147,7 @@ class SuperadminAuthService:
|
||||
raise ValueError(f"Target user not found: {target_user_id}")
|
||||
|
||||
# Create emergency session for the target user
|
||||
emergency_session = SessionService.create_session(
|
||||
emergency_session = AuthService.create_session(
|
||||
user=target_user,
|
||||
duration_seconds=duration_minutes * 60,
|
||||
is_compliance_only=False
|
||||
|
||||
@@ -59,11 +59,9 @@ def login_required(f):
|
||||
error_type="SESSION_INACTIVE"
|
||||
)
|
||||
|
||||
# Update last_activity_at timestamp
|
||||
from datetime import datetime, timezone
|
||||
session.last_activity_at = datetime.now(timezone.utc)
|
||||
from gatehouse_app import db
|
||||
db.session.commit()
|
||||
# Extend session via sliding window (updates last_activity_at
|
||||
# and recalculates expires_at within the idle / absolute caps).
|
||||
session.refresh()
|
||||
|
||||
# Set context variables
|
||||
g.current_user = session.user
|
||||
|
||||
@@ -153,6 +153,31 @@ def mfa_compliance_status():
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
@cli.command("cleanup_sessions")
|
||||
def cleanup_sessions():
|
||||
"""Clean up expired user sessions.
|
||||
|
||||
Marks sessions as EXPIRED when they have passed their expires_at
|
||||
timestamp. Safe to run frequently (e.g. every 5 minutes via job_runner).
|
||||
|
||||
Usage:
|
||||
python manage.py cleanup_sessions
|
||||
"""
|
||||
from gatehouse_app.services.session_service import SessionService
|
||||
|
||||
print("=" * 60)
|
||||
print("Session Cleanup Job")
|
||||
print("=" * 60)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
print(f"Start time: {datetime.now(timezone.utc).isoformat()}")
|
||||
|
||||
count = SessionService.cleanup_expired_sessions()
|
||||
|
||||
print(f"Expired sessions marked: {count}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
@cli.command("configure_oauth")
|
||||
@click.argument("provider", required=False)
|
||||
@click.option("--client-id", default=None, help="OAuth client ID")
|
||||
|
||||
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
JOB_COMMANDS = {
|
||||
"zerotier_reconciliation": "python manage.py run_zerotier_reconciliation",
|
||||
"mfa_compliance": "python manage.py run_mfa_compliance_job",
|
||||
"cleanup_sessions": "python manage.py cleanup_sessions",
|
||||
}
|
||||
|
||||
shutdown_requested = False
|
||||
|
||||
@@ -95,6 +95,10 @@ class AuthClient:
|
||||
"""Revoke a specific session belonging to the current user."""
|
||||
return self._client.delete(f"/auth/sessions/{session_id}")
|
||||
|
||||
def refresh_session(self) -> dict:
|
||||
"""Extend the current session's idle window."""
|
||||
return self._client.post("/auth/sessions/refresh")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Password recovery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
"""Session timeout integration tests.
|
||||
|
||||
Validates the sliding-window session timeout policy: idle timeout,
|
||||
absolute timeout, and the interaction between activity and expiration.
|
||||
Every test exercises the *public API* — the only internal manipulation
|
||||
is back-dating timestamps in the database, since we cannot wait minutes
|
||||
inside a test run.
|
||||
"""
|
||||
import pytest
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from tests.integration.client.base import ApiError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def assert_success(response: dict, message_contains: str = "") -> dict:
|
||||
"""Assert that an api_response-wrapped payload succeeded."""
|
||||
data = response.get("data", {})
|
||||
assert response.get("success") is not False, (
|
||||
f"Expected success but got error: {response.get('message')}"
|
||||
)
|
||||
if message_contains:
|
||||
assert message_contains.lower() in response.get("message", "").lower(), (
|
||||
f"Expected message to contain '{message_contains}' but got: {response.get('message')}"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def _get_session_row(integration_app, token: str):
|
||||
"""Look up the Session model row for a given bearer token."""
|
||||
from gatehouse_app.models.user.session import Session
|
||||
with integration_app.app_context():
|
||||
return Session.query.filter_by(token=token).first()
|
||||
|
||||
|
||||
def _touch_session(integration_app, session_id: str, **updates):
|
||||
"""Directly update columns on a Session row.
|
||||
|
||||
Only use this to simulate the passage of time — never to assert
|
||||
internal state.
|
||||
"""
|
||||
from gatehouse_app.models.user.session import Session
|
||||
with integration_app.app_context():
|
||||
sess = Session.query.get(session_id)
|
||||
for attr, value in updates.items():
|
||||
setattr(sess, attr, value)
|
||||
from gatehouse_app import db
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logged_in_session(integration_client, create_test_user, integration_app):
|
||||
"""Register a user, log in via the API, and return the session metadata.
|
||||
|
||||
Returns dict with ``user``, ``token``, ``session_id``, ``session_row``.
|
||||
The ``session_row`` is a detached SQLAlchemy instance — re-query if
|
||||
you need fresh DB state.
|
||||
"""
|
||||
user = create_test_user(password="TestPass123!")
|
||||
integration_client.auth.login(
|
||||
email=user["email"], password="TestPass123!",
|
||||
)
|
||||
token = integration_client._token
|
||||
|
||||
session_row = _get_session_row(integration_app, token)
|
||||
assert session_row is not None, "Session row should exist after login"
|
||||
|
||||
return {
|
||||
"user": user,
|
||||
"token": token,
|
||||
"session_id": session_row.id,
|
||||
"session_row": session_row,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionTimeouts:
|
||||
"""Sliding-window timeout behavior exercised through the public API."""
|
||||
|
||||
def test_session_valid_before_timeout(
|
||||
self, integration_client, create_test_user,
|
||||
):
|
||||
"""SESS-01 — Fresh session is accepted.
|
||||
|
||||
A session that was just created should pass all auth checks.
|
||||
This is the baseline: if this fails, every other timeout test
|
||||
is meaningless.
|
||||
"""
|
||||
user = create_test_user(password="MyPass123!")
|
||||
integration_client.auth.login(email=user["email"], password="MyPass123!")
|
||||
|
||||
result = integration_client.auth.me()
|
||||
data = assert_success(result)
|
||||
assert data["user"]["email"] == user["email"]
|
||||
|
||||
def test_idle_timeout_rejects_token(
|
||||
self, integration_client, logged_in_session, integration_app,
|
||||
):
|
||||
"""SESS-02 — Session rejected after idle period elapses.
|
||||
|
||||
Push ``last_activity_at`` far enough into the past that the
|
||||
idle window has closed. The API must return 401.
|
||||
"""
|
||||
_touch_session(
|
||||
integration_app,
|
||||
logged_in_session["session_id"],
|
||||
last_activity_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
)
|
||||
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.auth.me()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_absolute_timeout_rejects_even_active_user(
|
||||
self, integration_client, logged_in_session, integration_app,
|
||||
):
|
||||
"""SESS-03 — Absolute cap overrides recent activity.
|
||||
|
||||
Push ``created_at`` into the past so the absolute window has
|
||||
elapsed, but keep ``last_activity_at`` fresh. The session
|
||||
must still be rejected — activity cannot extend past the
|
||||
absolute limit.
|
||||
"""
|
||||
_touch_session(
|
||||
integration_app,
|
||||
logged_in_session["session_id"],
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.auth.me()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_api_request_keeps_session_alive(
|
||||
self, integration_client, logged_in_session, integration_app,
|
||||
):
|
||||
"""SESS-04 — Hitting an API endpoint extends the session.
|
||||
|
||||
Back-date ``last_activity_at`` to *just* inside the idle
|
||||
window. A subsequent API call should succeed and the session
|
||||
should remain usable — the sliding window should have reset.
|
||||
"""
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app import db
|
||||
|
||||
# Back-date to 10 seconds ago — still inside the idle window.
|
||||
_touch_session(
|
||||
integration_app,
|
||||
logged_in_session["session_id"],
|
||||
last_activity_at=datetime.now(timezone.utc) - timedelta(seconds=10),
|
||||
)
|
||||
|
||||
# This request should succeed AND extend the session.
|
||||
result = integration_client.auth.me()
|
||||
assert_success(result)
|
||||
|
||||
# After the request, last_activity_at should be much closer to now.
|
||||
with integration_app.app_context():
|
||||
refreshed = Session.query.get(logged_in_session["session_id"])
|
||||
now = datetime.now(timezone.utc)
|
||||
# Allow for clock skew / commit latency — should be within 30s.
|
||||
diff = abs((now - refreshed.last_activity_at.replace(tzinfo=timezone.utc)).total_seconds())
|
||||
assert diff < 30, (
|
||||
f"last_activity_at should be near-now after API call, "
|
||||
f"but delta is {diff:.1f}s"
|
||||
)
|
||||
|
||||
def test_revoked_session_rejected(
|
||||
self, integration_client, logged_in_session,
|
||||
):
|
||||
"""SESS-05 — Revoked session is rejected regardless of timing.
|
||||
|
||||
Revoke via the API, then verify the token is dead. This
|
||||
mirrors AUTH-12 but is included here so the timeout test
|
||||
suite is self-contained.
|
||||
"""
|
||||
integration_client.auth.revoke_session(logged_in_session["session_id"])
|
||||
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.auth.me()
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_refresh_endpoint_extends_session(
|
||||
self, integration_client, logged_in_session, integration_app,
|
||||
):
|
||||
"""SESS-06 — POST /auth/sessions/refresh extends the session.
|
||||
|
||||
The refresh endpoint exists so the frontend can proactively
|
||||
keep a session alive during idle UI periods. Verify it
|
||||
succeeds and returns a new ``expires_at``.
|
||||
"""
|
||||
result = integration_client.auth.refresh_session()
|
||||
data = assert_success(result, "session refreshed")
|
||||
|
||||
assert "expires_at" in data, "Response should include new expires_at"
|
||||
@@ -0,0 +1,205 @@
|
||||
"""Unit tests for ca_key_encryption module.
|
||||
|
||||
WHAT: Tests for the Fernet-based CA private key encryption/decryption
|
||||
utility functions.
|
||||
WHY: CA private keys are the most sensitive data in the system; we need
|
||||
to verify round-trip correctness, idempotency, and error handling.
|
||||
EXPECTED: All encrypt/decrypt operations produce correct plaintext.
|
||||
"""
|
||||
import os
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gatehouse_app.utils.ca_key_encryption import (
|
||||
CAKeyEncryptionError,
|
||||
_FERNET_PREFIX,
|
||||
_get_fernet,
|
||||
decrypt_ca_key,
|
||||
encrypt_ca_key,
|
||||
is_encrypted,
|
||||
reencrypt_ca_key,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_PEM = (
|
||||
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
|
||||
"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtz\n"
|
||||
"c2gtZWQyNTUxOQAAACBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBAAA\n"
|
||||
"-----END OPENSSH PRIVATE KEY-----"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _set_ca_encryption_key():
|
||||
"""Ensure CA_ENCRYPTION_KEY is set for every test."""
|
||||
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "test-secret-key-for-unit-tests"}):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# encrypt / decrypt round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEncryptDecryptRoundTrip:
|
||||
"""Verify that encrypt -> decrypt returns the original plaintext."""
|
||||
|
||||
def test_basic_round_trip(self):
|
||||
"""TEST: ENC-RT-01 -- Encrypt then decrypt returns original PEM."""
|
||||
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||||
decrypted = decrypt_ca_key(encrypted)
|
||||
assert decrypted == SAMPLE_PEM
|
||||
|
||||
def test_encrypted_value_has_prefix(self):
|
||||
"""TEST: ENC-RT-02 -- Encrypted output carries the $fernet$ envelope."""
|
||||
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||||
assert encrypted.startswith(_FERNET_PREFIX)
|
||||
|
||||
def test_different_ciphertext_each_time(self):
|
||||
"""TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ."""
|
||||
enc1 = encrypt_ca_key(SAMPLE_PEM)
|
||||
enc2 = encrypt_ca_key(SAMPLE_PEM)
|
||||
assert enc1 != enc2
|
||||
assert decrypt_ca_key(enc1) == SAMPLE_PEM
|
||||
assert decrypt_ca_key(enc2) == SAMPLE_PEM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Idempotency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIdempotency:
|
||||
"""The module must not double-encrypt or double-decrypt."""
|
||||
|
||||
def test_encrypt_idempotent(self):
|
||||
"""TEST: ENC-IDEM-01 -- Encrypting an already-encrypted value is a no-op."""
|
||||
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||||
double = encrypt_ca_key(encrypted)
|
||||
assert double == encrypted
|
||||
|
||||
def test_decrypt_plaintext_passthrough(self):
|
||||
"""TEST: ENC-IDEM-02 -- Decrypting a plaintext (legacy) value returns it as-is."""
|
||||
result = decrypt_ca_key(SAMPLE_PEM)
|
||||
assert result == SAMPLE_PEM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_encrypted helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsEncrypted:
|
||||
def test_encrypted_value(self):
|
||||
"""TEST: ENC-IE-01 -- is_encrypted returns True for $fernet$ values."""
|
||||
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||||
assert is_encrypted(encrypted) is True
|
||||
|
||||
def test_plaintext_value(self):
|
||||
"""TEST: ENC-IE-02 -- is_encrypted returns False for plain PEM."""
|
||||
assert is_encrypted(SAMPLE_PEM) is False
|
||||
|
||||
def test_empty_string(self):
|
||||
"""TEST: ENC-IE-03 -- is_encrypted returns False for empty string."""
|
||||
assert is_encrypted("") is False
|
||||
|
||||
def test_none_value(self):
|
||||
"""TEST: ENC-IE-04 -- is_encrypted returns False for None."""
|
||||
assert is_encrypted(None) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestErrorHandling:
|
||||
def test_encrypt_empty_raises(self):
|
||||
"""TEST: ENC-ERR-01 -- Encrypting empty string raises CAKeyEncryptionError."""
|
||||
with pytest.raises(CAKeyEncryptionError, match="empty"):
|
||||
encrypt_ca_key("")
|
||||
|
||||
def test_decrypt_empty_raises(self):
|
||||
"""TEST: ENC-ERR-02 -- Decrypting empty string raises CAKeyEncryptionError."""
|
||||
with pytest.raises(CAKeyEncryptionError, match="empty"):
|
||||
decrypt_ca_key("")
|
||||
|
||||
def test_missing_key_raises(self):
|
||||
"""TEST: ENC-ERR-03 -- Operations fail when CA_ENCRYPTION_KEY is unset."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("CA_ENCRYPTION_KEY", None)
|
||||
with pytest.raises(CAKeyEncryptionError, match="not set"):
|
||||
encrypt_ca_key(SAMPLE_PEM)
|
||||
|
||||
def test_wrong_key_raises_on_decrypt(self):
|
||||
"""TEST: ENC-ERR-04 -- Decrypting with the wrong key raises."""
|
||||
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||||
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "wrong-key"}):
|
||||
with pytest.raises(CAKeyEncryptionError, match="decryption failed"):
|
||||
decrypt_ca_key(encrypted)
|
||||
|
||||
def test_corrupted_data_raises(self):
|
||||
"""TEST: ENC-ERR-05 -- Decrypting corrupted ciphertext raises."""
|
||||
with pytest.raises(CAKeyEncryptionError):
|
||||
decrypt_ca_key("$fernet$not-a-real-token")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reencrypt_ca_key -- key rotation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReencrypt:
|
||||
def test_reencrypt_round_trip(self):
|
||||
"""TEST: ENC-RE-01 -- Re-encrypted value decrypts with the new key."""
|
||||
old_key = "old-secret-key"
|
||||
new_key = "new-secret-key"
|
||||
encrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", old_key)
|
||||
reencrypted = reencrypt_ca_key(encrypted, old_key, new_key)
|
||||
|
||||
# Verify it decrypts with the new key
|
||||
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}):
|
||||
decrypted = decrypt_ca_key(reencrypted)
|
||||
assert decrypted == SAMPLE_PEM
|
||||
|
||||
def test_reencrypt_plaintext_key(self):
|
||||
"""TEST: ENC-RE-02 -- Re-encrypting a legacy plaintext key works."""
|
||||
new_key = "brand-new-key"
|
||||
reencrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", new_key)
|
||||
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}):
|
||||
decrypted = decrypt_ca_key(reencrypted)
|
||||
assert decrypted == SAMPLE_PEM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread safety
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Concurrent encrypt/decrypt calls must not corrupt state."""
|
||||
|
||||
def test_concurrent_encrypt_decrypt(self):
|
||||
"""TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently."""
|
||||
errors = []
|
||||
results = []
|
||||
|
||||
def worker(i):
|
||||
try:
|
||||
data = f"key-data-{i}"
|
||||
enc = encrypt_ca_key(data)
|
||||
dec = decrypt_ca_key(enc)
|
||||
results.append((i, dec))
|
||||
except Exception as exc:
|
||||
errors.append((i, exc))
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=10)
|
||||
|
||||
assert not errors, f"Thread errors: {errors}"
|
||||
assert len(results) == 50
|
||||
for i, dec in results:
|
||||
assert dec == f"key-data-{i}", f"Thread {i}: expected 'key-data-{i}', got {dec!r}"
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Unit tests for CORS middleware.
|
||||
|
||||
WHAT: Tests for the CORS middleware configuration, including wildcard
|
||||
origin handling, credentials support, and preflight responses.
|
||||
WHY: CORS misconfiguration can break browser clients or leak credentials.
|
||||
EXPECTED: Correct Access-Control-* headers for all origin configurations.
|
||||
"""
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from gatehouse_app.middleware.cors import (
|
||||
_is_origin_allowed,
|
||||
_cors_origin_header,
|
||||
setup_cors,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_origin_allowed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsOriginAllowed:
|
||||
def test_empty_origin_rejected(self):
|
||||
"""TEST: CORS-01 -- Empty origin is never allowed."""
|
||||
assert _is_origin_allowed("", ["https://example.com"]) is False
|
||||
assert _is_origin_allowed(None, "*") is False
|
||||
|
||||
def test_wildcard_string(self):
|
||||
"""TEST: CORS-02 -- Wildcard string allows any origin."""
|
||||
assert _is_origin_allowed("https://evil.com", "*") is True
|
||||
|
||||
def test_wildcard_in_list(self):
|
||||
"""TEST: CORS-03 -- Wildcard in list allows any origin."""
|
||||
assert _is_origin_allowed("https://evil.com", ["*", "https://example.com"]) is True
|
||||
|
||||
def test_explicit_origin_match(self):
|
||||
"""TEST: CORS-04 -- Explicit list matches exact origin."""
|
||||
origins = ["https://example.com", "http://localhost:3000"]
|
||||
assert _is_origin_allowed("https://example.com", origins) is True
|
||||
assert _is_origin_allowed("https://evil.com", origins) is False
|
||||
|
||||
def test_empty_origins_list(self):
|
||||
"""TEST: CORS-05 -- Empty list rejects everything."""
|
||||
assert _is_origin_allowed("https://example.com", []) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cors_origin_header
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCorsOriginHeader:
|
||||
def test_wildcard_with_origin_echoes(self):
|
||||
"""TEST: CORS-HDR-01 -- Wildcard echoes request origin (for credentials)."""
|
||||
assert _cors_origin_header("*", "https://example.com") == "https://example.com"
|
||||
|
||||
def test_wildcard_without_origin(self):
|
||||
"""TEST: CORS-HDR-02 -- Wildcard with no origin returns *."""
|
||||
assert _cors_origin_header("*", None) == "*"
|
||||
|
||||
def test_wildcard_in_list_with_origin(self):
|
||||
"""TEST: CORS-HDR-03 -- Wildcard in list echoes request origin."""
|
||||
result = _cors_origin_header(["*", "https://example.com"], "https://any.com")
|
||||
assert result == "https://any.com"
|
||||
|
||||
def test_specific_origin_match(self):
|
||||
"""TEST: CORS-HDR-04 -- Matching origin is echoed."""
|
||||
origins = ["https://example.com"]
|
||||
assert _cors_origin_header(origins, "https://example.com") == "https://example.com"
|
||||
|
||||
def test_specific_origin_no_match(self):
|
||||
"""TEST: CORS-HDR-05 -- Non-matching origin returns None."""
|
||||
origins = ["https://example.com"]
|
||||
assert _cors_origin_header(origins, "https://evil.com") is None
|
||||
|
||||
def test_no_origin_no_match(self):
|
||||
"""TEST: CORS-HDR-06 -- No origin with specific list returns None."""
|
||||
origins = ["https://example.com"]
|
||||
assert _cors_origin_header(origins, None) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: preflight response
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPreflightIntegration:
|
||||
@pytest.fixture
|
||||
def app_wildcard(self):
|
||||
app = Flask(__name__)
|
||||
app.config["CORS_ORIGINS"] = "*"
|
||||
app.config["CORS_SUPPORTS_CREDENTIALS"] = True
|
||||
setup_cors(app)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def app_specific(self):
|
||||
app = Flask(__name__)
|
||||
app.config["CORS_ORIGINS"] = ["https://example.com"]
|
||||
app.config["CORS_SUPPORTS_CREDENTIALS"] = True
|
||||
setup_cors(app)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
def test_wildcard_preflight_echoes_origin(self, app_wildcard):
|
||||
"""TEST: CORS-PF-01 -- Wildcard preflight echoes request origin."""
|
||||
with app_wildcard.test_client() as client:
|
||||
resp = client.options("/", headers={"Origin": "https://example.com"})
|
||||
assert resp.status_code == 204
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "https://example.com"
|
||||
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
|
||||
|
||||
def test_specific_origin_preflight(self, app_specific):
|
||||
"""TEST: CORS-PF-02 -- Specific origin preflight allows matching origin."""
|
||||
with app_specific.test_client() as client:
|
||||
resp = client.options("/", headers={"Origin": "https://example.com"})
|
||||
assert resp.status_code == 204
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "https://example.com"
|
||||
assert resp.headers.get("Access-Control-Allow-Credentials") == "true"
|
||||
|
||||
def test_specific_origin_rejects_unknown(self, app_specific):
|
||||
"""TEST: CORS-PF-03 -- Non-matching origin gets no CORS headers."""
|
||||
with app_specific.test_client() as client:
|
||||
resp = client.options("/", headers={"Origin": "https://evil.com"})
|
||||
# No preflight handler runs, Flask returns default
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
||||
@@ -0,0 +1,164 @@
|
||||
"""Unit tests for encryption module (general-purpose Fernet encryption).
|
||||
|
||||
WHAT: Tests for the PBKDF2-based Fernet encryption/decryption used for
|
||||
OAuth tokens and client secrets.
|
||||
WHY: These utilities protect access tokens and client secrets; we need
|
||||
to verify round-trip correctness and error handling.
|
||||
EXPECTED: All encrypt/decrypt operations produce correct plaintext.
|
||||
"""
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from gatehouse_app.utils.encryption import (
|
||||
SALT_LENGTH,
|
||||
_get_fernet_key,
|
||||
decrypt,
|
||||
encrypt,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SECRET_KEY = "test-encryption-secret-key"
|
||||
SAMPLE_DATA = "access_token=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# encrypt / decrypt round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEncryptDecryptRoundTrip:
|
||||
"""Verify that encrypt -> decrypt returns the original plaintext."""
|
||||
|
||||
def test_basic_round_trip(self):
|
||||
"""TEST: ENC-RT-01 -- Encrypt then decrypt returns original data."""
|
||||
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
||||
decrypted = decrypt(encrypted, secret_key=SECRET_KEY)
|
||||
assert decrypted == SAMPLE_DATA
|
||||
|
||||
def test_encrypted_is_base64(self):
|
||||
"""TEST: ENC-RT-02 -- Encrypted output is valid base64."""
|
||||
import base64
|
||||
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
||||
# Should not raise
|
||||
base64.urlsafe_b64decode(encrypted.encode())
|
||||
|
||||
def test_different_ciphertext_each_time(self):
|
||||
"""TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ."""
|
||||
enc1 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
||||
enc2 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
||||
assert enc1 != enc2
|
||||
assert decrypt(enc1, secret_key=SECRET_KEY) == SAMPLE_DATA
|
||||
assert decrypt(enc2, secret_key=SECRET_KEY) == SAMPLE_DATA
|
||||
|
||||
def test_round_trip_unicode(self):
|
||||
"""TEST: ENC-RT-04 -- Unicode data round-trips correctly."""
|
||||
data = "token=cafe\u00e9\u00f1\u00fc"
|
||||
encrypted = encrypt(data, secret_key=SECRET_KEY)
|
||||
assert decrypt(encrypted, secret_key=SECRET_KEY) == data
|
||||
|
||||
def test_round_trip_long_data(self):
|
||||
"""TEST: ENC-RT-05 -- Large data round-trips correctly."""
|
||||
data = "x" * 10000
|
||||
encrypted = encrypt(data, secret_key=SECRET_KEY)
|
||||
assert decrypt(encrypted, secret_key=SECRET_KEY) == data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty / edge inputs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_encrypt_empty_returns_empty(self):
|
||||
"""TEST: ENC-EDGE-01 -- Encrypting empty string returns empty."""
|
||||
assert encrypt("", secret_key=SECRET_KEY) == ""
|
||||
|
||||
def test_decrypt_empty_returns_empty(self):
|
||||
"""TEST: ENC-EDGE-02 -- Decrypting empty string returns empty."""
|
||||
assert decrypt("", secret_key=SECRET_KEY) == ""
|
||||
|
||||
def test_missing_key_raises_on_encrypt(self):
|
||||
"""TEST: ENC-EDGE-03 -- Missing key raises ValueError on encrypt."""
|
||||
with pytest.raises(ValueError, match="Encryption key not configured"):
|
||||
encrypt("data", secret_key="")
|
||||
|
||||
def test_missing_key_raises_on_decrypt(self):
|
||||
"""TEST: ENC-EDGE-04 -- Missing key raises ValueError on decrypt."""
|
||||
with pytest.raises(ValueError, match="Encryption key not configured"):
|
||||
decrypt("something", secret_key="")
|
||||
|
||||
def test_wrong_key_raises_on_decrypt(self):
|
||||
"""TEST: ENC-EDGE-05 -- Wrong key raises ValueError on decrypt."""
|
||||
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
||||
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||
decrypt(encrypted, secret_key="wrong-key")
|
||||
|
||||
def test_corrupted_data_raises(self):
|
||||
"""TEST: ENC-EDGE-06 -- Corrupted ciphertext raises ValueError."""
|
||||
import base64
|
||||
bad = base64.urlsafe_b64encode(b"not-valid-fernet-data").decode()
|
||||
with pytest.raises(ValueError):
|
||||
decrypt(bad, secret_key=SECRET_KEY)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_fernet_key — PBKDF2 derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestKeyDerivation:
|
||||
def test_same_salt_same_key(self):
|
||||
"""TEST: ENC-KD-01 -- Same salt produces the same derived key."""
|
||||
salt = b"\x00" * SALT_LENGTH
|
||||
key1 = _get_fernet_key(SECRET_KEY, salt=salt)
|
||||
key2 = _get_fernet_key(SECRET_KEY, salt=salt)
|
||||
assert key1 == key2
|
||||
|
||||
def test_different_salt_different_key(self):
|
||||
"""TEST: ENC-KD-02 -- Different salts produce different keys."""
|
||||
salt1 = b"\x00" * SALT_LENGTH
|
||||
salt2 = b"\xff" * SALT_LENGTH
|
||||
key1 = _get_fernet_key(SECRET_KEY, salt=salt1)
|
||||
key2 = _get_fernet_key(SECRET_KEY, salt=salt2)
|
||||
assert key1 != key2
|
||||
|
||||
def test_auto_salt_length(self):
|
||||
"""TEST: ENC-KD-03 -- Auto-generated salt is 16 bytes."""
|
||||
key = _get_fernet_key(SECRET_KEY)
|
||||
# If it didn't raise, the salt was valid
|
||||
assert len(key) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread safety
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Concurrent encrypt/decrypt calls must not corrupt state."""
|
||||
|
||||
def test_concurrent_encrypt_decrypt(self):
|
||||
"""TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently."""
|
||||
errors = []
|
||||
results = []
|
||||
|
||||
def worker(i):
|
||||
try:
|
||||
data = f"token-{i}-secret"
|
||||
enc = encrypt(data, secret_key=SECRET_KEY)
|
||||
dec = decrypt(enc, secret_key=SECRET_KEY)
|
||||
results.append((i, dec))
|
||||
except Exception as exc:
|
||||
errors.append((i, exc))
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=10)
|
||||
|
||||
assert not errors, f"Thread errors: {errors}"
|
||||
assert len(results) == 50
|
||||
for i, dec in results:
|
||||
assert dec == f"token-{i}-secret", f"Thread {i}: mismatch"
|
||||
Reference in New Issue
Block a user