Chore: Refractor Models into organized file/folder
This commit is contained in:
@@ -1,76 +1,150 @@
|
||||
"""Models package."""
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.authentication_method import (
|
||||
"""Models package.
|
||||
|
||||
Sub-packages
|
||||
------------
|
||||
models.user — User, Session
|
||||
models.organization — Organization, OrganizationMember, Department,
|
||||
DepartmentMembership, DepartmentPrincipal,
|
||||
DepartmentCertPolicy, Principal, PrincipalMembership,
|
||||
OrgInviteToken
|
||||
models.auth — AuthenticationMethod, ApplicationProviderConfig,
|
||||
OrganizationProviderOverride, OAuthState,
|
||||
AuditLog, PasswordResetToken, EmailVerificationToken
|
||||
models.oidc — OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCSession,
|
||||
OIDCTokenMetadata, OIDCAuditLog, OidcJwksKey
|
||||
models.ssh_ca — CA, KeyType, CertType, CaType, CAPermission,
|
||||
SSHKey, SSHCertificate, CertificateStatus,
|
||||
CertificateAuditLog
|
||||
models.security — OrganizationSecurityPolicy, UserSecurityPolicy,
|
||||
MfaPolicyCompliance
|
||||
|
||||
All names are re-exported here so that existing code using the flat import
|
||||
style (``from gatehouse_app.models import X``) or the old per-file style
|
||||
(``from gatehouse_app.models.user import User``) continue to work unchanged.
|
||||
"""
|
||||
|
||||
# ── Base ──────────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.base import BaseModel # noqa: F401
|
||||
|
||||
# ── User ──────────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.user.user import User # noqa: F401
|
||||
from gatehouse_app.models.user.session import Session # noqa: F401
|
||||
|
||||
# ── Organization ──────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.organization.organization import Organization # noqa: F401
|
||||
from gatehouse_app.models.organization.organization_member import ( # noqa: F401
|
||||
OrganizationMember,
|
||||
)
|
||||
from gatehouse_app.models.organization.department import ( # noqa: F401
|
||||
Department,
|
||||
DepartmentMembership,
|
||||
DepartmentPrincipal,
|
||||
)
|
||||
from gatehouse_app.models.organization.department_cert_policy import ( # noqa: F401
|
||||
DepartmentCertPolicy,
|
||||
STANDARD_EXTENSIONS,
|
||||
)
|
||||
from gatehouse_app.models.organization.principal import ( # noqa: F401
|
||||
Principal,
|
||||
PrincipalMembership,
|
||||
)
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken # noqa: F401
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.auth.authentication_method import ( # noqa: F401
|
||||
AuthenticationMethod,
|
||||
ApplicationProviderConfig,
|
||||
OrganizationProviderOverride,
|
||||
OAuthState,
|
||||
)
|
||||
from gatehouse_app.models.session import Session
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken
|
||||
from gatehouse_app.models.oidc_session import OIDCSession
|
||||
from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
|
||||
from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
|
||||
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
|
||||
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
|
||||
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
||||
from gatehouse_app.models.department import (
|
||||
Department,
|
||||
DepartmentMembership,
|
||||
DepartmentPrincipal,
|
||||
from gatehouse_app.models.auth.audit_log import AuditLog # noqa: F401
|
||||
from gatehouse_app.models.auth.password_reset_token import PasswordResetToken # noqa: F401
|
||||
from gatehouse_app.models.auth.email_verification_token import ( # noqa: F401
|
||||
EmailVerificationToken,
|
||||
)
|
||||
from gatehouse_app.models.principal import (
|
||||
Principal,
|
||||
PrincipalMembership,
|
||||
|
||||
# ── OIDC ──────────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.oidc.oidc_client import OIDCClient # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_session import OIDCSession # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey # noqa: F401
|
||||
|
||||
# ── SSH / CA ──────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.ssh_ca.ca import ( # noqa: F401
|
||||
CA,
|
||||
KeyType,
|
||||
CertType,
|
||||
CaType,
|
||||
CAPermission,
|
||||
)
|
||||
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey # noqa: F401
|
||||
from gatehouse_app.models.ssh_ca.ssh_certificate import ( # noqa: F401
|
||||
SSHCertificate,
|
||||
CertificateStatus,
|
||||
)
|
||||
from gatehouse_app.models.ssh_ca.certificate_audit_log import ( # noqa: F401
|
||||
CertificateAuditLog,
|
||||
)
|
||||
|
||||
# ── Security ──────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.security.organization_security_policy import ( # noqa: F401
|
||||
OrganizationSecurityPolicy,
|
||||
)
|
||||
from gatehouse_app.models.security.user_security_policy import ( # noqa: F401
|
||||
UserSecurityPolicy,
|
||||
)
|
||||
from gatehouse_app.models.security.mfa_policy_compliance import ( # noqa: F401
|
||||
MfaPolicyCompliance,
|
||||
)
|
||||
from gatehouse_app.models.ssh_key import SSHKey
|
||||
from gatehouse_app.models.ca import CA, KeyType, CertType, CAPermission
|
||||
from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus
|
||||
from gatehouse_app.models.certificate_audit_log import CertificateAuditLog
|
||||
from gatehouse_app.models.password_reset_token import PasswordResetToken
|
||||
from gatehouse_app.models.email_verification_token import EmailVerificationToken
|
||||
from gatehouse_app.models.org_invite_token import OrgInviteToken
|
||||
|
||||
__all__ = [
|
||||
# Base
|
||||
"BaseModel",
|
||||
# User
|
||||
"User",
|
||||
"Session",
|
||||
# Organization
|
||||
"Organization",
|
||||
"OrganizationMember",
|
||||
"Department",
|
||||
"DepartmentMembership",
|
||||
"DepartmentPrincipal",
|
||||
"DepartmentCertPolicy",
|
||||
"STANDARD_EXTENSIONS",
|
||||
"Principal",
|
||||
"PrincipalMembership",
|
||||
"OrgInviteToken",
|
||||
# Auth
|
||||
"AuthenticationMethod",
|
||||
"ApplicationProviderConfig",
|
||||
"OrganizationProviderOverride",
|
||||
"OAuthState",
|
||||
"Session",
|
||||
"AuditLog",
|
||||
"PasswordResetToken",
|
||||
"EmailVerificationToken",
|
||||
# OIDC
|
||||
"OIDCClient",
|
||||
"OIDCAuthCode",
|
||||
"OIDCRefreshToken",
|
||||
"OIDCSession",
|
||||
"OIDCTokenMetadata",
|
||||
"OIDCAuditLog",
|
||||
"OrganizationSecurityPolicy",
|
||||
"UserSecurityPolicy",
|
||||
"MfaPolicyCompliance",
|
||||
"Department",
|
||||
"DepartmentMembership",
|
||||
"DepartmentPrincipal",
|
||||
"Principal",
|
||||
"PrincipalMembership",
|
||||
"SSHKey",
|
||||
"OidcJwksKey",
|
||||
# SSH / CA
|
||||
"CA",
|
||||
"KeyType",
|
||||
"CertType",
|
||||
"CaType",
|
||||
"CAPermission",
|
||||
"SSHKey",
|
||||
"SSHCertificate",
|
||||
"CertificateStatus",
|
||||
"CertificateAuditLog",
|
||||
"PasswordResetToken",
|
||||
"EmailVerificationToken",
|
||||
"OrgInviteToken",
|
||||
# Security
|
||||
"OrganizationSecurityPolicy",
|
||||
"UserSecurityPolicy",
|
||||
"MfaPolicyCompliance",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Auth subpackage — authentication methods, tokens, and audit logs."""
|
||||
from gatehouse_app.models.auth.authentication_method import (
|
||||
AuthenticationMethod,
|
||||
ApplicationProviderConfig,
|
||||
OrganizationProviderOverride,
|
||||
OAuthState,
|
||||
)
|
||||
from gatehouse_app.models.auth.audit_log import AuditLog
|
||||
from gatehouse_app.models.auth.password_reset_token import PasswordResetToken
|
||||
from gatehouse_app.models.auth.email_verification_token import EmailVerificationToken
|
||||
|
||||
__all__ = [
|
||||
"AuthenticationMethod",
|
||||
"ApplicationProviderConfig",
|
||||
"OrganizationProviderOverride",
|
||||
"OAuthState",
|
||||
"AuditLog",
|
||||
"PasswordResetToken",
|
||||
"EmailVerificationToken",
|
||||
]
|
||||
@@ -26,14 +26,13 @@ class AuditLog(BaseModel):
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Success/failure
|
||||
# Outcome
|
||||
success = db.Column(db.Boolean, default=True, nullable=False)
|
||||
error_message = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="audit_logs")
|
||||
|
||||
# Indexes for common queries
|
||||
__table_args__ = (
|
||||
db.Index("idx_audit_user_action", "user_id", "action"),
|
||||
db.Index("idx_audit_resource", "resource_type", "resource_id"),
|
||||
@@ -45,9 +44,8 @@ class AuditLog(BaseModel):
|
||||
return f"<AuditLog action={self.action} user_id={self.user_id}>"
|
||||
|
||||
@classmethod
|
||||
def log(cls, action, user_id=None, **kwargs):
|
||||
"""
|
||||
Create an audit log entry.
|
||||
def log(cls, action, user_id=None, **kwargs) -> "AuditLog":
|
||||
"""Create an audit log entry.
|
||||
|
||||
Args:
|
||||
action: AuditAction enum value
|
||||
+101
-95
@@ -1,4 +1,4 @@
|
||||
"""Authentication method model."""
|
||||
"""Authentication method model — user credentials and OAuth provider config."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import secrets
|
||||
from gatehouse_app.extensions import db
|
||||
@@ -35,7 +35,6 @@ class AuthenticationMethod(BaseModel):
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="authentication_methods")
|
||||
|
||||
# Ensure unique provider combinations
|
||||
__table_args__ = (
|
||||
db.Index("idx_user_method", "user_id", "method_type"),
|
||||
db.UniqueConstraint(
|
||||
@@ -45,13 +44,15 @@ class AuthenticationMethod(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of AuthenticationMethod."""
|
||||
return f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
|
||||
return (
|
||||
f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
|
||||
)
|
||||
|
||||
def is_password(self):
|
||||
def is_password(self) -> bool:
|
||||
"""Check if this is a password authentication method."""
|
||||
return self.method_type == AuthMethodType.PASSWORD
|
||||
|
||||
def is_oauth(self):
|
||||
def is_oauth(self) -> bool:
|
||||
"""Check if this is an OAuth authentication method."""
|
||||
return self.method_type in [
|
||||
AuthMethodType.GOOGLE,
|
||||
@@ -59,32 +60,32 @@ class AuthenticationMethod(BaseModel):
|
||||
AuthMethodType.MICROSOFT,
|
||||
]
|
||||
|
||||
def is_totp(self):
|
||||
def is_totp(self) -> bool:
|
||||
"""Check if this is a TOTP authentication method."""
|
||||
return self.method_type == AuthMethodType.TOTP
|
||||
|
||||
def is_webauthn(self):
|
||||
def is_webauthn(self) -> bool:
|
||||
"""Check if this is a WebAuthn authentication method."""
|
||||
return self.method_type == AuthMethodType.WEBAUTHN
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude password hash and TOTP secrets
|
||||
exclude.append("password_hash")
|
||||
exclude.append("totp_secret")
|
||||
exclude.append("totp_backup_codes")
|
||||
# Always exclude credential material
|
||||
for field in ("password_hash", "totp_secret", "totp_backup_codes"):
|
||||
if field not in exclude:
|
||||
exclude.append(field)
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
def to_webauthn_dict(self):
|
||||
"""Convert WebAuthn credential to public dictionary.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with safe-to-expose credential information.
|
||||
Dictionary with safe-to-expose credential information, or None.
|
||||
"""
|
||||
if not self.is_webauthn() or not self.provider_data:
|
||||
return None
|
||||
|
||||
|
||||
data = self.provider_data
|
||||
return {
|
||||
"id": data.get("credential_id"),
|
||||
@@ -98,26 +99,26 @@ class AuthenticationMethod(BaseModel):
|
||||
|
||||
class ApplicationProviderConfig(BaseModel):
|
||||
"""Application-wide OAuth provider configuration.
|
||||
|
||||
This model stores OAuth provider credentials at the application level,
|
||||
allowing users to authenticate without needing to specify an organization first.
|
||||
|
||||
Stores OAuth provider credentials at the application level, allowing users
|
||||
to authenticate without needing to specify an organization first.
|
||||
"""
|
||||
|
||||
__tablename__ = "application_provider_configs"
|
||||
|
||||
# Provider identification
|
||||
provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True)
|
||||
|
||||
# OAuth credentials (encrypted)
|
||||
|
||||
# OAuth credentials (client_secret encrypted at rest)
|
||||
client_id = db.Column(db.String(255), nullable=False)
|
||||
client_secret_encrypted = db.Column(db.String(512), nullable=True)
|
||||
|
||||
|
||||
# Provider status
|
||||
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
|
||||
# Default redirect URL
|
||||
default_redirect_url = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
|
||||
# Provider-specific settings (JSON)
|
||||
additional_config = db.Column(db.JSON, nullable=True)
|
||||
|
||||
@@ -126,28 +127,34 @@ class ApplicationProviderConfig(BaseModel):
|
||||
"OrganizationProviderOverride",
|
||||
back_populates="application_config",
|
||||
foreign_keys="OrganizationProviderOverride.provider_type",
|
||||
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
|
||||
cascade="all, delete-orphan"
|
||||
primaryjoin=(
|
||||
"ApplicationProviderConfig.provider_type"
|
||||
"==OrganizationProviderOverride.provider_type"
|
||||
),
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of ApplicationProviderConfig."""
|
||||
return f"<ApplicationProviderConfig provider={self.provider_type} enabled={self.is_enabled}>"
|
||||
return (
|
||||
f"<ApplicationProviderConfig provider={self.provider_type} "
|
||||
f"enabled={self.is_enabled}>"
|
||||
)
|
||||
|
||||
def set_client_secret(self, plaintext_secret: str):
|
||||
def set_client_secret(self, plaintext_secret: str) -> None:
|
||||
"""Encrypt and store client secret.
|
||||
|
||||
|
||||
Args:
|
||||
plaintext_secret: The plaintext OAuth client secret
|
||||
"""
|
||||
if plaintext_secret:
|
||||
self.client_secret_encrypted = encrypt(plaintext_secret)
|
||||
|
||||
def get_client_secret(self) -> str:
|
||||
def get_client_secret(self) -> str | None:
|
||||
"""Decrypt and return client secret.
|
||||
|
||||
|
||||
Returns:
|
||||
The plaintext OAuth client secret
|
||||
The plaintext OAuth client secret, or None if not set.
|
||||
"""
|
||||
if self.client_secret_encrypted:
|
||||
return decrypt(self.client_secret_encrypted)
|
||||
@@ -156,37 +163,38 @@ class ApplicationProviderConfig(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude encrypted client secret
|
||||
exclude.append("client_secret_encrypted")
|
||||
if "client_secret_encrypted" not in exclude:
|
||||
exclude.append("client_secret_encrypted")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
class OrganizationProviderOverride(BaseModel):
|
||||
"""Organization-specific OAuth configuration overrides.
|
||||
|
||||
This model allows organizations to override application-level OAuth settings
|
||||
for enterprise SSO scenarios or custom provider configurations.
|
||||
|
||||
Allows organizations to override application-level OAuth settings for
|
||||
enterprise SSO scenarios or custom provider configurations.
|
||||
"""
|
||||
|
||||
__tablename__ = "organization_provider_overrides"
|
||||
|
||||
# References
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"),
|
||||
nullable=False, index=True
|
||||
db.String(36),
|
||||
db.ForeignKey("organizations.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
provider_type = db.Column(db.String(50), nullable=False, index=True)
|
||||
|
||||
# Override OAuth credentials (encrypted, nullable - only if overriding)
|
||||
|
||||
# Override OAuth credentials (encrypted, nullable — only set when overriding)
|
||||
client_id = db.Column(db.String(255), nullable=True)
|
||||
client_secret_encrypted = db.Column(db.String(512), nullable=True)
|
||||
|
||||
|
||||
# Provider status
|
||||
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
|
||||
# Redirect URL override
|
||||
redirect_url_override = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
|
||||
# Provider-specific settings override (JSON)
|
||||
additional_config = db.Column(db.JSON, nullable=True)
|
||||
|
||||
@@ -196,37 +204,33 @@ class OrganizationProviderOverride(BaseModel):
|
||||
"ApplicationProviderConfig",
|
||||
back_populates="organization_overrides",
|
||||
foreign_keys=[provider_type],
|
||||
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
|
||||
viewonly=True
|
||||
primaryjoin=(
|
||||
"ApplicationProviderConfig.provider_type"
|
||||
"==OrganizationProviderOverride.provider_type"
|
||||
),
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
# Unique constraint on (organization_id, provider_type)
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"organization_id", "provider_type",
|
||||
name="uix_org_provider_type"
|
||||
"organization_id", "provider_type", name="uix_org_provider_type"
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OrganizationProviderOverride."""
|
||||
return f"<OrganizationProviderOverride org={self.organization_id} provider={self.provider_type}>"
|
||||
return (
|
||||
f"<OrganizationProviderOverride org={self.organization_id} "
|
||||
f"provider={self.provider_type}>"
|
||||
)
|
||||
|
||||
def set_client_secret(self, plaintext_secret: str):
|
||||
"""Encrypt and store client secret override.
|
||||
|
||||
Args:
|
||||
plaintext_secret: The plaintext OAuth client secret
|
||||
"""
|
||||
def set_client_secret(self, plaintext_secret: str) -> None:
|
||||
"""Encrypt and store client secret override."""
|
||||
if plaintext_secret:
|
||||
self.client_secret_encrypted = encrypt(plaintext_secret)
|
||||
|
||||
def get_client_secret(self) -> str:
|
||||
"""Decrypt and return client secret override.
|
||||
|
||||
Returns:
|
||||
The plaintext OAuth client secret
|
||||
"""
|
||||
def get_client_secret(self) -> str | None:
|
||||
"""Decrypt and return client secret override."""
|
||||
if self.client_secret_encrypted:
|
||||
return decrypt(self.client_secret_encrypted)
|
||||
return None
|
||||
@@ -234,53 +238,52 @@ class OrganizationProviderOverride(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude encrypted client secret
|
||||
exclude.append("client_secret_encrypted")
|
||||
if "client_secret_encrypted" not in exclude:
|
||||
exclude.append("client_secret_encrypted")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
"""OAuth flow state tracking.
|
||||
|
||||
This model tracks OAuth authentication flow state, including PKCE parameters
|
||||
and organization context (which is now optional to support login flows where
|
||||
the organization isn't known until after authentication).
|
||||
|
||||
Tracks OAuth authentication flow state, including PKCE parameters and
|
||||
organization context (which is optional to support login flows where the
|
||||
organization isn't known until after authentication).
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
# OAuth state parameter (unique, used for CSRF protection)
|
||||
state = db.Column(db.String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
|
||||
# Flow type: "login", "register", "link"
|
||||
flow_type = db.Column(db.String(50), nullable=False)
|
||||
|
||||
|
||||
# Provider type
|
||||
provider_type = db.Column(db.String(50), nullable=False)
|
||||
|
||||
# User context (optional - not set for login/register flows)
|
||||
|
||||
# User context (optional — not set for login/register flows)
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
|
||||
|
||||
# Organization context (NOW OPTIONAL - for SSO discovery or post-auth)
|
||||
|
||||
# Organization context (optional — for SSO discovery or post-auth)
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"),
|
||||
nullable=True, index=True
|
||||
db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
|
||||
# PKCE parameters
|
||||
nonce = db.Column(db.String(128), nullable=True)
|
||||
code_verifier = db.Column(db.String(128), nullable=True)
|
||||
code_challenge = db.Column(db.String(128), nullable=True)
|
||||
|
||||
|
||||
# OAuth parameters
|
||||
redirect_uri = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
|
||||
# Post-auth redirect (for frontend routing)
|
||||
return_url = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
|
||||
# Additional state data
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
|
||||
|
||||
# Expiration and usage tracking
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
used = db.Column(db.Boolean, default=False, nullable=False)
|
||||
@@ -291,7 +294,10 @@ class OAuthState(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OAuthState."""
|
||||
return f"<OAuthState state={self.state[:8]}... flow={self.flow_type} provider={self.provider_type}>"
|
||||
return (
|
||||
f"<OAuthState state={self.state[:8]}... "
|
||||
f"flow={self.flow_type} provider={self.provider_type}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_state(
|
||||
@@ -306,10 +312,10 @@ class OAuthState(BaseModel):
|
||||
code_challenge: str = None,
|
||||
nonce: str = None,
|
||||
extra_data: dict = None,
|
||||
lifetime_seconds: int = 600
|
||||
):
|
||||
"""Create a new OAuth state with auto-generated state parameter.
|
||||
|
||||
lifetime_seconds: int = 600,
|
||||
) -> "OAuthState":
|
||||
"""Create a new OAuth state with an auto-generated state parameter.
|
||||
|
||||
Args:
|
||||
flow_type: Type of flow ("login", "register", "link")
|
||||
provider_type: OAuth provider type
|
||||
@@ -322,13 +328,13 @@ class OAuthState(BaseModel):
|
||||
nonce: OpenID Connect nonce
|
||||
extra_data: Additional state data
|
||||
lifetime_seconds: How long the state is valid (default 10 minutes)
|
||||
|
||||
|
||||
Returns:
|
||||
New OAuthState instance
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds)
|
||||
|
||||
|
||||
oauth_state = cls(
|
||||
state=state,
|
||||
flow_type=flow_type,
|
||||
@@ -342,31 +348,30 @@ class OAuthState(BaseModel):
|
||||
nonce=nonce,
|
||||
extra_data=extra_data,
|
||||
expires_at=expires_at,
|
||||
used=False
|
||||
used=False,
|
||||
)
|
||||
oauth_state.save()
|
||||
return oauth_state
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the OAuth state is still valid.
|
||||
|
||||
|
||||
Returns:
|
||||
True if state hasn't expired and hasn't been used
|
||||
True if state hasn't expired and hasn't been used.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make expires_at timezone-aware if it's naive (database returns naive datetimes)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return not self.used and expires_at > now
|
||||
|
||||
def mark_used(self):
|
||||
def mark_used(self) -> None:
|
||||
"""Mark the state as used to prevent replay attacks."""
|
||||
self.used = True
|
||||
self.save()
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired(cls):
|
||||
def cleanup_expired(cls) -> None:
|
||||
"""Remove expired OAuth states."""
|
||||
now = datetime.now(timezone.utc)
|
||||
cls.query.filter(cls.expires_at < now).delete()
|
||||
@@ -375,6 +380,7 @@ class OAuthState(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Exclude code_verifier as it's sensitive
|
||||
exclude.append("code_verifier")
|
||||
# code_verifier must never be exposed
|
||||
if "code_verifier" not in exclude:
|
||||
exclude.append("code_verifier")
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Email verification token model."""
|
||||
import secrets
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class EmailVerificationToken(BaseModel):
|
||||
"""Single-use token for verifying a user's email address."""
|
||||
|
||||
__tablename__ = "email_verification_tokens"
|
||||
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
user = db.relationship(
|
||||
"User",
|
||||
backref=db.backref("email_verification_tokens", cascade="all, delete-orphan"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate(cls, user_id: str, ttl_hours: int = 24) -> "EmailVerificationToken":
|
||||
"""Create a new verification token for a user.
|
||||
|
||||
Any existing unused tokens for this user are invalidated first.
|
||||
"""
|
||||
cls.query.filter_by(user_id=user_id, used_at=None).delete()
|
||||
db.session.flush()
|
||||
|
||||
token_value = secrets.token_urlsafe(48)
|
||||
instance = cls(
|
||||
user_id=user_id,
|
||||
token=token_value,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
|
||||
)
|
||||
db.session.add(instance)
|
||||
db.session.commit()
|
||||
return instance
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Return True if the token has not been used and has not expired."""
|
||||
if self.used_at is not None:
|
||||
return False
|
||||
now = datetime.now(timezone.utc)
|
||||
expires = self.expires_at
|
||||
if expires.tzinfo is None:
|
||||
expires = expires.replace(tzinfo=timezone.utc)
|
||||
return now < expires
|
||||
|
||||
def consume(self) -> None:
|
||||
"""Mark the token as used."""
|
||||
self.used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<EmailVerificationToken user_id={self.user_id} "
|
||||
f"used={self.used_at is not None}>"
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Password reset token model."""
|
||||
import secrets
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class PasswordResetToken(BaseModel):
|
||||
"""Single-use token for resetting a user's password."""
|
||||
|
||||
__tablename__ = "password_reset_tokens"
|
||||
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
user = db.relationship(
|
||||
"User",
|
||||
backref=db.backref("password_reset_tokens", cascade="all, delete-orphan"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate(cls, user_id: str, ttl_hours: int = 2) -> "PasswordResetToken":
|
||||
"""Create a new password reset token for a user.
|
||||
|
||||
Any existing unused tokens for this user are invalidated first.
|
||||
"""
|
||||
# Invalidate any existing unused tokens for this user
|
||||
cls.query.filter_by(user_id=user_id, used_at=None).delete()
|
||||
db.session.flush()
|
||||
|
||||
token_value = secrets.token_urlsafe(48)
|
||||
instance = cls(
|
||||
user_id=user_id,
|
||||
token=token_value,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
|
||||
)
|
||||
db.session.add(instance)
|
||||
db.session.commit()
|
||||
return instance
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Return True if the token has not been used and has not expired."""
|
||||
if self.used_at is not None:
|
||||
return False
|
||||
now = datetime.now(timezone.utc)
|
||||
expires = self.expires_at
|
||||
if expires.tzinfo is None:
|
||||
expires = expires.replace(tzinfo=timezone.utc)
|
||||
return now < expires
|
||||
|
||||
def consume(self) -> None:
|
||||
"""Mark the token as used."""
|
||||
self.used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<PasswordResetToken user_id={self.user_id} "
|
||||
f"used={self.used_at is not None}>"
|
||||
)
|
||||
@@ -0,0 +1,18 @@
|
||||
"""OIDC subpackage — clients, tokens, sessions, and audit logs."""
|
||||
from gatehouse_app.models.oidc.oidc_client import OIDCClient
|
||||
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken
|
||||
from gatehouse_app.models.oidc.oidc_session import OIDCSession
|
||||
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata
|
||||
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog
|
||||
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey
|
||||
|
||||
__all__ = [
|
||||
"OIDCClient",
|
||||
"OIDCAuthCode",
|
||||
"OIDCRefreshToken",
|
||||
"OIDCSession",
|
||||
"OIDCTokenMetadata",
|
||||
"OIDCAuditLog",
|
||||
"OidcJwksKey",
|
||||
]
|
||||
+83
-50
@@ -1,5 +1,4 @@
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
|
||||
from datetime import datetime
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
@@ -7,8 +6,7 @@ from gatehouse_app.models.base import BaseModel
|
||||
class OIDCAuditLog(BaseModel):
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking.
|
||||
|
||||
This model logs all OIDC-related events for security, compliance,
|
||||
and debugging purposes.
|
||||
Logs all OIDC-related events for security, compliance, and debugging.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_audit_logs"
|
||||
@@ -46,16 +44,29 @@ class OIDCAuditLog(BaseModel):
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuditLog."""
|
||||
status = "success" if self.success else "failed"
|
||||
return f"<OIDCAuditLog event={self.event_type} status={status} client={self.client_id}>"
|
||||
return (
|
||||
f"<OIDCAuditLog event={self.event_type} "
|
||||
f"status={status} client={self.client_id}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_event(cls, event_type, client_id=None, user_id=None, success=True,
|
||||
error_code=None, error_description=None, ip_address=None,
|
||||
user_agent=None, request_id=None, event_metadata=None):
|
||||
def log_event(
|
||||
cls,
|
||||
event_type: str,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
event_metadata: dict = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log an OIDC event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event (e.g., "authorization_request", "token_issue")
|
||||
event_type: Type of event (e.g., "authorization_request")
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
success: Whether the event was successful
|
||||
@@ -86,9 +97,19 @@ class OIDCAuditLog(BaseModel):
|
||||
return log
|
||||
|
||||
@classmethod
|
||||
def log_authorization_request(cls, client_id, user_id, redirect_uri, scope,
|
||||
ip_address=None, user_agent=None, request_id=None,
|
||||
success=True, error_code=None, error_description=None):
|
||||
def log_authorization_request(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
redirect_uri: str,
|
||||
scope,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log an authorization request event."""
|
||||
return cls.log_event(
|
||||
event_type="authorization_request",
|
||||
@@ -100,15 +121,19 @@ class OIDCAuditLog(BaseModel):
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": scope,
|
||||
}
|
||||
event_metadata={"redirect_uri": redirect_uri, "scope": scope},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_issue(cls, client_id, user_id, token_type,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
def log_token_issue(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_type: str,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log a token issuance event."""
|
||||
return cls.log_event(
|
||||
event_type="token_issue",
|
||||
@@ -118,12 +143,20 @@ class OIDCAuditLog(BaseModel):
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={"token_type": token_type}
|
||||
event_metadata={"token_type": token_type},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_revocation(cls, client_id, user_id, token_type, reason=None,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
def log_token_revocation(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_type: str,
|
||||
reason: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log a token revocation event."""
|
||||
return cls.log_event(
|
||||
event_type="token_revocation",
|
||||
@@ -133,15 +166,19 @@ class OIDCAuditLog(BaseModel):
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={
|
||||
"token_type": token_type,
|
||||
"reason": reason,
|
||||
}
|
||||
event_metadata={"token_type": token_type, "reason": reason},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_authentication_failure(cls, client_id, error_code, error_description,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
def log_authentication_failure(
|
||||
cls,
|
||||
client_id: str,
|
||||
error_code: str,
|
||||
error_description: str,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log an authentication failure event."""
|
||||
return cls.log_event(
|
||||
event_type="authentication_failure",
|
||||
@@ -155,7 +192,7 @@ class OIDCAuditLog(BaseModel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_user(cls, user_id, limit=100):
|
||||
def get_events_for_user(cls, user_id: str, limit: int = 100) -> list:
|
||||
"""Get audit events for a user.
|
||||
|
||||
Args:
|
||||
@@ -165,13 +202,15 @@ class OIDCAuditLog(BaseModel):
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return cls.query.filter_by(user_id=user_id, deleted_at=None)\
|
||||
.order_by(cls.created_at.desc())\
|
||||
.limit(limit)\
|
||||
return (
|
||||
cls.query.filter_by(user_id=user_id, deleted_at=None)
|
||||
.order_by(cls.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_client(cls, client_id, limit=100):
|
||||
def get_events_for_client(cls, client_id: str, limit: int = 100) -> list:
|
||||
"""Get audit events for a client.
|
||||
|
||||
Args:
|
||||
@@ -181,14 +220,22 @@ class OIDCAuditLog(BaseModel):
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return cls.query.filter_by(client_id=client_id, deleted_at=None)\
|
||||
.order_by(cls.created_at.desc())\
|
||||
.limit(limit)\
|
||||
return (
|
||||
cls.query.filter_by(client_id=client_id, deleted_at=None)
|
||||
.order_by(cls.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_failed_events(cls, client_id=None, user_id=None, start_date=None,
|
||||
end_date=None, limit=100):
|
||||
def get_failed_events(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
limit: int = 100,
|
||||
) -> list:
|
||||
"""Get failed audit events.
|
||||
|
||||
Args:
|
||||
@@ -210,22 +257,8 @@ class OIDCAuditLog(BaseModel):
|
||||
query = query.filter(cls.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(cls.created_at <= end_date)
|
||||
|
||||
return query.order_by(cls.created_at.desc()).limit(limit).all()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
+32
-34
@@ -1,14 +1,14 @@
|
||||
"""OIDC Authorization Code model for auth code flow."""
|
||||
"""OIDC Authorization Code model for the authorization code grant flow."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCAuthCode(BaseModel):
|
||||
"""OIDC Authorization Code model for authorization code flow.
|
||||
"""OIDC Authorization Code model for the authorization code grant flow.
|
||||
|
||||
Authorization codes are single-use, short-lived codes used in the
|
||||
authorization code grant flow. The code is hashed for security.
|
||||
Authorization codes are single-use, short-lived codes. The code itself is
|
||||
hashed before storage so that a database breach cannot replay codes.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_authorization_codes"
|
||||
@@ -26,9 +26,9 @@ class OIDCAuthCode(BaseModel):
|
||||
|
||||
# Request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True) # Requested scopes
|
||||
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
|
||||
code_verifier = db.Column(db.String(255), nullable=True) # For PKCE
|
||||
scope = db.Column(db.JSON, nullable=True)
|
||||
nonce = db.Column(db.String(255), nullable=True)
|
||||
code_verifier = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Status tracking
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
@@ -39,37 +39,48 @@ class OIDCAuthCode(BaseModel):
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
# Relationships — back_populates declared on User and OIDCClient
|
||||
client = db.relationship("OIDCClient", back_populates="authorization_codes")
|
||||
user = db.relationship("User", back_populates="oidc_auth_codes")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuthCode."""
|
||||
return f"<OIDCAuthCode client_id={self.client_id} user_id={self.user_id} used={self.is_used}>"
|
||||
return (
|
||||
f"<OIDCAuthCode client_id={self.client_id} "
|
||||
f"user_id={self.user_id} used={self.is_used}>"
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the authorization code has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
# Make naive datetime timezone-aware (UTC)
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_valid(self):
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the authorization code is valid for use."""
|
||||
return not self.is_used and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def mark_as_used(self):
|
||||
def mark_as_used(self) -> None:
|
||||
"""Mark the authorization code as used."""
|
||||
self.is_used = True
|
||||
self.used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None,
|
||||
nonce=None, code_verifier=None, ip_address=None, user_agent=None,
|
||||
lifetime_seconds=600):
|
||||
def create_code(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
code_hash: str,
|
||||
redirect_uri: str,
|
||||
scope=None,
|
||||
nonce: str = None,
|
||||
code_verifier: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> "OIDCAuthCode":
|
||||
"""Create a new authorization code.
|
||||
|
||||
Args:
|
||||
@@ -79,7 +90,7 @@ class OIDCAuthCode(BaseModel):
|
||||
redirect_uri: The redirect URI
|
||||
scope: Requested scopes
|
||||
nonce: OIDC nonce
|
||||
code_verifier: PKCE code verifier
|
||||
code_verifier: PKCE code verifier (stored hashed server-side)
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
|
||||
@@ -106,20 +117,7 @@ class OIDCAuthCode(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude code hash
|
||||
exclude.append("code_hash")
|
||||
exclude.append("code_verifier")
|
||||
for field in ("code_hash", "code_verifier"):
|
||||
if field not in exclude:
|
||||
exclude.append(field)
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_auth_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.authorization_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -17,10 +17,10 @@ class OIDCClient(BaseModel):
|
||||
client_secret_hash = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# OAuth/OIDC configuration
|
||||
redirect_uris = db.Column(db.JSON, nullable=False) # List of allowed redirect URIs
|
||||
grant_types = db.Column(db.JSON, nullable=False) # List of allowed grant types
|
||||
response_types = db.Column(db.JSON, nullable=False) # List of allowed response types
|
||||
scopes = db.Column(db.JSON, nullable=False) # List of allowed scopes
|
||||
redirect_uris = db.Column(db.JSON, nullable=False) # Allowed redirect URIs
|
||||
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
|
||||
response_types = db.Column(db.JSON, nullable=False) # Allowed response types
|
||||
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
|
||||
|
||||
# Client metadata
|
||||
logo_uri = db.Column(db.String(512), nullable=True)
|
||||
@@ -41,6 +41,23 @@ class OIDCClient(BaseModel):
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="oidc_clients")
|
||||
|
||||
# OIDC sub-resource relationships (declared here, not monkey-patched elsewhere)
|
||||
authorization_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCClient."""
|
||||
return f"<OIDCClient {self.name} client_id={self.client_id}>"
|
||||
@@ -48,22 +65,22 @@ class OIDCClient(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude client secret
|
||||
exclude.append("client_secret_hash")
|
||||
if "client_secret_hash" not in exclude:
|
||||
exclude.append("client_secret_hash")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
def has_grant_type(self, grant_type):
|
||||
def has_grant_type(self, grant_type) -> bool:
|
||||
"""Check if client supports a specific grant type."""
|
||||
return grant_type in self.grant_types
|
||||
|
||||
def has_response_type(self, response_type):
|
||||
def has_response_type(self, response_type) -> bool:
|
||||
"""Check if client supports a specific response type."""
|
||||
return response_type in self.response_types
|
||||
|
||||
def is_redirect_uri_allowed(self, redirect_uri):
|
||||
def is_redirect_uri_allowed(self, redirect_uri: str) -> bool:
|
||||
"""Check if a redirect URI is allowed for this client."""
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def has_scope(self, scope):
|
||||
def has_scope(self, scope: str) -> bool:
|
||||
"""Check if client is allowed to request a specific scope."""
|
||||
return scope in self.scopes
|
||||
@@ -0,0 +1,76 @@
|
||||
"""OIDC JWKS Key model for persisting signing keys."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OidcJwksKey(BaseModel):
|
||||
"""OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
|
||||
|
||||
Stores RSA/ECDSA key pairs used for signing OIDC tokens. Multiple keys can
|
||||
be stored to support key rotation scenarios.
|
||||
|
||||
Attributes:
|
||||
kid: Unique key ID used in JWT ``kid`` header
|
||||
key_type: Type of key (e.g., "RSA", "EC")
|
||||
private_key: PEM-encoded private key (never exposed in API responses)
|
||||
public_key: PEM-encoded public key
|
||||
algorithm: Signing algorithm (e.g., "RS256", "ES256")
|
||||
is_active: Whether this key is currently used for signing/verification
|
||||
is_primary: Whether this is the primary signing key
|
||||
expires_at: Optional expiry for key rotation enforcement
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_jwks_keys"
|
||||
|
||||
# Override the default UUID id with integer primary key for JWKS key sets
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
|
||||
expires_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Key identification and type
|
||||
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
|
||||
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
|
||||
|
||||
# Key material (PEM-encoded) — private_key must never be returned by API
|
||||
private_key = db.Column(db.Text, nullable=False)
|
||||
public_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Key status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
is_primary = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OidcJwksKey."""
|
||||
return (
|
||||
f"<OidcJwksKey kid={self.kid} "
|
||||
f"key_type={self.key_type} algorithm={self.algorithm}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude_private_key: bool = True):
|
||||
"""Convert model to dictionary.
|
||||
|
||||
Args:
|
||||
exclude_private_key: If True (default), excludes the private key.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the model
|
||||
"""
|
||||
exclude = ["private_key"] if exclude_private_key else []
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
@classmethod
|
||||
def get_active_keys(cls) -> list:
|
||||
"""Get all active keys for signing operations."""
|
||||
return cls.query.filter_by(is_active=True).all()
|
||||
|
||||
@classmethod
|
||||
def get_primary_key(cls) -> "OidcJwksKey | None":
|
||||
"""Get the primary signing key."""
|
||||
return cls.query.filter_by(is_primary=True).first()
|
||||
|
||||
@classmethod
|
||||
def get_key_by_kid(cls, kid: str) -> "OidcJwksKey | None":
|
||||
"""Get an active key by its key ID."""
|
||||
return cls.query.filter_by(kid=kid, is_active=True).first()
|
||||
+33
-41
@@ -1,5 +1,5 @@
|
||||
"""OIDC Refresh Token model for token rotation."""
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
@@ -8,7 +8,8 @@ class OIDCRefreshToken(BaseModel):
|
||||
"""OIDC Refresh Token model for token refresh and rotation.
|
||||
|
||||
Refresh tokens are long-lived credentials used to obtain new access tokens.
|
||||
They support token rotation for enhanced security.
|
||||
They support token rotation for enhanced security — each use invalidates
|
||||
the old token and issues a new one.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_refresh_tokens"
|
||||
@@ -21,16 +22,14 @@ class OIDCRefreshToken(BaseModel):
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token (hashed for security)
|
||||
# Token (hashed for security — never store plaintext refresh tokens)
|
||||
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
|
||||
# Associated access token ID (stores JWT JTI string — no FK to sessions)
|
||||
access_token_id = db.Column(
|
||||
db.String(255), nullable=True, index=True
|
||||
)
|
||||
# Associated access token JTI (no FK — stored as string for lightweight lookup)
|
||||
access_token_id = db.Column(db.String(255), nullable=True, index=True)
|
||||
|
||||
# Token scope
|
||||
scope = db.Column(db.JSON, nullable=True) # Granted scopes
|
||||
scope = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
@@ -40,7 +39,7 @@ class OIDCRefreshToken(BaseModel):
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Token rotation metadata
|
||||
previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation
|
||||
previous_token_hash = db.Column(db.String(255), nullable=True)
|
||||
rotation_count = db.Column(db.Integer, default=0, nullable=False)
|
||||
|
||||
# Request metadata
|
||||
@@ -53,25 +52,27 @@ class OIDCRefreshToken(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCRefreshToken."""
|
||||
return f"<OIDCRefreshToken client_id={self.client_id} user_id={self.user_id} revoked={self.is_revoked()}>"
|
||||
return (
|
||||
f"<OIDCRefreshToken client_id={self.client_id} "
|
||||
f"user_id={self.user_id} revoked={self.is_revoked()}>"
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the refresh token has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_revoked(self):
|
||||
def is_revoked(self) -> bool:
|
||||
"""Check if the refresh token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self):
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the refresh token is valid for use."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason=None):
|
||||
def revoke(self, reason: str = None) -> None:
|
||||
"""Revoke the refresh token.
|
||||
|
||||
Args:
|
||||
@@ -81,8 +82,8 @@ class OIDCRefreshToken(BaseModel):
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
def rotate(self, new_token_hash):
|
||||
"""Rotate the refresh token (invalidate old, create new).
|
||||
def rotate(self, new_token_hash: str) -> "OIDCRefreshToken":
|
||||
"""Rotate the refresh token — invalidate the old hash, store the new one.
|
||||
|
||||
Args:
|
||||
new_token_hash: Hash of the new refresh token
|
||||
@@ -90,20 +91,25 @@ class OIDCRefreshToken(BaseModel):
|
||||
Returns:
|
||||
self for chaining
|
||||
"""
|
||||
# Store reference to old token
|
||||
self.previous_token_hash = self.token_hash
|
||||
self.token_hash = new_token_hash
|
||||
self.rotation_count += 1
|
||||
# Extend expiration on rotation
|
||||
from datetime import timedelta
|
||||
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
db.session.commit()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create_token(cls, client_id, user_id, token_hash, scope=None,
|
||||
access_token_id=None, ip_address=None, user_agent=None,
|
||||
lifetime_seconds=2592000):
|
||||
def create_token(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_hash: str,
|
||||
scope=None,
|
||||
access_token_id: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
lifetime_seconds: int = 2592000,
|
||||
) -> "OIDCRefreshToken":
|
||||
"""Create a new refresh token.
|
||||
|
||||
Args:
|
||||
@@ -111,7 +117,7 @@ class OIDCRefreshToken(BaseModel):
|
||||
user_id: The user ID
|
||||
token_hash: Hashed refresh token
|
||||
scope: Granted scopes
|
||||
access_token_id: Associated access token ID
|
||||
access_token_id: Associated access token JTI
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Token lifetime in seconds (default 30 days)
|
||||
@@ -119,7 +125,6 @@ class OIDCRefreshToken(BaseModel):
|
||||
Returns:
|
||||
OIDCRefreshToken instance
|
||||
"""
|
||||
from datetime import timedelta
|
||||
token = cls(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
@@ -137,20 +142,7 @@ class OIDCRefreshToken(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude token hashes
|
||||
exclude.append("token_hash")
|
||||
exclude.append("previous_token_hash")
|
||||
for field in ("token_hash", "previous_token_hash"):
|
||||
if field not in exclude:
|
||||
exclude.append(field)
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -1,5 +1,7 @@
|
||||
"""OIDC Session model for OIDC session tracking."""
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import base64
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
@@ -7,8 +9,8 @@ from gatehouse_app.models.base import BaseModel
|
||||
class OIDCSession(BaseModel):
|
||||
"""OIDC Session model for tracking OIDC authentication sessions.
|
||||
|
||||
This model tracks the state during the OIDC authentication flow,
|
||||
including PKCE parameters and nonce validation.
|
||||
Tracks the state during the OIDC authorization flow, including PKCE
|
||||
parameters and nonce validation.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_sessions"
|
||||
@@ -25,11 +27,11 @@ class OIDCSession(BaseModel):
|
||||
|
||||
# State management
|
||||
state = db.Column(db.String(255), nullable=False, index=True)
|
||||
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
|
||||
nonce = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Authorization request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True) # Requested scopes
|
||||
scope = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# PKCE parameters
|
||||
code_challenge = db.Column(db.String(255), nullable=True)
|
||||
@@ -45,50 +47,52 @@ class OIDCSession(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCSession."""
|
||||
return f"<OIDCSession user_id={self.user_id} client_id={self.client_id} state={self.state[:8]}...>"
|
||||
return (
|
||||
f"<OIDCSession user_id={self.user_id} "
|
||||
f"client_id={self.client_id} state={self.state[:8]}...>"
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the OIDC session has expired."""
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_authenticated(self):
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if the user has been authenticated in this session."""
|
||||
return self.authenticated_at is not None
|
||||
|
||||
def mark_authenticated(self):
|
||||
def mark_authenticated(self) -> None:
|
||||
"""Mark the session as authenticated."""
|
||||
self.authenticated_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def validate_nonce(self, expected_nonce):
|
||||
def validate_nonce(self, expected_nonce: str) -> bool:
|
||||
"""Validate the nonce matches the expected value.
|
||||
|
||||
Args:
|
||||
expected_nonce: The expected nonce value
|
||||
|
||||
Returns:
|
||||
bool: True if nonce matches
|
||||
True if nonce matches
|
||||
"""
|
||||
return self.nonce == expected_nonce
|
||||
|
||||
def validate_code_challenge(self, code_verifier):
|
||||
def validate_code_challenge(self, code_verifier: str) -> bool:
|
||||
"""Validate the code verifier against the stored code challenge.
|
||||
|
||||
Args:
|
||||
code_verifier: The PKCE code verifier
|
||||
|
||||
Returns:
|
||||
bool: True if code challenge is valid
|
||||
True if the challenge is satisfied
|
||||
"""
|
||||
if not self.code_challenge:
|
||||
return False
|
||||
|
||||
if self.code_challenge_method == "S256":
|
||||
import hashlib
|
||||
import base64
|
||||
# SHA256 hash of code_verifier
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
# Base64 URL encode without padding
|
||||
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return self.code_challenge == expected
|
||||
elif self.code_challenge_method == "plain":
|
||||
@@ -97,9 +101,18 @@ class OIDCSession(BaseModel):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def create_session(cls, user_id, client_id, state, redirect_uri, scope=None,
|
||||
nonce=None, code_challenge=None, code_challenge_method=None,
|
||||
lifetime_seconds=600):
|
||||
def create_session(
|
||||
cls,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
scope=None,
|
||||
nonce: str = None,
|
||||
code_challenge: str = None,
|
||||
code_challenge_method: str = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> "OIDCSession":
|
||||
"""Create a new OIDC session.
|
||||
|
||||
Args:
|
||||
@@ -116,7 +129,6 @@ class OIDCSession(BaseModel):
|
||||
Returns:
|
||||
OIDCSession instance
|
||||
"""
|
||||
from datetime import timedelta
|
||||
session = cls(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
@@ -133,7 +145,7 @@ class OIDCSession(BaseModel):
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def get_by_state(cls, state):
|
||||
def get_by_state(cls, state: str) -> "OIDCSession | None":
|
||||
"""Get a session by state parameter.
|
||||
|
||||
Args:
|
||||
@@ -147,16 +159,3 @@ class OIDCSession(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
+49
-45
@@ -8,13 +8,14 @@ from gatehouse_app.models.base import BaseModel
|
||||
class OIDCTokenMetadata(BaseModel):
|
||||
"""OIDC Token Metadata model for tracking issued tokens.
|
||||
|
||||
This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens)
|
||||
for the purpose of token revocation. The id field matches the JTI (JWT ID) claim.
|
||||
Stores metadata about issued tokens (access, refresh, ID) for revocation.
|
||||
The ``id`` field on this model intentionally overrides the BaseModel UUID
|
||||
to store the JWT JTI directly as the primary key for O(1) revocation checks.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_token_metadata"
|
||||
|
||||
# Token identifier (matches JTI in JWT)
|
||||
# Primary key = JTI so revocation lookups are always a PK scan
|
||||
id = db.Column(
|
||||
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
@@ -27,11 +28,11 @@ class OIDCTokenMetadata(BaseModel):
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token type
|
||||
token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token"
|
||||
# Token type: "access_token", "refresh_token", or "id_token"
|
||||
token_type = db.Column(db.String(50), nullable=False)
|
||||
|
||||
# Token identifier for revocation lookup
|
||||
token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim
|
||||
# JWT ID claim (indexed for fast lookup when id != jti)
|
||||
token_jti = db.Column(db.String(255), nullable=False, index=True)
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
@@ -46,25 +47,27 @@ class OIDCTokenMetadata(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCTokenMetadata."""
|
||||
return f"<OIDCTokenMetadata jti={self.token_jti[:8]}... type={self.token_type} revoked={self.is_revoked()}>"
|
||||
return (
|
||||
f"<OIDCTokenMetadata jti={self.token_jti[:8]}... "
|
||||
f"type={self.token_type} revoked={self.is_revoked()}>"
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the token has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_revoked(self):
|
||||
def is_revoked(self) -> bool:
|
||||
"""Check if the token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self):
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the token is valid (not expired and not revoked)."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason=None):
|
||||
def revoke(self, reason: str = None) -> None:
|
||||
"""Revoke the token.
|
||||
|
||||
Args:
|
||||
@@ -75,8 +78,16 @@ class OIDCTokenMetadata(BaseModel):
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_metadata(cls, client_id, user_id, token_type, token_jti,
|
||||
expires_at, ip_address=None, user_agent=None):
|
||||
def create_metadata(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_type: str,
|
||||
token_jti: str,
|
||||
expires_at,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> "OIDCTokenMetadata":
|
||||
"""Create token metadata for tracking.
|
||||
|
||||
Args:
|
||||
@@ -85,8 +96,8 @@ class OIDCTokenMetadata(BaseModel):
|
||||
token_type: Type of token ("access_token", "refresh_token", "id_token")
|
||||
token_jti: JWT ID claim
|
||||
expires_at: Token expiration datetime
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
ip_address: Client IP address (unused column, kept for API compat)
|
||||
user_agent: Client user agent (unused column, kept for API compat)
|
||||
|
||||
Returns:
|
||||
OIDCTokenMetadata instance
|
||||
@@ -104,7 +115,7 @@ class OIDCTokenMetadata(BaseModel):
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def get_by_jti(cls, token_jti):
|
||||
def get_by_jti(cls, token_jti: str) -> "OIDCTokenMetadata | None":
|
||||
"""Get token metadata by JWT ID.
|
||||
|
||||
Args:
|
||||
@@ -116,7 +127,7 @@ class OIDCTokenMetadata(BaseModel):
|
||||
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
|
||||
|
||||
@classmethod
|
||||
def revoke_by_jti(cls, token_jti, reason=None):
|
||||
def revoke_by_jti(cls, token_jti: str, reason: str = None) -> bool:
|
||||
"""Revoke a token by its JWT ID.
|
||||
|
||||
Args:
|
||||
@@ -124,7 +135,7 @@ class OIDCTokenMetadata(BaseModel):
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
bool: True if token was found and revoked
|
||||
True if token was found and revoked, False otherwise
|
||||
"""
|
||||
metadata = cls.get_by_jti(token_jti)
|
||||
if metadata:
|
||||
@@ -133,47 +144,53 @@ class OIDCTokenMetadata(BaseModel):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_user(cls, user_id, client_id=None, reason=None):
|
||||
def revoke_all_for_user(
|
||||
cls, user_id: str, client_id: str = None, reason: str = None
|
||||
) -> int:
|
||||
"""Revoke all tokens for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: Optional client ID to filter by
|
||||
client_id: Optional client ID filter
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
int: Number of tokens revoked
|
||||
Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(user_id=user_id, deleted_at=None)
|
||||
query = cls.query.filter_by(user_id=user_id, deleted_at=None).filter(
|
||||
cls.revoked_at.is_(None)
|
||||
)
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
|
||||
tokens = query.filter(cls.revoked_at == None).all()
|
||||
count = 0
|
||||
for token in tokens:
|
||||
for token in query.all():
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_client(cls, client_id, user_id=None, reason=None):
|
||||
def revoke_all_for_client(
|
||||
cls, client_id: str, user_id: str = None, reason: str = None
|
||||
) -> int:
|
||||
"""Revoke all tokens for a client.
|
||||
|
||||
Args:
|
||||
client_id: The client ID
|
||||
user_id: Optional user ID to filter by
|
||||
user_id: Optional user ID filter
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
int: Number of tokens revoked
|
||||
Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(client_id=client_id, deleted_at=None)
|
||||
query = cls.query.filter_by(client_id=client_id, deleted_at=None).filter(
|
||||
cls.revoked_at.is_(None)
|
||||
)
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
|
||||
tokens = query.filter(cls.revoked_at == None).all()
|
||||
count = 0
|
||||
for token in tokens:
|
||||
for token in query.all():
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
@@ -181,16 +198,3 @@ class OIDCTokenMetadata(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -1,77 +0,0 @@
|
||||
"""OIDC JWKS Key model for persisting signing keys."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OidcJwksKey(BaseModel):
|
||||
"""
|
||||
OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
|
||||
|
||||
This model stores RSA/ECDSA key pairs used for signing OIDC tokens.
|
||||
Multiple keys can be stored to support key rotation scenarios.
|
||||
|
||||
Attributes:
|
||||
id: Integer primary key
|
||||
kid: Unique key ID used in JWT "kid" header
|
||||
key_type: Type of key (e.g., "RSA", "EC")
|
||||
private_key: PEM-encoded private key
|
||||
public_key: PEM-encoded public key
|
||||
algorithm: Signing algorithm (e.g., "RS256", "ES256")
|
||||
created_at: When the key was created
|
||||
is_active: Whether this key is currently active for signing
|
||||
is_primary: Whether this is the primary signing key
|
||||
expires_at: ...
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_jwks_keys"
|
||||
|
||||
# Override the default UUID id with integer primary key
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
|
||||
expires_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Key identification and type
|
||||
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
|
||||
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
|
||||
|
||||
# Key material (PEM-encoded)
|
||||
private_key = db.Column(db.Text, nullable=False)
|
||||
public_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Key status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
is_primary = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OidcJwksKey."""
|
||||
return f"<OidcJwksKey kid={self.kid} key_type={self.key_type} algorithm={self.algorithm}>"
|
||||
|
||||
def to_dict(self, exclude_private_key=True):
|
||||
"""
|
||||
Convert model to dictionary.
|
||||
|
||||
Args:
|
||||
exclude_private_key: If True, excludes the private key from output
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the model
|
||||
"""
|
||||
exclude = ["private_key"] if exclude_private_key else []
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
@classmethod
|
||||
def get_active_keys(cls):
|
||||
"""Get all active keys for signing operations."""
|
||||
return cls.query.filter(cls.is_active == True).all()
|
||||
|
||||
@classmethod
|
||||
def get_primary_key(cls):
|
||||
"""Get the primary signing key."""
|
||||
return cls.query.filter(cls.is_primary == True).first()
|
||||
|
||||
@classmethod
|
||||
def get_key_by_kid(cls, kid):
|
||||
"""Get a key by its key ID."""
|
||||
return cls.query.filter(cls.kid == kid, cls.is_active == True).first()
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Organization subpackage."""
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.department import (
|
||||
Department,
|
||||
DepartmentMembership,
|
||||
DepartmentPrincipal,
|
||||
)
|
||||
from gatehouse_app.models.organization.department_cert_policy import (
|
||||
DepartmentCertPolicy,
|
||||
STANDARD_EXTENSIONS,
|
||||
)
|
||||
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
|
||||
__all__ = [
|
||||
"Organization",
|
||||
"OrganizationMember",
|
||||
"Department",
|
||||
"DepartmentMembership",
|
||||
"DepartmentPrincipal",
|
||||
"DepartmentCertPolicy",
|
||||
"STANDARD_EXTENSIONS",
|
||||
"Principal",
|
||||
"PrincipalMembership",
|
||||
"OrgInviteToken",
|
||||
]
|
||||
+38
-36
@@ -1,14 +1,15 @@
|
||||
"""Department, DepartmentMembership, and DepartmentPrincipal models."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class Department(BaseModel):
|
||||
"""Department model representing an organizational unit for SSH access control.
|
||||
|
||||
|
||||
Departments are used to group users and assign SSH principals (access levels)
|
||||
to them. A user can be a member of multiple departments, and each department
|
||||
can have multiple principals assigned.
|
||||
|
||||
|
||||
Example:
|
||||
- Department: "Engineering"
|
||||
- Members: user1@example.com, user2@example.com
|
||||
@@ -39,12 +40,15 @@ class Department(BaseModel):
|
||||
back_populates="department",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
cert_policy = db.relationship(
|
||||
"DepartmentCertPolicy",
|
||||
back_populates="department",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Unique constraint: department name per organization
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"organization_id", "name", name="uix_org_dept_name"
|
||||
),
|
||||
db.UniqueConstraint("organization_id", "name", name="uix_org_dept_name"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -55,47 +59,46 @@ class Department(BaseModel):
|
||||
"""Convert department to dictionary."""
|
||||
exclude = exclude or []
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
# Add member count
|
||||
data["member_count"] = len([m for m in self.memberships if m.deleted_at is None])
|
||||
|
||||
# Add principal count
|
||||
data["principal_count"] = len([p for p in self.principal_links if p.deleted_at is None])
|
||||
|
||||
return data
|
||||
|
||||
def get_members(self, active_only=True):
|
||||
def get_members(self, active_only: bool = True):
|
||||
"""Get all members of this department.
|
||||
|
||||
|
||||
Args:
|
||||
active_only: If True, exclude soft-deleted members
|
||||
|
||||
|
||||
Returns:
|
||||
List of DepartmentMembership objects
|
||||
"""
|
||||
if active_only:
|
||||
return [m for m in self.memberships if m.deleted_at is None]
|
||||
return self.memberships
|
||||
return list(self.memberships)
|
||||
|
||||
def get_principals(self, active_only=True):
|
||||
def get_principals(self, active_only: bool = True):
|
||||
"""Get all principals assigned to this department.
|
||||
|
||||
|
||||
Args:
|
||||
active_only: If True, exclude soft-deleted principals
|
||||
|
||||
|
||||
Returns:
|
||||
List of Principal objects via DepartmentPrincipal
|
||||
"""
|
||||
if active_only:
|
||||
return [p.principal for p in self.principal_links if p.deleted_at is None and p.principal.deleted_at is None]
|
||||
return [
|
||||
p.principal
|
||||
for p in self.principal_links
|
||||
if p.deleted_at is None and p.principal.deleted_at is None
|
||||
]
|
||||
return [p.principal for p in self.principal_links]
|
||||
|
||||
def is_member(self, user_id):
|
||||
def is_member(self, user_id: str) -> bool:
|
||||
"""Check if a user is a member of this department.
|
||||
|
||||
|
||||
Args:
|
||||
user_id: ID of the user to check
|
||||
|
||||
|
||||
Returns:
|
||||
True if user is an active member, False otherwise
|
||||
"""
|
||||
@@ -108,14 +111,14 @@ class Department(BaseModel):
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_member_count(self):
|
||||
def get_member_count(self) -> int:
|
||||
"""Get the count of active members in this department."""
|
||||
return len(self.get_members(active_only=True))
|
||||
|
||||
|
||||
class DepartmentMembership(BaseModel):
|
||||
"""Department membership model representing user membership in a department.
|
||||
|
||||
|
||||
When a user is added to a department, they become eligible for SSH principals
|
||||
assigned to that department.
|
||||
"""
|
||||
@@ -139,24 +142,23 @@ class DepartmentMembership(BaseModel):
|
||||
user = db.relationship("User", back_populates="department_memberships")
|
||||
department = db.relationship("Department", back_populates="memberships")
|
||||
|
||||
# Unique constraint: user can only be member of a department once
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"user_id", "department_id", name="uix_user_dept"
|
||||
),
|
||||
db.UniqueConstraint("user_id", "department_id", name="uix_user_dept"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of DepartmentMembership."""
|
||||
return f"<DepartmentMembership user_id={self.user_id} dept_id={self.department_id}>"
|
||||
return (
|
||||
f"<DepartmentMembership user_id={self.user_id} dept_id={self.department_id}>"
|
||||
)
|
||||
|
||||
|
||||
class DepartmentPrincipal(BaseModel):
|
||||
"""Department principal assignment model.
|
||||
|
||||
|
||||
Represents the assignment of principals to departments. All members of a
|
||||
department get access to its assigned principals (transitively).
|
||||
|
||||
|
||||
Example:
|
||||
- Department: "Engineering"
|
||||
- Principal: "eng-prod-servers"
|
||||
@@ -182,13 +184,13 @@ class DepartmentPrincipal(BaseModel):
|
||||
department = db.relationship("Department", back_populates="principal_links")
|
||||
principal = db.relationship("Principal", back_populates="department_links")
|
||||
|
||||
# Unique constraint: principal can only be assigned to a department once
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"department_id", "principal_id", name="uix_dept_principal"
|
||||
),
|
||||
db.UniqueConstraint("department_id", "principal_id", name="uix_dept_principal"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of DepartmentPrincipal."""
|
||||
return f"<DepartmentPrincipal dept_id={self.department_id} principal_id={self.principal_id}>"
|
||||
return (
|
||||
f"<DepartmentPrincipal dept_id={self.department_id} "
|
||||
f"principal_id={self.principal_id}>"
|
||||
)
|
||||
@@ -0,0 +1,76 @@
|
||||
"""DepartmentCertPolicy — per-department SSH certificate issuance rules."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
# Standard SSH certificate extensions
|
||||
STANDARD_EXTENSIONS = [
|
||||
"permit-X11-forwarding",
|
||||
"permit-agent-forwarding",
|
||||
"permit-pty",
|
||||
"permit-port-forwarding",
|
||||
"permit-user-rc",
|
||||
]
|
||||
|
||||
|
||||
class DepartmentCertPolicy(BaseModel):
|
||||
"""SSH certificate policy for a department.
|
||||
|
||||
Controls:
|
||||
- Whether members may choose their own expiry date (up to ``max_expiry_hours``)
|
||||
- Default expiry hours when the user doesn't (or can't) pick
|
||||
- Maximum expiry hours (hard ceiling, even for admins signing on behalf)
|
||||
- Which SSH certificate extensions are granted to members of this department
|
||||
- Any custom extensions the admin wants to add beyond the standard five
|
||||
|
||||
Inherits ``id``, ``created_at``, ``updated_at``, and ``deleted_at`` from
|
||||
:class:`BaseModel` so soft-delete and the standard timestamp behaviour are
|
||||
consistent with every other model in the application.
|
||||
"""
|
||||
|
||||
__tablename__ = "department_cert_policies"
|
||||
|
||||
department_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("departments.id"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Expiry control
|
||||
allow_user_expiry = db.Column(db.Boolean, nullable=False, default=False)
|
||||
default_expiry_hours = db.Column(db.Integer, nullable=False, default=1)
|
||||
max_expiry_hours = db.Column(db.Integer, nullable=False, default=24)
|
||||
|
||||
# Extensions — list of extension name strings
|
||||
allowed_extensions = db.Column(
|
||||
db.JSON,
|
||||
nullable=False,
|
||||
default=lambda: list(STANDARD_EXTENSIONS),
|
||||
)
|
||||
# Admin-defined extras beyond the standard five
|
||||
custom_extensions = db.Column(db.JSON, nullable=False, default=list)
|
||||
|
||||
# Relationship back to department
|
||||
department = db.relationship("Department", back_populates="cert_policy", uselist=False)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<DepartmentCertPolicy dept={self.department_id} "
|
||||
f"allow_user_expiry={self.allow_user_expiry}>"
|
||||
)
|
||||
|
||||
def all_extensions(self) -> list:
|
||||
"""Return the full list of enabled extensions (allowed + custom)."""
|
||||
return list((self.allowed_extensions or []) + (self.custom_extensions or []))
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
exclude = exclude or []
|
||||
data = super().to_dict(exclude=exclude)
|
||||
# Augment with computed / convenience fields not in the base columns
|
||||
data["all_extensions"] = self.all_extensions()
|
||||
data["standard_extensions"] = STANDARD_EXTENSIONS
|
||||
return data
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Organization invite token model."""
|
||||
import secrets
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OrgInviteToken(BaseModel):
|
||||
"""Token-based invitation to join an organization."""
|
||||
|
||||
__tablename__ = "org_invite_tokens"
|
||||
|
||||
organization_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("organizations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
invited_by_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
email = db.Column(db.String(255), nullable=False, index=True)
|
||||
role = db.Column(db.String(64), nullable=False, default="member")
|
||||
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
accepted_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
organization = db.relationship(
|
||||
"Organization",
|
||||
backref=db.backref("invite_tokens", cascade="all, delete-orphan"),
|
||||
)
|
||||
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
|
||||
|
||||
@classmethod
|
||||
def generate(
|
||||
cls,
|
||||
organization_id: str,
|
||||
email: str,
|
||||
role: str = "member",
|
||||
invited_by_id: str = None,
|
||||
ttl_days: int = 7,
|
||||
) -> "OrgInviteToken":
|
||||
"""Create a new invite token for an organization."""
|
||||
token_value = secrets.token_urlsafe(48)
|
||||
instance = cls(
|
||||
organization_id=organization_id,
|
||||
email=email.lower(),
|
||||
role=role,
|
||||
invited_by_id=invited_by_id,
|
||||
token=token_value,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=ttl_days),
|
||||
)
|
||||
db.session.add(instance)
|
||||
db.session.commit()
|
||||
return instance
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Return True if the token is unused and not expired."""
|
||||
if self.accepted_at is not None:
|
||||
return False
|
||||
now = datetime.now(timezone.utc)
|
||||
expires = self.expires_at
|
||||
if expires.tzinfo is None:
|
||||
expires = expires.replace(tzinfo=timezone.utc)
|
||||
return now < expires
|
||||
|
||||
def accept(self) -> None:
|
||||
"""Mark the invite as accepted."""
|
||||
self.accepted_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<OrgInviteToken org={self.organization_id} email={self.email}>"
|
||||
+2
-2
@@ -61,9 +61,9 @@ class Organization(BaseModel):
|
||||
return member.user
|
||||
return None
|
||||
|
||||
def is_member(self, user_id):
|
||||
def is_member(self, user_id: str) -> bool:
|
||||
"""Check if a user is a member of the organization."""
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
|
||||
return (
|
||||
OrganizationMember.query.filter_by(
|
||||
+12
-8
@@ -1,4 +1,4 @@
|
||||
"""Organization member model."""
|
||||
"""Organization member model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
@@ -21,31 +21,35 @@ class OrganizationMember(BaseModel):
|
||||
joined_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", foreign_keys=[user_id], back_populates="organization_memberships")
|
||||
user = db.relationship(
|
||||
"User", foreign_keys=[user_id], back_populates="organization_memberships"
|
||||
)
|
||||
organization = db.relationship("Organization", back_populates="members")
|
||||
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
|
||||
|
||||
# Unique constraint to prevent duplicate memberships
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OrganizationMember."""
|
||||
return f"<OrganizationMember user_id={self.user_id} org_id={self.organization_id} role={self.role}>"
|
||||
return (
|
||||
f"<OrganizationMember user_id={self.user_id} "
|
||||
f"org_id={self.organization_id} role={self.role}>"
|
||||
)
|
||||
|
||||
def is_owner(self):
|
||||
def is_owner(self) -> bool:
|
||||
"""Check if member is an owner."""
|
||||
return self.role == OrganizationRole.OWNER
|
||||
|
||||
def is_admin(self):
|
||||
def is_admin(self) -> bool:
|
||||
"""Check if member is an admin or owner."""
|
||||
return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
def can_manage_members(self):
|
||||
def can_manage_members(self) -> bool:
|
||||
"""Check if member can manage other members."""
|
||||
return self.is_admin()
|
||||
|
||||
def can_delete_organization(self):
|
||||
def can_delete_organization(self) -> bool:
|
||||
"""Check if member can delete the organization."""
|
||||
return self.is_owner()
|
||||
+61
-66
@@ -1,14 +1,16 @@
|
||||
"""Principal and PrincipalMembership models."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class Principal(BaseModel):
|
||||
"""Principal model representing an SSH principal (access level/role).
|
||||
|
||||
|
||||
In SSH CA terminology, a principal is a string like "eng-prod-servers" or
|
||||
"devops-admins" that represents a set of machines or access level. Users
|
||||
can be granted access to principals, either directly or via department membership.
|
||||
|
||||
can be granted access to principals, either directly or via department
|
||||
membership.
|
||||
|
||||
Example:
|
||||
- Principal: "eng-prod-servers"
|
||||
- Users with this principal can SSH to prod servers
|
||||
@@ -39,11 +41,8 @@ class Principal(BaseModel):
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Unique constraint: principal name per organization
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"organization_id", "name", name="uix_org_principal_name"
|
||||
),
|
||||
db.UniqueConstraint("organization_id", "name", name="uix_org_principal_name"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -54,79 +53,80 @@ class Principal(BaseModel):
|
||||
"""Convert principal to dictionary."""
|
||||
exclude = exclude or []
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
# Add member count
|
||||
data["direct_member_count"] = len([m for m in self.memberships if m.deleted_at is None])
|
||||
|
||||
# Add department count
|
||||
data["department_count"] = len([d for d in self.department_links if d.deleted_at is None])
|
||||
|
||||
data["direct_member_count"] = len(
|
||||
[m for m in self.memberships if m.deleted_at is None]
|
||||
)
|
||||
data["department_count"] = len(
|
||||
[d for d in self.department_links if d.deleted_at is None]
|
||||
)
|
||||
return data
|
||||
|
||||
def get_members(self, active_only=True):
|
||||
def get_members(self, active_only: bool = True):
|
||||
"""Get all users who are directly assigned to this principal.
|
||||
|
||||
|
||||
Does NOT include users who get access via department membership.
|
||||
|
||||
|
||||
Args:
|
||||
active_only: If True, exclude soft-deleted members
|
||||
|
||||
|
||||
Returns:
|
||||
List of PrincipalMembership objects
|
||||
"""
|
||||
if active_only:
|
||||
return [m for m in self.memberships if m.deleted_at is None]
|
||||
return self.memberships
|
||||
return list(self.memberships)
|
||||
|
||||
def get_all_members(self, active_only=True):
|
||||
def get_all_members(self, active_only: bool = True):
|
||||
"""Get all users who have access to this principal.
|
||||
|
||||
|
||||
Includes both direct members and users via department membership.
|
||||
|
||||
|
||||
Args:
|
||||
active_only: If True, exclude soft-deleted members
|
||||
|
||||
|
||||
Returns:
|
||||
Set of User objects with access to this principal
|
||||
"""
|
||||
from gatehouse_app.models.user import User
|
||||
|
||||
all_users = set()
|
||||
|
||||
# Add direct members
|
||||
all_users: set = set()
|
||||
|
||||
# Direct members
|
||||
for membership in self.get_members(active_only=active_only):
|
||||
if membership.user.deleted_at is None or not active_only:
|
||||
if not active_only or membership.user.deleted_at is None:
|
||||
all_users.add(membership.user)
|
||||
|
||||
# Add members via department assignment
|
||||
|
||||
# Members via department assignment
|
||||
for dept_link in self.department_links:
|
||||
if dept_link.deleted_at is None or not active_only:
|
||||
for dept_member in dept_link.department.get_members(active_only=active_only):
|
||||
if dept_member.user.deleted_at is None or not active_only:
|
||||
if not active_only or dept_member.user.deleted_at is None:
|
||||
all_users.add(dept_member.user)
|
||||
|
||||
|
||||
return all_users
|
||||
|
||||
def get_departments(self, active_only=True):
|
||||
def get_departments(self, active_only: bool = True):
|
||||
"""Get all departments this principal is assigned to.
|
||||
|
||||
|
||||
Args:
|
||||
active_only: If True, exclude soft-deleted departments
|
||||
|
||||
|
||||
Returns:
|
||||
List of Department objects
|
||||
"""
|
||||
if active_only:
|
||||
return [d.department for d in self.department_links if d.deleted_at is None and d.department.deleted_at is None]
|
||||
return [
|
||||
d.department
|
||||
for d in self.department_links
|
||||
if d.deleted_at is None and d.department.deleted_at is None
|
||||
]
|
||||
return [d.department for d in self.department_links]
|
||||
|
||||
def is_member(self, user_id, include_via_department=True):
|
||||
def is_member(self, user_id: str, include_via_department: bool = True) -> bool:
|
||||
"""Check if a user has access to this principal.
|
||||
|
||||
|
||||
Args:
|
||||
user_id: ID of the user to check
|
||||
include_via_department: If True, check department memberships too
|
||||
|
||||
|
||||
Returns:
|
||||
True if user has access to this principal
|
||||
"""
|
||||
@@ -139,54 +139,49 @@ class Principal(BaseModel):
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
if has_direct:
|
||||
return True
|
||||
|
||||
# Check department membership if requested
|
||||
|
||||
if not include_via_department:
|
||||
return False
|
||||
|
||||
# Get all departments this principal is assigned to
|
||||
depts = self.get_departments(active_only=True)
|
||||
dept_ids = [d.id for d in depts]
|
||||
|
||||
|
||||
# Check department membership
|
||||
dept_ids = [d.id for d in self.get_departments(active_only=True)]
|
||||
if not dept_ids:
|
||||
return False
|
||||
|
||||
# Check if user is a member of any of these departments
|
||||
from gatehouse_app.models.department import DepartmentMembership
|
||||
|
||||
|
||||
from gatehouse_app.models.organization.department import DepartmentMembership
|
||||
|
||||
return (
|
||||
DepartmentMembership.query.filter(
|
||||
DepartmentMembership.user_id == user_id,
|
||||
DepartmentMembership.department_id.in_(dept_ids),
|
||||
DepartmentMembership.deleted_at == None,
|
||||
DepartmentMembership.deleted_at.is_(None),
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_member_count(self, include_via_department=True):
|
||||
def get_member_count(self, include_via_department: bool = True) -> int:
|
||||
"""Get the count of active members with access to this principal.
|
||||
|
||||
|
||||
Args:
|
||||
include_via_department: If True, include members via department
|
||||
|
||||
|
||||
Returns:
|
||||
Count of members
|
||||
"""
|
||||
if not include_via_department:
|
||||
return len(self.get_members(active_only=True))
|
||||
|
||||
return len(self.get_all_members(active_only=True))
|
||||
|
||||
|
||||
class PrincipalMembership(BaseModel):
|
||||
"""Principal membership model representing direct user assignment to a principal.
|
||||
|
||||
When a user is assigned directly to a principal, they get access to that principal
|
||||
for SSH authentication. This is in addition to any principals they get via
|
||||
department membership.
|
||||
|
||||
When a user is assigned directly to a principal, they get access to that
|
||||
principal for SSH authentication. This is in addition to any principals
|
||||
they get via department membership.
|
||||
"""
|
||||
|
||||
__tablename__ = "principal_memberships"
|
||||
@@ -208,13 +203,13 @@ class PrincipalMembership(BaseModel):
|
||||
user = db.relationship("User", back_populates="principal_memberships")
|
||||
principal = db.relationship("Principal", back_populates="memberships")
|
||||
|
||||
# Unique constraint: user can only be member of a principal once
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"user_id", "principal_id", name="uix_user_principal"
|
||||
),
|
||||
db.UniqueConstraint("user_id", "principal_id", name="uix_user_principal"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of PrincipalMembership."""
|
||||
return f"<PrincipalMembership user_id={self.user_id} principal_id={self.principal_id}>"
|
||||
return (
|
||||
f"<PrincipalMembership user_id={self.user_id} "
|
||||
f"principal_id={self.principal_id}>"
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Security subpackage — organization and user security policies, MFA compliance."""
|
||||
from gatehouse_app.models.security.organization_security_policy import (
|
||||
OrganizationSecurityPolicy,
|
||||
)
|
||||
from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy
|
||||
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
|
||||
|
||||
__all__ = [
|
||||
"OrganizationSecurityPolicy",
|
||||
"UserSecurityPolicy",
|
||||
"MfaPolicyCompliance",
|
||||
]
|
||||
+11
-10
@@ -1,4 +1,4 @@
|
||||
"""MfaPolicyCompliance model."""
|
||||
"""MfaPolicyCompliance model — per-user per-organization MFA compliance tracking."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import MfaComplianceStatus
|
||||
@@ -7,7 +7,8 @@ from gatehouse_app.utils.constants import MfaComplianceStatus
|
||||
class MfaPolicyCompliance(BaseModel):
|
||||
"""MFA policy compliance tracking per user per organization.
|
||||
|
||||
Tracks each user's MFA compliance state separately for each organization membership.
|
||||
Tracks each user's MFA compliance state separately for each organization
|
||||
membership. One row per (user, org) pair.
|
||||
"""
|
||||
|
||||
__tablename__ = "mfa_policy_compliance"
|
||||
@@ -25,13 +26,13 @@ class MfaPolicyCompliance(BaseModel):
|
||||
default=MfaComplianceStatus.NOT_APPLICABLE,
|
||||
)
|
||||
|
||||
# Snapshot of org policy at the time this record became active
|
||||
# Snapshot of org policy version when this record became active
|
||||
policy_version = db.Column(db.Integer, nullable=False)
|
||||
|
||||
# When policy started applying to this user
|
||||
applied_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Final deadline for this user to comply (per user, not global)
|
||||
# Final deadline for this user to comply
|
||||
deadline_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# When they became compliant under this policy_version
|
||||
@@ -45,9 +46,7 @@ class MfaPolicyCompliance(BaseModel):
|
||||
notification_count = db.Column(db.Integer, nullable=False, default=0)
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"user_id", "organization_id", name="uix_user_org_compliance"
|
||||
),
|
||||
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_compliance"),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
@@ -58,9 +57,11 @@ class MfaPolicyCompliance(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of MfaPolicyCompliance."""
|
||||
return f"<MfaPolicyCompliance user={self.user_id} org={self.organization_id} status={self.status}>"
|
||||
return (
|
||||
f"<MfaPolicyCompliance user={self.user_id} "
|
||||
f"org={self.organization_id} status={self.status}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
exclude = exclude or []
|
||||
return super().to_dict(exclude=exclude)
|
||||
return super().to_dict(exclude=exclude or [])
|
||||
+8
-4
@@ -39,15 +39,19 @@ class OrganizationSecurityPolicy(BaseModel):
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship(
|
||||
"Organization", back_populates="security_policy", foreign_keys=[organization_id]
|
||||
"Organization",
|
||||
back_populates="security_policy",
|
||||
foreign_keys=[organization_id],
|
||||
)
|
||||
updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id])
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OrganizationSecurityPolicy."""
|
||||
return f"<OrganizationSecurityPolicy org={self.organization_id} mode={self.mfa_policy_mode}>"
|
||||
return (
|
||||
f"<OrganizationSecurityPolicy "
|
||||
f"org={self.organization_id} mode={self.mfa_policy_mode}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
exclude = exclude or []
|
||||
return super().to_dict(exclude=exclude)
|
||||
return super().to_dict(exclude=exclude or [])
|
||||
+10
-12
@@ -1,4 +1,4 @@
|
||||
"""UserSecurityPolicy model."""
|
||||
"""UserSecurityPolicy model — per-user MFA overrides."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import MfaRequirementOverride
|
||||
@@ -7,7 +7,7 @@ from gatehouse_app.utils.constants import MfaRequirementOverride
|
||||
class UserSecurityPolicy(BaseModel):
|
||||
"""User security policy model for per-user MFA overrides.
|
||||
|
||||
Stores per user overrides of organization level MFA requirements.
|
||||
Stores per-user overrides of organization-level MFA requirements.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_security_policies"
|
||||
@@ -25,29 +25,27 @@ class UserSecurityPolicy(BaseModel):
|
||||
default=MfaRequirementOverride.INHERIT,
|
||||
)
|
||||
|
||||
# If override is REQUIRED and you want to force a specific factor set
|
||||
# If override is REQUIRED, optionally force a specific factor set
|
||||
force_totp = db.Column(db.Boolean, nullable=False, default=False)
|
||||
force_webauthn = db.Column(db.Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"user_id", "organization_id", name="uix_user_org_policy"
|
||||
),
|
||||
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_policy"),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship(
|
||||
"User", back_populates="security_policies", foreign_keys=[user_id]
|
||||
)
|
||||
organization = db.relationship(
|
||||
"Organization", foreign_keys=[organization_id]
|
||||
)
|
||||
organization = db.relationship("Organization", foreign_keys=[organization_id])
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of UserSecurityPolicy."""
|
||||
return f"<UserSecurityPolicy user={self.user_id} org={self.organization_id} mode={self.mfa_override_mode}>"
|
||||
return (
|
||||
f"<UserSecurityPolicy user={self.user_id} "
|
||||
f"org={self.organization_id} mode={self.mfa_override_mode}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
exclude = exclude or []
|
||||
return super().to_dict(exclude=exclude)
|
||||
return super().to_dict(exclude=exclude or [])
|
||||
@@ -0,0 +1,17 @@
|
||||
"""SSH/CA subpackage — certificate authorities, SSH keys, and certificates."""
|
||||
from gatehouse_app.models.ssh_ca.ca import CA, KeyType, CertType, CaType, CAPermission
|
||||
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
|
||||
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
|
||||
from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
|
||||
|
||||
__all__ = [
|
||||
"CA",
|
||||
"KeyType",
|
||||
"CertType",
|
||||
"CaType",
|
||||
"CAPermission",
|
||||
"SSHKey",
|
||||
"SSHCertificate",
|
||||
"CertificateStatus",
|
||||
"CertificateAuditLog",
|
||||
]
|
||||
@@ -1,13 +1,13 @@
|
||||
"""Certificate Authority (CA) model."""
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class KeyType(str, Enum):
|
||||
"""SSH CA key types."""
|
||||
|
||||
|
||||
ED25519 = "ed25519"
|
||||
RSA = "rsa"
|
||||
ECDSA = "ecdsa"
|
||||
@@ -15,7 +15,7 @@ class KeyType(str, Enum):
|
||||
|
||||
class CertType(str, Enum):
|
||||
"""SSH certificate types."""
|
||||
|
||||
|
||||
USER = "user"
|
||||
HOST = "host"
|
||||
|
||||
@@ -29,7 +29,7 @@ class CaType(str, Enum):
|
||||
|
||||
class CA(BaseModel):
|
||||
"""Certificate Authority (CA) model for SSH certificate signing.
|
||||
|
||||
|
||||
Each organization can have multiple CAs for different purposes
|
||||
(e.g., production vs. staging). Private keys are encrypted at rest
|
||||
and should be protected with KMS.
|
||||
@@ -43,12 +43,12 @@ class CA(BaseModel):
|
||||
nullable=True, # NULL for the global system-config CA
|
||||
index=True,
|
||||
)
|
||||
|
||||
# CA name and description
|
||||
|
||||
# CA identity
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# CA signing type: 'user' signs user certificates, 'host' signs host certificates
|
||||
# CA signing type: 'user' signs user certificates, 'host' signs host certs
|
||||
ca_type = db.Column(
|
||||
db.Enum(CaType, values_callable=lambda x: [e.value for e in x]),
|
||||
default=CaType.USER,
|
||||
@@ -61,43 +61,33 @@ class CA(BaseModel):
|
||||
default=KeyType.ED25519,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Private key (encrypted at rest by database/KMS)
|
||||
# Format: PEM-encoded private key
|
||||
|
||||
# Private key — PEM-encoded, encrypted at rest by database/KMS
|
||||
private_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Public key (PEM format)
|
||||
|
||||
# Public key — PEM format
|
||||
public_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
|
||||
# SHA256 fingerprint of the public key
|
||||
fingerprint = db.Column(db.String(255), nullable=False, unique=True)
|
||||
|
||||
|
||||
# CRL (Certificate Revocation List) configuration
|
||||
crl_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||
crl_endpoint = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Default certificate validity in hours
|
||||
# Can be overridden per certificate request
|
||||
default_cert_validity_hours = db.Column(
|
||||
db.Integer,
|
||||
default=1,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
# Default certificate validity in hours (overridable per request)
|
||||
default_cert_validity_hours = db.Column(db.Integer, default=1, nullable=False)
|
||||
|
||||
# Maximum validity duration allowed
|
||||
max_cert_validity_hours = db.Column(
|
||||
db.Integer,
|
||||
default=24,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
max_cert_validity_hours = db.Column(db.Integer, default=24, nullable=False)
|
||||
|
||||
# CA status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
|
||||
# Key rotation tracking
|
||||
rotated_at = db.Column(db.DateTime, nullable=True)
|
||||
rotation_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="cas")
|
||||
certificates = db.relationship(
|
||||
@@ -112,54 +102,53 @@ class CA(BaseModel):
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"organization_id", "name", name="uix_org_ca_name"
|
||||
),
|
||||
db.UniqueConstraint("organization_id", "name", name="uix_org_ca_name"),
|
||||
db.Index("idx_ca_org_active", "organization_id", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of CA."""
|
||||
return f"<CA {self.name} (org_id={self.organization_id}, type={self.key_type})>"
|
||||
return (
|
||||
f"<CA {self.name} "
|
||||
f"(org_id={self.organization_id}, type={self.key_type})>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert CA to dictionary."""
|
||||
"""Convert CA to dictionary, never exposing the private key."""
|
||||
exclude = exclude or []
|
||||
# Never expose private key in API responses
|
||||
exclude.extend(["private_key"])
|
||||
if "private_key" not in exclude:
|
||||
exclude.append("private_key")
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add computed fields
|
||||
data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None])
|
||||
data["active_certs"] = len([
|
||||
c for c in self.certificates
|
||||
if c.deleted_at is None and not c.revoked
|
||||
])
|
||||
data["revoked_certs"] = len([
|
||||
c for c in self.certificates
|
||||
if c.deleted_at is None and c.revoked
|
||||
])
|
||||
|
||||
data["active_certs"] = len(
|
||||
[c for c in self.certificates if c.deleted_at is None and not c.revoked]
|
||||
)
|
||||
data["revoked_certs"] = len(
|
||||
[c for c in self.certificates if c.deleted_at is None and c.revoked]
|
||||
)
|
||||
return data
|
||||
|
||||
def get_active_certificates(self):
|
||||
"""Get all active (non-revoked) certificates issued by this CA.
|
||||
|
||||
Returns:
|
||||
List of non-revoked SSHCertificate objects
|
||||
"""
|
||||
def get_active_certificates(self) -> list:
|
||||
"""Get all active (non-revoked) certificates issued by this CA."""
|
||||
return [
|
||||
c for c in self.certificates
|
||||
if c.deleted_at is None and not c.revoked
|
||||
c for c in self.certificates if c.deleted_at is None and not c.revoked
|
||||
]
|
||||
|
||||
def rotate_key(self, new_private_key, new_public_key, new_fingerprint, reason=None):
|
||||
def rotate_key(
|
||||
self,
|
||||
new_private_key: str,
|
||||
new_public_key: str,
|
||||
new_fingerprint: str,
|
||||
reason: str = None,
|
||||
) -> None:
|
||||
"""Rotate the CA's key pair.
|
||||
|
||||
|
||||
This should only be done in carefully controlled circumstances.
|
||||
All existing certificates remain valid but no new certs can be
|
||||
signed with the old key.
|
||||
|
||||
All existing certificates remain valid but no new certificates can be
|
||||
signed with the old key after rotation.
|
||||
|
||||
Args:
|
||||
new_private_key: New PEM-encoded private key
|
||||
new_public_key: New PEM-encoded public key
|
||||
@@ -169,7 +158,7 @@ class CA(BaseModel):
|
||||
self.private_key = new_private_key
|
||||
self.public_key = new_public_key
|
||||
self.fingerprint = new_fingerprint
|
||||
self.rotated_at = datetime.utcnow()
|
||||
self.rotated_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
|
||||
self.rotation_reason = reason
|
||||
self.save()
|
||||
|
||||
@@ -178,7 +167,7 @@ class CAPermission(BaseModel):
|
||||
"""Per-user CA permission model.
|
||||
|
||||
Controls which users are allowed to sign certificates against a specific CA.
|
||||
When a CA has any permission rows the signing endpoint enforces the list;
|
||||
When a CA has any permission rows, the signing endpoint enforces the list;
|
||||
CAs with no rows are open to all org members (backwards-compatible default).
|
||||
|
||||
Permission values:
|
||||
@@ -212,7 +201,10 @@ class CAPermission(BaseModel):
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CAPermission ca_id={self.ca_id} user_id={self.user_id} permission={self.permission}>"
|
||||
return (
|
||||
f"<CAPermission ca_id={self.ca_id} "
|
||||
f"user_id={self.user_id} permission={self.permission}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
data = super().to_dict(exclude=exclude or [])
|
||||
+27
-19
@@ -1,14 +1,14 @@
|
||||
"""Certificate audit log model."""
|
||||
"""Certificate audit log model — tracks SSH certificate lifecycle events."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class CertificateAuditLog(BaseModel):
|
||||
"""Audit log for SSH certificate lifecycle events.
|
||||
|
||||
Tracks all operations on SSH certificates: signing, revocation,
|
||||
validation, etc. This is separate from the general AuditLog to
|
||||
provide detailed certificate operation tracking.
|
||||
|
||||
Tracks all operations on SSH certificates: signing, revocation, validation,
|
||||
etc. Kept separate from the general AuditLog to provide detailed certificate
|
||||
operation tracking without polluting the main audit stream.
|
||||
"""
|
||||
|
||||
__tablename__ = "certificate_audit_logs"
|
||||
@@ -20,33 +20,33 @@ class CertificateAuditLog(BaseModel):
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# The user who performed the action (can be null for system actions)
|
||||
|
||||
# The user who performed the action (null for system actions)
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
|
||||
# Action type (e.g., "signed", "revoked", "validated", "requested")
|
||||
action = db.Column(db.String(50), nullable=False, index=True)
|
||||
|
||||
|
||||
# Request details
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.String(512), nullable=True)
|
||||
request_id = db.Column(db.String(36), nullable=True)
|
||||
|
||||
|
||||
# Detailed message
|
||||
message = db.Column(db.Text, nullable=True)
|
||||
|
||||
|
||||
# Additional context
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Success/failure
|
||||
|
||||
# Outcome
|
||||
success = db.Column(db.Boolean, default=True, nullable=False)
|
||||
error_message = db.Column(db.Text, nullable=True)
|
||||
|
||||
|
||||
# Relationships
|
||||
certificate = db.relationship("SSHCertificate", back_populates="audit_logs")
|
||||
user = db.relationship("User")
|
||||
@@ -58,18 +58,26 @@ class CertificateAuditLog(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of CertificateAuditLog."""
|
||||
return f"<CertificateAuditLog cert_id={self.certificate_id} action={self.action}>"
|
||||
return (
|
||||
f"<CertificateAuditLog cert_id={self.certificate_id} action={self.action}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log(cls, certificate_id, action, user_id=None, **kwargs):
|
||||
def log(
|
||||
cls,
|
||||
certificate_id: str,
|
||||
action: str,
|
||||
user_id: str = None,
|
||||
**kwargs,
|
||||
) -> "CertificateAuditLog":
|
||||
"""Create a certificate audit log entry.
|
||||
|
||||
|
||||
Args:
|
||||
certificate_id: ID of the certificate
|
||||
action: Action type (e.g., "signed", "revoked")
|
||||
user_id: ID of the user performing the action (optional)
|
||||
**kwargs: Additional fields (ip_address, user_agent, message, etc.)
|
||||
|
||||
|
||||
Returns:
|
||||
CertificateAuditLog instance
|
||||
"""
|
||||
@@ -77,7 +85,7 @@ class CertificateAuditLog(BaseModel):
|
||||
certificate_id=certificate_id,
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
log_entry.save()
|
||||
return log_entry
|
||||
+46
-51
@@ -1,26 +1,26 @@
|
||||
"""SSH Certificate model."""
|
||||
"""SSH Certificate model — signed SSH user/host certificates."""
|
||||
from enum import Enum
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.ca import CertType
|
||||
from gatehouse_app.models.ssh_ca.ca import CertType
|
||||
|
||||
|
||||
class CertificateStatus(str, Enum):
|
||||
"""SSH certificate lifecycle status."""
|
||||
|
||||
REQUESTED = "requested" # Waiting for signing
|
||||
ISSUED = "issued" # Signed and valid
|
||||
REVOKED = "revoked" # Manually revoked
|
||||
EXPIRED = "expired" # Validity period ended
|
||||
SUPERSEDED = "superseded" # Replaced by newer cert
|
||||
|
||||
REQUESTED = "requested" # Waiting for signing
|
||||
ISSUED = "issued" # Signed and valid
|
||||
REVOKED = "revoked" # Manually revoked
|
||||
EXPIRED = "expired" # Validity period ended
|
||||
SUPERSEDED = "superseded" # Replaced by newer certificate
|
||||
|
||||
|
||||
class SSHCertificate(BaseModel):
|
||||
"""SSH Certificate model representing a signed SSH user/host certificate.
|
||||
|
||||
|
||||
Certificates are issued by a CA and associated with an SSH public key.
|
||||
They include principals (access levels), validity periods, and other
|
||||
They include principals (access levels), validity periods, and standard
|
||||
OpenSSH certificate metadata.
|
||||
"""
|
||||
|
||||
@@ -45,10 +45,10 @@ class SSHCertificate(BaseModel):
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Certificate content (full signed certificate in OpenSSH format)
|
||||
|
||||
# Certificate content — full signed certificate in OpenSSH format
|
||||
certificate = db.Column(db.Text, nullable=False)
|
||||
|
||||
|
||||
# Certificate metadata
|
||||
serial = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
key_id = db.Column(db.String(255), nullable=False) # Usually user email
|
||||
@@ -57,19 +57,19 @@ class SSHCertificate(BaseModel):
|
||||
default=CertType.USER,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Principals (JSON list) - e.g., ["prod-servers", "dev-servers"]
|
||||
|
||||
# Principals — JSON list, e.g., ["prod-servers", "dev-servers"]
|
||||
principals = db.Column(db.JSON, nullable=False, default=list)
|
||||
|
||||
|
||||
# Validity period
|
||||
valid_after = db.Column(db.DateTime, nullable=False)
|
||||
valid_before = db.Column(db.DateTime, nullable=False)
|
||||
|
||||
|
||||
# Revocation status
|
||||
revoked = db.Column(db.Boolean, default=False, nullable=False, index=True)
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoke_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
|
||||
# Status tracking
|
||||
status = db.Column(
|
||||
db.Enum(CertificateStatus, values_callable=lambda x: [e.value for e in x]),
|
||||
@@ -77,19 +77,17 @@ class SSHCertificate(BaseModel):
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
|
||||
# Request metadata
|
||||
request_ip = db.Column(db.String(45), nullable=True)
|
||||
request_user_agent = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Critical options (JSON) - OpenSSH critical options
|
||||
# See: https://man.openbsd.org/ssh-cert
|
||||
|
||||
# Critical options — OpenSSH critical options (JSON)
|
||||
critical_options = db.Column(db.JSON, nullable=True, default=dict)
|
||||
|
||||
# Extensions (JSON) - OpenSSH extensions
|
||||
# Common ones: permit-X11-forwarding, permit-agent-forwarding, permit-pty, etc.
|
||||
|
||||
# Extensions — OpenSSH extensions (JSON)
|
||||
extensions = db.Column(db.JSON, nullable=True, default=dict)
|
||||
|
||||
|
||||
# Relationships
|
||||
ca = db.relationship("CA", back_populates="certificates")
|
||||
user = db.relationship("User", back_populates="ssh_certificates")
|
||||
@@ -115,67 +113,64 @@ class SSHCertificate(BaseModel):
|
||||
return f"<SSHCertificate serial={self.serial[:16]}... user_id={self.user_id}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert certificate to dictionary."""
|
||||
"""Convert certificate to dictionary.
|
||||
|
||||
The raw ``certificate`` blob is excluded by default (it is large and
|
||||
callers can request it explicitly by removing it from the exclude list).
|
||||
"""
|
||||
exclude = exclude or []
|
||||
# Optionally exclude the certificate content (it's large)
|
||||
if "certificate" not in exclude:
|
||||
exclude.append("certificate")
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
# Add computed fields
|
||||
data["is_valid"] = self.is_valid()
|
||||
data["days_until_expiry"] = self.days_until_expiry()
|
||||
|
||||
return data
|
||||
|
||||
def is_valid(self):
|
||||
def _aware(self, dt: datetime) -> datetime:
|
||||
"""Return a timezone-aware UTC datetime."""
|
||||
return dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if certificate is currently valid.
|
||||
|
||||
|
||||
Returns:
|
||||
True if certificate is issued, not revoked, and within validity period
|
||||
"""
|
||||
if self.revoked or self.status == CertificateStatus.REVOKED:
|
||||
return False
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_after = self.valid_after.replace(tzinfo=timezone.utc) if self.valid_after.tzinfo is None else self.valid_after
|
||||
valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before
|
||||
return valid_after <= now <= valid_before
|
||||
return self._aware(self.valid_after) <= now <= self._aware(self.valid_before)
|
||||
|
||||
def is_expired(self):
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if certificate has expired.
|
||||
|
||||
|
||||
Returns:
|
||||
True if current time is past valid_before
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before
|
||||
return now > valid_before
|
||||
return datetime.now(timezone.utc) > self._aware(self.valid_before)
|
||||
|
||||
def days_until_expiry(self):
|
||||
def days_until_expiry(self) -> int:
|
||||
"""Get number of days until certificate expires.
|
||||
|
||||
|
||||
Returns:
|
||||
Number of days remaining (negative if already expired)
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_before = self.valid_before.replace(tzinfo=timezone.utc) if self.valid_before.tzinfo is None else self.valid_before
|
||||
delta = valid_before - now
|
||||
delta = self._aware(self.valid_before) - datetime.now(timezone.utc)
|
||||
return delta.days + (1 if delta.seconds > 0 else 0)
|
||||
|
||||
def revoke(self, reason=None):
|
||||
def revoke(self, reason: str = None) -> None:
|
||||
"""Revoke this certificate.
|
||||
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked = True
|
||||
self.revoked_at = datetime.utcnow()
|
||||
self.revoked_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
|
||||
self.revoke_reason = reason
|
||||
self.status = CertificateStatus.REVOKED
|
||||
self.save()
|
||||
|
||||
def mark_expired(self):
|
||||
def mark_expired(self) -> None:
|
||||
"""Mark certificate as expired when validity period ends."""
|
||||
self.status = CertificateStatus.EXPIRED
|
||||
self.save()
|
||||
@@ -1,14 +1,14 @@
|
||||
"""SSH Key model."""
|
||||
from datetime import datetime
|
||||
"""SSH Key model — user SSH public keys registered for certificate signing."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class SSHKey(BaseModel):
|
||||
"""SSH Key model representing a user's SSH public key.
|
||||
|
||||
This model stores SSH public keys that users register for certificate signing.
|
||||
Users must verify ownership of the key before it can be used for signing certificates.
|
||||
|
||||
Users register SSH public keys for certificate signing. Keys must be
|
||||
verified (owner proved possession) before they can be used.
|
||||
"""
|
||||
|
||||
__tablename__ = "ssh_keys"
|
||||
@@ -19,33 +19,29 @@ class SSHKey(BaseModel):
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# SSH key payload in OpenSSH format (e.g., "ssh-rsa AAAAB3Nz...")
|
||||
|
||||
# SSH key payload in OpenSSH format (e.g., "ssh-ed25519 AAAAB3Nz...")
|
||||
payload = db.Column(db.Text, nullable=False, unique=True)
|
||||
|
||||
# SHA256 fingerprint for quick comparison
|
||||
|
||||
# SHA256 fingerprint for quick comparison and deduplication
|
||||
fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
|
||||
# Optional description for the key (e.g., "My laptop key")
|
||||
|
||||
# Optional human-readable description (e.g., "My laptop key")
|
||||
description = db.Column(db.String(255), nullable=True)
|
||||
|
||||
|
||||
# Verification status
|
||||
verified = db.Column(db.Boolean, default=False, nullable=False, index=True)
|
||||
verified_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Verification challenge
|
||||
|
||||
# Verification challenge — shown to user once, cleared after verification
|
||||
verify_text = db.Column(db.String(255), nullable=True)
|
||||
verify_text_created_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Key type extracted from the key (ssh-rsa, ssh-ed25519, etc.)
|
||||
key_type = db.Column(db.String(50), nullable=True)
|
||||
|
||||
# Key bits/length
|
||||
key_bits = db.Column(db.Integer, nullable=True)
|
||||
|
||||
# Comment from the key (usually email or key name)
|
||||
|
||||
# Key metadata extracted from the key
|
||||
key_type = db.Column(db.String(50), nullable=True) # ssh-rsa, ssh-ed25519, etc.
|
||||
key_bits = db.Column(db.Integer, nullable=True) # key length
|
||||
key_comment = db.Column(db.String(255), nullable=True)
|
||||
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="ssh_keys")
|
||||
certificates = db.relationship(
|
||||
@@ -64,33 +60,39 @@ class SSHKey(BaseModel):
|
||||
return f"<SSHKey {self.fingerprint[:16]}... user_id={self.user_id}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert SSH key to dictionary."""
|
||||
"""Convert SSH key to dictionary.
|
||||
|
||||
``payload`` and ``verify_text`` are never exposed through the API.
|
||||
"""
|
||||
exclude = exclude or []
|
||||
exclude.extend(["payload", "verify_text"]) # Never expose these in API
|
||||
for field in ("payload", "verify_text"):
|
||||
if field not in exclude:
|
||||
exclude.append(field)
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
# Add computed fields
|
||||
data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None])
|
||||
|
||||
return data
|
||||
|
||||
def mark_verified(self):
|
||||
"""Mark this SSH key as verified."""
|
||||
def mark_verified(self) -> None:
|
||||
"""Mark this SSH key as verified and clear the challenge."""
|
||||
self.verified = True
|
||||
self.verified_at = datetime.utcnow()
|
||||
self.verified_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
|
||||
self.verify_text = None
|
||||
self.save()
|
||||
|
||||
def needs_verification_refresh(self, max_age_hours=24):
|
||||
def needs_verification_refresh(self, max_age_hours: int = 24) -> bool:
|
||||
"""Check if verification challenge needs to be refreshed.
|
||||
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age of verification challenge in hours
|
||||
|
||||
|
||||
Returns:
|
||||
True if verification challenge is stale
|
||||
True if verification challenge is stale or missing
|
||||
"""
|
||||
if not self.verify_text_created_at:
|
||||
return True
|
||||
|
||||
age = datetime.utcnow() - self.verify_text_created_at
|
||||
age = datetime.now(timezone.utc) - self.verify_text_created_at.replace(
|
||||
tzinfo=timezone.utc
|
||||
) if self.verify_text_created_at.tzinfo is None else (
|
||||
datetime.now(timezone.utc) - self.verify_text_created_at
|
||||
)
|
||||
return age.total_seconds() > (max_age_hours * 3600)
|
||||
@@ -1,177 +1,4 @@
|
||||
"""User model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
"""Backward-compatibility shim — import from gatehouse_app.models.user.user instead."""
|
||||
from gatehouse_app.models.user.user import User # noqa: F401
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""User model representing a user account."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
email_verified = db.Column(db.Boolean, default=False, nullable=False)
|
||||
full_name = db.Column(db.String(255), nullable=True)
|
||||
avatar_url = db.Column(db.String(512), nullable=True)
|
||||
status = db.Column(
|
||||
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
|
||||
)
|
||||
last_login_at = db.Column(db.DateTime, nullable=True)
|
||||
last_login_ip = db.Column(db.String(45), nullable=True)
|
||||
|
||||
# Relationships
|
||||
authentication_methods = db.relationship(
|
||||
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
|
||||
organization_memberships = db.relationship(
|
||||
"OrganizationMember",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="OrganizationMember.user_id",
|
||||
)
|
||||
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
|
||||
security_policies = db.relationship(
|
||||
"UserSecurityPolicy",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="UserSecurityPolicy.user_id",
|
||||
)
|
||||
mfa_compliance = db.relationship(
|
||||
"MfaPolicyCompliance",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="MfaPolicyCompliance.user_id",
|
||||
)
|
||||
department_memberships = db.relationship(
|
||||
"DepartmentMembership",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="DepartmentMembership.user_id",
|
||||
)
|
||||
principal_memberships = db.relationship(
|
||||
"PrincipalMembership",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="PrincipalMembership.user_id",
|
||||
)
|
||||
ssh_keys = db.relationship(
|
||||
"SSHKey",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHKey.user_id",
|
||||
)
|
||||
ssh_certificates = db.relationship(
|
||||
"SSHCertificate",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHCertificate.user_id",
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of User."""
|
||||
return f"<User {self.email}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert user to dictionary, excluding sensitive fields by default."""
|
||||
exclude = exclude or []
|
||||
# Always exclude password-related fields
|
||||
default_exclude = []
|
||||
all_exclude = list(set(default_exclude + exclude))
|
||||
return super().to_dict(exclude=all_exclude)
|
||||
|
||||
def has_password_auth(self):
|
||||
"""Check if user has password authentication enabled."""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_organizations(self):
|
||||
"""Get all organizations the user is a member of."""
|
||||
return [membership.organization for membership in self.organization_memberships]
|
||||
|
||||
def has_totp_enabled(self) -> bool:
|
||||
"""Check if user has TOTP enabled and verified.
|
||||
|
||||
Returns:
|
||||
True if user has a verified TOTP authentication method, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.TOTP,
|
||||
verified=True,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_totp_method(self):
|
||||
"""Get user's TOTP authentication method.
|
||||
|
||||
Returns:
|
||||
The AuthenticationMethod instance for TOTP or None if not found.
|
||||
|
||||
Note:
|
||||
Returns the most recently created TOTP method to handle cases where
|
||||
multiple enrollment attempts may exist.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
|
||||
).order_by(AuthenticationMethod.created_at.desc()).first()
|
||||
|
||||
def has_webauthn_enabled(self) -> bool:
|
||||
"""Check if user has any WebAuthn passkey credentials.
|
||||
|
||||
Returns:
|
||||
True if user has at least one WebAuthn credential, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_webauthn_credentials(self):
|
||||
"""Get all WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
).order_by(AuthenticationMethod.created_at.desc()).all()
|
||||
|
||||
def get_webauthn_credential_count(self) -> int:
|
||||
"""Get the count of WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
Number of WebAuthn credentials.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
).count()
|
||||
__all__ = ["User"]
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""User subpackage."""
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.user.session import Session
|
||||
|
||||
__all__ = ["User", "Session"]
|
||||
@@ -21,7 +21,9 @@ class Session(BaseModel):
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
last_activity_at = db.Column(
|
||||
db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
@@ -38,7 +40,6 @@ class Session(BaseModel):
|
||||
def is_active(self):
|
||||
"""Check if session is currently active."""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make expires_at timezone-aware if it's naive
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
@@ -51,15 +52,13 @@ class Session(BaseModel):
|
||||
def is_expired(self):
|
||||
"""Check if session has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make expires_at timezone-aware if it's naive
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return now > expires_at
|
||||
|
||||
def refresh(self, duration_seconds=86400):
|
||||
"""
|
||||
Refresh session expiration.
|
||||
def refresh(self, duration_seconds: int = 86400):
|
||||
"""Refresh session expiration.
|
||||
|
||||
Args:
|
||||
duration_seconds: New session duration in seconds
|
||||
@@ -68,9 +67,8 @@ class Session(BaseModel):
|
||||
self.last_activity_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""
|
||||
Revoke the session.
|
||||
def revoke(self, reason: str = None):
|
||||
"""Revoke the session.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
@@ -84,6 +82,5 @@ class Session(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Exclude token from dict
|
||||
exclude.append("token")
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,209 @@
|
||||
"""User model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""User model representing a user account."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
email_verified = db.Column(db.Boolean, default=False, nullable=False)
|
||||
full_name = db.Column(db.String(255), nullable=True)
|
||||
avatar_url = db.Column(db.String(512), nullable=True)
|
||||
status = db.Column(
|
||||
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
|
||||
)
|
||||
last_login_at = db.Column(db.DateTime, nullable=True)
|
||||
last_login_ip = db.Column(db.String(45), nullable=True)
|
||||
|
||||
# Account activation (email-link flow)
|
||||
activated = db.Column(db.Boolean, default=True, nullable=False)
|
||||
activation_key = db.Column(db.String(128), unique=True, nullable=True, index=True)
|
||||
|
||||
# Relationships – defined here only for models that don't circular-import
|
||||
authentication_methods = db.relationship(
|
||||
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
|
||||
organization_memberships = db.relationship(
|
||||
"OrganizationMember",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="OrganizationMember.user_id",
|
||||
)
|
||||
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
|
||||
security_policies = db.relationship(
|
||||
"UserSecurityPolicy",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="UserSecurityPolicy.user_id",
|
||||
)
|
||||
mfa_compliance = db.relationship(
|
||||
"MfaPolicyCompliance",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="MfaPolicyCompliance.user_id",
|
||||
)
|
||||
department_memberships = db.relationship(
|
||||
"DepartmentMembership",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="DepartmentMembership.user_id",
|
||||
)
|
||||
principal_memberships = db.relationship(
|
||||
"PrincipalMembership",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="PrincipalMembership.user_id",
|
||||
)
|
||||
ssh_keys = db.relationship(
|
||||
"SSHKey",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHKey.user_id",
|
||||
)
|
||||
ssh_certificates = db.relationship(
|
||||
"SSHCertificate",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHCertificate.user_id",
|
||||
)
|
||||
ca_permissions = db.relationship(
|
||||
"CAPermission",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="CAPermission.user_id",
|
||||
)
|
||||
|
||||
# OIDC relationships – registered here (no monkey-patching needed)
|
||||
oidc_auth_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of User."""
|
||||
return f"<User {self.email}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert user to dictionary, excluding sensitive fields by default."""
|
||||
exclude = exclude or []
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
def has_password_auth(self):
|
||||
"""Check if user has password authentication enabled."""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_organizations(self):
|
||||
"""Get all organizations the user is a member of."""
|
||||
return [membership.organization for membership in self.organization_memberships]
|
||||
|
||||
def has_totp_enabled(self) -> bool:
|
||||
"""Check if user has TOTP enabled and verified.
|
||||
|
||||
Returns:
|
||||
True if user has a verified TOTP authentication method, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.TOTP,
|
||||
verified=True,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_totp_method(self):
|
||||
"""Get user's TOTP authentication method.
|
||||
|
||||
Returns:
|
||||
The AuthenticationMethod instance for TOTP or None if not found.
|
||||
|
||||
Note:
|
||||
Returns the most recently created TOTP method to handle cases where
|
||||
multiple enrollment attempts may exist.
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
|
||||
)
|
||||
.order_by(AuthenticationMethod.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
def has_webauthn_enabled(self) -> bool:
|
||||
"""Check if user has any WebAuthn passkey credentials.
|
||||
|
||||
Returns:
|
||||
True if user has at least one WebAuthn credential, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_webauthn_credentials(self):
|
||||
"""Get all WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
)
|
||||
.order_by(AuthenticationMethod.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_webauthn_credential_count(self) -> int:
|
||||
"""Get the count of WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
Number of WebAuthn credentials.
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
).count()
|
||||
Reference in New Issue
Block a user