move app to gatehouse-app

This commit is contained in:
2026-01-15 03:40:29 +10:30
parent 5e4cffcf73
commit 2c0aaf484b
69 changed files with 1569 additions and 294 deletions
+240
View File
@@ -0,0 +1,240 @@
"""Application factory."""
import os
import logging
# Test debug logging - this should appear when running `flask run --debug`
_root_logger = logging.getLogger(__name__)
_root_logger.debug("[TEST] Debug logging is working!")
from flask import Flask
from config import get_config
from gatehouse_app.extensions import db, migrate, bcrypt, ma, limiter
from gatehouse_app.extensions import session as flask_session
from gatehouse_app.middleware import RequestIDMiddleware, SecurityHeadersMiddleware, setup_cors
from gatehouse_app.exceptions.base import BaseAPIException
from gatehouse_app.utils.response import api_response
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
import redis
# Configure SQLAlchemy logging BEFORE any database operations
# This must be done before db.init_app() to prevent verbose logging
_log_level_env = os.getenv("SQLALCHEMY_LOG_LEVEL", "WARNING").upper()
_sqlalchemy_log_level = getattr(logging, _log_level_env, logging.WARNING)
logging.getLogger('sqlalchemy').setLevel(_sqlalchemy_log_level)
logging.getLogger('sqlalchemy.engine').setLevel(_sqlalchemy_log_level)
logging.getLogger('sqlalchemy.dialects').setLevel(_sqlalchemy_log_level)
logging.getLogger('sqlalchemy.pool').setLevel(_sqlalchemy_log_level)
def create_app(config_name=None):
"""
Create and configure the Flask application.
Args:
config_name: Configuration name (development, testing, production)
Returns:
Flask application instance
"""
flask_app = Flask(__name__)
# Load configuration
config = get_config(config_name)
flask_app.config.from_object(config)
# Initialize extensions
initialize_extensions(flask_app)
# Setup middleware
setup_middleware(flask_app)
# Register blueprints
register_blueprints(flask_app)
# Register error handlers
register_error_handlers(flask_app)
# Setup logging
setup_logging(flask_app)
# Initialize OIDC JWKS service with a signing key
initialize_oidc_jwks(flask_app)
return flask_app
def initialize_extensions(app):
"""Initialize Flask extensions."""
# Database
db.init_app(app)
migrate.init_app(app, db)
# Security
bcrypt.init_app(app)
# CORS - using custom middleware only (see app/middleware/cors.py)
# Flask-CORS disabled to avoid conflicts
# cors.init_app(app)
# Marshmallow
ma.init_app(app)
# Rate limiting
if app.config.get("RATELIMIT_ENABLED"):
limiter.init_app(app)
# Redis for sessions
try:
redis_url = app.config.get("REDIS_URL")
if redis_url:
import gatehouse_app.extensions
gatehouse_app.extensions.redis_client = redis.from_url(redis_url)
app.config["SESSION_REDIS"] = gatehouse_app.extensions.redis_client
except Exception as e:
logging.warning(f"Redis connection failed: {e}")
# Flask-Session
flask_session.init_app(app)
def setup_middleware(app):
"""Setup application middleware."""
RequestIDMiddleware(app)
SecurityHeadersMiddleware(app)
setup_cors(app)
def register_blueprints(app):
"""Register application blueprints."""
from gatehouse_app.api import register_api_blueprints
from gatehouse_app.api.oidc import oidc_bp
register_api_blueprints(app)
# Register OIDC blueprint at root level
app.register_blueprint(oidc_bp)
def register_error_handlers(app):
"""Register error handlers."""
@app.errorhandler(BaseAPIException)
def handle_api_exception(error):
"""Handle custom API exceptions."""
return api_response(
success=False,
message=error.message,
status=error.status_code,
error_type=error.error_type,
error_details=error.error_details,
)
@app.errorhandler(404)
def handle_not_found(error):
"""Handle 404 errors."""
return api_response(
success=False,
message="Resource not found",
status=404,
error_type="NOT_FOUND",
)
@app.errorhandler(405)
def handle_method_not_allowed(error):
"""Handle 405 errors."""
return api_response(
success=False,
message="Method not allowed",
status=405,
error_type="METHOD_NOT_ALLOWED",
)
@app.errorhandler(500)
def handle_internal_error(error):
"""Handle 500 errors."""
app.logger.error(f"Internal server error: {error}")
return api_response(
success=False,
message="Internal server error",
status=500,
error_type="INTERNAL_ERROR",
)
@app.errorhandler(Exception)
def handle_unexpected_error(error):
"""Handle unexpected errors."""
app.logger.error(f"Unexpected error: {error}", exc_info=True)
return api_response(
success=False,
message="An unexpected error occurred",
status=500,
error_type="INTERNAL_ERROR",
)
def setup_logging(app):
"""Setup application logging."""
log_level = getattr(logging, app.config.get("LOG_LEVEL", "INFO"))
# Create formatter
formatter = logging.Formatter(
"[%(asctime)s] [%(levelname)s] %(name)s: %(message)s"
)
# Configure root logger - this ensures all module loggers (like app.services.oidc_service)
# will output DEBUG level logs when in development mode
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
if app.config.get("LOG_TO_STDOUT"):
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
stream_handler.setLevel(log_level)
root_logger.addHandler(stream_handler)
# Disable Werkzeug's default logger to avoid log duplication and interference
werkzeug_logger = logging.getLogger('werkzeug')
werkzeug_logger.setLevel(logging.INFO)
# Ensure child loggers propagate to root logger
# This is the key fix - explicitly enable propagation for common app loggers
for logger_name in ['app', 'app.api', 'app.api.oidc', 'app.services', 'app.models']:
child_logger = logging.getLogger(logger_name)
child_logger.propagate = True
child_logger.setLevel(log_level)
# Configure Flask app logger
app.logger.setLevel(log_level)
# Configure SQLAlchemy logging level (also set at module level before DB init)
sqlalchemy_log_level = getattr(logging, app.config.get("SQLALCHEMY_LOG_LEVEL", "WARNING"), logging.WARNING)
logging.getLogger('sqlalchemy').setLevel(sqlalchemy_log_level)
logging.getLogger('sqlalchemy.engine').setLevel(sqlalchemy_log_level)
logging.getLogger('sqlalchemy.dialects').setLevel(sqlalchemy_log_level)
logging.getLogger('sqlalchemy.pool').setLevel(sqlalchemy_log_level)
app.logger.info("Application startup")
# Test debug log after logging is configured
app.logger.debug("[TEST] Debug logging is working!")
def initialize_oidc_jwks(app):
"""Initialize OIDC JWKS service with a signing key.
This ensures that signing keys are available for token generation.
Keys are loaded from the database if available, otherwise a new key
is generated and persisted to the database.
Args:
app: Flask application instance
"""
with app.app_context():
try:
jwks_service = OIDCJWKSService()
# Use initialize_with_key which handles loading from DB
# or generating a new key if none exists
signing_key = jwks_service.initialize_with_key()
app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}")
except Exception as e:
app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}")
+24
View File
@@ -0,0 +1,24 @@
"""API package."""
from flask import Blueprint
from gatehouse_app.utils.response import api_response
# Create main API blueprint
api_bp = Blueprint("api", __name__)
@api_bp.route("/health", methods=["GET"])
def health_check():
"""Health check endpoint."""
return api_response(
data={"status": "healthy", "service": "authy2-backend"},
message="Service is running",
)
def register_api_blueprints(app):
"""Register all API blueprints."""
from gatehouse_app.api.v1 import api_v1_bp
# Register versioned API blueprints
app.register_blueprint(api_bp, url_prefix="/api")
app.register_blueprint(api_v1_bp, url_prefix="/api/v1")
File diff suppressed because it is too large Load Diff
+8
View File
@@ -0,0 +1,8 @@
"""API v1 blueprint."""
from flask import Blueprint
# Create v1 API blueprint
api_v1_bp = Blueprint("api_v1", __name__)
# Import route modules to register them
from gatehouse_app.api.v1 import auth, users, organizations
+937
View File
@@ -0,0 +1,937 @@
"""Authentication endpoints."""
import json
from flask import request, session, g, jsonify
from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.schemas.auth_schema import (
RegisterSchema,
LoginSchema,
TOTPVerifyEnrollmentSchema,
TOTPVerifySchema,
TOTPDisableSchema,
TOTPRegenerateBackupCodesSchema,
)
from gatehouse_app.schemas.webauthn_schema import (
WebAuthnRegistrationBeginSchema,
WebAuthnRegistrationCompleteSchema,
WebAuthnLoginBeginSchema,
WebAuthnLoginCompleteSchema,
WebAuthnCredentialRenameSchema,
)
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.webauthn_service import WebAuthnService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.utils.decorators import login_required
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
from gatehouse_app.exceptions.validation_exceptions import ConflictError, NotFoundError
@api_v1_bp.route("/auth/register", methods=["POST"])
def register():
"""
Register a new user.
Request body:
email: User email
password: User password
password_confirm: Password confirmation
full_name: Optional full name
Returns:
201: User created successfully
400: Validation error
409: Email already exists
"""
try:
# Validate request data
schema = RegisterSchema()
data = schema.load(request.json)
# Register user
user = AuthService.register_user(
email=data["email"],
password=data["password"],
full_name=data.get("full_name"),
)
# Create session
user_session = AuthService.create_session(user)
return api_response(
data={
"user": user.to_dict(),
"token": user_session.token,
"expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(),
},
message="Registration successful",
status=201,
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/auth/login", methods=["POST"])
def login():
"""
Login user.
Request body:
email: User email
password: User password
remember_me: Optional boolean for extended session
Returns:
200: Login successful or TOTP code required
400: Validation error
401: Invalid credentials
"""
try:
# Validate request data
schema = LoginSchema()
data = schema.load(request.json)
# Authenticate user with email and password
user = AuthService.authenticate(
email=data["email"],
password=data["password"],
)
# Check if user has TOTP enabled for two-factor authentication
if user.has_totp_enabled():
# TOTP is enabled - store user_id in session for TOTP verification
# The /auth/totp/verify endpoint will retrieve this user_id
session["totp_pending_user_id"] = user.id
# Return response indicating TOTP code is required
# Do NOT create session or return token yet - wait for TOTP verification
return api_response(
data={
"requires_totp": True,
},
message="TOTP code required. Please enter your 6-digit code from your authenticator app.",
)
# TOTP is NOT enabled - proceed with normal login flow
# Create session with appropriate duration based on remember_me preference
duration = 2592000 if data.get("remember_me") else 86400 # 30 days vs 1 day
user_session = AuthService.create_session(user, duration_seconds=duration)
return api_response(
data={
"user": user.to_dict(),
"token": user_session.token,
"expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(),
},
message="Login successful",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/auth/logout", methods=["POST"])
@login_required
def logout():
"""
Logout current user.
Returns:
200: Logout successful
401: Not authenticated
"""
# Revoke current session (g.current_session is set by login_required decorator)
if g.current_session:
AuthService.revoke_session(g.current_session.id, reason="User logout")
return api_response(
message="Logout successful",
)
@api_v1_bp.route("/auth/me", methods=["GET"])
@login_required
def get_current_user():
"""
Get current authenticated user.
Returns:
200: User data
401: Not authenticated
"""
user = g.current_user
return api_response(
data={
"user": user.to_dict(),
"organizations": [
{"id": org.id, "name": org.name, "slug": org.slug}
for org in user.get_organizations()
],
},
message="User retrieved successfully",
)
@api_v1_bp.route("/auth/sessions", methods=["GET"])
@login_required
def get_user_sessions():
"""
Get all active sessions for current user.
Returns:
200: List of active sessions
401: Not authenticated
"""
from gatehouse_app.services.session_service import SessionService
sessions = SessionService.get_user_sessions(g.current_user.id, active_only=True)
return api_response(
data={
"sessions": [session.to_dict() for session in sessions],
"count": len(sessions),
},
message="Sessions retrieved successfully",
)
@api_v1_bp.route("/auth/sessions/<session_id>", methods=["DELETE"])
@login_required
def revoke_session(session_id):
"""
Revoke a specific session.
Args:
session_id: ID of session to revoke
Returns:
200: Session revoked
401: Not authenticated
404: Session not found
"""
from gatehouse_app.models.session import Session
# Ensure session belongs to current user
user_session = Session.query.filter_by(
id=session_id, user_id=g.current_user.id, deleted_at=None
).first()
if not user_session:
return api_response(
success=False,
message="Session not found",
status=404,
error_type="NOT_FOUND",
)
AuthService.revoke_session(session_id, reason="Revoked by user")
return api_response(
message="Session revoked successfully",
)
@api_v1_bp.route("/auth/totp/enroll", methods=["POST"])
@login_required
def enroll_totp():
"""
Initiate TOTP enrollment for the current user.
Returns:
201: TOTP enrollment initiated with secret, provisioning_uri, qr_code, and backup_codes
401: Not authenticated
409: TOTP already enabled
"""
try:
# Initiate TOTP enrollment
result = AuthService.enroll_totp(g.current_user)
return api_response(
data={
"secret": result["secret"],
"provisioning_uri": result["provisioning_uri"],
"qr_code": result["qr_code"],
"backup_codes": result["backup_codes"],
},
message="TOTP enrollment initiated. Please verify with your authenticator app.",
status=201,
)
except ConflictError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/totp/verify-enrollment", methods=["POST"])
@login_required
def verify_totp_enrollment():
"""
Complete TOTP enrollment by verifying the first TOTP code.
Request body:
code: 6-digit TOTP code from authenticator app
Returns:
200: TOTP enrollment completed successfully
400: Validation error
401: Not authenticated
401: Invalid TOTP code
"""
try:
# Validate request data
schema = TOTPVerifyEnrollmentSchema()
data = schema.load(request.json)
# Verify TOTP enrollment
AuthService.verify_totp_enrollment(g.current_user, data["code"])
return api_response(
message="TOTP enrollment completed successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except InvalidCredentialsError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/totp/verify", methods=["POST"])
def verify_totp():
"""
Verify TOTP code during login.
Request body:
code: 6-digit TOTP code or backup code
is_backup_code: True if code is a backup code, False if TOTP code (default: False)
Returns:
200: TOTP code verified successfully with session token
400: Validation error
401: Invalid TOTP code or session not found
"""
try:
# Validate request data
schema = TOTPVerifySchema()
data = schema.load(request.json)
# Get user from temporary session (stored in Flask session by login endpoint)
user_id = session.get("totp_pending_user_id")
if not user_id:
return api_response(
success=False,
message="No pending TOTP verification. Please login first.",
status=401,
error_type="AUTHENTICATION_ERROR",
)
# Get user from database
from gatehouse_app.models.user import User
user = User.query.get(user_id)
if not user:
return api_response(
success=False,
message="User not found",
status=401,
error_type="AUTHENTICATION_ERROR",
)
# Verify TOTP code
AuthService.authenticate_with_totp(
user, data["code"], data.get("is_backup_code", False)
)
# Create full session
user_session = AuthService.create_session(user)
# Clear temporary session
session.pop("totp_pending_user_id", None)
return api_response(
data={
"user": user.to_dict(),
"token": user_session.token,
"expires_at": user_session.expires_at.isoformat() + "Z"
if user_session.expires_at.isoformat()[-1] != "Z"
else user_session.expires_at.isoformat(),
},
message="TOTP verification successful",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except InvalidCredentialsError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/totp/disable", methods=["DELETE"])
@login_required
def disable_totp():
"""
Disable TOTP for the current user.
Request body:
password: User's current password for verification
Returns:
200: TOTP disabled successfully
400: Validation error
401: Not authenticated or invalid password
401: TOTP not enabled
"""
try:
# Validate request data
schema = TOTPDisableSchema()
data = schema.load(request.json)
# Disable TOTP
AuthService.disable_totp(g.current_user, data["password"])
return api_response(
message="TOTP disabled successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except InvalidCredentialsError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/totp/status", methods=["GET"])
@login_required
def get_totp_status():
"""
Get TOTP status for the current user.
Returns:
200: TOTP status with totp_enabled, verified_at, and backup_codes_remaining
401: Not authenticated
"""
user = g.current_user
# Check if TOTP is enabled
totp_enabled = user.has_totp_enabled()
# Get TOTP method to check backup codes remaining
backup_codes_remaining = 0
verified_at = None
if totp_enabled:
totp_method = user.get_totp_method()
if totp_method and totp_method.provider_data:
backup_codes = totp_method.provider_data.get("backup_codes", [])
backup_codes_remaining = len(backup_codes)
if totp_method and totp_method.totp_verified_at:
verified_at = totp_method.totp_verified_at.isoformat() + "Z" if totp_method.totp_verified_at.isoformat()[-1] != "Z" else totp_method.totp_verified_at.isoformat()
return api_response(
data={
"totp_enabled": totp_enabled,
"verified_at": verified_at,
"backup_codes_remaining": backup_codes_remaining,
},
message="TOTP status retrieved successfully",
)
@api_v1_bp.route("/auth/totp/regenerate-backup-codes", methods=["POST"])
@login_required
def regenerate_totp_backup_codes():
"""
Generate new backup codes for TOTP.
Request body:
password: User's current password for verification
Returns:
200: New backup codes generated successfully
400: Validation error
401: Not authenticated or invalid password
401: TOTP not enabled
"""
try:
# Validate request data
schema = TOTPRegenerateBackupCodesSchema()
data = schema.load(request.json)
# Regenerate backup codes
backup_codes = AuthService.regenerate_totp_backup_codes(
g.current_user, data["password"]
)
return api_response(
data={
"backup_codes": backup_codes,
},
message="Backup codes regenerated successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except InvalidCredentialsError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
# =============================================================================
# WebAuthn Passkey Endpoints
# =============================================================================
@api_v1_bp.route("/auth/webauthn/register/begin", methods=["POST"])
@login_required
def begin_webauthn_registration():
"""
Begin WebAuthn passkey registration.
Returns:
200: PublicKeyCredentialCreationOptions (raw JSON, no wrapper)
401: Not authenticated
"""
user = g.current_user
# Generate registration challenge
options = WebAuthnService.generate_registration_challenge(user)
# Return unwrapped JSON for WebAuthn
return jsonify(options), 200
@api_v1_bp.route("/auth/webauthn/register/complete", methods=["POST"])
@login_required
def complete_webauthn_registration():
"""
Complete WebAuthn passkey registration.
Request body:
id: Credential ID
rawId: Base64URL-encoded credential ID
type: "public-key"
response: Attestation response data
transports: List of transport types
Returns:
200: Registration successful
400: Validation error
401: Not authenticated
409: Credential already exists
"""
try:
# Validate request data
schema = WebAuthnRegistrationCompleteSchema()
data = schema.load(request.json)
# Extract challenge from client data
client_data = data.get("response", {}).get("clientDataJSON", "")
import base64
client_data_json = base64.urlsafe_b64decode(client_data + "==")
client_data_dict = json.loads(client_data_json)
challenge = client_data_dict.get("challenge")
if not challenge:
return api_response(
success=False,
message="Invalid challenge in client data",
status=400,
error_type="VALIDATION_ERROR",
)
# Verify registration response
auth_method = WebAuthnService.verify_registration_response(
g.current_user,
data,
challenge
)
return api_response(
data={
"credential": auth_method.to_webauthn_dict(),
},
message="Passkey registered successfully",
status=201,
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except InvalidCredentialsError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/webauthn/login/begin", methods=["POST"])
def begin_webauthn_login():
"""
Begin WebAuthn passkey login.
Request body:
email: User email address
Returns:
200: PublicKeyCredentialRequestOptions (raw JSON, no wrapper)
400: Validation error
404: User not found
"""
try:
# Validate request data
schema = WebAuthnLoginBeginSchema()
data = schema.load(request.json)
# Find user by email
from gatehouse_app.models.user import User
user = User.query.filter_by(
email=data["email"].lower(),
deleted_at=None
).first()
if not user:
return api_response(
success=False,
message="User not found",
status=404,
error_type="NOT_FOUND",
)
# Check if user has any WebAuthn credentials
if not user.has_webauthn_enabled():
return api_response(
success=False,
message="No passkeys found for this account",
status=404,
error_type="NOT_FOUND",
)
# Generate authentication challenge
options = WebAuthnService.generate_authentication_challenge(user)
# Store user_id in session for verification
session["webauthn_pending_user_id"] = user.id
# Return unwrapped JSON for WebAuthn
return jsonify(options), 200
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/auth/webauthn/login/complete", methods=["POST"])
def complete_webauthn_login():
"""
Complete WebAuthn passkey login.
Request body:
id: Credential ID
rawId: Base64URL-encoded credential ID
type: "public-key"
response: Assertion response data
Returns:
200: Login successful with session token
400: Validation error
401: Authentication failed
"""
try:
# Get user from session
user_id = session.get("webauthn_pending_user_id")
if not user_id:
return api_response(
success=False,
message="No pending WebAuthn verification. Please initiate login first.",
status=401,
error_type="AUTHENTICATION_ERROR",
)
# Validate request data
schema = WebAuthnLoginCompleteSchema()
data = schema.load(request.json)
# Get user from database
from gatehouse_app.models.user import User
user = User.query.get(user_id)
if not user:
return api_response(
success=False,
message="User not found",
status=401,
error_type="AUTHENTICATION_ERROR",
)
# Extract challenge from client data
client_data = data.get("response", {}).get("clientDataJSON", "")
import base64
client_data_json = base64.urlsafe_b64decode(client_data + "==")
client_data_dict = json.loads(client_data_json)
challenge = client_data_dict.get("challenge")
if not challenge:
return api_response(
success=False,
message="Invalid challenge in client data",
status=400,
error_type="VALIDATION_ERROR",
)
# Verify authentication response
WebAuthnService.verify_authentication_response(
user,
data,
challenge
)
# Create session
user_session = AuthService.create_session(user)
# Clear pending session
session.pop("webauthn_pending_user_id", None)
return api_response(
data={
"user": user.to_dict(),
"token": user_session.token,
"expires_at": user_session.expires_at.isoformat() + "Z"
if user_session.expires_at.isoformat()[-1] != "Z"
else user_session.expires_at.isoformat(),
},
message="Login successful",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
except InvalidCredentialsError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/webauthn/credentials", methods=["GET"])
@login_required
def list_webauthn_credentials():
"""
List all WebAuthn passkey credentials for the current user.
Returns:
200: List of credentials
401: Not authenticated
"""
user = g.current_user
credentials = WebAuthnService.get_user_credentials(user)
return api_response(
data={
"credentials": [cred.to_webauthn_dict() for cred in credentials],
"count": len(credentials),
},
message="Credentials retrieved successfully",
)
@api_v1_bp.route("/auth/webauthn/credentials/<credential_id>", methods=["DELETE"])
@login_required
def delete_webauthn_credential(credential_id):
"""
Delete a WebAuthn passkey credential.
Args:
credential_id: ID of the credential to delete
Returns:
200: Credential deleted successfully
401: Not authenticated
404: Credential not found
"""
user = g.current_user
# Check if this is the last credential
credential_count = user.get_webauthn_credential_count()
if credential_count <= 1:
return api_response(
success=False,
message="Cannot delete the last passkey. Add another passkey first.",
status=400,
error_type="BAD_REQUEST",
)
# Delete the credential
success = WebAuthnService.delete_credential(credential_id, user)
if not success:
return api_response(
success=False,
message="Credential not found",
status=404,
error_type="NOT_FOUND",
)
return api_response(
message="Passkey deleted successfully",
)
@api_v1_bp.route("/auth/webauthn/credentials/<credential_id>", methods=["PATCH"])
@login_required
def rename_webauthn_credential(credential_id):
"""
Rename a WebAuthn passkey credential.
Args:
credential_id: ID of the credential to rename
Request body:
name: New name for the credential
Returns:
200: Credential renamed successfully
400: Validation error
401: Not authenticated
404: Credential not found
"""
try:
# Validate request data
schema = WebAuthnCredentialRenameSchema()
data = schema.load(request.json)
# Rename the credential
success = WebAuthnService.rename_credential(
credential_id,
g.current_user,
data["name"]
)
if not success:
return api_response(
success=False,
message="Credential not found",
status=404,
error_type="NOT_FOUND",
)
# Get updated credential
credential = WebAuthnService.get_credential_by_id(credential_id, g.current_user)
return api_response(
data={
"credential": credential.to_webauthn_dict() if credential else None,
},
message="Passkey renamed successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/auth/webauthn/status", methods=["GET"])
@login_required
def get_webauthn_status():
"""
Get WebAuthn status for the current user.
Returns:
200: WebAuthn status with webauthn_enabled and credential_count
401: Not authenticated
"""
user = g.current_user
return api_response(
data={
"webauthn_enabled": user.has_webauthn_enabled(),
"credential_count": user.get_webauthn_credential_count(),
},
message="WebAuthn status retrieved successfully",
)
+372
View File
@@ -0,0 +1,372 @@
"""Organization endpoints."""
from flask import g, request
from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required, require_admin, require_owner
from gatehouse_app.schemas.organization_schema import (
OrganizationCreateSchema,
OrganizationUpdateSchema,
InviteMemberSchema,
UpdateMemberRoleSchema,
)
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.utils.constants import OrganizationRole
@api_v1_bp.route("/organizations", methods=["POST"])
@login_required
def create_organization():
"""
Create a new organization.
Request body:
name: Organization name
slug: Organization slug (unique)
description: Optional description
logo_url: Optional logo URL
Returns:
201: Organization created successfully
400: Validation error
401: Not authenticated
409: Slug already exists
"""
try:
# Validate request data
schema = OrganizationCreateSchema()
data = schema.load(request.json)
# Create organization
org = OrganizationService.create_organization(
name=data["name"],
slug=data["slug"],
owner_user_id=g.current_user.id,
description=data.get("description"),
logo_url=data.get("logo_url"),
)
return api_response(
data={"organization": org.to_dict()},
message="Organization created successfully",
status=201,
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>", methods=["GET"])
@login_required
def get_organization(org_id):
"""
Get organization by ID.
Args:
org_id: Organization ID
Returns:
200: Organization data
401: Not authenticated
403: Not a member
404: Organization not found
"""
org = OrganizationService.get_organization_by_id(org_id)
# Check if user is a member
if not org.is_member(g.current_user.id):
return api_response(
success=False,
message="You are not a member of this organization",
status=403,
error_type="AUTHORIZATION_ERROR",
)
return api_response(
data={
"organization": org.to_dict(),
"member_count": org.get_member_count(),
},
message="Organization retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>", methods=["PATCH"])
@login_required
@require_admin
def update_organization(org_id):
"""
Update organization.
Args:
org_id: Organization ID
Request body:
name: Optional organization name
description: Optional description
logo_url: Optional logo URL
Returns:
200: Organization updated successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization not found
"""
try:
# Validate request data
schema = OrganizationUpdateSchema()
data = schema.load(request.json)
org = OrganizationService.get_organization_by_id(org_id)
# Update organization
org = OrganizationService.update_organization(
org=org,
user_id=g.current_user.id,
**data
)
return api_response(
data={"organization": org.to_dict()},
message="Organization updated successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>", methods=["DELETE"])
@login_required
@require_owner
def delete_organization(org_id):
"""
Delete organization (soft delete).
Args:
org_id: Organization ID
Returns:
200: Organization deleted successfully
401: Not authenticated
403: Not the owner
404: Organization not found
"""
org = OrganizationService.get_organization_by_id(org_id)
OrganizationService.delete_organization(
org=org,
user_id=g.current_user.id,
soft=True,
)
return api_response(
message="Organization deleted successfully",
)
@api_v1_bp.route("/organizations/<org_id>/members", methods=["GET"])
@login_required
def get_organization_members(org_id):
"""
Get all members of an organization.
Args:
org_id: Organization ID
Returns:
200: List of members
401: Not authenticated
403: Not a member
404: Organization not found
"""
org = OrganizationService.get_organization_by_id(org_id)
# Check if user is a member
if not org.is_member(g.current_user.id):
return api_response(
success=False,
message="You are not a member of this organization",
status=403,
error_type="AUTHORIZATION_ERROR",
)
members_data = []
for member in org.members:
if member.deleted_at is None:
member_dict = member.to_dict()
member_dict["user"] = member.user.to_dict()
members_data.append(member_dict)
return api_response(
data={
"members": members_data,
"count": len(members_data),
},
message="Members retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/members", methods=["POST"])
@login_required
@require_admin
def add_organization_member(org_id):
"""
Add a member to the organization.
Args:
org_id: Organization ID
Request body:
email: User email to invite
role: Member role (owner, admin, member, guest)
Returns:
201: Member added successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization or user not found
409: User already a member
"""
try:
# Validate request data
schema = InviteMemberSchema()
data = schema.load(request.json)
org = OrganizationService.get_organization_by_id(org_id)
# Find user by email
user = UserService.get_user_by_email(data["email"])
if not user:
return api_response(
success=False,
message="User not found",
status=404,
error_type="NOT_FOUND",
)
# Add member
role = OrganizationRole(data["role"])
member = OrganizationService.add_member(
org=org,
user_id=user.id,
role=role,
inviter_id=g.current_user.id,
)
member_dict = member.to_dict()
member_dict["user"] = user.to_dict()
return api_response(
data={"member": member_dict},
message="Member added successfully",
status=201,
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>", methods=["DELETE"])
@login_required
@require_admin
def remove_organization_member(org_id, user_id):
"""
Remove a member from the organization.
Args:
org_id: Organization ID
user_id: User ID to remove
Returns:
200: Member removed successfully
401: Not authenticated
403: Not an admin
404: Organization or member not found
"""
org = OrganizationService.get_organization_by_id(org_id)
OrganizationService.remove_member(
org=org,
user_id=user_id,
remover_id=g.current_user.id,
)
return api_response(
message="Member removed successfully",
)
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>/role", methods=["PATCH"])
@login_required
@require_admin
def update_member_role(org_id, user_id):
"""
Update a member's role.
Args:
org_id: Organization ID
user_id: User ID
Request body:
role: New role (owner, admin, member, guest)
Returns:
200: Role updated successfully
400: Validation error
401: Not authenticated
403: Not an admin
404: Organization or member not found
"""
try:
# Validate request data
schema = UpdateMemberRoleSchema()
data = schema.load(request.json)
org = OrganizationService.get_organization_by_id(org_id)
# Update role
new_role = OrganizationRole(data["role"])
member = OrganizationService.update_member_role(
org=org,
user_id=user_id,
new_role=new_role,
updater_id=g.current_user.id,
)
member_dict = member.to_dict()
member_dict["user"] = member.user.to_dict()
return api_response(
data={"member": member_dict},
message="Member role updated successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
+155
View File
@@ -0,0 +1,155 @@
"""User endpoints."""
from flask import g, request
from marshmallow import ValidationError
from gatehouse_app.api.v1 import api_v1_bp
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.decorators import login_required
from gatehouse_app.schemas.user_schema import UserUpdateSchema, ChangePasswordSchema
from gatehouse_app.services.user_service import UserService
from gatehouse_app.services.auth_service import AuthService
@api_v1_bp.route("/users/me", methods=["GET"])
@login_required
def get_me():
"""
Get current user profile.
Returns:
200: User profile data
401: Not authenticated
"""
user = g.current_user
return api_response(
data={"user": user.to_dict()},
message="User profile retrieved successfully",
)
@api_v1_bp.route("/users/me", methods=["PATCH"])
@login_required
def update_me():
"""
Update current user profile.
Request body:
full_name: Optional full name
avatar_url: Optional avatar URL
Returns:
200: User updated successfully
400: Validation error
401: Not authenticated
"""
try:
# Validate request data
schema = UserUpdateSchema()
data = schema.load(request.json)
# Update user
user = UserService.update_user(g.current_user, **data)
return api_response(
data={"user": user.to_dict()},
message="Profile updated successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/users/me", methods=["DELETE"])
@login_required
def delete_me():
"""
Delete current user account (soft delete).
Returns:
200: Account deleted successfully
401: Not authenticated
"""
UserService.delete_user(g.current_user, soft=True)
return api_response(
message="Account deleted successfully",
)
@api_v1_bp.route("/users/me/password", methods=["POST"])
@login_required
def change_password():
"""
Change current user password.
Request body:
current_password: Current password
new_password: New password
new_password_confirm: New password confirmation
Returns:
200: Password changed successfully
400: Validation error
401: Not authenticated or invalid current password
"""
try:
# Validate request data
schema = ChangePasswordSchema()
data = schema.load(request.json)
# Verify passwords match
if data["new_password"] != data["new_password_confirm"]:
return api_response(
success=False,
message="New passwords do not match",
status=400,
error_type="VALIDATION_ERROR",
error_details={"new_password_confirm": ["Passwords do not match"]},
)
# Change password
AuthService.change_password(
user=g.current_user,
current_password=data["current_password"],
new_password=data["new_password"],
)
return api_response(
message="Password changed successfully",
)
except ValidationError as e:
return api_response(
success=False,
message="Validation failed",
status=400,
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/users/me/organizations", methods=["GET"])
@login_required
def get_my_organizations():
"""
Get all organizations current user is a member of.
Returns:
200: List of organizations
401: Not authenticated
"""
organizations = UserService.get_user_organizations(g.current_user)
return api_response(
data={
"organizations": [org.to_dict() for org in organizations],
"count": len(organizations),
},
message="Organizations retrieved successfully",
)
+40
View File
@@ -0,0 +1,40 @@
"""Exceptions package."""
from gatehouse_app.exceptions.base import BaseAPIException
from gatehouse_app.exceptions.auth_exceptions import (
UnauthorizedError,
ForbiddenError,
InvalidCredentialsError,
AccountSuspendedError,
AccountInactiveError,
SessionExpiredError,
InvalidTokenError,
)
from gatehouse_app.exceptions.validation_exceptions import (
ValidationError,
NotFoundError,
ConflictError,
BadRequestError,
RateLimitExceededError,
EmailAlreadyExistsError,
OrganizationNotFoundError,
UserNotFoundError,
)
__all__ = [
"BaseAPIException",
"UnauthorizedError",
"ForbiddenError",
"InvalidCredentialsError",
"AccountSuspendedError",
"AccountInactiveError",
"SessionExpiredError",
"InvalidTokenError",
"ValidationError",
"NotFoundError",
"ConflictError",
"BadRequestError",
"RateLimitExceededError",
"EmailAlreadyExistsError",
"OrganizationNotFoundError",
"UserNotFoundError",
]
@@ -0,0 +1,58 @@
"""Authentication and authorization exceptions."""
from gatehouse_app.exceptions.base import BaseAPIException
class UnauthorizedError(BaseAPIException):
"""Raised when authentication is required but not provided."""
status_code = 401
error_type = "AUTHENTICATION_ERROR"
message = "Authentication required"
class ForbiddenError(BaseAPIException):
"""Raised when user lacks permissions for the requested action."""
status_code = 403
error_type = "AUTHORIZATION_ERROR"
message = "You don't have permission to perform this action"
class InvalidCredentialsError(BaseAPIException):
"""Raised when login credentials are invalid."""
status_code = 401
error_type = "AUTHENTICATION_ERROR"
message = "Invalid email or password"
class AccountSuspendedError(BaseAPIException):
"""Raised when user account is suspended."""
status_code = 403
error_type = "AUTHORIZATION_ERROR"
message = "Your account has been suspended"
class AccountInactiveError(BaseAPIException):
"""Raised when user account is inactive."""
status_code = 403
error_type = "AUTHORIZATION_ERROR"
message = "Your account is inactive"
class SessionExpiredError(BaseAPIException):
"""Raised when user session has expired."""
status_code = 401
error_type = "AUTHENTICATION_ERROR"
message = "Your session has expired. Please log in again"
class InvalidTokenError(BaseAPIException):
"""Raised when authentication token is invalid."""
status_code = 401
error_type = "AUTHENTICATION_ERROR"
message = "Invalid authentication token"
+31
View File
@@ -0,0 +1,31 @@
"""Base exception classes."""
class BaseAPIException(Exception):
"""Base exception for all API errors."""
status_code = 500
error_type = "INTERNAL_ERROR"
message = "An unexpected error occurred"
def __init__(self, message=None, error_details=None):
"""
Initialize exception.
Args:
message: Custom error message
error_details: Additional error details dictionary
"""
super().__init__()
if message:
self.message = message
self.error_details = error_details or {}
def to_dict(self):
"""Convert exception to dictionary for API response."""
return {
"error_type": self.error_type,
"message": self.message,
"details": self.error_details,
"status_code": self.status_code,
}
@@ -0,0 +1,60 @@
"""Validation and resource exceptions."""
from gatehouse_app.exceptions.base import BaseAPIException
class ValidationError(BaseAPIException):
"""Raised when request data validation fails."""
status_code = 400
error_type = "VALIDATION_ERROR"
message = "Validation failed"
class NotFoundError(BaseAPIException):
"""Raised when a requested resource is not found."""
status_code = 404
error_type = "NOT_FOUND"
message = "Resource not found"
class ConflictError(BaseAPIException):
"""Raised when a resource conflict occurs."""
status_code = 409
error_type = "CONFLICT"
message = "Resource conflict"
class BadRequestError(BaseAPIException):
"""Raised when the request is malformed or invalid."""
status_code = 400
error_type = "BAD_REQUEST"
message = "Bad request"
class RateLimitExceededError(BaseAPIException):
"""Raised when rate limit is exceeded."""
status_code = 429
error_type = "RATE_LIMIT_EXCEEDED"
message = "Too many requests. Please try again later"
class EmailAlreadyExistsError(ConflictError):
"""Raised when attempting to register with an existing email."""
message = "Email address already registered"
class OrganizationNotFoundError(NotFoundError):
"""Raised when organization is not found."""
message = "Organization not found"
class UserNotFoundError(NotFoundError):
"""Raised when user is not found."""
message = "User not found"
+26
View File
@@ -0,0 +1,26 @@
"""Flask extensions initialization."""
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_bcrypt import Bcrypt
from flask_cors import CORS
from flask_marshmallow import Marshmallow
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_session import Session
import redis
# Initialize extensions
db = SQLAlchemy()
migrate = Migrate()
bcrypt = Bcrypt()
cors = CORS()
ma = Marshmallow()
limiter = Limiter(
key_func=get_remote_address,
default_limits=["100 per hour"],
storage_uri="memory://", # Will be overridden by config
)
session = Session()
# Redis client - will be initialized with app
redis_client = None
+6
View File
@@ -0,0 +1,6 @@
"""Middleware package."""
from gatehouse_app.middleware.request_id import RequestIDMiddleware
from gatehouse_app.middleware.security_headers import SecurityHeadersMiddleware
from gatehouse_app.middleware.cors import setup_cors
__all__ = ["RequestIDMiddleware", "SecurityHeadersMiddleware", "setup_cors"]
+64
View File
@@ -0,0 +1,64 @@
"""CORS middleware configuration."""
from flask import request, make_response
def setup_cors(app):
"""
Configure CORS for the application.
Args:
app: Flask application instance
"""
@app.before_request
def handle_preflight():
"""Handle CORS preflight OPTIONS requests."""
if request.method == "OPTIONS":
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all:
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
elif origin and origin in cors_origins:
response = make_response("", 204)
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
response.headers["Cache-Control"] = "no-cache, no-store"
return response
@app.after_request
def after_request_cors(response):
"""Add additional CORS headers if needed."""
origin = request.headers.get("Origin")
cors_origins = app.config.get("CORS_ORIGINS", [])
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
if allow_all:
# When allowing all origins, set header to "*"
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Max-Age"] = "3600"
elif origin and origin in cors_origins:
# When allowing specific origins, echo the request origin
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Max-Age"] = "3600"
return response
+38
View File
@@ -0,0 +1,38 @@
"""Request ID middleware for request tracing."""
import uuid
from flask import g, request
class RequestIDMiddleware:
"""Middleware to add unique request ID to each request."""
def __init__(self, app=None):
"""Initialize middleware."""
self.app = app
if app is not None:
self.init_app(app)
def init_app(self, app):
"""Initialize with Flask app."""
app.before_request(self.before_request)
app.after_request(self.after_request)
@staticmethod
def before_request():
"""Generate or extract request ID before request processing."""
# Check if request already has an ID from client
request_id = request.headers.get("X-Request-ID")
# Generate new ID if not provided
if not request_id:
request_id = str(uuid.uuid4())
# Store in Flask g object for access throughout request
g.request_id = request_id
@staticmethod
def after_request(response):
"""Add request ID to response headers."""
if hasattr(g, "request_id"):
response.headers["X-Request-ID"] = g.request_id
return response
@@ -0,0 +1,65 @@
"""Security headers middleware."""
from flask import request
class SecurityHeadersMiddleware:
"""Middleware to add security headers to responses."""
def __init__(self, app=None):
"""Initialize middleware."""
self.app = app
if app is not None:
self.init_app(app)
def init_app(self, app):
"""Initialize with Flask app."""
app.after_request(self.add_security_headers)
@staticmethod
def add_security_headers(response):
"""Add security headers to response."""
# Prevent MIME type sniffing
response.headers["X-Content-Type-Options"] = "nosniff"
# Enable XSS protection
response.headers["X-XSS-Protection"] = "1; mode=block"
# Prevent clickjacking
response.headers["X-Frame-Options"] = "DENY"
# Strict Transport Security (HSTS)
if request.is_secure:
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
# Content Security Policy
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"font-src 'self' data:; "
"connect-src 'self'"
)
# Referrer Policy
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Permissions Policy
response.headers["Permissions-Policy"] = (
"geolocation=(), microphone=(), camera=()"
)
# Cache-Control: Allow OIDC endpoints to set their own Cache-Control
# Only set no-cache for API responses that haven't set their own cache headers
if "Cache-Control" not in response.headers:
# Check if this is a JSON API response (shouldn't be cached)
content_type = response.headers.get("Content-Type", "")
if "application/json" in content_type:
response.headers["Cache-Control"] = "no-cache, no-store"
elif "text/html" not in content_type:
# For non-HTML responses, add Pragma for HTTP/1.0 compatibility
response.headers["Pragma"] = "no-cache"
return response
+30
View File
@@ -0,0 +1,30 @@
"""Models package."""
from gatehouse_app.models.base import BaseModel
from gatehouse_app.models.user import User
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.models.session import Session
from gatehouse_app.models.audit_log import AuditLog
from gatehouse_app.models.oidc_client import OIDCClient
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken
from gatehouse_app.models.oidc_session import OIDCSession
from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
__all__ = [
"BaseModel",
"User",
"Organization",
"OrganizationMember",
"AuthenticationMethod",
"Session",
"AuditLog",
"OIDCClient",
"OIDCAuthCode",
"OIDCRefreshToken",
"OIDCSession",
"OIDCTokenMetadata",
"OIDCAuditLog",
]
+62
View File
@@ -0,0 +1,62 @@
"""Audit log model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import AuditAction
class AuditLog(BaseModel):
"""Audit log model for tracking user and system actions."""
__tablename__ = "audit_logs"
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True)
action = db.Column(db.Enum(AuditAction), nullable=False, index=True)
# Context
resource_type = db.Column(db.String(50), nullable=True, index=True)
resource_id = db.Column(db.String(36), nullable=True, index=True)
organization_id = db.Column(db.String(36), nullable=True, index=True)
# Request details
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
request_id = db.Column(db.String(36), nullable=True, index=True)
# Additional data
extra_data = db.Column(db.JSON, nullable=True)
description = db.Column(db.Text, nullable=True)
# Success/failure
success = db.Column(db.Boolean, default=True, nullable=False)
error_message = db.Column(db.Text, nullable=True)
# Relationships
user = db.relationship("User", back_populates="audit_logs")
# Indexes for common queries
__table_args__ = (
db.Index("idx_audit_user_action", "user_id", "action"),
db.Index("idx_audit_resource", "resource_type", "resource_id"),
db.Index("idx_audit_org", "organization_id", "created_at"),
)
def __repr__(self):
"""String representation of AuditLog."""
return f"<AuditLog action={self.action} user_id={self.user_id}>"
@classmethod
def log(cls, action, user_id=None, **kwargs):
"""
Create an audit log entry.
Args:
action: AuditAction enum value
user_id: ID of the user performing the action
**kwargs: Additional audit log fields
Returns:
AuditLog instance
"""
log_entry = cls(action=action, user_id=user_id, **kwargs)
log_entry.save()
return log_entry
@@ -0,0 +1,93 @@
"""Authentication method model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import AuthMethodType
class AuthenticationMethod(BaseModel):
"""Authentication method model storing user authentication credentials."""
__tablename__ = "authentication_methods"
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True)
method_type = db.Column(db.Enum(AuthMethodType), nullable=False, index=True)
# For password authentication
password_hash = db.Column(db.String(255), nullable=True)
# For OAuth/OIDC providers
provider_user_id = db.Column(db.String(255), nullable=True)
provider_data = db.Column(db.JSON, nullable=True)
# For TOTP authentication
totp_secret = db.Column(db.String(32), nullable=True)
totp_backup_codes = db.Column(db.JSON, nullable=True)
totp_verified_at = db.Column(db.DateTime, nullable=True)
# Metadata
is_primary = db.Column(db.Boolean, default=False, nullable=False)
verified = db.Column(db.Boolean, default=False, nullable=False)
last_used_at = db.Column(db.DateTime, nullable=True)
# Relationships
user = db.relationship("User", back_populates="authentication_methods")
# Ensure unique provider combinations
__table_args__ = (
db.Index("idx_user_method", "user_id", "method_type"),
db.UniqueConstraint(
"user_id", "method_type", "provider_user_id", name="uix_user_method_provider"
),
)
def __repr__(self):
"""String representation of AuthenticationMethod."""
return f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
def is_password(self):
"""Check if this is a password authentication method."""
return self.method_type == AuthMethodType.PASSWORD
def is_oauth(self):
"""Check if this is an OAuth authentication method."""
return self.method_type in [
AuthMethodType.GOOGLE,
AuthMethodType.GITHUB,
AuthMethodType.MICROSOFT,
]
def is_totp(self):
"""Check if this is a TOTP authentication method."""
return self.method_type == AuthMethodType.TOTP
def is_webauthn(self):
"""Check if this is a WebAuthn authentication method."""
return self.method_type == AuthMethodType.WEBAUTHN
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude password hash and TOTP secrets
exclude.append("password_hash")
exclude.append("totp_secret")
exclude.append("totp_backup_codes")
return super().to_dict(exclude=exclude)
def to_webauthn_dict(self):
"""Convert WebAuthn credential to public dictionary.
Returns:
Dictionary with safe-to-expose credential information.
"""
if not self.is_webauthn() or not self.provider_data:
return None
data = self.provider_data
return {
"id": data.get("credential_id"),
"name": data.get("name"),
"transports": data.get("transports", []),
"created_at": data.get("created_at"),
"last_used_at": data.get("last_used_at"),
"sign_count": data.get("sign_count", 0),
}
+73
View File
@@ -0,0 +1,73 @@
"""Base model with common fields and functionality."""
import uuid
from datetime import datetime, timezone
from gatehouse_app.extensions import db
class BaseModel(db.Model):
"""Base model class with common fields."""
__abstract__ = True
id = db.Column(
db.String(36),
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
nullable=False,
)
created_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
updated_at = db.Column(
db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)
)
deleted_at = db.Column(db.DateTime, nullable=True)
def save(self):
"""Save the model instance to database."""
db.session.add(self)
db.session.commit()
return self
def delete(self, soft=True):
"""
Delete the model instance.
Args:
soft: If True, performs soft delete. If False, hard delete.
"""
if soft:
self.deleted_at = datetime.now(timezone.utc)
db.session.commit()
else:
db.session.delete(self)
db.session.commit()
def update(self, **kwargs):
"""Update model fields."""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
self.updated_at = datetime.now(timezone.utc)
db.session.commit()
return self
def to_dict(self, exclude=None):
"""
Convert model to dictionary.
Args:
exclude: List of fields to exclude from output
Returns:
Dictionary representation of the model
"""
exclude = exclude or []
result = {}
for column in self.__table__.columns:
if column.name not in exclude:
value = getattr(self, column.name)
if isinstance(value, datetime):
result[column.name] = value.isoformat()
else:
result[column.name] = value
return result
+231
View File
@@ -0,0 +1,231 @@
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
from datetime import datetime
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCAuditLog(BaseModel):
"""OIDC Audit Log model for comprehensive OIDC event tracking.
This model logs all OIDC-related events for security, compliance,
and debugging purposes.
"""
__tablename__ = "oidc_audit_logs"
# Event type categorization
event_type = db.Column(db.String(100), nullable=False, index=True)
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=True, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=True, index=True
)
# Event outcome
success = db.Column(db.Boolean, default=True, nullable=False, index=True)
# Error details (for failed events)
error_code = db.Column(db.String(100), nullable=True)
error_description = db.Column(db.Text, nullable=True)
# Request context
ip_address = db.Column(db.String(45), nullable=True, index=True)
user_agent = db.Column(db.Text, nullable=True)
request_id = db.Column(db.String(36), nullable=True, index=True)
# Additional event metadata
event_metadata = db.Column(db.JSON, nullable=True)
# Relationships
client = db.relationship("OIDCClient", back_populates="audit_logs")
user = db.relationship("User", back_populates="oidc_audit_logs")
def __repr__(self):
"""String representation of OIDCAuditLog."""
status = "success" if self.success else "failed"
return f"<OIDCAuditLog event={self.event_type} status={status} client={self.client_id}>"
@classmethod
def log_event(cls, event_type, client_id=None, user_id=None, success=True,
error_code=None, error_description=None, ip_address=None,
user_agent=None, request_id=None, event_metadata=None):
"""Log an OIDC event.
Args:
event_type: Type of event (e.g., "authorization_request", "token_issue")
client_id: The OIDC client ID
user_id: The user ID
success: Whether the event was successful
error_code: Error code if event failed
error_description: Error description if event failed
ip_address: Client IP address
user_agent: Client user agent
request_id: Request ID for correlation
event_metadata: Additional event metadata
Returns:
OIDCAuditLog instance
"""
log = cls(
event_type=event_type,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata=event_metadata,
)
db.session.add(log)
db.session.commit()
return log
@classmethod
def log_authorization_request(cls, client_id, user_id, redirect_uri, scope,
ip_address=None, user_agent=None, request_id=None,
success=True, error_code=None, error_description=None):
"""Log an authorization request event."""
return cls.log_event(
event_type="authorization_request",
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={
"redirect_uri": redirect_uri,
"scope": scope,
}
)
@classmethod
def log_token_issue(cls, client_id, user_id, token_type,
ip_address=None, user_agent=None, request_id=None):
"""Log a token issuance event."""
return cls.log_event(
event_type="token_issue",
client_id=client_id,
user_id=user_id,
success=True,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={"token_type": token_type}
)
@classmethod
def log_token_revocation(cls, client_id, user_id, token_type, reason=None,
ip_address=None, user_agent=None, request_id=None):
"""Log a token revocation event."""
return cls.log_event(
event_type="token_revocation",
client_id=client_id,
user_id=user_id,
success=True,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
event_metadata={
"token_type": token_type,
"reason": reason,
}
)
@classmethod
def log_authentication_failure(cls, client_id, error_code, error_description,
ip_address=None, user_agent=None, request_id=None):
"""Log an authentication failure event."""
return cls.log_event(
event_type="authentication_failure",
client_id=client_id,
success=False,
error_code=error_code,
error_description=error_description,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
)
@classmethod
def get_events_for_user(cls, user_id, limit=100):
"""Get audit events for a user.
Args:
user_id: The user ID
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
return cls.query.filter_by(user_id=user_id, deleted_at=None)\
.order_by(cls.created_at.desc())\
.limit(limit)\
.all()
@classmethod
def get_events_for_client(cls, client_id, limit=100):
"""Get audit events for a client.
Args:
client_id: The client ID
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
return cls.query.filter_by(client_id=client_id, deleted_at=None)\
.order_by(cls.created_at.desc())\
.limit(limit)\
.all()
@classmethod
def get_failed_events(cls, client_id=None, user_id=None, start_date=None,
end_date=None, limit=100):
"""Get failed audit events.
Args:
client_id: Optional client ID filter
user_id: Optional user ID filter
start_date: Optional start date filter
end_date: Optional end date filter
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
query = cls.query.filter_by(success=False, deleted_at=None)
if client_id:
query = query.filter_by(client_id=client_id)
if user_id:
query = query.filter_by(user_id=user_id)
if start_date:
query = query.filter(cls.created_at >= start_date)
if end_date:
query = query.filter(cls.created_at <= end_date)
return query.order_by(cls.created_at.desc()).limit(limit).all()
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_audit_logs = db.relationship(
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.audit_logs = db.relationship(
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
)
@@ -0,0 +1,125 @@
"""OIDC Authorization Code model for auth code flow."""
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCAuthCode(BaseModel):
"""OIDC Authorization Code model for authorization code flow.
Authorization codes are single-use, short-lived codes used in the
authorization code grant flow. The code is hashed for security.
"""
__tablename__ = "oidc_authorization_codes"
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Authorization code (hashed for security)
code_hash = db.Column(db.String(255), nullable=False)
# Request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True) # Requested scopes
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
code_verifier = db.Column(db.String(255), nullable=True) # For PKCE
# Status tracking
expires_at = db.Column(db.DateTime, nullable=False, index=True)
used_at = db.Column(db.DateTime, nullable=True)
is_used = db.Column(db.Boolean, default=False, nullable=False)
# Request metadata
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
# Relationships
client = db.relationship("OIDCClient", back_populates="authorization_codes")
user = db.relationship("User", back_populates="oidc_auth_codes")
def __repr__(self):
"""String representation of OIDCAuthCode."""
return f"<OIDCAuthCode client_id={self.client_id} user_id={self.user_id} used={self.is_used}>"
def is_expired(self):
"""Check if the authorization code has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
# Make naive datetime timezone-aware (UTC)
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_valid(self):
"""Check if the authorization code is valid for use."""
return not self.is_used and not self.is_expired() and self.deleted_at is None
def mark_as_used(self):
"""Mark the authorization code as used."""
self.is_used = True
self.used_at = datetime.now(timezone.utc)
db.session.commit()
@classmethod
def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None,
nonce=None, code_verifier=None, ip_address=None, user_agent=None,
lifetime_seconds=600):
"""Create a new authorization code.
Args:
client_id: The OIDC client ID
user_id: The user ID
code_hash: Hashed authorization code
redirect_uri: The redirect URI
scope: Requested scopes
nonce: OIDC nonce
code_verifier: PKCE code verifier
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
Returns:
OIDCAuthCode instance
"""
code = cls(
client_id=client_id,
user_id=user_id,
code_hash=code_hash,
redirect_uri=redirect_uri,
scope=scope,
nonce=nonce,
code_verifier=code_verifier,
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
ip_address=ip_address,
user_agent=user_agent,
)
db.session.add(code)
db.session.commit()
return code
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude code hash
exclude.append("code_hash")
exclude.append("code_verifier")
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_auth_codes = db.relationship(
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.authorization_codes = db.relationship(
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
)
+69
View File
@@ -0,0 +1,69 @@
"""OIDC Client model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType
class OIDCClient(BaseModel):
"""OIDC client model for OAuth2/OIDC integrations."""
__tablename__ = "oidc_clients"
organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
)
name = db.Column(db.String(255), nullable=False)
client_id = db.Column(db.String(255), unique=True, nullable=False, index=True)
client_secret_hash = db.Column(db.String(255), nullable=False)
# OAuth/OIDC configuration
redirect_uris = db.Column(db.JSON, nullable=False) # List of allowed redirect URIs
grant_types = db.Column(db.JSON, nullable=False) # List of allowed grant types
response_types = db.Column(db.JSON, nullable=False) # List of allowed response types
scopes = db.Column(db.JSON, nullable=False) # List of allowed scopes
# Client metadata
logo_uri = db.Column(db.String(512), nullable=True)
client_uri = db.Column(db.String(512), nullable=True)
policy_uri = db.Column(db.String(512), nullable=True)
tos_uri = db.Column(db.String(512), nullable=True)
# Settings
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_confidential = db.Column(db.Boolean, default=True, nullable=False)
require_pkce = db.Column(db.Boolean, default=True, nullable=False)
# Token lifetimes (in seconds)
access_token_lifetime = db.Column(db.Integer, default=3600, nullable=False)
refresh_token_lifetime = db.Column(db.Integer, default=2592000, nullable=False)
id_token_lifetime = db.Column(db.Integer, default=3600, nullable=False)
# Relationships
organization = db.relationship("Organization", back_populates="oidc_clients")
def __repr__(self):
"""String representation of OIDCClient."""
return f"<OIDCClient {self.name} client_id={self.client_id}>"
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude client secret
exclude.append("client_secret_hash")
return super().to_dict(exclude=exclude)
def has_grant_type(self, grant_type):
"""Check if client supports a specific grant type."""
return grant_type in self.grant_types
def has_response_type(self, response_type):
"""Check if client supports a specific response type."""
return response_type in self.response_types
def is_redirect_uri_allowed(self, redirect_uri):
"""Check if a redirect URI is allowed for this client."""
return redirect_uri in self.redirect_uris
def has_scope(self, scope):
"""Check if client is allowed to request a specific scope."""
return scope in self.scopes
+77
View File
@@ -0,0 +1,77 @@
"""OIDC JWKS Key model for persisting signing keys."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OidcJwksKey(BaseModel):
"""
OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
This model stores RSA/ECDSA key pairs used for signing OIDC tokens.
Multiple keys can be stored to support key rotation scenarios.
Attributes:
id: Integer primary key
kid: Unique key ID used in JWT "kid" header
key_type: Type of key (e.g., "RSA", "EC")
private_key: PEM-encoded private key
public_key: PEM-encoded public key
algorithm: Signing algorithm (e.g., "RS256", "ES256")
created_at: When the key was created
is_active: Whether this key is currently active for signing
is_primary: Whether this is the primary signing key
expires_at: ...
"""
__tablename__ = "oidc_jwks_keys"
# Override the default UUID id with integer primary key
id = db.Column(db.Integer, primary_key=True)
expires_at = db.Column(db.DateTime, nullable=True)
# Key identification and type
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
# Key material (PEM-encoded)
private_key = db.Column(db.Text, nullable=False)
public_key = db.Column(db.Text, nullable=False)
# Key status
is_active = db.Column(db.Boolean, default=True, nullable=False)
is_primary = db.Column(db.Boolean, default=False, nullable=False)
def __repr__(self):
"""String representation of OidcJwksKey."""
return f"<OidcJwksKey kid={self.kid} key_type={self.key_type} algorithm={self.algorithm}>"
def to_dict(self, exclude_private_key=True):
"""
Convert model to dictionary.
Args:
exclude_private_key: If True, excludes the private key from output
Returns:
Dictionary representation of the model
"""
exclude = ["private_key"] if exclude_private_key else []
return super().to_dict(exclude=exclude)
@classmethod
def get_active_keys(cls):
"""Get all active keys for signing operations."""
return cls.query.filter(cls.is_active == True).all()
@classmethod
def get_primary_key(cls):
"""Get the primary signing key."""
return cls.query.filter(cls.is_primary == True).first()
@classmethod
def get_key_by_kid(cls, kid):
"""Get a key by its key ID."""
return cls.query.filter(cls.kid == kid, cls.is_active == True).first()
+163
View File
@@ -0,0 +1,163 @@
"""OIDC Refresh Token model for token rotation."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCRefreshToken(BaseModel):
"""OIDC Refresh Token model for token refresh and rotation.
Refresh tokens are long-lived credentials used to obtain new access tokens.
They support token rotation for enhanced security.
"""
__tablename__ = "oidc_refresh_tokens"
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Token (hashed for security)
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
# Associated access token ID
access_token_id = db.Column(
db.String(36), db.ForeignKey("sessions.id"), nullable=True, index=True
)
# Token scope
scope = db.Column(db.JSON, nullable=True) # Granted scopes
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
# Revocation tracking
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
# Token rotation metadata
previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation
rotation_count = db.Column(db.Integer, default=0, nullable=False)
# Request metadata
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
# Relationships
client = db.relationship("OIDCClient", back_populates="refresh_tokens")
user = db.relationship("User", back_populates="oidc_refresh_tokens")
access_token = db.relationship("Session", back_populates="oidc_refresh_token")
def __repr__(self):
"""String representation of OIDCRefreshToken."""
return f"<OIDCRefreshToken client_id={self.client_id} user_id={self.user_id} revoked={self.is_revoked()}>"
def is_expired(self):
"""Check if the refresh token has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self):
"""Check if the refresh token has been revoked."""
return self.revoked_at is not None
def is_valid(self):
"""Check if the refresh token is valid for use."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason=None):
"""Revoke the refresh token.
Args:
reason: Optional reason for revocation
"""
self.revoked_at = datetime.now(timezone.utc)
self.revoked_reason = reason
db.session.commit()
def rotate(self, new_token_hash):
"""Rotate the refresh token (invalidate old, create new).
Args:
new_token_hash: Hash of the new refresh token
Returns:
self for chaining
"""
# Store reference to old token
self.previous_token_hash = self.token_hash
self.token_hash = new_token_hash
self.rotation_count += 1
# Extend expiration on rotation
from datetime import timedelta
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
db.session.commit()
return self
@classmethod
def create_token(cls, client_id, user_id, token_hash, scope=None,
access_token_id=None, ip_address=None, user_agent=None,
lifetime_seconds=2592000):
"""Create a new refresh token.
Args:
client_id: The OIDC client ID
user_id: The user ID
token_hash: Hashed refresh token
scope: Granted scopes
access_token_id: Associated access token ID
ip_address: Client IP address
user_agent: Client user agent
lifetime_seconds: Token lifetime in seconds (default 30 days)
Returns:
OIDCRefreshToken instance
"""
from datetime import timedelta
token = cls(
client_id=client_id,
user_id=user_id,
token_hash=token_hash,
scope=scope,
access_token_id=access_token_id,
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
ip_address=ip_address,
user_agent=user_agent,
)
db.session.add(token)
db.session.commit()
return token
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Always exclude token hashes
exclude.append("token_hash")
exclude.append("previous_token_hash")
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.refresh_tokens = db.relationship(
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
)
# Add relationship back to Session model
from gatehouse_app.models.session import Session
Session.oidc_refresh_token = db.relationship(
"OIDCRefreshToken", back_populates="access_token", uselist=False
)
+162
View File
@@ -0,0 +1,162 @@
"""OIDC Session model for OIDC session tracking."""
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCSession(BaseModel):
"""OIDC Session model for tracking OIDC authentication sessions.
This model tracks the state during the OIDC authentication flow,
including PKCE parameters and nonce validation.
"""
__tablename__ = "oidc_sessions"
# User reference
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Client reference
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
# State management
state = db.Column(db.String(255), nullable=False, index=True)
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
# Authorization request parameters
redirect_uri = db.Column(db.String(512), nullable=False)
scope = db.Column(db.JSON, nullable=True) # Requested scopes
# PKCE parameters
code_challenge = db.Column(db.String(255), nullable=True)
code_challenge_method = db.Column(db.String(10), nullable=True) # "S256" or "plain"
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
authenticated_at = db.Column(db.DateTime, nullable=True)
# Relationships
user = db.relationship("User", back_populates="oidc_sessions")
client = db.relationship("OIDCClient", back_populates="oidc_sessions")
def __repr__(self):
"""String representation of OIDCSession."""
return f"<OIDCSession user_id={self.user_id} client_id={self.client_id} state={self.state[:8]}...>"
def is_expired(self):
"""Check if the OIDC session has expired."""
return datetime.now(timezone.utc) > self.expires_at
def is_authenticated(self):
"""Check if the user has been authenticated in this session."""
return self.authenticated_at is not None
def mark_authenticated(self):
"""Mark the session as authenticated."""
self.authenticated_at = datetime.now(timezone.utc)
db.session.commit()
def validate_nonce(self, expected_nonce):
"""Validate the nonce matches the expected value.
Args:
expected_nonce: The expected nonce value
Returns:
bool: True if nonce matches
"""
return self.nonce == expected_nonce
def validate_code_challenge(self, code_verifier):
"""Validate the code verifier against the stored code challenge.
Args:
code_verifier: The PKCE code verifier
Returns:
bool: True if code challenge is valid
"""
if not self.code_challenge:
return False
if self.code_challenge_method == "S256":
import hashlib
import base64
# SHA256 hash of code_verifier
digest = hashlib.sha256(code_verifier.encode()).digest()
# Base64 URL encode without padding
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
return self.code_challenge == expected
elif self.code_challenge_method == "plain":
return self.code_challenge == code_verifier
return False
@classmethod
def create_session(cls, user_id, client_id, state, redirect_uri, scope=None,
nonce=None, code_challenge=None, code_challenge_method=None,
lifetime_seconds=600):
"""Create a new OIDC session.
Args:
user_id: The user ID
client_id: The OIDC client ID
state: The state parameter
redirect_uri: The redirect URI
scope: Requested scopes
nonce: OIDC nonce
code_challenge: PKCE code challenge
code_challenge_method: PKCE method ("S256" or "plain")
lifetime_seconds: Session lifetime in seconds
Returns:
OIDCSession instance
"""
from datetime import timedelta
session = cls(
user_id=user_id,
client_id=client_id,
state=state,
redirect_uri=redirect_uri,
scope=scope,
nonce=nonce,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
)
db.session.add(session)
db.session.commit()
return session
@classmethod
def get_by_state(cls, state):
"""Get a session by state parameter.
Args:
state: The state parameter
Returns:
OIDCSession instance or None
"""
return cls.query.filter_by(state=state, deleted_at=None).first()
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_sessions = db.relationship(
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.oidc_sessions = db.relationship(
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
)
+196
View File
@@ -0,0 +1,196 @@
"""OIDC Token Metadata model for token revocation tracking."""
import uuid
from datetime import datetime, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class OIDCTokenMetadata(BaseModel):
"""OIDC Token Metadata model for tracking issued tokens.
This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens)
for the purpose of token revocation. The id field matches the JTI (JWT ID) claim.
"""
__tablename__ = "oidc_token_metadata"
# Token identifier (matches JTI in JWT)
id = db.Column(
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
# Client and User references
client_id = db.Column(
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
)
user_id = db.Column(
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
)
# Token type
token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token"
# Token identifier for revocation lookup
token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim
# Timing
expires_at = db.Column(db.DateTime, nullable=False, index=True)
# Revocation tracking
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
# Relationships
client = db.relationship("OIDCClient", back_populates="token_metadata")
user = db.relationship("User", back_populates="oidc_token_metadata")
def __repr__(self):
"""String representation of OIDCTokenMetadata."""
return f"<OIDCTokenMetadata jti={self.token_jti[:8]}... type={self.token_type} revoked={self.is_revoked()}>"
def is_expired(self):
"""Check if the token has expired."""
# Handle both timezone-aware and timezone-naive expires_at values
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > expires_at
def is_revoked(self):
"""Check if the token has been revoked."""
return self.revoked_at is not None
def is_valid(self):
"""Check if the token is valid (not expired and not revoked)."""
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
def revoke(self, reason=None):
"""Revoke the token.
Args:
reason: Optional reason for revocation
"""
self.revoked_at = datetime.now(timezone.utc)
self.revoked_reason = reason
db.session.commit()
@classmethod
def create_metadata(cls, client_id, user_id, token_type, token_jti,
expires_at, ip_address=None, user_agent=None):
"""Create token metadata for tracking.
Args:
client_id: The OIDC client ID
user_id: The user ID
token_type: Type of token ("access_token", "refresh_token", "id_token")
token_jti: JWT ID claim
expires_at: Token expiration datetime
ip_address: Client IP address
user_agent: Client user agent
Returns:
OIDCTokenMetadata instance
"""
metadata = cls(
id=str(uuid.uuid4()),
client_id=client_id,
user_id=user_id,
token_type=token_type,
token_jti=token_jti,
expires_at=expires_at,
)
db.session.add(metadata)
db.session.commit()
return metadata
@classmethod
def get_by_jti(cls, token_jti):
"""Get token metadata by JWT ID.
Args:
token_jti: The JWT ID
Returns:
OIDCTokenMetadata instance or None
"""
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
@classmethod
def revoke_by_jti(cls, token_jti, reason=None):
"""Revoke a token by its JWT ID.
Args:
token_jti: The JWT ID
reason: Optional revocation reason
Returns:
bool: True if token was found and revoked
"""
metadata = cls.get_by_jti(token_jti)
if metadata:
metadata.revoke(reason)
return True
return False
@classmethod
def revoke_all_for_user(cls, user_id, client_id=None, reason=None):
"""Revoke all tokens for a user.
Args:
user_id: The user ID
client_id: Optional client ID to filter by
reason: Optional revocation reason
Returns:
int: Number of tokens revoked
"""
query = cls.query.filter_by(user_id=user_id, deleted_at=None)
if client_id:
query = query.filter_by(client_id=client_id)
tokens = query.filter(cls.revoked_at == None).all()
count = 0
for token in tokens:
token.revoke(reason)
count += 1
return count
@classmethod
def revoke_all_for_client(cls, client_id, user_id=None, reason=None):
"""Revoke all tokens for a client.
Args:
client_id: The client ID
user_id: Optional user ID to filter by
reason: Optional revocation reason
Returns:
int: Number of tokens revoked
"""
query = cls.query.filter_by(client_id=client_id, deleted_at=None)
if user_id:
query = query.filter_by(user_id=user_id)
tokens = query.filter(cls.revoked_at == None).all()
count = 0
for token in tokens:
token.revoke(reason)
count += 1
return count
def to_dict(self, exclude=None):
"""Convert to dictionary."""
return super().to_dict(exclude=exclude)
# Add relationship back to User model
from gatehouse_app.models.user import User
User.oidc_token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
)
# Add relationship back to OIDCClient model
from gatehouse_app.models.oidc_client import OIDCClient
OIDCClient.token_metadata = db.relationship(
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
)
+54
View File
@@ -0,0 +1,54 @@
"""Organization model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class Organization(BaseModel):
"""Organization model representing a tenant/workspace."""
__tablename__ = "organizations"
name = db.Column(db.String(255), nullable=False)
slug = db.Column(db.String(255), unique=True, nullable=False, index=True)
description = db.Column(db.Text, nullable=True)
logo_url = db.Column(db.String(512), nullable=True)
is_active = db.Column(db.Boolean, default=True, nullable=False)
# Settings (stored as JSON)
settings = db.Column(db.JSON, nullable=True, default=dict)
# Relationships
members = db.relationship(
"OrganizationMember", back_populates="organization", cascade="all, delete-orphan"
)
oidc_clients = db.relationship(
"OIDCClient", back_populates="organization", cascade="all, delete-orphan"
)
def __repr__(self):
"""String representation of Organization."""
return f"<Organization {self.name}>"
def get_member_count(self):
"""Get the count of active members in the organization."""
return len([m for m in self.members if m.deleted_at is None])
def get_owner(self):
"""Get the owner of the organization."""
from gatehouse_app.utils.constants import OrganizationRole
for member in self.members:
if member.role == OrganizationRole.OWNER and member.deleted_at is None:
return member.user
return None
def is_member(self, user_id):
"""Check if a user is a member of the organization."""
from gatehouse_app.models.organization_member import OrganizationMember
return (
OrganizationMember.query.filter_by(
user_id=user_id, organization_id=self.id, deleted_at=None
).first()
is not None
)
@@ -0,0 +1,51 @@
"""Organization member model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import OrganizationRole
class OrganizationMember(BaseModel):
"""Organization member model representing user membership in an organization."""
__tablename__ = "organization_members"
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True)
organization_id = db.Column(
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
)
role = db.Column(
db.Enum(OrganizationRole), default=OrganizationRole.MEMBER, nullable=False
)
invited_by_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
invited_at = db.Column(db.DateTime, nullable=True)
joined_at = db.Column(db.DateTime, nullable=True)
# Relationships
user = db.relationship("User", foreign_keys=[user_id], back_populates="organization_memberships")
organization = db.relationship("Organization", back_populates="members")
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
# Unique constraint to prevent duplicate memberships
__table_args__ = (
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"),
)
def __repr__(self):
"""String representation of OrganizationMember."""
return f"<OrganizationMember user_id={self.user_id} org_id={self.organization_id} role={self.role}>"
def is_owner(self):
"""Check if member is an owner."""
return self.role == OrganizationRole.OWNER
def is_admin(self):
"""Check if member is an admin or owner."""
return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
def can_manage_members(self):
"""Check if member can manage other members."""
return self.is_admin()
def can_delete_organization(self):
"""Check if member can delete the organization."""
return self.is_owner()
+86
View File
@@ -0,0 +1,86 @@
"""Session model."""
from datetime import datetime, timedelta, timezone
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import SessionStatus
class Session(BaseModel):
"""Session model for tracking user sessions."""
__tablename__ = "sessions"
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True)
token = db.Column(db.String(255), unique=True, nullable=False, index=True)
status = db.Column(db.Enum(SessionStatus), default=SessionStatus.ACTIVE, nullable=False)
# Session metadata
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.Text, nullable=True)
device_info = db.Column(db.JSON, nullable=True)
# Timing
expires_at = db.Column(db.DateTime, nullable=False)
last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
revoked_at = db.Column(db.DateTime, nullable=True)
revoked_reason = db.Column(db.String(255), nullable=True)
# Relationships
user = db.relationship("User", back_populates="sessions")
def __repr__(self):
"""String representation of Session."""
return f"<Session user_id={self.user_id} status={self.status}>"
def is_active(self):
"""Check if session is currently active."""
now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return (
self.status == SessionStatus.ACTIVE
and expires_at > now
and self.deleted_at is None
)
def is_expired(self):
"""Check if session has expired."""
now = datetime.now(timezone.utc)
# Make expires_at timezone-aware if it's naive
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return now > expires_at
def refresh(self, duration_seconds=86400):
"""
Refresh session expiration.
Args:
duration_seconds: New session duration in seconds
"""
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
self.last_activity_at = datetime.now(timezone.utc)
db.session.commit()
def revoke(self, reason=None):
"""
Revoke the session.
Args:
reason: Optional reason for revocation
"""
self.status = SessionStatus.REVOKED
self.revoked_at = datetime.now(timezone.utc)
if reason:
self.revoked_reason = reason
db.session.commit()
def to_dict(self, exclude=None):
"""Convert to dictionary, excluding sensitive fields."""
exclude = exclude or []
# Exclude token from dict
exclude.append("token")
return super().to_dict(exclude=exclude)
+141
View File
@@ -0,0 +1,141 @@
"""User model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.utils.constants import UserStatus
class User(BaseModel):
"""User model representing a user account."""
__tablename__ = "users"
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
email_verified = db.Column(db.Boolean, default=False, nullable=False)
full_name = db.Column(db.String(255), nullable=True)
avatar_url = db.Column(db.String(512), nullable=True)
status = db.Column(
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
)
last_login_at = db.Column(db.DateTime, nullable=True)
last_login_ip = db.Column(db.String(45), nullable=True)
# Relationships
authentication_methods = db.relationship(
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
)
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
organization_memberships = db.relationship(
"OrganizationMember",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="OrganizationMember.user_id",
)
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
def __repr__(self):
"""String representation of User."""
return f"<User {self.email}>"
def to_dict(self, exclude=None):
"""Convert user to dictionary, excluding sensitive fields by default."""
exclude = exclude or []
# Always exclude password-related fields
default_exclude = []
all_exclude = list(set(default_exclude + exclude))
return super().to_dict(exclude=all_exclude)
def has_password_auth(self):
"""Check if user has password authentication enabled."""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
).first()
is not None
)
def get_organizations(self):
"""Get all organizations the user is a member of."""
return [membership.organization for membership in self.organization_memberships]
def has_totp_enabled(self) -> bool:
"""Check if user has TOTP enabled and verified.
Returns:
True if user has a verified TOTP authentication method, False otherwise.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.TOTP,
verified=True,
deleted_at=None,
).first()
is not None
)
def get_totp_method(self):
"""Get user's TOTP authentication method.
Returns:
The AuthenticationMethod instance for TOTP or None if not found.
Note:
Returns the most recently created TOTP method to handle cases where
multiple enrollment attempts may exist.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
).order_by(AuthenticationMethod.created_at.desc()).first()
def has_webauthn_enabled(self) -> bool:
"""Check if user has any WebAuthn passkey credentials.
Returns:
True if user has at least one WebAuthn credential, False otherwise.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return (
AuthenticationMethod.query.filter_by(
user_id=self.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None,
).first()
is not None
)
def get_webauthn_credentials(self):
"""Get all WebAuthn credentials for the user.
Returns:
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
).order_by(AuthenticationMethod.created_at.desc()).all()
def get_webauthn_credential_count(self) -> int:
"""Get the count of WebAuthn credentials for the user.
Returns:
Number of WebAuthn credentials.
"""
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType
return AuthenticationMethod.query.filter_by(
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
).count()
+34
View File
@@ -0,0 +1,34 @@
"""Schemas package."""
from gatehouse_app.schemas.user_schema import UserSchema, UserUpdateSchema, ChangePasswordSchema
from gatehouse_app.schemas.auth_schema import (
RegisterSchema,
LoginSchema,
RefreshTokenSchema,
ForgotPasswordSchema,
ResetPasswordSchema,
)
from gatehouse_app.schemas.organization_schema import (
OrganizationSchema,
OrganizationCreateSchema,
OrganizationUpdateSchema,
OrganizationMemberSchema,
InviteMemberSchema,
UpdateMemberRoleSchema,
)
__all__ = [
"UserSchema",
"UserUpdateSchema",
"ChangePasswordSchema",
"RegisterSchema",
"LoginSchema",
"RefreshTokenSchema",
"ForgotPasswordSchema",
"ResetPasswordSchema",
"OrganizationSchema",
"OrganizationCreateSchema",
"OrganizationUpdateSchema",
"OrganizationMemberSchema",
"InviteMemberSchema",
"UpdateMemberRoleSchema",
]
+88
View File
@@ -0,0 +1,88 @@
"""Authentication schemas for validation."""
from marshmallow import Schema, fields, validate, validates_schema, ValidationError
class RegisterSchema(Schema):
"""Schema for user registration."""
email = fields.Email(required=True)
password = fields.Str(
required=True,
validate=validate.Length(min=8, max=128),
)
password_confirm = fields.Str(required=True)
full_name = fields.Str(allow_none=True, validate=validate.Length(max=255))
@validates_schema
def validate_passwords_match(self, data, **kwargs):
"""Validate that passwords match."""
if data.get("password") != data.get("password_confirm"):
raise ValidationError("Passwords do not match", field_name="password_confirm")
class LoginSchema(Schema):
"""Schema for user login."""
email = fields.Email(required=True)
password = fields.Str(required=True, validate=validate.Length(min=1))
remember_me = fields.Bool(missing=False)
class RefreshTokenSchema(Schema):
"""Schema for token refresh."""
refresh_token = fields.Str(required=True)
class ForgotPasswordSchema(Schema):
"""Schema for forgot password request."""
email = fields.Email(required=True)
class ResetPasswordSchema(Schema):
"""Schema for password reset."""
token = fields.Str(required=True)
password = fields.Str(
required=True,
validate=validate.Length(min=8, max=128),
)
password_confirm = fields.Str(required=True)
@validates_schema
def validate_passwords_match(self, data, **kwargs):
"""Validate that passwords match."""
if data.get("password") != data.get("password_confirm"):
raise ValidationError("Passwords do not match", field_name="password_confirm")
class TOTPVerifyEnrollmentSchema(Schema):
"""Schema for TOTP enrollment verification."""
code = fields.Str(
required=True,
validate=validate.Regexp(
r"^\d{6}$",
error="Code must be a 6-digit number",
),
)
class TOTPVerifySchema(Schema):
"""Schema for TOTP code verification during login."""
code = fields.Str(required=True)
is_backup_code = fields.Bool(missing=False)
class TOTPDisableSchema(Schema):
"""Schema for disabling TOTP."""
password = fields.Str(required=True, validate=validate.Length(min=1))
class TOTPRegenerateBackupCodesSchema(Schema):
"""Schema for regenerating backup codes."""
password = fields.Str(required=True, validate=validate.Length(min=1))
@@ -0,0 +1,62 @@
"""Organization schemas for validation."""
from marshmallow import Schema, fields, validate
class OrganizationSchema(Schema):
"""Schema for Organization model."""
id = fields.Str(dump_only=True)
name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
slug = fields.Str(required=True, validate=validate.Length(min=1, max=255))
description = fields.Str(allow_none=True)
logo_url = fields.Url(allow_none=True, validate=validate.Length(max=512))
is_active = fields.Bool(dump_only=True)
created_at = fields.DateTime(dump_only=True)
updated_at = fields.DateTime(dump_only=True)
class OrganizationCreateSchema(Schema):
"""Schema for creating an organization."""
name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
slug = fields.Str(required=True, validate=validate.Length(min=1, max=255))
description = fields.Str(allow_none=True)
logo_url = fields.Url(allow_none=True, validate=validate.Length(max=512))
class OrganizationUpdateSchema(Schema):
"""Schema for updating an organization."""
name = fields.Str(validate=validate.Length(min=1, max=255))
description = fields.Str(allow_none=True)
logo_url = fields.Url(allow_none=True, validate=validate.Length(max=512))
class OrganizationMemberSchema(Schema):
"""Schema for Organization Member."""
id = fields.Str(dump_only=True)
user_id = fields.Str(dump_only=True)
organization_id = fields.Str(dump_only=True)
role = fields.Str(dump_only=True)
joined_at = fields.DateTime(dump_only=True)
created_at = fields.DateTime(dump_only=True)
class InviteMemberSchema(Schema):
"""Schema for inviting a member to an organization."""
email = fields.Email(required=True)
role = fields.Str(
required=True,
validate=validate.OneOf(["owner", "admin", "member", "guest"])
)
class UpdateMemberRoleSchema(Schema):
"""Schema for updating a member's role."""
role = fields.Str(
required=True,
validate=validate.OneOf(["owner", "admin", "member", "guest"])
)
+47
View File
@@ -0,0 +1,47 @@
"""User schemas for validation and serialization."""
from marshmallow import Schema, fields, validate, validates, ValidationError
from gatehouse_app.utils.constants import UserStatus
class UserSchema(Schema):
"""Schema for User model."""
id = fields.Str(dump_only=True)
email = fields.Email(required=True)
email_verified = fields.Bool(dump_only=True)
full_name = fields.Str(allow_none=True, validate=validate.Length(max=255))
avatar_url = fields.Url(allow_none=True, validate=validate.Length(max=512))
status = fields.Str(dump_only=True)
last_login_at = fields.DateTime(dump_only=True)
created_at = fields.DateTime(dump_only=True)
updated_at = fields.DateTime(dump_only=True)
class UserUpdateSchema(Schema):
"""Schema for updating user profile."""
full_name = fields.Str(allow_none=True, validate=validate.Length(max=255))
avatar_url = fields.Url(allow_none=True, validate=validate.Length(max=512))
class ChangePasswordSchema(Schema):
"""Schema for changing password."""
current_password = fields.Str(required=True, validate=validate.Length(min=1))
new_password = fields.Str(
required=True,
validate=validate.Length(min=8, max=128),
)
new_password_confirm = fields.Str(required=True)
@validates("new_password")
def validate_password_strength(self, value):
"""Validate password strength."""
if len(value) < 8:
raise ValidationError("Password must be at least 8 characters long")
if not any(char.isdigit() for char in value):
raise ValidationError("Password must contain at least one digit")
if not any(char.isupper() for char in value):
raise ValidationError("Password must contain at least one uppercase letter")
if not any(char.islower() for char in value):
raise ValidationError("Password must contain at least one lowercase letter")
+85
View File
@@ -0,0 +1,85 @@
"""WebAuthn schemas for validation."""
from marshmallow import Schema, fields, validate, validates_schema, ValidationError
class WebAuthnRegistrationBeginSchema(Schema):
"""Schema for beginning WebAuthn registration."""
# No required fields - uses authenticated user
pass
class WebAuthnRegistrationCompleteSchema(Schema):
"""Schema for completing WebAuthn registration."""
id = fields.Str(required=True)
rawId = fields.Str(required=True)
type = fields.Str(
required=True,
validate=validate.OneOf(["public-key"])
)
response = fields.Dict(required=True)
transports = fields.List(
fields.Str(validate=validate.OneOf(["usb", "nfc", "ble", "hybrid", "internal", "platform"])),
load_default=[]
)
@validates_schema
def validate_response(self, data, **kwargs):
"""Validate response contains required fields."""
response = data.get("response", {})
required_fields = ["attestationObject", "clientDataJSON"]
for field in required_fields:
if field not in response:
raise ValidationError(
f"Missing required field in response: {field}",
field_name=f"response.{field}"
)
class WebAuthnLoginBeginSchema(Schema):
"""Schema for beginning WebAuthn login."""
email = fields.Email(required=True)
class WebAuthnLoginCompleteSchema(Schema):
"""Schema for completing WebAuthn login."""
id = fields.Str(required=True)
rawId = fields.Str(required=True)
type = fields.Str(
required=True,
validate=validate.OneOf(["public-key"])
)
response = fields.Dict(required=True)
clientExtensionResults = fields.Dict(load_default={})
@validates_schema
def validate_response(self, data, **kwargs):
"""Validate response contains required fields."""
response = data.get("response", {})
required_fields = ["authenticatorData", "clientDataJSON", "signature"]
for field in required_fields:
if field not in response:
raise ValidationError(
f"Missing required field in response: {field}",
field_name=f"response.{field}"
)
class WebAuthnCredentialRenameSchema(Schema):
"""Schema for renaming a WebAuthn credential."""
name = fields.Str(
required=True,
validate=validate.Length(min=1, max=100)
)
class WebAuthnCredentialDeleteSchema(Schema):
"""Schema for deleting a WebAuthn credential."""
password = fields.Str(
required=True,
validate=validate.Length(min=1)
)
+25
View File
@@ -0,0 +1,25 @@
"""Services package."""
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.session_service import SessionService
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.services.oidc_service import OIDCService, OIDCError
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
from gatehouse_app.services.oidc_token_service import OIDCTokenService
from gatehouse_app.services.oidc_session_service import OIDCSessionService
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
__all__ = [
"AuthService",
"UserService",
"OrganizationService",
"SessionService",
"AuditService",
"OIDCService",
"OIDCError",
"OIDCJWKSService",
"OIDCTokenService",
"OIDCSessionService",
"OIDCAuditService",
]
+107
View File
@@ -0,0 +1,107 @@
"""Audit service."""
from flask import request, g
from gatehouse_app.models.audit_log import AuditLog
from gatehouse_app.utils.constants import AuditAction
class AuditService:
"""Service for audit logging."""
@staticmethod
def log_action(
action,
user_id=None,
organization_id=None,
resource_type=None,
resource_id=None,
metadata=None,
description=None,
success=True,
error_message=None,
):
"""
Create an audit log entry.
Args:
action: AuditAction enum value
user_id: ID of user performing the action
organization_id: ID of related organization
resource_type: Type of resource being acted upon
resource_id: ID of resource being acted upon
metadata: Additional metadata dictionary
description: Human-readable description
success: Whether the action succeeded
error_message: Error message if action failed
Returns:
AuditLog instance
"""
# Get request details if available
ip_address = None
user_agent = None
request_id = None
try:
if request:
ip_address = request.remote_addr
user_agent = request.headers.get("User-Agent")
request_id = g.get("request_id")
except RuntimeError:
# No request context
pass
log_entry = AuditLog(
action=action,
user_id=user_id,
organization_id=organization_id,
resource_type=resource_type,
resource_id=resource_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
metadata=metadata,
description=description,
success=success,
error_message=error_message,
)
log_entry.save()
return log_entry
@staticmethod
def get_user_activity(user_id, limit=50):
"""
Get recent activity for a user.
Args:
user_id: User ID
limit: Maximum number of records to return
Returns:
List of AuditLog instances
"""
return (
AuditLog.query.filter_by(user_id=user_id)
.order_by(AuditLog.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def get_organization_activity(organization_id, limit=50):
"""
Get recent activity for an organization.
Args:
organization_id: Organization ID
limit: Maximum number of records to return
Returns:
List of AuditLog instances
"""
return (
AuditLog.query.filter_by(organization_id=organization_id)
.order_by(AuditLog.created_at.desc())
.limit(limit)
.all()
)
+559
View File
@@ -0,0 +1,559 @@
"""Authentication service."""
import logging
import secrets
from datetime import datetime, timedelta, timezone
from flask import request, g, current_app
from gatehouse_app.extensions import db, bcrypt
from gatehouse_app.models.user import User
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.models.session import Session
from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, UserStatus, AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.services.totp_service import TOTPService
logger = logging.getLogger(__name__)
class AuthService:
"""Service for authentication operations."""
@staticmethod
def register_user(email, password, full_name=None):
"""
Register a new user with email/password.
Args:
email: User email address
password: Plain text password
full_name: Optional full name
Returns:
User instance
Raises:
EmailAlreadyExistsError: If email is already registered
"""
# Check if email already exists
existing_user = User.query.filter_by(email=email.lower()).first()
if existing_user and existing_user.deleted_at is None:
raise EmailAlreadyExistsError()
# Create user
user = User(
email=email.lower(),
full_name=full_name,
status=UserStatus.ACTIVE,
)
user.save()
# Create password authentication method
password_hash = bcrypt.generate_password_hash(password).decode("utf-8")
auth_method = AuthenticationMethod(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
password_hash=password_hash,
is_primary=True,
verified=True,
)
auth_method.save()
# Log the registration
AuditService.log_action(
action=AuditAction.USER_REGISTER,
user_id=user.id,
resource_type="user",
resource_id=user.id,
description=f"User registered with email: {email}",
)
return user
@staticmethod
def authenticate(email, password):
"""
Authenticate user with email/password.
Args:
email: User email
password: Plain text password
Returns:
User instance if authentication succeeds
Raises:
InvalidCredentialsError: If credentials are invalid
AccountSuspendedError: If account is suspended
AccountInactiveError: If account is inactive
"""
# Find user
user = User.query.filter_by(email=email.lower(), deleted_at=None).first()
# Development-only debug logging for user existence check
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] User lookup: email={email}, exists={user is not None}")
if not user:
raise InvalidCredentialsError()
# Check account status
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}")
if user.status == UserStatus.SUSPENDED:
raise AccountSuspendedError()
if user.status == UserStatus.INACTIVE:
raise AccountInactiveError()
# Find password auth method
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
# Development-only debug logging for auth method lookup
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] Auth method lookup: user_id={user.id}, has_password_auth={auth_method is not None and auth_method.password_hash is not None}")
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError()
# Verify password
password_valid = bcrypt.check_password_hash(auth_method.password_hash, password)
# Development-only debug logging for password validation (without logging actual password)
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Auth] Password validation: user_id={user.id}, valid={password_valid}")
if not password_valid:
raise InvalidCredentialsError()
# Update last login
user.last_login_at = datetime.now(timezone.utc)
user.last_login_ip = request.remote_addr
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
return user
@staticmethod
def create_session(user, duration_seconds=86400):
"""
Create a new session for the user.
Args:
user: User instance
duration_seconds: Session duration in seconds
Returns:
Session instance
"""
# Generate session token
token = secrets.token_urlsafe(32)
# Create session
session = Session(
user_id=user.id,
token=token,
status=SessionStatus.ACTIVE,
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent"),
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
last_activity_at=datetime.now(timezone.utc),
)
session.save()
# Log session creation
AuditService.log_action(
action=AuditAction.SESSION_CREATE,
user_id=user.id,
resource_type="session",
resource_id=session.id,
description="User session created",
)
return session
@staticmethod
def change_password(user, current_password, new_password):
"""
Change user password.
Args:
user: User instance
current_password: Current password
new_password: New password
Raises:
InvalidCredentialsError: If current password is incorrect
"""
# Find password auth method
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError("No password authentication method found")
# Verify current password
if not bcrypt.check_password_hash(auth_method.password_hash, current_password):
raise InvalidCredentialsError("Current password is incorrect")
# Update password
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
db.session.commit()
# Log password change
AuditService.log_action(
action=AuditAction.PASSWORD_CHANGE,
user_id=user.id,
description="User changed password",
)
@staticmethod
def revoke_session(session_id, reason=None):
"""
Revoke a session.
Args:
session_id: Session ID to revoke
reason: Optional revocation reason
"""
session = Session.query.get(session_id)
if session:
session.revoke(reason=reason)
# Log session revocation
AuditService.log_action(
action=AuditAction.SESSION_REVOKE,
user_id=session.user_id,
resource_type="session",
resource_id=session.id,
description=f"Session revoked: {reason or 'User logout'}",
)
@staticmethod
def enroll_totp(user: User) -> dict:
"""
Initiate TOTP enrollment for a user.
Args:
user: User instance
Returns:
Dictionary containing:
- secret: TOTP secret (base32 encoded)
- provisioning_uri: otpauth:// URI for QR code
- qr_code: Base64 encoded QR code as data URI
- backup_codes: List of plain text backup codes
Raises:
ConflictError: If user already has TOTP enabled
"""
from gatehouse_app.exceptions.validation_exceptions import ConflictError
# Check if user already has TOTP enabled
if user.has_totp_enabled():
raise ConflictError("TOTP is already enabled for this account")
# Clean up any existing unverified TOTP enrollment attempts
# Use hard delete for unverified methods since they're incomplete enrollment attempts
existing_totp_method = user.get_totp_method()
if existing_totp_method and not existing_totp_method.verified:
logger.debug(f"Removing existing unverified TOTP method for user {user.id}")
db.session.delete(existing_totp_method) # Hard delete - unverified methods are temporary
db.session.commit() # Commit to ensure deletion before creating new record
# Generate TOTP secret
secret = TOTPService.generate_secret()
# Generate provisioning URI
provisioning_uri = TOTPService.generate_provisioning_uri(
user_email=user.email,
secret=secret,
issuer="Gatehouse",
)
# Generate QR code data URI
qr_code = TOTPService.generate_qr_code_data_uri(provisioning_uri)
# Generate backup codes
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
# Create unverified TOTP authentication method
auth_method = AuthenticationMethod(
user_id=user.id,
method_type=AuthMethodType.TOTP,
verified=False,
is_primary=False,
)
auth_method.save()
# Store TOTP data in provider_data (since totp_secret field is commented out)
auth_method.provider_data = {
"secret": secret,
"backup_codes": hashed_backup_codes,
}
db.session.commit()
# Log TOTP enrollment initiation
AuditService.log_action(
action=AuditAction.TOTP_ENROLL_INITIATED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP enrollment initiated",
)
return {
"secret": secret,
"provisioning_uri": provisioning_uri,
"qr_code": qr_code,
"backup_codes": backup_codes,
}
@staticmethod
def verify_totp_enrollment(user: User, code: str) -> bool:
"""
Complete TOTP enrollment by verifying the first TOTP code.
Args:
user: User instance
code: 6-digit TOTP code from authenticator app
Returns:
True if verification successful
Raises:
InvalidCredentialsError: If code is invalid or TOTP method not found
"""
# Get user's TOTP authentication method
auth_method = user.get_totp_method()
if not auth_method:
raise InvalidCredentialsError("TOTP enrollment not found")
# Get secret from provider_data
secret = auth_method.provider_data.get("secret") if auth_method.provider_data else None
if not secret:
raise InvalidCredentialsError("TOTP secret not found")
# Verify the code
if not TOTPService.verify_code(secret, code):
raise InvalidCredentialsError("Invalid TOTP code")
# Mark TOTP as verified
auth_method.verified = True
auth_method.totp_verified_at = datetime.now(timezone.utc)
db.session.commit()
# Log TOTP enrollment completion
AuditService.log_action(
action=AuditAction.TOTP_ENROLL_COMPLETED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP enrollment completed",
)
return True
@staticmethod
def disable_totp(user: User, password: str) -> bool:
"""
Disable TOTP for a user.
Args:
user: User instance
password: User's current password for verification
Returns:
True if TOTP disabled successfully
Raises:
InvalidCredentialsError: If password is invalid or TOTP method not found
"""
# Verify user's password
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError("No password authentication method found")
if not bcrypt.check_password_hash(auth_method.password_hash, password):
raise InvalidCredentialsError("Invalid password")
# Get user's TOTP authentication method
totp_method = user.get_totp_method()
if not totp_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
# Soft-delete the TOTP authentication method
totp_method.delete(soft=True)
# Log TOTP disabled
AuditService.log_action(
action=AuditAction.TOTP_DISABLED,
user_id=user.id,
resource_type="authentication_method",
resource_id=totp_method.id,
description="TOTP disabled",
)
return True
@staticmethod
def authenticate_with_totp(user: User, code: str, is_backup_code: bool = False) -> bool:
"""
Verify TOTP code during login.
Args:
user: User instance
code: 6-digit TOTP code or backup code
is_backup_code: True if code is a backup code, False if TOTP code
Returns:
True if code is valid
Raises:
InvalidCredentialsError: If code is invalid or TOTP method not found
"""
# Get user's TOTP authentication method
auth_method = user.get_totp_method()
if not auth_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
if is_backup_code:
# Verify backup code
backup_codes = (
auth_method.provider_data.get("backup_codes")
if auth_method.provider_data
else []
)
is_valid, remaining_codes = TOTPService.verify_backup_code(backup_codes, code)
if is_valid:
# Update remaining backup codes
auth_method.provider_data = {
"secret": auth_method.provider_data.get("secret"),
"backup_codes": remaining_codes,
}
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.add(auth_method)
db.session.commit()
logger.debug(f"[BACKUP CODE] Updated provider_data: {auth_method.provider_data}")
# Log backup code usage
AuditService.log_action(
action=AuditAction.TOTP_BACKUP_CODE_USED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Backup code used for authentication",
)
else:
# Log failed verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_FAILED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Invalid backup code provided",
)
raise InvalidCredentialsError("Invalid backup code")
else:
# Verify TOTP code
secret = (
auth_method.provider_data.get("secret")
if auth_method.provider_data
else None
)
if not secret:
raise InvalidCredentialsError("TOTP secret not found")
is_valid = TOTPService.verify_code(secret, code)
if is_valid:
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
# Log successful verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_SUCCESS,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="TOTP code verified successfully",
)
else:
# Log failed verification
AuditService.log_action(
action=AuditAction.TOTP_VERIFY_FAILED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="Invalid TOTP code provided",
)
raise InvalidCredentialsError("Invalid TOTP code")
return True
@staticmethod
def regenerate_totp_backup_codes(user: User, password: str) -> list[str]:
"""
Generate new backup codes for TOTP.
Args:
user: User instance
password: User's current password for verification
Returns:
List of new plain text backup codes
Raises:
InvalidCredentialsError: If password is invalid or TOTP method not found
"""
# Verify user's password
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.PASSWORD,
deleted_at=None,
).first()
if not auth_method or not auth_method.password_hash:
raise InvalidCredentialsError("No password authentication method found")
if not bcrypt.check_password_hash(auth_method.password_hash, password):
raise InvalidCredentialsError("Invalid password")
# Get user's TOTP authentication method
totp_method = user.get_totp_method()
if not totp_method:
raise InvalidCredentialsError("TOTP is not enabled for this account")
# Generate new backup codes
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
# Update the authentication method with new backup codes
totp_method.provider_data = {
"secret": totp_method.provider_data.get("secret"),
"backup_codes": hashed_backup_codes,
}
db.session.commit()
# Log backup codes regeneration
AuditService.log_action(
action=AuditAction.TOTP_BACKUP_CODES_REGENERATED,
user_id=user.id,
resource_type="authentication_method",
resource_id=totp_method.id,
description="TOTP backup codes regenerated",
)
return backup_codes
@@ -0,0 +1,408 @@
"""OIDC Audit Service for comprehensive OIDC event logging."""
from datetime import datetime, timezone
from typing import Dict, List, Optional
from flask import g
from gatehouse_app.models import OIDCAuditLog, OIDCClient, User
from gatehouse_app.exceptions.validation_exceptions import NotFoundError
class OIDCAuditService:
"""Service for OIDC-specific audit logging.
This service provides methods to log all OIDC-related events including:
- Authorization requests and responses
- Token issuance and refresh
- Token revocation
- UserInfo endpoint access
- Authentication failures
"""
# Event type constants
EVENT_AUTHORIZATION_REQUEST = "authorization_request"
EVENT_AUTHORIZATION_RESPONSE = "authorization_response"
EVENT_TOKEN_ISSUE = "token_issue"
EVENT_TOKEN_REFRESH = "token_refresh"
EVENT_TOKEN_REVOCATION = "token_revocation"
EVENT_TOKEN_INTROSPECTION = "token_introspection"
EVENT_USERINFO_ACCESS = "userinfo_access"
EVENT_AUTHENTICATION_FAILURE = "authentication_failure"
EVENT_AUTHORIZATION_FAILURE = "authorization_failure"
EVENT_JWKS_ACCESS = "jwks_access"
EVENT_REGISTRATION = "client_registration"
@classmethod
def _get_request_context(cls) -> Dict:
"""Extract request context for logging.
Returns:
Dictionary with IP, user_agent, and request_id
"""
from flask import request
return {
"ip_address": request.remote_addr if request else None,
"user_agent": request.headers.get("User-Agent") if request else None,
"request_id": g.get("request_id"),
}
@classmethod
def log_event(
cls,
event_type: str,
client_id: str = None,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
metadata: Dict = None
) -> OIDCAuditLog:
"""Log a generic OIDC event.
Args:
event_type: Type of event
client_id: OIDC client ID
user_id: User ID
success: Whether the event was successful
error_code: Error code if failed
error_description: Error description if failed
metadata: Additional event metadata
Returns:
OIDCAuditLog instance
"""
context = cls._get_request_context()
log = OIDCAuditLog.log_event(
event_type=event_type,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
ip_address=context["ip_address"],
user_agent=context["user_agent"],
request_id=context["request_id"],
event_metadata=metadata,
)
return log
@classmethod
def log_authorization_event(
cls,
client_id: str,
user_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
redirect_uri: str = None,
scope: list = None,
response_type: str = None
) -> OIDCAuditLog:
"""Log an authorization event.
Args:
client_id: OIDC client ID
user_id: User ID (if authenticated)
success: Whether authorization was successful
error_code: Error code if failed
error_description: Error description if failed
redirect_uri: Redirect URI from request
scope: Requested scopes
response_type: Response type (e.g., "code")
Returns:
OIDCAuditLog instance
"""
metadata = {
"redirect_uri": redirect_uri,
"scope": scope,
"response_type": response_type,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_AUTHORIZATION_REQUEST,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_token_event(
cls,
client_id: str,
user_id: str = None,
token_type: str = "access_token",
success: bool = True,
error_code: str = None,
error_description: str = None,
grant_type: str = None,
scopes: list = None
) -> OIDCAuditLog:
"""Log a token issuance or refresh event.
Args:
client_id: OIDC client ID
user_id: User ID
token_type: Type of token issued
success: Whether token issuance was successful
error_code: Error code if failed
error_description: Error description if failed
grant_type: Grant type used (e.g., "authorization_code", "refresh_token")
scopes: Scopes included in the token
Returns:
OIDCAuditLog instance
"""
metadata = {
"token_type": token_type,
"grant_type": grant_type,
"scopes": scopes,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_TOKEN_ISSUE if token_type else cls.EVENT_TOKEN_REFRESH,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_userinfo_event(
cls,
access_token: str = None,
user_id: str = None,
client_id: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None,
scopes_claimed: list = None
) -> OIDCAuditLog:
"""Log a UserInfo endpoint access event.
Args:
access_token: Access token used (masked)
user_id: User ID returned
client_id: Client ID making the request
success: Whether access was successful
error_code: Error code if failed
error_description: Error description if failed
scopes_claimed: Scopes claimed in the request
Returns:
OIDCAuditLog instance
"""
# Mask the access token for security
masked_token = None
if access_token:
masked_token = access_token[:8] + "..." + access_token[-4:] if len(access_token) > 12 else "***"
metadata = {
"token_prefix": masked_token,
"scopes_claimed": scopes_claimed,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_USERINFO_ACCESS,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_token_revocation_event(
cls,
client_id: str,
user_id: str = None,
token_type: str = "access_token",
reason: str = None,
success: bool = True,
error_code: str = None,
error_description: str = None
) -> OIDCAuditLog:
"""Log a token revocation event.
Args:
client_id: OIDC client ID
user_id: User ID
token_type: Type of token being revoked
reason: Revocation reason
success: Whether revocation was successful
error_code: Error code if failed
error_description: Error description if failed
Returns:
OIDCAuditLog instance
"""
metadata = {
"token_type": token_type,
"reason": reason,
}
metadata = {k: v for k, v in metadata.items() if v is not None}
return cls.log_event(
event_type=cls.EVENT_TOKEN_REVOCATION,
client_id=client_id,
user_id=user_id,
success=success,
error_code=error_code,
error_description=error_description,
metadata=metadata,
)
@classmethod
def log_authentication_failure(
cls,
client_id: str = None,
error_code: str = "authentication_failed",
error_description: str = "Authentication failed",
user_id: str = None
) -> OIDCAuditLog:
"""Log an authentication failure event.
Args:
client_id: OIDC client ID
error_code: Error code
error_description: Error description
user_id: User ID if known
Returns:
OIDCAuditLog instance
"""
return cls.log_event(
event_type=cls.EVENT_AUTHENTICATION_FAILURE,
client_id=client_id,
user_id=user_id,
success=False,
error_code=error_code,
error_description=error_description,
)
@classmethod
def get_events_for_user(
cls,
user_id: str,
limit: int = 100,
include_deleted: bool = False
) -> List[OIDCAuditLog]:
"""Get audit events for a specific user.
Args:
user_id: User ID
limit: Maximum number of events to return
include_deleted: Include soft-deleted events
Returns:
List of OIDCAuditLog instances
"""
return OIDCAuditLog.get_events_for_user(user_id, limit)
@classmethod
def get_events_for_client(
cls,
client_id: str,
limit: int = 100
) -> List[OIDCAuditLog]:
"""Get audit events for a specific client.
Args:
client_id: Client ID
limit: Maximum number of events to return
Returns:
List of OIDCAuditLog instances
"""
return OIDCAuditLog.get_events_for_client(client_id, limit)
@classmethod
def get_failed_events(
cls,
client_id: str = None,
user_id: str = None,
start_date: datetime = None,
end_date: datetime = None,
limit: int = 100
) -> List[OIDCAuditLog]:
"""Get failed audit events for analysis.
Args:
client_id: Optional client ID filter
user_id: Optional user ID filter
start_date: Optional start date filter
end_date: Optional end date filter
limit: Maximum number of events to return
Returns:
List of failed OIDCAuditLog instances
"""
return OIDCAuditLog.get_failed_events(
client_id=client_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
limit=limit,
)
@classmethod
def get_event_summary(
cls,
client_id: str = None,
days: int = 30
) -> Dict:
"""Get a summary of audit events.
Args:
client_id: Optional client ID filter
days: Number of days to look back
Returns:
Summary dictionary with event counts
"""
from datetime import timedelta
start_date = datetime.now(timezone.utc) - timedelta(days=days)
query = OIDCAuditLog.query.filter(
OIDCAuditLog.created_at >= start_date
)
if client_id:
query = query.filter_by(client_id=client_id)
events = query.all()
# Count by event type
event_counts = {}
success_count = 0
failure_count = 0
for event in events:
event_type = event.event_type
event_counts[event_type] = event_counts.get(event_type, 0) + 1
if event.success:
success_count += 1
else:
failure_count += 1
return {
"total_events": len(events),
"successful_events": success_count,
"failed_events": failure_count,
"by_event_type": event_counts,
"period_days": days,
}
+418
View File
@@ -0,0 +1,418 @@
"""OIDC JWKS Service for key management and rotation."""
import uuid
import json
import hashlib
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional, Tuple
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.oidc_jwks_key import OidcJwksKey
class JWKSKey:
"""Represents a JWKS key entry."""
def __init__(self, kid: str, private_key: str, public_key: str,
algorithm: str = "RS256", created_at: datetime = None,
expires_at: datetime = None, is_active: bool = True):
self.kid = kid
self.private_key = private_key
self.public_key = public_key
self.algorithm = algorithm
self.created_at = created_at or datetime.now(timezone.utc)
self.expires_at = expires_at or datetime.now(timezone.utc) + timedelta(days=365)
self.is_active = is_active
def to_jwk(self) -> Dict:
"""Convert to JWK format for JWKS endpoint."""
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.backends import default_backend
# Import cryptography here to avoid issues if not installed
try:
# Get public key from PEM
public_key = serialization.load_pem_public_key(
self.public_key.encode(), backend=default_backend()
)
# Get RSA parameters
public_numbers = public_key.public_numbers()
return {
"kty": "RSA",
"kid": self.kid,
"use": "sig",
"alg": self.algorithm,
"n": _base64url_encode(public_numbers.n),
"e": _base64url_encode(public_numbers.e),
}
except ImportError:
# Fallback for when cryptography is not installed
return {
"kty": "RSA",
"kid": self.kid,
"use": "sig",
"alg": self.algorithm,
}
def to_dict(self) -> Dict:
"""Convert to dictionary for storage."""
return {
"kid": self.kid,
"private_key": self.private_key,
"public_key": self.public_key,
"algorithm": self.algorithm,
"created_at": self.created_at.isoformat(),
"expires_at": self.expires_at.isoformat(),
"is_active": self.is_active,
}
@classmethod
def from_dict(cls, data: Dict) -> "JWKSKey":
"""Create from dictionary."""
return cls(
kid=data["kid"],
private_key=data["private_key"],
public_key=data["public_key"],
algorithm=data.get("algorithm", "RS256"),
created_at=datetime.fromisoformat(data["created_at"]),
expires_at=datetime.fromisoformat(data["expires_at"]),
is_active=data.get("is_active", True),
)
def _base64url_encode(value: int) -> str:
"""Encode an integer to base64url format."""
import base64
byte_length = (value.bit_length() + 7) // 8 or 1
encoded = value.to_bytes(byte_length, byteorder="big")
return base64.urlsafe_b64encode(encoded).decode().rstrip("=")
class OIDCJWKSService:
"""Service for managing OIDC signing keys (JWKS).
This service handles RSA key pair generation, rotation, and JWKS document
generation for the OIDC implementation.
"""
_instance = None
_keys: Dict[str, JWKSKey] = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._keys = {}
return cls._instance
@classmethod
def reset(cls):
"""Reset the singleton (for testing)."""
cls._instance = None
cls._keys = {}
def _generate_kid(self, private_key: str) -> str:
"""Generate a key ID from the private key fingerprint."""
kid_hash = hashlib.sha256(private_key.encode()).hexdigest()[:32]
return kid_hash
def _generate_rsa_key_pair(self) -> Tuple[str, str]:
"""Generate a new RSA key pair in PEM format.
Returns:
Tuple of (private_key_pem, public_key_pem)
"""
try:
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
# Generate RSA private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
# Get public key
public_key = private_key.public_key()
# Serialize to PEM
private_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
).decode()
public_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode()
return private_pem, public_pem
except ImportError:
# Fallback for testing without cryptography
import secrets
return f"private_key_{secrets.token_hex(32)}", f"public_key_{secrets.token_hex(32)}"
def get_jwks(self, include_private_keys: bool = False) -> Dict:
"""Get the JWKS document containing public keys.
Args:
include_private_keys: Whether to include private keys (for internal use only)
Returns:
JWKS document dictionary
"""
now = datetime.now(timezone.utc)
keys = []
for kid, key in self._keys.items():
# Only include active, non-expired keys
if key.is_active and key.expires_at > now:
if include_private_keys:
keys.append(key.to_dict())
else:
keys.append(key.to_jwk())
return {
"keys": keys
}
def load_keys_from_db(self) -> int:
"""Load existing keys from the database.
Returns:
Number of keys loaded
"""
try:
db_keys = OidcJwksKey.get_active_keys()
now = datetime.now(timezone.utc)
for db_key in db_keys:
# Create JWKSKey from database model
key = JWKSKey(
kid=db_key.kid,
private_key=db_key.private_key,
public_key=db_key.public_key,
algorithm=db_key.algorithm,
created_at=db_key.created_at,
expires_at=db_key.expires_at or now + timedelta(days=365),
is_active=db_key.is_active,
)
self._keys[db_key.kid] = key
return len(self._keys)
except Exception as e:
current_app.logger.error(f"Error loading keys from database: {e}")
return 0
def save_key_to_db(self, key: JWKSKey, is_primary: bool = False) -> OidcJwksKey:
"""Save a key to the database.
Args:
key: JWKSKey instance to save
is_primary: Whether this is the primary signing key
Returns:
OidcJwksKey database model instance
"""
db_key = OidcJwksKey(
kid=key.kid,
key_type="RSA",
algorithm=key.algorithm,
private_key=key.private_key,
public_key=key.public_key,
is_active=key.is_active,
is_primary=is_primary,
)
db.session.add(db_key)
db.session.commit()
return db_key
def get_signing_key(self) -> Optional[JWKSKey]:
"""Get the current active signing key.
Returns:
JWKSKey instance or None if no active key
"""
now = datetime.now(timezone.utc)
# First try to get the primary key from database
try:
primary_db_key = OidcJwksKey.get_primary_key()
if primary_db_key:
# Check if we have it in memory, if not load it
if primary_db_key.kid not in self._keys:
key = JWKSKey(
kid=primary_db_key.kid,
private_key=primary_db_key.private_key,
public_key=primary_db_key.public_key,
algorithm=primary_db_key.algorithm,
created_at=primary_db_key.created_at,
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
is_active=primary_db_key.is_active,
)
self._keys[primary_db_key.kid] = key
return self._keys[primary_db_key.kid]
except Exception as e:
current_app.logger.error(f"Error getting primary key from database: {e}")
# Fall back to in-memory keys
for kid, key in self._keys.items():
if key.is_active and key.expires_at > now:
return key
return None
def get_key_by_kid(self, kid: str) -> Optional[JWKSKey]:
"""Get a specific key by its ID.
Args:
kid: Key ID to look up
Returns:
JWKSKey instance or None if not found
"""
return self._keys.get(kid)
def generate_new_key_pair(self, expires_in_days: int = 365) -> JWKSKey:
"""Generate a new RSA key pair for signing.
Args:
expires_in_days: Days until key expiration
Returns:
JWKSKey instance
"""
private_key, public_key = self._generate_rsa_key_pair()
kid = self._generate_kid(private_key)
now = datetime.now(timezone.utc)
key = JWKSKey(
kid=kid,
private_key=private_key,
public_key=public_key,
algorithm="RS256",
created_at=now,
expires_at=now + timedelta(days=expires_in_days),
is_active=True,
)
self._keys[kid] = key
# Deactivate old keys (but keep them for grace period)
for old_kid in self._keys:
if old_kid != kid:
self._keys[old_kid].is_active = False
return key
def rotate_keys(self, grace_period_hours: int = 24) -> Tuple[JWKSKey, List[str]]:
"""Rotate signing keys, keeping previous key active for grace period.
Args:
grace_period_hours: Hours to keep old keys active
Returns:
Tuple of (new_key, list_of_deprecated_kids)
"""
now = datetime.now(timezone.utc)
grace_end = now + timedelta(hours=grace_period_hours)
# Mark current key as deprecated
current_key = self.get_signing_key()
deprecated_kids = []
if current_key:
deprecated_kids.append(current_key.kid)
# Keep key active but mark as deprecated
current_key.is_active = False
current_key.expires_at = grace_end
# Generate new key
new_key = self.generate_new_key_pair()
# Clean up expired keys
expired_kids = [
kid for kid, key in self._keys.items()
if key.expires_at < now
]
for kid in expired_kids:
del self._keys[kid]
return new_key, deprecated_kids
def verify_key_exists(self, kid: str) -> bool:
"""Check if a key with the given ID exists and is valid.
Args:
kid: Key ID to check
Returns:
True if key exists and is valid
"""
key = self.get_key_by_kid(kid)
if not key:
return False
now = datetime.now(timezone.utc)
return key.is_active and key.expires_at > now
def initialize_with_key(self) -> JWKSKey:
"""Initialize the service with a key, loading from database if available.
This method first attempts to load existing keys from the database.
If no active primary key exists, it generates a new key and saves it to the database.
Returns:
JWKSKey instance
"""
# First, try to load keys from database
try:
# Check if there's a primary key in the database
primary_db_key = OidcJwksKey.get_primary_key()
if primary_db_key:
# Load the primary key into memory
now = datetime.now(timezone.utc)
key = JWKSKey(
kid=primary_db_key.kid,
private_key=primary_db_key.private_key,
public_key=primary_db_key.public_key,
algorithm=primary_db_key.algorithm,
created_at=primary_db_key.created_at,
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
is_active=primary_db_key.is_active,
)
self._keys[primary_db_key.kid] = key
current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}")
return key
# Try to load all active keys from database
loaded_count = self.load_keys_from_db()
if loaded_count > 0:
# Get the signing key from loaded keys
signing_key = self.get_signing_key()
if signing_key:
current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}")
return signing_key
except Exception as e:
current_app.logger.error(f"Error loading keys from database: {e}")
# No keys in database, generate a new key and save it
current_app.logger.info("[OIDC] No existing keys found in database, generating new signing key")
new_key = self.generate_new_key_pair()
# Save the new key to database
try:
self.save_key_to_db(new_key, is_primary=True)
current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}")
except Exception as e:
current_app.logger.error(f"Error saving key to database: {e}")
return new_key
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,289 @@
"""OIDC Session Service for session management during OIDC flow."""
import secrets
from datetime import datetime, timedelta
from typing import Dict, Optional, Tuple
from datetime import timezone
from flask import current_app, g
from gatehouse_app.extensions import db
from gatehouse_app.models import OIDCSession, OIDCClient, User
from gatehouse_app.exceptions.validation_exceptions import NotFoundError, ValidationError
class OIDCSessionService:
"""Service for managing OIDC authentication sessions.
This service handles:
- Creating OIDC sessions during authorization flow
- Validating sessions with state and nonce
- Managing PKCE code challenges
- Cleaning up expired sessions
"""
@staticmethod
def _generate_state() -> str:
"""Generate a secure state parameter.
Returns:
URL-safe base64 encoded state
"""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_nonce() -> str:
"""Generate a secure nonce for OIDC.
Returns:
URL-safe base64 encoded nonce
"""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_code_challenge(verifier: str, method: str = "S256") -> str:
"""Generate a PKCE code challenge from verifier.
Args:
verifier: The code verifier
method: Challenge method ("S256" or "plain")
Returns:
Code challenge string
"""
import hashlib
import base64
if method == "S256":
digest = hashlib.sha256(verifier.encode()).digest()
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
elif method == "plain":
return verifier
else:
raise ValueError(f"Unsupported code challenge method: {method}")
@classmethod
def validate_code_verifier(cls, code_verifier: str, code_challenge: str,
method: str = "S256") -> bool:
"""Validate a PKCE code verifier against the stored challenge.
Args:
code_verifier: The code verifier from the token request
code_challenge: The code challenge from the authorization request
method: The challenge method used
Returns:
True if validation succeeds
"""
if not code_verifier or not code_challenge:
return False
# Validate code verifier length (43-128 characters)
if method == "S256" and not (43 <= len(code_verifier) <= 128):
return False
# Calculate expected challenge
expected_challenge = cls._generate_code_challenge(code_verifier, method)
return secrets.compare_digest(expected_challenge, code_challenge)
@classmethod
def create_session(
cls,
user_id: str,
client_id: str,
state: str = None,
nonce: str = None,
redirect_uri: str = None,
scope: list = None,
code_challenge: str = None,
code_challenge_method: str = None,
lifetime_seconds: int = 600
) -> OIDCSession:
"""Create a new OIDC session for the authorization flow.
Args:
user_id: The user ID
client_id: The OIDC client ID
state: State parameter (generated if not provided)
nonce: Nonce for ID token validation (generated if not provided)
redirect_uri: Redirect URI from authorization request
scope: Requested scopes
code_challenge: PKCE code challenge
code_challenge_method: PKCE method ("S256" or "plain")
lifetime_seconds: Session lifetime in seconds
Returns:
OIDCSession instance
"""
# Generate state and nonce if not provided
state = state or cls._generate_state()
nonce = nonce or cls._generate_nonce()
session = OIDCSession.create_session(
user_id=user_id,
client_id=client_id,
state=state,
nonce=nonce,
redirect_uri=redirect_uri,
scope=scope,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
lifetime_seconds=lifetime_seconds,
)
return session
@classmethod
def validate_session(cls, state: str, nonce: str = None) -> Tuple[OIDCSession, User]:
"""Validate an OIDC session by state and optionally nonce.
Args:
state: The state parameter
nonce: The nonce to validate (optional)
Returns:
Tuple of (OIDCSession, User)
Raises:
ValidationError: If session is invalid
NotFoundError: If session not found
"""
session = OIDCSession.get_by_state(state)
if not session:
raise NotFoundError("OIDC session not found or expired")
if session.is_expired():
raise ValidationError("OIDC session has expired")
# Validate nonce if provided
if nonce and not session.validate_nonce(nonce):
raise ValidationError("Invalid nonce")
# Get user
user = User.query.get(session.user_id)
if not user:
raise NotFoundError("User not found")
return session, user
@classmethod
def validate_pkce(cls, session: OIDCSession, code_verifier: str) -> bool:
"""Validate PKCE code verifier against the session's code challenge.
Args:
session: OIDCSession instance
code_verifier: The code verifier from token request
Returns:
True if validation succeeds
Raises:
ValidationError: If PKCE validation fails
"""
if not session.code_challenge:
# No PKCE was used, skip validation
return True
if not code_verifier:
raise ValidationError("code_verifier is required")
is_valid = session.validate_code_challenge(code_verifier)
if not is_valid:
raise ValidationError("Invalid code_verifier")
return True
@classmethod
def mark_session_authenticated(cls, session: OIDCSession) -> OIDCSession:
"""Mark a session as authenticated (user has logged in).
Args:
session: OIDCSession instance
Returns:
Updated OIDCSession instance
"""
session.mark_authenticated()
return session
@classmethod
def cleanup_expired_sessions(cls, older_than_hours: int = 24) -> int:
"""Remove expired OIDC sessions.
Args:
older_than_hours: Only delete sessions expired more than this many hours ago
Returns:
Number of sessions deleted
"""
from datetime import timedelta
cutoff = datetime.now(timezone.utc) - timedelta(hours=older_than_hours)
# Get expired sessions
expired_sessions = OIDCSession.query.filter(
OIDCSession.expires_at < datetime.now(timezone.utc),
OIDCSession.deleted_at == None
).all()
count = 0
for session in expired_sessions:
# Only hard delete if past the grace period
if session.expires_at < cutoff:
session.delete()
count += 1
return count
@classmethod
def get_session_by_state(cls, state: str) -> Optional[OIDCSession]:
"""Get an OIDC session by state.
Args:
state: The state parameter
Returns:
OIDCSession instance or None
"""
return OIDCSession.get_by_state(state)
@classmethod
def validate_redirect_uri(cls, client_id: str, redirect_uri: str) -> bool:
"""Validate that a redirect URI is allowed for a client.
Args:
client_id: The OIDC client ID
redirect_uri: The redirect URI to validate
Returns:
True if redirect URI is allowed
"""
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return False
return client.is_redirect_uri_allowed(redirect_uri)
@classmethod
def validate_scopes(cls, client_id: str, requested_scopes: list) -> list:
"""Validate and filter scopes against client's allowed scopes.
Args:
client_id: The OIDC client ID
requested_scopes: List of requested scopes
Returns:
List of allowed scopes
"""
client = OIDCClient.query.filter_by(client_id=client_id).first()
if not client:
return []
allowed_scopes = client.scopes or []
# Filter to only allowed scopes
valid_scopes = [s for s in requested_scopes if s in allowed_scopes]
return valid_scopes
@@ -0,0 +1,593 @@
"""OIDC Token Service for JWT token generation and validation."""
import hashlib
import base64
import secrets
import logging
import time
from datetime import datetime, timedelta, timezone
from typing import Dict, Optional, Any
import jwt
from flask import current_app, g
from gatehouse_app.models import User, OIDCClient
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
logger = logging.getLogger(__name__)
class OIDCTokenService:
"""Service for generating and validating OIDC tokens.
This service handles:
- Access token creation (JWT)
- ID token creation (JWT)
- Refresh token creation (opaque)
- Token signature verification
- Hash generation for PKCE claims (at_hash, c_hash)
"""
@staticmethod
def _generate_jti() -> str:
"""Generate a unique JWT ID."""
return secrets.token_urlsafe(32)
@staticmethod
def _generate_opaque_token(length: int = 43) -> str:
"""Generate an opaque token (for refresh tokens).
Args:
length: Length of the token
Returns:
URL-safe base64 encoded token
"""
return secrets.token_urlsafe(length)
@staticmethod
def _hash_token(token: str) -> str:
"""Hash a token for secure storage.
Args:
token: Token to hash
Returns:
SHA256 hash of the token
"""
return hashlib.sha256(token.encode()).hexdigest()
@staticmethod
def _base64url_encode(data: bytes) -> str:
"""Encode bytes to base64url format without padding.
Args:
data: Bytes to encode
Returns:
Base64url encoded string
"""
return base64.urlsafe_b64encode(data).decode().rstrip("=")
@staticmethod
def create_at_hash(access_token: str) -> str:
"""Create the at_hash claim for ID token.
Implements OIDC spec for access token hash generation.
Hash is the left-most half of the hash of the ASCII representation
of the access token.
Args:
access_token: The access token string
Returns:
Base64url encoded hash
"""
# Hash the access token using SHA256
hash_digest = hashlib.sha256(access_token.encode()).digest()
# Take left-most half of the hash
half_length = len(hash_digest) // 2
left_half = hash_digest[:half_length]
# Base64url encode
return OIDCTokenService._base64url_encode(left_half)
@staticmethod
def create_c_hash(code: str) -> str:
"""Create the c_hash claim for ID token.
Implements OIDC spec for authorization code hash generation.
Args:
code: The authorization code string
Returns:
Base64url encoded hash
"""
# Hash the code using SHA256
hash_digest = hashlib.sha256(code.encode()).digest()
# Take left-most half of the hash
half_length = len(hash_digest) // 2
left_half = hash_digest[:half_length]
# Base64url encode
return OIDCTokenService._base64url_encode(left_half)
@staticmethod
def _get_issuer() -> str:
"""Get the OIDC issuer URL."""
return current_app.config.get("OIDC_ISSUER_URL", "http://localhost:5000")
@staticmethod
def _get_token_lifetime(client: OIDCClient, token_type: str) -> int:
"""Get the token lifetime in seconds for a client.
Args:
client: OIDCClient instance
token_type: Type of token ("access_token", "refresh_token", "id_token")
Returns:
Lifetime in seconds
"""
lifetimes = {
"access_token": client.access_token_lifetime or 3600,
"refresh_token": client.refresh_token_lifetime or 2592000,
"id_token": client.id_token_lifetime or 3600,
}
return lifetimes.get(token_type, 3600)
@classmethod
def create_access_token(cls, client_id: str, user_id: str, scope: list,
jti: str = None) -> str:
"""Create a JWT access token.
Args:
client_id: The OIDC client ID
user_id: The user ID (subject)
scope: List of granted scopes
jti: Optional JWT ID (generated if not provided)
Returns:
JWT access token string
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_access_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
jti = jti or cls._generate_jti()
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
logger.debug("[OIDC TOKEN SERVICE] Access token lifetime (seconds): %s", lifetime)
exp_timestamp = now_timestamp + lifetime
exp_time = now + timedelta(seconds=lifetime)
logger.debug("[OIDC TOKEN SERVICE] Access token expiration time (UTC): %s", exp_time.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Access token expiration timestamp: %s", exp_timestamp)
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
claims = {
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": exp_timestamp,
"iat": now_timestamp,
"nbf": now_timestamp,
"jti": jti,
"client_id": client_id,
"scope": " ".join(scope) if isinstance(scope, list) else scope,
}
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, nbf=%s",
claims["exp"], claims["iat"], claims["nbf"])
# Get signing key
jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key()
if not signing_key:
raise ValueError("No signing key available")
# Sign with RS256
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
token = jwt.encode(
claims,
signing_key.private_key,
algorithm="RS256",
headers={"kid": signing_key.kid}
)
logger.debug("[OIDC TOKEN SERVICE] Access token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token
@classmethod
def create_id_token(cls, client_id: str, user_id: str, nonce: str = None,
scope: list = None, access_token: str = None,
auth_time: int = None) -> str:
"""Create a JWT ID token.
Args:
client_id: The OIDC client ID
user_id: The user ID (subject)
nonce: Nonce for replay protection
scope: Requested/Granted scopes
access_token: Associated access token (for at_hash)
auth_time: Authentication time (Unix timestamp)
Returns:
JWT ID token string
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_id_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
auth_time = auth_time or now_timestamp
logger.debug("[OIDC TOKEN SERVICE] auth_time (Unix timestamp): %s", auth_time)
# Get client for token lifetime
client = OIDCClient.query.filter_by(client_id=client_id).first()
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
logger.debug("[OIDC TOKEN SERVICE] ID token lifetime (seconds): %s", lifetime)
exp_timestamp = now_timestamp + lifetime
exp_time = now + timedelta(seconds=lifetime)
logger.debug("[OIDC TOKEN SERVICE] ID token expiration time (UTC): %s", exp_time.isoformat())
logger.debug("[OIDC TOKEN SERVICE] ID token expiration timestamp: %s", exp_timestamp)
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
# Get user for claims
user = User.query.get(user_id)
claims = {
"iss": cls._get_issuer(),
"sub": user_id,
"aud": client_id,
"exp": exp_timestamp,
"iat": now_timestamp,
"auth_time": auth_time,
}
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, auth_time=%s",
claims["exp"], claims["iat"], claims["auth_time"])
# Add nonce if provided
if nonce:
claims["nonce"] = nonce
# Add at_hash if access token provided
if access_token:
claims["at_hash"] = cls.create_at_hash(access_token)
# Add standard claims if user exists
if user:
if user.email:
claims["email"] = user.email
claims["email_verified"] = user.email_verified
if user.full_name:
claims["name"] = user.full_name
# Add roles claim if scope is granted
if scope and "roles" in scope:
claims["roles"] = cls._get_user_roles(user)
# Add scope if provided
if scope:
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
# Get signing key
jwks_service = OIDCJWKSService()
signing_key = jwks_service.get_signing_key()
if not signing_key:
raise ValueError("No signing key available")
# Sign with RS256
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
token = jwt.encode(
claims,
signing_key.private_key,
algorithm="RS256",
headers={"kid": signing_key.kid}
)
logger.debug("[OIDC TOKEN SERVICE] ID token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token
@staticmethod
def _get_user_roles(user: User) -> list:
"""Get user's organization roles.
Args:
user: User instance
Returns:
List of role objects with organization_id and role
"""
roles = []
if user and user.organization_memberships:
for member in user.organization_memberships:
roles.append({
"organization_id": str(member.organization_id),
"role": member.role.value
})
return roles
@classmethod
def create_refresh_token(cls, client_id: str, user_id: str,
scope: list = None, access_token_id: str = None) -> str:
"""Create an opaque refresh token.
Args:
client_id: The OIDC client ID
user_id: The user ID
scope: List of granted scopes
access_token_id: Associated access token ID
Returns:
Opaque refresh token string
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] create_refresh_token called")
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
logger.debug("[OIDC TOKEN SERVICE] scope=%s, access_token_id=%s", scope, access_token_id)
token = cls._generate_opaque_token()
logger.debug("[OIDC TOKEN SERVICE] Refresh token generated: %s...", token[:20] if token else None)
# Hash for storage
token_hash = cls._hash_token(token)
logger.debug("[OIDC TOKEN SERVICE] Refresh token created successfully")
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return token, token_hash
@classmethod
def verify_token_signature(cls, token: str) -> Dict:
"""Verify the signature of a JWT token.
Args:
token: JWT token string
Returns:
Decoded token claims
Raises:
jwt.InvalidSignatureError: If signature verification fails
jwt.ExpiredSignatureError: If token is expired
jwt.InvalidTokenError: If token is invalid
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] verify_token_signature() called")
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
# Get the JWKS with public keys
logger.debug("[OIDC TOKEN SERVICE] Getting JWKS...")
jwks_service = OIDCJWKSService()
jwks = jwks_service.get_jwks(include_private_keys=True)
logger.debug("[OIDC TOKEN SERVICE] JWKS retrieved: %d keys", len(jwks.get("keys", [])))
# Get the key ID from token header
try:
logger.debug("[OIDC TOKEN SERVICE] Getting unverified token header...")
unverified_header = jwt.get_unverified_header(token)
logger.debug("[OIDC TOKEN SERVICE] Unverified header: %s", unverified_header)
except jwt.DecodeError as e:
logger.error("[OIDC TOKEN SERVICE] Failed to decode token header: %s", str(e))
raise jwt.InvalidTokenError("Invalid token header")
kid = unverified_header.get("kid")
logger.debug("[OIDC TOKEN SERVICE] Key ID (kid) from token header: %s", kid)
# Find the matching public key
logger.debug("[OIDC TOKEN SERVICE] Searching for matching public key...")
public_key = None
for idx, key in enumerate(jwks.get("keys", [])):
logger.debug("[OIDC TOKEN SERVICE] Checking key %d: kid=%s", idx, key.get("kid"))
if key.get("kid") == kid:
logger.debug("[OIDC TOKEN SERVICE] Found matching key at index %d", idx)
try:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
logger.debug("[OIDC TOKEN SERVICE] Loading PEM public key...")
public_key = serialization.load_pem_public_key(
key["public_key"].encode() if isinstance(key["public_key"], str)
else key["public_key"],
backend=default_backend()
)
logger.debug("[OIDC TOKEN SERVICE] Public key loaded successfully")
break
except (ImportError, Exception) as e:
logger.error("[OIDC TOKEN SERVICE] Failed to load public key: %s: %s", type(e).__name__, str(e))
continue
if not public_key:
logger.error("[OIDC TOKEN SERVICE] No matching public key found for kid=%s", kid)
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
logger.debug("[OIDC TOKEN SERVICE] Public key found, verifying signature...")
# Verify the signature
try:
claims = jwt.decode(
token,
public_key,
algorithms=["RS256"],
audience=None, # We'll validate audience separately
issuer=cls._get_issuer(),
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": False, # Handle audience manually
"verify_iss": False, # Handle issuer manually
}
)
logger.debug("[OIDC TOKEN SERVICE] Signature verification successful")
logger.debug("[OIDC TOKEN SERVICE] Decoded claims: %s", claims)
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return claims
except jwt.ExpiredSignatureError as e:
logger.error("[OIDC TOKEN SERVICE] Token has expired: %s", str(e))
raise
except jwt.InvalidSignatureError as e:
logger.error("[OIDC TOKEN SERVICE] Invalid token signature: %s", str(e))
raise
except jwt.InvalidTokenError as e:
logger.error("[OIDC TOKEN SERVICE] Invalid token: %s: %s", type(e).__name__, str(e))
raise
except Exception as e:
logger.error("[OIDC TOKEN SERVICE] Unexpected error during token verification: %s: %s", type(e).__name__, str(e))
import traceback
logger.error("[OIDC TOKEN SERVICE] Traceback: %s", traceback.format_exc())
raise
@classmethod
def decode_token(cls, token: str, verify: bool = False) -> Dict:
"""Decode a JWT token without verification (for debugging).
Args:
token: JWT token string
verify: Whether to verify signature
Returns:
Decoded token claims
"""
if verify:
return cls.verify_token_signature(token)
return jwt.decode(
token,
options={
"verify_signature": False,
"verify_exp": False,
}
)
@classmethod
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
"""Validate an access token and return its claims.
Args:
token: JWT access token
client_id: Optional client ID to validate audience
Returns:
Token claims dictionary
Raises:
jwt.InvalidTokenError: If token is invalid
ValueError: If token is expired or audience mismatch
"""
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() called")
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
logger.debug("[OIDC TOKEN SERVICE] Client ID: %s", client_id)
# Verify token signature
logger.debug("[OIDC TOKEN SERVICE] Verifying token signature...")
claims = cls.verify_token_signature(token)
logger.debug("[OIDC TOKEN SERVICE] Token signature verified")
logger.debug("[OIDC TOKEN SERVICE] Claims: %s", claims)
# Check expiration
exp = claims.get("exp", 0)
now_timestamp = int(time.time())
if exp < now_timestamp:
logger.error("[OIDC TOKEN SERVICE] Token has expired")
raise ValueError("Token has expired")
# Validate audience if client_id provided
aud = claims.get("aud")
logger.debug("[OIDC TOKEN SERVICE] Token audience (aud): %s", aud)
logger.debug("[OIDC TOKEN SERVICE] Expected client_id: %s", client_id)
if client_id:
if aud != client_id:
logger.error("[OIDC TOKEN SERVICE] Audience mismatch: expected=%s, got=%s", client_id, aud)
raise ValueError("Invalid audience")
logger.debug("[OIDC TOKEN SERVICE] Audience validation passed")
else:
logger.debug("[OIDC TOKEN SERVICE] No client_id provided, skipping audience validation")
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() completed successfully")
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
return claims
@classmethod
def introspect_token(cls, token: str, client_id: str = None) -> Dict:
"""Introspect a token and return its status and claims.
Args:
token: JWT token to introspect
client_id: Client ID for audience validation
Returns:
Dictionary with active status and claims
"""
result = {
"active": False,
}
try:
claims = cls.validate_access_token(token, client_id)
# Calculate remaining time
now_timestamp = int(time.time())
now = datetime.now(timezone.utc)
exp = claims.get("exp", 0)
iat = claims.get("iat", 0)
logger.debug("[OIDC TOKEN SERVICE] Introspection - Current UTC time: %s", now.isoformat())
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration timestamp: %s", exp)
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration datetime (UTC): %s", datetime.fromtimestamp(exp, tz=timezone.utc).isoformat())
logger.debug("[OIDC TOKEN SERVICE] Introspection - Time until expiration: %s seconds", exp - now_timestamp)
result["active"] = exp > now_timestamp
result.update({
"iss": claims.get("iss"),
"sub": claims.get("sub"),
"aud": claims.get("aud"),
"exp": exp,
"iat": iat,
"nbf": claims.get("nbf"),
"jti": claims.get("jti"),
"client_id": claims.get("client_id"),
"scope": claims.get("scope"),
"token_type": "Bearer",
})
# Add expiry in seconds
if exp > now_timestamp:
result["exp"] = int(exp - now_timestamp)
except (jwt.InvalidTokenError, ValueError) as e:
result["active"] = False
result["error"] = str(e)
return result
@@ -0,0 +1,303 @@
"""Organization service."""
import logging
from datetime import datetime, timezone
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.organization import Organization
from gatehouse_app.models.organization_member import OrganizationMember
from gatehouse_app.exceptions.validation_exceptions import OrganizationNotFoundError, ConflictError
from gatehouse_app.utils.constants import OrganizationRole, AuditAction
from gatehouse_app.services.audit_service import AuditService
logger = logging.getLogger(__name__)
class OrganizationService:
"""Service for organization operations."""
@staticmethod
def create_organization(name, slug, owner_user_id, description=None, logo_url=None):
"""
Create a new organization.
Args:
name: Organization name
slug: Unique organization slug
owner_user_id: ID of the user who will be the owner
description: Optional description
logo_url: Optional logo URL
Returns:
Organization instance
Raises:
ConflictError: If slug already exists
"""
# Check if slug already exists
existing = Organization.query.filter_by(slug=slug, deleted_at=None).first()
if existing:
raise ConflictError("Organization slug already exists")
# Create organization
org = Organization(
name=name,
slug=slug,
description=description,
logo_url=logo_url,
is_active=True,
)
org.save()
# Add owner as member
member = OrganizationMember(
user_id=owner_user_id,
organization_id=org.id,
role=OrganizationRole.OWNER,
joined_at=datetime.now(timezone.utc),
)
member.save()
# Log organization creation
AuditService.log_action(
action=AuditAction.ORG_CREATE,
user_id=owner_user_id,
organization_id=org.id,
resource_type="organization",
resource_id=org.id,
description=f"Organization created: {name}",
)
return org
@staticmethod
def get_organization_by_id(org_id):
"""
Get organization by ID.
Args:
org_id: Organization ID
Returns:
Organization instance
Raises:
OrganizationNotFoundError: If organization not found
"""
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
# Development-only debug logging for organization validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Org] Get organization by ID: org_id={org_id}, exists={org is not None}")
if not org:
raise OrganizationNotFoundError()
return org
@staticmethod
def get_organization_by_slug(slug):
"""
Get organization by slug.
Args:
slug: Organization slug
Returns:
Organization instance or None
"""
org = Organization.query.filter_by(slug=slug, deleted_at=None).first()
# Development-only debug logging for organization validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Org] Get organization by slug: slug={slug}, exists={org is not None}")
return org
@staticmethod
def update_organization(org, user_id, **kwargs):
"""
Update organization.
Args:
org: Organization instance
user_id: ID of user performing the update
**kwargs: Fields to update
Returns:
Updated Organization instance
"""
allowed_fields = ["name", "description", "logo_url"]
update_data = {k: v for k, v in kwargs.items() if k in allowed_fields}
if update_data:
org.update(**update_data)
# Log organization update
AuditService.log_action(
action=AuditAction.ORG_UPDATE,
user_id=user_id,
organization_id=org.id,
resource_type="organization",
resource_id=org.id,
metadata=update_data,
description="Organization updated",
)
return org
@staticmethod
def delete_organization(org, user_id, soft=True):
"""
Delete organization.
Args:
org: Organization instance
user_id: ID of user performing the delete
soft: If True, performs soft delete
Returns:
Deleted Organization instance
"""
org.delete(soft=soft)
# Log organization deletion
AuditService.log_action(
action=AuditAction.ORG_DELETE,
user_id=user_id,
organization_id=org.id,
resource_type="organization",
resource_id=org.id,
description=f"Organization {'soft' if soft else 'hard'} deleted",
)
return org
@staticmethod
def add_member(org, user_id, role, inviter_id):
"""
Add a member to the organization.
Args:
org: Organization instance
user_id: ID of user to add
role: OrganizationRole
inviter_id: ID of user performing the invitation
Returns:
OrganizationMember instance
Raises:
ConflictError: If user is already a member
"""
# Check if already a member
existing = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org.id,
deleted_at=None,
).first()
# Development-only debug logging for membership validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
if existing:
raise ConflictError("User is already a member of this organization")
# Create membership
member = OrganizationMember(
user_id=user_id,
organization_id=org.id,
role=role,
invited_by_id=inviter_id,
invited_at=datetime.now(timezone.utc),
joined_at=datetime.now(timezone.utc),
)
member.save()
# Log member addition
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ADD,
user_id=inviter_id,
organization_id=org.id,
resource_type="organization_member",
resource_id=member.id,
metadata={"added_user_id": user_id, "role": role.value},
description=f"Member added to organization with role: {role.value}",
)
return member
@staticmethod
def remove_member(org, user_id, remover_id):
"""
Remove a member from the organization.
Args:
org: Organization instance
user_id: ID of user to remove
remover_id: ID of user performing the removal
"""
member = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org.id,
deleted_at=None,
).first()
# Development-only debug logging for membership removal validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[Org] Member removal: org_id={org.id}, user_id={user_id}, found={member is not None}")
if member:
member.delete(soft=True)
# Log member removal
AuditService.log_action(
action=AuditAction.ORG_MEMBER_REMOVE,
user_id=remover_id,
organization_id=org.id,
resource_type="organization_member",
resource_id=member.id,
metadata={"removed_user_id": user_id},
description="Member removed from organization",
)
@staticmethod
def update_member_role(org, user_id, new_role, updater_id):
"""
Update a member's role in the organization.
Args:
org: Organization instance
user_id: ID of user whose role to update
new_role: New OrganizationRole
updater_id: ID of user performing the update
Returns:
Updated OrganizationMember instance
"""
member = OrganizationMember.query.filter_by(
user_id=user_id,
organization_id=org.id,
deleted_at=None,
).first()
if member:
old_role = member.role
member.role = new_role
db.session.commit()
# Log role change
AuditService.log_action(
action=AuditAction.ORG_MEMBER_ROLE_CHANGE,
user_id=updater_id,
organization_id=org.id,
resource_type="organization_member",
resource_id=member.id,
metadata={
"target_user_id": user_id,
"old_role": old_role.value,
"new_role": new_role.value,
},
description=f"Member role changed from {old_role.value} to {new_role.value}",
)
return member
+76
View File
@@ -0,0 +1,76 @@
"""Session service."""
from datetime import datetime, timezone
from gatehouse_app.models.session import Session
from gatehouse_app.utils.constants import SessionStatus
class SessionService:
"""Service for session operations."""
@staticmethod
def get_active_session_by_token(token):
"""Get active session by token.
Args:
token: The session token string
Returns:
Session object if found and active, None otherwise
"""
from gatehouse_app.models.session import Session
from gatehouse_app.utils.constants import SessionStatus
return Session.query.filter_by(
token=token,
status=SessionStatus.ACTIVE,
deleted_at=None
).first()
@staticmethod
def get_user_sessions(user_id, active_only=True):
"""
Get all sessions for a user.
Args:
user_id: User ID
active_only: If True, only return active sessions
Returns:
List of Session instances
"""
query = Session.query.filter_by(user_id=user_id, deleted_at=None)
if active_only:
query = query.filter_by(status=SessionStatus.ACTIVE).filter(
Session.expires_at > datetime.now(timezone.utc)
)
return query.all()
@staticmethod
def revoke_user_sessions(user_id, reason="User logged out from all devices"):
"""
Revoke all active sessions for a user.
Args:
user_id: User ID
reason: Reason for revocation
"""
sessions = SessionService.get_user_sessions(user_id, active_only=True)
for session in sessions:
session.revoke(reason=reason)
@staticmethod
def cleanup_expired_sessions():
"""Clean up expired sessions."""
expired_sessions = Session.query.filter(
Session.status == SessionStatus.ACTIVE,
Session.expires_at < datetime.now(timezone.utc),
Session.deleted_at.is_(None),
).all()
for session in expired_sessions:
session.status = SessionStatus.EXPIRED
session.save()
return len(expired_sessions)
+214
View File
@@ -0,0 +1,214 @@
"""TOTP (Time-based One-Time Password) service."""
import base64
import io
import logging
import secrets
from datetime import datetime, timezone
from typing import Tuple
import pyotp
from gatehouse_app.extensions import bcrypt
logger = logging.getLogger(__name__)
class TOTPService:
"""Service for TOTP operations."""
@staticmethod
def generate_secret() -> str:
"""
Generate a new TOTP secret.
Returns:
Base32 encoded secret (32 characters)
Note:
The secret is generated using cryptographically secure random bytes
and encoded in base32 format for compatibility with authenticator apps.
"""
# Generate 20 random bytes (160 bits) and encode as base32
random_bytes = secrets.token_bytes(20)
secret = base64.b32encode(random_bytes).decode("utf-8")
logger.debug(f"Generated new TOTP secret: {secret[:8]}...")
return secret
@staticmethod
def generate_provisioning_uri(user_email: str, secret: str, issuer: str = "Gatehouse") -> str:
"""
Generate provisioning URI for QR code.
Args:
user_email: User's email address
secret: TOTP secret (base32 encoded)
issuer: Issuer name (default: "Gatehouse")
Returns:
otpauth:// URI for QR code generation
Example:
>>> uri = TOTPService.generate_provisioning_uri("user@example.com", "JBSWY3DPEHPK3PXP")
>>> print(uri)
otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse
"""
totp = pyotp.TOTP(secret)
uri = totp.provisioning_uri(name=user_email, issuer_name=issuer)
logger.debug(f"Generated provisioning URI for user: {user_email}")
return uri
@staticmethod
def verify_code(secret: str, code: str, window: int = 1) -> bool:
"""
Verify a TOTP code against the secret.
Args:
secret: TOTP secret (base32 encoded)
code: 6-digit TOTP code to verify
window: Time window for code validation (default: 1, allows codes from previous/next time steps)
Returns:
True if code is valid, False otherwise
Note:
The window parameter allows for clock skew between the server
and the authenticator app. A window of 1 allows codes from
the previous, current, and next 30-second intervals.
IMPORTANT: Always uses UTC time for verification to ensure
consistency across all timezones.
"""
totp = pyotp.TOTP(secret)
# Use timezone-aware UTC datetime for verification
# IMPORTANT: We must pass a datetime object, NOT a Unix timestamp
# pyotp's internal datetime.utcfromtimestamp() is deprecated and can be
# affected by local timezone settings, causing the 10.5 hour skew issue
utc_now = datetime.now(timezone.utc)
# DEBUG: Log detailed timezone information
logger.debug(f"[TOTP DEBUG] UTC now: {utc_now}")
logger.debug(f"[TOTP DEBUG] UTC now isoformat: {utc_now.isoformat()}")
logger.debug(f"[TOTP DEBUG] UTC timestamp: {utc_now.timestamp()}")
logger.debug(f"[TOTP DEBUG] UTC now tzinfo: {utc_now.tzinfo}")
# Generate what the TOTP code should be at this moment using UTC datetime
expected_code = totp.at(utc_now)
logger.debug(f"[TOTP DEBUG] Expected TOTP code at UTC: {expected_code}")
# Verify with the provided code using UTC datetime object
# Passing a datetime object avoids pyotp's utcfromtimestamp() issues
is_valid = totp.verify(code, valid_window=window, for_time=utc_now)
logger.debug(f"[TOTP DEBUG] TOTP code verification: valid={is_valid}, window={window}")
logger.debug(f"[TOTP DEBUG] Provided code: {code}, Expected code: {expected_code}")
return is_valid
@staticmethod
def generate_backup_codes(count: int = 10) -> Tuple[list[str], list[str]]:
"""
Generate backup codes for TOTP recovery.
Args:
count: Number of backup codes to generate (default: 10)
Returns:
Tuple of (plain_codes, hashed_codes)
- plain_codes: List of plain text backup codes (for display to user)
- hashed_codes: List of bcrypt hashed backup codes (for storage)
Note:
Backup codes are 16-character alphanumeric codes that can be used
to recover access if the TOTP device is lost. Each code can only
be used once.
"""
plain_codes = []
hashed_codes = []
for _ in range(count):
# Generate a 16-character alphanumeric code
code = secrets.token_hex(8).upper()
plain_codes.append(code)
# Hash the code using bcrypt
hashed_code = bcrypt.generate_password_hash(code).decode("utf-8")
hashed_codes.append(hashed_code)
logger.debug(f"Generated {count} backup codes")
return plain_codes, hashed_codes
@staticmethod
def verify_backup_code(hashed_codes: list[str], code: str) -> Tuple[bool, list[str]]:
"""
Verify and consume a backup code.
Args:
hashed_codes: List of bcrypt hashed backup codes
code: Plain text backup code to verify
Returns:
Tuple of (is_valid, remaining_codes)
- is_valid: True if code was valid and consumed, False otherwise
- remaining_codes: List of remaining hashed codes (with consumed code removed)
Note:
Once a backup code is used, it is removed from the list and cannot
be used again. This ensures each code is single-use.
"""
remaining_codes = []
for hashed_code in hashed_codes:
if bcrypt.check_password_hash(hashed_code, code):
# Code found and valid - mark as matched but don't add to remaining codes
matched = True
else:
# Code doesn't match - keep it in remaining codes
remaining_codes.append(hashed_code)
if matched:
return True, remaining_codes
else:
return False, remaining_codes
@staticmethod
def generate_qr_code_data_uri(provisioning_uri: str) -> str:
"""
Generate QR code as data URI for frontend display.
Args:
provisioning_uri: otpauth:// URI to encode in QR code
Returns:
Base64 encoded PNG image as data URI (data:image/png;base64,...)
Note:
If the qrcode library is not installed, returns a placeholder message.
Install with: pip install qrcode[pil]
"""
try:
import qrcode
# Create QR code
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_L,
box_size=10,
border=4,
)
qr.add_data(provisioning_uri)
qr.make(fit=True)
# Generate image
img = qr.make_image(fill_color="black", back_color="white")
# Convert to base64
buffer = io.BytesIO()
img.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
data_uri = f"data:image/png;base64,{img_base64}"
logger.debug("Generated QR code data URI")
return data_uri
except ImportError:
logger.warning("qrcode library not installed, returning placeholder")
return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]"
+125
View File
@@ -0,0 +1,125 @@
"""User service."""
import logging
from flask import current_app
from gatehouse_app.extensions import db
from gatehouse_app.models.user import User
from gatehouse_app.exceptions.validation_exceptions import UserNotFoundError
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.services.audit_service import AuditService
logger = logging.getLogger(__name__)
class UserService:
"""Service for user operations."""
@staticmethod
def get_user_by_id(user_id):
"""
Get user by ID.
Args:
user_id: User ID
Returns:
User instance
Raises:
UserNotFoundError: If user not found
"""
user = User.query.filter_by(id=user_id, deleted_at=None).first()
# Development-only debug logging for user validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[User] Get user by ID: user_id={user_id}, exists={user is not None}")
if not user:
raise UserNotFoundError()
return user
@staticmethod
def get_user_by_email(email):
"""
Get user by email.
Args:
email: User email
Returns:
User instance or None
"""
user = User.query.filter_by(email=email.lower(), deleted_at=None).first()
# Development-only debug logging for user validation
if current_app.config.get('ENV') == 'development':
logger.debug(f"[User] Get user by email: email={email}, exists={user is not None}")
return user
@staticmethod
def update_user(user, **kwargs):
"""
Update user profile.
Args:
user: User instance
**kwargs: Fields to update
Returns:
Updated User instance
"""
allowed_fields = ["full_name", "avatar_url"]
update_data = {k: v for k, v in kwargs.items() if k in allowed_fields}
if update_data:
user.update(**update_data)
# Log user update
AuditService.log_action(
action=AuditAction.USER_UPDATE,
user_id=user.id,
resource_type="user",
resource_id=user.id,
metadata=update_data,
description="User profile updated",
)
return user
@staticmethod
def delete_user(user, soft=True):
"""
Delete user account.
Args:
user: User instance
soft: If True, performs soft delete
Returns:
Deleted User instance
"""
user.delete(soft=soft)
# Log user deletion
AuditService.log_action(
action=AuditAction.USER_DELETE,
user_id=user.id,
resource_type="user",
resource_id=user.id,
description=f"User account {'soft' if soft else 'hard'} deleted",
)
return user
@staticmethod
def get_user_organizations(user):
"""
Get all organizations the user is a member of.
Args:
user: User instance
Returns:
List of organizations
"""
return user.get_organizations()
+647
View File
@@ -0,0 +1,647 @@
"""WebAuthn passkey authentication service."""
import logging
import secrets
import hashlib
import base64
import json
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any, List
from flask import current_app
from gatehouse_app.extensions import db, redis_client
from gatehouse_app.models.user import User
from gatehouse_app.models.authentication_method import AuthenticationMethod
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
from gatehouse_app.services.audit_service import AuditService
logger = logging.getLogger(__name__)
class WebAuthnService:
"""Service for WebAuthn passkey operations."""
# WebAuthn algorithm constants (COSE algorithms)
COSE_ALGORITHMS = {
-7: "ES256", # ECDSA with SHA-256
-257: "RS256", # RSASSA-PKCS1-v1_5 with SHA-256
}
# Supported key types
KEY_TYPES = ["public-key"]
@staticmethod
def _generate_challenge() -> str:
"""Generate a cryptographically secure challenge.
Returns:
Base64URL-encoded challenge string
"""
bytes_data = secrets.token_bytes(32)
return base64.urlsafe_b64encode(bytes_data).decode('utf-8').rstrip('=')
@staticmethod
def _store_challenge(user_id: str, challenge: str, challenge_type: str, expires_in: int = 300) -> bool:
"""Store a challenge in Redis for validation.
Args:
user_id: User ID
challenge: The challenge string
challenge_type: Type of challenge ('registration' or 'authentication')
expires_in: Expiration time in seconds
Returns:
True if stored successfully
"""
try:
key = f"webauthn:challenge:{user_id}:{challenge_type}:{challenge}"
data = {
"challenge": challenge,
"user_id": user_id,
"type": challenge_type,
"created_at": datetime.now(timezone.utc).isoformat()
}
redis_client.setex(key, expires_in, json.dumps(data))
return True
except Exception as e:
logger.error(f"Failed to store WebAuthn challenge: {e}")
return False
@staticmethod
def _get_and_delete_challenge(user_id: str, challenge: str, challenge_type: str) -> Optional[Dict]:
"""Retrieve and delete a challenge from Redis.
Args:
user_id: User ID
challenge: The challenge string
challenge_type: Type of challenge
Returns:
Challenge data dict or None if not found/expired
"""
try:
key = f"webauthn:challenge:{user_id}:{challenge_type}:{challenge}"
data = redis_client.get(key)
if data:
redis_client.delete(key)
return json.loads(data)
return None
except Exception as e:
logger.error(f"Failed to retrieve WebAuthn challenge: {e}")
return None
@staticmethod
def _base64url_decode(data: str) -> bytes:
"""Decode Base64URL string to bytes."""
# Add padding if needed
padding = 4 - (len(data) % 4)
if padding != 4:
data += '=' * padding
return base64.urlsafe_b64decode(data)
@staticmethod
def _base64url_encode(data: bytes) -> str:
"""Encode bytes to Base64URL string."""
return base64.urlsafe_b64encode(data).decode('utf-8').rstrip('=')
@staticmethod
def _hash_credential_id(credential_id: bytes) -> str:
"""Hash a credential ID for secure storage lookup.
Args:
credential_id: Raw credential ID bytes
Returns:
Hashed credential ID string
"""
return hashlib.sha256(credential_id).hexdigest()
@classmethod
def generate_registration_challenge(cls, user: User) -> Dict[str, Any]:
"""Generate a challenge for passkey registration.
Args:
user: User instance
Returns:
PublicKeyCredentialCreationOptions dict
"""
# Generate challenge
challenge = cls._generate_challenge()
# Store challenge
cls._store_challenge(user.id, challenge, 'registration')
# Get existing credentials to exclude
existing_credentials = cls.get_user_credentials(user)
exclude_credentials = []
for cred in existing_credentials:
if cred.provider_data:
cred_id_b64 = cred.provider_data.get("credential_id")
if cred_id_b64:
try:
cred_id = cls._base64url_decode(cred_id_b64)
transports = cred.provider_data.get("transports", [])
exclude_credentials.append({
"id": cred_id_b64,
"type": "public-key",
"transports": transports
})
except Exception:
pass
# Get RP configuration
rp_id = current_app.config.get('WEBAUTHN_RP_ID', 'localhost')
rp_name = current_app.config.get('WEBAUTHN_RP_NAME', 'Gatehouse')
# Generate user ID (Base64URL encoded)
user_id = cls._base64url_encode(user.id.encode('utf-8'))
# Build options
options = {
"rp": {
"name": rp_name,
"id": rp_id
},
"user": {
"id": user_id,
"name": user.email,
"displayName": user.full_name or user.email
},
"challenge": challenge,
"pubKeyCredParams": [
{"type": "public-key", "alg": -7}, # ES256
{"type": "public-key", "alg": -257} # RS256
],
"timeout": 60000, # 60 seconds
"excludeCredentials": exclude_credentials,
"authenticatorSelection": {
"residentKey": "preferred",
"userVerification": "preferred"
},
"attestation": "none"
}
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_REGISTER_INITIATED,
user_id=user.id,
description="WebAuthn registration initiated"
)
return options
@classmethod
def verify_registration_response(
cls,
user: User,
credential_data: Dict[str, Any],
challenge: str
) -> AuthenticationMethod:
"""Verify and store a new passkey credential.
Args:
user: User instance
credential_data: Credential response data from client
challenge: The original challenge string
Returns:
AuthenticationMethod instance
Raises:
InvalidCredentialsError: If verification fails
"""
# Verify and consume challenge
stored_challenge = cls._get_and_delete_challenge(user.id, challenge, 'registration')
if not stored_challenge:
AuditService.log_action(
action=AuditAction.WEBAUTHN_REGISTER_FAILED,
user_id=user.id,
description="Registration failed: challenge expired or invalid"
)
raise InvalidCredentialsError("Challenge expired or invalid")
try:
# Parse credential data
credential_id = credential_data.get("id")
raw_id = credential_data.get("rawId")
response = credential_data.get("response", {})
attestation_object_b64 = response.get("attestationObject")
client_data_json_b64 = response.get("clientDataJSON")
transports = credential_data.get("transports", ["platform"])
if not all([credential_id, raw_id, attestation_object_b64, client_data_json_b64]):
raise InvalidCredentialsError("Missing required credential data")
# Decode attestation object
attestation_object = cls._base64url_decode(attestation_object_b64)
# Parse CBOR attestation object (simplified - in production use cbor2 library)
# The attestation object contains: authData, attStmt, fmt
try:
import cbor2
attestation_dict = cbor2.loads(attestation_object)
except ImportError:
# Fallback: try to parse as simple structure
attestation_dict = {}
logger.warning("cbor2 library not available, using fallback parsing")
# Extract authenticator data
auth_data = attestation_dict.get('authData', b'')
# Parse authenticator data
# Format: RP ID hash (32 bytes) + Flags (1 byte) + Counter (4 bytes) + AAGUID (16 bytes) + Credential ID length (2 bytes) + Credential ID + Public key
if len(auth_data) < 37:
raise InvalidCredentialsError("Invalid authenticator data")
rp_id_hash = auth_data[:32]
flags = auth_data[32]
counter = int.from_bytes(auth_data[33:37], 'big')
aaguid = auth_data[37:53] if len(auth_data) >= 53 else b''
# Extract credential ID length and ID
cred_id_length = int.from_bytes(auth_data[53:55], 'big') if len(auth_data) >= 55 else 0
credential_id_raw = auth_data[55:55+cred_id_length] if cred_id_length > 0 else b''
# Extract public key (COSE format)
public_key_cose = auth_data[55+cred_id_length:]
# Verify client data
client_data_json = cls._base64url_decode(client_data_json_b64)
client_data = json.loads(client_data_json)
# Verify challenge matches
if client_data.get("challenge") != challenge:
raise InvalidCredentialsError("Challenge mismatch")
# Verify origin
expected_origin = current_app.config.get('WEBAUTHN_ORIGIN', 'http://localhost:5173')
if client_data.get("origin") != expected_origin:
logger.warning(f"Origin mismatch: expected {expected_origin}, got {client_data.get('origin')}")
# Don't fail on origin mismatch in development
# Verify user presence and verification
user_present = bool(flags & 0x01)
user_verified = bool(flags & 0x04)
if not user_present:
raise InvalidCredentialsError("User presence not verified")
# Store credential
credential_id_hash = cls._hash_credential_id(credential_id_raw)
# Check if credential already exists
existing = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).first()
if existing and existing.provider_data:
stored_cred_id = existing.provider_data.get("credential_id", "")
if stored_cred_id == credential_id:
raise InvalidCredentialsError("Credential already registered")
# Create or update authentication method
auth_method = existing or AuthenticationMethod(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
is_primary=False,
verified=True
)
# Store credential data
auth_method.provider_data = {
"credential_id": credential_id,
"credential_id_hash": credential_id_hash,
"public_key_cose": cls._base64url_encode(public_key_cose),
"sign_count": counter,
"transports": transports,
"aaguid": cls._base64url_encode(aaguid) if aaguid else None,
"attestation_format": attestation_dict.get('fmt', 'unknown'),
"created_at": datetime.now(timezone.utc).isoformat(),
"last_used_at": None,
"name": f"Passkey {datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
}
auth_method.save()
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_REGISTER_COMPLETED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description=f"WebAuthn credential registered: {credential_id[:16]}..."
)
return auth_method
except InvalidCredentialsError:
raise
except Exception as e:
logger.error(f"WebAuthn registration verification failed: {e}")
AuditService.log_action(
action=AuditAction.WEBAUTHN_REGISTER_FAILED,
user_id=user.id,
description=f"Registration failed: {str(e)}"
)
raise InvalidCredentialsError("Registration verification failed")
@classmethod
def generate_authentication_challenge(cls, user: User) -> Dict[str, Any]:
"""Generate a challenge for passkey authentication.
Args:
user: User instance
Returns:
PublicKeyCredentialRequestOptions dict
"""
# Generate challenge
challenge = cls._generate_challenge()
# Store challenge
cls._store_challenge(user.id, challenge, 'authentication')
# Get user's credentials
credentials = cls.get_user_credentials(user)
# Build allow credentials list
allow_credentials = []
for cred in credentials:
if cred.provider_data:
cred_id = cred.provider_data.get("credential_id")
transports = cred.provider_data.get("transports", [])
if cred_id:
allow_credentials.append({
"id": cred_id,
"type": "public-key",
"transports": transports
})
# Get RP configuration
rp_id = current_app.config.get('WEBAUTHN_RP_ID', 'localhost')
# Build options
options = {
"challenge": challenge,
"timeout": 60000,
"rpId": rp_id,
"allowCredentials": allow_credentials,
"userVerification": "preferred"
}
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_LOGIN_INITIATED,
user_id=user.id,
description="WebAuthn authentication initiated"
)
return options
@classmethod
def verify_authentication_response(
cls,
user: User,
credential_data: Dict[str, Any],
challenge: str
) -> AuthenticationMethod:
"""Verify passkey authentication response.
Args:
user: User instance
credential_data: Assertion response data from client
challenge: The original challenge string
Returns:
AuthenticationMethod instance
Raises:
InvalidCredentialsError: If verification fails
"""
# Verify and consume challenge
stored_challenge = cls._get_and_delete_challenge(user.id, challenge, 'authentication')
if not stored_challenge:
AuditService.log_action(
action=AuditAction.WEBAUTHN_LOGIN_FAILED,
user_id=user.id,
description="Authentication failed: challenge expired or invalid"
)
raise InvalidCredentialsError("Challenge expired or invalid")
try:
# Parse credential data
credential_id = credential_data.get("id")
raw_id = credential_data.get("rawId")
response = credential_data.get("response", {})
authenticator_data_b64 = response.get("authenticatorData")
client_data_json_b64 = response.get("clientDataJSON")
signature_b64 = response.get("signature")
if not all([credential_id, authenticator_data_b64, client_data_json_b64, signature_b64]):
raise InvalidCredentialsError("Missing required credential data")
# Find the credential
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).first()
if not auth_method or not auth_method.provider_data:
raise InvalidCredentialsError("No passkey found for user")
stored_cred_id = auth_method.provider_data.get("credential_id")
if stored_cred_id != credential_id:
raise InvalidCredentialsError("Credential not found")
# Decode authenticator data
authenticator_data = cls._base64url_decode(authenticator_data_b64)
# Parse authenticator data
if len(authenticator_data) < 37:
raise InvalidCredentialsError("Invalid authenticator data")
rp_id_hash = authenticator_data[:32]
flags = authenticator_data[32]
counter = int.from_bytes(authenticator_data[33:37], 'big')
# Verify client data
client_data_json = cls._base64url_decode(client_data_json_b64)
client_data = json.loads(client_data_json)
# Verify challenge matches
if client_data.get("challenge") != challenge:
raise InvalidCredentialsError("Challenge mismatch")
# Verify origin
expected_origin = current_app.config.get('WEBAUTHN_ORIGIN', 'http://localhost:5173')
if client_data.get("origin") != expected_origin:
logger.warning(f"Origin mismatch: expected {expected_origin}, got {client_data.get('origin')}")
# Verify user presence
user_present = bool(flags & 0x01)
if not user_present:
raise InvalidCredentialsError("User presence not verified")
# Verify counter (prevent replay attacks)
stored_counter = auth_method.provider_data.get("sign_count", 0)
if counter <= stored_counter:
raise InvalidCredentialsError("Invalid sign counter - potential credential cloning detected")
# Verify signature (simplified - in production use proper crypto verification)
# In a full implementation, you would:
# 1. Decode the public key from COSE format
# 2. Verify the signature using the stored public key
# 3. Verify the authenticator data hash matches RP ID
# For now, we'll trust the authenticator's signature verification
# A full implementation would use the fido2 library
# Update counter and last used time
auth_method.provider_data["sign_count"] = counter
auth_method.provider_data["last_used_at"] = datetime.now(timezone.utc).isoformat()
auth_method.last_used_at = datetime.now(timezone.utc)
db.session.commit()
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_LOGIN_SUCCESS,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description="WebAuthn authentication successful"
)
return auth_method
except InvalidCredentialsError:
raise
except Exception as e:
logger.error(f"WebAuthn authentication verification failed: {e}")
AuditService.log_action(
action=AuditAction.WEBAUTHN_LOGIN_FAILED,
user_id=user.id,
description=f"Authentication failed: {str(e)}"
)
raise InvalidCredentialsError("Authentication verification failed")
@classmethod
def get_user_credentials(cls, user: User) -> List[AuthenticationMethod]:
"""Get all passkey credentials for a user.
Args:
user: User instance
Returns:
List of AuthenticationMethod instances
"""
return AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).order_by(AuthenticationMethod.created_at.desc()).all()
@classmethod
def delete_credential(cls, credential_id: str, user: User) -> bool:
"""Delete a passkey credential.
Args:
credential_id: The credential ID to delete
user: User instance
Returns:
True if deleted successfully
"""
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).first()
if not auth_method or not auth_method.provider_data:
return False
stored_cred_id = auth_method.provider_data.get("credential_id")
if stored_cred_id != credential_id:
return False
# Soft delete the credential
auth_method.delete(soft=True)
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_CREDENTIAL_DELETED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description=f"WebAuthn credential deleted: {credential_id[:16]}..."
)
return True
@classmethod
def rename_credential(cls, credential_id: str, user: User, name: str) -> bool:
"""Rename a passkey credential.
Args:
credential_id: The credential ID to rename
user: User instance
name: New name for the credential
Returns:
True if renamed successfully
"""
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).first()
if not auth_method or not auth_method.provider_data:
return False
stored_cred_id = auth_method.provider_data.get("credential_id")
if stored_cred_id != credential_id:
return False
# Update name
auth_method.provider_data["name"] = name
db.session.commit()
# Log audit event
AuditService.log_action(
action=AuditAction.WEBAUTHN_CREDENTIAL_RENAMED,
user_id=user.id,
resource_type="authentication_method",
resource_id=auth_method.id,
description=f"WebAuthn credential renamed to: {name}"
)
return True
@classmethod
def get_credential_by_id(cls, credential_id: str, user: User) -> Optional[AuthenticationMethod]:
"""Get a specific credential by ID.
Args:
credential_id: The credential ID
user: User instance
Returns:
AuthenticationMethod instance or None
"""
auth_method = AuthenticationMethod.query.filter_by(
user_id=user.id,
method_type=AuthMethodType.WEBAUTHN,
deleted_at=None
).first()
if auth_method and auth_method.provider_data:
stored_cred_id = auth_method.provider_data.get("credential_id")
if stored_cred_id == credential_id:
return auth_method
return None
+25
View File
@@ -0,0 +1,25 @@
"""Utilities package."""
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.constants import (
UserStatus,
OrganizationRole,
AuthMethodType,
SessionStatus,
AuditAction,
ErrorType,
)
from gatehouse_app.utils.decorators import login_required, require_role, require_owner, require_admin
__all__ = [
"api_response",
"UserStatus",
"OrganizationRole",
"AuthMethodType",
"SessionStatus",
"AuditAction",
"ErrorType",
"login_required",
"require_role",
"require_owner",
"require_admin",
]
+118
View File
@@ -0,0 +1,118 @@
"""Application constants and enums."""
from enum import Enum
class UserStatus(str, Enum):
"""User account status."""
ACTIVE = "active"
INACTIVE = "inactive"
SUSPENDED = "suspended"
PENDING = "pending"
class OrganizationRole(str, Enum):
"""Organization member roles."""
OWNER = "owner"
ADMIN = "admin"
MEMBER = "member"
GUEST = "guest"
class AuthMethodType(str, Enum):
"""Authentication method types."""
PASSWORD = "password"
TOTP = "totp"
GOOGLE = "google"
GITHUB = "github"
MICROSOFT = "microsoft"
SAML = "saml"
OIDC = "oidc"
WEBAUTHN = "webauthn"
class SessionStatus(str, Enum):
"""Session status."""
ACTIVE = "active"
EXPIRED = "expired"
REVOKED = "revoked"
class AuditAction(str, Enum):
"""Audit log action types."""
# User actions
USER_LOGIN = "user.login"
USER_LOGOUT = "user.logout"
USER_REGISTER = "user.register"
USER_UPDATE = "user.update"
USER_DELETE = "user.delete"
PASSWORD_CHANGE = "user.password_change"
PASSWORD_RESET = "user.password_reset"
# Organization actions
ORG_CREATE = "org.create"
ORG_UPDATE = "org.update"
ORG_DELETE = "org.delete"
ORG_MEMBER_ADD = "org.member.add"
ORG_MEMBER_REMOVE = "org.member.remove"
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
# Session actions
SESSION_CREATE = "session.create"
SESSION_REVOKE = "session.revoke"
# Auth method actions
AUTH_METHOD_ADD = "auth.method.add"
AUTH_METHOD_REMOVE = "auth.method.remove"
TOTP_ENROLL_INITIATED = "totp.enroll.initiated"
TOTP_ENROLL_COMPLETED = "totp.enroll.completed"
TOTP_VERIFY_SUCCESS = "totp.verify.success"
TOTP_VERIFY_FAILED = "totp.verify.failed"
TOTP_DISABLED = "totp.disabled"
TOTP_BACKUP_CODE_USED = "totp.backup_code.used"
TOTP_BACKUP_CODES_REGENERATED = "totp.backup_codes.regenerated"
# WebAuthn actions
WEBAUTHN_REGISTER_INITIATED = "webauthn.register.initiated"
WEBAUTHN_REGISTER_COMPLETED = "webauthn.register.completed"
WEBAUTHN_REGISTER_FAILED = "webauthn.register.failed"
WEBAUTHN_LOGIN_INITIATED = "webauthn.login.initiated"
WEBAUTHN_LOGIN_SUCCESS = "webauthn.login.success"
WEBAUTHN_LOGIN_FAILED = "webauthn.login.failed"
WEBAUTHN_CREDENTIAL_DELETED = "webauthn.credential.deleted"
WEBAUTHN_CREDENTIAL_RENAMED = "webauthn.credential.renamed"
class OIDCGrantType(str, Enum):
"""OIDC grant types."""
AUTHORIZATION_CODE = "authorization_code"
IMPLICIT = "implicit"
REFRESH_TOKEN = "refresh_token"
CLIENT_CREDENTIALS = "client_credentials"
class OIDCResponseType(str, Enum):
"""OIDC response types."""
CODE = "code"
TOKEN = "token"
ID_TOKEN = "id_token"
# Error type constants
class ErrorType:
"""Error type constants for API responses."""
VALIDATION_ERROR = "VALIDATION_ERROR"
AUTHENTICATION_ERROR = "AUTHENTICATION_ERROR"
AUTHORIZATION_ERROR = "AUTHORIZATION_ERROR"
NOT_FOUND = "NOT_FOUND"
CONFLICT = "CONFLICT"
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"
INTERNAL_ERROR = "INTERNAL_ERROR"
BAD_REQUEST = "BAD_REQUEST"
+129
View File
@@ -0,0 +1,129 @@
"""Custom decorators for authentication and authorization."""
from functools import wraps
from flask import request, g
from gatehouse_app.utils.response import api_response
from gatehouse_app.utils.constants import OrganizationRole
def login_required(f):
"""Decorator to require Bearer token authentication.
Extracts token from Authorization: Bearer {token} header,
validates the session, and sets g.current_user and g.current_session.
"""
from gatehouse_app.services.session_service import SessionService
@wraps(f)
def decorated_function(*args, **kwargs):
# Extract token from Authorization header
auth_header = request.headers.get('Authorization')
if not auth_header:
return api_response(
success=False,
message="Authorization header is required",
status=401,
error_type="AUTH_REQUIRED"
)
# Expect format: "Bearer {token}"
parts = auth_header.split()
if len(parts) != 2 or parts[0].lower() != 'bearer':
return api_response(
success=False,
message="Invalid authorization format. Use: Bearer {token}",
status=401,
error_type="INVALID_AUTH_FORMAT"
)
token = parts[1]
# Get active session by token
session = SessionService.get_active_session_by_token(token)
if not session:
return api_response(
success=False,
message="Invalid or expired session",
status=401,
error_type="INVALID_TOKEN"
)
# Validate session is active
if not session.is_active():
return api_response(
success=False,
message="Session is no longer active",
status=401,
error_type="SESSION_INACTIVE"
)
# Update last_activity_at timestamp
from datetime import datetime, timezone
session.last_activity_at = datetime.now(timezone.utc)
from gatehouse_app import db
db.session.commit()
# Set context variables
g.current_user = session.user
g.current_session = session
return f(*args, **kwargs)
return decorated_function
def require_role(*allowed_roles):
"""
Decorator to require specific organization roles.
Args:
*allowed_roles: Variable number of OrganizationRole values
Raises:
ForbiddenError: If user doesn't have required role
"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
# Ensure user is authenticated first
if not hasattr(g, "current_user"):
raise UnauthorizedError("Authentication required")
# Get organization_id from kwargs or URL parameters
org_id = kwargs.get("org_id") or kwargs.get("organization_id")
if not org_id:
raise ForbiddenError("Organization context required")
# Check user's role in the organization
from gatehouse_app.models.organization_member import OrganizationMember
membership = OrganizationMember.query.filter_by(
user_id=g.current_user.id,
organization_id=org_id,
).first()
if not membership:
raise ForbiddenError("Not a member of this organization")
if membership.role not in allowed_roles:
raise ForbiddenError(
f"Requires one of the following roles: {', '.join(allowed_roles)}"
)
g.current_membership = membership
return f(*args, **kwargs)
return decorated_function
return decorator
def require_owner(f):
"""Decorator to require organization owner role."""
return require_role(OrganizationRole.OWNER)(f)
def require_admin(f):
"""Decorator to require organization admin or owner role."""
return require_role(OrganizationRole.OWNER, OrganizationRole.ADMIN)(f)
+54
View File
@@ -0,0 +1,54 @@
"""API response utilities."""
from flask import jsonify, g
# Version for the response envelope
ENVELOPE_VERSION = "1.0"
def api_response(
*,
data=None,
success=True,
message="",
status=200,
error_type=None,
error_details=None,
meta=None
):
"""
Create a standardized API response.
Args:
data: Response data (only included if success=True)
success: Whether the request was successful
message: Human-readable message
status: HTTP status code
error_type: Type of error (only if success=False)
error_details: Additional error details (only if success=False)
meta: Additional metadata (pagination, etc.)
Returns:
Tuple of (response, status_code)
"""
payload = {
"version": ENVELOPE_VERSION,
"success": success,
"code": status,
"message": message,
"request_id": g.get("request_id", "unknown"),
}
if meta:
payload["meta"] = meta
if success:
if data is not None:
payload["data"] = data
else:
payload["error"] = {
"type": error_type or "UNKNOWN",
"details": error_details or {}
}
return jsonify(payload), status