Chore: Refractor Models into organized file/folder

This commit is contained in:
2026-03-01 12:40:48 +05:45
parent 58432da1c8
commit 07193a2d2e
35 changed files with 1475 additions and 932 deletions
+18
View File
@@ -0,0 +1,18 @@
"""OIDC subpackage — clients, tokens, sessions, and audit logs."""
from gatehouse_app.models.oidc.oidc_client import OIDCClient
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken
from gatehouse_app.models.oidc.oidc_session import OIDCSession
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey
__all__ = [
"OIDCClient",
"OIDCAuthCode",
"OIDCRefreshToken",
"OIDCSession",
"OIDCTokenMetadata",
"OIDCAuditLog",
"OidcJwksKey",
]
+264
View File
@@ -0,0 +1,264 @@
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCAuditLog(BaseModel):
"""OIDC Audit Log model for comprehensive OIDC event tracking.
Logs all OIDC-related events for security, compliance, and debugging.
"""
__tablename__ = "oidc_audit_logs"
# Event type categorization
event_type = db.Column(db.String(100), nullable=False, index=True)
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=True, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=True, index=True
)
# Event outcome
success = db.Column(db.Boolean, default=True, nullable=False, index=True)
# Error details (for failed events)
error_code = db.Column(db.String(100), nullable=True)
error_description = db.Column(db.Text, nullable=True)
# Request context
ip_address = db.Column(db.String(45), nullable=True, index=True)
user_agent = db.Column(db.Text, nullable=True)
request_id = db.Column(db.String(36), nullable=True, index=True)
# Additional event metadata
event_metadata = db.Column(db.JSON, nullable=True)
# Relationships
client = db.relationship("OIDCClient", back_populates="audit_logs")
user = db.relationship("User", back_populates="oidc_audit_logs")
def __repr__(self):
"""String representation of OIDCAuditLog."""
status = "success" if self.success else "failed"
return (
f"<OIDCAuditLog event={self.event_type} "
f"status={status} client={self.client_id}>"
)
@classmethod
def log_event(
cls,
event_type: str,
client_id: str = None,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
event_metadata: dict = None,
) -> "OIDCAuditLog":
"""Log an OIDC event.
Args:
event_type: Type of event (e.g., "authorization_request")
client_id: The OIDC client ID
user_id: The user ID
success: Whether the event was successful
error_code: Error code if event failed
error_description: Error description if event failed
ip_address: Client IP address
user_agent: Client user agent
request_id: Request ID for correlation
event_metadata: Additional event metadata
Returns:
OIDCAuditLog instance
"""
log = cls(
event_type=event_type,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata=event_metadata,
)
db.session.add(log)
db.session.commit()
return log
@classmethod
def log_authorization_request(
cls,
client_id: str,
user_id: str,
redirect_uri: str,
scope,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
) -> "OIDCAuditLog":
"""Log an authorization request event."""
return cls.log_event(
event_type="authorization_request",
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={"redirect_uri": redirect_uri, "scope": scope},
)
@classmethod
def log_token_issue(
cls,
client_id: str,
user_id: str,
token_type: str,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log a token issuance event."""
return cls.log_event(
event_type="token_issue",
client_id=client_id,
user_id=user_id,
success=True,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={"token_type": token_type},
)
@classmethod
def log_token_revocation(
cls,
client_id: str,
user_id: str,
token_type: str,
reason: str = None,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log a token revocation event."""
return cls.log_event(
event_type="token_revocation",
client_id=client_id,
user_id=user_id,
success=True,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={"token_type": token_type, "reason": reason},
)
@classmethod
def log_authentication_failure(
cls,
client_id: str,
error_code: str,
error_description: str,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log an authentication failure event."""
return cls.log_event(
event_type="authentication_failure",
client_id=client_id,
success=False,
error_code=error_code,
error_description=error_description,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
)
@classmethod
def get_events_for_user(cls, user_id: str, limit: int = 100) -> list:
"""Get audit events for a user.
Args:
user_id: The user ID
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
return (
cls.query.filter_by(user_id=user_id, deleted_at=None)
.order_by(cls.created_at.desc())
.limit(limit)
.all()
)
@classmethod
def get_events_for_client(cls, client_id: str, limit: int = 100) -> list:
"""Get audit events for a client.
Args:
client_id: The client ID
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
return (
cls.query.filter_by(client_id=client_id, deleted_at=None)
.order_by(cls.created_at.desc())
.limit(limit)
.all()
)
@classmethod
def get_failed_events(
cls,
client_id: str = None,
user_id: str = None,
start_date=None,
end_date=None,
limit: int = 100,
) -> list:
"""Get failed audit events.
Args:
client_id: Optional client ID filter
user_id: Optional user ID filter
start_date: Optional start date filter
end_date: Optional end date filter
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
query = cls.query.filter_by(success=False, deleted_at=None)
if client_id:
query = query.filter_by(client_id=client_id)
if user_id:
query = query.filter_by(user_id=user_id)
if start_date:
query = query.filter(cls.created_at >= start_date)
if end_date:
query = query.filter(cls.created_at <= end_date)
return query.order_by(cls.created_at.desc()).limit(limit).all()
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
@@ -0,0 +1,123 @@
"""OIDC Authorization Code model for the authorization code grant flow."""
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCAuthCode(BaseModel):
"""OIDC Authorization Code model for the authorization code grant flow.
Authorization codes are single-use, short-lived codes. The code itself is
hashed before storage so that a database breach cannot replay codes.
"""
__tablename__ = "oidc_authorization_codes"
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Authorization code (hashed for security)
code_hash = db.Column(db.String(255), nullable=False)
# Request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True)
nonce = db.Column(db.String(255), nullable=True)
code_verifier = db.Column(db.String(255), nullable=True)
# Status tracking
expires_at = db.Column(db.DateTime, nullable=False, index=True)
used_at = db.Column(db.DateTime, nullable=True)
is_used = db.Column(db.Boolean, default=False, nullable=False)
# Request metadata
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
# Relationships — back_populates declared on User and OIDCClient
client = db.relationship("OIDCClient", back_populates="authorization_codes")
user = db.relationship("User", back_populates="oidc_auth_codes")
def __repr__(self):
"""String representation of OIDCAuthCode."""
return (
f"<OIDCAuthCode client_id={self.client_id} "
f"user_id={self.user_id} used={self.is_used}>"
)
def is_expired(self) -> bool:
"""Check if the authorization code has expired."""
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_valid(self) -> bool:
"""Check if the authorization code is valid for use."""
return not self.is_used and not self.is_expired() and self.deleted_at is None
def mark_as_used(self) -> None:
"""Mark the authorization code as used."""
self.is_used = True
self.used_at = datetime.now(timezone.utc)
db.session.commit()
@classmethod
def create_code(
cls,
client_id: str,
user_id: str,
code_hash: str,
redirect_uri: str,
scope=None,
nonce: str = None,
code_verifier: str = None,
ip_address: str = None,
user_agent: str = None,
lifetime_seconds: int = 600,
) -> "OIDCAuthCode":
"""Create a new authorization code.
Args:
client_id: The OIDC client ID
user_id: The user ID
code_hash: Hashed authorization code
redirect_uri: The redirect URI
scope: Requested scopes
nonce: OIDC nonce
code_verifier: PKCE code verifier (stored hashed server-side)
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
Returns:
OIDCAuthCode instance
"""
code = cls(
client_id=client_id,
user_id=user_id,
code_hash=code_hash,
redirect_uri=redirect_uri,
scope=scope,
nonce=nonce,
code_verifier=code_verifier,
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
ip_address=ip_address,
user_agent=user_agent,
)
db.session.add(code)
db.session.commit()
return code
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
for field in ("code_hash", "code_verifier"):
if field not in exclude:
exclude.append(field)
return super().to_dict(exclude=exclude)
+86
View File
@@ -0,0 +1,86 @@
"""OIDC Client model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType
class OIDCClient(BaseModel):
"""OIDC client model for OAuth2/OIDC integrations."""
__tablename__ = "oidc_clients"
organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
)
name = db.Column(db.String(255), nullable=False)
client_id = db.Column(db.String(255), unique=True, nullable=False, index=True)
client_secret_hash = db.Column(db.String(255), nullable=False)
# OAuth/OIDC configuration
redirect_uris = db.Column(db.JSON, nullable=False) # Allowed redirect URIs
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
response_types = db.Column(db.JSON, nullable=False) # Allowed response types
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
# Client metadata
logo_uri = db.Column(db.String(512), nullable=True)
client_uri = db.Column(db.String(512), nullable=True)
policy_uri = db.Column(db.String(512), nullable=True)
tos_uri = db.Column(db.String(512), nullable=True)
# Settings
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_confidential = db.Column(db.Boolean, default=True, nullable=False)
require_pkce = db.Column(db.Boolean, default=True, nullable=False)
# Token lifetimes (in seconds)
access_token_lifetime = db.Column(db.Integer, default=3600, nullable=False)
refresh_token_lifetime = db.Column(db.Integer, default=2592000, nullable=False)
id_token_lifetime = db.Column(db.Integer, default=3600, nullable=False)
# Relationships
organization = db.relationship("Organization", back_populates="oidc_clients")
# OIDC sub-resource relationships (declared here, not monkey-patched elsewhere)
authorization_codes = db.relationship(
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
)
refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
)
oidc_sessions = db.relationship(
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
)
token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
)
audit_logs = db.relationship(
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
)
def __repr__(self):
"""String representation of OIDCClient."""
return f"<OIDCClient {self.name} client_id={self.client_id}>"
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
if "client_secret_hash" not in exclude:
exclude.append("client_secret_hash")
return super().to_dict(exclude=exclude)
def has_grant_type(self, grant_type) -> bool:
"""Check if client supports a specific grant type."""
return grant_type in self.grant_types
def has_response_type(self, response_type) -> bool:
"""Check if client supports a specific response type."""
return response_type in self.response_types
def is_redirect_uri_allowed(self, redirect_uri: str) -> bool:
"""Check if a redirect URI is allowed for this client."""
return redirect_uri in self.redirect_uris
def has_scope(self, scope: str) -> bool:
"""Check if client is allowed to request a specific scope."""
return scope in self.scopes
@@ -0,0 +1,76 @@
"""OIDC JWKS Key model for persisting signing keys."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OidcJwksKey(BaseModel):
"""OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
Stores RSA/ECDSA key pairs used for signing OIDC tokens. Multiple keys can
be stored to support key rotation scenarios.
Attributes:
kid: Unique key ID used in JWT ``kid`` header
key_type: Type of key (e.g., "RSA", "EC")
private_key: PEM-encoded private key (never exposed in API responses)
public_key: PEM-encoded public key
algorithm: Signing algorithm (e.g., "RS256", "ES256")
is_active: Whether this key is currently used for signing/verification
is_primary: Whether this is the primary signing key
expires_at: Optional expiry for key rotation enforcement
"""
__tablename__ = "oidc_jwks_keys"
# Override the default UUID id with integer primary key for JWKS key sets
id = db.Column(db.Integer, primary_key=True)
expires_at = db.Column(db.DateTime, nullable=True)
# Key identification and type
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
# Key material (PEM-encoded) — private_key must never be returned by API
private_key = db.Column(db.Text, nullable=False)
public_key = db.Column(db.Text, nullable=False)
# Key status
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_primary = db.Column(db.Boolean, default=False, nullable=False)
def __repr__(self):
"""String representation of OidcJwksKey."""
return (
f"<OidcJwksKey kid={self.kid} "
f"key_type={self.key_type} algorithm={self.algorithm}>"
)
def to_dict(self, exclude_private_key: bool = True):
"""Convert model to dictionary.
Args:
exclude_private_key: If True (default), excludes the private key.
Returns:
Dictionary representation of the model
"""
exclude = ["private_key"] if exclude_private_key else []
return super().to_dict(exclude=exclude)
@classmethod
def get_active_keys(cls) -> list:
"""Get all active keys for signing operations."""
return cls.query.filter_by(is_active=True).all()
@classmethod
def get_primary_key(cls) -> "OidcJwksKey | None":
"""Get the primary signing key."""
return cls.query.filter_by(is_primary=True).first()
@classmethod
def get_key_by_kid(cls, kid: str) -> "OidcJwksKey | None":
"""Get an active key by its key ID."""
return cls.query.filter_by(kid=kid, is_active=True).first()
@@ -0,0 +1,148 @@
"""OIDC Refresh Token model for token rotation."""
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCRefreshToken(BaseModel):
"""OIDC Refresh Token model for token refresh and rotation.
Refresh tokens are long-lived credentials used to obtain new access tokens.
They support token rotation for enhanced security — each use invalidates
the old token and issues a new one.
"""
__tablename__ = "oidc_refresh_tokens"
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Token (hashed for security — never store plaintext refresh tokens)
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
# Associated access token JTI (no FK — stored as string for lightweight lookup)
access_token_id = db.Column(db.String(255), nullable=True, index=True)
# Token scope
scope = db.Column(db.JSON, nullable=True)
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
# Revocation tracking
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
# Token rotation metadata
previous_token_hash = db.Column(db.String(255), nullable=True)
rotation_count = db.Column(db.Integer, default=0, nullable=False)
# Request metadata
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
# Relationships
client = db.relationship("OIDCClient", back_populates="refresh_tokens")
user = db.relationship("User", back_populates="oidc_refresh_tokens")
def __repr__(self):
"""String representation of OIDCRefreshToken."""
return (
f"<OIDCRefreshToken client_id={self.client_id} "
f"user_id={self.user_id} revoked={self.is_revoked()}>"
)
def is_expired(self) -> bool:
"""Check if the refresh token has expired."""
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self) -> bool:
"""Check if the refresh token has been revoked."""
return self.revoked_at is not None
def is_valid(self) -> bool:
"""Check if the refresh token is valid for use."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason: str = None) -> None:
"""Revoke the refresh token.
Args:
reason: Optional reason for revocation
"""
self.revoked_at = datetime.now(timezone.utc)
self.revoked_reason = reason
db.session.commit()
def rotate(self, new_token_hash: str) -> "OIDCRefreshToken":
"""Rotate the refresh token — invalidate the old hash, store the new one.
Args:
new_token_hash: Hash of the new refresh token
Returns:
self for chaining
"""
self.previous_token_hash = self.token_hash
self.token_hash = new_token_hash
self.rotation_count += 1
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
db.session.commit()
return self
@classmethod
def create_token(
cls,
client_id: str,
user_id: str,
token_hash: str,
scope=None,
access_token_id: str = None,
ip_address: str = None,
user_agent: str = None,
lifetime_seconds: int = 2592000,
) -> "OIDCRefreshToken":
"""Create a new refresh token.
Args:
client_id: The OIDC client ID
user_id: The user ID
token_hash: Hashed refresh token
scope: Granted scopes
access_token_id: Associated access token JTI
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Token lifetime in seconds (default 30 days)
Returns:
OIDCRefreshToken instance
"""
token = cls(
client_id=client_id,
user_id=user_id,
token_hash=token_hash,
scope=scope,
access_token_id=access_token_id,
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
ip_address=ip_address,
user_agent=user_agent,
)
db.session.add(token)
db.session.commit()
return token
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
for field in ("token_hash", "previous_token_hash"):
if field not in exclude:
exclude.append(field)
return super().to_dict(exclude=exclude)
+161
View File
@@ -0,0 +1,161 @@
"""OIDC Session model for OIDC session tracking."""
import hashlib
import base64
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCSession(BaseModel):
"""OIDC Session model for tracking OIDC authentication sessions.
Tracks the state during the OIDC authorization flow, including PKCE
parameters and nonce validation.
"""
__tablename__ = "oidc_sessions"
# User reference
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Client reference
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
# State management
state = db.Column(db.String(255), nullable=False, index=True)
nonce = db.Column(db.String(255), nullable=True)
# Authorization request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True)
# PKCE parameters
code_challenge = db.Column(db.String(255), nullable=True)
code_challenge_method = db.Column(db.String(10), nullable=True) # "S256" or "plain"
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
authenticated_at = db.Column(db.DateTime, nullable=True)
# Relationships
user = db.relationship("User", back_populates="oidc_sessions")
client = db.relationship("OIDCClient", back_populates="oidc_sessions")
def __repr__(self):
"""String representation of OIDCSession."""
return (
f"<OIDCSession user_id={self.user_id} "
f"client_id={self.client_id} state={self.state[:8]}...>"
)
def is_expired(self) -> bool:
"""Check if the OIDC session has expired."""
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_authenticated(self) -> bool:
"""Check if the user has been authenticated in this session."""
return self.authenticated_at is not None
def mark_authenticated(self) -> None:
"""Mark the session as authenticated."""
self.authenticated_at = datetime.now(timezone.utc)
db.session.commit()
def validate_nonce(self, expected_nonce: str) -> bool:
"""Validate the nonce matches the expected value.
Args:
expected_nonce: The expected nonce value
Returns:
True if nonce matches
"""
return self.nonce == expected_nonce
def validate_code_challenge(self, code_verifier: str) -> bool:
"""Validate the code verifier against the stored code challenge.
Args:
code_verifier: The PKCE code verifier
Returns:
True if the challenge is satisfied
"""
if not self.code_challenge:
return False
if self.code_challenge_method == "S256":
digest = hashlib.sha256(code_verifier.encode()).digest()
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
return self.code_challenge == expected
elif self.code_challenge_method == "plain":
return self.code_challenge == code_verifier
return False
@classmethod
def create_session(
cls,
user_id: str,
client_id: str,
state: str,
redirect_uri: str,
scope=None,
nonce: str = None,
code_challenge: str = None,
code_challenge_method: str = None,
lifetime_seconds: int = 600,
) -> "OIDCSession":
"""Create a new OIDC session.
Args:
user_id: The user ID
client_id: The OIDC client ID
state: The state parameter
redirect_uri: The redirect URI
scope: Requested scopes
nonce: OIDC nonce
code_challenge: PKCE code challenge
code_challenge_method: PKCE method ("S256" or "plain")
lifetime_seconds: Session lifetime in seconds
Returns:
OIDCSession instance
"""
session = cls(
user_id=user_id,
client_id=client_id,
state=state,
redirect_uri=redirect_uri,
scope=scope,
nonce=nonce,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
)
db.session.add(session)
db.session.commit()
return session
@classmethod
def get_by_state(cls, state: str) -> "OIDCSession | None":
"""Get a session by state parameter.
Args:
state: The state parameter
Returns:
OIDCSession instance or None
"""
return cls.query.filter_by(state=state, deleted_at=None).first()
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
@@ -0,0 +1,200 @@
"""OIDC Token Metadata model for token revocation tracking."""
import uuid
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCTokenMetadata(BaseModel):
"""OIDC Token Metadata model for tracking issued tokens.
Stores metadata about issued tokens (access, refresh, ID) for revocation.
The ``id`` field on this model intentionally overrides the BaseModel UUID
to store the JWT JTI directly as the primary key for O(1) revocation checks.
"""
__tablename__ = "oidc_token_metadata"
# Primary key = JTI so revocation lookups are always a PK scan
id = db.Column(
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Token type: "access_token", "refresh_token", or "id_token"
token_type = db.Column(db.String(50), nullable=False)
# JWT ID claim (indexed for fast lookup when id != jti)
token_jti = db.Column(db.String(255), nullable=False, index=True)
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
# Revocation tracking
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
# Relationships
client = db.relationship("OIDCClient", back_populates="token_metadata")
user = db.relationship("User", back_populates="oidc_token_metadata")
def __repr__(self):
"""String representation of OIDCTokenMetadata."""
return (
f"<OIDCTokenMetadata jti={self.token_jti[:8]}... "
f"type={self.token_type} revoked={self.is_revoked()}>"
)
def is_expired(self) -> bool:
"""Check if the token has expired."""
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self) -> bool:
"""Check if the token has been revoked."""
return self.revoked_at is not None
def is_valid(self) -> bool:
"""Check if the token is valid (not expired and not revoked)."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason: str = None) -> None:
"""Revoke the token.
Args:
reason: Optional reason for revocation
"""
self.revoked_at = datetime.now(timezone.utc)
self.revoked_reason = reason
db.session.commit()
@classmethod
def create_metadata(
cls,
client_id: str,
user_id: str,
token_type: str,
token_jti: str,
expires_at,
ip_address: str = None,
user_agent: str = None,
) -> "OIDCTokenMetadata":
"""Create token metadata for tracking.
Args:
client_id: The OIDC client ID
user_id: The user ID
token_type: Type of token ("access_token", "refresh_token", "id_token")
token_jti: JWT ID claim
expires_at: Token expiration datetime
ip_address: Client IP address (unused column, kept for API compat)
user_agent: Client user agent (unused column, kept for API compat)
Returns:
OIDCTokenMetadata instance
"""
metadata = cls(
id=str(uuid.uuid4()),
client_id=client_id,
user_id=user_id,
token_type=token_type,
token_jti=token_jti,
expires_at=expires_at,
)
db.session.add(metadata)
db.session.commit()
return metadata
@classmethod
def get_by_jti(cls, token_jti: str) -> "OIDCTokenMetadata | None":
"""Get token metadata by JWT ID.
Args:
token_jti: The JWT ID
Returns:
OIDCTokenMetadata instance or None
"""
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
@classmethod
def revoke_by_jti(cls, token_jti: str, reason: str = None) -> bool:
"""Revoke a token by its JWT ID.
Args:
token_jti: The JWT ID
reason: Optional revocation reason
Returns:
True if token was found and revoked, False otherwise
"""
metadata = cls.get_by_jti(token_jti)
if metadata:
metadata.revoke(reason)
return True
return False
@classmethod
def revoke_all_for_user(
cls, user_id: str, client_id: str = None, reason: str = None
) -> int:
"""Revoke all tokens for a user.
Args:
user_id: The user ID
client_id: Optional client ID filter
reason: Optional revocation reason
Returns:
Number of tokens revoked
"""
query = cls.query.filter_by(user_id=user_id, deleted_at=None).filter(
cls.revoked_at.is_(None)
)
if client_id:
query = query.filter_by(client_id=client_id)
count = 0
for token in query.all():
token.revoke(reason)
count += 1
return count
@classmethod
def revoke_all_for_client(
cls, client_id: str, user_id: str = None, reason: str = None
) -> int:
"""Revoke all tokens for a client.
Args:
client_id: The client ID
user_id: Optional user ID filter
reason: Optional revocation reason
Returns:
Number of tokens revoked
"""
query = cls.query.filter_by(client_id=client_id, deleted_at=None).filter(
cls.revoked_at.is_(None)
)
if user_id:
query = query.filter_by(user_id=user_id)
count = 0
for token in query.all():
token.revoke(reason)
count += 1
return count
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)