google login works
This commit is contained in:
@@ -2,12 +2,17 @@
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.authentication_method import (
|
||||
OAuthState,
|
||||
ApplicationProviderConfig,
|
||||
OrganizationProviderOverride
|
||||
)
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
@@ -25,95 +30,12 @@ class ExternalAuthError(Exception):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
"""Temporary OAuth state storage for secure flow management."""
|
||||
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
# State identifier (used in OAuth redirects)
|
||||
state = db.Column(db.String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
# Flow type
|
||||
flow_type = db.Column(db.String(50), nullable=False) # 'link', 'login', 'register'
|
||||
|
||||
# User context
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True)
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Provider information
|
||||
provider_type = db.Column(db.String(50), nullable=False)
|
||||
|
||||
# OAuth parameters
|
||||
nonce = db.Column(db.String(128), nullable=True)
|
||||
code_verifier = db.Column(db.String(128), nullable=True)
|
||||
code_challenge = db.Column(db.String(128), nullable=True)
|
||||
redirect_uri = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
# Additional state data
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Expiration
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
|
||||
# Status
|
||||
used = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
@classmethod
|
||||
def create_state(
|
||||
cls,
|
||||
flow_type: str,
|
||||
provider_type: AuthMethodType,
|
||||
user_id: str = None,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
nonce: str = None,
|
||||
code_verifier: str = None,
|
||||
code_challenge: str = None,
|
||||
extra_data: dict = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> "OAuthState":
|
||||
"""Create a new OAuth state record."""
|
||||
state = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds)
|
||||
|
||||
return cls.create(
|
||||
state=state,
|
||||
flow_type=flow_type,
|
||||
provider_type=provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri,
|
||||
nonce=nonce or secrets.token_urlsafe(16),
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
extra_data=extra_data,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if state is still valid."""
|
||||
return (
|
||||
not self.used
|
||||
and self.expires_at
|
||||
and self.expires_at.replace(tzinfo=timezone.utc) > datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
def mark_used(self):
|
||||
"""Mark state as used."""
|
||||
self.used = True
|
||||
self.save()
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired(cls):
|
||||
"""Remove expired states."""
|
||||
cls.query.filter(cls.expires_at < datetime.now(timezone.utc)).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
class ExternalProviderConfig(BaseModel):
|
||||
"""OAuth provider configuration per organization."""
|
||||
"""OAuth provider configuration per organization.
|
||||
|
||||
DEPRECATED: This model is maintained for backward compatibility only.
|
||||
Use ApplicationProviderConfig and OrganizationProviderOverride instead.
|
||||
"""
|
||||
|
||||
__tablename__ = "external_provider_configs"
|
||||
|
||||
@@ -198,31 +120,594 @@ class ExternalProviderConfig(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class ProviderConfigAdapter:
|
||||
"""
|
||||
Adapter to provide a unified interface for provider configuration.
|
||||
|
||||
This merges application-level config with optional organization overrides,
|
||||
presenting a single config object that works with existing OAuth flow code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_config: ApplicationProviderConfig,
|
||||
org_override: Optional[OrganizationProviderOverride] = None
|
||||
):
|
||||
"""
|
||||
Initialize adapter with app config and optional org override.
|
||||
|
||||
Args:
|
||||
app_config: Application-level provider configuration
|
||||
org_override: Optional organization-specific override
|
||||
"""
|
||||
self.app_config = app_config
|
||||
self.org_override = org_override
|
||||
self.provider_type = app_config.provider_type
|
||||
|
||||
@property
|
||||
def client_id(self) -> str:
|
||||
"""Get effective client ID (override takes precedence)."""
|
||||
if self.org_override and self.org_override.client_id:
|
||||
return self.org_override.client_id
|
||||
return self.app_config.client_id
|
||||
|
||||
def get_client_secret(self) -> str:
|
||||
"""Get effective client secret (override takes precedence)."""
|
||||
if self.org_override and self.org_override.client_secret_encrypted:
|
||||
return self.org_override.get_client_secret()
|
||||
return self.app_config.get_client_secret()
|
||||
|
||||
@property
|
||||
def auth_url(self) -> str:
|
||||
"""Get authorization URL from app config."""
|
||||
# Provider endpoints are not overridable
|
||||
return self._get_provider_endpoint('auth_url')
|
||||
|
||||
@property
|
||||
def token_url(self) -> str:
|
||||
"""Get token URL from app config."""
|
||||
return self._get_provider_endpoint('token_url')
|
||||
|
||||
@property
|
||||
def userinfo_url(self) -> str:
|
||||
"""Get userinfo URL from app config."""
|
||||
return self._get_provider_endpoint('userinfo_url')
|
||||
|
||||
@property
|
||||
def jwks_url(self) -> str:
|
||||
"""Get JWKS URL from app config."""
|
||||
return self._get_provider_endpoint('jwks_url')
|
||||
|
||||
@property
|
||||
def scopes(self) -> list:
|
||||
"""Get effective scopes (merged from app config and override)."""
|
||||
base_scopes = self.app_config.additional_config.get('scopes', []) if self.app_config.additional_config else []
|
||||
if self.org_override and self.org_override.additional_config:
|
||||
override_scopes = self.org_override.additional_config.get('scopes')
|
||||
if override_scopes is not None:
|
||||
return override_scopes
|
||||
return base_scopes or ['openid', 'profile', 'email']
|
||||
|
||||
@property
|
||||
def redirect_uris(self) -> list:
|
||||
"""Get effective redirect URIs."""
|
||||
# Use override redirect URL if present, otherwise app default
|
||||
if self.org_override and self.org_override.redirect_url_override:
|
||||
return [self.org_override.redirect_url_override]
|
||||
if self.app_config.default_redirect_url:
|
||||
return [self.app_config.default_redirect_url]
|
||||
return []
|
||||
|
||||
@property
|
||||
def settings(self) -> dict:
|
||||
"""Get merged settings (app config + org override)."""
|
||||
settings = {}
|
||||
if self.app_config.additional_config:
|
||||
settings.update(self.app_config.additional_config)
|
||||
if self.org_override and self.org_override.additional_config:
|
||||
settings.update(self.org_override.additional_config)
|
||||
return settings
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""Check if provider is active (both app and org must be enabled)."""
|
||||
app_enabled = self.app_config.is_enabled
|
||||
org_enabled = True if not self.org_override else self.org_override.is_enabled
|
||||
return app_enabled and org_enabled
|
||||
|
||||
def is_redirect_uri_allowed(self, uri: str) -> bool:
|
||||
"""Check if redirect URI is allowed."""
|
||||
return uri in self.redirect_uris
|
||||
|
||||
def _get_provider_endpoint(self, endpoint_name: str) -> Optional[str]:
|
||||
"""
|
||||
Get provider endpoint from app config additional_config.
|
||||
|
||||
For application-wide configs, endpoints are stored in additional_config JSON.
|
||||
"""
|
||||
if not self.app_config.additional_config:
|
||||
return None
|
||||
return self.app_config.additional_config.get(endpoint_name)
|
||||
|
||||
|
||||
class ExternalAuthService:
|
||||
"""Service for external authentication operations."""
|
||||
|
||||
@classmethod
|
||||
def get_provider_config(
|
||||
cls,
|
||||
organization_id: str,
|
||||
provider_type: AuthMethodType,
|
||||
) -> ExternalProviderConfig:
|
||||
"""Get provider configuration for organization."""
|
||||
organization_id: Optional[str] = None,
|
||||
) -> ProviderConfigAdapter:
|
||||
"""
|
||||
Get provider configuration for authentication.
|
||||
|
||||
This method retrieves application-wide provider configuration and merges
|
||||
it with organization-specific overrides if present. Both the application
|
||||
config and organization override (if present) must be enabled for the
|
||||
provider to be considered active.
|
||||
|
||||
Configuration Precedence:
|
||||
1. Application-level config provides the baseline configuration
|
||||
2. Organization override can override client_id and client_secret (for SSO)
|
||||
3. Both must be enabled for the provider to work
|
||||
|
||||
Args:
|
||||
provider_type: The OAuth provider type (google, github, etc.)
|
||||
organization_id: Optional organization ID for override lookup
|
||||
|
||||
Returns:
|
||||
ProviderConfigAdapter: Unified config object with merged settings
|
||||
|
||||
Raises:
|
||||
ExternalAuthError: If provider is not configured or disabled
|
||||
"""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
config = ExternalProviderConfig.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
is_active=True,
|
||||
|
||||
# Get application-wide config
|
||||
app_config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type_str
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
|
||||
if not app_config:
|
||||
raise ExternalAuthError(
|
||||
f"{provider_type_str.title()} OAuth is not configured for this organization",
|
||||
f"{provider_type_str.title()} OAuth is not configured for this application",
|
||||
"PROVIDER_NOT_CONFIGURED",
|
||||
400,
|
||||
)
|
||||
|
||||
if not app_config.is_enabled:
|
||||
raise ExternalAuthError(
|
||||
f"{provider_type_str.title()} OAuth is currently disabled",
|
||||
"PROVIDER_DISABLED",
|
||||
400,
|
||||
)
|
||||
|
||||
# Check for organization-specific override
|
||||
org_override = None
|
||||
if organization_id:
|
||||
org_override = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str
|
||||
).first()
|
||||
|
||||
# If override exists but is disabled, provider is not available for this org
|
||||
if org_override and not org_override.is_enabled:
|
||||
raise ExternalAuthError(
|
||||
f"{provider_type_str.title()} OAuth is disabled for this organization",
|
||||
"PROVIDER_DISABLED_FOR_ORG",
|
||||
400,
|
||||
)
|
||||
|
||||
# Return adapter with merged configuration
|
||||
return ProviderConfigAdapter(app_config, org_override)
|
||||
|
||||
# ==================== Application-Wide Provider Management ====================
|
||||
|
||||
@classmethod
|
||||
def create_app_provider_config(
|
||||
cls,
|
||||
provider_type: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
**kwargs
|
||||
) -> ApplicationProviderConfig:
|
||||
"""
|
||||
Create application-wide provider configuration.
|
||||
|
||||
Args:
|
||||
provider_type: Provider type (google, github, etc.)
|
||||
client_id: OAuth client ID
|
||||
client_secret: OAuth client secret
|
||||
**kwargs: Additional config (auth_url, token_url, userinfo_url, scopes, etc.)
|
||||
|
||||
Returns:
|
||||
ApplicationProviderConfig: Created configuration
|
||||
|
||||
Raises:
|
||||
ExternalAuthError: If provider already exists
|
||||
"""
|
||||
# Check if provider already exists
|
||||
existing = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise ExternalAuthError(
|
||||
f"Provider {provider_type} already exists",
|
||||
"PROVIDER_EXISTS",
|
||||
400
|
||||
)
|
||||
|
||||
# Build additional_config with endpoints and settings
|
||||
additional_config = {}
|
||||
for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']:
|
||||
if key in kwargs:
|
||||
additional_config[key] = kwargs.pop(key)
|
||||
|
||||
# Add any extra settings
|
||||
if 'settings' in kwargs:
|
||||
additional_config.update(kwargs.pop('settings'))
|
||||
|
||||
# Create new config
|
||||
config = ApplicationProviderConfig(
|
||||
provider_type=provider_type,
|
||||
client_id=client_id,
|
||||
is_enabled=kwargs.get('is_enabled', True),
|
||||
default_redirect_url=kwargs.get('default_redirect_url'),
|
||||
additional_config=additional_config
|
||||
)
|
||||
|
||||
# Set encrypted secret
|
||||
config.set_client_secret(client_secret)
|
||||
config.save()
|
||||
|
||||
logger.info(f"Created application provider config for {provider_type}")
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def update_app_provider_config(
|
||||
cls,
|
||||
provider_type: str,
|
||||
**updates
|
||||
) -> ApplicationProviderConfig:
|
||||
"""
|
||||
Update application-wide provider configuration.
|
||||
|
||||
Args:
|
||||
provider_type: Provider type to update
|
||||
**updates: Fields to update (client_id, client_secret, is_enabled, etc.)
|
||||
|
||||
Returns:
|
||||
ApplicationProviderConfig: Updated configuration
|
||||
|
||||
Raises:
|
||||
ExternalAuthError: If provider not found
|
||||
"""
|
||||
config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
raise ExternalAuthError(
|
||||
f"Provider {provider_type} not found",
|
||||
"PROVIDER_NOT_FOUND",
|
||||
404
|
||||
)
|
||||
|
||||
# Update simple fields
|
||||
if 'client_id' in updates:
|
||||
config.client_id = updates['client_id']
|
||||
|
||||
if 'client_secret' in updates:
|
||||
config.set_client_secret(updates['client_secret'])
|
||||
|
||||
if 'is_enabled' in updates:
|
||||
config.is_enabled = updates['is_enabled']
|
||||
|
||||
if 'default_redirect_url' in updates:
|
||||
config.default_redirect_url = updates['default_redirect_url']
|
||||
|
||||
# Update additional_config JSON fields
|
||||
if config.additional_config is None:
|
||||
config.additional_config = {}
|
||||
|
||||
for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']:
|
||||
if key in updates:
|
||||
config.additional_config[key] = updates[key]
|
||||
|
||||
if 'settings' in updates:
|
||||
config.additional_config.update(updates['settings'])
|
||||
|
||||
config.save()
|
||||
logger.info(f"Updated application provider config for {provider_type}")
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get_app_provider_config(cls, provider_type: str) -> ApplicationProviderConfig:
|
||||
"""
|
||||
Get application-wide provider configuration.
|
||||
|
||||
Args:
|
||||
provider_type: Provider type to retrieve
|
||||
|
||||
Returns:
|
||||
ApplicationProviderConfig: Provider configuration
|
||||
|
||||
Raises:
|
||||
ExternalAuthError: If provider not found
|
||||
"""
|
||||
config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
raise ExternalAuthError(
|
||||
f"Provider {provider_type} not found",
|
||||
"PROVIDER_NOT_FOUND",
|
||||
404
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def list_app_provider_configs(cls) -> list:
|
||||
"""
|
||||
List all application-wide provider configurations.
|
||||
|
||||
Returns:
|
||||
list: List of provider configuration dictionaries
|
||||
"""
|
||||
configs = ApplicationProviderConfig.query.all()
|
||||
return [config.to_dict() for config in configs]
|
||||
|
||||
@classmethod
|
||||
def delete_app_provider_config(cls, provider_type: str) -> bool:
|
||||
"""
|
||||
Delete application-wide provider configuration.
|
||||
|
||||
Args:
|
||||
provider_type: Provider type to delete
|
||||
|
||||
Returns:
|
||||
bool: True if deleted successfully
|
||||
|
||||
Raises:
|
||||
ExternalAuthError: If provider not found
|
||||
"""
|
||||
config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
raise ExternalAuthError(
|
||||
f"Provider {provider_type} not found",
|
||||
"PROVIDER_NOT_FOUND",
|
||||
404
|
||||
)
|
||||
|
||||
config.delete()
|
||||
logger.info(f"Deleted application provider config for {provider_type}")
|
||||
return True
|
||||
|
||||
# ==================== Organization Provider Override Management ====================
|
||||
|
||||
@classmethod
|
||||
def create_org_provider_override(
|
||||
cls,
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
**kwargs
|
||||
) -> OrganizationProviderOverride:
|
||||
"""
|
||||
Create organization-specific provider override (for SSO scenarios).
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID
|
||||
provider_type: Provider type to override
|
||||
**kwargs: Override fields (client_id, client_secret, redirect_url, etc.)
|
||||
|
||||
Returns:
|
||||
OrganizationProviderOverride: Created override
|
||||
|
||||
Raises:
|
||||
ExternalAuthError: If provider doesn't exist or override already exists
|
||||
"""
|
||||
# Verify app-level provider exists
|
||||
app_config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if not app_config:
|
||||
raise ExternalAuthError(
|
||||
f"Application provider {provider_type} must be configured first",
|
||||
"PROVIDER_NOT_CONFIGURED",
|
||||
400
|
||||
)
|
||||
|
||||
# Check if override already exists
|
||||
existing = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise ExternalAuthError(
|
||||
f"Override for {provider_type} already exists for this organization",
|
||||
"OVERRIDE_EXISTS",
|
||||
400
|
||||
)
|
||||
|
||||
# Build additional_config from kwargs
|
||||
additional_config = {}
|
||||
if 'settings' in kwargs:
|
||||
additional_config.update(kwargs.pop('settings'))
|
||||
if 'scopes' in kwargs:
|
||||
additional_config['scopes'] = kwargs.pop('scopes')
|
||||
|
||||
# Create override
|
||||
override = OrganizationProviderOverride(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type,
|
||||
client_id=kwargs.get('client_id'),
|
||||
is_enabled=kwargs.get('is_enabled', True),
|
||||
redirect_url_override=kwargs.get('redirect_url_override'),
|
||||
additional_config=additional_config if additional_config else None
|
||||
)
|
||||
|
||||
# Set encrypted secret if provided
|
||||
if 'client_secret' in kwargs:
|
||||
override.set_client_secret(kwargs['client_secret'])
|
||||
|
||||
override.save()
|
||||
logger.info(f"Created org override for {provider_type} in org {organization_id}")
|
||||
return override
|
||||
|
||||
@classmethod
|
||||
def update_org_provider_override(
|
||||
cls,
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
**updates
|
||||
) -> OrganizationProviderOverride:
|
||||
"""
|
||||
Update organization-specific provider override.
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID
|
||||
provider_type: Provider type
|
||||
**updates: Fields to update
|
||||
|
||||
Returns:
|
||||
OrganizationProviderOverride: Updated override
|
||||
|
||||
Raises:
|
||||
ExternalAuthError: If override not found
|
||||
"""
|
||||
override = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if not override:
|
||||
raise ExternalAuthError(
|
||||
f"Override for {provider_type} not found for this organization",
|
||||
"OVERRIDE_NOT_FOUND",
|
||||
404
|
||||
)
|
||||
|
||||
# Update simple fields
|
||||
if 'client_id' in updates:
|
||||
override.client_id = updates['client_id']
|
||||
|
||||
if 'client_secret' in updates:
|
||||
override.set_client_secret(updates['client_secret'])
|
||||
|
||||
if 'is_enabled' in updates:
|
||||
override.is_enabled = updates['is_enabled']
|
||||
|
||||
if 'redirect_url_override' in updates:
|
||||
override.redirect_url_override = updates['redirect_url_override']
|
||||
|
||||
# Update additional_config
|
||||
if 'settings' in updates or 'scopes' in updates:
|
||||
if override.additional_config is None:
|
||||
override.additional_config = {}
|
||||
|
||||
if 'settings' in updates:
|
||||
override.additional_config.update(updates['settings'])
|
||||
if 'scopes' in updates:
|
||||
override.additional_config['scopes'] = updates['scopes']
|
||||
|
||||
override.save()
|
||||
logger.info(f"Updated org override for {provider_type} in org {organization_id}")
|
||||
return override
|
||||
|
||||
@classmethod
|
||||
def get_org_provider_override(
|
||||
cls,
|
||||
organization_id: str,
|
||||
provider_type: str
|
||||
) -> OrganizationProviderOverride:
|
||||
"""
|
||||
Get organization-specific provider override.
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID
|
||||
provider_type: Provider type
|
||||
|
||||
Returns:
|
||||
OrganizationProviderOverride: Provider override
|
||||
|
||||
Raises:
|
||||
ExternalAuthError: If override not found
|
||||
"""
|
||||
override = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if not override:
|
||||
raise ExternalAuthError(
|
||||
f"Override for {provider_type} not found for this organization",
|
||||
"OVERRIDE_NOT_FOUND",
|
||||
404
|
||||
)
|
||||
|
||||
return override
|
||||
|
||||
@classmethod
|
||||
def list_org_provider_overrides(cls, organization_id: str) -> list:
|
||||
"""
|
||||
List all provider overrides for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID
|
||||
|
||||
Returns:
|
||||
list: List of override configuration dictionaries
|
||||
"""
|
||||
overrides = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id
|
||||
).all()
|
||||
return [override.to_dict() for override in overrides]
|
||||
|
||||
@classmethod
|
||||
def delete_org_provider_override(
|
||||
cls,
|
||||
organization_id: str,
|
||||
provider_type: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete organization-specific provider override.
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID
|
||||
provider_type: Provider type
|
||||
|
||||
Returns:
|
||||
bool: True if deleted successfully
|
||||
|
||||
Raises:
|
||||
ExternalAuthError: If override not found
|
||||
"""
|
||||
override = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if not override:
|
||||
raise ExternalAuthError(
|
||||
f"Override for {provider_type} not found for this organization",
|
||||
"OVERRIDE_NOT_FOUND",
|
||||
404
|
||||
)
|
||||
|
||||
override.delete()
|
||||
logger.info(f"Deleted org override for {provider_type} in org {organization_id}")
|
||||
return True
|
||||
|
||||
# ==================== OAuth Flow Methods (Updated for New Architecture) ====================
|
||||
|
||||
@classmethod
|
||||
def initiate_link_flow(
|
||||
@@ -240,8 +725,8 @@ class ExternalAuthService:
|
||||
"""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
# Get provider config
|
||||
config = cls.get_provider_config(organization_id, provider_type)
|
||||
# Get provider config (with org override if applicable)
|
||||
config = cls.get_provider_config(provider_type, organization_id)
|
||||
|
||||
# Validate redirect URI
|
||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||
@@ -261,13 +746,13 @@ class ExternalAuthService:
|
||||
provider_type=provider_type,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri or config.redirect_uris[0],
|
||||
redirect_uri=redirect_uri or config.redirect_uris[0] if config.redirect_uris else None,
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
|
||||
# Build authorization URL (simplified - in production would use provider-specific implementation)
|
||||
# Build authorization URL
|
||||
auth_url = cls._build_authorization_url(
|
||||
config=config,
|
||||
state=state,
|
||||
@@ -338,12 +823,12 @@ class ExternalAuthService:
|
||||
400,
|
||||
)
|
||||
|
||||
# Get provider config
|
||||
# Get provider config (with org override if applicable)
|
||||
config = cls.get_provider_config(
|
||||
state_record.organization_id, provider_type
|
||||
provider_type, state_record.organization_id
|
||||
)
|
||||
|
||||
# Exchange code for tokens (simplified - in production would use provider-specific implementation)
|
||||
# Exchange code for tokens
|
||||
tokens = cls._exchange_code(
|
||||
config=config,
|
||||
code=authorization_code,
|
||||
@@ -440,8 +925,8 @@ class ExternalAuthService:
|
||||
400,
|
||||
)
|
||||
|
||||
# Get provider config
|
||||
config = cls.get_provider_config(organization_id, provider_type)
|
||||
# Get provider config (with org override if applicable)
|
||||
config = cls.get_provider_config(provider_type, organization_id)
|
||||
|
||||
# Exchange code for tokens
|
||||
tokens = cls._exchange_code(
|
||||
@@ -606,6 +1091,8 @@ class ExternalAuthService:
|
||||
if m.method_type in external_providers or str(m.method_type) in [p.value for p in external_providers]
|
||||
]
|
||||
|
||||
# ==================== Helper Methods ====================
|
||||
|
||||
@staticmethod
|
||||
def _compute_s256_challenge(verifier: str) -> str:
|
||||
"""Compute S256 code challenge from verifier."""
|
||||
@@ -616,8 +1103,8 @@ class ExternalAuthService:
|
||||
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
|
||||
@staticmethod
|
||||
def _build_authorization_url(config: ExternalProviderConfig, state: OAuthState) -> str:
|
||||
"""Build authorization URL (simplified - provider-specific in production)."""
|
||||
def _build_authorization_url(config: ProviderConfigAdapter, state: OAuthState) -> str:
|
||||
"""Build authorization URL using the provider config adapter."""
|
||||
from urllib.parse import urlencode
|
||||
|
||||
params = {
|
||||
@@ -637,11 +1124,22 @@ class ExternalAuthService:
|
||||
params["code_challenge"] = state.code_challenge
|
||||
params["code_challenge_method"] = "S256"
|
||||
|
||||
return f"{config.auth_url}?{urlencode(params)}"
|
||||
full_url = f"{config.auth_url}?{urlencode(params)}"
|
||||
|
||||
# DIAGNOSTIC LOGGING: Show exact URL being built
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Building authorization URL:\n"
|
||||
f" provider_type: {config.provider_type}\n"
|
||||
f" state.code_challenge: {state.code_challenge[:20] if state.code_challenge else 'None'}...\n"
|
||||
f" params has code_challenge: {'code_challenge' in params}\n"
|
||||
f" Full URL: {full_url}"
|
||||
)
|
||||
|
||||
return full_url
|
||||
|
||||
@staticmethod
|
||||
def _exchange_code(config: ExternalProviderConfig, code: str, redirect_uri: str, code_verifier: str = None) -> dict:
|
||||
"""Exchange authorization code for tokens (simplified - provider-specific in production)."""
|
||||
def _exchange_code(config: ProviderConfigAdapter, code: str, redirect_uri: str, code_verifier: str = None) -> dict:
|
||||
"""Exchange authorization code for tokens using the provider config adapter."""
|
||||
import requests
|
||||
|
||||
data = {
|
||||
@@ -655,14 +1153,29 @@ class ExternalAuthService:
|
||||
if code_verifier:
|
||||
data["code_verifier"] = code_verifier
|
||||
|
||||
# Log token exchange request (without secrets)
|
||||
logger.debug(
|
||||
f"Token exchange request: url={config.token_url}, "
|
||||
f"client_id={config.client_id}, redirect_uri={redirect_uri}, "
|
||||
f"has_code_verifier={bool(code_verifier)}"
|
||||
)
|
||||
|
||||
response = requests.post(config.token_url, data=data)
|
||||
|
||||
# Log response details for debugging
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Token exchange failed: status={response.status_code}, "
|
||||
f"response={response.text}"
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def _get_user_info(config: ExternalProviderConfig, access_token: str) -> dict:
|
||||
"""Get user info from provider (simplified - provider-specific in production)."""
|
||||
def _get_user_info(config: ProviderConfigAdapter, access_token: str) -> dict:
|
||||
"""Get user info from provider using the provider config adapter."""
|
||||
import requests
|
||||
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
@@ -758,4 +1271,4 @@ class ExternalAuthService:
|
||||
else:
|
||||
result["id_token"] = None
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
"""OAuth flow service for handling external authentication flows."""
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from flask import current_app, request, g
|
||||
from flask import current_app, request, g, redirect
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.authentication_method import OAuthState
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth_service import (
|
||||
ExternalAuthService,
|
||||
ExternalAuthError,
|
||||
OAuthState,
|
||||
ExternalProviderConfig,
|
||||
)
|
||||
|
||||
@@ -43,11 +45,14 @@ class OAuthFlowService:
|
||||
state_data: dict = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Initiate OAuth login flow.
|
||||
Initiate OAuth login flow without requiring organization_id upfront.
|
||||
|
||||
This method initiates the OAuth flow using application-wide provider configuration.
|
||||
The organization context is determined after successful authentication.
|
||||
|
||||
Args:
|
||||
provider_type: The authentication provider type
|
||||
organization_id: Optional organization context for SSO
|
||||
organization_id: Optional organization hint for SSO discovery
|
||||
redirect_uri: Optional custom redirect URI
|
||||
state_data: Additional state data to include
|
||||
|
||||
@@ -65,8 +70,8 @@ class OAuthFlowService:
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
# Get provider config
|
||||
config = ExternalAuthService.get_provider_config(organization_id, provider_type)
|
||||
# Get provider config (application-wide, no organization required)
|
||||
config = ExternalAuthService.get_provider_config(provider_type, organization_id)
|
||||
|
||||
# Validate redirect URI
|
||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||
@@ -76,9 +81,19 @@ class OAuthFlowService:
|
||||
400,
|
||||
)
|
||||
|
||||
# Generate PKCE
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||
# Generate PKCE parameters (Google web applications don't use PKCE)
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if provider_type_str not in ['google']:
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||
|
||||
# DIAGNOSTIC LOGGING: Show PKCE decision
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Provider type check: provider_type_str='{provider_type_str}', "
|
||||
f"is_google={provider_type_str in ['google']}, "
|
||||
f"will_skip_pkce={provider_type_str in ['google']}"
|
||||
)
|
||||
|
||||
# Create OAuth state for login flow
|
||||
state = OAuthState.create_state(
|
||||
@@ -92,6 +107,15 @@ class OAuthFlowService:
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
|
||||
# DIAGNOSTIC LOGGING: Verify state object
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Created OAuthState object:\n"
|
||||
f" state.id: {state.id}\n"
|
||||
f" state.provider_type: {state.provider_type}\n"
|
||||
f" state.code_challenge: {state.code_challenge}\n"
|
||||
f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..."
|
||||
)
|
||||
|
||||
# Build authorization URL
|
||||
auth_url = ExternalAuthService._build_authorization_url(
|
||||
config=config,
|
||||
@@ -100,7 +124,13 @@ class OAuthFlowService:
|
||||
|
||||
logger.info(
|
||||
f"OAuth login flow initiated for provider={provider_type_str}, "
|
||||
f"org_id={organization_id}, state_id={state.id}"
|
||||
f"org_id={organization_id}, state_token={state.state}, state_record_id={state.id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] FINAL CHECK: code_challenge={code_challenge}, "
|
||||
f"code_verifier={code_verifier[:20] if code_verifier else None}..., "
|
||||
f"auth_url_has_challenge={'code_challenge=' in auth_url}, "
|
||||
f"returned_auth_url={auth_url}"
|
||||
)
|
||||
|
||||
return auth_url, state.state
|
||||
@@ -129,11 +159,11 @@ class OAuthFlowService:
|
||||
redirect_uri: str = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Initiate OAuth registration flow.
|
||||
Initiate OAuth registration flow without requiring organization_id upfront.
|
||||
|
||||
Args:
|
||||
provider_type: The authentication provider type
|
||||
organization_id: Optional organization context
|
||||
organization_id: Optional organization hint
|
||||
redirect_uri: Optional custom redirect URI
|
||||
|
||||
Returns:
|
||||
@@ -142,8 +172,8 @@ class OAuthFlowService:
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
# Get provider config
|
||||
config = ExternalAuthService.get_provider_config(organization_id, provider_type)
|
||||
# Get provider config (application-wide, no organization required)
|
||||
config = ExternalAuthService.get_provider_config(provider_type, organization_id)
|
||||
|
||||
# Validate redirect URI
|
||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||
@@ -153,9 +183,19 @@ class OAuthFlowService:
|
||||
400,
|
||||
)
|
||||
|
||||
# Generate PKCE
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||
# Generate PKCE parameters (Google web applications don't use PKCE)
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if provider_type_str not in ['google']:
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||
|
||||
# DIAGNOSTIC LOGGING: Show PKCE decision for register flow
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Register flow - Provider type check: provider_type_str='{provider_type_str}', "
|
||||
f"is_google={provider_type_str in ['google']}, "
|
||||
f"will_skip_pkce={provider_type_str in ['google']}"
|
||||
)
|
||||
|
||||
# Create OAuth state for register flow
|
||||
state = OAuthState.create_state(
|
||||
@@ -168,6 +208,14 @@ class OAuthFlowService:
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
|
||||
# DIAGNOSTIC LOGGING: Verify state object for register flow
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Register flow - Created OAuthState:\n"
|
||||
f" state.id: {state.id}\n"
|
||||
f" state.code_challenge: {state.code_challenge}\n"
|
||||
f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..."
|
||||
)
|
||||
|
||||
# Build authorization URL
|
||||
auth_url = ExternalAuthService._build_authorization_url(
|
||||
config=config,
|
||||
@@ -178,6 +226,9 @@ class OAuthFlowService:
|
||||
f"OAuth register flow initiated for provider={provider_type_str}, "
|
||||
f"org_id={organization_id}, state_id={state.id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Register flow - FINAL: auth_url_has_challenge={'code_challenge=' in auth_url}"
|
||||
)
|
||||
|
||||
return auth_url, state.state
|
||||
|
||||
@@ -245,6 +296,17 @@ class OAuthFlowService:
|
||||
|
||||
# Validate state
|
||||
state_record = OAuthState.query.filter_by(state=state).first()
|
||||
|
||||
# Log validation details for debugging
|
||||
if state_record:
|
||||
logger.debug(
|
||||
f"State validation: found=True, used={state_record.used}, "
|
||||
f"expires_at={state_record.expires_at}, now={datetime.now(timezone.utc)}, "
|
||||
f"is_valid={state_record.is_valid()}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"State validation: state token not found in database: {state}")
|
||||
|
||||
if not state_record or not state_record.is_valid():
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=state_record.organization_id if state_record else None,
|
||||
@@ -299,24 +361,175 @@ class OAuthFlowService:
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> dict:
|
||||
"""Handle login flow callback."""
|
||||
"""
|
||||
Handle login flow callback with organization discovery.
|
||||
|
||||
This method:
|
||||
1. Exchanges the authorization code for tokens
|
||||
2. Gets user info from the OAuth provider
|
||||
3. Looks up the user by provider_user_id
|
||||
4. Determines which organization(s) the user belongs to
|
||||
5. Creates a session or returns org selection needed
|
||||
"""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
# Authenticate with provider
|
||||
user, session_data = ExternalAuthService.authenticate_with_provider(
|
||||
provider_type=provider_type,
|
||||
organization_id=state_record.organization_id,
|
||||
authorization_code=authorization_code,
|
||||
state=state_record.state,
|
||||
# Get provider config (application-wide)
|
||||
config = ExternalAuthService.get_provider_config(
|
||||
provider_type, state_record.organization_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Exchanging code with PKCE: state_record.code_verifier={state_record.code_verifier[:20] if state_record.code_verifier else None}..."
|
||||
)
|
||||
|
||||
# Exchange code for tokens
|
||||
tokens = ExternalAuthService._exchange_code(
|
||||
config=config,
|
||||
code=authorization_code,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=state_record.code_verifier,
|
||||
)
|
||||
|
||||
# Get user info from provider
|
||||
user_info = ExternalAuthService._get_user_info(
|
||||
config=config,
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
# Look up user by provider_user_id
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
).first()
|
||||
|
||||
if not auth_method:
|
||||
# User doesn't exist - check if email matches existing user
|
||||
existing_user = User.query.filter_by(
|
||||
email=user_info["email"]
|
||||
).first()
|
||||
|
||||
if existing_user:
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
email=user_info["email"],
|
||||
failure_reason="email_exists",
|
||||
error_message=f"An account with email {user_info['email']} already exists",
|
||||
)
|
||||
raise OAuthFlowError(
|
||||
f"An account with email {user_info['email']} already exists. "
|
||||
"Please log in with your password and link your account from settings.",
|
||||
"EMAIL_EXISTS",
|
||||
400,
|
||||
)
|
||||
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
email=user_info["email"],
|
||||
failure_reason="account_not_found",
|
||||
error_message="No Gatehouse account matches this external account",
|
||||
)
|
||||
raise OAuthFlowError(
|
||||
"No Gatehouse account matches this external account. Please register first.",
|
||||
"ACCOUNT_NOT_FOUND",
|
||||
404,
|
||||
)
|
||||
|
||||
user = auth_method.user
|
||||
|
||||
# Update provider data
|
||||
auth_method.provider_data = ExternalAuthService._encrypt_provider_data(
|
||||
tokens, user_info
|
||||
)
|
||||
auth_method.last_used_at = datetime.utcnow()
|
||||
auth_method.save()
|
||||
|
||||
# Get user's organizations
|
||||
user_orgs = user.get_organizations()
|
||||
|
||||
# Determine target organization
|
||||
target_org = None
|
||||
|
||||
# Priority 1: Use organization_id from state if provided (org hint)
|
||||
if state_record.organization_id:
|
||||
target_org = next(
|
||||
(org for org in user_orgs if org.id == state_record.organization_id),
|
||||
None
|
||||
)
|
||||
|
||||
# Priority 2: If user has exactly one organization, use it
|
||||
if not target_org and len(user_orgs) == 1:
|
||||
target_org = user_orgs[0]
|
||||
|
||||
# Priority 3: No organization or multiple organizations - need selection
|
||||
if not target_org:
|
||||
# Mark state as used
|
||||
state_record.mark_used()
|
||||
|
||||
logger.info(
|
||||
f"OAuth login requires org selection for user={user.id}, "
|
||||
f"provider={provider_type_str}, org_count={len(user_orgs)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "login",
|
||||
"requires_org_selection": True,
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
},
|
||||
"available_organizations": [
|
||||
{
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
"slug": org.slug if hasattr(org, 'slug') else None,
|
||||
}
|
||||
for org in user_orgs
|
||||
],
|
||||
"state": state_record.state,
|
||||
}
|
||||
|
||||
# Create session for the target org
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(
|
||||
user=user,
|
||||
is_compliance_only=False,
|
||||
)
|
||||
|
||||
# Mark state as used
|
||||
state_record.mark_used()
|
||||
|
||||
# Audit log - login success
|
||||
AuditService.log_external_auth_login(
|
||||
user_id=user.id,
|
||||
organization_id=target_org.id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
auth_method_id=auth_method.id,
|
||||
session_id=session.id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth login successful for user={user.id}, "
|
||||
f"provider={provider_type_str}, org_id={state_record.organization_id}"
|
||||
f"provider={provider_type_str}, org_id={target_org.id}"
|
||||
)
|
||||
|
||||
# Build session dict with token (to_dict() excludes token for security)
|
||||
session_dict = session.to_dict()
|
||||
session_dict["token"] = session.token
|
||||
# Calculate expires_in handling naive datetime from database
|
||||
expires_at = session.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "login",
|
||||
@@ -324,9 +537,9 @@ class OAuthFlowService:
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": state_record.organization_id,
|
||||
"organization_id": target_org.id,
|
||||
},
|
||||
"session": session_data,
|
||||
"session": session_dict,
|
||||
}
|
||||
|
||||
except ExternalAuthError as e:
|
||||
@@ -335,6 +548,19 @@ class OAuthFlowService:
|
||||
f"provider={provider_type_str}, error={e.message}"
|
||||
)
|
||||
raise
|
||||
except OAuthFlowError:
|
||||
# Re-raise OAuthFlowError as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error in OAuth login callback: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise OAuthFlowError(
|
||||
"An unexpected error occurred during login",
|
||||
"INTERNAL_ERROR",
|
||||
500,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _handle_link_callback(
|
||||
@@ -387,13 +613,17 @@ class OAuthFlowService:
|
||||
authorization_code: str,
|
||||
redirect_uri: str,
|
||||
) -> dict:
|
||||
"""Handle registration flow callback."""
|
||||
"""
|
||||
Handle registration flow callback.
|
||||
|
||||
Creates a new user account and prompts for organization creation/selection.
|
||||
"""
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
# Get provider config
|
||||
# Get provider config (application-wide)
|
||||
config = ExternalAuthService.get_provider_config(
|
||||
state_record.organization_id, provider_type
|
||||
provider_type, state_record.organization_id
|
||||
)
|
||||
|
||||
# Exchange code for tokens
|
||||
@@ -429,6 +659,7 @@ class OAuthFlowService:
|
||||
email=user_info["email"],
|
||||
full_name=user_info.get("name", ""),
|
||||
status="active",
|
||||
email_verified=user_info.get("email_verified", False),
|
||||
)
|
||||
user.save()
|
||||
|
||||
@@ -440,6 +671,7 @@ class OAuthFlowService:
|
||||
provider_data=ExternalAuthService._encrypt_provider_data(tokens, user_info),
|
||||
verified=user_info.get("email_verified", False),
|
||||
is_primary=True,
|
||||
last_used_at=datetime.utcnow(),
|
||||
)
|
||||
auth_method.save()
|
||||
|
||||
@@ -475,23 +707,48 @@ class OAuthFlowService:
|
||||
f"provider={provider_type_str}, user_id={user.id}"
|
||||
)
|
||||
|
||||
# Create session
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(
|
||||
user=user,
|
||||
organization_id=state_record.organization_id,
|
||||
)
|
||||
# If organization_id hint was provided and valid, create session for that org
|
||||
if state_record.organization_id:
|
||||
from gatehouse_app.models.organization import Organization
|
||||
org = Organization.query.get(state_record.organization_id)
|
||||
if org:
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(
|
||||
user=user,
|
||||
is_compliance_only=False,
|
||||
)
|
||||
# Build session dict with token (to_dict() excludes token for security)
|
||||
session_dict = session.to_dict()
|
||||
session_dict["token"] = session.token
|
||||
# Calculate expires_in handling naive datetime from database
|
||||
expires_at = session.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "register",
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": org.id,
|
||||
},
|
||||
"session": session_dict,
|
||||
}
|
||||
|
||||
# No organization hint or invalid - need to create/select org
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "register",
|
||||
"requires_org_creation": True,
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": state_record.organization_id,
|
||||
},
|
||||
"session": session.to_dict(),
|
||||
"state": state_record.state,
|
||||
}
|
||||
|
||||
except ExternalAuthError as e:
|
||||
@@ -500,6 +757,19 @@ class OAuthFlowService:
|
||||
f"provider={provider_type_str}, error={e.message}"
|
||||
)
|
||||
raise
|
||||
except OAuthFlowError:
|
||||
# Re-raise OAuthFlowError as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error in OAuth registration callback: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise OAuthFlowError(
|
||||
"An unexpected error occurred during registration",
|
||||
"INTERNAL_ERROR",
|
||||
500,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_state(cls, state: str) -> Optional[OAuthState]:
|
||||
@@ -522,3 +792,232 @@ class OAuthFlowService:
|
||||
"""Remove expired OAuth states."""
|
||||
OAuthState.cleanup_expired()
|
||||
logger.info("Expired OAuth states cleaned up")
|
||||
|
||||
@classmethod
|
||||
def generate_authorization_code(
|
||||
cls,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
scope: list = None,
|
||||
nonce: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> str:
|
||||
"""
|
||||
Generate an authorization code for external OAuth applications.
|
||||
|
||||
This method creates a short-lived, single-use authorization code that can be
|
||||
exchanged for a session token by external applications like oauth2-proxy.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: The client ID (e.g., 'oauth2-proxy', 'bookstack')
|
||||
redirect_uri: The redirect URI
|
||||
scope: Requested scopes
|
||||
nonce: OIDC nonce for validation
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
|
||||
|
||||
Returns:
|
||||
The authorization code (plain text, not hashed)
|
||||
"""
|
||||
# Generate a secure random code
|
||||
code = secrets.token_urlsafe(32)
|
||||
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||
|
||||
# Create the authorization code record
|
||||
OIDCAuthCode.create_code(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
code_hash=code_hash,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
nonce=nonce,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
lifetime_seconds=lifetime_seconds,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated authorization code for user={user_id}, client={client_id}"
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
@classmethod
|
||||
def exchange_authorization_code(
|
||||
cls,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
ip_address: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Exchange an authorization code for a session token.
|
||||
|
||||
This method validates and consumes the authorization code, then creates
|
||||
a session for the user.
|
||||
|
||||
Args:
|
||||
code: The authorization code
|
||||
client_id: The client ID
|
||||
redirect_uri: The redirect URI (must match original request)
|
||||
ip_address: Client IP address
|
||||
|
||||
Returns:
|
||||
Dict with session token and user info
|
||||
"""
|
||||
# Hash the provided code for lookup
|
||||
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||
|
||||
# Find the authorization code record
|
||||
auth_code = OIDCAuthCode.query.filter_by(
|
||||
client_id=client_id,
|
||||
code_hash=code_hash,
|
||||
).first()
|
||||
|
||||
if not auth_code:
|
||||
raise OAuthFlowError(
|
||||
"Invalid authorization code",
|
||||
"INVALID_CODE",
|
||||
400,
|
||||
)
|
||||
|
||||
# Validate the code
|
||||
if not auth_code.is_valid():
|
||||
if auth_code.is_used:
|
||||
raise OAuthFlowError(
|
||||
"Authorization code has already been used",
|
||||
"CODE_USED",
|
||||
400,
|
||||
)
|
||||
else:
|
||||
raise OAuthFlowError(
|
||||
"Authorization code has expired",
|
||||
"CODE_EXPIRED",
|
||||
400,
|
||||
)
|
||||
|
||||
# Validate redirect URI
|
||||
if auth_code.redirect_uri != redirect_uri:
|
||||
raise OAuthFlowError(
|
||||
"Redirect URI mismatch",
|
||||
"INVALID_REDIRECT_URI",
|
||||
400,
|
||||
)
|
||||
|
||||
# Get the user
|
||||
from gatehouse_app.models import User
|
||||
user = User.query.get(auth_code.user_id)
|
||||
if not user:
|
||||
raise OAuthFlowError(
|
||||
"User not found",
|
||||
"USER_NOT_FOUND",
|
||||
404,
|
||||
)
|
||||
|
||||
# Determine organization
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
|
||||
# Get user's organizations
|
||||
user_orgs = user.get_organizations()
|
||||
|
||||
# Determine target organization
|
||||
target_org = None
|
||||
|
||||
# Priority 1: Use organization_id from auth code if available
|
||||
# Priority 2: If user has exactly one organization, use it
|
||||
if not target_org and len(user_orgs) == 1:
|
||||
target_org = user_orgs[0]
|
||||
|
||||
if not target_org:
|
||||
raise OAuthFlowError(
|
||||
"User does not have a default organization. Organization selection required.",
|
||||
"ORG_SELECTION_REQUIRED",
|
||||
400,
|
||||
)
|
||||
|
||||
# Create session
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(
|
||||
user=user,
|
||||
is_compliance_only=False,
|
||||
)
|
||||
|
||||
# Mark the code as used
|
||||
auth_code.mark_as_used()
|
||||
|
||||
# Build session dict
|
||||
session_dict = session.to_dict()
|
||||
session_dict["token"] = session.token
|
||||
expires_at = session.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||
|
||||
logger.info(
|
||||
f"Authorization code exchanged for session: user={user.id}, "
|
||||
f"org_id={target_org.id}, client={client_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"token": session_dict["token"],
|
||||
"expires_in": session_dict["expires_in"],
|
||||
"token_type": "Bearer",
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": target_org.id,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_redirect_response(
|
||||
cls,
|
||||
redirect_uri: str,
|
||||
authorization_code: str,
|
||||
state: str = None,
|
||||
):
|
||||
"""
|
||||
Create a redirect response with authorization code.
|
||||
|
||||
Args:
|
||||
redirect_uri: The redirect URI
|
||||
authorization_code: The authorization code
|
||||
state: Optional state parameter
|
||||
|
||||
Returns:
|
||||
Flask redirect response
|
||||
"""
|
||||
from urllib.parse import urlencode, urlparse, urlunparse
|
||||
|
||||
# Parse the redirect URI
|
||||
parsed = urlparse(redirect_uri)
|
||||
|
||||
# Build query parameters
|
||||
params = {"code": authorization_code}
|
||||
if state:
|
||||
params["state"] = state
|
||||
|
||||
# Reconstruct URL with query parameters
|
||||
redirect_url = urlunparse((
|
||||
parsed.scheme,
|
||||
parsed.netloc,
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
urlencode(params),
|
||||
parsed.fragment,
|
||||
))
|
||||
|
||||
logger.info(
|
||||
f"Redirecting to {parsed.scheme}://{parsed.netloc} with authorization code"
|
||||
)
|
||||
|
||||
return redirect(redirect_url)
|
||||
|
||||
Reference in New Issue
Block a user