Merge pull request #34 from CoryHawkless/cory-wip-session
fix(cors): handle wildcard origin with credentials and add unit tests
This commit is contained in:
@@ -116,7 +116,7 @@ def login():
|
||||
|
||||
remember_me = data.get("remember_me", False)
|
||||
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me)
|
||||
duration = 2592000 if remember_me else 86400
|
||||
duration = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800) if remember_me else None
|
||||
is_compliance_only = policy_result.create_compliance_only_session
|
||||
|
||||
user_session = AuthService.create_session(user, duration_seconds=duration, is_compliance_only=is_compliance_only)
|
||||
@@ -227,6 +227,32 @@ def revoke_session(session_id):
|
||||
return api_response(message="Session revoked successfully")
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/sessions/refresh", methods=["POST"])
|
||||
@login_required
|
||||
def refresh_session():
|
||||
"""Extend the current session's idle window.
|
||||
|
||||
The server already refreshes the session on every authenticated
|
||||
request, but this endpoint exists so the frontend can proactively
|
||||
keep a session alive (e.g. a heartbeat while the user is reading
|
||||
a long page with no API calls).
|
||||
|
||||
Returns the new ``expires_at`` so the frontend can display a
|
||||
countdown or warning before the absolute cap.
|
||||
"""
|
||||
session = g.current_session
|
||||
session.refresh()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"expires_at": session.expires_at.isoformat() + "Z"
|
||||
if session.expires_at.isoformat()[-1] != "Z"
|
||||
else session.expires_at.isoformat(),
|
||||
},
|
||||
message="Session refreshed",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/token", methods=["GET"])
|
||||
@login_required
|
||||
def get_token():
|
||||
@@ -246,7 +272,8 @@ def get_token():
|
||||
parsed_redirect = urlparse(redirect_url)
|
||||
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
|
||||
|
||||
if redirect_origin not in allowed_origins:
|
||||
wildcard = "*" in allowed_origins
|
||||
if not wildcard and redirect_origin not in allowed_origins:
|
||||
return api_response(success=False, message="Redirect URL is not allowed.", status=400, error_type="INVALID_REDIRECT")
|
||||
|
||||
sep = "&" if "?" in redirect_url else "?"
|
||||
|
||||
@@ -190,8 +190,8 @@ def select_organization():
|
||||
if not member:
|
||||
return api_response(success=False, message="You are not a member of this organization", status=403, error_type="FORBIDDEN")
|
||||
|
||||
from gatehouse_app.services.session_service import SessionService
|
||||
session = SessionService.create_session(user=user, organization_id=organization_id)
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user)
|
||||
state_record.mark_used()
|
||||
|
||||
provider_type_val = state_record.provider_type.value if isinstance(state_record.provider_type, _AuthMethodType) else state_record.provider_type
|
||||
|
||||
@@ -1,6 +1,44 @@
|
||||
"""CORS middleware configuration."""
|
||||
from flask import request, make_response
|
||||
|
||||
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
ALLOWED_HEADERS = (
|
||||
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
|
||||
"Cache-Control, Pragma, X-WebAuthn-Session-Token"
|
||||
)
|
||||
|
||||
|
||||
def _is_origin_allowed(origin, cors_origins):
|
||||
"""Return True if the origin is permitted by the CORS config.
|
||||
|
||||
Handles both wildcard ("*") and explicit origin lists.
|
||||
"""
|
||||
if not origin:
|
||||
return False
|
||||
if cors_origins == "*":
|
||||
return True
|
||||
if isinstance(cors_origins, list):
|
||||
if "*" in cors_origins:
|
||||
return True
|
||||
return origin in cors_origins
|
||||
return False
|
||||
|
||||
|
||||
def _cors_origin_header(cors_origins, request_origin):
|
||||
"""Return the value for Access-Control-Allow-Origin.
|
||||
|
||||
Per the CORS spec, browsers reject ``*`` when credentials are involved,
|
||||
so we echo the request origin when wildcard + credentials is configured.
|
||||
"""
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
if allow_all and request_origin:
|
||||
return request_origin
|
||||
if allow_all:
|
||||
return "*"
|
||||
if request_origin and request_origin in cors_origins:
|
||||
return request_origin
|
||||
return None
|
||||
|
||||
|
||||
def setup_cors(app):
|
||||
"""
|
||||
@@ -9,6 +47,7 @@ def setup_cors(app):
|
||||
Args:
|
||||
app: Flask application instance
|
||||
"""
|
||||
supports_credentials = app.config.get("CORS_SUPPORTS_CREDENTIALS", True)
|
||||
|
||||
@app.before_request
|
||||
def handle_preflight():
|
||||
@@ -16,49 +55,33 @@ def setup_cors(app):
|
||||
if request.method == "OPTIONS":
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
|
||||
if allow_all:
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
return response
|
||||
elif origin and origin in cors_origins:
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
|
||||
|
||||
if not _is_origin_allowed(origin, cors_origins):
|
||||
return None
|
||||
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = _cors_origin_header(cors_origins, origin)
|
||||
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
|
||||
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
|
||||
if supports_credentials:
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
return response
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
return response
|
||||
|
||||
@app.after_request
|
||||
def after_request_cors(response):
|
||||
"""Add additional CORS headers if needed."""
|
||||
"""Add CORS headers to non-preflight responses."""
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
|
||||
if allow_all:
|
||||
# When allowing all origins, set header to "*"
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
elif origin and origin in cors_origins:
|
||||
# When allowing specific origins, echo the request origin
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
allow_origin = _cors_origin_header(cors_origins, origin)
|
||||
if allow_origin:
|
||||
response.headers["Access-Control-Allow-Origin"] = allow_origin
|
||||
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
|
||||
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
|
||||
if supports_credentials:
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
|
||||
return response
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Session model."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from flask import current_app
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
@@ -38,33 +39,71 @@ class Session(BaseModel):
|
||||
return f"<Session user_id={self.user_id} status={self.status}>"
|
||||
|
||||
def is_active(self):
|
||||
"""Check if session is currently active."""
|
||||
"""Check if session is currently active.
|
||||
|
||||
Sessions are evaluated against two independent timeouts:
|
||||
- Idle timeout: expires if no request has been made within
|
||||
SESSION_IDLE_TIMEOUT seconds (default 15 min).
|
||||
- Absolute timeout: expires if SESSION_ABSOLUTE_TIMEOUT seconds
|
||||
have elapsed since the session was created (default 8 h),
|
||||
regardless of activity.
|
||||
|
||||
A session must satisfy *both* constraints to remain active.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
created_at = self.created_at
|
||||
last_activity_at = self.last_activity_at
|
||||
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
if last_activity_at.tzinfo is None:
|
||||
last_activity_at = last_activity_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
|
||||
|
||||
idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout)
|
||||
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
|
||||
|
||||
return (
|
||||
self.status == SessionStatus.ACTIVE
|
||||
and expires_at > now
|
||||
and now < idle_expires_at
|
||||
and now < absolute_expires_at
|
||||
and self.deleted_at is None
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if session has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return now > expires_at
|
||||
"""Check if session has expired (either idle or absolute)."""
|
||||
return not self.is_active() and self.status != SessionStatus.REVOKED
|
||||
|
||||
def refresh(self, duration_seconds: int = 86400):
|
||||
"""Refresh session expiration.
|
||||
def refresh(self, duration_seconds: int = None):
|
||||
"""Extend the session expiration using a sliding window.
|
||||
|
||||
The new ``expires_at`` is set to *now + idle timeout*, but is
|
||||
capped so that the session never exceeds the absolute lifetime
|
||||
(``created_at + absolute timeout``).
|
||||
|
||||
Args:
|
||||
duration_seconds: New session duration in seconds
|
||||
duration_seconds: Override for the idle timeout. When *None*
|
||||
(the common case), the value is read from
|
||||
``SESSION_IDLE_TIMEOUT`` in the Flask config.
|
||||
"""
|
||||
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
|
||||
self.last_activity_at = datetime.now(timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if duration_seconds is None:
|
||||
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
|
||||
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
|
||||
|
||||
idle_expires_at = now + timedelta(seconds=duration_seconds)
|
||||
|
||||
created_at = self.created_at
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
|
||||
|
||||
self.expires_at = min(idle_expires_at, absolute_expires_at)
|
||||
self.last_activity_at = now
|
||||
db.session.commit()
|
||||
|
||||
def revoke(self, reason: str = None):
|
||||
|
||||
@@ -140,18 +140,26 @@ class AuthService:
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_session(user, duration_seconds=86400, is_compliance_only=False):
|
||||
def create_session(user, duration_seconds=None, is_compliance_only=False):
|
||||
"""
|
||||
Create a new session for the user.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
duration_seconds: Session duration in seconds
|
||||
duration_seconds: Session idle timeout in seconds.
|
||||
When None, defaults to SESSION_IDLE_TIMEOUT from config.
|
||||
The absolute lifetime is always enforced by Session.is_active()
|
||||
regardless of this value.
|
||||
is_compliance_only: Whether this is a compliance-only session (limited access)
|
||||
|
||||
Returns:
|
||||
Session instance
|
||||
"""
|
||||
from flask import current_app
|
||||
|
||||
if duration_seconds is None:
|
||||
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
|
||||
# Generate session token
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
@@ -263,7 +263,7 @@ def authenticate_with_provider(
|
||||
state_record.mark_used()
|
||||
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user, organization_id=organization_id)
|
||||
session = AuthService.create_session(user=user)
|
||||
|
||||
AuditService.log_external_auth_login(
|
||||
user_id=user.id,
|
||||
|
||||
@@ -10,10 +10,10 @@ class SessionService:
|
||||
@staticmethod
|
||||
def get_active_session_by_token(token):
|
||||
"""Get active session by token.
|
||||
|
||||
|
||||
Args:
|
||||
token: The session token string
|
||||
|
||||
|
||||
Returns:
|
||||
Session object if found and active, None otherwise
|
||||
"""
|
||||
@@ -23,6 +23,8 @@ class SessionService:
|
||||
token=token,
|
||||
status=SessionStatus.ACTIVE,
|
||||
deleted_at=None
|
||||
).filter(
|
||||
Session.expires_at > datetime.now(timezone.utc)
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -138,7 +138,7 @@ class SuperadminAuthService:
|
||||
Dictionary with emergency session info
|
||||
"""
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.services.session_service import SessionService
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
# Verify target user exists
|
||||
@@ -147,7 +147,7 @@ class SuperadminAuthService:
|
||||
raise ValueError(f"Target user not found: {target_user_id}")
|
||||
|
||||
# Create emergency session for the target user
|
||||
emergency_session = SessionService.create_session(
|
||||
emergency_session = AuthService.create_session(
|
||||
user=target_user,
|
||||
duration_seconds=duration_minutes * 60,
|
||||
is_compliance_only=False
|
||||
|
||||
@@ -59,11 +59,9 @@ def login_required(f):
|
||||
error_type="SESSION_INACTIVE"
|
||||
)
|
||||
|
||||
# Update last_activity_at timestamp
|
||||
from datetime import datetime, timezone
|
||||
session.last_activity_at = datetime.now(timezone.utc)
|
||||
from gatehouse_app import db
|
||||
db.session.commit()
|
||||
# Extend session via sliding window (updates last_activity_at
|
||||
# and recalculates expires_at within the idle / absolute caps).
|
||||
session.refresh()
|
||||
|
||||
# Set context variables
|
||||
g.current_user = session.user
|
||||
|
||||
Reference in New Issue
Block a user