Chore: Refractor Models into organized file/folder
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
"""OIDC subpackage — clients, tokens, sessions, and audit logs."""
|
||||
from gatehouse_app.models.oidc.oidc_client import OIDCClient
|
||||
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken
|
||||
from gatehouse_app.models.oidc.oidc_session import OIDCSession
|
||||
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata
|
||||
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog
|
||||
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey
|
||||
|
||||
__all__ = [
|
||||
"OIDCClient",
|
||||
"OIDCAuthCode",
|
||||
"OIDCRefreshToken",
|
||||
"OIDCSession",
|
||||
"OIDCTokenMetadata",
|
||||
"OIDCAuditLog",
|
||||
"OidcJwksKey",
|
||||
]
|
||||
@@ -0,0 +1,264 @@
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCAuditLog(BaseModel):
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking.
|
||||
|
||||
Logs all OIDC-related events for security, compliance, and debugging.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_audit_logs"
|
||||
|
||||
# Event type categorization
|
||||
event_type = db.Column(db.String(100), nullable=False, index=True)
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=True, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Event outcome
|
||||
success = db.Column(db.Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Error details (for failed events)
|
||||
error_code = db.Column(db.String(100), nullable=True)
|
||||
error_description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Request context
|
||||
ip_address = db.Column(db.String(45), nullable=True, index=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
request_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
|
||||
# Additional event metadata
|
||||
event_metadata = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="audit_logs")
|
||||
user = db.relationship("User", back_populates="oidc_audit_logs")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuditLog."""
|
||||
status = "success" if self.success else "failed"
|
||||
return (
|
||||
f"<OIDCAuditLog event={self.event_type} "
|
||||
f"status={status} client={self.client_id}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_event(
|
||||
cls,
|
||||
event_type: str,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
event_metadata: dict = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log an OIDC event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event (e.g., "authorization_request")
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
success: Whether the event was successful
|
||||
error_code: Error code if event failed
|
||||
error_description: Error description if event failed
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
request_id: Request ID for correlation
|
||||
event_metadata: Additional event metadata
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
log = cls(
|
||||
event_type=event_type,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
db.session.add(log)
|
||||
db.session.commit()
|
||||
return log
|
||||
|
||||
@classmethod
|
||||
def log_authorization_request(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
redirect_uri: str,
|
||||
scope,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log an authorization request event."""
|
||||
return cls.log_event(
|
||||
event_type="authorization_request",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={"redirect_uri": redirect_uri, "scope": scope},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_issue(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_type: str,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log a token issuance event."""
|
||||
return cls.log_event(
|
||||
event_type="token_issue",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=True,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={"token_type": token_type},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_revocation(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_type: str,
|
||||
reason: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log a token revocation event."""
|
||||
return cls.log_event(
|
||||
event_type="token_revocation",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=True,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={"token_type": token_type, "reason": reason},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_authentication_failure(
|
||||
cls,
|
||||
client_id: str,
|
||||
error_code: str,
|
||||
error_description: str,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log an authentication failure event."""
|
||||
return cls.log_event(
|
||||
event_type="authentication_failure",
|
||||
client_id=client_id,
|
||||
success=False,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_user(cls, user_id: str, limit: int = 100) -> list:
|
||||
"""Get audit events for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return (
|
||||
cls.query.filter_by(user_id=user_id, deleted_at=None)
|
||||
.order_by(cls.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_client(cls, client_id: str, limit: int = 100) -> list:
|
||||
"""Get audit events for a client.
|
||||
|
||||
Args:
|
||||
client_id: The client ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return (
|
||||
cls.query.filter_by(client_id=client_id, deleted_at=None)
|
||||
.order_by(cls.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_failed_events(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
limit: int = 100,
|
||||
) -> list:
|
||||
"""Get failed audit events.
|
||||
|
||||
Args:
|
||||
client_id: Optional client ID filter
|
||||
user_id: Optional user ID filter
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
query = cls.query.filter_by(success=False, deleted_at=None)
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
if start_date:
|
||||
query = query.filter(cls.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(cls.created_at <= end_date)
|
||||
return query.order_by(cls.created_at.desc()).limit(limit).all()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,123 @@
|
||||
"""OIDC Authorization Code model for the authorization code grant flow."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCAuthCode(BaseModel):
|
||||
"""OIDC Authorization Code model for the authorization code grant flow.
|
||||
|
||||
Authorization codes are single-use, short-lived codes. The code itself is
|
||||
hashed before storage so that a database breach cannot replay codes.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_authorization_codes"
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Authorization code (hashed for security)
|
||||
code_hash = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# Request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True)
|
||||
nonce = db.Column(db.String(255), nullable=True)
|
||||
code_verifier = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Status tracking
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
is_used = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
# Request metadata
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships — back_populates declared on User and OIDCClient
|
||||
client = db.relationship("OIDCClient", back_populates="authorization_codes")
|
||||
user = db.relationship("User", back_populates="oidc_auth_codes")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuthCode."""
|
||||
return (
|
||||
f"<OIDCAuthCode client_id={self.client_id} "
|
||||
f"user_id={self.user_id} used={self.is_used}>"
|
||||
)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the authorization code has expired."""
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the authorization code is valid for use."""
|
||||
return not self.is_used and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def mark_as_used(self) -> None:
|
||||
"""Mark the authorization code as used."""
|
||||
self.is_used = True
|
||||
self.used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_code(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
code_hash: str,
|
||||
redirect_uri: str,
|
||||
scope=None,
|
||||
nonce: str = None,
|
||||
code_verifier: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> "OIDCAuthCode":
|
||||
"""Create a new authorization code.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
code_hash: Hashed authorization code
|
||||
redirect_uri: The redirect URI
|
||||
scope: Requested scopes
|
||||
nonce: OIDC nonce
|
||||
code_verifier: PKCE code verifier (stored hashed server-side)
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
|
||||
|
||||
Returns:
|
||||
OIDCAuthCode instance
|
||||
"""
|
||||
code = cls(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
code_hash=code_hash,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
nonce=nonce,
|
||||
code_verifier=code_verifier,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
db.session.add(code)
|
||||
db.session.commit()
|
||||
return code
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
for field in ("code_hash", "code_verifier"):
|
||||
if field not in exclude:
|
||||
exclude.append(field)
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,86 @@
|
||||
"""OIDC Client model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType
|
||||
|
||||
|
||||
class OIDCClient(BaseModel):
|
||||
"""OIDC client model for OAuth2/OIDC integrations."""
|
||||
|
||||
__tablename__ = "oidc_clients"
|
||||
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
|
||||
)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
client_id = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
client_secret_hash = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# OAuth/OIDC configuration
|
||||
redirect_uris = db.Column(db.JSON, nullable=False) # Allowed redirect URIs
|
||||
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
|
||||
response_types = db.Column(db.JSON, nullable=False) # Allowed response types
|
||||
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
|
||||
|
||||
# Client metadata
|
||||
logo_uri = db.Column(db.String(512), nullable=True)
|
||||
client_uri = db.Column(db.String(512), nullable=True)
|
||||
policy_uri = db.Column(db.String(512), nullable=True)
|
||||
tos_uri = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Settings
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
is_confidential = db.Column(db.Boolean, default=True, nullable=False)
|
||||
require_pkce = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
# Token lifetimes (in seconds)
|
||||
access_token_lifetime = db.Column(db.Integer, default=3600, nullable=False)
|
||||
refresh_token_lifetime = db.Column(db.Integer, default=2592000, nullable=False)
|
||||
id_token_lifetime = db.Column(db.Integer, default=3600, nullable=False)
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="oidc_clients")
|
||||
|
||||
# OIDC sub-resource relationships (declared here, not monkey-patched elsewhere)
|
||||
authorization_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCClient."""
|
||||
return f"<OIDCClient {self.name} client_id={self.client_id}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
if "client_secret_hash" not in exclude:
|
||||
exclude.append("client_secret_hash")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
def has_grant_type(self, grant_type) -> bool:
|
||||
"""Check if client supports a specific grant type."""
|
||||
return grant_type in self.grant_types
|
||||
|
||||
def has_response_type(self, response_type) -> bool:
|
||||
"""Check if client supports a specific response type."""
|
||||
return response_type in self.response_types
|
||||
|
||||
def is_redirect_uri_allowed(self, redirect_uri: str) -> bool:
|
||||
"""Check if a redirect URI is allowed for this client."""
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def has_scope(self, scope: str) -> bool:
|
||||
"""Check if client is allowed to request a specific scope."""
|
||||
return scope in self.scopes
|
||||
@@ -0,0 +1,76 @@
|
||||
"""OIDC JWKS Key model for persisting signing keys."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OidcJwksKey(BaseModel):
|
||||
"""OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
|
||||
|
||||
Stores RSA/ECDSA key pairs used for signing OIDC tokens. Multiple keys can
|
||||
be stored to support key rotation scenarios.
|
||||
|
||||
Attributes:
|
||||
kid: Unique key ID used in JWT ``kid`` header
|
||||
key_type: Type of key (e.g., "RSA", "EC")
|
||||
private_key: PEM-encoded private key (never exposed in API responses)
|
||||
public_key: PEM-encoded public key
|
||||
algorithm: Signing algorithm (e.g., "RS256", "ES256")
|
||||
is_active: Whether this key is currently used for signing/verification
|
||||
is_primary: Whether this is the primary signing key
|
||||
expires_at: Optional expiry for key rotation enforcement
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_jwks_keys"
|
||||
|
||||
# Override the default UUID id with integer primary key for JWKS key sets
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
|
||||
expires_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Key identification and type
|
||||
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
|
||||
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
|
||||
|
||||
# Key material (PEM-encoded) — private_key must never be returned by API
|
||||
private_key = db.Column(db.Text, nullable=False)
|
||||
public_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Key status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
is_primary = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OidcJwksKey."""
|
||||
return (
|
||||
f"<OidcJwksKey kid={self.kid} "
|
||||
f"key_type={self.key_type} algorithm={self.algorithm}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude_private_key: bool = True):
|
||||
"""Convert model to dictionary.
|
||||
|
||||
Args:
|
||||
exclude_private_key: If True (default), excludes the private key.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the model
|
||||
"""
|
||||
exclude = ["private_key"] if exclude_private_key else []
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
@classmethod
|
||||
def get_active_keys(cls) -> list:
|
||||
"""Get all active keys for signing operations."""
|
||||
return cls.query.filter_by(is_active=True).all()
|
||||
|
||||
@classmethod
|
||||
def get_primary_key(cls) -> "OidcJwksKey | None":
|
||||
"""Get the primary signing key."""
|
||||
return cls.query.filter_by(is_primary=True).first()
|
||||
|
||||
@classmethod
|
||||
def get_key_by_kid(cls, kid: str) -> "OidcJwksKey | None":
|
||||
"""Get an active key by its key ID."""
|
||||
return cls.query.filter_by(kid=kid, is_active=True).first()
|
||||
@@ -0,0 +1,148 @@
|
||||
"""OIDC Refresh Token model for token rotation."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCRefreshToken(BaseModel):
|
||||
"""OIDC Refresh Token model for token refresh and rotation.
|
||||
|
||||
Refresh tokens are long-lived credentials used to obtain new access tokens.
|
||||
They support token rotation for enhanced security — each use invalidates
|
||||
the old token and issues a new one.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_refresh_tokens"
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token (hashed for security — never store plaintext refresh tokens)
|
||||
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
|
||||
# Associated access token JTI (no FK — stored as string for lightweight lookup)
|
||||
access_token_id = db.Column(db.String(255), nullable=True, index=True)
|
||||
|
||||
# Token scope
|
||||
scope = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
|
||||
# Revocation tracking
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Token rotation metadata
|
||||
previous_token_hash = db.Column(db.String(255), nullable=True)
|
||||
rotation_count = db.Column(db.Integer, default=0, nullable=False)
|
||||
|
||||
# Request metadata
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="refresh_tokens")
|
||||
user = db.relationship("User", back_populates="oidc_refresh_tokens")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCRefreshToken."""
|
||||
return (
|
||||
f"<OIDCRefreshToken client_id={self.client_id} "
|
||||
f"user_id={self.user_id} revoked={self.is_revoked()}>"
|
||||
)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the refresh token has expired."""
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_revoked(self) -> bool:
|
||||
"""Check if the refresh token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the refresh token is valid for use."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason: str = None) -> None:
|
||||
"""Revoke the refresh token.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked_at = datetime.now(timezone.utc)
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
def rotate(self, new_token_hash: str) -> "OIDCRefreshToken":
|
||||
"""Rotate the refresh token — invalidate the old hash, store the new one.
|
||||
|
||||
Args:
|
||||
new_token_hash: Hash of the new refresh token
|
||||
|
||||
Returns:
|
||||
self for chaining
|
||||
"""
|
||||
self.previous_token_hash = self.token_hash
|
||||
self.token_hash = new_token_hash
|
||||
self.rotation_count += 1
|
||||
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
db.session.commit()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create_token(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_hash: str,
|
||||
scope=None,
|
||||
access_token_id: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
lifetime_seconds: int = 2592000,
|
||||
) -> "OIDCRefreshToken":
|
||||
"""Create a new refresh token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
token_hash: Hashed refresh token
|
||||
scope: Granted scopes
|
||||
access_token_id: Associated access token JTI
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Token lifetime in seconds (default 30 days)
|
||||
|
||||
Returns:
|
||||
OIDCRefreshToken instance
|
||||
"""
|
||||
token = cls(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
token_hash=token_hash,
|
||||
scope=scope,
|
||||
access_token_id=access_token_id,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
db.session.add(token)
|
||||
db.session.commit()
|
||||
return token
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
for field in ("token_hash", "previous_token_hash"):
|
||||
if field not in exclude:
|
||||
exclude.append(field)
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,161 @@
|
||||
"""OIDC Session model for OIDC session tracking."""
|
||||
import hashlib
|
||||
import base64
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCSession(BaseModel):
|
||||
"""OIDC Session model for tracking OIDC authentication sessions.
|
||||
|
||||
Tracks the state during the OIDC authorization flow, including PKCE
|
||||
parameters and nonce validation.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_sessions"
|
||||
|
||||
# User reference
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Client reference
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# State management
|
||||
state = db.Column(db.String(255), nullable=False, index=True)
|
||||
nonce = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Authorization request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# PKCE parameters
|
||||
code_challenge = db.Column(db.String(255), nullable=True)
|
||||
code_challenge_method = db.Column(db.String(10), nullable=True) # "S256" or "plain"
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
authenticated_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="oidc_sessions")
|
||||
client = db.relationship("OIDCClient", back_populates="oidc_sessions")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCSession."""
|
||||
return (
|
||||
f"<OIDCSession user_id={self.user_id} "
|
||||
f"client_id={self.client_id} state={self.state[:8]}...>"
|
||||
)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the OIDC session has expired."""
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if the user has been authenticated in this session."""
|
||||
return self.authenticated_at is not None
|
||||
|
||||
def mark_authenticated(self) -> None:
|
||||
"""Mark the session as authenticated."""
|
||||
self.authenticated_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def validate_nonce(self, expected_nonce: str) -> bool:
|
||||
"""Validate the nonce matches the expected value.
|
||||
|
||||
Args:
|
||||
expected_nonce: The expected nonce value
|
||||
|
||||
Returns:
|
||||
True if nonce matches
|
||||
"""
|
||||
return self.nonce == expected_nonce
|
||||
|
||||
def validate_code_challenge(self, code_verifier: str) -> bool:
|
||||
"""Validate the code verifier against the stored code challenge.
|
||||
|
||||
Args:
|
||||
code_verifier: The PKCE code verifier
|
||||
|
||||
Returns:
|
||||
True if the challenge is satisfied
|
||||
"""
|
||||
if not self.code_challenge:
|
||||
return False
|
||||
|
||||
if self.code_challenge_method == "S256":
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return self.code_challenge == expected
|
||||
elif self.code_challenge_method == "plain":
|
||||
return self.code_challenge == code_verifier
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def create_session(
|
||||
cls,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
scope=None,
|
||||
nonce: str = None,
|
||||
code_challenge: str = None,
|
||||
code_challenge_method: str = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> "OIDCSession":
|
||||
"""Create a new OIDC session.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: The OIDC client ID
|
||||
state: The state parameter
|
||||
redirect_uri: The redirect URI
|
||||
scope: Requested scopes
|
||||
nonce: OIDC nonce
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method ("S256" or "plain")
|
||||
lifetime_seconds: Session lifetime in seconds
|
||||
|
||||
Returns:
|
||||
OIDCSession instance
|
||||
"""
|
||||
session = cls(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
state=state,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
nonce=nonce,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||
)
|
||||
db.session.add(session)
|
||||
db.session.commit()
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def get_by_state(cls, state: str) -> "OIDCSession | None":
|
||||
"""Get a session by state parameter.
|
||||
|
||||
Args:
|
||||
state: The state parameter
|
||||
|
||||
Returns:
|
||||
OIDCSession instance or None
|
||||
"""
|
||||
return cls.query.filter_by(state=state, deleted_at=None).first()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,200 @@
|
||||
"""OIDC Token Metadata model for token revocation tracking."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCTokenMetadata(BaseModel):
|
||||
"""OIDC Token Metadata model for tracking issued tokens.
|
||||
|
||||
Stores metadata about issued tokens (access, refresh, ID) for revocation.
|
||||
The ``id`` field on this model intentionally overrides the BaseModel UUID
|
||||
to store the JWT JTI directly as the primary key for O(1) revocation checks.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_token_metadata"
|
||||
|
||||
# Primary key = JTI so revocation lookups are always a PK scan
|
||||
id = db.Column(
|
||||
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token type: "access_token", "refresh_token", or "id_token"
|
||||
token_type = db.Column(db.String(50), nullable=False)
|
||||
|
||||
# JWT ID claim (indexed for fast lookup when id != jti)
|
||||
token_jti = db.Column(db.String(255), nullable=False, index=True)
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
|
||||
# Revocation tracking
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="token_metadata")
|
||||
user = db.relationship("User", back_populates="oidc_token_metadata")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCTokenMetadata."""
|
||||
return (
|
||||
f"<OIDCTokenMetadata jti={self.token_jti[:8]}... "
|
||||
f"type={self.token_type} revoked={self.is_revoked()}>"
|
||||
)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the token has expired."""
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_revoked(self) -> bool:
|
||||
"""Check if the token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the token is valid (not expired and not revoked)."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason: str = None) -> None:
|
||||
"""Revoke the token.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked_at = datetime.now(timezone.utc)
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_metadata(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_type: str,
|
||||
token_jti: str,
|
||||
expires_at,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> "OIDCTokenMetadata":
|
||||
"""Create token metadata for tracking.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
token_type: Type of token ("access_token", "refresh_token", "id_token")
|
||||
token_jti: JWT ID claim
|
||||
expires_at: Token expiration datetime
|
||||
ip_address: Client IP address (unused column, kept for API compat)
|
||||
user_agent: Client user agent (unused column, kept for API compat)
|
||||
|
||||
Returns:
|
||||
OIDCTokenMetadata instance
|
||||
"""
|
||||
metadata = cls(
|
||||
id=str(uuid.uuid4()),
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
token_type=token_type,
|
||||
token_jti=token_jti,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.session.add(metadata)
|
||||
db.session.commit()
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def get_by_jti(cls, token_jti: str) -> "OIDCTokenMetadata | None":
|
||||
"""Get token metadata by JWT ID.
|
||||
|
||||
Args:
|
||||
token_jti: The JWT ID
|
||||
|
||||
Returns:
|
||||
OIDCTokenMetadata instance or None
|
||||
"""
|
||||
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
|
||||
|
||||
@classmethod
|
||||
def revoke_by_jti(cls, token_jti: str, reason: str = None) -> bool:
|
||||
"""Revoke a token by its JWT ID.
|
||||
|
||||
Args:
|
||||
token_jti: The JWT ID
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
True if token was found and revoked, False otherwise
|
||||
"""
|
||||
metadata = cls.get_by_jti(token_jti)
|
||||
if metadata:
|
||||
metadata.revoke(reason)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_user(
|
||||
cls, user_id: str, client_id: str = None, reason: str = None
|
||||
) -> int:
|
||||
"""Revoke all tokens for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: Optional client ID filter
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(user_id=user_id, deleted_at=None).filter(
|
||||
cls.revoked_at.is_(None)
|
||||
)
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
|
||||
count = 0
|
||||
for token in query.all():
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_client(
|
||||
cls, client_id: str, user_id: str = None, reason: str = None
|
||||
) -> int:
|
||||
"""Revoke all tokens for a client.
|
||||
|
||||
Args:
|
||||
client_id: The client ID
|
||||
user_id: Optional user ID filter
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(client_id=client_id, deleted_at=None).filter(
|
||||
cls.revoked_at.is_(None)
|
||||
)
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
|
||||
count = 0
|
||||
for token in query.all():
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
Reference in New Issue
Block a user