google login works
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
"""Authentication method model."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import secrets
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.utils.encryption import encrypt, decrypt
|
||||
|
||||
|
||||
class AuthenticationMethod(BaseModel):
|
||||
@@ -91,3 +94,287 @@ class AuthenticationMethod(BaseModel):
|
||||
"last_used_at": data.get("last_used_at"),
|
||||
"sign_count": data.get("sign_count", 0),
|
||||
}
|
||||
|
||||
|
||||
class ApplicationProviderConfig(BaseModel):
|
||||
"""Application-wide OAuth provider configuration.
|
||||
|
||||
This model stores OAuth provider credentials at the application level,
|
||||
allowing users to authenticate without needing to specify an organization first.
|
||||
"""
|
||||
|
||||
__tablename__ = "application_provider_configs"
|
||||
|
||||
# Provider identification
|
||||
provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True)
|
||||
|
||||
# OAuth credentials (encrypted)
|
||||
client_id = db.Column(db.String(255), nullable=False)
|
||||
client_secret_encrypted = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Provider status
|
||||
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
# Default redirect URL
|
||||
default_redirect_url = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
# Provider-specific settings (JSON)
|
||||
additional_config = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization_overrides = db.relationship(
|
||||
"OrganizationProviderOverride",
|
||||
back_populates="application_config",
|
||||
foreign_keys="OrganizationProviderOverride.provider_type",
|
||||
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of ApplicationProviderConfig."""
|
||||
return f"<ApplicationProviderConfig provider={self.provider_type} enabled={self.is_enabled}>"
|
||||
|
||||
def set_client_secret(self, plaintext_secret: str):
|
||||
"""Encrypt and store client secret.
|
||||
|
||||
Args:
|
||||
plaintext_secret: The plaintext OAuth client secret
|
||||
"""
|
||||
if plaintext_secret:
|
||||
self.client_secret_encrypted = encrypt(plaintext_secret)
|
||||
|
||||
def get_client_secret(self) -> str:
|
||||
"""Decrypt and return client secret.
|
||||
|
||||
Returns:
|
||||
The plaintext OAuth client secret
|
||||
"""
|
||||
if self.client_secret_encrypted:
|
||||
return decrypt(self.client_secret_encrypted)
|
||||
return None
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude encrypted client secret
|
||||
exclude.append("client_secret_encrypted")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
class OrganizationProviderOverride(BaseModel):
|
||||
"""Organization-specific OAuth configuration overrides.
|
||||
|
||||
This model allows organizations to override application-level OAuth settings
|
||||
for enterprise SSO scenarios or custom provider configurations.
|
||||
"""
|
||||
|
||||
__tablename__ = "organization_provider_overrides"
|
||||
|
||||
# References
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"),
|
||||
nullable=False, index=True
|
||||
)
|
||||
provider_type = db.Column(db.String(50), nullable=False, index=True)
|
||||
|
||||
# Override OAuth credentials (encrypted, nullable - only if overriding)
|
||||
client_id = db.Column(db.String(255), nullable=True)
|
||||
client_secret_encrypted = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Provider status
|
||||
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
# Redirect URL override
|
||||
redirect_url_override = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
# Provider-specific settings override (JSON)
|
||||
additional_config = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", backref="provider_overrides")
|
||||
application_config = db.relationship(
|
||||
"ApplicationProviderConfig",
|
||||
back_populates="organization_overrides",
|
||||
foreign_keys=[provider_type],
|
||||
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
|
||||
viewonly=True
|
||||
)
|
||||
|
||||
# Unique constraint on (organization_id, provider_type)
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"organization_id", "provider_type",
|
||||
name="uix_org_provider_type"
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OrganizationProviderOverride."""
|
||||
return f"<OrganizationProviderOverride org={self.organization_id} provider={self.provider_type}>"
|
||||
|
||||
def set_client_secret(self, plaintext_secret: str):
|
||||
"""Encrypt and store client secret override.
|
||||
|
||||
Args:
|
||||
plaintext_secret: The plaintext OAuth client secret
|
||||
"""
|
||||
if plaintext_secret:
|
||||
self.client_secret_encrypted = encrypt(plaintext_secret)
|
||||
|
||||
def get_client_secret(self) -> str:
|
||||
"""Decrypt and return client secret override.
|
||||
|
||||
Returns:
|
||||
The plaintext OAuth client secret
|
||||
"""
|
||||
if self.client_secret_encrypted:
|
||||
return decrypt(self.client_secret_encrypted)
|
||||
return None
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude encrypted client secret
|
||||
exclude.append("client_secret_encrypted")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
"""OAuth flow state tracking.
|
||||
|
||||
This model tracks OAuth authentication flow state, including PKCE parameters
|
||||
and organization context (which is now optional to support login flows where
|
||||
the organization isn't known until after authentication).
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
# OAuth state parameter (unique, used for CSRF protection)
|
||||
state = db.Column(db.String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
# Flow type: "login", "register", "link"
|
||||
flow_type = db.Column(db.String(50), nullable=False)
|
||||
|
||||
# Provider type
|
||||
provider_type = db.Column(db.String(50), nullable=False)
|
||||
|
||||
# User context (optional - not set for login/register flows)
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
|
||||
|
||||
# Organization context (NOW OPTIONAL - for SSO discovery or post-auth)
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"),
|
||||
nullable=True, index=True
|
||||
)
|
||||
|
||||
# PKCE parameters
|
||||
nonce = db.Column(db.String(128), nullable=True)
|
||||
code_verifier = db.Column(db.String(128), nullable=True)
|
||||
code_challenge = db.Column(db.String(128), nullable=True)
|
||||
|
||||
# OAuth parameters
|
||||
redirect_uri = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
# Post-auth redirect (for frontend routing)
|
||||
return_url = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
# Additional state data
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Expiration and usage tracking
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
used = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", backref="oauth_states")
|
||||
organization = db.relationship("Organization", backref="oauth_states")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OAuthState."""
|
||||
return f"<OAuthState state={self.state[:8]}... flow={self.flow_type} provider={self.provider_type}>"
|
||||
|
||||
@classmethod
|
||||
def create_state(
|
||||
cls,
|
||||
flow_type: str,
|
||||
provider_type: str,
|
||||
user_id: str = None,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
return_url: str = None,
|
||||
code_verifier: str = None,
|
||||
code_challenge: str = None,
|
||||
nonce: str = None,
|
||||
extra_data: dict = None,
|
||||
lifetime_seconds: int = 600
|
||||
):
|
||||
"""Create a new OAuth state with auto-generated state parameter.
|
||||
|
||||
Args:
|
||||
flow_type: Type of flow ("login", "register", "link")
|
||||
provider_type: OAuth provider type
|
||||
user_id: Optional user ID for authenticated flows
|
||||
organization_id: Optional organization ID
|
||||
redirect_uri: OAuth callback URI
|
||||
return_url: Post-auth redirect destination
|
||||
code_verifier: PKCE code verifier
|
||||
code_challenge: PKCE code challenge
|
||||
nonce: OpenID Connect nonce
|
||||
extra_data: Additional state data
|
||||
lifetime_seconds: How long the state is valid (default 10 minutes)
|
||||
|
||||
Returns:
|
||||
New OAuthState instance
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds)
|
||||
|
||||
oauth_state = cls(
|
||||
state=state,
|
||||
flow_type=flow_type,
|
||||
provider_type=provider_type,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri,
|
||||
return_url=return_url,
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
nonce=nonce,
|
||||
extra_data=extra_data,
|
||||
expires_at=expires_at,
|
||||
used=False
|
||||
)
|
||||
oauth_state.save()
|
||||
return oauth_state
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the OAuth state is still valid.
|
||||
|
||||
Returns:
|
||||
True if state hasn't expired and hasn't been used
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make expires_at timezone-aware if it's naive (database returns naive datetimes)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return not self.used and expires_at > now
|
||||
|
||||
def mark_used(self):
|
||||
"""Mark the state as used to prevent replay attacks."""
|
||||
self.used = True
|
||||
self.save()
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired(cls):
|
||||
"""Remove expired OAuth states."""
|
||||
now = datetime.now(timezone.utc)
|
||||
cls.query.filter(cls.expires_at < now).delete()
|
||||
db.session.commit()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Exclude code_verifier as it's sensitive
|
||||
exclude.append("code_verifier")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
Reference in New Issue
Block a user