move app to gatehouse-app
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
"""Models package."""
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.models.session import Session
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken
|
||||
from gatehouse_app.models.oidc_session import OIDCSession
|
||||
from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
|
||||
from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"User",
|
||||
"Organization",
|
||||
"OrganizationMember",
|
||||
"AuthenticationMethod",
|
||||
"Session",
|
||||
"AuditLog",
|
||||
"OIDCClient",
|
||||
"OIDCAuthCode",
|
||||
"OIDCRefreshToken",
|
||||
"OIDCSession",
|
||||
"OIDCTokenMetadata",
|
||||
"OIDCAuditLog",
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Audit log model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
|
||||
|
||||
class AuditLog(BaseModel):
|
||||
"""Audit log model for tracking user and system actions."""
|
||||
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True)
|
||||
action = db.Column(db.Enum(AuditAction), nullable=False, index=True)
|
||||
|
||||
# Context
|
||||
resource_type = db.Column(db.String(50), nullable=True, index=True)
|
||||
resource_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
organization_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
|
||||
# Request details
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
request_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
|
||||
# Additional data
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Success/failure
|
||||
success = db.Column(db.Boolean, default=True, nullable=False)
|
||||
error_message = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="audit_logs")
|
||||
|
||||
# Indexes for common queries
|
||||
__table_args__ = (
|
||||
db.Index("idx_audit_user_action", "user_id", "action"),
|
||||
db.Index("idx_audit_resource", "resource_type", "resource_id"),
|
||||
db.Index("idx_audit_org", "organization_id", "created_at"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of AuditLog."""
|
||||
return f"<AuditLog action={self.action} user_id={self.user_id}>"
|
||||
|
||||
@classmethod
|
||||
def log(cls, action, user_id=None, **kwargs):
|
||||
"""
|
||||
Create an audit log entry.
|
||||
|
||||
Args:
|
||||
action: AuditAction enum value
|
||||
user_id: ID of the user performing the action
|
||||
**kwargs: Additional audit log fields
|
||||
|
||||
Returns:
|
||||
AuditLog instance
|
||||
"""
|
||||
log_entry = cls(action=action, user_id=user_id, **kwargs)
|
||||
log_entry.save()
|
||||
return log_entry
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Authentication method model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
|
||||
class AuthenticationMethod(BaseModel):
|
||||
"""Authentication method model storing user authentication credentials."""
|
||||
|
||||
__tablename__ = "authentication_methods"
|
||||
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True)
|
||||
method_type = db.Column(db.Enum(AuthMethodType), nullable=False, index=True)
|
||||
|
||||
# For password authentication
|
||||
password_hash = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# For OAuth/OIDC providers
|
||||
provider_user_id = db.Column(db.String(255), nullable=True)
|
||||
provider_data = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# For TOTP authentication
|
||||
totp_secret = db.Column(db.String(32), nullable=True)
|
||||
totp_backup_codes = db.Column(db.JSON, nullable=True)
|
||||
totp_verified_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Metadata
|
||||
is_primary = db.Column(db.Boolean, default=False, nullable=False)
|
||||
verified = db.Column(db.Boolean, default=False, nullable=False)
|
||||
last_used_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="authentication_methods")
|
||||
|
||||
# Ensure unique provider combinations
|
||||
__table_args__ = (
|
||||
db.Index("idx_user_method", "user_id", "method_type"),
|
||||
db.UniqueConstraint(
|
||||
"user_id", "method_type", "provider_user_id", name="uix_user_method_provider"
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of AuthenticationMethod."""
|
||||
return f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
|
||||
|
||||
def is_password(self):
|
||||
"""Check if this is a password authentication method."""
|
||||
return self.method_type == AuthMethodType.PASSWORD
|
||||
|
||||
def is_oauth(self):
|
||||
"""Check if this is an OAuth authentication method."""
|
||||
return self.method_type in [
|
||||
AuthMethodType.GOOGLE,
|
||||
AuthMethodType.GITHUB,
|
||||
AuthMethodType.MICROSOFT,
|
||||
]
|
||||
|
||||
def is_totp(self):
|
||||
"""Check if this is a TOTP authentication method."""
|
||||
return self.method_type == AuthMethodType.TOTP
|
||||
|
||||
def is_webauthn(self):
|
||||
"""Check if this is a WebAuthn authentication method."""
|
||||
return self.method_type == AuthMethodType.WEBAUTHN
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude password hash and TOTP secrets
|
||||
exclude.append("password_hash")
|
||||
exclude.append("totp_secret")
|
||||
exclude.append("totp_backup_codes")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
def to_webauthn_dict(self):
|
||||
"""Convert WebAuthn credential to public dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary with safe-to-expose credential information.
|
||||
"""
|
||||
if not self.is_webauthn() or not self.provider_data:
|
||||
return None
|
||||
|
||||
data = self.provider_data
|
||||
return {
|
||||
"id": data.get("credential_id"),
|
||||
"name": data.get("name"),
|
||||
"transports": data.get("transports", []),
|
||||
"created_at": data.get("created_at"),
|
||||
"last_used_at": data.get("last_used_at"),
|
||||
"sign_count": data.get("sign_count", 0),
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Base model with common fields and functionality."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
|
||||
class BaseModel(db.Model):
|
||||
"""Base model class with common fields."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id = db.Column(
|
||||
db.String(36),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
created_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = db.Column(
|
||||
db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
deleted_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
def save(self):
|
||||
"""Save the model instance to database."""
|
||||
db.session.add(self)
|
||||
db.session.commit()
|
||||
return self
|
||||
|
||||
def delete(self, soft=True):
|
||||
"""
|
||||
Delete the model instance.
|
||||
|
||||
Args:
|
||||
soft: If True, performs soft delete. If False, hard delete.
|
||||
"""
|
||||
if soft:
|
||||
self.deleted_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
else:
|
||||
db.session.delete(self)
|
||||
db.session.commit()
|
||||
|
||||
def update(self, **kwargs):
|
||||
"""Update model fields."""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
self.updated_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
return self
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""
|
||||
Convert model to dictionary.
|
||||
|
||||
Args:
|
||||
exclude: List of fields to exclude from output
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the model
|
||||
"""
|
||||
exclude = exclude or []
|
||||
result = {}
|
||||
for column in self.__table__.columns:
|
||||
if column.name not in exclude:
|
||||
value = getattr(self, column.name)
|
||||
if isinstance(value, datetime):
|
||||
result[column.name] = value.isoformat()
|
||||
else:
|
||||
result[column.name] = value
|
||||
return result
|
||||
@@ -0,0 +1,231 @@
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
|
||||
from datetime import datetime
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCAuditLog(BaseModel):
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking.
|
||||
|
||||
This model logs all OIDC-related events for security, compliance,
|
||||
and debugging purposes.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_audit_logs"
|
||||
|
||||
# Event type categorization
|
||||
event_type = db.Column(db.String(100), nullable=False, index=True)
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=True, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Event outcome
|
||||
success = db.Column(db.Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Error details (for failed events)
|
||||
error_code = db.Column(db.String(100), nullable=True)
|
||||
error_description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Request context
|
||||
ip_address = db.Column(db.String(45), nullable=True, index=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
request_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
|
||||
# Additional event metadata
|
||||
event_metadata = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="audit_logs")
|
||||
user = db.relationship("User", back_populates="oidc_audit_logs")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuditLog."""
|
||||
status = "success" if self.success else "failed"
|
||||
return f"<OIDCAuditLog event={self.event_type} status={status} client={self.client_id}>"
|
||||
|
||||
@classmethod
|
||||
def log_event(cls, event_type, client_id=None, user_id=None, success=True,
|
||||
error_code=None, error_description=None, ip_address=None,
|
||||
user_agent=None, request_id=None, event_metadata=None):
|
||||
"""Log an OIDC event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event (e.g., "authorization_request", "token_issue")
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
success: Whether the event was successful
|
||||
error_code: Error code if event failed
|
||||
error_description: Error description if event failed
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
request_id: Request ID for correlation
|
||||
event_metadata: Additional event metadata
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
log = cls(
|
||||
event_type=event_type,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
db.session.add(log)
|
||||
db.session.commit()
|
||||
return log
|
||||
|
||||
@classmethod
|
||||
def log_authorization_request(cls, client_id, user_id, redirect_uri, scope,
|
||||
ip_address=None, user_agent=None, request_id=None,
|
||||
success=True, error_code=None, error_description=None):
|
||||
"""Log an authorization request event."""
|
||||
return cls.log_event(
|
||||
event_type="authorization_request",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": scope,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_issue(cls, client_id, user_id, token_type,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
"""Log a token issuance event."""
|
||||
return cls.log_event(
|
||||
event_type="token_issue",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=True,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={"token_type": token_type}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_revocation(cls, client_id, user_id, token_type, reason=None,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
"""Log a token revocation event."""
|
||||
return cls.log_event(
|
||||
event_type="token_revocation",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=True,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={
|
||||
"token_type": token_type,
|
||||
"reason": reason,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_authentication_failure(cls, client_id, error_code, error_description,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
"""Log an authentication failure event."""
|
||||
return cls.log_event(
|
||||
event_type="authentication_failure",
|
||||
client_id=client_id,
|
||||
success=False,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_user(cls, user_id, limit=100):
|
||||
"""Get audit events for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return cls.query.filter_by(user_id=user_id, deleted_at=None)\
|
||||
.order_by(cls.created_at.desc())\
|
||||
.limit(limit)\
|
||||
.all()
|
||||
|
||||
@classmethod
|
||||
def get_events_for_client(cls, client_id, limit=100):
|
||||
"""Get audit events for a client.
|
||||
|
||||
Args:
|
||||
client_id: The client ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return cls.query.filter_by(client_id=client_id, deleted_at=None)\
|
||||
.order_by(cls.created_at.desc())\
|
||||
.limit(limit)\
|
||||
.all()
|
||||
|
||||
@classmethod
|
||||
def get_failed_events(cls, client_id=None, user_id=None, start_date=None,
|
||||
end_date=None, limit=100):
|
||||
"""Get failed audit events.
|
||||
|
||||
Args:
|
||||
client_id: Optional client ID filter
|
||||
user_id: Optional user ID filter
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
query = cls.query.filter_by(success=False, deleted_at=None)
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
if start_date:
|
||||
query = query.filter(cls.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(cls.created_at <= end_date)
|
||||
|
||||
return query.order_by(cls.created_at.desc()).limit(limit).all()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -0,0 +1,125 @@
|
||||
"""OIDC Authorization Code model for auth code flow."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCAuthCode(BaseModel):
|
||||
"""OIDC Authorization Code model for authorization code flow.
|
||||
|
||||
Authorization codes are single-use, short-lived codes used in the
|
||||
authorization code grant flow. The code is hashed for security.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_authorization_codes"
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Authorization code (hashed for security)
|
||||
code_hash = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# Request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True) # Requested scopes
|
||||
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
|
||||
code_verifier = db.Column(db.String(255), nullable=True) # For PKCE
|
||||
|
||||
# Status tracking
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
is_used = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
# Request metadata
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="authorization_codes")
|
||||
user = db.relationship("User", back_populates="oidc_auth_codes")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuthCode."""
|
||||
return f"<OIDCAuthCode client_id={self.client_id} user_id={self.user_id} used={self.is_used}>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the authorization code has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
# Make naive datetime timezone-aware (UTC)
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_valid(self):
|
||||
"""Check if the authorization code is valid for use."""
|
||||
return not self.is_used and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def mark_as_used(self):
|
||||
"""Mark the authorization code as used."""
|
||||
self.is_used = True
|
||||
self.used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None,
|
||||
nonce=None, code_verifier=None, ip_address=None, user_agent=None,
|
||||
lifetime_seconds=600):
|
||||
"""Create a new authorization code.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
code_hash: Hashed authorization code
|
||||
redirect_uri: The redirect URI
|
||||
scope: Requested scopes
|
||||
nonce: OIDC nonce
|
||||
code_verifier: PKCE code verifier
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
|
||||
|
||||
Returns:
|
||||
OIDCAuthCode instance
|
||||
"""
|
||||
code = cls(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
code_hash=code_hash,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
nonce=nonce,
|
||||
code_verifier=code_verifier,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
db.session.add(code)
|
||||
db.session.commit()
|
||||
return code
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude code hash
|
||||
exclude.append("code_hash")
|
||||
exclude.append("code_verifier")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_auth_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.authorization_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""OIDC Client model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType
|
||||
|
||||
|
||||
class OIDCClient(BaseModel):
|
||||
"""OIDC client model for OAuth2/OIDC integrations."""
|
||||
|
||||
__tablename__ = "oidc_clients"
|
||||
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
|
||||
)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
client_id = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
client_secret_hash = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# OAuth/OIDC configuration
|
||||
redirect_uris = db.Column(db.JSON, nullable=False) # List of allowed redirect URIs
|
||||
grant_types = db.Column(db.JSON, nullable=False) # List of allowed grant types
|
||||
response_types = db.Column(db.JSON, nullable=False) # List of allowed response types
|
||||
scopes = db.Column(db.JSON, nullable=False) # List of allowed scopes
|
||||
|
||||
# Client metadata
|
||||
logo_uri = db.Column(db.String(512), nullable=True)
|
||||
client_uri = db.Column(db.String(512), nullable=True)
|
||||
policy_uri = db.Column(db.String(512), nullable=True)
|
||||
tos_uri = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Settings
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
is_confidential = db.Column(db.Boolean, default=True, nullable=False)
|
||||
require_pkce = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
# Token lifetimes (in seconds)
|
||||
access_token_lifetime = db.Column(db.Integer, default=3600, nullable=False)
|
||||
refresh_token_lifetime = db.Column(db.Integer, default=2592000, nullable=False)
|
||||
id_token_lifetime = db.Column(db.Integer, default=3600, nullable=False)
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="oidc_clients")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCClient."""
|
||||
return f"<OIDCClient {self.name} client_id={self.client_id}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude client secret
|
||||
exclude.append("client_secret_hash")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
def has_grant_type(self, grant_type):
|
||||
"""Check if client supports a specific grant type."""
|
||||
return grant_type in self.grant_types
|
||||
|
||||
def has_response_type(self, response_type):
|
||||
"""Check if client supports a specific response type."""
|
||||
return response_type in self.response_types
|
||||
|
||||
def is_redirect_uri_allowed(self, redirect_uri):
|
||||
"""Check if a redirect URI is allowed for this client."""
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def has_scope(self, scope):
|
||||
"""Check if client is allowed to request a specific scope."""
|
||||
return scope in self.scopes
|
||||
@@ -0,0 +1,77 @@
|
||||
"""OIDC JWKS Key model for persisting signing keys."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OidcJwksKey(BaseModel):
|
||||
"""
|
||||
OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
|
||||
|
||||
This model stores RSA/ECDSA key pairs used for signing OIDC tokens.
|
||||
Multiple keys can be stored to support key rotation scenarios.
|
||||
|
||||
Attributes:
|
||||
id: Integer primary key
|
||||
kid: Unique key ID used in JWT "kid" header
|
||||
key_type: Type of key (e.g., "RSA", "EC")
|
||||
private_key: PEM-encoded private key
|
||||
public_key: PEM-encoded public key
|
||||
algorithm: Signing algorithm (e.g., "RS256", "ES256")
|
||||
created_at: When the key was created
|
||||
is_active: Whether this key is currently active for signing
|
||||
is_primary: Whether this is the primary signing key
|
||||
expires_at: ...
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_jwks_keys"
|
||||
|
||||
# Override the default UUID id with integer primary key
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
|
||||
expires_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Key identification and type
|
||||
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
|
||||
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
|
||||
|
||||
# Key material (PEM-encoded)
|
||||
private_key = db.Column(db.Text, nullable=False)
|
||||
public_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Key status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
is_primary = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OidcJwksKey."""
|
||||
return f"<OidcJwksKey kid={self.kid} key_type={self.key_type} algorithm={self.algorithm}>"
|
||||
|
||||
def to_dict(self, exclude_private_key=True):
|
||||
"""
|
||||
Convert model to dictionary.
|
||||
|
||||
Args:
|
||||
exclude_private_key: If True, excludes the private key from output
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the model
|
||||
"""
|
||||
exclude = ["private_key"] if exclude_private_key else []
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
@classmethod
|
||||
def get_active_keys(cls):
|
||||
"""Get all active keys for signing operations."""
|
||||
return cls.query.filter(cls.is_active == True).all()
|
||||
|
||||
@classmethod
|
||||
def get_primary_key(cls):
|
||||
"""Get the primary signing key."""
|
||||
return cls.query.filter(cls.is_primary == True).first()
|
||||
|
||||
@classmethod
|
||||
def get_key_by_kid(cls, kid):
|
||||
"""Get a key by its key ID."""
|
||||
return cls.query.filter(cls.kid == kid, cls.is_active == True).first()
|
||||
@@ -0,0 +1,163 @@
|
||||
"""OIDC Refresh Token model for token rotation."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCRefreshToken(BaseModel):
|
||||
"""OIDC Refresh Token model for token refresh and rotation.
|
||||
|
||||
Refresh tokens are long-lived credentials used to obtain new access tokens.
|
||||
They support token rotation for enhanced security.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_refresh_tokens"
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token (hashed for security)
|
||||
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
|
||||
# Associated access token ID
|
||||
access_token_id = db.Column(
|
||||
db.String(36), db.ForeignKey("sessions.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Token scope
|
||||
scope = db.Column(db.JSON, nullable=True) # Granted scopes
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
|
||||
# Revocation tracking
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Token rotation metadata
|
||||
previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation
|
||||
rotation_count = db.Column(db.Integer, default=0, nullable=False)
|
||||
|
||||
# Request metadata
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="refresh_tokens")
|
||||
user = db.relationship("User", back_populates="oidc_refresh_tokens")
|
||||
access_token = db.relationship("Session", back_populates="oidc_refresh_token")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCRefreshToken."""
|
||||
return f"<OIDCRefreshToken client_id={self.client_id} user_id={self.user_id} revoked={self.is_revoked()}>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the refresh token has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_revoked(self):
|
||||
"""Check if the refresh token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self):
|
||||
"""Check if the refresh token is valid for use."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""Revoke the refresh token.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked_at = datetime.now(timezone.utc)
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
def rotate(self, new_token_hash):
|
||||
"""Rotate the refresh token (invalidate old, create new).
|
||||
|
||||
Args:
|
||||
new_token_hash: Hash of the new refresh token
|
||||
|
||||
Returns:
|
||||
self for chaining
|
||||
"""
|
||||
# Store reference to old token
|
||||
self.previous_token_hash = self.token_hash
|
||||
self.token_hash = new_token_hash
|
||||
self.rotation_count += 1
|
||||
# Extend expiration on rotation
|
||||
from datetime import timedelta
|
||||
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
db.session.commit()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create_token(cls, client_id, user_id, token_hash, scope=None,
|
||||
access_token_id=None, ip_address=None, user_agent=None,
|
||||
lifetime_seconds=2592000):
|
||||
"""Create a new refresh token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
token_hash: Hashed refresh token
|
||||
scope: Granted scopes
|
||||
access_token_id: Associated access token ID
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Token lifetime in seconds (default 30 days)
|
||||
|
||||
Returns:
|
||||
OIDCRefreshToken instance
|
||||
"""
|
||||
from datetime import timedelta
|
||||
token = cls(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
token_hash=token_hash,
|
||||
scope=scope,
|
||||
access_token_id=access_token_id,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
db.session.add(token)
|
||||
db.session.commit()
|
||||
return token
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude token hashes
|
||||
exclude.append("token_hash")
|
||||
exclude.append("previous_token_hash")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to Session model
|
||||
from gatehouse_app.models.session import Session
|
||||
Session.oidc_refresh_token = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="access_token", uselist=False
|
||||
)
|
||||
@@ -0,0 +1,162 @@
|
||||
"""OIDC Session model for OIDC session tracking."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCSession(BaseModel):
|
||||
"""OIDC Session model for tracking OIDC authentication sessions.
|
||||
|
||||
This model tracks the state during the OIDC authentication flow,
|
||||
including PKCE parameters and nonce validation.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_sessions"
|
||||
|
||||
# User reference
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Client reference
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# State management
|
||||
state = db.Column(db.String(255), nullable=False, index=True)
|
||||
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
|
||||
|
||||
# Authorization request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True) # Requested scopes
|
||||
|
||||
# PKCE parameters
|
||||
code_challenge = db.Column(db.String(255), nullable=True)
|
||||
code_challenge_method = db.Column(db.String(10), nullable=True) # "S256" or "plain"
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
authenticated_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="oidc_sessions")
|
||||
client = db.relationship("OIDCClient", back_populates="oidc_sessions")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCSession."""
|
||||
return f"<OIDCSession user_id={self.user_id} client_id={self.client_id} state={self.state[:8]}...>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the OIDC session has expired."""
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
def is_authenticated(self):
|
||||
"""Check if the user has been authenticated in this session."""
|
||||
return self.authenticated_at is not None
|
||||
|
||||
def mark_authenticated(self):
|
||||
"""Mark the session as authenticated."""
|
||||
self.authenticated_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def validate_nonce(self, expected_nonce):
|
||||
"""Validate the nonce matches the expected value.
|
||||
|
||||
Args:
|
||||
expected_nonce: The expected nonce value
|
||||
|
||||
Returns:
|
||||
bool: True if nonce matches
|
||||
"""
|
||||
return self.nonce == expected_nonce
|
||||
|
||||
def validate_code_challenge(self, code_verifier):
|
||||
"""Validate the code verifier against the stored code challenge.
|
||||
|
||||
Args:
|
||||
code_verifier: The PKCE code verifier
|
||||
|
||||
Returns:
|
||||
bool: True if code challenge is valid
|
||||
"""
|
||||
if not self.code_challenge:
|
||||
return False
|
||||
|
||||
if self.code_challenge_method == "S256":
|
||||
import hashlib
|
||||
import base64
|
||||
# SHA256 hash of code_verifier
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
# Base64 URL encode without padding
|
||||
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return self.code_challenge == expected
|
||||
elif self.code_challenge_method == "plain":
|
||||
return self.code_challenge == code_verifier
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def create_session(cls, user_id, client_id, state, redirect_uri, scope=None,
|
||||
nonce=None, code_challenge=None, code_challenge_method=None,
|
||||
lifetime_seconds=600):
|
||||
"""Create a new OIDC session.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: The OIDC client ID
|
||||
state: The state parameter
|
||||
redirect_uri: The redirect URI
|
||||
scope: Requested scopes
|
||||
nonce: OIDC nonce
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method ("S256" or "plain")
|
||||
lifetime_seconds: Session lifetime in seconds
|
||||
|
||||
Returns:
|
||||
OIDCSession instance
|
||||
"""
|
||||
from datetime import timedelta
|
||||
session = cls(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
state=state,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
nonce=nonce,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||
)
|
||||
db.session.add(session)
|
||||
db.session.commit()
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def get_by_state(cls, state):
|
||||
"""Get a session by state parameter.
|
||||
|
||||
Args:
|
||||
state: The state parameter
|
||||
|
||||
Returns:
|
||||
OIDCSession instance or None
|
||||
"""
|
||||
return cls.query.filter_by(state=state, deleted_at=None).first()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -0,0 +1,196 @@
|
||||
"""OIDC Token Metadata model for token revocation tracking."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCTokenMetadata(BaseModel):
|
||||
"""OIDC Token Metadata model for tracking issued tokens.
|
||||
|
||||
This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens)
|
||||
for the purpose of token revocation. The id field matches the JTI (JWT ID) claim.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_token_metadata"
|
||||
|
||||
# Token identifier (matches JTI in JWT)
|
||||
id = db.Column(
|
||||
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token type
|
||||
token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token"
|
||||
|
||||
# Token identifier for revocation lookup
|
||||
token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
|
||||
# Revocation tracking
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="token_metadata")
|
||||
user = db.relationship("User", back_populates="oidc_token_metadata")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCTokenMetadata."""
|
||||
return f"<OIDCTokenMetadata jti={self.token_jti[:8]}... type={self.token_type} revoked={self.is_revoked()}>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the token has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_revoked(self):
|
||||
"""Check if the token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self):
|
||||
"""Check if the token is valid (not expired and not revoked)."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""Revoke the token.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked_at = datetime.now(timezone.utc)
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_metadata(cls, client_id, user_id, token_type, token_jti,
|
||||
expires_at, ip_address=None, user_agent=None):
|
||||
"""Create token metadata for tracking.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
token_type: Type of token ("access_token", "refresh_token", "id_token")
|
||||
token_jti: JWT ID claim
|
||||
expires_at: Token expiration datetime
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
OIDCTokenMetadata instance
|
||||
"""
|
||||
metadata = cls(
|
||||
id=str(uuid.uuid4()),
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
token_type=token_type,
|
||||
token_jti=token_jti,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.session.add(metadata)
|
||||
db.session.commit()
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def get_by_jti(cls, token_jti):
|
||||
"""Get token metadata by JWT ID.
|
||||
|
||||
Args:
|
||||
token_jti: The JWT ID
|
||||
|
||||
Returns:
|
||||
OIDCTokenMetadata instance or None
|
||||
"""
|
||||
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
|
||||
|
||||
@classmethod
|
||||
def revoke_by_jti(cls, token_jti, reason=None):
|
||||
"""Revoke a token by its JWT ID.
|
||||
|
||||
Args:
|
||||
token_jti: The JWT ID
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
bool: True if token was found and revoked
|
||||
"""
|
||||
metadata = cls.get_by_jti(token_jti)
|
||||
if metadata:
|
||||
metadata.revoke(reason)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_user(cls, user_id, client_id=None, reason=None):
|
||||
"""Revoke all tokens for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: Optional client ID to filter by
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
int: Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(user_id=user_id, deleted_at=None)
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
|
||||
tokens = query.filter(cls.revoked_at == None).all()
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_client(cls, client_id, user_id=None, reason=None):
|
||||
"""Revoke all tokens for a client.
|
||||
|
||||
Args:
|
||||
client_id: The client ID
|
||||
user_id: Optional user ID to filter by
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
int: Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(client_id=client_id, deleted_at=None)
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
|
||||
tokens = query.filter(cls.revoked_at == None).all()
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Organization model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class Organization(BaseModel):
|
||||
"""Organization model representing a tenant/workspace."""
|
||||
|
||||
__tablename__ = "organizations"
|
||||
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
slug = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
logo_url = db.Column(db.String(512), nullable=True)
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
# Settings (stored as JSON)
|
||||
settings = db.Column(db.JSON, nullable=True, default=dict)
|
||||
|
||||
# Relationships
|
||||
members = db.relationship(
|
||||
"OrganizationMember", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_clients = db.relationship(
|
||||
"OIDCClient", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of Organization."""
|
||||
return f"<Organization {self.name}>"
|
||||
|
||||
def get_member_count(self):
|
||||
"""Get the count of active members in the organization."""
|
||||
return len([m for m in self.members if m.deleted_at is None])
|
||||
|
||||
def get_owner(self):
|
||||
"""Get the owner of the organization."""
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
for member in self.members:
|
||||
if member.role == OrganizationRole.OWNER and member.deleted_at is None:
|
||||
return member.user
|
||||
return None
|
||||
|
||||
def is_member(self, user_id):
|
||||
"""Check if a user is a member of the organization."""
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
|
||||
return (
|
||||
OrganizationMember.query.filter_by(
|
||||
user_id=user_id, organization_id=self.id, deleted_at=None
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Organization member model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
|
||||
class OrganizationMember(BaseModel):
|
||||
"""Organization member model representing user membership in an organization."""
|
||||
|
||||
__tablename__ = "organization_members"
|
||||
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True)
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
|
||||
)
|
||||
role = db.Column(
|
||||
db.Enum(OrganizationRole), default=OrganizationRole.MEMBER, nullable=False
|
||||
)
|
||||
invited_by_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
|
||||
invited_at = db.Column(db.DateTime, nullable=True)
|
||||
joined_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", foreign_keys=[user_id], back_populates="organization_memberships")
|
||||
organization = db.relationship("Organization", back_populates="members")
|
||||
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
|
||||
|
||||
# Unique constraint to prevent duplicate memberships
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OrganizationMember."""
|
||||
return f"<OrganizationMember user_id={self.user_id} org_id={self.organization_id} role={self.role}>"
|
||||
|
||||
def is_owner(self):
|
||||
"""Check if member is an owner."""
|
||||
return self.role == OrganizationRole.OWNER
|
||||
|
||||
def is_admin(self):
|
||||
"""Check if member is an admin or owner."""
|
||||
return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
def can_manage_members(self):
|
||||
"""Check if member can manage other members."""
|
||||
return self.is_admin()
|
||||
|
||||
def can_delete_organization(self):
|
||||
"""Check if member can delete the organization."""
|
||||
return self.is_owner()
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Session model."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
"""Session model for tracking user sessions."""
|
||||
|
||||
__tablename__ = "sessions"
|
||||
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True)
|
||||
token = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
status = db.Column(db.Enum(SessionStatus), default=SessionStatus.ACTIVE, nullable=False)
|
||||
|
||||
# Session metadata
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
device_info = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="sessions")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of Session."""
|
||||
return f"<Session user_id={self.user_id} status={self.status}>"
|
||||
|
||||
def is_active(self):
|
||||
"""Check if session is currently active."""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make expires_at timezone-aware if it's naive
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return (
|
||||
self.status == SessionStatus.ACTIVE
|
||||
and expires_at > now
|
||||
and self.deleted_at is None
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if session has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make expires_at timezone-aware if it's naive
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return now > expires_at
|
||||
|
||||
def refresh(self, duration_seconds=86400):
|
||||
"""
|
||||
Refresh session expiration.
|
||||
|
||||
Args:
|
||||
duration_seconds: New session duration in seconds
|
||||
"""
|
||||
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
|
||||
self.last_activity_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""
|
||||
Revoke the session.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.status = SessionStatus.REVOKED
|
||||
self.revoked_at = datetime.now(timezone.utc)
|
||||
if reason:
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Exclude token from dict
|
||||
exclude.append("token")
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,141 @@
|
||||
"""User model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""User model representing a user account."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
email_verified = db.Column(db.Boolean, default=False, nullable=False)
|
||||
full_name = db.Column(db.String(255), nullable=True)
|
||||
avatar_url = db.Column(db.String(512), nullable=True)
|
||||
status = db.Column(
|
||||
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
|
||||
)
|
||||
last_login_at = db.Column(db.DateTime, nullable=True)
|
||||
last_login_ip = db.Column(db.String(45), nullable=True)
|
||||
|
||||
# Relationships
|
||||
authentication_methods = db.relationship(
|
||||
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
|
||||
organization_memberships = db.relationship(
|
||||
"OrganizationMember",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="OrganizationMember.user_id",
|
||||
)
|
||||
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of User."""
|
||||
return f"<User {self.email}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert user to dictionary, excluding sensitive fields by default."""
|
||||
exclude = exclude or []
|
||||
# Always exclude password-related fields
|
||||
default_exclude = []
|
||||
all_exclude = list(set(default_exclude + exclude))
|
||||
return super().to_dict(exclude=all_exclude)
|
||||
|
||||
def has_password_auth(self):
|
||||
"""Check if user has password authentication enabled."""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_organizations(self):
|
||||
"""Get all organizations the user is a member of."""
|
||||
return [membership.organization for membership in self.organization_memberships]
|
||||
|
||||
def has_totp_enabled(self) -> bool:
|
||||
"""Check if user has TOTP enabled and verified.
|
||||
|
||||
Returns:
|
||||
True if user has a verified TOTP authentication method, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.TOTP,
|
||||
verified=True,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_totp_method(self):
|
||||
"""Get user's TOTP authentication method.
|
||||
|
||||
Returns:
|
||||
The AuthenticationMethod instance for TOTP or None if not found.
|
||||
|
||||
Note:
|
||||
Returns the most recently created TOTP method to handle cases where
|
||||
multiple enrollment attempts may exist.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
|
||||
).order_by(AuthenticationMethod.created_at.desc()).first()
|
||||
|
||||
def has_webauthn_enabled(self) -> bool:
|
||||
"""Check if user has any WebAuthn passkey credentials.
|
||||
|
||||
Returns:
|
||||
True if user has at least one WebAuthn credential, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_webauthn_credentials(self):
|
||||
"""Get all WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
).order_by(AuthenticationMethod.created_at.desc()).all()
|
||||
|
||||
def get_webauthn_credential_count(self) -> int:
|
||||
"""Get the count of WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
Number of WebAuthn credentials.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
).count()
|
||||
Reference in New Issue
Block a user