Feat(Chore, Fix): Refractor, Half Baked Deletion + Admin Privilege
Refractor Codes into sub file/folders Admin can remove users'/members mfa/2fa, unlink account from oauth provider Admin can add/reset password Different Email (OIDC + Manual)-Same Account; (Block Linking and authorize if available)
This commit is contained in:
@@ -0,0 +1,209 @@
|
||||
"""OAuthFlowService — public facade and handle_callback dispatcher."""
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from gatehouse_app.models.auth.authentication_method import OAuthState
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth import ExternalAuthService
|
||||
from gatehouse_app.services.external_auth.models import ExternalAuthError
|
||||
|
||||
from gatehouse_app.services.oauth_flow.login import OAuthFlowError, initiate_login_flow, handle_login_callback
|
||||
from gatehouse_app.services.oauth_flow.register import initiate_register_flow, handle_register_callback
|
||||
from gatehouse_app.services.oauth_flow.code import (
|
||||
generate_authorization_code,
|
||||
exchange_authorization_code,
|
||||
create_redirect_response,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthFlowService:
|
||||
"""Service for managing OAuth authentication flows."""
|
||||
|
||||
@classmethod
|
||||
def initiate_login_flow(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
state_data: dict = None,
|
||||
) -> Tuple[str, str]:
|
||||
return initiate_login_flow(provider_type, organization_id, redirect_uri, state_data)
|
||||
|
||||
@classmethod
|
||||
def initiate_register_flow(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
) -> Tuple[str, str]:
|
||||
return initiate_register_flow(provider_type, organization_id, redirect_uri)
|
||||
|
||||
@classmethod
|
||||
def handle_callback(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
authorization_code: str,
|
||||
state: str,
|
||||
redirect_uri: str = None,
|
||||
error: str = None,
|
||||
error_description: str = None,
|
||||
) -> dict:
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
from flask import request
|
||||
ip_address = request.remote_addr if request else None
|
||||
user_agent = request.headers.get("User-Agent") if request else None
|
||||
except RuntimeError:
|
||||
ip_address = None
|
||||
user_agent = None
|
||||
|
||||
if error:
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=None,
|
||||
provider_type=provider_type_str,
|
||||
failure_reason=error,
|
||||
error_message=error_description or error,
|
||||
)
|
||||
raise OAuthFlowError(
|
||||
error_description or f"OAuth error: {error}",
|
||||
error.upper() if error else "OAUTH_ERROR",
|
||||
400,
|
||||
)
|
||||
|
||||
state_record = OAuthState.query.filter_by(state=state).first()
|
||||
|
||||
if state_record:
|
||||
logger.debug(
|
||||
f"State validation: found=True, used={state_record.used}, "
|
||||
f"expires_at={state_record.expires_at}, is_valid={state_record.is_valid()}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"State validation: state token not found in database: {state}")
|
||||
|
||||
if not state_record or not state_record.is_valid():
|
||||
AuditService.log_external_auth_login_failed(
|
||||
organization_id=state_record.organization_id if state_record else None,
|
||||
provider_type=provider_type_str,
|
||||
failure_reason="invalid_state",
|
||||
error_message="Invalid or expired OAuth state",
|
||||
)
|
||||
raise OAuthFlowError("Invalid or expired OAuth state", "INVALID_STATE", 400)
|
||||
|
||||
effective_redirect = redirect_uri or state_record.redirect_uri
|
||||
|
||||
if state_record.flow_type == "login":
|
||||
return handle_login_callback(
|
||||
provider_type=provider_type,
|
||||
state_record=state_record,
|
||||
authorization_code=authorization_code,
|
||||
redirect_uri=effective_redirect,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
elif state_record.flow_type == "link":
|
||||
return cls._handle_link_callback(
|
||||
provider_type=provider_type,
|
||||
state_record=state_record,
|
||||
authorization_code=authorization_code,
|
||||
redirect_uri=effective_redirect,
|
||||
)
|
||||
elif state_record.flow_type == "register":
|
||||
return handle_register_callback(
|
||||
provider_type=provider_type,
|
||||
state_record=state_record,
|
||||
authorization_code=authorization_code,
|
||||
redirect_uri=effective_redirect,
|
||||
)
|
||||
else:
|
||||
raise OAuthFlowError(
|
||||
f"Unknown flow type: {state_record.flow_type}",
|
||||
"INVALID_FLOW_TYPE",
|
||||
400,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _handle_link_callback(
|
||||
cls,
|
||||
provider_type: AuthMethodType,
|
||||
state_record: OAuthState,
|
||||
authorization_code: str,
|
||||
redirect_uri: str,
|
||||
) -> dict:
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
auth_method = ExternalAuthService.complete_link_flow(
|
||||
provider_type=provider_type,
|
||||
authorization_code=authorization_code,
|
||||
state=state_record.state,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth link successful for user={state_record.user_id}, "
|
||||
f"provider={provider_type_str}, auth_method_id={auth_method.id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "link",
|
||||
"linked_account": {
|
||||
"id": auth_method.id,
|
||||
"provider_type": provider_type_str,
|
||||
"provider_user_id": auth_method.provider_user_id,
|
||||
"verified": auth_method.verified,
|
||||
},
|
||||
}
|
||||
|
||||
except ExternalAuthError as e:
|
||||
logger.warning(
|
||||
f"OAuth link failed for state={state_record.id}, "
|
||||
f"provider={provider_type_str}, error={e.message}"
|
||||
)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def validate_state(cls, state: str) -> Optional[OAuthState]:
|
||||
state_record = OAuthState.query.filter_by(state=state).first()
|
||||
if state_record and state_record.is_valid():
|
||||
return state_record
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired_states(cls):
|
||||
OAuthState.cleanup_expired()
|
||||
logger.info("Expired OAuth states cleaned up")
|
||||
|
||||
@classmethod
|
||||
def generate_authorization_code(
|
||||
cls,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
scope: list = None,
|
||||
nonce: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> str:
|
||||
return generate_authorization_code(
|
||||
user_id, client_id, redirect_uri, scope, nonce, ip_address, user_agent, lifetime_seconds
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def exchange_authorization_code(
|
||||
cls,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
ip_address: str = None,
|
||||
) -> dict:
|
||||
return exchange_authorization_code(code, client_id, redirect_uri, ip_address)
|
||||
|
||||
@classmethod
|
||||
def create_redirect_response(cls, redirect_uri: str, authorization_code: str, state: str = None):
|
||||
return create_redirect_response(redirect_uri, authorization_code, state)
|
||||
@@ -0,0 +1,141 @@
|
||||
"""Authorization code generation, exchange, and redirect helpers."""
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.services.oauth_flow.login import OAuthFlowError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_authorization_code(
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
scope: list = None,
|
||||
nonce: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> str:
|
||||
code = secrets.token_urlsafe(32)
|
||||
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||
|
||||
OIDCAuthCode.create_code(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
code_hash=code_hash,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
nonce=nonce,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
lifetime_seconds=lifetime_seconds,
|
||||
)
|
||||
|
||||
logger.info(f"Generated authorization code for user={user_id}, client={client_id}")
|
||||
return code
|
||||
|
||||
|
||||
def exchange_authorization_code(
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
ip_address: str = None,
|
||||
) -> dict:
|
||||
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||
|
||||
auth_code = OIDCAuthCode.query.filter_by(
|
||||
client_id=client_id,
|
||||
code_hash=code_hash,
|
||||
).first()
|
||||
|
||||
if not auth_code:
|
||||
raise OAuthFlowError("Invalid authorization code", "INVALID_CODE", 400)
|
||||
|
||||
if not auth_code.is_valid():
|
||||
if auth_code.is_used:
|
||||
raise OAuthFlowError(
|
||||
"Authorization code has already been used", "CODE_USED", 400
|
||||
)
|
||||
else:
|
||||
raise OAuthFlowError("Authorization code has expired", "CODE_EXPIRED", 400)
|
||||
|
||||
if auth_code.redirect_uri != redirect_uri:
|
||||
raise OAuthFlowError("Redirect URI mismatch", "INVALID_REDIRECT_URI", 400)
|
||||
|
||||
from gatehouse_app.models import User
|
||||
user = User.query.get(auth_code.user_id)
|
||||
if not user:
|
||||
raise OAuthFlowError("User not found", "USER_NOT_FOUND", 404)
|
||||
|
||||
user_orgs = user.get_organizations()
|
||||
target_org = None
|
||||
if len(user_orgs) == 1:
|
||||
target_org = user_orgs[0]
|
||||
|
||||
if not target_org:
|
||||
raise OAuthFlowError(
|
||||
"User does not have a default organization. Organization selection required.",
|
||||
"ORG_SELECTION_REQUIRED",
|
||||
400,
|
||||
)
|
||||
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user, is_compliance_only=False)
|
||||
auth_code.mark_as_used()
|
||||
|
||||
session_dict = session.to_dict()
|
||||
session_dict["token"] = session.token
|
||||
expires_at = session.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||
|
||||
logger.info(
|
||||
f"Authorization code exchanged for session: user={user.id}, "
|
||||
f"org_id={target_org.id}, client={client_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"token": session_dict["token"],
|
||||
"expires_in": session_dict["expires_in"],
|
||||
"token_type": "Bearer",
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": target_org.id,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_redirect_response(
|
||||
redirect_uri: str,
|
||||
authorization_code: str,
|
||||
state: str = None,
|
||||
):
|
||||
from urllib.parse import urlencode, urlparse, urlunparse
|
||||
from flask import redirect
|
||||
|
||||
parsed = urlparse(redirect_uri)
|
||||
params = {"code": authorization_code}
|
||||
if state:
|
||||
params["state"] = state
|
||||
|
||||
redirect_url = urlunparse((
|
||||
parsed.scheme,
|
||||
parsed.netloc,
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
urlencode(params),
|
||||
parsed.fragment,
|
||||
))
|
||||
|
||||
logger.info(f"Redirecting to {parsed.scheme}://{parsed.netloc} with authorization code")
|
||||
return redirect(redirect_url)
|
||||
@@ -0,0 +1,410 @@
|
||||
"""Login flow: initiate and handle OAuth login callback."""
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.auth.authentication_method import OAuthState
|
||||
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth import ExternalAuthService
|
||||
from gatehouse_app.services.external_auth.models import ExternalAuthError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthFlowError(Exception):
|
||||
def __init__(self, message: str, error_type: str, status_code: int = 400):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def initiate_login_flow(
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
state_data: dict = None,
|
||||
) -> Tuple[str, str]:
|
||||
try:
|
||||
from flask import request
|
||||
except Exception:
|
||||
request = None
|
||||
|
||||
try:
|
||||
ip_address = request.remote_addr if request else None
|
||||
user_agent = request.headers.get("User-Agent") if request else None
|
||||
except RuntimeError:
|
||||
ip_address = None
|
||||
user_agent = None
|
||||
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
config = ExternalAuthService.get_provider_config(provider_type, organization_id)
|
||||
|
||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||
raise OAuthFlowError("Invalid redirect URI", "INVALID_REDIRECT_URI", 400)
|
||||
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if provider_type_str not in ['google', 'microsoft']:
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Provider type check: provider_type_str='{provider_type_str}', "
|
||||
f"is_google={provider_type_str in ['google']}, "
|
||||
f"will_skip_pkce={provider_type_str in ['google', 'microsoft']}"
|
||||
)
|
||||
|
||||
state = OAuthState.create_state(
|
||||
flow_type="login",
|
||||
provider_type=provider_type,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None),
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
extra_data=state_data,
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Created OAuthState object:\n"
|
||||
f" state.id: {state.id}\n"
|
||||
f" state.provider_type: {state.provider_type}\n"
|
||||
f" state.code_challenge: {state.code_challenge}\n"
|
||||
f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..."
|
||||
)
|
||||
|
||||
auth_url = ExternalAuthService._build_authorization_url(config=config, state=state)
|
||||
|
||||
logger.info(
|
||||
f"OAuth login flow initiated for provider={provider_type_str}, "
|
||||
f"org_id={organization_id}, state_token={state.state}, state_record_id={state.id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] FINAL CHECK: code_challenge={code_challenge}, "
|
||||
f"code_verifier={code_verifier[:20] if code_verifier else None}..., "
|
||||
f"auth_url_has_challenge={'code_challenge=' in auth_url}, "
|
||||
f"returned_auth_url={auth_url}"
|
||||
)
|
||||
|
||||
return auth_url, state.state
|
||||
|
||||
except ExternalAuthError as e:
|
||||
AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
"failure_reason": e.error_type,
|
||||
"ip_address": ip_address,
|
||||
},
|
||||
description=f"OAuth login initiation failed: {e.message}",
|
||||
success=False,
|
||||
error_message=e.message,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def handle_login_callback(
|
||||
provider_type: AuthMethodType,
|
||||
state_record: OAuthState,
|
||||
authorization_code: str,
|
||||
redirect_uri: str,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> dict:
|
||||
from gatehouse_app.services.external_auth._helpers import _encrypt_provider_data
|
||||
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
config = ExternalAuthService.get_provider_config(
|
||||
provider_type, state_record.organization_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Exchanging code with PKCE: state_record.code_verifier="
|
||||
f"{state_record.code_verifier[:20] if state_record.code_verifier else None}..."
|
||||
)
|
||||
|
||||
tokens = ExternalAuthService._exchange_code(
|
||||
config=config,
|
||||
code=authorization_code,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=state_record.code_verifier,
|
||||
)
|
||||
|
||||
user_info = ExternalAuthService._get_user_info(
|
||||
config=config,
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
if not user_info.get("provider_user_id"):
|
||||
raise OAuthFlowError(
|
||||
"Provider did not return a user identifier (sub claim). "
|
||||
"Cannot complete authentication.",
|
||||
"MISSING_PROVIDER_USER_ID",
|
||||
400,
|
||||
)
|
||||
|
||||
if not user_info.get("email"):
|
||||
raise OAuthFlowError(
|
||||
"Provider did not return an email address. "
|
||||
"Cannot complete authentication.",
|
||||
"MISSING_EMAIL",
|
||||
400,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Got user_info from provider: sub={user_info['provider_user_id']}, "
|
||||
f"email={user_info['email']}, email_verified={user_info.get('email_verified')}"
|
||||
)
|
||||
|
||||
# Find the active auth method for this provider identity.
|
||||
# Order by created_at DESC so that an explicitly linked (newer) row wins
|
||||
# over an older auto-created primary row when the same Google identity
|
||||
# was linked to a second profile.
|
||||
auth_method = (
|
||||
AuthenticationMethod.query
|
||||
.filter_by(
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
deleted_at=None,
|
||||
)
|
||||
.order_by(AuthenticationMethod.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not auth_method:
|
||||
deleted_method = (
|
||||
AuthenticationMethod.query
|
||||
.filter_by(
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
)
|
||||
.order_by(AuthenticationMethod.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if deleted_method:
|
||||
logger.info(
|
||||
f"OAuth login: restoring previously unlinked {provider_type_str} "
|
||||
f"auth method for user {deleted_method.user_id}"
|
||||
)
|
||||
deleted_method.deleted_at = None
|
||||
deleted_method.provider_data = _encrypt_provider_data(tokens, user_info)
|
||||
deleted_method.last_used_at = datetime.utcnow()
|
||||
deleted_method.save()
|
||||
auth_method = deleted_method
|
||||
|
||||
else:
|
||||
existing_user = User.query.filter_by(email=user_info["email"], deleted_at=None).first()
|
||||
|
||||
if existing_user:
|
||||
logger.info(
|
||||
f"OAuth login: email {user_info['email']} matches existing user "
|
||||
f"{existing_user.id}, auto-linking {provider_type_str} account"
|
||||
)
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=existing_user.id,
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
provider_data=_encrypt_provider_data(tokens, user_info),
|
||||
verified=user_info.get("email_verified", False),
|
||||
is_primary=False,
|
||||
last_used_at=datetime.utcnow(),
|
||||
)
|
||||
auth_method.save()
|
||||
user = existing_user
|
||||
else:
|
||||
logger.info(
|
||||
f"OAuth login: no account for {user_info['email']}, "
|
||||
f"auto-creating user via {provider_type_str}"
|
||||
)
|
||||
user = User(
|
||||
email=user_info["email"],
|
||||
full_name=user_info.get("name", ""),
|
||||
status="active",
|
||||
email_verified=user_info.get("email_verified", False),
|
||||
)
|
||||
user.save()
|
||||
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=user.id,
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
provider_data=_encrypt_provider_data(tokens, user_info),
|
||||
verified=user_info.get("email_verified", False),
|
||||
is_primary=True,
|
||||
last_used_at=datetime.utcnow(),
|
||||
)
|
||||
auth_method.save()
|
||||
|
||||
AuditService.log_action(
|
||||
action="user.register",
|
||||
user_id=user.id,
|
||||
organization_id=state_record.organization_id,
|
||||
resource_type="user",
|
||||
resource_id=user.id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
"provider_user_id": user_info["provider_user_id"],
|
||||
"auto_registered": True,
|
||||
},
|
||||
description=f"User auto-registered via {provider_type_str} OAuth",
|
||||
success=True,
|
||||
)
|
||||
else:
|
||||
auth_method.provider_data = _encrypt_provider_data(tokens, user_info)
|
||||
auth_method.last_used_at = datetime.utcnow()
|
||||
auth_method.save()
|
||||
|
||||
user = auth_method.user
|
||||
|
||||
user_orgs = user.get_organizations()
|
||||
target_org = None
|
||||
|
||||
if state_record.organization_id:
|
||||
target_org = next(
|
||||
(org for org in user_orgs if org.id == state_record.organization_id),
|
||||
None,
|
||||
)
|
||||
|
||||
if not target_org and len(user_orgs) == 1:
|
||||
target_org = user_orgs[0]
|
||||
|
||||
if not target_org and len(user_orgs) > 1:
|
||||
# Multiple orgs and none specified in the OAuth state — pick the one the
|
||||
# user joined most recently (highest created_at on their membership row).
|
||||
# Users can switch organisations inside the app after logging in.
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember as _OM
|
||||
latest_membership = (
|
||||
_OM.query
|
||||
.filter_by(user_id=user.id, deleted_at=None)
|
||||
.order_by(_OM.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if latest_membership:
|
||||
target_org = latest_membership.organization
|
||||
else:
|
||||
target_org = user_orgs[0]
|
||||
|
||||
if not target_org and len(user_orgs) == 0:
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
from gatehouse_app.services.auth_service import AuthService as _AS
|
||||
_now = datetime.now(timezone.utc)
|
||||
_session = _AS.create_session(user=user, is_compliance_only=False)
|
||||
_session_dict = _session.to_dict()
|
||||
_session_dict["token"] = _session.token
|
||||
_expires_at = _session.expires_at
|
||||
if _expires_at.tzinfo is None:
|
||||
_expires_at = _expires_at.replace(tzinfo=timezone.utc)
|
||||
_session_dict["expires_in"] = int((_expires_at - _now).total_seconds())
|
||||
|
||||
_pending = OrgInviteToken.query.filter(
|
||||
OrgInviteToken.email == user.email,
|
||||
OrgInviteToken.accepted_at.is_(None),
|
||||
OrgInviteToken.expires_at > _now,
|
||||
OrgInviteToken.deleted_at.is_(None),
|
||||
).all()
|
||||
_pending_list = [
|
||||
{
|
||||
"token": inv.token,
|
||||
"organization": {"id": str(inv.organization_id), "name": inv.organization.name},
|
||||
"role": inv.role,
|
||||
"expires_at": inv.expires_at.isoformat(),
|
||||
}
|
||||
for inv in _pending
|
||||
]
|
||||
|
||||
state_record.mark_used()
|
||||
logger.info(
|
||||
f"OAuth login: user {user.id} has no org, redirecting to org-setup "
|
||||
f"(pending_invites={len(_pending_list)})"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "login",
|
||||
"requires_org_creation": True,
|
||||
"user": {"id": user.id, "email": user.email, "full_name": user.full_name},
|
||||
"session": _session_dict,
|
||||
"pending_invites": _pending_list,
|
||||
"state": state_record.state,
|
||||
}
|
||||
|
||||
if not target_org:
|
||||
state_record.mark_used()
|
||||
logger.info(
|
||||
f"OAuth login requires org selection for user={user.id}, "
|
||||
f"provider={provider_type_str}, org_count={len(user_orgs)}"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "login",
|
||||
"requires_org_selection": True,
|
||||
"user": {"id": user.id, "email": user.email, "full_name": user.full_name},
|
||||
"available_organizations": [
|
||||
{
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
"slug": org.slug if hasattr(org, "slug") else None,
|
||||
}
|
||||
for org in user_orgs
|
||||
],
|
||||
"state": state_record.state,
|
||||
}
|
||||
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user, is_compliance_only=False)
|
||||
state_record.mark_used()
|
||||
|
||||
AuditService.log_external_auth_login(
|
||||
user_id=user.id,
|
||||
organization_id=target_org.id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
auth_method_id=auth_method.id,
|
||||
session_id=session.id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth login successful for user={user.id}, "
|
||||
f"provider={provider_type_str}, org_id={target_org.id}"
|
||||
)
|
||||
|
||||
session_dict = session.to_dict()
|
||||
session_dict["token"] = session.token
|
||||
expires_at = session.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "login",
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": target_org.id,
|
||||
},
|
||||
"session": session_dict,
|
||||
}
|
||||
|
||||
except ExternalAuthError as e:
|
||||
logger.warning(
|
||||
f"OAuth login failed for state={state_record.id}, "
|
||||
f"provider={provider_type_str}, error={e.message}"
|
||||
)
|
||||
raise
|
||||
except OAuthFlowError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in OAuth login callback: {str(e)}", exc_info=True)
|
||||
raise OAuthFlowError("An unexpected error occurred during login", "INTERNAL_ERROR", 500)
|
||||
@@ -0,0 +1,248 @@
|
||||
"""Registration flow: initiate and handle OAuth register callback."""
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.auth.authentication_method import OAuthState
|
||||
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth import ExternalAuthService
|
||||
from gatehouse_app.services.external_auth.models import ExternalAuthError
|
||||
from gatehouse_app.services.oauth_flow.login import OAuthFlowError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initiate_register_flow(
|
||||
provider_type: AuthMethodType,
|
||||
organization_id: str = None,
|
||||
redirect_uri: str = None,
|
||||
) -> Tuple[str, str]:
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
config = ExternalAuthService.get_provider_config(provider_type, organization_id)
|
||||
|
||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||
raise OAuthFlowError("Invalid redirect URI", "INVALID_REDIRECT_URI", 400)
|
||||
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if provider_type_str not in ['google', 'microsoft']:
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Register flow - Provider type check: provider_type_str='{provider_type_str}', "
|
||||
f"is_google={provider_type_str in ['google']}, "
|
||||
f"will_skip_pkce={provider_type_str in ['google', 'microsoft']}"
|
||||
)
|
||||
|
||||
state = OAuthState.create_state(
|
||||
flow_type="register",
|
||||
provider_type=provider_type,
|
||||
organization_id=organization_id,
|
||||
redirect_uri=redirect_uri or (config.redirect_uris[0] if config.redirect_uris else None),
|
||||
code_verifier=code_verifier,
|
||||
code_challenge=code_challenge,
|
||||
lifetime_seconds=600,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Register flow - Created OAuthState:\n"
|
||||
f" state.id: {state.id}\n"
|
||||
f" state.code_challenge: {state.code_challenge}\n"
|
||||
f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..."
|
||||
)
|
||||
|
||||
auth_url = ExternalAuthService._build_authorization_url(config=config, state=state)
|
||||
|
||||
logger.info(
|
||||
f"OAuth register flow initiated for provider={provider_type_str}, "
|
||||
f"org_id={organization_id}, state_id={state.id}"
|
||||
)
|
||||
logger.info(
|
||||
f"[PKCE DEBUG] Register flow - FINAL: auth_url_has_challenge={'code_challenge=' in auth_url}"
|
||||
)
|
||||
|
||||
return auth_url, state.state
|
||||
|
||||
except ExternalAuthError as e:
|
||||
AuditService.log_action(
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
"failure_reason": e.error_type,
|
||||
},
|
||||
description=f"OAuth registration initiation failed: {e.message}",
|
||||
success=False,
|
||||
error_message=e.message,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def handle_register_callback(
|
||||
provider_type: AuthMethodType,
|
||||
state_record: OAuthState,
|
||||
authorization_code: str,
|
||||
redirect_uri: str,
|
||||
) -> dict:
|
||||
from gatehouse_app.services.external_auth._helpers import _encrypt_provider_data
|
||||
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
try:
|
||||
config = ExternalAuthService.get_provider_config(
|
||||
provider_type, state_record.organization_id
|
||||
)
|
||||
|
||||
tokens = ExternalAuthService._exchange_code(
|
||||
config=config,
|
||||
code=authorization_code,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=state_record.code_verifier,
|
||||
)
|
||||
|
||||
user_info = ExternalAuthService._get_user_info(
|
||||
config=config,
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
existing_user = User.query.filter_by(email=user_info["email"]).first()
|
||||
if existing_user:
|
||||
raise OAuthFlowError(
|
||||
f"An account with email {user_info['email']} already exists. "
|
||||
"Please log in with your password and link your Google account from settings.",
|
||||
"EMAIL_EXISTS",
|
||||
400,
|
||||
)
|
||||
|
||||
user = User(
|
||||
email=user_info["email"],
|
||||
full_name=user_info.get("name", ""),
|
||||
status="active",
|
||||
email_verified=user_info.get("email_verified", False),
|
||||
)
|
||||
user.save()
|
||||
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=user.id,
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
provider_data=_encrypt_provider_data(tokens, user_info),
|
||||
verified=user_info.get("email_verified", False),
|
||||
is_primary=True,
|
||||
last_used_at=datetime.utcnow(),
|
||||
)
|
||||
auth_method.save()
|
||||
|
||||
state_record.mark_used()
|
||||
|
||||
AuditService.log_action(
|
||||
action="user.register",
|
||||
user_id=user.id,
|
||||
organization_id=state_record.organization_id,
|
||||
resource_type="user",
|
||||
resource_id=user.id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
"provider_user_id": user_info["provider_user_id"],
|
||||
"auth_method_id": auth_method.id,
|
||||
},
|
||||
description=f"User registered via {provider_type_str}",
|
||||
success=True,
|
||||
)
|
||||
|
||||
AuditService.log_external_auth_link_completed(
|
||||
user_id=user.id,
|
||||
organization_id=state_record.organization_id,
|
||||
provider_type=provider_type_str,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
auth_method_id=auth_method.id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OAuth registration successful for email={user_info['email']}, "
|
||||
f"provider={provider_type_str}, user_id={user.id}"
|
||||
)
|
||||
|
||||
if state_record.organization_id:
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
org = Organization.query.get(state_record.organization_id)
|
||||
if org:
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user, is_compliance_only=False)
|
||||
session_dict = session.to_dict()
|
||||
session_dict["token"] = session.token
|
||||
expires_at = session.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "register",
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"organization_id": org.id,
|
||||
},
|
||||
"session": session_dict,
|
||||
}
|
||||
|
||||
from gatehouse_app.services.auth_service import AuthService as _AS
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
_session = _AS.create_session(user=user, is_compliance_only=False)
|
||||
_session_dict = _session.to_dict()
|
||||
_session_dict["token"] = _session.token
|
||||
_expires_at = _session.expires_at
|
||||
if _expires_at.tzinfo is None:
|
||||
_expires_at = _expires_at.replace(tzinfo=timezone.utc)
|
||||
_now = datetime.now(timezone.utc)
|
||||
_session_dict["expires_in"] = int((_expires_at - _now).total_seconds())
|
||||
|
||||
_pending = OrgInviteToken.query.filter(
|
||||
OrgInviteToken.email == user.email,
|
||||
OrgInviteToken.accepted_at.is_(None),
|
||||
OrgInviteToken.expires_at > _now,
|
||||
OrgInviteToken.deleted_at.is_(None),
|
||||
).all()
|
||||
_pending_list = [
|
||||
{
|
||||
"token": inv.token,
|
||||
"organization": {"id": str(inv.organization_id), "name": inv.organization.name},
|
||||
"role": inv.role,
|
||||
"expires_at": inv.expires_at.isoformat(),
|
||||
}
|
||||
for inv in _pending
|
||||
]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "register",
|
||||
"requires_org_creation": True,
|
||||
"user": {"id": user.id, "email": user.email, "full_name": user.full_name},
|
||||
"session": _session_dict,
|
||||
"pending_invites": _pending_list,
|
||||
"state": state_record.state,
|
||||
}
|
||||
|
||||
except ExternalAuthError as e:
|
||||
logger.warning(
|
||||
f"OAuth registration failed for state={state_record.id}, "
|
||||
f"provider={provider_type_str}, error={e.message}"
|
||||
)
|
||||
raise
|
||||
except OAuthFlowError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in OAuth registration callback: {str(e)}", exc_info=True)
|
||||
raise OAuthFlowError(
|
||||
"An unexpected error occurred during registration",
|
||||
"INTERNAL_ERROR",
|
||||
500,
|
||||
)
|
||||
Reference in New Issue
Block a user