google login works
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
"""External authentication provider endpoints."""
|
"""External authentication provider endpoints."""
|
||||||
|
import logging
|
||||||
from flask import request, g
|
from flask import request, g
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
from gatehouse_app.api.v1 import api_v1_bp
|
from gatehouse_app.api.v1 import api_v1_bp
|
||||||
@@ -15,6 +16,8 @@ from gatehouse_app.services.oauth_flow_service import (
|
|||||||
)
|
)
|
||||||
from gatehouse_app.services.audit_service import AuditService
|
from gatehouse_app.services.audit_service import AuditService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Provider type mapping
|
# Provider type mapping
|
||||||
PROVIDER_TYPE_MAP = {
|
PROVIDER_TYPE_MAP = {
|
||||||
@@ -532,25 +535,35 @@ def unlink_account(provider: str):
|
|||||||
def initiate_oauth_authorize(provider: str):
|
def initiate_oauth_authorize(provider: str):
|
||||||
"""
|
"""
|
||||||
Initiate OAuth authentication or account registration flow.
|
Initiate OAuth authentication or account registration flow.
|
||||||
|
|
||||||
|
This endpoint initiates OAuth flows without requiring organization_id upfront.
|
||||||
|
The organization context is determined after successful authentication based on
|
||||||
|
the user's memberships.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: Provider type (google, github, microsoft)
|
provider: Provider type (google, github, microsoft)
|
||||||
|
|
||||||
Query parameters:
|
Query parameters:
|
||||||
flow: 'login' or 'register'
|
flow: 'login' or 'register' (default: 'login')
|
||||||
redirect_uri: Optional redirect URI
|
redirect_uri: Optional redirect URI after OAuth completion
|
||||||
organization_id: Optional organization context
|
organization_id: Optional organization hint (for SSO discovery)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
302: Redirect to provider authorization page
|
200: Authorization URL and state token
|
||||||
400: Validation error or provider not configured
|
400: Validation error or provider not configured at application level
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"authorization_url": "https://...",
|
||||||
|
"state": "state_token"
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
provider_type = get_provider_type(provider)
|
provider_type = get_provider_type(provider)
|
||||||
|
|
||||||
# Get query parameters
|
# Get query parameters - organization_id is now optional
|
||||||
flow = request.args.get("flow", "login")
|
flow = request.args.get("flow", "login")
|
||||||
redirect_uri = request.args.get("redirect_uri")
|
redirect_uri = request.args.get("redirect_uri")
|
||||||
organization_id = request.args.get("organization_id")
|
organization_id = request.args.get("organization_id") # Optional hint
|
||||||
|
|
||||||
if flow not in ["login", "register"]:
|
if flow not in ["login", "register"]:
|
||||||
return api_response(
|
return api_response(
|
||||||
@@ -561,16 +574,17 @@ def initiate_oauth_authorize(provider: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Initiate flow - organization_id is now optional
|
||||||
if flow == "login":
|
if flow == "login":
|
||||||
auth_url, state = OAuthFlowService.initiate_login_flow(
|
auth_url, state = OAuthFlowService.initiate_login_flow(
|
||||||
provider_type=provider_type,
|
provider_type=provider_type,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id, # Optional hint
|
||||||
redirect_uri=redirect_uri,
|
redirect_uri=redirect_uri,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
auth_url, state = OAuthFlowService.initiate_register_flow(
|
auth_url, state = OAuthFlowService.initiate_register_flow(
|
||||||
provider_type=provider_type,
|
provider_type=provider_type,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id, # Optional hint
|
||||||
redirect_uri=redirect_uri,
|
redirect_uri=redirect_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -595,20 +609,54 @@ def initiate_oauth_authorize(provider: str):
|
|||||||
def handle_oauth_callback(provider: str):
|
def handle_oauth_callback(provider: str):
|
||||||
"""
|
"""
|
||||||
Handle OAuth callback from provider.
|
Handle OAuth callback from provider.
|
||||||
|
|
||||||
|
This endpoint handles the redirect from the OAuth provider after authentication.
|
||||||
|
It processes the response and handles different scenarios:
|
||||||
|
- Successful login/register with redirect_uri: Redirects with authorization code
|
||||||
|
- Successful login/register without redirect_uri: Returns session token
|
||||||
|
- Login with multiple orgs: Returns list of organizations for user to select
|
||||||
|
- Register with no org: Prompts for organization creation/selection
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: Provider type (google, github, microsoft)
|
provider: Provider type (google, github, microsoft)
|
||||||
|
|
||||||
Query parameters:
|
Query parameters:
|
||||||
code: Authorization code from provider
|
code: Authorization code from provider
|
||||||
state: State parameter
|
state: State parameter from OAuth flow
|
||||||
error: Error code if auth failed
|
redirect_uri: Optional redirect URI for OAuth 2.0 Authorization Code flow
|
||||||
|
error: Error code if auth failed at provider
|
||||||
error_description: Human-readable error description
|
error_description: Human-readable error description
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
200: OAuth flow completed successfully
|
302: Redirect with authorization code (if redirect_uri provided)
|
||||||
302: Redirect with error
|
200: OAuth flow completed successfully (JSON response)
|
||||||
400: Validation error or OAuth error
|
400: Validation error, OAuth error, or invalid state
|
||||||
|
404: User account not found (for login flows)
|
||||||
|
|
||||||
|
Response formats (when redirect_uri NOT provided):
|
||||||
|
|
||||||
|
Success with session:
|
||||||
|
{
|
||||||
|
"token": "session_token",
|
||||||
|
"expires_in": 86400,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"user": {...}
|
||||||
|
}
|
||||||
|
|
||||||
|
Requires organization selection (login flow):
|
||||||
|
{
|
||||||
|
"requires_org_selection": true,
|
||||||
|
"user": {...},
|
||||||
|
"available_organizations": [...],
|
||||||
|
"state": "state_token"
|
||||||
|
}
|
||||||
|
|
||||||
|
Requires organization creation (register flow):
|
||||||
|
{
|
||||||
|
"requires_org_creation": true,
|
||||||
|
"user": {...},
|
||||||
|
"state": "state_token"
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
provider_type = get_provider_type(provider)
|
provider_type = get_provider_type(provider)
|
||||||
|
|
||||||
@@ -618,7 +666,7 @@ def handle_oauth_callback(provider: str):
|
|||||||
error = request.args.get("error")
|
error = request.args.get("error")
|
||||||
error_description = request.args.get("error_description")
|
error_description = request.args.get("error_description")
|
||||||
|
|
||||||
# Get redirect URI from state if available
|
# Get redirect URI from query parameter (for OAuth 2.0 Authorization Code flow)
|
||||||
redirect_uri = request.args.get("redirect_uri")
|
redirect_uri = request.args.get("redirect_uri")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -632,7 +680,61 @@ def handle_oauth_callback(provider: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if result.get("success"):
|
if result.get("success"):
|
||||||
if result.get("flow_type") == "login":
|
flow_type = result.get("flow_type")
|
||||||
|
|
||||||
|
# Check if we should redirect with authorization code
|
||||||
|
if redirect_uri and flow_type in ["login", "register"]:
|
||||||
|
# Generate authorization code for external application
|
||||||
|
user_id = result.get("user", {}).get("id")
|
||||||
|
if not user_id:
|
||||||
|
# For org selection/creation flows, we can't redirect
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Determine organization_id
|
||||||
|
organization_id = result.get("user", {}).get("organization_id")
|
||||||
|
if not organization_id:
|
||||||
|
# Can't redirect without organization
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Generate authorization code
|
||||||
|
auth_code = OAuthFlowService.generate_authorization_code(
|
||||||
|
user_id=user_id,
|
||||||
|
client_id="external-app",
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=["openid", "profile", "email"],
|
||||||
|
ip_address=request.remote_addr,
|
||||||
|
user_agent=request.headers.get("User-Agent"),
|
||||||
|
lifetime_seconds=600, # 10 minutes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark state as used
|
||||||
|
state_record = OAuthFlowService.validate_state(state)
|
||||||
|
if state_record:
|
||||||
|
state_record.mark_used()
|
||||||
|
|
||||||
|
# Redirect with authorization code
|
||||||
|
return OAuthFlowService.create_redirect_response(
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
authorization_code=auth_code,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle login flow responses (no redirect_uri or org selection required)
|
||||||
|
if flow_type == "login":
|
||||||
|
# Check if organization selection is required
|
||||||
|
if result.get("requires_org_selection"):
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"requires_org_selection": True,
|
||||||
|
"user": result["user"],
|
||||||
|
"available_organizations": result["available_organizations"],
|
||||||
|
"state": result["state"],
|
||||||
|
},
|
||||||
|
message="Please select an organization to continue",
|
||||||
|
status=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normal login with session
|
||||||
return api_response(
|
return api_response(
|
||||||
data={
|
data={
|
||||||
"token": result["session"]["token"],
|
"token": result["session"]["token"],
|
||||||
@@ -642,7 +744,22 @@ def handle_oauth_callback(provider: str):
|
|||||||
},
|
},
|
||||||
message="Login successful",
|
message="Login successful",
|
||||||
)
|
)
|
||||||
elif result.get("flow_type") == "register":
|
|
||||||
|
# Handle register flow responses
|
||||||
|
elif flow_type == "register":
|
||||||
|
# Check if organization creation is required
|
||||||
|
if result.get("requires_org_creation"):
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"requires_org_creation": True,
|
||||||
|
"user": result["user"],
|
||||||
|
"state": result["state"],
|
||||||
|
},
|
||||||
|
message="Please create or select an organization to continue",
|
||||||
|
status=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normal registration with session
|
||||||
return api_response(
|
return api_response(
|
||||||
data={
|
data={
|
||||||
"token": result["session"]["token"],
|
"token": result["session"]["token"],
|
||||||
@@ -652,7 +769,9 @@ def handle_oauth_callback(provider: str):
|
|||||||
},
|
},
|
||||||
message="Registration successful",
|
message="Registration successful",
|
||||||
)
|
)
|
||||||
elif result.get("flow_type") == "link":
|
|
||||||
|
# Handle link flow responses
|
||||||
|
elif flow_type == "link":
|
||||||
return api_response(
|
return api_response(
|
||||||
data={
|
data={
|
||||||
"linked_account": result["linked_account"],
|
"linked_account": result["linked_account"],
|
||||||
@@ -660,6 +779,7 @@ def handle_oauth_callback(provider: str):
|
|||||||
message="Account linked successfully",
|
message="Account linked successfully",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Fallback for unexpected result format
|
||||||
return api_response(
|
return api_response(
|
||||||
data=result,
|
data=result,
|
||||||
message="OAuth flow completed",
|
message="OAuth flow completed",
|
||||||
@@ -674,6 +794,256 @@ def handle_oauth_callback(provider: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_v1_bp.route("/auth/external/select-organization", methods=["POST"])
|
||||||
|
def select_organization():
|
||||||
|
"""
|
||||||
|
Complete OAuth flow by selecting an organization.
|
||||||
|
|
||||||
|
This endpoint is called after OAuth callback when the user needs to select
|
||||||
|
which organization to log in to (when user belongs to multiple orgs).
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
state: The state token from the OAuth callback
|
||||||
|
organization_id: The selected organization ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: Session created successfully
|
||||||
|
400: Invalid state or organization
|
||||||
|
404: Organization not found or user not a member
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"token": "session_token",
|
||||||
|
"expires_in": 86400,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"user": {
|
||||||
|
"id": "...",
|
||||||
|
"email": "...",
|
||||||
|
"full_name": "...",
|
||||||
|
"organization_id": "..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
data = request.json or {}
|
||||||
|
state_token = data.get("state")
|
||||||
|
organization_id = data.get("organization_id")
|
||||||
|
|
||||||
|
if not state_token:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="state is required",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not organization_id:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="organization_id is required",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate state and get OAuth state record
|
||||||
|
state_record = OAuthFlowService.validate_state(state_token)
|
||||||
|
if not state_record or state_record.used:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Invalid or expired state token",
|
||||||
|
status=400,
|
||||||
|
error_type="INVALID_STATE",
|
||||||
|
)
|
||||||
|
|
||||||
|
# The state should have user information from the OAuth callback
|
||||||
|
# We need to find the user that was authenticated
|
||||||
|
from gatehouse_app.models import User, AuthenticationMethod, Organization, OrganizationMember
|
||||||
|
|
||||||
|
# Find user by provider authentication
|
||||||
|
# The state record should have provider info in extra_data if set by callback
|
||||||
|
# Otherwise, we need to find the most recently created auth method
|
||||||
|
auth_method = AuthenticationMethod.query.filter_by(
|
||||||
|
method_type=state_record.provider_type,
|
||||||
|
).order_by(AuthenticationMethod.created_at.desc()).first()
|
||||||
|
|
||||||
|
if not auth_method:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Authentication session not found",
|
||||||
|
status=400,
|
||||||
|
error_type="SESSION_NOT_FOUND",
|
||||||
|
)
|
||||||
|
|
||||||
|
user = auth_method.user
|
||||||
|
|
||||||
|
# Verify user is member of selected organization
|
||||||
|
org = Organization.query.get(organization_id)
|
||||||
|
if not org:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Organization not found",
|
||||||
|
status=404,
|
||||||
|
error_type="NOT_FOUND",
|
||||||
|
)
|
||||||
|
|
||||||
|
member = OrganizationMember.query.filter_by(
|
||||||
|
user_id=user.id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not member:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="You are not a member of this organization",
|
||||||
|
status=403,
|
||||||
|
error_type="FORBIDDEN",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create session for the selected organization
|
||||||
|
from gatehouse_app.services.session_service import SessionService
|
||||||
|
session = SessionService.create_session(
|
||||||
|
user=user,
|
||||||
|
organization_id=organization_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark state as used
|
||||||
|
state_record.mark_used()
|
||||||
|
|
||||||
|
# Audit log - login success with org selection
|
||||||
|
AuditService.log_external_auth_login(
|
||||||
|
user_id=user.id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
provider_type=state_record.provider_type.value if isinstance(state_record.provider_type, AuthMethodType) else state_record.provider_type,
|
||||||
|
provider_user_id=auth_method.provider_user_id,
|
||||||
|
auth_method_id=auth_method.id,
|
||||||
|
session_id=session.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"token": session.token,
|
||||||
|
"expires_in": session.lifetime_seconds,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"user": {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"full_name": user.full_name,
|
||||||
|
"organization_id": organization_id,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
message="Organization selected and session created successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in select_organization: {str(e)}", exc_info=True)
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="An error occurred while selecting organization",
|
||||||
|
status=500,
|
||||||
|
error_type="INTERNAL_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Authorization Code Exchange Endpoint
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@api_v1_bp.route("/auth/external/token", methods=["POST"])
|
||||||
|
def exchange_authorization_code():
|
||||||
|
"""
|
||||||
|
Exchange an authorization code for a session token.
|
||||||
|
|
||||||
|
This endpoint is used by external applications (like oauth2-proxy, BookStack)
|
||||||
|
to exchange the authorization code received from the OAuth callback for a
|
||||||
|
session token.
|
||||||
|
|
||||||
|
Request body (form-encoded or JSON):
|
||||||
|
grant_type: Must be "authorization_code"
|
||||||
|
code: The authorization code from the callback
|
||||||
|
redirect_uri: The redirect URI used in the original request
|
||||||
|
client_id: The client ID (optional, defaults to "external-app")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200: Session token exchanged successfully
|
||||||
|
400: Invalid or expired authorization code
|
||||||
|
404: User not found
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"token": "session_token",
|
||||||
|
"expires_in": 86400,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"user": {
|
||||||
|
"id": "...",
|
||||||
|
"email": "...",
|
||||||
|
"full_name": "...",
|
||||||
|
"organization_id": "..."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Support both JSON and form-encoded requests
|
||||||
|
if request.is_json:
|
||||||
|
data = request.json or {}
|
||||||
|
else:
|
||||||
|
data = request.form or {}
|
||||||
|
|
||||||
|
grant_type = data.get("grant_type")
|
||||||
|
code = data.get("code")
|
||||||
|
redirect_uri = data.get("redirect_uri")
|
||||||
|
client_id = data.get("client_id", "external-app")
|
||||||
|
|
||||||
|
# Validate required parameters
|
||||||
|
if grant_type and grant_type != "authorization_code":
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="Invalid grant_type. Must be 'authorization_code'",
|
||||||
|
status=400,
|
||||||
|
error_type="INVALID_GRANT_TYPE",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not code:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="code is required",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not redirect_uri:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message="redirect_uri is required",
|
||||||
|
status=400,
|
||||||
|
error_type="VALIDATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = OAuthFlowService.exchange_authorization_code(
|
||||||
|
code=code,
|
||||||
|
client_id=client_id,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
ip_address=request.remote_addr,
|
||||||
|
)
|
||||||
|
|
||||||
|
return api_response(
|
||||||
|
data={
|
||||||
|
"token": result["token"],
|
||||||
|
"expires_in": result["expires_in"],
|
||||||
|
"token_type": result["token_type"],
|
||||||
|
"user": result["user"],
|
||||||
|
},
|
||||||
|
message="Token exchanged successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
except OAuthFlowError as e:
|
||||||
|
return api_response(
|
||||||
|
success=False,
|
||||||
|
message=e.message,
|
||||||
|
status=e.status_code,
|
||||||
|
error_type=e.error_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Helper Functions
|
# Helper Functions
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
@@ -3,7 +3,12 @@ from gatehouse_app.models.base import BaseModel
|
|||||||
from gatehouse_app.models.user import User
|
from gatehouse_app.models.user import User
|
||||||
from gatehouse_app.models.organization import Organization
|
from gatehouse_app.models.organization import Organization
|
||||||
from gatehouse_app.models.organization_member import OrganizationMember
|
from gatehouse_app.models.organization_member import OrganizationMember
|
||||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
from gatehouse_app.models.authentication_method import (
|
||||||
|
AuthenticationMethod,
|
||||||
|
ApplicationProviderConfig,
|
||||||
|
OrganizationProviderOverride,
|
||||||
|
OAuthState,
|
||||||
|
)
|
||||||
from gatehouse_app.models.session import Session
|
from gatehouse_app.models.session import Session
|
||||||
from gatehouse_app.models.audit_log import AuditLog
|
from gatehouse_app.models.audit_log import AuditLog
|
||||||
from gatehouse_app.models.oidc_client import OIDCClient
|
from gatehouse_app.models.oidc_client import OIDCClient
|
||||||
@@ -22,6 +27,9 @@ __all__ = [
|
|||||||
"Organization",
|
"Organization",
|
||||||
"OrganizationMember",
|
"OrganizationMember",
|
||||||
"AuthenticationMethod",
|
"AuthenticationMethod",
|
||||||
|
"ApplicationProviderConfig",
|
||||||
|
"OrganizationProviderOverride",
|
||||||
|
"OAuthState",
|
||||||
"Session",
|
"Session",
|
||||||
"AuditLog",
|
"AuditLog",
|
||||||
"OIDCClient",
|
"OIDCClient",
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
"""Authentication method model."""
|
"""Authentication method model."""
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
import secrets
|
||||||
from gatehouse_app.extensions import db
|
from gatehouse_app.extensions import db
|
||||||
from gatehouse_app.models.base import BaseModel
|
from gatehouse_app.models.base import BaseModel
|
||||||
from gatehouse_app.utils.constants import AuthMethodType
|
from gatehouse_app.utils.constants import AuthMethodType
|
||||||
|
from gatehouse_app.utils.encryption import encrypt, decrypt
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMethod(BaseModel):
|
class AuthenticationMethod(BaseModel):
|
||||||
@@ -91,3 +94,287 @@ class AuthenticationMethod(BaseModel):
|
|||||||
"last_used_at": data.get("last_used_at"),
|
"last_used_at": data.get("last_used_at"),
|
||||||
"sign_count": data.get("sign_count", 0),
|
"sign_count": data.get("sign_count", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationProviderConfig(BaseModel):
|
||||||
|
"""Application-wide OAuth provider configuration.
|
||||||
|
|
||||||
|
This model stores OAuth provider credentials at the application level,
|
||||||
|
allowing users to authenticate without needing to specify an organization first.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "application_provider_configs"
|
||||||
|
|
||||||
|
# Provider identification
|
||||||
|
provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True)
|
||||||
|
|
||||||
|
# OAuth credentials (encrypted)
|
||||||
|
client_id = db.Column(db.String(255), nullable=False)
|
||||||
|
client_secret_encrypted = db.Column(db.String(512), nullable=True)
|
||||||
|
|
||||||
|
# Provider status
|
||||||
|
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||||
|
|
||||||
|
# Default redirect URL
|
||||||
|
default_redirect_url = db.Column(db.String(2048), nullable=True)
|
||||||
|
|
||||||
|
# Provider-specific settings (JSON)
|
||||||
|
additional_config = db.Column(db.JSON, nullable=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
organization_overrides = db.relationship(
|
||||||
|
"OrganizationProviderOverride",
|
||||||
|
back_populates="application_config",
|
||||||
|
foreign_keys="OrganizationProviderOverride.provider_type",
|
||||||
|
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
|
||||||
|
cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""String representation of ApplicationProviderConfig."""
|
||||||
|
return f"<ApplicationProviderConfig provider={self.provider_type} enabled={self.is_enabled}>"
|
||||||
|
|
||||||
|
def set_client_secret(self, plaintext_secret: str):
|
||||||
|
"""Encrypt and store client secret.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plaintext_secret: The plaintext OAuth client secret
|
||||||
|
"""
|
||||||
|
if plaintext_secret:
|
||||||
|
self.client_secret_encrypted = encrypt(plaintext_secret)
|
||||||
|
|
||||||
|
def get_client_secret(self) -> str:
|
||||||
|
"""Decrypt and return client secret.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The plaintext OAuth client secret
|
||||||
|
"""
|
||||||
|
if self.client_secret_encrypted:
|
||||||
|
return decrypt(self.client_secret_encrypted)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_dict(self, exclude=None):
|
||||||
|
"""Convert to dictionary, excluding sensitive fields."""
|
||||||
|
exclude = exclude or []
|
||||||
|
# Always exclude encrypted client secret
|
||||||
|
exclude.append("client_secret_encrypted")
|
||||||
|
return super().to_dict(exclude=exclude)
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationProviderOverride(BaseModel):
|
||||||
|
"""Organization-specific OAuth configuration overrides.
|
||||||
|
|
||||||
|
This model allows organizations to override application-level OAuth settings
|
||||||
|
for enterprise SSO scenarios or custom provider configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "organization_provider_overrides"
|
||||||
|
|
||||||
|
# References
|
||||||
|
organization_id = db.Column(
|
||||||
|
db.String(36), db.ForeignKey("organizations.id"),
|
||||||
|
nullable=False, index=True
|
||||||
|
)
|
||||||
|
provider_type = db.Column(db.String(50), nullable=False, index=True)
|
||||||
|
|
||||||
|
# Override OAuth credentials (encrypted, nullable - only if overriding)
|
||||||
|
client_id = db.Column(db.String(255), nullable=True)
|
||||||
|
client_secret_encrypted = db.Column(db.String(512), nullable=True)
|
||||||
|
|
||||||
|
# Provider status
|
||||||
|
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||||
|
|
||||||
|
# Redirect URL override
|
||||||
|
redirect_url_override = db.Column(db.String(2048), nullable=True)
|
||||||
|
|
||||||
|
# Provider-specific settings override (JSON)
|
||||||
|
additional_config = db.Column(db.JSON, nullable=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
organization = db.relationship("Organization", backref="provider_overrides")
|
||||||
|
application_config = db.relationship(
|
||||||
|
"ApplicationProviderConfig",
|
||||||
|
back_populates="organization_overrides",
|
||||||
|
foreign_keys=[provider_type],
|
||||||
|
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
|
||||||
|
viewonly=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unique constraint on (organization_id, provider_type)
|
||||||
|
__table_args__ = (
|
||||||
|
db.UniqueConstraint(
|
||||||
|
"organization_id", "provider_type",
|
||||||
|
name="uix_org_provider_type"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""String representation of OrganizationProviderOverride."""
|
||||||
|
return f"<OrganizationProviderOverride org={self.organization_id} provider={self.provider_type}>"
|
||||||
|
|
||||||
|
def set_client_secret(self, plaintext_secret: str):
|
||||||
|
"""Encrypt and store client secret override.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plaintext_secret: The plaintext OAuth client secret
|
||||||
|
"""
|
||||||
|
if plaintext_secret:
|
||||||
|
self.client_secret_encrypted = encrypt(plaintext_secret)
|
||||||
|
|
||||||
|
def get_client_secret(self) -> str:
|
||||||
|
"""Decrypt and return client secret override.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The plaintext OAuth client secret
|
||||||
|
"""
|
||||||
|
if self.client_secret_encrypted:
|
||||||
|
return decrypt(self.client_secret_encrypted)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_dict(self, exclude=None):
|
||||||
|
"""Convert to dictionary, excluding sensitive fields."""
|
||||||
|
exclude = exclude or []
|
||||||
|
# Always exclude encrypted client secret
|
||||||
|
exclude.append("client_secret_encrypted")
|
||||||
|
return super().to_dict(exclude=exclude)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthState(BaseModel):
|
||||||
|
"""OAuth flow state tracking.
|
||||||
|
|
||||||
|
This model tracks OAuth authentication flow state, including PKCE parameters
|
||||||
|
and organization context (which is now optional to support login flows where
|
||||||
|
the organization isn't known until after authentication).
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "oauth_states"
|
||||||
|
|
||||||
|
# OAuth state parameter (unique, used for CSRF protection)
|
||||||
|
state = db.Column(db.String(64), unique=True, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Flow type: "login", "register", "link"
|
||||||
|
flow_type = db.Column(db.String(50), nullable=False)
|
||||||
|
|
||||||
|
# Provider type
|
||||||
|
provider_type = db.Column(db.String(50), nullable=False)
|
||||||
|
|
||||||
|
# User context (optional - not set for login/register flows)
|
||||||
|
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
|
||||||
|
|
||||||
|
# Organization context (NOW OPTIONAL - for SSO discovery or post-auth)
|
||||||
|
organization_id = db.Column(
|
||||||
|
db.String(36), db.ForeignKey("organizations.id"),
|
||||||
|
nullable=True, index=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# PKCE parameters
|
||||||
|
nonce = db.Column(db.String(128), nullable=True)
|
||||||
|
code_verifier = db.Column(db.String(128), nullable=True)
|
||||||
|
code_challenge = db.Column(db.String(128), nullable=True)
|
||||||
|
|
||||||
|
# OAuth parameters
|
||||||
|
redirect_uri = db.Column(db.String(2048), nullable=True)
|
||||||
|
|
||||||
|
# Post-auth redirect (for frontend routing)
|
||||||
|
return_url = db.Column(db.String(2048), nullable=True)
|
||||||
|
|
||||||
|
# Additional state data
|
||||||
|
extra_data = db.Column(db.JSON, nullable=True)
|
||||||
|
|
||||||
|
# Expiration and usage tracking
|
||||||
|
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||||
|
used = db.Column(db.Boolean, default=False, nullable=False)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
user = db.relationship("User", backref="oauth_states")
|
||||||
|
organization = db.relationship("Organization", backref="oauth_states")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""String representation of OAuthState."""
|
||||||
|
return f"<OAuthState state={self.state[:8]}... flow={self.flow_type} provider={self.provider_type}>"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_state(
|
||||||
|
cls,
|
||||||
|
flow_type: str,
|
||||||
|
provider_type: str,
|
||||||
|
user_id: str = None,
|
||||||
|
organization_id: str = None,
|
||||||
|
redirect_uri: str = None,
|
||||||
|
return_url: str = None,
|
||||||
|
code_verifier: str = None,
|
||||||
|
code_challenge: str = None,
|
||||||
|
nonce: str = None,
|
||||||
|
extra_data: dict = None,
|
||||||
|
lifetime_seconds: int = 600
|
||||||
|
):
|
||||||
|
"""Create a new OAuth state with auto-generated state parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
flow_type: Type of flow ("login", "register", "link")
|
||||||
|
provider_type: OAuth provider type
|
||||||
|
user_id: Optional user ID for authenticated flows
|
||||||
|
organization_id: Optional organization ID
|
||||||
|
redirect_uri: OAuth callback URI
|
||||||
|
return_url: Post-auth redirect destination
|
||||||
|
code_verifier: PKCE code verifier
|
||||||
|
code_challenge: PKCE code challenge
|
||||||
|
nonce: OpenID Connect nonce
|
||||||
|
extra_data: Additional state data
|
||||||
|
lifetime_seconds: How long the state is valid (default 10 minutes)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New OAuthState instance
|
||||||
|
"""
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds)
|
||||||
|
|
||||||
|
oauth_state = cls(
|
||||||
|
state=state,
|
||||||
|
flow_type=flow_type,
|
||||||
|
provider_type=provider_type,
|
||||||
|
user_id=user_id,
|
||||||
|
organization_id=organization_id,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
return_url=return_url,
|
||||||
|
code_verifier=code_verifier,
|
||||||
|
code_challenge=code_challenge,
|
||||||
|
nonce=nonce,
|
||||||
|
extra_data=extra_data,
|
||||||
|
expires_at=expires_at,
|
||||||
|
used=False
|
||||||
|
)
|
||||||
|
oauth_state.save()
|
||||||
|
return oauth_state
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
"""Check if the OAuth state is still valid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if state hasn't expired and hasn't been used
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
# Make expires_at timezone-aware if it's naive (database returns naive datetimes)
|
||||||
|
expires_at = self.expires_at
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
return not self.used and expires_at > now
|
||||||
|
|
||||||
|
def mark_used(self):
|
||||||
|
"""Mark the state as used to prevent replay attacks."""
|
||||||
|
self.used = True
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def cleanup_expired(cls):
|
||||||
|
"""Remove expired OAuth states."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
cls.query.filter(cls.expires_at < now).delete()
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def to_dict(self, exclude=None):
|
||||||
|
"""Convert to dictionary, excluding sensitive fields."""
|
||||||
|
exclude = exclude or []
|
||||||
|
# Exclude code_verifier as it's sensitive
|
||||||
|
exclude.append("code_verifier")
|
||||||
|
return super().to_dict(exclude=exclude)
|
||||||
|
|||||||
@@ -2,12 +2,17 @@
|
|||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Dict, Any
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
|
||||||
from gatehouse_app.extensions import db
|
from gatehouse_app.extensions import db
|
||||||
from gatehouse_app.models import User, AuthenticationMethod
|
from gatehouse_app.models import User, AuthenticationMethod
|
||||||
|
from gatehouse_app.models.authentication_method import (
|
||||||
|
OAuthState,
|
||||||
|
ApplicationProviderConfig,
|
||||||
|
OrganizationProviderOverride
|
||||||
|
)
|
||||||
from gatehouse_app.models.base import BaseModel
|
from gatehouse_app.models.base import BaseModel
|
||||||
from gatehouse_app.utils.constants import AuthMethodType
|
from gatehouse_app.utils.constants import AuthMethodType
|
||||||
from gatehouse_app.services.audit_service import AuditService
|
from gatehouse_app.services.audit_service import AuditService
|
||||||
@@ -25,95 +30,12 @@ class ExternalAuthError(Exception):
|
|||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class OAuthState(BaseModel):
|
|
||||||
"""Temporary OAuth state storage for secure flow management."""
|
|
||||||
|
|
||||||
__tablename__ = "oauth_states"
|
|
||||||
|
|
||||||
# State identifier (used in OAuth redirects)
|
|
||||||
state = db.Column(db.String(64), unique=True, nullable=False, index=True)
|
|
||||||
|
|
||||||
# Flow type
|
|
||||||
flow_type = db.Column(db.String(50), nullable=False) # 'link', 'login', 'register'
|
|
||||||
|
|
||||||
# User context
|
|
||||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True)
|
|
||||||
organization_id = db.Column(
|
|
||||||
db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Provider information
|
|
||||||
provider_type = db.Column(db.String(50), nullable=False)
|
|
||||||
|
|
||||||
# OAuth parameters
|
|
||||||
nonce = db.Column(db.String(128), nullable=True)
|
|
||||||
code_verifier = db.Column(db.String(128), nullable=True)
|
|
||||||
code_challenge = db.Column(db.String(128), nullable=True)
|
|
||||||
redirect_uri = db.Column(db.String(2048), nullable=True)
|
|
||||||
|
|
||||||
# Additional state data
|
|
||||||
extra_data = db.Column(db.JSON, nullable=True)
|
|
||||||
|
|
||||||
# Expiration
|
|
||||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
|
||||||
|
|
||||||
# Status
|
|
||||||
used = db.Column(db.Boolean, default=False, nullable=False)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_state(
|
|
||||||
cls,
|
|
||||||
flow_type: str,
|
|
||||||
provider_type: AuthMethodType,
|
|
||||||
user_id: str = None,
|
|
||||||
organization_id: str = None,
|
|
||||||
redirect_uri: str = None,
|
|
||||||
nonce: str = None,
|
|
||||||
code_verifier: str = None,
|
|
||||||
code_challenge: str = None,
|
|
||||||
extra_data: dict = None,
|
|
||||||
lifetime_seconds: int = 600,
|
|
||||||
) -> "OAuthState":
|
|
||||||
"""Create a new OAuth state record."""
|
|
||||||
state = secrets.token_urlsafe(32)
|
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds)
|
|
||||||
|
|
||||||
return cls.create(
|
|
||||||
state=state,
|
|
||||||
flow_type=flow_type,
|
|
||||||
provider_type=provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type,
|
|
||||||
user_id=user_id,
|
|
||||||
organization_id=organization_id,
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
nonce=nonce or secrets.token_urlsafe(16),
|
|
||||||
code_verifier=code_verifier,
|
|
||||||
code_challenge=code_challenge,
|
|
||||||
extra_data=extra_data,
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_valid(self) -> bool:
|
|
||||||
"""Check if state is still valid."""
|
|
||||||
return (
|
|
||||||
not self.used
|
|
||||||
and self.expires_at
|
|
||||||
and self.expires_at.replace(tzinfo=timezone.utc) > datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
def mark_used(self):
|
|
||||||
"""Mark state as used."""
|
|
||||||
self.used = True
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def cleanup_expired(cls):
|
|
||||||
"""Remove expired states."""
|
|
||||||
cls.query.filter(cls.expires_at < datetime.now(timezone.utc)).delete()
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
class ExternalProviderConfig(BaseModel):
|
class ExternalProviderConfig(BaseModel):
|
||||||
"""OAuth provider configuration per organization."""
|
"""OAuth provider configuration per organization.
|
||||||
|
|
||||||
|
DEPRECATED: This model is maintained for backward compatibility only.
|
||||||
|
Use ApplicationProviderConfig and OrganizationProviderOverride instead.
|
||||||
|
"""
|
||||||
|
|
||||||
__tablename__ = "external_provider_configs"
|
__tablename__ = "external_provider_configs"
|
||||||
|
|
||||||
@@ -198,31 +120,594 @@ class ExternalProviderConfig(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigAdapter:
|
||||||
|
"""
|
||||||
|
Adapter to provide a unified interface for provider configuration.
|
||||||
|
|
||||||
|
This merges application-level config with optional organization overrides,
|
||||||
|
presenting a single config object that works with existing OAuth flow code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app_config: ApplicationProviderConfig,
|
||||||
|
org_override: Optional[OrganizationProviderOverride] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize adapter with app config and optional org override.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_config: Application-level provider configuration
|
||||||
|
org_override: Optional organization-specific override
|
||||||
|
"""
|
||||||
|
self.app_config = app_config
|
||||||
|
self.org_override = org_override
|
||||||
|
self.provider_type = app_config.provider_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client_id(self) -> str:
|
||||||
|
"""Get effective client ID (override takes precedence)."""
|
||||||
|
if self.org_override and self.org_override.client_id:
|
||||||
|
return self.org_override.client_id
|
||||||
|
return self.app_config.client_id
|
||||||
|
|
||||||
|
def get_client_secret(self) -> str:
|
||||||
|
"""Get effective client secret (override takes precedence)."""
|
||||||
|
if self.org_override and self.org_override.client_secret_encrypted:
|
||||||
|
return self.org_override.get_client_secret()
|
||||||
|
return self.app_config.get_client_secret()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auth_url(self) -> str:
|
||||||
|
"""Get authorization URL from app config."""
|
||||||
|
# Provider endpoints are not overridable
|
||||||
|
return self._get_provider_endpoint('auth_url')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_url(self) -> str:
|
||||||
|
"""Get token URL from app config."""
|
||||||
|
return self._get_provider_endpoint('token_url')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def userinfo_url(self) -> str:
|
||||||
|
"""Get userinfo URL from app config."""
|
||||||
|
return self._get_provider_endpoint('userinfo_url')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def jwks_url(self) -> str:
|
||||||
|
"""Get JWKS URL from app config."""
|
||||||
|
return self._get_provider_endpoint('jwks_url')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scopes(self) -> list:
|
||||||
|
"""Get effective scopes (merged from app config and override)."""
|
||||||
|
base_scopes = self.app_config.additional_config.get('scopes', []) if self.app_config.additional_config else []
|
||||||
|
if self.org_override and self.org_override.additional_config:
|
||||||
|
override_scopes = self.org_override.additional_config.get('scopes')
|
||||||
|
if override_scopes is not None:
|
||||||
|
return override_scopes
|
||||||
|
return base_scopes or ['openid', 'profile', 'email']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def redirect_uris(self) -> list:
|
||||||
|
"""Get effective redirect URIs."""
|
||||||
|
# Use override redirect URL if present, otherwise app default
|
||||||
|
if self.org_override and self.org_override.redirect_url_override:
|
||||||
|
return [self.org_override.redirect_url_override]
|
||||||
|
if self.app_config.default_redirect_url:
|
||||||
|
return [self.app_config.default_redirect_url]
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def settings(self) -> dict:
|
||||||
|
"""Get merged settings (app config + org override)."""
|
||||||
|
settings = {}
|
||||||
|
if self.app_config.additional_config:
|
||||||
|
settings.update(self.app_config.additional_config)
|
||||||
|
if self.org_override and self.org_override.additional_config:
|
||||||
|
settings.update(self.org_override.additional_config)
|
||||||
|
return settings
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
"""Check if provider is active (both app and org must be enabled)."""
|
||||||
|
app_enabled = self.app_config.is_enabled
|
||||||
|
org_enabled = True if not self.org_override else self.org_override.is_enabled
|
||||||
|
return app_enabled and org_enabled
|
||||||
|
|
||||||
|
def is_redirect_uri_allowed(self, uri: str) -> bool:
|
||||||
|
"""Check if redirect URI is allowed."""
|
||||||
|
return uri in self.redirect_uris
|
||||||
|
|
||||||
|
def _get_provider_endpoint(self, endpoint_name: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get provider endpoint from app config additional_config.
|
||||||
|
|
||||||
|
For application-wide configs, endpoints are stored in additional_config JSON.
|
||||||
|
"""
|
||||||
|
if not self.app_config.additional_config:
|
||||||
|
return None
|
||||||
|
return self.app_config.additional_config.get(endpoint_name)
|
||||||
|
|
||||||
|
|
||||||
class ExternalAuthService:
|
class ExternalAuthService:
|
||||||
"""Service for external authentication operations."""
|
"""Service for external authentication operations."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_provider_config(
|
def get_provider_config(
|
||||||
cls,
|
cls,
|
||||||
organization_id: str,
|
|
||||||
provider_type: AuthMethodType,
|
provider_type: AuthMethodType,
|
||||||
) -> ExternalProviderConfig:
|
organization_id: Optional[str] = None,
|
||||||
"""Get provider configuration for organization."""
|
) -> ProviderConfigAdapter:
|
||||||
|
"""
|
||||||
|
Get provider configuration for authentication.
|
||||||
|
|
||||||
|
This method retrieves application-wide provider configuration and merges
|
||||||
|
it with organization-specific overrides if present. Both the application
|
||||||
|
config and organization override (if present) must be enabled for the
|
||||||
|
provider to be considered active.
|
||||||
|
|
||||||
|
Configuration Precedence:
|
||||||
|
1. Application-level config provides the baseline configuration
|
||||||
|
2. Organization override can override client_id and client_secret (for SSO)
|
||||||
|
3. Both must be enabled for the provider to work
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: The OAuth provider type (google, github, etc.)
|
||||||
|
organization_id: Optional organization ID for override lookup
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProviderConfigAdapter: Unified config object with merged settings
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ExternalAuthError: If provider is not configured or disabled
|
||||||
|
"""
|
||||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||||
config = ExternalProviderConfig.query.filter_by(
|
|
||||||
organization_id=organization_id,
|
# Get application-wide config
|
||||||
provider_type=provider_type_str,
|
app_config = ApplicationProviderConfig.query.filter_by(
|
||||||
is_active=True,
|
provider_type=provider_type_str
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not config:
|
if not app_config:
|
||||||
raise ExternalAuthError(
|
raise ExternalAuthError(
|
||||||
f"{provider_type_str.title()} OAuth is not configured for this organization",
|
f"{provider_type_str.title()} OAuth is not configured for this application",
|
||||||
"PROVIDER_NOT_CONFIGURED",
|
"PROVIDER_NOT_CONFIGURED",
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not app_config.is_enabled:
|
||||||
|
raise ExternalAuthError(
|
||||||
|
f"{provider_type_str.title()} OAuth is currently disabled",
|
||||||
|
"PROVIDER_DISABLED",
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for organization-specific override
|
||||||
|
org_override = None
|
||||||
|
if organization_id:
|
||||||
|
org_override = OrganizationProviderOverride.query.filter_by(
|
||||||
|
organization_id=organization_id,
|
||||||
|
provider_type=provider_type_str
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# If override exists but is disabled, provider is not available for this org
|
||||||
|
if org_override and not org_override.is_enabled:
|
||||||
|
raise ExternalAuthError(
|
||||||
|
f"{provider_type_str.title()} OAuth is disabled for this organization",
|
||||||
|
"PROVIDER_DISABLED_FOR_ORG",
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return adapter with merged configuration
|
||||||
|
return ProviderConfigAdapter(app_config, org_override)
|
||||||
|
|
||||||
|
# ==================== Application-Wide Provider Management ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_app_provider_config(
|
||||||
|
cls,
|
||||||
|
provider_type: str,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
**kwargs
|
||||||
|
) -> ApplicationProviderConfig:
|
||||||
|
"""
|
||||||
|
Create application-wide provider configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Provider type (google, github, etc.)
|
||||||
|
client_id: OAuth client ID
|
||||||
|
client_secret: OAuth client secret
|
||||||
|
**kwargs: Additional config (auth_url, token_url, userinfo_url, scopes, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApplicationProviderConfig: Created configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ExternalAuthError: If provider already exists
|
||||||
|
"""
|
||||||
|
# Check if provider already exists
|
||||||
|
existing = ApplicationProviderConfig.query.filter_by(
|
||||||
|
provider_type=provider_type
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
raise ExternalAuthError(
|
||||||
|
f"Provider {provider_type} already exists",
|
||||||
|
"PROVIDER_EXISTS",
|
||||||
|
400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build additional_config with endpoints and settings
|
||||||
|
additional_config = {}
|
||||||
|
for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']:
|
||||||
|
if key in kwargs:
|
||||||
|
additional_config[key] = kwargs.pop(key)
|
||||||
|
|
||||||
|
# Add any extra settings
|
||||||
|
if 'settings' in kwargs:
|
||||||
|
additional_config.update(kwargs.pop('settings'))
|
||||||
|
|
||||||
|
# Create new config
|
||||||
|
config = ApplicationProviderConfig(
|
||||||
|
provider_type=provider_type,
|
||||||
|
client_id=client_id,
|
||||||
|
is_enabled=kwargs.get('is_enabled', True),
|
||||||
|
default_redirect_url=kwargs.get('default_redirect_url'),
|
||||||
|
additional_config=additional_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set encrypted secret
|
||||||
|
config.set_client_secret(client_secret)
|
||||||
|
config.save()
|
||||||
|
|
||||||
|
logger.info(f"Created application provider config for {provider_type}")
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_app_provider_config(
|
||||||
|
cls,
|
||||||
|
provider_type: str,
|
||||||
|
**updates
|
||||||
|
) -> ApplicationProviderConfig:
|
||||||
|
"""
|
||||||
|
Update application-wide provider configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Provider type to update
|
||||||
|
**updates: Fields to update (client_id, client_secret, is_enabled, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApplicationProviderConfig: Updated configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ExternalAuthError: If provider not found
|
||||||
|
"""
|
||||||
|
config = ApplicationProviderConfig.query.filter_by(
|
||||||
|
provider_type=provider_type
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
raise ExternalAuthError(
|
||||||
|
f"Provider {provider_type} not found",
|
||||||
|
"PROVIDER_NOT_FOUND",
|
||||||
|
404
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update simple fields
|
||||||
|
if 'client_id' in updates:
|
||||||
|
config.client_id = updates['client_id']
|
||||||
|
|
||||||
|
if 'client_secret' in updates:
|
||||||
|
config.set_client_secret(updates['client_secret'])
|
||||||
|
|
||||||
|
if 'is_enabled' in updates:
|
||||||
|
config.is_enabled = updates['is_enabled']
|
||||||
|
|
||||||
|
if 'default_redirect_url' in updates:
|
||||||
|
config.default_redirect_url = updates['default_redirect_url']
|
||||||
|
|
||||||
|
# Update additional_config JSON fields
|
||||||
|
if config.additional_config is None:
|
||||||
|
config.additional_config = {}
|
||||||
|
|
||||||
|
for key in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']:
|
||||||
|
if key in updates:
|
||||||
|
config.additional_config[key] = updates[key]
|
||||||
|
|
||||||
|
if 'settings' in updates:
|
||||||
|
config.additional_config.update(updates['settings'])
|
||||||
|
|
||||||
|
config.save()
|
||||||
|
logger.info(f"Updated application provider config for {provider_type}")
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_app_provider_config(cls, provider_type: str) -> ApplicationProviderConfig:
|
||||||
|
"""
|
||||||
|
Get application-wide provider configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Provider type to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApplicationProviderConfig: Provider configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ExternalAuthError: If provider not found
|
||||||
|
"""
|
||||||
|
config = ApplicationProviderConfig.query.filter_by(
|
||||||
|
provider_type=provider_type
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
raise ExternalAuthError(
|
||||||
|
f"Provider {provider_type} not found",
|
||||||
|
"PROVIDER_NOT_FOUND",
|
||||||
|
404
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_app_provider_configs(cls) -> list:
|
||||||
|
"""
|
||||||
|
List all application-wide provider configurations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of provider configuration dictionaries
|
||||||
|
"""
|
||||||
|
configs = ApplicationProviderConfig.query.all()
|
||||||
|
return [config.to_dict() for config in configs]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_app_provider_config(cls, provider_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete application-wide provider configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Provider type to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if deleted successfully
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ExternalAuthError: If provider not found
|
||||||
|
"""
|
||||||
|
config = ApplicationProviderConfig.query.filter_by(
|
||||||
|
provider_type=provider_type
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
raise ExternalAuthError(
|
||||||
|
f"Provider {provider_type} not found",
|
||||||
|
"PROVIDER_NOT_FOUND",
|
||||||
|
404
|
||||||
|
)
|
||||||
|
|
||||||
|
config.delete()
|
||||||
|
logger.info(f"Deleted application provider config for {provider_type}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ==================== Organization Provider Override Management ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_org_provider_override(
|
||||||
|
cls,
|
||||||
|
organization_id: str,
|
||||||
|
provider_type: str,
|
||||||
|
**kwargs
|
||||||
|
) -> OrganizationProviderOverride:
|
||||||
|
"""
|
||||||
|
Create organization-specific provider override (for SSO scenarios).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
organization_id: Organization ID
|
||||||
|
provider_type: Provider type to override
|
||||||
|
**kwargs: Override fields (client_id, client_secret, redirect_url, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OrganizationProviderOverride: Created override
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ExternalAuthError: If provider doesn't exist or override already exists
|
||||||
|
"""
|
||||||
|
# Verify app-level provider exists
|
||||||
|
app_config = ApplicationProviderConfig.query.filter_by(
|
||||||
|
provider_type=provider_type
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app_config:
|
||||||
|
raise ExternalAuthError(
|
||||||
|
f"Application provider {provider_type} must be configured first",
|
||||||
|
"PROVIDER_NOT_CONFIGURED",
|
||||||
|
400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if override already exists
|
||||||
|
existing = OrganizationProviderOverride.query.filter_by(
|
||||||
|
organization_id=organization_id,
|
||||||
|
provider_type=provider_type
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
raise ExternalAuthError(
|
||||||
|
f"Override for {provider_type} already exists for this organization",
|
||||||
|
"OVERRIDE_EXISTS",
|
||||||
|
400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build additional_config from kwargs
|
||||||
|
additional_config = {}
|
||||||
|
if 'settings' in kwargs:
|
||||||
|
additional_config.update(kwargs.pop('settings'))
|
||||||
|
if 'scopes' in kwargs:
|
||||||
|
additional_config['scopes'] = kwargs.pop('scopes')
|
||||||
|
|
||||||
|
# Create override
|
||||||
|
override = OrganizationProviderOverride(
|
||||||
|
organization_id=organization_id,
|
||||||
|
provider_type=provider_type,
|
||||||
|
client_id=kwargs.get('client_id'),
|
||||||
|
is_enabled=kwargs.get('is_enabled', True),
|
||||||
|
redirect_url_override=kwargs.get('redirect_url_override'),
|
||||||
|
additional_config=additional_config if additional_config else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set encrypted secret if provided
|
||||||
|
if 'client_secret' in kwargs:
|
||||||
|
override.set_client_secret(kwargs['client_secret'])
|
||||||
|
|
||||||
|
override.save()
|
||||||
|
logger.info(f"Created org override for {provider_type} in org {organization_id}")
|
||||||
|
return override
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_org_provider_override(
|
||||||
|
cls,
|
||||||
|
organization_id: str,
|
||||||
|
provider_type: str,
|
||||||
|
**updates
|
||||||
|
) -> OrganizationProviderOverride:
|
||||||
|
"""
|
||||||
|
Update organization-specific provider override.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
organization_id: Organization ID
|
||||||
|
provider_type: Provider type
|
||||||
|
**updates: Fields to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OrganizationProviderOverride: Updated override
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ExternalAuthError: If override not found
|
||||||
|
"""
|
||||||
|
override = OrganizationProviderOverride.query.filter_by(
|
||||||
|
organization_id=organization_id,
|
||||||
|
provider_type=provider_type
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not override:
|
||||||
|
raise ExternalAuthError(
|
||||||
|
f"Override for {provider_type} not found for this organization",
|
||||||
|
"OVERRIDE_NOT_FOUND",
|
||||||
|
404
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update simple fields
|
||||||
|
if 'client_id' in updates:
|
||||||
|
override.client_id = updates['client_id']
|
||||||
|
|
||||||
|
if 'client_secret' in updates:
|
||||||
|
override.set_client_secret(updates['client_secret'])
|
||||||
|
|
||||||
|
if 'is_enabled' in updates:
|
||||||
|
override.is_enabled = updates['is_enabled']
|
||||||
|
|
||||||
|
if 'redirect_url_override' in updates:
|
||||||
|
override.redirect_url_override = updates['redirect_url_override']
|
||||||
|
|
||||||
|
# Update additional_config
|
||||||
|
if 'settings' in updates or 'scopes' in updates:
|
||||||
|
if override.additional_config is None:
|
||||||
|
override.additional_config = {}
|
||||||
|
|
||||||
|
if 'settings' in updates:
|
||||||
|
override.additional_config.update(updates['settings'])
|
||||||
|
if 'scopes' in updates:
|
||||||
|
override.additional_config['scopes'] = updates['scopes']
|
||||||
|
|
||||||
|
override.save()
|
||||||
|
logger.info(f"Updated org override for {provider_type} in org {organization_id}")
|
||||||
|
return override
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_org_provider_override(
|
||||||
|
cls,
|
||||||
|
organization_id: str,
|
||||||
|
provider_type: str
|
||||||
|
) -> OrganizationProviderOverride:
|
||||||
|
"""
|
||||||
|
Get organization-specific provider override.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
organization_id: Organization ID
|
||||||
|
provider_type: Provider type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OrganizationProviderOverride: Provider override
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ExternalAuthError: If override not found
|
||||||
|
"""
|
||||||
|
override = OrganizationProviderOverride.query.filter_by(
|
||||||
|
organization_id=organization_id,
|
||||||
|
provider_type=provider_type
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not override:
|
||||||
|
raise ExternalAuthError(
|
||||||
|
f"Override for {provider_type} not found for this organization",
|
||||||
|
"OVERRIDE_NOT_FOUND",
|
||||||
|
404
|
||||||
|
)
|
||||||
|
|
||||||
|
return override
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_org_provider_overrides(cls, organization_id: str) -> list:
|
||||||
|
"""
|
||||||
|
List all provider overrides for an organization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
organization_id: Organization ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of override configuration dictionaries
|
||||||
|
"""
|
||||||
|
overrides = OrganizationProviderOverride.query.filter_by(
|
||||||
|
organization_id=organization_id
|
||||||
|
).all()
|
||||||
|
return [override.to_dict() for override in overrides]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_org_provider_override(
|
||||||
|
cls,
|
||||||
|
organization_id: str,
|
||||||
|
provider_type: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Delete organization-specific provider override.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
organization_id: Organization ID
|
||||||
|
provider_type: Provider type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if deleted successfully
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ExternalAuthError: If override not found
|
||||||
|
"""
|
||||||
|
override = OrganizationProviderOverride.query.filter_by(
|
||||||
|
organization_id=organization_id,
|
||||||
|
provider_type=provider_type
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not override:
|
||||||
|
raise ExternalAuthError(
|
||||||
|
f"Override for {provider_type} not found for this organization",
|
||||||
|
"OVERRIDE_NOT_FOUND",
|
||||||
|
404
|
||||||
|
)
|
||||||
|
|
||||||
|
override.delete()
|
||||||
|
logger.info(f"Deleted org override for {provider_type} in org {organization_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ==================== OAuth Flow Methods (Updated for New Architecture) ====================
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initiate_link_flow(
|
def initiate_link_flow(
|
||||||
@@ -240,8 +725,8 @@ class ExternalAuthService:
|
|||||||
"""
|
"""
|
||||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||||
|
|
||||||
# Get provider config
|
# Get provider config (with org override if applicable)
|
||||||
config = cls.get_provider_config(organization_id, provider_type)
|
config = cls.get_provider_config(provider_type, organization_id)
|
||||||
|
|
||||||
# Validate redirect URI
|
# Validate redirect URI
|
||||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||||
@@ -261,13 +746,13 @@ class ExternalAuthService:
|
|||||||
provider_type=provider_type,
|
provider_type=provider_type,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
redirect_uri=redirect_uri or config.redirect_uris[0],
|
redirect_uri=redirect_uri or config.redirect_uris[0] if config.redirect_uris else None,
|
||||||
code_verifier=code_verifier,
|
code_verifier=code_verifier,
|
||||||
code_challenge=code_challenge,
|
code_challenge=code_challenge,
|
||||||
lifetime_seconds=600,
|
lifetime_seconds=600,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build authorization URL (simplified - in production would use provider-specific implementation)
|
# Build authorization URL
|
||||||
auth_url = cls._build_authorization_url(
|
auth_url = cls._build_authorization_url(
|
||||||
config=config,
|
config=config,
|
||||||
state=state,
|
state=state,
|
||||||
@@ -338,12 +823,12 @@ class ExternalAuthService:
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get provider config
|
# Get provider config (with org override if applicable)
|
||||||
config = cls.get_provider_config(
|
config = cls.get_provider_config(
|
||||||
state_record.organization_id, provider_type
|
provider_type, state_record.organization_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exchange code for tokens (simplified - in production would use provider-specific implementation)
|
# Exchange code for tokens
|
||||||
tokens = cls._exchange_code(
|
tokens = cls._exchange_code(
|
||||||
config=config,
|
config=config,
|
||||||
code=authorization_code,
|
code=authorization_code,
|
||||||
@@ -440,8 +925,8 @@ class ExternalAuthService:
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get provider config
|
# Get provider config (with org override if applicable)
|
||||||
config = cls.get_provider_config(organization_id, provider_type)
|
config = cls.get_provider_config(provider_type, organization_id)
|
||||||
|
|
||||||
# Exchange code for tokens
|
# Exchange code for tokens
|
||||||
tokens = cls._exchange_code(
|
tokens = cls._exchange_code(
|
||||||
@@ -606,6 +1091,8 @@ class ExternalAuthService:
|
|||||||
if m.method_type in external_providers or str(m.method_type) in [p.value for p in external_providers]
|
if m.method_type in external_providers or str(m.method_type) in [p.value for p in external_providers]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# ==================== Helper Methods ====================
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _compute_s256_challenge(verifier: str) -> str:
|
def _compute_s256_challenge(verifier: str) -> str:
|
||||||
"""Compute S256 code challenge from verifier."""
|
"""Compute S256 code challenge from verifier."""
|
||||||
@@ -616,8 +1103,8 @@ class ExternalAuthService:
|
|||||||
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_authorization_url(config: ExternalProviderConfig, state: OAuthState) -> str:
|
def _build_authorization_url(config: ProviderConfigAdapter, state: OAuthState) -> str:
|
||||||
"""Build authorization URL (simplified - provider-specific in production)."""
|
"""Build authorization URL using the provider config adapter."""
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
@@ -637,11 +1124,22 @@ class ExternalAuthService:
|
|||||||
params["code_challenge"] = state.code_challenge
|
params["code_challenge"] = state.code_challenge
|
||||||
params["code_challenge_method"] = "S256"
|
params["code_challenge_method"] = "S256"
|
||||||
|
|
||||||
return f"{config.auth_url}?{urlencode(params)}"
|
full_url = f"{config.auth_url}?{urlencode(params)}"
|
||||||
|
|
||||||
|
# DIAGNOSTIC LOGGING: Show exact URL being built
|
||||||
|
logger.info(
|
||||||
|
f"[PKCE DEBUG] Building authorization URL:\n"
|
||||||
|
f" provider_type: {config.provider_type}\n"
|
||||||
|
f" state.code_challenge: {state.code_challenge[:20] if state.code_challenge else 'None'}...\n"
|
||||||
|
f" params has code_challenge: {'code_challenge' in params}\n"
|
||||||
|
f" Full URL: {full_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return full_url
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _exchange_code(config: ExternalProviderConfig, code: str, redirect_uri: str, code_verifier: str = None) -> dict:
|
def _exchange_code(config: ProviderConfigAdapter, code: str, redirect_uri: str, code_verifier: str = None) -> dict:
|
||||||
"""Exchange authorization code for tokens (simplified - provider-specific in production)."""
|
"""Exchange authorization code for tokens using the provider config adapter."""
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
@@ -655,14 +1153,29 @@ class ExternalAuthService:
|
|||||||
if code_verifier:
|
if code_verifier:
|
||||||
data["code_verifier"] = code_verifier
|
data["code_verifier"] = code_verifier
|
||||||
|
|
||||||
|
# Log token exchange request (without secrets)
|
||||||
|
logger.debug(
|
||||||
|
f"Token exchange request: url={config.token_url}, "
|
||||||
|
f"client_id={config.client_id}, redirect_uri={redirect_uri}, "
|
||||||
|
f"has_code_verifier={bool(code_verifier)}"
|
||||||
|
)
|
||||||
|
|
||||||
response = requests.post(config.token_url, data=data)
|
response = requests.post(config.token_url, data=data)
|
||||||
|
|
||||||
|
# Log response details for debugging
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(
|
||||||
|
f"Token exchange failed: status={response.status_code}, "
|
||||||
|
f"response={response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_user_info(config: ExternalProviderConfig, access_token: str) -> dict:
|
def _get_user_info(config: ProviderConfigAdapter, access_token: str) -> dict:
|
||||||
"""Get user info from provider (simplified - provider-specific in production)."""
|
"""Get user info from provider using the provider config adapter."""
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {access_token}"}
|
headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
@@ -758,4 +1271,4 @@ class ExternalAuthService:
|
|||||||
else:
|
else:
|
||||||
result["id_token"] = None
|
result["id_token"] = None
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -1,20 +1,22 @@
|
|||||||
"""OAuth flow service for handling external authentication flows."""
|
"""OAuth flow service for handling external authentication flows."""
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from flask import current_app, request, g
|
from flask import current_app, request, g, redirect
|
||||||
|
|
||||||
from gatehouse_app.extensions import db
|
from gatehouse_app.extensions import db
|
||||||
from gatehouse_app.models import User, AuthenticationMethod
|
from gatehouse_app.models import User, AuthenticationMethod
|
||||||
|
from gatehouse_app.models.authentication_method import OAuthState
|
||||||
from gatehouse_app.models.base import BaseModel
|
from gatehouse_app.models.base import BaseModel
|
||||||
|
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
|
||||||
from gatehouse_app.utils.constants import AuthMethodType
|
from gatehouse_app.utils.constants import AuthMethodType
|
||||||
from gatehouse_app.services.audit_service import AuditService
|
from gatehouse_app.services.audit_service import AuditService
|
||||||
from gatehouse_app.services.external_auth_service import (
|
from gatehouse_app.services.external_auth_service import (
|
||||||
ExternalAuthService,
|
ExternalAuthService,
|
||||||
ExternalAuthError,
|
ExternalAuthError,
|
||||||
OAuthState,
|
|
||||||
ExternalProviderConfig,
|
ExternalProviderConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,11 +45,14 @@ class OAuthFlowService:
|
|||||||
state_data: dict = None,
|
state_data: dict = None,
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Initiate OAuth login flow.
|
Initiate OAuth login flow without requiring organization_id upfront.
|
||||||
|
|
||||||
|
This method initiates the OAuth flow using application-wide provider configuration.
|
||||||
|
The organization context is determined after successful authentication.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider_type: The authentication provider type
|
provider_type: The authentication provider type
|
||||||
organization_id: Optional organization context for SSO
|
organization_id: Optional organization hint for SSO discovery
|
||||||
redirect_uri: Optional custom redirect URI
|
redirect_uri: Optional custom redirect URI
|
||||||
state_data: Additional state data to include
|
state_data: Additional state data to include
|
||||||
|
|
||||||
@@ -65,8 +70,8 @@ class OAuthFlowService:
|
|||||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get provider config
|
# Get provider config (application-wide, no organization required)
|
||||||
config = ExternalAuthService.get_provider_config(organization_id, provider_type)
|
config = ExternalAuthService.get_provider_config(provider_type, organization_id)
|
||||||
|
|
||||||
# Validate redirect URI
|
# Validate redirect URI
|
||||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||||
@@ -76,9 +81,19 @@ class OAuthFlowService:
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate PKCE
|
# Generate PKCE parameters (Google web applications don't use PKCE)
|
||||||
code_verifier = secrets.token_urlsafe(32)
|
code_verifier = None
|
||||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
code_challenge = None
|
||||||
|
if provider_type_str not in ['google']:
|
||||||
|
code_verifier = secrets.token_urlsafe(32)
|
||||||
|
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||||
|
|
||||||
|
# DIAGNOSTIC LOGGING: Show PKCE decision
|
||||||
|
logger.info(
|
||||||
|
f"[PKCE DEBUG] Provider type check: provider_type_str='{provider_type_str}', "
|
||||||
|
f"is_google={provider_type_str in ['google']}, "
|
||||||
|
f"will_skip_pkce={provider_type_str in ['google']}"
|
||||||
|
)
|
||||||
|
|
||||||
# Create OAuth state for login flow
|
# Create OAuth state for login flow
|
||||||
state = OAuthState.create_state(
|
state = OAuthState.create_state(
|
||||||
@@ -92,6 +107,15 @@ class OAuthFlowService:
|
|||||||
lifetime_seconds=600,
|
lifetime_seconds=600,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# DIAGNOSTIC LOGGING: Verify state object
|
||||||
|
logger.info(
|
||||||
|
f"[PKCE DEBUG] Created OAuthState object:\n"
|
||||||
|
f" state.id: {state.id}\n"
|
||||||
|
f" state.provider_type: {state.provider_type}\n"
|
||||||
|
f" state.code_challenge: {state.code_challenge}\n"
|
||||||
|
f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..."
|
||||||
|
)
|
||||||
|
|
||||||
# Build authorization URL
|
# Build authorization URL
|
||||||
auth_url = ExternalAuthService._build_authorization_url(
|
auth_url = ExternalAuthService._build_authorization_url(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -100,7 +124,13 @@ class OAuthFlowService:
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"OAuth login flow initiated for provider={provider_type_str}, "
|
f"OAuth login flow initiated for provider={provider_type_str}, "
|
||||||
f"org_id={organization_id}, state_id={state.id}"
|
f"org_id={organization_id}, state_token={state.state}, state_record_id={state.id}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[PKCE DEBUG] FINAL CHECK: code_challenge={code_challenge}, "
|
||||||
|
f"code_verifier={code_verifier[:20] if code_verifier else None}..., "
|
||||||
|
f"auth_url_has_challenge={'code_challenge=' in auth_url}, "
|
||||||
|
f"returned_auth_url={auth_url}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return auth_url, state.state
|
return auth_url, state.state
|
||||||
@@ -129,11 +159,11 @@ class OAuthFlowService:
|
|||||||
redirect_uri: str = None,
|
redirect_uri: str = None,
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Initiate OAuth registration flow.
|
Initiate OAuth registration flow without requiring organization_id upfront.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider_type: The authentication provider type
|
provider_type: The authentication provider type
|
||||||
organization_id: Optional organization context
|
organization_id: Optional organization hint
|
||||||
redirect_uri: Optional custom redirect URI
|
redirect_uri: Optional custom redirect URI
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -142,8 +172,8 @@ class OAuthFlowService:
|
|||||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get provider config
|
# Get provider config (application-wide, no organization required)
|
||||||
config = ExternalAuthService.get_provider_config(organization_id, provider_type)
|
config = ExternalAuthService.get_provider_config(provider_type, organization_id)
|
||||||
|
|
||||||
# Validate redirect URI
|
# Validate redirect URI
|
||||||
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
if redirect_uri and not config.is_redirect_uri_allowed(redirect_uri):
|
||||||
@@ -153,9 +183,19 @@ class OAuthFlowService:
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate PKCE
|
# Generate PKCE parameters (Google web applications don't use PKCE)
|
||||||
code_verifier = secrets.token_urlsafe(32)
|
code_verifier = None
|
||||||
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
code_challenge = None
|
||||||
|
if provider_type_str not in ['google']:
|
||||||
|
code_verifier = secrets.token_urlsafe(32)
|
||||||
|
code_challenge = ExternalAuthService._compute_s256_challenge(code_verifier)
|
||||||
|
|
||||||
|
# DIAGNOSTIC LOGGING: Show PKCE decision for register flow
|
||||||
|
logger.info(
|
||||||
|
f"[PKCE DEBUG] Register flow - Provider type check: provider_type_str='{provider_type_str}', "
|
||||||
|
f"is_google={provider_type_str in ['google']}, "
|
||||||
|
f"will_skip_pkce={provider_type_str in ['google']}"
|
||||||
|
)
|
||||||
|
|
||||||
# Create OAuth state for register flow
|
# Create OAuth state for register flow
|
||||||
state = OAuthState.create_state(
|
state = OAuthState.create_state(
|
||||||
@@ -168,6 +208,14 @@ class OAuthFlowService:
|
|||||||
lifetime_seconds=600,
|
lifetime_seconds=600,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# DIAGNOSTIC LOGGING: Verify state object for register flow
|
||||||
|
logger.info(
|
||||||
|
f"[PKCE DEBUG] Register flow - Created OAuthState:\n"
|
||||||
|
f" state.id: {state.id}\n"
|
||||||
|
f" state.code_challenge: {state.code_challenge}\n"
|
||||||
|
f" state.code_verifier: {state.code_verifier[:20] if state.code_verifier else None}..."
|
||||||
|
)
|
||||||
|
|
||||||
# Build authorization URL
|
# Build authorization URL
|
||||||
auth_url = ExternalAuthService._build_authorization_url(
|
auth_url = ExternalAuthService._build_authorization_url(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -178,6 +226,9 @@ class OAuthFlowService:
|
|||||||
f"OAuth register flow initiated for provider={provider_type_str}, "
|
f"OAuth register flow initiated for provider={provider_type_str}, "
|
||||||
f"org_id={organization_id}, state_id={state.id}"
|
f"org_id={organization_id}, state_id={state.id}"
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[PKCE DEBUG] Register flow - FINAL: auth_url_has_challenge={'code_challenge=' in auth_url}"
|
||||||
|
)
|
||||||
|
|
||||||
return auth_url, state.state
|
return auth_url, state.state
|
||||||
|
|
||||||
@@ -245,6 +296,17 @@ class OAuthFlowService:
|
|||||||
|
|
||||||
# Validate state
|
# Validate state
|
||||||
state_record = OAuthState.query.filter_by(state=state).first()
|
state_record = OAuthState.query.filter_by(state=state).first()
|
||||||
|
|
||||||
|
# Log validation details for debugging
|
||||||
|
if state_record:
|
||||||
|
logger.debug(
|
||||||
|
f"State validation: found=True, used={state_record.used}, "
|
||||||
|
f"expires_at={state_record.expires_at}, now={datetime.now(timezone.utc)}, "
|
||||||
|
f"is_valid={state_record.is_valid()}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"State validation: state token not found in database: {state}")
|
||||||
|
|
||||||
if not state_record or not state_record.is_valid():
|
if not state_record or not state_record.is_valid():
|
||||||
AuditService.log_external_auth_login_failed(
|
AuditService.log_external_auth_login_failed(
|
||||||
organization_id=state_record.organization_id if state_record else None,
|
organization_id=state_record.organization_id if state_record else None,
|
||||||
@@ -299,24 +361,175 @@ class OAuthFlowService:
|
|||||||
ip_address: str = None,
|
ip_address: str = None,
|
||||||
user_agent: str = None,
|
user_agent: str = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Handle login flow callback."""
|
"""
|
||||||
|
Handle login flow callback with organization discovery.
|
||||||
|
|
||||||
|
This method:
|
||||||
|
1. Exchanges the authorization code for tokens
|
||||||
|
2. Gets user info from the OAuth provider
|
||||||
|
3. Looks up the user by provider_user_id
|
||||||
|
4. Determines which organization(s) the user belongs to
|
||||||
|
5. Creates a session or returns org selection needed
|
||||||
|
"""
|
||||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Authenticate with provider
|
# Get provider config (application-wide)
|
||||||
user, session_data = ExternalAuthService.authenticate_with_provider(
|
config = ExternalAuthService.get_provider_config(
|
||||||
provider_type=provider_type,
|
provider_type, state_record.organization_id
|
||||||
organization_id=state_record.organization_id,
|
)
|
||||||
authorization_code=authorization_code,
|
|
||||||
state=state_record.state,
|
logger.debug(
|
||||||
|
f"Exchanging code with PKCE: state_record.code_verifier={state_record.code_verifier[:20] if state_record.code_verifier else None}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exchange code for tokens
|
||||||
|
tokens = ExternalAuthService._exchange_code(
|
||||||
|
config=config,
|
||||||
|
code=authorization_code,
|
||||||
redirect_uri=redirect_uri,
|
redirect_uri=redirect_uri,
|
||||||
|
code_verifier=state_record.code_verifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user info from provider
|
||||||
|
user_info = ExternalAuthService._get_user_info(
|
||||||
|
config=config,
|
||||||
|
access_token=tokens["access_token"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look up user by provider_user_id
|
||||||
|
auth_method = AuthenticationMethod.query.filter_by(
|
||||||
|
method_type=provider_type,
|
||||||
|
provider_user_id=user_info["provider_user_id"],
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not auth_method:
|
||||||
|
# User doesn't exist - check if email matches existing user
|
||||||
|
existing_user = User.query.filter_by(
|
||||||
|
email=user_info["email"]
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if existing_user:
|
||||||
|
AuditService.log_external_auth_login_failed(
|
||||||
|
organization_id=state_record.organization_id,
|
||||||
|
provider_type=provider_type_str,
|
||||||
|
provider_user_id=user_info["provider_user_id"],
|
||||||
|
email=user_info["email"],
|
||||||
|
failure_reason="email_exists",
|
||||||
|
error_message=f"An account with email {user_info['email']} already exists",
|
||||||
|
)
|
||||||
|
raise OAuthFlowError(
|
||||||
|
f"An account with email {user_info['email']} already exists. "
|
||||||
|
"Please log in with your password and link your account from settings.",
|
||||||
|
"EMAIL_EXISTS",
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
AuditService.log_external_auth_login_failed(
|
||||||
|
organization_id=state_record.organization_id,
|
||||||
|
provider_type=provider_type_str,
|
||||||
|
provider_user_id=user_info["provider_user_id"],
|
||||||
|
email=user_info["email"],
|
||||||
|
failure_reason="account_not_found",
|
||||||
|
error_message="No Gatehouse account matches this external account",
|
||||||
|
)
|
||||||
|
raise OAuthFlowError(
|
||||||
|
"No Gatehouse account matches this external account. Please register first.",
|
||||||
|
"ACCOUNT_NOT_FOUND",
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
|
||||||
|
user = auth_method.user
|
||||||
|
|
||||||
|
# Update provider data
|
||||||
|
auth_method.provider_data = ExternalAuthService._encrypt_provider_data(
|
||||||
|
tokens, user_info
|
||||||
|
)
|
||||||
|
auth_method.last_used_at = datetime.utcnow()
|
||||||
|
auth_method.save()
|
||||||
|
|
||||||
|
# Get user's organizations
|
||||||
|
user_orgs = user.get_organizations()
|
||||||
|
|
||||||
|
# Determine target organization
|
||||||
|
target_org = None
|
||||||
|
|
||||||
|
# Priority 1: Use organization_id from state if provided (org hint)
|
||||||
|
if state_record.organization_id:
|
||||||
|
target_org = next(
|
||||||
|
(org for org in user_orgs if org.id == state_record.organization_id),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Priority 2: If user has exactly one organization, use it
|
||||||
|
if not target_org and len(user_orgs) == 1:
|
||||||
|
target_org = user_orgs[0]
|
||||||
|
|
||||||
|
# Priority 3: No organization or multiple organizations - need selection
|
||||||
|
if not target_org:
|
||||||
|
# Mark state as used
|
||||||
|
state_record.mark_used()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"OAuth login requires org selection for user={user.id}, "
|
||||||
|
f"provider={provider_type_str}, org_count={len(user_orgs)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"flow_type": "login",
|
||||||
|
"requires_org_selection": True,
|
||||||
|
"user": {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"full_name": user.full_name,
|
||||||
|
},
|
||||||
|
"available_organizations": [
|
||||||
|
{
|
||||||
|
"id": org.id,
|
||||||
|
"name": org.name,
|
||||||
|
"slug": org.slug if hasattr(org, 'slug') else None,
|
||||||
|
}
|
||||||
|
for org in user_orgs
|
||||||
|
],
|
||||||
|
"state": state_record.state,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create session for the target org
|
||||||
|
from gatehouse_app.services.auth_service import AuthService
|
||||||
|
session = AuthService.create_session(
|
||||||
|
user=user,
|
||||||
|
is_compliance_only=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark state as used
|
||||||
|
state_record.mark_used()
|
||||||
|
|
||||||
|
# Audit log - login success
|
||||||
|
AuditService.log_external_auth_login(
|
||||||
|
user_id=user.id,
|
||||||
|
organization_id=target_org.id,
|
||||||
|
provider_type=provider_type_str,
|
||||||
|
provider_user_id=user_info["provider_user_id"],
|
||||||
|
auth_method_id=auth_method.id,
|
||||||
|
session_id=session.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"OAuth login successful for user={user.id}, "
|
f"OAuth login successful for user={user.id}, "
|
||||||
f"provider={provider_type_str}, org_id={state_record.organization_id}"
|
f"provider={provider_type_str}, org_id={target_org.id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build session dict with token (to_dict() excludes token for security)
|
||||||
|
session_dict = session.to_dict()
|
||||||
|
session_dict["token"] = session.token
|
||||||
|
# Calculate expires_in handling naive datetime from database
|
||||||
|
expires_at = session.expires_at
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"flow_type": "login",
|
"flow_type": "login",
|
||||||
@@ -324,9 +537,9 @@ class OAuthFlowService:
|
|||||||
"id": user.id,
|
"id": user.id,
|
||||||
"email": user.email,
|
"email": user.email,
|
||||||
"full_name": user.full_name,
|
"full_name": user.full_name,
|
||||||
"organization_id": state_record.organization_id,
|
"organization_id": target_org.id,
|
||||||
},
|
},
|
||||||
"session": session_data,
|
"session": session_dict,
|
||||||
}
|
}
|
||||||
|
|
||||||
except ExternalAuthError as e:
|
except ExternalAuthError as e:
|
||||||
@@ -335,6 +548,19 @@ class OAuthFlowService:
|
|||||||
f"provider={provider_type_str}, error={e.message}"
|
f"provider={provider_type_str}, error={e.message}"
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
except OAuthFlowError:
|
||||||
|
# Re-raise OAuthFlowError as-is
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error in OAuth login callback: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise OAuthFlowError(
|
||||||
|
"An unexpected error occurred during login",
|
||||||
|
"INTERNAL_ERROR",
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _handle_link_callback(
|
def _handle_link_callback(
|
||||||
@@ -387,13 +613,17 @@ class OAuthFlowService:
|
|||||||
authorization_code: str,
|
authorization_code: str,
|
||||||
redirect_uri: str,
|
redirect_uri: str,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Handle registration flow callback."""
|
"""
|
||||||
|
Handle registration flow callback.
|
||||||
|
|
||||||
|
Creates a new user account and prompts for organization creation/selection.
|
||||||
|
"""
|
||||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get provider config
|
# Get provider config (application-wide)
|
||||||
config = ExternalAuthService.get_provider_config(
|
config = ExternalAuthService.get_provider_config(
|
||||||
state_record.organization_id, provider_type
|
provider_type, state_record.organization_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Exchange code for tokens
|
# Exchange code for tokens
|
||||||
@@ -429,6 +659,7 @@ class OAuthFlowService:
|
|||||||
email=user_info["email"],
|
email=user_info["email"],
|
||||||
full_name=user_info.get("name", ""),
|
full_name=user_info.get("name", ""),
|
||||||
status="active",
|
status="active",
|
||||||
|
email_verified=user_info.get("email_verified", False),
|
||||||
)
|
)
|
||||||
user.save()
|
user.save()
|
||||||
|
|
||||||
@@ -440,6 +671,7 @@ class OAuthFlowService:
|
|||||||
provider_data=ExternalAuthService._encrypt_provider_data(tokens, user_info),
|
provider_data=ExternalAuthService._encrypt_provider_data(tokens, user_info),
|
||||||
verified=user_info.get("email_verified", False),
|
verified=user_info.get("email_verified", False),
|
||||||
is_primary=True,
|
is_primary=True,
|
||||||
|
last_used_at=datetime.utcnow(),
|
||||||
)
|
)
|
||||||
auth_method.save()
|
auth_method.save()
|
||||||
|
|
||||||
@@ -475,23 +707,48 @@ class OAuthFlowService:
|
|||||||
f"provider={provider_type_str}, user_id={user.id}"
|
f"provider={provider_type_str}, user_id={user.id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create session
|
# If organization_id hint was provided and valid, create session for that org
|
||||||
from gatehouse_app.services.auth_service import AuthService
|
if state_record.organization_id:
|
||||||
session = AuthService.create_session(
|
from gatehouse_app.models.organization import Organization
|
||||||
user=user,
|
org = Organization.query.get(state_record.organization_id)
|
||||||
organization_id=state_record.organization_id,
|
if org:
|
||||||
)
|
from gatehouse_app.services.auth_service import AuthService
|
||||||
|
session = AuthService.create_session(
|
||||||
|
user=user,
|
||||||
|
is_compliance_only=False,
|
||||||
|
)
|
||||||
|
# Build session dict with token (to_dict() excludes token for security)
|
||||||
|
session_dict = session.to_dict()
|
||||||
|
session_dict["token"] = session.token
|
||||||
|
# Calculate expires_in handling naive datetime from database
|
||||||
|
expires_at = session.expires_at
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"flow_type": "register",
|
||||||
|
"user": {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"full_name": user.full_name,
|
||||||
|
"organization_id": org.id,
|
||||||
|
},
|
||||||
|
"session": session_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
# No organization hint or invalid - need to create/select org
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"flow_type": "register",
|
"flow_type": "register",
|
||||||
|
"requires_org_creation": True,
|
||||||
"user": {
|
"user": {
|
||||||
"id": user.id,
|
"id": user.id,
|
||||||
"email": user.email,
|
"email": user.email,
|
||||||
"full_name": user.full_name,
|
"full_name": user.full_name,
|
||||||
"organization_id": state_record.organization_id,
|
|
||||||
},
|
},
|
||||||
"session": session.to_dict(),
|
"state": state_record.state,
|
||||||
}
|
}
|
||||||
|
|
||||||
except ExternalAuthError as e:
|
except ExternalAuthError as e:
|
||||||
@@ -500,6 +757,19 @@ class OAuthFlowService:
|
|||||||
f"provider={provider_type_str}, error={e.message}"
|
f"provider={provider_type_str}, error={e.message}"
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
except OAuthFlowError:
|
||||||
|
# Re-raise OAuthFlowError as-is
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error in OAuth registration callback: {str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise OAuthFlowError(
|
||||||
|
"An unexpected error occurred during registration",
|
||||||
|
"INTERNAL_ERROR",
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_state(cls, state: str) -> Optional[OAuthState]:
|
def validate_state(cls, state: str) -> Optional[OAuthState]:
|
||||||
@@ -522,3 +792,232 @@ class OAuthFlowService:
|
|||||||
"""Remove expired OAuth states."""
|
"""Remove expired OAuth states."""
|
||||||
OAuthState.cleanup_expired()
|
OAuthState.cleanup_expired()
|
||||||
logger.info("Expired OAuth states cleaned up")
|
logger.info("Expired OAuth states cleaned up")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_authorization_code(
|
||||||
|
cls,
|
||||||
|
user_id: str,
|
||||||
|
client_id: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
scope: list = None,
|
||||||
|
nonce: str = None,
|
||||||
|
ip_address: str = None,
|
||||||
|
user_agent: str = None,
|
||||||
|
lifetime_seconds: int = 600,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate an authorization code for external OAuth applications.
|
||||||
|
|
||||||
|
This method creates a short-lived, single-use authorization code that can be
|
||||||
|
exchanged for a session token by external applications like oauth2-proxy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
client_id: The client ID (e.g., 'oauth2-proxy', 'bookstack')
|
||||||
|
redirect_uri: The redirect URI
|
||||||
|
scope: Requested scopes
|
||||||
|
nonce: OIDC nonce for validation
|
||||||
|
ip_address: Client IP address
|
||||||
|
user_agent: Client user agent
|
||||||
|
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The authorization code (plain text, not hashed)
|
||||||
|
"""
|
||||||
|
# Generate a secure random code
|
||||||
|
code = secrets.token_urlsafe(32)
|
||||||
|
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||||
|
|
||||||
|
# Create the authorization code record
|
||||||
|
OIDCAuthCode.create_code(
|
||||||
|
client_id=client_id,
|
||||||
|
user_id=user_id,
|
||||||
|
code_hash=code_hash,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=scope,
|
||||||
|
nonce=nonce,
|
||||||
|
ip_address=ip_address,
|
||||||
|
user_agent=user_agent,
|
||||||
|
lifetime_seconds=lifetime_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated authorization code for user={user_id}, client={client_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return code
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def exchange_authorization_code(
|
||||||
|
cls,
|
||||||
|
code: str,
|
||||||
|
client_id: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
ip_address: str = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Exchange an authorization code for a session token.
|
||||||
|
|
||||||
|
This method validates and consumes the authorization code, then creates
|
||||||
|
a session for the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: The authorization code
|
||||||
|
client_id: The client ID
|
||||||
|
redirect_uri: The redirect URI (must match original request)
|
||||||
|
ip_address: Client IP address
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with session token and user info
|
||||||
|
"""
|
||||||
|
# Hash the provided code for lookup
|
||||||
|
code_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||||
|
|
||||||
|
# Find the authorization code record
|
||||||
|
auth_code = OIDCAuthCode.query.filter_by(
|
||||||
|
client_id=client_id,
|
||||||
|
code_hash=code_hash,
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not auth_code:
|
||||||
|
raise OAuthFlowError(
|
||||||
|
"Invalid authorization code",
|
||||||
|
"INVALID_CODE",
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate the code
|
||||||
|
if not auth_code.is_valid():
|
||||||
|
if auth_code.is_used:
|
||||||
|
raise OAuthFlowError(
|
||||||
|
"Authorization code has already been used",
|
||||||
|
"CODE_USED",
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise OAuthFlowError(
|
||||||
|
"Authorization code has expired",
|
||||||
|
"CODE_EXPIRED",
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate redirect URI
|
||||||
|
if auth_code.redirect_uri != redirect_uri:
|
||||||
|
raise OAuthFlowError(
|
||||||
|
"Redirect URI mismatch",
|
||||||
|
"INVALID_REDIRECT_URI",
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the user
|
||||||
|
from gatehouse_app.models import User
|
||||||
|
user = User.query.get(auth_code.user_id)
|
||||||
|
if not user:
|
||||||
|
raise OAuthFlowError(
|
||||||
|
"User not found",
|
||||||
|
"USER_NOT_FOUND",
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine organization
|
||||||
|
from gatehouse_app.models.organization import Organization
|
||||||
|
from gatehouse_app.models.organization_member import OrganizationMember
|
||||||
|
|
||||||
|
# Get user's organizations
|
||||||
|
user_orgs = user.get_organizations()
|
||||||
|
|
||||||
|
# Determine target organization
|
||||||
|
target_org = None
|
||||||
|
|
||||||
|
# Priority 1: Use organization_id from auth code if available
|
||||||
|
# Priority 2: If user has exactly one organization, use it
|
||||||
|
if not target_org and len(user_orgs) == 1:
|
||||||
|
target_org = user_orgs[0]
|
||||||
|
|
||||||
|
if not target_org:
|
||||||
|
raise OAuthFlowError(
|
||||||
|
"User does not have a default organization. Organization selection required.",
|
||||||
|
"ORG_SELECTION_REQUIRED",
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create session
|
||||||
|
from gatehouse_app.services.auth_service import AuthService
|
||||||
|
session = AuthService.create_session(
|
||||||
|
user=user,
|
||||||
|
is_compliance_only=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark the code as used
|
||||||
|
auth_code.mark_as_used()
|
||||||
|
|
||||||
|
# Build session dict
|
||||||
|
session_dict = session.to_dict()
|
||||||
|
session_dict["token"] = session.token
|
||||||
|
expires_at = session.expires_at
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
session_dict["expires_in"] = int((expires_at - now).total_seconds())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Authorization code exchanged for session: user={user.id}, "
|
||||||
|
f"org_id={target_org.id}, client={client_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"token": session_dict["token"],
|
||||||
|
"expires_in": session_dict["expires_in"],
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"user": {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"full_name": user.full_name,
|
||||||
|
"organization_id": target_org.id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_redirect_response(
|
||||||
|
cls,
|
||||||
|
redirect_uri: str,
|
||||||
|
authorization_code: str,
|
||||||
|
state: str = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a redirect response with authorization code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redirect_uri: The redirect URI
|
||||||
|
authorization_code: The authorization code
|
||||||
|
state: Optional state parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Flask redirect response
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlencode, urlparse, urlunparse
|
||||||
|
|
||||||
|
# Parse the redirect URI
|
||||||
|
parsed = urlparse(redirect_uri)
|
||||||
|
|
||||||
|
# Build query parameters
|
||||||
|
params = {"code": authorization_code}
|
||||||
|
if state:
|
||||||
|
params["state"] = state
|
||||||
|
|
||||||
|
# Reconstruct URL with query parameters
|
||||||
|
redirect_url = urlunparse((
|
||||||
|
parsed.scheme,
|
||||||
|
parsed.netloc,
|
||||||
|
parsed.path,
|
||||||
|
parsed.params,
|
||||||
|
urlencode(params),
|
||||||
|
parsed.fragment,
|
||||||
|
))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Redirecting to {parsed.scheme}://{parsed.netloc} with authorization code"
|
||||||
|
)
|
||||||
|
|
||||||
|
return redirect(redirect_url)
|
||||||
|
|||||||
@@ -0,0 +1,340 @@
|
|||||||
|
# Gatehouse Scripts
|
||||||
|
|
||||||
|
This directory contains utility scripts for managing and configuring Gatehouse.
|
||||||
|
|
||||||
|
## OAuth Provider Configuration Script
|
||||||
|
|
||||||
|
The [`configure_oauth_provider.py`](configure_oauth_provider.py:1) script allows administrators to easily configure OAuth providers at the application level.
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
|
||||||
|
This script manages application-wide OAuth provider configurations using the new [`ApplicationProviderConfig`](../gatehouse_app/models/authentication_method.py:99) architecture. Unlike the deprecated organization-specific configuration, this allows users to authenticate with OAuth providers without needing to specify an organization first.
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.8+
|
||||||
|
- Virtual environment with dependencies installed
|
||||||
|
- Flask app must be properly configured (`.env` or environment variables)
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Activate virtual environment
|
||||||
|
cd gatehouse-api
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Create Google OAuth configuration
|
||||||
|
python scripts/configure_oauth_provider.py create google \
|
||||||
|
--client-id "YOUR_CLIENT_ID" \
|
||||||
|
--client-secret "YOUR_CLIENT_SECRET" \
|
||||||
|
--redirect-url "http://localhost:5173/auth/callback"
|
||||||
|
|
||||||
|
# List all configured providers
|
||||||
|
python scripts/configure_oauth_provider.py list
|
||||||
|
|
||||||
|
# Show provider details
|
||||||
|
python scripts/configure_oauth_provider.py show google
|
||||||
|
```
|
||||||
|
|
||||||
|
### Commands
|
||||||
|
|
||||||
|
#### `create` - Create a New Provider
|
||||||
|
|
||||||
|
Create a new OAuth provider configuration at the application level.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/configure_oauth_provider.py create PROVIDER [OPTIONS]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Arguments:**
|
||||||
|
- `PROVIDER`: Provider type (google, github, microsoft)
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--client-id TEXT`: OAuth client ID (required, or via environment)
|
||||||
|
- `--client-secret TEXT`: OAuth client secret (required, or via environment)
|
||||||
|
- `--redirect-url TEXT`: Default redirect URL for callbacks
|
||||||
|
- `--disabled`: Create provider in disabled state
|
||||||
|
- `--settings KEY=VALUE`: Custom settings (can be specified multiple times)
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic Google configuration
|
||||||
|
python scripts/configure_oauth_provider.py create google \
|
||||||
|
--client-id "xxx.apps.googleusercontent.com" \
|
||||||
|
--client-secret "GOCSPX-xxx"
|
||||||
|
|
||||||
|
# With redirect URL
|
||||||
|
python scripts/configure_oauth_provider.py create google \
|
||||||
|
--client-id "xxx" \
|
||||||
|
--client-secret "yyy" \
|
||||||
|
--redirect-url "https://app.example.com/auth/callback"
|
||||||
|
|
||||||
|
# Create disabled initially
|
||||||
|
python scripts/configure_oauth_provider.py create github \
|
||||||
|
--client-id "xxx" \
|
||||||
|
--client-secret "yyy" \
|
||||||
|
--disabled
|
||||||
|
|
||||||
|
# With custom settings
|
||||||
|
python scripts/configure_oauth_provider.py create google \
|
||||||
|
--client-id "xxx" \
|
||||||
|
--client-secret "yyy" \
|
||||||
|
--settings "hosted_domain=example.com" \
|
||||||
|
--settings "prompt=consent"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `update` - Update Existing Provider
|
||||||
|
|
||||||
|
Update an existing OAuth provider configuration.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/configure_oauth_provider.py update PROVIDER [OPTIONS]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Arguments:**
|
||||||
|
- `PROVIDER`: Provider type to update
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--client-id TEXT`: New OAuth client ID
|
||||||
|
- `--client-secret TEXT`: New OAuth client secret
|
||||||
|
- `--redirect-url TEXT`: New default redirect URL
|
||||||
|
- `--enabled true|false`: Enable or disable the provider
|
||||||
|
- `--settings KEY=VALUE`: Custom settings to update
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Update client credentials
|
||||||
|
python scripts/configure_oauth_provider.py update google \
|
||||||
|
--client-id "new-client-id" \
|
||||||
|
--client-secret "new-secret"
|
||||||
|
|
||||||
|
# Enable/disable provider
|
||||||
|
python scripts/configure_oauth_provider.py update google --enabled false
|
||||||
|
python scripts/configure_oauth_provider.py update google --enabled true
|
||||||
|
|
||||||
|
# Update redirect URL
|
||||||
|
python scripts/configure_oauth_provider.py update google \
|
||||||
|
--redirect-url "https://new-domain.com/auth/callback"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `list` - List All Providers
|
||||||
|
|
||||||
|
List all configured OAuth providers with their status.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/configure_oauth_provider.py list
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example Output:**
|
||||||
|
```
|
||||||
|
Configured OAuth Providers
|
||||||
|
|
||||||
|
google - enabled
|
||||||
|
Client ID: 972920496362-xxx.apps.googleusercontent.com
|
||||||
|
Redirect URL: https://app.example.com/auth/callback
|
||||||
|
Created: 2026-01-20T13:00:00
|
||||||
|
Auth URL: https://accounts.google.com/o/oauth2/v2/auth
|
||||||
|
Scopes: openid, profile, email
|
||||||
|
|
||||||
|
github - disabled
|
||||||
|
Client ID: Iv1.xxx
|
||||||
|
Created: 2026-01-19T10:00:00
|
||||||
|
Auth URL: https://github.com/login/oauth/authorize
|
||||||
|
Scopes: read:user, user:email
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `show` - Show Provider Details
|
||||||
|
|
||||||
|
Display detailed information about a specific OAuth provider.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/configure_oauth_provider.py show PROVIDER
|
||||||
|
```
|
||||||
|
|
||||||
|
**Arguments:**
|
||||||
|
- `PROVIDER`: Provider type to display
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/configure_oauth_provider.py show google
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example Output:**
|
||||||
|
```
|
||||||
|
Google OAuth Provider Details
|
||||||
|
|
||||||
|
Basic Information:
|
||||||
|
Provider Type: google
|
||||||
|
Provider ID: 123e4567-e89b-12d3-a456-426614174000
|
||||||
|
Client ID: 972920496362-xxx.apps.googleusercontent.com
|
||||||
|
Status: enabled
|
||||||
|
Default Redirect URL: https://app.example.com/auth/callback
|
||||||
|
|
||||||
|
Timestamps:
|
||||||
|
Created: 2026-01-20T13:00:00
|
||||||
|
Updated: 2026-01-20T14:30:00
|
||||||
|
|
||||||
|
OAuth Configuration:
|
||||||
|
Authorization URL: https://accounts.google.com/o/oauth2/v2/auth
|
||||||
|
Token URL: https://oauth2.googleapis.com/token
|
||||||
|
User Info URL: https://openidconnect.googleapis.com/v1/userinfo
|
||||||
|
JWKS URL: https://www.googleapis.com/oauth2/v3/certs
|
||||||
|
Scopes: openid, profile, email
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `delete` - Delete Provider Configuration
|
||||||
|
|
||||||
|
Remove an OAuth provider configuration.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/configure_oauth_provider.py delete PROVIDER [OPTIONS]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Arguments:**
|
||||||
|
- `PROVIDER`: Provider type to delete
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--yes`, `-y`: Skip confirmation prompt
|
||||||
|
|
||||||
|
**Examples:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Delete with confirmation prompt
|
||||||
|
python scripts/configure_oauth_provider.py delete google
|
||||||
|
|
||||||
|
# Delete without confirmation
|
||||||
|
python scripts/configure_oauth_provider.py delete google --yes
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The script supports loading OAuth credentials from environment variables, which is useful for automation and CI/CD pipelines.
|
||||||
|
|
||||||
|
**Supported Variables:**
|
||||||
|
- `{PROVIDER}_CLIENT_ID`: OAuth client ID
|
||||||
|
- `{PROVIDER}_CLIENT_SECRET`: OAuth client secret
|
||||||
|
- `{PROVIDER}_REDIRECT_URL`: Default redirect URL
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Export environment variables
|
||||||
|
export GOOGLE_CLIENT_ID="xxx.apps.googleusercontent.com"
|
||||||
|
export GOOGLE_CLIENT_SECRET="GOCSPX-xxx"
|
||||||
|
export GOOGLE_REDIRECT_URL="https://app.example.com/auth/callback"
|
||||||
|
|
||||||
|
# Create provider using environment variables
|
||||||
|
python scripts/configure_oauth_provider.py create google
|
||||||
|
|
||||||
|
# You can still override with command-line arguments
|
||||||
|
python scripts/configure_oauth_provider.py create google \
|
||||||
|
--redirect-url "https://different.com/callback"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Providers
|
||||||
|
|
||||||
|
The script comes with pre-configured endpoint information for:
|
||||||
|
|
||||||
|
- **Google** (`google`)
|
||||||
|
- Authorization: `https://accounts.google.com/o/oauth2/v2/auth`
|
||||||
|
- Token: `https://oauth2.googleapis.com/token`
|
||||||
|
- User Info: `https://openidconnect.googleapis.com/v1/userinfo`
|
||||||
|
- Default Scopes: `openid, profile, email`
|
||||||
|
|
||||||
|
- **GitHub** (`github`)
|
||||||
|
- Authorization: `https://github.com/login/oauth/authorize`
|
||||||
|
- Token: `https://github.com/login/oauth/access_token`
|
||||||
|
- User Info: `https://api.github.com/user`
|
||||||
|
- Default Scopes: `read:user, user:email`
|
||||||
|
|
||||||
|
- **Microsoft** (`microsoft`)
|
||||||
|
- Authorization: `https://login.microsoftonline.com/common/oauth2/v2.0/authorize`
|
||||||
|
- Token: `https://login.microsoftonline.com/common/oauth2/v2.0/token`
|
||||||
|
- User Info: `https://graph.microsoft.com/oidc/userinfo`
|
||||||
|
- Default Scopes: `openid, profile, email`
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
|
||||||
|
The script provides clear error messages and appropriate exit codes:
|
||||||
|
|
||||||
|
- **Exit Code 0**: Success
|
||||||
|
- **Exit Code 1**: Error occurred
|
||||||
|
|
||||||
|
**Common Errors:**
|
||||||
|
|
||||||
|
1. **Provider Already Exists**
|
||||||
|
```
|
||||||
|
✗ Failed to create provider: Provider google already exists
|
||||||
|
ℹ Use 'update' command to modify existing provider configuration.
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Provider Not Found**
|
||||||
|
```
|
||||||
|
✗ Failed to update provider: Provider google not found
|
||||||
|
ℹ Use 'create' command to add a new provider configuration.
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Missing Credentials**
|
||||||
|
```
|
||||||
|
✗ Client ID is required. Provide via --client-id or GOOGLE_CLIENT_ID environment variable.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Integration with Shell Scripts
|
||||||
|
|
||||||
|
The [`configure-google-auth.sh`](../../docs/configure-google-auth.sh:1) script demonstrates how to integrate the Python script into a shell script for easier deployment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Set credentials
|
||||||
|
GOOGLE_CLIENT_ID="xxx"
|
||||||
|
GOOGLE_CLIENT_SECRET="yyy"
|
||||||
|
REDIRECT_URL="https://app.example.com/callback"
|
||||||
|
|
||||||
|
# Call Python script
|
||||||
|
cd gatehouse-api
|
||||||
|
python3 scripts/configure_oauth_provider.py create google \
|
||||||
|
--client-id "$GOOGLE_CLIENT_ID" \
|
||||||
|
--client-secret "$GOOGLE_CLIENT_SECRET" \
|
||||||
|
--redirect-url "$REDIRECT_URL"
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Service Methods
|
||||||
|
|
||||||
|
The script uses the following [`ExternalAuthService`](../gatehouse_app/services/external_auth_service.py:1) methods:
|
||||||
|
|
||||||
|
- [`create_app_provider_config()`](../gatehouse_app/services/external_auth_service.py:308) - Create provider configuration
|
||||||
|
- [`update_app_provider_config()`](../gatehouse_app/services/external_auth_service.py:369) - Update provider configuration
|
||||||
|
- [`get_app_provider_config()`](../gatehouse_app/services/external_auth_service.py:427) - Get single provider
|
||||||
|
- [`list_app_provider_configs()`](../gatehouse_app/services/external_auth_service.py:454) - List all providers
|
||||||
|
- [`delete_app_provider_config()`](../gatehouse_app/services/external_auth_service.py:465) - Delete provider configuration
|
||||||
|
|
||||||
|
### Security Considerations
|
||||||
|
|
||||||
|
1. **Client Secret Storage**: Client secrets are encrypted using the application's encryption key before storage in the database
|
||||||
|
2. **Environment Variables**: Be cautious when using environment variables in shared environments
|
||||||
|
3. **Secret Exposure**: The `show` command never displays the client secret (it's always excluded)
|
||||||
|
4. **Confirmation Prompts**: The `delete` command requires confirmation unless `--yes` flag is used
|
||||||
|
|
||||||
|
### Troubleshooting
|
||||||
|
|
||||||
|
**Database Connection Issues:**
|
||||||
|
- Ensure PostgreSQL is running and accessible
|
||||||
|
- Check `.env` file for correct `DATABASE_URL`
|
||||||
|
- Verify virtual environment is activated
|
||||||
|
|
||||||
|
**Import Errors:**
|
||||||
|
- Activate the virtual environment: `source .venv/bin/activate`
|
||||||
|
- Install dependencies: `pip install -r requirements.txt`
|
||||||
|
|
||||||
|
**Permission Issues:**
|
||||||
|
- Ensure script is executable: `chmod +x scripts/configure_oauth_provider.py`
|
||||||
|
|
||||||
|
### Related Documentation
|
||||||
|
|
||||||
|
- [External Auth Architecture](../../docs/external-auth-architecture.md)
|
||||||
|
- [Application-Wide OAuth Design](../../docs/external-auth-application-wide-design.md)
|
||||||
|
- [OAuth API Changes](../../docs/oauth-api-changes.md)
|
||||||
Executable
+484
@@ -0,0 +1,484 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
OAuth Provider Configuration Script for Gatehouse
|
||||||
|
|
||||||
|
This script allows administrators to configure OAuth providers at the application level
|
||||||
|
using the new ApplicationProviderConfig architecture.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Create a new provider configuration
|
||||||
|
python scripts/configure_oauth_provider.py create google \\
|
||||||
|
--client-id "YOUR_CLIENT_ID" \\
|
||||||
|
--client-secret "YOUR_CLIENT_SECRET" \\
|
||||||
|
--redirect-url "http://localhost:5173/auth/callback"
|
||||||
|
|
||||||
|
# List all configured providers
|
||||||
|
python scripts/configure_oauth_provider.py list
|
||||||
|
|
||||||
|
# Show details of a specific provider
|
||||||
|
python scripts/configure_oauth_provider.py show google
|
||||||
|
|
||||||
|
# Update a provider configuration
|
||||||
|
python scripts/configure_oauth_provider.py update google --enabled false
|
||||||
|
|
||||||
|
# Delete a provider configuration
|
||||||
|
python scripts/configure_oauth_provider.py delete google
|
||||||
|
|
||||||
|
# Use environment variables
|
||||||
|
GOOGLE_CLIENT_ID=xxx GOOGLE_CLIENT_SECRET=yyy \\
|
||||||
|
python scripts/configure_oauth_provider.py create google
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
# Add the parent directory to the path for imports
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
# Load environment variables from .env file before any other imports
|
||||||
|
# This ensures database and other configurations are available
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
env_file = os.path.join(script_dir, '.env')
|
||||||
|
if os.path.exists(env_file):
|
||||||
|
load_dotenv(env_file)
|
||||||
|
|
||||||
|
# Import after path setup
|
||||||
|
from gatehouse_app import create_app
|
||||||
|
from gatehouse_app.services.external_auth_service import ExternalAuthService, ExternalAuthError
|
||||||
|
|
||||||
|
|
||||||
|
# Provider endpoint configurations
|
||||||
|
PROVIDER_DEFAULTS = {
|
||||||
|
"google": {
|
||||||
|
"auth_url": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||||
|
"token_url": "https://oauth2.googleapis.com/token",
|
||||||
|
"userinfo_url": "https://openidconnect.googleapis.com/v1/userinfo",
|
||||||
|
"jwks_url": "https://www.googleapis.com/oauth2/v3/certs",
|
||||||
|
"scopes": ["openid", "profile", "email"],
|
||||||
|
},
|
||||||
|
"github": {
|
||||||
|
"auth_url": "https://github.com/login/oauth/authorize",
|
||||||
|
"token_url": "https://github.com/login/oauth/access_token",
|
||||||
|
"userinfo_url": "https://api.github.com/user",
|
||||||
|
"scopes": ["read:user", "user:email"],
|
||||||
|
},
|
||||||
|
"microsoft": {
|
||||||
|
"auth_url": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||||
|
"token_url": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||||
|
"userinfo_url": "https://graph.microsoft.com/oidc/userinfo",
|
||||||
|
"jwks_url": "https://login.microsoftonline.com/common/discovery/v2.0/keys",
|
||||||
|
"scopes": ["openid", "profile", "email"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Colors:
|
||||||
|
"""ANSI color codes for terminal output."""
|
||||||
|
HEADER = '\033[95m'
|
||||||
|
OKBLUE = '\033[94m'
|
||||||
|
OKCYAN = '\033[96m'
|
||||||
|
OKGREEN = '\033[92m'
|
||||||
|
WARNING = '\033[93m'
|
||||||
|
FAIL = '\033[91m'
|
||||||
|
ENDC = '\033[0m'
|
||||||
|
BOLD = '\033[1m'
|
||||||
|
UNDERLINE = '\033[4m'
|
||||||
|
|
||||||
|
|
||||||
|
def print_success(message: str):
|
||||||
|
"""Print success message in green."""
|
||||||
|
print(f"{Colors.OKGREEN}✓ {message}{Colors.ENDC}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_error(message: str):
|
||||||
|
"""Print error message in red."""
|
||||||
|
print(f"{Colors.FAIL}✗ {message}{Colors.ENDC}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def print_warning(message: str):
|
||||||
|
"""Print warning message in yellow."""
|
||||||
|
print(f"{Colors.WARNING}⚠ {message}{Colors.ENDC}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_info(message: str):
|
||||||
|
"""Print info message in blue."""
|
||||||
|
print(f"{Colors.OKBLUE}ℹ {message}{Colors.ENDC}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_header(message: str):
|
||||||
|
"""Print header message."""
|
||||||
|
print(f"\n{Colors.BOLD}{Colors.HEADER}{message}{Colors.ENDC}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_credentials(provider_type: str) -> Dict[str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Get OAuth credentials from environment variables.
|
||||||
|
|
||||||
|
Supports the following patterns:
|
||||||
|
- {PROVIDER}_CLIENT_ID
|
||||||
|
- {PROVIDER}_CLIENT_SECRET
|
||||||
|
- {PROVIDER}_REDIRECT_URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Provider type (google, github, microsoft)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with client_id, client_secret, and redirect_url if found
|
||||||
|
"""
|
||||||
|
provider_upper = provider_type.upper()
|
||||||
|
return {
|
||||||
|
"client_id": os.environ.get(f"{provider_upper}_CLIENT_ID"),
|
||||||
|
"client_secret": os.environ.get(f"{provider_upper}_CLIENT_SECRET"),
|
||||||
|
"redirect_url": os.environ.get(f"{provider_upper}_REDIRECT_URL"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_provider(args):
|
||||||
|
"""Create a new OAuth provider configuration."""
|
||||||
|
provider_type = args.provider.lower()
|
||||||
|
|
||||||
|
print_header(f"Creating {provider_type.title()} OAuth Provider Configuration")
|
||||||
|
|
||||||
|
# Get credentials from args or environment
|
||||||
|
env_creds = get_env_credentials(provider_type)
|
||||||
|
client_id = args.client_id or env_creds.get("client_id")
|
||||||
|
client_secret = args.client_secret or env_creds.get("client_secret")
|
||||||
|
redirect_url = args.redirect_url or env_creds.get("redirect_url")
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if not client_id:
|
||||||
|
print_error(f"Client ID is required. Provide via --client-id or {provider_type.upper()}_CLIENT_ID environment variable.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if not client_secret:
|
||||||
|
print_error(f"Client secret is required. Provide via --client-secret or {provider_type.upper()}_CLIENT_SECRET environment variable.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Get provider defaults
|
||||||
|
if provider_type not in PROVIDER_DEFAULTS:
|
||||||
|
print_error(f"Unknown provider: {provider_type}. Supported providers: {', '.join(PROVIDER_DEFAULTS.keys())}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
defaults = PROVIDER_DEFAULTS[provider_type]
|
||||||
|
|
||||||
|
# Build configuration
|
||||||
|
config_data = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
"default_redirect_url": redirect_url,
|
||||||
|
"is_enabled": not args.disabled,
|
||||||
|
**defaults,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add custom settings if provided
|
||||||
|
if args.settings:
|
||||||
|
settings = {}
|
||||||
|
for setting in args.settings:
|
||||||
|
try:
|
||||||
|
key, value = setting.split("=", 1)
|
||||||
|
settings[key] = value
|
||||||
|
except ValueError:
|
||||||
|
print_warning(f"Skipping invalid setting format: {setting}")
|
||||||
|
config_data["settings"] = settings
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create the provider configuration
|
||||||
|
config = ExternalAuthService.create_app_provider_config(
|
||||||
|
provider_type=provider_type,
|
||||||
|
**config_data
|
||||||
|
)
|
||||||
|
|
||||||
|
print_success(f"{provider_type.title()} provider created successfully!")
|
||||||
|
print_info(f"Provider ID: {config.id}")
|
||||||
|
print_info(f"Client ID: {config.client_id}")
|
||||||
|
if redirect_url:
|
||||||
|
print_info(f"Default Redirect URL: {redirect_url}")
|
||||||
|
print_info(f"Enabled: {config.is_enabled}")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except ExternalAuthError as e:
|
||||||
|
print_error(f"Failed to create provider: {e.message}")
|
||||||
|
if e.error_type == "PROVIDER_EXISTS":
|
||||||
|
print_info("Use 'update' command to modify existing provider configuration.")
|
||||||
|
return 1
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Unexpected error: {str(e)}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def update_provider(args):
|
||||||
|
"""Update an existing OAuth provider configuration."""
|
||||||
|
provider_type = args.provider.lower()
|
||||||
|
|
||||||
|
print_header(f"Updating {provider_type.title()} OAuth Provider Configuration")
|
||||||
|
|
||||||
|
# Build updates dictionary
|
||||||
|
updates = {}
|
||||||
|
|
||||||
|
if args.client_id:
|
||||||
|
updates["client_id"] = args.client_id
|
||||||
|
|
||||||
|
if args.client_secret:
|
||||||
|
updates["client_secret"] = args.client_secret
|
||||||
|
|
||||||
|
if args.redirect_url:
|
||||||
|
updates["default_redirect_url"] = args.redirect_url
|
||||||
|
|
||||||
|
if args.enabled is not None:
|
||||||
|
updates["is_enabled"] = args.enabled
|
||||||
|
|
||||||
|
if args.settings:
|
||||||
|
settings = {}
|
||||||
|
for setting in args.settings:
|
||||||
|
try:
|
||||||
|
key, value = setting.split("=", 1)
|
||||||
|
settings[key] = value
|
||||||
|
except ValueError:
|
||||||
|
print_warning(f"Skipping invalid setting format: {setting}")
|
||||||
|
updates["settings"] = settings
|
||||||
|
|
||||||
|
if not updates:
|
||||||
|
print_warning("No updates specified. Use --help to see available options.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = ExternalAuthService.update_app_provider_config(
|
||||||
|
provider_type=provider_type,
|
||||||
|
**updates
|
||||||
|
)
|
||||||
|
|
||||||
|
print_success(f"{provider_type.title()} provider updated successfully!")
|
||||||
|
print_info(f"Provider ID: {config.id}")
|
||||||
|
print_info(f"Client ID: {config.client_id}")
|
||||||
|
if config.default_redirect_url:
|
||||||
|
print_info(f"Default Redirect URL: {config.default_redirect_url}")
|
||||||
|
print_info(f"Enabled: {config.is_enabled}")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except ExternalAuthError as e:
|
||||||
|
print_error(f"Failed to update provider: {e.message}")
|
||||||
|
if e.error_type == "PROVIDER_NOT_FOUND":
|
||||||
|
print_info("Use 'create' command to add a new provider configuration.")
|
||||||
|
return 1
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Unexpected error: {str(e)}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def list_providers(args):
|
||||||
|
"""List all configured OAuth providers."""
|
||||||
|
print_header("Configured OAuth Providers")
|
||||||
|
|
||||||
|
try:
|
||||||
|
configs = ExternalAuthService.list_app_provider_configs()
|
||||||
|
|
||||||
|
if not configs:
|
||||||
|
print_info("No OAuth providers configured yet.")
|
||||||
|
print_info("Use 'create' command to add a provider.")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
print()
|
||||||
|
for config in configs:
|
||||||
|
status = f"{Colors.OKGREEN}enabled{Colors.ENDC}" if config.get("is_enabled") else f"{Colors.WARNING}disabled{Colors.ENDC}"
|
||||||
|
print(f" {Colors.BOLD}{config['provider_type']}{Colors.ENDC} - {status}")
|
||||||
|
print(f" Client ID: {config['client_id']}")
|
||||||
|
if config.get('default_redirect_url'):
|
||||||
|
print(f" Redirect URL: {config['default_redirect_url']}")
|
||||||
|
print(f" Created: {config.get('created_at', 'N/A')}")
|
||||||
|
|
||||||
|
# Show endpoint info if available
|
||||||
|
additional_config = config.get('additional_config', {})
|
||||||
|
if additional_config:
|
||||||
|
if additional_config.get('auth_url'):
|
||||||
|
print(f" Auth URL: {additional_config['auth_url']}")
|
||||||
|
if additional_config.get('scopes'):
|
||||||
|
scopes = ', '.join(additional_config['scopes'])
|
||||||
|
print(f" Scopes: {scopes}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Failed to list providers: {str(e)}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def show_provider(args):
|
||||||
|
"""Show details of a specific OAuth provider."""
|
||||||
|
provider_type = args.provider.lower()
|
||||||
|
|
||||||
|
print_header(f"{provider_type.title()} OAuth Provider Details")
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = ExternalAuthService.get_app_provider_config(provider_type)
|
||||||
|
config_dict = config.to_dict()
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"{Colors.BOLD}Basic Information:{Colors.ENDC}")
|
||||||
|
print(f" Provider Type: {config_dict['provider_type']}")
|
||||||
|
print(f" Provider ID: {config_dict['id']}")
|
||||||
|
print(f" Client ID: {config_dict['client_id']}")
|
||||||
|
|
||||||
|
status = f"{Colors.OKGREEN}enabled{Colors.ENDC}" if config_dict['is_enabled'] else f"{Colors.WARNING}disabled{Colors.ENDC}"
|
||||||
|
print(f" Status: {status}")
|
||||||
|
|
||||||
|
if config_dict.get('default_redirect_url'):
|
||||||
|
print(f" Default Redirect URL: {config_dict['default_redirect_url']}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"{Colors.BOLD}Timestamps:{Colors.ENDC}")
|
||||||
|
print(f" Created: {config_dict.get('created_at', 'N/A')}")
|
||||||
|
print(f" Updated: {config_dict.get('updated_at', 'N/A')}")
|
||||||
|
|
||||||
|
# Show additional configuration
|
||||||
|
additional_config = config_dict.get('additional_config', {})
|
||||||
|
if additional_config:
|
||||||
|
print()
|
||||||
|
print(f"{Colors.BOLD}OAuth Configuration:{Colors.ENDC}")
|
||||||
|
|
||||||
|
if additional_config.get('auth_url'):
|
||||||
|
print(f" Authorization URL: {additional_config['auth_url']}")
|
||||||
|
if additional_config.get('token_url'):
|
||||||
|
print(f" Token URL: {additional_config['token_url']}")
|
||||||
|
if additional_config.get('userinfo_url'):
|
||||||
|
print(f" User Info URL: {additional_config['userinfo_url']}")
|
||||||
|
if additional_config.get('jwks_url'):
|
||||||
|
print(f" JWKS URL: {additional_config['jwks_url']}")
|
||||||
|
if additional_config.get('scopes'):
|
||||||
|
scopes = ', '.join(additional_config['scopes'])
|
||||||
|
print(f" Scopes: {scopes}")
|
||||||
|
|
||||||
|
# Show any custom settings
|
||||||
|
custom_settings = {k: v for k, v in additional_config.items()
|
||||||
|
if k not in ['auth_url', 'token_url', 'userinfo_url', 'jwks_url', 'scopes']}
|
||||||
|
if custom_settings:
|
||||||
|
print()
|
||||||
|
print(f"{Colors.BOLD}Custom Settings:{Colors.ENDC}")
|
||||||
|
for key, value in custom_settings.items():
|
||||||
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except ExternalAuthError as e:
|
||||||
|
print_error(f"Failed to get provider: {e.message}")
|
||||||
|
return 1
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Unexpected error: {str(e)}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def delete_provider(args):
|
||||||
|
"""Delete an OAuth provider configuration."""
|
||||||
|
provider_type = args.provider.lower()
|
||||||
|
|
||||||
|
print_header(f"Deleting {provider_type.title()} OAuth Provider Configuration")
|
||||||
|
|
||||||
|
# Confirm deletion unless --yes flag is provided
|
||||||
|
if not args.yes:
|
||||||
|
print_warning("This will permanently delete the provider configuration.")
|
||||||
|
response = input(f"Are you sure you want to delete {provider_type}? (yes/no): ")
|
||||||
|
if response.lower() not in ['yes', 'y']:
|
||||||
|
print_info("Deletion cancelled.")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
ExternalAuthService.delete_app_provider_config(provider_type)
|
||||||
|
print_success(f"{provider_type.title()} provider deleted successfully!")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except ExternalAuthError as e:
|
||||||
|
print_error(f"Failed to delete provider: {e.message}")
|
||||||
|
return 1
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Unexpected error: {str(e)}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point for the script."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Configure OAuth providers for Gatehouse authentication",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
# Create Google OAuth configuration
|
||||||
|
%(prog)s create google --client-id "CLIENT_ID" --client-secret "SECRET"
|
||||||
|
|
||||||
|
# Create with environment variables
|
||||||
|
GOOGLE_CLIENT_ID=xxx GOOGLE_CLIENT_SECRET=yyy %(prog)s create google
|
||||||
|
|
||||||
|
# List all providers
|
||||||
|
%(prog)s list
|
||||||
|
|
||||||
|
# Show provider details
|
||||||
|
%(prog)s show google
|
||||||
|
|
||||||
|
# Update provider
|
||||||
|
%(prog)s update google --enabled true
|
||||||
|
|
||||||
|
# Delete provider
|
||||||
|
%(prog)s delete google --yes
|
||||||
|
|
||||||
|
Supported Providers:
|
||||||
|
- google
|
||||||
|
- github
|
||||||
|
- microsoft
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="Command to execute")
|
||||||
|
subparsers.required = True
|
||||||
|
|
||||||
|
# Create command
|
||||||
|
create_parser = subparsers.add_parser("create", help="Create a new OAuth provider configuration")
|
||||||
|
create_parser.add_argument("provider", help="Provider type (google, github, microsoft)")
|
||||||
|
create_parser.add_argument("--client-id", help="OAuth client ID")
|
||||||
|
create_parser.add_argument("--client-secret", help="OAuth client secret")
|
||||||
|
create_parser.add_argument("--redirect-url", help="Default redirect URL for OAuth callbacks")
|
||||||
|
create_parser.add_argument("--disabled", action="store_true", help="Create provider in disabled state")
|
||||||
|
create_parser.add_argument("--settings", action="append", help="Custom settings (key=value format)")
|
||||||
|
create_parser.set_defaults(func=create_provider)
|
||||||
|
|
||||||
|
# Update command
|
||||||
|
update_parser = subparsers.add_parser("update", help="Update an existing OAuth provider configuration")
|
||||||
|
update_parser.add_argument("provider", help="Provider type to update")
|
||||||
|
update_parser.add_argument("--client-id", help="New OAuth client ID")
|
||||||
|
update_parser.add_argument("--client-secret", help="New OAuth client secret")
|
||||||
|
update_parser.add_argument("--redirect-url", help="New default redirect URL")
|
||||||
|
update_parser.add_argument("--enabled", type=lambda x: x.lower() in ['true', '1', 'yes'],
|
||||||
|
help="Enable or disable the provider (true/false)")
|
||||||
|
update_parser.add_argument("--settings", action="append", help="Custom settings to update (key=value format)")
|
||||||
|
update_parser.set_defaults(func=update_provider)
|
||||||
|
|
||||||
|
# List command
|
||||||
|
list_parser = subparsers.add_parser("list", help="List all configured OAuth providers")
|
||||||
|
list_parser.set_defaults(func=list_providers)
|
||||||
|
|
||||||
|
# Show command
|
||||||
|
show_parser = subparsers.add_parser("show", help="Show details of a specific OAuth provider")
|
||||||
|
show_parser.add_argument("provider", help="Provider type to show")
|
||||||
|
show_parser.set_defaults(func=show_provider)
|
||||||
|
|
||||||
|
# Delete command
|
||||||
|
delete_parser = subparsers.add_parser("delete", help="Delete an OAuth provider configuration")
|
||||||
|
delete_parser.add_argument("provider", help="Provider type to delete")
|
||||||
|
delete_parser.add_argument("--yes", "-y", action="store_true", help="Skip confirmation prompt")
|
||||||
|
delete_parser.set_defaults(func=delete_provider)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create Flask app context
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
return args.func(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
Executable
+70
@@ -0,0 +1,70 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Test script to verify OAuth endpoints work without organization_id
|
||||||
|
# This tests the fix for the "Google OAuth is not configured for this organization" error
|
||||||
|
|
||||||
|
API_BASE="http://localhost:5001/api/v1"
|
||||||
|
|
||||||
|
echo "=== Testing OAuth Authorization Endpoint (without organization_id) ==="
|
||||||
|
echo ""
|
||||||
|
echo "1. Initiating Google OAuth login flow (NO organization_id)..."
|
||||||
|
RESPONSE=$(curl -s -X GET "${API_BASE}/auth/external/google/authorize?flow=login")
|
||||||
|
echo "Response: $RESPONSE"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Check if we get an authorization URL
|
||||||
|
if echo "$RESPONSE" | grep -q "authorization_url"; then
|
||||||
|
echo "✅ SUCCESS: Got authorization URL without requiring organization_id"
|
||||||
|
AUTH_URL=$(echo "$RESPONSE" | jq -r '.data.authorization_url')
|
||||||
|
STATE=$(echo "$RESPONSE" | jq -r '.data.state')
|
||||||
|
echo "Authorization URL: $AUTH_URL"
|
||||||
|
echo "State: $STATE"
|
||||||
|
else
|
||||||
|
echo "❌ FAILED: Did not get authorization URL"
|
||||||
|
echo "Error: $(echo "$RESPONSE" | jq -r '.message')"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== Testing with organization_id hint (should still work) ==="
|
||||||
|
echo ""
|
||||||
|
echo "2. Initiating Google OAuth login flow (WITH organization_id hint)..."
|
||||||
|
# You'll need to replace this with an actual organization ID from your database
|
||||||
|
ORG_ID="test-org-id"
|
||||||
|
RESPONSE=$(curl -s -X GET "${API_BASE}/auth/external/google/authorize?flow=login&organization_id=${ORG_ID}")
|
||||||
|
echo "Response: $RESPONSE"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
if echo "$RESPONSE" | grep -q "authorization_url"; then
|
||||||
|
echo "✅ SUCCESS: OAuth works with organization_id hint (backward compatible)"
|
||||||
|
else
|
||||||
|
echo "⚠️ Note: This may fail if the organization ID doesn't exist or if app-level config is not set"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== Testing Register Flow ==="
|
||||||
|
echo ""
|
||||||
|
echo "3. Initiating Google OAuth register flow (NO organization_id)..."
|
||||||
|
RESPONSE=$(curl -s -X GET "${API_BASE}/auth/external/google/authorize?flow=register")
|
||||||
|
echo "Response: $RESPONSE"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
if echo "$RESPONSE" | grep -q "authorization_url"; then
|
||||||
|
echo "✅ SUCCESS: Register flow works without organization_id"
|
||||||
|
else
|
||||||
|
echo "❌ FAILED: Register flow did not work"
|
||||||
|
echo "Error: $(echo "$RESPONSE" | jq -r '.message')"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== Summary ==="
|
||||||
|
echo ""
|
||||||
|
echo "The key fix addresses the error:"
|
||||||
|
echo " 'Google OAuth is not configured for this organization'"
|
||||||
|
echo ""
|
||||||
|
echo "Now OAuth flows work at the APPLICATION level, not requiring"
|
||||||
|
echo "an organization context during initial authentication."
|
||||||
|
echo ""
|
||||||
|
echo "After OAuth callback:"
|
||||||
|
echo " - Single org user → Automatic login"
|
||||||
|
echo " - Multi org user → Organization selection UI"
|
||||||
|
echo " - New user → Organization creation/selection UI"
|
||||||
Reference in New Issue
Block a user