can link google accounts!

This commit is contained in:
2026-01-20 15:54:00 +10:30
parent 900722d695
commit 4cf4a27c9a
17 changed files with 5325 additions and 4 deletions
+1 -1
View File
@@ -5,4 +5,4 @@ from flask import Blueprint
api_v1_bp = Blueprint("api_v1", __name__)
# Import route modules to register them
from gatehouse_app.api.v1 import auth, users, organizations, policies
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth
+706
View File
@@ -0,0 +1,706 @@
"""External authentication provider endpoints."""
from flask import request, g
from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required
from gatehouse_app.utils.constants import AuthMethodType
from gatehouse_app.services.external_auth_service import (
ExternalAuthService,
ExternalAuthError,
)
from gatehouse_app.services.oauth_flow_service import (
OAuthFlowService,
OAuthFlowError,
)
from gatehouse_app.services.audit_service import AuditService
# Provider type mapping
PROVIDER_TYPE_MAP = {
"google": AuthMethodType.GOOGLE,
"github": AuthMethodType.GITHUB,
"microsoft": AuthMethodType.MICROSOFT,
}
def get_provider_type(provider: str) -> AuthMethodType:
"""Get AuthMethodType from provider string."""
provider_lower = provider.lower()
if provider_lower not in PROVIDER_TYPE_MAP:
raise ExternalAuthError(
f"Unsupported provider: {provider}",
"UNSUPPORTED_PROVIDER",
400,
)
return PROVIDER_TYPE_MAP[provider_lower]
# =============================================================================
# Provider Configuration Endpoints (Admin)
# =============================================================================
@api_v1_bp.route("/auth/external/providers", methods=["GET"])
@login_required
def list_providers():
"""
List available external authentication providers for current organization.
Returns:
200: List of providers with their configuration status
401: Not authenticated
"""
from gatehouse_app.models import Organization
from gatehouse_app.services.external_auth_service import ExternalProviderConfig
# Get user's primary organization
user_orgs = g.current_user.get_organizations()
if not user_orgs:
return api_response(
success=False,
message="No organizations found for user",
status=400,
error_type="BAD_REQUEST",
)
organization_id = user_orgs[0].id
# Get all configured providers for organization
configs = ExternalProviderConfig.query.filter_by(
organization_id=organization_id,
).all()
configured_providers = {c.provider_type.lower(): c for c in configs}
# Provider definitions
providers = [
{
"id": "google",
"name": "Google",
"type": "google",
"is_configured": "google" in configured_providers,
"is_active": configured_providers.get("google", {}).is_active if "google" in configured_providers else False,
"settings": {
"requires_domain": False,
"supports_refresh_tokens": True,
},
},
{
"id": "github",
"name": "GitHub",
"type": "github",
"is_configured": "github" in configured_providers,
"is_active": configured_providers.get("github", {}).is_active if "github" in configured_providers else False,
"settings": {
"requires_domain": False,
"supports_refresh_tokens": True,
},
},
{
"id": "microsoft",
"name": "Microsoft",
"type": "microsoft",
"is_configured": "microsoft" in configured_providers,
"is_active": configured_providers.get("microsoft", {}).is_active if "microsoft" in configured_providers else False,
"settings": {
"requires_domain": False,
"supports_refresh_tokens": True,
},
},
]
return api_response(
data={"providers": providers},
message="Providers retrieved successfully",
)
@api_v1_bp.route("/auth/external/providers/<provider>/config", methods=["GET"])
@login_required
def get_provider_config(provider: str):
"""
Get provider configuration (admin only).
Args:
provider: Provider type (google, github, microsoft)
Returns:
200: Provider configuration
401: Not authenticated
403: Not authorized (not admin)
404: Provider not configured
"""
from gatehouse_app.models import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
from gatehouse_app.services.external_auth_service import ExternalProviderConfig
provider_type = get_provider_type(provider)
# Get user's primary organization
user_orgs = g.current_user.get_organizations()
if not user_orgs:
return api_response(
success=False,
message="No organizations found for user",
status=400,
error_type="BAD_REQUEST",
)
organization_id = user_orgs[0].id
# Check if user is admin
member = OrganizationMember.query.filter_by(
user_id=g.current_user.id,
organization_id=organization_id,
).first()
if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
return api_response(
success=False,
message="Admin access required",
status=403,
error_type="FORBIDDEN",
)
# Get provider config
config = ExternalProviderConfig.query.filter_by(
organization_id=organization_id,
provider_type=provider_type.value,
).first()
if not config:
return api_response(
success=False,
message=f"{provider.title()} OAuth is not configured",
status=404,
error_type="NOT_FOUND",
)
return api_response(
data=config.to_dict(include_secrets=False),
message="Provider configuration retrieved successfully",
)
@api_v1_bp.route("/auth/external/providers/<provider>/config", methods=["POST"])
@login_required
def create_or_update_provider_config(provider: str):
"""
Create or update provider configuration (admin only).
Args:
provider: Provider type (google, github, microsoft)
Request body:
client_id: OAuth client ID
client_secret: OAuth client secret
scopes: List of OAuth scopes
redirect_uris: List of allowed redirect URIs
settings: Provider-specific settings
is_active: Whether the provider is active
Returns:
200: Provider configuration updated
201: Provider configuration created
400: Validation error
401: Not authenticated
403: Not authorized (not admin)
"""
from gatehouse_app.models import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
from gatehouse_app.services.external_auth_service import ExternalProviderConfig
provider_type = get_provider_type(provider)
# Get user's primary organization
user_orgs = g.current_user.get_organizations()
if not user_orgs:
return api_response(
success=False,
message="No organizations found for user",
status=400,
error_type="BAD_REQUEST",
)
organization_id = user_orgs[0].id
# Check if user is admin
member = OrganizationMember.query.filter_by(
user_id=g.current_user.id,
organization_id=organization_id,
).first()
if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
return api_response(
success=False,
message="Admin access required",
status=403,
error_type="FORBIDDEN",
)
# Validate request data
data = request.json or {}
client_id = data.get("client_id")
client_secret = data.get("client_secret")
if not client_id:
return api_response(
success=False,
message="client_id is required",
status=400,
error_type="VALIDATION_ERROR",
)
# Get or create config
config = ExternalProviderConfig.query.filter_by(
organization_id=organization_id,
provider_type=provider_type.value,
).first()
is_new = config is None
if config:
# Update existing
config.client_id = client_id
if client_secret:
config.set_client_secret(client_secret)
config.scopes = data.get("scopes", ["openid", "profile", "email"])
config.redirect_uris = data.get("redirect_uris", [])
config.settings = data.get("settings", {})
config.is_active = data.get("is_active", True)
config.save()
# Audit log - config update
AuditService.log_external_auth_config_update(
user_id=g.current_user.id,
organization_id=organization_id,
provider_type=provider_type.value,
config_id=config.id,
changes={
"client_id": "updated",
"client_secret": "updated" if client_secret else None,
"scopes": data.get("scopes"),
"redirect_uris": data.get("redirect_uris"),
"is_active": config.is_active,
},
)
else:
# Create new - get provider endpoints
auth_url, token_url, userinfo_url = _get_provider_endpoints(provider_type)
config = ExternalProviderConfig(
organization_id=organization_id,
provider_type=provider_type.value,
client_id=client_id,
client_secret_encrypted=None,
auth_url=auth_url,
token_url=token_url,
userinfo_url=userinfo_url,
scopes=data.get("scopes", ["openid", "profile", "email"]),
redirect_uris=data.get("redirect_uris", []),
settings=data.get("settings", {}),
is_active=data.get("is_active", True),
)
if client_secret:
config.set_client_secret(client_secret)
config.save()
# Audit log - config create
AuditService.log_external_auth_config_create(
user_id=g.current_user.id,
organization_id=organization_id,
provider_type=provider_type.value,
config_id=config.id,
)
return api_response(
data=config.to_dict(include_secrets=False),
message="Provider configuration saved successfully",
status=201 if is_new else 200,
)
@api_v1_bp.route("/auth/external/providers/<provider>/config", methods=["DELETE"])
@login_required
def delete_provider_config(provider: str):
"""
Delete provider configuration (admin only).
Args:
provider: Provider type (google, github, microsoft)
Returns:
200: Provider configuration deleted
401: Not authenticated
403: Not authorized (not admin)
404: Provider not configured
"""
from gatehouse_app.models import OrganizationMember
from gatehouse_app.utils.constants import OrganizationRole
from gatehouse_app.services.external_auth_service import ExternalProviderConfig
provider_type = get_provider_type(provider)
# Get user's primary organization
user_orgs = g.current_user.get_organizations()
if not user_orgs:
return api_response(
success=False,
message="No organizations found for user",
status=400,
error_type="BAD_REQUEST",
)
organization_id = user_orgs[0].id
# Check if user is admin
member = OrganizationMember.query.filter_by(
user_id=g.current_user.id,
organization_id=organization_id,
).first()
if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
return api_response(
success=False,
message="Admin access required",
status=403,
error_type="FORBIDDEN",
)
# Get and delete config
config = ExternalProviderConfig.query.filter_by(
organization_id=organization_id,
provider_type=provider_type.value,
).first()
if not config:
return api_response(
success=False,
message=f"{provider.title()} OAuth is not configured",
status=404,
error_type="NOT_FOUND",
)
config_id = config.id
config.delete()
# Audit log - config delete
AuditService.log_external_auth_config_delete(
user_id=g.current_user.id,
organization_id=organization_id,
provider_type=provider_type.value,
config_id=config_id,
)
return api_response(
message=f"{provider.title()} provider configuration deleted successfully",
)
# =============================================================================
# Account Linking Endpoints
# =============================================================================
@api_v1_bp.route("/auth/external/linked-accounts", methods=["GET"])
@login_required
def list_linked_accounts():
"""
List all linked external accounts for the current user.
Returns:
200: List of linked accounts
401: Not authenticated
"""
linked_accounts = ExternalAuthService.get_linked_accounts(g.current_user.id)
# Check if user has other auth methods (for unlink availability)
from gatehouse_app.models import AuthenticationMethod
other_methods = AuthenticationMethod.query.filter_by(
user_id=g.current_user.id,
).count()
return api_response(
data={
"linked_accounts": linked_accounts,
"unlink_available": other_methods > 1,
},
message="Linked accounts retrieved successfully",
)
@api_v1_bp.route("/auth/external/<provider>/link", methods=["POST"])
@login_required
def initiate_link_account(provider: str):
"""
Initiate OAuth flow to link an external account.
Args:
provider: Provider type (google, github, microsoft)
Request body:
redirect_uri: Optional redirect URI after linking
Returns:
302: Redirect to provider authorization page
400: Validation error or provider not configured
401: Not authenticated
"""
provider_type = get_provider_type(provider)
# Get user's organization
user_orgs = g.current_user.get_organizations()
organization_id = user_orgs[0].id if user_orgs else None
# Get optional redirect URI
data = request.json or {}
redirect_uri = data.get("redirect_uri")
try:
# Initiate link flow
auth_url, state = ExternalAuthService.initiate_link_flow(
user_id=g.current_user.id,
provider_type=provider_type,
organization_id=organization_id,
redirect_uri=redirect_uri,
)
return api_response(
data={
"authorization_url": auth_url,
"state": state,
},
message="Link flow initiated. Redirect to authorization URL.",
)
except ExternalAuthError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/external/<provider>/unlink", methods=["DELETE"])
@login_required
def unlink_account(provider: str):
"""
Unlink an external account from the user's profile.
Args:
provider: Provider type (google, github, microsoft)
Returns:
200: Account unlinked successfully
400: Validation error or cannot unlink last method
401: Not authenticated
404: Provider not linked
"""
provider_type = get_provider_type(provider)
# Get user's organization
user_orgs = g.current_user.get_organizations()
organization_id = user_orgs[0].id if user_orgs else None
try:
ExternalAuthService.unlink_provider(
user_id=g.current_user.id,
provider_type=provider_type,
organization_id=organization_id,
)
return api_response(
message=f"{provider.title()} account unlinked successfully",
)
except ExternalAuthError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
# =============================================================================
# OAuth Flow Endpoints
# =============================================================================
@api_v1_bp.route("/auth/external/<provider>/authorize", methods=["GET"])
def initiate_oauth_authorize(provider: str):
"""
Initiate OAuth authentication or account registration flow.
Args:
provider: Provider type (google, github, microsoft)
Query parameters:
flow: 'login' or 'register'
redirect_uri: Optional redirect URI
organization_id: Optional organization context
Returns:
302: Redirect to provider authorization page
400: Validation error or provider not configured
"""
provider_type = get_provider_type(provider)
# Get query parameters
flow = request.args.get("flow", "login")
redirect_uri = request.args.get("redirect_uri")
organization_id = request.args.get("organization_id")
if flow not in ["login", "register"]:
return api_response(
success=False,
message="Invalid flow type. Must be 'login' or 'register'",
status=400,
error_type="VALIDATION_ERROR",
)
try:
if flow == "login":
auth_url, state = OAuthFlowService.initiate_login_flow(
provider_type=provider_type,
organization_id=organization_id,
redirect_uri=redirect_uri,
)
else:
auth_url, state = OAuthFlowService.initiate_register_flow(
provider_type=provider_type,
organization_id=organization_id,
redirect_uri=redirect_uri,
)
return api_response(
data={
"authorization_url": auth_url,
"state": state,
},
message=f"OAuth {flow} flow initiated",
)
except OAuthFlowError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/external/<provider>/callback", methods=["GET"])
def handle_oauth_callback(provider: str):
"""
Handle OAuth callback from provider.
Args:
provider: Provider type (google, github, microsoft)
Query parameters:
code: Authorization code from provider
state: State parameter
error: Error code if auth failed
error_description: Human-readable error description
Returns:
200: OAuth flow completed successfully
302: Redirect with error
400: Validation error or OAuth error
"""
provider_type = get_provider_type(provider)
# Get callback parameters
authorization_code = request.args.get("code")
state = request.args.get("state")
error = request.args.get("error")
error_description = request.args.get("error_description")
# Get redirect URI from state if available
redirect_uri = request.args.get("redirect_uri")
try:
result = OAuthFlowService.handle_callback(
provider_type=provider_type,
authorization_code=authorization_code,
state=state,
redirect_uri=redirect_uri,
error=error,
error_description=error_description,
)
if result.get("success"):
if result.get("flow_type") == "login":
return api_response(
data={
"token": result["session"]["token"],
"expires_in": result["session"].get("expires_in", 86400),
"token_type": "Bearer",
"user": result["user"],
},
message="Login successful",
)
elif result.get("flow_type") == "register":
return api_response(
data={
"token": result["session"]["token"],
"expires_in": result["session"].get("expires_in", 86400),
"token_type": "Bearer",
"user": result["user"],
},
message="Registration successful",
)
elif result.get("flow_type") == "link":
return api_response(
data={
"linked_account": result["linked_account"],
},
message="Account linked successfully",
)
return api_response(
data=result,
message="OAuth flow completed",
)
except OAuthFlowError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
# =============================================================================
# Helper Functions
# =============================================================================
def _get_provider_endpoints(provider_type: AuthMethodType):
"""Get OAuth endpoints for a provider."""
if provider_type == AuthMethodType.GOOGLE:
return (
"https://accounts.google.com/o/oauth2/v2/auth",
"https://oauth2.googleapis.com/token",
"https://www.googleapis.com/oauth2/v3/userinfo",
)
elif provider_type == AuthMethodType.GITHUB:
return (
"https://github.com/login/oauth/authorize",
"https://github.com/login/oauth/access_token",
"https://api.github.com/user",
)
elif provider_type == AuthMethodType.MICROSOFT:
return (
"https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
"https://login.microsoftonline.com/common/oauth2/v2.0/token",
"https://graph.microsoft.com/oidc/userinfo",
)
else:
raise ExternalAuthError(
f"Unsupported provider: {provider_type}",
"UNSUPPORTED_PROVIDER",
400,
)
+15
View File
@@ -22,6 +22,21 @@ class BaseModel(db.Model):
)
deleted_at = db.Column(db.DateTime, nullable=True)
@classmethod
def create(cls, **kwargs):
"""Create and save a new model instance.
Args:
**kwargs: Model field values
Returns:
The created model instance
"""
instance = cls(**kwargs)
db.session.add(instance)
db.session.commit()
return instance
def save(self):
"""Save the model instance to database."""
db.session.add(self)
+3
View File
@@ -24,6 +24,9 @@ class Organization(BaseModel):
oidc_clients = db.relationship(
"OIDCClient", back_populates="organization", cascade="all, delete-orphan"
)
external_provider_configs = db.relationship(
"ExternalProviderConfig", back_populates="organization", cascade="all, delete-orphan"
)
security_policy = db.relationship(
"OrganizationSecurityPolicy",
back_populates="organization",
+229
View File
@@ -105,3 +105,232 @@ class AuditService:
.limit(limit)
.all()
)
# External Authentication Provider Audit Methods
@staticmethod
def log_external_auth_link_initiated(
user_id: str,
organization_id: str,
provider_type: str,
state_id: str = None,
):
"""Log external auth account linking initiated event."""
return AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_LINK_INITIATED,
user_id=user_id,
organization_id=organization_id,
resource_type="oauth_state",
resource_id=state_id,
metadata={
"provider_type": provider_type,
},
description=f"External auth link initiated for {provider_type}",
success=True,
)
@staticmethod
def log_external_auth_link_completed(
user_id: str,
organization_id: str,
provider_type: str,
provider_user_id: str,
auth_method_id: str = None,
):
"""Log external auth account linking completed event."""
return AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_LINK_COMPLETED,
user_id=user_id,
organization_id=organization_id,
resource_type="authentication_method",
resource_id=auth_method_id,
metadata={
"provider_type": provider_type,
"provider_user_id": provider_user_id,
},
description=f"External auth account linked: {provider_type} ({provider_user_id})",
success=True,
)
@staticmethod
def log_external_auth_link_failed(
user_id: str,
organization_id: str,
provider_type: str,
error_message: str,
failure_reason: str = None,
):
"""Log external auth account linking failed event."""
return AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_LINK_FAILED,
user_id=user_id,
organization_id=organization_id,
metadata={
"provider_type": provider_type,
"failure_reason": failure_reason,
},
description=f"External auth link failed for {provider_type}: {error_message}",
success=False,
error_message=error_message,
)
@staticmethod
def log_external_auth_unlink(
user_id: str,
organization_id: str,
provider_type: str,
provider_user_id: str,
auth_method_id: str = None,
):
"""Log external auth account unlinking event."""
return AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_UNLINK,
user_id=user_id,
organization_id=organization_id,
resource_type="authentication_method",
resource_id=auth_method_id,
metadata={
"provider_type": provider_type,
"provider_user_id": provider_user_id,
},
description=f"External auth account unlinked: {provider_type} ({provider_user_id})",
success=True,
)
@staticmethod
def log_external_auth_login(
user_id: str,
organization_id: str,
provider_type: str,
provider_user_id: str,
auth_method_id: str = None,
session_id: str = None,
mfa_used: bool = False,
):
"""Log external auth login event."""
return AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_LOGIN,
user_id=user_id,
organization_id=organization_id,
resource_type="session",
resource_id=session_id,
metadata={
"provider_type": provider_type,
"provider_user_id": provider_user_id,
"auth_method_id": auth_method_id,
"mfa_used": mfa_used,
},
description=f"User logged in with {provider_type}",
success=True,
)
@staticmethod
def log_external_auth_login_failed(
organization_id: str,
provider_type: str,
provider_user_id: str = None,
email: str = None,
failure_reason: str = None,
error_message: str = None,
):
"""Log external auth login failed event."""
return AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
user_id=None, # Unknown user
organization_id=organization_id,
metadata={
"provider_type": provider_type,
"provider_user_id": provider_user_id,
"email": email,
"failure_reason": failure_reason,
},
description=f"Failed login attempt with {provider_type}: {failure_reason or error_message}",
success=False,
error_message=error_message or failure_reason,
)
@staticmethod
def log_external_auth_token_refresh(
user_id: str,
organization_id: str,
provider_type: str,
auth_method_id: str = None,
):
"""Log external auth token refresh event."""
return AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_TOKEN_REFRESH,
user_id=user_id,
organization_id=organization_id,
resource_type="authentication_method",
resource_id=auth_method_id,
metadata={
"provider_type": provider_type,
},
description=f"External auth token refreshed for {provider_type}",
success=True,
)
@staticmethod
def log_external_auth_config_create(
user_id: str,
organization_id: str,
provider_type: str,
config_id: str = None,
):
"""Log external auth provider config creation event."""
return AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_CONFIG_CREATE,
user_id=user_id,
organization_id=organization_id,
resource_type="external_provider_config",
resource_id=config_id,
metadata={
"provider_type": provider_type,
},
description=f"External auth provider config created: {provider_type}",
success=True,
)
@staticmethod
def log_external_auth_config_update(
user_id: str,
organization_id: str,
provider_type: str,
config_id: str = None,
changes: dict = None,
):
"""Log external auth provider config update event."""
return AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_CONFIG_UPDATE,
user_id=user_id,
organization_id=organization_id,
resource_type="external_provider_config",
resource_id=config_id,
metadata={
"provider_type": provider_type,
"changes": changes,
},
description=f"External auth provider config updated: {provider_type}",
success=True,
)
@staticmethod
def log_external_auth_config_delete(
user_id: str,
organization_id: str,
provider_type: str,
config_id: str = None,
):
"""Log external auth provider config deletion event."""
return AuditService.log_action(
action=AuditAction.EXTERNAL_AUTH_CONFIG_DELETE,
user_id=user_id,
organization_id=organization_id,
resource_type="external_provider_config",
resource_id=config_id,
metadata={
"provider_type": provider_type,
},
description=f"External auth provider config deleted: {provider_type}",
success=True,
)
@@ -0,0 +1,761 @@
"""External authentication provider service."""
import logging
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional, Tuple
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models import User, AuthenticationMethod
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import AuthMethodType
from gatehouse_app.services.audit_service import AuditService
logger = logging.getLogger(__name__)
class ExternalAuthError(Exception):
"""Base exception for external auth errors."""
def __init__(self, message: str, error_type: str, status_code: int = 400):
self.message = message
self.error_type = error_type
self.status_code = status_code
super().__init__(message)
class OAuthState(BaseModel):
"""Temporary OAuth state storage for secure flow management."""
__tablename__ = "oauth_states"
# State identifier (used in OAuth redirects)
state = db.Column(db.String(64), unique=True, nullable=False, index=True)
# Flow type
flow_type = db.Column(db.String(50), nullable=False) # 'link', 'login', 'register'
# User context
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True)
organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True
)
# Provider information
provider_type = db.Column(db.String(50), nullable=False)
# OAuth parameters
nonce = db.Column(db.String(128), nullable=True)
code_verifier = db.Column(db.String(128), nullable=True)
code_challenge = db.Column(db.String(128), nullable=True)
redirect_uri = db.Column(db.String(2048), nullable=True)
# Additional state data
extra_data = db.Column(db.JSON, nullable=True)
# Expiration
expires_at = db.Column(db.DateTime, nullable=False, index=True)
# Status
used = db.Column(db.Boolean, default=False, nullable=False)
@classmethod
def create_state(
cls,
flow_type: str,
provider_type: AuthMethodType,
user_id: str = None,
organization_id: str = None,
redirect_uri: str = None,
nonce: str = None,
code_verifier: str = None,
code_challenge: str = None,
extra_data: dict = None,
lifetime_seconds: int = 600,
) -> "OAuthState":
"""Create a new OAuth state record."""
state = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds)
return cls.create(
state=state,
flow_type=flow_type,
provider_type=provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type,
user_id=user_id,
organization_id=organization_id,
redirect_uri=redirect_uri,
nonce=nonce or secrets.token_urlsafe(16),
code_verifier=code_verifier,
code_challenge=code_challenge,
extra_data=extra_data,
expires_at=expires_at,
)
def is_valid(self) -> bool:
"""Check if state is still valid."""
return (
not self.used
and self.expires_at
and self.expires_at.replace(tzinfo=timezone.utc) > datetime.now(timezone.utc)
)
def mark_used(self):
"""Mark state as used."""
self.used = True
self.save()
@classmethod
def cleanup_expired(cls):
"""Remove expired states."""
cls.query.filter(cls.expires_at < datetime.now(timezone.utc)).delete()
db.session.commit()
class ExternalProviderConfig(BaseModel):
"""OAuth provider configuration per organization."""
__tablename__ = "external_provider_configs"
# Organization reference
organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
)
# Provider type
provider_type = db.Column(db.String(50), nullable=False, index=True)
# OAuth credentials (client_secret is encrypted)
client_id = db.Column(db.String(255), nullable=False)
client_secret_encrypted = db.Column(db.String(512), nullable=True)
# Provider endpoints
auth_url = db.Column(db.String(2048), nullable=False)
token_url = db.Column(db.String(2048), nullable=False)
userinfo_url = db.Column(db.String(2048), nullable=True)
jwks_url = db.Column(db.String(2048), nullable=True)
# Configuration
scopes = db.Column(db.JSON, nullable=False, default=list)
redirect_uris = db.Column(db.JSON, nullable=False, default=list)
# Provider-specific settings
settings = db.Column(db.JSON, nullable=True)
# Status
is_active = db.Column(db.Boolean, default=True, nullable=False)
# Relationships
organization = db.relationship(
"Organization", back_populates="external_provider_configs"
)
# Indexes
__table_args__ = (
db.Index("idx_provider_config_org", "organization_id", "provider_type"),
db.UniqueConstraint(
"organization_id",
"provider_type",
name="uix_org_provider_type",
),
)
def get_client_secret(self) -> str:
"""Decrypt and return client secret."""
from gatehouse_app.utils.encryption import decrypt
if self.client_secret_encrypted:
return decrypt(self.client_secret_encrypted)
return None
def set_client_secret(self, secret: str):
"""Encrypt and store client secret."""
from gatehouse_app.utils.encryption import encrypt
self.client_secret_encrypted = encrypt(secret)
def is_redirect_uri_allowed(self, uri: str) -> bool:
"""Check if redirect URI is allowed."""
return uri in (self.redirect_uris or [])
def to_dict(self, include_secrets: bool = False) -> dict:
"""Convert to dictionary."""
data = {
"id": self.id,
"organization_id": self.organization_id,
"provider_type": self.provider_type,
"client_id": self.client_id,
"auth_url": self.auth_url,
"token_url": self.token_url,
"userinfo_url": self.userinfo_url,
"scopes": self.scopes,
"redirect_uris": self.redirect_uris,
"is_active": self.is_active,
"settings": self.settings,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
if include_secrets and self.client_secret_encrypted:
data["client_secret"] = self.get_client_secret()
return data
class ExternalAuthService:
"""Service for external authentication operations."""
@classmethod
def get_provider_config(
cls,
organization_id: str,
provider_type: AuthMethodType,
) -> ExternalProviderConfig:
"""Get provider configuration for organization."""
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
config = ExternalProviderConfig.query.filter_by(
organization_id=organization_id,
provider_type=provider_type_str,
is_active=True,
).first()
if not config:
raise ExternalAuthError(
f"{provider_type_str.title()} OAuth is not configured for this organization",
"PROVIDER_NOT_CONFIGURED",
400,
)
return config
@classmethod
def initiate_link_flow(
cls,
user_id: str,
provider_type: AuthMethodType,
organization_id: str,
redirect_uri: str = None,
) -> Tuple[str, str]:
"""
Initiate account linking flow.
Returns:
Tuple of (redirect_url, state)
"""
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
# Get provider config
config = cls.get_provider_config(organization_id, provider_type)
# Validate redirect URI
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
raise ExternalAuthError(
"Invalid redirect URI",
"INVALID_REDIRECT_URI",
400,
)
# Generate PKCE
code_verifier = secrets.token_urlsafe(32)
code_challenge = cls._compute_s256_challenge(code_verifier)
# Create OAuth state
state = OAuthState.create_state(
flow_type="link",
provider_type=provider_type,
user_id=user_id,
organization_id=organization_id,
redirect_uri=redirect_uri or config.redirect_uris[0],
code_verifier=code_verifier,
code_challenge=code_challenge,
lifetime_seconds=600,
)
# Build authorization URL (simplified - in production would use provider-specific implementation)
auth_url = cls._build_authorization_url(
config=config,
state=state,
)
# Audit log - link initiated
AuditService.log_external_auth_link_initiated(
user_id=user_id,
organization_id=organization_id,
provider_type=provider_type_str,
state_id=state.id,
)
return auth_url, state.state
@classmethod
def complete_link_flow(
cls,
provider_type: AuthMethodType,
authorization_code: str,
state: str,
redirect_uri: str,
) -> AuthenticationMethod:
"""Complete account linking flow."""
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
# Validate state
state_record = OAuthState.query.filter_by(state=state).first()
if not state_record or not state_record.is_valid():
AuditService.log_external_auth_link_failed(
user_id=None,
organization_id=None,
provider_type=provider_type_str,
error_message="Invalid or expired OAuth state",
failure_reason="invalid_state",
)
raise ExternalAuthError(
"Invalid or expired OAuth state",
"INVALID_STATE",
400,
)
if state_record.flow_type != "link":
AuditService.log_external_auth_link_failed(
user_id=state_record.user_id,
organization_id=state_record.organization_id,
provider_type=provider_type_str,
error_message="Invalid flow type for this operation",
failure_reason="invalid_flow_type",
)
raise ExternalAuthError(
"Invalid flow type for this operation",
"INVALID_FLOW_TYPE",
400,
)
if state_record.provider_type != provider_type_str:
AuditService.log_external_auth_link_failed(
user_id=state_record.user_id,
organization_id=state_record.organization_id,
provider_type=provider_type_str,
error_message="Provider mismatch",
failure_reason="provider_mismatch",
)
raise ExternalAuthError(
"Provider mismatch",
"PROVIDER_MISMATCH",
400,
)
# Get provider config
config = cls.get_provider_config(
state_record.organization_id, provider_type
)
# Exchange code for tokens (simplified - in production would use provider-specific implementation)
tokens = cls._exchange_code(
config=config,
code=authorization_code,
redirect_uri=redirect_uri,
code_verifier=state_record.code_verifier,
)
# Get user info
user_info = cls._get_user_info(
config=config,
access_token=tokens["access_token"],
)
# Get user
user = User.query.get(state_record.user_id)
if not user:
AuditService.log_external_auth_link_failed(
user_id=None,
organization_id=state_record.organization_id,
provider_type=provider_type_str,
error_message="User not found",
failure_reason="user_not_found",
)
raise ExternalAuthError(
"User not found",
"USER_NOT_FOUND",
400,
)
# Create or update authentication method
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=provider_type,
provider_user_id=user_info["provider_user_id"],
).first()
if auth_method:
# Update existing
auth_method.provider_data = cls._encrypt_provider_data(tokens, user_info)
auth_method.verified = user_info.get("email_verified", False)
auth_method.last_used_at = datetime.utcnow()
auth_method.save()
else:
# Create new
auth_method = AuthenticationMethod(
user_id=user.id,
method_type=provider_type,
provider_user_id=user_info["provider_user_id"],
provider_data=cls._encrypt_provider_data(tokens, user_info),
verified=user_info.get("email_verified", False),
is_primary=False,
last_used_at=datetime.utcnow(),
)
auth_method.save()
# Mark state as used
state_record.mark_used()
# Audit log - link completed
AuditService.log_external_auth_link_completed(
user_id=user.id,
organization_id=state_record.organization_id,
provider_type=provider_type_str,
provider_user_id=user_info["provider_user_id"],
auth_method_id=auth_method.id,
)
return auth_method
@classmethod
def authenticate_with_provider(
cls,
provider_type: AuthMethodType,
organization_id: str,
authorization_code: str,
state: str,
redirect_uri: str,
) -> Tuple[User, dict]:
"""Authenticate user with external provider and return tokens."""
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
# Validate state
state_record = OAuthState.query.filter_by(state=state).first()
if not state_record or not state_record.is_valid():
AuditService.log_external_auth_login_failed(
organization_id=organization_id,
provider_type=provider_type_str,
failure_reason="invalid_state",
error_message="Invalid or expired OAuth state",
)
raise ExternalAuthError(
"Invalid or expired OAuth state",
"INVALID_STATE",
400,
)
# Get provider config
config = cls.get_provider_config(organization_id, provider_type)
# Exchange code for tokens
tokens = cls._exchange_code(
config=config,
code=authorization_code,
redirect_uri=redirect_uri,
code_verifier=state_record.code_verifier,
)
# Get user info
user_info = cls._get_user_info(
config=config,
access_token=tokens["access_token"],
)
# Look up user by provider_user_id
auth_method = AuthenticationMethod.query.filter_by(
method_type=provider_type,
provider_user_id=user_info["provider_user_id"],
).first()
if not auth_method:
# Check if email matches existing user
existing_user = User.query.filter_by(
email=user_info["email"]
).first()
if existing_user:
AuditService.log_external_auth_login_failed(
organization_id=organization_id,
provider_type=provider_type_str,
provider_user_id=user_info["provider_user_id"],
email=user_info["email"],
failure_reason="email_exists",
error_message=f"An account with email {user_info['email']} already exists",
)
raise ExternalAuthError(
f"An account with email {user_info['email']} already exists. "
"Please log in with your password and link your Google account from settings.",
"EMAIL_EXISTS",
400,
)
AuditService.log_external_auth_login_failed(
organization_id=organization_id,
provider_type=provider_type_str,
provider_user_id=user_info["provider_user_id"],
email=user_info["email"],
failure_reason="account_not_found",
error_message="No Gatehouse account matches this external account",
)
raise ExternalAuthError(
"No Gatehouse account matches this external account. Please register first.",
"ACCOUNT_NOT_FOUND",
400,
)
user = auth_method.user
# Update tokens
auth_method.provider_data = cls._encrypt_provider_data(tokens, user_info)
auth_method.last_used_at = datetime.utcnow()
auth_method.save()
# Mark state as used
state_record.mark_used()
# Create session
from gatehouse_app.services.auth_service import AuthService
session = AuthService.create_session(
user=user,
organization_id=organization_id,
)
# Audit log - login success
AuditService.log_external_auth_login(
user_id=user.id,
organization_id=organization_id,
provider_type=provider_type_str,
provider_user_id=user_info["provider_user_id"],
auth_method_id=auth_method.id,
session_id=session.id,
)
return user, session.to_dict()
@classmethod
def unlink_provider(
cls,
user_id: str,
provider_type: AuthMethodType,
organization_id: str = None,
) -> bool:
"""Unlink external provider from user account."""
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
auth_method = AuthenticationMethod.query.filter_by(
user_id=user_id,
method_type=provider_type,
).first()
if not auth_method:
raise ExternalAuthError(
f"Provider not linked",
"PROVIDER_NOT_LINKED",
400,
)
# Check if this is the last auth method
other_methods = AuthenticationMethod.query.filter_by(
user_id=user_id,
).count()
if other_methods <= 1:
raise ExternalAuthError(
"Cannot unlink the last authentication method",
"CANNOT_UNLINK_LAST",
400,
)
provider_user_id = auth_method.provider_user_id
auth_method_id = auth_method.id
auth_method.delete()
# Audit log - unlink
AuditService.log_external_auth_unlink(
user_id=user_id,
organization_id=organization_id,
provider_type=provider_type_str,
provider_user_id=provider_user_id,
auth_method_id=auth_method_id,
)
return True
@classmethod
def get_linked_accounts(cls, user_id: str) -> list:
"""Get all linked external accounts for user."""
methods = AuthenticationMethod.query.filter_by(
user_id=user_id,
).all()
external_providers = [
AuthMethodType.GOOGLE,
AuthMethodType.GITHUB,
AuthMethodType.MICROSOFT,
]
return [
{
"id": m.id,
"provider_type": m.method_type.value if hasattr(m.method_type, 'value') else str(m.method_type),
"provider_user_id": m.provider_user_id,
"email": m.provider_data.get("email") if m.provider_data else None,
"name": m.provider_data.get("name") if m.provider_data else None,
"picture": m.provider_data.get("picture") if m.provider_data else None,
"verified": m.verified,
"linked_at": m.created_at.isoformat() if m.created_at else None,
"last_used_at": m.last_used_at.isoformat() if m.last_used_at else None,
}
for m in methods
if m.method_type in external_providers or str(m.method_type) in [p.value for p in external_providers]
]
@staticmethod
def _compute_s256_challenge(verifier: str) -> str:
"""Compute S256 code challenge from verifier."""
import hashlib
import base64
digest = hashlib.sha256(verifier.encode()).digest()
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
@staticmethod
def _build_authorization_url(config: ExternalProviderConfig, state: OAuthState) -> str:
"""Build authorization URL (simplified - provider-specific in production)."""
from urllib.parse import urlencode
params = {
"client_id": config.client_id,
"redirect_uri": state.redirect_uri,
"response_type": "code",
"scope": " ".join(config.scopes or ["openid", "profile", "email"]),
"state": state.state,
"access_type": config.settings.get("access_type", "offline") if config.settings else "offline",
"prompt": config.settings.get("prompt", "consent") if config.settings else "consent",
}
if state.nonce:
params["nonce"] = state.nonce
if state.code_challenge:
params["code_challenge"] = state.code_challenge
params["code_challenge_method"] = "S256"
return f"{config.auth_url}?{urlencode(params)}"
@staticmethod
def _exchange_code(config: ExternalProviderConfig, code: str, redirect_uri: str, code_verifier: str = None) -> dict:
"""Exchange authorization code for tokens (simplified - provider-specific in production)."""
import requests
data = {
"client_id": config.client_id,
"client_secret": config.get_client_secret(),
"code": code,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
}
if code_verifier:
data["code_verifier"] = code_verifier
response = requests.post(config.token_url, data=data)
response.raise_for_status()
return response.json()
@staticmethod
def _get_user_info(config: ExternalProviderConfig, access_token: str) -> dict:
"""Get user info from provider (simplified - provider-specific in production)."""
import requests
headers = {"Authorization": f"Bearer {access_token}"}
response = requests.get(config.userinfo_url, headers=headers)
response.raise_for_status()
data = response.json()
# Standardize user info
return {
"provider_user_id": data.get("sub"),
"email": data.get("email"),
"email_verified": data.get("email_verified", False),
"name": data.get("name"),
"first_name": data.get("given_name"),
"last_name": data.get("family_name"),
"picture": data.get("picture"),
"raw_data": data,
}
@staticmethod
def _encrypt_provider_data(tokens: dict, user_info: dict) -> dict:
"""Encrypt and store provider tokens and user info."""
from gatehouse_app.utils.encryption import encrypt
result = {
"access_token": encrypt(tokens.get("access_token")) if tokens.get("access_token") else None,
"token_type": tokens.get("token_type", "Bearer"),
"expires_in": tokens.get("expires_in"),
"refresh_token": encrypt(tokens.get("refresh_token")) if tokens.get("refresh_token") else None,
"scope": tokens.get("scope", []),
"id_token": encrypt(tokens.get("id_token")) if tokens.get("id_token") else None,
"email": user_info.get("email"),
"name": user_info.get("name"),
"picture": user_info.get("picture"),
"raw_data": user_info.get("raw_data", {}),
}
return result
@staticmethod
def _decrypt_provider_data(provider_data: dict) -> dict:
"""
Decrypt provider tokens from stored data.
This method handles backward compatibility with existing data where
access_token may be stored in plain text (unencrypted).
"""
from gatehouse_app.utils.encryption import decrypt
if not provider_data:
return {}
result = {
"token_type": provider_data.get("token_type", "Bearer"),
"expires_in": provider_data.get("expires_in"),
"scope": provider_data.get("scope", []),
"email": provider_data.get("email"),
"name": provider_data.get("name"),
"picture": provider_data.get("picture"),
"raw_data": provider_data.get("raw_data", {}),
}
# Decrypt access_token with backward compatibility
access_token = provider_data.get("access_token")
if access_token:
# Try to decrypt - if it fails, assume it's plain text (old data)
try:
result["access_token"] = decrypt(access_token)
except Exception:
# Access token is plain text (pre-encryption data)
result["access_token"] = access_token
else:
result["access_token"] = None
# Decrypt refresh_token
refresh_token = provider_data.get("refresh_token")
if refresh_token:
try:
result["refresh_token"] = decrypt(refresh_token)
except Exception:
result["refresh_token"] = refresh_token
else:
result["refresh_token"] = None
# Decrypt id_token
id_token = provider_data.get("id_token")
if id_token:
try:
result["id_token"] = decrypt(id_token)
except Exception:
result["id_token"] = id_token
else:
result["id_token"] = None
return result
@@ -0,0 +1,524 @@
"""OAuth flow service for handling external authentication flows."""
import logging
import secrets
from datetime import datetime, timedelta, timezone
from typing import Optional, Tuple
from flask import current_app, request, g
from gatehouse_app.extensions import db
from gatehouse_app.models import User, AuthenticationMethod
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import AuthMethodType
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.services.external_auth_service import (
ExternalAuthService,
ExternalAuthError,
OAuthState,
ExternalProviderConfig,
)
logger = logging.getLogger(__name__)
class OAuthFlowError(Exception):
"""Exception for OAuth flow errors."""
def __init__(self, message: str, error_type: str, status_code: int = 400):
self.message = message
self.error_type = error_type
self.status_code = status_code
super().__init__(message)
class OAuthFlowService:
"""Service for managing OAuth authentication flows."""
@classmethod
def initiate_login_flow(
cls,
provider_type: AuthMethodType,
organization_id: str = None,
redirect_uri: str = None,
state_data: dict = None,
) -> Tuple[str, str]:
"""
Initiate OAuth login flow.
Args:
provider_type: The authentication provider type
organization_id: Optional organization context for SSO
redirect_uri: Optional custom redirect URI
state_data: Additional state data to include
Returns:
Tuple of (authorization_url, state)
"""
# Get request context for audit logging
try:
ip_address = request.remote_addr if request else None
user_agent = request.headers.get("User-Agent") if request else None
except RuntimeError:
ip_address = None
user_agent = None
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
try:
# Get provider config
config = ExternalAuthService.get_provider_config(organization_id, provider_type)
# Validate redirect URI
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
raise OAuthFlowError(
"Invalid redirect URI",
"INVALID_REDIRECT_URI",
400,
)
# Generate PKCE
code_verifier = secrets.token_urlsafe(32)
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
# Create OAuth state for login flow
state = OAuthState.create_state(
flow_type="login",
provider_type=provider_type,
organization_id=organization_id,
redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None),
code_verifier=code_verifier,
code_challenge=code_challenge,
extra_data=state_data,
lifetime_seconds=600,
)
# Build authorization URL
auth_url = ExternalAuthService._build_authorization_url(
config=config,
state=state,
)
logger.info(
f"OAuth login flow initiated for provider={provider_type_str}, "
f"org_id={organization_id}, state_id={state.id}"
)
return auth_url, state.state
except ExternalAuthError as e:
# Log failed initiation
AuditService.log_action(
action="external_auth.login.initiated",
organization_id=organization_id,
metadata={
"provider_type": provider_type_str,
"failure_reason": e.error_type,
"ip_address": ip_address,
},
description=f"OAuth login initiation failed: {e.message}",
success=False,
error_message=e.message,
)
raise
@classmethod
def initiate_register_flow(
cls,
provider_type: AuthMethodType,
organization_id: str = None,
redirect_uri: str = None,
) -> Tuple[str, str]:
"""
Initiate OAuth registration flow.
Args:
provider_type: The authentication provider type
organization_id: Optional organization context
redirect_uri: Optional custom redirect URI
Returns:
Tuple of (authorization_url, state)
"""
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
try:
# Get provider config
config = ExternalAuthService.get_provider_config(organization_id, provider_type)
# Validate redirect URI
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
raise OAuthFlowError(
"Invalid redirect URI",
"INVALID_REDIRECT_URI",
400,
)
# Generate PKCE
code_verifier = secrets.token_urlsafe(32)
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
# Create OAuth state for register flow
state = OAuthState.create_state(
flow_type="register",
provider_type=provider_type,
organization_id=organization_id,
redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None),
code_verifier=code_verifier,
code_challenge=code_challenge,
lifetime_seconds=600,
)
# Build authorization URL
auth_url = ExternalAuthService._build_authorization_url(
config=config,
state=state,
)
logger.info(
f"OAuth register flow initiated for provider={provider_type_str}, "
f"org_id={organization_id}, state_id={state.id}"
)
return auth_url, state.state
except ExternalAuthError as e:
AuditService.log_action(
action="external_auth.register.initiated",
organization_id=organization_id,
metadata={
"provider_type": provider_type_str,
"failure_reason": e.error_type,
},
description=f"OAuth registration initiation failed: {e.message}",
success=False,
error_message=e.message,
)
raise
@classmethod
def handle_callback(
cls,
provider_type: AuthMethodType,
authorization_code: str,
state: str,
redirect_uri: str = None,
error: str = None,
error_description: str = None,
) -> dict:
"""
Handle OAuth callback from provider.
Args:
provider_type: The authentication provider type
authorization_code: Authorization code from provider
state: State parameter from provider
redirect_uri: Redirect URI used in the flow
error: Error code if auth failed
error_description: Human-readable error description
Returns:
Dict with flow result
"""
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
# Get request context for audit logging
try:
ip_address = request.remote_addr if request else None
user_agent = request.headers.get("User-Agent") if request else None
except RuntimeError:
ip_address = None
user_agent = None
# Handle error response from provider
if error:
AuditService.log_external_auth_login_failed(
organization_id=None,
provider_type=provider_type_str,
failure_reason=error,
error_message=error_description or error,
)
raise OAuthFlowError(
error_description or f"OAuth error: {error}",
error.upper() if error else "OAUTH_ERROR",
400,
)
# Validate state
state_record = OAuthState.query.filter_by(state=state).first()
if not state_record or not state_record.is_valid():
AuditService.log_external_auth_login_failed(
organization_id=state_record.organization_id if state_record else None,
provider_type=provider_type_str,
failure_reason="invalid_state",
error_message="Invalid or expired OAuth state",
)
raise OAuthFlowError(
"Invalid or expired OAuth state",
"INVALID_STATE",
400,
)
# Route to appropriate handler based on flow type
if state_record.flow_type == "login":
return cls._handle_login_callback(
provider_type=provider_type,
state_record=state_record,
authorization_code=authorization_code,
redirect_uri=redirect_uri or state_record.redirect_uri,
ip_address=ip_address,
user_agent=user_agent,
)
elif state_record.flow_type == "link":
return cls._handle_link_callback(
provider_type=provider_type,
state_record=state_record,
authorization_code=authorization_code,
redirect_uri=redirect_uri or state_record.redirect_uri,
)
elif state_record.flow_type == "register":
return cls._handle_register_callback(
provider_type=provider_type,
state_record=state_record,
authorization_code=authorization_code,
redirect_uri=redirect_uri or state_record.redirect_uri,
)
else:
raise OAuthFlowError(
f"Unknown flow type: {state_record.flow_type}",
"INVALID_FLOW_TYPE",
400,
)
@classmethod
def _handle_login_callback(
cls,
provider_type: AuthMethodType,
state_record: OAuthState,
authorization_code: str,
redirect_uri: str,
ip_address: str = None,
user_agent: str = None,
) -> dict:
"""Handle login flow callback."""
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
try:
# Authenticate with provider
user, session_data = ExternalAuthService.authenticate_with_provider(
provider_type=provider_type,
organization_id=state_record.organization_id,
authorization_code=authorization_code,
state=state_record.state,
redirect_uri=redirect_uri,
)
logger.info(
f"OAuth login successful for user={user.id}, "
f"provider={provider_type_str}, org_id={state_record.organization_id}"
)
return {
"success": True,
"flow_type": "login",
"user": {
"id": user.id,
"email": user.email,
"full_name": user.full_name,
"organization_id": state_record.organization_id,
},
"session": session_data,
}
except ExternalAuthError as e:
logger.warning(
f"OAuth login failed for state={state_record.id}, "
f"provider={provider_type_str}, error={e.message}"
)
raise
@classmethod
def _handle_link_callback(
cls,
provider_type: AuthMethodType,
state_record: OAuthState,
authorization_code: str,
redirect_uri: str,
) -> dict:
"""Handle account linking flow callback."""
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
try:
# Complete link flow
auth_method = ExternalAuthService.complete_link_flow(
provider_type=provider_type,
authorization_code=authorization_code,
state=state_record.state,
redirect_uri=redirect_uri,
)
logger.info(
f"OAuth link successful for user={state_record.user_id}, "
f"provider={provider_type_str}, auth_method_id={auth_method.id}"
)
return {
"success": True,
"flow_type": "link",
"linked_account": {
"id": auth_method.id,
"provider_type": provider_type_str,
"provider_user_id": auth_method.provider_user_id,
"verified": auth_method.verified,
},
}
except ExternalAuthError as e:
logger.warning(
f"OAuth link failed for state={state_record.id}, "
f"provider={provider_type_str}, error={e.message}"
)
raise
@classmethod
def _handle_register_callback(
cls,
provider_type: AuthMethodType,
state_record: OAuthState,
authorization_code: str,
redirect_uri: str,
) -> dict:
"""Handle registration flow callback."""
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
try:
# Get provider config
config = ExternalAuthService.get_provider_config(
state_record.organization_id, provider_type
)
# Exchange code for tokens
tokens = ExternalAuthService._exchange_code(
config=config,
code=authorization_code,
redirect_uri=redirect_uri,
code_verifier=state_record.code_verifier,
)
# Get user info
user_info = ExternalAuthService._get_user_info(
config=config,
access_token=tokens["access_token"],
)
# Check if user already exists by email
existing_user = User.query.filter_by(
email=user_info["email"]
).first()
if existing_user:
# User exists - suggest linking
raise OAuthFlowError(
f"An account with email {user_info['email']} already exists. "
"Please log in with your password and link your Google account from settings.",
"EMAIL_EXISTS",
400,
)
# Create new user
user = User(
email=user_info["email"],
full_name=user_info.get("name", ""),
status="active",
)
user.save()
# Create authentication method
auth_method = AuthenticationMethod(
user_id=user.id,
method_type=provider_type,
provider_user_id=user_info["provider_user_id"],
provider_data=ExternalAuthService._encrypt_provider_data(tokens, user_info),
verified=user_info.get("email_verified", False),
is_primary=True,
)
auth_method.save()
# Mark state as used
state_record.mark_used()
# Audit log - registration success
AuditService.log_action(
action="user.register",
user_id=user.id,
organization_id=state_record.organization_id,
resource_type="user",
resource_id=user.id,
metadata={
"provider_type": provider_type_str,
"provider_user_id": user_info["provider_user_id"],
"auth_method_id": auth_method.id,
},
description=f"User registered via {provider_type_str}",
success=True,
)
AuditService.log_external_auth_link_completed(
user_id=user.id,
organization_id=state_record.organization_id,
provider_type=provider_type_str,
provider_user_id=user_info["provider_user_id"],
auth_method_id=auth_method.id,
)
logger.info(
f"OAuth registration successful for email={user_info['email']}, "
f"provider={provider_type_str}, user_id={user.id}"
)
# Create session
from gatehouse_app.services.auth_service import AuthService
session = AuthService.create_session(
user=user,
organization_id=state_record.organization_id,
)
return {
"success": True,
"flow_type": "register",
"user": {
"id": user.id,
"email": user.email,
"full_name": user.full_name,
"organization_id": state_record.organization_id,
},
"session": session.to_dict(),
}
except ExternalAuthError as e:
logger.warning(
f"OAuth registration failed for state={state_record.id}, "
f"provider={provider_type_str}, error={e.message}"
)
raise
@classmethod
def validate_state(cls, state: str) -> Optional[OAuthState]:
"""
Validate and return OAuth state.
Args:
state: The state parameter to validate
Returns:
OAuthState if valid, None otherwise
"""
state_record = OAuthState.query.filter_by(state=state).first()
if state_record and state_record.is_valid():
return state_record
return None
@classmethod
def cleanup_expired_states(cls):
"""Remove expired OAuth states."""
OAuthState.cleanup_expired()
logger.info("Expired OAuth states cleaned up")
+12
View File
@@ -93,6 +93,18 @@ class AuditAction(str, Enum):
MFA_POLICY_USER_SUSPENDED = "mfa.policy.user_suspended"
MFA_POLICY_USER_COMPLIANT = "mfa.policy.user_compliant"
# External authentication provider actions
EXTERNAL_AUTH_LINK_INITIATED = "external_auth.link.initiated"
EXTERNAL_AUTH_LINK_COMPLETED = "external_auth.link.completed"
EXTERNAL_AUTH_LINK_FAILED = "external_auth.link.failed"
EXTERNAL_AUTH_UNLINK = "external_auth.unlink"
EXTERNAL_AUTH_LOGIN = "external_auth.login"
EXTERNAL_AUTH_LOGIN_FAILED = "external_auth.login.failed"
EXTERNAL_AUTH_TOKEN_REFRESH = "external_auth.token_refresh"
EXTERNAL_AUTH_CONFIG_CREATE = "external_auth.config.create"
EXTERNAL_AUTH_CONFIG_UPDATE = "external_auth.config.update"
EXTERNAL_AUTH_CONFIG_DELETE = "external_auth.config.delete"
class OIDCGrantType(str, Enum):
"""OIDC grant types."""
+112
View File
@@ -0,0 +1,112 @@
"""Encryption utilities for sensitive data."""
import base64
import os
from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
# Encryption key derivation settings
SALT_LENGTH = 16
KEY_ITERATIONS = 480000
def _get_fernet_key(secret_key: str, salt: bytes = None) -> bytes:
"""
Derive a Fernet key from a secret key using PBKDF2.
Args:
secret_key: The master secret key
salt: Optional salt bytes (will be generated if not provided)
Returns:
32-byte key suitable for Fernet encryption
"""
if salt is None:
salt = os.urandom(SALT_LENGTH)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=KEY_ITERATIONS,
)
key = base64.urlsafe_b64encode(kdf.derive(secret_key.encode()))
return key
def encrypt(plaintext: str, secret_key: str = None) -> str:
"""
Encrypt a string using Fernet symmetric encryption.
Args:
plaintext: The string to encrypt
secret_key: The encryption key (uses app config if not provided)
Returns:
Base64-encoded encrypted string with salt prepended
"""
from flask import current_app
if not plaintext:
return ""
# Get secret key from app config or use provided key
if secret_key is None:
secret_key = current_app.config.get("ENCRYPTION_KEY", "")
if not secret_key:
raise ValueError("Encryption key not configured")
# Generate a random salt for this encryption
salt = os.urandom(SALT_LENGTH)
fernet_key = _get_fernet_key(secret_key, salt)
fernet = Fernet(fernet_key)
# Encrypt the plaintext
encrypted_bytes = fernet.encrypt(plaintext.encode())
# Combine salt + encrypted data and base64 encode
combined = salt + encrypted_bytes
return base64.urlsafe_b64encode(combined).decode()
def decrypt(encrypted_data: str, secret_key: str = None) -> str:
"""
Decrypt a string that was encrypted with the encrypt function.
Args:
encrypted_data: Base64-encoded encrypted string with salt prepended
secret_key: The encryption key (uses app config if not provided)
Returns:
The original plaintext string
"""
from flask import current_app
if not encrypted_data:
return ""
# Get secret key from app config or use provided key
if secret_key is None:
secret_key = current_app.config.get("ENCRYPTION_KEY", "")
if not secret_key:
raise ValueError("Encryption key not configured")
try:
# Decode from base64
combined = base64.urlsafe_b64decode(encrypted_data.encode())
# Extract salt and encrypted data
salt = combined[:SALT_LENGTH]
encrypted_bytes = combined[SALT_LENGTH:]
# Derive the key and decrypt
fernet_key = _get_fernet_key(secret_key, salt)
fernet = Fernet(fernet_key)
plaintext = fernet.decrypt(encrypted_bytes)
return plaintext.decode()
except (InvalidToken, ValueError):
raise ValueError("Failed to decrypt data - invalid key or corrupted data")