Merge pull request #34 from CoryHawkless/cory-wip-session

fix(cors): handle wildcard origin with credentials and add unit tests
This commit is contained in:
2026-04-26 22:34:50 +08:00
committed by GitHub
18 changed files with 953 additions and 68 deletions
+29 -2
View File
@@ -116,7 +116,7 @@ def login():
remember_me = data.get("remember_me", False)
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me)
duration = 2592000 if remember_me else 86400
duration = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800) if remember_me else None
is_compliance_only = policy_result.create_compliance_only_session
user_session = AuthService.create_session(user, duration_seconds=duration, is_compliance_only=is_compliance_only)
@@ -227,6 +227,32 @@ def revoke_session(session_id):
return api_response(message="Session revoked successfully")
@api_v1_bp.route("/auth/sessions/refresh", methods=["POST"])
@login_required
def refresh_session():
"""Extend the current session's idle window.
The server already refreshes the session on every authenticated
request, but this endpoint exists so the frontend can proactively
keep a session alive (e.g. a heartbeat while the user is reading
a long page with no API calls).
Returns the new ``expires_at`` so the frontend can display a
countdown or warning before the absolute cap.
"""
session = g.current_session
session.refresh()
return api_response(
data={
"expires_at": session.expires_at.isoformat() + "Z"
if session.expires_at.isoformat()[-1] != "Z"
else session.expires_at.isoformat(),
},
message="Session refreshed",
)
@api_v1_bp.route("/auth/token", methods=["GET"])
@login_required
def get_token():
@@ -246,7 +272,8 @@ def get_token():
parsed_redirect = urlparse(redirect_url)
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
if redirect_origin not in allowed_origins:
wildcard = "*" in allowed_origins
if not wildcard and redirect_origin not in allowed_origins:
return api_response(success=False, message="Redirect URL is not allowed.", status=400, error_type="INVALID_REDIRECT")
sep = "&" if "?" in redirect_url else "?"
+2 -2
View File
@@ -190,8 +190,8 @@ def select_organization():
if not member:
return api_response(success=False, message="You are not a member of this organization", status=403, error_type="FORBIDDEN")
from gatehouse_app.services.session_service import SessionService
session = SessionService.create_session(user=user, organization_id=organization_id)
from gatehouse_app.services.auth_service import AuthService
session = AuthService.create_session(user=user)
state_record.mark_used()
provider_type_val = state_record.provider_type.value if isinstance(state_record.provider_type, _AuthMethodType) else state_record.provider_type
+59 -36
View File
@@ -1,6 +1,44 @@
"""CORS middleware configuration."""
from flask import request, make_response
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
ALLOWED_HEADERS = (
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
"Cache-Control, Pragma, X-WebAuthn-Session-Token"
)
def _is_origin_allowed(origin, cors_origins):
"""Return True if the origin is permitted by the CORS config.
Handles both wildcard ("*") and explicit origin lists.
"""
if not origin:
return False
if cors_origins == "*":
return True
if isinstance(cors_origins, list):
if "*" in cors_origins:
return True
return origin in cors_origins
return False
def _cors_origin_header(cors_origins, request_origin):
"""Return the value for Access-Control-Allow-Origin.
Per the CORS spec, browsers reject ``*`` when credentials are involved,
so we echo the request origin when wildcard + credentials is configured.
"""
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all and request_origin:
return request_origin
if allow_all:
return "*"
if request_origin and request_origin in cors_origins:
return request_origin
return None
def setup_cors(app):
"""
@@ -9,6 +47,7 @@ def setup_cors(app):
Args:
app: Flask application instance
"""
supports_credentials = app.config.get("CORS_SUPPORTS_CREDENTIALS", True)
@app.before_request
def handle_preflight():
@@ -16,49 +55,33 @@ def setup_cors(app):
if request.method == "OPTIONS":
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all:
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
elif origin and origin in cors_origins:
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
if not _is_origin_allowed(origin, cors_origins):
return None
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = _cors_origin_header(cors_origins, origin)
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
if supports_credentials:
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
@app.after_request
def after_request_cors(response):
"""Add additional CORS headers if needed."""
"""Add CORS headers to non-preflight responses."""
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all:
# When allowing all origins, set header to "*"
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Max-Age"] = "3600"
elif origin and origin in cors_origins:
# When allowing specific origins, echo the request origin
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
response.headers["Access-Control-Allow-Credentials"] = "true"
allow_origin = _cors_origin_header(cors_origins, origin)
if allow_origin:
response.headers["Access-Control-Allow-Origin"] = allow_origin
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
if supports_credentials:
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
return response
+55 -16
View File
@@ -1,5 +1,6 @@
"""Session model."""
from datetime import datetime, timedelta, timezone
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import SessionStatus
@@ -38,33 +39,71 @@ class Session(BaseModel):
return f"<Session user_id={self.user_id} status={self.status}>"
def is_active(self):
"""Check if session is currently active."""
"""Check if session is currently active.
Sessions are evaluated against two independent timeouts:
- Idle timeout: expires if no request has been made within
SESSION_IDLE_TIMEOUT seconds (default 15 min).
- Absolute timeout: expires if SESSION_ABSOLUTE_TIMEOUT seconds
have elapsed since the session was created (default 8 h),
regardless of activity.
A session must satisfy *both* constraints to remain active.
"""
now = datetime.now(timezone.utc)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
created_at = self.created_at
last_activity_at = self.last_activity_at
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
if last_activity_at.tzinfo is None:
last_activity_at = last_activity_at.replace(tzinfo=timezone.utc)
idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout)
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
return (
self.status == SessionStatus.ACTIVE
and expires_at > now
and now < idle_expires_at
and now < absolute_expires_at
and self.deleted_at is None
)
def is_expired(self):
"""Check if session has expired."""
now = datetime.now(timezone.utc)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
"""Check if session has expired (either idle or absolute)."""
return not self.is_active() and self.status != SessionStatus.REVOKED
def refresh(self, duration_seconds: int = 86400):
"""Refresh session expiration.
def refresh(self, duration_seconds: int = None):
"""Extend the session expiration using a sliding window.
The new ``expires_at`` is set to *now + idle timeout*, but is
capped so that the session never exceeds the absolute lifetime
(``created_at + absolute timeout``).
Args:
duration_seconds: New session duration in seconds
duration_seconds: Override for the idle timeout. When *None*
(the common case), the value is read from
``SESSION_IDLE_TIMEOUT`` in the Flask config.
"""
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
self.last_activity_at = datetime.now(timezone.utc)
now = datetime.now(timezone.utc)
if duration_seconds is None:
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
idle_expires_at = now + timedelta(seconds=duration_seconds)
created_at = self.created_at
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=timezone.utc)
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
self.expires_at = min(idle_expires_at, absolute_expires_at)
self.last_activity_at = now
db.session.commit()
def revoke(self, reason: str = None):
+10 -2
View File
@@ -140,18 +140,26 @@ class AuthService:
return user
@staticmethod
def create_session(user, duration_seconds=86400, is_compliance_only=False):
def create_session(user, duration_seconds=None, is_compliance_only=False):
"""
Create a new session for the user.
Args:
user: User instance
duration_seconds: Session duration in seconds
duration_seconds: Session idle timeout in seconds.
When None, defaults to SESSION_IDLE_TIMEOUT from config.
The absolute lifetime is always enforced by Session.is_active()
regardless of this value.
is_compliance_only: Whether this is a compliance-only session (limited access)
Returns:
Session instance
"""
from flask import current_app
if duration_seconds is None:
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
# Generate session token
token = secrets.token_urlsafe(32)
@@ -263,7 +263,7 @@ def authenticate_with_provider(
state_record.mark_used()
from gatehouse_app.services.auth_service import AuthService
session = AuthService.create_session(user=user, organization_id=organization_id)
session = AuthService.create_session(user=user)
AuditService.log_external_auth_login(
user_id=user.id,
+4 -2
View File
@@ -10,10 +10,10 @@ class SessionService:
@staticmethod
def get_active_session_by_token(token):
"""Get active session by token.
Args:
token: The session token string
Returns:
Session object if found and active, None otherwise
"""
@@ -23,6 +23,8 @@ class SessionService:
token=token,
status=SessionStatus.ACTIVE,
deleted_at=None
).filter(
Session.expires_at > datetime.now(timezone.utc)
).first()
@staticmethod
@@ -138,7 +138,7 @@ class SuperadminAuthService:
Dictionary with emergency session info
"""
from gatehouse_app.models.user.user import User
from gatehouse_app.services.session_service import SessionService
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.audit_service import AuditService
# Verify target user exists
@@ -147,7 +147,7 @@ class SuperadminAuthService:
raise ValueError(f"Target user not found: {target_user_id}")
# Create emergency session for the target user
emergency_session = SessionService.create_session(
emergency_session = AuthService.create_session(
user=target_user,
duration_seconds=duration_minutes * 60,
is_compliance_only=False
+3 -5
View File
@@ -59,11 +59,9 @@ def login_required(f):
error_type="SESSION_INACTIVE"
)
# Update last_activity_at timestamp
from datetime import datetime, timezone
session.last_activity_at = datetime.now(timezone.utc)
from gatehouse_app import db
db.session.commit()
# Extend session via sliding window (updates last_activity_at
# and recalculates expires_at within the idle / absolute caps).
session.refresh()
# Set context variables
g.current_user = session.user