Chore: Refractor Models into organized file/folder

This commit is contained in:
2026-03-01 12:40:48 +05:45
parent 58432da1c8
commit 07193a2d2e
35 changed files with 1475 additions and 932 deletions
+118 -44
View File
@@ -1,76 +1,150 @@
"""Models package."""
from gatehouse_app.models.base import BaseModel
from gatehouse_app.models.user import User
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.authentication_method import (
"""Models package.
Sub-packages
------------
models.user — User, Session
models.organization — Organization, OrganizationMember, Department,
DepartmentMembership, DepartmentPrincipal,
DepartmentCertPolicy, Principal, PrincipalMembership,
OrgInviteToken
models.auth — AuthenticationMethod, ApplicationProviderConfig,
OrganizationProviderOverride, OAuthState,
AuditLog, PasswordResetToken, EmailVerificationToken
models.oidc — OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCSession,
OIDCTokenMetadata, OIDCAuditLog, OidcJwksKey
models.ssh_ca — CA, KeyType, CertType, CaType, CAPermission,
SSHKey, SSHCertificate, CertificateStatus,
CertificateAuditLog
models.security — OrganizationSecurityPolicy, UserSecurityPolicy,
MfaPolicyCompliance
All names are re-exported here so that existing code using the flat import
style (``from gatehouse_app.models import X``) or the old per-file style
(``from gatehouse_app.models.user import User``) continue to work unchanged.
"""
# ── Base ──────────────────────────────────────────────────────────────────────
from gatehouse_app.models.base import BaseModel # noqa: F401
# ── User ──────────────────────────────────────────────────────────────────────
from gatehouse_app.models.user.user import User # noqa: F401
from gatehouse_app.models.user.session import Session # noqa: F401
# ── Organization ──────────────────────────────────────────────────────────────
from gatehouse_app.models.organization.organization import Organization # noqa: F401
from gatehouse_app.models.organization.organization_member import ( # noqa: F401
OrganizationMember,
)
from gatehouse_app.models.organization.department import ( # noqa: F401
Department,
DepartmentMembership,
DepartmentPrincipal,
)
from gatehouse_app.models.organization.department_cert_policy import ( # noqa: F401
DepartmentCertPolicy,
STANDARD_EXTENSIONS,
)
from gatehouse_app.models.organization.principal import ( # noqa: F401
Principal,
PrincipalMembership,
)
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken # noqa: F401
# ── Auth ──────────────────────────────────────────────────────────────────────
from gatehouse_app.models.auth.authentication_method import ( # noqa: F401
AuthenticationMethod,
ApplicationProviderConfig,
OrganizationProviderOverride,
OAuthState,
)
from gatehouse_app.models.session import Session
from gatehouse_app.models.audit_log import AuditLog
from gatehouse_app.models.oidc_client import OIDCClient
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken
from gatehouse_app.models.oidc_session import OIDCSession
from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
from gatehouse_app.models.department import (
Department,
DepartmentMembership,
DepartmentPrincipal,
from gatehouse_app.models.auth.audit_log import AuditLog # noqa: F401
from gatehouse_app.models.auth.password_reset_token import PasswordResetToken # noqa: F401
from gatehouse_app.models.auth.email_verification_token import ( # noqa: F401
EmailVerificationToken,
)
from gatehouse_app.models.principal import (
Principal,
PrincipalMembership,
# ── OIDC ──────────────────────────────────────────────────────────────────────
from gatehouse_app.models.oidc.oidc_client import OIDCClient # noqa: F401
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode # noqa: F401
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken # noqa: F401
from gatehouse_app.models.oidc.oidc_session import OIDCSession # noqa: F401
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata # noqa: F401
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog # noqa: F401
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey # noqa: F401
# ── SSH / CA ──────────────────────────────────────────────────────────────────
from gatehouse_app.models.ssh_ca.ca import ( # noqa: F401
CA,
KeyType,
CertType,
CaType,
CAPermission,
)
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey # noqa: F401
from gatehouse_app.models.ssh_ca.ssh_certificate import ( # noqa: F401
SSHCertificate,
CertificateStatus,
)
from gatehouse_app.models.ssh_ca.certificate_audit_log import ( # noqa: F401
CertificateAuditLog,
)
# ── Security ──────────────────────────────────────────────────────────────────
from gatehouse_app.models.security.organization_security_policy import ( # noqa: F401
OrganizationSecurityPolicy,
)
from gatehouse_app.models.security.user_security_policy import ( # noqa: F401
UserSecurityPolicy,
)
from gatehouse_app.models.security.mfa_policy_compliance import ( # noqa: F401
MfaPolicyCompliance,
)
from gatehouse_app.models.ssh_key import SSHKey
from gatehouse_app.models.ca import CA, KeyType, CertType, CAPermission
from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus
from gatehouse_app.models.certificate_audit_log import CertificateAuditLog
from gatehouse_app.models.password_reset_token import PasswordResetToken
from gatehouse_app.models.email_verification_token import EmailVerificationToken
from gatehouse_app.models.org_invite_token import OrgInviteToken
__all__ = [
# Base
"BaseModel",
# User
"User",
"Session",
# Organization
"Organization",
"OrganizationMember",
"Department",
"DepartmentMembership",
"DepartmentPrincipal",
"DepartmentCertPolicy",
"STANDARD_EXTENSIONS",
"Principal",
"PrincipalMembership",
"OrgInviteToken",
# Auth
"AuthenticationMethod",
"ApplicationProviderConfig",
"OrganizationProviderOverride",
"OAuthState",
"Session",
"AuditLog",
"PasswordResetToken",
"EmailVerificationToken",
# OIDC
"OIDCClient",
"OIDCAuthCode",
"OIDCRefreshToken",
"OIDCSession",
"OIDCTokenMetadata",
"OIDCAuditLog",
"OrganizationSecurityPolicy",
"UserSecurityPolicy",
"MfaPolicyCompliance",
"Department",
"DepartmentMembership",
"DepartmentPrincipal",
"Principal",
"PrincipalMembership",
"SSHKey",
"OidcJwksKey",
# SSH / CA
"CA",
"KeyType",
"CertType",
"CaType",
"CAPermission",
"SSHKey",
"SSHCertificate",
"CertificateStatus",
"CertificateAuditLog",
"PasswordResetToken",
"EmailVerificationToken",
"OrgInviteToken",
# Security
"OrganizationSecurityPolicy",
"UserSecurityPolicy",
"MfaPolicyCompliance",
]
+20
View File
@@ -0,0 +1,20 @@
"""Auth subpackage — authentication methods, tokens, and audit logs."""
from gatehouse_app.models.auth.authentication_method import (
AuthenticationMethod,
ApplicationProviderConfig,
OrganizationProviderOverride,
OAuthState,
)
from gatehouse_app.models.auth.audit_log import AuditLog
from gatehouse_app.models.auth.password_reset_token import PasswordResetToken
from gatehouse_app.models.auth.email_verification_token import EmailVerificationToken
__all__ = [
"AuthenticationMethod",
"ApplicationProviderConfig",
"OrganizationProviderOverride",
"OAuthState",
"AuditLog",
"PasswordResetToken",
"EmailVerificationToken",
]
@@ -26,14 +26,13 @@ class AuditLog(BaseModel):
extra_data = db.Column(db.JSON, nullable=True)
description = db.Column(db.Text, nullable=True)
# Success/failure
# Outcome
success = db.Column(db.Boolean, default=True, nullable=False)
error_message = db.Column(db.Text, nullable=True)
# Relationships
user = db.relationship("User", back_populates="audit_logs")
# Indexes for common queries
__table_args__ = (
db.Index("idx_audit_user_action", "user_id", "action"),
db.Index("idx_audit_resource", "resource_type", "resource_id"),
@@ -45,9 +44,8 @@ class AuditLog(BaseModel):
return f"<AuditLog action={self.action} user_id={self.user_id}>"
@classmethod
def log(cls, action, user_id=None, **kwargs):
"""
Create an audit log entry.
def log(cls, action, user_id=None, **kwargs) -> "AuditLog":
"""Create an audit log entry.
Args:
action: AuditAction enum value
@@ -1,4 +1,4 @@
"""Authentication method model."""
"""Authentication method model — user credentials and OAuth provider config."""
from datetime import datetime, timedelta, timezone
import secrets
from gatehouse_app.extensions import db
@@ -35,7 +35,6 @@ class AuthenticationMethod(BaseModel):
# Relationships
user = db.relationship("User", back_populates="authentication_methods")
# Ensure unique provider combinations
__table_args__ = (
db.Index("idx_user_method", "user_id", "method_type"),
db.UniqueConstraint(
@@ -45,13 +44,15 @@ class AuthenticationMethod(BaseModel):
def __repr__(self):
"""String representation of AuthenticationMethod."""
return f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
return (
f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
)
def is_password(self):
def is_password(self) -> bool:
"""Check if this is a password authentication method."""
return self.method_type == AuthMethodType.PASSWORD
def is_oauth(self):
def is_oauth(self) -> bool:
"""Check if this is an OAuth authentication method."""
return self.method_type in [
AuthMethodType.GOOGLE,
@@ -59,32 +60,32 @@ class AuthenticationMethod(BaseModel):
AuthMethodType.MICROSOFT,
]
def is_totp(self):
def is_totp(self) -> bool:
"""Check if this is a TOTP authentication method."""
return self.method_type == AuthMethodType.TOTP
def is_webauthn(self):
def is_webauthn(self) -> bool:
"""Check if this is a WebAuthn authentication method."""
return self.method_type == AuthMethodType.WEBAUTHN
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude password hash and TOTP secrets
exclude.append("password_hash")
exclude.append("totp_secret")
exclude.append("totp_backup_codes")
# Always exclude credential material
for field in ("password_hash", "totp_secret", "totp_backup_codes"):
if field not in exclude:
exclude.append(field)
return super().to_dict(exclude=exclude)
def to_webauthn_dict(self):
"""Convert WebAuthn credential to public dictionary.
Returns:
Dictionary with safe-to-expose credential information.
Dictionary with safe-to-expose credential information, or None.
"""
if not self.is_webauthn() or not self.provider_data:
return None
data = self.provider_data
return {
"id": data.get("credential_id"),
@@ -98,26 +99,26 @@ class AuthenticationMethod(BaseModel):
class ApplicationProviderConfig(BaseModel):
"""Application-wide OAuth provider configuration.
This model stores OAuth provider credentials at the application level,
allowing users to authenticate without needing to specify an organization first.
Stores OAuth provider credentials at the application level, allowing users
to authenticate without needing to specify an organization first.
"""
__tablename__ = "application_provider_configs"
# Provider identification
provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True)
# OAuth credentials (encrypted)
# OAuth credentials (client_secret encrypted at rest)
client_id = db.Column(db.String(255), nullable=False)
client_secret_encrypted = db.Column(db.String(512), nullable=True)
# Provider status
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
# Default redirect URL
default_redirect_url = db.Column(db.String(2048), nullable=True)
# Provider-specific settings (JSON)
additional_config = db.Column(db.JSON, nullable=True)
@@ -126,28 +127,34 @@ class ApplicationProviderConfig(BaseModel):
"OrganizationProviderOverride",
back_populates="application_config",
foreign_keys="OrganizationProviderOverride.provider_type",
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
cascade="all, delete-orphan"
primaryjoin=(
"ApplicationProviderConfig.provider_type"
"==OrganizationProviderOverride.provider_type"
),
cascade="all, delete-orphan",
)
def __repr__(self):
"""String representation of ApplicationProviderConfig."""
return f"<ApplicationProviderConfig provider={self.provider_type} enabled={self.is_enabled}>"
return (
f"<ApplicationProviderConfig provider={self.provider_type} "
f"enabled={self.is_enabled}>"
)
def set_client_secret(self, plaintext_secret: str):
def set_client_secret(self, plaintext_secret: str) -> None:
"""Encrypt and store client secret.
Args:
plaintext_secret: The plaintext OAuth client secret
"""
if plaintext_secret:
self.client_secret_encrypted = encrypt(plaintext_secret)
def get_client_secret(self) -> str:
def get_client_secret(self) -> str | None:
"""Decrypt and return client secret.
Returns:
The plaintext OAuth client secret
The plaintext OAuth client secret, or None if not set.
"""
if self.client_secret_encrypted:
return decrypt(self.client_secret_encrypted)
@@ -156,37 +163,38 @@ class ApplicationProviderConfig(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude encrypted client secret
exclude.append("client_secret_encrypted")
if "client_secret_encrypted" not in exclude:
exclude.append("client_secret_encrypted")
return super().to_dict(exclude=exclude)
class OrganizationProviderOverride(BaseModel):
"""Organization-specific OAuth configuration overrides.
This model allows organizations to override application-level OAuth settings
for enterprise SSO scenarios or custom provider configurations.
Allows organizations to override application-level OAuth settings for
enterprise SSO scenarios or custom provider configurations.
"""
__tablename__ = "organization_provider_overrides"
# References
organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"),
nullable=False, index=True
db.String(36),
db.ForeignKey("organizations.id"),
nullable=False,
index=True,
)
provider_type = db.Column(db.String(50), nullable=False, index=True)
# Override OAuth credentials (encrypted, nullable - only if overriding)
# Override OAuth credentials (encrypted, nullable only set when overriding)
client_id = db.Column(db.String(255), nullable=True)
client_secret_encrypted = db.Column(db.String(512), nullable=True)
# Provider status
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
# Redirect URL override
redirect_url_override = db.Column(db.String(2048), nullable=True)
# Provider-specific settings override (JSON)
additional_config = db.Column(db.JSON, nullable=True)
@@ -196,37 +204,33 @@ class OrganizationProviderOverride(BaseModel):
"ApplicationProviderConfig",
back_populates="organization_overrides",
foreign_keys=[provider_type],
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
viewonly=True
primaryjoin=(
"ApplicationProviderConfig.provider_type"
"==OrganizationProviderOverride.provider_type"
),
viewonly=True,
)
# Unique constraint on (organization_id, provider_type)
__table_args__ = (
db.UniqueConstraint(
"organization_id", "provider_type",
name="uix_org_provider_type"
"organization_id", "provider_type", name="uix_org_provider_type"
),
)
def __repr__(self):
"""String representation of OrganizationProviderOverride."""
return f"<OrganizationProviderOverride org={self.organization_id} provider={self.provider_type}>"
return (
f"<OrganizationProviderOverride org={self.organization_id} "
f"provider={self.provider_type}>"
)
def set_client_secret(self, plaintext_secret: str):
"""Encrypt and store client secret override.
Args:
plaintext_secret: The plaintext OAuth client secret
"""
def set_client_secret(self, plaintext_secret: str) -> None:
"""Encrypt and store client secret override."""
if plaintext_secret:
self.client_secret_encrypted = encrypt(plaintext_secret)
def get_client_secret(self) -> str:
"""Decrypt and return client secret override.
Returns:
The plaintext OAuth client secret
"""
def get_client_secret(self) -> str | None:
"""Decrypt and return client secret override."""
if self.client_secret_encrypted:
return decrypt(self.client_secret_encrypted)
return None
@@ -234,53 +238,52 @@ class OrganizationProviderOverride(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude encrypted client secret
exclude.append("client_secret_encrypted")
if "client_secret_encrypted" not in exclude:
exclude.append("client_secret_encrypted")
return super().to_dict(exclude=exclude)
class OAuthState(BaseModel):
"""OAuth flow state tracking.
This model tracks OAuth authentication flow state, including PKCE parameters
and organization context (which is now optional to support login flows where
the organization isn't known until after authentication).
Tracks OAuth authentication flow state, including PKCE parameters and
organization context (which is optional to support login flows where the
organization isn't known until after authentication).
"""
__tablename__ = "oauth_states"
# OAuth state parameter (unique, used for CSRF protection)
state = db.Column(db.String(64), unique=True, nullable=False, index=True)
# Flow type: "login", "register", "link"
flow_type = db.Column(db.String(50), nullable=False)
# Provider type
provider_type = db.Column(db.String(50), nullable=False)
# User context (optional - not set for login/register flows)
# User context (optional not set for login/register flows)
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
# Organization context (NOW OPTIONAL - for SSO discovery or post-auth)
# Organization context (optional — for SSO discovery or post-auth)
organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"),
nullable=True, index=True
db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True
)
# PKCE parameters
nonce = db.Column(db.String(128), nullable=True)
code_verifier = db.Column(db.String(128), nullable=True)
code_challenge = db.Column(db.String(128), nullable=True)
# OAuth parameters
redirect_uri = db.Column(db.String(2048), nullable=True)
# Post-auth redirect (for frontend routing)
return_url = db.Column(db.String(2048), nullable=True)
# Additional state data
extra_data = db.Column(db.JSON, nullable=True)
# Expiration and usage tracking
expires_at = db.Column(db.DateTime, nullable=False, index=True)
used = db.Column(db.Boolean, default=False, nullable=False)
@@ -291,7 +294,10 @@ class OAuthState(BaseModel):
def __repr__(self):
"""String representation of OAuthState."""
return f"<OAuthState state={self.state[:8]}... flow={self.flow_type} provider={self.provider_type}>"
return (
f"<OAuthState state={self.state[:8]}... "
f"flow={self.flow_type} provider={self.provider_type}>"
)
@classmethod
def create_state(
@@ -306,10 +312,10 @@ class OAuthState(BaseModel):
code_challenge: str = None,
nonce: str = None,
extra_data: dict = None,
lifetime_seconds: int = 600
):
"""Create a new OAuth state with auto-generated state parameter.
lifetime_seconds: int = 600,
) -> "OAuthState":
"""Create a new OAuth state with an auto-generated state parameter.
Args:
flow_type: Type of flow ("login", "register", "link")
provider_type: OAuth provider type
@@ -322,13 +328,13 @@ class OAuthState(BaseModel):
nonce: OpenID Connect nonce
extra_data: Additional state data
lifetime_seconds: How long the state is valid (default 10 minutes)
Returns:
New OAuthState instance
"""
state = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds)
oauth_state = cls(
state=state,
flow_type=flow_type,
@@ -342,31 +348,30 @@ class OAuthState(BaseModel):
nonce=nonce,
extra_data=extra_data,
expires_at=expires_at,
used=False
used=False,
)
oauth_state.save()
return oauth_state
def is_valid(self) -> bool:
"""Check if the OAuth state is still valid.
Returns:
True if state hasn't expired and hasn't been used
True if state hasn't expired and hasn't been used.
"""
now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive (database returns naive datetimes)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return not self.used and expires_at > now
def mark_used(self):
def mark_used(self) -> None:
"""Mark the state as used to prevent replay attacks."""
self.used = True
self.save()
@classmethod
def cleanup_expired(cls):
def cleanup_expired(cls) -> None:
"""Remove expired OAuth states."""
now = datetime.now(timezone.utc)
cls.query.filter(cls.expires_at < now).delete()
@@ -375,6 +380,7 @@ class OAuthState(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Exclude code_verifier as it's sensitive
exclude.append("code_verifier")
# code_verifier must never be exposed
if "code_verifier" not in exclude:
exclude.append("code_verifier")
return super().to_dict(exclude=exclude)
@@ -0,0 +1,68 @@
"""Email verification token model."""
import secrets
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class EmailVerificationToken(BaseModel):
"""Single-use token for verifying a user's email address."""
__tablename__ = "email_verification_tokens"
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
used_at = db.Column(db.DateTime, nullable=True)
user = db.relationship(
"User",
backref=db.backref("email_verification_tokens", cascade="all, delete-orphan"),
)
@classmethod
def generate(cls, user_id: str, ttl_hours: int = 24) -> "EmailVerificationToken":
"""Create a new verification token for a user.
Any existing unused tokens for this user are invalidated first.
"""
cls.query.filter_by(user_id=user_id, used_at=None).delete()
db.session.flush()
token_value = secrets.token_urlsafe(48)
instance = cls(
user_id=user_id,
token=token_value,
expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
)
db.session.add(instance)
db.session.commit()
return instance
@property
def is_valid(self) -> bool:
"""Return True if the token has not been used and has not expired."""
if self.used_at is not None:
return False
now = datetime.now(timezone.utc)
expires = self.expires_at
if expires.tzinfo is None:
expires = expires.replace(tzinfo=timezone.utc)
return now < expires
def consume(self) -> None:
"""Mark the token as used."""
self.used_at = datetime.now(timezone.utc)
db.session.commit()
def __repr__(self) -> str:
return (
f"<EmailVerificationToken user_id={self.user_id} "
f"used={self.used_at is not None}>"
)
@@ -0,0 +1,69 @@
"""Password reset token model."""
import secrets
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class PasswordResetToken(BaseModel):
"""Single-use token for resetting a user's password."""
__tablename__ = "password_reset_tokens"
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
used_at = db.Column(db.DateTime, nullable=True)
user = db.relationship(
"User",
backref=db.backref("password_reset_tokens", cascade="all, delete-orphan"),
)
@classmethod
def generate(cls, user_id: str, ttl_hours: int = 2) -> "PasswordResetToken":
"""Create a new password reset token for a user.
Any existing unused tokens for this user are invalidated first.
"""
# Invalidate any existing unused tokens for this user
cls.query.filter_by(user_id=user_id, used_at=None).delete()
db.session.flush()
token_value = secrets.token_urlsafe(48)
instance = cls(
user_id=user_id,
token=token_value,
expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
)
db.session.add(instance)
db.session.commit()
return instance
@property
def is_valid(self) -> bool:
"""Return True if the token has not been used and has not expired."""
if self.used_at is not None:
return False
now = datetime.now(timezone.utc)
expires = self.expires_at
if expires.tzinfo is None:
expires = expires.replace(tzinfo=timezone.utc)
return now < expires
def consume(self) -> None:
"""Mark the token as used."""
self.used_at = datetime.now(timezone.utc)
db.session.commit()
def __repr__(self) -> str:
return (
f"<PasswordResetToken user_id={self.user_id} "
f"used={self.used_at is not None}>"
)
+18
View File
@@ -0,0 +1,18 @@
"""OIDC subpackage — clients, tokens, sessions, and audit logs."""
from gatehouse_app.models.oidc.oidc_client import OIDCClient
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken
from gatehouse_app.models.oidc.oidc_session import OIDCSession
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey
__all__ = [
"OIDCClient",
"OIDCAuthCode",
"OIDCRefreshToken",
"OIDCSession",
"OIDCTokenMetadata",
"OIDCAuditLog",
"OidcJwksKey",
]
@@ -1,5 +1,4 @@
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
from datetime import datetime
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
@@ -7,8 +6,7 @@ from gatehouse_app.models.base import BaseModel
class OIDCAuditLog(BaseModel):
"""OIDC Audit Log model for comprehensive OIDC event tracking.
This model logs all OIDC-related events for security, compliance,
and debugging purposes.
Logs all OIDC-related events for security, compliance, and debugging.
"""
__tablename__ = "oidc_audit_logs"
@@ -46,16 +44,29 @@ class OIDCAuditLog(BaseModel):
def __repr__(self):
"""String representation of OIDCAuditLog."""
status = "success" if self.success else "failed"
return f"<OIDCAuditLog event={self.event_type} status={status} client={self.client_id}>"
return (
f"<OIDCAuditLog event={self.event_type} "
f"status={status} client={self.client_id}>"
)
@classmethod
def log_event(cls, event_type, client_id=None, user_id=None, success=True,
error_code=None, error_description=None, ip_address=None,
user_agent=None, request_id=None, event_metadata=None):
def log_event(
cls,
event_type: str,
client_id: str = None,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
event_metadata: dict = None,
) -> "OIDCAuditLog":
"""Log an OIDC event.
Args:
event_type: Type of event (e.g., "authorization_request", "token_issue")
event_type: Type of event (e.g., "authorization_request")
client_id: The OIDC client ID
user_id: The user ID
success: Whether the event was successful
@@ -86,9 +97,19 @@ class OIDCAuditLog(BaseModel):
return log
@classmethod
def log_authorization_request(cls, client_id, user_id, redirect_uri, scope,
ip_address=None, user_agent=None, request_id=None,
success=True, error_code=None, error_description=None):
def log_authorization_request(
cls,
client_id: str,
user_id: str,
redirect_uri: str,
scope,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
) -> "OIDCAuditLog":
"""Log an authorization request event."""
return cls.log_event(
event_type="authorization_request",
@@ -100,15 +121,19 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={
"redirect_uri": redirect_uri,
"scope": scope,
}
event_metadata={"redirect_uri": redirect_uri, "scope": scope},
)
@classmethod
def log_token_issue(cls, client_id, user_id, token_type,
ip_address=None, user_agent=None, request_id=None):
def log_token_issue(
cls,
client_id: str,
user_id: str,
token_type: str,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log a token issuance event."""
return cls.log_event(
event_type="token_issue",
@@ -118,12 +143,20 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={"token_type": token_type}
event_metadata={"token_type": token_type},
)
@classmethod
def log_token_revocation(cls, client_id, user_id, token_type, reason=None,
ip_address=None, user_agent=None, request_id=None):
def log_token_revocation(
cls,
client_id: str,
user_id: str,
token_type: str,
reason: str = None,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log a token revocation event."""
return cls.log_event(
event_type="token_revocation",
@@ -133,15 +166,19 @@ class OIDCAuditLog(BaseModel):
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={
"token_type": token_type,
"reason": reason,
}
event_metadata={"token_type": token_type, "reason": reason},
)
@classmethod
def log_authentication_failure(cls, client_id, error_code, error_description,
ip_address=None, user_agent=None, request_id=None):
def log_authentication_failure(
cls,
client_id: str,
error_code: str,
error_description: str,
ip_address: str = None,
user_agent: str = None,
request_id: str = None,
) -> "OIDCAuditLog":
"""Log an authentication failure event."""
return cls.log_event(
event_type="authentication_failure",
@@ -155,7 +192,7 @@ class OIDCAuditLog(BaseModel):
)
@classmethod
def get_events_for_user(cls, user_id, limit=100):
def get_events_for_user(cls, user_id: str, limit: int = 100) -> list:
"""Get audit events for a user.
Args:
@@ -165,13 +202,15 @@ class OIDCAuditLog(BaseModel):
Returns:
List of OIDCAuditLog instances
"""
return cls.query.filter_by(user_id=user_id, deleted_at=None)\
.order_by(cls.created_at.desc())\
.limit(limit)\
return (
cls.query.filter_by(user_id=user_id, deleted_at=None)
.order_by(cls.created_at.desc())
.limit(limit)
.all()
)
@classmethod
def get_events_for_client(cls, client_id, limit=100):
def get_events_for_client(cls, client_id: str, limit: int = 100) -> list:
"""Get audit events for a client.
Args:
@@ -181,14 +220,22 @@ class OIDCAuditLog(BaseModel):
Returns:
List of OIDCAuditLog instances
"""
return cls.query.filter_by(client_id=client_id, deleted_at=None)\
.order_by(cls.created_at.desc())\
.limit(limit)\
return (
cls.query.filter_by(client_id=client_id, deleted_at=None)
.order_by(cls.created_at.desc())
.limit(limit)
.all()
)
@classmethod
def get_failed_events(cls, client_id=None, user_id=None, start_date=None,
end_date=None, limit=100):
def get_failed_events(
cls,
client_id: str = None,
user_id: str = None,
start_date=None,
end_date=None,
limit: int = 100,
) -> list:
"""Get failed audit events.
Args:
@@ -210,22 +257,8 @@ class OIDCAuditLog(BaseModel):
query = query.filter(cls.created_at >= start_date)
if end_date:
query = query.filter(cls.created_at <= end_date)
return query.order_by(cls.created_at.desc()).limit(limit).all()
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_audit_logs = db.relationship(
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.audit_logs = db.relationship(
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
)
@@ -1,14 +1,14 @@
"""OIDC Authorization Code model for auth code flow."""
"""OIDC Authorization Code model for the authorization code grant flow."""
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCAuthCode(BaseModel):
"""OIDC Authorization Code model for authorization code flow.
"""OIDC Authorization Code model for the authorization code grant flow.
Authorization codes are single-use, short-lived codes used in the
authorization code grant flow. The code is hashed for security.
Authorization codes are single-use, short-lived codes. The code itself is
hashed before storage so that a database breach cannot replay codes.
"""
__tablename__ = "oidc_authorization_codes"
@@ -26,9 +26,9 @@ class OIDCAuthCode(BaseModel):
# Request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True) # Requested scopes
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
code_verifier = db.Column(db.String(255), nullable=True) # For PKCE
scope = db.Column(db.JSON, nullable=True)
nonce = db.Column(db.String(255), nullable=True)
code_verifier = db.Column(db.String(255), nullable=True)
# Status tracking
expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -39,37 +39,48 @@ class OIDCAuthCode(BaseModel):
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
# Relationships
# Relationships — back_populates declared on User and OIDCClient
client = db.relationship("OIDCClient", back_populates="authorization_codes")
user = db.relationship("User", back_populates="oidc_auth_codes")
def __repr__(self):
"""String representation of OIDCAuthCode."""
return f"<OIDCAuthCode client_id={self.client_id} user_id={self.user_id} used={self.is_used}>"
return (
f"<OIDCAuthCode client_id={self.client_id} "
f"user_id={self.user_id} used={self.is_used}>"
)
def is_expired(self):
def is_expired(self) -> bool:
"""Check if the authorization code has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
# Make naive datetime timezone-aware (UTC)
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_valid(self):
def is_valid(self) -> bool:
"""Check if the authorization code is valid for use."""
return not self.is_used and not self.is_expired() and self.deleted_at is None
def mark_as_used(self):
def mark_as_used(self) -> None:
"""Mark the authorization code as used."""
self.is_used = True
self.used_at = datetime.now(timezone.utc)
db.session.commit()
@classmethod
def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None,
nonce=None, code_verifier=None, ip_address=None, user_agent=None,
lifetime_seconds=600):
def create_code(
cls,
client_id: str,
user_id: str,
code_hash: str,
redirect_uri: str,
scope=None,
nonce: str = None,
code_verifier: str = None,
ip_address: str = None,
user_agent: str = None,
lifetime_seconds: int = 600,
) -> "OIDCAuthCode":
"""Create a new authorization code.
Args:
@@ -79,7 +90,7 @@ class OIDCAuthCode(BaseModel):
redirect_uri: The redirect URI
scope: Requested scopes
nonce: OIDC nonce
code_verifier: PKCE code verifier
code_verifier: PKCE code verifier (stored hashed server-side)
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
@@ -106,20 +117,7 @@ class OIDCAuthCode(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude code hash
exclude.append("code_hash")
exclude.append("code_verifier")
for field in ("code_hash", "code_verifier"):
if field not in exclude:
exclude.append(field)
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_auth_codes = db.relationship(
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.authorization_codes = db.relationship(
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
)
@@ -17,10 +17,10 @@ class OIDCClient(BaseModel):
client_secret_hash = db.Column(db.String(255), nullable=False)
# OAuth/OIDC configuration
redirect_uris = db.Column(db.JSON, nullable=False) # List of allowed redirect URIs
grant_types = db.Column(db.JSON, nullable=False) # List of allowed grant types
response_types = db.Column(db.JSON, nullable=False) # List of allowed response types
scopes = db.Column(db.JSON, nullable=False) # List of allowed scopes
redirect_uris = db.Column(db.JSON, nullable=False) # Allowed redirect URIs
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
response_types = db.Column(db.JSON, nullable=False) # Allowed response types
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
# Client metadata
logo_uri = db.Column(db.String(512), nullable=True)
@@ -41,6 +41,23 @@ class OIDCClient(BaseModel):
# Relationships
organization = db.relationship("Organization", back_populates="oidc_clients")
# OIDC sub-resource relationships (declared here, not monkey-patched elsewhere)
authorization_codes = db.relationship(
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
)
refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
)
oidc_sessions = db.relationship(
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
)
token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
)
audit_logs = db.relationship(
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
)
def __repr__(self):
"""String representation of OIDCClient."""
return f"<OIDCClient {self.name} client_id={self.client_id}>"
@@ -48,22 +65,22 @@ class OIDCClient(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude client secret
exclude.append("client_secret_hash")
if "client_secret_hash" not in exclude:
exclude.append("client_secret_hash")
return super().to_dict(exclude=exclude)
def has_grant_type(self, grant_type):
def has_grant_type(self, grant_type) -> bool:
"""Check if client supports a specific grant type."""
return grant_type in self.grant_types
def has_response_type(self, response_type):
def has_response_type(self, response_type) -> bool:
"""Check if client supports a specific response type."""
return response_type in self.response_types
def is_redirect_uri_allowed(self, redirect_uri):
def is_redirect_uri_allowed(self, redirect_uri: str) -> bool:
"""Check if a redirect URI is allowed for this client."""
return redirect_uri in self.redirect_uris
def has_scope(self, scope):
def has_scope(self, scope: str) -> bool:
"""Check if client is allowed to request a specific scope."""
return scope in self.scopes
@@ -0,0 +1,76 @@
"""OIDC JWKS Key model for persisting signing keys."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OidcJwksKey(BaseModel):
"""OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
Stores RSA/ECDSA key pairs used for signing OIDC tokens. Multiple keys can
be stored to support key rotation scenarios.
Attributes:
kid: Unique key ID used in JWT ``kid`` header
key_type: Type of key (e.g., "RSA", "EC")
private_key: PEM-encoded private key (never exposed in API responses)
public_key: PEM-encoded public key
algorithm: Signing algorithm (e.g., "RS256", "ES256")
is_active: Whether this key is currently used for signing/verification
is_primary: Whether this is the primary signing key
expires_at: Optional expiry for key rotation enforcement
"""
__tablename__ = "oidc_jwks_keys"
# Override the default UUID id with integer primary key for JWKS key sets
id = db.Column(db.Integer, primary_key=True)
expires_at = db.Column(db.DateTime, nullable=True)
# Key identification and type
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
# Key material (PEM-encoded) — private_key must never be returned by API
private_key = db.Column(db.Text, nullable=False)
public_key = db.Column(db.Text, nullable=False)
# Key status
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_primary = db.Column(db.Boolean, default=False, nullable=False)
def __repr__(self):
"""String representation of OidcJwksKey."""
return (
f"<OidcJwksKey kid={self.kid} "
f"key_type={self.key_type} algorithm={self.algorithm}>"
)
def to_dict(self, exclude_private_key: bool = True):
"""Convert model to dictionary.
Args:
exclude_private_key: If True (default), excludes the private key.
Returns:
Dictionary representation of the model
"""
exclude = ["private_key"] if exclude_private_key else []
return super().to_dict(exclude=exclude)
@classmethod
def get_active_keys(cls) -> list:
"""Get all active keys for signing operations."""
return cls.query.filter_by(is_active=True).all()
@classmethod
def get_primary_key(cls) -> "OidcJwksKey | None":
"""Get the primary signing key."""
return cls.query.filter_by(is_primary=True).first()
@classmethod
def get_key_by_kid(cls, kid: str) -> "OidcJwksKey | None":
"""Get an active key by its key ID."""
return cls.query.filter_by(kid=kid, is_active=True).first()
@@ -1,5 +1,5 @@
"""OIDC Refresh Token model for token rotation."""
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
@@ -8,7 +8,8 @@ class OIDCRefreshToken(BaseModel):
"""OIDC Refresh Token model for token refresh and rotation.
Refresh tokens are long-lived credentials used to obtain new access tokens.
They support token rotation for enhanced security.
They support token rotation for enhanced security each use invalidates
the old token and issues a new one.
"""
__tablename__ = "oidc_refresh_tokens"
@@ -21,16 +22,14 @@ class OIDCRefreshToken(BaseModel):
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Token (hashed for security)
# Token (hashed for security — never store plaintext refresh tokens)
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
# Associated access token ID (stores JWT JTI string — no FK to sessions)
access_token_id = db.Column(
db.String(255), nullable=True, index=True
)
# Associated access token JTI (no FK — stored as string for lightweight lookup)
access_token_id = db.Column(db.String(255), nullable=True, index=True)
# Token scope
scope = db.Column(db.JSON, nullable=True) # Granted scopes
scope = db.Column(db.JSON, nullable=True)
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -40,7 +39,7 @@ class OIDCRefreshToken(BaseModel):
revoked_reason = db.Column(db.String(255), nullable=True)
# Token rotation metadata
previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation
previous_token_hash = db.Column(db.String(255), nullable=True)
rotation_count = db.Column(db.Integer, default=0, nullable=False)
# Request metadata
@@ -53,25 +52,27 @@ class OIDCRefreshToken(BaseModel):
def __repr__(self):
"""String representation of OIDCRefreshToken."""
return f"<OIDCRefreshToken client_id={self.client_id} user_id={self.user_id} revoked={self.is_revoked()}>"
return (
f"<OIDCRefreshToken client_id={self.client_id} "
f"user_id={self.user_id} revoked={self.is_revoked()}>"
)
def is_expired(self):
def is_expired(self) -> bool:
"""Check if the refresh token has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self):
def is_revoked(self) -> bool:
"""Check if the refresh token has been revoked."""
return self.revoked_at is not None
def is_valid(self):
def is_valid(self) -> bool:
"""Check if the refresh token is valid for use."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason=None):
def revoke(self, reason: str = None) -> None:
"""Revoke the refresh token.
Args:
@@ -81,8 +82,8 @@ class OIDCRefreshToken(BaseModel):
self.revoked_reason = reason
db.session.commit()
def rotate(self, new_token_hash):
"""Rotate the refresh token (invalidate old, create new).
def rotate(self, new_token_hash: str) -> "OIDCRefreshToken":
"""Rotate the refresh token invalidate the old hash, store the new one.
Args:
new_token_hash: Hash of the new refresh token
@@ -90,20 +91,25 @@ class OIDCRefreshToken(BaseModel):
Returns:
self for chaining
"""
# Store reference to old token
self.previous_token_hash = self.token_hash
self.token_hash = new_token_hash
self.rotation_count += 1
# Extend expiration on rotation
from datetime import timedelta
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
db.session.commit()
return self
@classmethod
def create_token(cls, client_id, user_id, token_hash, scope=None,
access_token_id=None, ip_address=None, user_agent=None,
lifetime_seconds=2592000):
def create_token(
cls,
client_id: str,
user_id: str,
token_hash: str,
scope=None,
access_token_id: str = None,
ip_address: str = None,
user_agent: str = None,
lifetime_seconds: int = 2592000,
) -> "OIDCRefreshToken":
"""Create a new refresh token.
Args:
@@ -111,7 +117,7 @@ class OIDCRefreshToken(BaseModel):
user_id: The user ID
token_hash: Hashed refresh token
scope: Granted scopes
access_token_id: Associated access token ID
access_token_id: Associated access token JTI
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Token lifetime in seconds (default 30 days)
@@ -119,7 +125,6 @@ class OIDCRefreshToken(BaseModel):
Returns:
OIDCRefreshToken instance
"""
from datetime import timedelta
token = cls(
client_id=client_id,
user_id=user_id,
@@ -137,20 +142,7 @@ class OIDCRefreshToken(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude token hashes
exclude.append("token_hash")
exclude.append("previous_token_hash")
for field in ("token_hash", "previous_token_hash"):
if field not in exclude:
exclude.append(field)
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
)
@@ -1,5 +1,7 @@
"""OIDC Session model for OIDC session tracking."""
from datetime import datetime, timezone
import hashlib
import base64
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
@@ -7,8 +9,8 @@ from gatehouse_app.models.base import BaseModel
class OIDCSession(BaseModel):
"""OIDC Session model for tracking OIDC authentication sessions.
This model tracks the state during the OIDC authentication flow,
including PKCE parameters and nonce validation.
Tracks the state during the OIDC authorization flow, including PKCE
parameters and nonce validation.
"""
__tablename__ = "oidc_sessions"
@@ -25,11 +27,11 @@ class OIDCSession(BaseModel):
# State management
state = db.Column(db.String(255), nullable=False, index=True)
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
nonce = db.Column(db.String(255), nullable=True)
# Authorization request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True) # Requested scopes
scope = db.Column(db.JSON, nullable=True)
# PKCE parameters
code_challenge = db.Column(db.String(255), nullable=True)
@@ -45,50 +47,52 @@ class OIDCSession(BaseModel):
def __repr__(self):
"""String representation of OIDCSession."""
return f"<OIDCSession user_id={self.user_id} client_id={self.client_id} state={self.state[:8]}...>"
return (
f"<OIDCSession user_id={self.user_id} "
f"client_id={self.client_id} state={self.state[:8]}...>"
)
def is_expired(self):
def is_expired(self) -> bool:
"""Check if the OIDC session has expired."""
return datetime.now(timezone.utc) > self.expires_at
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_authenticated(self):
def is_authenticated(self) -> bool:
"""Check if the user has been authenticated in this session."""
return self.authenticated_at is not None
def mark_authenticated(self):
def mark_authenticated(self) -> None:
"""Mark the session as authenticated."""
self.authenticated_at = datetime.now(timezone.utc)
db.session.commit()
def validate_nonce(self, expected_nonce):
def validate_nonce(self, expected_nonce: str) -> bool:
"""Validate the nonce matches the expected value.
Args:
expected_nonce: The expected nonce value
Returns:
bool: True if nonce matches
True if nonce matches
"""
return self.nonce == expected_nonce
def validate_code_challenge(self, code_verifier):
def validate_code_challenge(self, code_verifier: str) -> bool:
"""Validate the code verifier against the stored code challenge.
Args:
code_verifier: The PKCE code verifier
Returns:
bool: True if code challenge is valid
True if the challenge is satisfied
"""
if not self.code_challenge:
return False
if self.code_challenge_method == "S256":
import hashlib
import base64
# SHA256 hash of code_verifier
digest = hashlib.sha256(code_verifier.encode()).digest()
# Base64 URL encode without padding
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
return self.code_challenge == expected
elif self.code_challenge_method == "plain":
@@ -97,9 +101,18 @@ class OIDCSession(BaseModel):
return False
@classmethod
def create_session(cls, user_id, client_id, state, redirect_uri, scope=None,
nonce=None, code_challenge=None, code_challenge_method=None,
lifetime_seconds=600):
def create_session(
cls,
user_id: str,
client_id: str,
state: str,
redirect_uri: str,
scope=None,
nonce: str = None,
code_challenge: str = None,
code_challenge_method: str = None,
lifetime_seconds: int = 600,
) -> "OIDCSession":
"""Create a new OIDC session.
Args:
@@ -116,7 +129,6 @@ class OIDCSession(BaseModel):
Returns:
OIDCSession instance
"""
from datetime import timedelta
session = cls(
user_id=user_id,
client_id=client_id,
@@ -133,7 +145,7 @@ class OIDCSession(BaseModel):
return session
@classmethod
def get_by_state(cls, state):
def get_by_state(cls, state: str) -> "OIDCSession | None":
"""Get a session by state parameter.
Args:
@@ -147,16 +159,3 @@ class OIDCSession(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_sessions = db.relationship(
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.oidc_sessions = db.relationship(
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
)
@@ -8,13 +8,14 @@ from gatehouse_app.models.base import BaseModel
class OIDCTokenMetadata(BaseModel):
"""OIDC Token Metadata model for tracking issued tokens.
This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens)
for the purpose of token revocation. The id field matches the JTI (JWT ID) claim.
Stores metadata about issued tokens (access, refresh, ID) for revocation.
The ``id`` field on this model intentionally overrides the BaseModel UUID
to store the JWT JTI directly as the primary key for O(1) revocation checks.
"""
__tablename__ = "oidc_token_metadata"
# Token identifier (matches JTI in JWT)
# Primary key = JTI so revocation lookups are always a PK scan
id = db.Column(
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
@@ -27,11 +28,11 @@ class OIDCTokenMetadata(BaseModel):
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Token type
token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token"
# Token type: "access_token", "refresh_token", or "id_token"
token_type = db.Column(db.String(50), nullable=False)
# Token identifier for revocation lookup
token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim
# JWT ID claim (indexed for fast lookup when id != jti)
token_jti = db.Column(db.String(255), nullable=False, index=True)
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
@@ -46,25 +47,27 @@ class OIDCTokenMetadata(BaseModel):
def __repr__(self):
"""String representation of OIDCTokenMetadata."""
return f"<OIDCTokenMetadata jti={self.token_jti[:8]}... type={self.token_type} revoked={self.is_revoked()}>"
return (
f"<OIDCTokenMetadata jti={self.token_jti[:8]}... "
f"type={self.token_type} revoked={self.is_revoked()}>"
)
def is_expired(self):
def is_expired(self) -> bool:
"""Check if the token has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self):
def is_revoked(self) -> bool:
"""Check if the token has been revoked."""
return self.revoked_at is not None
def is_valid(self):
def is_valid(self) -> bool:
"""Check if the token is valid (not expired and not revoked)."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason=None):
def revoke(self, reason: str = None) -> None:
"""Revoke the token.
Args:
@@ -75,8 +78,16 @@ class OIDCTokenMetadata(BaseModel):
db.session.commit()
@classmethod
def create_metadata(cls, client_id, user_id, token_type, token_jti,
expires_at, ip_address=None, user_agent=None):
def create_metadata(
cls,
client_id: str,
user_id: str,
token_type: str,
token_jti: str,
expires_at,
ip_address: str = None,
user_agent: str = None,
) -> "OIDCTokenMetadata":
"""Create token metadata for tracking.
Args:
@@ -85,8 +96,8 @@ class OIDCTokenMetadata(BaseModel):
token_type: Type of token ("access_token", "refresh_token", "id_token")
token_jti: JWT ID claim
expires_at: Token expiration datetime
ip_address: Client IP address
user_agent: Client user agent
ip_address: Client IP address (unused column, kept for API compat)
user_agent: Client user agent (unused column, kept for API compat)
Returns:
OIDCTokenMetadata instance
@@ -104,7 +115,7 @@ class OIDCTokenMetadata(BaseModel):
return metadata
@classmethod
def get_by_jti(cls, token_jti):
def get_by_jti(cls, token_jti: str) -> "OIDCTokenMetadata | None":
"""Get token metadata by JWT ID.
Args:
@@ -116,7 +127,7 @@ class OIDCTokenMetadata(BaseModel):
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
@classmethod
def revoke_by_jti(cls, token_jti, reason=None):
def revoke_by_jti(cls, token_jti: str, reason: str = None) -> bool:
"""Revoke a token by its JWT ID.
Args:
@@ -124,7 +135,7 @@ class OIDCTokenMetadata(BaseModel):
reason: Optional revocation reason
Returns:
bool: True if token was found and revoked
True if token was found and revoked, False otherwise
"""
metadata = cls.get_by_jti(token_jti)
if metadata:
@@ -133,47 +144,53 @@ class OIDCTokenMetadata(BaseModel):
return False
@classmethod
def revoke_all_for_user(cls, user_id, client_id=None, reason=None):
def revoke_all_for_user(
cls, user_id: str, client_id: str = None, reason: str = None
) -> int:
"""Revoke all tokens for a user.
Args:
user_id: The user ID
client_id: Optional client ID to filter by
client_id: Optional client ID filter
reason: Optional revocation reason
Returns:
int: Number of tokens revoked
Number of tokens revoked
"""
query = cls.query.filter_by(user_id=user_id, deleted_at=None)
query = cls.query.filter_by(user_id=user_id, deleted_at=None).filter(
cls.revoked_at.is_(None)
)
if client_id:
query = query.filter_by(client_id=client_id)
tokens = query.filter(cls.revoked_at == None).all()
count = 0
for token in tokens:
for token in query.all():
token.revoke(reason)
count += 1
return count
@classmethod
def revoke_all_for_client(cls, client_id, user_id=None, reason=None):
def revoke_all_for_client(
cls, client_id: str, user_id: str = None, reason: str = None
) -> int:
"""Revoke all tokens for a client.
Args:
client_id: The client ID
user_id: Optional user ID to filter by
user_id: Optional user ID filter
reason: Optional revocation reason
Returns:
int: Number of tokens revoked
Number of tokens revoked
"""
query = cls.query.filter_by(client_id=client_id, deleted_at=None)
query = cls.query.filter_by(client_id=client_id, deleted_at=None).filter(
cls.revoked_at.is_(None)
)
if user_id:
query = query.filter_by(user_id=user_id)
tokens = query.filter(cls.revoked_at == None).all()
count = 0
for token in tokens:
for token in query.all():
token.revoke(reason)
count += 1
return count
@@ -181,16 +198,3 @@ class OIDCTokenMetadata(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
)
-77
View File
@@ -1,77 +0,0 @@
"""OIDC JWKS Key model for persisting signing keys."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OidcJwksKey(BaseModel):
"""
OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
This model stores RSA/ECDSA key pairs used for signing OIDC tokens.
Multiple keys can be stored to support key rotation scenarios.
Attributes:
id: Integer primary key
kid: Unique key ID used in JWT "kid" header
key_type: Type of key (e.g., "RSA", "EC")
private_key: PEM-encoded private key
public_key: PEM-encoded public key
algorithm: Signing algorithm (e.g., "RS256", "ES256")
created_at: When the key was created
is_active: Whether this key is currently active for signing
is_primary: Whether this is the primary signing key
expires_at: ...
"""
__tablename__ = "oidc_jwks_keys"
# Override the default UUID id with integer primary key
id = db.Column(db.Integer, primary_key=True)
expires_at = db.Column(db.DateTime, nullable=True)
# Key identification and type
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
# Key material (PEM-encoded)
private_key = db.Column(db.Text, nullable=False)
public_key = db.Column(db.Text, nullable=False)
# Key status
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_primary = db.Column(db.Boolean, default=False, nullable=False)
def __repr__(self):
"""String representation of OidcJwksKey."""
return f"<OidcJwksKey kid={self.kid} key_type={self.key_type} algorithm={self.algorithm}>"
def to_dict(self, exclude_private_key=True):
"""
Convert model to dictionary.
Args:
exclude_private_key: If True, excludes the private key from output
Returns:
Dictionary representation of the model
"""
exclude = ["private_key"] if exclude_private_key else []
return super().to_dict(exclude=exclude)
@classmethod
def get_active_keys(cls):
"""Get all active keys for signing operations."""
return cls.query.filter(cls.is_active == True).all()
@classmethod
def get_primary_key(cls):
"""Get the primary signing key."""
return cls.query.filter(cls.is_primary == True).first()
@classmethod
def get_key_by_kid(cls, kid):
"""Get a key by its key ID."""
return cls.query.filter(cls.kid == kid, cls.is_active == True).first()
@@ -0,0 +1,27 @@
"""Organization subpackage."""
from gatehouse_app.models.organization.organization import Organization
from gatehouse_app.models.organization.organization_member import OrganizationMember
from gatehouse_app.models.organization.department import (
Department,
DepartmentMembership,
DepartmentPrincipal,
)
from gatehouse_app.models.organization.department_cert_policy import (
DepartmentCertPolicy,
STANDARD_EXTENSIONS,
)
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
__all__ = [
"Organization",
"OrganizationMember",
"Department",
"DepartmentMembership",
"DepartmentPrincipal",
"DepartmentCertPolicy",
"STANDARD_EXTENSIONS",
"Principal",
"PrincipalMembership",
"OrgInviteToken",
]
@@ -1,14 +1,15 @@
"""Department, DepartmentMembership, and DepartmentPrincipal models."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class Department(BaseModel):
"""Department model representing an organizational unit for SSH access control.
Departments are used to group users and assign SSH principals (access levels)
to them. A user can be a member of multiple departments, and each department
can have multiple principals assigned.
Example:
- Department: "Engineering"
- Members: user1@example.com, user2@example.com
@@ -39,12 +40,15 @@ class Department(BaseModel):
back_populates="department",
cascade="all, delete-orphan",
)
cert_policy = db.relationship(
"DepartmentCertPolicy",
back_populates="department",
uselist=False,
cascade="all, delete-orphan",
)
# Unique constraint: department name per organization
__table_args__ = (
db.UniqueConstraint(
"organization_id", "name", name="uix_org_dept_name"
),
db.UniqueConstraint("organization_id", "name", name="uix_org_dept_name"),
)
def __repr__(self):
@@ -55,47 +59,46 @@ class Department(BaseModel):
"""Convert department to dictionary."""
exclude = exclude or []
data = super().to_dict(exclude=exclude)
# Add member count
data["member_count"] = len([m for m in self.memberships if m.deleted_at is None])
# Add principal count
data["principal_count"] = len([p for p in self.principal_links if p.deleted_at is None])
return data
def get_members(self, active_only=True):
def get_members(self, active_only: bool = True):
"""Get all members of this department.
Args:
active_only: If True, exclude soft-deleted members
Returns:
List of DepartmentMembership objects
"""
if active_only:
return [m for m in self.memberships if m.deleted_at is None]
return self.memberships
return list(self.memberships)
def get_principals(self, active_only=True):
def get_principals(self, active_only: bool = True):
"""Get all principals assigned to this department.
Args:
active_only: If True, exclude soft-deleted principals
Returns:
List of Principal objects via DepartmentPrincipal
"""
if active_only:
return [p.principal for p in self.principal_links if p.deleted_at is None and p.principal.deleted_at is None]
return [
p.principal
for p in self.principal_links
if p.deleted_at is None and p.principal.deleted_at is None
]
return [p.principal for p in self.principal_links]
def is_member(self, user_id):
def is_member(self, user_id: str) -> bool:
"""Check if a user is a member of this department.
Args:
user_id: ID of the user to check
Returns:
True if user is an active member, False otherwise
"""
@@ -108,14 +111,14 @@ class Department(BaseModel):
is not None
)
def get_member_count(self):
def get_member_count(self) -> int:
"""Get the count of active members in this department."""
return len(self.get_members(active_only=True))
class DepartmentMembership(BaseModel):
"""Department membership model representing user membership in a department.
When a user is added to a department, they become eligible for SSH principals
assigned to that department.
"""
@@ -139,24 +142,23 @@ class DepartmentMembership(BaseModel):
user = db.relationship("User", back_populates="department_memberships")
department = db.relationship("Department", back_populates="memberships")
# Unique constraint: user can only be member of a department once
__table_args__ = (
db.UniqueConstraint(
"user_id", "department_id", name="uix_user_dept"
),
db.UniqueConstraint("user_id", "department_id", name="uix_user_dept"),
)
def __repr__(self):
"""String representation of DepartmentMembership."""
return f"<DepartmentMembership user_id={self.user_id} dept_id={self.department_id}>"
return (
f"<DepartmentMembership user_id={self.user_id} dept_id={self.department_id}>"
)
class DepartmentPrincipal(BaseModel):
"""Department principal assignment model.
Represents the assignment of principals to departments. All members of a
department get access to its assigned principals (transitively).
Example:
- Department: "Engineering"
- Principal: "eng-prod-servers"
@@ -182,13 +184,13 @@ class DepartmentPrincipal(BaseModel):
department = db.relationship("Department", back_populates="principal_links")
principal = db.relationship("Principal", back_populates="department_links")
# Unique constraint: principal can only be assigned to a department once
__table_args__ = (
db.UniqueConstraint(
"department_id", "principal_id", name="uix_dept_principal"
),
db.UniqueConstraint("department_id", "principal_id", name="uix_dept_principal"),
)
def __repr__(self):
"""String representation of DepartmentPrincipal."""
return f"<DepartmentPrincipal dept_id={self.department_id} principal_id={self.principal_id}>"
return (
f"<DepartmentPrincipal dept_id={self.department_id} "
f"principal_id={self.principal_id}>"
)
@@ -0,0 +1,76 @@
"""DepartmentCertPolicy — per-department SSH certificate issuance rules."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
# Standard SSH certificate extensions
STANDARD_EXTENSIONS = [
"permit-X11-forwarding",
"permit-agent-forwarding",
"permit-pty",
"permit-port-forwarding",
"permit-user-rc",
]
class DepartmentCertPolicy(BaseModel):
"""SSH certificate policy for a department.
Controls:
- Whether members may choose their own expiry date (up to ``max_expiry_hours``)
- Default expiry hours when the user doesn't (or can't) pick
- Maximum expiry hours (hard ceiling, even for admins signing on behalf)
- Which SSH certificate extensions are granted to members of this department
- Any custom extensions the admin wants to add beyond the standard five
Inherits ``id``, ``created_at``, ``updated_at``, and ``deleted_at`` from
:class:`BaseModel` so soft-delete and the standard timestamp behaviour are
consistent with every other model in the application.
"""
__tablename__ = "department_cert_policies"
department_id = db.Column(
db.String(36),
db.ForeignKey("departments.id"),
nullable=False,
unique=True,
index=True,
)
# Expiry control
allow_user_expiry = db.Column(db.Boolean, nullable=False, default=False)
default_expiry_hours = db.Column(db.Integer, nullable=False, default=1)
max_expiry_hours = db.Column(db.Integer, nullable=False, default=24)
# Extensions — list of extension name strings
allowed_extensions = db.Column(
db.JSON,
nullable=False,
default=lambda: list(STANDARD_EXTENSIONS),
)
# Admin-defined extras beyond the standard five
custom_extensions = db.Column(db.JSON, nullable=False, default=list)
# Relationship back to department
department = db.relationship("Department", back_populates="cert_policy", uselist=False)
def __repr__(self):
return (
f"<DepartmentCertPolicy dept={self.department_id} "
f"allow_user_expiry={self.allow_user_expiry}>"
)
def all_extensions(self) -> list:
"""Return the full list of enabled extensions (allowed + custom)."""
return list((self.allowed_extensions or []) + (self.custom_extensions or []))
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
data = super().to_dict(exclude=exclude)
# Augment with computed / convenience fields not in the base columns
data["all_extensions"] = self.all_extensions()
data["standard_extensions"] = STANDARD_EXTENSIONS
return data
@@ -0,0 +1,77 @@
"""Organization invite token model."""
import secrets
from datetime import datetime, timezone, timedelta
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OrgInviteToken(BaseModel):
"""Token-based invitation to join an organization."""
__tablename__ = "org_invite_tokens"
organization_id = db.Column(
db.String(36),
db.ForeignKey("organizations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
invited_by_id = db.Column(
db.String(36),
db.ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
email = db.Column(db.String(255), nullable=False, index=True)
role = db.Column(db.String(64), nullable=False, default="member")
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
expires_at = db.Column(db.DateTime, nullable=False)
accepted_at = db.Column(db.DateTime, nullable=True)
organization = db.relationship(
"Organization",
backref=db.backref("invite_tokens", cascade="all, delete-orphan"),
)
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
@classmethod
def generate(
cls,
organization_id: str,
email: str,
role: str = "member",
invited_by_id: str = None,
ttl_days: int = 7,
) -> "OrgInviteToken":
"""Create a new invite token for an organization."""
token_value = secrets.token_urlsafe(48)
instance = cls(
organization_id=organization_id,
email=email.lower(),
role=role,
invited_by_id=invited_by_id,
token=token_value,
expires_at=datetime.now(timezone.utc) + timedelta(days=ttl_days),
)
db.session.add(instance)
db.session.commit()
return instance
@property
def is_valid(self) -> bool:
"""Return True if the token is unused and not expired."""
if self.accepted_at is not None:
return False
now = datetime.now(timezone.utc)
expires = self.expires_at
if expires.tzinfo is None:
expires = expires.replace(tzinfo=timezone.utc)
return now < expires
def accept(self) -> None:
"""Mark the invite as accepted."""
self.accepted_at = datetime.now(timezone.utc)
db.session.commit()
def __repr__(self) -> str:
return f"<OrgInviteToken org={self.organization_id} email={self.email}>"
@@ -61,9 +61,9 @@ class Organization(BaseModel):
return member.user
return None
def is_member(self, user_id):
def is_member(self, user_id: str) -> bool:
"""Check if a user is a member of the organization."""
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.organization.organization_member import OrganizationMember
return (
OrganizationMember.query.filter_by(
@@ -1,4 +1,4 @@
"""Organization member model."""
"""Organization member model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import OrganizationRole
@@ -21,31 +21,35 @@ class OrganizationMember(BaseModel):
joined_at = db.Column(db.DateTime, nullable=True)
# Relationships
user = db.relationship("User", foreign_keys=[user_id], back_populates="organization_memberships")
user = db.relationship(
"User", foreign_keys=[user_id], back_populates="organization_memberships"
)
organization = db.relationship("Organization", back_populates="members")
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
# Unique constraint to prevent duplicate memberships
__table_args__ = (
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"),
)
def __repr__(self):
"""String representation of OrganizationMember."""
return f"<OrganizationMember user_id={self.user_id} org_id={self.organization_id} role={self.role}>"
return (
f"<OrganizationMember user_id={self.user_id} "
f"org_id={self.organization_id} role={self.role}>"
)
def is_owner(self):
def is_owner(self) -> bool:
"""Check if member is an owner."""
return self.role == OrganizationRole.OWNER
def is_admin(self):
def is_admin(self) -> bool:
"""Check if member is an admin or owner."""
return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
def can_manage_members(self):
def can_manage_members(self) -> bool:
"""Check if member can manage other members."""
return self.is_admin()
def can_delete_organization(self):
def can_delete_organization(self) -> bool:
"""Check if member can delete the organization."""
return self.is_owner()
@@ -1,14 +1,16 @@
"""Principal and PrincipalMembership models."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class Principal(BaseModel):
"""Principal model representing an SSH principal (access level/role).
In SSH CA terminology, a principal is a string like "eng-prod-servers" or
"devops-admins" that represents a set of machines or access level. Users
can be granted access to principals, either directly or via department membership.
can be granted access to principals, either directly or via department
membership.
Example:
- Principal: "eng-prod-servers"
- Users with this principal can SSH to prod servers
@@ -39,11 +41,8 @@ class Principal(BaseModel):
cascade="all, delete-orphan",
)
# Unique constraint: principal name per organization
__table_args__ = (
db.UniqueConstraint(
"organization_id", "name", name="uix_org_principal_name"
),
db.UniqueConstraint("organization_id", "name", name="uix_org_principal_name"),
)
def __repr__(self):
@@ -54,79 +53,80 @@ class Principal(BaseModel):
"""Convert principal to dictionary."""
exclude = exclude or []
data = super().to_dict(exclude=exclude)
# Add member count
data["direct_member_count"] = len([m for m in self.memberships if m.deleted_at is None])
# Add department count
data["department_count"] = len([d for d in self.department_links if d.deleted_at is None])
data["direct_member_count"] = len(
[m for m in self.memberships if m.deleted_at is None]
)
data["department_count"] = len(
[d for d in self.department_links if d.deleted_at is None]
)
return data
def get_members(self, active_only=True):
def get_members(self, active_only: bool = True):
"""Get all users who are directly assigned to this principal.
Does NOT include users who get access via department membership.
Args:
active_only: If True, exclude soft-deleted members
Returns:
List of PrincipalMembership objects
"""
if active_only:
return [m for m in self.memberships if m.deleted_at is None]
return self.memberships
return list(self.memberships)
def get_all_members(self, active_only=True):
def get_all_members(self, active_only: bool = True):
"""Get all users who have access to this principal.
Includes both direct members and users via department membership.
Args:
active_only: If True, exclude soft-deleted members
Returns:
Set of User objects with access to this principal
"""
from gatehouse_app.models.user import User
all_users = set()
# Add direct members
all_users: set = set()
# Direct members
for membership in self.get_members(active_only=active_only):
if membership.user.deleted_at is None or not active_only:
if not active_only or membership.user.deleted_at is None:
all_users.add(membership.user)
# Add members via department assignment
# Members via department assignment
for dept_link in self.department_links:
if dept_link.deleted_at is None or not active_only:
for dept_member in dept_link.department.get_members(active_only=active_only):
if dept_member.user.deleted_at is None or not active_only:
if not active_only or dept_member.user.deleted_at is None:
all_users.add(dept_member.user)
return all_users
def get_departments(self, active_only=True):
def get_departments(self, active_only: bool = True):
"""Get all departments this principal is assigned to.
Args:
active_only: If True, exclude soft-deleted departments
Returns:
List of Department objects
"""
if active_only:
return [d.department for d in self.department_links if d.deleted_at is None and d.department.deleted_at is None]
return [
d.department
for d in self.department_links
if d.deleted_at is None and d.department.deleted_at is None
]
return [d.department for d in self.department_links]
def is_member(self, user_id, include_via_department=True):
def is_member(self, user_id: str, include_via_department: bool = True) -> bool:
"""Check if a user has access to this principal.
Args:
user_id: ID of the user to check
include_via_department: If True, check department memberships too
Returns:
True if user has access to this principal
"""
@@ -139,54 +139,49 @@ class Principal(BaseModel):
).first()
is not None
)
if has_direct:
return True
# Check department membership if requested
if not include_via_department:
return False
# Get all departments this principal is assigned to
depts = self.get_departments(active_only=True)
dept_ids = [d.id for d in depts]
# Check department membership
dept_ids = [d.id for d in self.get_departments(active_only=True)]
if not dept_ids:
return False
# Check if user is a member of any of these departments
from gatehouse_app.models.department import DepartmentMembership
from gatehouse_app.models.organization.department import DepartmentMembership
return (
DepartmentMembership.query.filter(
DepartmentMembership.user_id == user_id,
DepartmentMembership.department_id.in_(dept_ids),
DepartmentMembership.deleted_at == None,
DepartmentMembership.deleted_at.is_(None),
).first()
is not None
)
def get_member_count(self, include_via_department=True):
def get_member_count(self, include_via_department: bool = True) -> int:
"""Get the count of active members with access to this principal.
Args:
include_via_department: If True, include members via department
Returns:
Count of members
"""
if not include_via_department:
return len(self.get_members(active_only=True))
return len(self.get_all_members(active_only=True))
class PrincipalMembership(BaseModel):
"""Principal membership model representing direct user assignment to a principal.
When a user is assigned directly to a principal, they get access to that principal
for SSH authentication. This is in addition to any principals they get via
department membership.
When a user is assigned directly to a principal, they get access to that
principal for SSH authentication. This is in addition to any principals
they get via department membership.
"""
__tablename__ = "principal_memberships"
@@ -208,13 +203,13 @@ class PrincipalMembership(BaseModel):
user = db.relationship("User", back_populates="principal_memberships")
principal = db.relationship("Principal", back_populates="memberships")
# Unique constraint: user can only be member of a principal once
__table_args__ = (
db.UniqueConstraint(
"user_id", "principal_id", name="uix_user_principal"
),
db.UniqueConstraint("user_id", "principal_id", name="uix_user_principal"),
)
def __repr__(self):
"""String representation of PrincipalMembership."""
return f"<PrincipalMembership user_id={self.user_id} principal_id={self.principal_id}>"
return (
f"<PrincipalMembership user_id={self.user_id} "
f"principal_id={self.principal_id}>"
)
+12
View File
@@ -0,0 +1,12 @@
"""Security subpackage — organization and user security policies, MFA compliance."""
from gatehouse_app.models.security.organization_security_policy import (
OrganizationSecurityPolicy,
)
from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
__all__ = [
"OrganizationSecurityPolicy",
"UserSecurityPolicy",
"MfaPolicyCompliance",
]
@@ -1,4 +1,4 @@
"""MfaPolicyCompliance model."""
"""MfaPolicyCompliance model — per-user per-organization MFA compliance tracking."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import MfaComplianceStatus
@@ -7,7 +7,8 @@ from gatehouse_app.utils.constants import MfaComplianceStatus
class MfaPolicyCompliance(BaseModel):
"""MFA policy compliance tracking per user per organization.
Tracks each user's MFA compliance state separately for each organization membership.
Tracks each user's MFA compliance state separately for each organization
membership. One row per (user, org) pair.
"""
__tablename__ = "mfa_policy_compliance"
@@ -25,13 +26,13 @@ class MfaPolicyCompliance(BaseModel):
default=MfaComplianceStatus.NOT_APPLICABLE,
)
# Snapshot of org policy at the time this record became active
# Snapshot of org policy version when this record became active
policy_version = db.Column(db.Integer, nullable=False)
# When policy started applying to this user
applied_at = db.Column(db.DateTime, nullable=True)
# Final deadline for this user to comply (per user, not global)
# Final deadline for this user to comply
deadline_at = db.Column(db.DateTime, nullable=True)
# When they became compliant under this policy_version
@@ -45,9 +46,7 @@ class MfaPolicyCompliance(BaseModel):
notification_count = db.Column(db.Integer, nullable=False, default=0)
__table_args__ = (
db.UniqueConstraint(
"user_id", "organization_id", name="uix_user_org_compliance"
),
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_compliance"),
)
# Relationships
@@ -58,9 +57,11 @@ class MfaPolicyCompliance(BaseModel):
def __repr__(self):
"""String representation of MfaPolicyCompliance."""
return f"<MfaPolicyCompliance user={self.user_id} org={self.organization_id} status={self.status}>"
return (
f"<MfaPolicyCompliance user={self.user_id} "
f"org={self.organization_id} status={self.status}>"
)
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
return super().to_dict(exclude=exclude or [])
@@ -39,15 +39,19 @@ class OrganizationSecurityPolicy(BaseModel):
# Relationships
organization = db.relationship(
"Organization", back_populates="security_policy", foreign_keys=[organization_id]
"Organization",
back_populates="security_policy",
foreign_keys=[organization_id],
)
updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id])
def __repr__(self):
"""String representation of OrganizationSecurityPolicy."""
return f"<OrganizationSecurityPolicy org={self.organization_id} mode={self.mfa_policy_mode}>"
return (
f"<OrganizationSecurityPolicy "
f"org={self.organization_id} mode={self.mfa_policy_mode}>"
)
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
return super().to_dict(exclude=exclude or [])
@@ -1,4 +1,4 @@
"""UserSecurityPolicy model."""
"""UserSecurityPolicy model — per-user MFA overrides."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import MfaRequirementOverride
@@ -7,7 +7,7 @@ from gatehouse_app.utils.constants import MfaRequirementOverride
class UserSecurityPolicy(BaseModel):
"""User security policy model for per-user MFA overrides.
Stores per user overrides of organization level MFA requirements.
Stores per-user overrides of organization-level MFA requirements.
"""
__tablename__ = "user_security_policies"
@@ -25,29 +25,27 @@ class UserSecurityPolicy(BaseModel):
default=MfaRequirementOverride.INHERIT,
)
# If override is REQUIRED and you want to force a specific factor set
# If override is REQUIRED, optionally force a specific factor set
force_totp = db.Column(db.Boolean, nullable=False, default=False)
force_webauthn = db.Column(db.Boolean, nullable=False, default=False)
__table_args__ = (
db.UniqueConstraint(
"user_id", "organization_id", name="uix_user_org_policy"
),
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_policy"),
)
# Relationships
user = db.relationship(
"User", back_populates="security_policies", foreign_keys=[user_id]
)
organization = db.relationship(
"Organization", foreign_keys=[organization_id]
)
organization = db.relationship("Organization", foreign_keys=[organization_id])
def __repr__(self):
"""String representation of UserSecurityPolicy."""
return f"<UserSecurityPolicy user={self.user_id} org={self.organization_id} mode={self.mfa_override_mode}>"
return (
f"<UserSecurityPolicy user={self.user_id} "
f"org={self.organization_id} mode={self.mfa_override_mode}>"
)
def to_dict(self, exclude=None):
"""Convert to dictionary."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
return super().to_dict(exclude=exclude or [])
+17
View File
@@ -0,0 +1,17 @@
"""SSH/CA subpackage — certificate authorities, SSH keys, and certificates."""
from gatehouse_app.models.ssh_ca.ca import CA, KeyType, CertType, CaType, CAPermission
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
__all__ = [
"CA",
"KeyType",
"CertType",
"CaType",
"CAPermission",
"SSHKey",
"SSHCertificate",
"CertificateStatus",
"CertificateAuditLog",
]
@@ -1,13 +1,13 @@
"""Certificate Authority (CA) model."""
from enum import Enum
from datetime import datetime
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class KeyType(str, Enum):
"""SSH CA key types."""
ED25519 = "ed25519"
RSA = "rsa"
ECDSA = "ecdsa"
@@ -15,7 +15,7 @@ class KeyType(str, Enum):
class CertType(str, Enum):
"""SSH certificate types."""
USER = "user"
HOST = "host"
@@ -29,7 +29,7 @@ class CaType(str, Enum):
class CA(BaseModel):
"""Certificate Authority (CA) model for SSH certificate signing.
Each organization can have multiple CAs for different purposes
(e.g., production vs. staging). Private keys are encrypted at rest
and should be protected with KMS.
@@ -43,12 +43,12 @@ class CA(BaseModel):
nullable=True, # NULL for the global system-config CA
index=True,
)
# CA name and description
# CA identity
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=True)
# CA signing type: 'user' signs user certificates, 'host' signs host certificates
# CA signing type: 'user' signs user certificates, 'host' signs host certs
ca_type = db.Column(
db.Enum(CaType, values_callable=lambda x: [e.value for e in x]),
default=CaType.USER,
@@ -61,43 +61,33 @@ class CA(BaseModel):
default=KeyType.ED25519,
nullable=False,
)
# Private key (encrypted at rest by database/KMS)
# Format: PEM-encoded private key
# Private key — PEM-encoded, encrypted at rest by database/KMS
private_key = db.Column(db.Text, nullable=False)
# Public key (PEM format)
# Public key PEM format
public_key = db.Column(db.Text, nullable=False)
# SHA256 fingerprint of the public key
fingerprint = db.Column(db.String(255), nullable=False, unique=True)
# CRL (Certificate Revocation List) configuration
crl_enabled = db.Column(db.Boolean, default=True, nullable=False)
crl_endpoint = db.Column(db.String(512), nullable=True)
# Default certificate validity in hours
# Can be overridden per certificate request
default_cert_validity_hours = db.Column(
db.Integer,
default=1,
nullable=False,
)
# Default certificate validity in hours (overridable per request)
default_cert_validity_hours = db.Column(db.Integer, default=1, nullable=False)
# Maximum validity duration allowed
max_cert_validity_hours = db.Column(
db.Integer,
default=24,
nullable=False,
)
max_cert_validity_hours = db.Column(db.Integer, default=24, nullable=False)
# CA status
is_active = db.Column(db.Boolean, default=True, nullable=False, index=True)
# Key rotation tracking
rotated_at = db.Column(db.DateTime, nullable=True)
rotation_reason = db.Column(db.String(255), nullable=True)
# Relationships
organization = db.relationship("Organization", back_populates="cas")
certificates = db.relationship(
@@ -112,54 +102,53 @@ class CA(BaseModel):
)
__table_args__ = (
db.UniqueConstraint(
"organization_id", "name", name="uix_org_ca_name"
),
db.UniqueConstraint("organization_id", "name", name="uix_org_ca_name"),
db.Index("idx_ca_org_active", "organization_id", "is_active"),
)
def __repr__(self):
"""String representation of CA."""
return f"<CA {self.name} (org_id={self.organization_id}, type={self.key_type})>"
return (
f"<CA {self.name} "
f"(org_id={self.organization_id}, type={self.key_type})>"
)
def to_dict(self, exclude=None):
"""Convert CA to dictionary."""
"""Convert CA to dictionary, never exposing the private key."""
exclude = exclude or []
# Never expose private key in API responses
exclude.extend(["private_key"])
if "private_key" not in exclude:
exclude.append("private_key")
data = super().to_dict(exclude=exclude)
# Add computed fields
data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None])
data["active_certs"] = len([
c for c in self.certificates
if c.deleted_at is None and not c.revoked
])
data["revoked_certs"] = len([
c for c in self.certificates
if c.deleted_at is None and c.revoked
])
data["active_certs"] = len(
[c for c in self.certificates if c.deleted_at is None and not c.revoked]
)
data["revoked_certs"] = len(
[c for c in self.certificates if c.deleted_at is None and c.revoked]
)
return data
def get_active_certificates(self):
"""Get all active (non-revoked) certificates issued by this CA.
Returns:
List of non-revoked SSHCertificate objects
"""
def get_active_certificates(self) -> list:
"""Get all active (non-revoked) certificates issued by this CA."""
return [
c for c in self.certificates
if c.deleted_at is None and not c.revoked
c for c in self.certificates if c.deleted_at is None and not c.revoked
]
def rotate_key(self, new_private_key, new_public_key, new_fingerprint, reason=None):
def rotate_key(
self,
new_private_key: str,
new_public_key: str,
new_fingerprint: str,
reason: str = None,
) -> None:
"""Rotate the CA's key pair.
This should only be done in carefully controlled circumstances.
All existing certificates remain valid but no new certs can be
signed with the old key.
All existing certificates remain valid but no new certificates can be
signed with the old key after rotation.
Args:
new_private_key: New PEM-encoded private key
new_public_key: New PEM-encoded public key
@@ -169,7 +158,7 @@ class CA(BaseModel):
self.private_key = new_private_key
self.public_key = new_public_key
self.fingerprint = new_fingerprint
self.rotated_at = datetime.utcnow()
self.rotated_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
self.rotation_reason = reason
self.save()
@@ -178,7 +167,7 @@ class CAPermission(BaseModel):
"""Per-user CA permission model.
Controls which users are allowed to sign certificates against a specific CA.
When a CA has any permission rows the signing endpoint enforces the list;
When a CA has any permission rows, the signing endpoint enforces the list;
CAs with no rows are open to all org members (backwards-compatible default).
Permission values:
@@ -212,7 +201,10 @@ class CAPermission(BaseModel):
)
def __repr__(self):
return f"<CAPermission ca_id={self.ca_id} user_id={self.user_id} permission={self.permission}>"
return (
f"<CAPermission ca_id={self.ca_id} "
f"user_id={self.user_id} permission={self.permission}>"
)
def to_dict(self, exclude=None):
data = super().to_dict(exclude=exclude or [])
@@ -1,14 +1,14 @@
"""Certificate audit log model."""
"""Certificate audit log model — tracks SSH certificate lifecycle events."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class CertificateAuditLog(BaseModel):
"""Audit log for SSH certificate lifecycle events.
Tracks all operations on SSH certificates: signing, revocation,
validation, etc. This is separate from the general AuditLog to
provide detailed certificate operation tracking.
Tracks all operations on SSH certificates: signing, revocation, validation,
etc. Kept separate from the general AuditLog to provide detailed certificate
operation tracking without polluting the main audit stream.
"""
__tablename__ = "certificate_audit_logs"
@@ -20,33 +20,33 @@ class CertificateAuditLog(BaseModel):
nullable=False,
index=True,
)
# The user who performed the action (can be null for system actions)
# The user who performed the action (null for system actions)
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id"),
nullable=True,
index=True,
)
# Action type (e.g., "signed", "revoked", "validated", "requested")
action = db.Column(db.String(50), nullable=False, index=True)
# Request details
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.String(512), nullable=True)
request_id = db.Column(db.String(36), nullable=True)
# Detailed message
message = db.Column(db.Text, nullable=True)
# Additional context
extra_data = db.Column(db.JSON, nullable=True)
# Success/failure
# Outcome
success = db.Column(db.Boolean, default=True, nullable=False)
error_message = db.Column(db.Text, nullable=True)
# Relationships
certificate = db.relationship("SSHCertificate", back_populates="audit_logs")
user = db.relationship("User")
@@ -58,18 +58,26 @@ class CertificateAuditLog(BaseModel):
def __repr__(self):
"""String representation of CertificateAuditLog."""
return f"<CertificateAuditLog cert_id={self.certificate_id} action={self.action}>"
return (
f"<CertificateAuditLog cert_id={self.certificate_id} action={self.action}>"
)
@classmethod
def log(cls, certificate_id, action, user_id=None, **kwargs):
def log(
cls,
certificate_id: str,
action: str,
user_id: str = None,
**kwargs,
) -> "CertificateAuditLog":
"""Create a certificate audit log entry.
Args:
certificate_id: ID of the certificate
action: Action type (e.g., "signed", "revoked")
user_id: ID of the user performing the action (optional)
**kwargs: Additional fields (ip_address, user_agent, message, etc.)
Returns:
CertificateAuditLog instance
"""
@@ -77,7 +85,7 @@ class CertificateAuditLog(BaseModel):
certificate_id=certificate_id,
action=action,
user_id=user_id,
**kwargs
**kwargs,
)
log_entry.save()
return log_entry
@@ -1,26 +1,26 @@
"""SSH Certificate model."""
"""SSH Certificate model — signed SSH user/host certificates."""
from enum import Enum
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.models.ca import CertType
from gatehouse_app.models.ssh_ca.ca import CertType
class CertificateStatus(str, Enum):
"""SSH certificate lifecycle status."""
REQUESTED = "requested" # Waiting for signing
ISSUED = "issued" # Signed and valid
REVOKED = "revoked" # Manually revoked
EXPIRED = "expired" # Validity period ended
SUPERSEDED = "superseded" # Replaced by newer cert
REQUESTED = "requested" # Waiting for signing
ISSUED = "issued" # Signed and valid
REVOKED = "revoked" # Manually revoked
EXPIRED = "expired" # Validity period ended
SUPERSEDED = "superseded" # Replaced by newer certificate
class SSHCertificate(BaseModel):
"""SSH Certificate model representing a signed SSH user/host certificate.
Certificates are issued by a CA and associated with an SSH public key.
They include principals (access levels), validity periods, and other
They include principals (access levels), validity periods, and standard
OpenSSH certificate metadata.
"""
@@ -45,10 +45,10 @@ class SSHCertificate(BaseModel):
nullable=False,
index=True,
)
# Certificate content (full signed certificate in OpenSSH format)
# Certificate content full signed certificate in OpenSSH format
certificate = db.Column(db.Text, nullable=False)
# Certificate metadata
serial = db.Column(db.String(255), nullable=False, unique=True, index=True)
key_id = db.Column(db.String(255), nullable=False) # Usually user email
@@ -57,19 +57,19 @@ class SSHCertificate(BaseModel):
default=CertType.USER,
nullable=False,
)
# Principals (JSON list) - e.g., ["prod-servers", "dev-servers"]
# Principals JSON list, e.g., ["prod-servers", "dev-servers"]
principals = db.Column(db.JSON, nullable=False, default=list)
# Validity period
valid_after = db.Column(db.DateTime, nullable=False)
valid_before = db.Column(db.DateTime, nullable=False)
# Revocation status
revoked = db.Column(db.Boolean, default=False, nullable=False, index=True)
revoked_at = db.Column(db.DateTime, nullable=True)
revoke_reason = db.Column(db.String(255), nullable=True)
# Status tracking
status = db.Column(
db.Enum(CertificateStatus, values_callable=lambda x: [e.value for e in x]),
@@ -77,19 +77,17 @@ class SSHCertificate(BaseModel):
nullable=False,
index=True,
)
# Request metadata
request_ip = db.Column(db.String(45), nullable=True)
request_user_agent = db.Column(db.String(512), nullable=True)
# Critical options (JSON) - OpenSSH critical options
# See: https://man.openbsd.org/ssh-cert
# Critical options OpenSSH critical options (JSON)
critical_options = db.Column(db.JSON, nullable=True, default=dict)
# Extensions (JSON) - OpenSSH extensions
# Common ones: permit-X11-forwarding, permit-agent-forwarding, permit-pty, etc.
# Extensions OpenSSH extensions (JSON)
extensions = db.Column(db.JSON, nullable=True, default=dict)
# Relationships
ca = db.relationship("CA", back_populates="certificates")
user = db.relationship("User", back_populates="ssh_certificates")
@@ -115,67 +113,64 @@ class SSHCertificate(BaseModel):
return f"<SSHCertificate serial={self.serial[:16]}... user_id={self.user_id}>"
def to_dict(self, exclude=None):
"""Convert certificate to dictionary."""
"""Convert certificate to dictionary.
The raw ``certificate`` blob is excluded by default (it is large and
callers can request it explicitly by removing it from the exclude list).
"""
exclude = exclude or []
# Optionally exclude the certificate content (it's large)
if "certificate" not in exclude:
exclude.append("certificate")
data = super().to_dict(exclude=exclude)
# Add computed fields
data["is_valid"] = self.is_valid()
data["days_until_expiry"] = self.days_until_expiry()
return data
def is_valid(self):
def _aware(self, dt: datetime) -> datetime:
"""Return a timezone-aware UTC datetime."""
return dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt
def is_valid(self) -> bool:
"""Check if certificate is currently valid.
Returns:
True if certificate is issued, not revoked, and within validity period
"""
if self.revoked or self.status == CertificateStatus.REVOKED:
return False
now = datetime.now(timezone.utc)
valid_after = self.valid_after.replace(tzinfo=timezone.utc) if self.valid_after.tzinfo is None else self.valid_after
valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before
return valid_after <= now <= valid_before
return self._aware(self.valid_after) <= now <= self._aware(self.valid_before)
def is_expired(self):
def is_expired(self) -> bool:
"""Check if certificate has expired.
Returns:
True if current time is past valid_before
"""
now = datetime.now(timezone.utc)
valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before
return now > valid_before
return datetime.now(timezone.utc) > self._aware(self.valid_before)
def days_until_expiry(self):
def days_until_expiry(self) -> int:
"""Get number of days until certificate expires.
Returns:
Number of days remaining (negative if already expired)
"""
now = datetime.now(timezone.utc)
valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before
delta = valid_before - now
delta = self._aware(self.valid_before) - datetime.now(timezone.utc)
return delta.days + (1 if delta.seconds > 0 else 0)
def revoke(self, reason=None):
def revoke(self, reason: str = None) -> None:
"""Revoke this certificate.
Args:
reason: Optional reason for revocation
"""
self.revoked = True
self.revoked_at = datetime.utcnow()
self.revoked_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
self.revoke_reason = reason
self.status = CertificateStatus.REVOKED
self.save()
def mark_expired(self):
def mark_expired(self) -> None:
"""Mark certificate as expired when validity period ends."""
self.status = CertificateStatus.EXPIRED
self.save()
@@ -1,14 +1,14 @@
"""SSH Key model."""
from datetime import datetime
"""SSH Key model — user SSH public keys registered for certificate signing."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class SSHKey(BaseModel):
"""SSH Key model representing a user's SSH public key.
This model stores SSH public keys that users register for certificate signing.
Users must verify ownership of the key before it can be used for signing certificates.
Users register SSH public keys for certificate signing. Keys must be
verified (owner proved possession) before they can be used.
"""
__tablename__ = "ssh_keys"
@@ -19,33 +19,29 @@ class SSHKey(BaseModel):
nullable=False,
index=True,
)
# SSH key payload in OpenSSH format (e.g., "ssh-rsa AAAAB3Nz...")
# SSH key payload in OpenSSH format (e.g., "ssh-ed25519 AAAAB3Nz...")
payload = db.Column(db.Text, nullable=False, unique=True)
# SHA256 fingerprint for quick comparison
# SHA256 fingerprint for quick comparison and deduplication
fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True)
# Optional description for the key (e.g., "My laptop key")
# Optional human-readable description (e.g., "My laptop key")
description = db.Column(db.String(255), nullable=True)
# Verification status
verified = db.Column(db.Boolean, default=False, nullable=False, index=True)
verified_at = db.Column(db.DateTime, nullable=True)
# Verification challenge
# Verification challenge — shown to user once, cleared after verification
verify_text = db.Column(db.String(255), nullable=True)
verify_text_created_at = db.Column(db.DateTime, nullable=True)
# Key type extracted from the key (ssh-rsa, ssh-ed25519, etc.)
key_type = db.Column(db.String(50), nullable=True)
# Key bits/length
key_bits = db.Column(db.Integer, nullable=True)
# Comment from the key (usually email or key name)
# Key metadata extracted from the key
key_type = db.Column(db.String(50), nullable=True) # ssh-rsa, ssh-ed25519, etc.
key_bits = db.Column(db.Integer, nullable=True) # key length
key_comment = db.Column(db.String(255), nullable=True)
# Relationships
user = db.relationship("User", back_populates="ssh_keys")
certificates = db.relationship(
@@ -64,33 +60,39 @@ class SSHKey(BaseModel):
return f"<SSHKey {self.fingerprint[:16]}... user_id={self.user_id}>"
def to_dict(self, exclude=None):
"""Convert SSH key to dictionary."""
"""Convert SSH key to dictionary.
``payload`` and ``verify_text`` are never exposed through the API.
"""
exclude = exclude or []
exclude.extend(["payload", "verify_text"]) # Never expose these in API
for field in ("payload", "verify_text"):
if field not in exclude:
exclude.append(field)
data = super().to_dict(exclude=exclude)
# Add computed fields
data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None])
return data
def mark_verified(self):
"""Mark this SSH key as verified."""
def mark_verified(self) -> None:
"""Mark this SSH key as verified and clear the challenge."""
self.verified = True
self.verified_at = datetime.utcnow()
self.verified_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
self.verify_text = None
self.save()
def needs_verification_refresh(self, max_age_hours=24):
def needs_verification_refresh(self, max_age_hours: int = 24) -> bool:
"""Check if verification challenge needs to be refreshed.
Args:
max_age_hours: Maximum age of verification challenge in hours
Returns:
True if verification challenge is stale
True if verification challenge is stale or missing
"""
if not self.verify_text_created_at:
return True
age = datetime.utcnow() - self.verify_text_created_at
age = datetime.now(timezone.utc) - self.verify_text_created_at.replace(
tzinfo=timezone.utc
) if self.verify_text_created_at.tzinfo is None else (
datetime.now(timezone.utc) - self.verify_text_created_at
)
return age.total_seconds() > (max_age_hours * 3600)
+3 -176
View File
@@ -1,177 +1,4 @@
"""User model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import UserStatus
"""Backward-compatibility shim — import from gatehouse_app.models.user.user instead."""
from gatehouse_app.models.user.user import User # noqa: F401
class User(BaseModel):
"""User model representing a user account."""
__tablename__ = "users"
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
email_verified = db.Column(db.Boolean, default=False, nullable=False)
full_name = db.Column(db.String(255), nullable=True)
avatar_url = db.Column(db.String(512), nullable=True)
status = db.Column(
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
)
last_login_at = db.Column(db.DateTime, nullable=True)
last_login_ip = db.Column(db.String(45), nullable=True)
# Relationships
authentication_methods = db.relationship(
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
)
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
organization_memberships = db.relationship(
"OrganizationMember",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="OrganizationMember.user_id",
)
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
security_policies = db.relationship(
"UserSecurityPolicy",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="UserSecurityPolicy.user_id",
)
mfa_compliance = db.relationship(
"MfaPolicyCompliance",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="MfaPolicyCompliance.user_id",
)
department_memberships = db.relationship(
"DepartmentMembership",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="DepartmentMembership.user_id",
)
principal_memberships = db.relationship(
"PrincipalMembership",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="PrincipalMembership.user_id",
)
ssh_keys = db.relationship(
"SSHKey",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHKey.user_id",
)
ssh_certificates = db.relationship(
"SSHCertificate",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHCertificate.user_id",
)
def __repr__(self):
"""String representation of User."""
return f"<User {self.email}>"
def to_dict(self, exclude=None):
"""Convert user to dictionary, excluding sensitive fields by default."""
exclude = exclude or []
# Always exclude password-related fields
default_exclude = []
all_exclude = list(set(default_exclude + exclude))
return super().to_dict(exclude=all_exclude)
def has_password_auth(self):
"""Check if user has password authentication enabled."""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
).first()
is not None
)
def get_organizations(self):
"""Get all organizations the user is a member of."""
return [membership.organization for membership in self.organization_memberships]
def has_totp_enabled(self) -> bool:
"""Check if user has TOTP enabled and verified.
Returns:
True if user has a verified TOTP authentication method, False otherwise.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.TOTP,
verified=True,
deleted_at=None,
).first()
is not None
)
def get_totp_method(self):
"""Get user's TOTP authentication method.
Returns:
The AuthenticationMethod instance for TOTP or None if not found.
Note:
Returns the most recently created TOTP method to handle cases where
multiple enrollment attempts may exist.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
).order_by(AuthenticationMethod.created_at.desc()).first()
def has_webauthn_enabled(self) -> bool:
"""Check if user has any WebAuthn passkey credentials.
Returns:
True if user has at least one WebAuthn credential, False otherwise.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None,
).first()
is not None
)
def get_webauthn_credentials(self):
"""Get all WebAuthn credentials for the user.
Returns:
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
).order_by(AuthenticationMethod.created_at.desc()).all()
def get_webauthn_credential_count(self) -> int:
"""Get the count of WebAuthn credentials for the user.
Returns:
Number of WebAuthn credentials.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
).count()
__all__ = ["User"]
+5
View File
@@ -0,0 +1,5 @@
"""User subpackage."""
from gatehouse_app.models.user.user import User
from gatehouse_app.models.user.session import Session
__all__ = ["User", "Session"]
@@ -21,7 +21,9 @@ class Session(BaseModel):
# Timing
expires_at = db.Column(db.DateTime, nullable=False)
last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
last_activity_at = db.Column(
db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)
)
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
@@ -38,7 +40,6 @@ class Session(BaseModel):
def is_active(self):
"""Check if session is currently active."""
now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
@@ -51,15 +52,13 @@ class Session(BaseModel):
def is_expired(self):
"""Check if session has expired."""
now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
def refresh(self, duration_seconds=86400):
"""
Refresh session expiration.
def refresh(self, duration_seconds: int = 86400):
"""Refresh session expiration.
Args:
duration_seconds: New session duration in seconds
@@ -68,9 +67,8 @@ class Session(BaseModel):
self.last_activity_at = datetime.now(timezone.utc)
db.session.commit()
def revoke(self, reason=None):
"""
Revoke the session.
def revoke(self, reason: str = None):
"""Revoke the session.
Args:
reason: Optional reason for revocation
@@ -84,6 +82,5 @@ class Session(BaseModel):
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Exclude token from dict
exclude.append("token")
return super().to_dict(exclude=exclude)
+209
View File
@@ -0,0 +1,209 @@
"""User model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import UserStatus
class User(BaseModel):
"""User model representing a user account."""
__tablename__ = "users"
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
email_verified = db.Column(db.Boolean, default=False, nullable=False)
full_name = db.Column(db.String(255), nullable=True)
avatar_url = db.Column(db.String(512), nullable=True)
status = db.Column(
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
)
last_login_at = db.Column(db.DateTime, nullable=True)
last_login_ip = db.Column(db.String(45), nullable=True)
# Account activation (email-link flow)
activated = db.Column(db.Boolean, default=True, nullable=False)
activation_key = db.Column(db.String(128), unique=True, nullable=True, index=True)
# Relationships defined here only for models that don't circular-import
authentication_methods = db.relationship(
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
)
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
organization_memberships = db.relationship(
"OrganizationMember",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="OrganizationMember.user_id",
)
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
security_policies = db.relationship(
"UserSecurityPolicy",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="UserSecurityPolicy.user_id",
)
mfa_compliance = db.relationship(
"MfaPolicyCompliance",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="MfaPolicyCompliance.user_id",
)
department_memberships = db.relationship(
"DepartmentMembership",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="DepartmentMembership.user_id",
)
principal_memberships = db.relationship(
"PrincipalMembership",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="PrincipalMembership.user_id",
)
ssh_keys = db.relationship(
"SSHKey",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHKey.user_id",
)
ssh_certificates = db.relationship(
"SSHCertificate",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHCertificate.user_id",
)
ca_permissions = db.relationship(
"CAPermission",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="CAPermission.user_id",
)
# OIDC relationships registered here (no monkey-patching needed)
oidc_auth_codes = db.relationship(
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
)
oidc_refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
)
oidc_sessions = db.relationship(
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
)
oidc_token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
)
oidc_audit_logs = db.relationship(
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
)
def __repr__(self):
"""String representation of User."""
return f"<User {self.email}>"
def to_dict(self, exclude=None):
"""Convert user to dictionary, excluding sensitive fields by default."""
exclude = exclude or []
return super().to_dict(exclude=exclude)
def has_password_auth(self):
"""Check if user has password authentication enabled."""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
).first()
is not None
)
def get_organizations(self):
"""Get all organizations the user is a member of."""
return [membership.organization for membership in self.organization_memberships]
def has_totp_enabled(self) -> bool:
"""Check if user has TOTP enabled and verified.
Returns:
True if user has a verified TOTP authentication method, False otherwise.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.TOTP,
verified=True,
deleted_at=None,
).first()
is not None
)
def get_totp_method(self):
"""Get user's TOTP authentication method.
Returns:
The AuthenticationMethod instance for TOTP or None if not found.
Note:
Returns the most recently created TOTP method to handle cases where
multiple enrollment attempts may exist.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
)
.order_by(AuthenticationMethod.created_at.desc())
.first()
)
def has_webauthn_enabled(self) -> bool:
"""Check if user has any WebAuthn passkey credentials.
Returns:
True if user has at least one WebAuthn credential, False otherwise.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None,
).first()
is not None
)
def get_webauthn_credentials(self):
"""Get all WebAuthn credentials for the user.
Returns:
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
)
.order_by(AuthenticationMethod.created_at.desc())
.all()
)
def get_webauthn_credential_count(self) -> int:
"""Get the count of WebAuthn credentials for the user.
Returns:
Number of WebAuthn credentials.
"""
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
).count()