290 lines
9.1 KiB
Python
290 lines
9.1 KiB
Python
"""OIDC Session Service for session management during OIDC flow."""
|
|
import secrets
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
from datetime import timezone
|
|
from flask import current_app, g
|
|
|
|
from gatehouse_app.extensions import db
|
|
from gatehouse_app.models import OIDCSession, OIDCClient, User
|
|
from gatehouse_app.exceptions.validation_exceptions import NotFoundError, ValidationError
|
|
|
|
|
|
class OIDCSessionService:
|
|
"""Service for managing OIDC authentication sessions.
|
|
|
|
This service handles:
|
|
- Creating OIDC sessions during authorization flow
|
|
- Validating sessions with state and nonce
|
|
- Managing PKCE code challenges
|
|
- Cleaning up expired sessions
|
|
"""
|
|
|
|
@staticmethod
|
|
def _generate_state() -> str:
|
|
"""Generate a secure state parameter.
|
|
|
|
Returns:
|
|
URL-safe base64 encoded state
|
|
"""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
@staticmethod
|
|
def _generate_nonce() -> str:
|
|
"""Generate a secure nonce for OIDC.
|
|
|
|
Returns:
|
|
URL-safe base64 encoded nonce
|
|
"""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
@staticmethod
|
|
def _generate_code_challenge(verifier: str, method: str = "S256") -> str:
|
|
"""Generate a PKCE code challenge from verifier.
|
|
|
|
Args:
|
|
verifier: The code verifier
|
|
method: Challenge method ("S256" or "plain")
|
|
|
|
Returns:
|
|
Code challenge string
|
|
"""
|
|
import hashlib
|
|
import base64
|
|
|
|
if method == "S256":
|
|
digest = hashlib.sha256(verifier.encode()).digest()
|
|
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
|
elif method == "plain":
|
|
return verifier
|
|
else:
|
|
raise ValueError(f"Unsupported code challenge method: {method}")
|
|
|
|
@classmethod
|
|
def validate_code_verifier(cls, code_verifier: str, code_challenge: str,
|
|
method: str = "S256") -> bool:
|
|
"""Validate a PKCE code verifier against the stored challenge.
|
|
|
|
Args:
|
|
code_verifier: The code verifier from the token request
|
|
code_challenge: The code challenge from the authorization request
|
|
method: The challenge method used
|
|
|
|
Returns:
|
|
True if validation succeeds
|
|
"""
|
|
if not code_verifier or not code_challenge:
|
|
return False
|
|
|
|
# Validate code verifier length (43-128 characters)
|
|
if method == "S256" and not (43 <= len(code_verifier) <= 128):
|
|
return False
|
|
|
|
# Calculate expected challenge
|
|
expected_challenge = cls._generate_code_challenge(code_verifier, method)
|
|
|
|
return secrets.compare_digest(expected_challenge, code_challenge)
|
|
|
|
@classmethod
|
|
def create_session(
|
|
cls,
|
|
user_id: str,
|
|
client_id: str,
|
|
state: str = None,
|
|
nonce: str = None,
|
|
redirect_uri: str = None,
|
|
scope: list = None,
|
|
code_challenge: str = None,
|
|
code_challenge_method: str = None,
|
|
lifetime_seconds: int = 600
|
|
) -> OIDCSession:
|
|
"""Create a new OIDC session for the authorization flow.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
client_id: The OIDC client ID
|
|
state: State parameter (generated if not provided)
|
|
nonce: Nonce for ID token validation (generated if not provided)
|
|
redirect_uri: Redirect URI from authorization request
|
|
scope: Requested scopes
|
|
code_challenge: PKCE code challenge
|
|
code_challenge_method: PKCE method ("S256" or "plain")
|
|
lifetime_seconds: Session lifetime in seconds
|
|
|
|
Returns:
|
|
OIDCSession instance
|
|
"""
|
|
# Generate state and nonce if not provided
|
|
state = state or cls._generate_state()
|
|
nonce = nonce or cls._generate_nonce()
|
|
|
|
session = OIDCSession.create_session(
|
|
user_id=user_id,
|
|
client_id=client_id,
|
|
state=state,
|
|
nonce=nonce,
|
|
redirect_uri=redirect_uri,
|
|
scope=scope,
|
|
code_challenge=code_challenge,
|
|
code_challenge_method=code_challenge_method,
|
|
lifetime_seconds=lifetime_seconds,
|
|
)
|
|
|
|
return session
|
|
|
|
@classmethod
|
|
def validate_session(cls, state: str, nonce: str = None) -> Tuple[OIDCSession, User]:
|
|
"""Validate an OIDC session by state and optionally nonce.
|
|
|
|
Args:
|
|
state: The state parameter
|
|
nonce: The nonce to validate (optional)
|
|
|
|
Returns:
|
|
Tuple of (OIDCSession, User)
|
|
|
|
Raises:
|
|
ValidationError: If session is invalid
|
|
NotFoundError: If session not found
|
|
"""
|
|
session = OIDCSession.get_by_state(state)
|
|
|
|
if not session:
|
|
raise NotFoundError("OIDC session not found or expired")
|
|
|
|
if session.is_expired():
|
|
raise ValidationError("OIDC session has expired")
|
|
|
|
# Validate nonce if provided
|
|
if nonce and not session.validate_nonce(nonce):
|
|
raise ValidationError("Invalid nonce")
|
|
|
|
# Get user
|
|
user = User.query.get(session.user_id)
|
|
if not user:
|
|
raise NotFoundError("User not found")
|
|
|
|
return session, user
|
|
|
|
@classmethod
|
|
def validate_pkce(cls, session: OIDCSession, code_verifier: str) -> bool:
|
|
"""Validate PKCE code verifier against the session's code challenge.
|
|
|
|
Args:
|
|
session: OIDCSession instance
|
|
code_verifier: The code verifier from token request
|
|
|
|
Returns:
|
|
True if validation succeeds
|
|
|
|
Raises:
|
|
ValidationError: If PKCE validation fails
|
|
"""
|
|
if not session.code_challenge:
|
|
# No PKCE was used, skip validation
|
|
return True
|
|
|
|
if not code_verifier:
|
|
raise ValidationError("code_verifier is required")
|
|
|
|
is_valid = session.validate_code_challenge(code_verifier)
|
|
|
|
if not is_valid:
|
|
raise ValidationError("Invalid code_verifier")
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def mark_session_authenticated(cls, session: OIDCSession) -> OIDCSession:
|
|
"""Mark a session as authenticated (user has logged in).
|
|
|
|
Args:
|
|
session: OIDCSession instance
|
|
|
|
Returns:
|
|
Updated OIDCSession instance
|
|
"""
|
|
session.mark_authenticated()
|
|
return session
|
|
|
|
@classmethod
|
|
def cleanup_expired_sessions(cls, older_than_hours: int = 24) -> int:
|
|
"""Remove expired OIDC sessions.
|
|
|
|
Args:
|
|
older_than_hours: Only delete sessions expired more than this many hours ago
|
|
|
|
Returns:
|
|
Number of sessions deleted
|
|
"""
|
|
from datetime import timedelta
|
|
|
|
cutoff = datetime.now(timezone.utc) - timedelta(hours=older_than_hours)
|
|
|
|
# Get expired sessions
|
|
expired_sessions = OIDCSession.query.filter(
|
|
OIDCSession.expires_at < datetime.now(timezone.utc),
|
|
OIDCSession.deleted_at == None
|
|
).all()
|
|
|
|
count = 0
|
|
for session in expired_sessions:
|
|
# Only hard delete if past the grace period
|
|
if session.expires_at < cutoff:
|
|
session.delete()
|
|
count += 1
|
|
|
|
return count
|
|
|
|
@classmethod
|
|
def get_session_by_state(cls, state: str) -> Optional[OIDCSession]:
|
|
"""Get an OIDC session by state.
|
|
|
|
Args:
|
|
state: The state parameter
|
|
|
|
Returns:
|
|
OIDCSession instance or None
|
|
"""
|
|
return OIDCSession.get_by_state(state)
|
|
|
|
@classmethod
|
|
def validate_redirect_uri(cls, client_id: str, redirect_uri: str) -> bool:
|
|
"""Validate that a redirect URI is allowed for a client.
|
|
|
|
Args:
|
|
client_id: The OIDC client ID
|
|
redirect_uri: The redirect URI to validate
|
|
|
|
Returns:
|
|
True if redirect URI is allowed
|
|
"""
|
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
|
if not client:
|
|
return False
|
|
|
|
return client.is_redirect_uri_allowed(redirect_uri)
|
|
|
|
@classmethod
|
|
def validate_scopes(cls, client_id: str, requested_scopes: list) -> list:
|
|
"""Validate and filter scopes against client's allowed scopes.
|
|
|
|
Args:
|
|
client_id: The OIDC client ID
|
|
requested_scopes: List of requested scopes
|
|
|
|
Returns:
|
|
List of allowed scopes
|
|
"""
|
|
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
|
if not client:
|
|
return []
|
|
|
|
allowed_scopes = client.scopes or []
|
|
|
|
# Filter to only allowed scopes
|
|
valid_scopes = [s for s in requested_scopes if s in allowed_scopes]
|
|
|
|
return valid_scopes
|