From 07193a2d2e8dd29376b770c8ebc39a015c520c80 Mon Sep 17 00:00:00 2001 From: James Bhattarai Date: Sun, 1 Mar 2026 12:40:48 +0545 Subject: [PATCH] Chore: Refractor Models into organized file/folder --- gatehouse_app/models/__init__.py | 162 ++++++++++---- gatehouse_app/models/auth/__init__.py | 20 ++ gatehouse_app/models/{ => auth}/audit_log.py | 8 +- .../{ => auth}/authentication_method.py | 196 ++++++++-------- .../models/auth/email_verification_token.py | 68 ++++++ .../models/auth/password_reset_token.py | 69 ++++++ gatehouse_app/models/oidc/__init__.py | 18 ++ .../models/{ => oidc}/oidc_audit_log.py | 133 ++++++----- .../{ => oidc}/oidc_authorization_code.py | 66 +++--- .../models/{ => oidc}/oidc_client.py | 37 +++- gatehouse_app/models/oidc/oidc_jwks_key.py | 76 +++++++ .../models/{ => oidc}/oidc_refresh_token.py | 74 +++---- .../models/{ => oidc}/oidc_session.py | 71 +++--- .../models/{ => oidc}/oidc_token_metadata.py | 94 ++++---- gatehouse_app/models/oidc_jwks_key.py | 77 ------- gatehouse_app/models/organization/__init__.py | 27 +++ .../models/{ => organization}/department.py | 74 ++++--- .../organization/department_cert_policy.py | 76 +++++++ .../models/organization/org_invite_token.py | 77 +++++++ .../models/{ => organization}/organization.py | 4 +- .../{ => organization}/organization_member.py | 20 +- .../models/{ => organization}/principal.py | 127 +++++------ gatehouse_app/models/security/__init__.py | 12 + .../{ => security}/mfa_policy_compliance.py | 21 +- .../organization_security_policy.py | 12 +- .../{ => security}/user_security_policy.py | 22 +- gatehouse_app/models/ssh_ca/__init__.py | 17 ++ gatehouse_app/models/{ => ssh_ca}/ca.py | 120 +++++----- .../{ => ssh_ca}/certificate_audit_log.py | 46 ++-- .../models/{ => ssh_ca}/ssh_certificate.py | 97 ++++---- gatehouse_app/models/{ => ssh_ca}/ssh_key.py | 76 +++---- gatehouse_app/models/user.py | 179 +-------------- gatehouse_app/models/user/__init__.py | 5 + gatehouse_app/models/{ => user}/session.py | 17 +- gatehouse_app/models/user/user.py | 209 ++++++++++++++++++ 35 files changed, 1475 insertions(+), 932 deletions(-) create mode 100644 gatehouse_app/models/auth/__init__.py rename gatehouse_app/models/{ => auth}/audit_log.py (92%) rename gatehouse_app/models/{ => auth}/authentication_method.py (74%) create mode 100644 gatehouse_app/models/auth/email_verification_token.py create mode 100644 gatehouse_app/models/auth/password_reset_token.py create mode 100644 gatehouse_app/models/oidc/__init__.py rename gatehouse_app/models/{ => oidc}/oidc_audit_log.py (68%) rename gatehouse_app/models/{ => oidc}/oidc_authorization_code.py (65%) rename gatehouse_app/models/{ => oidc}/oidc_client.py (63%) create mode 100644 gatehouse_app/models/oidc/oidc_jwks_key.py rename gatehouse_app/models/{ => oidc}/oidc_refresh_token.py (67%) rename gatehouse_app/models/{ => oidc}/oidc_session.py (70%) rename gatehouse_app/models/{ => oidc}/oidc_token_metadata.py (67%) delete mode 100644 gatehouse_app/models/oidc_jwks_key.py create mode 100644 gatehouse_app/models/organization/__init__.py rename gatehouse_app/models/{ => organization}/department.py (80%) create mode 100644 gatehouse_app/models/organization/department_cert_policy.py create mode 100644 gatehouse_app/models/organization/org_invite_token.py rename gatehouse_app/models/{ => organization}/organization.py (94%) rename gatehouse_app/models/{ => organization}/organization_member.py (78%) rename gatehouse_app/models/{ => organization}/principal.py (69%) create mode 100644 gatehouse_app/models/security/__init__.py rename gatehouse_app/models/{ => security}/mfa_policy_compliance.py (76%) rename gatehouse_app/models/{ => security}/organization_security_policy.py (83%) rename gatehouse_app/models/{ => security}/user_security_policy.py (67%) create mode 100644 gatehouse_app/models/ssh_ca/__init__.py rename gatehouse_app/models/{ => ssh_ca}/ca.py (72%) rename gatehouse_app/models/{ => ssh_ca}/certificate_audit_log.py (78%) rename gatehouse_app/models/{ => ssh_ca}/ssh_certificate.py (70%) rename gatehouse_app/models/{ => ssh_ca}/ssh_key.py (55%) create mode 100644 gatehouse_app/models/user/__init__.py rename gatehouse_app/models/{ => user}/session.py (86%) create mode 100644 gatehouse_app/models/user/user.py diff --git a/gatehouse_app/models/__init__.py b/gatehouse_app/models/__init__.py index 2b01a60..f15764c 100644 --- a/gatehouse_app/models/__init__.py +++ b/gatehouse_app/models/__init__.py @@ -1,76 +1,150 @@ -"""Models package.""" -from gatehouse_app.models.base import BaseModel -from gatehouse_app.models.user import User -from gatehouse_app.models.organization import Organization -from gatehouse_app.models.organization_member import OrganizationMember -from gatehouse_app.models.authentication_method import ( +"""Models package. + +Sub-packages +------------ +models.user — User, Session +models.organization — Organization, OrganizationMember, Department, + DepartmentMembership, DepartmentPrincipal, + DepartmentCertPolicy, Principal, PrincipalMembership, + OrgInviteToken +models.auth — AuthenticationMethod, ApplicationProviderConfig, + OrganizationProviderOverride, OAuthState, + AuditLog, PasswordResetToken, EmailVerificationToken +models.oidc — OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCSession, + OIDCTokenMetadata, OIDCAuditLog, OidcJwksKey +models.ssh_ca — CA, KeyType, CertType, CaType, CAPermission, + SSHKey, SSHCertificate, CertificateStatus, + CertificateAuditLog +models.security — OrganizationSecurityPolicy, UserSecurityPolicy, + MfaPolicyCompliance + +All names are re-exported here so that existing code using the flat import +style (``from gatehouse_app.models import X``) or the old per-file style +(``from gatehouse_app.models.user import User``) continue to work unchanged. +""" + +# ── Base ────────────────────────────────────────────────────────────────────── +from gatehouse_app.models.base import BaseModel # noqa: F401 + +# ── User ────────────────────────────────────────────────────────────────────── +from gatehouse_app.models.user.user import User # noqa: F401 +from gatehouse_app.models.user.session import Session # noqa: F401 + +# ── Organization ────────────────────────────────────────────────────────────── +from gatehouse_app.models.organization.organization import Organization # noqa: F401 +from gatehouse_app.models.organization.organization_member import ( # noqa: F401 + OrganizationMember, +) +from gatehouse_app.models.organization.department import ( # noqa: F401 + Department, + DepartmentMembership, + DepartmentPrincipal, +) +from gatehouse_app.models.organization.department_cert_policy import ( # noqa: F401 + DepartmentCertPolicy, + STANDARD_EXTENSIONS, +) +from gatehouse_app.models.organization.principal import ( # noqa: F401 + Principal, + PrincipalMembership, +) +from gatehouse_app.models.organization.org_invite_token import OrgInviteToken # noqa: F401 + +# ── Auth ────────────────────────────────────────────────────────────────────── +from gatehouse_app.models.auth.authentication_method import ( # noqa: F401 AuthenticationMethod, ApplicationProviderConfig, OrganizationProviderOverride, OAuthState, ) -from gatehouse_app.models.session import Session -from gatehouse_app.models.audit_log import AuditLog -from gatehouse_app.models.oidc_client import OIDCClient -from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode -from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken -from gatehouse_app.models.oidc_session import OIDCSession -from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata -from gatehouse_app.models.oidc_audit_log import OIDCAuditLog -from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy -from gatehouse_app.models.user_security_policy import UserSecurityPolicy -from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance -from gatehouse_app.models.department import ( - Department, - DepartmentMembership, - DepartmentPrincipal, +from gatehouse_app.models.auth.audit_log import AuditLog # noqa: F401 +from gatehouse_app.models.auth.password_reset_token import PasswordResetToken # noqa: F401 +from gatehouse_app.models.auth.email_verification_token import ( # noqa: F401 + EmailVerificationToken, ) -from gatehouse_app.models.principal import ( - Principal, - PrincipalMembership, + +# ── OIDC ────────────────────────────────────────────────────────────────────── +from gatehouse_app.models.oidc.oidc_client import OIDCClient # noqa: F401 +from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode # noqa: F401 +from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken # noqa: F401 +from gatehouse_app.models.oidc.oidc_session import OIDCSession # noqa: F401 +from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata # noqa: F401 +from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog # noqa: F401 +from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey # noqa: F401 + +# ── SSH / CA ────────────────────────────────────────────────────────────────── +from gatehouse_app.models.ssh_ca.ca import ( # noqa: F401 + CA, + KeyType, + CertType, + CaType, + CAPermission, +) +from gatehouse_app.models.ssh_ca.ssh_key import SSHKey # noqa: F401 +from gatehouse_app.models.ssh_ca.ssh_certificate import ( # noqa: F401 + SSHCertificate, + CertificateStatus, +) +from gatehouse_app.models.ssh_ca.certificate_audit_log import ( # noqa: F401 + CertificateAuditLog, +) + +# ── Security ────────────────────────────────────────────────────────────────── +from gatehouse_app.models.security.organization_security_policy import ( # noqa: F401 + OrganizationSecurityPolicy, +) +from gatehouse_app.models.security.user_security_policy import ( # noqa: F401 + UserSecurityPolicy, +) +from gatehouse_app.models.security.mfa_policy_compliance import ( # noqa: F401 + MfaPolicyCompliance, ) -from gatehouse_app.models.ssh_key import SSHKey -from gatehouse_app.models.ca import CA, KeyType, CertType, CAPermission -from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus -from gatehouse_app.models.certificate_audit_log import CertificateAuditLog -from gatehouse_app.models.password_reset_token import PasswordResetToken -from gatehouse_app.models.email_verification_token import EmailVerificationToken -from gatehouse_app.models.org_invite_token import OrgInviteToken __all__ = [ + # Base "BaseModel", + # User "User", + "Session", + # Organization "Organization", "OrganizationMember", + "Department", + "DepartmentMembership", + "DepartmentPrincipal", + "DepartmentCertPolicy", + "STANDARD_EXTENSIONS", + "Principal", + "PrincipalMembership", + "OrgInviteToken", + # Auth "AuthenticationMethod", "ApplicationProviderConfig", "OrganizationProviderOverride", "OAuthState", - "Session", "AuditLog", + "PasswordResetToken", + "EmailVerificationToken", + # OIDC "OIDCClient", "OIDCAuthCode", "OIDCRefreshToken", "OIDCSession", "OIDCTokenMetadata", "OIDCAuditLog", - "OrganizationSecurityPolicy", - "UserSecurityPolicy", - "MfaPolicyCompliance", - "Department", - "DepartmentMembership", - "DepartmentPrincipal", - "Principal", - "PrincipalMembership", - "SSHKey", + "OidcJwksKey", + # SSH / CA "CA", "KeyType", "CertType", + "CaType", "CAPermission", + "SSHKey", "SSHCertificate", "CertificateStatus", "CertificateAuditLog", - "PasswordResetToken", - "EmailVerificationToken", - "OrgInviteToken", + # Security + "OrganizationSecurityPolicy", + "UserSecurityPolicy", + "MfaPolicyCompliance", ] diff --git a/gatehouse_app/models/auth/__init__.py b/gatehouse_app/models/auth/__init__.py new file mode 100644 index 0000000..e28b467 --- /dev/null +++ b/gatehouse_app/models/auth/__init__.py @@ -0,0 +1,20 @@ +"""Auth subpackage — authentication methods, tokens, and audit logs.""" +from gatehouse_app.models.auth.authentication_method import ( + AuthenticationMethod, + ApplicationProviderConfig, + OrganizationProviderOverride, + OAuthState, +) +from gatehouse_app.models.auth.audit_log import AuditLog +from gatehouse_app.models.auth.password_reset_token import PasswordResetToken +from gatehouse_app.models.auth.email_verification_token import EmailVerificationToken + +__all__ = [ + "AuthenticationMethod", + "ApplicationProviderConfig", + "OrganizationProviderOverride", + "OAuthState", + "AuditLog", + "PasswordResetToken", + "EmailVerificationToken", +] diff --git a/gatehouse_app/models/audit_log.py b/gatehouse_app/models/auth/audit_log.py similarity index 92% rename from gatehouse_app/models/audit_log.py rename to gatehouse_app/models/auth/audit_log.py index 3e3cea1..849f915 100644 --- a/gatehouse_app/models/audit_log.py +++ b/gatehouse_app/models/auth/audit_log.py @@ -26,14 +26,13 @@ class AuditLog(BaseModel): extra_data = db.Column(db.JSON, nullable=True) description = db.Column(db.Text, nullable=True) - # Success/failure + # Outcome success = db.Column(db.Boolean, default=True, nullable=False) error_message = db.Column(db.Text, nullable=True) # Relationships user = db.relationship("User", back_populates="audit_logs") - # Indexes for common queries __table_args__ = ( db.Index("idx_audit_user_action", "user_id", "action"), db.Index("idx_audit_resource", "resource_type", "resource_id"), @@ -45,9 +44,8 @@ class AuditLog(BaseModel): return f"" @classmethod - def log(cls, action, user_id=None, **kwargs): - """ - Create an audit log entry. + def log(cls, action, user_id=None, **kwargs) -> "AuditLog": + """Create an audit log entry. Args: action: AuditAction enum value diff --git a/gatehouse_app/models/authentication_method.py b/gatehouse_app/models/auth/authentication_method.py similarity index 74% rename from gatehouse_app/models/authentication_method.py rename to gatehouse_app/models/auth/authentication_method.py index 3766d52..3fbd7c0 100644 --- a/gatehouse_app/models/authentication_method.py +++ b/gatehouse_app/models/auth/authentication_method.py @@ -1,4 +1,4 @@ -"""Authentication method model.""" +"""Authentication method model — user credentials and OAuth provider config.""" from datetime import datetime, timedelta, timezone import secrets from gatehouse_app.extensions import db @@ -35,7 +35,6 @@ class AuthenticationMethod(BaseModel): # Relationships user = db.relationship("User", back_populates="authentication_methods") - # Ensure unique provider combinations __table_args__ = ( db.Index("idx_user_method", "user_id", "method_type"), db.UniqueConstraint( @@ -45,13 +44,15 @@ class AuthenticationMethod(BaseModel): def __repr__(self): """String representation of AuthenticationMethod.""" - return f"" + return ( + f"" + ) - def is_password(self): + def is_password(self) -> bool: """Check if this is a password authentication method.""" return self.method_type == AuthMethodType.PASSWORD - def is_oauth(self): + def is_oauth(self) -> bool: """Check if this is an OAuth authentication method.""" return self.method_type in [ AuthMethodType.GOOGLE, @@ -59,32 +60,32 @@ class AuthenticationMethod(BaseModel): AuthMethodType.MICROSOFT, ] - def is_totp(self): + def is_totp(self) -> bool: """Check if this is a TOTP authentication method.""" return self.method_type == AuthMethodType.TOTP - def is_webauthn(self): + def is_webauthn(self) -> bool: """Check if this is a WebAuthn authentication method.""" return self.method_type == AuthMethodType.WEBAUTHN def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude password hash and TOTP secrets - exclude.append("password_hash") - exclude.append("totp_secret") - exclude.append("totp_backup_codes") + # Always exclude credential material + for field in ("password_hash", "totp_secret", "totp_backup_codes"): + if field not in exclude: + exclude.append(field) return super().to_dict(exclude=exclude) def to_webauthn_dict(self): """Convert WebAuthn credential to public dictionary. - + Returns: - Dictionary with safe-to-expose credential information. + Dictionary with safe-to-expose credential information, or None. """ if not self.is_webauthn() or not self.provider_data: return None - + data = self.provider_data return { "id": data.get("credential_id"), @@ -98,26 +99,26 @@ class AuthenticationMethod(BaseModel): class ApplicationProviderConfig(BaseModel): """Application-wide OAuth provider configuration. - - This model stores OAuth provider credentials at the application level, - allowing users to authenticate without needing to specify an organization first. + + Stores OAuth provider credentials at the application level, allowing users + to authenticate without needing to specify an organization first. """ __tablename__ = "application_provider_configs" # Provider identification provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True) - - # OAuth credentials (encrypted) + + # OAuth credentials (client_secret encrypted at rest) client_id = db.Column(db.String(255), nullable=False) client_secret_encrypted = db.Column(db.String(512), nullable=True) - + # Provider status is_enabled = db.Column(db.Boolean, default=True, nullable=False) - + # Default redirect URL default_redirect_url = db.Column(db.String(2048), nullable=True) - + # Provider-specific settings (JSON) additional_config = db.Column(db.JSON, nullable=True) @@ -126,28 +127,34 @@ class ApplicationProviderConfig(BaseModel): "OrganizationProviderOverride", back_populates="application_config", foreign_keys="OrganizationProviderOverride.provider_type", - primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type", - cascade="all, delete-orphan" + primaryjoin=( + "ApplicationProviderConfig.provider_type" + "==OrganizationProviderOverride.provider_type" + ), + cascade="all, delete-orphan", ) def __repr__(self): """String representation of ApplicationProviderConfig.""" - return f"" + return ( + f"" + ) - def set_client_secret(self, plaintext_secret: str): + def set_client_secret(self, plaintext_secret: str) -> None: """Encrypt and store client secret. - + Args: plaintext_secret: The plaintext OAuth client secret """ if plaintext_secret: self.client_secret_encrypted = encrypt(plaintext_secret) - def get_client_secret(self) -> str: + def get_client_secret(self) -> str | None: """Decrypt and return client secret. - + Returns: - The plaintext OAuth client secret + The plaintext OAuth client secret, or None if not set. """ if self.client_secret_encrypted: return decrypt(self.client_secret_encrypted) @@ -156,37 +163,38 @@ class ApplicationProviderConfig(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude encrypted client secret - exclude.append("client_secret_encrypted") + if "client_secret_encrypted" not in exclude: + exclude.append("client_secret_encrypted") return super().to_dict(exclude=exclude) class OrganizationProviderOverride(BaseModel): """Organization-specific OAuth configuration overrides. - - This model allows organizations to override application-level OAuth settings - for enterprise SSO scenarios or custom provider configurations. + + Allows organizations to override application-level OAuth settings for + enterprise SSO scenarios or custom provider configurations. """ __tablename__ = "organization_provider_overrides" - # References organization_id = db.Column( - db.String(36), db.ForeignKey("organizations.id"), - nullable=False, index=True + db.String(36), + db.ForeignKey("organizations.id"), + nullable=False, + index=True, ) provider_type = db.Column(db.String(50), nullable=False, index=True) - - # Override OAuth credentials (encrypted, nullable - only if overriding) + + # Override OAuth credentials (encrypted, nullable — only set when overriding) client_id = db.Column(db.String(255), nullable=True) client_secret_encrypted = db.Column(db.String(512), nullable=True) - + # Provider status is_enabled = db.Column(db.Boolean, default=True, nullable=False) - + # Redirect URL override redirect_url_override = db.Column(db.String(2048), nullable=True) - + # Provider-specific settings override (JSON) additional_config = db.Column(db.JSON, nullable=True) @@ -196,37 +204,33 @@ class OrganizationProviderOverride(BaseModel): "ApplicationProviderConfig", back_populates="organization_overrides", foreign_keys=[provider_type], - primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type", - viewonly=True + primaryjoin=( + "ApplicationProviderConfig.provider_type" + "==OrganizationProviderOverride.provider_type" + ), + viewonly=True, ) - # Unique constraint on (organization_id, provider_type) __table_args__ = ( db.UniqueConstraint( - "organization_id", "provider_type", - name="uix_org_provider_type" + "organization_id", "provider_type", name="uix_org_provider_type" ), ) def __repr__(self): """String representation of OrganizationProviderOverride.""" - return f"" + return ( + f"" + ) - def set_client_secret(self, plaintext_secret: str): - """Encrypt and store client secret override. - - Args: - plaintext_secret: The plaintext OAuth client secret - """ + def set_client_secret(self, plaintext_secret: str) -> None: + """Encrypt and store client secret override.""" if plaintext_secret: self.client_secret_encrypted = encrypt(plaintext_secret) - def get_client_secret(self) -> str: - """Decrypt and return client secret override. - - Returns: - The plaintext OAuth client secret - """ + def get_client_secret(self) -> str | None: + """Decrypt and return client secret override.""" if self.client_secret_encrypted: return decrypt(self.client_secret_encrypted) return None @@ -234,53 +238,52 @@ class OrganizationProviderOverride(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude encrypted client secret - exclude.append("client_secret_encrypted") + if "client_secret_encrypted" not in exclude: + exclude.append("client_secret_encrypted") return super().to_dict(exclude=exclude) class OAuthState(BaseModel): """OAuth flow state tracking. - - This model tracks OAuth authentication flow state, including PKCE parameters - and organization context (which is now optional to support login flows where - the organization isn't known until after authentication). + + Tracks OAuth authentication flow state, including PKCE parameters and + organization context (which is optional to support login flows where the + organization isn't known until after authentication). """ __tablename__ = "oauth_states" # OAuth state parameter (unique, used for CSRF protection) state = db.Column(db.String(64), unique=True, nullable=False, index=True) - + # Flow type: "login", "register", "link" flow_type = db.Column(db.String(50), nullable=False) - + # Provider type provider_type = db.Column(db.String(50), nullable=False) - - # User context (optional - not set for login/register flows) + + # User context (optional — not set for login/register flows) user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True) - - # Organization context (NOW OPTIONAL - for SSO discovery or post-auth) + + # Organization context (optional — for SSO discovery or post-auth) organization_id = db.Column( - db.String(36), db.ForeignKey("organizations.id"), - nullable=True, index=True + db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True ) - + # PKCE parameters nonce = db.Column(db.String(128), nullable=True) code_verifier = db.Column(db.String(128), nullable=True) code_challenge = db.Column(db.String(128), nullable=True) - + # OAuth parameters redirect_uri = db.Column(db.String(2048), nullable=True) - + # Post-auth redirect (for frontend routing) return_url = db.Column(db.String(2048), nullable=True) - + # Additional state data extra_data = db.Column(db.JSON, nullable=True) - + # Expiration and usage tracking expires_at = db.Column(db.DateTime, nullable=False, index=True) used = db.Column(db.Boolean, default=False, nullable=False) @@ -291,7 +294,10 @@ class OAuthState(BaseModel): def __repr__(self): """String representation of OAuthState.""" - return f"" + return ( + f"" + ) @classmethod def create_state( @@ -306,10 +312,10 @@ class OAuthState(BaseModel): code_challenge: str = None, nonce: str = None, extra_data: dict = None, - lifetime_seconds: int = 600 - ): - """Create a new OAuth state with auto-generated state parameter. - + lifetime_seconds: int = 600, + ) -> "OAuthState": + """Create a new OAuth state with an auto-generated state parameter. + Args: flow_type: Type of flow ("login", "register", "link") provider_type: OAuth provider type @@ -322,13 +328,13 @@ class OAuthState(BaseModel): nonce: OpenID Connect nonce extra_data: Additional state data lifetime_seconds: How long the state is valid (default 10 minutes) - + Returns: New OAuthState instance """ state = secrets.token_urlsafe(32) expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds) - + oauth_state = cls( state=state, flow_type=flow_type, @@ -342,31 +348,30 @@ class OAuthState(BaseModel): nonce=nonce, extra_data=extra_data, expires_at=expires_at, - used=False + used=False, ) oauth_state.save() return oauth_state def is_valid(self) -> bool: """Check if the OAuth state is still valid. - + Returns: - True if state hasn't expired and hasn't been used + True if state hasn't expired and hasn't been used. """ now = datetime.now(timezone.utc) - # Make expires_at timezone-aware if it's naive (database returns naive datetimes) expires_at = self.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) return not self.used and expires_at > now - def mark_used(self): + def mark_used(self) -> None: """Mark the state as used to prevent replay attacks.""" self.used = True self.save() @classmethod - def cleanup_expired(cls): + def cleanup_expired(cls) -> None: """Remove expired OAuth states.""" now = datetime.now(timezone.utc) cls.query.filter(cls.expires_at < now).delete() @@ -375,6 +380,7 @@ class OAuthState(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Exclude code_verifier as it's sensitive - exclude.append("code_verifier") + # code_verifier must never be exposed + if "code_verifier" not in exclude: + exclude.append("code_verifier") return super().to_dict(exclude=exclude) diff --git a/gatehouse_app/models/auth/email_verification_token.py b/gatehouse_app/models/auth/email_verification_token.py new file mode 100644 index 0000000..9f40682 --- /dev/null +++ b/gatehouse_app/models/auth/email_verification_token.py @@ -0,0 +1,68 @@ +"""Email verification token model.""" +import secrets +from datetime import datetime, timezone, timedelta + +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class EmailVerificationToken(BaseModel): + """Single-use token for verifying a user's email address.""" + + __tablename__ = "email_verification_tokens" + + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + token = db.Column(db.String(128), unique=True, nullable=False, index=True) + expires_at = db.Column(db.DateTime, nullable=False) + used_at = db.Column(db.DateTime, nullable=True) + + user = db.relationship( + "User", + backref=db.backref("email_verification_tokens", cascade="all, delete-orphan"), + ) + + @classmethod + def generate(cls, user_id: str, ttl_hours: int = 24) -> "EmailVerificationToken": + """Create a new verification token for a user. + + Any existing unused tokens for this user are invalidated first. + """ + cls.query.filter_by(user_id=user_id, used_at=None).delete() + db.session.flush() + + token_value = secrets.token_urlsafe(48) + instance = cls( + user_id=user_id, + token=token_value, + expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours), + ) + db.session.add(instance) + db.session.commit() + return instance + + @property + def is_valid(self) -> bool: + """Return True if the token has not been used and has not expired.""" + if self.used_at is not None: + return False + now = datetime.now(timezone.utc) + expires = self.expires_at + if expires.tzinfo is None: + expires = expires.replace(tzinfo=timezone.utc) + return now < expires + + def consume(self) -> None: + """Mark the token as used.""" + self.used_at = datetime.now(timezone.utc) + db.session.commit() + + def __repr__(self) -> str: + return ( + f"" + ) diff --git a/gatehouse_app/models/auth/password_reset_token.py b/gatehouse_app/models/auth/password_reset_token.py new file mode 100644 index 0000000..53072ef --- /dev/null +++ b/gatehouse_app/models/auth/password_reset_token.py @@ -0,0 +1,69 @@ +"""Password reset token model.""" +import secrets +from datetime import datetime, timezone, timedelta + +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class PasswordResetToken(BaseModel): + """Single-use token for resetting a user's password.""" + + __tablename__ = "password_reset_tokens" + + user_id = db.Column( + db.String(36), + db.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + token = db.Column(db.String(128), unique=True, nullable=False, index=True) + expires_at = db.Column(db.DateTime, nullable=False) + used_at = db.Column(db.DateTime, nullable=True) + + user = db.relationship( + "User", + backref=db.backref("password_reset_tokens", cascade="all, delete-orphan"), + ) + + @classmethod + def generate(cls, user_id: str, ttl_hours: int = 2) -> "PasswordResetToken": + """Create a new password reset token for a user. + + Any existing unused tokens for this user are invalidated first. + """ + # Invalidate any existing unused tokens for this user + cls.query.filter_by(user_id=user_id, used_at=None).delete() + db.session.flush() + + token_value = secrets.token_urlsafe(48) + instance = cls( + user_id=user_id, + token=token_value, + expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours), + ) + db.session.add(instance) + db.session.commit() + return instance + + @property + def is_valid(self) -> bool: + """Return True if the token has not been used and has not expired.""" + if self.used_at is not None: + return False + now = datetime.now(timezone.utc) + expires = self.expires_at + if expires.tzinfo is None: + expires = expires.replace(tzinfo=timezone.utc) + return now < expires + + def consume(self) -> None: + """Mark the token as used.""" + self.used_at = datetime.now(timezone.utc) + db.session.commit() + + def __repr__(self) -> str: + return ( + f"" + ) diff --git a/gatehouse_app/models/oidc/__init__.py b/gatehouse_app/models/oidc/__init__.py new file mode 100644 index 0000000..2cb08da --- /dev/null +++ b/gatehouse_app/models/oidc/__init__.py @@ -0,0 +1,18 @@ +"""OIDC subpackage — clients, tokens, sessions, and audit logs.""" +from gatehouse_app.models.oidc.oidc_client import OIDCClient +from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode +from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken +from gatehouse_app.models.oidc.oidc_session import OIDCSession +from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata +from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog +from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey + +__all__ = [ + "OIDCClient", + "OIDCAuthCode", + "OIDCRefreshToken", + "OIDCSession", + "OIDCTokenMetadata", + "OIDCAuditLog", + "OidcJwksKey", +] diff --git a/gatehouse_app/models/oidc_audit_log.py b/gatehouse_app/models/oidc/oidc_audit_log.py similarity index 68% rename from gatehouse_app/models/oidc_audit_log.py rename to gatehouse_app/models/oidc/oidc_audit_log.py index 39b21a5..c0ae557 100644 --- a/gatehouse_app/models/oidc_audit_log.py +++ b/gatehouse_app/models/oidc/oidc_audit_log.py @@ -1,5 +1,4 @@ """OIDC Audit Log model for comprehensive OIDC event tracking.""" -from datetime import datetime from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel @@ -7,8 +6,7 @@ from gatehouse_app.models.base import BaseModel class OIDCAuditLog(BaseModel): """OIDC Audit Log model for comprehensive OIDC event tracking. - This model logs all OIDC-related events for security, compliance, - and debugging purposes. + Logs all OIDC-related events for security, compliance, and debugging. """ __tablename__ = "oidc_audit_logs" @@ -46,16 +44,29 @@ class OIDCAuditLog(BaseModel): def __repr__(self): """String representation of OIDCAuditLog.""" status = "success" if self.success else "failed" - return f"" + return ( + f"" + ) @classmethod - def log_event(cls, event_type, client_id=None, user_id=None, success=True, - error_code=None, error_description=None, ip_address=None, - user_agent=None, request_id=None, event_metadata=None): + def log_event( + cls, + event_type: str, + client_id: str = None, + user_id: str = None, + success: bool = True, + error_code: str = None, + error_description: str = None, + ip_address: str = None, + user_agent: str = None, + request_id: str = None, + event_metadata: dict = None, + ) -> "OIDCAuditLog": """Log an OIDC event. Args: - event_type: Type of event (e.g., "authorization_request", "token_issue") + event_type: Type of event (e.g., "authorization_request") client_id: The OIDC client ID user_id: The user ID success: Whether the event was successful @@ -86,9 +97,19 @@ class OIDCAuditLog(BaseModel): return log @classmethod - def log_authorization_request(cls, client_id, user_id, redirect_uri, scope, - ip_address=None, user_agent=None, request_id=None, - success=True, error_code=None, error_description=None): + def log_authorization_request( + cls, + client_id: str, + user_id: str, + redirect_uri: str, + scope, + ip_address: str = None, + user_agent: str = None, + request_id: str = None, + success: bool = True, + error_code: str = None, + error_description: str = None, + ) -> "OIDCAuditLog": """Log an authorization request event.""" return cls.log_event( event_type="authorization_request", @@ -100,15 +121,19 @@ class OIDCAuditLog(BaseModel): ip_address=ip_address, user_agent=user_agent, request_id=request_id, - event_metadata={ - "redirect_uri": redirect_uri, - "scope": scope, - } + event_metadata={"redirect_uri": redirect_uri, "scope": scope}, ) @classmethod - def log_token_issue(cls, client_id, user_id, token_type, - ip_address=None, user_agent=None, request_id=None): + def log_token_issue( + cls, + client_id: str, + user_id: str, + token_type: str, + ip_address: str = None, + user_agent: str = None, + request_id: str = None, + ) -> "OIDCAuditLog": """Log a token issuance event.""" return cls.log_event( event_type="token_issue", @@ -118,12 +143,20 @@ class OIDCAuditLog(BaseModel): ip_address=ip_address, user_agent=user_agent, request_id=request_id, - event_metadata={"token_type": token_type} + event_metadata={"token_type": token_type}, ) @classmethod - def log_token_revocation(cls, client_id, user_id, token_type, reason=None, - ip_address=None, user_agent=None, request_id=None): + def log_token_revocation( + cls, + client_id: str, + user_id: str, + token_type: str, + reason: str = None, + ip_address: str = None, + user_agent: str = None, + request_id: str = None, + ) -> "OIDCAuditLog": """Log a token revocation event.""" return cls.log_event( event_type="token_revocation", @@ -133,15 +166,19 @@ class OIDCAuditLog(BaseModel): ip_address=ip_address, user_agent=user_agent, request_id=request_id, - event_metadata={ - "token_type": token_type, - "reason": reason, - } + event_metadata={"token_type": token_type, "reason": reason}, ) @classmethod - def log_authentication_failure(cls, client_id, error_code, error_description, - ip_address=None, user_agent=None, request_id=None): + def log_authentication_failure( + cls, + client_id: str, + error_code: str, + error_description: str, + ip_address: str = None, + user_agent: str = None, + request_id: str = None, + ) -> "OIDCAuditLog": """Log an authentication failure event.""" return cls.log_event( event_type="authentication_failure", @@ -155,7 +192,7 @@ class OIDCAuditLog(BaseModel): ) @classmethod - def get_events_for_user(cls, user_id, limit=100): + def get_events_for_user(cls, user_id: str, limit: int = 100) -> list: """Get audit events for a user. Args: @@ -165,13 +202,15 @@ class OIDCAuditLog(BaseModel): Returns: List of OIDCAuditLog instances """ - return cls.query.filter_by(user_id=user_id, deleted_at=None)\ - .order_by(cls.created_at.desc())\ - .limit(limit)\ + return ( + cls.query.filter_by(user_id=user_id, deleted_at=None) + .order_by(cls.created_at.desc()) + .limit(limit) .all() + ) @classmethod - def get_events_for_client(cls, client_id, limit=100): + def get_events_for_client(cls, client_id: str, limit: int = 100) -> list: """Get audit events for a client. Args: @@ -181,14 +220,22 @@ class OIDCAuditLog(BaseModel): Returns: List of OIDCAuditLog instances """ - return cls.query.filter_by(client_id=client_id, deleted_at=None)\ - .order_by(cls.created_at.desc())\ - .limit(limit)\ + return ( + cls.query.filter_by(client_id=client_id, deleted_at=None) + .order_by(cls.created_at.desc()) + .limit(limit) .all() + ) @classmethod - def get_failed_events(cls, client_id=None, user_id=None, start_date=None, - end_date=None, limit=100): + def get_failed_events( + cls, + client_id: str = None, + user_id: str = None, + start_date=None, + end_date=None, + limit: int = 100, + ) -> list: """Get failed audit events. Args: @@ -210,22 +257,8 @@ class OIDCAuditLog(BaseModel): query = query.filter(cls.created_at >= start_date) if end_date: query = query.filter(cls.created_at <= end_date) - return query.order_by(cls.created_at.desc()).limit(limit).all() def to_dict(self, exclude=None): """Convert to dictionary.""" return super().to_dict(exclude=exclude) - - -# Add relationship back to User model -from gatehouse_app.models.user import User -User.oidc_audit_logs = db.relationship( - "OIDCAuditLog", back_populates="user", cascade="all, delete-orphan" -) - -# Add relationship back to OIDCClient model -from gatehouse_app.models.oidc_client import OIDCClient -OIDCClient.audit_logs = db.relationship( - "OIDCAuditLog", back_populates="client", cascade="all, delete-orphan" -) diff --git a/gatehouse_app/models/oidc_authorization_code.py b/gatehouse_app/models/oidc/oidc_authorization_code.py similarity index 65% rename from gatehouse_app/models/oidc_authorization_code.py rename to gatehouse_app/models/oidc/oidc_authorization_code.py index 640078e..3884592 100644 --- a/gatehouse_app/models/oidc_authorization_code.py +++ b/gatehouse_app/models/oidc/oidc_authorization_code.py @@ -1,14 +1,14 @@ -"""OIDC Authorization Code model for auth code flow.""" +"""OIDC Authorization Code model for the authorization code grant flow.""" from datetime import datetime, timedelta, timezone from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel class OIDCAuthCode(BaseModel): - """OIDC Authorization Code model for authorization code flow. + """OIDC Authorization Code model for the authorization code grant flow. - Authorization codes are single-use, short-lived codes used in the - authorization code grant flow. The code is hashed for security. + Authorization codes are single-use, short-lived codes. The code itself is + hashed before storage so that a database breach cannot replay codes. """ __tablename__ = "oidc_authorization_codes" @@ -26,9 +26,9 @@ class OIDCAuthCode(BaseModel): # Request parameters redirect_uri = db.Column(db.String(512), nullable=False) - scope = db.Column(db.JSON, nullable=True) # Requested scopes - nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation - code_verifier = db.Column(db.String(255), nullable=True) # For PKCE + scope = db.Column(db.JSON, nullable=True) + nonce = db.Column(db.String(255), nullable=True) + code_verifier = db.Column(db.String(255), nullable=True) # Status tracking expires_at = db.Column(db.DateTime, nullable=False, index=True) @@ -39,37 +39,48 @@ class OIDCAuthCode(BaseModel): ip_address = db.Column(db.String(45), nullable=True) user_agent = db.Column(db.Text, nullable=True) - # Relationships + # Relationships — back_populates declared on User and OIDCClient client = db.relationship("OIDCClient", back_populates="authorization_codes") user = db.relationship("User", back_populates="oidc_auth_codes") def __repr__(self): """String representation of OIDCAuthCode.""" - return f"" + return ( + f"" + ) - def is_expired(self): + def is_expired(self) -> bool: """Check if the authorization code has expired.""" - # Handle both timezone-aware and timezone-naive expires_at values expires_at = self.expires_at if expires_at.tzinfo is None: - # Make naive datetime timezone-aware (UTC) expires_at = expires_at.replace(tzinfo=timezone.utc) return datetime.now(timezone.utc) > expires_at - def is_valid(self): + def is_valid(self) -> bool: """Check if the authorization code is valid for use.""" return not self.is_used and not self.is_expired() and self.deleted_at is None - def mark_as_used(self): + def mark_as_used(self) -> None: """Mark the authorization code as used.""" self.is_used = True self.used_at = datetime.now(timezone.utc) db.session.commit() @classmethod - def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None, - nonce=None, code_verifier=None, ip_address=None, user_agent=None, - lifetime_seconds=600): + def create_code( + cls, + client_id: str, + user_id: str, + code_hash: str, + redirect_uri: str, + scope=None, + nonce: str = None, + code_verifier: str = None, + ip_address: str = None, + user_agent: str = None, + lifetime_seconds: int = 600, + ) -> "OIDCAuthCode": """Create a new authorization code. Args: @@ -79,7 +90,7 @@ class OIDCAuthCode(BaseModel): redirect_uri: The redirect URI scope: Requested scopes nonce: OIDC nonce - code_verifier: PKCE code verifier + code_verifier: PKCE code verifier (stored hashed server-side) ip_address: Client IP address user_agent: Client user agent lifetime_seconds: Code lifetime in seconds (default 10 minutes) @@ -106,20 +117,7 @@ class OIDCAuthCode(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude code hash - exclude.append("code_hash") - exclude.append("code_verifier") + for field in ("code_hash", "code_verifier"): + if field not in exclude: + exclude.append(field) return super().to_dict(exclude=exclude) - - -# Add relationship back to User model -from gatehouse_app.models.user import User -User.oidc_auth_codes = db.relationship( - "OIDCAuthCode", back_populates="user", cascade="all, delete-orphan" -) - -# Add relationship back to OIDCClient model -from gatehouse_app.models.oidc_client import OIDCClient -OIDCClient.authorization_codes = db.relationship( - "OIDCAuthCode", back_populates="client", cascade="all, delete-orphan" -) diff --git a/gatehouse_app/models/oidc_client.py b/gatehouse_app/models/oidc/oidc_client.py similarity index 63% rename from gatehouse_app/models/oidc_client.py rename to gatehouse_app/models/oidc/oidc_client.py index a446983..03c0b18 100644 --- a/gatehouse_app/models/oidc_client.py +++ b/gatehouse_app/models/oidc/oidc_client.py @@ -17,10 +17,10 @@ class OIDCClient(BaseModel): client_secret_hash = db.Column(db.String(255), nullable=False) # OAuth/OIDC configuration - redirect_uris = db.Column(db.JSON, nullable=False) # List of allowed redirect URIs - grant_types = db.Column(db.JSON, nullable=False) # List of allowed grant types - response_types = db.Column(db.JSON, nullable=False) # List of allowed response types - scopes = db.Column(db.JSON, nullable=False) # List of allowed scopes + redirect_uris = db.Column(db.JSON, nullable=False) # Allowed redirect URIs + grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types + response_types = db.Column(db.JSON, nullable=False) # Allowed response types + scopes = db.Column(db.JSON, nullable=False) # Allowed scopes # Client metadata logo_uri = db.Column(db.String(512), nullable=True) @@ -41,6 +41,23 @@ class OIDCClient(BaseModel): # Relationships organization = db.relationship("Organization", back_populates="oidc_clients") + # OIDC sub-resource relationships (declared here, not monkey-patched elsewhere) + authorization_codes = db.relationship( + "OIDCAuthCode", back_populates="client", cascade="all, delete-orphan" + ) + refresh_tokens = db.relationship( + "OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan" + ) + oidc_sessions = db.relationship( + "OIDCSession", back_populates="client", cascade="all, delete-orphan" + ) + token_metadata = db.relationship( + "OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan" + ) + audit_logs = db.relationship( + "OIDCAuditLog", back_populates="client", cascade="all, delete-orphan" + ) + def __repr__(self): """String representation of OIDCClient.""" return f"" @@ -48,22 +65,22 @@ class OIDCClient(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude client secret - exclude.append("client_secret_hash") + if "client_secret_hash" not in exclude: + exclude.append("client_secret_hash") return super().to_dict(exclude=exclude) - def has_grant_type(self, grant_type): + def has_grant_type(self, grant_type) -> bool: """Check if client supports a specific grant type.""" return grant_type in self.grant_types - def has_response_type(self, response_type): + def has_response_type(self, response_type) -> bool: """Check if client supports a specific response type.""" return response_type in self.response_types - def is_redirect_uri_allowed(self, redirect_uri): + def is_redirect_uri_allowed(self, redirect_uri: str) -> bool: """Check if a redirect URI is allowed for this client.""" return redirect_uri in self.redirect_uris - def has_scope(self, scope): + def has_scope(self, scope: str) -> bool: """Check if client is allowed to request a specific scope.""" return scope in self.scopes diff --git a/gatehouse_app/models/oidc/oidc_jwks_key.py b/gatehouse_app/models/oidc/oidc_jwks_key.py new file mode 100644 index 0000000..f8fa982 --- /dev/null +++ b/gatehouse_app/models/oidc/oidc_jwks_key.py @@ -0,0 +1,76 @@ +"""OIDC JWKS Key model for persisting signing keys.""" +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class OidcJwksKey(BaseModel): + """OIDC JWKS Key model for persisting JSON Web Key Set signing keys. + + Stores RSA/ECDSA key pairs used for signing OIDC tokens. Multiple keys can + be stored to support key rotation scenarios. + + Attributes: + kid: Unique key ID used in JWT ``kid`` header + key_type: Type of key (e.g., "RSA", "EC") + private_key: PEM-encoded private key (never exposed in API responses) + public_key: PEM-encoded public key + algorithm: Signing algorithm (e.g., "RS256", "ES256") + is_active: Whether this key is currently used for signing/verification + is_primary: Whether this is the primary signing key + expires_at: Optional expiry for key rotation enforcement + """ + + __tablename__ = "oidc_jwks_keys" + + # Override the default UUID id with integer primary key for JWKS key sets + id = db.Column(db.Integer, primary_key=True) + + expires_at = db.Column(db.DateTime, nullable=True) + + # Key identification and type + kid = db.Column(db.String(255), unique=True, nullable=False, index=True) + key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC" + algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256" + + # Key material (PEM-encoded) — private_key must never be returned by API + private_key = db.Column(db.Text, nullable=False) + public_key = db.Column(db.Text, nullable=False) + + # Key status + is_active = db.Column(db.Boolean, default=True, nullable=False) + is_primary = db.Column(db.Boolean, default=False, nullable=False) + + def __repr__(self): + """String representation of OidcJwksKey.""" + return ( + f"" + ) + + def to_dict(self, exclude_private_key: bool = True): + """Convert model to dictionary. + + Args: + exclude_private_key: If True (default), excludes the private key. + + Returns: + Dictionary representation of the model + """ + exclude = ["private_key"] if exclude_private_key else [] + return super().to_dict(exclude=exclude) + + @classmethod + def get_active_keys(cls) -> list: + """Get all active keys for signing operations.""" + return cls.query.filter_by(is_active=True).all() + + @classmethod + def get_primary_key(cls) -> "OidcJwksKey | None": + """Get the primary signing key.""" + return cls.query.filter_by(is_primary=True).first() + + @classmethod + def get_key_by_kid(cls, kid: str) -> "OidcJwksKey | None": + """Get an active key by its key ID.""" + return cls.query.filter_by(kid=kid, is_active=True).first() diff --git a/gatehouse_app/models/oidc_refresh_token.py b/gatehouse_app/models/oidc/oidc_refresh_token.py similarity index 67% rename from gatehouse_app/models/oidc_refresh_token.py rename to gatehouse_app/models/oidc/oidc_refresh_token.py index a6459ea..3e1228f 100644 --- a/gatehouse_app/models/oidc_refresh_token.py +++ b/gatehouse_app/models/oidc/oidc_refresh_token.py @@ -1,5 +1,5 @@ """OIDC Refresh Token model for token rotation.""" -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel @@ -8,7 +8,8 @@ class OIDCRefreshToken(BaseModel): """OIDC Refresh Token model for token refresh and rotation. Refresh tokens are long-lived credentials used to obtain new access tokens. - They support token rotation for enhanced security. + They support token rotation for enhanced security — each use invalidates + the old token and issues a new one. """ __tablename__ = "oidc_refresh_tokens" @@ -21,16 +22,14 @@ class OIDCRefreshToken(BaseModel): db.String(36), db.ForeignKey("users.id"), nullable=False, index=True ) - # Token (hashed for security) + # Token (hashed for security — never store plaintext refresh tokens) token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True) - # Associated access token ID (stores JWT JTI string — no FK to sessions) - access_token_id = db.Column( - db.String(255), nullable=True, index=True - ) + # Associated access token JTI (no FK — stored as string for lightweight lookup) + access_token_id = db.Column(db.String(255), nullable=True, index=True) # Token scope - scope = db.Column(db.JSON, nullable=True) # Granted scopes + scope = db.Column(db.JSON, nullable=True) # Timing expires_at = db.Column(db.DateTime, nullable=False, index=True) @@ -40,7 +39,7 @@ class OIDCRefreshToken(BaseModel): revoked_reason = db.Column(db.String(255), nullable=True) # Token rotation metadata - previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation + previous_token_hash = db.Column(db.String(255), nullable=True) rotation_count = db.Column(db.Integer, default=0, nullable=False) # Request metadata @@ -53,25 +52,27 @@ class OIDCRefreshToken(BaseModel): def __repr__(self): """String representation of OIDCRefreshToken.""" - return f"" + return ( + f"" + ) - def is_expired(self): + def is_expired(self) -> bool: """Check if the refresh token has expired.""" - # Handle both timezone-aware and timezone-naive expires_at values expires_at = self.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) return datetime.now(timezone.utc) > expires_at - def is_revoked(self): + def is_revoked(self) -> bool: """Check if the refresh token has been revoked.""" return self.revoked_at is not None - def is_valid(self): + def is_valid(self) -> bool: """Check if the refresh token is valid for use.""" return not self.is_revoked() and not self.is_expired() and self.deleted_at is None - def revoke(self, reason=None): + def revoke(self, reason: str = None) -> None: """Revoke the refresh token. Args: @@ -81,8 +82,8 @@ class OIDCRefreshToken(BaseModel): self.revoked_reason = reason db.session.commit() - def rotate(self, new_token_hash): - """Rotate the refresh token (invalidate old, create new). + def rotate(self, new_token_hash: str) -> "OIDCRefreshToken": + """Rotate the refresh token — invalidate the old hash, store the new one. Args: new_token_hash: Hash of the new refresh token @@ -90,20 +91,25 @@ class OIDCRefreshToken(BaseModel): Returns: self for chaining """ - # Store reference to old token self.previous_token_hash = self.token_hash self.token_hash = new_token_hash self.rotation_count += 1 - # Extend expiration on rotation - from datetime import timedelta self.expires_at = datetime.now(timezone.utc) + timedelta(days=30) db.session.commit() return self @classmethod - def create_token(cls, client_id, user_id, token_hash, scope=None, - access_token_id=None, ip_address=None, user_agent=None, - lifetime_seconds=2592000): + def create_token( + cls, + client_id: str, + user_id: str, + token_hash: str, + scope=None, + access_token_id: str = None, + ip_address: str = None, + user_agent: str = None, + lifetime_seconds: int = 2592000, + ) -> "OIDCRefreshToken": """Create a new refresh token. Args: @@ -111,7 +117,7 @@ class OIDCRefreshToken(BaseModel): user_id: The user ID token_hash: Hashed refresh token scope: Granted scopes - access_token_id: Associated access token ID + access_token_id: Associated access token JTI ip_address: Client IP address user_agent: Client user agent lifetime_seconds: Token lifetime in seconds (default 30 days) @@ -119,7 +125,6 @@ class OIDCRefreshToken(BaseModel): Returns: OIDCRefreshToken instance """ - from datetime import timedelta token = cls( client_id=client_id, user_id=user_id, @@ -137,20 +142,7 @@ class OIDCRefreshToken(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Always exclude token hashes - exclude.append("token_hash") - exclude.append("previous_token_hash") + for field in ("token_hash", "previous_token_hash"): + if field not in exclude: + exclude.append(field) return super().to_dict(exclude=exclude) - - -# Add relationship back to User model -from gatehouse_app.models.user import User -User.oidc_refresh_tokens = db.relationship( - "OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan" -) - -# Add relationship back to OIDCClient model -from gatehouse_app.models.oidc_client import OIDCClient -OIDCClient.refresh_tokens = db.relationship( - "OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan" -) diff --git a/gatehouse_app/models/oidc_session.py b/gatehouse_app/models/oidc/oidc_session.py similarity index 70% rename from gatehouse_app/models/oidc_session.py rename to gatehouse_app/models/oidc/oidc_session.py index 8d6a88b..5768bd6 100644 --- a/gatehouse_app/models/oidc_session.py +++ b/gatehouse_app/models/oidc/oidc_session.py @@ -1,5 +1,7 @@ """OIDC Session model for OIDC session tracking.""" -from datetime import datetime, timezone +import hashlib +import base64 +from datetime import datetime, timedelta, timezone from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel @@ -7,8 +9,8 @@ from gatehouse_app.models.base import BaseModel class OIDCSession(BaseModel): """OIDC Session model for tracking OIDC authentication sessions. - This model tracks the state during the OIDC authentication flow, - including PKCE parameters and nonce validation. + Tracks the state during the OIDC authorization flow, including PKCE + parameters and nonce validation. """ __tablename__ = "oidc_sessions" @@ -25,11 +27,11 @@ class OIDCSession(BaseModel): # State management state = db.Column(db.String(255), nullable=False, index=True) - nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation + nonce = db.Column(db.String(255), nullable=True) # Authorization request parameters redirect_uri = db.Column(db.String(512), nullable=False) - scope = db.Column(db.JSON, nullable=True) # Requested scopes + scope = db.Column(db.JSON, nullable=True) # PKCE parameters code_challenge = db.Column(db.String(255), nullable=True) @@ -45,50 +47,52 @@ class OIDCSession(BaseModel): def __repr__(self): """String representation of OIDCSession.""" - return f"" + return ( + f"" + ) - def is_expired(self): + def is_expired(self) -> bool: """Check if the OIDC session has expired.""" - return datetime.now(timezone.utc) > self.expires_at + expires_at = self.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + return datetime.now(timezone.utc) > expires_at - def is_authenticated(self): + def is_authenticated(self) -> bool: """Check if the user has been authenticated in this session.""" return self.authenticated_at is not None - def mark_authenticated(self): + def mark_authenticated(self) -> None: """Mark the session as authenticated.""" self.authenticated_at = datetime.now(timezone.utc) db.session.commit() - def validate_nonce(self, expected_nonce): + def validate_nonce(self, expected_nonce: str) -> bool: """Validate the nonce matches the expected value. Args: expected_nonce: The expected nonce value Returns: - bool: True if nonce matches + True if nonce matches """ return self.nonce == expected_nonce - def validate_code_challenge(self, code_verifier): + def validate_code_challenge(self, code_verifier: str) -> bool: """Validate the code verifier against the stored code challenge. Args: code_verifier: The PKCE code verifier Returns: - bool: True if code challenge is valid + True if the challenge is satisfied """ if not self.code_challenge: return False if self.code_challenge_method == "S256": - import hashlib - import base64 - # SHA256 hash of code_verifier digest = hashlib.sha256(code_verifier.encode()).digest() - # Base64 URL encode without padding expected = base64.urlsafe_b64encode(digest).decode().rstrip("=") return self.code_challenge == expected elif self.code_challenge_method == "plain": @@ -97,9 +101,18 @@ class OIDCSession(BaseModel): return False @classmethod - def create_session(cls, user_id, client_id, state, redirect_uri, scope=None, - nonce=None, code_challenge=None, code_challenge_method=None, - lifetime_seconds=600): + def create_session( + cls, + user_id: str, + client_id: str, + state: str, + redirect_uri: str, + scope=None, + nonce: str = None, + code_challenge: str = None, + code_challenge_method: str = None, + lifetime_seconds: int = 600, + ) -> "OIDCSession": """Create a new OIDC session. Args: @@ -116,7 +129,6 @@ class OIDCSession(BaseModel): Returns: OIDCSession instance """ - from datetime import timedelta session = cls( user_id=user_id, client_id=client_id, @@ -133,7 +145,7 @@ class OIDCSession(BaseModel): return session @classmethod - def get_by_state(cls, state): + def get_by_state(cls, state: str) -> "OIDCSession | None": """Get a session by state parameter. Args: @@ -147,16 +159,3 @@ class OIDCSession(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary.""" return super().to_dict(exclude=exclude) - - -# Add relationship back to User model -from gatehouse_app.models.user import User -User.oidc_sessions = db.relationship( - "OIDCSession", back_populates="user", cascade="all, delete-orphan" -) - -# Add relationship back to OIDCClient model -from gatehouse_app.models.oidc_client import OIDCClient -OIDCClient.oidc_sessions = db.relationship( - "OIDCSession", back_populates="client", cascade="all, delete-orphan" -) diff --git a/gatehouse_app/models/oidc_token_metadata.py b/gatehouse_app/models/oidc/oidc_token_metadata.py similarity index 67% rename from gatehouse_app/models/oidc_token_metadata.py rename to gatehouse_app/models/oidc/oidc_token_metadata.py index 2c6c7a8..be8c862 100644 --- a/gatehouse_app/models/oidc_token_metadata.py +++ b/gatehouse_app/models/oidc/oidc_token_metadata.py @@ -8,13 +8,14 @@ from gatehouse_app.models.base import BaseModel class OIDCTokenMetadata(BaseModel): """OIDC Token Metadata model for tracking issued tokens. - This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens) - for the purpose of token revocation. The id field matches the JTI (JWT ID) claim. + Stores metadata about issued tokens (access, refresh, ID) for revocation. + The ``id`` field on this model intentionally overrides the BaseModel UUID + to store the JWT JTI directly as the primary key for O(1) revocation checks. """ __tablename__ = "oidc_token_metadata" - # Token identifier (matches JTI in JWT) + # Primary key = JTI so revocation lookups are always a PK scan id = db.Column( db.String(36), primary_key=True, default=lambda: str(uuid.uuid4()) ) @@ -27,11 +28,11 @@ class OIDCTokenMetadata(BaseModel): db.String(36), db.ForeignKey("users.id"), nullable=False, index=True ) - # Token type - token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token" + # Token type: "access_token", "refresh_token", or "id_token" + token_type = db.Column(db.String(50), nullable=False) - # Token identifier for revocation lookup - token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim + # JWT ID claim (indexed for fast lookup when id != jti) + token_jti = db.Column(db.String(255), nullable=False, index=True) # Timing expires_at = db.Column(db.DateTime, nullable=False, index=True) @@ -46,25 +47,27 @@ class OIDCTokenMetadata(BaseModel): def __repr__(self): """String representation of OIDCTokenMetadata.""" - return f"" + return ( + f"" + ) - def is_expired(self): + def is_expired(self) -> bool: """Check if the token has expired.""" - # Handle both timezone-aware and timezone-naive expires_at values expires_at = self.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) return datetime.now(timezone.utc) > expires_at - def is_revoked(self): + def is_revoked(self) -> bool: """Check if the token has been revoked.""" return self.revoked_at is not None - def is_valid(self): + def is_valid(self) -> bool: """Check if the token is valid (not expired and not revoked).""" return not self.is_revoked() and not self.is_expired() and self.deleted_at is None - def revoke(self, reason=None): + def revoke(self, reason: str = None) -> None: """Revoke the token. Args: @@ -75,8 +78,16 @@ class OIDCTokenMetadata(BaseModel): db.session.commit() @classmethod - def create_metadata(cls, client_id, user_id, token_type, token_jti, - expires_at, ip_address=None, user_agent=None): + def create_metadata( + cls, + client_id: str, + user_id: str, + token_type: str, + token_jti: str, + expires_at, + ip_address: str = None, + user_agent: str = None, + ) -> "OIDCTokenMetadata": """Create token metadata for tracking. Args: @@ -85,8 +96,8 @@ class OIDCTokenMetadata(BaseModel): token_type: Type of token ("access_token", "refresh_token", "id_token") token_jti: JWT ID claim expires_at: Token expiration datetime - ip_address: Client IP address - user_agent: Client user agent + ip_address: Client IP address (unused column, kept for API compat) + user_agent: Client user agent (unused column, kept for API compat) Returns: OIDCTokenMetadata instance @@ -104,7 +115,7 @@ class OIDCTokenMetadata(BaseModel): return metadata @classmethod - def get_by_jti(cls, token_jti): + def get_by_jti(cls, token_jti: str) -> "OIDCTokenMetadata | None": """Get token metadata by JWT ID. Args: @@ -116,7 +127,7 @@ class OIDCTokenMetadata(BaseModel): return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first() @classmethod - def revoke_by_jti(cls, token_jti, reason=None): + def revoke_by_jti(cls, token_jti: str, reason: str = None) -> bool: """Revoke a token by its JWT ID. Args: @@ -124,7 +135,7 @@ class OIDCTokenMetadata(BaseModel): reason: Optional revocation reason Returns: - bool: True if token was found and revoked + True if token was found and revoked, False otherwise """ metadata = cls.get_by_jti(token_jti) if metadata: @@ -133,47 +144,53 @@ class OIDCTokenMetadata(BaseModel): return False @classmethod - def revoke_all_for_user(cls, user_id, client_id=None, reason=None): + def revoke_all_for_user( + cls, user_id: str, client_id: str = None, reason: str = None + ) -> int: """Revoke all tokens for a user. Args: user_id: The user ID - client_id: Optional client ID to filter by + client_id: Optional client ID filter reason: Optional revocation reason Returns: - int: Number of tokens revoked + Number of tokens revoked """ - query = cls.query.filter_by(user_id=user_id, deleted_at=None) + query = cls.query.filter_by(user_id=user_id, deleted_at=None).filter( + cls.revoked_at.is_(None) + ) if client_id: query = query.filter_by(client_id=client_id) - tokens = query.filter(cls.revoked_at == None).all() count = 0 - for token in tokens: + for token in query.all(): token.revoke(reason) count += 1 return count @classmethod - def revoke_all_for_client(cls, client_id, user_id=None, reason=None): + def revoke_all_for_client( + cls, client_id: str, user_id: str = None, reason: str = None + ) -> int: """Revoke all tokens for a client. Args: client_id: The client ID - user_id: Optional user ID to filter by + user_id: Optional user ID filter reason: Optional revocation reason Returns: - int: Number of tokens revoked + Number of tokens revoked """ - query = cls.query.filter_by(client_id=client_id, deleted_at=None) + query = cls.query.filter_by(client_id=client_id, deleted_at=None).filter( + cls.revoked_at.is_(None) + ) if user_id: query = query.filter_by(user_id=user_id) - tokens = query.filter(cls.revoked_at == None).all() count = 0 - for token in tokens: + for token in query.all(): token.revoke(reason) count += 1 return count @@ -181,16 +198,3 @@ class OIDCTokenMetadata(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary.""" return super().to_dict(exclude=exclude) - - -# Add relationship back to User model -from gatehouse_app.models.user import User -User.oidc_token_metadata = db.relationship( - "OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan" -) - -# Add relationship back to OIDCClient model -from gatehouse_app.models.oidc_client import OIDCClient -OIDCClient.token_metadata = db.relationship( - "OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan" -) diff --git a/gatehouse_app/models/oidc_jwks_key.py b/gatehouse_app/models/oidc_jwks_key.py deleted file mode 100644 index 07dcb80..0000000 --- a/gatehouse_app/models/oidc_jwks_key.py +++ /dev/null @@ -1,77 +0,0 @@ -"""OIDC JWKS Key model for persisting signing keys.""" -from datetime import datetime, timezone -from gatehouse_app.extensions import db -from gatehouse_app.models.base import BaseModel - - -class OidcJwksKey(BaseModel): - """ - OIDC JWKS Key model for persisting JSON Web Key Set signing keys. - - This model stores RSA/ECDSA key pairs used for signing OIDC tokens. - Multiple keys can be stored to support key rotation scenarios. - - Attributes: - id: Integer primary key - kid: Unique key ID used in JWT "kid" header - key_type: Type of key (e.g., "RSA", "EC") - private_key: PEM-encoded private key - public_key: PEM-encoded public key - algorithm: Signing algorithm (e.g., "RS256", "ES256") - created_at: When the key was created - is_active: Whether this key is currently active for signing - is_primary: Whether this is the primary signing key - expires_at: ... - """ - - __tablename__ = "oidc_jwks_keys" - - # Override the default UUID id with integer primary key - id = db.Column(db.Integer, primary_key=True) - - expires_at = db.Column(db.DateTime, nullable=True) - - # Key identification and type - kid = db.Column(db.String(255), unique=True, nullable=False, index=True) - key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC" - algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256" - - # Key material (PEM-encoded) - private_key = db.Column(db.Text, nullable=False) - public_key = db.Column(db.Text, nullable=False) - - # Key status - is_active = db.Column(db.Boolean, default=True, nullable=False) - is_primary = db.Column(db.Boolean, default=False, nullable=False) - - def __repr__(self): - """String representation of OidcJwksKey.""" - return f"" - - def to_dict(self, exclude_private_key=True): - """ - Convert model to dictionary. - - Args: - exclude_private_key: If True, excludes the private key from output - - Returns: - Dictionary representation of the model - """ - exclude = ["private_key"] if exclude_private_key else [] - return super().to_dict(exclude=exclude) - - @classmethod - def get_active_keys(cls): - """Get all active keys for signing operations.""" - return cls.query.filter(cls.is_active == True).all() - - @classmethod - def get_primary_key(cls): - """Get the primary signing key.""" - return cls.query.filter(cls.is_primary == True).first() - - @classmethod - def get_key_by_kid(cls, kid): - """Get a key by its key ID.""" - return cls.query.filter(cls.kid == kid, cls.is_active == True).first() \ No newline at end of file diff --git a/gatehouse_app/models/organization/__init__.py b/gatehouse_app/models/organization/__init__.py new file mode 100644 index 0000000..aa33f8e --- /dev/null +++ b/gatehouse_app/models/organization/__init__.py @@ -0,0 +1,27 @@ +"""Organization subpackage.""" +from gatehouse_app.models.organization.organization import Organization +from gatehouse_app.models.organization.organization_member import OrganizationMember +from gatehouse_app.models.organization.department import ( + Department, + DepartmentMembership, + DepartmentPrincipal, +) +from gatehouse_app.models.organization.department_cert_policy import ( + DepartmentCertPolicy, + STANDARD_EXTENSIONS, +) +from gatehouse_app.models.organization.principal import Principal, PrincipalMembership +from gatehouse_app.models.organization.org_invite_token import OrgInviteToken + +__all__ = [ + "Organization", + "OrganizationMember", + "Department", + "DepartmentMembership", + "DepartmentPrincipal", + "DepartmentCertPolicy", + "STANDARD_EXTENSIONS", + "Principal", + "PrincipalMembership", + "OrgInviteToken", +] diff --git a/gatehouse_app/models/department.py b/gatehouse_app/models/organization/department.py similarity index 80% rename from gatehouse_app/models/department.py rename to gatehouse_app/models/organization/department.py index 30d1a0f..800780b 100644 --- a/gatehouse_app/models/department.py +++ b/gatehouse_app/models/organization/department.py @@ -1,14 +1,15 @@ +"""Department, DepartmentMembership, and DepartmentPrincipal models.""" from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel class Department(BaseModel): """Department model representing an organizational unit for SSH access control. - + Departments are used to group users and assign SSH principals (access levels) to them. A user can be a member of multiple departments, and each department can have multiple principals assigned. - + Example: - Department: "Engineering" - Members: user1@example.com, user2@example.com @@ -39,12 +40,15 @@ class Department(BaseModel): back_populates="department", cascade="all, delete-orphan", ) + cert_policy = db.relationship( + "DepartmentCertPolicy", + back_populates="department", + uselist=False, + cascade="all, delete-orphan", + ) - # Unique constraint: department name per organization __table_args__ = ( - db.UniqueConstraint( - "organization_id", "name", name="uix_org_dept_name" - ), + db.UniqueConstraint("organization_id", "name", name="uix_org_dept_name"), ) def __repr__(self): @@ -55,47 +59,46 @@ class Department(BaseModel): """Convert department to dictionary.""" exclude = exclude or [] data = super().to_dict(exclude=exclude) - - # Add member count data["member_count"] = len([m for m in self.memberships if m.deleted_at is None]) - - # Add principal count data["principal_count"] = len([p for p in self.principal_links if p.deleted_at is None]) - return data - def get_members(self, active_only=True): + def get_members(self, active_only: bool = True): """Get all members of this department. - + Args: active_only: If True, exclude soft-deleted members - + Returns: List of DepartmentMembership objects """ if active_only: return [m for m in self.memberships if m.deleted_at is None] - return self.memberships + return list(self.memberships) - def get_principals(self, active_only=True): + def get_principals(self, active_only: bool = True): """Get all principals assigned to this department. - + Args: active_only: If True, exclude soft-deleted principals - + Returns: List of Principal objects via DepartmentPrincipal """ if active_only: - return [p.principal for p in self.principal_links if p.deleted_at is None and p.principal.deleted_at is None] + return [ + p.principal + for p in self.principal_links + if p.deleted_at is None and p.principal.deleted_at is None + ] return [p.principal for p in self.principal_links] - def is_member(self, user_id): + def is_member(self, user_id: str) -> bool: """Check if a user is a member of this department. - + Args: user_id: ID of the user to check - + Returns: True if user is an active member, False otherwise """ @@ -108,14 +111,14 @@ class Department(BaseModel): is not None ) - def get_member_count(self): + def get_member_count(self) -> int: """Get the count of active members in this department.""" return len(self.get_members(active_only=True)) class DepartmentMembership(BaseModel): """Department membership model representing user membership in a department. - + When a user is added to a department, they become eligible for SSH principals assigned to that department. """ @@ -139,24 +142,23 @@ class DepartmentMembership(BaseModel): user = db.relationship("User", back_populates="department_memberships") department = db.relationship("Department", back_populates="memberships") - # Unique constraint: user can only be member of a department once __table_args__ = ( - db.UniqueConstraint( - "user_id", "department_id", name="uix_user_dept" - ), + db.UniqueConstraint("user_id", "department_id", name="uix_user_dept"), ) def __repr__(self): """String representation of DepartmentMembership.""" - return f"" + return ( + f"" + ) class DepartmentPrincipal(BaseModel): """Department principal assignment model. - + Represents the assignment of principals to departments. All members of a department get access to its assigned principals (transitively). - + Example: - Department: "Engineering" - Principal: "eng-prod-servers" @@ -182,13 +184,13 @@ class DepartmentPrincipal(BaseModel): department = db.relationship("Department", back_populates="principal_links") principal = db.relationship("Principal", back_populates="department_links") - # Unique constraint: principal can only be assigned to a department once __table_args__ = ( - db.UniqueConstraint( - "department_id", "principal_id", name="uix_dept_principal" - ), + db.UniqueConstraint("department_id", "principal_id", name="uix_dept_principal"), ) def __repr__(self): """String representation of DepartmentPrincipal.""" - return f"" + return ( + f"" + ) diff --git a/gatehouse_app/models/organization/department_cert_policy.py b/gatehouse_app/models/organization/department_cert_policy.py new file mode 100644 index 0000000..357329f --- /dev/null +++ b/gatehouse_app/models/organization/department_cert_policy.py @@ -0,0 +1,76 @@ +"""DepartmentCertPolicy — per-department SSH certificate issuance rules.""" +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +# Standard SSH certificate extensions +STANDARD_EXTENSIONS = [ + "permit-X11-forwarding", + "permit-agent-forwarding", + "permit-pty", + "permit-port-forwarding", + "permit-user-rc", +] + + +class DepartmentCertPolicy(BaseModel): + """SSH certificate policy for a department. + + Controls: + - Whether members may choose their own expiry date (up to ``max_expiry_hours``) + - Default expiry hours when the user doesn't (or can't) pick + - Maximum expiry hours (hard ceiling, even for admins signing on behalf) + - Which SSH certificate extensions are granted to members of this department + - Any custom extensions the admin wants to add beyond the standard five + + Inherits ``id``, ``created_at``, ``updated_at``, and ``deleted_at`` from + :class:`BaseModel` so soft-delete and the standard timestamp behaviour are + consistent with every other model in the application. + """ + + __tablename__ = "department_cert_policies" + + department_id = db.Column( + db.String(36), + db.ForeignKey("departments.id"), + nullable=False, + unique=True, + index=True, + ) + + # Expiry control + allow_user_expiry = db.Column(db.Boolean, nullable=False, default=False) + default_expiry_hours = db.Column(db.Integer, nullable=False, default=1) + max_expiry_hours = db.Column(db.Integer, nullable=False, default=24) + + # Extensions — list of extension name strings + allowed_extensions = db.Column( + db.JSON, + nullable=False, + default=lambda: list(STANDARD_EXTENSIONS), + ) + # Admin-defined extras beyond the standard five + custom_extensions = db.Column(db.JSON, nullable=False, default=list) + + # Relationship back to department + department = db.relationship("Department", back_populates="cert_policy", uselist=False) + + def __repr__(self): + return ( + f"" + ) + + def all_extensions(self) -> list: + """Return the full list of enabled extensions (allowed + custom).""" + return list((self.allowed_extensions or []) + (self.custom_extensions or [])) + + def to_dict(self, exclude=None): + """Convert to dictionary.""" + exclude = exclude or [] + data = super().to_dict(exclude=exclude) + # Augment with computed / convenience fields not in the base columns + data["all_extensions"] = self.all_extensions() + data["standard_extensions"] = STANDARD_EXTENSIONS + return data diff --git a/gatehouse_app/models/organization/org_invite_token.py b/gatehouse_app/models/organization/org_invite_token.py new file mode 100644 index 0000000..2830b25 --- /dev/null +++ b/gatehouse_app/models/organization/org_invite_token.py @@ -0,0 +1,77 @@ +"""Organization invite token model.""" +import secrets +from datetime import datetime, timezone, timedelta + +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel + + +class OrgInviteToken(BaseModel): + """Token-based invitation to join an organization.""" + + __tablename__ = "org_invite_tokens" + + organization_id = db.Column( + db.String(36), + db.ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + invited_by_id = db.Column( + db.String(36), + db.ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ) + email = db.Column(db.String(255), nullable=False, index=True) + role = db.Column(db.String(64), nullable=False, default="member") + token = db.Column(db.String(128), unique=True, nullable=False, index=True) + expires_at = db.Column(db.DateTime, nullable=False) + accepted_at = db.Column(db.DateTime, nullable=True) + + organization = db.relationship( + "Organization", + backref=db.backref("invite_tokens", cascade="all, delete-orphan"), + ) + invited_by = db.relationship("User", foreign_keys=[invited_by_id]) + + @classmethod + def generate( + cls, + organization_id: str, + email: str, + role: str = "member", + invited_by_id: str = None, + ttl_days: int = 7, + ) -> "OrgInviteToken": + """Create a new invite token for an organization.""" + token_value = secrets.token_urlsafe(48) + instance = cls( + organization_id=organization_id, + email=email.lower(), + role=role, + invited_by_id=invited_by_id, + token=token_value, + expires_at=datetime.now(timezone.utc) + timedelta(days=ttl_days), + ) + db.session.add(instance) + db.session.commit() + return instance + + @property + def is_valid(self) -> bool: + """Return True if the token is unused and not expired.""" + if self.accepted_at is not None: + return False + now = datetime.now(timezone.utc) + expires = self.expires_at + if expires.tzinfo is None: + expires = expires.replace(tzinfo=timezone.utc) + return now < expires + + def accept(self) -> None: + """Mark the invite as accepted.""" + self.accepted_at = datetime.now(timezone.utc) + db.session.commit() + + def __repr__(self) -> str: + return f"" diff --git a/gatehouse_app/models/organization.py b/gatehouse_app/models/organization/organization.py similarity index 94% rename from gatehouse_app/models/organization.py rename to gatehouse_app/models/organization/organization.py index a6fa756..9be5c65 100644 --- a/gatehouse_app/models/organization.py +++ b/gatehouse_app/models/organization/organization.py @@ -61,9 +61,9 @@ class Organization(BaseModel): return member.user return None - def is_member(self, user_id): + def is_member(self, user_id: str) -> bool: """Check if a user is a member of the organization.""" - from gatehouse_app.models.organization_member import OrganizationMember + from gatehouse_app.models.organization.organization_member import OrganizationMember return ( OrganizationMember.query.filter_by( diff --git a/gatehouse_app/models/organization_member.py b/gatehouse_app/models/organization/organization_member.py similarity index 78% rename from gatehouse_app/models/organization_member.py rename to gatehouse_app/models/organization/organization_member.py index 3247082..3b02242 100644 --- a/gatehouse_app/models/organization_member.py +++ b/gatehouse_app/models/organization/organization_member.py @@ -1,4 +1,4 @@ -"""Organization member model.""" +"""Organization member model.""" from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel from gatehouse_app.utils.constants import OrganizationRole @@ -21,31 +21,35 @@ class OrganizationMember(BaseModel): joined_at = db.Column(db.DateTime, nullable=True) # Relationships - user = db.relationship("User", foreign_keys=[user_id], back_populates="organization_memberships") + user = db.relationship( + "User", foreign_keys=[user_id], back_populates="organization_memberships" + ) organization = db.relationship("Organization", back_populates="members") invited_by = db.relationship("User", foreign_keys=[invited_by_id]) - # Unique constraint to prevent duplicate memberships __table_args__ = ( db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"), ) def __repr__(self): """String representation of OrganizationMember.""" - return f"" + return ( + f"" + ) - def is_owner(self): + def is_owner(self) -> bool: """Check if member is an owner.""" return self.role == OrganizationRole.OWNER - def is_admin(self): + def is_admin(self) -> bool: """Check if member is an admin or owner.""" return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] - def can_manage_members(self): + def can_manage_members(self) -> bool: """Check if member can manage other members.""" return self.is_admin() - def can_delete_organization(self): + def can_delete_organization(self) -> bool: """Check if member can delete the organization.""" return self.is_owner() diff --git a/gatehouse_app/models/principal.py b/gatehouse_app/models/organization/principal.py similarity index 69% rename from gatehouse_app/models/principal.py rename to gatehouse_app/models/organization/principal.py index 7783ec3..8f87fe3 100644 --- a/gatehouse_app/models/principal.py +++ b/gatehouse_app/models/organization/principal.py @@ -1,14 +1,16 @@ +"""Principal and PrincipalMembership models.""" from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel class Principal(BaseModel): """Principal model representing an SSH principal (access level/role). - + In SSH CA terminology, a principal is a string like "eng-prod-servers" or "devops-admins" that represents a set of machines or access level. Users - can be granted access to principals, either directly or via department membership. - + can be granted access to principals, either directly or via department + membership. + Example: - Principal: "eng-prod-servers" - Users with this principal can SSH to prod servers @@ -39,11 +41,8 @@ class Principal(BaseModel): cascade="all, delete-orphan", ) - # Unique constraint: principal name per organization __table_args__ = ( - db.UniqueConstraint( - "organization_id", "name", name="uix_org_principal_name" - ), + db.UniqueConstraint("organization_id", "name", name="uix_org_principal_name"), ) def __repr__(self): @@ -54,79 +53,80 @@ class Principal(BaseModel): """Convert principal to dictionary.""" exclude = exclude or [] data = super().to_dict(exclude=exclude) - - # Add member count - data["direct_member_count"] = len([m for m in self.memberships if m.deleted_at is None]) - - # Add department count - data["department_count"] = len([d for d in self.department_links if d.deleted_at is None]) - + data["direct_member_count"] = len( + [m for m in self.memberships if m.deleted_at is None] + ) + data["department_count"] = len( + [d for d in self.department_links if d.deleted_at is None] + ) return data - def get_members(self, active_only=True): + def get_members(self, active_only: bool = True): """Get all users who are directly assigned to this principal. - + Does NOT include users who get access via department membership. - + Args: active_only: If True, exclude soft-deleted members - + Returns: List of PrincipalMembership objects """ if active_only: return [m for m in self.memberships if m.deleted_at is None] - return self.memberships + return list(self.memberships) - def get_all_members(self, active_only=True): + def get_all_members(self, active_only: bool = True): """Get all users who have access to this principal. - + Includes both direct members and users via department membership. - + Args: active_only: If True, exclude soft-deleted members - + Returns: Set of User objects with access to this principal """ - from gatehouse_app.models.user import User - - all_users = set() - - # Add direct members + all_users: set = set() + + # Direct members for membership in self.get_members(active_only=active_only): - if membership.user.deleted_at is None or not active_only: + if not active_only or membership.user.deleted_at is None: all_users.add(membership.user) - - # Add members via department assignment + + # Members via department assignment for dept_link in self.department_links: if dept_link.deleted_at is None or not active_only: for dept_member in dept_link.department.get_members(active_only=active_only): - if dept_member.user.deleted_at is None or not active_only: + if not active_only or dept_member.user.deleted_at is None: all_users.add(dept_member.user) - + return all_users - def get_departments(self, active_only=True): + def get_departments(self, active_only: bool = True): """Get all departments this principal is assigned to. - + Args: active_only: If True, exclude soft-deleted departments - + Returns: List of Department objects """ if active_only: - return [d.department for d in self.department_links if d.deleted_at is None and d.department.deleted_at is None] + return [ + d.department + for d in self.department_links + if d.deleted_at is None and d.department.deleted_at is None + ] return [d.department for d in self.department_links] - def is_member(self, user_id, include_via_department=True): + def is_member(self, user_id: str, include_via_department: bool = True) -> bool: """Check if a user has access to this principal. - + Args: user_id: ID of the user to check include_via_department: If True, check department memberships too - + Returns: True if user has access to this principal """ @@ -139,54 +139,49 @@ class Principal(BaseModel): ).first() is not None ) - + if has_direct: return True - - # Check department membership if requested + if not include_via_department: return False - - # Get all departments this principal is assigned to - depts = self.get_departments(active_only=True) - dept_ids = [d.id for d in depts] - + + # Check department membership + dept_ids = [d.id for d in self.get_departments(active_only=True)] if not dept_ids: return False - - # Check if user is a member of any of these departments - from gatehouse_app.models.department import DepartmentMembership - + + from gatehouse_app.models.organization.department import DepartmentMembership + return ( DepartmentMembership.query.filter( DepartmentMembership.user_id == user_id, DepartmentMembership.department_id.in_(dept_ids), - DepartmentMembership.deleted_at == None, + DepartmentMembership.deleted_at.is_(None), ).first() is not None ) - def get_member_count(self, include_via_department=True): + def get_member_count(self, include_via_department: bool = True) -> int: """Get the count of active members with access to this principal. - + Args: include_via_department: If True, include members via department - + Returns: Count of members """ if not include_via_department: return len(self.get_members(active_only=True)) - return len(self.get_all_members(active_only=True)) class PrincipalMembership(BaseModel): """Principal membership model representing direct user assignment to a principal. - - When a user is assigned directly to a principal, they get access to that principal - for SSH authentication. This is in addition to any principals they get via - department membership. + + When a user is assigned directly to a principal, they get access to that + principal for SSH authentication. This is in addition to any principals + they get via department membership. """ __tablename__ = "principal_memberships" @@ -208,13 +203,13 @@ class PrincipalMembership(BaseModel): user = db.relationship("User", back_populates="principal_memberships") principal = db.relationship("Principal", back_populates="memberships") - # Unique constraint: user can only be member of a principal once __table_args__ = ( - db.UniqueConstraint( - "user_id", "principal_id", name="uix_user_principal" - ), + db.UniqueConstraint("user_id", "principal_id", name="uix_user_principal"), ) def __repr__(self): """String representation of PrincipalMembership.""" - return f"" + return ( + f"" + ) diff --git a/gatehouse_app/models/security/__init__.py b/gatehouse_app/models/security/__init__.py new file mode 100644 index 0000000..d24aef5 --- /dev/null +++ b/gatehouse_app/models/security/__init__.py @@ -0,0 +1,12 @@ +"""Security subpackage — organization and user security policies, MFA compliance.""" +from gatehouse_app.models.security.organization_security_policy import ( + OrganizationSecurityPolicy, +) +from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy +from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance + +__all__ = [ + "OrganizationSecurityPolicy", + "UserSecurityPolicy", + "MfaPolicyCompliance", +] diff --git a/gatehouse_app/models/mfa_policy_compliance.py b/gatehouse_app/models/security/mfa_policy_compliance.py similarity index 76% rename from gatehouse_app/models/mfa_policy_compliance.py rename to gatehouse_app/models/security/mfa_policy_compliance.py index 6ecd217..5c2de13 100644 --- a/gatehouse_app/models/mfa_policy_compliance.py +++ b/gatehouse_app/models/security/mfa_policy_compliance.py @@ -1,4 +1,4 @@ -"""MfaPolicyCompliance model.""" +"""MfaPolicyCompliance model — per-user per-organization MFA compliance tracking.""" from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel from gatehouse_app.utils.constants import MfaComplianceStatus @@ -7,7 +7,8 @@ from gatehouse_app.utils.constants import MfaComplianceStatus class MfaPolicyCompliance(BaseModel): """MFA policy compliance tracking per user per organization. - Tracks each user's MFA compliance state separately for each organization membership. + Tracks each user's MFA compliance state separately for each organization + membership. One row per (user, org) pair. """ __tablename__ = "mfa_policy_compliance" @@ -25,13 +26,13 @@ class MfaPolicyCompliance(BaseModel): default=MfaComplianceStatus.NOT_APPLICABLE, ) - # Snapshot of org policy at the time this record became active + # Snapshot of org policy version when this record became active policy_version = db.Column(db.Integer, nullable=False) # When policy started applying to this user applied_at = db.Column(db.DateTime, nullable=True) - # Final deadline for this user to comply (per user, not global) + # Final deadline for this user to comply deadline_at = db.Column(db.DateTime, nullable=True) # When they became compliant under this policy_version @@ -45,9 +46,7 @@ class MfaPolicyCompliance(BaseModel): notification_count = db.Column(db.Integer, nullable=False, default=0) __table_args__ = ( - db.UniqueConstraint( - "user_id", "organization_id", name="uix_user_org_compliance" - ), + db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_compliance"), ) # Relationships @@ -58,9 +57,11 @@ class MfaPolicyCompliance(BaseModel): def __repr__(self): """String representation of MfaPolicyCompliance.""" - return f"" + return ( + f"" + ) def to_dict(self, exclude=None): """Convert to dictionary.""" - exclude = exclude or [] - return super().to_dict(exclude=exclude) \ No newline at end of file + return super().to_dict(exclude=exclude or []) diff --git a/gatehouse_app/models/organization_security_policy.py b/gatehouse_app/models/security/organization_security_policy.py similarity index 83% rename from gatehouse_app/models/organization_security_policy.py rename to gatehouse_app/models/security/organization_security_policy.py index 991b72d..593781a 100644 --- a/gatehouse_app/models/organization_security_policy.py +++ b/gatehouse_app/models/security/organization_security_policy.py @@ -39,15 +39,19 @@ class OrganizationSecurityPolicy(BaseModel): # Relationships organization = db.relationship( - "Organization", back_populates="security_policy", foreign_keys=[organization_id] + "Organization", + back_populates="security_policy", + foreign_keys=[organization_id], ) updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id]) def __repr__(self): """String representation of OrganizationSecurityPolicy.""" - return f"" + return ( + f"" + ) def to_dict(self, exclude=None): """Convert to dictionary.""" - exclude = exclude or [] - return super().to_dict(exclude=exclude) \ No newline at end of file + return super().to_dict(exclude=exclude or []) diff --git a/gatehouse_app/models/user_security_policy.py b/gatehouse_app/models/security/user_security_policy.py similarity index 67% rename from gatehouse_app/models/user_security_policy.py rename to gatehouse_app/models/security/user_security_policy.py index d765575..a96ef84 100644 --- a/gatehouse_app/models/user_security_policy.py +++ b/gatehouse_app/models/security/user_security_policy.py @@ -1,4 +1,4 @@ -"""UserSecurityPolicy model.""" +"""UserSecurityPolicy model — per-user MFA overrides.""" from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel from gatehouse_app.utils.constants import MfaRequirementOverride @@ -7,7 +7,7 @@ from gatehouse_app.utils.constants import MfaRequirementOverride class UserSecurityPolicy(BaseModel): """User security policy model for per-user MFA overrides. - Stores per user overrides of organization level MFA requirements. + Stores per-user overrides of organization-level MFA requirements. """ __tablename__ = "user_security_policies" @@ -25,29 +25,27 @@ class UserSecurityPolicy(BaseModel): default=MfaRequirementOverride.INHERIT, ) - # If override is REQUIRED and you want to force a specific factor set + # If override is REQUIRED, optionally force a specific factor set force_totp = db.Column(db.Boolean, nullable=False, default=False) force_webauthn = db.Column(db.Boolean, nullable=False, default=False) __table_args__ = ( - db.UniqueConstraint( - "user_id", "organization_id", name="uix_user_org_policy" - ), + db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_policy"), ) # Relationships user = db.relationship( "User", back_populates="security_policies", foreign_keys=[user_id] ) - organization = db.relationship( - "Organization", foreign_keys=[organization_id] - ) + organization = db.relationship("Organization", foreign_keys=[organization_id]) def __repr__(self): """String representation of UserSecurityPolicy.""" - return f"" + return ( + f"" + ) def to_dict(self, exclude=None): """Convert to dictionary.""" - exclude = exclude or [] - return super().to_dict(exclude=exclude) \ No newline at end of file + return super().to_dict(exclude=exclude or []) diff --git a/gatehouse_app/models/ssh_ca/__init__.py b/gatehouse_app/models/ssh_ca/__init__.py new file mode 100644 index 0000000..d6932b3 --- /dev/null +++ b/gatehouse_app/models/ssh_ca/__init__.py @@ -0,0 +1,17 @@ +"""SSH/CA subpackage — certificate authorities, SSH keys, and certificates.""" +from gatehouse_app.models.ssh_ca.ca import CA, KeyType, CertType, CaType, CAPermission +from gatehouse_app.models.ssh_ca.ssh_key import SSHKey +from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus +from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog + +__all__ = [ + "CA", + "KeyType", + "CertType", + "CaType", + "CAPermission", + "SSHKey", + "SSHCertificate", + "CertificateStatus", + "CertificateAuditLog", +] diff --git a/gatehouse_app/models/ca.py b/gatehouse_app/models/ssh_ca/ca.py similarity index 72% rename from gatehouse_app/models/ca.py rename to gatehouse_app/models/ssh_ca/ca.py index bfcef44..337f7c8 100644 --- a/gatehouse_app/models/ca.py +++ b/gatehouse_app/models/ssh_ca/ca.py @@ -1,13 +1,13 @@ """Certificate Authority (CA) model.""" from enum import Enum -from datetime import datetime +from datetime import datetime, timezone from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel class KeyType(str, Enum): """SSH CA key types.""" - + ED25519 = "ed25519" RSA = "rsa" ECDSA = "ecdsa" @@ -15,7 +15,7 @@ class KeyType(str, Enum): class CertType(str, Enum): """SSH certificate types.""" - + USER = "user" HOST = "host" @@ -29,7 +29,7 @@ class CaType(str, Enum): class CA(BaseModel): """Certificate Authority (CA) model for SSH certificate signing. - + Each organization can have multiple CAs for different purposes (e.g., production vs. staging). Private keys are encrypted at rest and should be protected with KMS. @@ -43,12 +43,12 @@ class CA(BaseModel): nullable=True, # NULL for the global system-config CA index=True, ) - - # CA name and description + + # CA identity name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=True) - # CA signing type: 'user' signs user certificates, 'host' signs host certificates + # CA signing type: 'user' signs user certificates, 'host' signs host certs ca_type = db.Column( db.Enum(CaType, values_callable=lambda x: [e.value for e in x]), default=CaType.USER, @@ -61,43 +61,33 @@ class CA(BaseModel): default=KeyType.ED25519, nullable=False, ) - - # Private key (encrypted at rest by database/KMS) - # Format: PEM-encoded private key + + # Private key — PEM-encoded, encrypted at rest by database/KMS private_key = db.Column(db.Text, nullable=False) - - # Public key (PEM format) + + # Public key — PEM format public_key = db.Column(db.Text, nullable=False) - + # SHA256 fingerprint of the public key fingerprint = db.Column(db.String(255), nullable=False, unique=True) - + # CRL (Certificate Revocation List) configuration crl_enabled = db.Column(db.Boolean, default=True, nullable=False) crl_endpoint = db.Column(db.String(512), nullable=True) - - # Default certificate validity in hours - # Can be overridden per certificate request - default_cert_validity_hours = db.Column( - db.Integer, - default=1, - nullable=False, - ) - + + # Default certificate validity in hours (overridable per request) + default_cert_validity_hours = db.Column(db.Integer, default=1, nullable=False) + # Maximum validity duration allowed - max_cert_validity_hours = db.Column( - db.Integer, - default=24, - nullable=False, - ) - + max_cert_validity_hours = db.Column(db.Integer, default=24, nullable=False) + # CA status is_active = db.Column(db.Boolean, default=True, nullable=False, index=True) - + # Key rotation tracking rotated_at = db.Column(db.DateTime, nullable=True) rotation_reason = db.Column(db.String(255), nullable=True) - + # Relationships organization = db.relationship("Organization", back_populates="cas") certificates = db.relationship( @@ -112,54 +102,53 @@ class CA(BaseModel): ) __table_args__ = ( - db.UniqueConstraint( - "organization_id", "name", name="uix_org_ca_name" - ), + db.UniqueConstraint("organization_id", "name", name="uix_org_ca_name"), db.Index("idx_ca_org_active", "organization_id", "is_active"), ) def __repr__(self): """String representation of CA.""" - return f"" + return ( + f"" + ) def to_dict(self, exclude=None): - """Convert CA to dictionary.""" + """Convert CA to dictionary, never exposing the private key.""" exclude = exclude or [] - # Never expose private key in API responses - exclude.extend(["private_key"]) + if "private_key" not in exclude: + exclude.append("private_key") data = super().to_dict(exclude=exclude) - + # Add computed fields data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None]) - data["active_certs"] = len([ - c for c in self.certificates - if c.deleted_at is None and not c.revoked - ]) - data["revoked_certs"] = len([ - c for c in self.certificates - if c.deleted_at is None and c.revoked - ]) - + data["active_certs"] = len( + [c for c in self.certificates if c.deleted_at is None and not c.revoked] + ) + data["revoked_certs"] = len( + [c for c in self.certificates if c.deleted_at is None and c.revoked] + ) return data - def get_active_certificates(self): - """Get all active (non-revoked) certificates issued by this CA. - - Returns: - List of non-revoked SSHCertificate objects - """ + def get_active_certificates(self) -> list: + """Get all active (non-revoked) certificates issued by this CA.""" return [ - c for c in self.certificates - if c.deleted_at is None and not c.revoked + c for c in self.certificates if c.deleted_at is None and not c.revoked ] - def rotate_key(self, new_private_key, new_public_key, new_fingerprint, reason=None): + def rotate_key( + self, + new_private_key: str, + new_public_key: str, + new_fingerprint: str, + reason: str = None, + ) -> None: """Rotate the CA's key pair. - + This should only be done in carefully controlled circumstances. - All existing certificates remain valid but no new certs can be - signed with the old key. - + All existing certificates remain valid but no new certificates can be + signed with the old key after rotation. + Args: new_private_key: New PEM-encoded private key new_public_key: New PEM-encoded public key @@ -169,7 +158,7 @@ class CA(BaseModel): self.private_key = new_private_key self.public_key = new_public_key self.fingerprint = new_fingerprint - self.rotated_at = datetime.utcnow() + self.rotated_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow() self.rotation_reason = reason self.save() @@ -178,7 +167,7 @@ class CAPermission(BaseModel): """Per-user CA permission model. Controls which users are allowed to sign certificates against a specific CA. - When a CA has any permission rows the signing endpoint enforces the list; + When a CA has any permission rows, the signing endpoint enforces the list; CAs with no rows are open to all org members (backwards-compatible default). Permission values: @@ -212,7 +201,10 @@ class CAPermission(BaseModel): ) def __repr__(self): - return f"" + return ( + f"" + ) def to_dict(self, exclude=None): data = super().to_dict(exclude=exclude or []) diff --git a/gatehouse_app/models/certificate_audit_log.py b/gatehouse_app/models/ssh_ca/certificate_audit_log.py similarity index 78% rename from gatehouse_app/models/certificate_audit_log.py rename to gatehouse_app/models/ssh_ca/certificate_audit_log.py index 0d3274b..02f24d3 100644 --- a/gatehouse_app/models/certificate_audit_log.py +++ b/gatehouse_app/models/ssh_ca/certificate_audit_log.py @@ -1,14 +1,14 @@ -"""Certificate audit log model.""" +"""Certificate audit log model — tracks SSH certificate lifecycle events.""" from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel class CertificateAuditLog(BaseModel): """Audit log for SSH certificate lifecycle events. - - Tracks all operations on SSH certificates: signing, revocation, - validation, etc. This is separate from the general AuditLog to - provide detailed certificate operation tracking. + + Tracks all operations on SSH certificates: signing, revocation, validation, + etc. Kept separate from the general AuditLog to provide detailed certificate + operation tracking without polluting the main audit stream. """ __tablename__ = "certificate_audit_logs" @@ -20,33 +20,33 @@ class CertificateAuditLog(BaseModel): nullable=False, index=True, ) - - # The user who performed the action (can be null for system actions) + + # The user who performed the action (null for system actions) user_id = db.Column( db.String(36), db.ForeignKey("users.id"), nullable=True, index=True, ) - + # Action type (e.g., "signed", "revoked", "validated", "requested") action = db.Column(db.String(50), nullable=False, index=True) - + # Request details ip_address = db.Column(db.String(45), nullable=True) user_agent = db.Column(db.String(512), nullable=True) request_id = db.Column(db.String(36), nullable=True) - + # Detailed message message = db.Column(db.Text, nullable=True) - + # Additional context extra_data = db.Column(db.JSON, nullable=True) - - # Success/failure + + # Outcome success = db.Column(db.Boolean, default=True, nullable=False) error_message = db.Column(db.Text, nullable=True) - + # Relationships certificate = db.relationship("SSHCertificate", back_populates="audit_logs") user = db.relationship("User") @@ -58,18 +58,26 @@ class CertificateAuditLog(BaseModel): def __repr__(self): """String representation of CertificateAuditLog.""" - return f"" + return ( + f"" + ) @classmethod - def log(cls, certificate_id, action, user_id=None, **kwargs): + def log( + cls, + certificate_id: str, + action: str, + user_id: str = None, + **kwargs, + ) -> "CertificateAuditLog": """Create a certificate audit log entry. - + Args: certificate_id: ID of the certificate action: Action type (e.g., "signed", "revoked") user_id: ID of the user performing the action (optional) **kwargs: Additional fields (ip_address, user_agent, message, etc.) - + Returns: CertificateAuditLog instance """ @@ -77,7 +85,7 @@ class CertificateAuditLog(BaseModel): certificate_id=certificate_id, action=action, user_id=user_id, - **kwargs + **kwargs, ) log_entry.save() return log_entry diff --git a/gatehouse_app/models/ssh_certificate.py b/gatehouse_app/models/ssh_ca/ssh_certificate.py similarity index 70% rename from gatehouse_app/models/ssh_certificate.py rename to gatehouse_app/models/ssh_ca/ssh_certificate.py index 58226e9..f226a69 100644 --- a/gatehouse_app/models/ssh_certificate.py +++ b/gatehouse_app/models/ssh_ca/ssh_certificate.py @@ -1,26 +1,26 @@ -"""SSH Certificate model.""" +"""SSH Certificate model — signed SSH user/host certificates.""" from enum import Enum from datetime import datetime, timezone from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel -from gatehouse_app.models.ca import CertType +from gatehouse_app.models.ssh_ca.ca import CertType class CertificateStatus(str, Enum): """SSH certificate lifecycle status.""" - - REQUESTED = "requested" # Waiting for signing - ISSUED = "issued" # Signed and valid - REVOKED = "revoked" # Manually revoked - EXPIRED = "expired" # Validity period ended - SUPERSEDED = "superseded" # Replaced by newer cert + + REQUESTED = "requested" # Waiting for signing + ISSUED = "issued" # Signed and valid + REVOKED = "revoked" # Manually revoked + EXPIRED = "expired" # Validity period ended + SUPERSEDED = "superseded" # Replaced by newer certificate class SSHCertificate(BaseModel): """SSH Certificate model representing a signed SSH user/host certificate. - + Certificates are issued by a CA and associated with an SSH public key. - They include principals (access levels), validity periods, and other + They include principals (access levels), validity periods, and standard OpenSSH certificate metadata. """ @@ -45,10 +45,10 @@ class SSHCertificate(BaseModel): nullable=False, index=True, ) - - # Certificate content (full signed certificate in OpenSSH format) + + # Certificate content — full signed certificate in OpenSSH format certificate = db.Column(db.Text, nullable=False) - + # Certificate metadata serial = db.Column(db.String(255), nullable=False, unique=True, index=True) key_id = db.Column(db.String(255), nullable=False) # Usually user email @@ -57,19 +57,19 @@ class SSHCertificate(BaseModel): default=CertType.USER, nullable=False, ) - - # Principals (JSON list) - e.g., ["prod-servers", "dev-servers"] + + # Principals — JSON list, e.g., ["prod-servers", "dev-servers"] principals = db.Column(db.JSON, nullable=False, default=list) - + # Validity period valid_after = db.Column(db.DateTime, nullable=False) valid_before = db.Column(db.DateTime, nullable=False) - + # Revocation status revoked = db.Column(db.Boolean, default=False, nullable=False, index=True) revoked_at = db.Column(db.DateTime, nullable=True) revoke_reason = db.Column(db.String(255), nullable=True) - + # Status tracking status = db.Column( db.Enum(CertificateStatus, values_callable=lambda x: [e.value for e in x]), @@ -77,19 +77,17 @@ class SSHCertificate(BaseModel): nullable=False, index=True, ) - + # Request metadata request_ip = db.Column(db.String(45), nullable=True) request_user_agent = db.Column(db.String(512), nullable=True) - - # Critical options (JSON) - OpenSSH critical options - # See: https://man.openbsd.org/ssh-cert + + # Critical options — OpenSSH critical options (JSON) critical_options = db.Column(db.JSON, nullable=True, default=dict) - - # Extensions (JSON) - OpenSSH extensions - # Common ones: permit-X11-forwarding, permit-agent-forwarding, permit-pty, etc. + + # Extensions — OpenSSH extensions (JSON) extensions = db.Column(db.JSON, nullable=True, default=dict) - + # Relationships ca = db.relationship("CA", back_populates="certificates") user = db.relationship("User", back_populates="ssh_certificates") @@ -115,67 +113,64 @@ class SSHCertificate(BaseModel): return f"" def to_dict(self, exclude=None): - """Convert certificate to dictionary.""" + """Convert certificate to dictionary. + + The raw ``certificate`` blob is excluded by default (it is large and + callers can request it explicitly by removing it from the exclude list). + """ exclude = exclude or [] - # Optionally exclude the certificate content (it's large) if "certificate" not in exclude: exclude.append("certificate") data = super().to_dict(exclude=exclude) - - # Add computed fields data["is_valid"] = self.is_valid() data["days_until_expiry"] = self.days_until_expiry() - return data - def is_valid(self): + def _aware(self, dt: datetime) -> datetime: + """Return a timezone-aware UTC datetime.""" + return dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt + + def is_valid(self) -> bool: """Check if certificate is currently valid. - + Returns: True if certificate is issued, not revoked, and within validity period """ if self.revoked or self.status == CertificateStatus.REVOKED: return False - now = datetime.now(timezone.utc) - valid_after = self.valid_after.replace(tzinfo=timezone.utc) if self.valid_after.tzinfo is None else self.valid_after - valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before - return valid_after <= now <= valid_before + return self._aware(self.valid_after) <= now <= self._aware(self.valid_before) - def is_expired(self): + def is_expired(self) -> bool: """Check if certificate has expired. - + Returns: True if current time is past valid_before """ - now = datetime.now(timezone.utc) - valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before - return now > valid_before + return datetime.now(timezone.utc) > self._aware(self.valid_before) - def days_until_expiry(self): + def days_until_expiry(self) -> int: """Get number of days until certificate expires. - + Returns: Number of days remaining (negative if already expired) """ - now = datetime.now(timezone.utc) - valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before - delta = valid_before - now + delta = self._aware(self.valid_before) - datetime.now(timezone.utc) return delta.days + (1 if delta.seconds > 0 else 0) - def revoke(self, reason=None): + def revoke(self, reason: str = None) -> None: """Revoke this certificate. - + Args: reason: Optional reason for revocation """ self.revoked = True - self.revoked_at = datetime.utcnow() + self.revoked_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow() self.revoke_reason = reason self.status = CertificateStatus.REVOKED self.save() - def mark_expired(self): + def mark_expired(self) -> None: """Mark certificate as expired when validity period ends.""" self.status = CertificateStatus.EXPIRED self.save() diff --git a/gatehouse_app/models/ssh_key.py b/gatehouse_app/models/ssh_ca/ssh_key.py similarity index 55% rename from gatehouse_app/models/ssh_key.py rename to gatehouse_app/models/ssh_ca/ssh_key.py index 0d6fbca..218fd99 100644 --- a/gatehouse_app/models/ssh_key.py +++ b/gatehouse_app/models/ssh_ca/ssh_key.py @@ -1,14 +1,14 @@ -"""SSH Key model.""" -from datetime import datetime +"""SSH Key model — user SSH public keys registered for certificate signing.""" +from datetime import datetime, timezone from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel class SSHKey(BaseModel): """SSH Key model representing a user's SSH public key. - - This model stores SSH public keys that users register for certificate signing. - Users must verify ownership of the key before it can be used for signing certificates. + + Users register SSH public keys for certificate signing. Keys must be + verified (owner proved possession) before they can be used. """ __tablename__ = "ssh_keys" @@ -19,33 +19,29 @@ class SSHKey(BaseModel): nullable=False, index=True, ) - - # SSH key payload in OpenSSH format (e.g., "ssh-rsa AAAAB3Nz...") + + # SSH key payload in OpenSSH format (e.g., "ssh-ed25519 AAAAB3Nz...") payload = db.Column(db.Text, nullable=False, unique=True) - - # SHA256 fingerprint for quick comparison + + # SHA256 fingerprint for quick comparison and deduplication fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True) - - # Optional description for the key (e.g., "My laptop key") + + # Optional human-readable description (e.g., "My laptop key") description = db.Column(db.String(255), nullable=True) - + # Verification status verified = db.Column(db.Boolean, default=False, nullable=False, index=True) verified_at = db.Column(db.DateTime, nullable=True) - - # Verification challenge + + # Verification challenge — shown to user once, cleared after verification verify_text = db.Column(db.String(255), nullable=True) verify_text_created_at = db.Column(db.DateTime, nullable=True) - - # Key type extracted from the key (ssh-rsa, ssh-ed25519, etc.) - key_type = db.Column(db.String(50), nullable=True) - - # Key bits/length - key_bits = db.Column(db.Integer, nullable=True) - - # Comment from the key (usually email or key name) + + # Key metadata extracted from the key + key_type = db.Column(db.String(50), nullable=True) # ssh-rsa, ssh-ed25519, etc. + key_bits = db.Column(db.Integer, nullable=True) # key length key_comment = db.Column(db.String(255), nullable=True) - + # Relationships user = db.relationship("User", back_populates="ssh_keys") certificates = db.relationship( @@ -64,33 +60,39 @@ class SSHKey(BaseModel): return f"" def to_dict(self, exclude=None): - """Convert SSH key to dictionary.""" + """Convert SSH key to dictionary. + + ``payload`` and ``verify_text`` are never exposed through the API. + """ exclude = exclude or [] - exclude.extend(["payload", "verify_text"]) # Never expose these in API + for field in ("payload", "verify_text"): + if field not in exclude: + exclude.append(field) data = super().to_dict(exclude=exclude) - - # Add computed fields data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None]) - return data - def mark_verified(self): - """Mark this SSH key as verified.""" + def mark_verified(self) -> None: + """Mark this SSH key as verified and clear the challenge.""" self.verified = True - self.verified_at = datetime.utcnow() + self.verified_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow() + self.verify_text = None self.save() - def needs_verification_refresh(self, max_age_hours=24): + def needs_verification_refresh(self, max_age_hours: int = 24) -> bool: """Check if verification challenge needs to be refreshed. - + Args: max_age_hours: Maximum age of verification challenge in hours - + Returns: - True if verification challenge is stale + True if verification challenge is stale or missing """ if not self.verify_text_created_at: return True - - age = datetime.utcnow() - self.verify_text_created_at + age = datetime.now(timezone.utc) - self.verify_text_created_at.replace( + tzinfo=timezone.utc + ) if self.verify_text_created_at.tzinfo is None else ( + datetime.now(timezone.utc) - self.verify_text_created_at + ) return age.total_seconds() > (max_age_hours * 3600) diff --git a/gatehouse_app/models/user.py b/gatehouse_app/models/user.py index a5b3aff..eecf942 100644 --- a/gatehouse_app/models/user.py +++ b/gatehouse_app/models/user.py @@ -1,177 +1,4 @@ -"""User model.""" -from gatehouse_app.extensions import db -from gatehouse_app.models.base import BaseModel -from gatehouse_app.utils.constants import UserStatus +"""Backward-compatibility shim — import from gatehouse_app.models.user.user instead.""" +from gatehouse_app.models.user.user import User # noqa: F401 - -class User(BaseModel): - """User model representing a user account.""" - - __tablename__ = "users" - - email = db.Column(db.String(255), unique=True, nullable=False, index=True) - email_verified = db.Column(db.Boolean, default=False, nullable=False) - full_name = db.Column(db.String(255), nullable=True) - avatar_url = db.Column(db.String(512), nullable=True) - status = db.Column( - db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True - ) - last_login_at = db.Column(db.DateTime, nullable=True) - last_login_ip = db.Column(db.String(45), nullable=True) - - # Relationships - authentication_methods = db.relationship( - "AuthenticationMethod", back_populates="user", cascade="all, delete-orphan" - ) - sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan") - organization_memberships = db.relationship( - "OrganizationMember", - back_populates="user", - cascade="all, delete-orphan", - foreign_keys="OrganizationMember.user_id", - ) - audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan") - security_policies = db.relationship( - "UserSecurityPolicy", - back_populates="user", - cascade="all, delete-orphan", - foreign_keys="UserSecurityPolicy.user_id", - ) - mfa_compliance = db.relationship( - "MfaPolicyCompliance", - back_populates="user", - cascade="all, delete-orphan", - foreign_keys="MfaPolicyCompliance.user_id", - ) - department_memberships = db.relationship( - "DepartmentMembership", - back_populates="user", - cascade="all, delete-orphan", - foreign_keys="DepartmentMembership.user_id", - ) - principal_memberships = db.relationship( - "PrincipalMembership", - back_populates="user", - cascade="all, delete-orphan", - foreign_keys="PrincipalMembership.user_id", - ) - ssh_keys = db.relationship( - "SSHKey", - back_populates="user", - cascade="all, delete-orphan", - foreign_keys="SSHKey.user_id", - ) - ssh_certificates = db.relationship( - "SSHCertificate", - back_populates="user", - cascade="all, delete-orphan", - foreign_keys="SSHCertificate.user_id", - ) - - def __repr__(self): - """String representation of User.""" - return f"" - - def to_dict(self, exclude=None): - """Convert user to dictionary, excluding sensitive fields by default.""" - exclude = exclude or [] - # Always exclude password-related fields - default_exclude = [] - all_exclude = list(set(default_exclude + exclude)) - return super().to_dict(exclude=all_exclude) - - def has_password_auth(self): - """Check if user has password authentication enabled.""" - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return ( - AuthenticationMethod.query.filter_by( - user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None - ).first() - is not None - ) - - def get_organizations(self): - """Get all organizations the user is a member of.""" - return [membership.organization for membership in self.organization_memberships] - - def has_totp_enabled(self) -> bool: - """Check if user has TOTP enabled and verified. - - Returns: - True if user has a verified TOTP authentication method, False otherwise. - """ - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return ( - AuthenticationMethod.query.filter_by( - user_id=self.id, - method_type=AuthMethodType.TOTP, - verified=True, - deleted_at=None, - ).first() - is not None - ) - - def get_totp_method(self): - """Get user's TOTP authentication method. - - Returns: - The AuthenticationMethod instance for TOTP or None if not found. - - Note: - Returns the most recently created TOTP method to handle cases where - multiple enrollment attempts may exist. - """ - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return AuthenticationMethod.query.filter_by( - user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None - ).order_by(AuthenticationMethod.created_at.desc()).first() - - def has_webauthn_enabled(self) -> bool: - """Check if user has any WebAuthn passkey credentials. - - Returns: - True if user has at least one WebAuthn credential, False otherwise. - """ - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return ( - AuthenticationMethod.query.filter_by( - user_id=self.id, - method_type=AuthMethodType.WEBAUTHN, - deleted_at=None, - ).first() - is not None - ) - - def get_webauthn_credentials(self): - """Get all WebAuthn credentials for the user. - - Returns: - List of AuthenticationMethod instances for WebAuthn, ordered by creation date. - """ - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return AuthenticationMethod.query.filter_by( - user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None - ).order_by(AuthenticationMethod.created_at.desc()).all() - - def get_webauthn_credential_count(self) -> int: - """Get the count of WebAuthn credentials for the user. - - Returns: - Number of WebAuthn credentials. - """ - from gatehouse_app.models.authentication_method import AuthenticationMethod - from gatehouse_app.utils.constants import AuthMethodType - - return AuthenticationMethod.query.filter_by( - user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None - ).count() +__all__ = ["User"] diff --git a/gatehouse_app/models/user/__init__.py b/gatehouse_app/models/user/__init__.py new file mode 100644 index 0000000..1d05e9a --- /dev/null +++ b/gatehouse_app/models/user/__init__.py @@ -0,0 +1,5 @@ +"""User subpackage.""" +from gatehouse_app.models.user.user import User +from gatehouse_app.models.user.session import Session + +__all__ = ["User", "Session"] diff --git a/gatehouse_app/models/session.py b/gatehouse_app/models/user/session.py similarity index 86% rename from gatehouse_app/models/session.py rename to gatehouse_app/models/user/session.py index 0290a20..9a78830 100644 --- a/gatehouse_app/models/session.py +++ b/gatehouse_app/models/user/session.py @@ -21,7 +21,9 @@ class Session(BaseModel): # Timing expires_at = db.Column(db.DateTime, nullable=False) - last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)) + last_activity_at = db.Column( + db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc) + ) revoked_at = db.Column(db.DateTime, nullable=True) revoked_reason = db.Column(db.String(255), nullable=True) @@ -38,7 +40,6 @@ class Session(BaseModel): def is_active(self): """Check if session is currently active.""" now = datetime.now(timezone.utc) - # Make expires_at timezone-aware if it's naive expires_at = self.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) @@ -51,15 +52,13 @@ class Session(BaseModel): def is_expired(self): """Check if session has expired.""" now = datetime.now(timezone.utc) - # Make expires_at timezone-aware if it's naive expires_at = self.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) return now > expires_at - def refresh(self, duration_seconds=86400): - """ - Refresh session expiration. + def refresh(self, duration_seconds: int = 86400): + """Refresh session expiration. Args: duration_seconds: New session duration in seconds @@ -68,9 +67,8 @@ class Session(BaseModel): self.last_activity_at = datetime.now(timezone.utc) db.session.commit() - def revoke(self, reason=None): - """ - Revoke the session. + def revoke(self, reason: str = None): + """Revoke the session. Args: reason: Optional reason for revocation @@ -84,6 +82,5 @@ class Session(BaseModel): def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] - # Exclude token from dict exclude.append("token") return super().to_dict(exclude=exclude) diff --git a/gatehouse_app/models/user/user.py b/gatehouse_app/models/user/user.py new file mode 100644 index 0000000..c2fb1c8 --- /dev/null +++ b/gatehouse_app/models/user/user.py @@ -0,0 +1,209 @@ +"""User model.""" +from gatehouse_app.extensions import db +from gatehouse_app.models.base import BaseModel +from gatehouse_app.utils.constants import UserStatus + + +class User(BaseModel): + """User model representing a user account.""" + + __tablename__ = "users" + + email = db.Column(db.String(255), unique=True, nullable=False, index=True) + email_verified = db.Column(db.Boolean, default=False, nullable=False) + full_name = db.Column(db.String(255), nullable=True) + avatar_url = db.Column(db.String(512), nullable=True) + status = db.Column( + db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True + ) + last_login_at = db.Column(db.DateTime, nullable=True) + last_login_ip = db.Column(db.String(45), nullable=True) + + # Account activation (email-link flow) + activated = db.Column(db.Boolean, default=True, nullable=False) + activation_key = db.Column(db.String(128), unique=True, nullable=True, index=True) + + # Relationships – defined here only for models that don't circular-import + authentication_methods = db.relationship( + "AuthenticationMethod", back_populates="user", cascade="all, delete-orphan" + ) + sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan") + organization_memberships = db.relationship( + "OrganizationMember", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="OrganizationMember.user_id", + ) + audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan") + security_policies = db.relationship( + "UserSecurityPolicy", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="UserSecurityPolicy.user_id", + ) + mfa_compliance = db.relationship( + "MfaPolicyCompliance", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="MfaPolicyCompliance.user_id", + ) + department_memberships = db.relationship( + "DepartmentMembership", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="DepartmentMembership.user_id", + ) + principal_memberships = db.relationship( + "PrincipalMembership", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="PrincipalMembership.user_id", + ) + ssh_keys = db.relationship( + "SSHKey", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="SSHKey.user_id", + ) + ssh_certificates = db.relationship( + "SSHCertificate", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="SSHCertificate.user_id", + ) + ca_permissions = db.relationship( + "CAPermission", + back_populates="user", + cascade="all, delete-orphan", + foreign_keys="CAPermission.user_id", + ) + + # OIDC relationships – registered here (no monkey-patching needed) + oidc_auth_codes = db.relationship( + "OIDCAuthCode", back_populates="user", cascade="all, delete-orphan" + ) + oidc_refresh_tokens = db.relationship( + "OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan" + ) + oidc_sessions = db.relationship( + "OIDCSession", back_populates="user", cascade="all, delete-orphan" + ) + oidc_token_metadata = db.relationship( + "OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan" + ) + oidc_audit_logs = db.relationship( + "OIDCAuditLog", back_populates="user", cascade="all, delete-orphan" + ) + + def __repr__(self): + """String representation of User.""" + return f"" + + def to_dict(self, exclude=None): + """Convert user to dictionary, excluding sensitive fields by default.""" + exclude = exclude or [] + return super().to_dict(exclude=exclude) + + def has_password_auth(self): + """Check if user has password authentication enabled.""" + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return ( + AuthenticationMethod.query.filter_by( + user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None + ).first() + is not None + ) + + def get_organizations(self): + """Get all organizations the user is a member of.""" + return [membership.organization for membership in self.organization_memberships] + + def has_totp_enabled(self) -> bool: + """Check if user has TOTP enabled and verified. + + Returns: + True if user has a verified TOTP authentication method, False otherwise. + """ + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return ( + AuthenticationMethod.query.filter_by( + user_id=self.id, + method_type=AuthMethodType.TOTP, + verified=True, + deleted_at=None, + ).first() + is not None + ) + + def get_totp_method(self): + """Get user's TOTP authentication method. + + Returns: + The AuthenticationMethod instance for TOTP or None if not found. + + Note: + Returns the most recently created TOTP method to handle cases where + multiple enrollment attempts may exist. + """ + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return ( + AuthenticationMethod.query.filter_by( + user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None + ) + .order_by(AuthenticationMethod.created_at.desc()) + .first() + ) + + def has_webauthn_enabled(self) -> bool: + """Check if user has any WebAuthn passkey credentials. + + Returns: + True if user has at least one WebAuthn credential, False otherwise. + """ + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return ( + AuthenticationMethod.query.filter_by( + user_id=self.id, + method_type=AuthMethodType.WEBAUTHN, + deleted_at=None, + ).first() + is not None + ) + + def get_webauthn_credentials(self): + """Get all WebAuthn credentials for the user. + + Returns: + List of AuthenticationMethod instances for WebAuthn, ordered by creation date. + """ + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return ( + AuthenticationMethod.query.filter_by( + user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None + ) + .order_by(AuthenticationMethod.created_at.desc()) + .all() + ) + + def get_webauthn_credential_count(self) -> int: + """Get the count of WebAuthn credentials for the user. + + Returns: + Number of WebAuthn credentials. + """ + from gatehouse_app.models.auth.authentication_method import AuthenticationMethod + from gatehouse_app.utils.constants import AuthMethodType + + return AuthenticationMethod.query.filter_by( + user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None + ).count()