Feat(Chore, Fix): Refractor, Half Baked Deletion + Admin Privilege
Refractor Codes into sub file/folders Admin can remove users'/members mfa/2fa, unlink account from oauth provider Admin can add/reset password Different Email (OIDC + Manual)-Same Account; (Block Linking and authorize if available)
This commit is contained in:
@@ -4,7 +4,7 @@ from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.services.session_service import SessionService
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.oidc_service import OIDCService, OIDCError
|
||||
from gatehouse_app.services.oidc import OIDCService, OIDCError
|
||||
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
|
||||
from gatehouse_app.services.oidc_token_service import OIDCTokenService
|
||||
from gatehouse_app.services.oidc_session_service import OIDCSessionService
|
||||
|
||||
@@ -388,7 +388,7 @@ class AuthService:
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
password: User's current password for verification
|
||||
password: User's current password for verification (ignored for OAuth-only users)
|
||||
|
||||
Returns:
|
||||
True if TOTP disabled successfully
|
||||
@@ -396,18 +396,21 @@ class AuthService:
|
||||
Raises:
|
||||
InvalidCredentialsError: If password is invalid or TOTP method not found
|
||||
"""
|
||||
# Verify user's password
|
||||
# Verify user's password — only required when the user actually has one.
|
||||
# OAuth-only users have no PASSWORD auth method; they authenticate via their
|
||||
# identity provider so there is nothing to check here.
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.PASSWORD,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not auth_method or not auth_method.password_hash:
|
||||
raise InvalidCredentialsError("No password authentication method found")
|
||||
|
||||
if not bcrypt.check_password_hash(auth_method.password_hash, password):
|
||||
raise InvalidCredentialsError("Invalid password")
|
||||
if auth_method and auth_method.password_hash:
|
||||
# Password-based account: a password must be supplied and must match.
|
||||
if not password:
|
||||
raise InvalidCredentialsError("Password is required")
|
||||
if not bcrypt.check_password_hash(auth_method.password_hash, password):
|
||||
raise InvalidCredentialsError("Invalid password")
|
||||
|
||||
# Get user's TOTP authentication method
|
||||
totp_method = user.get_totp_method()
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
"""ExternalAuthService — public facade re-exporting the full API."""
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from gatehouse_app.models import AuthenticationMethod, User
|
||||
from gatehouse_app.models.auth.authentication_method import (
|
||||
ApplicationProviderConfig,
|
||||
OrganizationProviderOverride,
|
||||
OAuthState,
|
||||
)
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
from gatehouse_app.services.external_auth.models import (
|
||||
ExternalAuthError,
|
||||
ExternalProviderConfig,
|
||||
ProviderConfigAdapter,
|
||||
)
|
||||
from gatehouse_app.services.external_auth import app_provider, org_override, linking
|
||||
from gatehouse_app.services.external_auth._helpers import (
|
||||
_compute_s256_challenge,
|
||||
_build_authorization_url,
|
||||
_exchange_code,
|
||||
_get_user_info,
|
||||
_encrypt_provider_data,
|
||||
_decrypt_provider_data,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExternalAuthService:
|
||||
"""Service for external authentication operations."""
|
||||
|
||||
# ── Provider config lookup ──────────────────────────────────────────────
|
||||
|
||||
@classmethod
|
||||
def get_provider_config(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: Optional[str] = None,
|
||||
) -> ProviderConfigAdapter:
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
app_config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type_str
|
||||
).first()
|
||||
|
||||
if not app_config:
|
||||
raise ExternalAuthError(
|
||||
f"{provider_type_str.title()} OAuth is not configured for this application",
|
||||
"PROVIDER_NOT_CONFIGURED",
|
||||
400,
|
||||
)
|
||||
|
||||
if not app_config.is_enabled:
|
||||
raise ExternalAuthError(
|
||||
f"{provider_type_str.title()} OAuth is currently disabled",
|
||||
"PROVIDER_DISABLED",
|
||||
400,
|
||||
)
|
||||
|
||||
org_override_obj = None
|
||||
if organization_id:
|
||||
org_override_obj = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
).first()
|
||||
|
||||
if org_override_obj and not org_override_obj.is_enabled:
|
||||
raise ExternalAuthError(
|
||||
f"{provider_type_str.title()} OAuth is disabled for this organization",
|
||||
"PROVIDER_DISABLED_FOR_ORG",
|
||||
400,
|
||||
)
|
||||
|
||||
return ProviderConfigAdapter(app_config, org_override_obj)
|
||||
|
||||
# ── App-wide provider config ────────────────────────────────────────────
|
||||
|
||||
@classmethod
|
||||
def create_app_provider_config(cls, provider_type, client_id, client_secret, **kwargs):
|
||||
return app_provider.create_app_provider_config(provider_type, client_id, client_secret, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def update_app_provider_config(cls, provider_type, **updates):
|
||||
return app_provider.update_app_provider_config(provider_type, **updates)
|
||||
|
||||
@classmethod
|
||||
def get_app_provider_config(cls, provider_type):
|
||||
return app_provider.get_app_provider_config(provider_type)
|
||||
|
||||
@classmethod
|
||||
def list_app_provider_configs(cls):
|
||||
return app_provider.list_app_provider_configs()
|
||||
|
||||
@classmethod
|
||||
def delete_app_provider_config(cls, provider_type):
|
||||
return app_provider.delete_app_provider_config(provider_type)
|
||||
|
||||
# ── Org override management ─────────────────────────────────────────────
|
||||
|
||||
@classmethod
|
||||
def create_org_provider_override(cls, organization_id, provider_type, **kwargs):
|
||||
return org_override.create_org_provider_override(organization_id, provider_type, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def update_org_provider_override(cls, organization_id, provider_type, **updates):
|
||||
return org_override.update_org_provider_override(organization_id, provider_type, **updates)
|
||||
|
||||
@classmethod
|
||||
def get_org_provider_override(cls, organization_id, provider_type):
|
||||
return org_override.get_org_provider_override(organization_id, provider_type)
|
||||
|
||||
@classmethod
|
||||
def list_org_provider_overrides(cls, organization_id):
|
||||
return org_override.list_org_provider_overrides(organization_id)
|
||||
|
||||
@classmethod
|
||||
def delete_org_provider_override(cls, organization_id, provider_type):
|
||||
return org_override.delete_org_provider_override(organization_id, provider_type)
|
||||
|
||||
# ── OAuth link / auth flows ─────────────────────────────────────────────
|
||||
|
||||
@classmethod
|
||||
def initiate_link_flow(cls, user_id, provider_type, organization_id, redirect_uri=None):
|
||||
return linking.initiate_link_flow(cls.get_provider_config, user_id, provider_type, organization_id, redirect_uri)
|
||||
|
||||
@classmethod
|
||||
def complete_link_flow(cls, provider_type, authorization_code, state, redirect_uri):
|
||||
return linking.complete_link_flow(cls.get_provider_config, provider_type, authorization_code, state, redirect_uri)
|
||||
|
||||
@classmethod
|
||||
def authenticate_with_provider(cls, provider_type, organization_id, authorization_code, state, redirect_uri):
|
||||
return linking.authenticate_with_provider(cls.get_provider_config, provider_type, organization_id, authorization_code, state, redirect_uri)
|
||||
|
||||
@classmethod
|
||||
def unlink_provider(cls, user_id, provider_type, organization_id=None):
|
||||
return linking.unlink_provider(user_id, provider_type, organization_id)
|
||||
|
||||
@classmethod
|
||||
def get_linked_accounts(cls, user_id):
|
||||
return linking.get_linked_accounts(user_id)
|
||||
|
||||
# ── Static helpers (kept as class methods for backward compatibility) ───
|
||||
|
||||
@staticmethod
|
||||
def _compute_s256_challenge(verifier: str) -> str:
|
||||
return _compute_s256_challenge(verifier)
|
||||
|
||||
@staticmethod
|
||||
def _build_authorization_url(config, state) -> str:
|
||||
return _build_authorization_url(config, state)
|
||||
|
||||
@staticmethod
|
||||
def _exchange_code(config, code, redirect_uri, code_verifier=None) -> dict:
|
||||
return _exchange_code(config, code, redirect_uri, code_verifier)
|
||||
|
||||
@staticmethod
|
||||
def _get_user_info(config, access_token) -> dict:
|
||||
return _get_user_info(config, access_token)
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_provider_data(tokens, user_info) -> dict:
|
||||
return _encrypt_provider_data(tokens, user_info)
|
||||
|
||||
@staticmethod
|
||||
def _decrypt_provider_data(provider_data) -> dict:
|
||||
return _decrypt_provider_data(provider_data)
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Static helper methods for OAuth flows."""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _compute_s256_challenge(verifier: str) -> str:
|
||||
import hashlib
|
||||
import base64
|
||||
digest = hashlib.sha256(verifier.encode()).digest()
|
||||
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
|
||||
|
||||
def _build_authorization_url(config, state) -> str:
|
||||
from urllib.parse import urlencode
|
||||
provider = (config.provider_type or "").lower()
|
||||
|
||||
params = {
|
||||
"client_id": config.client_id,
|
||||
"redirect_uri": state.redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(config.scopes or ["openid", "profile", "email"]),
|
||||
"state": state.state,
|
||||
}
|
||||
|
||||
if provider == "google":
|
||||
params["access_type"] = (
|
||||
config.settings.get("access_type", "offline") if config.settings else "offline"
|
||||
)
|
||||
params["prompt"] = (
|
||||
config.settings.get("prompt", "consent") if config.settings else "consent"
|
||||
)
|
||||
elif provider == "microsoft":
|
||||
params["prompt"] = (
|
||||
config.settings.get("prompt", "select_account") if config.settings else "select_account"
|
||||
)
|
||||
else:
|
||||
if config.settings:
|
||||
if "prompt" in config.settings:
|
||||
params["prompt"] = config.settings["prompt"]
|
||||
if "access_type" in config.settings:
|
||||
params["access_type"] = config.settings["access_type"]
|
||||
|
||||
if state.nonce:
|
||||
params["nonce"] = state.nonce
|
||||
|
||||
if state.code_challenge:
|
||||
params["code_challenge"] = state.code_challenge
|
||||
params["code_challenge_method"] = "S256"
|
||||
|
||||
full_url = f"{config.auth_url}?{urlencode(params)}"
|
||||
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Building authorization URL:\n"
|
||||
f" provider_type: {config.provider_type}\n"
|
||||
f" state.code_challenge: {state.code_challenge[:20] if state.code_challenge else 'None'}...\n"
|
||||
f" params has code_challenge: {'code_challenge' in params}\n"
|
||||
f" Full URL: {full_url}"
|
||||
)
|
||||
|
||||
return full_url
|
||||
|
||||
|
||||
def _exchange_code(config, code: str, redirect_uri: str, code_verifier: str = None) -> dict:
|
||||
import requests
|
||||
|
||||
data = {
|
||||
"client_id": config.client_id,
|
||||
"client_secret": config.get_client_secret(),
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
if code_verifier:
|
||||
data["code_verifier"] = code_verifier
|
||||
|
||||
logger.debug(
|
||||
f"Token exchange request: url={config.token_url}, "
|
||||
f"client_id={config.client_id}, redirect_uri={redirect_uri}, "
|
||||
f"has_code_verifier={bool(code_verifier)}"
|
||||
)
|
||||
|
||||
response = requests.post(config.token_url, data=data)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Token exchange failed: status={response.status_code}, "
|
||||
f"response={response.text}"
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _get_user_info(config, access_token: str) -> dict:
|
||||
import re
|
||||
import requests
|
||||
|
||||
provider = (config.provider_type or "").lower()
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
response = requests.get(config.userinfo_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
if provider == "microsoft":
|
||||
email_verified = data.get("email_verified", True)
|
||||
else:
|
||||
email_verified = data.get("email_verified", False)
|
||||
|
||||
sub = data.get("sub")
|
||||
|
||||
raw_email = data.get("email")
|
||||
if not raw_email and sub:
|
||||
if re.match(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", sub):
|
||||
raw_email = sub
|
||||
email_verified = True
|
||||
else:
|
||||
raw_email = f"{sub}@{provider or 'oauth'}.local"
|
||||
email_verified = False
|
||||
|
||||
raw_name = data.get("name") or data.get("display_name")
|
||||
if not raw_name and raw_email:
|
||||
raw_name = raw_email.split("@")[0]
|
||||
|
||||
return {
|
||||
"provider_user_id": sub,
|
||||
"email": raw_email,
|
||||
"email_verified": email_verified,
|
||||
"name": raw_name,
|
||||
"first_name": data.get("given_name"),
|
||||
"last_name": data.get("family_name"),
|
||||
"picture": data.get("picture"),
|
||||
"raw_data": data,
|
||||
}
|
||||
|
||||
|
||||
def _encrypt_provider_data(tokens: dict, user_info: dict) -> dict:
|
||||
from gatehouse_app.utils.encryption import encrypt
|
||||
|
||||
return {
|
||||
"access_token": encrypt(tokens.get("access_token")) if tokens.get("access_token") else None,
|
||||
"token_type": tokens.get("token_type", "Bearer"),
|
||||
"expires_in": tokens.get("expires_in"),
|
||||
"refresh_token": encrypt(tokens.get("refresh_token")) if tokens.get("refresh_token") else None,
|
||||
"scope": tokens.get("scope", []),
|
||||
"id_token": encrypt(tokens.get("id_token")) if tokens.get("id_token") else None,
|
||||
"email": user_info.get("email"),
|
||||
"name": user_info.get("name"),
|
||||
"picture": user_info.get("picture"),
|
||||
"raw_data": user_info.get("raw_data", {}),
|
||||
}
|
||||
|
||||
|
||||
def _decrypt_provider_data(provider_data: dict) -> dict:
|
||||
from gatehouse_app.utils.encryption import decrypt
|
||||
|
||||
if not provider_data:
|
||||
return {}
|
||||
|
||||
result = {
|
||||
"token_type": provider_data.get("token_type", "Bearer"),
|
||||
"expires_in": provider_data.get("expires_in"),
|
||||
"scope": provider_data.get("scope", []),
|
||||
"email": provider_data.get("email"),
|
||||
"name": provider_data.get("name"),
|
||||
"picture": provider_data.get("picture"),
|
||||
"raw_data": provider_data.get("raw_data", {}),
|
||||
}
|
||||
|
||||
for field in ("access_token", "refresh_token", "id_token"):
|
||||
value = provider_data.get(field)
|
||||
if value:
|
||||
try:
|
||||
result[field] = decrypt(value)
|
||||
except Exception:
|
||||
result[field] = value
|
||||
else:
|
||||
result[field] = None
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Application-wide provider configuration management."""
|
||||
import logging
|
||||
|
||||
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
|
||||
from gatehouse_app.services.external_auth.models import ExternalAuthError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app_provider_config(
|
||||
provider_type: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
**kwargs,
|
||||
) -> ApplicationProviderConfig:
|
||||
existing = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise ExternalAuthError(
|
||||
f"Provider {provider_type} already exists",
|
||||
"PROVIDER_EXISTS",
|
||||
400,
|
||||
)
|
||||
|
||||
additional_config = {}
|
||||
for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']:
|
||||
if key in kwargs:
|
||||
additional_config[key] = kwargs.pop(key)
|
||||
|
||||
if 'settings' in kwargs:
|
||||
additional_config.update(kwargs.pop('settings'))
|
||||
|
||||
config = ApplicationProviderConfig(
|
||||
provider_type=provider_type,
|
||||
client_id=client_id,
|
||||
is_enabled=kwargs.get('is_enabled', True),
|
||||
default_redirect_url=kwargs.get('default_redirect_url'),
|
||||
additional_config=additional_config,
|
||||
)
|
||||
config.set_client_secret(client_secret)
|
||||
config.save()
|
||||
|
||||
logger.info(f"Created application provider config for {provider_type}")
|
||||
return config
|
||||
|
||||
|
||||
def update_app_provider_config(
|
||||
provider_type: str,
|
||||
**updates,
|
||||
) -> ApplicationProviderConfig:
|
||||
config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
raise ExternalAuthError(
|
||||
f"Provider {provider_type} not found",
|
||||
"PROVIDER_NOT_FOUND",
|
||||
404,
|
||||
)
|
||||
|
||||
if 'client_id' in updates:
|
||||
config.client_id = updates['client_id']
|
||||
|
||||
if 'client_secret' in updates:
|
||||
config.set_client_secret(updates['client_secret'])
|
||||
|
||||
if 'is_enabled' in updates:
|
||||
config.is_enabled = updates['is_enabled']
|
||||
|
||||
if 'default_redirect_url' in updates:
|
||||
config.default_redirect_url = updates['default_redirect_url']
|
||||
|
||||
if config.additional_config is None:
|
||||
config.additional_config = {}
|
||||
|
||||
for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']:
|
||||
if key in updates:
|
||||
config.additional_config[key] = updates[key]
|
||||
|
||||
if 'settings' in updates:
|
||||
config.additional_config.update(updates['settings'])
|
||||
|
||||
config.save()
|
||||
logger.info(f"Updated application provider config for {provider_type}")
|
||||
return config
|
||||
|
||||
|
||||
def get_app_provider_config(provider_type: str) -> ApplicationProviderConfig:
|
||||
config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
raise ExternalAuthError(
|
||||
f"Provider {provider_type} not found",
|
||||
"PROVIDER_NOT_FOUND",
|
||||
404,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def list_app_provider_configs() -> list:
|
||||
configs = ApplicationProviderConfig.query.all()
|
||||
return [config.to_dict() for config in configs]
|
||||
|
||||
|
||||
def delete_app_provider_config(provider_type: str) -> bool:
|
||||
config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
raise ExternalAuthError(
|
||||
f"Provider {provider_type} not found",
|
||||
"PROVIDER_NOT_FOUND",
|
||||
404,
|
||||
)
|
||||
|
||||
config.delete()
|
||||
logger.info(f"Deleted application provider config for {provider_type}")
|
||||
return True
|
||||
@@ -0,0 +1,339 @@
|
||||
"""Account linking, authentication, and unlinking flows."""
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.auth.authentication_method import OAuthState
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth.models import ExternalAuthError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initiate_link_flow(
|
||||
get_provider_config,
|
||||
user_id: str,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str,
|
||||
redirect_uri: str = None,
|
||||
) -> Tuple[str, str]:
|
||||
from gatehouse_app.services.external_auth._helpers import (
|
||||
_compute_s256_challenge,
|
||||
_build_authorization_url,
|
||||
)
|
||||
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
config = get_provider_config(provider_type, organization_id)
|
||||
|
||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||
raise ExternalAuthError("Invalid redirect URI", "INVALID_REDIRECT_URI", 400)
|
||||
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if provider_type_str not in ('google', 'microsoft'):
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = _compute_s256_challenge(code_verifier)
|
||||
|
||||
state = OAuthState.create_state(
|
||||
flow_type="link",
|
||||
provider_type=provider_type,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None),
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
|
||||
auth_url = _build_authorization_url(config=config, state=state)
|
||||
|
||||
AuditService.log_external_auth_link_initiated(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
state_id=state.id,
|
||||
)
|
||||
|
||||
return auth_url, state.state
|
||||
|
||||
|
||||
def complete_link_flow(
|
||||
get_provider_config,
|
||||
provider_type: AuthMethodType,
|
||||
authorization_code: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
) -> AuthenticationMethod:
|
||||
from gatehouse_app.services.external_auth._helpers import (
|
||||
_exchange_code,
|
||||
_get_user_info,
|
||||
_encrypt_provider_data,
|
||||
)
|
||||
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
state_record = OAuthState.query.filter_by(state=state).first()
|
||||
if not state_record or not state_record.is_valid():
|
||||
AuditService.log_external_auth_link_failed(
|
||||
user_id=None,
|
||||
organization_id=None,
|
||||
provider_type=provider_type_str,
|
||||
error_message="Invalid or expired OAuth state",
|
||||
failure_reason="invalid_state",
|
||||
)
|
||||
raise ExternalAuthError("Invalid or expired OAuth state", "INVALID_STATE", 400)
|
||||
|
||||
if state_record.flow_type != "link":
|
||||
AuditService.log_external_auth_link_failed(
|
||||
user_id=state_record.user_id,
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
error_message="Invalid flow type for this operation",
|
||||
failure_reason="invalid_flow_type",
|
||||
)
|
||||
raise ExternalAuthError("Invalid flow type for this operation", "INVALID_FLOW_TYPE", 400)
|
||||
|
||||
if state_record.provider_type != provider_type_str:
|
||||
AuditService.log_external_auth_link_failed(
|
||||
user_id=state_record.user_id,
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
error_message="Provider mismatch",
|
||||
failure_reason="provider_mismatch",
|
||||
)
|
||||
raise ExternalAuthError("Provider mismatch", "PROVIDER_MISMATCH", 400)
|
||||
|
||||
config = get_provider_config(provider_type, state_record.organization_id)
|
||||
|
||||
tokens = _exchange_code(
|
||||
config=config,
|
||||
code=authorization_code,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=state_record.code_verifier,
|
||||
)
|
||||
|
||||
user_info = _get_user_info(config=config, access_token=tokens["access_token"])
|
||||
|
||||
user = User.query.get(state_record.user_id)
|
||||
if not user:
|
||||
AuditService.log_external_auth_link_failed(
|
||||
user_id=None,
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
error_message="User not found",
|
||||
failure_reason="user_not_found",
|
||||
)
|
||||
raise ExternalAuthError("User not found", "USER_NOT_FOUND", 400)
|
||||
|
||||
conflicting = AuthenticationMethod.query.filter(
|
||||
AuthenticationMethod.method_type == provider_type,
|
||||
AuthenticationMethod.provider_user_id == user_info["provider_user_id"],
|
||||
AuthenticationMethod.user_id != user.id,
|
||||
AuthenticationMethod.deleted_at == None,
|
||||
).first()
|
||||
if conflicting:
|
||||
raise ExternalAuthError(
|
||||
f"This {provider_type_str} account is already linked to a different Gatehouse user.",
|
||||
"PROVIDER_ALREADY_LINKED",
|
||||
409,
|
||||
)
|
||||
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
).first()
|
||||
|
||||
if auth_method:
|
||||
# Restore the row if it was previously soft-deleted (re-linking after admin unlink)
|
||||
auth_method.deleted_at = None
|
||||
auth_method.provider_data = _encrypt_provider_data(tokens, user_info)
|
||||
auth_method.verified = user_info.get("email_verified", False)
|
||||
auth_method.last_used_at = datetime.utcnow()
|
||||
auth_method.save()
|
||||
else:
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=user.id,
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
provider_data=_encrypt_provider_data(tokens, user_info),
|
||||
verified=user_info.get("email_verified", False),
|
||||
is_primary=False,
|
||||
last_used_at=datetime.utcnow(),
|
||||
)
|
||||
auth_method.save()
|
||||
|
||||
state_record.mark_used()
|
||||
|
||||
AuditService.log_external_auth_link_completed(
|
||||
user_id=user.id,
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
auth_method_id=auth_method.id,
|
||||
)
|
||||
|
||||
return auth_method
|
||||
|
||||
|
||||
def authenticate_with_provider(
|
||||
get_provider_config,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str,
|
||||
authorization_code: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
) -> Tuple[User, dict]:
|
||||
from gatehouse_app.services.external_auth._helpers import (
|
||||
_exchange_code,
|
||||
_get_user_info,
|
||||
_encrypt_provider_data,
|
||||
)
|
||||
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
state_record = OAuthState.query.filter_by(state=state).first()
|
||||
if not state_record or not state_record.is_valid():
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
failure_reason="invalid_state",
|
||||
error_message="Invalid or expired OAuth state",
|
||||
)
|
||||
raise ExternalAuthError("Invalid or expired OAuth state", "INVALID_STATE", 400)
|
||||
|
||||
config = get_provider_config(provider_type, organization_id)
|
||||
|
||||
tokens = _exchange_code(
|
||||
config=config,
|
||||
code=authorization_code,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=state_record.code_verifier,
|
||||
)
|
||||
|
||||
user_info = _get_user_info(config=config, access_token=tokens["access_token"])
|
||||
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
).first()
|
||||
|
||||
if not auth_method:
|
||||
existing_user = User.query.filter_by(email=user_info["email"]).first()
|
||||
|
||||
if existing_user:
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
email=user_info["email"],
|
||||
failure_reason="email_exists",
|
||||
error_message=f"An account with email {user_info['email']} already exists",
|
||||
)
|
||||
raise ExternalAuthError(
|
||||
f"An account with email {user_info['email']} already exists. "
|
||||
"Please log in with your password and link your Google account from settings.",
|
||||
"EMAIL_EXISTS",
|
||||
400,
|
||||
)
|
||||
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
email=user_info["email"],
|
||||
failure_reason="account_not_found",
|
||||
error_message="No Gatehouse account matches this external account",
|
||||
)
|
||||
raise ExternalAuthError(
|
||||
"No Gatehouse account matches this external account. Please register first.",
|
||||
"ACCOUNT_NOT_FOUND",
|
||||
400,
|
||||
)
|
||||
|
||||
user = auth_method.user
|
||||
auth_method.provider_data = _encrypt_provider_data(tokens, user_info)
|
||||
auth_method.last_used_at = datetime.utcnow()
|
||||
auth_method.save()
|
||||
|
||||
state_record.mark_used()
|
||||
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user, organization_id=organization_id)
|
||||
|
||||
AuditService.log_external_auth_login(
|
||||
user_id=user.id,
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
auth_method_id=auth_method.id,
|
||||
session_id=session.id,
|
||||
)
|
||||
|
||||
return user, session.to_dict()
|
||||
|
||||
|
||||
def unlink_provider(
|
||||
user_id: str,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
) -> bool:
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user_id,
|
||||
method_type=provider_type,
|
||||
).first()
|
||||
|
||||
if not auth_method:
|
||||
raise ExternalAuthError("Provider not linked", "PROVIDER_NOT_LINKED", 400)
|
||||
|
||||
other_methods = AuthenticationMethod.query.filter_by(user_id=user_id).count()
|
||||
if other_methods <= 1:
|
||||
raise ExternalAuthError(
|
||||
"Cannot unlink the last authentication method",
|
||||
"CANNOT_UNLINK_LAST",
|
||||
400,
|
||||
)
|
||||
|
||||
provider_user_id = auth_method.provider_user_id
|
||||
auth_method_id = auth_method.id
|
||||
auth_method.delete()
|
||||
|
||||
AuditService.log_external_auth_unlink(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=provider_user_id,
|
||||
auth_method_id=auth_method_id,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_linked_accounts(user_id: str) -> list:
|
||||
from gatehouse_app.utils.constants import AuthMethodType as AMT
|
||||
|
||||
methods = AuthenticationMethod.query.filter_by(user_id=user_id, deleted_at=None).all()
|
||||
|
||||
external_providers = [AMT.GOOGLE, AMT.GITHUB, AMT.MICROSOFT]
|
||||
|
||||
return [
|
||||
{
|
||||
"id": m.id,
|
||||
"provider_type": m.method_type.value if hasattr(m.method_type, 'value') else str(m.method_type),
|
||||
"provider_user_id": m.provider_user_id,
|
||||
"email": m.provider_data.get("email") if m.provider_data else None,
|
||||
"name": m.provider_data.get("name") if m.provider_data else None,
|
||||
"picture": m.provider_data.get("picture") if m.provider_data else None,
|
||||
"verified": m.verified,
|
||||
"linked_at": m.created_at.isoformat() if m.created_at else None,
|
||||
"last_used_at": m.last_used_at.isoformat() if m.last_used_at else None,
|
||||
}
|
||||
for m in methods
|
||||
if m.method_type in external_providers
|
||||
or str(m.method_type) in [p.value for p in external_providers]
|
||||
]
|
||||
@@ -0,0 +1,173 @@
|
||||
"""External auth models and adapter classes."""
|
||||
from typing import Optional
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.auth.authentication_method import (
|
||||
ApplicationProviderConfig,
|
||||
OrganizationProviderOverride,
|
||||
)
|
||||
|
||||
|
||||
class ExternalAuthError(Exception):
|
||||
"""Base exception for external auth errors."""
|
||||
|
||||
def __init__(self, message: str, error_type: str, status_code: int = 400):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ExternalProviderConfig(BaseModel):
|
||||
"""OAuth provider configuration per organization.
|
||||
|
||||
DEPRECATED: This model is maintained for backward compatibility only.
|
||||
Use ApplicationProviderConfig and OrganizationProviderOverride instead.
|
||||
"""
|
||||
|
||||
__tablename__ = "external_provider_configs"
|
||||
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
|
||||
)
|
||||
provider_type = db.Column(db.String(50), nullable=False, index=True)
|
||||
client_id = db.Column(db.String(255), nullable=False)
|
||||
client_secret_encrypted = db.Column(db.String(512), nullable=True)
|
||||
auth_url = db.Column(db.String(2048), nullable=False)
|
||||
token_url = db.Column(db.String(2048), nullable=False)
|
||||
userinfo_url = db.Column(db.String(2048), nullable=True)
|
||||
jwks_url = db.Column(db.String(2048), nullable=True)
|
||||
scopes = db.Column(db.JSON, nullable=False, default=list)
|
||||
redirect_uris = db.Column(db.JSON, nullable=False, default=list)
|
||||
settings = db.Column(db.JSON, nullable=True)
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
organization = db.relationship(
|
||||
"Organization", back_populates="external_provider_configs"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_provider_config_org", "organization_id", "provider_type"),
|
||||
db.UniqueConstraint(
|
||||
"organization_id",
|
||||
"provider_type",
|
||||
name="uix_org_provider_type",
|
||||
),
|
||||
)
|
||||
|
||||
def get_client_secret(self) -> str:
|
||||
from gatehouse_app.utils.encryption import decrypt
|
||||
if self.client_secret_encrypted:
|
||||
return decrypt(self.client_secret_encrypted)
|
||||
return None
|
||||
|
||||
def set_client_secret(self, secret: str):
|
||||
from gatehouse_app.utils.encryption import encrypt
|
||||
self.client_secret_encrypted = encrypt(secret)
|
||||
|
||||
def is_redirect_uri_allowed(self, uri: str) -> bool:
|
||||
return uri in (self.redirect_uris or [])
|
||||
|
||||
def to_dict(self, include_secrets: bool = False) -> dict:
|
||||
data = {
|
||||
"id": self.id,
|
||||
"organization_id": self.organization_id,
|
||||
"provider_type": self.provider_type,
|
||||
"client_id": self.client_id,
|
||||
"auth_url": self.auth_url,
|
||||
"token_url": self.token_url,
|
||||
"userinfo_url": self.userinfo_url,
|
||||
"scopes": self.scopes,
|
||||
"redirect_uris": self.redirect_uris,
|
||||
"is_active": self.is_active,
|
||||
"settings": self.settings,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
if include_secrets and self.client_secret_encrypted:
|
||||
data["client_secret"] = self.get_client_secret()
|
||||
return data
|
||||
|
||||
|
||||
class ProviderConfigAdapter:
|
||||
"""Unified interface for provider configuration.
|
||||
|
||||
Merges application-level config with optional organization overrides.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_config: ApplicationProviderConfig,
|
||||
org_override: Optional[OrganizationProviderOverride] = None,
|
||||
):
|
||||
self.app_config = app_config
|
||||
self.org_override = org_override
|
||||
self.provider_type = app_config.provider_type
|
||||
|
||||
@property
|
||||
def client_id(self) -> str:
|
||||
if self.org_override and self.org_override.client_id:
|
||||
return self.org_override.client_id
|
||||
return self.app_config.client_id
|
||||
|
||||
def get_client_secret(self) -> str:
|
||||
if self.org_override and self.org_override.client_secret_encrypted:
|
||||
return self.org_override.get_client_secret()
|
||||
return self.app_config.get_client_secret()
|
||||
|
||||
@property
|
||||
def auth_url(self) -> str:
|
||||
return self._get_provider_endpoint('auth_url')
|
||||
|
||||
@property
|
||||
def token_url(self) -> str:
|
||||
return self._get_provider_endpoint('token_url')
|
||||
|
||||
@property
|
||||
def userinfo_url(self) -> str:
|
||||
return self._get_provider_endpoint('userinfo_url')
|
||||
|
||||
@property
|
||||
def jwks_url(self) -> str:
|
||||
return self._get_provider_endpoint('jwks_url')
|
||||
|
||||
@property
|
||||
def scopes(self) -> list:
|
||||
base_scopes = self.app_config.additional_config.get('scopes', []) if self.app_config.additional_config else []
|
||||
if self.org_override and self.org_override.additional_config:
|
||||
override_scopes = self.org_override.additional_config.get('scopes')
|
||||
if override_scopes is not None:
|
||||
return override_scopes
|
||||
return base_scopes or ['openid', 'profile', 'email']
|
||||
|
||||
@property
|
||||
def redirect_uris(self) -> list:
|
||||
if self.org_override and self.org_override.redirect_url_override:
|
||||
return [self.org_override.redirect_url_override]
|
||||
if self.app_config.default_redirect_url:
|
||||
return [self.app_config.default_redirect_url]
|
||||
return []
|
||||
|
||||
@property
|
||||
def settings(self) -> dict:
|
||||
settings = {}
|
||||
if self.app_config.additional_config:
|
||||
settings.update(self.app_config.additional_config)
|
||||
if self.org_override and self.org_override.additional_config:
|
||||
settings.update(self.org_override.additional_config)
|
||||
return settings
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
app_enabled = self.app_config.is_enabled
|
||||
org_enabled = True if not self.org_override else self.org_override.is_enabled
|
||||
return app_enabled and org_enabled
|
||||
|
||||
def is_redirect_uri_allowed(self, uri: str) -> bool:
|
||||
return uri in self.redirect_uris
|
||||
|
||||
def _get_provider_endpoint(self, endpoint_name: str) -> Optional[str]:
|
||||
if not self.app_config.additional_config:
|
||||
return None
|
||||
return self.app_config.additional_config.get(endpoint_name)
|
||||
@@ -0,0 +1,147 @@
|
||||
"""Organization-specific provider override management."""
|
||||
import logging
|
||||
|
||||
from gatehouse_app.models.auth.authentication_method import (
|
||||
ApplicationProviderConfig,
|
||||
OrganizationProviderOverride,
|
||||
)
|
||||
from gatehouse_app.services.external_auth.models import ExternalAuthError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_org_provider_override(
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
**kwargs,
|
||||
) -> OrganizationProviderOverride:
|
||||
app_config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
).first()
|
||||
|
||||
if not app_config:
|
||||
raise ExternalAuthError(
|
||||
f"Application provider {provider_type} must be configured first",
|
||||
"PROVIDER_NOT_CONFIGURED",
|
||||
400,
|
||||
)
|
||||
|
||||
existing = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type,
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise ExternalAuthError(
|
||||
f"Override for {provider_type} already exists for this organization",
|
||||
"OVERRIDE_EXISTS",
|
||||
400,
|
||||
)
|
||||
|
||||
additional_config = {}
|
||||
if 'settings' in kwargs:
|
||||
additional_config.update(kwargs.pop('settings'))
|
||||
if 'scopes' in kwargs:
|
||||
additional_config['scopes'] = kwargs.pop('scopes')
|
||||
|
||||
override = OrganizationProviderOverride(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type,
|
||||
client_id=kwargs.get('client_id'),
|
||||
is_enabled=kwargs.get('is_enabled', True),
|
||||
redirect_url_override=kwargs.get('redirect_url_override'),
|
||||
additional_config=additional_config if additional_config else None,
|
||||
)
|
||||
|
||||
if 'client_secret' in kwargs:
|
||||
override.set_client_secret(kwargs['client_secret'])
|
||||
|
||||
override.save()
|
||||
logger.info(f"Created org override for {provider_type} in org {organization_id}")
|
||||
return override
|
||||
|
||||
|
||||
def update_org_provider_override(
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
**updates,
|
||||
) -> OrganizationProviderOverride:
|
||||
override = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type,
|
||||
).first()
|
||||
|
||||
if not override:
|
||||
raise ExternalAuthError(
|
||||
f"Override for {provider_type} not found for this organization",
|
||||
"OVERRIDE_NOT_FOUND",
|
||||
404,
|
||||
)
|
||||
|
||||
if 'client_id' in updates:
|
||||
override.client_id = updates['client_id']
|
||||
|
||||
if 'client_secret' in updates:
|
||||
override.set_client_secret(updates['client_secret'])
|
||||
|
||||
if 'is_enabled' in updates:
|
||||
override.is_enabled = updates['is_enabled']
|
||||
|
||||
if 'redirect_url_override' in updates:
|
||||
override.redirect_url_override = updates['redirect_url_override']
|
||||
|
||||
if 'settings' in updates or 'scopes' in updates:
|
||||
if override.additional_config is None:
|
||||
override.additional_config = {}
|
||||
if 'settings' in updates:
|
||||
override.additional_config.update(updates['settings'])
|
||||
if 'scopes' in updates:
|
||||
override.additional_config['scopes'] = updates['scopes']
|
||||
|
||||
override.save()
|
||||
logger.info(f"Updated org override for {provider_type} in org {organization_id}")
|
||||
return override
|
||||
|
||||
|
||||
def get_org_provider_override(
|
||||
organization_id: str,
|
||||
provider_type: str,
|
||||
) -> OrganizationProviderOverride:
|
||||
override = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type,
|
||||
).first()
|
||||
|
||||
if not override:
|
||||
raise ExternalAuthError(
|
||||
f"Override for {provider_type} not found for this organization",
|
||||
"OVERRIDE_NOT_FOUND",
|
||||
404,
|
||||
)
|
||||
|
||||
return override
|
||||
|
||||
|
||||
def list_org_provider_overrides(organization_id: str) -> list:
|
||||
overrides = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id
|
||||
).all()
|
||||
return [override.to_dict() for override in overrides]
|
||||
|
||||
|
||||
def delete_org_provider_override(organization_id: str, provider_type: str) -> bool:
|
||||
override = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type,
|
||||
).first()
|
||||
|
||||
if not override:
|
||||
raise ExternalAuthError(
|
||||
f"Override for {provider_type} not found for this organization",
|
||||
"OVERRIDE_NOT_FOUND",
|
||||
404,
|
||||
)
|
||||
|
||||
override.delete()
|
||||
logger.info(f"Deleted org override for {provider_type} in org {organization_id}")
|
||||
return True
|
||||
File diff suppressed because it is too large
Load Diff
@@ -295,6 +295,7 @@ Gatehouse Security Team
|
||||
|
||||
Returns True if the email was sent successfully, False otherwise.
|
||||
If EMAIL_ENABLED is False, logs the email body instead (simulation mode).
|
||||
All SMTP exceptions are caught and logged — this method never raises.
|
||||
"""
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
@@ -310,17 +311,37 @@ Gatehouse Security Team
|
||||
)
|
||||
return False
|
||||
|
||||
smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY, "localhost")
|
||||
smtp_port = int(current_app.config.get(NotificationService.SMTP_PORT_KEY, 587))
|
||||
smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY, "")
|
||||
smtp_port_raw = current_app.config.get(NotificationService.SMTP_PORT_KEY, 587)
|
||||
smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY)
|
||||
smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY)
|
||||
from_address = current_app.config.get(
|
||||
NotificationService.FROM_ADDRESS_KEY, ""
|
||||
)
|
||||
|
||||
# Guard: refuse to attempt a connection when critical config is missing.
|
||||
# This surfaces a clear log message instead of a confusing socket error.
|
||||
missing = [k for k, v in [
|
||||
("SMTP_HOST", smtp_host),
|
||||
("FROM_ADDRESS", from_address),
|
||||
] if not v]
|
||||
if missing:
|
||||
logger.error(
|
||||
f"[EMAIL] Cannot send — missing config: {', '.join(missing)}. "
|
||||
f"Would have sent to: {to_address} | Subject: {subject}"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
smtp_port = int(smtp_port_raw)
|
||||
except (TypeError, ValueError):
|
||||
logger.error(f"[EMAIL] Invalid SMTP_PORT value: {smtp_port_raw!r}")
|
||||
return False
|
||||
|
||||
smtp_use_tls = current_app.config.get(
|
||||
NotificationService.SMTP_USE_TLS_KEY,
|
||||
smtp_port not in (25, 1025),
|
||||
)
|
||||
from_address = current_app.config.get(
|
||||
NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local"
|
||||
)
|
||||
|
||||
try:
|
||||
msg = MIMEMultipart("alternative")
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
"""OAuthFlowService — public facade and handle_callback dispatcher."""
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from gatehouse_app.models.auth.authentication_method import OAuthState
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth import ExternalAuthService
|
||||
from gatehouse_app.services.external_auth.models import ExternalAuthError
|
||||
|
||||
from gatehouse_app.services.oauth_flow.login import OAuthFlowError, initiate_login_flow, handle_login_callback
|
||||
from gatehouse_app.services.oauth_flow.register import initiate_register_flow, handle_register_callback
|
||||
from gatehouse_app.services.oauth_flow.code import (
|
||||
generate_authorization_code,
|
||||
exchange_authorization_code,
|
||||
create_redirect_response,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthFlowService:
|
||||
"""Service for managing OAuth authentication flows."""
|
||||
|
||||
@classmethod
|
||||
def initiate_login_flow(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
state_data: dict = None,
|
||||
) -> Tuple[str, str]:
|
||||
return initiate_login_flow(provider_type, organization_id, redirect_uri, state_data)
|
||||
|
||||
@classmethod
|
||||
def initiate_register_flow(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
) -> Tuple[str, str]:
|
||||
return initiate_register_flow(provider_type, organization_id, redirect_uri)
|
||||
|
||||
@classmethod
|
||||
def handle_callback(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
authorization_code: str,
|
||||
state: str,
|
||||
redirect_uri: str = None,
|
||||
error: str = None,
|
||||
error_description: str = None,
|
||||
) -> dict:
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
from flask import request
|
||||
ip_address = request.remote_addr if request else None
|
||||
user_agent = request.headers.get("User-Agent") if request else None
|
||||
except RuntimeError:
|
||||
ip_address = None
|
||||
user_agent = None
|
||||
|
||||
if error:
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=None,
|
||||
provider_type=provider_type_str,
|
||||
failure_reason=error,
|
||||
error_message=error_description or error,
|
||||
)
|
||||
raise OAuthFlowError(
|
||||
error_description or f"OAuth error: {error}",
|
||||
error.upper() if error else "OAUTH_ERROR",
|
||||
400,
|
||||
)
|
||||
|
||||
state_record = OAuthState.query.filter_by(state=state).first()
|
||||
|
||||
if state_record:
|
||||
logger.debug(
|
||||
f"State validation: found=True, used={state_record.used}, "
|
||||
f"expires_at={state_record.expires_at}, is_valid={state_record.is_valid()}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"State validation: state token not found in database: {state}")
|
||||
|
||||
if not state_record or not state_record.is_valid():
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=state_record.organization_id if state_record else None,
|
||||
provider_type=provider_type_str,
|
||||
failure_reason="invalid_state",
|
||||
error_message="Invalid or expired OAuth state",
|
||||
)
|
||||
raise OAuthFlowError("Invalid or expired OAuth state", "INVALID_STATE", 400)
|
||||
|
||||
effective_redirect = redirect_uri or state_record.redirect_uri
|
||||
|
||||
if state_record.flow_type == "login":
|
||||
return handle_login_callback(
|
||||
provider_type=provider_type,
|
||||
state_record=state_record,
|
||||
authorization_code=authorization_code,
|
||||
redirect_uri=effective_redirect,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
elif state_record.flow_type == "link":
|
||||
return cls._handle_link_callback(
|
||||
provider_type=provider_type,
|
||||
state_record=state_record,
|
||||
authorization_code=authorization_code,
|
||||
redirect_uri=effective_redirect,
|
||||
)
|
||||
elif state_record.flow_type == "register":
|
||||
return handle_register_callback(
|
||||
provider_type=provider_type,
|
||||
state_record=state_record,
|
||||
authorization_code=authorization_code,
|
||||
redirect_uri=effective_redirect,
|
||||
)
|
||||
else:
|
||||
raise OAuthFlowError(
|
||||
f"Unknown flow type: {state_record.flow_type}",
|
||||
"INVALID_FLOW_TYPE",
|
||||
400,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _handle_link_callback(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
state_record: OAuthState,
|
||||
authorization_code: str,
|
||||
redirect_uri: str,
|
||||
) -> dict:
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
auth_method = ExternalAuthService.complete_link_flow(
|
||||
provider_type=provider_type,
|
||||
authorization_code=authorization_code,
|
||||
state=state_record.state,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth link successful for user={state_record.user_id}, "
|
||||
f"provider={provider_type_str}, auth_method_id={auth_method.id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "link",
|
||||
"linked_account": {
|
||||
"id": auth_method.id,
|
||||
"provider_type": provider_type_str,
|
||||
"provider_user_id": auth_method.provider_user_id,
|
||||
"verified": auth_method.verified,
|
||||
},
|
||||
}
|
||||
|
||||
except ExternalAuthError as e:
|
||||
logger.warning(
|
||||
f"OAuth link failed for state={state_record.id}, "
|
||||
f"provider={provider_type_str}, error={e.message}"
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def validate_state(cls, state: str) -> Optional[OAuthState]:
|
||||
state_record = OAuthState.query.filter_by(state=state).first()
|
||||
if state_record and state_record.is_valid():
|
||||
return state_record
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired_states(cls):
|
||||
OAuthState.cleanup_expired()
|
||||
logger.info("Expired OAuth states cleaned up")
|
||||
|
||||
@classmethod
|
||||
def generate_authorization_code(
|
||||
cls,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
scope: list = None,
|
||||
nonce: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> str:
|
||||
return generate_authorization_code(
|
||||
user_id, client_id, redirect_uri, scope, nonce, ip_address, user_agent, lifetime_seconds
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def exchange_authorization_code(
|
||||
cls,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
ip_address: str = None,
|
||||
) -> dict:
|
||||
return exchange_authorization_code(code, client_id, redirect_uri, ip_address)
|
||||
|
||||
@classmethod
|
||||
def create_redirect_response(cls, redirect_uri: str, authorization_code: str, state: str = None):
|
||||
return create_redirect_response(redirect_uri, authorization_code, state)
|
||||
@@ -0,0 +1,141 @@
|
||||
"""Authorization code generation, exchange, and redirect helpers."""
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.services.oauth_flow.login import OAuthFlowError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_authorization_code(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
scope: list = None,
|
||||
nonce: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> str:
|
||||
code = secrets.token_urlsafe(32)
|
||||
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||
|
||||
OIDCAuthCode.create_code(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
code_hash=code_hash,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
nonce=nonce,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
lifetime_seconds=lifetime_seconds,
|
||||
)
|
||||
|
||||
logger.info(f"Generated authorization code for user={user_id}, client={client_id}")
|
||||
return code
|
||||
|
||||
|
||||
def exchange_authorization_code(
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
ip_address: str = None,
|
||||
) -> dict:
|
||||
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||
|
||||
auth_code = OIDCAuthCode.query.filter_by(
|
||||
client_id=client_id,
|
||||
code_hash=code_hash,
|
||||
).first()
|
||||
|
||||
if not auth_code:
|
||||
raise OAuthFlowError("Invalid authorization code", "INVALID_CODE", 400)
|
||||
|
||||
if not auth_code.is_valid():
|
||||
if auth_code.is_used:
|
||||
raise OAuthFlowError(
|
||||
"Authorization code has already been used", "CODE_USED", 400
|
||||
)
|
||||
else:
|
||||
raise OAuthFlowError("Authorization code has expired", "CODE_EXPIRED", 400)
|
||||
|
||||
if auth_code.redirect_uri != redirect_uri:
|
||||
raise OAuthFlowError("Redirect URI mismatch", "INVALID_REDIRECT_URI", 400)
|
||||
|
||||
from gatehouse_app.models import User
|
||||
user = User.query.get(auth_code.user_id)
|
||||
if not user:
|
||||
raise OAuthFlowError("User not found", "USER_NOT_FOUND", 404)
|
||||
|
||||
user_orgs = user.get_organizations()
|
||||
target_org = None
|
||||
if len(user_orgs) == 1:
|
||||
target_org = user_orgs[0]
|
||||
|
||||
if not target_org:
|
||||
raise OAuthFlowError(
|
||||
"User does not have a default organization. Organization selection required.",
|
||||
"ORG_SELECTION_REQUIRED",
|
||||
400,
|
||||
)
|
||||
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user, is_compliance_only=False)
|
||||
auth_code.mark_as_used()
|
||||
|
||||
session_dict = session.to_dict()
|
||||
session_dict["token"] = session.token
|
||||
expires_at = session.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||
|
||||
logger.info(
|
||||
f"Authorization code exchanged for session: user={user.id}, "
|
||||
f"org_id={target_org.id}, client={client_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"token": session_dict["token"],
|
||||
"expires_in": session_dict["expires_in"],
|
||||
"token_type": "Bearer",
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": target_org.id,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_redirect_response(
|
||||
redirect_uri: str,
|
||||
authorization_code: str,
|
||||
state: str = None,
|
||||
):
|
||||
from urllib.parse import urlencode, urlparse, urlunparse
|
||||
from flask import redirect
|
||||
|
||||
parsed = urlparse(redirect_uri)
|
||||
params = {"code": authorization_code}
|
||||
if state:
|
||||
params["state"] = state
|
||||
|
||||
redirect_url = urlunparse((
|
||||
parsed.scheme,
|
||||
parsed.netloc,
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
urlencode(params),
|
||||
parsed.fragment,
|
||||
))
|
||||
|
||||
logger.info(f"Redirecting to {parsed.scheme}://{parsed.netloc} with authorization code")
|
||||
return redirect(redirect_url)
|
||||
@@ -0,0 +1,410 @@
|
||||
"""Login flow: initiate and handle OAuth login callback."""
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.auth.authentication_method import OAuthState
|
||||
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth import ExternalAuthService
|
||||
from gatehouse_app.services.external_auth.models import ExternalAuthError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthFlowError(Exception):
|
||||
def __init__(self, message: str, error_type: str, status_code: int = 400):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def initiate_login_flow(
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
state_data: dict = None,
|
||||
) -> Tuple[str, str]:
|
||||
try:
|
||||
from flask import request
|
||||
except Exception:
|
||||
request = None
|
||||
|
||||
try:
|
||||
ip_address = request.remote_addr if request else None
|
||||
user_agent = request.headers.get("User-Agent") if request else None
|
||||
except RuntimeError:
|
||||
ip_address = None
|
||||
user_agent = None
|
||||
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
config = ExternalAuthService.get_provider_config(provider_type, organization_id)
|
||||
|
||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||
raise OAuthFlowError("Invalid redirect URI", "INVALID_REDIRECT_URI", 400)
|
||||
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if provider_type_str not in ['google', 'microsoft']:
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Provider type check: provider_type_str='{provider_type_str}', "
|
||||
f"is_google={provider_type_str in ['google']}, "
|
||||
f"will_skip_pkce={provider_type_str in ['google', 'microsoft']}"
|
||||
)
|
||||
|
||||
state = OAuthState.create_state(
|
||||
flow_type="login",
|
||||
provider_type=provider_type,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None),
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
extra_data=state_data,
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Created OAuthState object:\n"
|
||||
f" state.id: {state.id}\n"
|
||||
f" state.provider_type: {state.provider_type}\n"
|
||||
f" state.code_challenge: {state.code_challenge}\n"
|
||||
f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..."
|
||||
)
|
||||
|
||||
auth_url = ExternalAuthService._build_authorization_url(config=config, state=state)
|
||||
|
||||
logger.info(
|
||||
f"OAuth login flow initiated for provider={provider_type_str}, "
|
||||
f"org_id={organization_id}, state_token={state.state}, state_record_id={state.id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] FINAL CHECK: code_challenge={code_challenge}, "
|
||||
f"code_verifier={code_verifier[:20] if code_verifier else None}..., "
|
||||
f"auth_url_has_challenge={'code_challenge=' in auth_url}, "
|
||||
f"returned_auth_url={auth_url}"
|
||||
)
|
||||
|
||||
return auth_url, state.state
|
||||
|
||||
except ExternalAuthError as e:
|
||||
AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
"failure_reason": e.error_type,
|
||||
"ip_address": ip_address,
|
||||
},
|
||||
description=f"OAuth login initiation failed: {e.message}",
|
||||
success=False,
|
||||
error_message=e.message,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def handle_login_callback(
|
||||
provider_type: AuthMethodType,
|
||||
state_record: OAuthState,
|
||||
authorization_code: str,
|
||||
redirect_uri: str,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> dict:
|
||||
from gatehouse_app.services.external_auth._helpers import _encrypt_provider_data
|
||||
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
config = ExternalAuthService.get_provider_config(
|
||||
provider_type, state_record.organization_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Exchanging code with PKCE: state_record.code_verifier="
|
||||
f"{state_record.code_verifier[:20] if state_record.code_verifier else None}..."
|
||||
)
|
||||
|
||||
tokens = ExternalAuthService._exchange_code(
|
||||
config=config,
|
||||
code=authorization_code,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=state_record.code_verifier,
|
||||
)
|
||||
|
||||
user_info = ExternalAuthService._get_user_info(
|
||||
config=config,
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
if not user_info.get("provider_user_id"):
|
||||
raise OAuthFlowError(
|
||||
"Provider did not return a user identifier (sub claim). "
|
||||
"Cannot complete authentication.",
|
||||
"MISSING_PROVIDER_USER_ID",
|
||||
400,
|
||||
)
|
||||
|
||||
if not user_info.get("email"):
|
||||
raise OAuthFlowError(
|
||||
"Provider did not return an email address. "
|
||||
"Cannot complete authentication.",
|
||||
"MISSING_EMAIL",
|
||||
400,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Got user_info from provider: sub={user_info['provider_user_id']}, "
|
||||
f"email={user_info['email']}, email_verified={user_info.get('email_verified')}"
|
||||
)
|
||||
|
||||
# Find the active auth method for this provider identity.
|
||||
# Order by created_at DESC so that an explicitly linked (newer) row wins
|
||||
# over an older auto-created primary row when the same Google identity
|
||||
# was linked to a second profile.
|
||||
auth_method = (
|
||||
AuthenticationMethod.query
|
||||
.filter_by(
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
deleted_at=None,
|
||||
)
|
||||
.order_by(AuthenticationMethod.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not auth_method:
|
||||
deleted_method = (
|
||||
AuthenticationMethod.query
|
||||
.filter_by(
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
)
|
||||
.order_by(AuthenticationMethod.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if deleted_method:
|
||||
logger.info(
|
||||
f"OAuth login: restoring previously unlinked {provider_type_str} "
|
||||
f"auth method for user {deleted_method.user_id}"
|
||||
)
|
||||
deleted_method.deleted_at = None
|
||||
deleted_method.provider_data = _encrypt_provider_data(tokens, user_info)
|
||||
deleted_method.last_used_at = datetime.utcnow()
|
||||
deleted_method.save()
|
||||
auth_method = deleted_method
|
||||
|
||||
else:
|
||||
existing_user = User.query.filter_by(email=user_info["email"], deleted_at=None).first()
|
||||
|
||||
if existing_user:
|
||||
logger.info(
|
||||
f"OAuth login: email {user_info['email']} matches existing user "
|
||||
f"{existing_user.id}, auto-linking {provider_type_str} account"
|
||||
)
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=existing_user.id,
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
provider_data=_encrypt_provider_data(tokens, user_info),
|
||||
verified=user_info.get("email_verified", False),
|
||||
is_primary=False,
|
||||
last_used_at=datetime.utcnow(),
|
||||
)
|
||||
auth_method.save()
|
||||
user = existing_user
|
||||
else:
|
||||
logger.info(
|
||||
f"OAuth login: no account for {user_info['email']}, "
|
||||
f"auto-creating user via {provider_type_str}"
|
||||
)
|
||||
user = User(
|
||||
email=user_info["email"],
|
||||
full_name=user_info.get("name", ""),
|
||||
status="active",
|
||||
email_verified=user_info.get("email_verified", False),
|
||||
)
|
||||
user.save()
|
||||
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=user.id,
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
provider_data=_encrypt_provider_data(tokens, user_info),
|
||||
verified=user_info.get("email_verified", False),
|
||||
is_primary=True,
|
||||
last_used_at=datetime.utcnow(),
|
||||
)
|
||||
auth_method.save()
|
||||
|
||||
AuditService.log_action(
|
||||
action="user.register",
|
||||
user_id=user.id,
|
||||
organization_id=state_record.organization_id,
|
||||
resource_type="user",
|
||||
resource_id=user.id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
"provider_user_id": user_info["provider_user_id"],
|
||||
"auto_registered": True,
|
||||
},
|
||||
description=f"User auto-registered via {provider_type_str} OAuth",
|
||||
success=True,
|
||||
)
|
||||
else:
|
||||
auth_method.provider_data = _encrypt_provider_data(tokens, user_info)
|
||||
auth_method.last_used_at = datetime.utcnow()
|
||||
auth_method.save()
|
||||
|
||||
user = auth_method.user
|
||||
|
||||
user_orgs = user.get_organizations()
|
||||
target_org = None
|
||||
|
||||
if state_record.organization_id:
|
||||
target_org = next(
|
||||
(org for org in user_orgs if org.id == state_record.organization_id),
|
||||
None,
|
||||
)
|
||||
|
||||
if not target_org and len(user_orgs) == 1:
|
||||
target_org = user_orgs[0]
|
||||
|
||||
if not target_org and len(user_orgs) > 1:
|
||||
# Multiple orgs and none specified in the OAuth state — pick the one the
|
||||
# user joined most recently (highest created_at on their membership row).
|
||||
# Users can switch organisations inside the app after logging in.
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember as _OM
|
||||
latest_membership = (
|
||||
_OM.query
|
||||
.filter_by(user_id=user.id, deleted_at=None)
|
||||
.order_by(_OM.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if latest_membership:
|
||||
target_org = latest_membership.organization
|
||||
else:
|
||||
target_org = user_orgs[0]
|
||||
|
||||
if not target_org and len(user_orgs) == 0:
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
from gatehouse_app.services.auth_service import AuthService as _AS
|
||||
_now = datetime.now(timezone.utc)
|
||||
_session = _AS.create_session(user=user, is_compliance_only=False)
|
||||
_session_dict = _session.to_dict()
|
||||
_session_dict["token"] = _session.token
|
||||
_expires_at = _session.expires_at
|
||||
if _expires_at.tzinfo is None:
|
||||
_expires_at = _expires_at.replace(tzinfo=timezone.utc)
|
||||
_session_dict["expires_in"] = int((_expires_at - _now).total_seconds())
|
||||
|
||||
_pending = OrgInviteToken.query.filter(
|
||||
OrgInviteToken.email == user.email,
|
||||
OrgInviteToken.accepted_at.is_(None),
|
||||
OrgInviteToken.expires_at > _now,
|
||||
OrgInviteToken.deleted_at.is_(None),
|
||||
).all()
|
||||
_pending_list = [
|
||||
{
|
||||
"token": inv.token,
|
||||
"organization": {"id": str(inv.organization_id), "name": inv.organization.name},
|
||||
"role": inv.role,
|
||||
"expires_at": inv.expires_at.isoformat(),
|
||||
}
|
||||
for inv in _pending
|
||||
]
|
||||
|
||||
state_record.mark_used()
|
||||
logger.info(
|
||||
f"OAuth login: user {user.id} has no org, redirecting to org-setup "
|
||||
f"(pending_invites={len(_pending_list)})"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "login",
|
||||
"requires_org_creation": True,
|
||||
"user": {"id": user.id, "email": user.email, "full_name": user.full_name},
|
||||
"session": _session_dict,
|
||||
"pending_invites": _pending_list,
|
||||
"state": state_record.state,
|
||||
}
|
||||
|
||||
if not target_org:
|
||||
state_record.mark_used()
|
||||
logger.info(
|
||||
f"OAuth login requires org selection for user={user.id}, "
|
||||
f"provider={provider_type_str}, org_count={len(user_orgs)}"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "login",
|
||||
"requires_org_selection": True,
|
||||
"user": {"id": user.id, "email": user.email, "full_name": user.full_name},
|
||||
"available_organizations": [
|
||||
{
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
"slug": org.slug if hasattr(org, "slug") else None,
|
||||
}
|
||||
for org in user_orgs
|
||||
],
|
||||
"state": state_record.state,
|
||||
}
|
||||
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user, is_compliance_only=False)
|
||||
state_record.mark_used()
|
||||
|
||||
AuditService.log_external_auth_login(
|
||||
user_id=user.id,
|
||||
organization_id=target_org.id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
auth_method_id=auth_method.id,
|
||||
session_id=session.id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth login successful for user={user.id}, "
|
||||
f"provider={provider_type_str}, org_id={target_org.id}"
|
||||
)
|
||||
|
||||
session_dict = session.to_dict()
|
||||
session_dict["token"] = session.token
|
||||
expires_at = session.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "login",
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": target_org.id,
|
||||
},
|
||||
"session": session_dict,
|
||||
}
|
||||
|
||||
except ExternalAuthError as e:
|
||||
logger.warning(
|
||||
f"OAuth login failed for state={state_record.id}, "
|
||||
f"provider={provider_type_str}, error={e.message}"
|
||||
)
|
||||
raise
|
||||
except OAuthFlowError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in OAuth login callback: {str(e)}", exc_info=True)
|
||||
raise OAuthFlowError("An unexpected error occurred during login", "INTERNAL_ERROR", 500)
|
||||
@@ -0,0 +1,248 @@
|
||||
"""Registration flow: initiate and handle OAuth register callback."""
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.auth.authentication_method import OAuthState
|
||||
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth import ExternalAuthService
|
||||
from gatehouse_app.services.external_auth.models import ExternalAuthError
|
||||
from gatehouse_app.services.oauth_flow.login import OAuthFlowError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initiate_register_flow(
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
) -> Tuple[str, str]:
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
config = ExternalAuthService.get_provider_config(provider_type, organization_id)
|
||||
|
||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||
raise OAuthFlowError("Invalid redirect URI", "INVALID_REDIRECT_URI", 400)
|
||||
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if provider_type_str not in ['google', 'microsoft']:
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Register flow - Provider type check: provider_type_str='{provider_type_str}', "
|
||||
f"is_google={provider_type_str in ['google']}, "
|
||||
f"will_skip_pkce={provider_type_str in ['google', 'microsoft']}"
|
||||
)
|
||||
|
||||
state = OAuthState.create_state(
|
||||
flow_type="register",
|
||||
provider_type=provider_type,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None),
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Register flow - Created OAuthState:\n"
|
||||
f" state.id: {state.id}\n"
|
||||
f" state.code_challenge: {state.code_challenge}\n"
|
||||
f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..."
|
||||
)
|
||||
|
||||
auth_url = ExternalAuthService._build_authorization_url(config=config, state=state)
|
||||
|
||||
logger.info(
|
||||
f"OAuth register flow initiated for provider={provider_type_str}, "
|
||||
f"org_id={organization_id}, state_id={state.id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Register flow - FINAL: auth_url_has_challenge={'code_challenge=' in auth_url}"
|
||||
)
|
||||
|
||||
return auth_url, state.state
|
||||
|
||||
except ExternalAuthError as e:
|
||||
AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
"failure_reason": e.error_type,
|
||||
},
|
||||
description=f"OAuth registration initiation failed: {e.message}",
|
||||
success=False,
|
||||
error_message=e.message,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def handle_register_callback(
|
||||
provider_type: AuthMethodType,
|
||||
state_record: OAuthState,
|
||||
authorization_code: str,
|
||||
redirect_uri: str,
|
||||
) -> dict:
|
||||
from gatehouse_app.services.external_auth._helpers import _encrypt_provider_data
|
||||
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
config = ExternalAuthService.get_provider_config(
|
||||
provider_type, state_record.organization_id
|
||||
)
|
||||
|
||||
tokens = ExternalAuthService._exchange_code(
|
||||
config=config,
|
||||
code=authorization_code,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=state_record.code_verifier,
|
||||
)
|
||||
|
||||
user_info = ExternalAuthService._get_user_info(
|
||||
config=config,
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
existing_user = User.query.filter_by(email=user_info["email"]).first()
|
||||
if existing_user:
|
||||
raise OAuthFlowError(
|
||||
f"An account with email {user_info['email']} already exists. "
|
||||
"Please log in with your password and link your Google account from settings.",
|
||||
"EMAIL_EXISTS",
|
||||
400,
|
||||
)
|
||||
|
||||
user = User(
|
||||
email=user_info["email"],
|
||||
full_name=user_info.get("name", ""),
|
||||
status="active",
|
||||
email_verified=user_info.get("email_verified", False),
|
||||
)
|
||||
user.save()
|
||||
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=user.id,
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
provider_data=_encrypt_provider_data(tokens, user_info),
|
||||
verified=user_info.get("email_verified", False),
|
||||
is_primary=True,
|
||||
last_used_at=datetime.utcnow(),
|
||||
)
|
||||
auth_method.save()
|
||||
|
||||
state_record.mark_used()
|
||||
|
||||
AuditService.log_action(
|
||||
action="user.register",
|
||||
user_id=user.id,
|
||||
organization_id=state_record.organization_id,
|
||||
resource_type="user",
|
||||
resource_id=user.id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
"provider_user_id": user_info["provider_user_id"],
|
||||
"auth_method_id": auth_method.id,
|
||||
},
|
||||
description=f"User registered via {provider_type_str}",
|
||||
success=True,
|
||||
)
|
||||
|
||||
AuditService.log_external_auth_link_completed(
|
||||
user_id=user.id,
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
auth_method_id=auth_method.id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth registration successful for email={user_info['email']}, "
|
||||
f"provider={provider_type_str}, user_id={user.id}"
|
||||
)
|
||||
|
||||
if state_record.organization_id:
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
org = Organization.query.get(state_record.organization_id)
|
||||
if org:
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user, is_compliance_only=False)
|
||||
session_dict = session.to_dict()
|
||||
session_dict["token"] = session.token
|
||||
expires_at = session.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "register",
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": org.id,
|
||||
},
|
||||
"session": session_dict,
|
||||
}
|
||||
|
||||
from gatehouse_app.services.auth_service import AuthService as _AS
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
_session = _AS.create_session(user=user, is_compliance_only=False)
|
||||
_session_dict = _session.to_dict()
|
||||
_session_dict["token"] = _session.token
|
||||
_expires_at = _session.expires_at
|
||||
if _expires_at.tzinfo is None:
|
||||
_expires_at = _expires_at.replace(tzinfo=timezone.utc)
|
||||
_now = datetime.now(timezone.utc)
|
||||
_session_dict["expires_in"] = int((_expires_at - _now).total_seconds())
|
||||
|
||||
_pending = OrgInviteToken.query.filter(
|
||||
OrgInviteToken.email == user.email,
|
||||
OrgInviteToken.accepted_at.is_(None),
|
||||
OrgInviteToken.expires_at > _now,
|
||||
OrgInviteToken.deleted_at.is_(None),
|
||||
).all()
|
||||
_pending_list = [
|
||||
{
|
||||
"token": inv.token,
|
||||
"organization": {"id": str(inv.organization_id), "name": inv.organization.name},
|
||||
"role": inv.role,
|
||||
"expires_at": inv.expires_at.isoformat(),
|
||||
}
|
||||
for inv in _pending
|
||||
]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "register",
|
||||
"requires_org_creation": True,
|
||||
"user": {"id": user.id, "email": user.email, "full_name": user.full_name},
|
||||
"session": _session_dict,
|
||||
"pending_invites": _pending_list,
|
||||
"state": state_record.state,
|
||||
}
|
||||
|
||||
except ExternalAuthError as e:
|
||||
logger.warning(
|
||||
f"OAuth registration failed for state={state_record.id}, "
|
||||
f"provider={provider_type_str}, error={e.message}"
|
||||
)
|
||||
raise
|
||||
except OAuthFlowError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in OAuth registration callback: {str(e)}", exc_info=True)
|
||||
raise OAuthFlowError(
|
||||
"An unexpected error occurred during registration",
|
||||
"INTERNAL_ERROR",
|
||||
500,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,150 @@
|
||||
"""OIDCService — public facade over the oidc sub-package."""
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidTokenError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OIDCError(Exception):
|
||||
def __init__(self, error: str, error_description: str = None, status_code: int = 400):
|
||||
self.error = error
|
||||
self.error_description = error_description
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class InvalidClientError(OIDCError):
|
||||
def __init__(self, error_description: str = "Invalid client"):
|
||||
super().__init__("invalid_client", error_description, 401)
|
||||
|
||||
|
||||
class InvalidGrantError(OIDCError):
|
||||
def __init__(self, error_description: str = "Invalid grant"):
|
||||
super().__init__("invalid_grant", error_description, 400)
|
||||
|
||||
|
||||
class InvalidRequestError(OIDCError):
|
||||
def __init__(self, error_description: str = "Invalid request"):
|
||||
super().__init__("invalid_request", error_description, 400)
|
||||
|
||||
|
||||
from gatehouse_app.services.oidc import auth_code as _auth_code
|
||||
from gatehouse_app.services.oidc import tokens as _tokens
|
||||
from gatehouse_app.services.oidc import userinfo as _userinfo
|
||||
|
||||
|
||||
class OIDCService:
|
||||
"""Main OIDC service handling all OpenID Connect operations."""
|
||||
|
||||
@staticmethod
|
||||
def _generate_code() -> str:
|
||||
import secrets
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _hash_value(value: str) -> str:
|
||||
import hashlib
|
||||
return hashlib.sha256(value.encode()).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def generate_authorization_code(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
redirect_uri: str,
|
||||
scope: list,
|
||||
state: str,
|
||||
nonce: str,
|
||||
code_challenge: str = None,
|
||||
code_challenge_method: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> str:
|
||||
return _auth_code.generate_authorization_code(
|
||||
client_id, user_id, redirect_uri, scope, state, nonce,
|
||||
code_challenge, code_challenge_method, ip_address, user_agent,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_authorization_code(
|
||||
cls,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> Tuple[Dict, object]:
|
||||
return _auth_code.validate_authorization_code(
|
||||
code, client_id, redirect_uri, code_verifier, ip_address, user_agent
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _compute_code_challenge(cls, verifier: str, method: str = "S256") -> str:
|
||||
return _auth_code._compute_code_challenge(verifier, method)
|
||||
|
||||
@classmethod
|
||||
def generate_tokens(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
scope: list,
|
||||
nonce: str = None,
|
||||
refresh_token: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
auth_time: int = None,
|
||||
) -> Dict:
|
||||
return _tokens.generate_tokens(
|
||||
client_id, user_id, scope, nonce, refresh_token, ip_address, user_agent, auth_time
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def refresh_access_token(
|
||||
cls,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
scope: list = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> Dict:
|
||||
return _tokens.refresh_access_token(refresh_token, client_id, scope, ip_address, user_agent)
|
||||
|
||||
@classmethod
|
||||
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
return _tokens.validate_access_token(token, client_id)
|
||||
|
||||
@classmethod
|
||||
def revoke_token(
|
||||
cls,
|
||||
token: str,
|
||||
client_id: str,
|
||||
token_type_hint: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> bool:
|
||||
return _tokens.revoke_token(token, client_id, token_type_hint, ip_address, user_agent)
|
||||
|
||||
@classmethod
|
||||
def introspect_token(
|
||||
cls,
|
||||
token: str,
|
||||
client_id: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> Dict:
|
||||
return _tokens.introspect_token(token, client_id, ip_address, user_agent)
|
||||
|
||||
@classmethod
|
||||
def get_jwks(cls) -> Dict:
|
||||
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
|
||||
return OIDCJWKSService().get_jwks()
|
||||
|
||||
@classmethod
|
||||
def get_userinfo(cls, access_token: str) -> Dict:
|
||||
return _userinfo.get_userinfo(access_token, cls.validate_access_token)
|
||||
|
||||
@staticmethod
|
||||
def _get_user_roles(user) -> list:
|
||||
return _userinfo._get_user_roles(user)
|
||||
@@ -0,0 +1,196 @@
|
||||
"""OIDC authorization code generation and validation."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from gatehouse_app.models import User, OIDCAuthCode
|
||||
from gatehouse_app.exceptions.validation_exceptions import ValidationError, NotFoundError
|
||||
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _hash_value(value: str) -> str:
|
||||
import hashlib
|
||||
return hashlib.sha256(value.encode()).hexdigest()
|
||||
|
||||
|
||||
def _compute_code_challenge(verifier: str, method: str = "S256") -> str:
|
||||
import hashlib
|
||||
import base64
|
||||
if method == "S256":
|
||||
digest = hashlib.sha256(verifier.encode()).digest()
|
||||
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return verifier
|
||||
|
||||
|
||||
def generate_authorization_code(
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
redirect_uri: str,
|
||||
scope: list,
|
||||
state: str,
|
||||
nonce: str,
|
||||
code_challenge: str = None,
|
||||
code_challenge_method: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> str:
|
||||
import secrets
|
||||
|
||||
from gatehouse_app.models import OIDCClient
|
||||
|
||||
logger.debug("[OIDC SERVICE] generate_authorization_code called")
|
||||
logger.debug("[OIDC SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[OIDC] Generate auth code - Client validation: client_id={client_id}, exists={client is not None}")
|
||||
|
||||
if not client:
|
||||
raise NotFoundError("Client not found")
|
||||
|
||||
if not client.is_active:
|
||||
raise ValidationError("Client is not active")
|
||||
|
||||
if not client.is_redirect_uri_allowed(redirect_uri):
|
||||
raise ValidationError("Invalid redirect_uri")
|
||||
|
||||
allowed_scopes = client.scopes or []
|
||||
valid_scopes = [s for s in scope if s in allowed_scopes]
|
||||
|
||||
if not valid_scopes:
|
||||
raise ValidationError("Invalid scopes")
|
||||
|
||||
code = secrets.token_urlsafe(32)
|
||||
code_hash = _hash_value(code)
|
||||
|
||||
auth_code = OIDCAuthCode.create_code(
|
||||
client_id=client.id,
|
||||
user_id=user_id,
|
||||
code_hash=code_hash,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=valid_scopes,
|
||||
nonce=nonce,
|
||||
code_verifier=code_challenge,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
logger.debug("[OIDC SERVICE] Auth code created, expires_at=%s", auth_code.expires_at.isoformat())
|
||||
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client.id,
|
||||
user_id=user_id,
|
||||
success=True,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=valid_scopes,
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def validate_authorization_code(
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> Tuple[Dict, User]:
|
||||
from gatehouse_app.models import OIDCClient
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidTokenError
|
||||
|
||||
logger.debug("[OIDC SERVICE] validate_authorization_code called, client_id=%s", client_id)
|
||||
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
logger.error(f"[OIDC] Validate auth code - Client not found: client_id={client_id}")
|
||||
from gatehouse_app.services.oidc import InvalidGrantError
|
||||
raise InvalidGrantError("Invalid client")
|
||||
|
||||
code_hash = _hash_value(code)
|
||||
auth_code = OIDCAuthCode.query.filter_by(
|
||||
code_hash=code_hash,
|
||||
client_id=client.id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not auth_code:
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client.id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Invalid or expired authorization code",
|
||||
)
|
||||
from gatehouse_app.services.oidc import InvalidGrantError
|
||||
raise InvalidGrantError("Invalid or expired authorization code")
|
||||
|
||||
if auth_code.is_used:
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client.id,
|
||||
user_id=auth_code.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Authorization code already used",
|
||||
)
|
||||
from gatehouse_app.services.oidc import InvalidGrantError
|
||||
raise InvalidGrantError("Authorization code already used")
|
||||
|
||||
expires_at = auth_code.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
logger.debug(
|
||||
"[OIDC SERVICE] Time until expiration (seconds): %s",
|
||||
(expires_at - datetime.now(timezone.utc)).total_seconds(),
|
||||
)
|
||||
|
||||
if auth_code.is_expired():
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client.id,
|
||||
user_id=auth_code.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Authorization code expired",
|
||||
)
|
||||
from gatehouse_app.services.oidc import InvalidGrantError
|
||||
raise InvalidGrantError("Authorization code expired")
|
||||
|
||||
if auth_code.redirect_uri != redirect_uri:
|
||||
from gatehouse_app.services.oidc import InvalidGrantError
|
||||
raise InvalidGrantError("Invalid redirect_uri")
|
||||
|
||||
if client.require_pkce and auth_code.code_verifier:
|
||||
if not code_verifier:
|
||||
raise ValidationError("code_verifier is required")
|
||||
expected_challenge = _compute_code_challenge(code_verifier, "S256")
|
||||
if expected_challenge != auth_code.code_verifier:
|
||||
OIDCAuditService.log_authorization_event(
|
||||
client_id=client.id,
|
||||
user_id=auth_code.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Invalid code_verifier",
|
||||
)
|
||||
from gatehouse_app.services.oidc import InvalidGrantError
|
||||
raise InvalidGrantError("Invalid code_verifier")
|
||||
|
||||
auth_code.mark_as_used()
|
||||
|
||||
user = User.query.get(auth_code.user_id)
|
||||
if not user:
|
||||
from gatehouse_app.services.oidc import InvalidGrantError
|
||||
raise InvalidGrantError("User not found")
|
||||
|
||||
claims = {
|
||||
"user_id": auth_code.user_id,
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": auth_code.scope,
|
||||
"nonce": auth_code.nonce,
|
||||
}
|
||||
|
||||
return claims, user
|
||||
@@ -0,0 +1,321 @@
|
||||
"""OIDC token generation, refresh, validation, revocation, and introspection."""
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from gatehouse_app.models import OIDCClient, OIDCRefreshToken, OIDCTokenMetadata
|
||||
from gatehouse_app.services.oidc_token_service import OIDCTokenService
|
||||
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidTokenError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_tokens(
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
scope: list,
|
||||
nonce: str = None,
|
||||
refresh_token: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
auth_time: int = None,
|
||||
) -> Dict:
|
||||
logger.debug("[OIDC SERVICE] generate_tokens called: client_id=%s, user_id=%s", client_id, user_id)
|
||||
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
from gatehouse_app.services.oidc import InvalidClientError
|
||||
raise InvalidClientError()
|
||||
|
||||
access_token_jti = OIDCTokenService._generate_jti()
|
||||
access_token = OIDCTokenService.create_access_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
jti=access_token_jti,
|
||||
)
|
||||
|
||||
id_token = OIDCTokenService.create_id_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
nonce=nonce,
|
||||
scope=scope,
|
||||
access_token=access_token,
|
||||
auth_time=auth_time,
|
||||
)
|
||||
|
||||
final_refresh_token = None
|
||||
if "refresh_token" in (client.grant_types or []):
|
||||
if refresh_token:
|
||||
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
||||
token_hash=hashlib.sha256(refresh_token.encode()).hexdigest(),
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if refresh_token_obj and refresh_token_obj.is_valid():
|
||||
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
access_token_id=access_token_jti,
|
||||
)
|
||||
refresh_token_obj.rotate(new_hash)
|
||||
final_refresh_token = new_refresh
|
||||
else:
|
||||
final_refresh_token, refresh_hash = OIDCTokenService.create_refresh_token(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
access_token_id=access_token_jti,
|
||||
)
|
||||
OIDCRefreshToken.create_token(
|
||||
client_id=client.id,
|
||||
user_id=user_id,
|
||||
token_hash=refresh_hash,
|
||||
scope=scope,
|
||||
access_token_id=access_token_jti,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
lifetime_seconds=client.refresh_token_lifetime or 2592000,
|
||||
)
|
||||
|
||||
access_token_expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=client.access_token_lifetime or 3600
|
||||
)
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client.id,
|
||||
user_id=user_id,
|
||||
token_type="access_token",
|
||||
token_jti=access_token_jti,
|
||||
expires_at=access_token_expires_at,
|
||||
)
|
||||
|
||||
id_token_jti = OIDCTokenService._generate_jti()
|
||||
id_token_expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=client.id_token_lifetime or 3600
|
||||
)
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client.id,
|
||||
user_id=user_id,
|
||||
token_type="id_token",
|
||||
token_jti=id_token_jti,
|
||||
expires_at=id_token_expires_at,
|
||||
)
|
||||
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client.id,
|
||||
user_id=user_id,
|
||||
token_type="access_token",
|
||||
success=True,
|
||||
grant_type="authorization_code",
|
||||
scopes=scope,
|
||||
)
|
||||
|
||||
result = {
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": client.access_token_lifetime or 3600,
|
||||
"id_token": id_token,
|
||||
}
|
||||
if final_refresh_token:
|
||||
result["refresh_token"] = final_refresh_token
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def refresh_access_token(
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
scope: list = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> Dict:
|
||||
logger.debug("[OIDC SERVICE] refresh_access_token called, client_id=%s", client_id)
|
||||
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
from gatehouse_app.services.oidc import InvalidClientError
|
||||
raise InvalidClientError()
|
||||
|
||||
token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
||||
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
||||
token_hash=token_hash,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not refresh_token_obj:
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client.id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Invalid refresh token",
|
||||
)
|
||||
from gatehouse_app.services.oidc import InvalidGrantError
|
||||
raise InvalidGrantError("Invalid refresh token")
|
||||
|
||||
if not refresh_token_obj.is_valid():
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client.id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
success=False,
|
||||
error_code="invalid_grant",
|
||||
error_description="Refresh token expired or revoked",
|
||||
)
|
||||
from gatehouse_app.services.oidc import InvalidGrantError
|
||||
raise InvalidGrantError("Refresh token expired or revoked")
|
||||
|
||||
if refresh_token_obj.client_id != client.id:
|
||||
from gatehouse_app.services.oidc import InvalidGrantError
|
||||
raise InvalidGrantError("Client mismatch")
|
||||
|
||||
granted_scope = scope or (refresh_token_obj.scope or [])
|
||||
|
||||
access_token_jti = OIDCTokenService._generate_jti()
|
||||
access_token = OIDCTokenService.create_access_token(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
scope=granted_scope,
|
||||
jti=access_token_jti,
|
||||
)
|
||||
|
||||
id_token = OIDCTokenService.create_id_token(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
scope=granted_scope,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
new_refresh, new_hash = OIDCTokenService.create_refresh_token(
|
||||
client_id=client_id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
scope=granted_scope,
|
||||
access_token_id=access_token_jti,
|
||||
)
|
||||
refresh_token_obj.rotate(new_hash)
|
||||
|
||||
access_token_expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=client.access_token_lifetime or 3600
|
||||
)
|
||||
OIDCTokenMetadata.create_metadata(
|
||||
client_id=client.id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
token_type="access_token",
|
||||
token_jti=access_token_jti,
|
||||
expires_at=access_token_expires_at,
|
||||
)
|
||||
|
||||
OIDCAuditService.log_token_event(
|
||||
client_id=client.id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
token_type="access_token",
|
||||
success=True,
|
||||
grant_type="refresh_token",
|
||||
scopes=granted_scope,
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": client.access_token_lifetime or 3600,
|
||||
"id_token": id_token,
|
||||
"refresh_token": new_refresh,
|
||||
}
|
||||
|
||||
|
||||
def validate_access_token(token: str, client_id: str = None) -> Dict:
|
||||
logger.debug("[OIDC SERVICE] validate_access_token() called")
|
||||
|
||||
try:
|
||||
claims = OIDCTokenService.validate_access_token(token, client_id)
|
||||
logger.debug("[OIDC SERVICE] Token validation successful")
|
||||
return claims
|
||||
except Exception as e:
|
||||
logger.error("[OIDC SERVICE] Token validation failed: %s: %s", type(e).__name__, str(e))
|
||||
_client_db_id = None
|
||||
if client_id:
|
||||
_c = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
_client_db_id = _c.id if _c else None
|
||||
OIDCAuditService.log_event(
|
||||
event_type="token_validation",
|
||||
client_id=_client_db_id,
|
||||
success=False,
|
||||
error_code="invalid_token",
|
||||
error_description=str(e),
|
||||
)
|
||||
raise InvalidTokenError(str(e))
|
||||
|
||||
|
||||
def revoke_token(
|
||||
token: str,
|
||||
client_id: str,
|
||||
token_type_hint: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> bool:
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
from gatehouse_app.services.oidc import InvalidClientError
|
||||
raise InvalidClientError()
|
||||
|
||||
revoked = False
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
if token_type_hint in (None, "refresh_token"):
|
||||
refresh_token_obj = OIDCRefreshToken.query.filter_by(
|
||||
token_hash=token_hash,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if refresh_token_obj:
|
||||
refresh_token_obj.revoke(reason="revoked_by_client")
|
||||
revoked = True
|
||||
OIDCAuditService.log_token_revocation_event(
|
||||
client_id=client.id,
|
||||
user_id=refresh_token_obj.user_id,
|
||||
token_type="refresh_token",
|
||||
reason="revoked_by_client",
|
||||
)
|
||||
|
||||
if not revoked or token_type_hint in (None, "access_token"):
|
||||
try:
|
||||
claims = OIDCTokenService.decode_token(token)
|
||||
jti = claims.get("jti")
|
||||
if jti:
|
||||
revoked_at = OIDCTokenMetadata.revoke_by_jti(jti, reason="revoked_by_client")
|
||||
if revoked_at:
|
||||
revoked = True
|
||||
OIDCAuditService.log_token_revocation_event(
|
||||
client_id=client.id,
|
||||
user_id=claims.get("sub"),
|
||||
token_type="access_token",
|
||||
reason="revoked_by_client",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return revoked
|
||||
|
||||
|
||||
def introspect_token(
|
||||
token: str,
|
||||
client_id: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> Dict:
|
||||
result = OIDCTokenService.introspect_token(token, client_id)
|
||||
|
||||
_client_db_id = None
|
||||
if client_id:
|
||||
_ic = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
_client_db_id = _ic.id if _ic else None
|
||||
OIDCAuditService.log_event(
|
||||
event_type="token_introspection",
|
||||
client_id=_client_db_id,
|
||||
user_id=result.get("sub"),
|
||||
success=result.get("active", False),
|
||||
metadata={"active": result.get("active")},
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,65 @@
|
||||
"""OIDC userinfo endpoint logic."""
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from gatehouse_app.models import User
|
||||
from gatehouse_app.exceptions.validation_exceptions import NotFoundError
|
||||
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_userinfo(access_token: str, validate_access_token_fn) -> Dict:
|
||||
logger.debug("[OIDC SERVICE] get_userinfo() called")
|
||||
|
||||
claims = validate_access_token_fn(access_token)
|
||||
user_id = claims.get("sub")
|
||||
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
logger.error("[OIDC SERVICE] User not found in database: user_id=%s", user_id)
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
scope_str = claims.get("scope", "")
|
||||
scopes = scope_str.split() if scope_str else []
|
||||
|
||||
userinfo = {"sub": user_id}
|
||||
|
||||
if "profile" in scopes and user.full_name:
|
||||
userinfo["name"] = user.full_name
|
||||
|
||||
if "email" in scopes:
|
||||
userinfo["email"] = user.email
|
||||
userinfo["email_verified"] = user.email_verified
|
||||
|
||||
if "roles" in scopes:
|
||||
userinfo["roles"] = _get_user_roles(user)
|
||||
|
||||
_userinfo_client_id_str = claims.get("client_id")
|
||||
_userinfo_client_db_id = None
|
||||
if _userinfo_client_id_str:
|
||||
from gatehouse_app.models import OIDCClient
|
||||
_uc = OIDCClient.query.filter_by(client_id=_userinfo_client_id_str).first()
|
||||
_userinfo_client_db_id = _uc.id if _uc else None
|
||||
|
||||
OIDCAuditService.log_userinfo_event(
|
||||
access_token=access_token,
|
||||
user_id=user_id,
|
||||
client_id=_userinfo_client_db_id,
|
||||
success=True,
|
||||
scopes_claimed=scopes,
|
||||
)
|
||||
|
||||
return userinfo
|
||||
|
||||
|
||||
def _get_user_roles(user: User) -> list:
|
||||
roles = []
|
||||
if not user or not user.organization_memberships:
|
||||
return roles
|
||||
for member in user.organization_memberships:
|
||||
roles.append({
|
||||
"organization_id": str(member.organization_id),
|
||||
"role": member.role.value,
|
||||
})
|
||||
return roles
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,6 @@
|
||||
"""Organization service."""
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from flask import current_app
|
||||
from gatehouse_app.extensions import db
|
||||
@@ -157,6 +158,12 @@ class OrganizationService:
|
||||
Returns:
|
||||
Deleted Organization instance
|
||||
"""
|
||||
if soft:
|
||||
# Mangle slug so it can be reused
|
||||
original_slug = org.slug
|
||||
org.slug = f"{original_slug}__deleted_{uuid.uuid4().hex[:8]}"
|
||||
org.is_active = False
|
||||
|
||||
org.delete(soft=soft)
|
||||
|
||||
# Log organization deletion
|
||||
@@ -174,11 +181,16 @@ class OrganizationService:
|
||||
@staticmethod
|
||||
def force_delete_organization(org, user_id):
|
||||
"""
|
||||
Force-delete an organization and all its members in a single atomic operation.
|
||||
Force-delete an organization and ALL associated data in a single atomic
|
||||
operation.
|
||||
|
||||
All active memberships are soft-deleted before the organization itself
|
||||
is soft-deleted, preventing orphaned membership rows and avoiding any
|
||||
cascade deadlocks.
|
||||
Cleans up:
|
||||
- All active memberships (soft-deleted)
|
||||
- MFA policy compliance records for this org
|
||||
- User security policy overrides for this org
|
||||
- Pending invite tokens for this org
|
||||
- OIDC clients for this org
|
||||
- The organization slug is mangled so the same slug can be reused
|
||||
|
||||
Args:
|
||||
org: Organization instance
|
||||
@@ -187,31 +199,71 @@ class OrganizationService:
|
||||
Returns:
|
||||
Deleted Organization instance
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
|
||||
from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
member_count = 0
|
||||
cleanup_counts = {}
|
||||
|
||||
# Soft-delete all active memberships first.
|
||||
# 1. Soft-delete all active memberships first.
|
||||
for member in org.members:
|
||||
if member.deleted_at is None:
|
||||
member.deleted_at = now
|
||||
member_count += 1
|
||||
|
||||
# Now soft-delete the organization itself.
|
||||
org.delete(soft=True)
|
||||
# 2. Remove MFA compliance records for this org so the compliance job
|
||||
# doesn't accidentally process stale records for a deleted org.
|
||||
compliance_records = MfaPolicyCompliance.query.filter_by(
|
||||
organization_id=org.id,
|
||||
).filter(MfaPolicyCompliance.deleted_at == None).all()
|
||||
for record in compliance_records:
|
||||
record.deleted_at = now
|
||||
cleanup_counts["compliance_records"] = len(compliance_records)
|
||||
|
||||
# Log with member count for audit trail.
|
||||
# 3. Remove user security policy overrides for this org.
|
||||
user_policies = UserSecurityPolicy.query.filter_by(
|
||||
organization_id=org.id,
|
||||
).filter(UserSecurityPolicy.deleted_at == None).all()
|
||||
for policy in user_policies:
|
||||
policy.deleted_at = now
|
||||
cleanup_counts["user_security_policies"] = len(user_policies)
|
||||
|
||||
# 4. Remove pending invite tokens for this org.
|
||||
pending_invites = OrgInviteToken.query.filter_by(
|
||||
organization_id=org.id,
|
||||
).filter(OrgInviteToken.accepted_at == None, OrgInviteToken.deleted_at == None).all()
|
||||
for invite in pending_invites:
|
||||
invite.deleted_at = now
|
||||
cleanup_counts["pending_invites"] = len(pending_invites)
|
||||
|
||||
# 5. Mangle the slug so the same slug can be reused for a new org.
|
||||
# Format: "original-slug__deleted_<short-uuid>"
|
||||
original_slug = org.slug
|
||||
org.slug = f"{original_slug}__deleted_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 6. Now soft-delete the organization itself.
|
||||
org.deleted_at = now
|
||||
org.is_active = False
|
||||
db.session.commit()
|
||||
|
||||
# Log with member count and cleanup summary for audit trail.
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_DELETE,
|
||||
user_id=user_id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization",
|
||||
resource_id=org.id,
|
||||
metadata={"members_removed": member_count},
|
||||
metadata={
|
||||
"members_removed": member_count,
|
||||
"original_slug": original_slug,
|
||||
**cleanup_counts,
|
||||
},
|
||||
description=(
|
||||
f"Organization deleted by owner; "
|
||||
f"{member_count} membership(s) removed."
|
||||
f"Organization '{original_slug}' deleted by owner; "
|
||||
f"{member_count} membership(s) removed, "
|
||||
f"{cleanup_counts.get('compliance_records', 0)} compliance record(s) cleaned."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user