Feat: Added CA-merged with Securid-Principals, Depart, Client-CLI
This commit is contained in:
@@ -2,6 +2,9 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env'))
|
||||
|
||||
# Test debug logging - this should appear when running `flask run --debug`
|
||||
_root_logger = logging.getLogger(__name__)
|
||||
_root_logger.debug("[TEST] Debug logging is working!")
|
||||
@@ -239,3 +242,6 @@ def initialize_oidc_jwks(app):
|
||||
app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}")
|
||||
except Exception as e:
|
||||
app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}")
|
||||
|
||||
# Create default app instance for gunicorn/wsgi
|
||||
app = create_app()
|
||||
|
||||
@@ -5,4 +5,6 @@ from flask import Blueprint
|
||||
api_v1_bp = Blueprint("api_v1", __name__)
|
||||
|
||||
# Import route modules to register them
|
||||
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth
|
||||
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh
|
||||
|
||||
api_v1_bp.register_blueprint(ssh.ssh_bp)
|
||||
|
||||
@@ -46,6 +46,33 @@ def _pop_oidc_bridge(oauth_state: str) -> str | None:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _store_cli_redirect(oauth_state: str, redirect_url: str) -> None:
|
||||
"""Store CLI redirect_url keyed by OAuth state (for /token_please flow)."""
|
||||
try:
|
||||
import gatehouse_app.extensions as _ext
|
||||
rc = _ext.redis_client
|
||||
if rc is not None:
|
||||
rc.setex(f"oauth_cli_redirect:{oauth_state}", _OAUTH_BRIDGE_TTL, redirect_url)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _pop_cli_redirect(oauth_state: str) -> str | None:
|
||||
"""Retrieve and delete CLI redirect_url for the given OAuth state."""
|
||||
try:
|
||||
import gatehouse_app.extensions as _ext
|
||||
rc = _ext.redis_client
|
||||
if rc is not None:
|
||||
key = f"oauth_cli_redirect:{oauth_state}"
|
||||
val = rc.get(key)
|
||||
if val:
|
||||
rc.delete(key)
|
||||
return val.decode() if isinstance(val, bytes) else val
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -69,6 +96,71 @@ def get_provider_type(provider: str) -> AuthMethodType:
|
||||
return PROVIDER_TYPE_MAP[provider_lower]
|
||||
|
||||
|
||||
@api_v1_bp.route("/token_please", methods=["GET"])
|
||||
def token_please():
|
||||
"""
|
||||
CLI token acquisition endpoint.
|
||||
|
||||
Initiates an OAuth login flow and, on success, redirects the user's browser
|
||||
to the CLI's local callback server (redirect_url) with the session token
|
||||
appended, e.g.: http://127.0.0.1:8250/?token=<SESSION_TOKEN>
|
||||
|
||||
This endpoint is designed for CLI clients that:
|
||||
1. Start a local HTTP server on LISTENER_SERVER_PORT (e.g. 8250)
|
||||
2. Open a browser to /api/v1/token_please?redirect_url=http://127.0.0.1:8250/?token=
|
||||
3. Wait for the browser to POST the token back to their local server
|
||||
|
||||
Query parameters:
|
||||
redirect_url: Local callback URL where the token will be appended
|
||||
provider: OAuth provider to use (default: 'google')
|
||||
"""
|
||||
from urllib.parse import urlencode
|
||||
from flask import current_app, redirect as flask_redirect
|
||||
|
||||
redirect_url = request.args.get("redirect_url", "").strip()
|
||||
provider = request.args.get("provider", "google").lower()
|
||||
|
||||
if not redirect_url:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="redirect_url query parameter is required",
|
||||
status=400,
|
||||
error_type="MISSING_REDIRECT_URL",
|
||||
)
|
||||
|
||||
# Validate redirect_url is localhost/127.0.0.1 (security: prevent open redirect)
|
||||
from urllib.parse import urlparse as _urlparse
|
||||
parsed = _urlparse(redirect_url)
|
||||
if parsed.hostname not in ("localhost", "127.0.0.1"):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="redirect_url must point to localhost",
|
||||
status=400,
|
||||
error_type="INVALID_REDIRECT_URL",
|
||||
)
|
||||
|
||||
try:
|
||||
provider_type = get_provider_type(provider)
|
||||
auth_url, state = OAuthFlowService.initiate_login_flow(
|
||||
provider_type=provider_type,
|
||||
organization_id=None,
|
||||
redirect_uri=None,
|
||||
)
|
||||
except (OAuthFlowError, ExternalAuthError) as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=getattr(e, "message", str(e)),
|
||||
status=getattr(e, "status_code", 400),
|
||||
error_type=getattr(e, "error_type", "OAUTH_ERROR"),
|
||||
)
|
||||
|
||||
# Store the CLI redirect URL so the callback can use it
|
||||
_store_cli_redirect(state, redirect_url)
|
||||
|
||||
logger.info(f"CLI token_please: provider={provider}, redirect_url={redirect_url}, redirecting to OAuth")
|
||||
return flask_redirect(auth_url, code=302)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider Configuration Endpoints (Admin)
|
||||
# =============================================================================
|
||||
@@ -575,8 +667,6 @@ def initiate_oauth_authorize(provider: str):
|
||||
"state": "state_token"
|
||||
}
|
||||
"""
|
||||
provider_type = get_provider_type(provider)
|
||||
|
||||
# Get query parameters - organization_id is now optional
|
||||
flow = request.args.get("flow", "login")
|
||||
redirect_uri = request.args.get("redirect_uri")
|
||||
@@ -592,7 +682,7 @@ def initiate_oauth_authorize(provider: str):
|
||||
)
|
||||
|
||||
try:
|
||||
# Initiate flow - organization_id is now optional
|
||||
provider_type = get_provider_type(provider)
|
||||
if flow == "login":
|
||||
auth_url, state = OAuthFlowService.initiate_login_flow(
|
||||
provider_type=provider_type,
|
||||
@@ -626,6 +716,13 @@ def initiate_oauth_authorize(provider: str):
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
except ExternalAuthError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=e.message,
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/external/<provider>/callback", methods=["GET"])
|
||||
@@ -666,8 +763,19 @@ def handle_oauth_callback(provider: str):
|
||||
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080")
|
||||
frontend_callback = f"{frontend_url}/oauth/callback"
|
||||
|
||||
# Check if this is a CLI /token_please flow — retrieve stored redirect_url
|
||||
cli_redirect_url = _pop_cli_redirect(state) if state else None
|
||||
|
||||
def redirect_error(message: str, error_type: str = "OAUTH_ERROR"):
|
||||
"""Redirect to frontend with error params."""
|
||||
"""Redirect to frontend (or CLI) with error params."""
|
||||
if cli_redirect_url:
|
||||
# CLI flow: return a plain error page instead of redirecting back
|
||||
from flask import make_response
|
||||
return make_response(
|
||||
f"<html><body><h2>Authentication Error</h2><p>{message}</p>"
|
||||
f"<p>You may close this window.</p></body></html>",
|
||||
400,
|
||||
)
|
||||
params = {"error": message, "error_type": error_type}
|
||||
if state:
|
||||
params["state"] = state
|
||||
@@ -706,8 +814,11 @@ def handle_oauth_callback(provider: str):
|
||||
# Recover oidc_session_id if this was triggered from an OIDC bridge flow
|
||||
oidc_session_id = _pop_oidc_bridge(state)
|
||||
|
||||
# Organization selection / creation flows are not supported in CLI mode
|
||||
# (fall through to token redirect with whatever session we have)
|
||||
|
||||
# Organization selection needed (user belongs to multiple orgs)
|
||||
if result.get("requires_org_selection"):
|
||||
if result.get("requires_org_selection") and not cli_redirect_url:
|
||||
import json
|
||||
orgs = json.dumps(result.get("available_organizations", []))
|
||||
params = {
|
||||
@@ -722,7 +833,7 @@ def handle_oauth_callback(provider: str):
|
||||
return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302)
|
||||
|
||||
# Organization creation needed (new user via OAuth with no org)
|
||||
if result.get("requires_org_creation"):
|
||||
if result.get("requires_org_creation") and not cli_redirect_url:
|
||||
params = {
|
||||
"requires_org_creation": "1",
|
||||
"state": result["state"],
|
||||
@@ -751,6 +862,19 @@ def handle_oauth_callback(provider: str):
|
||||
user_info = result.get("user", {})
|
||||
if user_info.get("email"):
|
||||
params["email"] = user_info["email"]
|
||||
|
||||
# ── CLI /token_please flow: redirect to the CLI's local callback ─────
|
||||
if cli_redirect_url:
|
||||
# The CLI expects: http://127.0.0.1:8250/?token=<TOKEN>
|
||||
# cli_redirect_url already ends with "token=" so just append the value
|
||||
cli_final_url = cli_redirect_url + token
|
||||
logger.info(
|
||||
f"CLI token_please success: provider={provider}, user={user_info.get('email')}, "
|
||||
f"redirecting to CLI callback"
|
||||
)
|
||||
return flask_redirect(cli_final_url, code=302)
|
||||
|
||||
# ── Frontend flow ─────────────────────────────────────────────────────
|
||||
# Pass oidc_session_id through so the frontend can complete the OIDC flow
|
||||
if oidc_session_id:
|
||||
params["oidc_session_id"] = oidc_session_id
|
||||
|
||||
@@ -13,8 +13,6 @@ from gatehouse_app.schemas.organization_schema import (
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
########jb- need to implement departs, principals
|
||||
@api_v1_bp.route("/organizations", methods=["POST"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
@@ -378,3 +376,557 @@ def update_member_role(org_id, user_id):
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/audit-logs", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def get_organization_audit_logs(org_id):
|
||||
"""
|
||||
Get audit logs for an organization.
|
||||
|
||||
Query params:
|
||||
page: Page number (default 1)
|
||||
per_page: Results per page (default 50, max 200)
|
||||
action: Filter by action type
|
||||
|
||||
Returns:
|
||||
200: List of audit log entries
|
||||
401: Not authenticated
|
||||
403: Not a member / insufficient permissions
|
||||
404: Organization not found
|
||||
"""
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
|
||||
# Ensure org exists and user is a member (full_access_required handles this)
|
||||
OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
page = int(request.args.get("page", 1))
|
||||
per_page = min(int(request.args.get("per_page", 50)), 200)
|
||||
action_filter = request.args.get("action")
|
||||
|
||||
query = AuditLog.query.filter_by(organization_id=org_id)
|
||||
if action_filter:
|
||||
query = query.filter_by(action=action_filter)
|
||||
|
||||
query = query.order_by(AuditLog.created_at.desc())
|
||||
total = query.count()
|
||||
logs = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
def log_to_dict(log):
|
||||
return {
|
||||
"id": log.id,
|
||||
"action": log.action.value if log.action else None,
|
||||
"user_id": log.user_id,
|
||||
"user_email": log.user.email if log.user else None,
|
||||
"user": {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} if log.user else None,
|
||||
"organization_id": log.organization_id,
|
||||
"resource_type": log.resource_type,
|
||||
"resource_id": log.resource_id,
|
||||
"ip_address": log.ip_address,
|
||||
"user_agent": log.user_agent,
|
||||
"request_id": log.request_id,
|
||||
"description": log.description,
|
||||
"success": log.success,
|
||||
"error_message": log.error_message,
|
||||
"metadata": log.extra_data,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
"updated_at": log.updated_at.isoformat() if log.updated_at else None,
|
||||
}
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"audit_logs": [log_to_dict(log) for log in logs],
|
||||
"count": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page,
|
||||
},
|
||||
message="Audit logs retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Organization Invite Tokens
|
||||
# ============================================================================
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/invites", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def create_org_invite(org_id):
|
||||
"""Create an invite token for an organization.
|
||||
|
||||
Request body:
|
||||
email: Email address to invite
|
||||
role: Role to assign (default: member)
|
||||
|
||||
Returns:
|
||||
201: Invite created
|
||||
400: Validation error
|
||||
403: Not an admin
|
||||
404: Organization not found
|
||||
"""
|
||||
from gatehouse_app.models import OrgInviteToken, Organization
|
||||
from gatehouse_app.services.notification_service import NotificationService
|
||||
from flask import current_app
|
||||
|
||||
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
|
||||
if not org:
|
||||
return api_response(success=False, message="Organization not found", status=404)
|
||||
|
||||
data = request.get_json() or {}
|
||||
email = (data.get("email") or "").strip().lower()
|
||||
role = (data.get("role") or "member").strip()
|
||||
|
||||
if not email:
|
||||
return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
invite = OrgInviteToken.generate(
|
||||
organization_id=org_id,
|
||||
email=email,
|
||||
role=role,
|
||||
invited_by_id=g.current_user.id,
|
||||
)
|
||||
|
||||
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
|
||||
invite_link = f"{app_url}/invite?token={invite.token}"
|
||||
|
||||
NotificationService._send_email(
|
||||
to_address=email,
|
||||
subject=f"You're invited to join {org.name} on Gatehouse",
|
||||
body=(
|
||||
f"You've been invited to join {org.name} on Gatehouse.\n\n"
|
||||
f"Click the link below to accept the invitation (valid for 7 days):\n"
|
||||
f"{invite_link}\n\n"
|
||||
f"Gatehouse Security Team"
|
||||
),
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"invite": {"id": invite.id, "email": invite.email, "role": invite.role, "expires_at": invite.expires_at.isoformat() + "Z"}},
|
||||
message="Invite sent successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/invites/<token>", methods=["GET"])
|
||||
def get_invite(token):
|
||||
"""Get invite details by token.
|
||||
|
||||
Returns:
|
||||
200: Invite details (org name, email)
|
||||
400: Invalid or expired token
|
||||
"""
|
||||
from gatehouse_app.models import OrgInviteToken
|
||||
|
||||
invite = OrgInviteToken.query.filter_by(token=token).first()
|
||||
if not invite or not invite.is_valid:
|
||||
return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN")
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"email": invite.email,
|
||||
"organization": {"id": invite.organization_id, "name": invite.organization.name},
|
||||
"role": invite.role,
|
||||
},
|
||||
message="Invite found",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/invites/<token>/accept", methods=["POST"])
|
||||
def accept_invite(token):
|
||||
"""Accept an organization invite.
|
||||
|
||||
Creates the user account (if not already registered) and adds them
|
||||
to the organization.
|
||||
|
||||
Request body:
|
||||
full_name: User's display name
|
||||
password: Password for new account (if not already registered)
|
||||
password_confirm: Password confirmation
|
||||
|
||||
Returns:
|
||||
200: Invite accepted, returns user token
|
||||
400: Invalid/expired token or validation error
|
||||
409: Already a member
|
||||
"""
|
||||
from gatehouse_app.models import OrgInviteToken, User
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
invite = OrgInviteToken.query.filter_by(token=token).first()
|
||||
if not invite or not invite.is_valid:
|
||||
return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN")
|
||||
|
||||
data = request.get_json() or {}
|
||||
full_name = data.get("full_name") or ""
|
||||
password = data.get("password") or ""
|
||||
password_confirm = data.get("password_confirm") or ""
|
||||
|
||||
user = User.query.filter_by(email=invite.email, deleted_at=None).first()
|
||||
|
||||
if not user:
|
||||
# Register a new user
|
||||
if not password:
|
||||
return api_response(success=False, message="Password is required for new accounts.", status=400, error_type="VALIDATION_ERROR")
|
||||
if password != password_confirm:
|
||||
return api_response(success=False, message="Passwords do not match.", status=400, error_type="VALIDATION_ERROR")
|
||||
if len(password) < 8:
|
||||
return api_response(success=False, message="Password must be at least 8 characters.", status=400, error_type="VALIDATION_ERROR")
|
||||
try:
|
||||
user = AuthService.register_user(email=invite.email, password=password, full_name=full_name or None)
|
||||
except Exception as exc:
|
||||
return api_response(success=False, message=str(exc), status=400, error_type="REGISTRATION_ERROR")
|
||||
|
||||
# Add to org
|
||||
role_value = invite.role
|
||||
try:
|
||||
org_role = OrganizationRole(role_value)
|
||||
except ValueError:
|
||||
org_role = OrganizationRole.MEMBER
|
||||
|
||||
try:
|
||||
OrganizationService.add_member(
|
||||
org=invite.organization,
|
||||
user_id=user.id,
|
||||
role=org_role,
|
||||
inviter_id=invite.invited_by_id,
|
||||
)
|
||||
except Exception:
|
||||
pass # Already a member is fine
|
||||
|
||||
invite.accept()
|
||||
|
||||
user_session = AuthService.create_session(user)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"user": user.to_dict(),
|
||||
"token": user_session.token,
|
||||
"expires_at": user_session.expires_at.isoformat() + "Z",
|
||||
},
|
||||
message="Invitation accepted. Welcome!",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Organization OIDC Clients
|
||||
# ============================================================================
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/clients", methods=["GET"])
|
||||
@login_required
|
||||
def list_org_clients(org_id):
|
||||
"""List OIDC clients for an organization.
|
||||
|
||||
Returns:
|
||||
200: List of OIDC clients
|
||||
403: Not a member
|
||||
404: Organization not found
|
||||
"""
|
||||
from gatehouse_app.models import OIDCClient, Organization
|
||||
|
||||
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
|
||||
if not org:
|
||||
return api_response(success=False, message="Organization not found", status=404)
|
||||
|
||||
clients = OIDCClient.query.filter_by(organization_id=org_id, is_active=True).all()
|
||||
|
||||
def client_to_dict(c):
|
||||
return {
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"client_id": c.client_id,
|
||||
"redirect_uris": c.redirect_uris,
|
||||
"scopes": c.scopes,
|
||||
"grant_types": c.grant_types,
|
||||
"is_active": c.is_active,
|
||||
"created_at": c.created_at.isoformat() + "Z",
|
||||
}
|
||||
|
||||
return api_response(
|
||||
data={"clients": [client_to_dict(c) for c in clients], "count": len(clients)},
|
||||
message="Clients retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/clients", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def create_org_client(org_id):
|
||||
"""Create a new OIDC client for an organization.
|
||||
|
||||
Request body:
|
||||
name: Client name
|
||||
redirect_uris: List of allowed redirect URIs (newline or comma separated string)
|
||||
|
||||
Returns:
|
||||
201: Client created with client_id and client_secret
|
||||
403: Not an admin
|
||||
404: Organization not found
|
||||
"""
|
||||
import secrets as _secrets
|
||||
from gatehouse_app.extensions import bcrypt
|
||||
from gatehouse_app.models import OIDCClient, Organization
|
||||
|
||||
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
|
||||
if not org:
|
||||
return api_response(success=False, message="Organization not found", status=404)
|
||||
|
||||
data = request.get_json() or {}
|
||||
name = (data.get("name") or "").strip()
|
||||
redirect_uris_raw = data.get("redirect_uris") or []
|
||||
|
||||
if not name:
|
||||
return api_response(success=False, message="Client name is required", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
if isinstance(redirect_uris_raw, str):
|
||||
redirect_uris = [u.strip() for u in redirect_uris_raw.replace(",", "\n").splitlines() if u.strip()]
|
||||
else:
|
||||
redirect_uris = [u.strip() for u in redirect_uris_raw if isinstance(u, str) and u.strip()]
|
||||
|
||||
if not redirect_uris:
|
||||
return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
client_id = _secrets.token_hex(16)
|
||||
client_secret = _secrets.token_urlsafe(32)
|
||||
|
||||
client = OIDCClient(
|
||||
organization_id=org_id,
|
||||
name=name,
|
||||
client_id=client_id,
|
||||
client_secret_hash=bcrypt.generate_password_hash(client_secret).decode("utf-8"),
|
||||
redirect_uris=redirect_uris,
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
scopes=["openid", "profile", "email"],
|
||||
is_active=True,
|
||||
is_confidential=True,
|
||||
)
|
||||
from gatehouse_app.extensions import db
|
||||
db.session.add(client)
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"client": {
|
||||
"id": client.id,
|
||||
"name": client.name,
|
||||
"client_id": client.client_id,
|
||||
"client_secret": client_secret, # Only returned once
|
||||
"redirect_uris": client.redirect_uris,
|
||||
"scopes": client.scopes,
|
||||
"created_at": client.created_at.isoformat() + "Z",
|
||||
}
|
||||
},
|
||||
message="OIDC client created successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/clients/<client_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def delete_org_client(org_id, client_id):
|
||||
"""Deactivate an OIDC client.
|
||||
|
||||
Returns:
|
||||
200: Client deactivated
|
||||
403: Not an admin
|
||||
404: Client not found
|
||||
"""
|
||||
from gatehouse_app.models import OIDCClient
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
client = OIDCClient.query.filter_by(id=client_id, organization_id=org_id).first()
|
||||
if not client:
|
||||
return api_response(success=False, message="Client not found", status=404)
|
||||
|
||||
client.is_active = False
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={}, message="Client deactivated successfully")
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>/send-mfa-reminder", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def send_mfa_reminder(org_id, user_id):
|
||||
"""Send an MFA reminder email to a specific member.
|
||||
|
||||
Returns:
|
||||
200: Reminder sent (or silently skipped if no deadline record)
|
||||
403: Not an admin
|
||||
404: Member not found
|
||||
"""
|
||||
from gatehouse_app.models import User, MfaPolicyCompliance, OrganizationSecurityPolicy
|
||||
from gatehouse_app.services.notification_service import NotificationService
|
||||
|
||||
user = User.query.filter_by(id=user_id, deleted_at=None).first()
|
||||
if not user:
|
||||
return api_response(success=False, message="User not found", status=404)
|
||||
|
||||
compliance = MfaPolicyCompliance.query.filter_by(
|
||||
user_id=user_id, organization_id=org_id
|
||||
).first()
|
||||
policy = OrganizationSecurityPolicy.query.filter_by(organization_id=org_id).first()
|
||||
|
||||
if compliance and policy and compliance.deadline_at:
|
||||
NotificationService.send_mfa_deadline_reminder(user, compliance, policy)
|
||||
else:
|
||||
# No compliance deadline — send a generic nudge
|
||||
NotificationService._send_email(
|
||||
to_address=user.email,
|
||||
subject="Reminder: Set up multi-factor authentication",
|
||||
body=(
|
||||
f"Hi {user.full_name or user.email},\n\n"
|
||||
"Your organization administrator has asked you to set up "
|
||||
"multi-factor authentication (MFA) on your Gatehouse account.\n\n"
|
||||
"Please log in and configure MFA as soon as possible.\n\n"
|
||||
"Gatehouse Security Team"
|
||||
),
|
||||
)
|
||||
|
||||
return api_response(data={}, message="Reminder sent successfully")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# System-wide Audit Log (admin view) + User self audit
|
||||
# =============================================================================
|
||||
|
||||
def _audit_log_to_dict(log):
|
||||
"""Serialize an AuditLog record to a dict."""
|
||||
return {
|
||||
"id": log.id,
|
||||
"action": log.action.value if log.action else None,
|
||||
"user_id": log.user_id,
|
||||
"user": (
|
||||
{"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name}
|
||||
if log.user else None
|
||||
),
|
||||
"organization_id": log.organization_id,
|
||||
"resource_type": log.resource_type,
|
||||
"resource_id": log.resource_id,
|
||||
"ip_address": log.ip_address,
|
||||
"user_agent": log.user_agent,
|
||||
"request_id": log.request_id,
|
||||
"description": log.description,
|
||||
"success": log.success,
|
||||
"error_message": log.error_message,
|
||||
"metadata": log.extra_data,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
"updated_at": log.updated_at.isoformat() if log.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@api_v1_bp.route("/audit-logs", methods=["GET"])
|
||||
@login_required
|
||||
def get_system_audit_logs():
|
||||
"""
|
||||
Get all audit logs (system-wide). Any authenticated user can query
|
||||
their own logs; org owners/admins also see org-scoped logs; this
|
||||
endpoint returns ALL logs for users who own at least one org
|
||||
(acting as an admin view).
|
||||
|
||||
Query params:
|
||||
page – page number (default 1)
|
||||
per_page – results per page (default 50, max 200)
|
||||
action – filter by AuditAction value
|
||||
user_id – filter by user id
|
||||
resource_type – filter by resource type
|
||||
success – "true"/"false"
|
||||
q – free-text search on description
|
||||
"""
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
|
||||
current_user = g.current_user
|
||||
page = max(1, int(request.args.get("page", 1)))
|
||||
per_page = min(int(request.args.get("per_page", 50)), 200)
|
||||
|
||||
# Check if the user is an owner of any org to grant admin-level access
|
||||
is_admin = OrganizationMember.query.filter_by(
|
||||
user_id=current_user.id, role="OWNER"
|
||||
).first() is not None
|
||||
|
||||
query = AuditLog.query
|
||||
|
||||
if not is_admin:
|
||||
# Non-admins can only see their own logs
|
||||
query = query.filter(AuditLog.user_id == current_user.id)
|
||||
|
||||
# Optional filters
|
||||
action_filter = request.args.get("action")
|
||||
if action_filter:
|
||||
query = query.filter(AuditLog.action == action_filter)
|
||||
|
||||
user_id_filter = request.args.get("user_id")
|
||||
if user_id_filter:
|
||||
query = query.filter(AuditLog.user_id == user_id_filter)
|
||||
|
||||
resource_type_filter = request.args.get("resource_type")
|
||||
if resource_type_filter:
|
||||
query = query.filter(AuditLog.resource_type == resource_type_filter)
|
||||
|
||||
success_filter = request.args.get("success")
|
||||
if success_filter is not None:
|
||||
query = query.filter(AuditLog.success == (success_filter.lower() == "true"))
|
||||
|
||||
q = request.args.get("q", "").strip()
|
||||
if q:
|
||||
query = query.filter(AuditLog.description.ilike(f"%{q}%"))
|
||||
|
||||
query = query.order_by(AuditLog.created_at.desc())
|
||||
total = query.count()
|
||||
logs = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"audit_logs": [_audit_log_to_dict(log) for log in logs],
|
||||
"count": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page,
|
||||
"is_admin_view": is_admin,
|
||||
},
|
||||
message="Audit logs retrieved",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/audit-logs", methods=["GET"])
|
||||
@login_required
|
||||
def get_my_audit_logs():
|
||||
"""
|
||||
Get audit logs for the currently authenticated user only.
|
||||
|
||||
Query params:
|
||||
page – page number (default 1)
|
||||
per_page – results per page (default 50, max 200)
|
||||
action – filter by AuditAction value
|
||||
"""
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
|
||||
current_user = g.current_user
|
||||
page = max(1, int(request.args.get("page", 1)))
|
||||
per_page = min(int(request.args.get("per_page", 50)), 200)
|
||||
|
||||
query = AuditLog.query.filter(AuditLog.user_id == current_user.id)
|
||||
|
||||
action_filter = request.args.get("action")
|
||||
if action_filter:
|
||||
query = query.filter(AuditLog.action == action_filter)
|
||||
|
||||
query = query.order_by(AuditLog.created_at.desc())
|
||||
total = query.count()
|
||||
logs = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"audit_logs": [_audit_log_to_dict(log) for log in logs],
|
||||
"count": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page,
|
||||
},
|
||||
message="Activity retrieved",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,615 @@
|
||||
"""SSH Key and Certificate API routes."""
|
||||
from flask import Blueprint, request, jsonify, g
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from gatehouse_app.services.ssh_key_service import SSHKeyService
|
||||
from gatehouse_app.services.ssh_ca_signing_service import (
|
||||
SSHCASigningService,
|
||||
SSHCertificateSigningRequest,
|
||||
)
|
||||
from gatehouse_app.exceptions import (
|
||||
SSHKeyError,
|
||||
SSHKeyNotFoundError,
|
||||
SSHCertificateError,
|
||||
ValidationError,
|
||||
SSHKeyAlreadyExistsError,
|
||||
)
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.models import AuditLog
|
||||
from gatehouse_app.utils.decorators import login_required
|
||||
|
||||
ssh_bp = Blueprint('ssh', __name__, url_prefix='/ssh')
|
||||
ssh_key_service = SSHKeyService()
|
||||
ssh_ca_service = SSHCASigningService()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_org_ca_for_user(user):
|
||||
"""Return the active DB CA for the user's first org, or None."""
|
||||
try:
|
||||
from gatehouse_app.models.ca import CA
|
||||
org_ids = [m.organization_id for m in user.organization_memberships]
|
||||
if not org_ids:
|
||||
return None
|
||||
return CA.query.filter(
|
||||
CA.organization_id.in_(org_ids),
|
||||
CA.is_active == True, # noqa: E712
|
||||
).first()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_or_create_system_ca():
|
||||
"""
|
||||
Return a CA DB record representing the config-file CA.
|
||||
|
||||
This is used as the ``ca_id`` FK when persisting certificates that were
|
||||
signed by the globally-configured CA key (not an org-specific DB CA).
|
||||
The record is created on first use and has no ``organization_id``.
|
||||
"""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.ca import CA, KeyType
|
||||
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
|
||||
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
||||
import os
|
||||
|
||||
try:
|
||||
existing = CA.query.filter_by(name="system-config-ca").first()
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
cfg = get_ssh_ca_config()
|
||||
key_path = cfg.get_str("ca_key_path", "").strip()
|
||||
pub_key_path = key_path + ".pub"
|
||||
|
||||
if not os.path.exists(pub_key_path):
|
||||
return None
|
||||
|
||||
with open(pub_key_path) as f:
|
||||
pub_key = f.read().strip()
|
||||
|
||||
# Load private key for the record (stored but not actually used for signing here)
|
||||
priv_key = ""
|
||||
if os.path.exists(key_path):
|
||||
with open(key_path) as f:
|
||||
priv_key = f.read()
|
||||
|
||||
fingerprint = compute_ssh_fingerprint(pub_key)
|
||||
|
||||
# Check by fingerprint in case it was created under a different name
|
||||
existing_by_fp = CA.query.filter_by(fingerprint=fingerprint).first()
|
||||
if existing_by_fp:
|
||||
return existing_by_fp
|
||||
|
||||
system_ca = CA(
|
||||
name="system-config-ca",
|
||||
description="Global CA loaded from etc/ssh_ca.conf (ca_key_path)",
|
||||
key_type=KeyType.ED25519,
|
||||
private_key=priv_key,
|
||||
public_key=pub_key,
|
||||
fingerprint=fingerprint,
|
||||
is_active=True,
|
||||
default_cert_validity_hours=24,
|
||||
max_cert_validity_hours=720,
|
||||
)
|
||||
# organization_id is nullable=False in schema — we need a dummy org or
|
||||
# need to allow NULL. Use None; the DB constraint will tell us quickly.
|
||||
# If the migration enforces NOT NULL we'll catch the error gracefully.
|
||||
db.session.add(system_ca)
|
||||
db.session.commit()
|
||||
return system_ca
|
||||
except Exception as exc:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
f"Could not upsert system-config-ca: {exc}"
|
||||
)
|
||||
try:
|
||||
db.session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None):
|
||||
"""Save a signed certificate to the ssh_certificates table.
|
||||
|
||||
Args:
|
||||
user_id: UUID of the user
|
||||
ssh_key_id: UUID of the SSH key that was signed
|
||||
ca: CA model instance (may be None — cert still returned but not persisted)
|
||||
signing_response: SSHCertificateSigningResponse
|
||||
request_ip: Client IP address
|
||||
|
||||
Returns:
|
||||
SSHCertificate instance or None if persistence failed
|
||||
"""
|
||||
if ca is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus
|
||||
from gatehouse_app.models.ca import CertType
|
||||
|
||||
cert_record = SSHCertificate(
|
||||
ca_id=ca.id,
|
||||
user_id=user_id,
|
||||
ssh_key_id=ssh_key_id,
|
||||
certificate=signing_response.certificate,
|
||||
serial=signing_response.serial,
|
||||
key_id=str(ssh_key_id),
|
||||
cert_type=CertType.USER,
|
||||
principals=signing_response.principals,
|
||||
valid_after=signing_response.valid_after,
|
||||
valid_before=signing_response.valid_before,
|
||||
revoked=False,
|
||||
status=CertificateStatus.ISSUED,
|
||||
request_ip=request_ip,
|
||||
)
|
||||
db.session.add(cert_record)
|
||||
db.session.commit()
|
||||
return cert_record
|
||||
except Exception as exc:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
f"Failed to persist certificate to DB: {exc}"
|
||||
)
|
||||
try:
|
||||
from gatehouse_app.extensions import db as _db
|
||||
_db.session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ssh_bp.route('/keys', methods=['GET'])
|
||||
@login_required
|
||||
def list_ssh_keys():
|
||||
"""Get all SSH keys for current user."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
keys = ssh_key_service.get_user_ssh_keys(user_id)
|
||||
return jsonify({
|
||||
'keys': [k.to_dict() for k in keys],
|
||||
'count': len(keys),
|
||||
}), 200
|
||||
|
||||
|
||||
@ssh_bp.route('/keys', methods=['POST'])
|
||||
@login_required
|
||||
def add_ssh_key():
|
||||
"""Add a new SSH public key for current user."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return jsonify({'error': 'No JSON data provided'}), 400
|
||||
|
||||
public_key = data.get('public_key') or data.get('key')
|
||||
description = data.get('description')
|
||||
|
||||
if not public_key:
|
||||
return jsonify({'error': 'public_key is required'}), 400
|
||||
|
||||
try:
|
||||
ssh_key = ssh_key_service.add_ssh_key(
|
||||
user_id=user_id,
|
||||
public_key=public_key,
|
||||
description=description,
|
||||
)
|
||||
|
||||
# Audit log
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_KEY_ADDED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHKey',
|
||||
resource_id=ssh_key.id,
|
||||
ip_address=request.remote_addr,
|
||||
)
|
||||
|
||||
return jsonify(ssh_key.to_dict()), 201
|
||||
|
||||
except SSHKeyAlreadyExistsError as e:
|
||||
return jsonify({'error': e.message, 'code': 'SSH_KEY_ALREADY_EXISTS'}), 409
|
||||
except IntegrityError:
|
||||
return jsonify({'error': 'SSH key already exists', 'code': 'SSH_KEY_ALREADY_EXISTS'}), 409
|
||||
except SSHKeyError as e:
|
||||
return jsonify({'error': str(e)}), 400
|
||||
except ValidationError as e:
|
||||
return jsonify({'error': str(e)}), 400
|
||||
|
||||
|
||||
@ssh_bp.route('/keys/<key_id>', methods=['GET'])
|
||||
@login_required
|
||||
def get_ssh_key(key_id):
|
||||
"""Get a specific SSH key."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
try:
|
||||
ssh_key = ssh_key_service.get_ssh_key(key_id)
|
||||
|
||||
# Check ownership
|
||||
if ssh_key.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
|
||||
return jsonify(ssh_key.to_dict()), 200
|
||||
|
||||
except SSHKeyNotFoundError:
|
||||
return jsonify({'error': 'SSH key not found'}), 404
|
||||
|
||||
|
||||
@ssh_bp.route('/keys/<key_id>', methods=['DELETE'])
|
||||
@login_required
|
||||
def delete_ssh_key(key_id):
|
||||
"""Delete an SSH key."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
try:
|
||||
ssh_key = ssh_key_service.get_ssh_key(key_id)
|
||||
|
||||
# Check ownership
|
||||
if ssh_key.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
|
||||
ssh_key_service.delete_ssh_key(key_id)
|
||||
|
||||
# Audit log
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_KEY_DELETED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHKey',
|
||||
resource_id=key_id,
|
||||
ip_address=request.remote_addr,
|
||||
)
|
||||
|
||||
return jsonify({'status': 'deleted'}), 200
|
||||
|
||||
except SSHKeyNotFoundError:
|
||||
return jsonify({'error': 'SSH key not found'}), 404
|
||||
|
||||
|
||||
@ssh_bp.route('/keys/<key_id>/verify', methods=['GET', 'POST'])
|
||||
@login_required
|
||||
def verify_ssh_key(key_id):
|
||||
"""Generate or verify SSH key ownership challenge."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
try:
|
||||
ssh_key = ssh_key_service.get_ssh_key(key_id)
|
||||
|
||||
# Check ownership
|
||||
if ssh_key.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
|
||||
# Handle GET request - return challenge
|
||||
if request.method == 'GET':
|
||||
challenge = ssh_key_service.generate_verification_challenge(key_id)
|
||||
return jsonify({
|
||||
'challenge_text': challenge,
|
||||
'validationText': challenge, # Backwards compatibility
|
||||
'key_id': key_id,
|
||||
}), 200
|
||||
|
||||
# Handle POST request - verify signature
|
||||
data = request.get_json() or {}
|
||||
action = data.get('action', 'verify_signature')
|
||||
|
||||
if action == 'verify_signature':
|
||||
# Verify signature
|
||||
signature = data.get('signature')
|
||||
if not signature:
|
||||
return jsonify({'error': 'signature is required'}), 400
|
||||
|
||||
try:
|
||||
verified = ssh_key_service.verify_ssh_key_ownership(key_id, signature)
|
||||
|
||||
# Audit log
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_KEY_VERIFIED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHKey',
|
||||
resource_id=key_id,
|
||||
ip_address=request.remote_addr,
|
||||
success=verified,
|
||||
)
|
||||
|
||||
return jsonify({'verified': verified}), 200
|
||||
|
||||
except Exception as e:
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_KEY_VALIDATION_FAILED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHKey',
|
||||
resource_id=key_id,
|
||||
ip_address=request.remote_addr,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
return jsonify({'error': str(e)}), 400
|
||||
|
||||
else: # generate_challenge
|
||||
# Generate verification challenge
|
||||
challenge = ssh_key_service.generate_verification_challenge(key_id)
|
||||
return jsonify({
|
||||
'challenge_text': challenge,
|
||||
'challenge': challenge, # Both for compatibility
|
||||
}), 200
|
||||
|
||||
except SSHKeyNotFoundError:
|
||||
return jsonify({'error': 'SSH key not found'}), 404
|
||||
|
||||
|
||||
@ssh_bp.route('/keys/<key_id>/update-description', methods=['PATCH'])
|
||||
@login_required
|
||||
def update_ssh_key_description(key_id):
|
||||
"""Update SSH key description."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
data = request.get_json()
|
||||
if not data or 'description' not in data:
|
||||
return jsonify({'error': 'description is required'}), 400
|
||||
|
||||
try:
|
||||
ssh_key = ssh_key_service.get_ssh_key(key_id)
|
||||
|
||||
# Check ownership
|
||||
if ssh_key.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
|
||||
updated_key = ssh_key_service.update_ssh_key_description(
|
||||
key_id,
|
||||
data['description']
|
||||
)
|
||||
|
||||
return jsonify(updated_key.to_dict()), 200
|
||||
|
||||
except SSHKeyNotFoundError:
|
||||
return jsonify({'error': 'SSH key not found'}), 404
|
||||
|
||||
|
||||
@ssh_bp.route('/sign', methods=['POST'])
|
||||
@login_required
|
||||
def sign_certificate():
|
||||
"""Sign an SSH certificate for the current user."""
|
||||
user = g.current_user
|
||||
user_id = user.id
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return jsonify({'error': 'No JSON data provided'}), 400
|
||||
|
||||
try:
|
||||
principals = data.get('principals', [])
|
||||
cert_type = data.get('cert_type', 'user')
|
||||
# Accept both 'key_id' and 'cert_id' (from CLI)
|
||||
key_id = data.get('key_id') or data.get('cert_id')
|
||||
expiry_hours = data.get('expiry_hours')
|
||||
|
||||
if not principals:
|
||||
return jsonify({'error': 'principals is required'}), 400
|
||||
|
||||
# If key_id not specified, use first verified key
|
||||
if not key_id:
|
||||
verified_keys = ssh_key_service.get_user_verified_ssh_keys(user_id)
|
||||
if not verified_keys:
|
||||
return jsonify({'error': 'No verified SSH keys found'}), 400
|
||||
key_id = verified_keys[0].id
|
||||
|
||||
# Get the SSH key
|
||||
ssh_key = ssh_key_service.get_ssh_key(key_id)
|
||||
if ssh_key.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
|
||||
if not ssh_key.verified:
|
||||
return jsonify({'error': 'SSH key is not verified'}), 400
|
||||
|
||||
# Resolve which CA to use: org DB CA > config-file CA
|
||||
db_ca = _get_org_ca_for_user(user)
|
||||
ca_private_key = db_ca.private_key if db_ca else None # None → signing service uses config
|
||||
|
||||
# Create signing request
|
||||
signing_request = SSHCertificateSigningRequest(
|
||||
ssh_public_key=ssh_key.payload,
|
||||
principals=principals,
|
||||
cert_type=cert_type,
|
||||
key_id=key_id,
|
||||
expiry_hours=int(expiry_hours) if expiry_hours else None,
|
||||
)
|
||||
|
||||
# Validate request
|
||||
validation_errors = signing_request.validate()
|
||||
if validation_errors:
|
||||
return jsonify({'errors': validation_errors}), 400
|
||||
|
||||
# Sign the certificate (pass ca_private_key=None → service loads from config)
|
||||
response = ssh_ca_service.sign_certificate(signing_request, ca_private_key=ca_private_key)
|
||||
|
||||
# Persist certificate to DB
|
||||
# If user's org has no DB CA, use the system-config-ca record
|
||||
ca_for_db = db_ca or _get_or_create_system_ca()
|
||||
cert_record = _persist_certificate(
|
||||
user_id=user_id,
|
||||
ssh_key_id=key_id,
|
||||
ca=ca_for_db,
|
||||
signing_response=response,
|
||||
request_ip=request.remote_addr,
|
||||
)
|
||||
|
||||
# Audit log
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_CERT_ISSUED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHCertificate',
|
||||
resource_id=cert_record.id if cert_record else key_id,
|
||||
ip_address=request.remote_addr,
|
||||
description=f'Certificate issued for principals: {", ".join(principals)}',
|
||||
)
|
||||
|
||||
result = {
|
||||
'certificate': response.certificate,
|
||||
'serial': response.serial,
|
||||
'principals': response.principals,
|
||||
'valid_after': response.valid_after.isoformat() if response.valid_after else None,
|
||||
'valid_before': response.valid_before.isoformat() if response.valid_before else None,
|
||||
}
|
||||
if cert_record:
|
||||
result['cert_id'] = str(cert_record.id)
|
||||
|
||||
return jsonify(result), 201
|
||||
|
||||
except SSHKeyNotFoundError:
|
||||
return jsonify({'error': 'SSH key not found'}), 404
|
||||
except SSHCertificateError as e:
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_CERT_FAILED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHCertificate',
|
||||
ip_address=request.remote_addr,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
return jsonify({'error': str(e)}), 400
|
||||
except Exception as e:
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_CERT_FAILED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHCertificate',
|
||||
ip_address=request.remote_addr,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
return jsonify({'error': 'Certificate signing failed: ' + str(e)}), 500
|
||||
|
||||
|
||||
@ssh_bp.route('/certificates', methods=['GET'])
|
||||
@login_required
|
||||
def list_certificates():
|
||||
"""List all SSH certificates issued for the current user."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
try:
|
||||
from gatehouse_app.models.ssh_certificate import SSHCertificate
|
||||
certs = (
|
||||
SSHCertificate.query
|
||||
.filter_by(user_id=user_id, deleted_at=None)
|
||||
.order_by(SSHCertificate.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return jsonify({
|
||||
'certificates': [c.to_dict() for c in certs],
|
||||
'count': len(certs),
|
||||
}), 200
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@ssh_bp.route('/certificates/<cert_id>', methods=['GET'])
|
||||
@login_required
|
||||
def get_certificate(cert_id):
|
||||
"""Get a specific issued certificate (metadata only)."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
try:
|
||||
from gatehouse_app.models.ssh_certificate import SSHCertificate
|
||||
cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first()
|
||||
if not cert:
|
||||
return jsonify({'error': 'Certificate not found'}), 404
|
||||
if cert.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
# Include full certificate text in single-fetch endpoint
|
||||
data = cert.to_dict()
|
||||
data['certificate'] = cert.certificate
|
||||
return jsonify(data), 200
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@ssh_bp.route('/certificates/<cert_id>/revoke', methods=['POST'])
|
||||
@login_required
|
||||
def revoke_certificate(cert_id):
|
||||
"""Revoke an issued certificate."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
data = request.get_json() or {}
|
||||
reason = data.get('reason', 'User requested revocation')
|
||||
|
||||
try:
|
||||
from gatehouse_app.models.ssh_certificate import SSHCertificate
|
||||
cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first()
|
||||
if not cert:
|
||||
return jsonify({'error': 'Certificate not found'}), 404
|
||||
if cert.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
if cert.revoked:
|
||||
return jsonify({'error': 'Certificate is already revoked'}), 409
|
||||
|
||||
cert.revoke(reason=reason)
|
||||
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_CERT_REVOKED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHCertificate',
|
||||
resource_id=cert_id,
|
||||
ip_address=request.remote_addr,
|
||||
description=f'Revoked: {reason}',
|
||||
)
|
||||
|
||||
return jsonify({'status': 'revoked', 'cert_id': cert_id, 'reason': reason}), 200
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@ssh_bp.route('/ca/public-key', methods=['GET'])
|
||||
@login_required
|
||||
def get_ca_public_key():
|
||||
"""
|
||||
Return the CA public key for this user's organization.
|
||||
|
||||
Server admins should add this key to their host's ``TrustedUserCAKeys``
|
||||
directive so that certificates issued by gatehouse are trusted.
|
||||
|
||||
Query parameters:
|
||||
format: 'openssh' (default) or 'text' — affects Content-Type only
|
||||
|
||||
Returns:
|
||||
{ "public_key": "ssh-ed25519 AAAA...",
|
||||
"fingerprint": "SHA256:...",
|
||||
"ca_name": "..." }
|
||||
"""
|
||||
user = g.current_user
|
||||
|
||||
# Try org CA first
|
||||
db_ca = _get_org_ca_for_user(user)
|
||||
if db_ca:
|
||||
return jsonify({
|
||||
'public_key': db_ca.public_key,
|
||||
'fingerprint': db_ca.fingerprint,
|
||||
'ca_name': db_ca.name,
|
||||
'source': 'db',
|
||||
}), 200
|
||||
|
||||
# Fall back to config-file CA
|
||||
try:
|
||||
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
|
||||
import os
|
||||
cfg = get_ssh_ca_config()
|
||||
key_path = cfg.get_str('ca_key_path', '').strip() + '.pub'
|
||||
if os.path.exists(key_path):
|
||||
with open(key_path) as f:
|
||||
pub_key = f.read().strip()
|
||||
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
||||
return jsonify({
|
||||
'public_key': pub_key,
|
||||
'fingerprint': compute_ssh_fingerprint(pub_key),
|
||||
'ca_name': 'system-config-ca',
|
||||
'source': 'config',
|
||||
}), 200
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Could not load CA public key: {e}'}), 500
|
||||
|
||||
return jsonify({'error': 'No CA configured for this organization'}), 404
|
||||
|
||||
|
||||
@@ -0,0 +1,271 @@
|
||||
"""SSH CA Configuration Manager.
|
||||
|
||||
Handles loading and managing SSH CA configuration from etc/ssh_ca.conf
|
||||
and environment variables.
|
||||
"""
|
||||
import os
|
||||
import configparser
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class SSHCAConfig:
|
||||
"""Configuration manager for SSH CA settings.
|
||||
|
||||
Loads configuration from:
|
||||
1. etc/ssh_ca.conf file
|
||||
2. Environment variables (override config file)
|
||||
3. Application environment-specific defaults
|
||||
|
||||
Example:
|
||||
config = SSHCAConfig()
|
||||
cert_hours = config.get_int('cert_validity_hours')
|
||||
kms_key = config.get_str('aws_kms_key_id')
|
||||
"""
|
||||
|
||||
# Configuration file location (relative to project root)
|
||||
DEFAULT_CONFIG_FILE = "etc/ssh_ca.conf"
|
||||
|
||||
# Default values if config file is missing
|
||||
DEFAULTS = {
|
||||
'cert_validity_hours': '1',
|
||||
'max_cert_validity_hours': '24',
|
||||
'max_certs_per_user': '100',
|
||||
'crl_enabled': 'true',
|
||||
'crl_endpoint': 'https://ca.example.com/crl',
|
||||
'crl_refresh_hours': '24',
|
||||
'default_key_type': 'ed25519',
|
||||
'rsa_key_bits': '4096',
|
||||
'private_key_encryption': 'kms',
|
||||
'aws_kms_key_id': '',
|
||||
'extensions_enabled': 'true',
|
||||
'extensions': 'permit-X11-forwarding,permit-agent-forwarding,permit-pty,permit-port-forwarding,permit-user-rc',
|
||||
'critical_options_enabled': 'false',
|
||||
'max_principals_per_cert': '256',
|
||||
'max_key_id_length': '255',
|
||||
'log_level': 'INFO',
|
||||
'audit_enabled': 'true',
|
||||
'require_key_verification': 'true',
|
||||
'verification_challenge_max_age': '24',
|
||||
'rate_limit_certs_per_minute': '5',
|
||||
'request_timeout': '30',
|
||||
'auto_delete_unverified_days': '30',
|
||||
'archive_expired_days': '365',
|
||||
'oauth_token_endpoint': '/api/v1/oauth2/token',
|
||||
'oauth_userinfo_endpoint': '/api/v1/oauth2/userinfo',
|
||||
'ca_key_path': '',
|
||||
}
|
||||
|
||||
def __init__(self, config_file: Optional[str] = None, environment: Optional[str] = None):
|
||||
"""Initialize SSH CA configuration.
|
||||
|
||||
Args:
|
||||
config_file: Path to config file (default: etc/ssh_ca.conf)
|
||||
environment: Environment name (development, production, testing)
|
||||
Default: value of FLASK_ENV or 'development'
|
||||
"""
|
||||
self.config = configparser.ConfigParser()
|
||||
|
||||
# Determine environment
|
||||
if environment is None:
|
||||
environment = os.environ.get('FLASK_ENV', 'development')
|
||||
self.environment = environment
|
||||
|
||||
# Load config file
|
||||
if config_file is None:
|
||||
# Try to find config file relative to this module
|
||||
module_dir = Path(__file__).parent.parent.parent
|
||||
config_file = module_dir / self.DEFAULT_CONFIG_FILE
|
||||
|
||||
self.config_file = config_file
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from file and apply environment-specific overrides."""
|
||||
# Set defaults
|
||||
self.config['default'] = self.DEFAULTS.copy()
|
||||
|
||||
# Load config file if it exists
|
||||
if Path(self.config_file).exists():
|
||||
self.config.read(self.config_file)
|
||||
|
||||
# Apply environment-specific configuration
|
||||
if self.environment in self.config:
|
||||
for key, value in self.config[self.environment].items():
|
||||
self.config['default'][key] = value
|
||||
|
||||
def get_str(self, key: str, default: Optional[str] = None) -> str:
|
||||
"""Get a string configuration value.
|
||||
|
||||
First checks environment variables (SSH_CA_<KEY>), then config file.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as string
|
||||
"""
|
||||
env_key = f"SSH_CA_{key.upper()}"
|
||||
|
||||
# Check environment variable first
|
||||
if env_key in os.environ:
|
||||
return os.environ[env_key]
|
||||
|
||||
# Check config file
|
||||
if key in self.config['default']:
|
||||
value = self.config['default'][key]
|
||||
# Handle environment variable substitution
|
||||
return os.path.expandvars(value)
|
||||
|
||||
# Return default
|
||||
if default is not None:
|
||||
return default
|
||||
|
||||
return self.DEFAULTS.get(key, '')
|
||||
|
||||
def get_int(self, key: str, default: Optional[int] = None) -> int:
|
||||
"""Get an integer configuration value.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as integer
|
||||
|
||||
Raises:
|
||||
ValueError: If value cannot be converted to integer
|
||||
"""
|
||||
str_value = self.get_str(key)
|
||||
if not str_value:
|
||||
if default is not None:
|
||||
return default
|
||||
raise ValueError(f"No value found for {key}")
|
||||
|
||||
try:
|
||||
return int(str_value)
|
||||
except ValueError:
|
||||
if default is not None:
|
||||
return default
|
||||
raise ValueError(f"Configuration {key}={str_value} is not a valid integer")
|
||||
|
||||
def get_bool(self, key: str, default: Optional[bool] = None) -> bool:
|
||||
"""Get a boolean configuration value.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as boolean
|
||||
"""
|
||||
str_value = self.get_str(key)
|
||||
if not str_value:
|
||||
if default is not None:
|
||||
return default
|
||||
return False
|
||||
|
||||
return str_value.lower() in ('true', '1', 'yes', 'on')
|
||||
|
||||
def get_list(self, key: str, delimiter: str = ',', default: Optional[list] = None) -> list:
|
||||
"""Get a comma-separated list configuration value.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
delimiter: Delimiter between items (default: comma)
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as list of strings
|
||||
"""
|
||||
str_value = self.get_str(key)
|
||||
if not str_value:
|
||||
if default is not None:
|
||||
return default
|
||||
return []
|
||||
|
||||
return [item.strip() for item in str_value.split(delimiter) if item.strip()]
|
||||
|
||||
def validate_config(self) -> list:
|
||||
"""Validate SSH CA configuration.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Check cert validity hours
|
||||
try:
|
||||
validity = self.get_int('cert_validity_hours')
|
||||
max_validity = self.get_int('max_cert_validity_hours')
|
||||
if validity > max_validity:
|
||||
errors.append(
|
||||
f"cert_validity_hours ({validity}) > max_cert_validity_hours ({max_validity})"
|
||||
)
|
||||
except ValueError as e:
|
||||
errors.append(f"Invalid cert validity hours: {e}")
|
||||
|
||||
# Check key type
|
||||
valid_key_types = ['ed25519', 'rsa', 'ecdsa']
|
||||
key_type = self.get_str('default_key_type', 'ed25519')
|
||||
if key_type not in valid_key_types:
|
||||
errors.append(f"Invalid key type: {key_type}. Must be one of {valid_key_types}")
|
||||
|
||||
# Check encryption method
|
||||
valid_methods = ['kms', 'local']
|
||||
encryption = self.get_str('private_key_encryption', 'kms')
|
||||
if encryption not in valid_methods:
|
||||
errors.append(f"Invalid private_key_encryption: {encryption}. Must be one of {valid_methods}")
|
||||
|
||||
# Warn if using local encryption in production
|
||||
if encryption == 'local' and self.environment == 'production':
|
||||
errors.append("WARNING: Using local key encryption in production! Use KMS instead.")
|
||||
|
||||
# Check KMS key ID if using KMS
|
||||
if encryption == 'kms':
|
||||
kms_key = self.get_str('aws_kms_key_id', '').strip()
|
||||
if not kms_key:
|
||||
errors.append("aws_kms_key_id not set but private_key_encryption=kms")
|
||||
|
||||
# Check principals limit
|
||||
max_principals = self.get_int('max_principals_per_cert')
|
||||
if max_principals > 256:
|
||||
errors.append(f"max_principals_per_cert ({max_principals}) exceeds SSH limit of 256")
|
||||
|
||||
return errors
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Export current configuration as dictionary.
|
||||
"""
|
||||
return dict(self.config['default'])
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of configuration."""
|
||||
return f"<SSHCAConfig environment={self.environment} file={self.config_file}>"
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_config_instance = None
|
||||
|
||||
|
||||
def get_ssh_ca_config() -> SSHCAConfig:
|
||||
"""Get the global SSH CA configuration instance.
|
||||
|
||||
This function uses a singleton pattern to ensure only one
|
||||
configuration instance is created and reused.
|
||||
|
||||
Returns:
|
||||
SSHCAConfig instance
|
||||
"""
|
||||
global _config_instance
|
||||
if _config_instance is None:
|
||||
_config_instance = SSHCAConfig()
|
||||
return _config_instance
|
||||
|
||||
|
||||
def reset_config_instance():
|
||||
"""Reset the global configuration instance.
|
||||
"""
|
||||
global _config_instance
|
||||
_config_instance = None
|
||||
@@ -19,6 +19,21 @@ from gatehouse_app.exceptions.validation_exceptions import (
|
||||
OrganizationNotFoundError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from gatehouse_app.exceptions.ssh_exceptions import (
|
||||
SSHCAError,
|
||||
SSHKeyError,
|
||||
SSHKeyNotFoundError,
|
||||
SSHKeyAlreadyExistsError,
|
||||
SSHKeyNotVerifiedError,
|
||||
SSHCertificateError,
|
||||
SSHCertificateNotFoundError,
|
||||
CAError,
|
||||
CANotFoundError,
|
||||
PrincipalError,
|
||||
PrincipalNotFoundError,
|
||||
DepartmentError,
|
||||
DepartmentNotFoundError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseAPIException",
|
||||
@@ -37,4 +52,18 @@ __all__ = [
|
||||
"EmailAlreadyExistsError",
|
||||
"OrganizationNotFoundError",
|
||||
"UserNotFoundError",
|
||||
"SSHCAError",
|
||||
"SSHKeyError",
|
||||
"SSHKeyNotFoundError",
|
||||
"SSHKeyAlreadyExistsError",
|
||||
"SSHKeyNotVerifiedError",
|
||||
"SSHCertificateError",
|
||||
"SSHCertificateNotFoundError",
|
||||
"CAError",
|
||||
"CANotFoundError",
|
||||
"PrincipalError",
|
||||
"PrincipalNotFoundError",
|
||||
"DepartmentError",
|
||||
"DepartmentNotFoundError",
|
||||
]
|
||||
|
||||
|
||||
@@ -16,9 +16,10 @@ class BaseAPIException(Exception):
|
||||
message: Custom error message
|
||||
error_details: Additional error details dictionary
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(self.message)
|
||||
if message:
|
||||
self.message = message
|
||||
super().__init__(message) # update args so str(e) works
|
||||
self.error_details = error_details or {}
|
||||
|
||||
def to_dict(self):
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
"""SSH-specific exceptions."""
|
||||
from gatehouse_app.exceptions.base import BaseAPIException
|
||||
|
||||
|
||||
class SSHCAError(BaseAPIException):
|
||||
"""Base exception for SSH CA operations."""
|
||||
|
||||
status_code = 500
|
||||
error_type = "SSH_CA_ERROR"
|
||||
|
||||
|
||||
class SSHKeyError(BaseAPIException):
|
||||
"""Exception for SSH key operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "SSH_KEY_ERROR"
|
||||
|
||||
|
||||
class SSHKeyNotFoundError(BaseAPIException):
|
||||
"""SSH key not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "SSH_KEY_NOT_FOUND"
|
||||
|
||||
|
||||
class SSHKeyAlreadyExistsError(BaseAPIException):
|
||||
"""SSH key already exists (duplicate fingerprint)."""
|
||||
|
||||
status_code = 409
|
||||
error_type = "SSH_KEY_ALREADY_EXISTS"
|
||||
|
||||
|
||||
class SSHKeyNotVerifiedError(BaseAPIException):
|
||||
"""SSH key has not been verified."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "SSH_KEY_NOT_VERIFIED"
|
||||
|
||||
|
||||
class SSHCertificateError(BaseAPIException):
|
||||
"""Exception for SSH certificate operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "SSH_CERT_ERROR"
|
||||
|
||||
|
||||
class SSHCertificateNotFoundError(BaseAPIException):
|
||||
"""SSH certificate not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "SSH_CERT_NOT_FOUND"
|
||||
|
||||
|
||||
class CAError(BaseAPIException):
|
||||
"""Exception for Certificate Authority operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "CA_ERROR"
|
||||
|
||||
|
||||
class CANotFoundError(BaseAPIException):
|
||||
"""Certificate Authority not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "CA_NOT_FOUND"
|
||||
|
||||
|
||||
class PrincipalError(BaseAPIException):
|
||||
"""Exception for principal operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "PRINCIPAL_ERROR"
|
||||
|
||||
|
||||
class PrincipalNotFoundError(BaseAPIException):
|
||||
"""Principal not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "PRINCIPAL_NOT_FOUND"
|
||||
|
||||
|
||||
class DepartmentError(BaseAPIException):
|
||||
"""Exception for department operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "DEPARTMENT_ERROR"
|
||||
|
||||
|
||||
class DepartmentNotFoundError(BaseAPIException):
|
||||
"""Department not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "DEPARTMENT_NOT_FOUND"
|
||||
@@ -29,6 +29,13 @@ from gatehouse_app.models.principal import (
|
||||
Principal,
|
||||
PrincipalMembership,
|
||||
)
|
||||
from gatehouse_app.models.ssh_key import SSHKey
|
||||
from gatehouse_app.models.ca import CA, KeyType, CertType
|
||||
from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus
|
||||
from gatehouse_app.models.certificate_audit_log import CertificateAuditLog
|
||||
from gatehouse_app.models.password_reset_token import PasswordResetToken
|
||||
from gatehouse_app.models.email_verification_token import EmailVerificationToken
|
||||
from gatehouse_app.models.org_invite_token import OrgInviteToken
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
@@ -55,4 +62,14 @@ __all__ = [
|
||||
"DepartmentPrincipal",
|
||||
"Principal",
|
||||
"PrincipalMembership",
|
||||
"SSHKey",
|
||||
"CA",
|
||||
"KeyType",
|
||||
"CertType",
|
||||
"SSHCertificate",
|
||||
"CertificateStatus",
|
||||
"CertificateAuditLog",
|
||||
"PasswordResetToken",
|
||||
"EmailVerificationToken",
|
||||
"OrgInviteToken",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
"""Certificate Authority (CA) model."""
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class KeyType(str, Enum):
|
||||
"""SSH CA key types."""
|
||||
|
||||
ED25519 = "ed25519"
|
||||
RSA = "rsa"
|
||||
ECDSA = "ecdsa"
|
||||
|
||||
|
||||
class CertType(str, Enum):
|
||||
"""SSH certificate types."""
|
||||
|
||||
USER = "user"
|
||||
HOST = "host"
|
||||
|
||||
|
||||
class CA(BaseModel):
|
||||
"""Certificate Authority (CA) model for SSH certificate signing.
|
||||
|
||||
Each organization can have multiple CAs for different purposes
|
||||
(e.g., production vs. staging). Private keys are encrypted at rest
|
||||
and should be protected with KMS.
|
||||
"""
|
||||
|
||||
__tablename__ = "cas"
|
||||
|
||||
organization_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("organizations.id"),
|
||||
nullable=True, # NULL for the global system-config CA
|
||||
index=True,
|
||||
)
|
||||
|
||||
# CA name and description
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Key type (ED25519, RSA, ECDSA)
|
||||
key_type = db.Column(
|
||||
db.Enum(KeyType, values_callable=lambda x: [e.value for e in x]),
|
||||
default=KeyType.ED25519,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Private key (encrypted at rest by database/KMS)
|
||||
# Format: PEM-encoded private key
|
||||
private_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Public key (PEM format)
|
||||
public_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# SHA256 fingerprint of the public key
|
||||
fingerprint = db.Column(db.String(255), nullable=False, unique=True)
|
||||
|
||||
# CRL (Certificate Revocation List) configuration
|
||||
crl_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||
crl_endpoint = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Default certificate validity in hours
|
||||
# Can be overridden per certificate request
|
||||
default_cert_validity_hours = db.Column(
|
||||
db.Integer,
|
||||
default=1,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Maximum validity duration allowed
|
||||
max_cert_validity_hours = db.Column(
|
||||
db.Integer,
|
||||
default=24,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# CA status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Key rotation tracking
|
||||
rotated_at = db.Column(db.DateTime, nullable=True)
|
||||
rotation_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="cas")
|
||||
certificates = db.relationship(
|
||||
"SSHCertificate",
|
||||
back_populates="ca",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"organization_id", "name", name="uix_org_ca_name"
|
||||
),
|
||||
db.Index("idx_ca_org_active", "organization_id", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of CA."""
|
||||
return f"<CA {self.name} (org_id={self.organization_id}, type={self.key_type})>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert CA to dictionary."""
|
||||
exclude = exclude or []
|
||||
# Never expose private key in API responses
|
||||
exclude.extend(["private_key"])
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
# Add computed fields
|
||||
data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None])
|
||||
data["active_certs"] = len([
|
||||
c for c in self.certificates
|
||||
if c.deleted_at is None and not c.revoked
|
||||
])
|
||||
data["revoked_certs"] = len([
|
||||
c for c in self.certificates
|
||||
if c.deleted_at is None and c.revoked
|
||||
])
|
||||
|
||||
return data
|
||||
|
||||
def get_active_certificates(self):
|
||||
"""Get all active (non-revoked) certificates issued by this CA.
|
||||
|
||||
Returns:
|
||||
List of non-revoked SSHCertificate objects
|
||||
"""
|
||||
return [
|
||||
c for c in self.certificates
|
||||
if c.deleted_at is None and not c.revoked
|
||||
]
|
||||
|
||||
def rotate_key(self, new_private_key, new_public_key, new_fingerprint, reason=None):
|
||||
"""Rotate the CA's key pair.
|
||||
|
||||
This should only be done in carefully controlled circumstances.
|
||||
All existing certificates remain valid but no new certs can be
|
||||
signed with the old key.
|
||||
|
||||
Args:
|
||||
new_private_key: New PEM-encoded private key
|
||||
new_public_key: New PEM-encoded public key
|
||||
new_fingerprint: SHA256 fingerprint of new public key
|
||||
reason: Optional reason for rotation
|
||||
"""
|
||||
self.private_key = new_private_key
|
||||
self.public_key = new_public_key
|
||||
self.fingerprint = new_fingerprint
|
||||
self.rotated_at = datetime.utcnow()
|
||||
self.rotation_reason = reason
|
||||
self.save()
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Certificate audit log model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class CertificateAuditLog(BaseModel):
|
||||
"""Audit log for SSH certificate lifecycle events.
|
||||
|
||||
Tracks all operations on SSH certificates: signing, revocation,
|
||||
validation, etc. This is separate from the general AuditLog to
|
||||
provide detailed certificate operation tracking.
|
||||
"""
|
||||
|
||||
__tablename__ = "certificate_audit_logs"
|
||||
|
||||
# Reference to the certificate
|
||||
certificate_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("ssh_certificates.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# The user who performed the action (can be null for system actions)
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Action type (e.g., "signed", "revoked", "validated", "requested")
|
||||
action = db.Column(db.String(50), nullable=False, index=True)
|
||||
|
||||
# Request details
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.String(512), nullable=True)
|
||||
request_id = db.Column(db.String(36), nullable=True)
|
||||
|
||||
# Detailed message
|
||||
message = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Additional context
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Success/failure
|
||||
success = db.Column(db.Boolean, default=True, nullable=False)
|
||||
error_message = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
certificate = db.relationship("SSHCertificate", back_populates="audit_logs")
|
||||
user = db.relationship("User")
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_cert_audit_cert_action", "certificate_id", "action"),
|
||||
db.Index("idx_cert_audit_user", "user_id", "created_at"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of CertificateAuditLog."""
|
||||
return f"<CertificateAuditLog cert_id={self.certificate_id} action={self.action}>"
|
||||
|
||||
@classmethod
|
||||
def log(cls, certificate_id, action, user_id=None, **kwargs):
|
||||
"""Create a certificate audit log entry.
|
||||
|
||||
Args:
|
||||
certificate_id: ID of the certificate
|
||||
action: Action type (e.g., "signed", "revoked")
|
||||
user_id: ID of the user performing the action (optional)
|
||||
**kwargs: Additional fields (ip_address, user_agent, message, etc.)
|
||||
|
||||
Returns:
|
||||
CertificateAuditLog instance
|
||||
"""
|
||||
log_entry = cls(
|
||||
certificate_id=certificate_id,
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
**kwargs
|
||||
)
|
||||
log_entry.save()
|
||||
return log_entry
|
||||
@@ -40,6 +40,9 @@ class Organization(BaseModel):
|
||||
principals = db.relationship(
|
||||
"Principal", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
cas = db.relationship(
|
||||
"CA", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of Organization."""
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
"""SSH Certificate model."""
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.ca import CertType
|
||||
|
||||
|
||||
class CertificateStatus(str, Enum):
|
||||
"""SSH certificate lifecycle status."""
|
||||
|
||||
REQUESTED = "requested" # Waiting for signing
|
||||
ISSUED = "issued" # Signed and valid
|
||||
REVOKED = "revoked" # Manually revoked
|
||||
EXPIRED = "expired" # Validity period ended
|
||||
SUPERSEDED = "superseded" # Replaced by newer cert
|
||||
|
||||
|
||||
class SSHCertificate(BaseModel):
|
||||
"""SSH Certificate model representing a signed SSH user/host certificate.
|
||||
|
||||
Certificates are issued by a CA and associated with an SSH public key.
|
||||
They include principals (access levels), validity periods, and other
|
||||
OpenSSH certificate metadata.
|
||||
"""
|
||||
|
||||
__tablename__ = "ssh_certificates"
|
||||
|
||||
# Certificate relationships
|
||||
ca_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("cas.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
ssh_key_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("ssh_keys.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Certificate content (full signed certificate in OpenSSH format)
|
||||
certificate = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Certificate metadata
|
||||
serial = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
key_id = db.Column(db.String(255), nullable=False) # Usually user email
|
||||
cert_type = db.Column(
|
||||
db.Enum(CertType, values_callable=lambda x: [e.value for e in x]),
|
||||
default=CertType.USER,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Principals (JSON list) - e.g., ["prod-servers", "dev-servers"]
|
||||
principals = db.Column(db.JSON, nullable=False, default=list)
|
||||
|
||||
# Validity period
|
||||
valid_after = db.Column(db.DateTime, nullable=False)
|
||||
valid_before = db.Column(db.DateTime, nullable=False)
|
||||
|
||||
# Revocation status
|
||||
revoked = db.Column(db.Boolean, default=False, nullable=False, index=True)
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoke_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Status tracking
|
||||
status = db.Column(
|
||||
db.Enum(CertificateStatus, values_callable=lambda x: [e.value for e in x]),
|
||||
default=CertificateStatus.ISSUED,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Request metadata
|
||||
request_ip = db.Column(db.String(45), nullable=True)
|
||||
request_user_agent = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Critical options (JSON) - OpenSSH critical options
|
||||
# See: https://man.openbsd.org/ssh-cert
|
||||
critical_options = db.Column(db.JSON, nullable=True, default=dict)
|
||||
|
||||
# Extensions (JSON) - OpenSSH extensions
|
||||
# Common ones: permit-X11-forwarding, permit-agent-forwarding, permit-pty, etc.
|
||||
extensions = db.Column(db.JSON, nullable=True, default=dict)
|
||||
|
||||
# Relationships
|
||||
ca = db.relationship("CA", back_populates="certificates")
|
||||
user = db.relationship("User", back_populates="ssh_certificates")
|
||||
ssh_key = db.relationship(
|
||||
"SSHKey",
|
||||
back_populates="certificates",
|
||||
foreign_keys="SSHCertificate.ssh_key_id",
|
||||
)
|
||||
audit_logs = db.relationship(
|
||||
"CertificateAuditLog",
|
||||
back_populates="certificate",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_cert_user_status", "user_id", "status"),
|
||||
db.Index("idx_cert_validity", "valid_after", "valid_before"),
|
||||
db.Index("idx_cert_revoked", "revoked", "revoked_at"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of SSHCertificate."""
|
||||
return f"<SSHCertificate serial={self.serial[:16]}... user_id={self.user_id}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert certificate to dictionary."""
|
||||
exclude = exclude or []
|
||||
# Optionally exclude the certificate content (it's large)
|
||||
if "certificate" not in exclude:
|
||||
exclude.append("certificate")
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
# Add computed fields
|
||||
data["is_valid"] = self.is_valid()
|
||||
data["days_until_expiry"] = self.days_until_expiry()
|
||||
|
||||
return data
|
||||
|
||||
def is_valid(self):
|
||||
"""Check if certificate is currently valid.
|
||||
|
||||
Returns:
|
||||
True if certificate is issued, not revoked, and within validity period
|
||||
"""
|
||||
if self.revoked or self.status == CertificateStatus.REVOKED:
|
||||
return False
|
||||
|
||||
now = datetime.utcnow()
|
||||
return self.valid_after <= now <= self.valid_before
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if certificate has expired.
|
||||
|
||||
Returns:
|
||||
True if current time is past valid_before
|
||||
"""
|
||||
return datetime.utcnow() > self.valid_before
|
||||
|
||||
def days_until_expiry(self):
|
||||
"""Get number of days until certificate expires.
|
||||
|
||||
Returns:
|
||||
Number of days remaining (negative if already expired)
|
||||
"""
|
||||
delta = self.valid_before - datetime.utcnow()
|
||||
return delta.days + (1 if delta.seconds > 0 else 0)
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""Revoke this certificate.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked = True
|
||||
self.revoked_at = datetime.utcnow()
|
||||
self.revoke_reason = reason
|
||||
self.status = CertificateStatus.REVOKED
|
||||
self.save()
|
||||
|
||||
def mark_expired(self):
|
||||
"""Mark certificate as expired when validity period ends."""
|
||||
self.status = CertificateStatus.EXPIRED
|
||||
self.save()
|
||||
@@ -0,0 +1,96 @@
|
||||
"""SSH Key model."""
|
||||
from datetime import datetime
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class SSHKey(BaseModel):
|
||||
"""SSH Key model representing a user's SSH public key.
|
||||
|
||||
This model stores SSH public keys that users register for certificate signing.
|
||||
Users must verify ownership of the key before it can be used for signing certificates.
|
||||
"""
|
||||
|
||||
__tablename__ = "ssh_keys"
|
||||
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# SSH key payload in OpenSSH format (e.g., "ssh-rsa AAAAB3Nz...")
|
||||
payload = db.Column(db.Text, nullable=False, unique=True)
|
||||
|
||||
# SHA256 fingerprint for quick comparison
|
||||
fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
|
||||
# Optional description for the key (e.g., "My laptop key")
|
||||
description = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Verification status
|
||||
verified = db.Column(db.Boolean, default=False, nullable=False, index=True)
|
||||
verified_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Verification challenge
|
||||
verify_text = db.Column(db.String(255), nullable=True)
|
||||
verify_text_created_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Key type extracted from the key (ssh-rsa, ssh-ed25519, etc.)
|
||||
key_type = db.Column(db.String(50), nullable=True)
|
||||
|
||||
# Key bits/length
|
||||
key_bits = db.Column(db.Integer, nullable=True)
|
||||
|
||||
# Comment from the key (usually email or key name)
|
||||
key_comment = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="ssh_keys")
|
||||
certificates = db.relationship(
|
||||
"SSHCertificate",
|
||||
back_populates="ssh_key",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHCertificate.ssh_key_id",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_ssh_key_user_verified", "user_id", "verified"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of SSHKey."""
|
||||
return f"<SSHKey {self.fingerprint[:16]}... user_id={self.user_id}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert SSH key to dictionary."""
|
||||
exclude = exclude or []
|
||||
exclude.extend(["payload", "verify_text"]) # Never expose these in API
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
# Add computed fields
|
||||
data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None])
|
||||
|
||||
return data
|
||||
|
||||
def mark_verified(self):
|
||||
"""Mark this SSH key as verified."""
|
||||
self.verified = True
|
||||
self.verified_at = datetime.utcnow()
|
||||
self.save()
|
||||
|
||||
def needs_verification_refresh(self, max_age_hours=24):
|
||||
"""Check if verification challenge needs to be refreshed.
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age of verification challenge in hours
|
||||
|
||||
Returns:
|
||||
True if verification challenge is stale
|
||||
"""
|
||||
if not self.verify_text_created_at:
|
||||
return True
|
||||
|
||||
age = datetime.utcnow() - self.verify_text_created_at
|
||||
return age.total_seconds() > (max_age_hours * 3600)
|
||||
@@ -55,6 +55,18 @@ class User(BaseModel):
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="PrincipalMembership.user_id",
|
||||
)
|
||||
ssh_keys = db.relationship(
|
||||
"SSHKey",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHKey.user_id",
|
||||
)
|
||||
ssh_certificates = db.relationship(
|
||||
"SSHCertificate",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHCertificate.user_id",
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of User."""
|
||||
|
||||
@@ -12,7 +12,7 @@ from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.authentication_method import OAuthState
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth_service import (
|
||||
ExternalAuthService,
|
||||
@@ -139,7 +139,7 @@ class OAuthFlowService:
|
||||
except ExternalAuthError as e:
|
||||
# Log failed initiation
|
||||
AuditService.log_action(
|
||||
action="external_auth.login.initiated",
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
@@ -236,7 +236,7 @@ class OAuthFlowService:
|
||||
|
||||
except ExternalAuthError as e:
|
||||
AuditService.log_action(
|
||||
action="external_auth.register.initiated",
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
@@ -399,6 +399,27 @@ class OAuthFlowService:
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
if not user_info.get("provider_user_id"):
|
||||
raise OAuthFlowError(
|
||||
"Provider did not return a user identifier (sub claim). "
|
||||
"Cannot complete authentication.",
|
||||
"MISSING_PROVIDER_USER_ID",
|
||||
400,
|
||||
)
|
||||
|
||||
if not user_info.get("email"):
|
||||
raise OAuthFlowError(
|
||||
"Provider did not return an email address. "
|
||||
"Cannot complete authentication.",
|
||||
"MISSING_EMAIL",
|
||||
400,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Got user_info from provider: sub={user_info['provider_user_id']}, "
|
||||
f"email={user_info['email']}, email_verified={user_info.get('email_verified')}"
|
||||
)
|
||||
|
||||
# Look up user by provider_user_id
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
method_type=provider_type,
|
||||
|
||||
@@ -0,0 +1,333 @@
|
||||
"""SSH Certificate Authority signing service.
|
||||
|
||||
Handles SSH certificate signing operations, leveraging sshkey-tools library.
|
||||
This service is a Gatehouse-integrated version of the secuird/ssh_ca.py logic.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from sshkey_tools.cert import SSHCertificate, CertificateFields
|
||||
from sshkey_tools.keys import PublicKey, PrivateKey
|
||||
|
||||
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
|
||||
from gatehouse_app.exceptions import SSHCAError, ValidationError
|
||||
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHCASigningError(Exception):
|
||||
"""SSH CA signing operation error."""
|
||||
pass
|
||||
|
||||
|
||||
class SSHCertificateSigningRequest:
|
||||
"""Represents an SSH certificate signing request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssh_public_key: str,
|
||||
principals: List[str],
|
||||
key_id: str,
|
||||
cert_type: str = "user",
|
||||
expiry_hours: Optional[int] = None,
|
||||
critical_options: Optional[Dict[str, str]] = None,
|
||||
extensions: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize signing request.
|
||||
|
||||
Args:
|
||||
ssh_public_key: Public key in OpenSSH format (e.g., "ssh-ed25519 AAAA...")
|
||||
principals: List of principals (e.g., ["prod-servers", "staging"])
|
||||
key_id: Key identifier (usually user email)
|
||||
cert_type: Certificate type - "user" or "host" (default: user)
|
||||
expiry_hours: Certificate validity in hours
|
||||
critical_options: Critical options dict
|
||||
extensions: List of extensions (e.g., ["permit-pty", "permit-agent-forwarding"])
|
||||
"""
|
||||
self.ssh_public_key = ssh_public_key
|
||||
self.principals = principals or []
|
||||
self.key_id = key_id
|
||||
self.cert_type = cert_type
|
||||
self.expiry_hours = expiry_hours
|
||||
self.critical_options = critical_options or {}
|
||||
self.extensions = extensions or []
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate the signing request.
|
||||
|
||||
Returns:
|
||||
List of validation errors (empty if valid)
|
||||
"""
|
||||
errors = []
|
||||
config = get_ssh_ca_config()
|
||||
|
||||
# Validate cert type
|
||||
if self.cert_type not in ("user", "host"):
|
||||
errors.append(f"Invalid cert_type: {self.cert_type}. Must be 'user' or 'host'")
|
||||
|
||||
# Validate SSH public key
|
||||
if not self.ssh_public_key or len(self.ssh_public_key) < 16:
|
||||
errors.append("SSH public key is missing or invalid")
|
||||
else:
|
||||
try:
|
||||
PublicKey.from_string(self.ssh_public_key)
|
||||
except Exception as e:
|
||||
errors.append(f"SSH public key is not valid: {str(e)}")
|
||||
|
||||
# Validate principals
|
||||
if not self.principals or len(self.principals) == 0:
|
||||
errors.append("At least one principal is required")
|
||||
else:
|
||||
max_principals = config.get_int('max_principals_per_cert')
|
||||
if len(self.principals) > max_principals:
|
||||
errors.append(
|
||||
f"Too many principals ({len(self.principals)}). "
|
||||
f"Maximum is {max_principals}"
|
||||
)
|
||||
|
||||
# Validate key_id
|
||||
if not self.key_id or len(self.key_id) < 5:
|
||||
errors.append("key_id is missing or too short (minimum 5 characters)")
|
||||
else:
|
||||
max_id_len = config.get_int('max_key_id_length')
|
||||
if len(self.key_id) > max_id_len:
|
||||
errors.append(f"key_id exceeds maximum length of {max_id_len}")
|
||||
|
||||
# Validate expiry_hours
|
||||
if self.expiry_hours is not None:
|
||||
if not isinstance(self.expiry_hours, int) or self.expiry_hours <= 0:
|
||||
errors.append("expiry_hours must be a positive integer")
|
||||
else:
|
||||
max_validity = config.get_int('max_cert_validity_hours')
|
||||
if self.expiry_hours > max_validity:
|
||||
errors.append(
|
||||
f"Requested expiry ({self.expiry_hours}h) exceeds "
|
||||
f"maximum allowed ({max_validity}h)"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
class SSHCertificateSigningResponse:
|
||||
"""Represents a signed SSH certificate response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
certificate: str,
|
||||
serial: str,
|
||||
valid_after: datetime,
|
||||
valid_before: datetime,
|
||||
principals: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize signing response.
|
||||
|
||||
Args:
|
||||
certificate: Full certificate in OpenSSH format
|
||||
serial: Certificate serial number
|
||||
valid_after: Validity start datetime
|
||||
valid_before: Validity end datetime
|
||||
principals: List of principals the cert was issued for
|
||||
"""
|
||||
self.certificate = certificate
|
||||
self.serial = serial
|
||||
self.valid_after = valid_after
|
||||
self.valid_before = valid_before
|
||||
self.principals = principals or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert response to dictionary."""
|
||||
return {
|
||||
'certificate': self.certificate,
|
||||
'serial': self.serial,
|
||||
'valid_after': self.valid_after.isoformat(),
|
||||
'valid_before': self.valid_before.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class SSHCASigningService:
|
||||
"""Service for signing SSH certificates.
|
||||
|
||||
This service handles all SSH certificate signing operations.
|
||||
It uses configuration from ssh_ca_config to apply rules and limits.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the SSH CA signing service."""
|
||||
self.config = get_ssh_ca_config()
|
||||
self.logger = logger
|
||||
|
||||
def _load_ca_key_from_config(self) -> str:
|
||||
"""Load CA private key from config (local file or env var).
|
||||
|
||||
Returns:
|
||||
CA private key in PEM/OpenSSH format as string
|
||||
|
||||
Raises:
|
||||
SSHCASigningError: If key cannot be loaded
|
||||
"""
|
||||
# Check env var first
|
||||
key_content = os.environ.get('SSH_CA_PRIVATE_KEY')
|
||||
if key_content:
|
||||
return key_content
|
||||
|
||||
# Load from file path
|
||||
key_path = self.config.get_str('ca_key_path', '').strip()
|
||||
if not key_path:
|
||||
raise SSHCASigningError(
|
||||
"CA private key not configured. Set SSH_CA_PRIVATE_KEY env var "
|
||||
"or ca_key_path in etc/ssh_ca.conf"
|
||||
)
|
||||
|
||||
key_path = os.path.expandvars(os.path.expanduser(key_path))
|
||||
if not os.path.exists(key_path):
|
||||
raise SSHCASigningError(f"CA private key file not found: {key_path}")
|
||||
|
||||
with open(key_path, 'r') as f:
|
||||
return f.read()
|
||||
|
||||
def sign_certificate(
|
||||
self,
|
||||
signing_request: SSHCertificateSigningRequest,
|
||||
ca_private_key: Optional[str] = None,
|
||||
) -> SSHCertificateSigningResponse:
|
||||
"""Sign an SSH certificate.
|
||||
|
||||
Args:
|
||||
signing_request: SSHCertificateSigningRequest instance
|
||||
ca_private_key: CA private key in PEM format. If not provided,
|
||||
loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var)
|
||||
|
||||
Returns:
|
||||
SSHCertificateSigningResponse with signed certificate
|
||||
|
||||
Raises:
|
||||
SSHCASigningError: If signing fails
|
||||
ValidationError: If request is invalid
|
||||
"""
|
||||
# Validate request
|
||||
errors = signing_request.validate()
|
||||
if errors:
|
||||
error_msg = "; ".join(errors)
|
||||
self.logger.error(f"Certificate signing validation failed: {error_msg}")
|
||||
raise ValidationError(f"Certificate signing validation failed: {error_msg}")
|
||||
|
||||
# Load CA key if not provided
|
||||
if ca_private_key is None:
|
||||
ca_private_key = self._load_ca_key_from_config()
|
||||
|
||||
try:
|
||||
# Parse CA private key
|
||||
try:
|
||||
ca_key = PrivateKey.from_string(ca_private_key)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load CA private key: {str(e)}")
|
||||
raise SSHCASigningError(f"Invalid CA private key: {str(e)}")
|
||||
|
||||
# Parse user's public key
|
||||
try:
|
||||
user_pub_key = PublicKey.from_string(signing_request.ssh_public_key)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to parse user public key: {str(e)}")
|
||||
raise SSHCASigningError(f"Invalid user public key: {str(e)}")
|
||||
|
||||
# Create certificate
|
||||
certificate = SSHCertificate.create(
|
||||
subject_pubkey=user_pub_key,
|
||||
ca_privkey=ca_key,
|
||||
)
|
||||
|
||||
# Set validity period
|
||||
now = datetime.utcnow()
|
||||
expiry_hours = signing_request.expiry_hours or self.config.get_int('cert_validity_hours')
|
||||
valid_before = now + timedelta(hours=expiry_hours)
|
||||
|
||||
# Set certificate fields
|
||||
cert_type = 1 if signing_request.cert_type == "user" else 0
|
||||
|
||||
certificate.fields.cert_type = cert_type
|
||||
certificate.fields.key_id = signing_request.key_id
|
||||
certificate.fields.principals = signing_request.principals
|
||||
certificate.fields.valid_after = now
|
||||
certificate.fields.valid_before = valid_before
|
||||
|
||||
# Set extensions
|
||||
extensions = signing_request.extensions
|
||||
if not extensions and self.config.get_bool('extensions_enabled'):
|
||||
extensions = self.config.get_list('extensions')
|
||||
|
||||
certificate.fields.extensions = extensions or []
|
||||
certificate.fields.critical_options = signing_request.critical_options or {}
|
||||
|
||||
# Validate certificate before signing
|
||||
if not certificate.can_sign():
|
||||
raise SSHCASigningError("Certificate cannot be signed")
|
||||
|
||||
# Sign the certificate
|
||||
certificate.sign()
|
||||
|
||||
# Verify the certificate
|
||||
try:
|
||||
certificate.verify(ca_key.public_key, raise_on_error=True)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Certificate verification failed: {str(e)}")
|
||||
raise SSHCASigningError(f"Certificate verification failed: {str(e)}")
|
||||
|
||||
# Extract serial from certificate
|
||||
serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
|
||||
|
||||
# Build response
|
||||
cert_string = certificate.to_string()
|
||||
|
||||
self.logger.info(
|
||||
f"Successfully signed certificate: serial={serial}, "
|
||||
f"key_id={signing_request.key_id}, principals={signing_request.principals}"
|
||||
)
|
||||
|
||||
return SSHCertificateSigningResponse(
|
||||
certificate=cert_string,
|
||||
serial=serial,
|
||||
valid_after=now,
|
||||
valid_before=valid_before,
|
||||
principals=signing_request.principals,
|
||||
)
|
||||
|
||||
except (SSHCASigningError, ValidationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error during certificate signing: {str(e)}", exc_info=True)
|
||||
raise SSHCASigningError(f"Error signing certificate: {str(e)}")
|
||||
|
||||
def verify_ca_key(self, ca_private_key: str) -> Dict[str, Any]:
|
||||
"""Verify a CA private key is valid and extract metadata.
|
||||
|
||||
Args:
|
||||
ca_private_key: CA private key in PEM format
|
||||
|
||||
Returns:
|
||||
Dictionary with key metadata (fingerprint, key_type, etc.)
|
||||
|
||||
Raises:
|
||||
SSHCASigningError: If key is invalid
|
||||
"""
|
||||
try:
|
||||
ca_key = PrivateKey.from_string(ca_private_key)
|
||||
pub_key = ca_key.public_key
|
||||
|
||||
# Compute fingerprint
|
||||
fingerprint = compute_ssh_fingerprint(pub_key.to_string())
|
||||
|
||||
# Get key type
|
||||
key_type = pub_key.keytype if hasattr(pub_key, 'keytype') else 'unknown'
|
||||
|
||||
return {
|
||||
'fingerprint': fingerprint,
|
||||
'key_type': key_type,
|
||||
'public_key': pub_key.to_string(),
|
||||
'valid': True,
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"CA key verification failed: {str(e)}")
|
||||
raise SSHCASigningError(f"Invalid CA key: {str(e)}")
|
||||
@@ -0,0 +1,373 @@
|
||||
"""SSH Key management service."""
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import subprocess
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models import SSHKey, User
|
||||
from gatehouse_app.exceptions import (
|
||||
SSHKeyError,
|
||||
SSHKeyNotFoundError,
|
||||
SSHKeyAlreadyExistsError,
|
||||
SSHKeyNotVerifiedError,
|
||||
ValidationError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from gatehouse_app.utils.crypto import (
|
||||
compute_ssh_fingerprint,
|
||||
verify_ssh_key_format,
|
||||
extract_ssh_key_type,
|
||||
extract_ssh_key_comment,
|
||||
)
|
||||
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHKeyService:
|
||||
"""Service for managing SSH keys."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize SSH key service."""
|
||||
self.config = get_ssh_ca_config()
|
||||
|
||||
def add_ssh_key(
|
||||
self,
|
||||
user_id: str,
|
||||
public_key: str,
|
||||
description: Optional[str] = None,
|
||||
) -> SSHKey:
|
||||
"""Add an SSH public key for a user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user
|
||||
public_key: SSH public key in OpenSSH format
|
||||
description: Optional description of the key
|
||||
|
||||
Returns:
|
||||
Created SSHKey instance
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If user doesn't exist
|
||||
SSHKeyError: If key format is invalid
|
||||
SSHKeyAlreadyExistsError: If key already exists
|
||||
"""
|
||||
# Verify user exists
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"User {user_id} not found")
|
||||
|
||||
# Validate key format
|
||||
if not verify_ssh_key_format(public_key):
|
||||
raise SSHKeyError("Invalid SSH public key format")
|
||||
|
||||
# Compute fingerprint
|
||||
try:
|
||||
fingerprint = compute_ssh_fingerprint(public_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compute fingerprint: {str(e)}")
|
||||
raise SSHKeyError(f"Failed to compute key fingerprint: {str(e)}")
|
||||
|
||||
# Check for duplicate (including soft-deleted records — fingerprint is unique in DB)
|
||||
existing = SSHKey.query.filter_by(fingerprint=fingerprint).first()
|
||||
if existing:
|
||||
if existing.deleted_at is not None:
|
||||
# Restore the soft-deleted key: clear deleted_at and update fields
|
||||
existing.deleted_at = None
|
||||
existing.user_id = user_id
|
||||
existing.description = description or existing.description
|
||||
existing.verified = False
|
||||
existing.verified_at = None
|
||||
existing.verify_text = None
|
||||
existing.verify_text_created_at = None
|
||||
db.session.commit()
|
||||
logger.info(
|
||||
f"Restored soft-deleted SSH key for user {user_id}: "
|
||||
f"fingerprint={fingerprint}"
|
||||
)
|
||||
return existing
|
||||
raise SSHKeyAlreadyExistsError(
|
||||
f"SSH key with fingerprint {fingerprint} already exists"
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
key_type = extract_ssh_key_type(public_key)
|
||||
key_comment = extract_ssh_key_comment(public_key)
|
||||
|
||||
# Create SSH key record
|
||||
ssh_key = SSHKey(
|
||||
user_id=user_id,
|
||||
payload=public_key,
|
||||
fingerprint=fingerprint,
|
||||
description=description,
|
||||
key_type=key_type,
|
||||
key_comment=key_comment,
|
||||
verified=False,
|
||||
)
|
||||
|
||||
ssh_key.save()
|
||||
|
||||
logger.info(
|
||||
f"SSH key added for user {user_id}: "
|
||||
f"fingerprint={fingerprint}, type={key_type}"
|
||||
)
|
||||
|
||||
return ssh_key
|
||||
|
||||
def get_ssh_key(self, key_id: str) -> SSHKey:
|
||||
"""Get an SSH key by ID.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Returns:
|
||||
SSHKey instance
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = SSHKey.query.filter_by(id=key_id, deleted_at=None).first()
|
||||
if not key:
|
||||
raise SSHKeyNotFoundError(f"SSH key {key_id} not found")
|
||||
return key
|
||||
|
||||
def get_user_ssh_keys(self, user_id: str) -> List[SSHKey]:
|
||||
"""Get all SSH keys for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of SSHKey instances
|
||||
"""
|
||||
return SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all()
|
||||
|
||||
def get_user_verified_ssh_keys(self, user_id: str) -> List[SSHKey]:
|
||||
"""Get all verified SSH keys for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of verified SSHKey instances
|
||||
"""
|
||||
return SSHKey.query.filter_by(
|
||||
user_id=user_id,
|
||||
verified=True,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
|
||||
def delete_ssh_key(self, key_id: str) -> None:
|
||||
"""Soft-delete an SSH key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
key.delete()
|
||||
|
||||
logger.info(f"SSH key deleted: {key_id}")
|
||||
|
||||
def generate_verification_challenge(self, key_id: str) -> str:
|
||||
"""Generate a verification challenge for an SSH key.
|
||||
|
||||
The user must sign this challenge text with their private key
|
||||
to prove key ownership.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Returns:
|
||||
Verification challenge text
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
|
||||
# Generate random challenge
|
||||
challenge = secrets.token_hex(32)
|
||||
challenge_text = f"Please sign this to verify SSH key ownership: {challenge}"
|
||||
|
||||
# Store challenge
|
||||
key.verify_text = challenge_text
|
||||
key.verify_text_created_at = datetime.utcnow()
|
||||
key.save()
|
||||
|
||||
logger.info(f"Generated verification challenge for SSH key {key_id}")
|
||||
|
||||
return challenge_text
|
||||
|
||||
def verify_ssh_key_ownership(
|
||||
self,
|
||||
key_id: str,
|
||||
signature: str,
|
||||
) -> bool:
|
||||
"""Verify SSH key ownership via signature.
|
||||
|
||||
The user must sign the verification challenge with their private key.
|
||||
We verify the signature using the public key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
signature: Base64-encoded signature of the challenge
|
||||
|
||||
Returns:
|
||||
True if signature is valid
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
SSHKeyNotVerifiedError: If challenge is stale or missing
|
||||
SSHKeyError: If verification fails
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
|
||||
# Check if challenge exists and is not stale
|
||||
if not key.verify_text or not key.verify_text_created_at:
|
||||
raise SSHKeyNotVerifiedError("No verification challenge generated")
|
||||
|
||||
max_age = self.config.get_int('verification_challenge_max_age')
|
||||
age = datetime.utcnow() - key.verify_text_created_at
|
||||
if age.total_seconds() > (max_age * 3600):
|
||||
raise SSHKeyNotVerifiedError("Verification challenge has expired")
|
||||
|
||||
try:
|
||||
# Verify the SSH signature using ssh-keygen -Y verify.
|
||||
# The CLI signs the challenge with: ssh-keygen -Y sign -f <key> -n file <challenge>
|
||||
# We verify with: ssh-keygen -Y verify -f <allowed_signers> -I <identity> -n file -s <sig> < <message>
|
||||
#
|
||||
# allowed_signers format: "<identity> <keytype> <pubkey>"
|
||||
# We use the key fingerprint as the identity.
|
||||
|
||||
sig_bytes = base64.b64decode(signature)
|
||||
challenge_text = key.verify_text + "\n"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
allowed_signers_path = os.path.join(tmpdir, "allowed_signers")
|
||||
sig_path = os.path.join(tmpdir, "message.sig")
|
||||
message_path = os.path.join(tmpdir, "message.txt")
|
||||
|
||||
identity = key.fingerprint
|
||||
|
||||
# Write the allowed_signers file
|
||||
with open(allowed_signers_path, "w") as f:
|
||||
f.write(f"{identity} {key.payload}\n")
|
||||
|
||||
# Write the signature file
|
||||
with open(sig_path, "wb") as f:
|
||||
f.write(sig_bytes)
|
||||
|
||||
# Write the challenge message
|
||||
with open(message_path, "w") as f:
|
||||
f.write(challenge_text)
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ssh-keygen", "-Y", "verify",
|
||||
"-f", allowed_signers_path,
|
||||
"-I", identity,
|
||||
"-n", "file",
|
||||
"-s", sig_path,
|
||||
],
|
||||
stdin=open(message_path, "rb"),
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = result.stderr.decode(errors="replace").strip()
|
||||
logger.warning(f"SSH signature verification failed for key {key_id}: {stderr}")
|
||||
raise SSHKeyError(f"Signature verification failed: {stderr}")
|
||||
|
||||
key.mark_verified()
|
||||
logger.info(f"SSH key verified: {key_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SSH key verification failed: {str(e)}")
|
||||
raise SSHKeyError(f"Signature verification failed: {str(e)}")
|
||||
|
||||
def get_key_fingerprint(self, key_id: str) -> str:
|
||||
"""Get the fingerprint of an SSH key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Returns:
|
||||
Fingerprint string
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
return key.fingerprint
|
||||
|
||||
def update_ssh_key_description(self, key_id: str, description: str) -> SSHKey:
|
||||
"""Update the description of an SSH key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
description: New description
|
||||
|
||||
Returns:
|
||||
Updated SSHKey instance
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
key.description = description
|
||||
key.save()
|
||||
|
||||
return key
|
||||
|
||||
def cleanup_expired_challenges(self) -> int:
|
||||
"""Clean up expired verification challenges.
|
||||
|
||||
Returns:
|
||||
Number of challenges cleaned
|
||||
"""
|
||||
max_age = self.config.get_int('verification_challenge_max_age')
|
||||
threshold = datetime.utcnow() - timedelta(hours=max_age)
|
||||
|
||||
expired = SSHKey.query.filter(
|
||||
SSHKey.verify_text_created_at < threshold,
|
||||
SSHKey.verify_text_created_at.isnot(None),
|
||||
SSHKey.deleted_at.is_(None),
|
||||
).update({"verify_text": None, "verify_text_created_at": None})
|
||||
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"Cleaned up {expired} expired verification challenges")
|
||||
return expired
|
||||
|
||||
def cleanup_unverified_keys(self) -> int:
|
||||
"""Delete unverified SSH keys older than configured days.
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
days = self.config.get_int('auto_delete_unverified_days')
|
||||
threshold = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
old_unverified = SSHKey.query.filter(
|
||||
SSHKey.verified == False,
|
||||
SSHKey.created_at < threshold,
|
||||
SSHKey.deleted_at.is_(None),
|
||||
).all()
|
||||
|
||||
count = 0
|
||||
for key in old_unverified:
|
||||
key.delete()
|
||||
count += 1
|
||||
|
||||
logger.info(f"Deleted {count} unverified SSH keys older than {days} days")
|
||||
return count
|
||||
@@ -12,6 +12,16 @@ class UserStatus(str, Enum):
|
||||
COMPLIANCE_SUSPENDED = "compliance_suspended"
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
"""Generic role definitions (hierarchy: Admin > Manager > Member > Viewer > Guest)."""
|
||||
|
||||
ADMIN = "admin"
|
||||
MANAGER = "manager"
|
||||
MEMBER = "member"
|
||||
VIEWER = "viewer"
|
||||
GUEST = "guest"
|
||||
|
||||
|
||||
class OrganizationRole(str, Enum):
|
||||
"""Organization member roles."""
|
||||
|
||||
@@ -105,6 +115,37 @@ class AuditAction(str, Enum):
|
||||
EXTERNAL_AUTH_CONFIG_UPDATE = "external_auth.config.update"
|
||||
EXTERNAL_AUTH_CONFIG_DELETE = "external_auth.config.delete"
|
||||
|
||||
# SSH Key and Certificate actions
|
||||
SSH_KEY_ADDED = "ssh.key.added"
|
||||
SSH_KEY_VERIFIED = "ssh.key.verified"
|
||||
SSH_KEY_DELETED = "ssh.key.deleted"
|
||||
SSH_KEY_VALIDATION_FAILED = "ssh.key.validation.failed"
|
||||
SSH_CERT_REQUESTED = "ssh.cert.requested"
|
||||
SSH_CERT_ISSUED = "ssh.cert.issued"
|
||||
SSH_CERT_FAILED = "ssh.cert.failed"
|
||||
SSH_CERT_REVOKED = "ssh.cert.revoked"
|
||||
SSH_CERT_EXPIRED = "ssh.cert.expired"
|
||||
|
||||
# CA actions
|
||||
CA_CREATED = "ca.created"
|
||||
CA_UPDATED = "ca.updated"
|
||||
CA_DELETED = "ca.deleted"
|
||||
CA_KEY_ROTATED = "ca.key.rotated"
|
||||
|
||||
# Principal actions
|
||||
PRINCIPAL_CREATED = "principal.created"
|
||||
PRINCIPAL_UPDATED = "principal.updated"
|
||||
PRINCIPAL_DELETED = "principal.deleted"
|
||||
PRINCIPAL_MEMBER_ADDED = "principal.member.added"
|
||||
PRINCIPAL_MEMBER_REMOVED = "principal.member.removed"
|
||||
|
||||
# Department actions
|
||||
DEPARTMENT_CREATED = "department.created"
|
||||
DEPARTMENT_UPDATED = "department.updated"
|
||||
DEPARTMENT_DELETED = "department.deleted"
|
||||
DEPARTMENT_MEMBER_ADDED = "department.member.added"
|
||||
DEPARTMENT_MEMBER_REMOVED = "department.member.removed"
|
||||
|
||||
|
||||
class OIDCGrantType(str, Enum):
|
||||
"""OIDC grant types."""
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Cryptographic utilities for SSH operations."""
|
||||
import hashlib
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def compute_ssh_fingerprint(public_key_str: str, hash_algorithm: str = "sha256") -> str:
|
||||
"""Compute the fingerprint of an SSH public key.
|
||||
|
||||
Args:
|
||||
public_key_str: SSH public key in OpenSSH format
|
||||
hash_algorithm: Hash algorithm to use (sha256, sha1, md5)
|
||||
|
||||
Returns:
|
||||
Fingerprint string in the format "algorithm:hex_digest"
|
||||
|
||||
Example:
|
||||
>>> key = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKp2..."
|
||||
>>> fp = compute_ssh_fingerprint(key)
|
||||
>>> print(fp)
|
||||
sha256:Kb+...
|
||||
"""
|
||||
if not public_key_str:
|
||||
raise ValueError("Public key string is empty")
|
||||
|
||||
# Parse OpenSSH format: "ssh-ed25519 <base64> [comment]"
|
||||
parts = public_key_str.strip().split()
|
||||
if len(parts) < 2:
|
||||
raise ValueError("Invalid OpenSSH public key format")
|
||||
|
||||
try:
|
||||
# The base64-encoded key is the second part
|
||||
key_bytes = base64.b64decode(parts[1])
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to decode public key: {str(e)}")
|
||||
|
||||
# Compute hash
|
||||
if hash_algorithm == "sha256":
|
||||
digest = hashlib.sha256(key_bytes).digest()
|
||||
# SSH format uses base64 encoding without padding
|
||||
fingerprint = base64.b64encode(digest).decode().rstrip('=')
|
||||
elif hash_algorithm == "sha1":
|
||||
digest = hashlib.sha1(key_bytes).hexdigest()
|
||||
fingerprint = digest
|
||||
elif hash_algorithm == "md5":
|
||||
digest = hashlib.md5(key_bytes).hexdigest()
|
||||
# Format as colons
|
||||
fingerprint = ':'.join(digest[i:i+2] for i in range(0, len(digest), 2))
|
||||
else:
|
||||
raise ValueError(f"Unsupported hash algorithm: {hash_algorithm}")
|
||||
|
||||
return f"{hash_algorithm}:{fingerprint}"
|
||||
|
||||
|
||||
def verify_ssh_key_format(public_key_str: str) -> bool:
|
||||
"""Verify that a string is in valid OpenSSH public key format.
|
||||
|
||||
Args:
|
||||
public_key_str: Potential SSH public key
|
||||
|
||||
Returns:
|
||||
True if valid OpenSSH format, False otherwise
|
||||
"""
|
||||
if not public_key_str or not isinstance(public_key_str, str):
|
||||
return False
|
||||
|
||||
parts = public_key_str.strip().split()
|
||||
|
||||
# Must have at least key type and key material
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
|
||||
key_type = parts[0]
|
||||
|
||||
# Valid key types
|
||||
valid_types = [
|
||||
'ssh-rsa',
|
||||
'ssh-ed25519',
|
||||
'ecdsa-sha2-nistp256',
|
||||
'ecdsa-sha2-nistp384',
|
||||
'ecdsa-sha2-nistp521',
|
||||
'ssh-dss',
|
||||
]
|
||||
|
||||
if key_type not in valid_types:
|
||||
return False
|
||||
|
||||
# Try to decode base64
|
||||
try:
|
||||
base64.b64decode(parts[1])
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def extract_ssh_key_type(public_key_str: str) -> Optional[str]:
|
||||
"""Extract the key type from an OpenSSH public key.
|
||||
|
||||
Args:
|
||||
public_key_str: SSH public key in OpenSSH format
|
||||
|
||||
Returns:
|
||||
Key type (e.g., "ssh-ed25519") or None if invalid
|
||||
"""
|
||||
if not verify_ssh_key_format(public_key_str):
|
||||
return None
|
||||
|
||||
return public_key_str.strip().split()[0]
|
||||
|
||||
|
||||
def extract_ssh_key_comment(public_key_str: str) -> Optional[str]:
|
||||
"""Extract the comment from an OpenSSH public key.
|
||||
|
||||
Args:
|
||||
public_key_str: SSH public key in OpenSSH format
|
||||
|
||||
Returns:
|
||||
Comment string or None if not present
|
||||
"""
|
||||
if not verify_ssh_key_format(public_key_str):
|
||||
return None
|
||||
|
||||
parts = public_key_str.strip().split()
|
||||
if len(parts) >= 3:
|
||||
# Everything after the second part is the comment
|
||||
return ' '.join(parts[2:])
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user