"""Authentication method model.""" from datetime import datetime, timedelta, timezone import secrets from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel from gatehouse_app.utils.constants import AuthMethodType from gatehouse_app.utils.encryption import encrypt, decrypt class AuthenticationMethod(BaseModel): """Authentication method model storing user authentication credentials.""" __tablename__ = "authentication_methods" user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True) method_type = db.Column(db.Enum(AuthMethodType), nullable=False, index=True) # For password authentication password_hash = db.Column(db.String(255), nullable=True) # For OAuth/OIDC providers provider_user_id = db.Column(db.String(255), nullable=True) provider_data = db.Column(db.JSON, nullable=True) # For TOTP authentication totp_secret = db.Column(db.String(32), nullable=True) totp_backup_codes = db.Column(db.JSON, nullable=True) totp_verified_at = db.Column(db.DateTime, nullable=True) # Metadata is_primary = db.Column(db.Boolean, default=False, nullable=False) verified = db.Column(db.Boolean, default=False, nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) # Relationships user = db.relationship("User", back_populates="authentication_methods") # Ensure unique provider combinations __table_args__ = ( db.Index("idx_user_method", "user_id", "method_type"), db.UniqueConstraint( "user_id", "method_type", "provider_user_id", name="uix_user_method_provider" ), ) def __repr__(self): """String representation of AuthenticationMethod.""" return f"" def is_password(self): """Check if this is a password authentication method.""" return self.method_type == AuthMethodType.PASSWORD def is_oauth(self): """Check if this is an OAuth authentication method.""" return self.method_type in [ AuthMethodType.GOOGLE, AuthMethodType.GITHUB, AuthMethodType.MICROSOFT, ] def is_totp(self): """Check if this is a TOTP authentication method.""" return self.method_type == AuthMethodType.TOTP def is_webauthn(self): """Check if this is a WebAuthn authentication method.""" return self.method_type == AuthMethodType.WEBAUTHN def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] # Always exclude password hash and TOTP secrets exclude.append("password_hash") exclude.append("totp_secret") exclude.append("totp_backup_codes") return super().to_dict(exclude=exclude) def to_webauthn_dict(self): """Convert WebAuthn credential to public dictionary. Returns: Dictionary with safe-to-expose credential information. """ if not self.is_webauthn() or not self.provider_data: return None data = self.provider_data return { "id": data.get("credential_id"), "name": data.get("name"), "transports": data.get("transports", []), "created_at": data.get("created_at"), "last_used_at": data.get("last_used_at"), "sign_count": data.get("sign_count", 0), } class ApplicationProviderConfig(BaseModel): """Application-wide OAuth provider configuration. This model stores OAuth provider credentials at the application level, allowing users to authenticate without needing to specify an organization first. """ __tablename__ = "application_provider_configs" # Provider identification provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True) # OAuth credentials (encrypted) client_id = db.Column(db.String(255), nullable=False) client_secret_encrypted = db.Column(db.String(512), nullable=True) # Provider status is_enabled = db.Column(db.Boolean, default=True, nullable=False) # Default redirect URL default_redirect_url = db.Column(db.String(2048), nullable=True) # Provider-specific settings (JSON) additional_config = db.Column(db.JSON, nullable=True) # Relationships organization_overrides = db.relationship( "OrganizationProviderOverride", back_populates="application_config", foreign_keys="OrganizationProviderOverride.provider_type", primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type", cascade="all, delete-orphan" ) def __repr__(self): """String representation of ApplicationProviderConfig.""" return f"" def set_client_secret(self, plaintext_secret: str): """Encrypt and store client secret. Args: plaintext_secret: The plaintext OAuth client secret """ if plaintext_secret: self.client_secret_encrypted = encrypt(plaintext_secret) def get_client_secret(self) -> str: """Decrypt and return client secret. Returns: The plaintext OAuth client secret """ if self.client_secret_encrypted: return decrypt(self.client_secret_encrypted) return None def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] # Always exclude encrypted client secret exclude.append("client_secret_encrypted") return super().to_dict(exclude=exclude) class OrganizationProviderOverride(BaseModel): """Organization-specific OAuth configuration overrides. This model allows organizations to override application-level OAuth settings for enterprise SSO scenarios or custom provider configurations. """ __tablename__ = "organization_provider_overrides" # References organization_id = db.Column( db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True ) provider_type = db.Column(db.String(50), nullable=False, index=True) # Override OAuth credentials (encrypted, nullable - only if overriding) client_id = db.Column(db.String(255), nullable=True) client_secret_encrypted = db.Column(db.String(512), nullable=True) # Provider status is_enabled = db.Column(db.Boolean, default=True, nullable=False) # Redirect URL override redirect_url_override = db.Column(db.String(2048), nullable=True) # Provider-specific settings override (JSON) additional_config = db.Column(db.JSON, nullable=True) # Relationships organization = db.relationship("Organization", backref="provider_overrides") application_config = db.relationship( "ApplicationProviderConfig", back_populates="organization_overrides", foreign_keys=[provider_type], primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type", viewonly=True ) # Unique constraint on (organization_id, provider_type) __table_args__ = ( db.UniqueConstraint( "organization_id", "provider_type", name="uix_org_provider_type" ), ) def __repr__(self): """String representation of OrganizationProviderOverride.""" return f"" def set_client_secret(self, plaintext_secret: str): """Encrypt and store client secret override. Args: plaintext_secret: The plaintext OAuth client secret """ if plaintext_secret: self.client_secret_encrypted = encrypt(plaintext_secret) def get_client_secret(self) -> str: """Decrypt and return client secret override. Returns: The plaintext OAuth client secret """ if self.client_secret_encrypted: return decrypt(self.client_secret_encrypted) return None def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] # Always exclude encrypted client secret exclude.append("client_secret_encrypted") return super().to_dict(exclude=exclude) class OAuthState(BaseModel): """OAuth flow state tracking. This model tracks OAuth authentication flow state, including PKCE parameters and organization context (which is now optional to support login flows where the organization isn't known until after authentication). """ __tablename__ = "oauth_states" # OAuth state parameter (unique, used for CSRF protection) state = db.Column(db.String(64), unique=True, nullable=False, index=True) # Flow type: "login", "register", "link" flow_type = db.Column(db.String(50), nullable=False) # Provider type provider_type = db.Column(db.String(50), nullable=False) # User context (optional - not set for login/register flows) user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True) # Organization context (NOW OPTIONAL - for SSO discovery or post-auth) organization_id = db.Column( db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True ) # PKCE parameters nonce = db.Column(db.String(128), nullable=True) code_verifier = db.Column(db.String(128), nullable=True) code_challenge = db.Column(db.String(128), nullable=True) # OAuth parameters redirect_uri = db.Column(db.String(2048), nullable=True) # Post-auth redirect (for frontend routing) return_url = db.Column(db.String(2048), nullable=True) # Additional state data extra_data = db.Column(db.JSON, nullable=True) # Expiration and usage tracking expires_at = db.Column(db.DateTime, nullable=False, index=True) used = db.Column(db.Boolean, default=False, nullable=False) # Relationships user = db.relationship("User", backref="oauth_states") organization = db.relationship("Organization", backref="oauth_states") def __repr__(self): """String representation of OAuthState.""" return f"" @classmethod def create_state( cls, flow_type: str, provider_type: str, user_id: str = None, organization_id: str = None, redirect_uri: str = None, return_url: str = None, code_verifier: str = None, code_challenge: str = None, nonce: str = None, extra_data: dict = None, lifetime_seconds: int = 600 ): """Create a new OAuth state with auto-generated state parameter. Args: flow_type: Type of flow ("login", "register", "link") provider_type: OAuth provider type user_id: Optional user ID for authenticated flows organization_id: Optional organization ID redirect_uri: OAuth callback URI return_url: Post-auth redirect destination code_verifier: PKCE code verifier code_challenge: PKCE code challenge nonce: OpenID Connect nonce extra_data: Additional state data lifetime_seconds: How long the state is valid (default 10 minutes) Returns: New OAuthState instance """ state = secrets.token_urlsafe(32) expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds) oauth_state = cls( state=state, flow_type=flow_type, provider_type=provider_type, user_id=user_id, organization_id=organization_id, redirect_uri=redirect_uri, return_url=return_url, code_verifier=code_verifier, code_challenge=code_challenge, nonce=nonce, extra_data=extra_data, expires_at=expires_at, used=False ) oauth_state.save() return oauth_state def is_valid(self) -> bool: """Check if the OAuth state is still valid. Returns: True if state hasn't expired and hasn't been used """ now = datetime.now(timezone.utc) # Make expires_at timezone-aware if it's naive (database returns naive datetimes) expires_at = self.expires_at if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) return not self.used and expires_at > now def mark_used(self): """Mark the state as used to prevent replay attacks.""" self.used = True self.save() @classmethod def cleanup_expired(cls): """Remove expired OAuth states.""" now = datetime.now(timezone.utc) cls.query.filter(cls.expires_at < now).delete() db.session.commit() def to_dict(self, exclude=None): """Convert to dictionary, excluding sensitive fields.""" exclude = exclude or [] # Exclude code_verifier as it's sensitive exclude.append("code_verifier") return super().to_dict(exclude=exclude)