can link google accounts!
This commit is contained in:
@@ -105,3 +105,232 @@ class AuditService:
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# External Authentication Provider Audit Methods
|
||||
|
||||
@staticmethod
|
||||
def log_external_auth_link_initiated(
|
||||
user_id: str,
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
state_id: str = None,
|
||||
):
|
||||
"""Log external auth account linking initiated event."""
|
||||
return AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_LINK_INITIATED,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
resource_type="oauth_state",
|
||||
resource_id=state_id,
|
||||
metadata={
|
||||
"provider_type": provider_type,
|
||||
},
|
||||
description=f"External auth link initiated for {provider_type}",
|
||||
success=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_external_auth_link_completed(
|
||||
user_id: str,
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
provider_user_id: str,
|
||||
auth_method_id: str = None,
|
||||
):
|
||||
"""Log external auth account linking completed event."""
|
||||
return AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_LINK_COMPLETED,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method_id,
|
||||
metadata={
|
||||
"provider_type": provider_type,
|
||||
"provider_user_id": provider_user_id,
|
||||
},
|
||||
description=f"External auth account linked: {provider_type} ({provider_user_id})",
|
||||
success=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_external_auth_link_failed(
|
||||
user_id: str,
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
error_message: str,
|
||||
failure_reason: str = None,
|
||||
):
|
||||
"""Log external auth account linking failed event."""
|
||||
return AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_LINK_FAILED,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type,
|
||||
"failure_reason": failure_reason,
|
||||
},
|
||||
description=f"External auth link failed for {provider_type}: {error_message}",
|
||||
success=False,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_external_auth_unlink(
|
||||
user_id: str,
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
provider_user_id: str,
|
||||
auth_method_id: str = None,
|
||||
):
|
||||
"""Log external auth account unlinking event."""
|
||||
return AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_UNLINK,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method_id,
|
||||
metadata={
|
||||
"provider_type": provider_type,
|
||||
"provider_user_id": provider_user_id,
|
||||
},
|
||||
description=f"External auth account unlinked: {provider_type} ({provider_user_id})",
|
||||
success=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_external_auth_login(
|
||||
user_id: str,
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
provider_user_id: str,
|
||||
auth_method_id: str = None,
|
||||
session_id: str = None,
|
||||
mfa_used: bool = False,
|
||||
):
|
||||
"""Log external auth login event."""
|
||||
return AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
resource_type="session",
|
||||
resource_id=session_id,
|
||||
metadata={
|
||||
"provider_type": provider_type,
|
||||
"provider_user_id": provider_user_id,
|
||||
"auth_method_id": auth_method_id,
|
||||
"mfa_used": mfa_used,
|
||||
},
|
||||
description=f"User logged in with {provider_type}",
|
||||
success=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_external_auth_login_failed(
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
provider_user_id: str = None,
|
||||
email: str = None,
|
||||
failure_reason: str = None,
|
||||
error_message: str = None,
|
||||
):
|
||||
"""Log external auth login failed event."""
|
||||
return AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
|
||||
user_id=None, # Unknown user
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type,
|
||||
"provider_user_id": provider_user_id,
|
||||
"email": email,
|
||||
"failure_reason": failure_reason,
|
||||
},
|
||||
description=f"Failed login attempt with {provider_type}: {failure_reason or error_message}",
|
||||
success=False,
|
||||
error_message=error_message or failure_reason,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_external_auth_token_refresh(
|
||||
user_id: str,
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
auth_method_id: str = None,
|
||||
):
|
||||
"""Log external auth token refresh event."""
|
||||
return AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_TOKEN_REFRESH,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method_id,
|
||||
metadata={
|
||||
"provider_type": provider_type,
|
||||
},
|
||||
description=f"External auth token refreshed for {provider_type}",
|
||||
success=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_external_auth_config_create(
|
||||
user_id: str,
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
config_id: str = None,
|
||||
):
|
||||
"""Log external auth provider config creation event."""
|
||||
return AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_CONFIG_CREATE,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
resource_type="external_provider_config",
|
||||
resource_id=config_id,
|
||||
metadata={
|
||||
"provider_type": provider_type,
|
||||
},
|
||||
description=f"External auth provider config created: {provider_type}",
|
||||
success=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_external_auth_config_update(
|
||||
user_id: str,
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
config_id: str = None,
|
||||
changes: dict = None,
|
||||
):
|
||||
"""Log external auth provider config update event."""
|
||||
return AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_CONFIG_UPDATE,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
resource_type="external_provider_config",
|
||||
resource_id=config_id,
|
||||
metadata={
|
||||
"provider_type": provider_type,
|
||||
"changes": changes,
|
||||
},
|
||||
description=f"External auth provider config updated: {provider_type}",
|
||||
success=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_external_auth_config_delete(
|
||||
user_id: str,
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
config_id: str = None,
|
||||
):
|
||||
"""Log external auth provider config deletion event."""
|
||||
return AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_CONFIG_DELETE,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
resource_type="external_provider_config",
|
||||
resource_id=config_id,
|
||||
metadata={
|
||||
"provider_type": provider_type,
|
||||
},
|
||||
description=f"External auth provider config deleted: {provider_type}",
|
||||
success=True,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,761 @@
|
||||
"""External authentication provider service."""
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExternalAuthError(Exception):
|
||||
"""Base exception for external auth errors."""
|
||||
|
||||
def __init__(self, message: str, error_type: str, status_code: int = 400):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
"""Temporary OAuth state storage for secure flow management."""
|
||||
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
# State identifier (used in OAuth redirects)
|
||||
state = db.Column(db.String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
# Flow type
|
||||
flow_type = db.Column(db.String(50), nullable=False) # 'link', 'login', 'register'
|
||||
|
||||
# User context
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True)
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Provider information
|
||||
provider_type = db.Column(db.String(50), nullable=False)
|
||||
|
||||
# OAuth parameters
|
||||
nonce = db.Column(db.String(128), nullable=True)
|
||||
code_verifier = db.Column(db.String(128), nullable=True)
|
||||
code_challenge = db.Column(db.String(128), nullable=True)
|
||||
redirect_uri = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
# Additional state data
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Expiration
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
|
||||
# Status
|
||||
used = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
@classmethod
|
||||
def create_state(
|
||||
cls,
|
||||
flow_type: str,
|
||||
provider_type: AuthMethodType,
|
||||
user_id: str = None,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
nonce: str = None,
|
||||
code_verifier: str = None,
|
||||
code_challenge: str = None,
|
||||
extra_data: dict = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> "OAuthState":
|
||||
"""Create a new OAuth state record."""
|
||||
state = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds)
|
||||
|
||||
return cls.create(
|
||||
state=state,
|
||||
flow_type=flow_type,
|
||||
provider_type=provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri,
|
||||
nonce=nonce or secrets.token_urlsafe(16),
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
extra_data=extra_data,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if state is still valid."""
|
||||
return (
|
||||
not self.used
|
||||
and self.expires_at
|
||||
and self.expires_at.replace(tzinfo=timezone.utc) > datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
def mark_used(self):
|
||||
"""Mark state as used."""
|
||||
self.used = True
|
||||
self.save()
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired(cls):
|
||||
"""Remove expired states."""
|
||||
cls.query.filter(cls.expires_at < datetime.now(timezone.utc)).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
class ExternalProviderConfig(BaseModel):
|
||||
"""OAuth provider configuration per organization."""
|
||||
|
||||
__tablename__ = "external_provider_configs"
|
||||
|
||||
# Organization reference
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Provider type
|
||||
provider_type = db.Column(db.String(50), nullable=False, index=True)
|
||||
|
||||
# OAuth credentials (client_secret is encrypted)
|
||||
client_id = db.Column(db.String(255), nullable=False)
|
||||
client_secret_encrypted = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Provider endpoints
|
||||
auth_url = db.Column(db.String(2048), nullable=False)
|
||||
token_url = db.Column(db.String(2048), nullable=False)
|
||||
userinfo_url = db.Column(db.String(2048), nullable=True)
|
||||
jwks_url = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
# Configuration
|
||||
scopes = db.Column(db.JSON, nullable=False, default=list)
|
||||
redirect_uris = db.Column(db.JSON, nullable=False, default=list)
|
||||
|
||||
# Provider-specific settings
|
||||
settings = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship(
|
||||
"Organization", back_populates="external_provider_configs"
|
||||
)
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
db.Index("idx_provider_config_org", "organization_id", "provider_type"),
|
||||
db.UniqueConstraint(
|
||||
"organization_id",
|
||||
"provider_type",
|
||||
name="uix_org_provider_type",
|
||||
),
|
||||
)
|
||||
|
||||
def get_client_secret(self) -> str:
|
||||
"""Decrypt and return client secret."""
|
||||
from gatehouse_app.utils.encryption import decrypt
|
||||
if self.client_secret_encrypted:
|
||||
return decrypt(self.client_secret_encrypted)
|
||||
return None
|
||||
|
||||
def set_client_secret(self, secret: str):
|
||||
"""Encrypt and store client secret."""
|
||||
from gatehouse_app.utils.encryption import encrypt
|
||||
self.client_secret_encrypted = encrypt(secret)
|
||||
|
||||
def is_redirect_uri_allowed(self, uri: str) -> bool:
|
||||
"""Check if redirect URI is allowed."""
|
||||
return uri in (self.redirect_uris or [])
|
||||
|
||||
def to_dict(self, include_secrets: bool = False) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
data = {
|
||||
"id": self.id,
|
||||
"organization_id": self.organization_id,
|
||||
"provider_type": self.provider_type,
|
||||
"client_id": self.client_id,
|
||||
"auth_url": self.auth_url,
|
||||
"token_url": self.token_url,
|
||||
"userinfo_url": self.userinfo_url,
|
||||
"scopes": self.scopes,
|
||||
"redirect_uris": self.redirect_uris,
|
||||
"is_active": self.is_active,
|
||||
"settings": self.settings,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
if include_secrets and self.client_secret_encrypted:
|
||||
data["client_secret"] = self.get_client_secret()
|
||||
return data
|
||||
|
||||
|
||||
class ExternalAuthService:
|
||||
"""Service for external authentication operations."""
|
||||
|
||||
@classmethod
|
||||
def get_provider_config(
|
||||
cls,
|
||||
organization_id: str,
|
||||
provider_type: AuthMethodType,
|
||||
) -> ExternalProviderConfig:
|
||||
"""Get provider configuration for organization."""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
config = ExternalProviderConfig.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
is_active=True,
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
raise ExternalAuthError(
|
||||
f"{provider_type_str.title()} OAuth is not configured for this organization",
|
||||
"PROVIDER_NOT_CONFIGURED",
|
||||
400,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def initiate_link_flow(
|
||||
cls,
|
||||
user_id: str,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str,
|
||||
redirect_uri: str = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Initiate account linking flow.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, state)
|
||||
"""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
# Get provider config
|
||||
config = cls.get_provider_config(organization_id, provider_type)
|
||||
|
||||
# Validate redirect URI
|
||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||
raise ExternalAuthError(
|
||||
"Invalid redirect URI",
|
||||
"INVALID_REDIRECT_URI",
|
||||
400,
|
||||
)
|
||||
|
||||
# Generate PKCE
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = cls._compute_s256_challenge(code_verifier)
|
||||
|
||||
# Create OAuth state
|
||||
state = OAuthState.create_state(
|
||||
flow_type="link",
|
||||
provider_type=provider_type,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri or config.redirect_uris[0],
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
|
||||
# Build authorization URL (simplified - in production would use provider-specific implementation)
|
||||
auth_url = cls._build_authorization_url(
|
||||
config=config,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Audit log - link initiated
|
||||
AuditService.log_external_auth_link_initiated(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
state_id=state.id,
|
||||
)
|
||||
|
||||
return auth_url, state.state
|
||||
|
||||
@classmethod
|
||||
def complete_link_flow(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
authorization_code: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
) -> AuthenticationMethod:
|
||||
"""Complete account linking flow."""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
# Validate state
|
||||
state_record = OAuthState.query.filter_by(state=state).first()
|
||||
if not state_record or not state_record.is_valid():
|
||||
AuditService.log_external_auth_link_failed(
|
||||
user_id=None,
|
||||
organization_id=None,
|
||||
provider_type=provider_type_str,
|
||||
error_message="Invalid or expired OAuth state",
|
||||
failure_reason="invalid_state",
|
||||
)
|
||||
raise ExternalAuthError(
|
||||
"Invalid or expired OAuth state",
|
||||
"INVALID_STATE",
|
||||
400,
|
||||
)
|
||||
|
||||
if state_record.flow_type != "link":
|
||||
AuditService.log_external_auth_link_failed(
|
||||
user_id=state_record.user_id,
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
error_message="Invalid flow type for this operation",
|
||||
failure_reason="invalid_flow_type",
|
||||
)
|
||||
raise ExternalAuthError(
|
||||
"Invalid flow type for this operation",
|
||||
"INVALID_FLOW_TYPE",
|
||||
400,
|
||||
)
|
||||
|
||||
if state_record.provider_type != provider_type_str:
|
||||
AuditService.log_external_auth_link_failed(
|
||||
user_id=state_record.user_id,
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
error_message="Provider mismatch",
|
||||
failure_reason="provider_mismatch",
|
||||
)
|
||||
raise ExternalAuthError(
|
||||
"Provider mismatch",
|
||||
"PROVIDER_MISMATCH",
|
||||
400,
|
||||
)
|
||||
|
||||
# Get provider config
|
||||
config = cls.get_provider_config(
|
||||
state_record.organization_id, provider_type
|
||||
)
|
||||
|
||||
# Exchange code for tokens (simplified - in production would use provider-specific implementation)
|
||||
tokens = cls._exchange_code(
|
||||
config=config,
|
||||
code=authorization_code,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=state_record.code_verifier,
|
||||
)
|
||||
|
||||
# Get user info
|
||||
user_info = cls._get_user_info(
|
||||
config=config,
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = User.query.get(state_record.user_id)
|
||||
if not user:
|
||||
AuditService.log_external_auth_link_failed(
|
||||
user_id=None,
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
error_message="User not found",
|
||||
failure_reason="user_not_found",
|
||||
)
|
||||
raise ExternalAuthError(
|
||||
"User not found",
|
||||
"USER_NOT_FOUND",
|
||||
400,
|
||||
)
|
||||
|
||||
# Create or update authentication method
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
).first()
|
||||
|
||||
if auth_method:
|
||||
# Update existing
|
||||
auth_method.provider_data = cls._encrypt_provider_data(tokens, user_info)
|
||||
auth_method.verified = user_info.get("email_verified", False)
|
||||
auth_method.last_used_at = datetime.utcnow()
|
||||
auth_method.save()
|
||||
else:
|
||||
# Create new
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=user.id,
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
provider_data=cls._encrypt_provider_data(tokens, user_info),
|
||||
verified=user_info.get("email_verified", False),
|
||||
is_primary=False,
|
||||
last_used_at=datetime.utcnow(),
|
||||
)
|
||||
auth_method.save()
|
||||
|
||||
# Mark state as used
|
||||
state_record.mark_used()
|
||||
|
||||
# Audit log - link completed
|
||||
AuditService.log_external_auth_link_completed(
|
||||
user_id=user.id,
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
auth_method_id=auth_method.id,
|
||||
)
|
||||
|
||||
return auth_method
|
||||
|
||||
@classmethod
|
||||
def authenticate_with_provider(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str,
|
||||
authorization_code: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
) -> Tuple[User, dict]:
|
||||
"""Authenticate user with external provider and return tokens."""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
# Validate state
|
||||
state_record = OAuthState.query.filter_by(state=state).first()
|
||||
if not state_record or not state_record.is_valid():
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
failure_reason="invalid_state",
|
||||
error_message="Invalid or expired OAuth state",
|
||||
)
|
||||
raise ExternalAuthError(
|
||||
"Invalid or expired OAuth state",
|
||||
"INVALID_STATE",
|
||||
400,
|
||||
)
|
||||
|
||||
# Get provider config
|
||||
config = cls.get_provider_config(organization_id, provider_type)
|
||||
|
||||
# Exchange code for tokens
|
||||
tokens = cls._exchange_code(
|
||||
config=config,
|
||||
code=authorization_code,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=state_record.code_verifier,
|
||||
)
|
||||
|
||||
# Get user info
|
||||
user_info = cls._get_user_info(
|
||||
config=config,
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
# Look up user by provider_user_id
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
).first()
|
||||
|
||||
if not auth_method:
|
||||
# Check if email matches existing user
|
||||
existing_user = User.query.filter_by(
|
||||
email=user_info["email"]
|
||||
).first()
|
||||
|
||||
if existing_user:
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
email=user_info["email"],
|
||||
failure_reason="email_exists",
|
||||
error_message=f"An account with email {user_info['email']} already exists",
|
||||
)
|
||||
raise ExternalAuthError(
|
||||
f"An account with email {user_info['email']} already exists. "
|
||||
"Please log in with your password and link your Google account from settings.",
|
||||
"EMAIL_EXISTS",
|
||||
400,
|
||||
)
|
||||
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
email=user_info["email"],
|
||||
failure_reason="account_not_found",
|
||||
error_message="No Gatehouse account matches this external account",
|
||||
)
|
||||
raise ExternalAuthError(
|
||||
"No Gatehouse account matches this external account. Please register first.",
|
||||
"ACCOUNT_NOT_FOUND",
|
||||
400,
|
||||
)
|
||||
|
||||
user = auth_method.user
|
||||
|
||||
# Update tokens
|
||||
auth_method.provider_data = cls._encrypt_provider_data(tokens, user_info)
|
||||
auth_method.last_used_at = datetime.utcnow()
|
||||
auth_method.save()
|
||||
|
||||
# Mark state as used
|
||||
state_record.mark_used()
|
||||
|
||||
# Create session
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(
|
||||
user=user,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
# Audit log - login success
|
||||
AuditService.log_external_auth_login(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
auth_method_id=auth_method.id,
|
||||
session_id=session.id,
|
||||
)
|
||||
|
||||
return user, session.to_dict()
|
||||
|
||||
@classmethod
|
||||
def unlink_provider(
|
||||
cls,
|
||||
user_id: str,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
) -> bool:
|
||||
"""Unlink external provider from user account."""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user_id,
|
||||
method_type=provider_type,
|
||||
).first()
|
||||
|
||||
if not auth_method:
|
||||
raise ExternalAuthError(
|
||||
f"Provider not linked",
|
||||
"PROVIDER_NOT_LINKED",
|
||||
400,
|
||||
)
|
||||
|
||||
# Check if this is the last auth method
|
||||
other_methods = AuthenticationMethod.query.filter_by(
|
||||
user_id=user_id,
|
||||
).count()
|
||||
|
||||
if other_methods <= 1:
|
||||
raise ExternalAuthError(
|
||||
"Cannot unlink the last authentication method",
|
||||
"CANNOT_UNLINK_LAST",
|
||||
400,
|
||||
)
|
||||
|
||||
provider_user_id = auth_method.provider_user_id
|
||||
auth_method_id = auth_method.id
|
||||
auth_method.delete()
|
||||
|
||||
# Audit log - unlink
|
||||
AuditService.log_external_auth_unlink(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=provider_user_id,
|
||||
auth_method_id=auth_method_id,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_linked_accounts(cls, user_id: str) -> list:
|
||||
"""Get all linked external accounts for user."""
|
||||
methods = AuthenticationMethod.query.filter_by(
|
||||
user_id=user_id,
|
||||
).all()
|
||||
|
||||
external_providers = [
|
||||
AuthMethodType.GOOGLE,
|
||||
AuthMethodType.GITHUB,
|
||||
AuthMethodType.MICROSOFT,
|
||||
]
|
||||
|
||||
return [
|
||||
{
|
||||
"id": m.id,
|
||||
"provider_type": m.method_type.value if hasattr(m.method_type, 'value') else str(m.method_type),
|
||||
"provider_user_id": m.provider_user_id,
|
||||
"email": m.provider_data.get("email") if m.provider_data else None,
|
||||
"name": m.provider_data.get("name") if m.provider_data else None,
|
||||
"picture": m.provider_data.get("picture") if m.provider_data else None,
|
||||
"verified": m.verified,
|
||||
"linked_at": m.created_at.isoformat() if m.created_at else None,
|
||||
"last_used_at": m.last_used_at.isoformat() if m.last_used_at else None,
|
||||
}
|
||||
for m in methods
|
||||
if m.method_type in external_providers or str(m.method_type) in [p.value for p in external_providers]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _compute_s256_challenge(verifier: str) -> str:
|
||||
"""Compute S256 code challenge from verifier."""
|
||||
import hashlib
|
||||
import base64
|
||||
|
||||
digest = hashlib.sha256(verifier.encode()).digest()
|
||||
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
|
||||
@staticmethod
|
||||
def _build_authorization_url(config: ExternalProviderConfig, state: OAuthState) -> str:
|
||||
"""Build authorization URL (simplified - provider-specific in production)."""
|
||||
from urllib.parse import urlencode
|
||||
|
||||
params = {
|
||||
"client_id": config.client_id,
|
||||
"redirect_uri": state.redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(config.scopes or ["openid", "profile", "email"]),
|
||||
"state": state.state,
|
||||
"access_type": config.settings.get("access_type", "offline") if config.settings else "offline",
|
||||
"prompt": config.settings.get("prompt", "consent") if config.settings else "consent",
|
||||
}
|
||||
|
||||
if state.nonce:
|
||||
params["nonce"] = state.nonce
|
||||
|
||||
if state.code_challenge:
|
||||
params["code_challenge"] = state.code_challenge
|
||||
params["code_challenge_method"] = "S256"
|
||||
|
||||
return f"{config.auth_url}?{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _exchange_code(config: ExternalProviderConfig, code: str, redirect_uri: str, code_verifier: str = None) -> dict:
|
||||
"""Exchange authorization code for tokens (simplified - provider-specific in production)."""
|
||||
import requests
|
||||
|
||||
data = {
|
||||
"client_id": config.client_id,
|
||||
"client_secret": config.get_client_secret(),
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
if code_verifier:
|
||||
data["code_verifier"] = code_verifier
|
||||
|
||||
response = requests.post(config.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def _get_user_info(config: ExternalProviderConfig, access_token: str) -> dict:
|
||||
"""Get user info from provider (simplified - provider-specific in production)."""
|
||||
import requests
|
||||
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
response = requests.get(config.userinfo_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Standardize user info
|
||||
return {
|
||||
"provider_user_id": data.get("sub"),
|
||||
"email": data.get("email"),
|
||||
"email_verified": data.get("email_verified", False),
|
||||
"name": data.get("name"),
|
||||
"first_name": data.get("given_name"),
|
||||
"last_name": data.get("family_name"),
|
||||
"picture": data.get("picture"),
|
||||
"raw_data": data,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_provider_data(tokens: dict, user_info: dict) -> dict:
|
||||
"""Encrypt and store provider tokens and user info."""
|
||||
from gatehouse_app.utils.encryption import encrypt
|
||||
|
||||
result = {
|
||||
"access_token": encrypt(tokens.get("access_token")) if tokens.get("access_token") else None,
|
||||
"token_type": tokens.get("token_type", "Bearer"),
|
||||
"expires_in": tokens.get("expires_in"),
|
||||
"refresh_token": encrypt(tokens.get("refresh_token")) if tokens.get("refresh_token") else None,
|
||||
"scope": tokens.get("scope", []),
|
||||
"id_token": encrypt(tokens.get("id_token")) if tokens.get("id_token") else None,
|
||||
"email": user_info.get("email"),
|
||||
"name": user_info.get("name"),
|
||||
"picture": user_info.get("picture"),
|
||||
"raw_data": user_info.get("raw_data", {}),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _decrypt_provider_data(provider_data: dict) -> dict:
|
||||
"""
|
||||
Decrypt provider tokens from stored data.
|
||||
|
||||
This method handles backward compatibility with existing data where
|
||||
access_token may be stored in plain text (unencrypted).
|
||||
"""
|
||||
from gatehouse_app.utils.encryption import decrypt
|
||||
|
||||
if not provider_data:
|
||||
return {}
|
||||
|
||||
result = {
|
||||
"token_type": provider_data.get("token_type", "Bearer"),
|
||||
"expires_in": provider_data.get("expires_in"),
|
||||
"scope": provider_data.get("scope", []),
|
||||
"email": provider_data.get("email"),
|
||||
"name": provider_data.get("name"),
|
||||
"picture": provider_data.get("picture"),
|
||||
"raw_data": provider_data.get("raw_data", {}),
|
||||
}
|
||||
|
||||
# Decrypt access_token with backward compatibility
|
||||
access_token = provider_data.get("access_token")
|
||||
if access_token:
|
||||
# Try to decrypt - if it fails, assume it's plain text (old data)
|
||||
try:
|
||||
result["access_token"] = decrypt(access_token)
|
||||
except Exception:
|
||||
# Access token is plain text (pre-encryption data)
|
||||
result["access_token"] = access_token
|
||||
else:
|
||||
result["access_token"] = None
|
||||
|
||||
# Decrypt refresh_token
|
||||
refresh_token = provider_data.get("refresh_token")
|
||||
if refresh_token:
|
||||
try:
|
||||
result["refresh_token"] = decrypt(refresh_token)
|
||||
except Exception:
|
||||
result["refresh_token"] = refresh_token
|
||||
else:
|
||||
result["refresh_token"] = None
|
||||
|
||||
# Decrypt id_token
|
||||
id_token = provider_data.get("id_token")
|
||||
if id_token:
|
||||
try:
|
||||
result["id_token"] = decrypt(id_token)
|
||||
except Exception:
|
||||
result["id_token"] = id_token
|
||||
else:
|
||||
result["id_token"] = None
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,524 @@
|
||||
"""OAuth flow service for handling external authentication flows."""
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from flask import current_app, request, g
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth_service import (
|
||||
ExternalAuthService,
|
||||
ExternalAuthError,
|
||||
OAuthState,
|
||||
ExternalProviderConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthFlowError(Exception):
|
||||
"""Exception for OAuth flow errors."""
|
||||
|
||||
def __init__(self, message: str, error_type: str, status_code: int = 400):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OAuthFlowService:
|
||||
"""Service for managing OAuth authentication flows."""
|
||||
|
||||
@classmethod
|
||||
def initiate_login_flow(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
state_data: dict = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Initiate OAuth login flow.
|
||||
|
||||
Args:
|
||||
provider_type: The authentication provider type
|
||||
organization_id: Optional organization context for SSO
|
||||
redirect_uri: Optional custom redirect URI
|
||||
state_data: Additional state data to include
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
"""
|
||||
# Get request context for audit logging
|
||||
try:
|
||||
ip_address = request.remote_addr if request else None
|
||||
user_agent = request.headers.get("User-Agent") if request else None
|
||||
except RuntimeError:
|
||||
ip_address = None
|
||||
user_agent = None
|
||||
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
# Get provider config
|
||||
config = ExternalAuthService.get_provider_config(organization_id, provider_type)
|
||||
|
||||
# Validate redirect URI
|
||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||
raise OAuthFlowError(
|
||||
"Invalid redirect URI",
|
||||
"INVALID_REDIRECT_URI",
|
||||
400,
|
||||
)
|
||||
|
||||
# Generate PKCE
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||
|
||||
# Create OAuth state for login flow
|
||||
state = OAuthState.create_state(
|
||||
flow_type="login",
|
||||
provider_type=provider_type,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None),
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
extra_data=state_data,
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
|
||||
# Build authorization URL
|
||||
auth_url = ExternalAuthService._build_authorization_url(
|
||||
config=config,
|
||||
state=state,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth login flow initiated for provider={provider_type_str}, "
|
||||
f"org_id={organization_id}, state_id={state.id}"
|
||||
)
|
||||
|
||||
return auth_url, state.state
|
||||
|
||||
except ExternalAuthError as e:
|
||||
# Log failed initiation
|
||||
AuditService.log_action(
|
||||
action="external_auth.login.initiated",
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
"failure_reason": e.error_type,
|
||||
"ip_address": ip_address,
|
||||
},
|
||||
description=f"OAuth login initiation failed: {e.message}",
|
||||
success=False,
|
||||
error_message=e.message,
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def initiate_register_flow(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Initiate OAuth registration flow.
|
||||
|
||||
Args:
|
||||
provider_type: The authentication provider type
|
||||
organization_id: Optional organization context
|
||||
redirect_uri: Optional custom redirect URI
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
"""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
# Get provider config
|
||||
config = ExternalAuthService.get_provider_config(organization_id, provider_type)
|
||||
|
||||
# Validate redirect URI
|
||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||
raise OAuthFlowError(
|
||||
"Invalid redirect URI",
|
||||
"INVALID_REDIRECT_URI",
|
||||
400,
|
||||
)
|
||||
|
||||
# Generate PKCE
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||
|
||||
# Create OAuth state for register flow
|
||||
state = OAuthState.create_state(
|
||||
flow_type="register",
|
||||
provider_type=provider_type,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None),
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
|
||||
# Build authorization URL
|
||||
auth_url = ExternalAuthService._build_authorization_url(
|
||||
config=config,
|
||||
state=state,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth register flow initiated for provider={provider_type_str}, "
|
||||
f"org_id={organization_id}, state_id={state.id}"
|
||||
)
|
||||
|
||||
return auth_url, state.state
|
||||
|
||||
except ExternalAuthError as e:
|
||||
AuditService.log_action(
|
||||
action="external_auth.register.initiated",
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
"failure_reason": e.error_type,
|
||||
},
|
||||
description=f"OAuth registration initiation failed: {e.message}",
|
||||
success=False,
|
||||
error_message=e.message,
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def handle_callback(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
authorization_code: str,
|
||||
state: str,
|
||||
redirect_uri: str = None,
|
||||
error: str = None,
|
||||
error_description: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Handle OAuth callback from provider.
|
||||
|
||||
Args:
|
||||
provider_type: The authentication provider type
|
||||
authorization_code: Authorization code from provider
|
||||
state: State parameter from provider
|
||||
redirect_uri: Redirect URI used in the flow
|
||||
error: Error code if auth failed
|
||||
error_description: Human-readable error description
|
||||
|
||||
Returns:
|
||||
Dict with flow result
|
||||
"""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
# Get request context for audit logging
|
||||
try:
|
||||
ip_address = request.remote_addr if request else None
|
||||
user_agent = request.headers.get("User-Agent") if request else None
|
||||
except RuntimeError:
|
||||
ip_address = None
|
||||
user_agent = None
|
||||
|
||||
# Handle error response from provider
|
||||
if error:
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=None,
|
||||
provider_type=provider_type_str,
|
||||
failure_reason=error,
|
||||
error_message=error_description or error,
|
||||
)
|
||||
raise OAuthFlowError(
|
||||
error_description or f"OAuth error: {error}",
|
||||
error.upper() if error else "OAUTH_ERROR",
|
||||
400,
|
||||
)
|
||||
|
||||
# Validate state
|
||||
state_record = OAuthState.query.filter_by(state=state).first()
|
||||
if not state_record or not state_record.is_valid():
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=state_record.organization_id if state_record else None,
|
||||
provider_type=provider_type_str,
|
||||
failure_reason="invalid_state",
|
||||
error_message="Invalid or expired OAuth state",
|
||||
)
|
||||
raise OAuthFlowError(
|
||||
"Invalid or expired OAuth state",
|
||||
"INVALID_STATE",
|
||||
400,
|
||||
)
|
||||
|
||||
# Route to appropriate handler based on flow type
|
||||
if state_record.flow_type == "login":
|
||||
return cls._handle_login_callback(
|
||||
provider_type=provider_type,
|
||||
state_record=state_record,
|
||||
authorization_code=authorization_code,
|
||||
redirect_uri=redirect_uri or state_record.redirect_uri,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
elif state_record.flow_type == "link":
|
||||
return cls._handle_link_callback(
|
||||
provider_type=provider_type,
|
||||
state_record=state_record,
|
||||
authorization_code=authorization_code,
|
||||
redirect_uri=redirect_uri or state_record.redirect_uri,
|
||||
)
|
||||
elif state_record.flow_type == "register":
|
||||
return cls._handle_register_callback(
|
||||
provider_type=provider_type,
|
||||
state_record=state_record,
|
||||
authorization_code=authorization_code,
|
||||
redirect_uri=redirect_uri or state_record.redirect_uri,
|
||||
)
|
||||
else:
|
||||
raise OAuthFlowError(
|
||||
f"Unknown flow type: {state_record.flow_type}",
|
||||
"INVALID_FLOW_TYPE",
|
||||
400,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _handle_login_callback(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
state_record: OAuthState,
|
||||
authorization_code: str,
|
||||
redirect_uri: str,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> dict:
|
||||
"""Handle login flow callback."""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
# Authenticate with provider
|
||||
user, session_data = ExternalAuthService.authenticate_with_provider(
|
||||
provider_type=provider_type,
|
||||
organization_id=state_record.organization_id,
|
||||
authorization_code=authorization_code,
|
||||
state=state_record.state,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth login successful for user={user.id}, "
|
||||
f"provider={provider_type_str}, org_id={state_record.organization_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "login",
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": state_record.organization_id,
|
||||
},
|
||||
"session": session_data,
|
||||
}
|
||||
|
||||
except ExternalAuthError as e:
|
||||
logger.warning(
|
||||
f"OAuth login failed for state={state_record.id}, "
|
||||
f"provider={provider_type_str}, error={e.message}"
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def _handle_link_callback(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
state_record: OAuthState,
|
||||
authorization_code: str,
|
||||
redirect_uri: str,
|
||||
) -> dict:
|
||||
"""Handle account linking flow callback."""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
# Complete link flow
|
||||
auth_method = ExternalAuthService.complete_link_flow(
|
||||
provider_type=provider_type,
|
||||
authorization_code=authorization_code,
|
||||
state=state_record.state,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth link successful for user={state_record.user_id}, "
|
||||
f"provider={provider_type_str}, auth_method_id={auth_method.id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "link",
|
||||
"linked_account": {
|
||||
"id": auth_method.id,
|
||||
"provider_type": provider_type_str,
|
||||
"provider_user_id": auth_method.provider_user_id,
|
||||
"verified": auth_method.verified,
|
||||
},
|
||||
}
|
||||
|
||||
except ExternalAuthError as e:
|
||||
logger.warning(
|
||||
f"OAuth link failed for state={state_record.id}, "
|
||||
f"provider={provider_type_str}, error={e.message}"
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def _handle_register_callback(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
state_record: OAuthState,
|
||||
authorization_code: str,
|
||||
redirect_uri: str,
|
||||
) -> dict:
|
||||
"""Handle registration flow callback."""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
# Get provider config
|
||||
config = ExternalAuthService.get_provider_config(
|
||||
state_record.organization_id, provider_type
|
||||
)
|
||||
|
||||
# Exchange code for tokens
|
||||
tokens = ExternalAuthService._exchange_code(
|
||||
config=config,
|
||||
code=authorization_code,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=state_record.code_verifier,
|
||||
)
|
||||
|
||||
# Get user info
|
||||
user_info = ExternalAuthService._get_user_info(
|
||||
config=config,
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
# Check if user already exists by email
|
||||
existing_user = User.query.filter_by(
|
||||
email=user_info["email"]
|
||||
).first()
|
||||
|
||||
if existing_user:
|
||||
# User exists - suggest linking
|
||||
raise OAuthFlowError(
|
||||
f"An account with email {user_info['email']} already exists. "
|
||||
"Please log in with your password and link your Google account from settings.",
|
||||
"EMAIL_EXISTS",
|
||||
400,
|
||||
)
|
||||
|
||||
# Create new user
|
||||
user = User(
|
||||
email=user_info["email"],
|
||||
full_name=user_info.get("name", ""),
|
||||
status="active",
|
||||
)
|
||||
user.save()
|
||||
|
||||
# Create authentication method
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=user.id,
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
provider_data=ExternalAuthService._encrypt_provider_data(tokens, user_info),
|
||||
verified=user_info.get("email_verified", False),
|
||||
is_primary=True,
|
||||
)
|
||||
auth_method.save()
|
||||
|
||||
# Mark state as used
|
||||
state_record.mark_used()
|
||||
|
||||
# Audit log - registration success
|
||||
AuditService.log_action(
|
||||
action="user.register",
|
||||
user_id=user.id,
|
||||
organization_id=state_record.organization_id,
|
||||
resource_type="user",
|
||||
resource_id=user.id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
"provider_user_id": user_info["provider_user_id"],
|
||||
"auth_method_id": auth_method.id,
|
||||
},
|
||||
description=f"User registered via {provider_type_str}",
|
||||
success=True,
|
||||
)
|
||||
|
||||
AuditService.log_external_auth_link_completed(
|
||||
user_id=user.id,
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
auth_method_id=auth_method.id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth registration successful for email={user_info['email']}, "
|
||||
f"provider={provider_type_str}, user_id={user.id}"
|
||||
)
|
||||
|
||||
# Create session
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(
|
||||
user=user,
|
||||
organization_id=state_record.organization_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "register",
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": state_record.organization_id,
|
||||
},
|
||||
"session": session.to_dict(),
|
||||
}
|
||||
|
||||
except ExternalAuthError as e:
|
||||
logger.warning(
|
||||
f"OAuth registration failed for state={state_record.id}, "
|
||||
f"provider={provider_type_str}, error={e.message}"
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def validate_state(cls, state: str) -> Optional[OAuthState]:
|
||||
"""
|
||||
Validate and return OAuth state.
|
||||
|
||||
Args:
|
||||
state: The state parameter to validate
|
||||
|
||||
Returns:
|
||||
OAuthState if valid, None otherwise
|
||||
"""
|
||||
state_record = OAuthState.query.filter_by(state=state).first()
|
||||
if state_record and state_record.is_valid():
|
||||
return state_record
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired_states(cls):
|
||||
"""Remove expired OAuth states."""
|
||||
OAuthState.cleanup_expired()
|
||||
logger.info("Expired OAuth states cleaned up")
|
||||
Reference in New Issue
Block a user