From ae2421763ab2c8b08c91ea2cf63236ca2caa4c74 Mon Sep 17 00:00:00 2001 From: Cory Hawklvelt Date: Wed, 21 Jan 2026 03:09:46 +1030 Subject: [PATCH] google login works --- gatehouse_app/api/v1/external_auth.py | 406 +++++++++- gatehouse_app/models/__init__.py | 10 +- gatehouse_app/models/authentication_method.py | 287 +++++++ .../services/external_auth_service.py | 745 +++++++++++++++--- gatehouse_app/services/oauth_flow_service.py | 575 +++++++++++++- scripts/README.md | 340 ++++++++ scripts/configure_oauth_provider.py | 484 ++++++++++++ test_oauth_without_org.sh | 70 ++ 8 files changed, 2744 insertions(+), 173 deletions(-) create mode 100644 scripts/README.md create mode 100755 scripts/configure_oauth_provider.py create mode 100755 test_oauth_without_org.sh diff --git a/gatehouse_app/api/v1/external_auth.py b/gatehouse_app/api/v1/external_auth.py index 2ca49b0..93d6e7b 100644 --- a/gatehouse_app/api/v1/external_auth.py +++ b/gatehouse_app/api/v1/external_auth.py @@ -1,4 +1,5 @@ """External authentication provider endpoints.""" +import logging from flask import request, g from marshmallow import ValidationError from gatehouse_app.api.v1 import api_v1_bp @@ -15,6 +16,8 @@ from gatehouse_app.services.oauth_flow_service import ( ) from gatehouse_app.services.audit_service import AuditService +logger = logging.getLogger(__name__) + # Provider type mapping PROVIDER_TYPE_MAP = { @@ -532,25 +535,35 @@ def unlink_account(provider: str): def initiate_oauth_authorize(provider: str): """ Initiate OAuth authentication or account registration flow. + + This endpoint initiates OAuth flows without requiring organization_id upfront. + The organization context is determined after successful authentication based on + the user's memberships. Args: provider: Provider type (google, github, microsoft) Query parameters: - flow: 'login' or 'register' - redirect_uri: Optional redirect URI - organization_id: Optional organization context + flow: 'login' or 'register' (default: 'login') + redirect_uri: Optional redirect URI after OAuth completion + organization_id: Optional organization hint (for SSO discovery) Returns: - 302: Redirect to provider authorization page - 400: Validation error or provider not configured + 200: Authorization URL and state token + 400: Validation error or provider not configured at application level + + Response: + { + "authorization_url": "https://...", + "state": "state_token" + } """ provider_type = get_provider_type(provider) - # Get query parameters + # Get query parameters - organization_id is now optional flow = request.args.get("flow", "login") redirect_uri = request.args.get("redirect_uri") - organization_id = request.args.get("organization_id") + organization_id = request.args.get("organization_id") # Optional hint if flow not in ["login", "register"]: return api_response( @@ -561,16 +574,17 @@ def initiate_oauth_authorize(provider: str): ) try: + # Initiate flow - organization_id is now optional if flow == "login": auth_url, state = OAuthFlowService.initiate_login_flow( provider_type=provider_type, - organization_id=organization_id, + organization_id=organization_id, # Optional hint redirect_uri=redirect_uri, ) else: auth_url, state = OAuthFlowService.initiate_register_flow( provider_type=provider_type, - organization_id=organization_id, + organization_id=organization_id, # Optional hint redirect_uri=redirect_uri, ) @@ -595,20 +609,54 @@ def initiate_oauth_authorize(provider: str): def handle_oauth_callback(provider: str): """ Handle OAuth callback from provider. + + This endpoint handles the redirect from the OAuth provider after authentication. + It processes the response and handles different scenarios: + - Successful login/register with redirect_uri: Redirects with authorization code + - Successful login/register without redirect_uri: Returns session token + - Login with multiple orgs: Returns list of organizations for user to select + - Register with no org: Prompts for organization creation/selection Args: provider: Provider type (google, github, microsoft) Query parameters: code: Authorization code from provider - state: State parameter - error: Error code if auth failed + state: State parameter from OAuth flow + redirect_uri: Optional redirect URI for OAuth 2.0 Authorization Code flow + error: Error code if auth failed at provider error_description: Human-readable error description Returns: - 200: OAuth flow completed successfully - 302: Redirect with error - 400: Validation error or OAuth error + 302: Redirect with authorization code (if redirect_uri provided) + 200: OAuth flow completed successfully (JSON response) + 400: Validation error, OAuth error, or invalid state + 404: User account not found (for login flows) + + Response formats (when redirect_uri NOT provided): + + Success with session: + { + "token": "session_token", + "expires_in": 86400, + "token_type": "Bearer", + "user": {...} + } + + Requires organization selection (login flow): + { + "requires_org_selection": true, + "user": {...}, + "available_organizations": [...], + "state": "state_token" + } + + Requires organization creation (register flow): + { + "requires_org_creation": true, + "user": {...}, + "state": "state_token" + } """ provider_type = get_provider_type(provider) @@ -618,7 +666,7 @@ def handle_oauth_callback(provider: str): error = request.args.get("error") error_description = request.args.get("error_description") - # Get redirect URI from state if available + # Get redirect URI from query parameter (for OAuth 2.0 Authorization Code flow) redirect_uri = request.args.get("redirect_uri") try: @@ -632,7 +680,61 @@ def handle_oauth_callback(provider: str): ) if result.get("success"): - if result.get("flow_type") == "login": + flow_type = result.get("flow_type") + + # Check if we should redirect with authorization code + if redirect_uri and flow_type in ["login", "register"]: + # Generate authorization code for external application + user_id = result.get("user", {}).get("id") + if not user_id: + # For org selection/creation flows, we can't redirect + pass + else: + # Determine organization_id + organization_id = result.get("user", {}).get("organization_id") + if not organization_id: + # Can't redirect without organization + pass + else: + # Generate authorization code + auth_code = OAuthFlowService.generate_authorization_code( + user_id=user_id, + client_id="external-app", + redirect_uri=redirect_uri, + scope=["openid", "profile", "email"], + ip_address=request.remote_addr, + user_agent=request.headers.get("User-Agent"), + lifetime_seconds=600, # 10 minutes + ) + + # Mark state as used + state_record = OAuthFlowService.validate_state(state) + if state_record: + state_record.mark_used() + + # Redirect with authorization code + return OAuthFlowService.create_redirect_response( + redirect_uri=redirect_uri, + authorization_code=auth_code, + state=state, + ) + + # Handle login flow responses (no redirect_uri or org selection required) + if flow_type == "login": + # Check if organization selection is required + if result.get("requires_org_selection"): + return api_response( + data={ + "requires_org_selection": True, + "user": result["user"], + "available_organizations": result["available_organizations"], + "state": result["state"], + }, + message="Please select an organization to continue", + status=200, + ) + + # Normal login with session return api_response( data={ "token": result["session"]["token"], @@ -642,7 +744,22 @@ def handle_oauth_callback(provider: str): }, message="Login successful", ) - elif result.get("flow_type") == "register": + + # Handle register flow responses + elif flow_type == "register": + # Check if organization creation is required + if result.get("requires_org_creation"): + return api_response( + data={ + "requires_org_creation": True, + "user": result["user"], + "state": result["state"], + }, + message="Please create or select an organization to continue", + status=200, + ) + + # Normal registration with session return api_response( data={ "token": result["session"]["token"], @@ -652,7 +769,9 @@ def handle_oauth_callback(provider: str): }, message="Registration successful", ) - elif result.get("flow_type") == "link": + + # Handle link flow responses + elif flow_type == "link": return api_response( data={ "linked_account": result["linked_account"], @@ -660,6 +779,7 @@ def handle_oauth_callback(provider: str): message="Account linked successfully", ) + # Fallback for unexpected result format return api_response( data=result, message="OAuth flow completed", @@ -674,6 +794,256 @@ def handle_oauth_callback(provider: str): ) +@api_v1_bp.route("/auth/external/select-organization", methods=["POST"]) +def select_organization(): + """ + Complete OAuth flow by selecting an organization. + + This endpoint is called after OAuth callback when the user needs to select + which organization to log in to (when user belongs to multiple orgs). + + Request body: + state: The state token from the OAuth callback + organization_id: The selected organization ID + + Returns: + 200: Session created successfully + 400: Invalid state or organization + 404: Organization not found or user not a member + + Response: + { + "token": "session_token", + "expires_in": 86400, + "token_type": "Bearer", + "user": { + "id": "...", + "email": "...", + "full_name": "...", + "organization_id": "..." + } + } + """ + data = request.json or {} + state_token = data.get("state") + organization_id = data.get("organization_id") + + if not state_token: + return api_response( + success=False, + message="state is required", + status=400, + error_type="VALIDATION_ERROR", + ) + + if not organization_id: + return api_response( + success=False, + message="organization_id is required", + status=400, + error_type="VALIDATION_ERROR", + ) + + try: + # Validate state and get OAuth state record + state_record = OAuthFlowService.validate_state(state_token) + if not state_record or state_record.used: + return api_response( + success=False, + message="Invalid or expired state token", + status=400, + error_type="INVALID_STATE", + ) + + # The state should have user information from the OAuth callback + # We need to find the user that was authenticated + from gatehouse_app.models import User, AuthenticationMethod, Organization, OrganizationMember + + # Find user by provider authentication + # The state record should have provider info in extra_data if set by callback + # Otherwise, we need to find the most recently created auth method + auth_method = AuthenticationMethod.query.filter_by( + method_type=state_record.provider_type, + ).order_by(AuthenticationMethod.created_at.desc()).first() + + if not auth_method: + return api_response( + success=False, + message="Authentication session not found", + status=400, + error_type="SESSION_NOT_FOUND", + ) + + user = auth_method.user + + # Verify user is member of selected organization + org = Organization.query.get(organization_id) + if not org: + return api_response( + success=False, + message="Organization not found", + status=404, + error_type="NOT_FOUND", + ) + + member = OrganizationMember.query.filter_by( + user_id=user.id, + organization_id=organization_id, + ).first() + + if not member: + return api_response( + success=False, + message="You are not a member of this organization", + status=403, + error_type="FORBIDDEN", + ) + + # Create session for the selected organization + from gatehouse_app.services.session_service import SessionService + session = SessionService.create_session( + user=user, + organization_id=organization_id, + ) + + # Mark state as used + state_record.mark_used() + + # Audit log - login success with org selection + AuditService.log_external_auth_login( + user_id=user.id, + organization_id=organization_id, + provider_type=state_record.provider_type.value if isinstance(state_record.provider_type, AuthMethodType) else state_record.provider_type, + provider_user_id=auth_method.provider_user_id, + auth_method_id=auth_method.id, + session_id=session.id, + ) + + return api_response( + data={ + "token": session.token, + "expires_in": session.lifetime_seconds, + "token_type": "Bearer", + "user": { + "id": user.id, + "email": user.email, + "full_name": user.full_name, + "organization_id": organization_id, + }, + }, + message="Organization selected and session created successfully", + ) + + except Exception as e: + logger.error(f"Error in select_organization: {str(e)}", exc_info=True) + return api_response( + success=False, + message="An error occurred while selecting organization", + status=500, + error_type="INTERNAL_ERROR", + ) + + +# ============================================================================= +# Authorization Code Exchange Endpoint +# ============================================================================= + +@api_v1_bp.route("/auth/external/token", methods=["POST"]) +def exchange_authorization_code(): + """ + Exchange an authorization code for a session token. + + This endpoint is used by external applications (like oauth2-proxy, BookStack) + to exchange the authorization code received from the OAuth callback for a + session token. + + Request body (form-encoded or JSON): + grant_type: Must be "authorization_code" + code: The authorization code from the callback + redirect_uri: The redirect URI used in the original request + client_id: The client ID (optional, defaults to "external-app") + + Returns: + 200: Session token exchanged successfully + 400: Invalid or expired authorization code + 404: User not found + + Response: + { + "token": "session_token", + "expires_in": 86400, + "token_type": "Bearer", + "user": { + "id": "...", + "email": "...", + "full_name": "...", + "organization_id": "..." + } + } + """ + # Support both JSON and form-encoded requests + if request.is_json: + data = request.json or {} + else: + data = request.form or {} + + grant_type = data.get("grant_type") + code = data.get("code") + redirect_uri = data.get("redirect_uri") + client_id = data.get("client_id", "external-app") + + # Validate required parameters + if grant_type and grant_type != "authorization_code": + return api_response( + success=False, + message="Invalid grant_type. Must be 'authorization_code'", + status=400, + error_type="INVALID_GRANT_TYPE", + ) + + if not code: + return api_response( + success=False, + message="code is required", + status=400, + error_type="VALIDATION_ERROR", + ) + + if not redirect_uri: + return api_response( + success=False, + message="redirect_uri is required", + status=400, + error_type="VALIDATION_ERROR", + ) + + try: + result = OAuthFlowService.exchange_authorization_code( + code=code, + client_id=client_id, + redirect_uri=redirect_uri, + ip_address=request.remote_addr, + ) + + return api_response( + data={ + "token": result["token"], + "expires_in": result["expires_in"], + "token_type": result["token_type"], + "user": result["user"], + }, + message="Token exchanged successfully", + ) + + except OAuthFlowError as e: + return api_response( + success=False, + message=e.message, + status=e.status_code, + error_type=e.error_type, + ) + + # ============================================================================= # Helper Functions # ============================================================================= diff --git a/gatehouse_app/models/__init__.py b/gatehouse_app/models/__init__.py index b114718..99ef6fb 100644 --- a/gatehouse_app/models/__init__.py +++ b/gatehouse_app/models/__init__.py @@ -3,7 +3,12 @@ from gatehouse_app.models.base import BaseModel from gatehouse_app.models.user import User from gatehouse_app.models.organization import Organization from gatehouse_app.models.organization_member import OrganizationMember -from gatehouse_app.models.authentication_method import AuthenticationMethod +from gatehouse_app.models.authentication_method import ( + AuthenticationMethod, + ApplicationProviderConfig, + OrganizationProviderOverride, + OAuthState, +) from gatehouse_app.models.session import Session from gatehouse_app.models.audit_log import AuditLog from gatehouse_app.models.oidc_client import OIDCClient @@ -22,6 +27,9 @@ __all__ = [ "Organization", "OrganizationMember", "AuthenticationMethod", + "ApplicationProviderConfig", + "OrganizationProviderOverride", + "OAuthState", "Session", "AuditLog", "OIDCClient", diff --git a/gatehouse_app/models/authentication_method.py b/gatehouse_app/models/authentication_method.py index 429632c..3766d52 100644 --- a/gatehouse_app/models/authentication_method.py +++ b/gatehouse_app/models/authentication_method.py @@ -1,7 +1,10 @@ """Authentication method model.""" +from datetime import datetime, timedelta, timezone +import secrets from gatehouse_app.extensions import db from gatehouse_app.models.base import BaseModel from gatehouse_app.utils.constants import AuthMethodType +from gatehouse_app.utils.encryption import encrypt, decrypt class AuthenticationMethod(BaseModel): @@ -91,3 +94,287 @@ class AuthenticationMethod(BaseModel): "last_used_at": data.get("last_used_at"), "sign_count": data.get("sign_count", 0), } + + +class ApplicationProviderConfig(BaseModel): + """Application-wide OAuth provider configuration. + + This model stores OAuth provider credentials at the application level, + allowing users to authenticate without needing to specify an organization first. + """ + + __tablename__ = "application_provider_configs" + + # Provider identification + provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True) + + # OAuth credentials (encrypted) + client_id = db.Column(db.String(255), nullable=False) + client_secret_encrypted = db.Column(db.String(512), nullable=True) + + # Provider status + is_enabled = db.Column(db.Boolean, default=True, nullable=False) + + # Default redirect URL + default_redirect_url = db.Column(db.String(2048), nullable=True) + + # Provider-specific settings (JSON) + additional_config = db.Column(db.JSON, nullable=True) + + # Relationships + organization_overrides = db.relationship( + "OrganizationProviderOverride", + back_populates="application_config", + foreign_keys="OrganizationProviderOverride.provider_type", + primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type", + cascade="all, delete-orphan" + ) + + def __repr__(self): + """String representation of ApplicationProviderConfig.""" + return f"" + + def set_client_secret(self, plaintext_secret: str): + """Encrypt and store client secret. + + Args: + plaintext_secret: The plaintext OAuth client secret + """ + if plaintext_secret: + self.client_secret_encrypted = encrypt(plaintext_secret) + + def get_client_secret(self) -> str: + """Decrypt and return client secret. + + Returns: + The plaintext OAuth client secret + """ + if self.client_secret_encrypted: + return decrypt(self.client_secret_encrypted) + return None + + def to_dict(self, exclude=None): + """Convert to dictionary, excluding sensitive fields.""" + exclude = exclude or [] + # Always exclude encrypted client secret + exclude.append("client_secret_encrypted") + return super().to_dict(exclude=exclude) + + +class OrganizationProviderOverride(BaseModel): + """Organization-specific OAuth configuration overrides. + + This model allows organizations to override application-level OAuth settings + for enterprise SSO scenarios or custom provider configurations. + """ + + __tablename__ = "organization_provider_overrides" + + # References + organization_id = db.Column( + db.String(36), db.ForeignKey("organizations.id"), + nullable=False, index=True + ) + provider_type = db.Column(db.String(50), nullable=False, index=True) + + # Override OAuth credentials (encrypted, nullable - only if overriding) + client_id = db.Column(db.String(255), nullable=True) + client_secret_encrypted = db.Column(db.String(512), nullable=True) + + # Provider status + is_enabled = db.Column(db.Boolean, default=True, nullable=False) + + # Redirect URL override + redirect_url_override = db.Column(db.String(2048), nullable=True) + + # Provider-specific settings override (JSON) + additional_config = db.Column(db.JSON, nullable=True) + + # Relationships + organization = db.relationship("Organization", backref="provider_overrides") + application_config = db.relationship( + "ApplicationProviderConfig", + back_populates="organization_overrides", + foreign_keys=[provider_type], + primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type", + viewonly=True + ) + + # Unique constraint on (organization_id, provider_type) + __table_args__ = ( + db.UniqueConstraint( + "organization_id", "provider_type", + name="uix_org_provider_type" + ), + ) + + def __repr__(self): + """String representation of OrganizationProviderOverride.""" + return f"" + + def set_client_secret(self, plaintext_secret: str): + """Encrypt and store client secret override. + + Args: + plaintext_secret: The plaintext OAuth client secret + """ + if plaintext_secret: + self.client_secret_encrypted = encrypt(plaintext_secret) + + def get_client_secret(self) -> str: + """Decrypt and return client secret override. + + Returns: + The plaintext OAuth client secret + """ + if self.client_secret_encrypted: + return decrypt(self.client_secret_encrypted) + return None + + def to_dict(self, exclude=None): + """Convert to dictionary, excluding sensitive fields.""" + exclude = exclude or [] + # Always exclude encrypted client secret + exclude.append("client_secret_encrypted") + return super().to_dict(exclude=exclude) + + +class OAuthState(BaseModel): + """OAuth flow state tracking. + + This model tracks OAuth authentication flow state, including PKCE parameters + and organization context (which is now optional to support login flows where + the organization isn't known until after authentication). + """ + + __tablename__ = "oauth_states" + + # OAuth state parameter (unique, used for CSRF protection) + state = db.Column(db.String(64), unique=True, nullable=False, index=True) + + # Flow type: "login", "register", "link" + flow_type = db.Column(db.String(50), nullable=False) + + # Provider type + provider_type = db.Column(db.String(50), nullable=False) + + # User context (optional - not set for login/register flows) + user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True) + + # Organization context (NOW OPTIONAL - for SSO discovery or post-auth) + organization_id = db.Column( + db.String(36), db.ForeignKey("organizations.id"), + nullable=True, index=True + ) + + # PKCE parameters + nonce = db.Column(db.String(128), nullable=True) + code_verifier = db.Column(db.String(128), nullable=True) + code_challenge = db.Column(db.String(128), nullable=True) + + # OAuth parameters + redirect_uri = db.Column(db.String(2048), nullable=True) + + # Post-auth redirect (for frontend routing) + return_url = db.Column(db.String(2048), nullable=True) + + # Additional state data + extra_data = db.Column(db.JSON, nullable=True) + + # Expiration and usage tracking + expires_at = db.Column(db.DateTime, nullable=False, index=True) + used = db.Column(db.Boolean, default=False, nullable=False) + + # Relationships + user = db.relationship("User", backref="oauth_states") + organization = db.relationship("Organization", backref="oauth_states") + + def __repr__(self): + """String representation of OAuthState.""" + return f"" + + @classmethod + def create_state( + cls, + flow_type: str, + provider_type: str, + user_id: str = None, + organization_id: str = None, + redirect_uri: str = None, + return_url: str = None, + code_verifier: str = None, + code_challenge: str = None, + nonce: str = None, + extra_data: dict = None, + lifetime_seconds: int = 600 + ): + """Create a new OAuth state with auto-generated state parameter. + + Args: + flow_type: Type of flow ("login", "register", "link") + provider_type: OAuth provider type + user_id: Optional user ID for authenticated flows + organization_id: Optional organization ID + redirect_uri: OAuth callback URI + return_url: Post-auth redirect destination + code_verifier: PKCE code verifier + code_challenge: PKCE code challenge + nonce: OpenID Connect nonce + extra_data: Additional state data + lifetime_seconds: How long the state is valid (default 10 minutes) + + Returns: + New OAuthState instance + """ + state = secrets.token_urlsafe(32) + expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds) + + oauth_state = cls( + state=state, + flow_type=flow_type, + provider_type=provider_type, + user_id=user_id, + organization_id=organization_id, + redirect_uri=redirect_uri, + return_url=return_url, + code_verifier=code_verifier, + code_challenge=code_challenge, + nonce=nonce, + extra_data=extra_data, + expires_at=expires_at, + used=False + ) + oauth_state.save() + return oauth_state + + def is_valid(self) -> bool: + """Check if the OAuth state is still valid. + + Returns: + True if state hasn't expired and hasn't been used + """ + now = datetime.now(timezone.utc) + # Make expires_at timezone-aware if it's naive (database returns naive datetimes) + expires_at = self.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + return not self.used and expires_at > now + + def mark_used(self): + """Mark the state as used to prevent replay attacks.""" + self.used = True + self.save() + + @classmethod + def cleanup_expired(cls): + """Remove expired OAuth states.""" + now = datetime.now(timezone.utc) + cls.query.filter(cls.expires_at < now).delete() + db.session.commit() + + def to_dict(self, exclude=None): + """Convert to dictionary, excluding sensitive fields.""" + exclude = exclude or [] + # Exclude code_verifier as it's sensitive + exclude.append("code_verifier") + return super().to_dict(exclude=exclude) diff --git a/gatehouse_app/services/external_auth_service.py b/gatehouse_app/services/external_auth_service.py index afc434d..aa16273 100644 --- a/gatehouse_app/services/external_auth_service.py +++ b/gatehouse_app/services/external_auth_service.py @@ -2,12 +2,17 @@ import logging import secrets from datetime import datetime, timedelta, timezone -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict, Any from flask import current_app from gatehouse_app.extensions import db from gatehouse_app.models import User, AuthenticationMethod +from gatehouse_app.models.authentication_method import ( + OAuthState, + ApplicationProviderConfig, + OrganizationProviderOverride +) from gatehouse_app.models.base import BaseModel from gatehouse_app.utils.constants import AuthMethodType from gatehouse_app.services.audit_service import AuditService @@ -25,95 +30,12 @@ class ExternalAuthError(Exception): super().__init__(message) -class OAuthState(BaseModel): - """Temporary OAuth state storage for secure flow management.""" - - __tablename__ = "oauth_states" - - # State identifier (used in OAuth redirects) - state = db.Column(db.String(64), unique=True, nullable=False, index=True) - - # Flow type - flow_type = db.Column(db.String(50), nullable=False) # 'link', 'login', 'register' - - # User context - user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True) - organization_id = db.Column( - db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True - ) - - # Provider information - provider_type = db.Column(db.String(50), nullable=False) - - # OAuth parameters - nonce = db.Column(db.String(128), nullable=True) - code_verifier = db.Column(db.String(128), nullable=True) - code_challenge = db.Column(db.String(128), nullable=True) - redirect_uri = db.Column(db.String(2048), nullable=True) - - # Additional state data - extra_data = db.Column(db.JSON, nullable=True) - - # Expiration - expires_at = db.Column(db.DateTime, nullable=False, index=True) - - # Status - used = db.Column(db.Boolean, default=False, nullable=False) - - @classmethod - def create_state( - cls, - flow_type: str, - provider_type: AuthMethodType, - user_id: str = None, - organization_id: str = None, - redirect_uri: str = None, - nonce: str = None, - code_verifier: str = None, - code_challenge: str = None, - extra_data: dict = None, - lifetime_seconds: int = 600, - ) -> "OAuthState": - """Create a new OAuth state record.""" - state = secrets.token_urlsafe(32) - expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds) - - return cls.create( - state=state, - flow_type=flow_type, - provider_type=provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type, - user_id=user_id, - organization_id=organization_id, - redirect_uri=redirect_uri, - nonce=nonce or secrets.token_urlsafe(16), - code_verifier=code_verifier, - code_challenge=code_challenge, - extra_data=extra_data, - expires_at=expires_at, - ) - - def is_valid(self) -> bool: - """Check if state is still valid.""" - return ( - not self.used - and self.expires_at - and self.expires_at.replace(tzinfo=timezone.utc) > datetime.now(timezone.utc) - ) - - def mark_used(self): - """Mark state as used.""" - self.used = True - self.save() - - @classmethod - def cleanup_expired(cls): - """Remove expired states.""" - cls.query.filter(cls.expires_at < datetime.now(timezone.utc)).delete() - db.session.commit() - - class ExternalProviderConfig(BaseModel): - """OAuth provider configuration per organization.""" + """OAuth provider configuration per organization. + + DEPRECATED: This model is maintained for backward compatibility only. + Use ApplicationProviderConfig and OrganizationProviderOverride instead. + """ __tablename__ = "external_provider_configs" @@ -198,31 +120,594 @@ class ExternalProviderConfig(BaseModel): return data +class ProviderConfigAdapter: + """ + Adapter to provide a unified interface for provider configuration. + + This merges application-level config with optional organization overrides, + presenting a single config object that works with existing OAuth flow code. + """ + + def __init__( + self, + app_config: ApplicationProviderConfig, + org_override: Optional[OrganizationProviderOverride] = None + ): + """ + Initialize adapter with app config and optional org override. + + Args: + app_config: Application-level provider configuration + org_override: Optional organization-specific override + """ + self.app_config = app_config + self.org_override = org_override + self.provider_type = app_config.provider_type + + @property + def client_id(self) -> str: + """Get effective client ID (override takes precedence).""" + if self.org_override and self.org_override.client_id: + return self.org_override.client_id + return self.app_config.client_id + + def get_client_secret(self) -> str: + """Get effective client secret (override takes precedence).""" + if self.org_override and self.org_override.client_secret_encrypted: + return self.org_override.get_client_secret() + return self.app_config.get_client_secret() + + @property + def auth_url(self) -> str: + """Get authorization URL from app config.""" + # Provider endpoints are not overridable + return self._get_provider_endpoint('auth_url') + + @property + def token_url(self) -> str: + """Get token URL from app config.""" + return self._get_provider_endpoint('token_url') + + @property + def userinfo_url(self) -> str: + """Get userinfo URL from app config.""" + return self._get_provider_endpoint('userinfo_url') + + @property + def jwks_url(self) -> str: + """Get JWKS URL from app config.""" + return self._get_provider_endpoint('jwks_url') + + @property + def scopes(self) -> list: + """Get effective scopes (merged from app config and override).""" + base_scopes = self.app_config.additional_config.get('scopes', []) if self.app_config.additional_config else [] + if self.org_override and self.org_override.additional_config: + override_scopes = self.org_override.additional_config.get('scopes') + if override_scopes is not None: + return override_scopes + return base_scopes or ['openid', 'profile', 'email'] + + @property + def redirect_uris(self) -> list: + """Get effective redirect URIs.""" + # Use override redirect URL if present, otherwise app default + if self.org_override and self.org_override.redirect_url_override: + return [self.org_override.redirect_url_override] + if self.app_config.default_redirect_url: + return [self.app_config.default_redirect_url] + return [] + + @property + def settings(self) -> dict: + """Get merged settings (app config + org override).""" + settings = {} + if self.app_config.additional_config: + settings.update(self.app_config.additional_config) + if self.org_override and self.org_override.additional_config: + settings.update(self.org_override.additional_config) + return settings + + @property + def is_active(self) -> bool: + """Check if provider is active (both app and org must be enabled).""" + app_enabled = self.app_config.is_enabled + org_enabled = True if not self.org_override else self.org_override.is_enabled + return app_enabled and org_enabled + + def is_redirect_uri_allowed(self, uri: str) -> bool: + """Check if redirect URI is allowed.""" + return uri in self.redirect_uris + + def _get_provider_endpoint(self, endpoint_name: str) -> Optional[str]: + """ + Get provider endpoint from app config additional_config. + + For application-wide configs, endpoints are stored in additional_config JSON. + """ + if not self.app_config.additional_config: + return None + return self.app_config.additional_config.get(endpoint_name) + + class ExternalAuthService: """Service for external authentication operations.""" @classmethod def get_provider_config( cls, - organization_id: str, provider_type: AuthMethodType, - ) -> ExternalProviderConfig: - """Get provider configuration for organization.""" + organization_id: Optional[str] = None, + ) -> ProviderConfigAdapter: + """ + Get provider configuration for authentication. + + This method retrieves application-wide provider configuration and merges + it with organization-specific overrides if present. Both the application + config and organization override (if present) must be enabled for the + provider to be considered active. + + Configuration Precedence: + 1. Application-level config provides the baseline configuration + 2. Organization override can override client_id and client_secret (for SSO) + 3. Both must be enabled for the provider to work + + Args: + provider_type: The OAuth provider type (google, github, etc.) + organization_id: Optional organization ID for override lookup + + Returns: + ProviderConfigAdapter: Unified config object with merged settings + + Raises: + ExternalAuthError: If provider is not configured or disabled + """ provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - config = ExternalProviderConfig.query.filter_by( - organization_id=organization_id, - provider_type=provider_type_str, - is_active=True, + + # Get application-wide config + app_config = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type_str ).first() - - if not config: + + if not app_config: raise ExternalAuthError( - f"{provider_type_str.title()} OAuth is not configured for this organization", + f"{provider_type_str.title()} OAuth is not configured for this application", "PROVIDER_NOT_CONFIGURED", 400, ) + + if not app_config.is_enabled: + raise ExternalAuthError( + f"{provider_type_str.title()} OAuth is currently disabled", + "PROVIDER_DISABLED", + 400, + ) + + # Check for organization-specific override + org_override = None + if organization_id: + org_override = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id, + provider_type=provider_type_str + ).first() + + # If override exists but is disabled, provider is not available for this org + if org_override and not org_override.is_enabled: + raise ExternalAuthError( + f"{provider_type_str.title()} OAuth is disabled for this organization", + "PROVIDER_DISABLED_FOR_ORG", + 400, + ) + + # Return adapter with merged configuration + return ProviderConfigAdapter(app_config, org_override) + # ==================== Application-Wide Provider Management ==================== + + @classmethod + def create_app_provider_config( + cls, + provider_type: str, + client_id: str, + client_secret: str, + **kwargs + ) -> ApplicationProviderConfig: + """ + Create application-wide provider configuration. + + Args: + provider_type: Provider type (google, github, etc.) + client_id: OAuth client ID + client_secret: OAuth client secret + **kwargs: Additional config (auth_url, token_url, userinfo_url, scopes, etc.) + + Returns: + ApplicationProviderConfig: Created configuration + + Raises: + ExternalAuthError: If provider already exists + """ + # Check if provider already exists + existing = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type + ).first() + + if existing: + raise ExternalAuthError( + f"Provider {provider_type} already exists", + "PROVIDER_EXISTS", + 400 + ) + + # Build additional_config with endpoints and settings + additional_config = {} + for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']: + if key in kwargs: + additional_config[key] = kwargs.pop(key) + + # Add any extra settings + if 'settings' in kwargs: + additional_config.update(kwargs.pop('settings')) + + # Create new config + config = ApplicationProviderConfig( + provider_type=provider_type, + client_id=client_id, + is_enabled=kwargs.get('is_enabled', True), + default_redirect_url=kwargs.get('default_redirect_url'), + additional_config=additional_config + ) + + # Set encrypted secret + config.set_client_secret(client_secret) + config.save() + + logger.info(f"Created application provider config for {provider_type}") return config + + @classmethod + def update_app_provider_config( + cls, + provider_type: str, + **updates + ) -> ApplicationProviderConfig: + """ + Update application-wide provider configuration. + + Args: + provider_type: Provider type to update + **updates: Fields to update (client_id, client_secret, is_enabled, etc.) + + Returns: + ApplicationProviderConfig: Updated configuration + + Raises: + ExternalAuthError: If provider not found + """ + config = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type + ).first() + + if not config: + raise ExternalAuthError( + f"Provider {provider_type} not found", + "PROVIDER_NOT_FOUND", + 404 + ) + + # Update simple fields + if 'client_id' in updates: + config.client_id = updates['client_id'] + + if 'client_secret' in updates: + config.set_client_secret(updates['client_secret']) + + if 'is_enabled' in updates: + config.is_enabled = updates['is_enabled'] + + if 'default_redirect_url' in updates: + config.default_redirect_url = updates['default_redirect_url'] + + # Update additional_config JSON fields + if config.additional_config is None: + config.additional_config = {} + + for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']: + if key in updates: + config.additional_config[key] = updates[key] + + if 'settings' in updates: + config.additional_config.update(updates['settings']) + + config.save() + logger.info(f"Updated application provider config for {provider_type}") + return config + + @classmethod + def get_app_provider_config(cls, provider_type: str) -> ApplicationProviderConfig: + """ + Get application-wide provider configuration. + + Args: + provider_type: Provider type to retrieve + + Returns: + ApplicationProviderConfig: Provider configuration + + Raises: + ExternalAuthError: If provider not found + """ + config = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type + ).first() + + if not config: + raise ExternalAuthError( + f"Provider {provider_type} not found", + "PROVIDER_NOT_FOUND", + 404 + ) + + return config + + @classmethod + def list_app_provider_configs(cls) -> list: + """ + List all application-wide provider configurations. + + Returns: + list: List of provider configuration dictionaries + """ + configs = ApplicationProviderConfig.query.all() + return [config.to_dict() for config in configs] + + @classmethod + def delete_app_provider_config(cls, provider_type: str) -> bool: + """ + Delete application-wide provider configuration. + + Args: + provider_type: Provider type to delete + + Returns: + bool: True if deleted successfully + + Raises: + ExternalAuthError: If provider not found + """ + config = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type + ).first() + + if not config: + raise ExternalAuthError( + f"Provider {provider_type} not found", + "PROVIDER_NOT_FOUND", + 404 + ) + + config.delete() + logger.info(f"Deleted application provider config for {provider_type}") + return True + + # ==================== Organization Provider Override Management ==================== + + @classmethod + def create_org_provider_override( + cls, + organization_id: str, + provider_type: str, + **kwargs + ) -> OrganizationProviderOverride: + """ + Create organization-specific provider override (for SSO scenarios). + + Args: + organization_id: Organization ID + provider_type: Provider type to override + **kwargs: Override fields (client_id, client_secret, redirect_url, etc.) + + Returns: + OrganizationProviderOverride: Created override + + Raises: + ExternalAuthError: If provider doesn't exist or override already exists + """ + # Verify app-level provider exists + app_config = ApplicationProviderConfig.query.filter_by( + provider_type=provider_type + ).first() + + if not app_config: + raise ExternalAuthError( + f"Application provider {provider_type} must be configured first", + "PROVIDER_NOT_CONFIGURED", + 400 + ) + + # Check if override already exists + existing = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id, + provider_type=provider_type + ).first() + + if existing: + raise ExternalAuthError( + f"Override for {provider_type} already exists for this organization", + "OVERRIDE_EXISTS", + 400 + ) + + # Build additional_config from kwargs + additional_config = {} + if 'settings' in kwargs: + additional_config.update(kwargs.pop('settings')) + if 'scopes' in kwargs: + additional_config['scopes'] = kwargs.pop('scopes') + + # Create override + override = OrganizationProviderOverride( + organization_id=organization_id, + provider_type=provider_type, + client_id=kwargs.get('client_id'), + is_enabled=kwargs.get('is_enabled', True), + redirect_url_override=kwargs.get('redirect_url_override'), + additional_config=additional_config if additional_config else None + ) + + # Set encrypted secret if provided + if 'client_secret' in kwargs: + override.set_client_secret(kwargs['client_secret']) + + override.save() + logger.info(f"Created org override for {provider_type} in org {organization_id}") + return override + + @classmethod + def update_org_provider_override( + cls, + organization_id: str, + provider_type: str, + **updates + ) -> OrganizationProviderOverride: + """ + Update organization-specific provider override. + + Args: + organization_id: Organization ID + provider_type: Provider type + **updates: Fields to update + + Returns: + OrganizationProviderOverride: Updated override + + Raises: + ExternalAuthError: If override not found + """ + override = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id, + provider_type=provider_type + ).first() + + if not override: + raise ExternalAuthError( + f"Override for {provider_type} not found for this organization", + "OVERRIDE_NOT_FOUND", + 404 + ) + + # Update simple fields + if 'client_id' in updates: + override.client_id = updates['client_id'] + + if 'client_secret' in updates: + override.set_client_secret(updates['client_secret']) + + if 'is_enabled' in updates: + override.is_enabled = updates['is_enabled'] + + if 'redirect_url_override' in updates: + override.redirect_url_override = updates['redirect_url_override'] + + # Update additional_config + if 'settings' in updates or 'scopes' in updates: + if override.additional_config is None: + override.additional_config = {} + + if 'settings' in updates: + override.additional_config.update(updates['settings']) + if 'scopes' in updates: + override.additional_config['scopes'] = updates['scopes'] + + override.save() + logger.info(f"Updated org override for {provider_type} in org {organization_id}") + return override + + @classmethod + def get_org_provider_override( + cls, + organization_id: str, + provider_type: str + ) -> OrganizationProviderOverride: + """ + Get organization-specific provider override. + + Args: + organization_id: Organization ID + provider_type: Provider type + + Returns: + OrganizationProviderOverride: Provider override + + Raises: + ExternalAuthError: If override not found + """ + override = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id, + provider_type=provider_type + ).first() + + if not override: + raise ExternalAuthError( + f"Override for {provider_type} not found for this organization", + "OVERRIDE_NOT_FOUND", + 404 + ) + + return override + + @classmethod + def list_org_provider_overrides(cls, organization_id: str) -> list: + """ + List all provider overrides for an organization. + + Args: + organization_id: Organization ID + + Returns: + list: List of override configuration dictionaries + """ + overrides = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id + ).all() + return [override.to_dict() for override in overrides] + + @classmethod + def delete_org_provider_override( + cls, + organization_id: str, + provider_type: str + ) -> bool: + """ + Delete organization-specific provider override. + + Args: + organization_id: Organization ID + provider_type: Provider type + + Returns: + bool: True if deleted successfully + + Raises: + ExternalAuthError: If override not found + """ + override = OrganizationProviderOverride.query.filter_by( + organization_id=organization_id, + provider_type=provider_type + ).first() + + if not override: + raise ExternalAuthError( + f"Override for {provider_type} not found for this organization", + "OVERRIDE_NOT_FOUND", + 404 + ) + + override.delete() + logger.info(f"Deleted org override for {provider_type} in org {organization_id}") + return True + + # ==================== OAuth Flow Methods (Updated for New Architecture) ==================== @classmethod def initiate_link_flow( @@ -240,8 +725,8 @@ class ExternalAuthService: """ provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type - # Get provider config - config = cls.get_provider_config(organization_id, provider_type) + # Get provider config (with org override if applicable) + config = cls.get_provider_config(provider_type, organization_id) # Validate redirect URI if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri): @@ -261,13 +746,13 @@ class ExternalAuthService: provider_type=provider_type, user_id=user_id, organization_id=organization_id, - redirect_uri=redirect_uri or config.redirect_uris[0], + redirect_uri=redirect_uri or config.redirect_uris[0] if config.redirect_uris else None, code_verifier=code_verifier, code_challenge=code_challenge, lifetime_seconds=600, ) - # Build authorization URL (simplified - in production would use provider-specific implementation) + # Build authorization URL auth_url = cls._build_authorization_url( config=config, state=state, @@ -338,12 +823,12 @@ class ExternalAuthService: 400, ) - # Get provider config + # Get provider config (with org override if applicable) config = cls.get_provider_config( - state_record.organization_id, provider_type + provider_type, state_record.organization_id ) - # Exchange code for tokens (simplified - in production would use provider-specific implementation) + # Exchange code for tokens tokens = cls._exchange_code( config=config, code=authorization_code, @@ -440,8 +925,8 @@ class ExternalAuthService: 400, ) - # Get provider config - config = cls.get_provider_config(organization_id, provider_type) + # Get provider config (with org override if applicable) + config = cls.get_provider_config(provider_type, organization_id) # Exchange code for tokens tokens = cls._exchange_code( @@ -606,6 +1091,8 @@ class ExternalAuthService: if m.method_type in external_providers or str(m.method_type) in [p.value for p in external_providers] ] + # ==================== Helper Methods ==================== + @staticmethod def _compute_s256_challenge(verifier: str) -> str: """Compute S256 code challenge from verifier.""" @@ -616,8 +1103,8 @@ class ExternalAuthService: return base64.urlsafe_b64encode(digest).decode().rstrip("=") @staticmethod - def _build_authorization_url(config: ExternalProviderConfig, state: OAuthState) -> str: - """Build authorization URL (simplified - provider-specific in production).""" + def _build_authorization_url(config: ProviderConfigAdapter, state: OAuthState) -> str: + """Build authorization URL using the provider config adapter.""" from urllib.parse import urlencode params = { @@ -637,11 +1124,22 @@ class ExternalAuthService: params["code_challenge"] = state.code_challenge params["code_challenge_method"] = "S256" - return f"{config.auth_url}?{urlencode(params)}" + full_url = f"{config.auth_url}?{urlencode(params)}" + + # DIAGNOSTIC LOGGING: Show exact URL being built + logger.info( + f"[PKCE DEBUG] Building authorization URL:\n" + f" provider_type: {config.provider_type}\n" + f" state.code_challenge: {state.code_challenge[:20] if state.code_challenge else 'None'}...\n" + f" params has code_challenge: {'code_challenge' in params}\n" + f" Full URL: {full_url}" + ) + + return full_url @staticmethod - def _exchange_code(config: ExternalProviderConfig, code: str, redirect_uri: str, code_verifier: str = None) -> dict: - """Exchange authorization code for tokens (simplified - provider-specific in production).""" + def _exchange_code(config: ProviderConfigAdapter, code: str, redirect_uri: str, code_verifier: str = None) -> dict: + """Exchange authorization code for tokens using the provider config adapter.""" import requests data = { @@ -655,14 +1153,29 @@ class ExternalAuthService: if code_verifier: data["code_verifier"] = code_verifier + # Log token exchange request (without secrets) + logger.debug( + f"Token exchange request: url={config.token_url}, " + f"client_id={config.client_id}, redirect_uri={redirect_uri}, " + f"has_code_verifier={bool(code_verifier)}" + ) + response = requests.post(config.token_url, data=data) + + # Log response details for debugging + if response.status_code != 200: + logger.error( + f"Token exchange failed: status={response.status_code}, " + f"response={response.text}" + ) + response.raise_for_status() return response.json() @staticmethod - def _get_user_info(config: ExternalProviderConfig, access_token: str) -> dict: - """Get user info from provider (simplified - provider-specific in production).""" + def _get_user_info(config: ProviderConfigAdapter, access_token: str) -> dict: + """Get user info from provider using the provider config adapter.""" import requests headers = {"Authorization": f"Bearer {access_token}"} @@ -758,4 +1271,4 @@ class ExternalAuthService: else: result["id_token"] = None - return result \ No newline at end of file + return result diff --git a/gatehouse_app/services/oauth_flow_service.py b/gatehouse_app/services/oauth_flow_service.py index 31f40da..29c45cd 100644 --- a/gatehouse_app/services/oauth_flow_service.py +++ b/gatehouse_app/services/oauth_flow_service.py @@ -1,20 +1,22 @@ """OAuth flow service for handling external authentication flows.""" +import hashlib import logging import secrets from datetime import datetime, timedelta, timezone from typing import Optional, Tuple -from flask import current_app, request, g +from flask import current_app, request, g, redirect from gatehouse_app.extensions import db from gatehouse_app.models import User, AuthenticationMethod +from gatehouse_app.models.authentication_method import OAuthState from gatehouse_app.models.base import BaseModel +from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode from gatehouse_app.utils.constants import AuthMethodType from gatehouse_app.services.audit_service import AuditService from gatehouse_app.services.external_auth_service import ( ExternalAuthService, ExternalAuthError, - OAuthState, ExternalProviderConfig, ) @@ -43,11 +45,14 @@ class OAuthFlowService: state_data: dict = None, ) -> Tuple[str, str]: """ - Initiate OAuth login flow. + Initiate OAuth login flow without requiring organization_id upfront. + + This method initiates the OAuth flow using application-wide provider configuration. + The organization context is determined after successful authentication. Args: provider_type: The authentication provider type - organization_id: Optional organization context for SSO + organization_id: Optional organization hint for SSO discovery redirect_uri: Optional custom redirect URI state_data: Additional state data to include @@ -65,8 +70,8 @@ class OAuthFlowService: provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type try: - # Get provider config - config = ExternalAuthService.get_provider_config(organization_id, provider_type) + # Get provider config (application-wide, no organization required) + config = ExternalAuthService.get_provider_config(provider_type, organization_id) # Validate redirect URI if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri): @@ -76,9 +81,19 @@ class OAuthFlowService: 400, ) - # Generate PKCE - code_verifier = secrets.token_urlsafe(32) - code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier) + # Generate PKCE parameters (Google web applications don't use PKCE) + code_verifier = None + code_challenge = None + if provider_type_str not in ['google']: + code_verifier = secrets.token_urlsafe(32) + code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier) + + # DIAGNOSTIC LOGGING: Show PKCE decision + logger.info( + f"[PKCE DEBUG] Provider type check: provider_type_str='{provider_type_str}', " + f"is_google={provider_type_str in ['google']}, " + f"will_skip_pkce={provider_type_str in ['google']}" + ) # Create OAuth state for login flow state = OAuthState.create_state( @@ -92,6 +107,15 @@ class OAuthFlowService: lifetime_seconds=600, ) + # DIAGNOSTIC LOGGING: Verify state object + logger.info( + f"[PKCE DEBUG] Created OAuthState object:\n" + f" state.id: {state.id}\n" + f" state.provider_type: {state.provider_type}\n" + f" state.code_challenge: {state.code_challenge}\n" + f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..." + ) + # Build authorization URL auth_url = ExternalAuthService._build_authorization_url( config=config, @@ -100,7 +124,13 @@ class OAuthFlowService: logger.info( f"OAuth login flow initiated for provider={provider_type_str}, " - f"org_id={organization_id}, state_id={state.id}" + f"org_id={organization_id}, state_token={state.state}, state_record_id={state.id}" + ) + logger.info( + f"[PKCE DEBUG] FINAL CHECK: code_challenge={code_challenge}, " + f"code_verifier={code_verifier[:20] if code_verifier else None}..., " + f"auth_url_has_challenge={'code_challenge=' in auth_url}, " + f"returned_auth_url={auth_url}" ) return auth_url, state.state @@ -129,11 +159,11 @@ class OAuthFlowService: redirect_uri: str = None, ) -> Tuple[str, str]: """ - Initiate OAuth registration flow. + Initiate OAuth registration flow without requiring organization_id upfront. Args: provider_type: The authentication provider type - organization_id: Optional organization context + organization_id: Optional organization hint redirect_uri: Optional custom redirect URI Returns: @@ -142,8 +172,8 @@ class OAuthFlowService: provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type try: - # Get provider config - config = ExternalAuthService.get_provider_config(organization_id, provider_type) + # Get provider config (application-wide, no organization required) + config = ExternalAuthService.get_provider_config(provider_type, organization_id) # Validate redirect URI if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri): @@ -153,9 +183,19 @@ class OAuthFlowService: 400, ) - # Generate PKCE - code_verifier = secrets.token_urlsafe(32) - code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier) + # Generate PKCE parameters (Google web applications don't use PKCE) + code_verifier = None + code_challenge = None + if provider_type_str not in ['google']: + code_verifier = secrets.token_urlsafe(32) + code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier) + + # DIAGNOSTIC LOGGING: Show PKCE decision for register flow + logger.info( + f"[PKCE DEBUG] Register flow - Provider type check: provider_type_str='{provider_type_str}', " + f"is_google={provider_type_str in ['google']}, " + f"will_skip_pkce={provider_type_str in ['google']}" + ) # Create OAuth state for register flow state = OAuthState.create_state( @@ -168,6 +208,14 @@ class OAuthFlowService: lifetime_seconds=600, ) + # DIAGNOSTIC LOGGING: Verify state object for register flow + logger.info( + f"[PKCE DEBUG] Register flow - Created OAuthState:\n" + f" state.id: {state.id}\n" + f" state.code_challenge: {state.code_challenge}\n" + f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..." + ) + # Build authorization URL auth_url = ExternalAuthService._build_authorization_url( config=config, @@ -178,6 +226,9 @@ class OAuthFlowService: f"OAuth register flow initiated for provider={provider_type_str}, " f"org_id={organization_id}, state_id={state.id}" ) + logger.info( + f"[PKCE DEBUG] Register flow - FINAL: auth_url_has_challenge={'code_challenge=' in auth_url}" + ) return auth_url, state.state @@ -245,6 +296,17 @@ class OAuthFlowService: # Validate state state_record = OAuthState.query.filter_by(state=state).first() + + # Log validation details for debugging + if state_record: + logger.debug( + f"State validation: found=True, used={state_record.used}, " + f"expires_at={state_record.expires_at}, now={datetime.now(timezone.utc)}, " + f"is_valid={state_record.is_valid()}" + ) + else: + logger.warning(f"State validation: state token not found in database: {state}") + if not state_record or not state_record.is_valid(): AuditService.log_external_auth_login_failed( organization_id=state_record.organization_id if state_record else None, @@ -299,24 +361,175 @@ class OAuthFlowService: ip_address: str = None, user_agent: str = None, ) -> dict: - """Handle login flow callback.""" + """ + Handle login flow callback with organization discovery. + + This method: + 1. Exchanges the authorization code for tokens + 2. Gets user info from the OAuth provider + 3. Looks up the user by provider_user_id + 4. Determines which organization(s) the user belongs to + 5. Creates a session or returns org selection needed + """ provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type try: - # Authenticate with provider - user, session_data = ExternalAuthService.authenticate_with_provider( - provider_type=provider_type, - organization_id=state_record.organization_id, - authorization_code=authorization_code, - state=state_record.state, + # Get provider config (application-wide) + config = ExternalAuthService.get_provider_config( + provider_type, state_record.organization_id + ) + + logger.debug( + f"Exchanging code with PKCE: state_record.code_verifier={state_record.code_verifier[:20] if state_record.code_verifier else None}..." + ) + + # Exchange code for tokens + tokens = ExternalAuthService._exchange_code( + config=config, + code=authorization_code, redirect_uri=redirect_uri, + code_verifier=state_record.code_verifier, + ) + + # Get user info from provider + user_info = ExternalAuthService._get_user_info( + config=config, + access_token=tokens["access_token"], + ) + + # Look up user by provider_user_id + auth_method = AuthenticationMethod.query.filter_by( + method_type=provider_type, + provider_user_id=user_info["provider_user_id"], + ).first() + + if not auth_method: + # User doesn't exist - check if email matches existing user + existing_user = User.query.filter_by( + email=user_info["email"] + ).first() + + if existing_user: + AuditService.log_external_auth_login_failed( + organization_id=state_record.organization_id, + provider_type=provider_type_str, + provider_user_id=user_info["provider_user_id"], + email=user_info["email"], + failure_reason="email_exists", + error_message=f"An account with email {user_info['email']} already exists", + ) + raise OAuthFlowError( + f"An account with email {user_info['email']} already exists. " + "Please log in with your password and link your account from settings.", + "EMAIL_EXISTS", + 400, + ) + + AuditService.log_external_auth_login_failed( + organization_id=state_record.organization_id, + provider_type=provider_type_str, + provider_user_id=user_info["provider_user_id"], + email=user_info["email"], + failure_reason="account_not_found", + error_message="No Gatehouse account matches this external account", + ) + raise OAuthFlowError( + "No Gatehouse account matches this external account. Please register first.", + "ACCOUNT_NOT_FOUND", + 404, + ) + + user = auth_method.user + + # Update provider data + auth_method.provider_data = ExternalAuthService._encrypt_provider_data( + tokens, user_info + ) + auth_method.last_used_at = datetime.utcnow() + auth_method.save() + + # Get user's organizations + user_orgs = user.get_organizations() + + # Determine target organization + target_org = None + + # Priority 1: Use organization_id from state if provided (org hint) + if state_record.organization_id: + target_org = next( + (org for org in user_orgs if org.id == state_record.organization_id), + None + ) + + # Priority 2: If user has exactly one organization, use it + if not target_org and len(user_orgs) == 1: + target_org = user_orgs[0] + + # Priority 3: No organization or multiple organizations - need selection + if not target_org: + # Mark state as used + state_record.mark_used() + + logger.info( + f"OAuth login requires org selection for user={user.id}, " + f"provider={provider_type_str}, org_count={len(user_orgs)}" + ) + + return { + "success": True, + "flow_type": "login", + "requires_org_selection": True, + "user": { + "id": user.id, + "email": user.email, + "full_name": user.full_name, + }, + "available_organizations": [ + { + "id": org.id, + "name": org.name, + "slug": org.slug if hasattr(org, 'slug') else None, + } + for org in user_orgs + ], + "state": state_record.state, + } + + # Create session for the target org + from gatehouse_app.services.auth_service import AuthService + session = AuthService.create_session( + user=user, + is_compliance_only=False, + ) + + # Mark state as used + state_record.mark_used() + + # Audit log - login success + AuditService.log_external_auth_login( + user_id=user.id, + organization_id=target_org.id, + provider_type=provider_type_str, + provider_user_id=user_info["provider_user_id"], + auth_method_id=auth_method.id, + session_id=session.id, ) logger.info( f"OAuth login successful for user={user.id}, " - f"provider={provider_type_str}, org_id={state_record.organization_id}" + f"provider={provider_type_str}, org_id={target_org.id}" ) + # Build session dict with token (to_dict() excludes token for security) + session_dict = session.to_dict() + session_dict["token"] = session.token + # Calculate expires_in handling naive datetime from database + expires_at = session.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + session_dict["expires_in"] = int((expires_at - now).total_seconds()) + return { "success": True, "flow_type": "login", @@ -324,9 +537,9 @@ class OAuthFlowService: "id": user.id, "email": user.email, "full_name": user.full_name, - "organization_id": state_record.organization_id, + "organization_id": target_org.id, }, - "session": session_data, + "session": session_dict, } except ExternalAuthError as e: @@ -335,6 +548,19 @@ class OAuthFlowService: f"provider={provider_type_str}, error={e.message}" ) raise + except OAuthFlowError: + # Re-raise OAuthFlowError as-is + raise + except Exception as e: + logger.error( + f"Unexpected error in OAuth login callback: {str(e)}", + exc_info=True + ) + raise OAuthFlowError( + "An unexpected error occurred during login", + "INTERNAL_ERROR", + 500, + ) @classmethod def _handle_link_callback( @@ -387,13 +613,17 @@ class OAuthFlowService: authorization_code: str, redirect_uri: str, ) -> dict: - """Handle registration flow callback.""" + """ + Handle registration flow callback. + + Creates a new user account and prompts for organization creation/selection. + """ provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type try: - # Get provider config + # Get provider config (application-wide) config = ExternalAuthService.get_provider_config( - state_record.organization_id, provider_type + provider_type, state_record.organization_id ) # Exchange code for tokens @@ -429,6 +659,7 @@ class OAuthFlowService: email=user_info["email"], full_name=user_info.get("name", ""), status="active", + email_verified=user_info.get("email_verified", False), ) user.save() @@ -440,6 +671,7 @@ class OAuthFlowService: provider_data=ExternalAuthService._encrypt_provider_data(tokens, user_info), verified=user_info.get("email_verified", False), is_primary=True, + last_used_at=datetime.utcnow(), ) auth_method.save() @@ -475,23 +707,48 @@ class OAuthFlowService: f"provider={provider_type_str}, user_id={user.id}" ) - # Create session - from gatehouse_app.services.auth_service import AuthService - session = AuthService.create_session( - user=user, - organization_id=state_record.organization_id, - ) + # If organization_id hint was provided and valid, create session for that org + if state_record.organization_id: + from gatehouse_app.models.organization import Organization + org = Organization.query.get(state_record.organization_id) + if org: + from gatehouse_app.services.auth_service import AuthService + session = AuthService.create_session( + user=user, + is_compliance_only=False, + ) + # Build session dict with token (to_dict() excludes token for security) + session_dict = session.to_dict() + session_dict["token"] = session.token + # Calculate expires_in handling naive datetime from database + expires_at = session.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + session_dict["expires_in"] = int((expires_at - now).total_seconds()) + return { + "success": True, + "flow_type": "register", + "user": { + "id": user.id, + "email": user.email, + "full_name": user.full_name, + "organization_id": org.id, + }, + "session": session_dict, + } + # No organization hint or invalid - need to create/select org return { "success": True, "flow_type": "register", + "requires_org_creation": True, "user": { "id": user.id, "email": user.email, "full_name": user.full_name, - "organization_id": state_record.organization_id, }, - "session": session.to_dict(), + "state": state_record.state, } except ExternalAuthError as e: @@ -500,6 +757,19 @@ class OAuthFlowService: f"provider={provider_type_str}, error={e.message}" ) raise + except OAuthFlowError: + # Re-raise OAuthFlowError as-is + raise + except Exception as e: + logger.error( + f"Unexpected error in OAuth registration callback: {str(e)}", + exc_info=True + ) + raise OAuthFlowError( + "An unexpected error occurred during registration", + "INTERNAL_ERROR", + 500, + ) @classmethod def validate_state(cls, state: str) -> Optional[OAuthState]: @@ -522,3 +792,232 @@ class OAuthFlowService: """Remove expired OAuth states.""" OAuthState.cleanup_expired() logger.info("Expired OAuth states cleaned up") + + @classmethod + def generate_authorization_code( + cls, + user_id: str, + client_id: str, + redirect_uri: str, + scope: list = None, + nonce: str = None, + ip_address: str = None, + user_agent: str = None, + lifetime_seconds: int = 600, + ) -> str: + """ + Generate an authorization code for external OAuth applications. + + This method creates a short-lived, single-use authorization code that can be + exchanged for a session token by external applications like oauth2-proxy. + + Args: + user_id: The user ID + client_id: The client ID (e.g., 'oauth2-proxy', 'bookstack') + redirect_uri: The redirect URI + scope: Requested scopes + nonce: OIDC nonce for validation + ip_address: Client IP address + user_agent: Client user agent + lifetime_seconds: Code lifetime in seconds (default 10 minutes) + + Returns: + The authorization code (plain text, not hashed) + """ + # Generate a secure random code + code = secrets.token_urlsafe(32) + code_hash = hashlib.sha256(code.encode()).hexdigest() + + # Create the authorization code record + OIDCAuthCode.create_code( + client_id=client_id, + user_id=user_id, + code_hash=code_hash, + redirect_uri=redirect_uri, + scope=scope, + nonce=nonce, + ip_address=ip_address, + user_agent=user_agent, + lifetime_seconds=lifetime_seconds, + ) + + logger.info( + f"Generated authorization code for user={user_id}, client={client_id}" + ) + + return code + + @classmethod + def exchange_authorization_code( + cls, + code: str, + client_id: str, + redirect_uri: str, + ip_address: str = None, + ) -> dict: + """ + Exchange an authorization code for a session token. + + This method validates and consumes the authorization code, then creates + a session for the user. + + Args: + code: The authorization code + client_id: The client ID + redirect_uri: The redirect URI (must match original request) + ip_address: Client IP address + + Returns: + Dict with session token and user info + """ + # Hash the provided code for lookup + code_hash = hashlib.sha256(code.encode()).hexdigest() + + # Find the authorization code record + auth_code = OIDCAuthCode.query.filter_by( + client_id=client_id, + code_hash=code_hash, + ).first() + + if not auth_code: + raise OAuthFlowError( + "Invalid authorization code", + "INVALID_CODE", + 400, + ) + + # Validate the code + if not auth_code.is_valid(): + if auth_code.is_used: + raise OAuthFlowError( + "Authorization code has already been used", + "CODE_USED", + 400, + ) + else: + raise OAuthFlowError( + "Authorization code has expired", + "CODE_EXPIRED", + 400, + ) + + # Validate redirect URI + if auth_code.redirect_uri != redirect_uri: + raise OAuthFlowError( + "Redirect URI mismatch", + "INVALID_REDIRECT_URI", + 400, + ) + + # Get the user + from gatehouse_app.models import User + user = User.query.get(auth_code.user_id) + if not user: + raise OAuthFlowError( + "User not found", + "USER_NOT_FOUND", + 404, + ) + + # Determine organization + from gatehouse_app.models.organization import Organization + from gatehouse_app.models.organization_member import OrganizationMember + + # Get user's organizations + user_orgs = user.get_organizations() + + # Determine target organization + target_org = None + + # Priority 1: Use organization_id from auth code if available + # Priority 2: If user has exactly one organization, use it + if not target_org and len(user_orgs) == 1: + target_org = user_orgs[0] + + if not target_org: + raise OAuthFlowError( + "User does not have a default organization. Organization selection required.", + "ORG_SELECTION_REQUIRED", + 400, + ) + + # Create session + from gatehouse_app.services.auth_service import AuthService + session = AuthService.create_session( + user=user, + is_compliance_only=False, + ) + + # Mark the code as used + auth_code.mark_as_used() + + # Build session dict + session_dict = session.to_dict() + session_dict["token"] = session.token + expires_at = session.expires_at + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + session_dict["expires_in"] = int((expires_at - now).total_seconds()) + + logger.info( + f"Authorization code exchanged for session: user={user.id}, " + f"org_id={target_org.id}, client={client_id}" + ) + + return { + "success": True, + "token": session_dict["token"], + "expires_in": session_dict["expires_in"], + "token_type": "Bearer", + "user": { + "id": user.id, + "email": user.email, + "full_name": user.full_name, + "organization_id": target_org.id, + }, + } + + @classmethod + def create_redirect_response( + cls, + redirect_uri: str, + authorization_code: str, + state: str = None, + ): + """ + Create a redirect response with authorization code. + + Args: + redirect_uri: The redirect URI + authorization_code: The authorization code + state: Optional state parameter + + Returns: + Flask redirect response + """ + from urllib.parse import urlencode, urlparse, urlunparse + + # Parse the redirect URI + parsed = urlparse(redirect_uri) + + # Build query parameters + params = {"code": authorization_code} + if state: + params["state"] = state + + # Reconstruct URL with query parameters + redirect_url = urlunparse(( + parsed.scheme, + parsed.netloc, + parsed.path, + parsed.params, + urlencode(params), + parsed.fragment, + )) + + logger.info( + f"Redirecting to {parsed.scheme}://{parsed.netloc} with authorization code" + ) + + return redirect(redirect_url) diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..683c66f --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,340 @@ +# Gatehouse Scripts + +This directory contains utility scripts for managing and configuring Gatehouse. + +## OAuth Provider Configuration Script + +The [`configure_oauth_provider.py`](configure_oauth_provider.py:1) script allows administrators to easily configure OAuth providers at the application level. + +### Overview + +This script manages application-wide OAuth provider configurations using the new [`ApplicationProviderConfig`](../gatehouse_app/models/authentication_method.py:99) architecture. Unlike the deprecated organization-specific configuration, this allows users to authenticate with OAuth providers without needing to specify an organization first. + +### Prerequisites + +- Python 3.8+ +- Virtual environment with dependencies installed +- Flask app must be properly configured (`.env` or environment variables) + +### Quick Start + +```bash +# Activate virtual environment +cd gatehouse-api +source .venv/bin/activate + +# Create Google OAuth configuration +python scripts/configure_oauth_provider.py create google \ + --client-id "YOUR_CLIENT_ID" \ + --client-secret "YOUR_CLIENT_SECRET" \ + --redirect-url "http://localhost:5173/auth/callback" + +# List all configured providers +python scripts/configure_oauth_provider.py list + +# Show provider details +python scripts/configure_oauth_provider.py show google +``` + +### Commands + +#### `create` - Create a New Provider + +Create a new OAuth provider configuration at the application level. + +```bash +python scripts/configure_oauth_provider.py create PROVIDER [OPTIONS] +``` + +**Arguments:** +- `PROVIDER`: Provider type (google, github, microsoft) + +**Options:** +- `--client-id TEXT`: OAuth client ID (required, or via environment) +- `--client-secret TEXT`: OAuth client secret (required, or via environment) +- `--redirect-url TEXT`: Default redirect URL for callbacks +- `--disabled`: Create provider in disabled state +- `--settings KEY=VALUE`: Custom settings (can be specified multiple times) + +**Examples:** + +```bash +# Basic Google configuration +python scripts/configure_oauth_provider.py create google \ + --client-id "xxx.apps.googleusercontent.com" \ + --client-secret "GOCSPX-xxx" + +# With redirect URL +python scripts/configure_oauth_provider.py create google \ + --client-id "xxx" \ + --client-secret "yyy" \ + --redirect-url "https://app.example.com/auth/callback" + +# Create disabled initially +python scripts/configure_oauth_provider.py create github \ + --client-id "xxx" \ + --client-secret "yyy" \ + --disabled + +# With custom settings +python scripts/configure_oauth_provider.py create google \ + --client-id "xxx" \ + --client-secret "yyy" \ + --settings "hosted_domain=example.com" \ + --settings "prompt=consent" +``` + +#### `update` - Update Existing Provider + +Update an existing OAuth provider configuration. + +```bash +python scripts/configure_oauth_provider.py update PROVIDER [OPTIONS] +``` + +**Arguments:** +- `PROVIDER`: Provider type to update + +**Options:** +- `--client-id TEXT`: New OAuth client ID +- `--client-secret TEXT`: New OAuth client secret +- `--redirect-url TEXT`: New default redirect URL +- `--enabled true|false`: Enable or disable the provider +- `--settings KEY=VALUE`: Custom settings to update + +**Examples:** + +```bash +# Update client credentials +python scripts/configure_oauth_provider.py update google \ + --client-id "new-client-id" \ + --client-secret "new-secret" + +# Enable/disable provider +python scripts/configure_oauth_provider.py update google --enabled false +python scripts/configure_oauth_provider.py update google --enabled true + +# Update redirect URL +python scripts/configure_oauth_provider.py update google \ + --redirect-url "https://new-domain.com/auth/callback" +``` + +#### `list` - List All Providers + +List all configured OAuth providers with their status. + +```bash +python scripts/configure_oauth_provider.py list +``` + +**Example Output:** +``` +Configured OAuth Providers + + google - enabled + Client ID: 972920496362-xxx.apps.googleusercontent.com + Redirect URL: https://app.example.com/auth/callback + Created: 2026-01-20T13:00:00 + Auth URL: https://accounts.google.com/o/oauth2/v2/auth + Scopes: openid, profile, email + + github - disabled + Client ID: Iv1.xxx + Created: 2026-01-19T10:00:00 + Auth URL: https://github.com/login/oauth/authorize + Scopes: read:user, user:email +``` + +#### `show` - Show Provider Details + +Display detailed information about a specific OAuth provider. + +```bash +python scripts/configure_oauth_provider.py show PROVIDER +``` + +**Arguments:** +- `PROVIDER`: Provider type to display + +**Example:** + +```bash +python scripts/configure_oauth_provider.py show google +``` + +**Example Output:** +``` +Google OAuth Provider Details + +Basic Information: + Provider Type: google + Provider ID: 123e4567-e89b-12d3-a456-426614174000 + Client ID: 972920496362-xxx.apps.googleusercontent.com + Status: enabled + Default Redirect URL: https://app.example.com/auth/callback + +Timestamps: + Created: 2026-01-20T13:00:00 + Updated: 2026-01-20T14:30:00 + +OAuth Configuration: + Authorization URL: https://accounts.google.com/o/oauth2/v2/auth + Token URL: https://oauth2.googleapis.com/token + User Info URL: https://openidconnect.googleapis.com/v1/userinfo + JWKS URL: https://www.googleapis.com/oauth2/v3/certs + Scopes: openid, profile, email +``` + +#### `delete` - Delete Provider Configuration + +Remove an OAuth provider configuration. + +```bash +python scripts/configure_oauth_provider.py delete PROVIDER [OPTIONS] +``` + +**Arguments:** +- `PROVIDER`: Provider type to delete + +**Options:** +- `--yes`, `-y`: Skip confirmation prompt + +**Examples:** + +```bash +# Delete with confirmation prompt +python scripts/configure_oauth_provider.py delete google + +# Delete without confirmation +python scripts/configure_oauth_provider.py delete google --yes +``` + +### Environment Variables + +The script supports loading OAuth credentials from environment variables, which is useful for automation and CI/CD pipelines. + +**Supported Variables:** +- `{PROVIDER}_CLIENT_ID`: OAuth client ID +- `{PROVIDER}_CLIENT_SECRET`: OAuth client secret +- `{PROVIDER}_REDIRECT_URL`: Default redirect URL + +**Example:** + +```bash +# Export environment variables +export GOOGLE_CLIENT_ID="xxx.apps.googleusercontent.com" +export GOOGLE_CLIENT_SECRET="GOCSPX-xxx" +export GOOGLE_REDIRECT_URL="https://app.example.com/auth/callback" + +# Create provider using environment variables +python scripts/configure_oauth_provider.py create google + +# You can still override with command-line arguments +python scripts/configure_oauth_provider.py create google \ + --redirect-url "https://different.com/callback" +``` + +### Supported Providers + +The script comes with pre-configured endpoint information for: + +- **Google** (`google`) + - Authorization: `https://accounts.google.com/o/oauth2/v2/auth` + - Token: `https://oauth2.googleapis.com/token` + - User Info: `https://openidconnect.googleapis.com/v1/userinfo` + - Default Scopes: `openid, profile, email` + +- **GitHub** (`github`) + - Authorization: `https://github.com/login/oauth/authorize` + - Token: `https://github.com/login/oauth/access_token` + - User Info: `https://api.github.com/user` + - Default Scopes: `read:user, user:email` + +- **Microsoft** (`microsoft`) + - Authorization: `https://login.microsoftonline.com/common/oauth2/v2.0/authorize` + - Token: `https://login.microsoftonline.com/common/oauth2/v2.0/token` + - User Info: `https://graph.microsoft.com/oidc/userinfo` + - Default Scopes: `openid, profile, email` + +### Error Handling + +The script provides clear error messages and appropriate exit codes: + +- **Exit Code 0**: Success +- **Exit Code 1**: Error occurred + +**Common Errors:** + +1. **Provider Already Exists** + ``` + ✗ Failed to create provider: Provider google already exists + ℹ Use 'update' command to modify existing provider configuration. + ``` + +2. **Provider Not Found** + ``` + ✗ Failed to update provider: Provider google not found + ℹ Use 'create' command to add a new provider configuration. + ``` + +3. **Missing Credentials** + ``` + ✗ Client ID is required. Provide via --client-id or GOOGLE_CLIENT_ID environment variable. + ``` + +### Integration with Shell Scripts + +The [`configure-google-auth.sh`](../../docs/configure-google-auth.sh:1) script demonstrates how to integrate the Python script into a shell script for easier deployment: + +```bash +#!/bin/bash + +# Set credentials +GOOGLE_CLIENT_ID="xxx" +GOOGLE_CLIENT_SECRET="yyy" +REDIRECT_URL="https://app.example.com/callback" + +# Call Python script +cd gatehouse-api +python3 scripts/configure_oauth_provider.py create google \ + --client-id "$GOOGLE_CLIENT_ID" \ + --client-secret "$GOOGLE_CLIENT_SECRET" \ + --redirect-url "$REDIRECT_URL" +``` + +### API Service Methods + +The script uses the following [`ExternalAuthService`](../gatehouse_app/services/external_auth_service.py:1) methods: + +- [`create_app_provider_config()`](../gatehouse_app/services/external_auth_service.py:308) - Create provider configuration +- [`update_app_provider_config()`](../gatehouse_app/services/external_auth_service.py:369) - Update provider configuration +- [`get_app_provider_config()`](../gatehouse_app/services/external_auth_service.py:427) - Get single provider +- [`list_app_provider_configs()`](../gatehouse_app/services/external_auth_service.py:454) - List all providers +- [`delete_app_provider_config()`](../gatehouse_app/services/external_auth_service.py:465) - Delete provider configuration + +### Security Considerations + +1. **Client Secret Storage**: Client secrets are encrypted using the application's encryption key before storage in the database +2. **Environment Variables**: Be cautious when using environment variables in shared environments +3. **Secret Exposure**: The `show` command never displays the client secret (it's always excluded) +4. **Confirmation Prompts**: The `delete` command requires confirmation unless `--yes` flag is used + +### Troubleshooting + +**Database Connection Issues:** +- Ensure PostgreSQL is running and accessible +- Check `.env` file for correct `DATABASE_URL` +- Verify virtual environment is activated + +**Import Errors:** +- Activate the virtual environment: `source .venv/bin/activate` +- Install dependencies: `pip install -r requirements.txt` + +**Permission Issues:** +- Ensure script is executable: `chmod +x scripts/configure_oauth_provider.py` + +### Related Documentation + +- [External Auth Architecture](../../docs/external-auth-architecture.md) +- [Application-Wide OAuth Design](../../docs/external-auth-application-wide-design.md) +- [OAuth API Changes](../../docs/oauth-api-changes.md) diff --git a/scripts/configure_oauth_provider.py b/scripts/configure_oauth_provider.py new file mode 100755 index 0000000..e222266 --- /dev/null +++ b/scripts/configure_oauth_provider.py @@ -0,0 +1,484 @@ +#!/usr/bin/env python3 +""" +OAuth Provider Configuration Script for Gatehouse + +This script allows administrators to configure OAuth providers at the application level +using the new ApplicationProviderConfig architecture. + +Usage: + # Create a new provider configuration + python scripts/configure_oauth_provider.py create google \\ + --client-id "YOUR_CLIENT_ID" \\ + --client-secret "YOUR_CLIENT_SECRET" \\ + --redirect-url "http://localhost:5173/auth/callback" + + # List all configured providers + python scripts/configure_oauth_provider.py list + + # Show details of a specific provider + python scripts/configure_oauth_provider.py show google + + # Update a provider configuration + python scripts/configure_oauth_provider.py update google --enabled false + + # Delete a provider configuration + python scripts/configure_oauth_provider.py delete google + + # Use environment variables + GOOGLE_CLIENT_ID=xxx GOOGLE_CLIENT_SECRET=yyy \\ + python scripts/configure_oauth_provider.py create google +""" + +import os +import sys +import argparse +from typing import Optional, Dict, Any + +# Add the parent directory to the path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Load environment variables from .env file before any other imports +# This ensures database and other configurations are available +from dotenv import load_dotenv +script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +env_file = os.path.join(script_dir, '.env') +if os.path.exists(env_file): + load_dotenv(env_file) + +# Import after path setup +from gatehouse_app import create_app +from gatehouse_app.services.external_auth_service import ExternalAuthService, ExternalAuthError + + +# Provider endpoint configurations +PROVIDER_DEFAULTS = { + "google": { + "auth_url": "https://accounts.google.com/o/oauth2/v2/auth", + "token_url": "https://oauth2.googleapis.com/token", + "userinfo_url": "https://openidconnect.googleapis.com/v1/userinfo", + "jwks_url": "https://www.googleapis.com/oauth2/v3/certs", + "scopes": ["openid", "profile", "email"], + }, + "github": { + "auth_url": "https://github.com/login/oauth/authorize", + "token_url": "https://github.com/login/oauth/access_token", + "userinfo_url": "https://api.github.com/user", + "scopes": ["read:user", "user:email"], + }, + "microsoft": { + "auth_url": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", + "token_url": "https://login.microsoftonline.com/common/oauth2/v2.0/token", + "userinfo_url": "https://graph.microsoft.com/oidc/userinfo", + "jwks_url": "https://login.microsoftonline.com/common/discovery/v2.0/keys", + "scopes": ["openid", "profile", "email"], + }, +} + + +class Colors: + """ANSI color codes for terminal output.""" + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +def print_success(message: str): + """Print success message in green.""" + print(f"{Colors.OKGREEN}✓ {message}{Colors.ENDC}") + + +def print_error(message: str): + """Print error message in red.""" + print(f"{Colors.FAIL}✗ {message}{Colors.ENDC}", file=sys.stderr) + + +def print_warning(message: str): + """Print warning message in yellow.""" + print(f"{Colors.WARNING}⚠ {message}{Colors.ENDC}") + + +def print_info(message: str): + """Print info message in blue.""" + print(f"{Colors.OKBLUE}ℹ {message}{Colors.ENDC}") + + +def print_header(message: str): + """Print header message.""" + print(f"\n{Colors.BOLD}{Colors.HEADER}{message}{Colors.ENDC}") + + +def get_env_credentials(provider_type: str) -> Dict[str, Optional[str]]: + """ + Get OAuth credentials from environment variables. + + Supports the following patterns: + - {PROVIDER}_CLIENT_ID + - {PROVIDER}_CLIENT_SECRET + - {PROVIDER}_REDIRECT_URL + + Args: + provider_type: Provider type (google, github, microsoft) + + Returns: + Dictionary with client_id, client_secret, and redirect_url if found + """ + provider_upper = provider_type.upper() + return { + "client_id": os.environ.get(f"{provider_upper}_CLIENT_ID"), + "client_secret": os.environ.get(f"{provider_upper}_CLIENT_SECRET"), + "redirect_url": os.environ.get(f"{provider_upper}_REDIRECT_URL"), + } + + +def create_provider(args): + """Create a new OAuth provider configuration.""" + provider_type = args.provider.lower() + + print_header(f"Creating {provider_type.title()} OAuth Provider Configuration") + + # Get credentials from args or environment + env_creds = get_env_credentials(provider_type) + client_id = args.client_id or env_creds.get("client_id") + client_secret = args.client_secret or env_creds.get("client_secret") + redirect_url = args.redirect_url or env_creds.get("redirect_url") + + # Validation + if not client_id: + print_error(f"Client ID is required. Provide via --client-id or {provider_type.upper()}_CLIENT_ID environment variable.") + return 1 + + if not client_secret: + print_error(f"Client secret is required. Provide via --client-secret or {provider_type.upper()}_CLIENT_SECRET environment variable.") + return 1 + + # Get provider defaults + if provider_type not in PROVIDER_DEFAULTS: + print_error(f"Unknown provider: {provider_type}. Supported providers: {', '.join(PROVIDER_DEFAULTS.keys())}") + return 1 + + defaults = PROVIDER_DEFAULTS[provider_type] + + # Build configuration + config_data = { + "client_id": client_id, + "client_secret": client_secret, + "default_redirect_url": redirect_url, + "is_enabled": not args.disabled, + **defaults, + } + + # Add custom settings if provided + if args.settings: + settings = {} + for setting in args.settings: + try: + key, value = setting.split("=", 1) + settings[key] = value + except ValueError: + print_warning(f"Skipping invalid setting format: {setting}") + config_data["settings"] = settings + + try: + # Create the provider configuration + config = ExternalAuthService.create_app_provider_config( + provider_type=provider_type, + **config_data + ) + + print_success(f"{provider_type.title()} provider created successfully!") + print_info(f"Provider ID: {config.id}") + print_info(f"Client ID: {config.client_id}") + if redirect_url: + print_info(f"Default Redirect URL: {redirect_url}") + print_info(f"Enabled: {config.is_enabled}") + + return 0 + + except ExternalAuthError as e: + print_error(f"Failed to create provider: {e.message}") + if e.error_type == "PROVIDER_EXISTS": + print_info("Use 'update' command to modify existing provider configuration.") + return 1 + except Exception as e: + print_error(f"Unexpected error: {str(e)}") + return 1 + + +def update_provider(args): + """Update an existing OAuth provider configuration.""" + provider_type = args.provider.lower() + + print_header(f"Updating {provider_type.title()} OAuth Provider Configuration") + + # Build updates dictionary + updates = {} + + if args.client_id: + updates["client_id"] = args.client_id + + if args.client_secret: + updates["client_secret"] = args.client_secret + + if args.redirect_url: + updates["default_redirect_url"] = args.redirect_url + + if args.enabled is not None: + updates["is_enabled"] = args.enabled + + if args.settings: + settings = {} + for setting in args.settings: + try: + key, value = setting.split("=", 1) + settings[key] = value + except ValueError: + print_warning(f"Skipping invalid setting format: {setting}") + updates["settings"] = settings + + if not updates: + print_warning("No updates specified. Use --help to see available options.") + return 1 + + try: + config = ExternalAuthService.update_app_provider_config( + provider_type=provider_type, + **updates + ) + + print_success(f"{provider_type.title()} provider updated successfully!") + print_info(f"Provider ID: {config.id}") + print_info(f"Client ID: {config.client_id}") + if config.default_redirect_url: + print_info(f"Default Redirect URL: {config.default_redirect_url}") + print_info(f"Enabled: {config.is_enabled}") + + return 0 + + except ExternalAuthError as e: + print_error(f"Failed to update provider: {e.message}") + if e.error_type == "PROVIDER_NOT_FOUND": + print_info("Use 'create' command to add a new provider configuration.") + return 1 + except Exception as e: + print_error(f"Unexpected error: {str(e)}") + return 1 + + +def list_providers(args): + """List all configured OAuth providers.""" + print_header("Configured OAuth Providers") + + try: + configs = ExternalAuthService.list_app_provider_configs() + + if not configs: + print_info("No OAuth providers configured yet.") + print_info("Use 'create' command to add a provider.") + return 0 + + print() + for config in configs: + status = f"{Colors.OKGREEN}enabled{Colors.ENDC}" if config.get("is_enabled") else f"{Colors.WARNING}disabled{Colors.ENDC}" + print(f" {Colors.BOLD}{config['provider_type']}{Colors.ENDC} - {status}") + print(f" Client ID: {config['client_id']}") + if config.get('default_redirect_url'): + print(f" Redirect URL: {config['default_redirect_url']}") + print(f" Created: {config.get('created_at', 'N/A')}") + + # Show endpoint info if available + additional_config = config.get('additional_config', {}) + if additional_config: + if additional_config.get('auth_url'): + print(f" Auth URL: {additional_config['auth_url']}") + if additional_config.get('scopes'): + scopes = ', '.join(additional_config['scopes']) + print(f" Scopes: {scopes}") + print() + + return 0 + + except Exception as e: + print_error(f"Failed to list providers: {str(e)}") + return 1 + + +def show_provider(args): + """Show details of a specific OAuth provider.""" + provider_type = args.provider.lower() + + print_header(f"{provider_type.title()} OAuth Provider Details") + + try: + config = ExternalAuthService.get_app_provider_config(provider_type) + config_dict = config.to_dict() + + print() + print(f"{Colors.BOLD}Basic Information:{Colors.ENDC}") + print(f" Provider Type: {config_dict['provider_type']}") + print(f" Provider ID: {config_dict['id']}") + print(f" Client ID: {config_dict['client_id']}") + + status = f"{Colors.OKGREEN}enabled{Colors.ENDC}" if config_dict['is_enabled'] else f"{Colors.WARNING}disabled{Colors.ENDC}" + print(f" Status: {status}") + + if config_dict.get('default_redirect_url'): + print(f" Default Redirect URL: {config_dict['default_redirect_url']}") + + print() + print(f"{Colors.BOLD}Timestamps:{Colors.ENDC}") + print(f" Created: {config_dict.get('created_at', 'N/A')}") + print(f" Updated: {config_dict.get('updated_at', 'N/A')}") + + # Show additional configuration + additional_config = config_dict.get('additional_config', {}) + if additional_config: + print() + print(f"{Colors.BOLD}OAuth Configuration:{Colors.ENDC}") + + if additional_config.get('auth_url'): + print(f" Authorization URL: {additional_config['auth_url']}") + if additional_config.get('token_url'): + print(f" Token URL: {additional_config['token_url']}") + if additional_config.get('userinfo_url'): + print(f" User Info URL: {additional_config['userinfo_url']}") + if additional_config.get('jwks_url'): + print(f" JWKS URL: {additional_config['jwks_url']}") + if additional_config.get('scopes'): + scopes = ', '.join(additional_config['scopes']) + print(f" Scopes: {scopes}") + + # Show any custom settings + custom_settings = {k: v for k, v in additional_config.items() + if k not in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']} + if custom_settings: + print() + print(f"{Colors.BOLD}Custom Settings:{Colors.ENDC}") + for key, value in custom_settings.items(): + print(f" {key}: {value}") + + print() + return 0 + + except ExternalAuthError as e: + print_error(f"Failed to get provider: {e.message}") + return 1 + except Exception as e: + print_error(f"Unexpected error: {str(e)}") + return 1 + + +def delete_provider(args): + """Delete an OAuth provider configuration.""" + provider_type = args.provider.lower() + + print_header(f"Deleting {provider_type.title()} OAuth Provider Configuration") + + # Confirm deletion unless --yes flag is provided + if not args.yes: + print_warning("This will permanently delete the provider configuration.") + response = input(f"Are you sure you want to delete {provider_type}? (yes/no): ") + if response.lower() not in ['yes', 'y']: + print_info("Deletion cancelled.") + return 0 + + try: + ExternalAuthService.delete_app_provider_config(provider_type) + print_success(f"{provider_type.title()} provider deleted successfully!") + return 0 + + except ExternalAuthError as e: + print_error(f"Failed to delete provider: {e.message}") + return 1 + except Exception as e: + print_error(f"Unexpected error: {str(e)}") + return 1 + + +def main(): + """Main entry point for the script.""" + parser = argparse.ArgumentParser( + description="Configure OAuth providers for Gatehouse authentication", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Create Google OAuth configuration + %(prog)s create google --client-id "CLIENT_ID" --client-secret "SECRET" + + # Create with environment variables + GOOGLE_CLIENT_ID=xxx GOOGLE_CLIENT_SECRET=yyy %(prog)s create google + + # List all providers + %(prog)s list + + # Show provider details + %(prog)s show google + + # Update provider + %(prog)s update google --enabled true + + # Delete provider + %(prog)s delete google --yes + +Supported Providers: + - google + - github + - microsoft + """ + ) + + subparsers = parser.add_subparsers(dest="command", help="Command to execute") + subparsers.required = True + + # Create command + create_parser = subparsers.add_parser("create", help="Create a new OAuth provider configuration") + create_parser.add_argument("provider", help="Provider type (google, github, microsoft)") + create_parser.add_argument("--client-id", help="OAuth client ID") + create_parser.add_argument("--client-secret", help="OAuth client secret") + create_parser.add_argument("--redirect-url", help="Default redirect URL for OAuth callbacks") + create_parser.add_argument("--disabled", action="store_true", help="Create provider in disabled state") + create_parser.add_argument("--settings", action="append", help="Custom settings (key=value format)") + create_parser.set_defaults(func=create_provider) + + # Update command + update_parser = subparsers.add_parser("update", help="Update an existing OAuth provider configuration") + update_parser.add_argument("provider", help="Provider type to update") + update_parser.add_argument("--client-id", help="New OAuth client ID") + update_parser.add_argument("--client-secret", help="New OAuth client secret") + update_parser.add_argument("--redirect-url", help="New default redirect URL") + update_parser.add_argument("--enabled", type=lambda x: x.lower() in ['true', '1', 'yes'], + help="Enable or disable the provider (true/false)") + update_parser.add_argument("--settings", action="append", help="Custom settings to update (key=value format)") + update_parser.set_defaults(func=update_provider) + + # List command + list_parser = subparsers.add_parser("list", help="List all configured OAuth providers") + list_parser.set_defaults(func=list_providers) + + # Show command + show_parser = subparsers.add_parser("show", help="Show details of a specific OAuth provider") + show_parser.add_argument("provider", help="Provider type to show") + show_parser.set_defaults(func=show_provider) + + # Delete command + delete_parser = subparsers.add_parser("delete", help="Delete an OAuth provider configuration") + delete_parser.add_argument("provider", help="Provider type to delete") + delete_parser.add_argument("--yes", "-y", action="store_true", help="Skip confirmation prompt") + delete_parser.set_defaults(func=delete_provider) + + args = parser.parse_args() + + # Create Flask app context + app = create_app() + + with app.app_context(): + return args.func(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/test_oauth_without_org.sh b/test_oauth_without_org.sh new file mode 100755 index 0000000..fdf4d76 --- /dev/null +++ b/test_oauth_without_org.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +# Test script to verify OAuth endpoints work without organization_id +# This tests the fix for the "Google OAuth is not configured for this organization" error + +API_BASE="http://localhost:5001/api/v1" + +echo "=== Testing OAuth Authorization Endpoint (without organization_id) ===" +echo "" +echo "1. Initiating Google OAuth login flow (NO organization_id)..." +RESPONSE=$(curl -s -X GET "${API_BASE}/auth/external/google/authorize?flow=login") +echo "Response: $RESPONSE" +echo "" + +# Check if we get an authorization URL +if echo "$RESPONSE" | grep -q "authorization_url"; then + echo "✅ SUCCESS: Got authorization URL without requiring organization_id" + AUTH_URL=$(echo "$RESPONSE" | jq -r '.data.authorization_url') + STATE=$(echo "$RESPONSE" | jq -r '.data.state') + echo "Authorization URL: $AUTH_URL" + echo "State: $STATE" +else + echo "❌ FAILED: Did not get authorization URL" + echo "Error: $(echo "$RESPONSE" | jq -r '.message')" +fi + +echo "" +echo "=== Testing with organization_id hint (should still work) ===" +echo "" +echo "2. Initiating Google OAuth login flow (WITH organization_id hint)..." +# You'll need to replace this with an actual organization ID from your database +ORG_ID="test-org-id" +RESPONSE=$(curl -s -X GET "${API_BASE}/auth/external/google/authorize?flow=login&organization_id=${ORG_ID}") +echo "Response: $RESPONSE" +echo "" + +if echo "$RESPONSE" | grep -q "authorization_url"; then + echo "✅ SUCCESS: OAuth works with organization_id hint (backward compatible)" +else + echo "⚠️ Note: This may fail if the organization ID doesn't exist or if app-level config is not set" +fi + +echo "" +echo "=== Testing Register Flow ===" +echo "" +echo "3. Initiating Google OAuth register flow (NO organization_id)..." +RESPONSE=$(curl -s -X GET "${API_BASE}/auth/external/google/authorize?flow=register") +echo "Response: $RESPONSE" +echo "" + +if echo "$RESPONSE" | grep -q "authorization_url"; then + echo "✅ SUCCESS: Register flow works without organization_id" +else + echo "❌ FAILED: Register flow did not work" + echo "Error: $(echo "$RESPONSE" | jq -r '.message')" +fi + +echo "" +echo "=== Summary ===" +echo "" +echo "The key fix addresses the error:" +echo " 'Google OAuth is not configured for this organization'" +echo "" +echo "Now OAuth flows work at the APPLICATION level, not requiring" +echo "an organization context during initial authentication." +echo "" +echo "After OAuth callback:" +echo " - Single org user → Automatic login" +echo " - Multi org user → Organization selection UI" +echo " - New user → Organization creation/selection UI"