move app to gatehouse-app
This commit is contained in:
@@ -0,0 +1,240 @@
|
||||
"""Application factory."""
|
||||
import os
|
||||
import logging
|
||||
|
||||
# Test debug logging - this should appear when running `flask run --debug`
|
||||
_root_logger = logging.getLogger(__name__)
|
||||
_root_logger.debug("[TEST] Debug logging is working!")
|
||||
|
||||
from flask import Flask
|
||||
from config import get_config
|
||||
from gatehouse_app.extensions import db, migrate, bcrypt, ma, limiter
|
||||
from gatehouse_app.extensions import session as flask_session
|
||||
from gatehouse_app.middleware import RequestIDMiddleware, SecurityHeadersMiddleware, setup_cors
|
||||
from gatehouse_app.exceptions.base import BaseAPIException
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
|
||||
import redis
|
||||
|
||||
# Configure SQLAlchemy logging BEFORE any database operations
|
||||
# This must be done before db.init_app() to prevent verbose logging
|
||||
_log_level_env = os.getenv("SQLALCHEMY_LOG_LEVEL", "WARNING").upper()
|
||||
_sqlalchemy_log_level = getattr(logging, _log_level_env, logging.WARNING)
|
||||
logging.getLogger('sqlalchemy').setLevel(_sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(_sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.dialects').setLevel(_sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.pool').setLevel(_sqlalchemy_log_level)
|
||||
|
||||
|
||||
def create_app(config_name=None):
|
||||
"""
|
||||
Create and configure the Flask application.
|
||||
|
||||
Args:
|
||||
config_name: Configuration name (development, testing, production)
|
||||
|
||||
Returns:
|
||||
Flask application instance
|
||||
"""
|
||||
flask_app = Flask(__name__)
|
||||
|
||||
# Load configuration
|
||||
config = get_config(config_name)
|
||||
flask_app.config.from_object(config)
|
||||
|
||||
# Initialize extensions
|
||||
initialize_extensions(flask_app)
|
||||
|
||||
# Setup middleware
|
||||
setup_middleware(flask_app)
|
||||
|
||||
# Register blueprints
|
||||
register_blueprints(flask_app)
|
||||
|
||||
# Register error handlers
|
||||
register_error_handlers(flask_app)
|
||||
|
||||
# Setup logging
|
||||
setup_logging(flask_app)
|
||||
|
||||
# Initialize OIDC JWKS service with a signing key
|
||||
initialize_oidc_jwks(flask_app)
|
||||
|
||||
return flask_app
|
||||
|
||||
|
||||
def initialize_extensions(app):
|
||||
"""Initialize Flask extensions."""
|
||||
# Database
|
||||
db.init_app(app)
|
||||
migrate.init_app(app, db)
|
||||
|
||||
# Security
|
||||
bcrypt.init_app(app)
|
||||
|
||||
# CORS - using custom middleware only (see app/middleware/cors.py)
|
||||
# Flask-CORS disabled to avoid conflicts
|
||||
# cors.init_app(app)
|
||||
|
||||
# Marshmallow
|
||||
ma.init_app(app)
|
||||
|
||||
# Rate limiting
|
||||
if app.config.get("RATELIMIT_ENABLED"):
|
||||
limiter.init_app(app)
|
||||
|
||||
# Redis for sessions
|
||||
try:
|
||||
redis_url = app.config.get("REDIS_URL")
|
||||
if redis_url:
|
||||
import gatehouse_app.extensions
|
||||
gatehouse_app.extensions.redis_client = redis.from_url(redis_url)
|
||||
app.config["SESSION_REDIS"] = gatehouse_app.extensions.redis_client
|
||||
except Exception as e:
|
||||
logging.warning(f"Redis connection failed: {e}")
|
||||
|
||||
# Flask-Session
|
||||
flask_session.init_app(app)
|
||||
|
||||
|
||||
def setup_middleware(app):
|
||||
"""Setup application middleware."""
|
||||
RequestIDMiddleware(app)
|
||||
SecurityHeadersMiddleware(app)
|
||||
setup_cors(app)
|
||||
|
||||
|
||||
def register_blueprints(app):
|
||||
"""Register application blueprints."""
|
||||
from gatehouse_app.api import register_api_blueprints
|
||||
from gatehouse_app.api.oidc import oidc_bp
|
||||
|
||||
register_api_blueprints(app)
|
||||
|
||||
# Register OIDC blueprint at root level
|
||||
app.register_blueprint(oidc_bp)
|
||||
|
||||
|
||||
def register_error_handlers(app):
|
||||
"""Register error handlers."""
|
||||
|
||||
@app.errorhandler(BaseAPIException)
|
||||
def handle_api_exception(error):
|
||||
"""Handle custom API exceptions."""
|
||||
return api_response(
|
||||
success=False,
|
||||
message=error.message,
|
||||
status=error.status_code,
|
||||
error_type=error.error_type,
|
||||
error_details=error.error_details,
|
||||
)
|
||||
|
||||
@app.errorhandler(404)
|
||||
def handle_not_found(error):
|
||||
"""Handle 404 errors."""
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Resource not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
@app.errorhandler(405)
|
||||
def handle_method_not_allowed(error):
|
||||
"""Handle 405 errors."""
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Method not allowed",
|
||||
status=405,
|
||||
error_type="METHOD_NOT_ALLOWED",
|
||||
)
|
||||
|
||||
@app.errorhandler(500)
|
||||
def handle_internal_error(error):
|
||||
"""Handle 500 errors."""
|
||||
app.logger.error(f"Internal server error: {error}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Internal server error",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
@app.errorhandler(Exception)
|
||||
def handle_unexpected_error(error):
|
||||
"""Handle unexpected errors."""
|
||||
app.logger.error(f"Unexpected error: {error}", exc_info=True)
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An unexpected error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
def setup_logging(app):
|
||||
"""Setup application logging."""
|
||||
log_level = getattr(logging, app.config.get("LOG_LEVEL", "INFO"))
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(
|
||||
"[%(asctime)s] [%(levelname)s] %(name)s: %(message)s"
|
||||
)
|
||||
|
||||
# Configure root logger - this ensures all module loggers (like app.services.oidc_service)
|
||||
# will output DEBUG level logs when in development mode
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(log_level)
|
||||
|
||||
if app.config.get("LOG_TO_STDOUT"):
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(formatter)
|
||||
stream_handler.setLevel(log_level)
|
||||
root_logger.addHandler(stream_handler)
|
||||
|
||||
# Disable Werkzeug's default logger to avoid log duplication and interference
|
||||
werkzeug_logger = logging.getLogger('werkzeug')
|
||||
werkzeug_logger.setLevel(logging.INFO)
|
||||
|
||||
# Ensure child loggers propagate to root logger
|
||||
# This is the key fix - explicitly enable propagation for common app loggers
|
||||
for logger_name in ['app', 'app.api', 'app.api.oidc', 'app.services', 'app.models']:
|
||||
child_logger = logging.getLogger(logger_name)
|
||||
child_logger.propagate = True
|
||||
child_logger.setLevel(log_level)
|
||||
|
||||
# Configure Flask app logger
|
||||
app.logger.setLevel(log_level)
|
||||
|
||||
# Configure SQLAlchemy logging level (also set at module level before DB init)
|
||||
sqlalchemy_log_level = getattr(logging, app.config.get("SQLALCHEMY_LOG_LEVEL", "WARNING"), logging.WARNING)
|
||||
logging.getLogger('sqlalchemy').setLevel(sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.dialects').setLevel(sqlalchemy_log_level)
|
||||
logging.getLogger('sqlalchemy.pool').setLevel(sqlalchemy_log_level)
|
||||
|
||||
app.logger.info("Application startup")
|
||||
|
||||
# Test debug log after logging is configured
|
||||
app.logger.debug("[TEST] Debug logging is working!")
|
||||
|
||||
|
||||
def initialize_oidc_jwks(app):
|
||||
"""Initialize OIDC JWKS service with a signing key.
|
||||
|
||||
This ensures that signing keys are available for token generation.
|
||||
Keys are loaded from the database if available, otherwise a new key
|
||||
is generated and persisted to the database.
|
||||
|
||||
Args:
|
||||
app: Flask application instance
|
||||
"""
|
||||
with app.app_context():
|
||||
try:
|
||||
jwks_service = OIDCJWKSService()
|
||||
# Use initialize_with_key which handles loading from DB
|
||||
# or generating a new key if none exists
|
||||
signing_key = jwks_service.initialize_with_key()
|
||||
app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}")
|
||||
except Exception as e:
|
||||
app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}")
|
||||
@@ -0,0 +1,24 @@
|
||||
"""API package."""
|
||||
from flask import Blueprint
|
||||
from gatehouse_app.utils.response import api_response
|
||||
|
||||
# Create main API blueprint
|
||||
api_bp = Blueprint("api", __name__)
|
||||
|
||||
|
||||
@api_bp.route("/health", methods=["GET"])
|
||||
def health_check():
|
||||
"""Health check endpoint."""
|
||||
return api_response(
|
||||
data={"status": "healthy", "service": "authy2-backend"},
|
||||
message="Service is running",
|
||||
)
|
||||
|
||||
|
||||
def register_api_blueprints(app):
|
||||
"""Register all API blueprints."""
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
|
||||
# Register versioned API blueprints
|
||||
app.register_blueprint(api_bp, url_prefix="/api")
|
||||
app.register_blueprint(api_v1_bp, url_prefix="/api/v1")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,8 @@
|
||||
"""API v1 blueprint."""
|
||||
from flask import Blueprint
|
||||
|
||||
# Create v1 API blueprint
|
||||
api_v1_bp = Blueprint("api_v1", __name__)
|
||||
|
||||
# Import route modules to register them
|
||||
from gatehouse_app.api.v1 import auth, users, organizations
|
||||
@@ -0,0 +1,937 @@
|
||||
"""Authentication endpoints."""
|
||||
import json
|
||||
from flask import request, session, g, jsonify
|
||||
from marshmallow import ValidationError
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.schemas.auth_schema import (
|
||||
RegisterSchema,
|
||||
LoginSchema,
|
||||
TOTPVerifyEnrollmentSchema,
|
||||
TOTPVerifySchema,
|
||||
TOTPDisableSchema,
|
||||
TOTPRegenerateBackupCodesSchema,
|
||||
)
|
||||
from gatehouse_app.schemas.webauthn_schema import (
|
||||
WebAuthnRegistrationBeginSchema,
|
||||
WebAuthnRegistrationCompleteSchema,
|
||||
WebAuthnLoginBeginSchema,
|
||||
WebAuthnLoginCompleteSchema,
|
||||
WebAuthnCredentialRenameSchema,
|
||||
)
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
from gatehouse_app.services.webauthn_service import WebAuthnService
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.utils.decorators import login_required
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||
from gatehouse_app.exceptions.validation_exceptions import ConflictError, NotFoundError
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/register", methods=["POST"])
|
||||
def register():
|
||||
"""
|
||||
Register a new user.
|
||||
|
||||
Request body:
|
||||
email: User email
|
||||
password: User password
|
||||
password_confirm: Password confirmation
|
||||
full_name: Optional full name
|
||||
|
||||
Returns:
|
||||
201: User created successfully
|
||||
400: Validation error
|
||||
409: Email already exists
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = RegisterSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Register user
|
||||
user = AuthService.register_user(
|
||||
email=data["email"],
|
||||
password=data["password"],
|
||||
full_name=data.get("full_name"),
|
||||
)
|
||||
|
||||
# Create session
|
||||
user_session = AuthService.create_session(user)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"user": user.to_dict(),
|
||||
"token": user_session.token,
|
||||
"expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(),
|
||||
},
|
||||
message="Registration successful",
|
||||
status=201,
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/login", methods=["POST"])
|
||||
def login():
|
||||
"""
|
||||
Login user.
|
||||
|
||||
Request body:
|
||||
email: User email
|
||||
password: User password
|
||||
remember_me: Optional boolean for extended session
|
||||
|
||||
Returns:
|
||||
200: Login successful or TOTP code required
|
||||
400: Validation error
|
||||
401: Invalid credentials
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = LoginSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Authenticate user with email and password
|
||||
user = AuthService.authenticate(
|
||||
email=data["email"],
|
||||
password=data["password"],
|
||||
)
|
||||
|
||||
# Check if user has TOTP enabled for two-factor authentication
|
||||
if user.has_totp_enabled():
|
||||
# TOTP is enabled - store user_id in session for TOTP verification
|
||||
# The /auth/totp/verify endpoint will retrieve this user_id
|
||||
session["totp_pending_user_id"] = user.id
|
||||
|
||||
# Return response indicating TOTP code is required
|
||||
# Do NOT create session or return token yet - wait for TOTP verification
|
||||
return api_response(
|
||||
data={
|
||||
"requires_totp": True,
|
||||
},
|
||||
message="TOTP code required. Please enter your 6-digit code from your authenticator app.",
|
||||
)
|
||||
|
||||
# TOTP is NOT enabled - proceed with normal login flow
|
||||
# Create session with appropriate duration based on remember_me preference
|
||||
duration = 2592000 if data.get("remember_me") else 86400 # 30 days vs 1 day
|
||||
user_session = AuthService.create_session(user, duration_seconds=duration)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"user": user.to_dict(),
|
||||
"token": user_session.token,
|
||||
"expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(),
|
||||
},
|
||||
message="Login successful",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/logout", methods=["POST"])
|
||||
@login_required
|
||||
def logout():
|
||||
"""
|
||||
Logout current user.
|
||||
|
||||
Returns:
|
||||
200: Logout successful
|
||||
401: Not authenticated
|
||||
"""
|
||||
# Revoke current session (g.current_session is set by login_required decorator)
|
||||
if g.current_session:
|
||||
AuthService.revoke_session(g.current_session.id, reason="User logout")
|
||||
|
||||
return api_response(
|
||||
message="Logout successful",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/me", methods=["GET"])
|
||||
@login_required
|
||||
def get_current_user():
|
||||
"""
|
||||
Get current authenticated user.
|
||||
|
||||
Returns:
|
||||
200: User data
|
||||
401: Not authenticated
|
||||
"""
|
||||
user = g.current_user
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"user": user.to_dict(),
|
||||
"organizations": [
|
||||
{"id": org.id, "name": org.name, "slug": org.slug}
|
||||
for org in user.get_organizations()
|
||||
],
|
||||
},
|
||||
message="User retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/sessions", methods=["GET"])
|
||||
@login_required
|
||||
def get_user_sessions():
|
||||
"""
|
||||
Get all active sessions for current user.
|
||||
|
||||
Returns:
|
||||
200: List of active sessions
|
||||
401: Not authenticated
|
||||
"""
|
||||
from gatehouse_app.services.session_service import SessionService
|
||||
|
||||
sessions = SessionService.get_user_sessions(g.current_user.id, active_only=True)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"sessions": [session.to_dict() for session in sessions],
|
||||
"count": len(sessions),
|
||||
},
|
||||
message="Sessions retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/sessions/<session_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
def revoke_session(session_id):
|
||||
"""
|
||||
Revoke a specific session.
|
||||
|
||||
Args:
|
||||
session_id: ID of session to revoke
|
||||
|
||||
Returns:
|
||||
200: Session revoked
|
||||
401: Not authenticated
|
||||
404: Session not found
|
||||
"""
|
||||
from gatehouse_app.models.session import Session
|
||||
|
||||
# Ensure session belongs to current user
|
||||
user_session = Session.query.filter_by(
|
||||
id=session_id, user_id=g.current_user.id, deleted_at=None
|
||||
).first()
|
||||
|
||||
if not user_session:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Session not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
AuthService.revoke_session(session_id, reason="Revoked by user")
|
||||
|
||||
return api_response(
|
||||
message="Session revoked successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/totp/enroll", methods=["POST"])
|
||||
@login_required
|
||||
def enroll_totp():
|
||||
"""
|
||||
Initiate TOTP enrollment for the current user.
|
||||
|
||||
Returns:
|
||||
201: TOTP enrollment initiated with secret, provisioning_uri, qr_code, and backup_codes
|
||||
401: Not authenticated
|
||||
409: TOTP already enabled
|
||||
"""
|
||||
try:
|
||||
# Initiate TOTP enrollment
|
||||
result = AuthService.enroll_totp(g.current_user)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"secret": result["secret"],
|
||||
"provisioning_uri": result["provisioning_uri"],
|
||||
"qr_code": result["qr_code"],
|
||||
"backup_codes": result["backup_codes"],
|
||||
},
|
||||
message="TOTP enrollment initiated. Please verify with your authenticator app.",
|
||||
status=201,
|
||||
)
|
||||
|
||||
except ConflictError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=e.message,
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/totp/verify-enrollment", methods=["POST"])
|
||||
@login_required
|
||||
def verify_totp_enrollment():
|
||||
"""
|
||||
Complete TOTP enrollment by verifying the first TOTP code.
|
||||
|
||||
Request body:
|
||||
code: 6-digit TOTP code from authenticator app
|
||||
|
||||
Returns:
|
||||
200: TOTP enrollment completed successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
401: Invalid TOTP code
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = TOTPVerifyEnrollmentSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Verify TOTP enrollment
|
||||
AuthService.verify_totp_enrollment(g.current_user, data["code"])
|
||||
|
||||
return api_response(
|
||||
message="TOTP enrollment completed successfully",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
except InvalidCredentialsError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=e.message,
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/totp/verify", methods=["POST"])
|
||||
def verify_totp():
|
||||
"""
|
||||
Verify TOTP code during login.
|
||||
|
||||
Request body:
|
||||
code: 6-digit TOTP code or backup code
|
||||
is_backup_code: True if code is a backup code, False if TOTP code (default: False)
|
||||
|
||||
Returns:
|
||||
200: TOTP code verified successfully with session token
|
||||
400: Validation error
|
||||
401: Invalid TOTP code or session not found
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = TOTPVerifySchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Get user from temporary session (stored in Flask session by login endpoint)
|
||||
user_id = session.get("totp_pending_user_id")
|
||||
if not user_id:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="No pending TOTP verification. Please login first.",
|
||||
status=401,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
)
|
||||
|
||||
# Get user from database
|
||||
from gatehouse_app.models.user import User
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=401,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
)
|
||||
|
||||
# Verify TOTP code
|
||||
AuthService.authenticate_with_totp(
|
||||
user, data["code"], data.get("is_backup_code", False)
|
||||
)
|
||||
|
||||
# Create full session
|
||||
user_session = AuthService.create_session(user)
|
||||
|
||||
# Clear temporary session
|
||||
session.pop("totp_pending_user_id", None)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"user": user.to_dict(),
|
||||
"token": user_session.token,
|
||||
"expires_at": user_session.expires_at.isoformat() + "Z"
|
||||
if user_session.expires_at.isoformat()[-1] != "Z"
|
||||
else user_session.expires_at.isoformat(),
|
||||
},
|
||||
message="TOTP verification successful",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
except InvalidCredentialsError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=e.message,
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/totp/disable", methods=["DELETE"])
|
||||
@login_required
|
||||
def disable_totp():
|
||||
"""
|
||||
Disable TOTP for the current user.
|
||||
|
||||
Request body:
|
||||
password: User's current password for verification
|
||||
|
||||
Returns:
|
||||
200: TOTP disabled successfully
|
||||
400: Validation error
|
||||
401: Not authenticated or invalid password
|
||||
401: TOTP not enabled
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = TOTPDisableSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Disable TOTP
|
||||
AuthService.disable_totp(g.current_user, data["password"])
|
||||
|
||||
return api_response(
|
||||
message="TOTP disabled successfully",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
except InvalidCredentialsError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=e.message,
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/totp/status", methods=["GET"])
|
||||
@login_required
|
||||
def get_totp_status():
|
||||
"""
|
||||
Get TOTP status for the current user.
|
||||
|
||||
Returns:
|
||||
200: TOTP status with totp_enabled, verified_at, and backup_codes_remaining
|
||||
401: Not authenticated
|
||||
"""
|
||||
user = g.current_user
|
||||
|
||||
# Check if TOTP is enabled
|
||||
totp_enabled = user.has_totp_enabled()
|
||||
|
||||
# Get TOTP method to check backup codes remaining
|
||||
backup_codes_remaining = 0
|
||||
verified_at = None
|
||||
|
||||
if totp_enabled:
|
||||
totp_method = user.get_totp_method()
|
||||
if totp_method and totp_method.provider_data:
|
||||
backup_codes = totp_method.provider_data.get("backup_codes", [])
|
||||
backup_codes_remaining = len(backup_codes)
|
||||
if totp_method and totp_method.totp_verified_at:
|
||||
verified_at = totp_method.totp_verified_at.isoformat() + "Z" if totp_method.totp_verified_at.isoformat()[-1] != "Z" else totp_method.totp_verified_at.isoformat()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"totp_enabled": totp_enabled,
|
||||
"verified_at": verified_at,
|
||||
"backup_codes_remaining": backup_codes_remaining,
|
||||
},
|
||||
message="TOTP status retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/totp/regenerate-backup-codes", methods=["POST"])
|
||||
@login_required
|
||||
def regenerate_totp_backup_codes():
|
||||
"""
|
||||
Generate new backup codes for TOTP.
|
||||
|
||||
Request body:
|
||||
password: User's current password for verification
|
||||
|
||||
Returns:
|
||||
200: New backup codes generated successfully
|
||||
400: Validation error
|
||||
401: Not authenticated or invalid password
|
||||
401: TOTP not enabled
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = TOTPRegenerateBackupCodesSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Regenerate backup codes
|
||||
backup_codes = AuthService.regenerate_totp_backup_codes(
|
||||
g.current_user, data["password"]
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"backup_codes": backup_codes,
|
||||
},
|
||||
message="Backup codes regenerated successfully",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
except InvalidCredentialsError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=e.message,
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WebAuthn Passkey Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/webauthn/register/begin", methods=["POST"])
|
||||
@login_required
|
||||
def begin_webauthn_registration():
|
||||
"""
|
||||
Begin WebAuthn passkey registration.
|
||||
|
||||
Returns:
|
||||
200: PublicKeyCredentialCreationOptions (raw JSON, no wrapper)
|
||||
401: Not authenticated
|
||||
"""
|
||||
user = g.current_user
|
||||
|
||||
# Generate registration challenge
|
||||
options = WebAuthnService.generate_registration_challenge(user)
|
||||
|
||||
# Return unwrapped JSON for WebAuthn
|
||||
return jsonify(options), 200
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/webauthn/register/complete", methods=["POST"])
|
||||
@login_required
|
||||
def complete_webauthn_registration():
|
||||
"""
|
||||
Complete WebAuthn passkey registration.
|
||||
|
||||
Request body:
|
||||
id: Credential ID
|
||||
rawId: Base64URL-encoded credential ID
|
||||
type: "public-key"
|
||||
response: Attestation response data
|
||||
transports: List of transport types
|
||||
|
||||
Returns:
|
||||
200: Registration successful
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
409: Credential already exists
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = WebAuthnRegistrationCompleteSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Extract challenge from client data
|
||||
client_data = data.get("response", {}).get("clientDataJSON", "")
|
||||
import base64
|
||||
client_data_json = base64.urlsafe_b64decode(client_data + "==")
|
||||
client_data_dict = json.loads(client_data_json)
|
||||
challenge = client_data_dict.get("challenge")
|
||||
|
||||
if not challenge:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Invalid challenge in client data",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
# Verify registration response
|
||||
auth_method = WebAuthnService.verify_registration_response(
|
||||
g.current_user,
|
||||
data,
|
||||
challenge
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"credential": auth_method.to_webauthn_dict(),
|
||||
},
|
||||
message="Passkey registered successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
except InvalidCredentialsError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=e.message,
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/webauthn/login/begin", methods=["POST"])
|
||||
def begin_webauthn_login():
|
||||
"""
|
||||
Begin WebAuthn passkey login.
|
||||
|
||||
Request body:
|
||||
email: User email address
|
||||
|
||||
Returns:
|
||||
200: PublicKeyCredentialRequestOptions (raw JSON, no wrapper)
|
||||
400: Validation error
|
||||
404: User not found
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = WebAuthnLoginBeginSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Find user by email
|
||||
from gatehouse_app.models.user import User
|
||||
user = User.query.filter_by(
|
||||
email=data["email"].lower(),
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Check if user has any WebAuthn credentials
|
||||
if not user.has_webauthn_enabled():
|
||||
return api_response(
|
||||
success=False,
|
||||
message="No passkeys found for this account",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Generate authentication challenge
|
||||
options = WebAuthnService.generate_authentication_challenge(user)
|
||||
|
||||
# Store user_id in session for verification
|
||||
session["webauthn_pending_user_id"] = user.id
|
||||
|
||||
# Return unwrapped JSON for WebAuthn
|
||||
return jsonify(options), 200
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/webauthn/login/complete", methods=["POST"])
|
||||
def complete_webauthn_login():
|
||||
"""
|
||||
Complete WebAuthn passkey login.
|
||||
|
||||
Request body:
|
||||
id: Credential ID
|
||||
rawId: Base64URL-encoded credential ID
|
||||
type: "public-key"
|
||||
response: Assertion response data
|
||||
|
||||
Returns:
|
||||
200: Login successful with session token
|
||||
400: Validation error
|
||||
401: Authentication failed
|
||||
"""
|
||||
try:
|
||||
# Get user from session
|
||||
user_id = session.get("webauthn_pending_user_id")
|
||||
if not user_id:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="No pending WebAuthn verification. Please initiate login first.",
|
||||
status=401,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
)
|
||||
|
||||
# Validate request data
|
||||
schema = WebAuthnLoginCompleteSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Get user from database
|
||||
from gatehouse_app.models.user import User
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=401,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
)
|
||||
|
||||
# Extract challenge from client data
|
||||
client_data = data.get("response", {}).get("clientDataJSON", "")
|
||||
import base64
|
||||
client_data_json = base64.urlsafe_b64decode(client_data + "==")
|
||||
client_data_dict = json.loads(client_data_json)
|
||||
challenge = client_data_dict.get("challenge")
|
||||
|
||||
if not challenge:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Invalid challenge in client data",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
# Verify authentication response
|
||||
WebAuthnService.verify_authentication_response(
|
||||
user,
|
||||
data,
|
||||
challenge
|
||||
)
|
||||
|
||||
# Create session
|
||||
user_session = AuthService.create_session(user)
|
||||
|
||||
# Clear pending session
|
||||
session.pop("webauthn_pending_user_id", None)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"user": user.to_dict(),
|
||||
"token": user_session.token,
|
||||
"expires_at": user_session.expires_at.isoformat() + "Z"
|
||||
if user_session.expires_at.isoformat()[-1] != "Z"
|
||||
else user_session.expires_at.isoformat(),
|
||||
},
|
||||
message="Login successful",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
except InvalidCredentialsError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=e.message,
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/webauthn/credentials", methods=["GET"])
|
||||
@login_required
|
||||
def list_webauthn_credentials():
|
||||
"""
|
||||
List all WebAuthn passkey credentials for the current user.
|
||||
|
||||
Returns:
|
||||
200: List of credentials
|
||||
401: Not authenticated
|
||||
"""
|
||||
user = g.current_user
|
||||
credentials = WebAuthnService.get_user_credentials(user)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"credentials": [cred.to_webauthn_dict() for cred in credentials],
|
||||
"count": len(credentials),
|
||||
},
|
||||
message="Credentials retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/webauthn/credentials/<credential_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
def delete_webauthn_credential(credential_id):
|
||||
"""
|
||||
Delete a WebAuthn passkey credential.
|
||||
|
||||
Args:
|
||||
credential_id: ID of the credential to delete
|
||||
|
||||
Returns:
|
||||
200: Credential deleted successfully
|
||||
401: Not authenticated
|
||||
404: Credential not found
|
||||
"""
|
||||
user = g.current_user
|
||||
|
||||
# Check if this is the last credential
|
||||
credential_count = user.get_webauthn_credential_count()
|
||||
if credential_count <= 1:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Cannot delete the last passkey. Add another passkey first.",
|
||||
status=400,
|
||||
error_type="BAD_REQUEST",
|
||||
)
|
||||
|
||||
# Delete the credential
|
||||
success = WebAuthnService.delete_credential(credential_id, user)
|
||||
|
||||
if not success:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Credential not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
message="Passkey deleted successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/webauthn/credentials/<credential_id>", methods=["PATCH"])
|
||||
@login_required
|
||||
def rename_webauthn_credential(credential_id):
|
||||
"""
|
||||
Rename a WebAuthn passkey credential.
|
||||
|
||||
Args:
|
||||
credential_id: ID of the credential to rename
|
||||
|
||||
Request body:
|
||||
name: New name for the credential
|
||||
|
||||
Returns:
|
||||
200: Credential renamed successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
404: Credential not found
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = WebAuthnCredentialRenameSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Rename the credential
|
||||
success = WebAuthnService.rename_credential(
|
||||
credential_id,
|
||||
g.current_user,
|
||||
data["name"]
|
||||
)
|
||||
|
||||
if not success:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Credential not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Get updated credential
|
||||
credential = WebAuthnService.get_credential_by_id(credential_id, g.current_user)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"credential": credential.to_webauthn_dict() if credential else None,
|
||||
},
|
||||
message="Passkey renamed successfully",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/webauthn/status", methods=["GET"])
|
||||
@login_required
|
||||
def get_webauthn_status():
|
||||
"""
|
||||
Get WebAuthn status for the current user.
|
||||
|
||||
Returns:
|
||||
200: WebAuthn status with webauthn_enabled and credential_count
|
||||
401: Not authenticated
|
||||
"""
|
||||
user = g.current_user
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"webauthn_enabled": user.has_webauthn_enabled(),
|
||||
"credential_count": user.get_webauthn_credential_count(),
|
||||
},
|
||||
message="WebAuthn status retrieved successfully",
|
||||
)
|
||||
@@ -0,0 +1,372 @@
|
||||
"""Organization endpoints."""
|
||||
from flask import g, request
|
||||
from marshmallow import ValidationError
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.utils.decorators import login_required, require_admin, require_owner
|
||||
from gatehouse_app.schemas.organization_schema import (
|
||||
OrganizationCreateSchema,
|
||||
OrganizationUpdateSchema,
|
||||
InviteMemberSchema,
|
||||
UpdateMemberRoleSchema,
|
||||
)
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations", methods=["POST"])
|
||||
@login_required
|
||||
def create_organization():
|
||||
"""
|
||||
Create a new organization.
|
||||
|
||||
Request body:
|
||||
name: Organization name
|
||||
slug: Organization slug (unique)
|
||||
description: Optional description
|
||||
logo_url: Optional logo URL
|
||||
|
||||
Returns:
|
||||
201: Organization created successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
409: Slug already exists
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = OrganizationCreateSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Create organization
|
||||
org = OrganizationService.create_organization(
|
||||
name=data["name"],
|
||||
slug=data["slug"],
|
||||
owner_user_id=g.current_user.id,
|
||||
description=data.get("description"),
|
||||
logo_url=data.get("logo_url"),
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"organization": org.to_dict()},
|
||||
message="Organization created successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>", methods=["GET"])
|
||||
@login_required
|
||||
def get_organization(org_id):
|
||||
"""
|
||||
Get organization by ID.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
|
||||
Returns:
|
||||
200: Organization data
|
||||
401: Not authenticated
|
||||
403: Not a member
|
||||
404: Organization not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
# Check if user is a member
|
||||
if not org.is_member(g.current_user.id):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are not a member of this organization",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"organization": org.to_dict(),
|
||||
"member_count": org.get_member_count(),
|
||||
},
|
||||
message="Organization retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>", methods=["PATCH"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def update_organization(org_id):
|
||||
"""
|
||||
Update organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
|
||||
Request body:
|
||||
name: Optional organization name
|
||||
description: Optional description
|
||||
logo_url: Optional logo URL
|
||||
|
||||
Returns:
|
||||
200: Organization updated successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization not found
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = OrganizationUpdateSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
# Update organization
|
||||
org = OrganizationService.update_organization(
|
||||
org=org,
|
||||
user_id=g.current_user.id,
|
||||
**data
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"organization": org.to_dict()},
|
||||
message="Organization updated successfully",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
@require_owner
|
||||
def delete_organization(org_id):
|
||||
"""
|
||||
Delete organization (soft delete).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
|
||||
Returns:
|
||||
200: Organization deleted successfully
|
||||
401: Not authenticated
|
||||
403: Not the owner
|
||||
404: Organization not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
OrganizationService.delete_organization(
|
||||
org=org,
|
||||
user_id=g.current_user.id,
|
||||
soft=True,
|
||||
)
|
||||
|
||||
return api_response(
|
||||
message="Organization deleted successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/members", methods=["GET"])
|
||||
@login_required
|
||||
def get_organization_members(org_id):
|
||||
"""
|
||||
Get all members of an organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
|
||||
Returns:
|
||||
200: List of members
|
||||
401: Not authenticated
|
||||
403: Not a member
|
||||
404: Organization not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
# Check if user is a member
|
||||
if not org.is_member(g.current_user.id):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are not a member of this organization",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
members_data = []
|
||||
for member in org.members:
|
||||
if member.deleted_at is None:
|
||||
member_dict = member.to_dict()
|
||||
member_dict["user"] = member.user.to_dict()
|
||||
members_data.append(member_dict)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"members": members_data,
|
||||
"count": len(members_data),
|
||||
},
|
||||
message="Members retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/members", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def add_organization_member(org_id):
|
||||
"""
|
||||
Add a member to the organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
|
||||
Request body:
|
||||
email: User email to invite
|
||||
role: Member role (owner, admin, member, guest)
|
||||
|
||||
Returns:
|
||||
201: Member added successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization or user not found
|
||||
409: User already a member
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = InviteMemberSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
# Find user by email
|
||||
user = UserService.get_user_by_email(data["email"])
|
||||
if not user:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Add member
|
||||
role = OrganizationRole(data["role"])
|
||||
member = OrganizationService.add_member(
|
||||
org=org,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
inviter_id=g.current_user.id,
|
||||
)
|
||||
|
||||
member_dict = member.to_dict()
|
||||
member_dict["user"] = user.to_dict()
|
||||
|
||||
return api_response(
|
||||
data={"member": member_dict},
|
||||
message="Member added successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def remove_organization_member(org_id, user_id):
|
||||
"""
|
||||
Remove a member from the organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
user_id: User ID to remove
|
||||
|
||||
Returns:
|
||||
200: Member removed successfully
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization or member not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
OrganizationService.remove_member(
|
||||
org=org,
|
||||
user_id=user_id,
|
||||
remover_id=g.current_user.id,
|
||||
)
|
||||
|
||||
return api_response(
|
||||
message="Member removed successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>/role", methods=["PATCH"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def update_member_role(org_id, user_id):
|
||||
"""
|
||||
Update a member's role.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
user_id: User ID
|
||||
|
||||
Request body:
|
||||
role: New role (owner, admin, member, guest)
|
||||
|
||||
Returns:
|
||||
200: Role updated successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization or member not found
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = UpdateMemberRoleSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
# Update role
|
||||
new_role = OrganizationRole(data["role"])
|
||||
member = OrganizationService.update_member_role(
|
||||
org=org,
|
||||
user_id=user_id,
|
||||
new_role=new_role,
|
||||
updater_id=g.current_user.id,
|
||||
)
|
||||
|
||||
member_dict = member.to_dict()
|
||||
member_dict["user"] = member.user.to_dict()
|
||||
|
||||
return api_response(
|
||||
data={"member": member_dict},
|
||||
message="Member role updated successfully",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
@@ -0,0 +1,155 @@
|
||||
"""User endpoints."""
|
||||
from flask import g, request
|
||||
from marshmallow import ValidationError
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.utils.decorators import login_required
|
||||
from gatehouse_app.schemas.user_schema import UserUpdateSchema, ChangePasswordSchema
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
|
||||
|
||||
@api_v1_bp.route("/users/me", methods=["GET"])
|
||||
@login_required
|
||||
def get_me():
|
||||
"""
|
||||
Get current user profile.
|
||||
|
||||
Returns:
|
||||
200: User profile data
|
||||
401: Not authenticated
|
||||
"""
|
||||
user = g.current_user
|
||||
|
||||
return api_response(
|
||||
data={"user": user.to_dict()},
|
||||
message="User profile retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/users/me", methods=["PATCH"])
|
||||
@login_required
|
||||
def update_me():
|
||||
"""
|
||||
Update current user profile.
|
||||
|
||||
Request body:
|
||||
full_name: Optional full name
|
||||
avatar_url: Optional avatar URL
|
||||
|
||||
Returns:
|
||||
200: User updated successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = UserUpdateSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Update user
|
||||
user = UserService.update_user(g.current_user, **data)
|
||||
|
||||
return api_response(
|
||||
data={"user": user.to_dict()},
|
||||
message="Profile updated successfully",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/users/me", methods=["DELETE"])
|
||||
@login_required
|
||||
def delete_me():
|
||||
"""
|
||||
Delete current user account (soft delete).
|
||||
|
||||
Returns:
|
||||
200: Account deleted successfully
|
||||
401: Not authenticated
|
||||
"""
|
||||
UserService.delete_user(g.current_user, soft=True)
|
||||
|
||||
return api_response(
|
||||
message="Account deleted successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/users/me/password", methods=["POST"])
|
||||
@login_required
|
||||
def change_password():
|
||||
"""
|
||||
Change current user password.
|
||||
|
||||
Request body:
|
||||
current_password: Current password
|
||||
new_password: New password
|
||||
new_password_confirm: New password confirmation
|
||||
|
||||
Returns:
|
||||
200: Password changed successfully
|
||||
400: Validation error
|
||||
401: Not authenticated or invalid current password
|
||||
"""
|
||||
try:
|
||||
# Validate request data
|
||||
schema = ChangePasswordSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Verify passwords match
|
||||
if data["new_password"] != data["new_password_confirm"]:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="New passwords do not match",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details={"new_password_confirm": ["Passwords do not match"]},
|
||||
)
|
||||
|
||||
# Change password
|
||||
AuthService.change_password(
|
||||
user=g.current_user,
|
||||
current_password=data["current_password"],
|
||||
new_password=data["new_password"],
|
||||
)
|
||||
|
||||
return api_response(
|
||||
message="Password changed successfully",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/users/me/organizations", methods=["GET"])
|
||||
@login_required
|
||||
def get_my_organizations():
|
||||
"""
|
||||
Get all organizations current user is a member of.
|
||||
|
||||
Returns:
|
||||
200: List of organizations
|
||||
401: Not authenticated
|
||||
"""
|
||||
organizations = UserService.get_user_organizations(g.current_user)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"organizations": [org.to_dict() for org in organizations],
|
||||
"count": len(organizations),
|
||||
},
|
||||
message="Organizations retrieved successfully",
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Exceptions package."""
|
||||
from gatehouse_app.exceptions.base import BaseAPIException
|
||||
from gatehouse_app.exceptions.auth_exceptions import (
|
||||
UnauthorizedError,
|
||||
ForbiddenError,
|
||||
InvalidCredentialsError,
|
||||
AccountSuspendedError,
|
||||
AccountInactiveError,
|
||||
SessionExpiredError,
|
||||
InvalidTokenError,
|
||||
)
|
||||
from gatehouse_app.exceptions.validation_exceptions import (
|
||||
ValidationError,
|
||||
NotFoundError,
|
||||
ConflictError,
|
||||
BadRequestError,
|
||||
RateLimitExceededError,
|
||||
EmailAlreadyExistsError,
|
||||
OrganizationNotFoundError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseAPIException",
|
||||
"UnauthorizedError",
|
||||
"ForbiddenError",
|
||||
"InvalidCredentialsError",
|
||||
"AccountSuspendedError",
|
||||
"AccountInactiveError",
|
||||
"SessionExpiredError",
|
||||
"InvalidTokenError",
|
||||
"ValidationError",
|
||||
"NotFoundError",
|
||||
"ConflictError",
|
||||
"BadRequestError",
|
||||
"RateLimitExceededError",
|
||||
"EmailAlreadyExistsError",
|
||||
"OrganizationNotFoundError",
|
||||
"UserNotFoundError",
|
||||
]
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Authentication and authorization exceptions."""
|
||||
from gatehouse_app.exceptions.base import BaseAPIException
|
||||
|
||||
|
||||
class UnauthorizedError(BaseAPIException):
|
||||
"""Raised when authentication is required but not provided."""
|
||||
|
||||
status_code = 401
|
||||
error_type = "AUTHENTICATION_ERROR"
|
||||
message = "Authentication required"
|
||||
|
||||
|
||||
class ForbiddenError(BaseAPIException):
|
||||
"""Raised when user lacks permissions for the requested action."""
|
||||
|
||||
status_code = 403
|
||||
error_type = "AUTHORIZATION_ERROR"
|
||||
message = "You don't have permission to perform this action"
|
||||
|
||||
|
||||
class InvalidCredentialsError(BaseAPIException):
|
||||
"""Raised when login credentials are invalid."""
|
||||
|
||||
status_code = 401
|
||||
error_type = "AUTHENTICATION_ERROR"
|
||||
message = "Invalid email or password"
|
||||
|
||||
|
||||
class AccountSuspendedError(BaseAPIException):
|
||||
"""Raised when user account is suspended."""
|
||||
|
||||
status_code = 403
|
||||
error_type = "AUTHORIZATION_ERROR"
|
||||
message = "Your account has been suspended"
|
||||
|
||||
|
||||
class AccountInactiveError(BaseAPIException):
|
||||
"""Raised when user account is inactive."""
|
||||
|
||||
status_code = 403
|
||||
error_type = "AUTHORIZATION_ERROR"
|
||||
message = "Your account is inactive"
|
||||
|
||||
|
||||
class SessionExpiredError(BaseAPIException):
|
||||
"""Raised when user session has expired."""
|
||||
|
||||
status_code = 401
|
||||
error_type = "AUTHENTICATION_ERROR"
|
||||
message = "Your session has expired. Please log in again"
|
||||
|
||||
|
||||
class InvalidTokenError(BaseAPIException):
|
||||
"""Raised when authentication token is invalid."""
|
||||
|
||||
status_code = 401
|
||||
error_type = "AUTHENTICATION_ERROR"
|
||||
message = "Invalid authentication token"
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Base exception classes."""
|
||||
|
||||
|
||||
class BaseAPIException(Exception):
|
||||
"""Base exception for all API errors."""
|
||||
|
||||
status_code = 500
|
||||
error_type = "INTERNAL_ERROR"
|
||||
message = "An unexpected error occurred"
|
||||
|
||||
def __init__(self, message=None, error_details=None):
|
||||
"""
|
||||
Initialize exception.
|
||||
|
||||
Args:
|
||||
message: Custom error message
|
||||
error_details: Additional error details dictionary
|
||||
"""
|
||||
super().__init__()
|
||||
if message:
|
||||
self.message = message
|
||||
self.error_details = error_details or {}
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert exception to dictionary for API response."""
|
||||
return {
|
||||
"error_type": self.error_type,
|
||||
"message": self.message,
|
||||
"details": self.error_details,
|
||||
"status_code": self.status_code,
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Validation and resource exceptions."""
|
||||
from gatehouse_app.exceptions.base import BaseAPIException
|
||||
|
||||
|
||||
class ValidationError(BaseAPIException):
|
||||
"""Raised when request data validation fails."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "VALIDATION_ERROR"
|
||||
message = "Validation failed"
|
||||
|
||||
|
||||
class NotFoundError(BaseAPIException):
|
||||
"""Raised when a requested resource is not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "NOT_FOUND"
|
||||
message = "Resource not found"
|
||||
|
||||
|
||||
class ConflictError(BaseAPIException):
|
||||
"""Raised when a resource conflict occurs."""
|
||||
|
||||
status_code = 409
|
||||
error_type = "CONFLICT"
|
||||
message = "Resource conflict"
|
||||
|
||||
|
||||
class BadRequestError(BaseAPIException):
|
||||
"""Raised when the request is malformed or invalid."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "BAD_REQUEST"
|
||||
message = "Bad request"
|
||||
|
||||
|
||||
class RateLimitExceededError(BaseAPIException):
|
||||
"""Raised when rate limit is exceeded."""
|
||||
|
||||
status_code = 429
|
||||
error_type = "RATE_LIMIT_EXCEEDED"
|
||||
message = "Too many requests. Please try again later"
|
||||
|
||||
|
||||
class EmailAlreadyExistsError(ConflictError):
|
||||
"""Raised when attempting to register with an existing email."""
|
||||
|
||||
message = "Email address already registered"
|
||||
|
||||
|
||||
class OrganizationNotFoundError(NotFoundError):
|
||||
"""Raised when organization is not found."""
|
||||
|
||||
message = "Organization not found"
|
||||
|
||||
|
||||
class UserNotFoundError(NotFoundError):
|
||||
"""Raised when user is not found."""
|
||||
|
||||
message = "User not found"
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Flask extensions initialization."""
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
from flask_bcrypt import Bcrypt
|
||||
from flask_cors import CORS
|
||||
from flask_marshmallow import Marshmallow
|
||||
from flask_limiter import Limiter
|
||||
from flask_limiter.util import get_remote_address
|
||||
from flask_session import Session
|
||||
import redis
|
||||
|
||||
# Initialize extensions
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
bcrypt = Bcrypt()
|
||||
cors = CORS()
|
||||
ma = Marshmallow()
|
||||
limiter = Limiter(
|
||||
key_func=get_remote_address,
|
||||
default_limits=["100 per hour"],
|
||||
storage_uri="memory://", # Will be overridden by config
|
||||
)
|
||||
session = Session()
|
||||
|
||||
# Redis client - will be initialized with app
|
||||
redis_client = None
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Middleware package."""
|
||||
from gatehouse_app.middleware.request_id import RequestIDMiddleware
|
||||
from gatehouse_app.middleware.security_headers import SecurityHeadersMiddleware
|
||||
from gatehouse_app.middleware.cors import setup_cors
|
||||
|
||||
__all__ = ["RequestIDMiddleware", "SecurityHeadersMiddleware", "setup_cors"]
|
||||
@@ -0,0 +1,64 @@
|
||||
"""CORS middleware configuration."""
|
||||
from flask import request, make_response
|
||||
|
||||
|
||||
def setup_cors(app):
|
||||
"""
|
||||
Configure CORS for the application.
|
||||
|
||||
Args:
|
||||
app: Flask application instance
|
||||
"""
|
||||
|
||||
@app.before_request
|
||||
def handle_preflight():
|
||||
"""Handle CORS preflight OPTIONS requests."""
|
||||
if request.method == "OPTIONS":
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
|
||||
if allow_all:
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
return response
|
||||
elif origin and origin in cors_origins:
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
return response
|
||||
|
||||
@app.after_request
|
||||
def after_request_cors(response):
|
||||
"""Add additional CORS headers if needed."""
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
|
||||
if allow_all:
|
||||
# When allowing all origins, set header to "*"
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
elif origin and origin in cors_origins:
|
||||
# When allowing specific origins, echo the request origin
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Request-ID, Cache-Control, Pragma"
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,38 @@
|
||||
"""Request ID middleware for request tracing."""
|
||||
import uuid
|
||||
from flask import g, request
|
||||
|
||||
|
||||
class RequestIDMiddleware:
|
||||
"""Middleware to add unique request ID to each request."""
|
||||
|
||||
def __init__(self, app=None):
|
||||
"""Initialize middleware."""
|
||||
self.app = app
|
||||
if app is not None:
|
||||
self.init_app(app)
|
||||
|
||||
def init_app(self, app):
|
||||
"""Initialize with Flask app."""
|
||||
app.before_request(self.before_request)
|
||||
app.after_request(self.after_request)
|
||||
|
||||
@staticmethod
|
||||
def before_request():
|
||||
"""Generate or extract request ID before request processing."""
|
||||
# Check if request already has an ID from client
|
||||
request_id = request.headers.get("X-Request-ID")
|
||||
|
||||
# Generate new ID if not provided
|
||||
if not request_id:
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Store in Flask g object for access throughout request
|
||||
g.request_id = request_id
|
||||
|
||||
@staticmethod
|
||||
def after_request(response):
|
||||
"""Add request ID to response headers."""
|
||||
if hasattr(g, "request_id"):
|
||||
response.headers["X-Request-ID"] = g.request_id
|
||||
return response
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Security headers middleware."""
|
||||
from flask import request
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware:
|
||||
"""Middleware to add security headers to responses."""
|
||||
|
||||
def __init__(self, app=None):
|
||||
"""Initialize middleware."""
|
||||
self.app = app
|
||||
if app is not None:
|
||||
self.init_app(app)
|
||||
|
||||
def init_app(self, app):
|
||||
"""Initialize with Flask app."""
|
||||
app.after_request(self.add_security_headers)
|
||||
|
||||
@staticmethod
|
||||
def add_security_headers(response):
|
||||
"""Add security headers to response."""
|
||||
# Prevent MIME type sniffing
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
|
||||
# Enable XSS protection
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
|
||||
# Prevent clickjacking
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
|
||||
# Strict Transport Security (HSTS)
|
||||
if request.is_secure:
|
||||
response.headers["Strict-Transport-Security"] = (
|
||||
"max-age=31536000; includeSubDomains"
|
||||
)
|
||||
|
||||
# Content Security Policy
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self' 'unsafe-inline'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data: https:; "
|
||||
"font-src 'self' data:; "
|
||||
"connect-src 'self'"
|
||||
)
|
||||
|
||||
# Referrer Policy
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Permissions Policy
|
||||
response.headers["Permissions-Policy"] = (
|
||||
"geolocation=(), microphone=(), camera=()"
|
||||
)
|
||||
|
||||
# Cache-Control: Allow OIDC endpoints to set their own Cache-Control
|
||||
# Only set no-cache for API responses that haven't set their own cache headers
|
||||
if "Cache-Control" not in response.headers:
|
||||
# Check if this is a JSON API response (shouldn't be cached)
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if "application/json" in content_type:
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
elif "text/html" not in content_type:
|
||||
# For non-HTML responses, add Pragma for HTTP/1.0 compatibility
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Models package."""
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.models.session import Session
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken
|
||||
from gatehouse_app.models.oidc_session import OIDCSession
|
||||
from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
|
||||
from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"User",
|
||||
"Organization",
|
||||
"OrganizationMember",
|
||||
"AuthenticationMethod",
|
||||
"Session",
|
||||
"AuditLog",
|
||||
"OIDCClient",
|
||||
"OIDCAuthCode",
|
||||
"OIDCRefreshToken",
|
||||
"OIDCSession",
|
||||
"OIDCTokenMetadata",
|
||||
"OIDCAuditLog",
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Audit log model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
|
||||
|
||||
class AuditLog(BaseModel):
|
||||
"""Audit log model for tracking user and system actions."""
|
||||
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True, index=True)
|
||||
action = db.Column(db.Enum(AuditAction), nullable=False, index=True)
|
||||
|
||||
# Context
|
||||
resource_type = db.Column(db.String(50), nullable=True, index=True)
|
||||
resource_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
organization_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
|
||||
# Request details
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
request_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
|
||||
# Additional data
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Success/failure
|
||||
success = db.Column(db.Boolean, default=True, nullable=False)
|
||||
error_message = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="audit_logs")
|
||||
|
||||
# Indexes for common queries
|
||||
__table_args__ = (
|
||||
db.Index("idx_audit_user_action", "user_id", "action"),
|
||||
db.Index("idx_audit_resource", "resource_type", "resource_id"),
|
||||
db.Index("idx_audit_org", "organization_id", "created_at"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of AuditLog."""
|
||||
return f"<AuditLog action={self.action} user_id={self.user_id}>"
|
||||
|
||||
@classmethod
|
||||
def log(cls, action, user_id=None, **kwargs):
|
||||
"""
|
||||
Create an audit log entry.
|
||||
|
||||
Args:
|
||||
action: AuditAction enum value
|
||||
user_id: ID of the user performing the action
|
||||
**kwargs: Additional audit log fields
|
||||
|
||||
Returns:
|
||||
AuditLog instance
|
||||
"""
|
||||
log_entry = cls(action=action, user_id=user_id, **kwargs)
|
||||
log_entry.save()
|
||||
return log_entry
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Authentication method model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
|
||||
class AuthenticationMethod(BaseModel):
|
||||
"""Authentication method model storing user authentication credentials."""
|
||||
|
||||
__tablename__ = "authentication_methods"
|
||||
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True)
|
||||
method_type = db.Column(db.Enum(AuthMethodType), nullable=False, index=True)
|
||||
|
||||
# For password authentication
|
||||
password_hash = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# For OAuth/OIDC providers
|
||||
provider_user_id = db.Column(db.String(255), nullable=True)
|
||||
provider_data = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# For TOTP authentication
|
||||
totp_secret = db.Column(db.String(32), nullable=True)
|
||||
totp_backup_codes = db.Column(db.JSON, nullable=True)
|
||||
totp_verified_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Metadata
|
||||
is_primary = db.Column(db.Boolean, default=False, nullable=False)
|
||||
verified = db.Column(db.Boolean, default=False, nullable=False)
|
||||
last_used_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="authentication_methods")
|
||||
|
||||
# Ensure unique provider combinations
|
||||
__table_args__ = (
|
||||
db.Index("idx_user_method", "user_id", "method_type"),
|
||||
db.UniqueConstraint(
|
||||
"user_id", "method_type", "provider_user_id", name="uix_user_method_provider"
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of AuthenticationMethod."""
|
||||
return f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
|
||||
|
||||
def is_password(self):
|
||||
"""Check if this is a password authentication method."""
|
||||
return self.method_type == AuthMethodType.PASSWORD
|
||||
|
||||
def is_oauth(self):
|
||||
"""Check if this is an OAuth authentication method."""
|
||||
return self.method_type in [
|
||||
AuthMethodType.GOOGLE,
|
||||
AuthMethodType.GITHUB,
|
||||
AuthMethodType.MICROSOFT,
|
||||
]
|
||||
|
||||
def is_totp(self):
|
||||
"""Check if this is a TOTP authentication method."""
|
||||
return self.method_type == AuthMethodType.TOTP
|
||||
|
||||
def is_webauthn(self):
|
||||
"""Check if this is a WebAuthn authentication method."""
|
||||
return self.method_type == AuthMethodType.WEBAUTHN
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude password hash and TOTP secrets
|
||||
exclude.append("password_hash")
|
||||
exclude.append("totp_secret")
|
||||
exclude.append("totp_backup_codes")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
def to_webauthn_dict(self):
|
||||
"""Convert WebAuthn credential to public dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary with safe-to-expose credential information.
|
||||
"""
|
||||
if not self.is_webauthn() or not self.provider_data:
|
||||
return None
|
||||
|
||||
data = self.provider_data
|
||||
return {
|
||||
"id": data.get("credential_id"),
|
||||
"name": data.get("name"),
|
||||
"transports": data.get("transports", []),
|
||||
"created_at": data.get("created_at"),
|
||||
"last_used_at": data.get("last_used_at"),
|
||||
"sign_count": data.get("sign_count", 0),
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Base model with common fields and functionality."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
|
||||
class BaseModel(db.Model):
|
||||
"""Base model class with common fields."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id = db.Column(
|
||||
db.String(36),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
created_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = db.Column(
|
||||
db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
deleted_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
def save(self):
|
||||
"""Save the model instance to database."""
|
||||
db.session.add(self)
|
||||
db.session.commit()
|
||||
return self
|
||||
|
||||
def delete(self, soft=True):
|
||||
"""
|
||||
Delete the model instance.
|
||||
|
||||
Args:
|
||||
soft: If True, performs soft delete. If False, hard delete.
|
||||
"""
|
||||
if soft:
|
||||
self.deleted_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
else:
|
||||
db.session.delete(self)
|
||||
db.session.commit()
|
||||
|
||||
def update(self, **kwargs):
|
||||
"""Update model fields."""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
self.updated_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
return self
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""
|
||||
Convert model to dictionary.
|
||||
|
||||
Args:
|
||||
exclude: List of fields to exclude from output
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the model
|
||||
"""
|
||||
exclude = exclude or []
|
||||
result = {}
|
||||
for column in self.__table__.columns:
|
||||
if column.name not in exclude:
|
||||
value = getattr(self, column.name)
|
||||
if isinstance(value, datetime):
|
||||
result[column.name] = value.isoformat()
|
||||
else:
|
||||
result[column.name] = value
|
||||
return result
|
||||
@@ -0,0 +1,231 @@
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
|
||||
from datetime import datetime
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCAuditLog(BaseModel):
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking.
|
||||
|
||||
This model logs all OIDC-related events for security, compliance,
|
||||
and debugging purposes.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_audit_logs"
|
||||
|
||||
# Event type categorization
|
||||
event_type = db.Column(db.String(100), nullable=False, index=True)
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=True, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Event outcome
|
||||
success = db.Column(db.Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Error details (for failed events)
|
||||
error_code = db.Column(db.String(100), nullable=True)
|
||||
error_description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Request context
|
||||
ip_address = db.Column(db.String(45), nullable=True, index=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
request_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
|
||||
# Additional event metadata
|
||||
event_metadata = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="audit_logs")
|
||||
user = db.relationship("User", back_populates="oidc_audit_logs")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuditLog."""
|
||||
status = "success" if self.success else "failed"
|
||||
return f"<OIDCAuditLog event={self.event_type} status={status} client={self.client_id}>"
|
||||
|
||||
@classmethod
|
||||
def log_event(cls, event_type, client_id=None, user_id=None, success=True,
|
||||
error_code=None, error_description=None, ip_address=None,
|
||||
user_agent=None, request_id=None, event_metadata=None):
|
||||
"""Log an OIDC event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event (e.g., "authorization_request", "token_issue")
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
success: Whether the event was successful
|
||||
error_code: Error code if event failed
|
||||
error_description: Error description if event failed
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
request_id: Request ID for correlation
|
||||
event_metadata: Additional event metadata
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
log = cls(
|
||||
event_type=event_type,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
db.session.add(log)
|
||||
db.session.commit()
|
||||
return log
|
||||
|
||||
@classmethod
|
||||
def log_authorization_request(cls, client_id, user_id, redirect_uri, scope,
|
||||
ip_address=None, user_agent=None, request_id=None,
|
||||
success=True, error_code=None, error_description=None):
|
||||
"""Log an authorization request event."""
|
||||
return cls.log_event(
|
||||
event_type="authorization_request",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": scope,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_issue(cls, client_id, user_id, token_type,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
"""Log a token issuance event."""
|
||||
return cls.log_event(
|
||||
event_type="token_issue",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=True,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={"token_type": token_type}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_revocation(cls, client_id, user_id, token_type, reason=None,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
"""Log a token revocation event."""
|
||||
return cls.log_event(
|
||||
event_type="token_revocation",
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=True,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={
|
||||
"token_type": token_type,
|
||||
"reason": reason,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_authentication_failure(cls, client_id, error_code, error_description,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
"""Log an authentication failure event."""
|
||||
return cls.log_event(
|
||||
event_type="authentication_failure",
|
||||
client_id=client_id,
|
||||
success=False,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_user(cls, user_id, limit=100):
|
||||
"""Get audit events for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return cls.query.filter_by(user_id=user_id, deleted_at=None)\
|
||||
.order_by(cls.created_at.desc())\
|
||||
.limit(limit)\
|
||||
.all()
|
||||
|
||||
@classmethod
|
||||
def get_events_for_client(cls, client_id, limit=100):
|
||||
"""Get audit events for a client.
|
||||
|
||||
Args:
|
||||
client_id: The client ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return cls.query.filter_by(client_id=client_id, deleted_at=None)\
|
||||
.order_by(cls.created_at.desc())\
|
||||
.limit(limit)\
|
||||
.all()
|
||||
|
||||
@classmethod
|
||||
def get_failed_events(cls, client_id=None, user_id=None, start_date=None,
|
||||
end_date=None, limit=100):
|
||||
"""Get failed audit events.
|
||||
|
||||
Args:
|
||||
client_id: Optional client ID filter
|
||||
user_id: Optional user ID filter
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
query = cls.query.filter_by(success=False, deleted_at=None)
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
if start_date:
|
||||
query = query.filter(cls.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(cls.created_at <= end_date)
|
||||
|
||||
return query.order_by(cls.created_at.desc()).limit(limit).all()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -0,0 +1,125 @@
|
||||
"""OIDC Authorization Code model for auth code flow."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCAuthCode(BaseModel):
|
||||
"""OIDC Authorization Code model for authorization code flow.
|
||||
|
||||
Authorization codes are single-use, short-lived codes used in the
|
||||
authorization code grant flow. The code is hashed for security.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_authorization_codes"
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Authorization code (hashed for security)
|
||||
code_hash = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# Request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True) # Requested scopes
|
||||
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
|
||||
code_verifier = db.Column(db.String(255), nullable=True) # For PKCE
|
||||
|
||||
# Status tracking
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
is_used = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
# Request metadata
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="authorization_codes")
|
||||
user = db.relationship("User", back_populates="oidc_auth_codes")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuthCode."""
|
||||
return f"<OIDCAuthCode client_id={self.client_id} user_id={self.user_id} used={self.is_used}>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the authorization code has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
# Make naive datetime timezone-aware (UTC)
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_valid(self):
|
||||
"""Check if the authorization code is valid for use."""
|
||||
return not self.is_used and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def mark_as_used(self):
|
||||
"""Mark the authorization code as used."""
|
||||
self.is_used = True
|
||||
self.used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None,
|
||||
nonce=None, code_verifier=None, ip_address=None, user_agent=None,
|
||||
lifetime_seconds=600):
|
||||
"""Create a new authorization code.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
code_hash: Hashed authorization code
|
||||
redirect_uri: The redirect URI
|
||||
scope: Requested scopes
|
||||
nonce: OIDC nonce
|
||||
code_verifier: PKCE code verifier
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
|
||||
|
||||
Returns:
|
||||
OIDCAuthCode instance
|
||||
"""
|
||||
code = cls(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
code_hash=code_hash,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
nonce=nonce,
|
||||
code_verifier=code_verifier,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
db.session.add(code)
|
||||
db.session.commit()
|
||||
return code
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude code hash
|
||||
exclude.append("code_hash")
|
||||
exclude.append("code_verifier")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_auth_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.authorization_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""OIDC Client model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import OIDCGrantType, OIDCResponseType
|
||||
|
||||
|
||||
class OIDCClient(BaseModel):
|
||||
"""OIDC client model for OAuth2/OIDC integrations."""
|
||||
|
||||
__tablename__ = "oidc_clients"
|
||||
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
|
||||
)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
client_id = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
client_secret_hash = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# OAuth/OIDC configuration
|
||||
redirect_uris = db.Column(db.JSON, nullable=False) # List of allowed redirect URIs
|
||||
grant_types = db.Column(db.JSON, nullable=False) # List of allowed grant types
|
||||
response_types = db.Column(db.JSON, nullable=False) # List of allowed response types
|
||||
scopes = db.Column(db.JSON, nullable=False) # List of allowed scopes
|
||||
|
||||
# Client metadata
|
||||
logo_uri = db.Column(db.String(512), nullable=True)
|
||||
client_uri = db.Column(db.String(512), nullable=True)
|
||||
policy_uri = db.Column(db.String(512), nullable=True)
|
||||
tos_uri = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Settings
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
is_confidential = db.Column(db.Boolean, default=True, nullable=False)
|
||||
require_pkce = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
# Token lifetimes (in seconds)
|
||||
access_token_lifetime = db.Column(db.Integer, default=3600, nullable=False)
|
||||
refresh_token_lifetime = db.Column(db.Integer, default=2592000, nullable=False)
|
||||
id_token_lifetime = db.Column(db.Integer, default=3600, nullable=False)
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="oidc_clients")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCClient."""
|
||||
return f"<OIDCClient {self.name} client_id={self.client_id}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude client secret
|
||||
exclude.append("client_secret_hash")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
def has_grant_type(self, grant_type):
|
||||
"""Check if client supports a specific grant type."""
|
||||
return grant_type in self.grant_types
|
||||
|
||||
def has_response_type(self, response_type):
|
||||
"""Check if client supports a specific response type."""
|
||||
return response_type in self.response_types
|
||||
|
||||
def is_redirect_uri_allowed(self, redirect_uri):
|
||||
"""Check if a redirect URI is allowed for this client."""
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def has_scope(self, scope):
|
||||
"""Check if client is allowed to request a specific scope."""
|
||||
return scope in self.scopes
|
||||
@@ -0,0 +1,77 @@
|
||||
"""OIDC JWKS Key model for persisting signing keys."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OidcJwksKey(BaseModel):
|
||||
"""
|
||||
OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
|
||||
|
||||
This model stores RSA/ECDSA key pairs used for signing OIDC tokens.
|
||||
Multiple keys can be stored to support key rotation scenarios.
|
||||
|
||||
Attributes:
|
||||
id: Integer primary key
|
||||
kid: Unique key ID used in JWT "kid" header
|
||||
key_type: Type of key (e.g., "RSA", "EC")
|
||||
private_key: PEM-encoded private key
|
||||
public_key: PEM-encoded public key
|
||||
algorithm: Signing algorithm (e.g., "RS256", "ES256")
|
||||
created_at: When the key was created
|
||||
is_active: Whether this key is currently active for signing
|
||||
is_primary: Whether this is the primary signing key
|
||||
expires_at: ...
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_jwks_keys"
|
||||
|
||||
# Override the default UUID id with integer primary key
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
|
||||
expires_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Key identification and type
|
||||
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
|
||||
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
|
||||
|
||||
# Key material (PEM-encoded)
|
||||
private_key = db.Column(db.Text, nullable=False)
|
||||
public_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Key status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
is_primary = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OidcJwksKey."""
|
||||
return f"<OidcJwksKey kid={self.kid} key_type={self.key_type} algorithm={self.algorithm}>"
|
||||
|
||||
def to_dict(self, exclude_private_key=True):
|
||||
"""
|
||||
Convert model to dictionary.
|
||||
|
||||
Args:
|
||||
exclude_private_key: If True, excludes the private key from output
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the model
|
||||
"""
|
||||
exclude = ["private_key"] if exclude_private_key else []
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
@classmethod
|
||||
def get_active_keys(cls):
|
||||
"""Get all active keys for signing operations."""
|
||||
return cls.query.filter(cls.is_active == True).all()
|
||||
|
||||
@classmethod
|
||||
def get_primary_key(cls):
|
||||
"""Get the primary signing key."""
|
||||
return cls.query.filter(cls.is_primary == True).first()
|
||||
|
||||
@classmethod
|
||||
def get_key_by_kid(cls, kid):
|
||||
"""Get a key by its key ID."""
|
||||
return cls.query.filter(cls.kid == kid, cls.is_active == True).first()
|
||||
@@ -0,0 +1,163 @@
|
||||
"""OIDC Refresh Token model for token rotation."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCRefreshToken(BaseModel):
|
||||
"""OIDC Refresh Token model for token refresh and rotation.
|
||||
|
||||
Refresh tokens are long-lived credentials used to obtain new access tokens.
|
||||
They support token rotation for enhanced security.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_refresh_tokens"
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token (hashed for security)
|
||||
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
|
||||
# Associated access token ID
|
||||
access_token_id = db.Column(
|
||||
db.String(36), db.ForeignKey("sessions.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Token scope
|
||||
scope = db.Column(db.JSON, nullable=True) # Granted scopes
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
|
||||
# Revocation tracking
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Token rotation metadata
|
||||
previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation
|
||||
rotation_count = db.Column(db.Integer, default=0, nullable=False)
|
||||
|
||||
# Request metadata
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="refresh_tokens")
|
||||
user = db.relationship("User", back_populates="oidc_refresh_tokens")
|
||||
access_token = db.relationship("Session", back_populates="oidc_refresh_token")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCRefreshToken."""
|
||||
return f"<OIDCRefreshToken client_id={self.client_id} user_id={self.user_id} revoked={self.is_revoked()}>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the refresh token has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_revoked(self):
|
||||
"""Check if the refresh token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self):
|
||||
"""Check if the refresh token is valid for use."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""Revoke the refresh token.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked_at = datetime.now(timezone.utc)
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
def rotate(self, new_token_hash):
|
||||
"""Rotate the refresh token (invalidate old, create new).
|
||||
|
||||
Args:
|
||||
new_token_hash: Hash of the new refresh token
|
||||
|
||||
Returns:
|
||||
self for chaining
|
||||
"""
|
||||
# Store reference to old token
|
||||
self.previous_token_hash = self.token_hash
|
||||
self.token_hash = new_token_hash
|
||||
self.rotation_count += 1
|
||||
# Extend expiration on rotation
|
||||
from datetime import timedelta
|
||||
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
db.session.commit()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create_token(cls, client_id, user_id, token_hash, scope=None,
|
||||
access_token_id=None, ip_address=None, user_agent=None,
|
||||
lifetime_seconds=2592000):
|
||||
"""Create a new refresh token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
token_hash: Hashed refresh token
|
||||
scope: Granted scopes
|
||||
access_token_id: Associated access token ID
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Token lifetime in seconds (default 30 days)
|
||||
|
||||
Returns:
|
||||
OIDCRefreshToken instance
|
||||
"""
|
||||
from datetime import timedelta
|
||||
token = cls(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
token_hash=token_hash,
|
||||
scope=scope,
|
||||
access_token_id=access_token_id,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
db.session.add(token)
|
||||
db.session.commit()
|
||||
return token
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude token hashes
|
||||
exclude.append("token_hash")
|
||||
exclude.append("previous_token_hash")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to Session model
|
||||
from gatehouse_app.models.session import Session
|
||||
Session.oidc_refresh_token = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="access_token", uselist=False
|
||||
)
|
||||
@@ -0,0 +1,162 @@
|
||||
"""OIDC Session model for OIDC session tracking."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCSession(BaseModel):
|
||||
"""OIDC Session model for tracking OIDC authentication sessions.
|
||||
|
||||
This model tracks the state during the OIDC authentication flow,
|
||||
including PKCE parameters and nonce validation.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_sessions"
|
||||
|
||||
# User reference
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Client reference
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# State management
|
||||
state = db.Column(db.String(255), nullable=False, index=True)
|
||||
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
|
||||
|
||||
# Authorization request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True) # Requested scopes
|
||||
|
||||
# PKCE parameters
|
||||
code_challenge = db.Column(db.String(255), nullable=True)
|
||||
code_challenge_method = db.Column(db.String(10), nullable=True) # "S256" or "plain"
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
authenticated_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="oidc_sessions")
|
||||
client = db.relationship("OIDCClient", back_populates="oidc_sessions")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCSession."""
|
||||
return f"<OIDCSession user_id={self.user_id} client_id={self.client_id} state={self.state[:8]}...>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the OIDC session has expired."""
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
def is_authenticated(self):
|
||||
"""Check if the user has been authenticated in this session."""
|
||||
return self.authenticated_at is not None
|
||||
|
||||
def mark_authenticated(self):
|
||||
"""Mark the session as authenticated."""
|
||||
self.authenticated_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def validate_nonce(self, expected_nonce):
|
||||
"""Validate the nonce matches the expected value.
|
||||
|
||||
Args:
|
||||
expected_nonce: The expected nonce value
|
||||
|
||||
Returns:
|
||||
bool: True if nonce matches
|
||||
"""
|
||||
return self.nonce == expected_nonce
|
||||
|
||||
def validate_code_challenge(self, code_verifier):
|
||||
"""Validate the code verifier against the stored code challenge.
|
||||
|
||||
Args:
|
||||
code_verifier: The PKCE code verifier
|
||||
|
||||
Returns:
|
||||
bool: True if code challenge is valid
|
||||
"""
|
||||
if not self.code_challenge:
|
||||
return False
|
||||
|
||||
if self.code_challenge_method == "S256":
|
||||
import hashlib
|
||||
import base64
|
||||
# SHA256 hash of code_verifier
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
# Base64 URL encode without padding
|
||||
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return self.code_challenge == expected
|
||||
elif self.code_challenge_method == "plain":
|
||||
return self.code_challenge == code_verifier
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def create_session(cls, user_id, client_id, state, redirect_uri, scope=None,
|
||||
nonce=None, code_challenge=None, code_challenge_method=None,
|
||||
lifetime_seconds=600):
|
||||
"""Create a new OIDC session.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: The OIDC client ID
|
||||
state: The state parameter
|
||||
redirect_uri: The redirect URI
|
||||
scope: Requested scopes
|
||||
nonce: OIDC nonce
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method ("S256" or "plain")
|
||||
lifetime_seconds: Session lifetime in seconds
|
||||
|
||||
Returns:
|
||||
OIDCSession instance
|
||||
"""
|
||||
from datetime import timedelta
|
||||
session = cls(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
state=state,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
nonce=nonce,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds),
|
||||
)
|
||||
db.session.add(session)
|
||||
db.session.commit()
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def get_by_state(cls, state):
|
||||
"""Get a session by state parameter.
|
||||
|
||||
Args:
|
||||
state: The state parameter
|
||||
|
||||
Returns:
|
||||
OIDCSession instance or None
|
||||
"""
|
||||
return cls.query.filter_by(state=state, deleted_at=None).first()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -0,0 +1,196 @@
|
||||
"""OIDC Token Metadata model for token revocation tracking."""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCTokenMetadata(BaseModel):
|
||||
"""OIDC Token Metadata model for tracking issued tokens.
|
||||
|
||||
This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens)
|
||||
for the purpose of token revocation. The id field matches the JTI (JWT ID) claim.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_token_metadata"
|
||||
|
||||
# Token identifier (matches JTI in JWT)
|
||||
id = db.Column(
|
||||
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
|
||||
# Client and User references
|
||||
client_id = db.Column(
|
||||
db.String(255), db.ForeignKey("oidc_clients.id"), nullable=False, index=True
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token type
|
||||
token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token"
|
||||
|
||||
# Token identifier for revocation lookup
|
||||
token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
|
||||
# Revocation tracking
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = db.relationship("OIDCClient", back_populates="token_metadata")
|
||||
user = db.relationship("User", back_populates="oidc_token_metadata")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCTokenMetadata."""
|
||||
return f"<OIDCTokenMetadata jti={self.token_jti[:8]}... type={self.token_type} revoked={self.is_revoked()}>"
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if the token has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_revoked(self):
|
||||
"""Check if the token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self):
|
||||
"""Check if the token is valid (not expired and not revoked)."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""Revoke the token.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked_at = datetime.now(timezone.utc)
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_metadata(cls, client_id, user_id, token_type, token_jti,
|
||||
expires_at, ip_address=None, user_agent=None):
|
||||
"""Create token metadata for tracking.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
token_type: Type of token ("access_token", "refresh_token", "id_token")
|
||||
token_jti: JWT ID claim
|
||||
expires_at: Token expiration datetime
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
|
||||
Returns:
|
||||
OIDCTokenMetadata instance
|
||||
"""
|
||||
metadata = cls(
|
||||
id=str(uuid.uuid4()),
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
token_type=token_type,
|
||||
token_jti=token_jti,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.session.add(metadata)
|
||||
db.session.commit()
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def get_by_jti(cls, token_jti):
|
||||
"""Get token metadata by JWT ID.
|
||||
|
||||
Args:
|
||||
token_jti: The JWT ID
|
||||
|
||||
Returns:
|
||||
OIDCTokenMetadata instance or None
|
||||
"""
|
||||
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
|
||||
|
||||
@classmethod
|
||||
def revoke_by_jti(cls, token_jti, reason=None):
|
||||
"""Revoke a token by its JWT ID.
|
||||
|
||||
Args:
|
||||
token_jti: The JWT ID
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
bool: True if token was found and revoked
|
||||
"""
|
||||
metadata = cls.get_by_jti(token_jti)
|
||||
if metadata:
|
||||
metadata.revoke(reason)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_user(cls, user_id, client_id=None, reason=None):
|
||||
"""Revoke all tokens for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: Optional client ID to filter by
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
int: Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(user_id=user_id, deleted_at=None)
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
|
||||
tokens = query.filter(cls.revoked_at == None).all()
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_client(cls, client_id, user_id=None, reason=None):
|
||||
"""Revoke all tokens for a client.
|
||||
|
||||
Args:
|
||||
client_id: The client ID
|
||||
user_id: Optional user ID to filter by
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
int: Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(client_id=client_id, deleted_at=None)
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
|
||||
tokens = query.filter(cls.revoked_at == None).all()
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Organization model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class Organization(BaseModel):
|
||||
"""Organization model representing a tenant/workspace."""
|
||||
|
||||
__tablename__ = "organizations"
|
||||
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
slug = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
logo_url = db.Column(db.String(512), nullable=True)
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
# Settings (stored as JSON)
|
||||
settings = db.Column(db.JSON, nullable=True, default=dict)
|
||||
|
||||
# Relationships
|
||||
members = db.relationship(
|
||||
"OrganizationMember", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_clients = db.relationship(
|
||||
"OIDCClient", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of Organization."""
|
||||
return f"<Organization {self.name}>"
|
||||
|
||||
def get_member_count(self):
|
||||
"""Get the count of active members in the organization."""
|
||||
return len([m for m in self.members if m.deleted_at is None])
|
||||
|
||||
def get_owner(self):
|
||||
"""Get the owner of the organization."""
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
for member in self.members:
|
||||
if member.role == OrganizationRole.OWNER and member.deleted_at is None:
|
||||
return member.user
|
||||
return None
|
||||
|
||||
def is_member(self, user_id):
|
||||
"""Check if a user is a member of the organization."""
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
|
||||
return (
|
||||
OrganizationMember.query.filter_by(
|
||||
user_id=user_id, organization_id=self.id, deleted_at=None
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Organization member model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
|
||||
class OrganizationMember(BaseModel):
|
||||
"""Organization member model representing user membership in an organization."""
|
||||
|
||||
__tablename__ = "organization_members"
|
||||
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True)
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"), nullable=False, index=True
|
||||
)
|
||||
role = db.Column(
|
||||
db.Enum(OrganizationRole), default=OrganizationRole.MEMBER, nullable=False
|
||||
)
|
||||
invited_by_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
|
||||
invited_at = db.Column(db.DateTime, nullable=True)
|
||||
joined_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", foreign_keys=[user_id], back_populates="organization_memberships")
|
||||
organization = db.relationship("Organization", back_populates="members")
|
||||
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
|
||||
|
||||
# Unique constraint to prevent duplicate memberships
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OrganizationMember."""
|
||||
return f"<OrganizationMember user_id={self.user_id} org_id={self.organization_id} role={self.role}>"
|
||||
|
||||
def is_owner(self):
|
||||
"""Check if member is an owner."""
|
||||
return self.role == OrganizationRole.OWNER
|
||||
|
||||
def is_admin(self):
|
||||
"""Check if member is an admin or owner."""
|
||||
return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
def can_manage_members(self):
|
||||
"""Check if member can manage other members."""
|
||||
return self.is_admin()
|
||||
|
||||
def can_delete_organization(self):
|
||||
"""Check if member can delete the organization."""
|
||||
return self.is_owner()
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Session model."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
"""Session model for tracking user sessions."""
|
||||
|
||||
__tablename__ = "sessions"
|
||||
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=False, index=True)
|
||||
token = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
status = db.Column(db.Enum(SessionStatus), default=SessionStatus.ACTIVE, nullable=False)
|
||||
|
||||
# Session metadata
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
device_info = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="sessions")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of Session."""
|
||||
return f"<Session user_id={self.user_id} status={self.status}>"
|
||||
|
||||
def is_active(self):
|
||||
"""Check if session is currently active."""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make expires_at timezone-aware if it's naive
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return (
|
||||
self.status == SessionStatus.ACTIVE
|
||||
and expires_at > now
|
||||
and self.deleted_at is None
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if session has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make expires_at timezone-aware if it's naive
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return now > expires_at
|
||||
|
||||
def refresh(self, duration_seconds=86400):
|
||||
"""
|
||||
Refresh session expiration.
|
||||
|
||||
Args:
|
||||
duration_seconds: New session duration in seconds
|
||||
"""
|
||||
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
|
||||
self.last_activity_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""
|
||||
Revoke the session.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.status = SessionStatus.REVOKED
|
||||
self.revoked_at = datetime.now(timezone.utc)
|
||||
if reason:
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Exclude token from dict
|
||||
exclude.append("token")
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,141 @@
|
||||
"""User model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""User model representing a user account."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
email_verified = db.Column(db.Boolean, default=False, nullable=False)
|
||||
full_name = db.Column(db.String(255), nullable=True)
|
||||
avatar_url = db.Column(db.String(512), nullable=True)
|
||||
status = db.Column(
|
||||
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
|
||||
)
|
||||
last_login_at = db.Column(db.DateTime, nullable=True)
|
||||
last_login_ip = db.Column(db.String(45), nullable=True)
|
||||
|
||||
# Relationships
|
||||
authentication_methods = db.relationship(
|
||||
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
|
||||
organization_memberships = db.relationship(
|
||||
"OrganizationMember",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="OrganizationMember.user_id",
|
||||
)
|
||||
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of User."""
|
||||
return f"<User {self.email}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert user to dictionary, excluding sensitive fields by default."""
|
||||
exclude = exclude or []
|
||||
# Always exclude password-related fields
|
||||
default_exclude = []
|
||||
all_exclude = list(set(default_exclude + exclude))
|
||||
return super().to_dict(exclude=all_exclude)
|
||||
|
||||
def has_password_auth(self):
|
||||
"""Check if user has password authentication enabled."""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_organizations(self):
|
||||
"""Get all organizations the user is a member of."""
|
||||
return [membership.organization for membership in self.organization_memberships]
|
||||
|
||||
def has_totp_enabled(self) -> bool:
|
||||
"""Check if user has TOTP enabled and verified.
|
||||
|
||||
Returns:
|
||||
True if user has a verified TOTP authentication method, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.TOTP,
|
||||
verified=True,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_totp_method(self):
|
||||
"""Get user's TOTP authentication method.
|
||||
|
||||
Returns:
|
||||
The AuthenticationMethod instance for TOTP or None if not found.
|
||||
|
||||
Note:
|
||||
Returns the most recently created TOTP method to handle cases where
|
||||
multiple enrollment attempts may exist.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
|
||||
).order_by(AuthenticationMethod.created_at.desc()).first()
|
||||
|
||||
def has_webauthn_enabled(self) -> bool:
|
||||
"""Check if user has any WebAuthn passkey credentials.
|
||||
|
||||
Returns:
|
||||
True if user has at least one WebAuthn credential, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_webauthn_credentials(self):
|
||||
"""Get all WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
).order_by(AuthenticationMethod.created_at.desc()).all()
|
||||
|
||||
def get_webauthn_credential_count(self) -> int:
|
||||
"""Get the count of WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
Number of WebAuthn credentials.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
).count()
|
||||
@@ -0,0 +1,34 @@
|
||||
"""Schemas package."""
|
||||
from gatehouse_app.schemas.user_schema import UserSchema, UserUpdateSchema, ChangePasswordSchema
|
||||
from gatehouse_app.schemas.auth_schema import (
|
||||
RegisterSchema,
|
||||
LoginSchema,
|
||||
RefreshTokenSchema,
|
||||
ForgotPasswordSchema,
|
||||
ResetPasswordSchema,
|
||||
)
|
||||
from gatehouse_app.schemas.organization_schema import (
|
||||
OrganizationSchema,
|
||||
OrganizationCreateSchema,
|
||||
OrganizationUpdateSchema,
|
||||
OrganizationMemberSchema,
|
||||
InviteMemberSchema,
|
||||
UpdateMemberRoleSchema,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"UserSchema",
|
||||
"UserUpdateSchema",
|
||||
"ChangePasswordSchema",
|
||||
"RegisterSchema",
|
||||
"LoginSchema",
|
||||
"RefreshTokenSchema",
|
||||
"ForgotPasswordSchema",
|
||||
"ResetPasswordSchema",
|
||||
"OrganizationSchema",
|
||||
"OrganizationCreateSchema",
|
||||
"OrganizationUpdateSchema",
|
||||
"OrganizationMemberSchema",
|
||||
"InviteMemberSchema",
|
||||
"UpdateMemberRoleSchema",
|
||||
]
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Authentication schemas for validation."""
|
||||
from marshmallow import Schema, fields, validate, validates_schema, ValidationError
|
||||
|
||||
|
||||
class RegisterSchema(Schema):
|
||||
"""Schema for user registration."""
|
||||
|
||||
email = fields.Email(required=True)
|
||||
password = fields.Str(
|
||||
required=True,
|
||||
validate=validate.Length(min=8, max=128),
|
||||
)
|
||||
password_confirm = fields.Str(required=True)
|
||||
full_name = fields.Str(allow_none=True, validate=validate.Length(max=255))
|
||||
|
||||
@validates_schema
|
||||
def validate_passwords_match(self, data, **kwargs):
|
||||
"""Validate that passwords match."""
|
||||
if data.get("password") != data.get("password_confirm"):
|
||||
raise ValidationError("Passwords do not match", field_name="password_confirm")
|
||||
|
||||
|
||||
class LoginSchema(Schema):
|
||||
"""Schema for user login."""
|
||||
|
||||
email = fields.Email(required=True)
|
||||
password = fields.Str(required=True, validate=validate.Length(min=1))
|
||||
remember_me = fields.Bool(missing=False)
|
||||
|
||||
|
||||
class RefreshTokenSchema(Schema):
|
||||
"""Schema for token refresh."""
|
||||
|
||||
refresh_token = fields.Str(required=True)
|
||||
|
||||
|
||||
class ForgotPasswordSchema(Schema):
|
||||
"""Schema for forgot password request."""
|
||||
|
||||
email = fields.Email(required=True)
|
||||
|
||||
|
||||
class ResetPasswordSchema(Schema):
|
||||
"""Schema for password reset."""
|
||||
|
||||
token = fields.Str(required=True)
|
||||
password = fields.Str(
|
||||
required=True,
|
||||
validate=validate.Length(min=8, max=128),
|
||||
)
|
||||
password_confirm = fields.Str(required=True)
|
||||
|
||||
@validates_schema
|
||||
def validate_passwords_match(self, data, **kwargs):
|
||||
"""Validate that passwords match."""
|
||||
if data.get("password") != data.get("password_confirm"):
|
||||
raise ValidationError("Passwords do not match", field_name="password_confirm")
|
||||
|
||||
|
||||
class TOTPVerifyEnrollmentSchema(Schema):
|
||||
"""Schema for TOTP enrollment verification."""
|
||||
|
||||
code = fields.Str(
|
||||
required=True,
|
||||
validate=validate.Regexp(
|
||||
r"^\d{6}$",
|
||||
error="Code must be a 6-digit number",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TOTPVerifySchema(Schema):
|
||||
"""Schema for TOTP code verification during login."""
|
||||
|
||||
code = fields.Str(required=True)
|
||||
is_backup_code = fields.Bool(missing=False)
|
||||
|
||||
|
||||
class TOTPDisableSchema(Schema):
|
||||
"""Schema for disabling TOTP."""
|
||||
|
||||
password = fields.Str(required=True, validate=validate.Length(min=1))
|
||||
|
||||
|
||||
class TOTPRegenerateBackupCodesSchema(Schema):
|
||||
"""Schema for regenerating backup codes."""
|
||||
|
||||
password = fields.Str(required=True, validate=validate.Length(min=1))
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Organization schemas for validation."""
|
||||
from marshmallow import Schema, fields, validate
|
||||
|
||||
|
||||
class OrganizationSchema(Schema):
|
||||
"""Schema for Organization model."""
|
||||
|
||||
id = fields.Str(dump_only=True)
|
||||
name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
|
||||
slug = fields.Str(required=True, validate=validate.Length(min=1, max=255))
|
||||
description = fields.Str(allow_none=True)
|
||||
logo_url = fields.Url(allow_none=True, validate=validate.Length(max=512))
|
||||
is_active = fields.Bool(dump_only=True)
|
||||
created_at = fields.DateTime(dump_only=True)
|
||||
updated_at = fields.DateTime(dump_only=True)
|
||||
|
||||
|
||||
class OrganizationCreateSchema(Schema):
|
||||
"""Schema for creating an organization."""
|
||||
|
||||
name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
|
||||
slug = fields.Str(required=True, validate=validate.Length(min=1, max=255))
|
||||
description = fields.Str(allow_none=True)
|
||||
logo_url = fields.Url(allow_none=True, validate=validate.Length(max=512))
|
||||
|
||||
|
||||
class OrganizationUpdateSchema(Schema):
|
||||
"""Schema for updating an organization."""
|
||||
|
||||
name = fields.Str(validate=validate.Length(min=1, max=255))
|
||||
description = fields.Str(allow_none=True)
|
||||
logo_url = fields.Url(allow_none=True, validate=validate.Length(max=512))
|
||||
|
||||
|
||||
class OrganizationMemberSchema(Schema):
|
||||
"""Schema for Organization Member."""
|
||||
|
||||
id = fields.Str(dump_only=True)
|
||||
user_id = fields.Str(dump_only=True)
|
||||
organization_id = fields.Str(dump_only=True)
|
||||
role = fields.Str(dump_only=True)
|
||||
joined_at = fields.DateTime(dump_only=True)
|
||||
created_at = fields.DateTime(dump_only=True)
|
||||
|
||||
|
||||
class InviteMemberSchema(Schema):
|
||||
"""Schema for inviting a member to an organization."""
|
||||
|
||||
email = fields.Email(required=True)
|
||||
role = fields.Str(
|
||||
required=True,
|
||||
validate=validate.OneOf(["owner", "admin", "member", "guest"])
|
||||
)
|
||||
|
||||
|
||||
class UpdateMemberRoleSchema(Schema):
|
||||
"""Schema for updating a member's role."""
|
||||
|
||||
role = fields.Str(
|
||||
required=True,
|
||||
validate=validate.OneOf(["owner", "admin", "member", "guest"])
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
"""User schemas for validation and serialization."""
|
||||
from marshmallow import Schema, fields, validate, validates, ValidationError
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
|
||||
|
||||
class UserSchema(Schema):
|
||||
"""Schema for User model."""
|
||||
|
||||
id = fields.Str(dump_only=True)
|
||||
email = fields.Email(required=True)
|
||||
email_verified = fields.Bool(dump_only=True)
|
||||
full_name = fields.Str(allow_none=True, validate=validate.Length(max=255))
|
||||
avatar_url = fields.Url(allow_none=True, validate=validate.Length(max=512))
|
||||
status = fields.Str(dump_only=True)
|
||||
last_login_at = fields.DateTime(dump_only=True)
|
||||
created_at = fields.DateTime(dump_only=True)
|
||||
updated_at = fields.DateTime(dump_only=True)
|
||||
|
||||
|
||||
class UserUpdateSchema(Schema):
|
||||
"""Schema for updating user profile."""
|
||||
|
||||
full_name = fields.Str(allow_none=True, validate=validate.Length(max=255))
|
||||
avatar_url = fields.Url(allow_none=True, validate=validate.Length(max=512))
|
||||
|
||||
|
||||
class ChangePasswordSchema(Schema):
|
||||
"""Schema for changing password."""
|
||||
|
||||
current_password = fields.Str(required=True, validate=validate.Length(min=1))
|
||||
new_password = fields.Str(
|
||||
required=True,
|
||||
validate=validate.Length(min=8, max=128),
|
||||
)
|
||||
new_password_confirm = fields.Str(required=True)
|
||||
|
||||
@validates("new_password")
|
||||
def validate_password_strength(self, value):
|
||||
"""Validate password strength."""
|
||||
if len(value) < 8:
|
||||
raise ValidationError("Password must be at least 8 characters long")
|
||||
if not any(char.isdigit() for char in value):
|
||||
raise ValidationError("Password must contain at least one digit")
|
||||
if not any(char.isupper() for char in value):
|
||||
raise ValidationError("Password must contain at least one uppercase letter")
|
||||
if not any(char.islower() for char in value):
|
||||
raise ValidationError("Password must contain at least one lowercase letter")
|
||||
@@ -0,0 +1,85 @@
|
||||
"""WebAuthn schemas for validation."""
|
||||
from marshmallow import Schema, fields, validate, validates_schema, ValidationError
|
||||
|
||||
|
||||
class WebAuthnRegistrationBeginSchema(Schema):
|
||||
"""Schema for beginning WebAuthn registration."""
|
||||
# No required fields - uses authenticated user
|
||||
pass
|
||||
|
||||
|
||||
class WebAuthnRegistrationCompleteSchema(Schema):
|
||||
"""Schema for completing WebAuthn registration."""
|
||||
|
||||
id = fields.Str(required=True)
|
||||
rawId = fields.Str(required=True)
|
||||
type = fields.Str(
|
||||
required=True,
|
||||
validate=validate.OneOf(["public-key"])
|
||||
)
|
||||
response = fields.Dict(required=True)
|
||||
transports = fields.List(
|
||||
fields.Str(validate=validate.OneOf(["usb", "nfc", "ble", "hybrid", "internal", "platform"])),
|
||||
load_default=[]
|
||||
)
|
||||
|
||||
@validates_schema
|
||||
def validate_response(self, data, **kwargs):
|
||||
"""Validate response contains required fields."""
|
||||
response = data.get("response", {})
|
||||
required_fields = ["attestationObject", "clientDataJSON"]
|
||||
for field in required_fields:
|
||||
if field not in response:
|
||||
raise ValidationError(
|
||||
f"Missing required field in response: {field}",
|
||||
field_name=f"response.{field}"
|
||||
)
|
||||
|
||||
|
||||
class WebAuthnLoginBeginSchema(Schema):
|
||||
"""Schema for beginning WebAuthn login."""
|
||||
|
||||
email = fields.Email(required=True)
|
||||
|
||||
|
||||
class WebAuthnLoginCompleteSchema(Schema):
|
||||
"""Schema for completing WebAuthn login."""
|
||||
|
||||
id = fields.Str(required=True)
|
||||
rawId = fields.Str(required=True)
|
||||
type = fields.Str(
|
||||
required=True,
|
||||
validate=validate.OneOf(["public-key"])
|
||||
)
|
||||
response = fields.Dict(required=True)
|
||||
clientExtensionResults = fields.Dict(load_default={})
|
||||
|
||||
@validates_schema
|
||||
def validate_response(self, data, **kwargs):
|
||||
"""Validate response contains required fields."""
|
||||
response = data.get("response", {})
|
||||
required_fields = ["authenticatorData", "clientDataJSON", "signature"]
|
||||
for field in required_fields:
|
||||
if field not in response:
|
||||
raise ValidationError(
|
||||
f"Missing required field in response: {field}",
|
||||
field_name=f"response.{field}"
|
||||
)
|
||||
|
||||
|
||||
class WebAuthnCredentialRenameSchema(Schema):
|
||||
"""Schema for renaming a WebAuthn credential."""
|
||||
|
||||
name = fields.Str(
|
||||
required=True,
|
||||
validate=validate.Length(min=1, max=100)
|
||||
)
|
||||
|
||||
|
||||
class WebAuthnCredentialDeleteSchema(Schema):
|
||||
"""Schema for deleting a WebAuthn credential."""
|
||||
|
||||
password = fields.Str(
|
||||
required=True,
|
||||
validate=validate.Length(min=1)
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Services package."""
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.services.session_service import SessionService
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.oidc_service import OIDCService, OIDCError
|
||||
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
|
||||
from gatehouse_app.services.oidc_token_service import OIDCTokenService
|
||||
from gatehouse_app.services.oidc_session_service import OIDCSessionService
|
||||
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
"UserService",
|
||||
"OrganizationService",
|
||||
"SessionService",
|
||||
"AuditService",
|
||||
"OIDCService",
|
||||
"OIDCError",
|
||||
"OIDCJWKSService",
|
||||
"OIDCTokenService",
|
||||
"OIDCSessionService",
|
||||
"OIDCAuditService",
|
||||
]
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Audit service."""
|
||||
from flask import request, g
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
|
||||
|
||||
class AuditService:
|
||||
"""Service for audit logging."""
|
||||
|
||||
@staticmethod
|
||||
def log_action(
|
||||
action,
|
||||
user_id=None,
|
||||
organization_id=None,
|
||||
resource_type=None,
|
||||
resource_id=None,
|
||||
metadata=None,
|
||||
description=None,
|
||||
success=True,
|
||||
error_message=None,
|
||||
):
|
||||
"""
|
||||
Create an audit log entry.
|
||||
|
||||
Args:
|
||||
action: AuditAction enum value
|
||||
user_id: ID of user performing the action
|
||||
organization_id: ID of related organization
|
||||
resource_type: Type of resource being acted upon
|
||||
resource_id: ID of resource being acted upon
|
||||
metadata: Additional metadata dictionary
|
||||
description: Human-readable description
|
||||
success: Whether the action succeeded
|
||||
error_message: Error message if action failed
|
||||
|
||||
Returns:
|
||||
AuditLog instance
|
||||
"""
|
||||
# Get request details if available
|
||||
ip_address = None
|
||||
user_agent = None
|
||||
request_id = None
|
||||
|
||||
try:
|
||||
if request:
|
||||
ip_address = request.remote_addr
|
||||
user_agent = request.headers.get("User-Agent")
|
||||
request_id = g.get("request_id")
|
||||
except RuntimeError:
|
||||
# No request context
|
||||
pass
|
||||
|
||||
log_entry = AuditLog(
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
metadata=metadata,
|
||||
description=description,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
)
|
||||
log_entry.save()
|
||||
|
||||
return log_entry
|
||||
|
||||
@staticmethod
|
||||
def get_user_activity(user_id, limit=50):
|
||||
"""
|
||||
Get recent activity for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
limit: Maximum number of records to return
|
||||
|
||||
Returns:
|
||||
List of AuditLog instances
|
||||
"""
|
||||
return (
|
||||
AuditLog.query.filter_by(user_id=user_id)
|
||||
.order_by(AuditLog.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_organization_activity(organization_id, limit=50):
|
||||
"""
|
||||
Get recent activity for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: Organization ID
|
||||
limit: Maximum number of records to return
|
||||
|
||||
Returns:
|
||||
List of AuditLog instances
|
||||
"""
|
||||
return (
|
||||
AuditLog.query.filter_by(organization_id=organization_id)
|
||||
.order_by(AuditLog.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
@@ -0,0 +1,559 @@
|
||||
"""Authentication service."""
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from flask import request, g, current_app
|
||||
from gatehouse_app.extensions import db, bcrypt
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.models.session import Session
|
||||
from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, UserStatus, AuditAction
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
|
||||
from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.totp_service import TOTPService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for authentication operations."""
|
||||
|
||||
@staticmethod
|
||||
def register_user(email, password, full_name=None):
|
||||
"""
|
||||
Register a new user with email/password.
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
password: Plain text password
|
||||
full_name: Optional full name
|
||||
|
||||
Returns:
|
||||
User instance
|
||||
|
||||
Raises:
|
||||
EmailAlreadyExistsError: If email is already registered
|
||||
"""
|
||||
# Check if email already exists
|
||||
existing_user = User.query.filter_by(email=email.lower()).first()
|
||||
if existing_user and existing_user.deleted_at is None:
|
||||
raise EmailAlreadyExistsError()
|
||||
|
||||
# Create user
|
||||
user = User(
|
||||
email=email.lower(),
|
||||
full_name=full_name,
|
||||
status=UserStatus.ACTIVE,
|
||||
)
|
||||
user.save()
|
||||
|
||||
# Create password authentication method
|
||||
password_hash = bcrypt.generate_password_hash(password).decode("utf-8")
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.PASSWORD,
|
||||
password_hash=password_hash,
|
||||
is_primary=True,
|
||||
verified=True,
|
||||
)
|
||||
auth_method.save()
|
||||
|
||||
# Log the registration
|
||||
AuditService.log_action(
|
||||
action=AuditAction.USER_REGISTER,
|
||||
user_id=user.id,
|
||||
resource_type="user",
|
||||
resource_id=user.id,
|
||||
description=f"User registered with email: {email}",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def authenticate(email, password):
|
||||
"""
|
||||
Authenticate user with email/password.
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
password: Plain text password
|
||||
|
||||
Returns:
|
||||
User instance if authentication succeeds
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If credentials are invalid
|
||||
AccountSuspendedError: If account is suspended
|
||||
AccountInactiveError: If account is inactive
|
||||
"""
|
||||
# Find user
|
||||
user = User.query.filter_by(email=email.lower(), deleted_at=None).first()
|
||||
|
||||
# Development-only debug logging for user existence check
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[Auth] User lookup: email={email}, exists={user is not None}")
|
||||
|
||||
if not user:
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# Check account status
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}")
|
||||
|
||||
if user.status == UserStatus.SUSPENDED:
|
||||
raise AccountSuspendedError()
|
||||
if user.status == UserStatus.INACTIVE:
|
||||
raise AccountInactiveError()
|
||||
|
||||
# Find password auth method
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.PASSWORD,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
# Development-only debug logging for auth method lookup
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[Auth] Auth method lookup: user_id={user.id}, has_password_auth={auth_method is not None and auth_method.password_hash is not None}")
|
||||
|
||||
if not auth_method or not auth_method.password_hash:
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# Verify password
|
||||
password_valid = bcrypt.check_password_hash(auth_method.password_hash, password)
|
||||
|
||||
# Development-only debug logging for password validation (without logging actual password)
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[Auth] Password validation: user_id={user.id}, valid={password_valid}")
|
||||
|
||||
if not password_valid:
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# Update last login
|
||||
user.last_login_at = datetime.now(timezone.utc)
|
||||
user.last_login_ip = request.remote_addr
|
||||
auth_method.last_used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_session(user, duration_seconds=86400):
|
||||
"""
|
||||
Create a new session for the user.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
duration_seconds: Session duration in seconds
|
||||
|
||||
Returns:
|
||||
Session instance
|
||||
"""
|
||||
# Generate session token
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
# Create session
|
||||
session = Session(
|
||||
user_id=user.id,
|
||||
token=token,
|
||||
status=SessionStatus.ACTIVE,
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.save()
|
||||
|
||||
# Log session creation
|
||||
AuditService.log_action(
|
||||
action=AuditAction.SESSION_CREATE,
|
||||
user_id=user.id,
|
||||
resource_type="session",
|
||||
resource_id=session.id,
|
||||
description="User session created",
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
def change_password(user, current_password, new_password):
|
||||
"""
|
||||
Change user password.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
current_password: Current password
|
||||
new_password: New password
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If current password is incorrect
|
||||
"""
|
||||
# Find password auth method
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.PASSWORD,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not auth_method or not auth_method.password_hash:
|
||||
raise InvalidCredentialsError("No password authentication method found")
|
||||
|
||||
# Verify current password
|
||||
if not bcrypt.check_password_hash(auth_method.password_hash, current_password):
|
||||
raise InvalidCredentialsError("Current password is incorrect")
|
||||
|
||||
# Update password
|
||||
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
|
||||
db.session.commit()
|
||||
|
||||
# Log password change
|
||||
AuditService.log_action(
|
||||
action=AuditAction.PASSWORD_CHANGE,
|
||||
user_id=user.id,
|
||||
description="User changed password",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def revoke_session(session_id, reason=None):
|
||||
"""
|
||||
Revoke a session.
|
||||
|
||||
Args:
|
||||
session_id: Session ID to revoke
|
||||
reason: Optional revocation reason
|
||||
"""
|
||||
session = Session.query.get(session_id)
|
||||
if session:
|
||||
session.revoke(reason=reason)
|
||||
|
||||
# Log session revocation
|
||||
AuditService.log_action(
|
||||
action=AuditAction.SESSION_REVOKE,
|
||||
user_id=session.user_id,
|
||||
resource_type="session",
|
||||
resource_id=session.id,
|
||||
description=f"Session revoked: {reason or 'User logout'}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def enroll_totp(user: User) -> dict:
|
||||
"""
|
||||
Initiate TOTP enrollment for a user.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- secret: TOTP secret (base32 encoded)
|
||||
- provisioning_uri: otpauth:// URI for QR code
|
||||
- qr_code: Base64 encoded QR code as data URI
|
||||
- backup_codes: List of plain text backup codes
|
||||
|
||||
Raises:
|
||||
ConflictError: If user already has TOTP enabled
|
||||
"""
|
||||
from gatehouse_app.exceptions.validation_exceptions import ConflictError
|
||||
|
||||
# Check if user already has TOTP enabled
|
||||
if user.has_totp_enabled():
|
||||
raise ConflictError("TOTP is already enabled for this account")
|
||||
|
||||
# Clean up any existing unverified TOTP enrollment attempts
|
||||
# Use hard delete for unverified methods since they're incomplete enrollment attempts
|
||||
existing_totp_method = user.get_totp_method()
|
||||
if existing_totp_method and not existing_totp_method.verified:
|
||||
logger.debug(f"Removing existing unverified TOTP method for user {user.id}")
|
||||
db.session.delete(existing_totp_method) # Hard delete - unverified methods are temporary
|
||||
db.session.commit() # Commit to ensure deletion before creating new record
|
||||
|
||||
# Generate TOTP secret
|
||||
secret = TOTPService.generate_secret()
|
||||
|
||||
# Generate provisioning URI
|
||||
provisioning_uri = TOTPService.generate_provisioning_uri(
|
||||
user_email=user.email,
|
||||
secret=secret,
|
||||
issuer="Gatehouse",
|
||||
)
|
||||
|
||||
# Generate QR code data URI
|
||||
qr_code = TOTPService.generate_qr_code_data_uri(provisioning_uri)
|
||||
|
||||
# Generate backup codes
|
||||
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
|
||||
|
||||
# Create unverified TOTP authentication method
|
||||
auth_method = AuthenticationMethod(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.TOTP,
|
||||
verified=False,
|
||||
is_primary=False,
|
||||
)
|
||||
auth_method.save()
|
||||
|
||||
# Store TOTP data in provider_data (since totp_secret field is commented out)
|
||||
auth_method.provider_data = {
|
||||
"secret": secret,
|
||||
"backup_codes": hashed_backup_codes,
|
||||
}
|
||||
db.session.commit()
|
||||
|
||||
# Log TOTP enrollment initiation
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_ENROLL_INITIATED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="TOTP enrollment initiated",
|
||||
)
|
||||
|
||||
return {
|
||||
"secret": secret,
|
||||
"provisioning_uri": provisioning_uri,
|
||||
"qr_code": qr_code,
|
||||
"backup_codes": backup_codes,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def verify_totp_enrollment(user: User, code: str) -> bool:
|
||||
"""
|
||||
Complete TOTP enrollment by verifying the first TOTP code.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
code: 6-digit TOTP code from authenticator app
|
||||
|
||||
Returns:
|
||||
True if verification successful
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If code is invalid or TOTP method not found
|
||||
"""
|
||||
# Get user's TOTP authentication method
|
||||
auth_method = user.get_totp_method()
|
||||
if not auth_method:
|
||||
raise InvalidCredentialsError("TOTP enrollment not found")
|
||||
|
||||
# Get secret from provider_data
|
||||
secret = auth_method.provider_data.get("secret") if auth_method.provider_data else None
|
||||
if not secret:
|
||||
raise InvalidCredentialsError("TOTP secret not found")
|
||||
|
||||
# Verify the code
|
||||
if not TOTPService.verify_code(secret, code):
|
||||
raise InvalidCredentialsError("Invalid TOTP code")
|
||||
|
||||
# Mark TOTP as verified
|
||||
auth_method.verified = True
|
||||
auth_method.totp_verified_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
# Log TOTP enrollment completion
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_ENROLL_COMPLETED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="TOTP enrollment completed",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def disable_totp(user: User, password: str) -> bool:
|
||||
"""
|
||||
Disable TOTP for a user.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
password: User's current password for verification
|
||||
|
||||
Returns:
|
||||
True if TOTP disabled successfully
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If password is invalid or TOTP method not found
|
||||
"""
|
||||
# Verify user's password
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.PASSWORD,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not auth_method or not auth_method.password_hash:
|
||||
raise InvalidCredentialsError("No password authentication method found")
|
||||
|
||||
if not bcrypt.check_password_hash(auth_method.password_hash, password):
|
||||
raise InvalidCredentialsError("Invalid password")
|
||||
|
||||
# Get user's TOTP authentication method
|
||||
totp_method = user.get_totp_method()
|
||||
if not totp_method:
|
||||
raise InvalidCredentialsError("TOTP is not enabled for this account")
|
||||
|
||||
# Soft-delete the TOTP authentication method
|
||||
totp_method.delete(soft=True)
|
||||
|
||||
# Log TOTP disabled
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_DISABLED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=totp_method.id,
|
||||
description="TOTP disabled",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def authenticate_with_totp(user: User, code: str, is_backup_code: bool = False) -> bool:
|
||||
"""
|
||||
Verify TOTP code during login.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
code: 6-digit TOTP code or backup code
|
||||
is_backup_code: True if code is a backup code, False if TOTP code
|
||||
|
||||
Returns:
|
||||
True if code is valid
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If code is invalid or TOTP method not found
|
||||
"""
|
||||
# Get user's TOTP authentication method
|
||||
auth_method = user.get_totp_method()
|
||||
if not auth_method:
|
||||
raise InvalidCredentialsError("TOTP is not enabled for this account")
|
||||
|
||||
if is_backup_code:
|
||||
# Verify backup code
|
||||
backup_codes = (
|
||||
auth_method.provider_data.get("backup_codes")
|
||||
if auth_method.provider_data
|
||||
else []
|
||||
)
|
||||
is_valid, remaining_codes = TOTPService.verify_backup_code(backup_codes, code)
|
||||
|
||||
if is_valid:
|
||||
# Update remaining backup codes
|
||||
auth_method.provider_data = {
|
||||
"secret": auth_method.provider_data.get("secret"),
|
||||
"backup_codes": remaining_codes,
|
||||
}
|
||||
auth_method.last_used_at = datetime.now(timezone.utc)
|
||||
db.session.add(auth_method)
|
||||
db.session.commit()
|
||||
logger.debug(f"[BACKUP CODE] Updated provider_data: {auth_method.provider_data}")
|
||||
|
||||
# Log backup code usage
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_BACKUP_CODE_USED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="Backup code used for authentication",
|
||||
)
|
||||
else:
|
||||
# Log failed verification
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_VERIFY_FAILED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="Invalid backup code provided",
|
||||
)
|
||||
raise InvalidCredentialsError("Invalid backup code")
|
||||
else:
|
||||
# Verify TOTP code
|
||||
secret = (
|
||||
auth_method.provider_data.get("secret")
|
||||
if auth_method.provider_data
|
||||
else None
|
||||
)
|
||||
if not secret:
|
||||
raise InvalidCredentialsError("TOTP secret not found")
|
||||
|
||||
is_valid = TOTPService.verify_code(secret, code)
|
||||
|
||||
if is_valid:
|
||||
auth_method.last_used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
# Log successful verification
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_VERIFY_SUCCESS,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="TOTP code verified successfully",
|
||||
)
|
||||
else:
|
||||
# Log failed verification
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_VERIFY_FAILED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="Invalid TOTP code provided",
|
||||
)
|
||||
raise InvalidCredentialsError("Invalid TOTP code")
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def regenerate_totp_backup_codes(user: User, password: str) -> list[str]:
|
||||
"""
|
||||
Generate new backup codes for TOTP.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
password: User's current password for verification
|
||||
|
||||
Returns:
|
||||
List of new plain text backup codes
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If password is invalid or TOTP method not found
|
||||
"""
|
||||
# Verify user's password
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.PASSWORD,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not auth_method or not auth_method.password_hash:
|
||||
raise InvalidCredentialsError("No password authentication method found")
|
||||
|
||||
if not bcrypt.check_password_hash(auth_method.password_hash, password):
|
||||
raise InvalidCredentialsError("Invalid password")
|
||||
|
||||
# Get user's TOTP authentication method
|
||||
totp_method = user.get_totp_method()
|
||||
if not totp_method:
|
||||
raise InvalidCredentialsError("TOTP is not enabled for this account")
|
||||
|
||||
# Generate new backup codes
|
||||
backup_codes, hashed_backup_codes = TOTPService.generate_backup_codes()
|
||||
|
||||
# Update the authentication method with new backup codes
|
||||
totp_method.provider_data = {
|
||||
"secret": totp_method.provider_data.get("secret"),
|
||||
"backup_codes": hashed_backup_codes,
|
||||
}
|
||||
db.session.commit()
|
||||
|
||||
# Log backup codes regeneration
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_BACKUP_CODES_REGENERATED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=totp_method.id,
|
||||
description="TOTP backup codes regenerated",
|
||||
)
|
||||
|
||||
return backup_codes
|
||||
@@ -0,0 +1,408 @@
|
||||
"""OIDC Audit Service for comprehensive OIDC event logging."""
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from flask import g
|
||||
|
||||
from gatehouse_app.models import OIDCAuditLog, OIDCClient, User
|
||||
from gatehouse_app.exceptions.validation_exceptions import NotFoundError
|
||||
|
||||
|
||||
class OIDCAuditService:
|
||||
"""Service for OIDC-specific audit logging.
|
||||
|
||||
This service provides methods to log all OIDC-related events including:
|
||||
- Authorization requests and responses
|
||||
- Token issuance and refresh
|
||||
- Token revocation
|
||||
- UserInfo endpoint access
|
||||
- Authentication failures
|
||||
"""
|
||||
|
||||
# Event type constants
|
||||
EVENT_AUTHORIZATION_REQUEST = "authorization_request"
|
||||
EVENT_AUTHORIZATION_RESPONSE = "authorization_response"
|
||||
EVENT_TOKEN_ISSUE = "token_issue"
|
||||
EVENT_TOKEN_REFRESH = "token_refresh"
|
||||
EVENT_TOKEN_REVOCATION = "token_revocation"
|
||||
EVENT_TOKEN_INTROSPECTION = "token_introspection"
|
||||
EVENT_USERINFO_ACCESS = "userinfo_access"
|
||||
EVENT_AUTHENTICATION_FAILURE = "authentication_failure"
|
||||
EVENT_AUTHORIZATION_FAILURE = "authorization_failure"
|
||||
EVENT_JWKS_ACCESS = "jwks_access"
|
||||
EVENT_REGISTRATION = "client_registration"
|
||||
|
||||
@classmethod
|
||||
def _get_request_context(cls) -> Dict:
|
||||
"""Extract request context for logging.
|
||||
|
||||
Returns:
|
||||
Dictionary with IP, user_agent, and request_id
|
||||
"""
|
||||
from flask import request
|
||||
|
||||
return {
|
||||
"ip_address": request.remote_addr if request else None,
|
||||
"user_agent": request.headers.get("User-Agent") if request else None,
|
||||
"request_id": g.get("request_id"),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def log_event(
|
||||
cls,
|
||||
event_type: str,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
metadata: Dict = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a generic OIDC event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
success: Whether the event was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
metadata: Additional event metadata
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
context = cls._get_request_context()
|
||||
|
||||
log = OIDCAuditLog.log_event(
|
||||
event_type=event_type,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
ip_address=context["ip_address"],
|
||||
user_agent=context["user_agent"],
|
||||
request_id=context["request_id"],
|
||||
event_metadata=metadata,
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
@classmethod
|
||||
def log_authorization_event(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
redirect_uri: str = None,
|
||||
scope: list = None,
|
||||
response_type: str = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log an authorization event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID (if authenticated)
|
||||
success: Whether authorization was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
redirect_uri: Redirect URI from request
|
||||
scope: Requested scopes
|
||||
response_type: Response type (e.g., "code")
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
metadata = {
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": scope,
|
||||
"response_type": response_type,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_AUTHORIZATION_REQUEST,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_event(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str = None,
|
||||
token_type: str = "access_token",
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
grant_type: str = None,
|
||||
scopes: list = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a token issuance or refresh event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
token_type: Type of token issued
|
||||
success: Whether token issuance was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
grant_type: Grant type used (e.g., "authorization_code", "refresh_token")
|
||||
scopes: Scopes included in the token
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
metadata = {
|
||||
"token_type": token_type,
|
||||
"grant_type": grant_type,
|
||||
"scopes": scopes,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_TOKEN_ISSUE if token_type else cls.EVENT_TOKEN_REFRESH,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_userinfo_event(
|
||||
cls,
|
||||
access_token: str = None,
|
||||
user_id: str = None,
|
||||
client_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
scopes_claimed: list = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a UserInfo endpoint access event.
|
||||
|
||||
Args:
|
||||
access_token: Access token used (masked)
|
||||
user_id: User ID returned
|
||||
client_id: Client ID making the request
|
||||
success: Whether access was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
scopes_claimed: Scopes claimed in the request
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
# Mask the access token for security
|
||||
masked_token = None
|
||||
if access_token:
|
||||
masked_token = access_token[:8] + "..." + access_token[-4:] if len(access_token) > 12 else "***"
|
||||
|
||||
metadata = {
|
||||
"token_prefix": masked_token,
|
||||
"scopes_claimed": scopes_claimed,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_USERINFO_ACCESS,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_revocation_event(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str = None,
|
||||
token_type: str = "access_token",
|
||||
reason: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log a token revocation event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
user_id: User ID
|
||||
token_type: Type of token being revoked
|
||||
reason: Revocation reason
|
||||
success: Whether revocation was successful
|
||||
error_code: Error code if failed
|
||||
error_description: Error description if failed
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
metadata = {
|
||||
"token_type": token_type,
|
||||
"reason": reason,
|
||||
}
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_TOKEN_REVOCATION,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=success,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_authentication_failure(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
error_code: str = "authentication_failed",
|
||||
error_description: str = "Authentication failed",
|
||||
user_id: str = None
|
||||
) -> OIDCAuditLog:
|
||||
"""Log an authentication failure event.
|
||||
|
||||
Args:
|
||||
client_id: OIDC client ID
|
||||
error_code: Error code
|
||||
error_description: Error description
|
||||
user_id: User ID if known
|
||||
|
||||
Returns:
|
||||
OIDCAuditLog instance
|
||||
"""
|
||||
return cls.log_event(
|
||||
event_type=cls.EVENT_AUTHENTICATION_FAILURE,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
success=False,
|
||||
error_code=error_code,
|
||||
error_description=error_description,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_user(
|
||||
cls,
|
||||
user_id: str,
|
||||
limit: int = 100,
|
||||
include_deleted: bool = False
|
||||
) -> List[OIDCAuditLog]:
|
||||
"""Get audit events for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
limit: Maximum number of events to return
|
||||
include_deleted: Include soft-deleted events
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return OIDCAuditLog.get_events_for_user(user_id, limit)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_client(
|
||||
cls,
|
||||
client_id: str,
|
||||
limit: int = 100
|
||||
) -> List[OIDCAuditLog]:
|
||||
"""Get audit events for a specific client.
|
||||
|
||||
Args:
|
||||
client_id: Client ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return OIDCAuditLog.get_events_for_client(client_id, limit)
|
||||
|
||||
@classmethod
|
||||
def get_failed_events(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
start_date: datetime = None,
|
||||
end_date: datetime = None,
|
||||
limit: int = 100
|
||||
) -> List[OIDCAuditLog]:
|
||||
"""Get failed audit events for analysis.
|
||||
|
||||
Args:
|
||||
client_id: Optional client ID filter
|
||||
user_id: Optional user ID filter
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of failed OIDCAuditLog instances
|
||||
"""
|
||||
return OIDCAuditLog.get_failed_events(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_event_summary(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
days: int = 30
|
||||
) -> Dict:
|
||||
"""Get a summary of audit events.
|
||||
|
||||
Args:
|
||||
client_id: Optional client ID filter
|
||||
days: Number of days to look back
|
||||
|
||||
Returns:
|
||||
Summary dictionary with event counts
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
query = OIDCAuditLog.query.filter(
|
||||
OIDCAuditLog.created_at >= start_date
|
||||
)
|
||||
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
|
||||
events = query.all()
|
||||
|
||||
# Count by event type
|
||||
event_counts = {}
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for event in events:
|
||||
event_type = event.event_type
|
||||
event_counts[event_type] = event_counts.get(event_type, 0) + 1
|
||||
|
||||
if event.success:
|
||||
success_count += 1
|
||||
else:
|
||||
failure_count += 1
|
||||
|
||||
return {
|
||||
"total_events": len(events),
|
||||
"successful_events": success_count,
|
||||
"failed_events": failure_count,
|
||||
"by_event_type": event_counts,
|
||||
"period_days": days,
|
||||
}
|
||||
@@ -0,0 +1,418 @@
|
||||
"""OIDC JWKS Service for key management and rotation."""
|
||||
import uuid
|
||||
import json
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.oidc_jwks_key import OidcJwksKey
|
||||
|
||||
|
||||
class JWKSKey:
|
||||
"""Represents a JWKS key entry."""
|
||||
|
||||
def __init__(self, kid: str, private_key: str, public_key: str,
|
||||
algorithm: str = "RS256", created_at: datetime = None,
|
||||
expires_at: datetime = None, is_active: bool = True):
|
||||
self.kid = kid
|
||||
self.private_key = private_key
|
||||
self.public_key = public_key
|
||||
self.algorithm = algorithm
|
||||
self.created_at = created_at or datetime.now(timezone.utc)
|
||||
self.expires_at = expires_at or datetime.now(timezone.utc) + timedelta(days=365)
|
||||
self.is_active = is_active
|
||||
|
||||
def to_jwk(self) -> Dict:
|
||||
"""Convert to JWK format for JWKS endpoint."""
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
# Import cryptography here to avoid issues if not installed
|
||||
try:
|
||||
# Get public key from PEM
|
||||
public_key = serialization.load_pem_public_key(
|
||||
self.public_key.encode(), backend=default_backend()
|
||||
)
|
||||
|
||||
# Get RSA parameters
|
||||
public_numbers = public_key.public_numbers()
|
||||
|
||||
return {
|
||||
"kty": "RSA",
|
||||
"kid": self.kid,
|
||||
"use": "sig",
|
||||
"alg": self.algorithm,
|
||||
"n": _base64url_encode(public_numbers.n),
|
||||
"e": _base64url_encode(public_numbers.e),
|
||||
}
|
||||
except ImportError:
|
||||
# Fallback for when cryptography is not installed
|
||||
return {
|
||||
"kty": "RSA",
|
||||
"kid": self.kid,
|
||||
"use": "sig",
|
||||
"alg": self.algorithm,
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary for storage."""
|
||||
return {
|
||||
"kid": self.kid,
|
||||
"private_key": self.private_key,
|
||||
"public_key": self.public_key,
|
||||
"algorithm": self.algorithm,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"expires_at": self.expires_at.isoformat(),
|
||||
"is_active": self.is_active,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "JWKSKey":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
kid=data["kid"],
|
||||
private_key=data["private_key"],
|
||||
public_key=data["public_key"],
|
||||
algorithm=data.get("algorithm", "RS256"),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]),
|
||||
is_active=data.get("is_active", True),
|
||||
)
|
||||
|
||||
|
||||
def _base64url_encode(value: int) -> str:
|
||||
"""Encode an integer to base64url format."""
|
||||
import base64
|
||||
byte_length = (value.bit_length() + 7) // 8 or 1
|
||||
encoded = value.to_bytes(byte_length, byteorder="big")
|
||||
return base64.urlsafe_b64encode(encoded).decode().rstrip("=")
|
||||
|
||||
|
||||
class OIDCJWKSService:
|
||||
"""Service for managing OIDC signing keys (JWKS).
|
||||
|
||||
This service handles RSA key pair generation, rotation, and JWKS document
|
||||
generation for the OIDC implementation.
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_keys: Dict[str, JWKSKey] = {}
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._keys = {}
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset the singleton (for testing)."""
|
||||
cls._instance = None
|
||||
cls._keys = {}
|
||||
|
||||
def _generate_kid(self, private_key: str) -> str:
|
||||
"""Generate a key ID from the private key fingerprint."""
|
||||
kid_hash = hashlib.sha256(private_key.encode()).hexdigest()[:32]
|
||||
return kid_hash
|
||||
|
||||
def _generate_rsa_key_pair(self) -> Tuple[str, str]:
|
||||
"""Generate a new RSA key pair in PEM format.
|
||||
|
||||
Returns:
|
||||
Tuple of (private_key_pem, public_key_pem)
|
||||
"""
|
||||
try:
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
# Generate RSA private key
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
backend=default_backend()
|
||||
)
|
||||
|
||||
# Get public key
|
||||
public_key = private_key.public_key()
|
||||
|
||||
# Serialize to PEM
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
).decode()
|
||||
|
||||
public_pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
).decode()
|
||||
|
||||
return private_pem, public_pem
|
||||
except ImportError:
|
||||
# Fallback for testing without cryptography
|
||||
import secrets
|
||||
return f"private_key_{secrets.token_hex(32)}", f"public_key_{secrets.token_hex(32)}"
|
||||
|
||||
def get_jwks(self, include_private_keys: bool = False) -> Dict:
|
||||
"""Get the JWKS document containing public keys.
|
||||
|
||||
Args:
|
||||
include_private_keys: Whether to include private keys (for internal use only)
|
||||
|
||||
Returns:
|
||||
JWKS document dictionary
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
keys = []
|
||||
for kid, key in self._keys.items():
|
||||
# Only include active, non-expired keys
|
||||
if key.is_active and key.expires_at > now:
|
||||
if include_private_keys:
|
||||
keys.append(key.to_dict())
|
||||
else:
|
||||
keys.append(key.to_jwk())
|
||||
|
||||
return {
|
||||
"keys": keys
|
||||
}
|
||||
|
||||
def load_keys_from_db(self) -> int:
|
||||
"""Load existing keys from the database.
|
||||
|
||||
Returns:
|
||||
Number of keys loaded
|
||||
"""
|
||||
try:
|
||||
db_keys = OidcJwksKey.get_active_keys()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
for db_key in db_keys:
|
||||
# Create JWKSKey from database model
|
||||
key = JWKSKey(
|
||||
kid=db_key.kid,
|
||||
private_key=db_key.private_key,
|
||||
public_key=db_key.public_key,
|
||||
algorithm=db_key.algorithm,
|
||||
created_at=db_key.created_at,
|
||||
expires_at=db_key.expires_at or now + timedelta(days=365),
|
||||
is_active=db_key.is_active,
|
||||
)
|
||||
self._keys[db_key.kid] = key
|
||||
|
||||
return len(self._keys)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading keys from database: {e}")
|
||||
return 0
|
||||
|
||||
def save_key_to_db(self, key: JWKSKey, is_primary: bool = False) -> OidcJwksKey:
|
||||
"""Save a key to the database.
|
||||
|
||||
Args:
|
||||
key: JWKSKey instance to save
|
||||
is_primary: Whether this is the primary signing key
|
||||
|
||||
Returns:
|
||||
OidcJwksKey database model instance
|
||||
"""
|
||||
db_key = OidcJwksKey(
|
||||
kid=key.kid,
|
||||
key_type="RSA",
|
||||
algorithm=key.algorithm,
|
||||
private_key=key.private_key,
|
||||
public_key=key.public_key,
|
||||
is_active=key.is_active,
|
||||
is_primary=is_primary,
|
||||
)
|
||||
|
||||
db.session.add(db_key)
|
||||
db.session.commit()
|
||||
|
||||
return db_key
|
||||
|
||||
def get_signing_key(self) -> Optional[JWKSKey]:
|
||||
"""Get the current active signing key.
|
||||
|
||||
Returns:
|
||||
JWKSKey instance or None if no active key
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# First try to get the primary key from database
|
||||
try:
|
||||
primary_db_key = OidcJwksKey.get_primary_key()
|
||||
if primary_db_key:
|
||||
# Check if we have it in memory, if not load it
|
||||
if primary_db_key.kid not in self._keys:
|
||||
key = JWKSKey(
|
||||
kid=primary_db_key.kid,
|
||||
private_key=primary_db_key.private_key,
|
||||
public_key=primary_db_key.public_key,
|
||||
algorithm=primary_db_key.algorithm,
|
||||
created_at=primary_db_key.created_at,
|
||||
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
|
||||
is_active=primary_db_key.is_active,
|
||||
)
|
||||
self._keys[primary_db_key.kid] = key
|
||||
return self._keys[primary_db_key.kid]
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error getting primary key from database: {e}")
|
||||
|
||||
# Fall back to in-memory keys
|
||||
for kid, key in self._keys.items():
|
||||
if key.is_active and key.expires_at > now:
|
||||
return key
|
||||
|
||||
return None
|
||||
|
||||
def get_key_by_kid(self, kid: str) -> Optional[JWKSKey]:
|
||||
"""Get a specific key by its ID.
|
||||
|
||||
Args:
|
||||
kid: Key ID to look up
|
||||
|
||||
Returns:
|
||||
JWKSKey instance or None if not found
|
||||
"""
|
||||
return self._keys.get(kid)
|
||||
|
||||
def generate_new_key_pair(self, expires_in_days: int = 365) -> JWKSKey:
|
||||
"""Generate a new RSA key pair for signing.
|
||||
|
||||
Args:
|
||||
expires_in_days: Days until key expiration
|
||||
|
||||
Returns:
|
||||
JWKSKey instance
|
||||
"""
|
||||
private_key, public_key = self._generate_rsa_key_pair()
|
||||
kid = self._generate_kid(private_key)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
key = JWKSKey(
|
||||
kid=kid,
|
||||
private_key=private_key,
|
||||
public_key=public_key,
|
||||
algorithm="RS256",
|
||||
created_at=now,
|
||||
expires_at=now + timedelta(days=expires_in_days),
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
self._keys[kid] = key
|
||||
|
||||
# Deactivate old keys (but keep them for grace period)
|
||||
for old_kid in self._keys:
|
||||
if old_kid != kid:
|
||||
self._keys[old_kid].is_active = False
|
||||
|
||||
return key
|
||||
|
||||
def rotate_keys(self, grace_period_hours: int = 24) -> Tuple[JWKSKey, List[str]]:
|
||||
"""Rotate signing keys, keeping previous key active for grace period.
|
||||
|
||||
Args:
|
||||
grace_period_hours: Hours to keep old keys active
|
||||
|
||||
Returns:
|
||||
Tuple of (new_key, list_of_deprecated_kids)
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
grace_end = now + timedelta(hours=grace_period_hours)
|
||||
|
||||
# Mark current key as deprecated
|
||||
current_key = self.get_signing_key()
|
||||
deprecated_kids = []
|
||||
|
||||
if current_key:
|
||||
deprecated_kids.append(current_key.kid)
|
||||
# Keep key active but mark as deprecated
|
||||
current_key.is_active = False
|
||||
current_key.expires_at = grace_end
|
||||
|
||||
# Generate new key
|
||||
new_key = self.generate_new_key_pair()
|
||||
|
||||
# Clean up expired keys
|
||||
expired_kids = [
|
||||
kid for kid, key in self._keys.items()
|
||||
if key.expires_at < now
|
||||
]
|
||||
for kid in expired_kids:
|
||||
del self._keys[kid]
|
||||
|
||||
return new_key, deprecated_kids
|
||||
|
||||
def verify_key_exists(self, kid: str) -> bool:
|
||||
"""Check if a key with the given ID exists and is valid.
|
||||
|
||||
Args:
|
||||
kid: Key ID to check
|
||||
|
||||
Returns:
|
||||
True if key exists and is valid
|
||||
"""
|
||||
key = self.get_key_by_kid(kid)
|
||||
if not key:
|
||||
return False
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
return key.is_active and key.expires_at > now
|
||||
|
||||
def initialize_with_key(self) -> JWKSKey:
|
||||
"""Initialize the service with a key, loading from database if available.
|
||||
|
||||
This method first attempts to load existing keys from the database.
|
||||
If no active primary key exists, it generates a new key and saves it to the database.
|
||||
|
||||
Returns:
|
||||
JWKSKey instance
|
||||
"""
|
||||
# First, try to load keys from database
|
||||
try:
|
||||
# Check if there's a primary key in the database
|
||||
primary_db_key = OidcJwksKey.get_primary_key()
|
||||
if primary_db_key:
|
||||
# Load the primary key into memory
|
||||
now = datetime.now(timezone.utc)
|
||||
key = JWKSKey(
|
||||
kid=primary_db_key.kid,
|
||||
private_key=primary_db_key.private_key,
|
||||
public_key=primary_db_key.public_key,
|
||||
algorithm=primary_db_key.algorithm,
|
||||
created_at=primary_db_key.created_at,
|
||||
expires_at=primary_db_key.expires_at or now + timedelta(days=365),
|
||||
is_active=primary_db_key.is_active,
|
||||
)
|
||||
self._keys[primary_db_key.kid] = key
|
||||
current_app.logger.info(f"[OIDC] Loaded existing signing key from database: kid={primary_db_key.kid}")
|
||||
return key
|
||||
|
||||
# Try to load all active keys from database
|
||||
loaded_count = self.load_keys_from_db()
|
||||
if loaded_count > 0:
|
||||
# Get the signing key from loaded keys
|
||||
signing_key = self.get_signing_key()
|
||||
if signing_key:
|
||||
current_app.logger.info(f"[OIDC] Loaded {loaded_count} keys from database, using signing key: kid={signing_key.kid}")
|
||||
return signing_key
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading keys from database: {e}")
|
||||
|
||||
# No keys in database, generate a new key and save it
|
||||
current_app.logger.info("[OIDC] No existing keys found in database, generating new signing key")
|
||||
new_key = self.generate_new_key_pair()
|
||||
|
||||
# Save the new key to database
|
||||
try:
|
||||
self.save_key_to_db(new_key, is_primary=True)
|
||||
current_app.logger.info(f"[OIDC] Saved new signing key to database: kid={new_key.kid}")
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error saving key to database: {e}")
|
||||
|
||||
return new_key
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,289 @@
|
||||
"""OIDC Session Service for session management during OIDC flow."""
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from datetime import timezone
|
||||
from flask import current_app, g
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models import OIDCSession, OIDCClient, User
|
||||
from gatehouse_app.exceptions.validation_exceptions import NotFoundError, ValidationError
|
||||
|
||||
|
||||
class OIDCSessionService:
|
||||
"""Service for managing OIDC authentication sessions.
|
||||
|
||||
This service handles:
|
||||
- Creating OIDC sessions during authorization flow
|
||||
- Validating sessions with state and nonce
|
||||
- Managing PKCE code challenges
|
||||
- Cleaning up expired sessions
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_state() -> str:
|
||||
"""Generate a secure state parameter.
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded state
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _generate_nonce() -> str:
|
||||
"""Generate a secure nonce for OIDC.
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded nonce
|
||||
"""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _generate_code_challenge(verifier: str, method: str = "S256") -> str:
|
||||
"""Generate a PKCE code challenge from verifier.
|
||||
|
||||
Args:
|
||||
verifier: The code verifier
|
||||
method: Challenge method ("S256" or "plain")
|
||||
|
||||
Returns:
|
||||
Code challenge string
|
||||
"""
|
||||
import hashlib
|
||||
import base64
|
||||
|
||||
if method == "S256":
|
||||
digest = hashlib.sha256(verifier.encode()).digest()
|
||||
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
elif method == "plain":
|
||||
return verifier
|
||||
else:
|
||||
raise ValueError(f"Unsupported code challenge method: {method}")
|
||||
|
||||
@classmethod
|
||||
def validate_code_verifier(cls, code_verifier: str, code_challenge: str,
|
||||
method: str = "S256") -> bool:
|
||||
"""Validate a PKCE code verifier against the stored challenge.
|
||||
|
||||
Args:
|
||||
code_verifier: The code verifier from the token request
|
||||
code_challenge: The code challenge from the authorization request
|
||||
method: The challenge method used
|
||||
|
||||
Returns:
|
||||
True if validation succeeds
|
||||
"""
|
||||
if not code_verifier or not code_challenge:
|
||||
return False
|
||||
|
||||
# Validate code verifier length (43-128 characters)
|
||||
if method == "S256" and not (43 <= len(code_verifier) <= 128):
|
||||
return False
|
||||
|
||||
# Calculate expected challenge
|
||||
expected_challenge = cls._generate_code_challenge(code_verifier, method)
|
||||
|
||||
return secrets.compare_digest(expected_challenge, code_challenge)
|
||||
|
||||
@classmethod
|
||||
def create_session(
|
||||
cls,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
state: str = None,
|
||||
nonce: str = None,
|
||||
redirect_uri: str = None,
|
||||
scope: list = None,
|
||||
code_challenge: str = None,
|
||||
code_challenge_method: str = None,
|
||||
lifetime_seconds: int = 600
|
||||
) -> OIDCSession:
|
||||
"""Create a new OIDC session for the authorization flow.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: The OIDC client ID
|
||||
state: State parameter (generated if not provided)
|
||||
nonce: Nonce for ID token validation (generated if not provided)
|
||||
redirect_uri: Redirect URI from authorization request
|
||||
scope: Requested scopes
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method ("S256" or "plain")
|
||||
lifetime_seconds: Session lifetime in seconds
|
||||
|
||||
Returns:
|
||||
OIDCSession instance
|
||||
"""
|
||||
# Generate state and nonce if not provided
|
||||
state = state or cls._generate_state()
|
||||
nonce = nonce or cls._generate_nonce()
|
||||
|
||||
session = OIDCSession.create_session(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
lifetime_seconds=lifetime_seconds,
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def validate_session(cls, state: str, nonce: str = None) -> Tuple[OIDCSession, User]:
|
||||
"""Validate an OIDC session by state and optionally nonce.
|
||||
|
||||
Args:
|
||||
state: The state parameter
|
||||
nonce: The nonce to validate (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (OIDCSession, User)
|
||||
|
||||
Raises:
|
||||
ValidationError: If session is invalid
|
||||
NotFoundError: If session not found
|
||||
"""
|
||||
session = OIDCSession.get_by_state(state)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError("OIDC session not found or expired")
|
||||
|
||||
if session.is_expired():
|
||||
raise ValidationError("OIDC session has expired")
|
||||
|
||||
# Validate nonce if provided
|
||||
if nonce and not session.validate_nonce(nonce):
|
||||
raise ValidationError("Invalid nonce")
|
||||
|
||||
# Get user
|
||||
user = User.query.get(session.user_id)
|
||||
if not user:
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
return session, user
|
||||
|
||||
@classmethod
|
||||
def validate_pkce(cls, session: OIDCSession, code_verifier: str) -> bool:
|
||||
"""Validate PKCE code verifier against the session's code challenge.
|
||||
|
||||
Args:
|
||||
session: OIDCSession instance
|
||||
code_verifier: The code verifier from token request
|
||||
|
||||
Returns:
|
||||
True if validation succeeds
|
||||
|
||||
Raises:
|
||||
ValidationError: If PKCE validation fails
|
||||
"""
|
||||
if not session.code_challenge:
|
||||
# No PKCE was used, skip validation
|
||||
return True
|
||||
|
||||
if not code_verifier:
|
||||
raise ValidationError("code_verifier is required")
|
||||
|
||||
is_valid = session.validate_code_challenge(code_verifier)
|
||||
|
||||
if not is_valid:
|
||||
raise ValidationError("Invalid code_verifier")
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def mark_session_authenticated(cls, session: OIDCSession) -> OIDCSession:
|
||||
"""Mark a session as authenticated (user has logged in).
|
||||
|
||||
Args:
|
||||
session: OIDCSession instance
|
||||
|
||||
Returns:
|
||||
Updated OIDCSession instance
|
||||
"""
|
||||
session.mark_authenticated()
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired_sessions(cls, older_than_hours: int = 24) -> int:
|
||||
"""Remove expired OIDC sessions.
|
||||
|
||||
Args:
|
||||
older_than_hours: Only delete sessions expired more than this many hours ago
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=older_than_hours)
|
||||
|
||||
# Get expired sessions
|
||||
expired_sessions = OIDCSession.query.filter(
|
||||
OIDCSession.expires_at < datetime.now(timezone.utc),
|
||||
OIDCSession.deleted_at == None
|
||||
).all()
|
||||
|
||||
count = 0
|
||||
for session in expired_sessions:
|
||||
# Only hard delete if past the grace period
|
||||
if session.expires_at < cutoff:
|
||||
session.delete()
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def get_session_by_state(cls, state: str) -> Optional[OIDCSession]:
|
||||
"""Get an OIDC session by state.
|
||||
|
||||
Args:
|
||||
state: The state parameter
|
||||
|
||||
Returns:
|
||||
OIDCSession instance or None
|
||||
"""
|
||||
return OIDCSession.get_by_state(state)
|
||||
|
||||
@classmethod
|
||||
def validate_redirect_uri(cls, client_id: str, redirect_uri: str) -> bool:
|
||||
"""Validate that a redirect URI is allowed for a client.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
redirect_uri: The redirect URI to validate
|
||||
|
||||
Returns:
|
||||
True if redirect URI is allowed
|
||||
"""
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
return False
|
||||
|
||||
return client.is_redirect_uri_allowed(redirect_uri)
|
||||
|
||||
@classmethod
|
||||
def validate_scopes(cls, client_id: str, requested_scopes: list) -> list:
|
||||
"""Validate and filter scopes against client's allowed scopes.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
requested_scopes: List of requested scopes
|
||||
|
||||
Returns:
|
||||
List of allowed scopes
|
||||
"""
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
if not client:
|
||||
return []
|
||||
|
||||
allowed_scopes = client.scopes or []
|
||||
|
||||
# Filter to only allowed scopes
|
||||
valid_scopes = [s for s in requested_scopes if s in allowed_scopes]
|
||||
|
||||
return valid_scopes
|
||||
@@ -0,0 +1,593 @@
|
||||
"""OIDC Token Service for JWT token generation and validation."""
|
||||
import hashlib
|
||||
import base64
|
||||
import secrets
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
import jwt
|
||||
from flask import current_app, g
|
||||
|
||||
from gatehouse_app.models import User, OIDCClient
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OIDCTokenService:
|
||||
"""Service for generating and validating OIDC tokens.
|
||||
|
||||
This service handles:
|
||||
- Access token creation (JWT)
|
||||
- ID token creation (JWT)
|
||||
- Refresh token creation (opaque)
|
||||
- Token signature verification
|
||||
- Hash generation for PKCE claims (at_hash, c_hash)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_jti() -> str:
|
||||
"""Generate a unique JWT ID."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def _generate_opaque_token(length: int = 43) -> str:
|
||||
"""Generate an opaque token (for refresh tokens).
|
||||
|
||||
Args:
|
||||
length: Length of the token
|
||||
|
||||
Returns:
|
||||
URL-safe base64 encoded token
|
||||
"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
@staticmethod
|
||||
def _hash_token(token: str) -> str:
|
||||
"""Hash a token for secure storage.
|
||||
|
||||
Args:
|
||||
token: Token to hash
|
||||
|
||||
Returns:
|
||||
SHA256 hash of the token
|
||||
"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _base64url_encode(data: bytes) -> str:
|
||||
"""Encode bytes to base64url format without padding.
|
||||
|
||||
Args:
|
||||
data: Bytes to encode
|
||||
|
||||
Returns:
|
||||
Base64url encoded string
|
||||
"""
|
||||
return base64.urlsafe_b64encode(data).decode().rstrip("=")
|
||||
|
||||
@staticmethod
|
||||
def create_at_hash(access_token: str) -> str:
|
||||
"""Create the at_hash claim for ID token.
|
||||
|
||||
Implements OIDC spec for access token hash generation.
|
||||
Hash is the left-most half of the hash of the ASCII representation
|
||||
of the access token.
|
||||
|
||||
Args:
|
||||
access_token: The access token string
|
||||
|
||||
Returns:
|
||||
Base64url encoded hash
|
||||
"""
|
||||
# Hash the access token using SHA256
|
||||
hash_digest = hashlib.sha256(access_token.encode()).digest()
|
||||
|
||||
# Take left-most half of the hash
|
||||
half_length = len(hash_digest) // 2
|
||||
left_half = hash_digest[:half_length]
|
||||
|
||||
# Base64url encode
|
||||
return OIDCTokenService._base64url_encode(left_half)
|
||||
|
||||
@staticmethod
|
||||
def create_c_hash(code: str) -> str:
|
||||
"""Create the c_hash claim for ID token.
|
||||
|
||||
Implements OIDC spec for authorization code hash generation.
|
||||
|
||||
Args:
|
||||
code: The authorization code string
|
||||
|
||||
Returns:
|
||||
Base64url encoded hash
|
||||
"""
|
||||
# Hash the code using SHA256
|
||||
hash_digest = hashlib.sha256(code.encode()).digest()
|
||||
|
||||
# Take left-most half of the hash
|
||||
half_length = len(hash_digest) // 2
|
||||
left_half = hash_digest[:half_length]
|
||||
|
||||
# Base64url encode
|
||||
return OIDCTokenService._base64url_encode(left_half)
|
||||
|
||||
@staticmethod
|
||||
def _get_issuer() -> str:
|
||||
"""Get the OIDC issuer URL."""
|
||||
return current_app.config.get("OIDC_ISSUER_URL", "http://localhost:5000")
|
||||
|
||||
@staticmethod
|
||||
def _get_token_lifetime(client: OIDCClient, token_type: str) -> int:
|
||||
"""Get the token lifetime in seconds for a client.
|
||||
|
||||
Args:
|
||||
client: OIDCClient instance
|
||||
token_type: Type of token ("access_token", "refresh_token", "id_token")
|
||||
|
||||
Returns:
|
||||
Lifetime in seconds
|
||||
"""
|
||||
lifetimes = {
|
||||
"access_token": client.access_token_lifetime or 3600,
|
||||
"refresh_token": client.refresh_token_lifetime or 2592000,
|
||||
"id_token": client.id_token_lifetime or 3600,
|
||||
}
|
||||
return lifetimes.get(token_type, 3600)
|
||||
|
||||
@classmethod
|
||||
def create_access_token(cls, client_id: str, user_id: str, scope: list,
|
||||
jti: str = None) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID (subject)
|
||||
scope: List of granted scopes
|
||||
jti: Optional JWT ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
JWT access token string
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] create_access_token called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
|
||||
|
||||
jti = jti or cls._generate_jti()
|
||||
now_timestamp = int(time.time())
|
||||
now = datetime.now(timezone.utc)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "access_token") if client else 3600
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token lifetime (seconds): %s", lifetime)
|
||||
|
||||
exp_timestamp = now_timestamp + lifetime
|
||||
exp_time = now + timedelta(seconds=lifetime)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token expiration time (UTC): %s", exp_time.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token expiration timestamp: %s", exp_timestamp)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
|
||||
|
||||
claims = {
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": exp_timestamp,
|
||||
"iat": now_timestamp,
|
||||
"nbf": now_timestamp,
|
||||
"jti": jti,
|
||||
"client_id": client_id,
|
||||
"scope": " ".join(scope) if isinstance(scope, list) else scope,
|
||||
}
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, nbf=%s",
|
||||
claims["exp"], claims["iat"], claims["nbf"])
|
||||
|
||||
# Get signing key
|
||||
jwks_service = OIDCJWKSService()
|
||||
signing_key = jwks_service.get_signing_key()
|
||||
|
||||
if not signing_key:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
algorithm="RS256",
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Access token created successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def create_id_token(cls, client_id: str, user_id: str, nonce: str = None,
|
||||
scope: list = None, access_token: str = None,
|
||||
auth_time: int = None) -> str:
|
||||
"""Create a JWT ID token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID (subject)
|
||||
nonce: Nonce for replay protection
|
||||
scope: Requested/Granted scopes
|
||||
access_token: Associated access token (for at_hash)
|
||||
auth_time: Authentication time (Unix timestamp)
|
||||
|
||||
Returns:
|
||||
JWT ID token string
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] create_id_token called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||
logger.debug("[OIDC TOKEN SERVICE] nonce=%s, auth_time=%s", nonce, auth_time)
|
||||
logger.debug("[OIDC TOKEN SERVICE] scope=%s", scope)
|
||||
|
||||
now_timestamp = int(time.time())
|
||||
now = datetime.now(timezone.utc)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation time (UTC): %s", now.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token creation timestamp: %s", now_timestamp)
|
||||
auth_time = auth_time or now_timestamp
|
||||
logger.debug("[OIDC TOKEN SERVICE] auth_time (Unix timestamp): %s", auth_time)
|
||||
|
||||
# Get client for token lifetime
|
||||
client = OIDCClient.query.filter_by(client_id=client_id).first()
|
||||
lifetime = cls._get_token_lifetime(client, "id_token") if client else 3600
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token lifetime (seconds): %s", lifetime)
|
||||
|
||||
exp_timestamp = now_timestamp + lifetime
|
||||
exp_time = now + timedelta(seconds=lifetime)
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token expiration time (UTC): %s", exp_time.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token expiration timestamp: %s", exp_timestamp)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Time until expiration (seconds): %s", lifetime)
|
||||
|
||||
# Get user for claims
|
||||
user = User.query.get(user_id)
|
||||
|
||||
claims = {
|
||||
"iss": cls._get_issuer(),
|
||||
"sub": user_id,
|
||||
"aud": client_id,
|
||||
"exp": exp_timestamp,
|
||||
"iat": now_timestamp,
|
||||
"auth_time": auth_time,
|
||||
}
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token claims: exp=%s, iat=%s, auth_time=%s",
|
||||
claims["exp"], claims["iat"], claims["auth_time"])
|
||||
|
||||
# Add nonce if provided
|
||||
if nonce:
|
||||
claims["nonce"] = nonce
|
||||
|
||||
# Add at_hash if access token provided
|
||||
if access_token:
|
||||
claims["at_hash"] = cls.create_at_hash(access_token)
|
||||
|
||||
# Add standard claims if user exists
|
||||
if user:
|
||||
if user.email:
|
||||
claims["email"] = user.email
|
||||
claims["email_verified"] = user.email_verified
|
||||
if user.full_name:
|
||||
claims["name"] = user.full_name
|
||||
|
||||
# Add roles claim if scope is granted
|
||||
if scope and "roles" in scope:
|
||||
claims["roles"] = cls._get_user_roles(user)
|
||||
|
||||
# Add scope if provided
|
||||
if scope:
|
||||
claims["scope"] = " ".join(scope) if isinstance(scope, list) else scope
|
||||
|
||||
# Get signing key
|
||||
jwks_service = OIDCJWKSService()
|
||||
signing_key = jwks_service.get_signing_key()
|
||||
|
||||
if not signing_key:
|
||||
raise ValueError("No signing key available")
|
||||
|
||||
# Sign with RS256
|
||||
logger.debug("[OIDC TOKEN SERVICE] Signing token with RS256...")
|
||||
token = jwt.encode(
|
||||
claims,
|
||||
signing_key.private_key,
|
||||
algorithm="RS256",
|
||||
headers={"kid": signing_key.kid}
|
||||
)
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] ID token created successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def _get_user_roles(user: User) -> list:
|
||||
"""Get user's organization roles.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
List of role objects with organization_id and role
|
||||
"""
|
||||
roles = []
|
||||
if user and user.organization_memberships:
|
||||
for member in user.organization_memberships:
|
||||
roles.append({
|
||||
"organization_id": str(member.organization_id),
|
||||
"role": member.role.value
|
||||
})
|
||||
return roles
|
||||
|
||||
@classmethod
|
||||
def create_refresh_token(cls, client_id: str, user_id: str,
|
||||
scope: list = None, access_token_id: str = None) -> str:
|
||||
"""Create an opaque refresh token.
|
||||
|
||||
Args:
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
scope: List of granted scopes
|
||||
access_token_id: Associated access token ID
|
||||
|
||||
Returns:
|
||||
Opaque refresh token string
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] create_refresh_token called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Current UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] client_id=%s, user_id=%s", client_id, user_id)
|
||||
logger.debug("[OIDC TOKEN SERVICE] scope=%s, access_token_id=%s", scope, access_token_id)
|
||||
|
||||
token = cls._generate_opaque_token()
|
||||
logger.debug("[OIDC TOKEN SERVICE] Refresh token generated: %s...", token[:20] if token else None)
|
||||
|
||||
# Hash for storage
|
||||
token_hash = cls._hash_token(token)
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Refresh token created successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Final UTC time: %s", datetime.now(timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return token, token_hash
|
||||
|
||||
@classmethod
|
||||
def verify_token_signature(cls, token: str) -> Dict:
|
||||
"""Verify the signature of a JWT token.
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
Decoded token claims
|
||||
|
||||
Raises:
|
||||
jwt.InvalidSignatureError: If signature verification fails
|
||||
jwt.ExpiredSignatureError: If token is expired
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] verify_token_signature() called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
|
||||
|
||||
# Get the JWKS with public keys
|
||||
logger.debug("[OIDC TOKEN SERVICE] Getting JWKS...")
|
||||
jwks_service = OIDCJWKSService()
|
||||
jwks = jwks_service.get_jwks(include_private_keys=True)
|
||||
logger.debug("[OIDC TOKEN SERVICE] JWKS retrieved: %d keys", len(jwks.get("keys", [])))
|
||||
|
||||
# Get the key ID from token header
|
||||
try:
|
||||
logger.debug("[OIDC TOKEN SERVICE] Getting unverified token header...")
|
||||
unverified_header = jwt.get_unverified_header(token)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Unverified header: %s", unverified_header)
|
||||
except jwt.DecodeError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Failed to decode token header: %s", str(e))
|
||||
raise jwt.InvalidTokenError("Invalid token header")
|
||||
|
||||
kid = unverified_header.get("kid")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Key ID (kid) from token header: %s", kid)
|
||||
|
||||
# Find the matching public key
|
||||
logger.debug("[OIDC TOKEN SERVICE] Searching for matching public key...")
|
||||
public_key = None
|
||||
for idx, key in enumerate(jwks.get("keys", [])):
|
||||
logger.debug("[OIDC TOKEN SERVICE] Checking key %d: kid=%s", idx, key.get("kid"))
|
||||
if key.get("kid") == kid:
|
||||
logger.debug("[OIDC TOKEN SERVICE] Found matching key at index %d", idx)
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Loading PEM public key...")
|
||||
public_key = serialization.load_pem_public_key(
|
||||
key["public_key"].encode() if isinstance(key["public_key"], str)
|
||||
else key["public_key"],
|
||||
backend=default_backend()
|
||||
)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Public key loaded successfully")
|
||||
break
|
||||
except (ImportError, Exception) as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Failed to load public key: %s: %s", type(e).__name__, str(e))
|
||||
continue
|
||||
|
||||
if not public_key:
|
||||
logger.error("[OIDC TOKEN SERVICE] No matching public key found for kid=%s", kid)
|
||||
raise jwt.InvalidSignatureError(f"Key with kid={kid} not found")
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Public key found, verifying signature...")
|
||||
|
||||
# Verify the signature
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
public_key,
|
||||
algorithms=["RS256"],
|
||||
audience=None, # We'll validate audience separately
|
||||
issuer=cls._get_issuer(),
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_aud": False, # Handle audience manually
|
||||
"verify_iss": False, # Handle issuer manually
|
||||
}
|
||||
)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Signature verification successful")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Decoded claims: %s", claims)
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
return claims
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Token has expired: %s", str(e))
|
||||
raise
|
||||
except jwt.InvalidSignatureError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Invalid token signature: %s", str(e))
|
||||
raise
|
||||
except jwt.InvalidTokenError as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Invalid token: %s: %s", type(e).__name__, str(e))
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("[OIDC TOKEN SERVICE] Unexpected error during token verification: %s: %s", type(e).__name__, str(e))
|
||||
import traceback
|
||||
logger.error("[OIDC TOKEN SERVICE] Traceback: %s", traceback.format_exc())
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def decode_token(cls, token: str, verify: bool = False) -> Dict:
|
||||
"""Decode a JWT token without verification (for debugging).
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
verify: Whether to verify signature
|
||||
|
||||
Returns:
|
||||
Decoded token claims
|
||||
"""
|
||||
if verify:
|
||||
return cls.verify_token_signature(token)
|
||||
|
||||
return jwt.decode(
|
||||
token,
|
||||
options={
|
||||
"verify_signature": False,
|
||||
"verify_exp": False,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_access_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
"""Validate an access token and return its claims.
|
||||
|
||||
Args:
|
||||
token: JWT access token
|
||||
client_id: Optional client ID to validate audience
|
||||
|
||||
Returns:
|
||||
Token claims dictionary
|
||||
|
||||
Raises:
|
||||
jwt.InvalidTokenError: If token is invalid
|
||||
ValueError: If token is expired or audience mismatch
|
||||
"""
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() called")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token (first 50 chars): %s...", token[:50] if len(token) > 50 else token)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token length: %d", len(token))
|
||||
logger.debug("[OIDC TOKEN SERVICE] Client ID: %s", client_id)
|
||||
|
||||
# Verify token signature
|
||||
logger.debug("[OIDC TOKEN SERVICE] Verifying token signature...")
|
||||
claims = cls.verify_token_signature(token)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token signature verified")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Claims: %s", claims)
|
||||
|
||||
# Check expiration
|
||||
exp = claims.get("exp", 0)
|
||||
now_timestamp = int(time.time())
|
||||
|
||||
if exp < now_timestamp:
|
||||
logger.error("[OIDC TOKEN SERVICE] Token has expired")
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
# Validate audience if client_id provided
|
||||
aud = claims.get("aud")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Token audience (aud): %s", aud)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Expected client_id: %s", client_id)
|
||||
|
||||
if client_id:
|
||||
if aud != client_id:
|
||||
logger.error("[OIDC TOKEN SERVICE] Audience mismatch: expected=%s, got=%s", client_id, aud)
|
||||
raise ValueError("Invalid audience")
|
||||
logger.debug("[OIDC TOKEN SERVICE] Audience validation passed")
|
||||
else:
|
||||
logger.debug("[OIDC TOKEN SERVICE] No client_id provided, skipping audience validation")
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] validate_access_token() completed successfully")
|
||||
logger.debug("[OIDC TOKEN SERVICE] ===========================================")
|
||||
|
||||
return claims
|
||||
|
||||
@classmethod
|
||||
def introspect_token(cls, token: str, client_id: str = None) -> Dict:
|
||||
"""Introspect a token and return its status and claims.
|
||||
|
||||
Args:
|
||||
token: JWT token to introspect
|
||||
client_id: Client ID for audience validation
|
||||
|
||||
Returns:
|
||||
Dictionary with active status and claims
|
||||
"""
|
||||
result = {
|
||||
"active": False,
|
||||
}
|
||||
|
||||
try:
|
||||
claims = cls.validate_access_token(token, client_id)
|
||||
|
||||
# Calculate remaining time
|
||||
now_timestamp = int(time.time())
|
||||
now = datetime.now(timezone.utc)
|
||||
exp = claims.get("exp", 0)
|
||||
iat = claims.get("iat", 0)
|
||||
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Current UTC time: %s", now.isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration timestamp: %s", exp)
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Token expiration datetime (UTC): %s", datetime.fromtimestamp(exp, tz=timezone.utc).isoformat())
|
||||
logger.debug("[OIDC TOKEN SERVICE] Introspection - Time until expiration: %s seconds", exp - now_timestamp)
|
||||
|
||||
result["active"] = exp > now_timestamp
|
||||
result.update({
|
||||
"iss": claims.get("iss"),
|
||||
"sub": claims.get("sub"),
|
||||
"aud": claims.get("aud"),
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
"nbf": claims.get("nbf"),
|
||||
"jti": claims.get("jti"),
|
||||
"client_id": claims.get("client_id"),
|
||||
"scope": claims.get("scope"),
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
|
||||
# Add expiry in seconds
|
||||
if exp > now_timestamp:
|
||||
result["exp"] = int(exp - now_timestamp)
|
||||
|
||||
except (jwt.InvalidTokenError, ValueError) as e:
|
||||
result["active"] = False
|
||||
result["error"] = str(e)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,303 @@
|
||||
"""Organization service."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from flask import current_app
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.exceptions.validation_exceptions import OrganizationNotFoundError, ConflictError
|
||||
from gatehouse_app.utils.constants import OrganizationRole, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrganizationService:
|
||||
"""Service for organization operations."""
|
||||
|
||||
@staticmethod
|
||||
def create_organization(name, slug, owner_user_id, description=None, logo_url=None):
|
||||
"""
|
||||
Create a new organization.
|
||||
|
||||
Args:
|
||||
name: Organization name
|
||||
slug: Unique organization slug
|
||||
owner_user_id: ID of the user who will be the owner
|
||||
description: Optional description
|
||||
logo_url: Optional logo URL
|
||||
|
||||
Returns:
|
||||
Organization instance
|
||||
|
||||
Raises:
|
||||
ConflictError: If slug already exists
|
||||
"""
|
||||
# Check if slug already exists
|
||||
existing = Organization.query.filter_by(slug=slug, deleted_at=None).first()
|
||||
if existing:
|
||||
raise ConflictError("Organization slug already exists")
|
||||
|
||||
# Create organization
|
||||
org = Organization(
|
||||
name=name,
|
||||
slug=slug,
|
||||
description=description,
|
||||
logo_url=logo_url,
|
||||
is_active=True,
|
||||
)
|
||||
org.save()
|
||||
|
||||
# Add owner as member
|
||||
member = OrganizationMember(
|
||||
user_id=owner_user_id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.OWNER,
|
||||
joined_at=datetime.now(timezone.utc),
|
||||
)
|
||||
member.save()
|
||||
|
||||
# Log organization creation
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_CREATE,
|
||||
user_id=owner_user_id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization",
|
||||
resource_id=org.id,
|
||||
description=f"Organization created: {name}",
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_organization_by_id(org_id):
|
||||
"""
|
||||
Get organization by ID.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
|
||||
Returns:
|
||||
Organization instance
|
||||
|
||||
Raises:
|
||||
OrganizationNotFoundError: If organization not found
|
||||
"""
|
||||
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
|
||||
|
||||
# Development-only debug logging for organization validation
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[Org] Get organization by ID: org_id={org_id}, exists={org is not None}")
|
||||
|
||||
if not org:
|
||||
raise OrganizationNotFoundError()
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_organization_by_slug(slug):
|
||||
"""
|
||||
Get organization by slug.
|
||||
|
||||
Args:
|
||||
slug: Organization slug
|
||||
|
||||
Returns:
|
||||
Organization instance or None
|
||||
"""
|
||||
org = Organization.query.filter_by(slug=slug, deleted_at=None).first()
|
||||
|
||||
# Development-only debug logging for organization validation
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[Org] Get organization by slug: slug={slug}, exists={org is not None}")
|
||||
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def update_organization(org, user_id, **kwargs):
|
||||
"""
|
||||
Update organization.
|
||||
|
||||
Args:
|
||||
org: Organization instance
|
||||
user_id: ID of user performing the update
|
||||
**kwargs: Fields to update
|
||||
|
||||
Returns:
|
||||
Updated Organization instance
|
||||
"""
|
||||
allowed_fields = ["name", "description", "logo_url"]
|
||||
update_data = {k: v for k, v in kwargs.items() if k in allowed_fields}
|
||||
|
||||
if update_data:
|
||||
org.update(**update_data)
|
||||
|
||||
# Log organization update
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_UPDATE,
|
||||
user_id=user_id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization",
|
||||
resource_id=org.id,
|
||||
metadata=update_data,
|
||||
description="Organization updated",
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def delete_organization(org, user_id, soft=True):
|
||||
"""
|
||||
Delete organization.
|
||||
|
||||
Args:
|
||||
org: Organization instance
|
||||
user_id: ID of user performing the delete
|
||||
soft: If True, performs soft delete
|
||||
|
||||
Returns:
|
||||
Deleted Organization instance
|
||||
"""
|
||||
org.delete(soft=soft)
|
||||
|
||||
# Log organization deletion
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_DELETE,
|
||||
user_id=user_id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization",
|
||||
resource_id=org.id,
|
||||
description=f"Organization {'soft' if soft else 'hard'} deleted",
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def add_member(org, user_id, role, inviter_id):
|
||||
"""
|
||||
Add a member to the organization.
|
||||
|
||||
Args:
|
||||
org: Organization instance
|
||||
user_id: ID of user to add
|
||||
role: OrganizationRole
|
||||
inviter_id: ID of user performing the invitation
|
||||
|
||||
Returns:
|
||||
OrganizationMember instance
|
||||
|
||||
Raises:
|
||||
ConflictError: If user is already a member
|
||||
"""
|
||||
# Check if already a member
|
||||
existing = OrganizationMember.query.filter_by(
|
||||
user_id=user_id,
|
||||
organization_id=org.id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
# Development-only debug logging for membership validation
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
|
||||
|
||||
if existing:
|
||||
raise ConflictError("User is already a member of this organization")
|
||||
|
||||
# Create membership
|
||||
member = OrganizationMember(
|
||||
user_id=user_id,
|
||||
organization_id=org.id,
|
||||
role=role,
|
||||
invited_by_id=inviter_id,
|
||||
invited_at=datetime.now(timezone.utc),
|
||||
joined_at=datetime.now(timezone.utc),
|
||||
)
|
||||
member.save()
|
||||
|
||||
# Log member addition
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_MEMBER_ADD,
|
||||
user_id=inviter_id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization_member",
|
||||
resource_id=member.id,
|
||||
metadata={"added_user_id": user_id, "role": role.value},
|
||||
description=f"Member added to organization with role: {role.value}",
|
||||
)
|
||||
|
||||
return member
|
||||
|
||||
@staticmethod
|
||||
def remove_member(org, user_id, remover_id):
|
||||
"""
|
||||
Remove a member from the organization.
|
||||
|
||||
Args:
|
||||
org: Organization instance
|
||||
user_id: ID of user to remove
|
||||
remover_id: ID of user performing the removal
|
||||
"""
|
||||
member = OrganizationMember.query.filter_by(
|
||||
user_id=user_id,
|
||||
organization_id=org.id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
# Development-only debug logging for membership removal validation
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[Org] Member removal: org_id={org.id}, user_id={user_id}, found={member is not None}")
|
||||
|
||||
if member:
|
||||
member.delete(soft=True)
|
||||
|
||||
# Log member removal
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_MEMBER_REMOVE,
|
||||
user_id=remover_id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization_member",
|
||||
resource_id=member.id,
|
||||
metadata={"removed_user_id": user_id},
|
||||
description="Member removed from organization",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_member_role(org, user_id, new_role, updater_id):
|
||||
"""
|
||||
Update a member's role in the organization.
|
||||
|
||||
Args:
|
||||
org: Organization instance
|
||||
user_id: ID of user whose role to update
|
||||
new_role: New OrganizationRole
|
||||
updater_id: ID of user performing the update
|
||||
|
||||
Returns:
|
||||
Updated OrganizationMember instance
|
||||
"""
|
||||
member = OrganizationMember.query.filter_by(
|
||||
user_id=user_id,
|
||||
organization_id=org.id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if member:
|
||||
old_role = member.role
|
||||
member.role = new_role
|
||||
db.session.commit()
|
||||
|
||||
# Log role change
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_MEMBER_ROLE_CHANGE,
|
||||
user_id=updater_id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization_member",
|
||||
resource_id=member.id,
|
||||
metadata={
|
||||
"target_user_id": user_id,
|
||||
"old_role": old_role.value,
|
||||
"new_role": new_role.value,
|
||||
},
|
||||
description=f"Member role changed from {old_role.value} to {new_role.value}",
|
||||
)
|
||||
|
||||
return member
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Session service."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.models.session import Session
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""Service for session operations."""
|
||||
|
||||
@staticmethod
|
||||
def get_active_session_by_token(token):
|
||||
"""Get active session by token.
|
||||
|
||||
Args:
|
||||
token: The session token string
|
||||
|
||||
Returns:
|
||||
Session object if found and active, None otherwise
|
||||
"""
|
||||
from gatehouse_app.models.session import Session
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
return Session.query.filter_by(
|
||||
token=token,
|
||||
status=SessionStatus.ACTIVE,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_sessions(user_id, active_only=True):
|
||||
"""
|
||||
Get all sessions for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
active_only: If True, only return active sessions
|
||||
|
||||
Returns:
|
||||
List of Session instances
|
||||
"""
|
||||
query = Session.query.filter_by(user_id=user_id, deleted_at=None)
|
||||
|
||||
if active_only:
|
||||
query = query.filter_by(status=SessionStatus.ACTIVE).filter(
|
||||
Session.expires_at > datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
return query.all()
|
||||
|
||||
@staticmethod
|
||||
def revoke_user_sessions(user_id, reason="User logged out from all devices"):
|
||||
"""
|
||||
Revoke all active sessions for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
reason: Reason for revocation
|
||||
"""
|
||||
sessions = SessionService.get_user_sessions(user_id, active_only=True)
|
||||
|
||||
for session in sessions:
|
||||
session.revoke(reason=reason)
|
||||
|
||||
@staticmethod
|
||||
def cleanup_expired_sessions():
|
||||
"""Clean up expired sessions."""
|
||||
expired_sessions = Session.query.filter(
|
||||
Session.status == SessionStatus.ACTIVE,
|
||||
Session.expires_at < datetime.now(timezone.utc),
|
||||
Session.deleted_at.is_(None),
|
||||
).all()
|
||||
|
||||
for session in expired_sessions:
|
||||
session.status = SessionStatus.EXPIRED
|
||||
session.save()
|
||||
|
||||
return len(expired_sessions)
|
||||
@@ -0,0 +1,214 @@
|
||||
"""TOTP (Time-based One-Time Password) service."""
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Tuple
|
||||
|
||||
import pyotp
|
||||
from gatehouse_app.extensions import bcrypt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TOTPService:
|
||||
"""Service for TOTP operations."""
|
||||
|
||||
@staticmethod
|
||||
def generate_secret() -> str:
|
||||
"""
|
||||
Generate a new TOTP secret.
|
||||
|
||||
Returns:
|
||||
Base32 encoded secret (32 characters)
|
||||
|
||||
Note:
|
||||
The secret is generated using cryptographically secure random bytes
|
||||
and encoded in base32 format for compatibility with authenticator apps.
|
||||
"""
|
||||
# Generate 20 random bytes (160 bits) and encode as base32
|
||||
random_bytes = secrets.token_bytes(20)
|
||||
secret = base64.b32encode(random_bytes).decode("utf-8")
|
||||
logger.debug(f"Generated new TOTP secret: {secret[:8]}...")
|
||||
return secret
|
||||
|
||||
@staticmethod
|
||||
def generate_provisioning_uri(user_email: str, secret: str, issuer: str = "Gatehouse") -> str:
|
||||
"""
|
||||
Generate provisioning URI for QR code.
|
||||
|
||||
Args:
|
||||
user_email: User's email address
|
||||
secret: TOTP secret (base32 encoded)
|
||||
issuer: Issuer name (default: "Gatehouse")
|
||||
|
||||
Returns:
|
||||
otpauth:// URI for QR code generation
|
||||
|
||||
Example:
|
||||
>>> uri = TOTPService.generate_provisioning_uri("user@example.com", "JBSWY3DPEHPK3PXP")
|
||||
>>> print(uri)
|
||||
otpauth://totp/Gatehouse:user@example.com?secret=JBSWY3DPEHPK3PXP&issuer=Gatehouse
|
||||
"""
|
||||
totp = pyotp.TOTP(secret)
|
||||
uri = totp.provisioning_uri(name=user_email, issuer_name=issuer)
|
||||
logger.debug(f"Generated provisioning URI for user: {user_email}")
|
||||
return uri
|
||||
|
||||
@staticmethod
|
||||
def verify_code(secret: str, code: str, window: int = 1) -> bool:
|
||||
"""
|
||||
Verify a TOTP code against the secret.
|
||||
|
||||
Args:
|
||||
secret: TOTP secret (base32 encoded)
|
||||
code: 6-digit TOTP code to verify
|
||||
window: Time window for code validation (default: 1, allows codes from previous/next time steps)
|
||||
|
||||
Returns:
|
||||
True if code is valid, False otherwise
|
||||
|
||||
Note:
|
||||
The window parameter allows for clock skew between the server
|
||||
and the authenticator app. A window of 1 allows codes from
|
||||
the previous, current, and next 30-second intervals.
|
||||
|
||||
IMPORTANT: Always uses UTC time for verification to ensure
|
||||
consistency across all timezones.
|
||||
"""
|
||||
totp = pyotp.TOTP(secret)
|
||||
# Use timezone-aware UTC datetime for verification
|
||||
# IMPORTANT: We must pass a datetime object, NOT a Unix timestamp
|
||||
# pyotp's internal datetime.utcfromtimestamp() is deprecated and can be
|
||||
# affected by local timezone settings, causing the 10.5 hour skew issue
|
||||
utc_now = datetime.now(timezone.utc)
|
||||
|
||||
# DEBUG: Log detailed timezone information
|
||||
logger.debug(f"[TOTP DEBUG] UTC now: {utc_now}")
|
||||
logger.debug(f"[TOTP DEBUG] UTC now isoformat: {utc_now.isoformat()}")
|
||||
logger.debug(f"[TOTP DEBUG] UTC timestamp: {utc_now.timestamp()}")
|
||||
logger.debug(f"[TOTP DEBUG] UTC now tzinfo: {utc_now.tzinfo}")
|
||||
|
||||
# Generate what the TOTP code should be at this moment using UTC datetime
|
||||
expected_code = totp.at(utc_now)
|
||||
logger.debug(f"[TOTP DEBUG] Expected TOTP code at UTC: {expected_code}")
|
||||
|
||||
# Verify with the provided code using UTC datetime object
|
||||
# Passing a datetime object avoids pyotp's utcfromtimestamp() issues
|
||||
is_valid = totp.verify(code, valid_window=window, for_time=utc_now)
|
||||
|
||||
logger.debug(f"[TOTP DEBUG] TOTP code verification: valid={is_valid}, window={window}")
|
||||
logger.debug(f"[TOTP DEBUG] Provided code: {code}, Expected code: {expected_code}")
|
||||
|
||||
return is_valid
|
||||
|
||||
@staticmethod
|
||||
def generate_backup_codes(count: int = 10) -> Tuple[list[str], list[str]]:
|
||||
"""
|
||||
Generate backup codes for TOTP recovery.
|
||||
|
||||
Args:
|
||||
count: Number of backup codes to generate (default: 10)
|
||||
|
||||
Returns:
|
||||
Tuple of (plain_codes, hashed_codes)
|
||||
- plain_codes: List of plain text backup codes (for display to user)
|
||||
- hashed_codes: List of bcrypt hashed backup codes (for storage)
|
||||
|
||||
Note:
|
||||
Backup codes are 16-character alphanumeric codes that can be used
|
||||
to recover access if the TOTP device is lost. Each code can only
|
||||
be used once.
|
||||
"""
|
||||
plain_codes = []
|
||||
hashed_codes = []
|
||||
|
||||
for _ in range(count):
|
||||
# Generate a 16-character alphanumeric code
|
||||
code = secrets.token_hex(8).upper()
|
||||
plain_codes.append(code)
|
||||
|
||||
# Hash the code using bcrypt
|
||||
hashed_code = bcrypt.generate_password_hash(code).decode("utf-8")
|
||||
hashed_codes.append(hashed_code)
|
||||
|
||||
logger.debug(f"Generated {count} backup codes")
|
||||
return plain_codes, hashed_codes
|
||||
|
||||
@staticmethod
|
||||
def verify_backup_code(hashed_codes: list[str], code: str) -> Tuple[bool, list[str]]:
|
||||
"""
|
||||
Verify and consume a backup code.
|
||||
|
||||
Args:
|
||||
hashed_codes: List of bcrypt hashed backup codes
|
||||
code: Plain text backup code to verify
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, remaining_codes)
|
||||
- is_valid: True if code was valid and consumed, False otherwise
|
||||
- remaining_codes: List of remaining hashed codes (with consumed code removed)
|
||||
|
||||
Note:
|
||||
Once a backup code is used, it is removed from the list and cannot
|
||||
be used again. This ensures each code is single-use.
|
||||
"""
|
||||
remaining_codes = []
|
||||
|
||||
for hashed_code in hashed_codes:
|
||||
if bcrypt.check_password_hash(hashed_code, code):
|
||||
# Code found and valid - mark as matched but don't add to remaining codes
|
||||
matched = True
|
||||
else:
|
||||
# Code doesn't match - keep it in remaining codes
|
||||
remaining_codes.append(hashed_code)
|
||||
|
||||
if matched:
|
||||
return True, remaining_codes
|
||||
else:
|
||||
return False, remaining_codes
|
||||
|
||||
@staticmethod
|
||||
def generate_qr_code_data_uri(provisioning_uri: str) -> str:
|
||||
"""
|
||||
Generate QR code as data URI for frontend display.
|
||||
|
||||
Args:
|
||||
provisioning_uri: otpauth:// URI to encode in QR code
|
||||
|
||||
Returns:
|
||||
Base64 encoded PNG image as data URI (data:image/png;base64,...)
|
||||
|
||||
Note:
|
||||
If the qrcode library is not installed, returns a placeholder message.
|
||||
Install with: pip install qrcode[pil]
|
||||
"""
|
||||
try:
|
||||
import qrcode
|
||||
|
||||
# Create QR code
|
||||
qr = qrcode.QRCode(
|
||||
version=1,
|
||||
error_correction=qrcode.constants.ERROR_CORRECT_L,
|
||||
box_size=10,
|
||||
border=4,
|
||||
)
|
||||
qr.add_data(provisioning_uri)
|
||||
qr.make(fit=True)
|
||||
|
||||
# Generate image
|
||||
img = qr.make_image(fill_color="black", back_color="white")
|
||||
|
||||
# Convert to base64
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
data_uri = f"data:image/png;base64,{img_base64}"
|
||||
logger.debug("Generated QR code data URI")
|
||||
return data_uri
|
||||
|
||||
except ImportError:
|
||||
logger.warning("qrcode library not installed, returning placeholder")
|
||||
return "QR code generation requires the qrcode library. Install with: pip install qrcode[pil]"
|
||||
@@ -0,0 +1,125 @@
|
||||
"""User service."""
|
||||
import logging
|
||||
from flask import current_app
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.exceptions.validation_exceptions import UserNotFoundError
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserService:
|
||||
"""Service for user operations."""
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(user_id):
|
||||
"""
|
||||
Get user by ID.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
User instance
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If user not found
|
||||
"""
|
||||
user = User.query.filter_by(id=user_id, deleted_at=None).first()
|
||||
|
||||
# Development-only debug logging for user validation
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[User] Get user by ID: user_id={user_id}, exists={user is not None}")
|
||||
|
||||
if not user:
|
||||
raise UserNotFoundError()
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_email(email):
|
||||
"""
|
||||
Get user by email.
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
|
||||
Returns:
|
||||
User instance or None
|
||||
"""
|
||||
user = User.query.filter_by(email=email.lower(), deleted_at=None).first()
|
||||
|
||||
# Development-only debug logging for user validation
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[User] Get user by email: email={email}, exists={user is not None}")
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def update_user(user, **kwargs):
|
||||
"""
|
||||
Update user profile.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
**kwargs: Fields to update
|
||||
|
||||
Returns:
|
||||
Updated User instance
|
||||
"""
|
||||
allowed_fields = ["full_name", "avatar_url"]
|
||||
update_data = {k: v for k, v in kwargs.items() if k in allowed_fields}
|
||||
|
||||
if update_data:
|
||||
user.update(**update_data)
|
||||
|
||||
# Log user update
|
||||
AuditService.log_action(
|
||||
action=AuditAction.USER_UPDATE,
|
||||
user_id=user.id,
|
||||
resource_type="user",
|
||||
resource_id=user.id,
|
||||
metadata=update_data,
|
||||
description="User profile updated",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def delete_user(user, soft=True):
|
||||
"""
|
||||
Delete user account.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
soft: If True, performs soft delete
|
||||
|
||||
Returns:
|
||||
Deleted User instance
|
||||
"""
|
||||
user.delete(soft=soft)
|
||||
|
||||
# Log user deletion
|
||||
AuditService.log_action(
|
||||
action=AuditAction.USER_DELETE,
|
||||
user_id=user.id,
|
||||
resource_type="user",
|
||||
resource_id=user.id,
|
||||
description=f"User account {'soft' if soft else 'hard'} deleted",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_organizations(user):
|
||||
"""
|
||||
Get all organizations the user is a member of.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
List of organizations
|
||||
"""
|
||||
return user.get_organizations()
|
||||
@@ -0,0 +1,647 @@
|
||||
"""WebAuthn passkey authentication service."""
|
||||
import logging
|
||||
import secrets
|
||||
import hashlib
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any, List
|
||||
from flask import current_app
|
||||
|
||||
from gatehouse_app.extensions import db, redis_client
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebAuthnService:
|
||||
"""Service for WebAuthn passkey operations."""
|
||||
|
||||
# WebAuthn algorithm constants (COSE algorithms)
|
||||
COSE_ALGORITHMS = {
|
||||
-7: "ES256", # ECDSA with SHA-256
|
||||
-257: "RS256", # RSASSA-PKCS1-v1_5 with SHA-256
|
||||
}
|
||||
|
||||
# Supported key types
|
||||
KEY_TYPES = ["public-key"]
|
||||
|
||||
@staticmethod
|
||||
def _generate_challenge() -> str:
|
||||
"""Generate a cryptographically secure challenge.
|
||||
|
||||
Returns:
|
||||
Base64URL-encoded challenge string
|
||||
"""
|
||||
bytes_data = secrets.token_bytes(32)
|
||||
return base64.urlsafe_b64encode(bytes_data).decode('utf-8').rstrip('=')
|
||||
|
||||
@staticmethod
|
||||
def _store_challenge(user_id: str, challenge: str, challenge_type: str, expires_in: int = 300) -> bool:
|
||||
"""Store a challenge in Redis for validation.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
challenge: The challenge string
|
||||
challenge_type: Type of challenge ('registration' or 'authentication')
|
||||
expires_in: Expiration time in seconds
|
||||
|
||||
Returns:
|
||||
True if stored successfully
|
||||
"""
|
||||
try:
|
||||
key = f"webauthn:challenge:{user_id}:{challenge_type}:{challenge}"
|
||||
data = {
|
||||
"challenge": challenge,
|
||||
"user_id": user_id,
|
||||
"type": challenge_type,
|
||||
"created_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
redis_client.setex(key, expires_in, json.dumps(data))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store WebAuthn challenge: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_and_delete_challenge(user_id: str, challenge: str, challenge_type: str) -> Optional[Dict]:
|
||||
"""Retrieve and delete a challenge from Redis.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
challenge: The challenge string
|
||||
challenge_type: Type of challenge
|
||||
|
||||
Returns:
|
||||
Challenge data dict or None if not found/expired
|
||||
"""
|
||||
try:
|
||||
key = f"webauthn:challenge:{user_id}:{challenge_type}:{challenge}"
|
||||
data = redis_client.get(key)
|
||||
if data:
|
||||
redis_client.delete(key)
|
||||
return json.loads(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve WebAuthn challenge: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _base64url_decode(data: str) -> bytes:
|
||||
"""Decode Base64URL string to bytes."""
|
||||
# Add padding if needed
|
||||
padding = 4 - (len(data) % 4)
|
||||
if padding != 4:
|
||||
data += '=' * padding
|
||||
return base64.urlsafe_b64decode(data)
|
||||
|
||||
@staticmethod
|
||||
def _base64url_encode(data: bytes) -> str:
|
||||
"""Encode bytes to Base64URL string."""
|
||||
return base64.urlsafe_b64encode(data).decode('utf-8').rstrip('=')
|
||||
|
||||
@staticmethod
|
||||
def _hash_credential_id(credential_id: bytes) -> str:
|
||||
"""Hash a credential ID for secure storage lookup.
|
||||
|
||||
Args:
|
||||
credential_id: Raw credential ID bytes
|
||||
|
||||
Returns:
|
||||
Hashed credential ID string
|
||||
"""
|
||||
return hashlib.sha256(credential_id).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def generate_registration_challenge(cls, user: User) -> Dict[str, Any]:
|
||||
"""Generate a challenge for passkey registration.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
PublicKeyCredentialCreationOptions dict
|
||||
"""
|
||||
# Generate challenge
|
||||
challenge = cls._generate_challenge()
|
||||
|
||||
# Store challenge
|
||||
cls._store_challenge(user.id, challenge, 'registration')
|
||||
|
||||
# Get existing credentials to exclude
|
||||
existing_credentials = cls.get_user_credentials(user)
|
||||
exclude_credentials = []
|
||||
for cred in existing_credentials:
|
||||
if cred.provider_data:
|
||||
cred_id_b64 = cred.provider_data.get("credential_id")
|
||||
if cred_id_b64:
|
||||
try:
|
||||
cred_id = cls._base64url_decode(cred_id_b64)
|
||||
transports = cred.provider_data.get("transports", [])
|
||||
exclude_credentials.append({
|
||||
"id": cred_id_b64,
|
||||
"type": "public-key",
|
||||
"transports": transports
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get RP configuration
|
||||
rp_id = current_app.config.get('WEBAUTHN_RP_ID', 'localhost')
|
||||
rp_name = current_app.config.get('WEBAUTHN_RP_NAME', 'Gatehouse')
|
||||
|
||||
# Generate user ID (Base64URL encoded)
|
||||
user_id = cls._base64url_encode(user.id.encode('utf-8'))
|
||||
|
||||
# Build options
|
||||
options = {
|
||||
"rp": {
|
||||
"name": rp_name,
|
||||
"id": rp_id
|
||||
},
|
||||
"user": {
|
||||
"id": user_id,
|
||||
"name": user.email,
|
||||
"displayName": user.full_name or user.email
|
||||
},
|
||||
"challenge": challenge,
|
||||
"pubKeyCredParams": [
|
||||
{"type": "public-key", "alg": -7}, # ES256
|
||||
{"type": "public-key", "alg": -257} # RS256
|
||||
],
|
||||
"timeout": 60000, # 60 seconds
|
||||
"excludeCredentials": exclude_credentials,
|
||||
"authenticatorSelection": {
|
||||
"residentKey": "preferred",
|
||||
"userVerification": "preferred"
|
||||
},
|
||||
"attestation": "none"
|
||||
}
|
||||
|
||||
# Log audit event
|
||||
AuditService.log_action(
|
||||
action=AuditAction.WEBAUTHN_REGISTER_INITIATED,
|
||||
user_id=user.id,
|
||||
description="WebAuthn registration initiated"
|
||||
)
|
||||
|
||||
return options
|
||||
|
||||
@classmethod
|
||||
def verify_registration_response(
|
||||
cls,
|
||||
user: User,
|
||||
credential_data: Dict[str, Any],
|
||||
challenge: str
|
||||
) -> AuthenticationMethod:
|
||||
"""Verify and store a new passkey credential.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
credential_data: Credential response data from client
|
||||
challenge: The original challenge string
|
||||
|
||||
Returns:
|
||||
AuthenticationMethod instance
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If verification fails
|
||||
"""
|
||||
# Verify and consume challenge
|
||||
stored_challenge = cls._get_and_delete_challenge(user.id, challenge, 'registration')
|
||||
if not stored_challenge:
|
||||
AuditService.log_action(
|
||||
action=AuditAction.WEBAUTHN_REGISTER_FAILED,
|
||||
user_id=user.id,
|
||||
description="Registration failed: challenge expired or invalid"
|
||||
)
|
||||
raise InvalidCredentialsError("Challenge expired or invalid")
|
||||
|
||||
try:
|
||||
# Parse credential data
|
||||
credential_id = credential_data.get("id")
|
||||
raw_id = credential_data.get("rawId")
|
||||
response = credential_data.get("response", {})
|
||||
attestation_object_b64 = response.get("attestationObject")
|
||||
client_data_json_b64 = response.get("clientDataJSON")
|
||||
transports = credential_data.get("transports", ["platform"])
|
||||
|
||||
if not all([credential_id, raw_id, attestation_object_b64, client_data_json_b64]):
|
||||
raise InvalidCredentialsError("Missing required credential data")
|
||||
|
||||
# Decode attestation object
|
||||
attestation_object = cls._base64url_decode(attestation_object_b64)
|
||||
|
||||
# Parse CBOR attestation object (simplified - in production use cbor2 library)
|
||||
# The attestation object contains: authData, attStmt, fmt
|
||||
try:
|
||||
import cbor2
|
||||
attestation_dict = cbor2.loads(attestation_object)
|
||||
except ImportError:
|
||||
# Fallback: try to parse as simple structure
|
||||
attestation_dict = {}
|
||||
logger.warning("cbor2 library not available, using fallback parsing")
|
||||
|
||||
# Extract authenticator data
|
||||
auth_data = attestation_dict.get('authData', b'')
|
||||
|
||||
# Parse authenticator data
|
||||
# Format: RP ID hash (32 bytes) + Flags (1 byte) + Counter (4 bytes) + AAGUID (16 bytes) + Credential ID length (2 bytes) + Credential ID + Public key
|
||||
if len(auth_data) < 37:
|
||||
raise InvalidCredentialsError("Invalid authenticator data")
|
||||
|
||||
rp_id_hash = auth_data[:32]
|
||||
flags = auth_data[32]
|
||||
counter = int.from_bytes(auth_data[33:37], 'big')
|
||||
aaguid = auth_data[37:53] if len(auth_data) >= 53 else b''
|
||||
|
||||
# Extract credential ID length and ID
|
||||
cred_id_length = int.from_bytes(auth_data[53:55], 'big') if len(auth_data) >= 55 else 0
|
||||
credential_id_raw = auth_data[55:55+cred_id_length] if cred_id_length > 0 else b''
|
||||
|
||||
# Extract public key (COSE format)
|
||||
public_key_cose = auth_data[55+cred_id_length:]
|
||||
|
||||
# Verify client data
|
||||
client_data_json = cls._base64url_decode(client_data_json_b64)
|
||||
client_data = json.loads(client_data_json)
|
||||
|
||||
# Verify challenge matches
|
||||
if client_data.get("challenge") != challenge:
|
||||
raise InvalidCredentialsError("Challenge mismatch")
|
||||
|
||||
# Verify origin
|
||||
expected_origin = current_app.config.get('WEBAUTHN_ORIGIN', 'http://localhost:5173')
|
||||
if client_data.get("origin") != expected_origin:
|
||||
logger.warning(f"Origin mismatch: expected {expected_origin}, got {client_data.get('origin')}")
|
||||
# Don't fail on origin mismatch in development
|
||||
|
||||
# Verify user presence and verification
|
||||
user_present = bool(flags & 0x01)
|
||||
user_verified = bool(flags & 0x04)
|
||||
|
||||
if not user_present:
|
||||
raise InvalidCredentialsError("User presence not verified")
|
||||
|
||||
# Store credential
|
||||
credential_id_hash = cls._hash_credential_id(credential_id_raw)
|
||||
|
||||
# Check if credential already exists
|
||||
existing = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if existing and existing.provider_data:
|
||||
stored_cred_id = existing.provider_data.get("credential_id", "")
|
||||
if stored_cred_id == credential_id:
|
||||
raise InvalidCredentialsError("Credential already registered")
|
||||
|
||||
# Create or update authentication method
|
||||
auth_method = existing or AuthenticationMethod(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
is_primary=False,
|
||||
verified=True
|
||||
)
|
||||
|
||||
# Store credential data
|
||||
auth_method.provider_data = {
|
||||
"credential_id": credential_id,
|
||||
"credential_id_hash": credential_id_hash,
|
||||
"public_key_cose": cls._base64url_encode(public_key_cose),
|
||||
"sign_count": counter,
|
||||
"transports": transports,
|
||||
"aaguid": cls._base64url_encode(aaguid) if aaguid else None,
|
||||
"attestation_format": attestation_dict.get('fmt', 'unknown'),
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"last_used_at": None,
|
||||
"name": f"Passkey {datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
|
||||
}
|
||||
|
||||
auth_method.save()
|
||||
|
||||
# Log audit event
|
||||
AuditService.log_action(
|
||||
action=AuditAction.WEBAUTHN_REGISTER_COMPLETED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description=f"WebAuthn credential registered: {credential_id[:16]}..."
|
||||
)
|
||||
|
||||
return auth_method
|
||||
|
||||
except InvalidCredentialsError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"WebAuthn registration verification failed: {e}")
|
||||
AuditService.log_action(
|
||||
action=AuditAction.WEBAUTHN_REGISTER_FAILED,
|
||||
user_id=user.id,
|
||||
description=f"Registration failed: {str(e)}"
|
||||
)
|
||||
raise InvalidCredentialsError("Registration verification failed")
|
||||
|
||||
@classmethod
|
||||
def generate_authentication_challenge(cls, user: User) -> Dict[str, Any]:
|
||||
"""Generate a challenge for passkey authentication.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
PublicKeyCredentialRequestOptions dict
|
||||
"""
|
||||
# Generate challenge
|
||||
challenge = cls._generate_challenge()
|
||||
|
||||
# Store challenge
|
||||
cls._store_challenge(user.id, challenge, 'authentication')
|
||||
|
||||
# Get user's credentials
|
||||
credentials = cls.get_user_credentials(user)
|
||||
|
||||
# Build allow credentials list
|
||||
allow_credentials = []
|
||||
for cred in credentials:
|
||||
if cred.provider_data:
|
||||
cred_id = cred.provider_data.get("credential_id")
|
||||
transports = cred.provider_data.get("transports", [])
|
||||
if cred_id:
|
||||
allow_credentials.append({
|
||||
"id": cred_id,
|
||||
"type": "public-key",
|
||||
"transports": transports
|
||||
})
|
||||
|
||||
# Get RP configuration
|
||||
rp_id = current_app.config.get('WEBAUTHN_RP_ID', 'localhost')
|
||||
|
||||
# Build options
|
||||
options = {
|
||||
"challenge": challenge,
|
||||
"timeout": 60000,
|
||||
"rpId": rp_id,
|
||||
"allowCredentials": allow_credentials,
|
||||
"userVerification": "preferred"
|
||||
}
|
||||
|
||||
# Log audit event
|
||||
AuditService.log_action(
|
||||
action=AuditAction.WEBAUTHN_LOGIN_INITIATED,
|
||||
user_id=user.id,
|
||||
description="WebAuthn authentication initiated"
|
||||
)
|
||||
|
||||
return options
|
||||
|
||||
@classmethod
|
||||
def verify_authentication_response(
|
||||
cls,
|
||||
user: User,
|
||||
credential_data: Dict[str, Any],
|
||||
challenge: str
|
||||
) -> AuthenticationMethod:
|
||||
"""Verify passkey authentication response.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
credential_data: Assertion response data from client
|
||||
challenge: The original challenge string
|
||||
|
||||
Returns:
|
||||
AuthenticationMethod instance
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If verification fails
|
||||
"""
|
||||
# Verify and consume challenge
|
||||
stored_challenge = cls._get_and_delete_challenge(user.id, challenge, 'authentication')
|
||||
if not stored_challenge:
|
||||
AuditService.log_action(
|
||||
action=AuditAction.WEBAUTHN_LOGIN_FAILED,
|
||||
user_id=user.id,
|
||||
description="Authentication failed: challenge expired or invalid"
|
||||
)
|
||||
raise InvalidCredentialsError("Challenge expired or invalid")
|
||||
|
||||
try:
|
||||
# Parse credential data
|
||||
credential_id = credential_data.get("id")
|
||||
raw_id = credential_data.get("rawId")
|
||||
response = credential_data.get("response", {})
|
||||
authenticator_data_b64 = response.get("authenticatorData")
|
||||
client_data_json_b64 = response.get("clientDataJSON")
|
||||
signature_b64 = response.get("signature")
|
||||
|
||||
if not all([credential_id, authenticator_data_b64, client_data_json_b64, signature_b64]):
|
||||
raise InvalidCredentialsError("Missing required credential data")
|
||||
|
||||
# Find the credential
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not auth_method or not auth_method.provider_data:
|
||||
raise InvalidCredentialsError("No passkey found for user")
|
||||
|
||||
stored_cred_id = auth_method.provider_data.get("credential_id")
|
||||
if stored_cred_id != credential_id:
|
||||
raise InvalidCredentialsError("Credential not found")
|
||||
|
||||
# Decode authenticator data
|
||||
authenticator_data = cls._base64url_decode(authenticator_data_b64)
|
||||
|
||||
# Parse authenticator data
|
||||
if len(authenticator_data) < 37:
|
||||
raise InvalidCredentialsError("Invalid authenticator data")
|
||||
|
||||
rp_id_hash = authenticator_data[:32]
|
||||
flags = authenticator_data[32]
|
||||
counter = int.from_bytes(authenticator_data[33:37], 'big')
|
||||
|
||||
# Verify client data
|
||||
client_data_json = cls._base64url_decode(client_data_json_b64)
|
||||
client_data = json.loads(client_data_json)
|
||||
|
||||
# Verify challenge matches
|
||||
if client_data.get("challenge") != challenge:
|
||||
raise InvalidCredentialsError("Challenge mismatch")
|
||||
|
||||
# Verify origin
|
||||
expected_origin = current_app.config.get('WEBAUTHN_ORIGIN', 'http://localhost:5173')
|
||||
if client_data.get("origin") != expected_origin:
|
||||
logger.warning(f"Origin mismatch: expected {expected_origin}, got {client_data.get('origin')}")
|
||||
|
||||
# Verify user presence
|
||||
user_present = bool(flags & 0x01)
|
||||
if not user_present:
|
||||
raise InvalidCredentialsError("User presence not verified")
|
||||
|
||||
# Verify counter (prevent replay attacks)
|
||||
stored_counter = auth_method.provider_data.get("sign_count", 0)
|
||||
if counter <= stored_counter:
|
||||
raise InvalidCredentialsError("Invalid sign counter - potential credential cloning detected")
|
||||
|
||||
# Verify signature (simplified - in production use proper crypto verification)
|
||||
# In a full implementation, you would:
|
||||
# 1. Decode the public key from COSE format
|
||||
# 2. Verify the signature using the stored public key
|
||||
# 3. Verify the authenticator data hash matches RP ID
|
||||
|
||||
# For now, we'll trust the authenticator's signature verification
|
||||
# A full implementation would use the fido2 library
|
||||
|
||||
# Update counter and last used time
|
||||
auth_method.provider_data["sign_count"] = counter
|
||||
auth_method.provider_data["last_used_at"] = datetime.now(timezone.utc).isoformat()
|
||||
auth_method.last_used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
# Log audit event
|
||||
AuditService.log_action(
|
||||
action=AuditAction.WEBAUTHN_LOGIN_SUCCESS,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="WebAuthn authentication successful"
|
||||
)
|
||||
|
||||
return auth_method
|
||||
|
||||
except InvalidCredentialsError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"WebAuthn authentication verification failed: {e}")
|
||||
AuditService.log_action(
|
||||
action=AuditAction.WEBAUTHN_LOGIN_FAILED,
|
||||
user_id=user.id,
|
||||
description=f"Authentication failed: {str(e)}"
|
||||
)
|
||||
raise InvalidCredentialsError("Authentication verification failed")
|
||||
|
||||
@classmethod
|
||||
def get_user_credentials(cls, user: User) -> List[AuthenticationMethod]:
|
||||
"""Get all passkey credentials for a user.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
List of AuthenticationMethod instances
|
||||
"""
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None
|
||||
).order_by(AuthenticationMethod.created_at.desc()).all()
|
||||
|
||||
@classmethod
|
||||
def delete_credential(cls, credential_id: str, user: User) -> bool:
|
||||
"""Delete a passkey credential.
|
||||
|
||||
Args:
|
||||
credential_id: The credential ID to delete
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
True if deleted successfully
|
||||
"""
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not auth_method or not auth_method.provider_data:
|
||||
return False
|
||||
|
||||
stored_cred_id = auth_method.provider_data.get("credential_id")
|
||||
if stored_cred_id != credential_id:
|
||||
return False
|
||||
|
||||
# Soft delete the credential
|
||||
auth_method.delete(soft=True)
|
||||
|
||||
# Log audit event
|
||||
AuditService.log_action(
|
||||
action=AuditAction.WEBAUTHN_CREDENTIAL_DELETED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description=f"WebAuthn credential deleted: {credential_id[:16]}..."
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def rename_credential(cls, credential_id: str, user: User, name: str) -> bool:
|
||||
"""Rename a passkey credential.
|
||||
|
||||
Args:
|
||||
credential_id: The credential ID to rename
|
||||
user: User instance
|
||||
name: New name for the credential
|
||||
|
||||
Returns:
|
||||
True if renamed successfully
|
||||
"""
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not auth_method or not auth_method.provider_data:
|
||||
return False
|
||||
|
||||
stored_cred_id = auth_method.provider_data.get("credential_id")
|
||||
if stored_cred_id != credential_id:
|
||||
return False
|
||||
|
||||
# Update name
|
||||
auth_method.provider_data["name"] = name
|
||||
db.session.commit()
|
||||
|
||||
# Log audit event
|
||||
AuditService.log_action(
|
||||
action=AuditAction.WEBAUTHN_CREDENTIAL_RENAMED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description=f"WebAuthn credential renamed to: {name}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_credential_by_id(cls, credential_id: str, user: User) -> Optional[AuthenticationMethod]:
|
||||
"""Get a specific credential by ID.
|
||||
|
||||
Args:
|
||||
credential_id: The credential ID
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
AuthenticationMethod instance or None
|
||||
"""
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if auth_method and auth_method.provider_data:
|
||||
stored_cred_id = auth_method.provider_data.get("credential_id")
|
||||
if stored_cred_id == credential_id:
|
||||
return auth_method
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Utilities package."""
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.utils.constants import (
|
||||
UserStatus,
|
||||
OrganizationRole,
|
||||
AuthMethodType,
|
||||
SessionStatus,
|
||||
AuditAction,
|
||||
ErrorType,
|
||||
)
|
||||
from gatehouse_app.utils.decorators import login_required, require_role, require_owner, require_admin
|
||||
|
||||
__all__ = [
|
||||
"api_response",
|
||||
"UserStatus",
|
||||
"OrganizationRole",
|
||||
"AuthMethodType",
|
||||
"SessionStatus",
|
||||
"AuditAction",
|
||||
"ErrorType",
|
||||
"login_required",
|
||||
"require_role",
|
||||
"require_owner",
|
||||
"require_admin",
|
||||
]
|
||||
@@ -0,0 +1,118 @@
|
||||
"""Application constants and enums."""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
"""User account status."""
|
||||
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
SUSPENDED = "suspended"
|
||||
PENDING = "pending"
|
||||
|
||||
|
||||
class OrganizationRole(str, Enum):
|
||||
"""Organization member roles."""
|
||||
|
||||
OWNER = "owner"
|
||||
ADMIN = "admin"
|
||||
MEMBER = "member"
|
||||
GUEST = "guest"
|
||||
|
||||
|
||||
class AuthMethodType(str, Enum):
|
||||
"""Authentication method types."""
|
||||
|
||||
PASSWORD = "password"
|
||||
TOTP = "totp"
|
||||
GOOGLE = "google"
|
||||
GITHUB = "github"
|
||||
MICROSOFT = "microsoft"
|
||||
SAML = "saml"
|
||||
OIDC = "oidc"
|
||||
WEBAUTHN = "webauthn"
|
||||
|
||||
|
||||
class SessionStatus(str, Enum):
|
||||
"""Session status."""
|
||||
|
||||
ACTIVE = "active"
|
||||
EXPIRED = "expired"
|
||||
REVOKED = "revoked"
|
||||
|
||||
|
||||
class AuditAction(str, Enum):
|
||||
"""Audit log action types."""
|
||||
|
||||
# User actions
|
||||
USER_LOGIN = "user.login"
|
||||
USER_LOGOUT = "user.logout"
|
||||
USER_REGISTER = "user.register"
|
||||
USER_UPDATE = "user.update"
|
||||
USER_DELETE = "user.delete"
|
||||
PASSWORD_CHANGE = "user.password_change"
|
||||
PASSWORD_RESET = "user.password_reset"
|
||||
|
||||
# Organization actions
|
||||
ORG_CREATE = "org.create"
|
||||
ORG_UPDATE = "org.update"
|
||||
ORG_DELETE = "org.delete"
|
||||
ORG_MEMBER_ADD = "org.member.add"
|
||||
ORG_MEMBER_REMOVE = "org.member.remove"
|
||||
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
|
||||
|
||||
# Session actions
|
||||
SESSION_CREATE = "session.create"
|
||||
SESSION_REVOKE = "session.revoke"
|
||||
|
||||
# Auth method actions
|
||||
AUTH_METHOD_ADD = "auth.method.add"
|
||||
AUTH_METHOD_REMOVE = "auth.method.remove"
|
||||
TOTP_ENROLL_INITIATED = "totp.enroll.initiated"
|
||||
TOTP_ENROLL_COMPLETED = "totp.enroll.completed"
|
||||
TOTP_VERIFY_SUCCESS = "totp.verify.success"
|
||||
TOTP_VERIFY_FAILED = "totp.verify.failed"
|
||||
TOTP_DISABLED = "totp.disabled"
|
||||
TOTP_BACKUP_CODE_USED = "totp.backup_code.used"
|
||||
TOTP_BACKUP_CODES_REGENERATED = "totp.backup_codes.regenerated"
|
||||
|
||||
# WebAuthn actions
|
||||
WEBAUTHN_REGISTER_INITIATED = "webauthn.register.initiated"
|
||||
WEBAUTHN_REGISTER_COMPLETED = "webauthn.register.completed"
|
||||
WEBAUTHN_REGISTER_FAILED = "webauthn.register.failed"
|
||||
WEBAUTHN_LOGIN_INITIATED = "webauthn.login.initiated"
|
||||
WEBAUTHN_LOGIN_SUCCESS = "webauthn.login.success"
|
||||
WEBAUTHN_LOGIN_FAILED = "webauthn.login.failed"
|
||||
WEBAUTHN_CREDENTIAL_DELETED = "webauthn.credential.deleted"
|
||||
WEBAUTHN_CREDENTIAL_RENAMED = "webauthn.credential.renamed"
|
||||
|
||||
|
||||
class OIDCGrantType(str, Enum):
|
||||
"""OIDC grant types."""
|
||||
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
IMPLICIT = "implicit"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
|
||||
|
||||
class OIDCResponseType(str, Enum):
|
||||
"""OIDC response types."""
|
||||
|
||||
CODE = "code"
|
||||
TOKEN = "token"
|
||||
ID_TOKEN = "id_token"
|
||||
|
||||
|
||||
# Error type constants
|
||||
class ErrorType:
|
||||
"""Error type constants for API responses."""
|
||||
|
||||
VALIDATION_ERROR = "VALIDATION_ERROR"
|
||||
AUTHENTICATION_ERROR = "AUTHENTICATION_ERROR"
|
||||
AUTHORIZATION_ERROR = "AUTHORIZATION_ERROR"
|
||||
NOT_FOUND = "NOT_FOUND"
|
||||
CONFLICT = "CONFLICT"
|
||||
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"
|
||||
INTERNAL_ERROR = "INTERNAL_ERROR"
|
||||
BAD_REQUEST = "BAD_REQUEST"
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Custom decorators for authentication and authorization."""
|
||||
from functools import wraps
|
||||
from flask import request, g
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
|
||||
def login_required(f):
|
||||
"""Decorator to require Bearer token authentication.
|
||||
|
||||
Extracts token from Authorization: Bearer {token} header,
|
||||
validates the session, and sets g.current_user and g.current_session.
|
||||
"""
|
||||
from gatehouse_app.services.session_service import SessionService
|
||||
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
# Extract token from Authorization header
|
||||
auth_header = request.headers.get('Authorization')
|
||||
|
||||
if not auth_header:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Authorization header is required",
|
||||
status=401,
|
||||
error_type="AUTH_REQUIRED"
|
||||
)
|
||||
|
||||
# Expect format: "Bearer {token}"
|
||||
parts = auth_header.split()
|
||||
if len(parts) != 2 or parts[0].lower() != 'bearer':
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Invalid authorization format. Use: Bearer {token}",
|
||||
status=401,
|
||||
error_type="INVALID_AUTH_FORMAT"
|
||||
)
|
||||
|
||||
token = parts[1]
|
||||
|
||||
# Get active session by token
|
||||
session = SessionService.get_active_session_by_token(token)
|
||||
|
||||
if not session:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Invalid or expired session",
|
||||
status=401,
|
||||
error_type="INVALID_TOKEN"
|
||||
)
|
||||
|
||||
# Validate session is active
|
||||
if not session.is_active():
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Session is no longer active",
|
||||
status=401,
|
||||
error_type="SESSION_INACTIVE"
|
||||
)
|
||||
|
||||
# Update last_activity_at timestamp
|
||||
from datetime import datetime, timezone
|
||||
session.last_activity_at = datetime.now(timezone.utc)
|
||||
from gatehouse_app import db
|
||||
db.session.commit()
|
||||
|
||||
# Set context variables
|
||||
g.current_user = session.user
|
||||
g.current_session = session
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
|
||||
def require_role(*allowed_roles):
|
||||
"""
|
||||
Decorator to require specific organization roles.
|
||||
|
||||
Args:
|
||||
*allowed_roles: Variable number of OrganizationRole values
|
||||
|
||||
Raises:
|
||||
ForbiddenError: If user doesn't have required role
|
||||
"""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
# Ensure user is authenticated first
|
||||
if not hasattr(g, "current_user"):
|
||||
raise UnauthorizedError("Authentication required")
|
||||
|
||||
# Get organization_id from kwargs or URL parameters
|
||||
org_id = kwargs.get("org_id") or kwargs.get("organization_id")
|
||||
if not org_id:
|
||||
raise ForbiddenError("Organization context required")
|
||||
|
||||
# Check user's role in the organization
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
|
||||
membership = OrganizationMember.query.filter_by(
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
raise ForbiddenError("Not a member of this organization")
|
||||
|
||||
if membership.role not in allowed_roles:
|
||||
raise ForbiddenError(
|
||||
f"Requires one of the following roles: {', '.join(allowed_roles)}"
|
||||
)
|
||||
|
||||
g.current_membership = membership
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_owner(f):
|
||||
"""Decorator to require organization owner role."""
|
||||
return require_role(OrganizationRole.OWNER)(f)
|
||||
|
||||
|
||||
def require_admin(f):
|
||||
"""Decorator to require organization admin or owner role."""
|
||||
return require_role(OrganizationRole.OWNER, OrganizationRole.ADMIN)(f)
|
||||
@@ -0,0 +1,54 @@
|
||||
"""API response utilities."""
|
||||
from flask import jsonify, g
|
||||
|
||||
|
||||
# Version for the response envelope
|
||||
ENVELOPE_VERSION = "1.0"
|
||||
|
||||
|
||||
def api_response(
|
||||
*,
|
||||
data=None,
|
||||
success=True,
|
||||
message="",
|
||||
status=200,
|
||||
error_type=None,
|
||||
error_details=None,
|
||||
meta=None
|
||||
):
|
||||
"""
|
||||
Create a standardized API response.
|
||||
|
||||
Args:
|
||||
data: Response data (only included if success=True)
|
||||
success: Whether the request was successful
|
||||
message: Human-readable message
|
||||
status: HTTP status code
|
||||
error_type: Type of error (only if success=False)
|
||||
error_details: Additional error details (only if success=False)
|
||||
meta: Additional metadata (pagination, etc.)
|
||||
|
||||
Returns:
|
||||
Tuple of (response, status_code)
|
||||
"""
|
||||
payload = {
|
||||
"version": ENVELOPE_VERSION,
|
||||
"success": success,
|
||||
"code": status,
|
||||
"message": message,
|
||||
"request_id": g.get("request_id", "unknown"),
|
||||
}
|
||||
|
||||
if meta:
|
||||
payload["meta"] = meta
|
||||
|
||||
if success:
|
||||
if data is not None:
|
||||
payload["data"] = data
|
||||
else:
|
||||
payload["error"] = {
|
||||
"type": error_type or "UNKNOWN",
|
||||
"details": error_details or {}
|
||||
}
|
||||
|
||||
return jsonify(payload), status
|
||||
Reference in New Issue
Block a user