Chore: Refractor Models into organized file/folder

This commit is contained in:
2026-03-01 12:40:48 +05:45
parent 58432da1c8
commit 07193a2d2e
35 changed files with 1475 additions and 932 deletions
+118 -44
View File
@@ -1,76 +1,150 @@
"""Models package.""" """Models package.
from gatehouse_app.models.base import BaseModel
from gatehouse_app.models.user import User Sub-packages
from gatehouse_app.models.organization import Organization ------------
from gatehouse_app.models.organization_member import OrganizationMember models.user — User, Session
from gatehouse_app.models.authentication_method import ( models.organization — Organization, OrganizationMember, Department,
DepartmentMembership, DepartmentPrincipal,
DepartmentCertPolicy, Principal, PrincipalMembership,
OrgInviteToken
models.auth — AuthenticationMethod, ApplicationProviderConfig,
OrganizationProviderOverride, OAuthState,
AuditLog, PasswordResetToken, EmailVerificationToken
models.oidc — OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCSession,
OIDCTokenMetadata, OIDCAuditLog, OidcJwksKey
models.ssh_ca — CA, KeyType, CertType, CaType, CAPermission,
SSHKey, SSHCertificate, CertificateStatus,
CertificateAuditLog
models.security — OrganizationSecurityPolicy, UserSecurityPolicy,
MfaPolicyCompliance
All names are re-exported here so that existing code using the flat import
style (``from gatehouse_app.models import X``) or the old per-file style
(``from gatehouse_app.models.user import User``) continue to work unchanged.
"""
# ── Base ──────────────────────────────────────────────────────────────────────
from gatehouse_app.models.base import BaseModel # noqa: F401
# ── User ──────────────────────────────────────────────────────────────────────
from gatehouse_app.models.user.user import User # noqa: F401
from gatehouse_app.models.user.session import Session # noqa: F401
# ── Organization ──────────────────────────────────────────────────────────────
from gatehouse_app.models.organization.organization import Organization # noqa: F401
from gatehouse_app.models.organization.organization_member import ( # noqa: F401
OrganizationMember,
)
from gatehouse_app.models.organization.department import ( # noqa: F401
Department,
DepartmentMembership,
DepartmentPrincipal,
)
from gatehouse_app.models.organization.department_cert_policy import ( # noqa: F401
DepartmentCertPolicy,
STANDARD_EXTENSIONS,
)
from gatehouse_app.models.organization.principal import ( # noqa: F401
Principal,
PrincipalMembership,
)
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken # noqa: F401
# ── Auth ──────────────────────────────────────────────────────────────────────
from gatehouse_app.models.auth.authentication_method import ( # noqa: F401
AuthenticationMethod, AuthenticationMethod,
ApplicationProviderConfig, ApplicationProviderConfig,
OrganizationProviderOverride, OrganizationProviderOverride,
OAuthState, OAuthState,
) )
from gatehouse_app.models.session import Session from gatehouse_app.models.auth.audit_log import AuditLog # noqa: F401
from gatehouse_app.models.audit_log import AuditLog from gatehouse_app.models.auth.password_reset_token import PasswordResetToken # noqa: F401
from gatehouse_app.models.oidc_client import OIDCClient from gatehouse_app.models.auth.email_verification_token import ( # noqa: F401
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode EmailVerificationToken,
from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken
from gatehouse_app.models.oidc_session import OIDCSession
from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.department import (
Department,
DepartmentMembership,
DepartmentPrincipal,
) )
from gatehouse_app.models.principal import (
Principal, # ── OIDC ──────────────────────────────────────────────────────────────────────
PrincipalMembership, from gatehouse_app.models.oidc.oidc_client import OIDCClient # noqa: F401
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode # noqa: F401
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken # noqa: F401
from gatehouse_app.models.oidc.oidc_session import OIDCSession # noqa: F401
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata # noqa: F401
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog # noqa: F401
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey # noqa: F401
# ── SSH / CA ──────────────────────────────────────────────────────────────────
from gatehouse_app.models.ssh_ca.ca import ( # noqa: F401
CA,
KeyType,
CertType,
CaType,
CAPermission,
)
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey # noqa: F401
from gatehouse_app.models.ssh_ca.ssh_certificate import ( # noqa: F401
SSHCertificate,
CertificateStatus,
)
from gatehouse_app.models.ssh_ca.certificate_audit_log import ( # noqa: F401
CertificateAuditLog,
)
# ── Security ──────────────────────────────────────────────────────────────────
from gatehouse_app.models.security.organization_security_policy import ( # noqa: F401
OrganizationSecurityPolicy,
)
from gatehouse_app.models.security.user_security_policy import ( # noqa: F401
UserSecurityPolicy,
)
from gatehouse_app.models.security.mfa_policy_compliance import ( # noqa: F401
MfaPolicyCompliance,
) )
from gatehouse_app.models.ssh_key import SSHKey
from gatehouse_app.models.ca import CA, KeyType, CertType, CAPermission
from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus
from gatehouse_app.models.certificate_audit_log import CertificateAuditLog
from gatehouse_app.models.password_reset_token import PasswordResetToken
from gatehouse_app.models.email_verification_token import EmailVerificationToken
from gatehouse_app.models.org_invite_token import OrgInviteToken
__all__ = [ __all__ = [
# Base
"BaseModel", "BaseModel",
# User
"User", "User",
"Session",
# Organization
"Organization", "Organization",
"OrganizationMember", "OrganizationMember",
"Department",
"DepartmentMembership",
"DepartmentPrincipal",
"DepartmentCertPolicy",
"STANDARD_EXTENSIONS",
"Principal",
"PrincipalMembership",
"OrgInviteToken",
# Auth
"AuthenticationMethod", "AuthenticationMethod",
"ApplicationProviderConfig", "ApplicationProviderConfig",
"OrganizationProviderOverride", "OrganizationProviderOverride",
"OAuthState", "OAuthState",
"Session",
"AuditLog", "AuditLog",
"PasswordResetToken",
"EmailVerificationToken",
# OIDC
"OIDCClient", "OIDCClient",
"OIDCAuthCode", "OIDCAuthCode",
"OIDCRefreshToken", "OIDCRefreshToken",
"OIDCSession", "OIDCSession",
"OIDCTokenMetadata", "OIDCTokenMetadata",
"OIDCAuditLog", "OIDCAuditLog",
"OrganizationSecurityPolicy", "OidcJwksKey",
"UserSecurityPolicy", # SSH / CA
"MfaPolicyCompliance",
"Department",
"DepartmentMembership",
"DepartmentPrincipal",
"Principal",
"PrincipalMembership",
"SSHKey",
"CA", "CA",
"KeyType", "KeyType",
"CertType", "CertType",
"CaType",
"CAPermission", "CAPermission",
"SSHKey",
"SSHCertificate", "SSHCertificate",
"CertificateStatus", "CertificateStatus",
"CertificateAuditLog", "CertificateAuditLog",
"PasswordResetToken", # Security
"EmailVerificationToken", "OrganizationSecurityPolicy",
"OrgInviteToken", "UserSecurityPolicy",
"MfaPolicyCompliance",
] ]
+20
View File
@@ -0,0 +1,20 @@
"""Auth subpackage — authentication methods, tokens, and audit logs."""
from gatehouse_app.models.auth.authentication_method import (
AuthenticationMethod,
ApplicationProviderConfig,
OrganizationProviderOverride,
OAuthState,
)
from gatehouse_app.models.auth.audit_log import AuditLog
from gatehouse_app.models.auth.password_reset_token import PasswordResetToken
from gatehouse_app.models.auth.email_verification_token import EmailVerificationToken
__all__ = [
"AuthenticationMethod",
"ApplicationProviderConfig",
"OrganizationProviderOverride",
"OAuthState",
"AuditLog",
"PasswordResetToken",
"EmailVerificationToken",
]
@@ -26,14 +26,13 @@ class AuditLog(BaseModel):
extra_data = db.Column(db.JSON, nullable=True) extra_data = db.Column(db.JSON, nullable=True)
description = db.Column(db.Text, nullable=True) description = db.Column(db.Text, nullable=True)
# Success/failure # Outcome
success = db.Column(db.Boolean, default=True, nullable=False) success = db.Column(db.Boolean, default=True, nullable=False)
error_message = db.Column(db.Text, nullable=True) error_message = db.Column(db.Text, nullable=True)
# Relationships # Relationships
user = db.relationship("User", back_populates="audit_logs") user = db.relationship("User", back_populates="audit_logs")
# Indexes for common queries
__table_args__ = ( __table_args__ = (
db.Index("idx_audit_user_action", "user_id", "action"), db.Index("idx_audit_user_action", "user_id", "action"),
db.Index("idx_audit_resource", "resource_type", "resource_id"), db.Index("idx_audit_resource", "resource_type", "resource_id"),
@@ -45,9 +44,8 @@ class AuditLog(BaseModel):
return f"<AuditLog action={self.action} user_id={self.user_id}>" return f"<AuditLog action={self.action} user_id={self.user_id}>"
@classmethod @classmethod
def log(cls, action, user_id=None, **kwargs): def log(cls, action, user_id=None, **kwargs) -> "AuditLog":
""" """Create an audit log entry.
Create an audit log entry.
Args: Args:
action: AuditAction enum value action: AuditAction enum value
@@ -1,4 +1,4 @@
"""Authentication method model.""" """Authentication method model — user credentials and OAuth provider config."""
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
import secrets import secrets
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
@@ -35,7 +35,6 @@ class AuthenticationMethod(BaseModel):
# Relationships # Relationships
user = db.relationship("User", back_populates="authentication_methods") user = db.relationship("User", back_populates="authentication_methods")
# Ensure unique provider combinations
__table_args__ = ( __table_args__ = (
db.Index("idx_user_method", "user_id", "method_type"), db.Index("idx_user_method", "user_id", "method_type"),
db.UniqueConstraint( db.UniqueConstraint(
@@ -45,13 +44,15 @@ class AuthenticationMethod(BaseModel):
def __repr__(self): def __repr__(self):
"""String representation of AuthenticationMethod.""" """String representation of AuthenticationMethod."""
return f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>" return (
f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
)
def is_password(self): def is_password(self) -> bool:
"""Check if this is a password authentication method.""" """Check if this is a password authentication method."""
return self.method_type == AuthMethodType.PASSWORD return self.method_type == AuthMethodType.PASSWORD
def is_oauth(self): def is_oauth(self) -> bool:
"""Check if this is an OAuth authentication method.""" """Check if this is an OAuth authentication method."""
return self.method_type in [ return self.method_type in [
AuthMethodType.GOOGLE, AuthMethodType.GOOGLE,
@@ -59,28 +60,28 @@ class AuthenticationMethod(BaseModel):
AuthMethodType.MICROSOFT, AuthMethodType.MICROSOFT,
] ]
def is_totp(self): def is_totp(self) -> bool:
"""Check if this is a TOTP authentication method.""" """Check if this is a TOTP authentication method."""
return self.method_type == AuthMethodType.TOTP return self.method_type == AuthMethodType.TOTP
def is_webauthn(self): def is_webauthn(self) -> bool:
"""Check if this is a WebAuthn authentication method.""" """Check if this is a WebAuthn authentication method."""
return self.method_type == AuthMethodType.WEBAUTHN return self.method_type == AuthMethodType.WEBAUTHN
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields.""" """Convert to dictionary, excluding sensitive fields."""
exclude = exclude or [] exclude = exclude or []
# Always exclude password hash and TOTP secrets # Always exclude credential material
exclude.append("password_hash") for field in ("password_hash", "totp_secret", "totp_backup_codes"):
exclude.append("totp_secret") if field not in exclude:
exclude.append("totp_backup_codes") exclude.append(field)
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
def to_webauthn_dict(self): def to_webauthn_dict(self):
"""Convert WebAuthn credential to public dictionary. """Convert WebAuthn credential to public dictionary.
Returns: Returns:
Dictionary with safe-to-expose credential information. Dictionary with safe-to-expose credential information, or None.
""" """
if not self.is_webauthn() or not self.provider_data: if not self.is_webauthn() or not self.provider_data:
return None return None
@@ -99,8 +100,8 @@ class AuthenticationMethod(BaseModel):
class ApplicationProviderConfig(BaseModel): class ApplicationProviderConfig(BaseModel):
"""Application-wide OAuth provider configuration. """Application-wide OAuth provider configuration.
This model stores OAuth provider credentials at the application level, Stores OAuth provider credentials at the application level, allowing users
allowing users to authenticate without needing to specify an organization first. to authenticate without needing to specify an organization first.
""" """
__tablename__ = "application_provider_configs" __tablename__ = "application_provider_configs"
@@ -108,7 +109,7 @@ class ApplicationProviderConfig(BaseModel):
# Provider identification # Provider identification
provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True) provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True)
# OAuth credentials (encrypted) # OAuth credentials (client_secret encrypted at rest)
client_id = db.Column(db.String(255), nullable=False) client_id = db.Column(db.String(255), nullable=False)
client_secret_encrypted = db.Column(db.String(512), nullable=True) client_secret_encrypted = db.Column(db.String(512), nullable=True)
@@ -126,15 +127,21 @@ class ApplicationProviderConfig(BaseModel):
"OrganizationProviderOverride", "OrganizationProviderOverride",
back_populates="application_config", back_populates="application_config",
foreign_keys="OrganizationProviderOverride.provider_type", foreign_keys="OrganizationProviderOverride.provider_type",
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type", primaryjoin=(
cascade="all, delete-orphan" "ApplicationProviderConfig.provider_type"
"==OrganizationProviderOverride.provider_type"
),
cascade="all, delete-orphan",
) )
def __repr__(self): def __repr__(self):
"""String representation of ApplicationProviderConfig.""" """String representation of ApplicationProviderConfig."""
return f"<ApplicationProviderConfig provider={self.provider_type} enabled={self.is_enabled}>" return (
f"<ApplicationProviderConfig provider={self.provider_type} "
f"enabled={self.is_enabled}>"
)
def set_client_secret(self, plaintext_secret: str): def set_client_secret(self, plaintext_secret: str) -> None:
"""Encrypt and store client secret. """Encrypt and store client secret.
Args: Args:
@@ -143,11 +150,11 @@ class ApplicationProviderConfig(BaseModel):
if plaintext_secret: if plaintext_secret:
self.client_secret_encrypted = encrypt(plaintext_secret) self.client_secret_encrypted = encrypt(plaintext_secret)
def get_client_secret(self) -> str: def get_client_secret(self) -> str | None:
"""Decrypt and return client secret. """Decrypt and return client secret.
Returns: Returns:
The plaintext OAuth client secret The plaintext OAuth client secret, or None if not set.
""" """
if self.client_secret_encrypted: if self.client_secret_encrypted:
return decrypt(self.client_secret_encrypted) return decrypt(self.client_secret_encrypted)
@@ -156,7 +163,7 @@ class ApplicationProviderConfig(BaseModel):
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields.""" """Convert to dictionary, excluding sensitive fields."""
exclude = exclude or [] exclude = exclude or []
# Always exclude encrypted client secret if "client_secret_encrypted" not in exclude:
exclude.append("client_secret_encrypted") exclude.append("client_secret_encrypted")
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
@@ -164,20 +171,21 @@ class ApplicationProviderConfig(BaseModel):
class OrganizationProviderOverride(BaseModel): class OrganizationProviderOverride(BaseModel):
"""Organization-specific OAuth configuration overrides. """Organization-specific OAuth configuration overrides.
This model allows organizations to override application-level OAuth settings Allows organizations to override application-level OAuth settings for
for enterprise SSO scenarios or custom provider configurations. enterprise SSO scenarios or custom provider configurations.
""" """
__tablename__ = "organization_provider_overrides" __tablename__ = "organization_provider_overrides"
# References
organization_id = db.Column( organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"), db.String(36),
nullable=False, index=True db.ForeignKey("organizations.id"),
nullable=False,
index=True,
) )
provider_type = db.Column(db.String(50), nullable=False, index=True) provider_type = db.Column(db.String(50), nullable=False, index=True)
# Override OAuth credentials (encrypted, nullable - only if overriding) # Override OAuth credentials (encrypted, nullable only set when overriding)
client_id = db.Column(db.String(255), nullable=True) client_id = db.Column(db.String(255), nullable=True)
client_secret_encrypted = db.Column(db.String(512), nullable=True) client_secret_encrypted = db.Column(db.String(512), nullable=True)
@@ -196,37 +204,33 @@ class OrganizationProviderOverride(BaseModel):
"ApplicationProviderConfig", "ApplicationProviderConfig",
back_populates="organization_overrides", back_populates="organization_overrides",
foreign_keys=[provider_type], foreign_keys=[provider_type],
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type", primaryjoin=(
viewonly=True "ApplicationProviderConfig.provider_type"
"==OrganizationProviderOverride.provider_type"
),
viewonly=True,
) )
# Unique constraint on (organization_id, provider_type)
__table_args__ = ( __table_args__ = (
db.UniqueConstraint( db.UniqueConstraint(
"organization_id", "provider_type", "organization_id", "provider_type", name="uix_org_provider_type"
name="uix_org_provider_type"
), ),
) )
def __repr__(self): def __repr__(self):
"""String representation of OrganizationProviderOverride.""" """String representation of OrganizationProviderOverride."""
return f"<OrganizationProviderOverride org={self.organization_id} provider={self.provider_type}>" return (
f"<OrganizationProviderOverride org={self.organization_id} "
f"provider={self.provider_type}>"
)
def set_client_secret(self, plaintext_secret: str): def set_client_secret(self, plaintext_secret: str) -> None:
"""Encrypt and store client secret override. """Encrypt and store client secret override."""
Args:
plaintext_secret: The plaintext OAuth client secret
"""
if plaintext_secret: if plaintext_secret:
self.client_secret_encrypted = encrypt(plaintext_secret) self.client_secret_encrypted = encrypt(plaintext_secret)
def get_client_secret(self) -> str: def get_client_secret(self) -> str | None:
"""Decrypt and return client secret override. """Decrypt and return client secret override."""
Returns:
The plaintext OAuth client secret
"""
if self.client_secret_encrypted: if self.client_secret_encrypted:
return decrypt(self.client_secret_encrypted) return decrypt(self.client_secret_encrypted)
return None return None
@@ -234,7 +238,7 @@ class OrganizationProviderOverride(BaseModel):
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields.""" """Convert to dictionary, excluding sensitive fields."""
exclude = exclude or [] exclude = exclude or []
# Always exclude encrypted client secret if "client_secret_encrypted" not in exclude:
exclude.append("client_secret_encrypted") exclude.append("client_secret_encrypted")
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
@@ -242,9 +246,9 @@ class OrganizationProviderOverride(BaseModel):
class OAuthState(BaseModel): class OAuthState(BaseModel):
"""OAuth flow state tracking. """OAuth flow state tracking.
This model tracks OAuth authentication flow state, including PKCE parameters Tracks OAuth authentication flow state, including PKCE parameters and
and organization context (which is now optional to support login flows where organization context (which is optional to support login flows where the
the organization isn't known until after authentication). organization isn't known until after authentication).
""" """
__tablename__ = "oauth_states" __tablename__ = "oauth_states"
@@ -258,13 +262,12 @@ class OAuthState(BaseModel):
# Provider type # Provider type
provider_type = db.Column(db.String(50), nullable=False) provider_type = db.Column(db.String(50), nullable=False)
# User context (optional - not set for login/register flows) # User context (optional not set for login/register flows)
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True) user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
# Organization context (NOW OPTIONAL - for SSO discovery or post-auth) # Organization context (optional — for SSO discovery or post-auth)
organization_id = db.Column( organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"), db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True
nullable=True, index=True
) )
# PKCE parameters # PKCE parameters
@@ -291,7 +294,10 @@ class OAuthState(BaseModel):
def __repr__(self): def __repr__(self):
"""String representation of OAuthState.""" """String representation of OAuthState."""
return f"<OAuthState state={self.state[:8]}... flow={self.flow_type} provider={self.provider_type}>" return (
f"<OAuthState state={self.state[:8]}... "
f"flow={self.flow_type} provider={self.provider_type}>"
)
@classmethod @classmethod
def create_state( def create_state(
@@ -306,9 +312,9 @@ class OAuthState(BaseModel):
code_challenge: str = None, code_challenge: str = None,
nonce: str = None, nonce: str = None,
extra_data: dict = None, extra_data: dict = None,
lifetime_seconds: int = 600 lifetime_seconds: int = 600,
): ) -> "OAuthState":
"""Create a new OAuth state with auto-generated state parameter. """Create a new OAuth state with an auto-generated state parameter.
Args: Args:
flow_type: Type of flow ("login", "register", "link") flow_type: Type of flow ("login", "register", "link")
@@ -342,7 +348,7 @@ class OAuthState(BaseModel):
nonce=nonce, nonce=nonce,
extra_data=extra_data, extra_data=extra_data,
expires_at=expires_at, expires_at=expires_at,
used=False used=False,
) )
oauth_state.save() oauth_state.save()
return oauth_state return oauth_state
@@ -351,22 +357,21 @@ class OAuthState(BaseModel):
"""Check if the OAuth state is still valid. """Check if the OAuth state is still valid.
Returns: Returns:
True if state hasn't expired and hasn't been used True if state hasn't expired and hasn't been used.
""" """
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive (database returns naive datetimes)
expires_at = self.expires_at expires_at = self.expires_at
if expires_at.tzinfo is None: if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc) expires_at = expires_at.replace(tzinfo=timezone.utc)
return not self.used and expires_at > now return not self.used and expires_at > now
def mark_used(self): def mark_used(self) -> None:
"""Mark the state as used to prevent replay attacks.""" """Mark the state as used to prevent replay attacks."""
self.used = True self.used = True
self.save() self.save()
@classmethod @classmethod
def cleanup_expired(cls): def cleanup_expired(cls) -> None:
"""Remove expired OAuth states.""" """Remove expired OAuth states."""
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
cls.query.filter(cls.expires_at < now).delete() cls.query.filter(cls.expires_at < now).delete()
@@ -375,6 +380,7 @@ class OAuthState(BaseModel):
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields.""" """Convert to dictionary, excluding sensitive fields."""
exclude = exclude or [] exclude = exclude or []
# Exclude code_verifier as it's sensitive # code_verifier must never be exposed
if "code_verifier" not in exclude:
exclude.append("code_verifier") exclude.append("code_verifier")
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
@@ -0,0 +1,68 @@
"""Email verification token model."""
import secrets
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class EmailVerificationToken(BaseModel):
"""Single-use token for verifying a user's email address."""
__tablename__ = "email_verification_tokens"
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
used_at = db.Column(db.DateTime, nullable=True)
user = db.relationship(
"User",
backref=db.backref("email_verification_tokens", cascade="all, delete-orphan"),
)
@classmethod
def generate(cls, user_id: str, ttl_hours: int = 24) -> "EmailVerificationToken":
"""Create a new verification token for a user.
Any existing unused tokens for this user are invalidated first.
"""
cls.query.filter_by(user_id=user_id, used_at=None).delete()
db.session.flush()
token_value = secrets.token_urlsafe(48)
instance = cls(
user_id=user_id,
token=token_value,
expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
)
db.session.add(instance)
db.session.commit()
return instance
@property
def is_valid(self) -> bool:
"""Return True if the token has not been used and has not expired."""
if self.used_at is not None:
return False
now = datetime.now(timezone.utc)
expires = self.expires_at
if expires.tzinfo is None:
expires = expires.replace(tzinfo=timezone.utc)
return now < expires
def consume(self) -> None:
"""Mark the token as used."""
self.used_at = datetime.now(timezone.utc)
db.session.commit()
def __repr__(self) -> str:
return (
f"<EmailVerificationToken user_id={self.user_id} "
f"used={self.used_at is not None}>"
)
@@ -0,0 +1,69 @@
"""Password reset token model."""
import secrets
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class PasswordResetToken(BaseModel):
"""Single-use token for resetting a user's password."""
__tablename__ = "password_reset_tokens"
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
used_at = db.Column(db.DateTime, nullable=True)
user = db.relationship(
"User",
backref=db.backref("password_reset_tokens", cascade="all, delete-orphan"),
)
@classmethod
def generate(cls, user_id: str, ttl_hours: int = 2) -> "PasswordResetToken":
"""Create a new password reset token for a user.
Any existing unused tokens for this user are invalidated first.
"""
# Invalidate any existing unused tokens for this user
cls.query.filter_by(user_id=user_id, used_at=None).delete()
db.session.flush()
token_value = secrets.token_urlsafe(48)
instance = cls(
user_id=user_id,
token=token_value,
expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
)
db.session.add(instance)
db.session.commit()
return instance
@property
def is_valid(self) -> bool:
"""Return True if the token has not been used and has not expired."""
if self.used_at is not None:
return False
now = datetime.now(timezone.utc)
expires = self.expires_at
if expires.tzinfo is None:
expires = expires.replace(tzinfo=timezone.utc)
return now < expires
def consume(self) -> None:
"""Mark the token as used."""
self.used_at = datetime.now(timezone.utc)
db.session.commit()
def __repr__(self) -> str:
return (
f"<PasswordResetToken user_id={self.user_id} "
f"used={self.used_at is not None}>"
)
+18
View File
@@ -0,0 +1,18 @@
"""OIDC subpackage — clients, tokens, sessions, and audit logs."""
from gatehouse_app.models.oidc.oidc_client import OIDCClient
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken
from gatehouse_app.models.oidc.oidc_session import OIDCSession
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey
__all__ = [
"OIDCClient",
"OIDCAuthCode",
"OIDCRefreshToken",
"OIDCSession",
"OIDCTokenMetadata",
"OIDCAuditLog",
"OidcJwksKey",
]
@@ -1,5 +1,4 @@
"""OIDC Audit Log model for comprehensive OIDC event tracking.""" """OIDC Audit Log model for comprehensive OIDC event tracking."""
from datetime import datetime
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
@@ -7,8 +6,7 @@ from gatehouse_app.models.base import BaseModel
class OIDCAuditLog(BaseModel): class OIDCAuditLog(BaseModel):
"""OIDC Audit Log model for comprehensive OIDC event tracking. """OIDC Audit Log model for comprehensive OIDC event tracking.
This model logs all OIDC-related events for security, compliance, Logs all OIDC-related events for security, compliance, and debugging.
and debugging purposes.
""" """
__tablename__ = "oidc_audit_logs" __tablename__ = "oidc_audit_logs"
@@ -46,16 +44,29 @@ class OIDCAuditLog(BaseModel):
def __repr__(self): def __repr__(self):
"""String representation of OIDCAuditLog.""" """String representation of OIDCAuditLog."""
status = "success" if self.success else "failed" status = "success" if self.success else "failed"
return f"<OIDCAuditLog event={self.event_type} status={status} client={self.client_id}>" return (
f"<OIDCAuditLog event={self.event_type} "
f"status={status} client={self.client_id}>"
)
@classmethod @classmethod
def log_event(cls, event_type, client_id=None, user_id=None, success=True, def log_event(
error_code=None, error_description=None, ip_address=None, cls,
user_agent=None, request_id=None, event_metadata=None): event_type: str,
client_id: str = None,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
event_metadata: dict = None,
) -> "OIDCAuditLog":
"""Log an OIDC event. """Log an OIDC event.
Args: Args:
event_type: Type of event (e.g., "authorization_request", "token_issue") event_type: Type of event (e.g., "authorization_request")
client_id: The OIDC client ID client_id: The OIDC client ID
user_id: The user ID user_id: The user ID
success: Whether the event was successful success: Whether the event was successful
@@ -86,9 +97,19 @@ class OIDCAuditLog(BaseModel):
return log return log
@classmethod @classmethod
def log_authorization_request(cls, client_id, user_id, redirect_uri, scope, def log_authorization_request(
ip_address=None, user_agent=None, request_id=None, cls,
success=True, error_code=None, error_description=None): client_id: str,
user_id: str,
redirect_uri: str,
scope,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
) -> "OIDCAuditLog":
"""Log an authorization request event.""" """Log an authorization request event."""
return cls.log_event( return cls.log_event(
event_type="authorization_request", event_type="authorization_request",
@@ -100,15 +121,19 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address, ip_address=ip_address,
user_agent=user_agent, user_agent=user_agent,
request_id=request_id, request_id=request_id,
event_metadata={ event_metadata={"redirect_uri": redirect_uri, "scope": scope},
"redirect_uri": redirect_uri,
"scope": scope,
}
) )
@classmethod @classmethod
def log_token_issue(cls, client_id, user_id, token_type, def log_token_issue(
ip_address=None, user_agent=None, request_id=None): cls,
client_id: str,
user_id: str,
token_type: str,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log a token issuance event.""" """Log a token issuance event."""
return cls.log_event( return cls.log_event(
event_type="token_issue", event_type="token_issue",
@@ -118,12 +143,20 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address, ip_address=ip_address,
user_agent=user_agent, user_agent=user_agent,
request_id=request_id, request_id=request_id,
event_metadata={"token_type": token_type} event_metadata={"token_type": token_type},
) )
@classmethod @classmethod
def log_token_revocation(cls, client_id, user_id, token_type, reason=None, def log_token_revocation(
ip_address=None, user_agent=None, request_id=None): cls,
client_id: str,
user_id: str,
token_type: str,
reason: str = None,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log a token revocation event.""" """Log a token revocation event."""
return cls.log_event( return cls.log_event(
event_type="token_revocation", event_type="token_revocation",
@@ -133,15 +166,19 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address, ip_address=ip_address,
user_agent=user_agent, user_agent=user_agent,
request_id=request_id, request_id=request_id,
event_metadata={ event_metadata={"token_type": token_type, "reason": reason},
"token_type": token_type,
"reason": reason,
}
) )
@classmethod @classmethod
def log_authentication_failure(cls, client_id, error_code, error_description, def log_authentication_failure(
ip_address=None, user_agent=None, request_id=None): cls,
client_id: str,
error_code: str,
error_description: str,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log an authentication failure event.""" """Log an authentication failure event."""
return cls.log_event( return cls.log_event(
event_type="authentication_failure", event_type="authentication_failure",
@@ -155,7 +192,7 @@ class OIDCAuditLog(BaseModel):
) )
@classmethod @classmethod
def get_events_for_user(cls, user_id, limit=100): def get_events_for_user(cls, user_id: str, limit: int = 100) -> list:
"""Get audit events for a user. """Get audit events for a user.
Args: Args:
@@ -165,13 +202,15 @@ class OIDCAuditLog(BaseModel):
Returns: Returns:
List of OIDCAuditLog instances List of OIDCAuditLog instances
""" """
return cls.query.filter_by(user_id=user_id, deleted_at=None)\ return (
.order_by(cls.created_at.desc())\ cls.query.filter_by(user_id=user_id, deleted_at=None)
.limit(limit)\ .order_by(cls.created_at.desc())
.limit(limit)
.all() .all()
)
@classmethod @classmethod
def get_events_for_client(cls, client_id, limit=100): def get_events_for_client(cls, client_id: str, limit: int = 100) -> list:
"""Get audit events for a client. """Get audit events for a client.
Args: Args:
@@ -181,14 +220,22 @@ class OIDCAuditLog(BaseModel):
Returns: Returns:
List of OIDCAuditLog instances List of OIDCAuditLog instances
""" """
return cls.query.filter_by(client_id=client_id, deleted_at=None)\ return (
.order_by(cls.created_at.desc())\ cls.query.filter_by(client_id=client_id, deleted_at=None)
.limit(limit)\ .order_by(cls.created_at.desc())
.limit(limit)
.all() .all()
)
@classmethod @classmethod
def get_failed_events(cls, client_id=None, user_id=None, start_date=None, def get_failed_events(
end_date=None, limit=100): cls,
client_id: str = None,
user_id: str = None,
start_date=None,
end_date=None,
limit: int = 100,
) -> list:
"""Get failed audit events. """Get failed audit events.
Args: Args:
@@ -210,22 +257,8 @@ class OIDCAuditLog(BaseModel):
query = query.filter(cls.created_at >= start_date) query = query.filter(cls.created_at >= start_date)
if end_date: if end_date:
query = query.filter(cls.created_at <= end_date) query = query.filter(cls.created_at <= end_date)
return query.order_by(cls.created_at.desc()).limit(limit).all() return query.order_by(cls.created_at.desc()).limit(limit).all()
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary.""" """Convert to dictionary."""
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_audit_logs = db.relationship(
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.audit_logs = db.relationship(
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
)
@@ -1,14 +1,14 @@
"""OIDC Authorization Code model for auth code flow.""" """OIDC Authorization Code model for the authorization code grant flow."""
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
class OIDCAuthCode(BaseModel): class OIDCAuthCode(BaseModel):
"""OIDC Authorization Code model for authorization code flow. """OIDC Authorization Code model for the authorization code grant flow.
Authorization codes are single-use, short-lived codes used in the Authorization codes are single-use, short-lived codes. The code itself is
authorization code grant flow. The code is hashed for security. hashed before storage so that a database breach cannot replay codes.
""" """
__tablename__ = "oidc_authorization_codes" __tablename__ = "oidc_authorization_codes"
@@ -26,9 +26,9 @@ class OIDCAuthCode(BaseModel):
# Request parameters # Request parameters
redirect_uri = db.Column(db.String(512), nullable=False) redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True) # Requested scopes scope = db.Column(db.JSON, nullable=True)
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation nonce = db.Column(db.String(255), nullable=True)
code_verifier = db.Column(db.String(255), nullable=True) # For PKCE code_verifier = db.Column(db.String(255), nullable=True)
# Status tracking # Status tracking
expires_at = db.Column(db.DateTime, nullable=False, index=True) expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -39,37 +39,48 @@ class OIDCAuthCode(BaseModel):
ip_address = db.Column(db.String(45), nullable=True) ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True) user_agent = db.Column(db.Text, nullable=True)
# Relationships # Relationships — back_populates declared on User and OIDCClient
client = db.relationship("OIDCClient", back_populates="authorization_codes") client = db.relationship("OIDCClient", back_populates="authorization_codes")
user = db.relationship("User", back_populates="oidc_auth_codes") user = db.relationship("User", back_populates="oidc_auth_codes")
def __repr__(self): def __repr__(self):
"""String representation of OIDCAuthCode.""" """String representation of OIDCAuthCode."""
return f"<OIDCAuthCode client_id={self.client_id} user_id={self.user_id} used={self.is_used}>" return (
f"<OIDCAuthCode client_id={self.client_id} "
f"user_id={self.user_id} used={self.is_used}>"
)
def is_expired(self): def is_expired(self) -> bool:
"""Check if the authorization code has expired.""" """Check if the authorization code has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at expires_at = self.expires_at
if expires_at.tzinfo is None: if expires_at.tzinfo is None:
# Make naive datetime timezone-aware (UTC)
expires_at = expires_at.replace(tzinfo=timezone.utc) expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at return datetime.now(timezone.utc) > expires_at
def is_valid(self): def is_valid(self) -> bool:
"""Check if the authorization code is valid for use.""" """Check if the authorization code is valid for use."""
return not self.is_used and not self.is_expired() and self.deleted_at is None return not self.is_used and not self.is_expired() and self.deleted_at is None
def mark_as_used(self): def mark_as_used(self) -> None:
"""Mark the authorization code as used.""" """Mark the authorization code as used."""
self.is_used = True self.is_used = True
self.used_at = datetime.now(timezone.utc) self.used_at = datetime.now(timezone.utc)
db.session.commit() db.session.commit()
@classmethod @classmethod
def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None, def create_code(
nonce=None, code_verifier=None, ip_address=None, user_agent=None, cls,
lifetime_seconds=600): client_id: str,
user_id: str,
code_hash: str,
redirect_uri: str,
scope=None,
nonce: str = None,
code_verifier: str = None,
ip_address: str = None,
user_agent: str = None,
lifetime_seconds: int = 600,
) -> "OIDCAuthCode":
"""Create a new authorization code. """Create a new authorization code.
Args: Args:
@@ -79,7 +90,7 @@ class OIDCAuthCode(BaseModel):
redirect_uri: The redirect URI redirect_uri: The redirect URI
scope: Requested scopes scope: Requested scopes
nonce: OIDC nonce nonce: OIDC nonce
code_verifier: PKCE code verifier code_verifier: PKCE code verifier (stored hashed server-side)
ip_address: Client IP address ip_address: Client IP address
user_agent: Client user agent user_agent: Client user agent
lifetime_seconds: Code lifetime in seconds (default 10 minutes) lifetime_seconds: Code lifetime in seconds (default 10 minutes)
@@ -106,20 +117,7 @@ class OIDCAuthCode(BaseModel):
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields.""" """Convert to dictionary, excluding sensitive fields."""
exclude = exclude or [] exclude = exclude or []
# Always exclude code hash for field in ("code_hash", "code_verifier"):
exclude.append("code_hash") if field not in exclude:
exclude.append("code_verifier") exclude.append(field)
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_auth_codes = db.relationship(
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.authorization_codes = db.relationship(
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
)
@@ -17,10 +17,10 @@ class OIDCClient(BaseModel):
client_secret_hash = db.Column(db.String(255), nullable=False) client_secret_hash = db.Column(db.String(255), nullable=False)
# OAuth/OIDC configuration # OAuth/OIDC configuration
redirect_uris = db.Column(db.JSON, nullable=False) # List of allowed redirect URIs redirect_uris = db.Column(db.JSON, nullable=False) # Allowed redirect URIs
grant_types = db.Column(db.JSON, nullable=False) # List of allowed grant types grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
response_types = db.Column(db.JSON, nullable=False) # List of allowed response types response_types = db.Column(db.JSON, nullable=False) # Allowed response types
scopes = db.Column(db.JSON, nullable=False) # List of allowed scopes scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
# Client metadata # Client metadata
logo_uri = db.Column(db.String(512), nullable=True) logo_uri = db.Column(db.String(512), nullable=True)
@@ -41,6 +41,23 @@ class OIDCClient(BaseModel):
# Relationships # Relationships
organization = db.relationship("Organization", back_populates="oidc_clients") organization = db.relationship("Organization", back_populates="oidc_clients")
# OIDC sub-resource relationships (declared here, not monkey-patched elsewhere)
authorization_codes = db.relationship(
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
)
refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
)
oidc_sessions = db.relationship(
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
)
token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
)
audit_logs = db.relationship(
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
)
def __repr__(self): def __repr__(self):
"""String representation of OIDCClient.""" """String representation of OIDCClient."""
return f"<OIDCClient {self.name} client_id={self.client_id}>" return f"<OIDCClient {self.name} client_id={self.client_id}>"
@@ -48,22 +65,22 @@ class OIDCClient(BaseModel):
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields.""" """Convert to dictionary, excluding sensitive fields."""
exclude = exclude or [] exclude = exclude or []
# Always exclude client secret if "client_secret_hash" not in exclude:
exclude.append("client_secret_hash") exclude.append("client_secret_hash")
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
def has_grant_type(self, grant_type): def has_grant_type(self, grant_type) -> bool:
"""Check if client supports a specific grant type.""" """Check if client supports a specific grant type."""
return grant_type in self.grant_types return grant_type in self.grant_types
def has_response_type(self, response_type): def has_response_type(self, response_type) -> bool:
"""Check if client supports a specific response type.""" """Check if client supports a specific response type."""
return response_type in self.response_types return response_type in self.response_types
def is_redirect_uri_allowed(self, redirect_uri): def is_redirect_uri_allowed(self, redirect_uri: str) -> bool:
"""Check if a redirect URI is allowed for this client.""" """Check if a redirect URI is allowed for this client."""
return redirect_uri in self.redirect_uris return redirect_uri in self.redirect_uris
def has_scope(self, scope): def has_scope(self, scope: str) -> bool:
"""Check if client is allowed to request a specific scope.""" """Check if client is allowed to request a specific scope."""
return scope in self.scopes return scope in self.scopes
@@ -0,0 +1,76 @@
"""OIDC JWKS Key model for persisting signing keys."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OidcJwksKey(BaseModel):
"""OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
Stores RSA/ECDSA key pairs used for signing OIDC tokens. Multiple keys can
be stored to support key rotation scenarios.
Attributes:
kid: Unique key ID used in JWT ``kid`` header
key_type: Type of key (e.g., "RSA", "EC")
private_key: PEM-encoded private key (never exposed in API responses)
public_key: PEM-encoded public key
algorithm: Signing algorithm (e.g., "RS256", "ES256")
is_active: Whether this key is currently used for signing/verification
is_primary: Whether this is the primary signing key
expires_at: Optional expiry for key rotation enforcement
"""
__tablename__ = "oidc_jwks_keys"
# Override the default UUID id with integer primary key for JWKS key sets
id = db.Column(db.Integer, primary_key=True)
expires_at = db.Column(db.DateTime, nullable=True)
# Key identification and type
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
# Key material (PEM-encoded) — private_key must never be returned by API
private_key = db.Column(db.Text, nullable=False)
public_key = db.Column(db.Text, nullable=False)
# Key status
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_primary = db.Column(db.Boolean, default=False, nullable=False)
def __repr__(self):
"""String representation of OidcJwksKey."""
return (
f"<OidcJwksKey kid={self.kid} "
f"key_type={self.key_type} algorithm={self.algorithm}>"
)
def to_dict(self, exclude_private_key: bool = True):
"""Convert model to dictionary.
Args:
exclude_private_key: If True (default), excludes the private key.
Returns:
Dictionary representation of the model
"""
exclude = ["private_key"] if exclude_private_key else []
return super().to_dict(exclude=exclude)
@classmethod
def get_active_keys(cls) -> list:
"""Get all active keys for signing operations."""
return cls.query.filter_by(is_active=True).all()
@classmethod
def get_primary_key(cls) -> "OidcJwksKey | None":
"""Get the primary signing key."""
return cls.query.filter_by(is_primary=True).first()
@classmethod
def get_key_by_kid(cls, kid: str) -> "OidcJwksKey | None":
"""Get an active key by its key ID."""
return cls.query.filter_by(kid=kid, is_active=True).first()
@@ -1,5 +1,5 @@
"""OIDC Refresh Token model for token rotation.""" """OIDC Refresh Token model for token rotation."""
from datetime import datetime, timezone from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
@@ -8,7 +8,8 @@ class OIDCRefreshToken(BaseModel):
"""OIDC Refresh Token model for token refresh and rotation. """OIDC Refresh Token model for token refresh and rotation.
Refresh tokens are long-lived credentials used to obtain new access tokens. Refresh tokens are long-lived credentials used to obtain new access tokens.
They support token rotation for enhanced security. They support token rotation for enhanced security each use invalidates
the old token and issues a new one.
""" """
__tablename__ = "oidc_refresh_tokens" __tablename__ = "oidc_refresh_tokens"
@@ -21,16 +22,14 @@ class OIDCRefreshToken(BaseModel):
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
) )
# Token (hashed for security) # Token (hashed for security — never store plaintext refresh tokens)
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True) token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
# Associated access token ID (stores JWT JTI string — no FK to sessions) # Associated access token JTI (no FK — stored as string for lightweight lookup)
access_token_id = db.Column( access_token_id = db.Column(db.String(255), nullable=True, index=True)
db.String(255), nullable=True, index=True
)
# Token scope # Token scope
scope = db.Column(db.JSON, nullable=True) # Granted scopes scope = db.Column(db.JSON, nullable=True)
# Timing # Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True) expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -40,7 +39,7 @@ class OIDCRefreshToken(BaseModel):
revoked_reason = db.Column(db.String(255), nullable=True) revoked_reason = db.Column(db.String(255), nullable=True)
# Token rotation metadata # Token rotation metadata
previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation previous_token_hash = db.Column(db.String(255), nullable=True)
rotation_count = db.Column(db.Integer, default=0, nullable=False) rotation_count = db.Column(db.Integer, default=0, nullable=False)
# Request metadata # Request metadata
@@ -53,25 +52,27 @@ class OIDCRefreshToken(BaseModel):
def __repr__(self): def __repr__(self):
"""String representation of OIDCRefreshToken.""" """String representation of OIDCRefreshToken."""
return f"<OIDCRefreshToken client_id={self.client_id} user_id={self.user_id} revoked={self.is_revoked()}>" return (
f"<OIDCRefreshToken client_id={self.client_id} "
f"user_id={self.user_id} revoked={self.is_revoked()}>"
)
def is_expired(self): def is_expired(self) -> bool:
"""Check if the refresh token has expired.""" """Check if the refresh token has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at expires_at = self.expires_at
if expires_at.tzinfo is None: if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc) expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at return datetime.now(timezone.utc) > expires_at
def is_revoked(self): def is_revoked(self) -> bool:
"""Check if the refresh token has been revoked.""" """Check if the refresh token has been revoked."""
return self.revoked_at is not None return self.revoked_at is not None
def is_valid(self): def is_valid(self) -> bool:
"""Check if the refresh token is valid for use.""" """Check if the refresh token is valid for use."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason=None): def revoke(self, reason: str = None) -> None:
"""Revoke the refresh token. """Revoke the refresh token.
Args: Args:
@@ -81,8 +82,8 @@ class OIDCRefreshToken(BaseModel):
self.revoked_reason = reason self.revoked_reason = reason
db.session.commit() db.session.commit()
def rotate(self, new_token_hash): def rotate(self, new_token_hash: str) -> "OIDCRefreshToken":
"""Rotate the refresh token (invalidate old, create new). """Rotate the refresh token invalidate the old hash, store the new one.
Args: Args:
new_token_hash: Hash of the new refresh token new_token_hash: Hash of the new refresh token
@@ -90,20 +91,25 @@ class OIDCRefreshToken(BaseModel):
Returns: Returns:
self for chaining self for chaining
""" """
# Store reference to old token
self.previous_token_hash = self.token_hash self.previous_token_hash = self.token_hash
self.token_hash = new_token_hash self.token_hash = new_token_hash
self.rotation_count += 1 self.rotation_count += 1
# Extend expiration on rotation
from datetime import timedelta
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30) self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
db.session.commit() db.session.commit()
return self return self
@classmethod @classmethod
def create_token(cls, client_id, user_id, token_hash, scope=None, def create_token(
access_token_id=None, ip_address=None, user_agent=None, cls,
lifetime_seconds=2592000): client_id: str,
user_id: str,
token_hash: str,
scope=None,
access_token_id: str = None,
ip_address: str = None,
user_agent: str = None,
lifetime_seconds: int = 2592000,
) -> "OIDCRefreshToken":
"""Create a new refresh token. """Create a new refresh token.
Args: Args:
@@ -111,7 +117,7 @@ class OIDCRefreshToken(BaseModel):
user_id: The user ID user_id: The user ID
token_hash: Hashed refresh token token_hash: Hashed refresh token
scope: Granted scopes scope: Granted scopes
access_token_id: Associated access token ID access_token_id: Associated access token JTI
ip_address: Client IP address ip_address: Client IP address
user_agent: Client user agent user_agent: Client user agent
lifetime_seconds: Token lifetime in seconds (default 30 days) lifetime_seconds: Token lifetime in seconds (default 30 days)
@@ -119,7 +125,6 @@ class OIDCRefreshToken(BaseModel):
Returns: Returns:
OIDCRefreshToken instance OIDCRefreshToken instance
""" """
from datetime import timedelta
token = cls( token = cls(
client_id=client_id, client_id=client_id,
user_id=user_id, user_id=user_id,
@@ -137,20 +142,7 @@ class OIDCRefreshToken(BaseModel):
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields.""" """Convert to dictionary, excluding sensitive fields."""
exclude = exclude or [] exclude = exclude or []
# Always exclude token hashes for field in ("token_hash", "previous_token_hash"):
exclude.append("token_hash") if field not in exclude:
exclude.append("previous_token_hash") exclude.append(field)
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
)
@@ -1,5 +1,7 @@
"""OIDC Session model for OIDC session tracking.""" """OIDC Session model for OIDC session tracking."""
from datetime import datetime, timezone import hashlib
import base64
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
@@ -7,8 +9,8 @@ from gatehouse_app.models.base import BaseModel
class OIDCSession(BaseModel): class OIDCSession(BaseModel):
"""OIDC Session model for tracking OIDC authentication sessions. """OIDC Session model for tracking OIDC authentication sessions.
This model tracks the state during the OIDC authentication flow, Tracks the state during the OIDC authorization flow, including PKCE
including PKCE parameters and nonce validation. parameters and nonce validation.
""" """
__tablename__ = "oidc_sessions" __tablename__ = "oidc_sessions"
@@ -25,11 +27,11 @@ class OIDCSession(BaseModel):
# State management # State management
state = db.Column(db.String(255), nullable=False, index=True) state = db.Column(db.String(255), nullable=False, index=True)
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation nonce = db.Column(db.String(255), nullable=True)
# Authorization request parameters # Authorization request parameters
redirect_uri = db.Column(db.String(512), nullable=False) redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True) # Requested scopes scope = db.Column(db.JSON, nullable=True)
# PKCE parameters # PKCE parameters
code_challenge = db.Column(db.String(255), nullable=True) code_challenge = db.Column(db.String(255), nullable=True)
@@ -45,50 +47,52 @@ class OIDCSession(BaseModel):
def __repr__(self): def __repr__(self):
"""String representation of OIDCSession.""" """String representation of OIDCSession."""
return f"<OIDCSession user_id={self.user_id} client_id={self.client_id} state={self.state[:8]}...>" return (
f"<OIDCSession user_id={self.user_id} "
f"client_id={self.client_id} state={self.state[:8]}...>"
)
def is_expired(self): def is_expired(self) -> bool:
"""Check if the OIDC session has expired.""" """Check if the OIDC session has expired."""
return datetime.now(timezone.utc) > self.expires_at expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_authenticated(self): def is_authenticated(self) -> bool:
"""Check if the user has been authenticated in this session.""" """Check if the user has been authenticated in this session."""
return self.authenticated_at is not None return self.authenticated_at is not None
def mark_authenticated(self): def mark_authenticated(self) -> None:
"""Mark the session as authenticated.""" """Mark the session as authenticated."""
self.authenticated_at = datetime.now(timezone.utc) self.authenticated_at = datetime.now(timezone.utc)
db.session.commit() db.session.commit()
def validate_nonce(self, expected_nonce): def validate_nonce(self, expected_nonce: str) -> bool:
"""Validate the nonce matches the expected value. """Validate the nonce matches the expected value.
Args: Args:
expected_nonce: The expected nonce value expected_nonce: The expected nonce value
Returns: Returns:
bool: True if nonce matches True if nonce matches
""" """
return self.nonce == expected_nonce return self.nonce == expected_nonce
def validate_code_challenge(self, code_verifier): def validate_code_challenge(self, code_verifier: str) -> bool:
"""Validate the code verifier against the stored code challenge. """Validate the code verifier against the stored code challenge.
Args: Args:
code_verifier: The PKCE code verifier code_verifier: The PKCE code verifier
Returns: Returns:
bool: True if code challenge is valid True if the challenge is satisfied
""" """
if not self.code_challenge: if not self.code_challenge:
return False return False
if self.code_challenge_method == "S256": if self.code_challenge_method == "S256":
import hashlib
import base64
# SHA256 hash of code_verifier
digest = hashlib.sha256(code_verifier.encode()).digest() digest = hashlib.sha256(code_verifier.encode()).digest()
# Base64 URL encode without padding
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=") expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
return self.code_challenge == expected return self.code_challenge == expected
elif self.code_challenge_method == "plain": elif self.code_challenge_method == "plain":
@@ -97,9 +101,18 @@ class OIDCSession(BaseModel):
return False return False
@classmethod @classmethod
def create_session(cls, user_id, client_id, state, redirect_uri, scope=None, def create_session(
nonce=None, code_challenge=None, code_challenge_method=None, cls,
lifetime_seconds=600): user_id: str,
client_id: str,
state: str,
redirect_uri: str,
scope=None,
nonce: str = None,
code_challenge: str = None,
code_challenge_method: str = None,
lifetime_seconds: int = 600,
) -> "OIDCSession":
"""Create a new OIDC session. """Create a new OIDC session.
Args: Args:
@@ -116,7 +129,6 @@ class OIDCSession(BaseModel):
Returns: Returns:
OIDCSession instance OIDCSession instance
""" """
from datetime import timedelta
session = cls( session = cls(
user_id=user_id, user_id=user_id,
client_id=client_id, client_id=client_id,
@@ -133,7 +145,7 @@ class OIDCSession(BaseModel):
return session return session
@classmethod @classmethod
def get_by_state(cls, state): def get_by_state(cls, state: str) -> "OIDCSession | None":
"""Get a session by state parameter. """Get a session by state parameter.
Args: Args:
@@ -147,16 +159,3 @@ class OIDCSession(BaseModel):
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary.""" """Convert to dictionary."""
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_sessions = db.relationship(
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.oidc_sessions = db.relationship(
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
)
@@ -8,13 +8,14 @@ from gatehouse_app.models.base import BaseModel
class OIDCTokenMetadata(BaseModel): class OIDCTokenMetadata(BaseModel):
"""OIDC Token Metadata model for tracking issued tokens. """OIDC Token Metadata model for tracking issued tokens.
This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens) Stores metadata about issued tokens (access, refresh, ID) for revocation.
for the purpose of token revocation. The id field matches the JTI (JWT ID) claim. The ``id`` field on this model intentionally overrides the BaseModel UUID
to store the JWT JTI directly as the primary key for O(1) revocation checks.
""" """
__tablename__ = "oidc_token_metadata" __tablename__ = "oidc_token_metadata"
# Token identifier (matches JTI in JWT) # Primary key = JTI so revocation lookups are always a PK scan
id = db.Column( id = db.Column(
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4()) db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
) )
@@ -27,11 +28,11 @@ class OIDCTokenMetadata(BaseModel):
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
) )
# Token type # Token type: "access_token", "refresh_token", or "id_token"
token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token" token_type = db.Column(db.String(50), nullable=False)
# Token identifier for revocation lookup # JWT ID claim (indexed for fast lookup when id != jti)
token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim token_jti = db.Column(db.String(255), nullable=False, index=True)
# Timing # Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True) expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -46,25 +47,27 @@ class OIDCTokenMetadata(BaseModel):
def __repr__(self): def __repr__(self):
"""String representation of OIDCTokenMetadata.""" """String representation of OIDCTokenMetadata."""
return f"<OIDCTokenMetadata jti={self.token_jti[:8]}... type={self.token_type} revoked={self.is_revoked()}>" return (
f"<OIDCTokenMetadata jti={self.token_jti[:8]}... "
f"type={self.token_type} revoked={self.is_revoked()}>"
)
def is_expired(self): def is_expired(self) -> bool:
"""Check if the token has expired.""" """Check if the token has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at expires_at = self.expires_at
if expires_at.tzinfo is None: if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc) expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at return datetime.now(timezone.utc) > expires_at
def is_revoked(self): def is_revoked(self) -> bool:
"""Check if the token has been revoked.""" """Check if the token has been revoked."""
return self.revoked_at is not None return self.revoked_at is not None
def is_valid(self): def is_valid(self) -> bool:
"""Check if the token is valid (not expired and not revoked).""" """Check if the token is valid (not expired and not revoked)."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason=None): def revoke(self, reason: str = None) -> None:
"""Revoke the token. """Revoke the token.
Args: Args:
@@ -75,8 +78,16 @@ class OIDCTokenMetadata(BaseModel):
db.session.commit() db.session.commit()
@classmethod @classmethod
def create_metadata(cls, client_id, user_id, token_type, token_jti, def create_metadata(
expires_at, ip_address=None, user_agent=None): cls,
client_id: str,
user_id: str,
token_type: str,
token_jti: str,
expires_at,
ip_address: str = None,
user_agent: str = None,
) -> "OIDCTokenMetadata":
"""Create token metadata for tracking. """Create token metadata for tracking.
Args: Args:
@@ -85,8 +96,8 @@ class OIDCTokenMetadata(BaseModel):
token_type: Type of token ("access_token", "refresh_token", "id_token") token_type: Type of token ("access_token", "refresh_token", "id_token")
token_jti: JWT ID claim token_jti: JWT ID claim
expires_at: Token expiration datetime expires_at: Token expiration datetime
ip_address: Client IP address ip_address: Client IP address (unused column, kept for API compat)
user_agent: Client user agent user_agent: Client user agent (unused column, kept for API compat)
Returns: Returns:
OIDCTokenMetadata instance OIDCTokenMetadata instance
@@ -104,7 +115,7 @@ class OIDCTokenMetadata(BaseModel):
return metadata return metadata
@classmethod @classmethod
def get_by_jti(cls, token_jti): def get_by_jti(cls, token_jti: str) -> "OIDCTokenMetadata | None":
"""Get token metadata by JWT ID. """Get token metadata by JWT ID.
Args: Args:
@@ -116,7 +127,7 @@ class OIDCTokenMetadata(BaseModel):
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first() return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
@classmethod @classmethod
def revoke_by_jti(cls, token_jti, reason=None): def revoke_by_jti(cls, token_jti: str, reason: str = None) -> bool:
"""Revoke a token by its JWT ID. """Revoke a token by its JWT ID.
Args: Args:
@@ -124,7 +135,7 @@ class OIDCTokenMetadata(BaseModel):
reason: Optional revocation reason reason: Optional revocation reason
Returns: Returns:
bool: True if token was found and revoked True if token was found and revoked, False otherwise
""" """
metadata = cls.get_by_jti(token_jti) metadata = cls.get_by_jti(token_jti)
if metadata: if metadata:
@@ -133,47 +144,53 @@ class OIDCTokenMetadata(BaseModel):
return False return False
@classmethod @classmethod
def revoke_all_for_user(cls, user_id, client_id=None, reason=None): def revoke_all_for_user(
cls, user_id: str, client_id: str = None, reason: str = None
) -> int:
"""Revoke all tokens for a user. """Revoke all tokens for a user.
Args: Args:
user_id: The user ID user_id: The user ID
client_id: Optional client ID to filter by client_id: Optional client ID filter
reason: Optional revocation reason reason: Optional revocation reason
Returns: Returns:
int: Number of tokens revoked Number of tokens revoked
""" """
query = cls.query.filter_by(user_id=user_id, deleted_at=None) query = cls.query.filter_by(user_id=user_id, deleted_at=None).filter(
cls.revoked_at.is_(None)
)
if client_id: if client_id:
query = query.filter_by(client_id=client_id) query = query.filter_by(client_id=client_id)
tokens = query.filter(cls.revoked_at == None).all()
count = 0 count = 0
for token in tokens: for token in query.all():
token.revoke(reason) token.revoke(reason)
count += 1 count += 1
return count return count
@classmethod @classmethod
def revoke_all_for_client(cls, client_id, user_id=None, reason=None): def revoke_all_for_client(
cls, client_id: str, user_id: str = None, reason: str = None
) -> int:
"""Revoke all tokens for a client. """Revoke all tokens for a client.
Args: Args:
client_id: The client ID client_id: The client ID
user_id: Optional user ID to filter by user_id: Optional user ID filter
reason: Optional revocation reason reason: Optional revocation reason
Returns: Returns:
int: Number of tokens revoked Number of tokens revoked
""" """
query = cls.query.filter_by(client_id=client_id, deleted_at=None) query = cls.query.filter_by(client_id=client_id, deleted_at=None).filter(
cls.revoked_at.is_(None)
)
if user_id: if user_id:
query = query.filter_by(user_id=user_id) query = query.filter_by(user_id=user_id)
tokens = query.filter(cls.revoked_at == None).all()
count = 0 count = 0
for token in tokens: for token in query.all():
token.revoke(reason) token.revoke(reason)
count += 1 count += 1
return count return count
@@ -181,16 +198,3 @@ class OIDCTokenMetadata(BaseModel):
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary.""" """Convert to dictionary."""
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
)
-77
View File
@@ -1,77 +0,0 @@
"""OIDC JWKS Key model for persisting signing keys."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OidcJwksKey(BaseModel):
"""
OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
This model stores RSA/ECDSA key pairs used for signing OIDC tokens.
Multiple keys can be stored to support key rotation scenarios.
Attributes:
id: Integer primary key
kid: Unique key ID used in JWT "kid" header
key_type: Type of key (e.g., "RSA", "EC")
private_key: PEM-encoded private key
public_key: PEM-encoded public key
algorithm: Signing algorithm (e.g., "RS256", "ES256")
created_at: When the key was created
is_active: Whether this key is currently active for signing
is_primary: Whether this is the primary signing key
expires_at: ...
"""
__tablename__ = "oidc_jwks_keys"
# Override the default UUID id with integer primary key
id = db.Column(db.Integer, primary_key=True)
expires_at = db.Column(db.DateTime, nullable=True)
# Key identification and type
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
# Key material (PEM-encoded)
private_key = db.Column(db.Text, nullable=False)
public_key = db.Column(db.Text, nullable=False)
# Key status
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_primary = db.Column(db.Boolean, default=False, nullable=False)
def __repr__(self):
"""String representation of OidcJwksKey."""
return f"<OidcJwksKey kid={self.kid} key_type={self.key_type} algorithm={self.algorithm}>"
def to_dict(self, exclude_private_key=True):
"""
Convert model to dictionary.
Args:
exclude_private_key: If True, excludes the private key from output
Returns:
Dictionary representation of the model
"""
exclude = ["private_key"] if exclude_private_key else []
return super().to_dict(exclude=exclude)
@classmethod
def get_active_keys(cls):
"""Get all active keys for signing operations."""
return cls.query.filter(cls.is_active == True).all()
@classmethod
def get_primary_key(cls):
"""Get the primary signing key."""
return cls.query.filter(cls.is_primary == True).first()
@classmethod
def get_key_by_kid(cls, kid):
"""Get a key by its key ID."""
return cls.query.filter(cls.kid == kid, cls.is_active == True).first()
@@ -0,0 +1,27 @@
"""Organization subpackage."""
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.organization.department import (
Department,
DepartmentMembership,
DepartmentPrincipal,
)
from gatehouse_app.models.organization.department_cert_policy import (
DepartmentCertPolicy,
STANDARD_EXTENSIONS,
)
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
__all__ = [
"Organization",
"OrganizationMember",
"Department",
"DepartmentMembership",
"DepartmentPrincipal",
"DepartmentCertPolicy",
"STANDARD_EXTENSIONS",
"Principal",
"PrincipalMembership",
"OrgInviteToken",
]
@@ -1,3 +1,4 @@
"""Department, DepartmentMembership, and DepartmentPrincipal models."""
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
@@ -39,12 +40,15 @@ class Department(BaseModel):
back_populates="department", back_populates="department",
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
cert_policy = db.relationship(
"DepartmentCertPolicy",
back_populates="department",
uselist=False,
cascade="all, delete-orphan",
)
# Unique constraint: department name per organization
__table_args__ = ( __table_args__ = (
db.UniqueConstraint( db.UniqueConstraint("organization_id", "name", name="uix_org_dept_name"),
"organization_id", "name", name="uix_org_dept_name"
),
) )
def __repr__(self): def __repr__(self):
@@ -55,16 +59,11 @@ class Department(BaseModel):
"""Convert department to dictionary.""" """Convert department to dictionary."""
exclude = exclude or [] exclude = exclude or []
data = super().to_dict(exclude=exclude) data = super().to_dict(exclude=exclude)
# Add member count
data["member_count"] = len([m for m in self.memberships if m.deleted_at is None]) data["member_count"] = len([m for m in self.memberships if m.deleted_at is None])
# Add principal count
data["principal_count"] = len([p for p in self.principal_links if p.deleted_at is None]) data["principal_count"] = len([p for p in self.principal_links if p.deleted_at is None])
return data return data
def get_members(self, active_only=True): def get_members(self, active_only: bool = True):
"""Get all members of this department. """Get all members of this department.
Args: Args:
@@ -75,9 +74,9 @@ class Department(BaseModel):
""" """
if active_only: if active_only:
return [m for m in self.memberships if m.deleted_at is None] return [m for m in self.memberships if m.deleted_at is None]
return self.memberships return list(self.memberships)
def get_principals(self, active_only=True): def get_principals(self, active_only: bool = True):
"""Get all principals assigned to this department. """Get all principals assigned to this department.
Args: Args:
@@ -87,10 +86,14 @@ class Department(BaseModel):
List of Principal objects via DepartmentPrincipal List of Principal objects via DepartmentPrincipal
""" """
if active_only: if active_only:
return [p.principal for p in self.principal_links if p.deleted_at is None and p.principal.deleted_at is None] return [
p.principal
for p in self.principal_links
if p.deleted_at is None and p.principal.deleted_at is None
]
return [p.principal for p in self.principal_links] return [p.principal for p in self.principal_links]
def is_member(self, user_id): def is_member(self, user_id: str) -> bool:
"""Check if a user is a member of this department. """Check if a user is a member of this department.
Args: Args:
@@ -108,7 +111,7 @@ class Department(BaseModel):
is not None is not None
) )
def get_member_count(self): def get_member_count(self) -> int:
"""Get the count of active members in this department.""" """Get the count of active members in this department."""
return len(self.get_members(active_only=True)) return len(self.get_members(active_only=True))
@@ -139,16 +142,15 @@ class DepartmentMembership(BaseModel):
user = db.relationship("User", back_populates="department_memberships") user = db.relationship("User", back_populates="department_memberships")
department = db.relationship("Department", back_populates="memberships") department = db.relationship("Department", back_populates="memberships")
# Unique constraint: user can only be member of a department once
__table_args__ = ( __table_args__ = (
db.UniqueConstraint( db.UniqueConstraint("user_id", "department_id", name="uix_user_dept"),
"user_id", "department_id", name="uix_user_dept"
),
) )
def __repr__(self): def __repr__(self):
"""String representation of DepartmentMembership.""" """String representation of DepartmentMembership."""
return f"<DepartmentMembership user_id={self.user_id} dept_id={self.department_id}>" return (
f"<DepartmentMembership user_id={self.user_id} dept_id={self.department_id}>"
)
class DepartmentPrincipal(BaseModel): class DepartmentPrincipal(BaseModel):
@@ -182,13 +184,13 @@ class DepartmentPrincipal(BaseModel):
department = db.relationship("Department", back_populates="principal_links") department = db.relationship("Department", back_populates="principal_links")
principal = db.relationship("Principal", back_populates="department_links") principal = db.relationship("Principal", back_populates="department_links")
# Unique constraint: principal can only be assigned to a department once
__table_args__ = ( __table_args__ = (
db.UniqueConstraint( db.UniqueConstraint("department_id", "principal_id", name="uix_dept_principal"),
"department_id", "principal_id", name="uix_dept_principal"
),
) )
def __repr__(self): def __repr__(self):
"""String representation of DepartmentPrincipal.""" """String representation of DepartmentPrincipal."""
return f"<DepartmentPrincipal dept_id={self.department_id} principal_id={self.principal_id}>" return (
f"<DepartmentPrincipal dept_id={self.department_id} "
f"principal_id={self.principal_id}>"
)
@@ -0,0 +1,76 @@
"""DepartmentCertPolicy — per-department SSH certificate issuance rules."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
# Standard SSH certificate extensions
STANDARD_EXTENSIONS = [
"permit-X11-forwarding",
"permit-agent-forwarding",
"permit-pty",
"permit-port-forwarding",
"permit-user-rc",
]
class DepartmentCertPolicy(BaseModel):
"""SSH certificate policy for a department.
Controls:
- Whether members may choose their own expiry date (up to ``max_expiry_hours``)
- Default expiry hours when the user doesn't (or can't) pick
- Maximum expiry hours (hard ceiling, even for admins signing on behalf)
- Which SSH certificate extensions are granted to members of this department
- Any custom extensions the admin wants to add beyond the standard five
Inherits ``id``, ``created_at``, ``updated_at``, and ``deleted_at`` from
:class:`BaseModel` so soft-delete and the standard timestamp behaviour are
consistent with every other model in the application.
"""
__tablename__ = "department_cert_policies"
department_id = db.Column(
db.String(36),
db.ForeignKey("departments.id"),
nullable=False,
unique=True,
index=True,
)
# Expiry control
allow_user_expiry = db.Column(db.Boolean, nullable=False, default=False)
default_expiry_hours = db.Column(db.Integer, nullable=False, default=1)
max_expiry_hours = db.Column(db.Integer, nullable=False, default=24)
# Extensions — list of extension name strings
allowed_extensions = db.Column(
db.JSON,
nullable=False,
default=lambda: list(STANDARD_EXTENSIONS),
)
# Admin-defined extras beyond the standard five
custom_extensions = db.Column(db.JSON, nullable=False, default=list)
# Relationship back to department
department = db.relationship("Department", back_populates="cert_policy", uselist=False)
def __repr__(self):
return (
f"<DepartmentCertPolicy dept={self.department_id} "
f"allow_user_expiry={self.allow_user_expiry}>"
)
def all_extensions(self) -> list:
"""Return the full list of enabled extensions (allowed + custom)."""
return list((self.allowed_extensions or []) + (self.custom_extensions or []))
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
data = super().to_dict(exclude=exclude)
# Augment with computed / convenience fields not in the base columns
data["all_extensions"] = self.all_extensions()
data["standard_extensions"] = STANDARD_EXTENSIONS
return data
@@ -0,0 +1,77 @@
"""Organization invite token model."""
import secrets
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OrgInviteToken(BaseModel):
"""Token-based invitation to join an organization."""
__tablename__ = "org_invite_tokens"
organization_id = db.Column(
db.String(36),
db.ForeignKey("organizations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
invited_by_id = db.Column(
db.String(36),
db.ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
email = db.Column(db.String(255), nullable=False, index=True)
role = db.Column(db.String(64), nullable=False, default="member")
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
accepted_at = db.Column(db.DateTime, nullable=True)
organization = db.relationship(
"Organization",
backref=db.backref("invite_tokens", cascade="all, delete-orphan"),
)
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
@classmethod
def generate(
cls,
organization_id: str,
email: str,
role: str = "member",
invited_by_id: str = None,
ttl_days: int = 7,
) -> "OrgInviteToken":
"""Create a new invite token for an organization."""
token_value = secrets.token_urlsafe(48)
instance = cls(
organization_id=organization_id,
email=email.lower(),
role=role,
invited_by_id=invited_by_id,
token=token_value,
expires_at=datetime.now(timezone.utc) + timedelta(days=ttl_days),
)
db.session.add(instance)
db.session.commit()
return instance
@property
def is_valid(self) -> bool:
"""Return True if the token is unused and not expired."""
if self.accepted_at is not None:
return False
now = datetime.now(timezone.utc)
expires = self.expires_at
if expires.tzinfo is None:
expires = expires.replace(tzinfo=timezone.utc)
return now < expires
def accept(self) -> None:
"""Mark the invite as accepted."""
self.accepted_at = datetime.now(timezone.utc)
db.session.commit()
def __repr__(self) -> str:
return f"<OrgInviteToken org={self.organization_id} email={self.email}>"
@@ -61,9 +61,9 @@ class Organization(BaseModel):
return member.user return member.user
return None return None
def is_member(self, user_id): def is_member(self, user_id: str) -> bool:
"""Check if a user is a member of the organization.""" """Check if a user is a member of the organization."""
from gatehouse_app.models.organization_member import OrganizationMember from gatehouse_app.models.organization.organization_member import OrganizationMember
return ( return (
OrganizationMember.query.filter_by( OrganizationMember.query.filter_by(
@@ -21,31 +21,35 @@ class OrganizationMember(BaseModel):
joined_at = db.Column(db.DateTime, nullable=True) joined_at = db.Column(db.DateTime, nullable=True)
# Relationships # Relationships
user = db.relationship("User", foreign_keys=[user_id], back_populates="organization_memberships") user = db.relationship(
"User", foreign_keys=[user_id], back_populates="organization_memberships"
)
organization = db.relationship("Organization", back_populates="members") organization = db.relationship("Organization", back_populates="members")
invited_by = db.relationship("User", foreign_keys=[invited_by_id]) invited_by = db.relationship("User", foreign_keys=[invited_by_id])
# Unique constraint to prevent duplicate memberships
__table_args__ = ( __table_args__ = (
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"), db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"),
) )
def __repr__(self): def __repr__(self):
"""String representation of OrganizationMember.""" """String representation of OrganizationMember."""
return f"<OrganizationMember user_id={self.user_id} org_id={self.organization_id} role={self.role}>" return (
f"<OrganizationMember user_id={self.user_id} "
f"org_id={self.organization_id} role={self.role}>"
)
def is_owner(self): def is_owner(self) -> bool:
"""Check if member is an owner.""" """Check if member is an owner."""
return self.role == OrganizationRole.OWNER return self.role == OrganizationRole.OWNER
def is_admin(self): def is_admin(self) -> bool:
"""Check if member is an admin or owner.""" """Check if member is an admin or owner."""
return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
def can_manage_members(self): def can_manage_members(self) -> bool:
"""Check if member can manage other members.""" """Check if member can manage other members."""
return self.is_admin() return self.is_admin()
def can_delete_organization(self): def can_delete_organization(self) -> bool:
"""Check if member can delete the organization.""" """Check if member can delete the organization."""
return self.is_owner() return self.is_owner()
@@ -1,3 +1,4 @@
"""Principal and PrincipalMembership models."""
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
@@ -7,7 +8,8 @@ class Principal(BaseModel):
In SSH CA terminology, a principal is a string like "eng-prod-servers" or In SSH CA terminology, a principal is a string like "eng-prod-servers" or
"devops-admins" that represents a set of machines or access level. Users "devops-admins" that represents a set of machines or access level. Users
can be granted access to principals, either directly or via department membership. can be granted access to principals, either directly or via department
membership.
Example: Example:
- Principal: "eng-prod-servers" - Principal: "eng-prod-servers"
@@ -39,11 +41,8 @@ class Principal(BaseModel):
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
# Unique constraint: principal name per organization
__table_args__ = ( __table_args__ = (
db.UniqueConstraint( db.UniqueConstraint("organization_id", "name", name="uix_org_principal_name"),
"organization_id", "name", name="uix_org_principal_name"
),
) )
def __repr__(self): def __repr__(self):
@@ -54,16 +53,15 @@ class Principal(BaseModel):
"""Convert principal to dictionary.""" """Convert principal to dictionary."""
exclude = exclude or [] exclude = exclude or []
data = super().to_dict(exclude=exclude) data = super().to_dict(exclude=exclude)
data["direct_member_count"] = len(
# Add member count [m for m in self.memberships if m.deleted_at is None]
data["direct_member_count"] = len([m for m in self.memberships if m.deleted_at is None]) )
data["department_count"] = len(
# Add department count [d for d in self.department_links if d.deleted_at is None]
data["department_count"] = len([d for d in self.department_links if d.deleted_at is None]) )
return data return data
def get_members(self, active_only=True): def get_members(self, active_only: bool = True):
"""Get all users who are directly assigned to this principal. """Get all users who are directly assigned to this principal.
Does NOT include users who get access via department membership. Does NOT include users who get access via department membership.
@@ -76,9 +74,9 @@ class Principal(BaseModel):
""" """
if active_only: if active_only:
return [m for m in self.memberships if m.deleted_at is None] return [m for m in self.memberships if m.deleted_at is None]
return self.memberships return list(self.memberships)
def get_all_members(self, active_only=True): def get_all_members(self, active_only: bool = True):
"""Get all users who have access to this principal. """Get all users who have access to this principal.
Includes both direct members and users via department membership. Includes both direct members and users via department membership.
@@ -89,25 +87,23 @@ class Principal(BaseModel):
Returns: Returns:
Set of User objects with access to this principal Set of User objects with access to this principal
""" """
from gatehouse_app.models.user import User all_users: set = set()
all_users = set() # Direct members
# Add direct members
for membership in self.get_members(active_only=active_only): for membership in self.get_members(active_only=active_only):
if membership.user.deleted_at is None or not active_only: if not active_only or membership.user.deleted_at is None:
all_users.add(membership.user) all_users.add(membership.user)
# Add members via department assignment # Members via department assignment
for dept_link in self.department_links: for dept_link in self.department_links:
if dept_link.deleted_at is None or not active_only: if dept_link.deleted_at is None or not active_only:
for dept_member in dept_link.department.get_members(active_only=active_only): for dept_member in dept_link.department.get_members(active_only=active_only):
if dept_member.user.deleted_at is None or not active_only: if not active_only or dept_member.user.deleted_at is None:
all_users.add(dept_member.user) all_users.add(dept_member.user)
return all_users return all_users
def get_departments(self, active_only=True): def get_departments(self, active_only: bool = True):
"""Get all departments this principal is assigned to. """Get all departments this principal is assigned to.
Args: Args:
@@ -117,10 +113,14 @@ class Principal(BaseModel):
List of Department objects List of Department objects
""" """
if active_only: if active_only:
return [d.department for d in self.department_links if d.deleted_at is None and d.department.deleted_at is None] return [
d.department
for d in self.department_links
if d.deleted_at is None and d.department.deleted_at is None
]
return [d.department for d in self.department_links] return [d.department for d in self.department_links]
def is_member(self, user_id, include_via_department=True): def is_member(self, user_id: str, include_via_department: bool = True) -> bool:
"""Check if a user has access to this principal. """Check if a user has access to this principal.
Args: Args:
@@ -143,30 +143,26 @@ class Principal(BaseModel):
if has_direct: if has_direct:
return True return True
# Check department membership if requested
if not include_via_department: if not include_via_department:
return False return False
# Get all departments this principal is assigned to # Check department membership
depts = self.get_departments(active_only=True) dept_ids = [d.id for d in self.get_departments(active_only=True)]
dept_ids = [d.id for d in depts]
if not dept_ids: if not dept_ids:
return False return False
# Check if user is a member of any of these departments from gatehouse_app.models.organization.department import DepartmentMembership
from gatehouse_app.models.department import DepartmentMembership
return ( return (
DepartmentMembership.query.filter( DepartmentMembership.query.filter(
DepartmentMembership.user_id == user_id, DepartmentMembership.user_id == user_id,
DepartmentMembership.department_id.in_(dept_ids), DepartmentMembership.department_id.in_(dept_ids),
DepartmentMembership.deleted_at == None, DepartmentMembership.deleted_at.is_(None),
).first() ).first()
is not None is not None
) )
def get_member_count(self, include_via_department=True): def get_member_count(self, include_via_department: bool = True) -> int:
"""Get the count of active members with access to this principal. """Get the count of active members with access to this principal.
Args: Args:
@@ -177,16 +173,15 @@ class Principal(BaseModel):
""" """
if not include_via_department: if not include_via_department:
return len(self.get_members(active_only=True)) return len(self.get_members(active_only=True))
return len(self.get_all_members(active_only=True)) return len(self.get_all_members(active_only=True))
class PrincipalMembership(BaseModel): class PrincipalMembership(BaseModel):
"""Principal membership model representing direct user assignment to a principal. """Principal membership model representing direct user assignment to a principal.
When a user is assigned directly to a principal, they get access to that principal When a user is assigned directly to a principal, they get access to that
for SSH authentication. This is in addition to any principals they get via principal for SSH authentication. This is in addition to any principals
department membership. they get via department membership.
""" """
__tablename__ = "principal_memberships" __tablename__ = "principal_memberships"
@@ -208,13 +203,13 @@ class PrincipalMembership(BaseModel):
user = db.relationship("User", back_populates="principal_memberships") user = db.relationship("User", back_populates="principal_memberships")
principal = db.relationship("Principal", back_populates="memberships") principal = db.relationship("Principal", back_populates="memberships")
# Unique constraint: user can only be member of a principal once
__table_args__ = ( __table_args__ = (
db.UniqueConstraint( db.UniqueConstraint("user_id", "principal_id", name="uix_user_principal"),
"user_id", "principal_id", name="uix_user_principal"
),
) )
def __repr__(self): def __repr__(self):
"""String representation of PrincipalMembership.""" """String representation of PrincipalMembership."""
return f"<PrincipalMembership user_id={self.user_id} principal_id={self.principal_id}>" return (
f"<PrincipalMembership user_id={self.user_id} "
f"principal_id={self.principal_id}>"
)
+12
View File
@@ -0,0 +1,12 @@
"""Security subpackage — organization and user security policies, MFA compliance."""
from gatehouse_app.models.security.organization_security_policy import (
OrganizationSecurityPolicy,
)
from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
__all__ = [
"OrganizationSecurityPolicy",
"UserSecurityPolicy",
"MfaPolicyCompliance",
]
@@ -1,4 +1,4 @@
"""MfaPolicyCompliance model.""" """MfaPolicyCompliance model — per-user per-organization MFA compliance tracking."""
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import MfaComplianceStatus from gatehouse_app.utils.constants import MfaComplianceStatus
@@ -7,7 +7,8 @@ from gatehouse_app.utils.constants import MfaComplianceStatus
class MfaPolicyCompliance(BaseModel): class MfaPolicyCompliance(BaseModel):
"""MFA policy compliance tracking per user per organization. """MFA policy compliance tracking per user per organization.
Tracks each user's MFA compliance state separately for each organization membership. Tracks each user's MFA compliance state separately for each organization
membership. One row per (user, org) pair.
""" """
__tablename__ = "mfa_policy_compliance" __tablename__ = "mfa_policy_compliance"
@@ -25,13 +26,13 @@ class MfaPolicyCompliance(BaseModel):
default=MfaComplianceStatus.NOT_APPLICABLE, default=MfaComplianceStatus.NOT_APPLICABLE,
) )
# Snapshot of org policy at the time this record became active # Snapshot of org policy version when this record became active
policy_version = db.Column(db.Integer, nullable=False) policy_version = db.Column(db.Integer, nullable=False)
# When policy started applying to this user # When policy started applying to this user
applied_at = db.Column(db.DateTime, nullable=True) applied_at = db.Column(db.DateTime, nullable=True)
# Final deadline for this user to comply (per user, not global) # Final deadline for this user to comply
deadline_at = db.Column(db.DateTime, nullable=True) deadline_at = db.Column(db.DateTime, nullable=True)
# When they became compliant under this policy_version # When they became compliant under this policy_version
@@ -45,9 +46,7 @@ class MfaPolicyCompliance(BaseModel):
notification_count = db.Column(db.Integer, nullable=False, default=0) notification_count = db.Column(db.Integer, nullable=False, default=0)
__table_args__ = ( __table_args__ = (
db.UniqueConstraint( db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_compliance"),
"user_id", "organization_id", name="uix_user_org_compliance"
),
) )
# Relationships # Relationships
@@ -58,9 +57,11 @@ class MfaPolicyCompliance(BaseModel):
def __repr__(self): def __repr__(self):
"""String representation of MfaPolicyCompliance.""" """String representation of MfaPolicyCompliance."""
return f"<MfaPolicyCompliance user={self.user_id} org={self.organization_id} status={self.status}>" return (
f"<MfaPolicyCompliance user={self.user_id} "
f"org={self.organization_id} status={self.status}>"
)
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary.""" """Convert to dictionary."""
exclude = exclude or [] return super().to_dict(exclude=exclude or [])
return super().to_dict(exclude=exclude)
@@ -39,15 +39,19 @@ class OrganizationSecurityPolicy(BaseModel):
# Relationships # Relationships
organization = db.relationship( organization = db.relationship(
"Organization", back_populates="security_policy", foreign_keys=[organization_id] "Organization",
back_populates="security_policy",
foreign_keys=[organization_id],
) )
updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id]) updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id])
def __repr__(self): def __repr__(self):
"""String representation of OrganizationSecurityPolicy.""" """String representation of OrganizationSecurityPolicy."""
return f"<OrganizationSecurityPolicy org={self.organization_id} mode={self.mfa_policy_mode}>" return (
f"<OrganizationSecurityPolicy "
f"org={self.organization_id} mode={self.mfa_policy_mode}>"
)
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary.""" """Convert to dictionary."""
exclude = exclude or [] return super().to_dict(exclude=exclude or [])
return super().to_dict(exclude=exclude)
@@ -1,4 +1,4 @@
"""UserSecurityPolicy model.""" """UserSecurityPolicy model — per-user MFA overrides."""
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import MfaRequirementOverride from gatehouse_app.utils.constants import MfaRequirementOverride
@@ -7,7 +7,7 @@ from gatehouse_app.utils.constants import MfaRequirementOverride
class UserSecurityPolicy(BaseModel): class UserSecurityPolicy(BaseModel):
"""User security policy model for per-user MFA overrides. """User security policy model for per-user MFA overrides.
Stores per user overrides of organization level MFA requirements. Stores per-user overrides of organization-level MFA requirements.
""" """
__tablename__ = "user_security_policies" __tablename__ = "user_security_policies"
@@ -25,29 +25,27 @@ class UserSecurityPolicy(BaseModel):
default=MfaRequirementOverride.INHERIT, default=MfaRequirementOverride.INHERIT,
) )
# If override is REQUIRED and you want to force a specific factor set # If override is REQUIRED, optionally force a specific factor set
force_totp = db.Column(db.Boolean, nullable=False, default=False) force_totp = db.Column(db.Boolean, nullable=False, default=False)
force_webauthn = db.Column(db.Boolean, nullable=False, default=False) force_webauthn = db.Column(db.Boolean, nullable=False, default=False)
__table_args__ = ( __table_args__ = (
db.UniqueConstraint( db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_policy"),
"user_id", "organization_id", name="uix_user_org_policy"
),
) )
# Relationships # Relationships
user = db.relationship( user = db.relationship(
"User", back_populates="security_policies", foreign_keys=[user_id] "User", back_populates="security_policies", foreign_keys=[user_id]
) )
organization = db.relationship( organization = db.relationship("Organization", foreign_keys=[organization_id])
"Organization", foreign_keys=[organization_id]
)
def __repr__(self): def __repr__(self):
"""String representation of UserSecurityPolicy.""" """String representation of UserSecurityPolicy."""
return f"<UserSecurityPolicy user={self.user_id} org={self.organization_id} mode={self.mfa_override_mode}>" return (
f"<UserSecurityPolicy user={self.user_id} "
f"org={self.organization_id} mode={self.mfa_override_mode}>"
)
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary.""" """Convert to dictionary."""
exclude = exclude or [] return super().to_dict(exclude=exclude or [])
return super().to_dict(exclude=exclude)
+17
View File
@@ -0,0 +1,17 @@
"""SSH/CA subpackage — certificate authorities, SSH keys, and certificates."""
from gatehouse_app.models.ssh_ca.ca import CA, KeyType, CertType, CaType, CAPermission
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
__all__ = [
"CA",
"KeyType",
"CertType",
"CaType",
"CAPermission",
"SSHKey",
"SSHCertificate",
"CertificateStatus",
"CertificateAuditLog",
]
@@ -1,6 +1,6 @@
"""Certificate Authority (CA) model.""" """Certificate Authority (CA) model."""
from enum import Enum from enum import Enum
from datetime import datetime from datetime import datetime, timezone
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
@@ -44,11 +44,11 @@ class CA(BaseModel):
index=True, index=True,
) )
# CA name and description # CA identity
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=True) description = db.Column(db.Text, nullable=True)
# CA signing type: 'user' signs user certificates, 'host' signs host certificates # CA signing type: 'user' signs user certificates, 'host' signs host certs
ca_type = db.Column( ca_type = db.Column(
db.Enum(CaType, values_callable=lambda x: [e.value for e in x]), db.Enum(CaType, values_callable=lambda x: [e.value for e in x]),
default=CaType.USER, default=CaType.USER,
@@ -62,11 +62,10 @@ class CA(BaseModel):
nullable=False, nullable=False,
) )
# Private key (encrypted at rest by database/KMS) # Private key — PEM-encoded, encrypted at rest by database/KMS
# Format: PEM-encoded private key
private_key = db.Column(db.Text, nullable=False) private_key = db.Column(db.Text, nullable=False)
# Public key (PEM format) # Public key PEM format
public_key = db.Column(db.Text, nullable=False) public_key = db.Column(db.Text, nullable=False)
# SHA256 fingerprint of the public key # SHA256 fingerprint of the public key
@@ -76,20 +75,11 @@ class CA(BaseModel):
crl_enabled = db.Column(db.Boolean, default=True, nullable=False) crl_enabled = db.Column(db.Boolean, default=True, nullable=False)
crl_endpoint = db.Column(db.String(512), nullable=True) crl_endpoint = db.Column(db.String(512), nullable=True)
# Default certificate validity in hours # Default certificate validity in hours (overridable per request)
# Can be overridden per certificate request default_cert_validity_hours = db.Column(db.Integer, default=1, nullable=False)
default_cert_validity_hours = db.Column(
db.Integer,
default=1,
nullable=False,
)
# Maximum validity duration allowed # Maximum validity duration allowed
max_cert_validity_hours = db.Column( max_cert_validity_hours = db.Column(db.Integer, default=24, nullable=False)
db.Integer,
default=24,
nullable=False,
)
# CA status # CA status
is_active = db.Column(db.Boolean, default=True, nullable=False, index=True) is_active = db.Column(db.Boolean, default=True, nullable=False, index=True)
@@ -112,53 +102,52 @@ class CA(BaseModel):
) )
__table_args__ = ( __table_args__ = (
db.UniqueConstraint( db.UniqueConstraint("organization_id", "name", name="uix_org_ca_name"),
"organization_id", "name", name="uix_org_ca_name"
),
db.Index("idx_ca_org_active", "organization_id", "is_active"), db.Index("idx_ca_org_active", "organization_id", "is_active"),
) )
def __repr__(self): def __repr__(self):
"""String representation of CA.""" """String representation of CA."""
return f"<CA {self.name} (org_id={self.organization_id}, type={self.key_type})>" return (
f"<CA {self.name} "
f"(org_id={self.organization_id}, type={self.key_type})>"
)
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert CA to dictionary.""" """Convert CA to dictionary, never exposing the private key."""
exclude = exclude or [] exclude = exclude or []
# Never expose private key in API responses if "private_key" not in exclude:
exclude.extend(["private_key"]) exclude.append("private_key")
data = super().to_dict(exclude=exclude) data = super().to_dict(exclude=exclude)
# Add computed fields # Add computed fields
data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None]) data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None])
data["active_certs"] = len([ data["active_certs"] = len(
c for c in self.certificates [c for c in self.certificates if c.deleted_at is None and not c.revoked]
if c.deleted_at is None and not c.revoked )
]) data["revoked_certs"] = len(
data["revoked_certs"] = len([ [c for c in self.certificates if c.deleted_at is None and c.revoked]
c for c in self.certificates )
if c.deleted_at is None and c.revoked
])
return data return data
def get_active_certificates(self): def get_active_certificates(self) -> list:
"""Get all active (non-revoked) certificates issued by this CA. """Get all active (non-revoked) certificates issued by this CA."""
Returns:
List of non-revoked SSHCertificate objects
"""
return [ return [
c for c in self.certificates c for c in self.certificates if c.deleted_at is None and not c.revoked
if c.deleted_at is None and not c.revoked
] ]
def rotate_key(self, new_private_key, new_public_key, new_fingerprint, reason=None): def rotate_key(
self,
new_private_key: str,
new_public_key: str,
new_fingerprint: str,
reason: str = None,
) -> None:
"""Rotate the CA's key pair. """Rotate the CA's key pair.
This should only be done in carefully controlled circumstances. This should only be done in carefully controlled circumstances.
All existing certificates remain valid but no new certs can be All existing certificates remain valid but no new certificates can be
signed with the old key. signed with the old key after rotation.
Args: Args:
new_private_key: New PEM-encoded private key new_private_key: New PEM-encoded private key
@@ -169,7 +158,7 @@ class CA(BaseModel):
self.private_key = new_private_key self.private_key = new_private_key
self.public_key = new_public_key self.public_key = new_public_key
self.fingerprint = new_fingerprint self.fingerprint = new_fingerprint
self.rotated_at = datetime.utcnow() self.rotated_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
self.rotation_reason = reason self.rotation_reason = reason
self.save() self.save()
@@ -178,7 +167,7 @@ class CAPermission(BaseModel):
"""Per-user CA permission model. """Per-user CA permission model.
Controls which users are allowed to sign certificates against a specific CA. Controls which users are allowed to sign certificates against a specific CA.
When a CA has any permission rows the signing endpoint enforces the list; When a CA has any permission rows, the signing endpoint enforces the list;
CAs with no rows are open to all org members (backwards-compatible default). CAs with no rows are open to all org members (backwards-compatible default).
Permission values: Permission values:
@@ -212,7 +201,10 @@ class CAPermission(BaseModel):
) )
def __repr__(self): def __repr__(self):
return f"<CAPermission ca_id={self.ca_id} user_id={self.user_id} permission={self.permission}>" return (
f"<CAPermission ca_id={self.ca_id} "
f"user_id={self.user_id} permission={self.permission}>"
)
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
data = super().to_dict(exclude=exclude or []) data = super().to_dict(exclude=exclude or [])
@@ -1,4 +1,4 @@
"""Certificate audit log model.""" """Certificate audit log model — tracks SSH certificate lifecycle events."""
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
@@ -6,9 +6,9 @@ from gatehouse_app.models.base import BaseModel
class CertificateAuditLog(BaseModel): class CertificateAuditLog(BaseModel):
"""Audit log for SSH certificate lifecycle events. """Audit log for SSH certificate lifecycle events.
Tracks all operations on SSH certificates: signing, revocation, Tracks all operations on SSH certificates: signing, revocation, validation,
validation, etc. This is separate from the general AuditLog to etc. Kept separate from the general AuditLog to provide detailed certificate
provide detailed certificate operation tracking. operation tracking without polluting the main audit stream.
""" """
__tablename__ = "certificate_audit_logs" __tablename__ = "certificate_audit_logs"
@@ -21,7 +21,7 @@ class CertificateAuditLog(BaseModel):
index=True, index=True,
) )
# The user who performed the action (can be null for system actions) # The user who performed the action (null for system actions)
user_id = db.Column( user_id = db.Column(
db.String(36), db.String(36),
db.ForeignKey("users.id"), db.ForeignKey("users.id"),
@@ -43,7 +43,7 @@ class CertificateAuditLog(BaseModel):
# Additional context # Additional context
extra_data = db.Column(db.JSON, nullable=True) extra_data = db.Column(db.JSON, nullable=True)
# Success/failure # Outcome
success = db.Column(db.Boolean, default=True, nullable=False) success = db.Column(db.Boolean, default=True, nullable=False)
error_message = db.Column(db.Text, nullable=True) error_message = db.Column(db.Text, nullable=True)
@@ -58,10 +58,18 @@ class CertificateAuditLog(BaseModel):
def __repr__(self): def __repr__(self):
"""String representation of CertificateAuditLog.""" """String representation of CertificateAuditLog."""
return f"<CertificateAuditLog cert_id={self.certificate_id} action={self.action}>" return (
f"<CertificateAuditLog cert_id={self.certificate_id} action={self.action}>"
)
@classmethod @classmethod
def log(cls, certificate_id, action, user_id=None, **kwargs): def log(
cls,
certificate_id: str,
action: str,
user_id: str = None,
**kwargs,
) -> "CertificateAuditLog":
"""Create a certificate audit log entry. """Create a certificate audit log entry.
Args: Args:
@@ -77,7 +85,7 @@ class CertificateAuditLog(BaseModel):
certificate_id=certificate_id, certificate_id=certificate_id,
action=action, action=action,
user_id=user_id, user_id=user_id,
**kwargs **kwargs,
) )
log_entry.save() log_entry.save()
return log_entry return log_entry
@@ -1,9 +1,9 @@
"""SSH Certificate model.""" """SSH Certificate model — signed SSH user/host certificates."""
from enum import Enum from enum import Enum
from datetime import datetime, timezone from datetime import datetime, timezone
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
from gatehouse_app.models.ca import CertType from gatehouse_app.models.ssh_ca.ca import CertType
class CertificateStatus(str, Enum): class CertificateStatus(str, Enum):
@@ -13,14 +13,14 @@ class CertificateStatus(str, Enum):
ISSUED = "issued" # Signed and valid ISSUED = "issued" # Signed and valid
REVOKED = "revoked" # Manually revoked REVOKED = "revoked" # Manually revoked
EXPIRED = "expired" # Validity period ended EXPIRED = "expired" # Validity period ended
SUPERSEDED = "superseded" # Replaced by newer cert SUPERSEDED = "superseded" # Replaced by newer certificate
class SSHCertificate(BaseModel): class SSHCertificate(BaseModel):
"""SSH Certificate model representing a signed SSH user/host certificate. """SSH Certificate model representing a signed SSH user/host certificate.
Certificates are issued by a CA and associated with an SSH public key. Certificates are issued by a CA and associated with an SSH public key.
They include principals (access levels), validity periods, and other They include principals (access levels), validity periods, and standard
OpenSSH certificate metadata. OpenSSH certificate metadata.
""" """
@@ -46,7 +46,7 @@ class SSHCertificate(BaseModel):
index=True, index=True,
) )
# Certificate content (full signed certificate in OpenSSH format) # Certificate content full signed certificate in OpenSSH format
certificate = db.Column(db.Text, nullable=False) certificate = db.Column(db.Text, nullable=False)
# Certificate metadata # Certificate metadata
@@ -58,7 +58,7 @@ class SSHCertificate(BaseModel):
nullable=False, nullable=False,
) )
# Principals (JSON list) - e.g., ["prod-servers", "dev-servers"] # Principals JSON list, e.g., ["prod-servers", "dev-servers"]
principals = db.Column(db.JSON, nullable=False, default=list) principals = db.Column(db.JSON, nullable=False, default=list)
# Validity period # Validity period
@@ -82,12 +82,10 @@ class SSHCertificate(BaseModel):
request_ip = db.Column(db.String(45), nullable=True) request_ip = db.Column(db.String(45), nullable=True)
request_user_agent = db.Column(db.String(512), nullable=True) request_user_agent = db.Column(db.String(512), nullable=True)
# Critical options (JSON) - OpenSSH critical options # Critical options OpenSSH critical options (JSON)
# See: https://man.openbsd.org/ssh-cert
critical_options = db.Column(db.JSON, nullable=True, default=dict) critical_options = db.Column(db.JSON, nullable=True, default=dict)
# Extensions (JSON) - OpenSSH extensions # Extensions OpenSSH extensions (JSON)
# Common ones: permit-X11-forwarding, permit-agent-forwarding, permit-pty, etc.
extensions = db.Column(db.JSON, nullable=True, default=dict) extensions = db.Column(db.JSON, nullable=True, default=dict)
# Relationships # Relationships
@@ -115,20 +113,24 @@ class SSHCertificate(BaseModel):
return f"<SSHCertificate serial={self.serial[:16]}... user_id={self.user_id}>" return f"<SSHCertificate serial={self.serial[:16]}... user_id={self.user_id}>"
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert certificate to dictionary.""" """Convert certificate to dictionary.
The raw ``certificate`` blob is excluded by default (it is large and
callers can request it explicitly by removing it from the exclude list).
"""
exclude = exclude or [] exclude = exclude or []
# Optionally exclude the certificate content (it's large)
if "certificate" not in exclude: if "certificate" not in exclude:
exclude.append("certificate") exclude.append("certificate")
data = super().to_dict(exclude=exclude) data = super().to_dict(exclude=exclude)
# Add computed fields
data["is_valid"] = self.is_valid() data["is_valid"] = self.is_valid()
data["days_until_expiry"] = self.days_until_expiry() data["days_until_expiry"] = self.days_until_expiry()
return data return data
def is_valid(self): def _aware(self, dt: datetime) -> datetime:
"""Return a timezone-aware UTC datetime."""
return dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt
def is_valid(self) -> bool:
"""Check if certificate is currently valid. """Check if certificate is currently valid.
Returns: Returns:
@@ -136,46 +138,39 @@ class SSHCertificate(BaseModel):
""" """
if self.revoked or self.status == CertificateStatus.REVOKED: if self.revoked or self.status == CertificateStatus.REVOKED:
return False return False
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
valid_after = self.valid_after.replace(tzinfo=timezone.utc) if self.valid_after.tzinfo is None else self.valid_after return self._aware(self.valid_after) <= now <= self._aware(self.valid_before)
valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before
return valid_after <= now <= valid_before
def is_expired(self): def is_expired(self) -> bool:
"""Check if certificate has expired. """Check if certificate has expired.
Returns: Returns:
True if current time is past valid_before True if current time is past valid_before
""" """
now = datetime.now(timezone.utc) return datetime.now(timezone.utc) > self._aware(self.valid_before)
valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before
return now > valid_before
def days_until_expiry(self): def days_until_expiry(self) -> int:
"""Get number of days until certificate expires. """Get number of days until certificate expires.
Returns: Returns:
Number of days remaining (negative if already expired) Number of days remaining (negative if already expired)
""" """
now = datetime.now(timezone.utc) delta = self._aware(self.valid_before) - datetime.now(timezone.utc)
valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before
delta = valid_before - now
return delta.days + (1 if delta.seconds > 0 else 0) return delta.days + (1 if delta.seconds > 0 else 0)
def revoke(self, reason=None): def revoke(self, reason: str = None) -> None:
"""Revoke this certificate. """Revoke this certificate.
Args: Args:
reason: Optional reason for revocation reason: Optional reason for revocation
""" """
self.revoked = True self.revoked = True
self.revoked_at = datetime.utcnow() self.revoked_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
self.revoke_reason = reason self.revoke_reason = reason
self.status = CertificateStatus.REVOKED self.status = CertificateStatus.REVOKED
self.save() self.save()
def mark_expired(self): def mark_expired(self) -> None:
"""Mark certificate as expired when validity period ends.""" """Mark certificate as expired when validity period ends."""
self.status = CertificateStatus.EXPIRED self.status = CertificateStatus.EXPIRED
self.save() self.save()
@@ -1,5 +1,5 @@
"""SSH Key model.""" """SSH Key model — user SSH public keys registered for certificate signing."""
from datetime import datetime from datetime import datetime, timezone
from gatehouse_app.extensions import db from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel from gatehouse_app.models.base import BaseModel
@@ -7,8 +7,8 @@ from gatehouse_app.models.base import BaseModel
class SSHKey(BaseModel): class SSHKey(BaseModel):
"""SSH Key model representing a user's SSH public key. """SSH Key model representing a user's SSH public key.
This model stores SSH public keys that users register for certificate signing. Users register SSH public keys for certificate signing. Keys must be
Users must verify ownership of the key before it can be used for signing certificates. verified (owner proved possession) before they can be used.
""" """
__tablename__ = "ssh_keys" __tablename__ = "ssh_keys"
@@ -20,30 +20,26 @@ class SSHKey(BaseModel):
index=True, index=True,
) )
# SSH key payload in OpenSSH format (e.g., "ssh-rsa AAAAB3Nz...") # SSH key payload in OpenSSH format (e.g., "ssh-ed25519 AAAAB3Nz...")
payload = db.Column(db.Text, nullable=False, unique=True) payload = db.Column(db.Text, nullable=False, unique=True)
# SHA256 fingerprint for quick comparison # SHA256 fingerprint for quick comparison and deduplication
fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True) fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True)
# Optional description for the key (e.g., "My laptop key") # Optional human-readable description (e.g., "My laptop key")
description = db.Column(db.String(255), nullable=True) description = db.Column(db.String(255), nullable=True)
# Verification status # Verification status
verified = db.Column(db.Boolean, default=False, nullable=False, index=True) verified = db.Column(db.Boolean, default=False, nullable=False, index=True)
verified_at = db.Column(db.DateTime, nullable=True) verified_at = db.Column(db.DateTime, nullable=True)
# Verification challenge # Verification challenge — shown to user once, cleared after verification
verify_text = db.Column(db.String(255), nullable=True) verify_text = db.Column(db.String(255), nullable=True)
verify_text_created_at = db.Column(db.DateTime, nullable=True) verify_text_created_at = db.Column(db.DateTime, nullable=True)
# Key type extracted from the key (ssh-rsa, ssh-ed25519, etc.) # Key metadata extracted from the key
key_type = db.Column(db.String(50), nullable=True) key_type = db.Column(db.String(50), nullable=True) # ssh-rsa, ssh-ed25519, etc.
key_bits = db.Column(db.Integer, nullable=True) # key length
# Key bits/length
key_bits = db.Column(db.Integer, nullable=True)
# Comment from the key (usually email or key name)
key_comment = db.Column(db.String(255), nullable=True) key_comment = db.Column(db.String(255), nullable=True)
# Relationships # Relationships
@@ -64,33 +60,39 @@ class SSHKey(BaseModel):
return f"<SSHKey {self.fingerprint[:16]}... user_id={self.user_id}>" return f"<SSHKey {self.fingerprint[:16]}... user_id={self.user_id}>"
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert SSH key to dictionary.""" """Convert SSH key to dictionary.
``payload`` and ``verify_text`` are never exposed through the API.
"""
exclude = exclude or [] exclude = exclude or []
exclude.extend(["payload", "verify_text"]) # Never expose these in API for field in ("payload", "verify_text"):
if field not in exclude:
exclude.append(field)
data = super().to_dict(exclude=exclude) data = super().to_dict(exclude=exclude)
# Add computed fields
data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None]) data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None])
return data return data
def mark_verified(self): def mark_verified(self) -> None:
"""Mark this SSH key as verified.""" """Mark this SSH key as verified and clear the challenge."""
self.verified = True self.verified = True
self.verified_at = datetime.utcnow() self.verified_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
self.verify_text = None
self.save() self.save()
def needs_verification_refresh(self, max_age_hours=24): def needs_verification_refresh(self, max_age_hours: int = 24) -> bool:
"""Check if verification challenge needs to be refreshed. """Check if verification challenge needs to be refreshed.
Args: Args:
max_age_hours: Maximum age of verification challenge in hours max_age_hours: Maximum age of verification challenge in hours
Returns: Returns:
True if verification challenge is stale True if verification challenge is stale or missing
""" """
if not self.verify_text_created_at: if not self.verify_text_created_at:
return True return True
age = datetime.now(timezone.utc) - self.verify_text_created_at.replace(
age = datetime.utcnow() - self.verify_text_created_at tzinfo=timezone.utc
) if self.verify_text_created_at.tzinfo is None else (
datetime.now(timezone.utc) - self.verify_text_created_at
)
return age.total_seconds() > (max_age_hours * 3600) return age.total_seconds() > (max_age_hours * 3600)
+3 -176
View File
@@ -1,177 +1,4 @@
"""User model.""" """Backward-compatibility shim — import from gatehouse_app.models.user.user instead."""
from gatehouse_app.extensions import db from gatehouse_app.models.user.user import User # noqa: F401
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import UserStatus
__all__ = ["User"]
class User(BaseModel):
"""User model representing a user account."""
__tablename__ = "users"
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
email_verified = db.Column(db.Boolean, default=False, nullable=False)
full_name = db.Column(db.String(255), nullable=True)
avatar_url = db.Column(db.String(512), nullable=True)
status = db.Column(
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
)
last_login_at = db.Column(db.DateTime, nullable=True)
last_login_ip = db.Column(db.String(45), nullable=True)
# Relationships
authentication_methods = db.relationship(
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
)
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
organization_memberships = db.relationship(
"OrganizationMember",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="OrganizationMember.user_id",
)
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
security_policies = db.relationship(
"UserSecurityPolicy",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="UserSecurityPolicy.user_id",
)
mfa_compliance = db.relationship(
"MfaPolicyCompliance",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="MfaPolicyCompliance.user_id",
)
department_memberships = db.relationship(
"DepartmentMembership",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="DepartmentMembership.user_id",
)
principal_memberships = db.relationship(
"PrincipalMembership",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="PrincipalMembership.user_id",
)
ssh_keys = db.relationship(
"SSHKey",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHKey.user_id",
)
ssh_certificates = db.relationship(
"SSHCertificate",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHCertificate.user_id",
)
def __repr__(self):
"""String representation of User."""
return f"<User {self.email}>"
def to_dict(self, exclude=None):
"""Convert user to dictionary, excluding sensitive fields by default."""
exclude = exclude or []
# Always exclude password-related fields
default_exclude = []
all_exclude = list(set(default_exclude + exclude))
return super().to_dict(exclude=all_exclude)
def has_password_auth(self):
"""Check if user has password authentication enabled."""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
).first()
is not None
)
def get_organizations(self):
"""Get all organizations the user is a member of."""
return [membership.organization for membership in self.organization_memberships]
def has_totp_enabled(self) -> bool:
"""Check if user has TOTP enabled and verified.
Returns:
True if user has a verified TOTP authentication method, False otherwise.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.TOTP,
verified=True,
deleted_at=None,
).first()
is not None
)
def get_totp_method(self):
"""Get user's TOTP authentication method.
Returns:
The AuthenticationMethod instance for TOTP or None if not found.
Note:
Returns the most recently created TOTP method to handle cases where
multiple enrollment attempts may exist.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
).order_by(AuthenticationMethod.created_at.desc()).first()
def has_webauthn_enabled(self) -> bool:
"""Check if user has any WebAuthn passkey credentials.
Returns:
True if user has at least one WebAuthn credential, False otherwise.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None,
).first()
is not None
)
def get_webauthn_credentials(self):
"""Get all WebAuthn credentials for the user.
Returns:
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
).order_by(AuthenticationMethod.created_at.desc()).all()
def get_webauthn_credential_count(self) -> int:
"""Get the count of WebAuthn credentials for the user.
Returns:
Number of WebAuthn credentials.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
).count()
+5
View File
@@ -0,0 +1,5 @@
"""User subpackage."""
from gatehouse_app.models.user.user import User
from gatehouse_app.models.user.session import Session
__all__ = ["User", "Session"]
@@ -21,7 +21,9 @@ class Session(BaseModel):
# Timing # Timing
expires_at = db.Column(db.DateTime, nullable=False) expires_at = db.Column(db.DateTime, nullable=False)
last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)) last_activity_at = db.Column(
db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)
)
revoked_at = db.Column(db.DateTime, nullable=True) revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True) revoked_reason = db.Column(db.String(255), nullable=True)
@@ -38,7 +40,6 @@ class Session(BaseModel):
def is_active(self): def is_active(self):
"""Check if session is currently active.""" """Check if session is currently active."""
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive
expires_at = self.expires_at expires_at = self.expires_at
if expires_at.tzinfo is None: if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc) expires_at = expires_at.replace(tzinfo=timezone.utc)
@@ -51,15 +52,13 @@ class Session(BaseModel):
def is_expired(self): def is_expired(self):
"""Check if session has expired.""" """Check if session has expired."""
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive
expires_at = self.expires_at expires_at = self.expires_at
if expires_at.tzinfo is None: if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc) expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at return now > expires_at
def refresh(self, duration_seconds=86400): def refresh(self, duration_seconds: int = 86400):
""" """Refresh session expiration.
Refresh session expiration.
Args: Args:
duration_seconds: New session duration in seconds duration_seconds: New session duration in seconds
@@ -68,9 +67,8 @@ class Session(BaseModel):
self.last_activity_at = datetime.now(timezone.utc) self.last_activity_at = datetime.now(timezone.utc)
db.session.commit() db.session.commit()
def revoke(self, reason=None): def revoke(self, reason: str = None):
""" """Revoke the session.
Revoke the session.
Args: Args:
reason: Optional reason for revocation reason: Optional reason for revocation
@@ -84,6 +82,5 @@ class Session(BaseModel):
def to_dict(self, exclude=None): def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields.""" """Convert to dictionary, excluding sensitive fields."""
exclude = exclude or [] exclude = exclude or []
# Exclude token from dict
exclude.append("token") exclude.append("token")
return super().to_dict(exclude=exclude) return super().to_dict(exclude=exclude)
+209
View File
@@ -0,0 +1,209 @@
"""User model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import UserStatus
class User(BaseModel):
"""User model representing a user account."""
__tablename__ = "users"
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
email_verified = db.Column(db.Boolean, default=False, nullable=False)
full_name = db.Column(db.String(255), nullable=True)
avatar_url = db.Column(db.String(512), nullable=True)
status = db.Column(
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
)
last_login_at = db.Column(db.DateTime, nullable=True)
last_login_ip = db.Column(db.String(45), nullable=True)
# Account activation (email-link flow)
activated = db.Column(db.Boolean, default=True, nullable=False)
activation_key = db.Column(db.String(128), unique=True, nullable=True, index=True)
# Relationships defined here only for models that don't circular-import
authentication_methods = db.relationship(
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
)
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
organization_memberships = db.relationship(
"OrganizationMember",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="OrganizationMember.user_id",
)
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
security_policies = db.relationship(
"UserSecurityPolicy",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="UserSecurityPolicy.user_id",
)
mfa_compliance = db.relationship(
"MfaPolicyCompliance",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="MfaPolicyCompliance.user_id",
)
department_memberships = db.relationship(
"DepartmentMembership",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="DepartmentMembership.user_id",
)
principal_memberships = db.relationship(
"PrincipalMembership",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="PrincipalMembership.user_id",
)
ssh_keys = db.relationship(
"SSHKey",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHKey.user_id",
)
ssh_certificates = db.relationship(
"SSHCertificate",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHCertificate.user_id",
)
ca_permissions = db.relationship(
"CAPermission",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="CAPermission.user_id",
)
# OIDC relationships registered here (no monkey-patching needed)
oidc_auth_codes = db.relationship(
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
)
oidc_refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
)
oidc_sessions = db.relationship(
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
)
oidc_token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
)
oidc_audit_logs = db.relationship(
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
)
def __repr__(self):
"""String representation of User."""
return f"<User {self.email}>"
def to_dict(self, exclude=None):
"""Convert user to dictionary, excluding sensitive fields by default."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
def has_password_auth(self):
"""Check if user has password authentication enabled."""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
).first()
is not None
)
def get_organizations(self):
"""Get all organizations the user is a member of."""
return [membership.organization for membership in self.organization_memberships]
def has_totp_enabled(self) -> bool:
"""Check if user has TOTP enabled and verified.
Returns:
True if user has a verified TOTP authentication method, False otherwise.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.TOTP,
verified=True,
deleted_at=None,
).first()
is not None
)
def get_totp_method(self):
"""Get user's TOTP authentication method.
Returns:
The AuthenticationMethod instance for TOTP or None if not found.
Note:
Returns the most recently created TOTP method to handle cases where
multiple enrollment attempts may exist.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
)
.order_by(AuthenticationMethod.created_at.desc())
.first()
)
def has_webauthn_enabled(self) -> bool:
"""Check if user has any WebAuthn passkey credentials.
Returns:
True if user has at least one WebAuthn credential, False otherwise.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None,
).first()
is not None
)
def get_webauthn_credentials(self):
"""Get all WebAuthn credentials for the user.
Returns:
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
)
.order_by(AuthenticationMethod.created_at.desc())
.all()
)
def get_webauthn_credential_count(self) -> int:
"""Get the count of WebAuthn credentials for the user.
Returns:
Number of WebAuthn credentials.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
).count()