Feat: Added CA-merged with Securid-Principals, Depart, Client-CLI

This commit is contained in:
2026-02-27 21:59:01 +05:45
parent 92fd57447d
commit b2212ab4d6
29 changed files with 3718 additions and 53 deletions
+110 -34
View File
@@ -2,6 +2,7 @@
import base64
from datetime import datetime
import os
import sys
import webbrowser
import requests
import argparse
@@ -22,13 +23,12 @@ import base64
load_dotenv()
# Get the API_URL from the environment variables
SIGN_URL = os.getenv("SIGN_URL", "http://localhost:1234")
SIGN_URL = os.getenv("SIGN_URL", "http://localhost:5000")
LISTENER_HOST_NAME = "127.0.0.1"
LISTENER_SERVER_PORT = 8250
CA_API_HOST = "127.0.0.1"
CA_SERVER_PORT = 1234
CACHE_FILE = 'token_cache.json' ###need to change it to secure location and permissions if used in production
CERT_FILE_PATH = "/tmp/ssl-cert"
CACHE_FILE = os.path.expanduser('~/.gatehouse/token_cache.json')
os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True)
CERT_FILE_PATH = "/tmp/ssh-cert"
CHALLENGE_FILE_PATH = "/tmp/challenge.txt"
CHALLENGE_SIG_FILE_PATH = "/tmp/challenge.txt.sig"
@@ -116,7 +116,7 @@ def decode_and_validate_token(token):
return True # Token is valid
except Exception as e:
logger.error(f"Token validation failed: {e}")
logger.debug(f"Token validation failed: {e}")
return False
def request_token():
@@ -129,11 +129,26 @@ def request_token():
logger.debug("Token loaded from cache: %s", token)
# Validate the cached token, if it exists
if token and decode_and_validate_token(token):
if token:
try:
if decode_and_validate_token(token):
logger.info("Cached token is valid. Using cached token.")
return token
logger.info("No valid cached token found, proceeding to request a new token.")
except Exception:
pass
# Try opaque token via /auth/me
try:
r = requests.get(
f"{SIGN_URL}/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"},
timeout=5,
)
if r.status_code == 200:
logger.info("Cached session token is valid. Using cached token.")
return token
except Exception:
pass
logger.info("Cached token is expired or invalid, requesting a new token.")
token = ""
# Prepare the redirect URL for the token request
@@ -141,7 +156,7 @@ def request_token():
logger.info("Redirect URL: %s", redirect_url)
# Construct the token request URL
token_url = f"{SIGN_URL}/token_please?redirect_url={redirect_url}"
token_url = f"{SIGN_URL}/api/v1/token_please?redirect_url={redirect_url}"
logger.info("Token request URL: %s", token_url)
# Start the web server to handle the token response
@@ -168,10 +183,10 @@ def get_activated_ssh_key():
'Authorization': f'Bearer {token}',
'Content-Type': 'application/json'
}
response = requests.get(f"{SIGN_URL}/api/ssh-keys", headers=headers)
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=headers)
if response.status_code == 200:
keys = response.json().get('ssh_keys', [])
keys = response.json().get('keys', [])
verified_keys = [key for key in keys if key['verified']]
if not verified_keys:
@@ -179,8 +194,19 @@ def get_activated_ssh_key():
exit(1)
if len(verified_keys) > 1:
logger.error("Multiple verified SSH keys found. Please specify CERT_ID.")
exit(1)
# If running interactively, let the user pick; otherwise use the most recently added key
if sys.stdout.isatty():
print("\nMultiple verified SSH keys found. Please choose one:")
for i, k in enumerate(verified_keys):
print(f" [{i+1}] {k['id'][:8]}... fingerprint={k.get('fingerprint','?')} name={k.get('key_comment','?')}")
try:
choice = int(input("Enter number: ").strip()) - 1
if 0 <= choice < len(verified_keys):
return verified_keys[choice]['id']
except (ValueError, EOFError):
pass
logger.info("Multiple verified SSH keys found; using the most recently added one.")
verified_keys.sort(key=lambda k: k.get('created_at', ''), reverse=True)
return verified_keys[0]['id']
@@ -193,26 +219,35 @@ def get_activated_ssh_key():
exit(1)
def request_certificate():
def request_certificate(principals=None):
CERT_ID = os.getenv("CERT_ID") or get_activated_ssh_key()
if not principals:
env_principals = os.getenv("PRINCIPALS")
if env_principals:
principals = [p.strip() for p in env_principals.split(',')]
else:
principals = [os.getlogin()]
headers = {
'content-type': 'application/json',
"Authorization": "bearer " + token
}
payload = {
'cert_id': CERT_ID
'cert_id': CERT_ID,
'principals': principals,
}
try:
response = requests.post(f"{SIGN_URL}/sign_cert", json=payload, headers=headers)
response = requests.post(f"{SIGN_URL}/api/v1/ssh/sign", json=payload, headers=headers)
if response.status_code == 200:
if response.status_code == 201:
json_result = response.json()
with open(CERT_FILE_PATH, 'w') as f:
f.write(json_result['certificate'])
logger.info(f"Certificate signed successfully, located at {CERT_FILE_PATH}")
logger.info(f"Valid for principals: {', '.join(json_result.get('principals', principals))}")
logger.info("You can login to your destination server with the following command")
logger.info(f"\tssh user@server -o CertificateFile={CERT_FILE_PATH}")
else:
@@ -238,14 +273,14 @@ def generate_and_sign_challenge(ssh_key_file,key_id):
# Send the POST request
response = requests.get(
f"http://{CA_API_HOST}:{CA_SERVER_PORT}/api/ssh-key/{key_id}/validationData",
f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify",
headers=headers
)
if response.status_code!=200:
logger.error(f"Server returned unexpected code {response.status_code}")
return False
challenge_text=response.json()['validationText']+"\n"
challenge_text=response.json().get('challenge_text', response.json().get('validationText', ''))+"\n"
except Exception as e:
logger.error(f"Unable to fetch SSH Key validation data {e}")
@@ -291,7 +326,7 @@ def submit_signature_validation(signature, key_id):
# Send the POST request
response = requests.post(
f"http://{CA_API_HOST}:{CA_SERVER_PORT}/api/ssh-key/{key_id}/validate",
f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify",
headers=headers,
json=payload
)
@@ -317,12 +352,12 @@ def remove_ssh_key(key_id=None):
}
# List keys first
response = requests.get(f"{SIGN_URL}/api/ssh-keys", headers=headers)
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=headers)
if response.status_code != 200:
logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}")
exit(1)
keys = response.json().get('ssh_keys', [])
keys = response.json().get('keys', [])
if not keys:
logger.info("No SSH keys found for your user.")
return
@@ -359,7 +394,7 @@ def remove_ssh_key(key_id=None):
exit(1)
for k in keys_to_delete:
del_response = requests.delete(f"{SIGN_URL}/api/ssh-key/{k['id']}", headers=headers)
del_response = requests.delete(f"{SIGN_URL}/api/v1/ssh/keys/{k['id']}", headers=headers)
if del_response.status_code == 200:
logger.info(f"Key {k['id']} removed successfully.")
else:
@@ -383,18 +418,38 @@ def add_ssh_key(ssh_key_file):
'Content-Type': 'application/json'
}
ssh_key = ssh_key_file.read().decode('utf-8')
if hasattr(ssh_key_file, 'read'):
# File object (e.g. argparse.FileType('rb'))
key_bytes = ssh_key_file.read()
key_path = ssh_key_file.name
elif isinstance(ssh_key_file, bytes):
key_bytes = ssh_key_file
key_path = None
else:
# String path
key_path = str(ssh_key_file)
with open(key_path, 'rb') as f:
key_bytes = f.read()
ssh_key = key_bytes.decode('utf-8').strip()
payload = {
'description': 'Added via gatehouse CLI tool',
'key': ssh_key
}
response = requests.post(f"{SIGN_URL}/api/ssh-key/add", json=payload, headers=headers)
response = requests.post(f"{SIGN_URL}/api/v1/ssh/keys", json=payload, headers=headers)
if response.status_code == 200:
ssh_key_id=response.json()['key_id']
if response.status_code == 201:
ssh_key_id=response.json()['id']
logger.info(f"SSH key {ssh_key_id} added successfully")
generate_and_sign_challenge(ssh_key_file.name,ssh_key_id)
if key_path:
# Strip .pub suffix to get the private key path for signing
private_key_path = key_path[:-4] if key_path.endswith('.pub') else key_path
generate_and_sign_challenge(private_key_path, ssh_key_id)
else:
logger.warning("No key file path available — skipping auto-verification. "
"Run with -k <path> to enable automatic key verification.")
else:
logger.error(f"Failed to add SSH key: {response.status_code} - {response.text}")
@@ -431,13 +486,15 @@ if __name__ == "__main__":
parser.add_argument("-a", "--add-key", action='store_true', default=False, help="Add SSH key to the server")
parser.add_argument("-c", "--check-cert", action='store_true', default=False, help="Check the certificate, if it's valid exit 0, if it's invalid exit 1")
parser.add_argument("-r", "--request-cert", action='store_true', default=False, help="Request that gatehouse sign a new certificate for you based on an SSH public key on file in your profile")
parser.add_argument("--principals", nargs='+', metavar='PRINCIPAL', help="Unix usernames for the certificate (default: current OS user)")
parser.add_argument("--clear-cache", action='store_true', default=False, help="Remove the cached authentication token")
parser.add_argument("--remove-key", nargs='?', const='', metavar='KEY_ID', help="Remove an SSH key from your profile. Omit KEY_ID to pick interactively.")
parser.add_argument("--list-keys", action='store_true', default=False, help="List SSH keys in your profile")
args = parser.parse_args()
# Ensure that one of --check-cert, --request-cert, or --add-key is provided
if not (args.check_cert or args.request_cert or args.add_key or args.clear_cache or args.remove_key is not None):
parser.error("At least one of --check-cert, --request-cert, --add-key, --validate-key, or --clear-cache must be provided.")
if not (args.check_cert or args.request_cert or args.add_key or args.clear_cache
or args.remove_key is not None or args.list_keys):
parser.error("At least one of --check-cert, --request-cert, --add-key, --list-keys, --remove-key, or --clear-cache must be provided.")
# Retrieve SSH key from environment variables if not provided via CLI
@@ -456,6 +513,25 @@ if __name__ == "__main__":
remove_ssh_key(args.remove_key if args.remove_key else None)
exit(0)
if args.list_keys:
request_token()
response = requests.get(
f"{SIGN_URL}/api/v1/ssh/keys",
headers={"Authorization": f"Bearer {token}"},
)
if response.status_code == 200:
data = response.json()
keys = data.get('keys', [])
if not keys:
print("No SSH keys found in your profile.")
else:
for k in keys:
verified = "✓ verified" if k.get('verified') else "✗ unverified"
print(f" {k['id']} {verified} {k.get('description', '')} (added {k['created_at'][:10]})")
else:
logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}")
exit(0)
if args.add_key:
request_token()
@@ -476,10 +552,10 @@ if __name__ == "__main__":
if args.force:
request_token()
logger.info("Forcing renewal of certificate")
request_certificate()
request_certificate(principals=args.principals)
if checkCert() == 1:
request_token()
request_certificate()
request_certificate(principals=args.principals)
exit(0)
+114
View File
@@ -0,0 +1,114 @@
[default]
# Certificate validity period (in hours)
# Default: 1 hour
cert_validity_hours=1
# Maximum certificate validity allowed (in hours)
# Default: 24 hours
# Prevents users from requesting certificates valid longer than this
max_cert_validity_hours=24
# Certificate Request Limits
# Maximum number of certificates per user
max_certs_per_user=100
# Certificate revocation list (CRL) configuration
crl_enabled=true
# CRL endpoint URL - set to your domain where CRL is served
crl_endpoint=https://ca.example.com/crl
# CRL refresh interval (in hours)
crl_refresh_hours=24
# CA Key Configuration
# Default key type for new CAs (ed25519, rsa, ecdsa)
default_key_type=ed25519
# RSA key size (if using RSA)
rsa_key_bits=4096
# Private key encryption
# Method: kms (AWS Key Management Service) or local (for development only)
private_key_encryption=kms
# AWS KMS Key ID (only used if private_key_encryption=kms)
aws_kms_key_id=${SSH_CA_KMS_KEY_ID}
# SSH Certificate Extensions
# Default extensions to add to certificates
extensions_enabled=true
extensions=permit-X11-forwarding,permit-agent-forwarding,permit-pty,permit-port-forwarding,permit-user-rc
# Critical Options
# Critical options to add to certificates (rarely needed)
critical_options_enabled=false
# Certificate Field Limits
# Maximum number of principals per certificate (SSH limitation is 256)
max_principals_per_cert=256
# Maximum length for key_id field
max_key_id_length=255
# Logging Configuration
# Log level for SSH CA operations (DEBUG, INFO, WARNING, ERROR)
log_level=INFO
# Audit Configuration
# Log all certificate signing operations
audit_enabled=true
# Security Configuration
# Require SSH key verification before issuing certificates
require_key_verification=true
# Verification challenge max age (in hours)
verification_challenge_max_age=24
# Rate limiting for certificate signing
# Max certificates per minute per user
rate_limit_certs_per_minute=5
# Request timeout (in seconds)
request_timeout=30
# Cleanup Configuration
# Automatically delete unverified SSH keys after this many days
auto_delete_unverified_days=30
# Archive expired certificates after this many days
archive_expired_days=365
# CLI OAuth Configuration (for secuird-cli.py compatibility)
# OAuth token endpoint for CLI clients
oauth_token_endpoint=/api/v1/oauth2/token
# OAuth userinfo endpoint for CLI clients
oauth_userinfo_endpoint=/api/v1/oauth2/userinfo
[development]
# Override settings for development environment
private_key_encryption=local
ca_key_path=/home/james/cory/secuird/certs/ca-users
log_level=DEBUG
cert_validity_hours=24
max_cert_validity_hours=720
rate_limit_certs_per_minute=100
require_key_verification=false
[production]
# Override settings for production environment
private_key_encryption=kms
log_level=WARNING
cert_validity_hours=1
max_cert_validity_hours=24
rate_limit_certs_per_minute=5
require_key_verification=true
[testing]
# Override settings for testing environment
private_key_encryption=local
log_level=DEBUG
cert_validity_hours=1
max_cert_validity_hours=24
rate_limit_certs_per_minute=100
require_key_verification=true
audit_enabled=false
+6
View File
@@ -2,6 +2,9 @@
import os
import logging
from dotenv import load_dotenv
load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env'))
# Test debug logging - this should appear when running `flask run --debug`
_root_logger = logging.getLogger(__name__)
_root_logger.debug("[TEST] Debug logging is working!")
@@ -239,3 +242,6 @@ def initialize_oidc_jwks(app):
app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}")
except Exception as e:
app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}")
# Create default app instance for gunicorn/wsgi
app = create_app()
+3 -1
View File
@@ -5,4 +5,6 @@ from flask import Blueprint
api_v1_bp = Blueprint("api_v1", __name__)
# Import route modules to register them
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh
api_v1_bp.register_blueprint(ssh.ssh_bp)
+130 -6
View File
@@ -46,6 +46,33 @@ def _pop_oidc_bridge(oauth_state: str) -> str | None:
pass
return None
def _store_cli_redirect(oauth_state: str, redirect_url: str) -> None:
"""Store CLI redirect_url keyed by OAuth state (for /token_please flow)."""
try:
import gatehouse_app.extensions as _ext
rc = _ext.redis_client
if rc is not None:
rc.setex(f"oauth_cli_redirect:{oauth_state}", _OAUTH_BRIDGE_TTL, redirect_url)
except Exception:
pass
def _pop_cli_redirect(oauth_state: str) -> str | None:
"""Retrieve and delete CLI redirect_url for the given OAuth state."""
try:
import gatehouse_app.extensions as _ext
rc = _ext.redis_client
if rc is not None:
key = f"oauth_cli_redirect:{oauth_state}"
val = rc.get(key)
if val:
rc.delete(key)
return val.decode() if isinstance(val, bytes) else val
except Exception:
pass
return None
logger = logging.getLogger(__name__)
@@ -69,6 +96,71 @@ def get_provider_type(provider: str) -> AuthMethodType:
return PROVIDER_TYPE_MAP[provider_lower]
@api_v1_bp.route("/token_please", methods=["GET"])
def token_please():
"""
CLI token acquisition endpoint.
Initiates an OAuth login flow and, on success, redirects the user's browser
to the CLI's local callback server (redirect_url) with the session token
appended, e.g.: http://127.0.0.1:8250/?token=<SESSION_TOKEN>
This endpoint is designed for CLI clients that:
1. Start a local HTTP server on LISTENER_SERVER_PORT (e.g. 8250)
2. Open a browser to /api/v1/token_please?redirect_url=http://127.0.0.1:8250/?token=
3. Wait for the browser to POST the token back to their local server
Query parameters:
redirect_url: Local callback URL where the token will be appended
provider: OAuth provider to use (default: 'google')
"""
from urllib.parse import urlencode
from flask import current_app, redirect as flask_redirect
redirect_url = request.args.get("redirect_url", "").strip()
provider = request.args.get("provider", "google").lower()
if not redirect_url:
return api_response(
success=False,
message="redirect_url query parameter is required",
status=400,
error_type="MISSING_REDIRECT_URL",
)
# Validate redirect_url is localhost/127.0.0.1 (security: prevent open redirect)
from urllib.parse import urlparse as _urlparse
parsed = _urlparse(redirect_url)
if parsed.hostname not in ("localhost", "127.0.0.1"):
return api_response(
success=False,
message="redirect_url must point to localhost",
status=400,
error_type="INVALID_REDIRECT_URL",
)
try:
provider_type = get_provider_type(provider)
auth_url, state = OAuthFlowService.initiate_login_flow(
provider_type=provider_type,
organization_id=None,
redirect_uri=None,
)
except (OAuthFlowError, ExternalAuthError) as e:
return api_response(
success=False,
message=getattr(e, "message", str(e)),
status=getattr(e, "status_code", 400),
error_type=getattr(e, "error_type", "OAUTH_ERROR"),
)
# Store the CLI redirect URL so the callback can use it
_store_cli_redirect(state, redirect_url)
logger.info(f"CLI token_please: provider={provider}, redirect_url={redirect_url}, redirecting to OAuth")
return flask_redirect(auth_url, code=302)
# =============================================================================
# Provider Configuration Endpoints (Admin)
# =============================================================================
@@ -575,8 +667,6 @@ def initiate_oauth_authorize(provider: str):
"state": "state_token"
}
"""
provider_type = get_provider_type(provider)
# Get query parameters - organization_id is now optional
flow = request.args.get("flow", "login")
redirect_uri = request.args.get("redirect_uri")
@@ -592,7 +682,7 @@ def initiate_oauth_authorize(provider: str):
)
try:
# Initiate flow - organization_id is now optional
provider_type = get_provider_type(provider)
if flow == "login":
auth_url, state = OAuthFlowService.initiate_login_flow(
provider_type=provider_type,
@@ -626,6 +716,13 @@ def initiate_oauth_authorize(provider: str):
status=e.status_code,
error_type=e.error_type,
)
except ExternalAuthError as e:
return api_response(
success=False,
message=e.message,
status=e.status_code,
error_type=e.error_type,
)
@api_v1_bp.route("/auth/external/<provider>/callback", methods=["GET"])
@@ -666,8 +763,19 @@ def handle_oauth_callback(provider: str):
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080")
frontend_callback = f"{frontend_url}/oauth/callback"
# Check if this is a CLI /token_please flow — retrieve stored redirect_url
cli_redirect_url = _pop_cli_redirect(state) if state else None
def redirect_error(message: str, error_type: str = "OAUTH_ERROR"):
"""Redirect to frontend with error params."""
"""Redirect to frontend (or CLI) with error params."""
if cli_redirect_url:
# CLI flow: return a plain error page instead of redirecting back
from flask import make_response
return make_response(
f"<html><body><h2>Authentication Error</h2><p>{message}</p>"
f"<p>You may close this window.</p></body></html>",
400,
)
params = {"error": message, "error_type": error_type}
if state:
params["state"] = state
@@ -706,8 +814,11 @@ def handle_oauth_callback(provider: str):
# Recover oidc_session_id if this was triggered from an OIDC bridge flow
oidc_session_id = _pop_oidc_bridge(state)
# Organization selection / creation flows are not supported in CLI mode
# (fall through to token redirect with whatever session we have)
# Organization selection needed (user belongs to multiple orgs)
if result.get("requires_org_selection"):
if result.get("requires_org_selection") and not cli_redirect_url:
import json
orgs = json.dumps(result.get("available_organizations", []))
params = {
@@ -722,7 +833,7 @@ def handle_oauth_callback(provider: str):
return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302)
# Organization creation needed (new user via OAuth with no org)
if result.get("requires_org_creation"):
if result.get("requires_org_creation") and not cli_redirect_url:
params = {
"requires_org_creation": "1",
"state": result["state"],
@@ -751,6 +862,19 @@ def handle_oauth_callback(provider: str):
user_info = result.get("user", {})
if user_info.get("email"):
params["email"] = user_info["email"]
# ── CLI /token_please flow: redirect to the CLI's local callback ─────
if cli_redirect_url:
# The CLI expects: http://127.0.0.1:8250/?token=<TOKEN>
# cli_redirect_url already ends with "token=" so just append the value
cli_final_url = cli_redirect_url + token
logger.info(
f"CLI token_please success: provider={provider}, user={user_info.get('email')}, "
f"redirecting to CLI callback"
)
return flask_redirect(cli_final_url, code=302)
# ── Frontend flow ─────────────────────────────────────────────────────
# Pass oidc_session_id through so the frontend can complete the OIDC flow
if oidc_session_id:
params["oidc_session_id"] = oidc_session_id
+554 -2
View File
@@ -13,8 +13,6 @@ from gatehouse_app.schemas.organization_schema import (
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.services.user_service import UserService
from gatehouse_app.utils.constants import OrganizationRole
########jb- need to implement departs, principals
@api_v1_bp.route("/organizations", methods=["POST"])
@login_required
@full_access_required
@@ -378,3 +376,557 @@ def update_member_role(org_id, user_id):
error_type="VALIDATION_ERROR",
error_details=e.messages,
)
@api_v1_bp.route("/organizations/<org_id>/audit-logs", methods=["GET"])
@login_required
@full_access_required
def get_organization_audit_logs(org_id):
"""
Get audit logs for an organization.
Query params:
page: Page number (default 1)
per_page: Results per page (default 50, max 200)
action: Filter by action type
Returns:
200: List of audit log entries
401: Not authenticated
403: Not a member / insufficient permissions
404: Organization not found
"""
from gatehouse_app.models.audit_log import AuditLog
# Ensure org exists and user is a member (full_access_required handles this)
OrganizationService.get_organization_by_id(org_id)
page = int(request.args.get("page", 1))
per_page = min(int(request.args.get("per_page", 50)), 200)
action_filter = request.args.get("action")
query = AuditLog.query.filter_by(organization_id=org_id)
if action_filter:
query = query.filter_by(action=action_filter)
query = query.order_by(AuditLog.created_at.desc())
total = query.count()
logs = query.offset((page - 1) * per_page).limit(per_page).all()
def log_to_dict(log):
return {
"id": log.id,
"action": log.action.value if log.action else None,
"user_id": log.user_id,
"user_email": log.user.email if log.user else None,
"user": {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} if log.user else None,
"organization_id": log.organization_id,
"resource_type": log.resource_type,
"resource_id": log.resource_id,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"request_id": log.request_id,
"description": log.description,
"success": log.success,
"error_message": log.error_message,
"metadata": log.extra_data,
"created_at": log.created_at.isoformat() if log.created_at else None,
"updated_at": log.updated_at.isoformat() if log.updated_at else None,
}
return api_response(
data={
"audit_logs": [log_to_dict(log) for log in logs],
"count": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page,
},
message="Audit logs retrieved successfully",
)
# ============================================================================
# Organization Invite Tokens
# ============================================================================
@api_v1_bp.route("/organizations/<org_id>/invites", methods=["POST"])
@login_required
@require_admin
def create_org_invite(org_id):
"""Create an invite token for an organization.
Request body:
email: Email address to invite
role: Role to assign (default: member)
Returns:
201: Invite created
400: Validation error
403: Not an admin
404: Organization not found
"""
from gatehouse_app.models import OrgInviteToken, Organization
from gatehouse_app.services.notification_service import NotificationService
from flask import current_app
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
if not org:
return api_response(success=False, message="Organization not found", status=404)
data = request.get_json() or {}
email = (data.get("email") or "").strip().lower()
role = (data.get("role") or "member").strip()
if not email:
return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR")
invite = OrgInviteToken.generate(
organization_id=org_id,
email=email,
role=role,
invited_by_id=g.current_user.id,
)
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
invite_link = f"{app_url}/invite?token={invite.token}"
NotificationService._send_email(
to_address=email,
subject=f"You're invited to join {org.name} on Gatehouse",
body=(
f"You've been invited to join {org.name} on Gatehouse.\n\n"
f"Click the link below to accept the invitation (valid for 7 days):\n"
f"{invite_link}\n\n"
f"Gatehouse Security Team"
),
)
return api_response(
data={"invite": {"id": invite.id, "email": invite.email, "role": invite.role, "expires_at": invite.expires_at.isoformat() + "Z"}},
message="Invite sent successfully",
status=201,
)
@api_v1_bp.route("/invites/<token>", methods=["GET"])
def get_invite(token):
"""Get invite details by token.
Returns:
200: Invite details (org name, email)
400: Invalid or expired token
"""
from gatehouse_app.models import OrgInviteToken
invite = OrgInviteToken.query.filter_by(token=token).first()
if not invite or not invite.is_valid:
return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN")
return api_response(
data={
"email": invite.email,
"organization": {"id": invite.organization_id, "name": invite.organization.name},
"role": invite.role,
},
message="Invite found",
)
@api_v1_bp.route("/invites/<token>/accept", methods=["POST"])
def accept_invite(token):
"""Accept an organization invite.
Creates the user account (if not already registered) and adds them
to the organization.
Request body:
full_name: User's display name
password: Password for new account (if not already registered)
password_confirm: Password confirmation
Returns:
200: Invite accepted, returns user token
400: Invalid/expired token or validation error
409: Already a member
"""
from gatehouse_app.models import OrgInviteToken, User
from gatehouse_app.services.auth_service import AuthService
from gatehouse_app.services.organization_service import OrganizationService
from gatehouse_app.utils.constants import OrganizationRole
invite = OrgInviteToken.query.filter_by(token=token).first()
if not invite or not invite.is_valid:
return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN")
data = request.get_json() or {}
full_name = data.get("full_name") or ""
password = data.get("password") or ""
password_confirm = data.get("password_confirm") or ""
user = User.query.filter_by(email=invite.email, deleted_at=None).first()
if not user:
# Register a new user
if not password:
return api_response(success=False, message="Password is required for new accounts.", status=400, error_type="VALIDATION_ERROR")
if password != password_confirm:
return api_response(success=False, message="Passwords do not match.", status=400, error_type="VALIDATION_ERROR")
if len(password) < 8:
return api_response(success=False, message="Password must be at least 8 characters.", status=400, error_type="VALIDATION_ERROR")
try:
user = AuthService.register_user(email=invite.email, password=password, full_name=full_name or None)
except Exception as exc:
return api_response(success=False, message=str(exc), status=400, error_type="REGISTRATION_ERROR")
# Add to org
role_value = invite.role
try:
org_role = OrganizationRole(role_value)
except ValueError:
org_role = OrganizationRole.MEMBER
try:
OrganizationService.add_member(
org=invite.organization,
user_id=user.id,
role=org_role,
inviter_id=invite.invited_by_id,
)
except Exception:
pass # Already a member is fine
invite.accept()
user_session = AuthService.create_session(user)
return api_response(
data={
"user": user.to_dict(),
"token": user_session.token,
"expires_at": user_session.expires_at.isoformat() + "Z",
},
message="Invitation accepted. Welcome!",
)
# ============================================================================
# Organization OIDC Clients
# ============================================================================
@api_v1_bp.route("/organizations/<org_id>/clients", methods=["GET"])
@login_required
def list_org_clients(org_id):
"""List OIDC clients for an organization.
Returns:
200: List of OIDC clients
403: Not a member
404: Organization not found
"""
from gatehouse_app.models import OIDCClient, Organization
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
if not org:
return api_response(success=False, message="Organization not found", status=404)
clients = OIDCClient.query.filter_by(organization_id=org_id, is_active=True).all()
def client_to_dict(c):
return {
"id": c.id,
"name": c.name,
"client_id": c.client_id,
"redirect_uris": c.redirect_uris,
"scopes": c.scopes,
"grant_types": c.grant_types,
"is_active": c.is_active,
"created_at": c.created_at.isoformat() + "Z",
}
return api_response(
data={"clients": [client_to_dict(c) for c in clients], "count": len(clients)},
message="Clients retrieved successfully",
)
@api_v1_bp.route("/organizations/<org_id>/clients", methods=["POST"])
@login_required
@require_admin
def create_org_client(org_id):
"""Create a new OIDC client for an organization.
Request body:
name: Client name
redirect_uris: List of allowed redirect URIs (newline or comma separated string)
Returns:
201: Client created with client_id and client_secret
403: Not an admin
404: Organization not found
"""
import secrets as _secrets
from gatehouse_app.extensions import bcrypt
from gatehouse_app.models import OIDCClient, Organization
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
if not org:
return api_response(success=False, message="Organization not found", status=404)
data = request.get_json() or {}
name = (data.get("name") or "").strip()
redirect_uris_raw = data.get("redirect_uris") or []
if not name:
return api_response(success=False, message="Client name is required", status=400, error_type="VALIDATION_ERROR")
if isinstance(redirect_uris_raw, str):
redirect_uris = [u.strip() for u in redirect_uris_raw.replace(",", "\n").splitlines() if u.strip()]
else:
redirect_uris = [u.strip() for u in redirect_uris_raw if isinstance(u, str) and u.strip()]
if not redirect_uris:
return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR")
client_id = _secrets.token_hex(16)
client_secret = _secrets.token_urlsafe(32)
client = OIDCClient(
organization_id=org_id,
name=name,
client_id=client_id,
client_secret_hash=bcrypt.generate_password_hash(client_secret).decode("utf-8"),
redirect_uris=redirect_uris,
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scopes=["openid", "profile", "email"],
is_active=True,
is_confidential=True,
)
from gatehouse_app.extensions import db
db.session.add(client)
db.session.commit()
return api_response(
data={
"client": {
"id": client.id,
"name": client.name,
"client_id": client.client_id,
"client_secret": client_secret, # Only returned once
"redirect_uris": client.redirect_uris,
"scopes": client.scopes,
"created_at": client.created_at.isoformat() + "Z",
}
},
message="OIDC client created successfully",
status=201,
)
@api_v1_bp.route("/organizations/<org_id>/clients/<client_id>", methods=["DELETE"])
@login_required
@require_admin
def delete_org_client(org_id, client_id):
"""Deactivate an OIDC client.
Returns:
200: Client deactivated
403: Not an admin
404: Client not found
"""
from gatehouse_app.models import OIDCClient
from gatehouse_app.extensions import db
client = OIDCClient.query.filter_by(id=client_id, organization_id=org_id).first()
if not client:
return api_response(success=False, message="Client not found", status=404)
client.is_active = False
db.session.commit()
return api_response(data={}, message="Client deactivated successfully")
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>/send-mfa-reminder", methods=["POST"])
@login_required
@require_admin
def send_mfa_reminder(org_id, user_id):
"""Send an MFA reminder email to a specific member.
Returns:
200: Reminder sent (or silently skipped if no deadline record)
403: Not an admin
404: Member not found
"""
from gatehouse_app.models import User, MfaPolicyCompliance, OrganizationSecurityPolicy
from gatehouse_app.services.notification_service import NotificationService
user = User.query.filter_by(id=user_id, deleted_at=None).first()
if not user:
return api_response(success=False, message="User not found", status=404)
compliance = MfaPolicyCompliance.query.filter_by(
user_id=user_id, organization_id=org_id
).first()
policy = OrganizationSecurityPolicy.query.filter_by(organization_id=org_id).first()
if compliance and policy and compliance.deadline_at:
NotificationService.send_mfa_deadline_reminder(user, compliance, policy)
else:
# No compliance deadline — send a generic nudge
NotificationService._send_email(
to_address=user.email,
subject="Reminder: Set up multi-factor authentication",
body=(
f"Hi {user.full_name or user.email},\n\n"
"Your organization administrator has asked you to set up "
"multi-factor authentication (MFA) on your Gatehouse account.\n\n"
"Please log in and configure MFA as soon as possible.\n\n"
"Gatehouse Security Team"
),
)
return api_response(data={}, message="Reminder sent successfully")
# =============================================================================
# System-wide Audit Log (admin view) + User self audit
# =============================================================================
def _audit_log_to_dict(log):
"""Serialize an AuditLog record to a dict."""
return {
"id": log.id,
"action": log.action.value if log.action else None,
"user_id": log.user_id,
"user": (
{"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name}
if log.user else None
),
"organization_id": log.organization_id,
"resource_type": log.resource_type,
"resource_id": log.resource_id,
"ip_address": log.ip_address,
"user_agent": log.user_agent,
"request_id": log.request_id,
"description": log.description,
"success": log.success,
"error_message": log.error_message,
"metadata": log.extra_data,
"created_at": log.created_at.isoformat() if log.created_at else None,
"updated_at": log.updated_at.isoformat() if log.updated_at else None,
}
@api_v1_bp.route("/audit-logs", methods=["GET"])
@login_required
def get_system_audit_logs():
"""
Get all audit logs (system-wide). Any authenticated user can query
their own logs; org owners/admins also see org-scoped logs; this
endpoint returns ALL logs for users who own at least one org
(acting as an admin view).
Query params:
page page number (default 1)
per_page results per page (default 50, max 200)
action filter by AuditAction value
user_id filter by user id
resource_type filter by resource type
success "true"/"false"
q free-text search on description
"""
from gatehouse_app.models.audit_log import AuditLog
from gatehouse_app.models.organization_member import OrganizationMember
current_user = g.current_user
page = max(1, int(request.args.get("page", 1)))
per_page = min(int(request.args.get("per_page", 50)), 200)
# Check if the user is an owner of any org to grant admin-level access
is_admin = OrganizationMember.query.filter_by(
user_id=current_user.id, role="OWNER"
).first() is not None
query = AuditLog.query
if not is_admin:
# Non-admins can only see their own logs
query = query.filter(AuditLog.user_id == current_user.id)
# Optional filters
action_filter = request.args.get("action")
if action_filter:
query = query.filter(AuditLog.action == action_filter)
user_id_filter = request.args.get("user_id")
if user_id_filter:
query = query.filter(AuditLog.user_id == user_id_filter)
resource_type_filter = request.args.get("resource_type")
if resource_type_filter:
query = query.filter(AuditLog.resource_type == resource_type_filter)
success_filter = request.args.get("success")
if success_filter is not None:
query = query.filter(AuditLog.success == (success_filter.lower() == "true"))
q = request.args.get("q", "").strip()
if q:
query = query.filter(AuditLog.description.ilike(f"%{q}%"))
query = query.order_by(AuditLog.created_at.desc())
total = query.count()
logs = query.offset((page - 1) * per_page).limit(per_page).all()
return api_response(
data={
"audit_logs": [_audit_log_to_dict(log) for log in logs],
"count": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page,
"is_admin_view": is_admin,
},
message="Audit logs retrieved",
)
@api_v1_bp.route("/auth/audit-logs", methods=["GET"])
@login_required
def get_my_audit_logs():
"""
Get audit logs for the currently authenticated user only.
Query params:
page page number (default 1)
per_page results per page (default 50, max 200)
action filter by AuditAction value
"""
from gatehouse_app.models.audit_log import AuditLog
current_user = g.current_user
page = max(1, int(request.args.get("page", 1)))
per_page = min(int(request.args.get("per_page", 50)), 200)
query = AuditLog.query.filter(AuditLog.user_id == current_user.id)
action_filter = request.args.get("action")
if action_filter:
query = query.filter(AuditLog.action == action_filter)
query = query.order_by(AuditLog.created_at.desc())
total = query.count()
logs = query.offset((page - 1) * per_page).limit(per_page).all()
return api_response(
data={
"audit_logs": [_audit_log_to_dict(log) for log in logs],
"count": total,
"page": page,
"per_page": per_page,
"pages": (total + per_page - 1) // per_page,
},
message="Activity retrieved",
)
+615
View File
@@ -0,0 +1,615 @@
"""SSH Key and Certificate API routes."""
from flask import Blueprint, request, jsonify, g
from sqlalchemy.exc import IntegrityError
from gatehouse_app.services.ssh_key_service import SSHKeyService
from gatehouse_app.services.ssh_ca_signing_service import (
SSHCASigningService,
SSHCertificateSigningRequest,
)
from gatehouse_app.exceptions import (
SSHKeyError,
SSHKeyNotFoundError,
SSHCertificateError,
ValidationError,
SSHKeyAlreadyExistsError,
)
from gatehouse_app.utils.constants import AuditAction
from gatehouse_app.models import AuditLog
from gatehouse_app.utils.decorators import login_required
ssh_bp = Blueprint('ssh', __name__, url_prefix='/ssh')
ssh_key_service = SSHKeyService()
ssh_ca_service = SSHCASigningService()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_org_ca_for_user(user):
"""Return the active DB CA for the user's first org, or None."""
try:
from gatehouse_app.models.ca import CA
org_ids = [m.organization_id for m in user.organization_memberships]
if not org_ids:
return None
return CA.query.filter(
CA.organization_id.in_(org_ids),
CA.is_active == True, # noqa: E712
).first()
except Exception:
return None
def _get_or_create_system_ca():
"""
Return a CA DB record representing the config-file CA.
This is used as the ``ca_id`` FK when persisting certificates that were
signed by the globally-configured CA key (not an org-specific DB CA).
The record is created on first use and has no ``organization_id``.
"""
from gatehouse_app.extensions import db
from gatehouse_app.models.ca import CA, KeyType
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
import os
try:
existing = CA.query.filter_by(name="system-config-ca").first()
if existing:
return existing
cfg = get_ssh_ca_config()
key_path = cfg.get_str("ca_key_path", "").strip()
pub_key_path = key_path + ".pub"
if not os.path.exists(pub_key_path):
return None
with open(pub_key_path) as f:
pub_key = f.read().strip()
# Load private key for the record (stored but not actually used for signing here)
priv_key = ""
if os.path.exists(key_path):
with open(key_path) as f:
priv_key = f.read()
fingerprint = compute_ssh_fingerprint(pub_key)
# Check by fingerprint in case it was created under a different name
existing_by_fp = CA.query.filter_by(fingerprint=fingerprint).first()
if existing_by_fp:
return existing_by_fp
system_ca = CA(
name="system-config-ca",
description="Global CA loaded from etc/ssh_ca.conf (ca_key_path)",
key_type=KeyType.ED25519,
private_key=priv_key,
public_key=pub_key,
fingerprint=fingerprint,
is_active=True,
default_cert_validity_hours=24,
max_cert_validity_hours=720,
)
# organization_id is nullable=False in schema — we need a dummy org or
# need to allow NULL. Use None; the DB constraint will tell us quickly.
# If the migration enforces NOT NULL we'll catch the error gracefully.
db.session.add(system_ca)
db.session.commit()
return system_ca
except Exception as exc:
import logging
logging.getLogger(__name__).warning(
f"Could not upsert system-config-ca: {exc}"
)
try:
db.session.rollback()
except Exception:
pass
return None
def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None):
"""Save a signed certificate to the ssh_certificates table.
Args:
user_id: UUID of the user
ssh_key_id: UUID of the SSH key that was signed
ca: CA model instance (may be None cert still returned but not persisted)
signing_response: SSHCertificateSigningResponse
request_ip: Client IP address
Returns:
SSHCertificate instance or None if persistence failed
"""
if ca is None:
return None
try:
from gatehouse_app.extensions import db
from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus
from gatehouse_app.models.ca import CertType
cert_record = SSHCertificate(
ca_id=ca.id,
user_id=user_id,
ssh_key_id=ssh_key_id,
certificate=signing_response.certificate,
serial=signing_response.serial,
key_id=str(ssh_key_id),
cert_type=CertType.USER,
principals=signing_response.principals,
valid_after=signing_response.valid_after,
valid_before=signing_response.valid_before,
revoked=False,
status=CertificateStatus.ISSUED,
request_ip=request_ip,
)
db.session.add(cert_record)
db.session.commit()
return cert_record
except Exception as exc:
import logging
logging.getLogger(__name__).warning(
f"Failed to persist certificate to DB: {exc}"
)
try:
from gatehouse_app.extensions import db as _db
_db.session.rollback()
except Exception:
pass
return None
@ssh_bp.route('/keys', methods=['GET'])
@login_required
def list_ssh_keys():
"""Get all SSH keys for current user."""
user_id = g.current_user.id
keys = ssh_key_service.get_user_ssh_keys(user_id)
return jsonify({
'keys': [k.to_dict() for k in keys],
'count': len(keys),
}), 200
@ssh_bp.route('/keys', methods=['POST'])
@login_required
def add_ssh_key():
"""Add a new SSH public key for current user."""
user_id = g.current_user.id
data = request.get_json()
if not data:
return jsonify({'error': 'No JSON data provided'}), 400
public_key = data.get('public_key') or data.get('key')
description = data.get('description')
if not public_key:
return jsonify({'error': 'public_key is required'}), 400
try:
ssh_key = ssh_key_service.add_ssh_key(
user_id=user_id,
public_key=public_key,
description=description,
)
# Audit log
AuditLog.log(
action=AuditAction.SSH_KEY_ADDED,
user_id=user_id,
resource_type='SSHKey',
resource_id=ssh_key.id,
ip_address=request.remote_addr,
)
return jsonify(ssh_key.to_dict()), 201
except SSHKeyAlreadyExistsError as e:
return jsonify({'error': e.message, 'code': 'SSH_KEY_ALREADY_EXISTS'}), 409
except IntegrityError:
return jsonify({'error': 'SSH key already exists', 'code': 'SSH_KEY_ALREADY_EXISTS'}), 409
except SSHKeyError as e:
return jsonify({'error': str(e)}), 400
except ValidationError as e:
return jsonify({'error': str(e)}), 400
@ssh_bp.route('/keys/<key_id>', methods=['GET'])
@login_required
def get_ssh_key(key_id):
"""Get a specific SSH key."""
user_id = g.current_user.id
try:
ssh_key = ssh_key_service.get_ssh_key(key_id)
# Check ownership
if ssh_key.user_id != user_id:
return jsonify({'error': 'Forbidden'}), 403
return jsonify(ssh_key.to_dict()), 200
except SSHKeyNotFoundError:
return jsonify({'error': 'SSH key not found'}), 404
@ssh_bp.route('/keys/<key_id>', methods=['DELETE'])
@login_required
def delete_ssh_key(key_id):
"""Delete an SSH key."""
user_id = g.current_user.id
try:
ssh_key = ssh_key_service.get_ssh_key(key_id)
# Check ownership
if ssh_key.user_id != user_id:
return jsonify({'error': 'Forbidden'}), 403
ssh_key_service.delete_ssh_key(key_id)
# Audit log
AuditLog.log(
action=AuditAction.SSH_KEY_DELETED,
user_id=user_id,
resource_type='SSHKey',
resource_id=key_id,
ip_address=request.remote_addr,
)
return jsonify({'status': 'deleted'}), 200
except SSHKeyNotFoundError:
return jsonify({'error': 'SSH key not found'}), 404
@ssh_bp.route('/keys/<key_id>/verify', methods=['GET', 'POST'])
@login_required
def verify_ssh_key(key_id):
"""Generate or verify SSH key ownership challenge."""
user_id = g.current_user.id
try:
ssh_key = ssh_key_service.get_ssh_key(key_id)
# Check ownership
if ssh_key.user_id != user_id:
return jsonify({'error': 'Forbidden'}), 403
# Handle GET request - return challenge
if request.method == 'GET':
challenge = ssh_key_service.generate_verification_challenge(key_id)
return jsonify({
'challenge_text': challenge,
'validationText': challenge, # Backwards compatibility
'key_id': key_id,
}), 200
# Handle POST request - verify signature
data = request.get_json() or {}
action = data.get('action', 'verify_signature')
if action == 'verify_signature':
# Verify signature
signature = data.get('signature')
if not signature:
return jsonify({'error': 'signature is required'}), 400
try:
verified = ssh_key_service.verify_ssh_key_ownership(key_id, signature)
# Audit log
AuditLog.log(
action=AuditAction.SSH_KEY_VERIFIED,
user_id=user_id,
resource_type='SSHKey',
resource_id=key_id,
ip_address=request.remote_addr,
success=verified,
)
return jsonify({'verified': verified}), 200
except Exception as e:
AuditLog.log(
action=AuditAction.SSH_KEY_VALIDATION_FAILED,
user_id=user_id,
resource_type='SSHKey',
resource_id=key_id,
ip_address=request.remote_addr,
success=False,
error_message=str(e),
)
return jsonify({'error': str(e)}), 400
else: # generate_challenge
# Generate verification challenge
challenge = ssh_key_service.generate_verification_challenge(key_id)
return jsonify({
'challenge_text': challenge,
'challenge': challenge, # Both for compatibility
}), 200
except SSHKeyNotFoundError:
return jsonify({'error': 'SSH key not found'}), 404
@ssh_bp.route('/keys/<key_id>/update-description', methods=['PATCH'])
@login_required
def update_ssh_key_description(key_id):
"""Update SSH key description."""
user_id = g.current_user.id
data = request.get_json()
if not data or 'description' not in data:
return jsonify({'error': 'description is required'}), 400
try:
ssh_key = ssh_key_service.get_ssh_key(key_id)
# Check ownership
if ssh_key.user_id != user_id:
return jsonify({'error': 'Forbidden'}), 403
updated_key = ssh_key_service.update_ssh_key_description(
key_id,
data['description']
)
return jsonify(updated_key.to_dict()), 200
except SSHKeyNotFoundError:
return jsonify({'error': 'SSH key not found'}), 404
@ssh_bp.route('/sign', methods=['POST'])
@login_required
def sign_certificate():
"""Sign an SSH certificate for the current user."""
user = g.current_user
user_id = user.id
data = request.get_json()
if not data:
return jsonify({'error': 'No JSON data provided'}), 400
try:
principals = data.get('principals', [])
cert_type = data.get('cert_type', 'user')
# Accept both 'key_id' and 'cert_id' (from CLI)
key_id = data.get('key_id') or data.get('cert_id')
expiry_hours = data.get('expiry_hours')
if not principals:
return jsonify({'error': 'principals is required'}), 400
# If key_id not specified, use first verified key
if not key_id:
verified_keys = ssh_key_service.get_user_verified_ssh_keys(user_id)
if not verified_keys:
return jsonify({'error': 'No verified SSH keys found'}), 400
key_id = verified_keys[0].id
# Get the SSH key
ssh_key = ssh_key_service.get_ssh_key(key_id)
if ssh_key.user_id != user_id:
return jsonify({'error': 'Forbidden'}), 403
if not ssh_key.verified:
return jsonify({'error': 'SSH key is not verified'}), 400
# Resolve which CA to use: org DB CA > config-file CA
db_ca = _get_org_ca_for_user(user)
ca_private_key = db_ca.private_key if db_ca else None # None → signing service uses config
# Create signing request
signing_request = SSHCertificateSigningRequest(
ssh_public_key=ssh_key.payload,
principals=principals,
cert_type=cert_type,
key_id=key_id,
expiry_hours=int(expiry_hours) if expiry_hours else None,
)
# Validate request
validation_errors = signing_request.validate()
if validation_errors:
return jsonify({'errors': validation_errors}), 400
# Sign the certificate (pass ca_private_key=None → service loads from config)
response = ssh_ca_service.sign_certificate(signing_request, ca_private_key=ca_private_key)
# Persist certificate to DB
# If user's org has no DB CA, use the system-config-ca record
ca_for_db = db_ca or _get_or_create_system_ca()
cert_record = _persist_certificate(
user_id=user_id,
ssh_key_id=key_id,
ca=ca_for_db,
signing_response=response,
request_ip=request.remote_addr,
)
# Audit log
AuditLog.log(
action=AuditAction.SSH_CERT_ISSUED,
user_id=user_id,
resource_type='SSHCertificate',
resource_id=cert_record.id if cert_record else key_id,
ip_address=request.remote_addr,
description=f'Certificate issued for principals: {", ".join(principals)}',
)
result = {
'certificate': response.certificate,
'serial': response.serial,
'principals': response.principals,
'valid_after': response.valid_after.isoformat() if response.valid_after else None,
'valid_before': response.valid_before.isoformat() if response.valid_before else None,
}
if cert_record:
result['cert_id'] = str(cert_record.id)
return jsonify(result), 201
except SSHKeyNotFoundError:
return jsonify({'error': 'SSH key not found'}), 404
except SSHCertificateError as e:
AuditLog.log(
action=AuditAction.SSH_CERT_FAILED,
user_id=user_id,
resource_type='SSHCertificate',
ip_address=request.remote_addr,
success=False,
error_message=str(e),
)
return jsonify({'error': str(e)}), 400
except Exception as e:
AuditLog.log(
action=AuditAction.SSH_CERT_FAILED,
user_id=user_id,
resource_type='SSHCertificate',
ip_address=request.remote_addr,
success=False,
error_message=str(e),
)
return jsonify({'error': 'Certificate signing failed: ' + str(e)}), 500
@ssh_bp.route('/certificates', methods=['GET'])
@login_required
def list_certificates():
"""List all SSH certificates issued for the current user."""
user_id = g.current_user.id
try:
from gatehouse_app.models.ssh_certificate import SSHCertificate
certs = (
SSHCertificate.query
.filter_by(user_id=user_id, deleted_at=None)
.order_by(SSHCertificate.created_at.desc())
.all()
)
return jsonify({
'certificates': [c.to_dict() for c in certs],
'count': len(certs),
}), 200
except Exception as e:
return jsonify({'error': str(e)}), 500
@ssh_bp.route('/certificates/<cert_id>', methods=['GET'])
@login_required
def get_certificate(cert_id):
"""Get a specific issued certificate (metadata only)."""
user_id = g.current_user.id
try:
from gatehouse_app.models.ssh_certificate import SSHCertificate
cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first()
if not cert:
return jsonify({'error': 'Certificate not found'}), 404
if cert.user_id != user_id:
return jsonify({'error': 'Forbidden'}), 403
# Include full certificate text in single-fetch endpoint
data = cert.to_dict()
data['certificate'] = cert.certificate
return jsonify(data), 200
except Exception as e:
return jsonify({'error': str(e)}), 500
@ssh_bp.route('/certificates/<cert_id>/revoke', methods=['POST'])
@login_required
def revoke_certificate(cert_id):
"""Revoke an issued certificate."""
user_id = g.current_user.id
data = request.get_json() or {}
reason = data.get('reason', 'User requested revocation')
try:
from gatehouse_app.models.ssh_certificate import SSHCertificate
cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first()
if not cert:
return jsonify({'error': 'Certificate not found'}), 404
if cert.user_id != user_id:
return jsonify({'error': 'Forbidden'}), 403
if cert.revoked:
return jsonify({'error': 'Certificate is already revoked'}), 409
cert.revoke(reason=reason)
AuditLog.log(
action=AuditAction.SSH_CERT_REVOKED,
user_id=user_id,
resource_type='SSHCertificate',
resource_id=cert_id,
ip_address=request.remote_addr,
description=f'Revoked: {reason}',
)
return jsonify({'status': 'revoked', 'cert_id': cert_id, 'reason': reason}), 200
except Exception as e:
return jsonify({'error': str(e)}), 500
@ssh_bp.route('/ca/public-key', methods=['GET'])
@login_required
def get_ca_public_key():
"""
Return the CA public key for this user's organization.
Server admins should add this key to their host's ``TrustedUserCAKeys``
directive so that certificates issued by gatehouse are trusted.
Query parameters:
format: 'openssh' (default) or 'text' affects Content-Type only
Returns:
{ "public_key": "ssh-ed25519 AAAA...",
"fingerprint": "SHA256:...",
"ca_name": "..." }
"""
user = g.current_user
# Try org CA first
db_ca = _get_org_ca_for_user(user)
if db_ca:
return jsonify({
'public_key': db_ca.public_key,
'fingerprint': db_ca.fingerprint,
'ca_name': db_ca.name,
'source': 'db',
}), 200
# Fall back to config-file CA
try:
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
import os
cfg = get_ssh_ca_config()
key_path = cfg.get_str('ca_key_path', '').strip() + '.pub'
if os.path.exists(key_path):
with open(key_path) as f:
pub_key = f.read().strip()
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
return jsonify({
'public_key': pub_key,
'fingerprint': compute_ssh_fingerprint(pub_key),
'ca_name': 'system-config-ca',
'source': 'config',
}), 200
except Exception as e:
return jsonify({'error': f'Could not load CA public key: {e}'}), 500
return jsonify({'error': 'No CA configured for this organization'}), 404
+271
View File
@@ -0,0 +1,271 @@
"""SSH CA Configuration Manager.
Handles loading and managing SSH CA configuration from etc/ssh_ca.conf
and environment variables.
"""
import os
import configparser
from pathlib import Path
from typing import Optional, Union
class SSHCAConfig:
"""Configuration manager for SSH CA settings.
Loads configuration from:
1. etc/ssh_ca.conf file
2. Environment variables (override config file)
3. Application environment-specific defaults
Example:
config = SSHCAConfig()
cert_hours = config.get_int('cert_validity_hours')
kms_key = config.get_str('aws_kms_key_id')
"""
# Configuration file location (relative to project root)
DEFAULT_CONFIG_FILE = "etc/ssh_ca.conf"
# Default values if config file is missing
DEFAULTS = {
'cert_validity_hours': '1',
'max_cert_validity_hours': '24',
'max_certs_per_user': '100',
'crl_enabled': 'true',
'crl_endpoint': 'https://ca.example.com/crl',
'crl_refresh_hours': '24',
'default_key_type': 'ed25519',
'rsa_key_bits': '4096',
'private_key_encryption': 'kms',
'aws_kms_key_id': '',
'extensions_enabled': 'true',
'extensions': 'permit-X11-forwarding,permit-agent-forwarding,permit-pty,permit-port-forwarding,permit-user-rc',
'critical_options_enabled': 'false',
'max_principals_per_cert': '256',
'max_key_id_length': '255',
'log_level': 'INFO',
'audit_enabled': 'true',
'require_key_verification': 'true',
'verification_challenge_max_age': '24',
'rate_limit_certs_per_minute': '5',
'request_timeout': '30',
'auto_delete_unverified_days': '30',
'archive_expired_days': '365',
'oauth_token_endpoint': '/api/v1/oauth2/token',
'oauth_userinfo_endpoint': '/api/v1/oauth2/userinfo',
'ca_key_path': '',
}
def __init__(self, config_file: Optional[str] = None, environment: Optional[str] = None):
"""Initialize SSH CA configuration.
Args:
config_file: Path to config file (default: etc/ssh_ca.conf)
environment: Environment name (development, production, testing)
Default: value of FLASK_ENV or 'development'
"""
self.config = configparser.ConfigParser()
# Determine environment
if environment is None:
environment = os.environ.get('FLASK_ENV', 'development')
self.environment = environment
# Load config file
if config_file is None:
# Try to find config file relative to this module
module_dir = Path(__file__).parent.parent.parent
config_file = module_dir / self.DEFAULT_CONFIG_FILE
self.config_file = config_file
self._load_config()
def _load_config(self):
"""Load configuration from file and apply environment-specific overrides."""
# Set defaults
self.config['default'] = self.DEFAULTS.copy()
# Load config file if it exists
if Path(self.config_file).exists():
self.config.read(self.config_file)
# Apply environment-specific configuration
if self.environment in self.config:
for key, value in self.config[self.environment].items():
self.config['default'][key] = value
def get_str(self, key: str, default: Optional[str] = None) -> str:
"""Get a string configuration value.
First checks environment variables (SSH_CA_<KEY>), then config file.
Args:
key: Configuration key
default: Default value if not found
Returns:
Configuration value as string
"""
env_key = f"SSH_CA_{key.upper()}"
# Check environment variable first
if env_key in os.environ:
return os.environ[env_key]
# Check config file
if key in self.config['default']:
value = self.config['default'][key]
# Handle environment variable substitution
return os.path.expandvars(value)
# Return default
if default is not None:
return default
return self.DEFAULTS.get(key, '')
def get_int(self, key: str, default: Optional[int] = None) -> int:
"""Get an integer configuration value.
Args:
key: Configuration key
default: Default value if not found
Returns:
Configuration value as integer
Raises:
ValueError: If value cannot be converted to integer
"""
str_value = self.get_str(key)
if not str_value:
if default is not None:
return default
raise ValueError(f"No value found for {key}")
try:
return int(str_value)
except ValueError:
if default is not None:
return default
raise ValueError(f"Configuration {key}={str_value} is not a valid integer")
def get_bool(self, key: str, default: Optional[bool] = None) -> bool:
"""Get a boolean configuration value.
Args:
key: Configuration key
default: Default value if not found
Returns:
Configuration value as boolean
"""
str_value = self.get_str(key)
if not str_value:
if default is not None:
return default
return False
return str_value.lower() in ('true', '1', 'yes', 'on')
def get_list(self, key: str, delimiter: str = ',', default: Optional[list] = None) -> list:
"""Get a comma-separated list configuration value.
Args:
key: Configuration key
delimiter: Delimiter between items (default: comma)
default: Default value if not found
Returns:
Configuration value as list of strings
"""
str_value = self.get_str(key)
if not str_value:
if default is not None:
return default
return []
return [item.strip() for item in str_value.split(delimiter) if item.strip()]
def validate_config(self) -> list:
"""Validate SSH CA configuration.
Returns:
List of validation error messages (empty if valid)
"""
errors = []
# Check cert validity hours
try:
validity = self.get_int('cert_validity_hours')
max_validity = self.get_int('max_cert_validity_hours')
if validity > max_validity:
errors.append(
f"cert_validity_hours ({validity}) > max_cert_validity_hours ({max_validity})"
)
except ValueError as e:
errors.append(f"Invalid cert validity hours: {e}")
# Check key type
valid_key_types = ['ed25519', 'rsa', 'ecdsa']
key_type = self.get_str('default_key_type', 'ed25519')
if key_type not in valid_key_types:
errors.append(f"Invalid key type: {key_type}. Must be one of {valid_key_types}")
# Check encryption method
valid_methods = ['kms', 'local']
encryption = self.get_str('private_key_encryption', 'kms')
if encryption not in valid_methods:
errors.append(f"Invalid private_key_encryption: {encryption}. Must be one of {valid_methods}")
# Warn if using local encryption in production
if encryption == 'local' and self.environment == 'production':
errors.append("WARNING: Using local key encryption in production! Use KMS instead.")
# Check KMS key ID if using KMS
if encryption == 'kms':
kms_key = self.get_str('aws_kms_key_id', '').strip()
if not kms_key:
errors.append("aws_kms_key_id not set but private_key_encryption=kms")
# Check principals limit
max_principals = self.get_int('max_principals_per_cert')
if max_principals > 256:
errors.append(f"max_principals_per_cert ({max_principals}) exceeds SSH limit of 256")
return errors
def to_dict(self) -> dict:
"""Export current configuration as dictionary.
"""
return dict(self.config['default'])
def __repr__(self):
"""String representation of configuration."""
return f"<SSHCAConfig environment={self.environment} file={self.config_file}>"
# Global configuration instance
_config_instance = None
def get_ssh_ca_config() -> SSHCAConfig:
"""Get the global SSH CA configuration instance.
This function uses a singleton pattern to ensure only one
configuration instance is created and reused.
Returns:
SSHCAConfig instance
"""
global _config_instance
if _config_instance is None:
_config_instance = SSHCAConfig()
return _config_instance
def reset_config_instance():
"""Reset the global configuration instance.
"""
global _config_instance
_config_instance = None
+29
View File
@@ -19,6 +19,21 @@ from gatehouse_app.exceptions.validation_exceptions import (
OrganizationNotFoundError,
UserNotFoundError,
)
from gatehouse_app.exceptions.ssh_exceptions import (
SSHCAError,
SSHKeyError,
SSHKeyNotFoundError,
SSHKeyAlreadyExistsError,
SSHKeyNotVerifiedError,
SSHCertificateError,
SSHCertificateNotFoundError,
CAError,
CANotFoundError,
PrincipalError,
PrincipalNotFoundError,
DepartmentError,
DepartmentNotFoundError,
)
__all__ = [
"BaseAPIException",
@@ -37,4 +52,18 @@ __all__ = [
"EmailAlreadyExistsError",
"OrganizationNotFoundError",
"UserNotFoundError",
"SSHCAError",
"SSHKeyError",
"SSHKeyNotFoundError",
"SSHKeyAlreadyExistsError",
"SSHKeyNotVerifiedError",
"SSHCertificateError",
"SSHCertificateNotFoundError",
"CAError",
"CANotFoundError",
"PrincipalError",
"PrincipalNotFoundError",
"DepartmentError",
"DepartmentNotFoundError",
]
+2 -1
View File
@@ -16,9 +16,10 @@ class BaseAPIException(Exception):
message: Custom error message
error_details: Additional error details dictionary
"""
super().__init__()
super().__init__(self.message)
if message:
self.message = message
super().__init__(message) # update args so str(e) works
self.error_details = error_details or {}
def to_dict(self):
@@ -0,0 +1,93 @@
"""SSH-specific exceptions."""
from gatehouse_app.exceptions.base import BaseAPIException
class SSHCAError(BaseAPIException):
"""Base exception for SSH CA operations."""
status_code = 500
error_type = "SSH_CA_ERROR"
class SSHKeyError(BaseAPIException):
"""Exception for SSH key operations."""
status_code = 400
error_type = "SSH_KEY_ERROR"
class SSHKeyNotFoundError(BaseAPIException):
"""SSH key not found."""
status_code = 404
error_type = "SSH_KEY_NOT_FOUND"
class SSHKeyAlreadyExistsError(BaseAPIException):
"""SSH key already exists (duplicate fingerprint)."""
status_code = 409
error_type = "SSH_KEY_ALREADY_EXISTS"
class SSHKeyNotVerifiedError(BaseAPIException):
"""SSH key has not been verified."""
status_code = 400
error_type = "SSH_KEY_NOT_VERIFIED"
class SSHCertificateError(BaseAPIException):
"""Exception for SSH certificate operations."""
status_code = 400
error_type = "SSH_CERT_ERROR"
class SSHCertificateNotFoundError(BaseAPIException):
"""SSH certificate not found."""
status_code = 404
error_type = "SSH_CERT_NOT_FOUND"
class CAError(BaseAPIException):
"""Exception for Certificate Authority operations."""
status_code = 400
error_type = "CA_ERROR"
class CANotFoundError(BaseAPIException):
"""Certificate Authority not found."""
status_code = 404
error_type = "CA_NOT_FOUND"
class PrincipalError(BaseAPIException):
"""Exception for principal operations."""
status_code = 400
error_type = "PRINCIPAL_ERROR"
class PrincipalNotFoundError(BaseAPIException):
"""Principal not found."""
status_code = 404
error_type = "PRINCIPAL_NOT_FOUND"
class DepartmentError(BaseAPIException):
"""Exception for department operations."""
status_code = 400
error_type = "DEPARTMENT_ERROR"
class DepartmentNotFoundError(BaseAPIException):
"""Department not found."""
status_code = 404
error_type = "DEPARTMENT_NOT_FOUND"
+17
View File
@@ -29,6 +29,13 @@ from gatehouse_app.models.principal import (
Principal,
PrincipalMembership,
)
from gatehouse_app.models.ssh_key import SSHKey
from gatehouse_app.models.ca import CA, KeyType, CertType
from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus
from gatehouse_app.models.certificate_audit_log import CertificateAuditLog
from gatehouse_app.models.password_reset_token import PasswordResetToken
from gatehouse_app.models.email_verification_token import EmailVerificationToken
from gatehouse_app.models.org_invite_token import OrgInviteToken
__all__ = [
"BaseModel",
@@ -55,4 +62,14 @@ __all__ = [
"DepartmentPrincipal",
"Principal",
"PrincipalMembership",
"SSHKey",
"CA",
"KeyType",
"CertType",
"SSHCertificate",
"CertificateStatus",
"CertificateAuditLog",
"PasswordResetToken",
"EmailVerificationToken",
"OrgInviteToken",
]
+155
View File
@@ -0,0 +1,155 @@
"""Certificate Authority (CA) model."""
from enum import Enum
from datetime import datetime
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class KeyType(str, Enum):
"""SSH CA key types."""
ED25519 = "ed25519"
RSA = "rsa"
ECDSA = "ecdsa"
class CertType(str, Enum):
"""SSH certificate types."""
USER = "user"
HOST = "host"
class CA(BaseModel):
"""Certificate Authority (CA) model for SSH certificate signing.
Each organization can have multiple CAs for different purposes
(e.g., production vs. staging). Private keys are encrypted at rest
and should be protected with KMS.
"""
__tablename__ = "cas"
organization_id = db.Column(
db.String(36),
db.ForeignKey("organizations.id"),
nullable=True, # NULL for the global system-config CA
index=True,
)
# CA name and description
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.Text, nullable=True)
# Key type (ED25519, RSA, ECDSA)
key_type = db.Column(
db.Enum(KeyType, values_callable=lambda x: [e.value for e in x]),
default=KeyType.ED25519,
nullable=False,
)
# Private key (encrypted at rest by database/KMS)
# Format: PEM-encoded private key
private_key = db.Column(db.Text, nullable=False)
# Public key (PEM format)
public_key = db.Column(db.Text, nullable=False)
# SHA256 fingerprint of the public key
fingerprint = db.Column(db.String(255), nullable=False, unique=True)
# CRL (Certificate Revocation List) configuration
crl_enabled = db.Column(db.Boolean, default=True, nullable=False)
crl_endpoint = db.Column(db.String(512), nullable=True)
# Default certificate validity in hours
# Can be overridden per certificate request
default_cert_validity_hours = db.Column(
db.Integer,
default=1,
nullable=False,
)
# Maximum validity duration allowed
max_cert_validity_hours = db.Column(
db.Integer,
default=24,
nullable=False,
)
# CA status
is_active = db.Column(db.Boolean, default=True, nullable=False, index=True)
# Key rotation tracking
rotated_at = db.Column(db.DateTime, nullable=True)
rotation_reason = db.Column(db.String(255), nullable=True)
# Relationships
organization = db.relationship("Organization", back_populates="cas")
certificates = db.relationship(
"SSHCertificate",
back_populates="ca",
cascade="all, delete-orphan",
)
__table_args__ = (
db.UniqueConstraint(
"organization_id", "name", name="uix_org_ca_name"
),
db.Index("idx_ca_org_active", "organization_id", "is_active"),
)
def __repr__(self):
"""String representation of CA."""
return f"<CA {self.name} (org_id={self.organization_id}, type={self.key_type})>"
def to_dict(self, exclude=None):
"""Convert CA to dictionary."""
exclude = exclude or []
# Never expose private key in API responses
exclude.extend(["private_key"])
data = super().to_dict(exclude=exclude)
# Add computed fields
data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None])
data["active_certs"] = len([
c for c in self.certificates
if c.deleted_at is None and not c.revoked
])
data["revoked_certs"] = len([
c for c in self.certificates
if c.deleted_at is None and c.revoked
])
return data
def get_active_certificates(self):
"""Get all active (non-revoked) certificates issued by this CA.
Returns:
List of non-revoked SSHCertificate objects
"""
return [
c for c in self.certificates
if c.deleted_at is None and not c.revoked
]
def rotate_key(self, new_private_key, new_public_key, new_fingerprint, reason=None):
"""Rotate the CA's key pair.
This should only be done in carefully controlled circumstances.
All existing certificates remain valid but no new certs can be
signed with the old key.
Args:
new_private_key: New PEM-encoded private key
new_public_key: New PEM-encoded public key
new_fingerprint: SHA256 fingerprint of new public key
reason: Optional reason for rotation
"""
self.private_key = new_private_key
self.public_key = new_public_key
self.fingerprint = new_fingerprint
self.rotated_at = datetime.utcnow()
self.rotation_reason = reason
self.save()
@@ -0,0 +1,83 @@
"""Certificate audit log model."""
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class CertificateAuditLog(BaseModel):
"""Audit log for SSH certificate lifecycle events.
Tracks all operations on SSH certificates: signing, revocation,
validation, etc. This is separate from the general AuditLog to
provide detailed certificate operation tracking.
"""
__tablename__ = "certificate_audit_logs"
# Reference to the certificate
certificate_id = db.Column(
db.String(36),
db.ForeignKey("ssh_certificates.id"),
nullable=False,
index=True,
)
# The user who performed the action (can be null for system actions)
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id"),
nullable=True,
index=True,
)
# Action type (e.g., "signed", "revoked", "validated", "requested")
action = db.Column(db.String(50), nullable=False, index=True)
# Request details
ip_address = db.Column(db.String(45), nullable=True)
user_agent = db.Column(db.String(512), nullable=True)
request_id = db.Column(db.String(36), nullable=True)
# Detailed message
message = db.Column(db.Text, nullable=True)
# Additional context
extra_data = db.Column(db.JSON, nullable=True)
# Success/failure
success = db.Column(db.Boolean, default=True, nullable=False)
error_message = db.Column(db.Text, nullable=True)
# Relationships
certificate = db.relationship("SSHCertificate", back_populates="audit_logs")
user = db.relationship("User")
__table_args__ = (
db.Index("idx_cert_audit_cert_action", "certificate_id", "action"),
db.Index("idx_cert_audit_user", "user_id", "created_at"),
)
def __repr__(self):
"""String representation of CertificateAuditLog."""
return f"<CertificateAuditLog cert_id={self.certificate_id} action={self.action}>"
@classmethod
def log(cls, certificate_id, action, user_id=None, **kwargs):
"""Create a certificate audit log entry.
Args:
certificate_id: ID of the certificate
action: Action type (e.g., "signed", "revoked")
user_id: ID of the user performing the action (optional)
**kwargs: Additional fields (ip_address, user_agent, message, etc.)
Returns:
CertificateAuditLog instance
"""
log_entry = cls(
certificate_id=certificate_id,
action=action,
user_id=user_id,
**kwargs
)
log_entry.save()
return log_entry
+3
View File
@@ -40,6 +40,9 @@ class Organization(BaseModel):
principals = db.relationship(
"Principal", back_populates="organization", cascade="all, delete-orphan"
)
cas = db.relationship(
"CA", back_populates="organization", cascade="all, delete-orphan"
)
def __repr__(self):
"""String representation of Organization."""
+175
View File
@@ -0,0 +1,175 @@
"""SSH Certificate model."""
from enum import Enum
from datetime import datetime
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
from gatehouse_app.models.ca import CertType
class CertificateStatus(str, Enum):
"""SSH certificate lifecycle status."""
REQUESTED = "requested" # Waiting for signing
ISSUED = "issued" # Signed and valid
REVOKED = "revoked" # Manually revoked
EXPIRED = "expired" # Validity period ended
SUPERSEDED = "superseded" # Replaced by newer cert
class SSHCertificate(BaseModel):
"""SSH Certificate model representing a signed SSH user/host certificate.
Certificates are issued by a CA and associated with an SSH public key.
They include principals (access levels), validity periods, and other
OpenSSH certificate metadata.
"""
__tablename__ = "ssh_certificates"
# Certificate relationships
ca_id = db.Column(
db.String(36),
db.ForeignKey("cas.id"),
nullable=False,
index=True,
)
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id"),
nullable=False,
index=True,
)
ssh_key_id = db.Column(
db.String(36),
db.ForeignKey("ssh_keys.id"),
nullable=False,
index=True,
)
# Certificate content (full signed certificate in OpenSSH format)
certificate = db.Column(db.Text, nullable=False)
# Certificate metadata
serial = db.Column(db.String(255), nullable=False, unique=True, index=True)
key_id = db.Column(db.String(255), nullable=False) # Usually user email
cert_type = db.Column(
db.Enum(CertType, values_callable=lambda x: [e.value for e in x]),
default=CertType.USER,
nullable=False,
)
# Principals (JSON list) - e.g., ["prod-servers", "dev-servers"]
principals = db.Column(db.JSON, nullable=False, default=list)
# Validity period
valid_after = db.Column(db.DateTime, nullable=False)
valid_before = db.Column(db.DateTime, nullable=False)
# Revocation status
revoked = db.Column(db.Boolean, default=False, nullable=False, index=True)
revoked_at = db.Column(db.DateTime, nullable=True)
revoke_reason = db.Column(db.String(255), nullable=True)
# Status tracking
status = db.Column(
db.Enum(CertificateStatus, values_callable=lambda x: [e.value for e in x]),
default=CertificateStatus.ISSUED,
nullable=False,
index=True,
)
# Request metadata
request_ip = db.Column(db.String(45), nullable=True)
request_user_agent = db.Column(db.String(512), nullable=True)
# Critical options (JSON) - OpenSSH critical options
# See: https://man.openbsd.org/ssh-cert
critical_options = db.Column(db.JSON, nullable=True, default=dict)
# Extensions (JSON) - OpenSSH extensions
# Common ones: permit-X11-forwarding, permit-agent-forwarding, permit-pty, etc.
extensions = db.Column(db.JSON, nullable=True, default=dict)
# Relationships
ca = db.relationship("CA", back_populates="certificates")
user = db.relationship("User", back_populates="ssh_certificates")
ssh_key = db.relationship(
"SSHKey",
back_populates="certificates",
foreign_keys="SSHCertificate.ssh_key_id",
)
audit_logs = db.relationship(
"CertificateAuditLog",
back_populates="certificate",
cascade="all, delete-orphan",
)
__table_args__ = (
db.Index("idx_cert_user_status", "user_id", "status"),
db.Index("idx_cert_validity", "valid_after", "valid_before"),
db.Index("idx_cert_revoked", "revoked", "revoked_at"),
)
def __repr__(self):
"""String representation of SSHCertificate."""
return f"<SSHCertificate serial={self.serial[:16]}... user_id={self.user_id}>"
def to_dict(self, exclude=None):
"""Convert certificate to dictionary."""
exclude = exclude or []
# Optionally exclude the certificate content (it's large)
if "certificate" not in exclude:
exclude.append("certificate")
data = super().to_dict(exclude=exclude)
# Add computed fields
data["is_valid"] = self.is_valid()
data["days_until_expiry"] = self.days_until_expiry()
return data
def is_valid(self):
"""Check if certificate is currently valid.
Returns:
True if certificate is issued, not revoked, and within validity period
"""
if self.revoked or self.status == CertificateStatus.REVOKED:
return False
now = datetime.utcnow()
return self.valid_after <= now <= self.valid_before
def is_expired(self):
"""Check if certificate has expired.
Returns:
True if current time is past valid_before
"""
return datetime.utcnow() > self.valid_before
def days_until_expiry(self):
"""Get number of days until certificate expires.
Returns:
Number of days remaining (negative if already expired)
"""
delta = self.valid_before - datetime.utcnow()
return delta.days + (1 if delta.seconds > 0 else 0)
def revoke(self, reason=None):
"""Revoke this certificate.
Args:
reason: Optional reason for revocation
"""
self.revoked = True
self.revoked_at = datetime.utcnow()
self.revoke_reason = reason
self.status = CertificateStatus.REVOKED
self.save()
def mark_expired(self):
"""Mark certificate as expired when validity period ends."""
self.status = CertificateStatus.EXPIRED
self.save()
+96
View File
@@ -0,0 +1,96 @@
"""SSH Key model."""
from datetime import datetime
from gatehouse_app.extensions import db
from gatehouse_app.models.base import BaseModel
class SSHKey(BaseModel):
"""SSH Key model representing a user's SSH public key.
This model stores SSH public keys that users register for certificate signing.
Users must verify ownership of the key before it can be used for signing certificates.
"""
__tablename__ = "ssh_keys"
user_id = db.Column(
db.String(36),
db.ForeignKey("users.id"),
nullable=False,
index=True,
)
# SSH key payload in OpenSSH format (e.g., "ssh-rsa AAAAB3Nz...")
payload = db.Column(db.Text, nullable=False, unique=True)
# SHA256 fingerprint for quick comparison
fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True)
# Optional description for the key (e.g., "My laptop key")
description = db.Column(db.String(255), nullable=True)
# Verification status
verified = db.Column(db.Boolean, default=False, nullable=False, index=True)
verified_at = db.Column(db.DateTime, nullable=True)
# Verification challenge
verify_text = db.Column(db.String(255), nullable=True)
verify_text_created_at = db.Column(db.DateTime, nullable=True)
# Key type extracted from the key (ssh-rsa, ssh-ed25519, etc.)
key_type = db.Column(db.String(50), nullable=True)
# Key bits/length
key_bits = db.Column(db.Integer, nullable=True)
# Comment from the key (usually email or key name)
key_comment = db.Column(db.String(255), nullable=True)
# Relationships
user = db.relationship("User", back_populates="ssh_keys")
certificates = db.relationship(
"SSHCertificate",
back_populates="ssh_key",
cascade="all, delete-orphan",
foreign_keys="SSHCertificate.ssh_key_id",
)
__table_args__ = (
db.Index("idx_ssh_key_user_verified", "user_id", "verified"),
)
def __repr__(self):
"""String representation of SSHKey."""
return f"<SSHKey {self.fingerprint[:16]}... user_id={self.user_id}>"
def to_dict(self, exclude=None):
"""Convert SSH key to dictionary."""
exclude = exclude or []
exclude.extend(["payload", "verify_text"]) # Never expose these in API
data = super().to_dict(exclude=exclude)
# Add computed fields
data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None])
return data
def mark_verified(self):
"""Mark this SSH key as verified."""
self.verified = True
self.verified_at = datetime.utcnow()
self.save()
def needs_verification_refresh(self, max_age_hours=24):
"""Check if verification challenge needs to be refreshed.
Args:
max_age_hours: Maximum age of verification challenge in hours
Returns:
True if verification challenge is stale
"""
if not self.verify_text_created_at:
return True
age = datetime.utcnow() - self.verify_text_created_at
return age.total_seconds() > (max_age_hours * 3600)
+12
View File
@@ -55,6 +55,18 @@ class User(BaseModel):
cascade="all, delete-orphan",
foreign_keys="PrincipalMembership.user_id",
)
ssh_keys = db.relationship(
"SSHKey",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHKey.user_id",
)
ssh_certificates = db.relationship(
"SSHCertificate",
back_populates="user",
cascade="all, delete-orphan",
foreign_keys="SSHCertificate.user_id",
)
def __repr__(self):
"""String representation of User."""
+24 -3
View File
@@ -12,7 +12,7 @@ from gatehouse_app.models import User, AuthenticationMethod
from gatehouse_app.models.authentication_method import OAuthState
from gatehouse_app.models.base import BaseModel
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
from gatehouse_app.utils.constants import AuthMethodType
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
from gatehouse_app.services.audit_service import AuditService
from gatehouse_app.services.external_auth_service import (
ExternalAuthService,
@@ -139,7 +139,7 @@ class OAuthFlowService:
except ExternalAuthError as e:
# Log failed initiation
AuditService.log_action(
action="external_auth.login.initiated",
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
organization_id=organization_id,
metadata={
"provider_type": provider_type_str,
@@ -236,7 +236,7 @@ class OAuthFlowService:
except ExternalAuthError as e:
AuditService.log_action(
action="external_auth.register.initiated",
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
organization_id=organization_id,
metadata={
"provider_type": provider_type_str,
@@ -399,6 +399,27 @@ class OAuthFlowService:
access_token=tokens["access_token"],
)
if not user_info.get("provider_user_id"):
raise OAuthFlowError(
"Provider did not return a user identifier (sub claim). "
"Cannot complete authentication.",
"MISSING_PROVIDER_USER_ID",
400,
)
if not user_info.get("email"):
raise OAuthFlowError(
"Provider did not return an email address. "
"Cannot complete authentication.",
"MISSING_EMAIL",
400,
)
logger.debug(
f"Got user_info from provider: sub={user_info['provider_user_id']}, "
f"email={user_info['email']}, email_verified={user_info.get('email_verified')}"
)
# Look up user by provider_user_id
auth_method = AuthenticationMethod.query.filter_by(
method_type=provider_type,
@@ -0,0 +1,333 @@
"""SSH Certificate Authority signing service.
Handles SSH certificate signing operations, leveraging sshkey-tools library.
This service is a Gatehouse-integrated version of the secuird/ssh_ca.py logic.
"""
import logging
import os
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional
from sshkey_tools.cert import SSHCertificate, CertificateFields
from sshkey_tools.keys import PublicKey, PrivateKey
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
from gatehouse_app.exceptions import SSHCAError, ValidationError
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
logger = logging.getLogger(__name__)
class SSHCASigningError(Exception):
"""SSH CA signing operation error."""
pass
class SSHCertificateSigningRequest:
"""Represents an SSH certificate signing request."""
def __init__(
self,
ssh_public_key: str,
principals: List[str],
key_id: str,
cert_type: str = "user",
expiry_hours: Optional[int] = None,
critical_options: Optional[Dict[str, str]] = None,
extensions: Optional[List[str]] = None,
):
"""Initialize signing request.
Args:
ssh_public_key: Public key in OpenSSH format (e.g., "ssh-ed25519 AAAA...")
principals: List of principals (e.g., ["prod-servers", "staging"])
key_id: Key identifier (usually user email)
cert_type: Certificate type - "user" or "host" (default: user)
expiry_hours: Certificate validity in hours
critical_options: Critical options dict
extensions: List of extensions (e.g., ["permit-pty", "permit-agent-forwarding"])
"""
self.ssh_public_key = ssh_public_key
self.principals = principals or []
self.key_id = key_id
self.cert_type = cert_type
self.expiry_hours = expiry_hours
self.critical_options = critical_options or {}
self.extensions = extensions or []
def validate(self) -> List[str]:
"""Validate the signing request.
Returns:
List of validation errors (empty if valid)
"""
errors = []
config = get_ssh_ca_config()
# Validate cert type
if self.cert_type not in ("user", "host"):
errors.append(f"Invalid cert_type: {self.cert_type}. Must be 'user' or 'host'")
# Validate SSH public key
if not self.ssh_public_key or len(self.ssh_public_key) < 16:
errors.append("SSH public key is missing or invalid")
else:
try:
PublicKey.from_string(self.ssh_public_key)
except Exception as e:
errors.append(f"SSH public key is not valid: {str(e)}")
# Validate principals
if not self.principals or len(self.principals) == 0:
errors.append("At least one principal is required")
else:
max_principals = config.get_int('max_principals_per_cert')
if len(self.principals) > max_principals:
errors.append(
f"Too many principals ({len(self.principals)}). "
f"Maximum is {max_principals}"
)
# Validate key_id
if not self.key_id or len(self.key_id) < 5:
errors.append("key_id is missing or too short (minimum 5 characters)")
else:
max_id_len = config.get_int('max_key_id_length')
if len(self.key_id) > max_id_len:
errors.append(f"key_id exceeds maximum length of {max_id_len}")
# Validate expiry_hours
if self.expiry_hours is not None:
if not isinstance(self.expiry_hours, int) or self.expiry_hours <= 0:
errors.append("expiry_hours must be a positive integer")
else:
max_validity = config.get_int('max_cert_validity_hours')
if self.expiry_hours > max_validity:
errors.append(
f"Requested expiry ({self.expiry_hours}h) exceeds "
f"maximum allowed ({max_validity}h)"
)
return errors
class SSHCertificateSigningResponse:
"""Represents a signed SSH certificate response."""
def __init__(
self,
certificate: str,
serial: str,
valid_after: datetime,
valid_before: datetime,
principals: Optional[List[str]] = None,
):
"""Initialize signing response.
Args:
certificate: Full certificate in OpenSSH format
serial: Certificate serial number
valid_after: Validity start datetime
valid_before: Validity end datetime
principals: List of principals the cert was issued for
"""
self.certificate = certificate
self.serial = serial
self.valid_after = valid_after
self.valid_before = valid_before
self.principals = principals or []
def to_dict(self) -> Dict[str, Any]:
"""Convert response to dictionary."""
return {
'certificate': self.certificate,
'serial': self.serial,
'valid_after': self.valid_after.isoformat(),
'valid_before': self.valid_before.isoformat(),
}
class SSHCASigningService:
"""Service for signing SSH certificates.
This service handles all SSH certificate signing operations.
It uses configuration from ssh_ca_config to apply rules and limits.
"""
def __init__(self):
"""Initialize the SSH CA signing service."""
self.config = get_ssh_ca_config()
self.logger = logger
def _load_ca_key_from_config(self) -> str:
"""Load CA private key from config (local file or env var).
Returns:
CA private key in PEM/OpenSSH format as string
Raises:
SSHCASigningError: If key cannot be loaded
"""
# Check env var first
key_content = os.environ.get('SSH_CA_PRIVATE_KEY')
if key_content:
return key_content
# Load from file path
key_path = self.config.get_str('ca_key_path', '').strip()
if not key_path:
raise SSHCASigningError(
"CA private key not configured. Set SSH_CA_PRIVATE_KEY env var "
"or ca_key_path in etc/ssh_ca.conf"
)
key_path = os.path.expandvars(os.path.expanduser(key_path))
if not os.path.exists(key_path):
raise SSHCASigningError(f"CA private key file not found: {key_path}")
with open(key_path, 'r') as f:
return f.read()
def sign_certificate(
self,
signing_request: SSHCertificateSigningRequest,
ca_private_key: Optional[str] = None,
) -> SSHCertificateSigningResponse:
"""Sign an SSH certificate.
Args:
signing_request: SSHCertificateSigningRequest instance
ca_private_key: CA private key in PEM format. If not provided,
loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var)
Returns:
SSHCertificateSigningResponse with signed certificate
Raises:
SSHCASigningError: If signing fails
ValidationError: If request is invalid
"""
# Validate request
errors = signing_request.validate()
if errors:
error_msg = "; ".join(errors)
self.logger.error(f"Certificate signing validation failed: {error_msg}")
raise ValidationError(f"Certificate signing validation failed: {error_msg}")
# Load CA key if not provided
if ca_private_key is None:
ca_private_key = self._load_ca_key_from_config()
try:
# Parse CA private key
try:
ca_key = PrivateKey.from_string(ca_private_key)
except Exception as e:
self.logger.error(f"Failed to load CA private key: {str(e)}")
raise SSHCASigningError(f"Invalid CA private key: {str(e)}")
# Parse user's public key
try:
user_pub_key = PublicKey.from_string(signing_request.ssh_public_key)
except Exception as e:
self.logger.error(f"Failed to parse user public key: {str(e)}")
raise SSHCASigningError(f"Invalid user public key: {str(e)}")
# Create certificate
certificate = SSHCertificate.create(
subject_pubkey=user_pub_key,
ca_privkey=ca_key,
)
# Set validity period
now = datetime.utcnow()
expiry_hours = signing_request.expiry_hours or self.config.get_int('cert_validity_hours')
valid_before = now + timedelta(hours=expiry_hours)
# Set certificate fields
cert_type = 1 if signing_request.cert_type == "user" else 0
certificate.fields.cert_type = cert_type
certificate.fields.key_id = signing_request.key_id
certificate.fields.principals = signing_request.principals
certificate.fields.valid_after = now
certificate.fields.valid_before = valid_before
# Set extensions
extensions = signing_request.extensions
if not extensions and self.config.get_bool('extensions_enabled'):
extensions = self.config.get_list('extensions')
certificate.fields.extensions = extensions or []
certificate.fields.critical_options = signing_request.critical_options or {}
# Validate certificate before signing
if not certificate.can_sign():
raise SSHCASigningError("Certificate cannot be signed")
# Sign the certificate
certificate.sign()
# Verify the certificate
try:
certificate.verify(ca_key.public_key, raise_on_error=True)
except Exception as e:
self.logger.error(f"Certificate verification failed: {str(e)}")
raise SSHCASigningError(f"Certificate verification failed: {str(e)}")
# Extract serial from certificate
serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
# Build response
cert_string = certificate.to_string()
self.logger.info(
f"Successfully signed certificate: serial={serial}, "
f"key_id={signing_request.key_id}, principals={signing_request.principals}"
)
return SSHCertificateSigningResponse(
certificate=cert_string,
serial=serial,
valid_after=now,
valid_before=valid_before,
principals=signing_request.principals,
)
except (SSHCASigningError, ValidationError):
raise
except Exception as e:
self.logger.error(f"Unexpected error during certificate signing: {str(e)}", exc_info=True)
raise SSHCASigningError(f"Error signing certificate: {str(e)}")
def verify_ca_key(self, ca_private_key: str) -> Dict[str, Any]:
"""Verify a CA private key is valid and extract metadata.
Args:
ca_private_key: CA private key in PEM format
Returns:
Dictionary with key metadata (fingerprint, key_type, etc.)
Raises:
SSHCASigningError: If key is invalid
"""
try:
ca_key = PrivateKey.from_string(ca_private_key)
pub_key = ca_key.public_key
# Compute fingerprint
fingerprint = compute_ssh_fingerprint(pub_key.to_string())
# Get key type
key_type = pub_key.keytype if hasattr(pub_key, 'keytype') else 'unknown'
return {
'fingerprint': fingerprint,
'key_type': key_type,
'public_key': pub_key.to_string(),
'valid': True,
}
except Exception as e:
self.logger.error(f"CA key verification failed: {str(e)}")
raise SSHCASigningError(f"Invalid CA key: {str(e)}")
+373
View File
@@ -0,0 +1,373 @@
"""SSH Key management service."""
import base64
import logging
import os
import secrets
import subprocess
import tempfile
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from gatehouse_app.extensions import db
from gatehouse_app.models import SSHKey, User
from gatehouse_app.exceptions import (
SSHKeyError,
SSHKeyNotFoundError,
SSHKeyAlreadyExistsError,
SSHKeyNotVerifiedError,
ValidationError,
UserNotFoundError,
)
from gatehouse_app.utils.crypto import (
compute_ssh_fingerprint,
verify_ssh_key_format,
extract_ssh_key_type,
extract_ssh_key_comment,
)
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
logger = logging.getLogger(__name__)
class SSHKeyService:
"""Service for managing SSH keys."""
def __init__(self):
"""Initialize SSH key service."""
self.config = get_ssh_ca_config()
def add_ssh_key(
self,
user_id: str,
public_key: str,
description: Optional[str] = None,
) -> SSHKey:
"""Add an SSH public key for a user.
Args:
user_id: ID of the user
public_key: SSH public key in OpenSSH format
description: Optional description of the key
Returns:
Created SSHKey instance
Raises:
UserNotFoundError: If user doesn't exist
SSHKeyError: If key format is invalid
SSHKeyAlreadyExistsError: If key already exists
"""
# Verify user exists
user = User.query.get(user_id)
if not user:
raise UserNotFoundError(f"User {user_id} not found")
# Validate key format
if not verify_ssh_key_format(public_key):
raise SSHKeyError("Invalid SSH public key format")
# Compute fingerprint
try:
fingerprint = compute_ssh_fingerprint(public_key)
except Exception as e:
logger.error(f"Failed to compute fingerprint: {str(e)}")
raise SSHKeyError(f"Failed to compute key fingerprint: {str(e)}")
# Check for duplicate (including soft-deleted records — fingerprint is unique in DB)
existing = SSHKey.query.filter_by(fingerprint=fingerprint).first()
if existing:
if existing.deleted_at is not None:
# Restore the soft-deleted key: clear deleted_at and update fields
existing.deleted_at = None
existing.user_id = user_id
existing.description = description or existing.description
existing.verified = False
existing.verified_at = None
existing.verify_text = None
existing.verify_text_created_at = None
db.session.commit()
logger.info(
f"Restored soft-deleted SSH key for user {user_id}: "
f"fingerprint={fingerprint}"
)
return existing
raise SSHKeyAlreadyExistsError(
f"SSH key with fingerprint {fingerprint} already exists"
)
# Extract metadata
key_type = extract_ssh_key_type(public_key)
key_comment = extract_ssh_key_comment(public_key)
# Create SSH key record
ssh_key = SSHKey(
user_id=user_id,
payload=public_key,
fingerprint=fingerprint,
description=description,
key_type=key_type,
key_comment=key_comment,
verified=False,
)
ssh_key.save()
logger.info(
f"SSH key added for user {user_id}: "
f"fingerprint={fingerprint}, type={key_type}"
)
return ssh_key
def get_ssh_key(self, key_id: str) -> SSHKey:
"""Get an SSH key by ID.
Args:
key_id: SSH key ID
Returns:
SSHKey instance
Raises:
SSHKeyNotFoundError: If key not found
"""
key = SSHKey.query.filter_by(id=key_id, deleted_at=None).first()
if not key:
raise SSHKeyNotFoundError(f"SSH key {key_id} not found")
return key
def get_user_ssh_keys(self, user_id: str) -> List[SSHKey]:
"""Get all SSH keys for a user.
Args:
user_id: User ID
Returns:
List of SSHKey instances
"""
return SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all()
def get_user_verified_ssh_keys(self, user_id: str) -> List[SSHKey]:
"""Get all verified SSH keys for a user.
Args:
user_id: User ID
Returns:
List of verified SSHKey instances
"""
return SSHKey.query.filter_by(
user_id=user_id,
verified=True,
deleted_at=None,
).all()
def delete_ssh_key(self, key_id: str) -> None:
"""Soft-delete an SSH key.
Args:
key_id: SSH key ID
Raises:
SSHKeyNotFoundError: If key not found
"""
key = self.get_ssh_key(key_id)
key.delete()
logger.info(f"SSH key deleted: {key_id}")
def generate_verification_challenge(self, key_id: str) -> str:
"""Generate a verification challenge for an SSH key.
The user must sign this challenge text with their private key
to prove key ownership.
Args:
key_id: SSH key ID
Returns:
Verification challenge text
Raises:
SSHKeyNotFoundError: If key not found
"""
key = self.get_ssh_key(key_id)
# Generate random challenge
challenge = secrets.token_hex(32)
challenge_text = f"Please sign this to verify SSH key ownership: {challenge}"
# Store challenge
key.verify_text = challenge_text
key.verify_text_created_at = datetime.utcnow()
key.save()
logger.info(f"Generated verification challenge for SSH key {key_id}")
return challenge_text
def verify_ssh_key_ownership(
self,
key_id: str,
signature: str,
) -> bool:
"""Verify SSH key ownership via signature.
The user must sign the verification challenge with their private key.
We verify the signature using the public key.
Args:
key_id: SSH key ID
signature: Base64-encoded signature of the challenge
Returns:
True if signature is valid
Raises:
SSHKeyNotFoundError: If key not found
SSHKeyNotVerifiedError: If challenge is stale or missing
SSHKeyError: If verification fails
"""
key = self.get_ssh_key(key_id)
# Check if challenge exists and is not stale
if not key.verify_text or not key.verify_text_created_at:
raise SSHKeyNotVerifiedError("No verification challenge generated")
max_age = self.config.get_int('verification_challenge_max_age')
age = datetime.utcnow() - key.verify_text_created_at
if age.total_seconds() > (max_age * 3600):
raise SSHKeyNotVerifiedError("Verification challenge has expired")
try:
# Verify the SSH signature using ssh-keygen -Y verify.
# The CLI signs the challenge with: ssh-keygen -Y sign -f <key> -n file <challenge>
# We verify with: ssh-keygen -Y verify -f <allowed_signers> -I <identity> -n file -s <sig> < <message>
#
# allowed_signers format: "<identity> <keytype> <pubkey>"
# We use the key fingerprint as the identity.
sig_bytes = base64.b64decode(signature)
challenge_text = key.verify_text + "\n"
with tempfile.TemporaryDirectory() as tmpdir:
allowed_signers_path = os.path.join(tmpdir, "allowed_signers")
sig_path = os.path.join(tmpdir, "message.sig")
message_path = os.path.join(tmpdir, "message.txt")
identity = key.fingerprint
# Write the allowed_signers file
with open(allowed_signers_path, "w") as f:
f.write(f"{identity} {key.payload}\n")
# Write the signature file
with open(sig_path, "wb") as f:
f.write(sig_bytes)
# Write the challenge message
with open(message_path, "w") as f:
f.write(challenge_text)
result = subprocess.run(
[
"ssh-keygen", "-Y", "verify",
"-f", allowed_signers_path,
"-I", identity,
"-n", "file",
"-s", sig_path,
],
stdin=open(message_path, "rb"),
capture_output=True,
timeout=10,
)
if result.returncode != 0:
stderr = result.stderr.decode(errors="replace").strip()
logger.warning(f"SSH signature verification failed for key {key_id}: {stderr}")
raise SSHKeyError(f"Signature verification failed: {stderr}")
key.mark_verified()
logger.info(f"SSH key verified: {key_id}")
return True
except Exception as e:
logger.error(f"SSH key verification failed: {str(e)}")
raise SSHKeyError(f"Signature verification failed: {str(e)}")
def get_key_fingerprint(self, key_id: str) -> str:
"""Get the fingerprint of an SSH key.
Args:
key_id: SSH key ID
Returns:
Fingerprint string
Raises:
SSHKeyNotFoundError: If key not found
"""
key = self.get_ssh_key(key_id)
return key.fingerprint
def update_ssh_key_description(self, key_id: str, description: str) -> SSHKey:
"""Update the description of an SSH key.
Args:
key_id: SSH key ID
description: New description
Returns:
Updated SSHKey instance
Raises:
SSHKeyNotFoundError: If key not found
"""
key = self.get_ssh_key(key_id)
key.description = description
key.save()
return key
def cleanup_expired_challenges(self) -> int:
"""Clean up expired verification challenges.
Returns:
Number of challenges cleaned
"""
max_age = self.config.get_int('verification_challenge_max_age')
threshold = datetime.utcnow() - timedelta(hours=max_age)
expired = SSHKey.query.filter(
SSHKey.verify_text_created_at < threshold,
SSHKey.verify_text_created_at.isnot(None),
SSHKey.deleted_at.is_(None),
).update({"verify_text": None, "verify_text_created_at": None})
db.session.commit()
logger.info(f"Cleaned up {expired} expired verification challenges")
return expired
def cleanup_unverified_keys(self) -> int:
"""Delete unverified SSH keys older than configured days.
Returns:
Number of keys deleted
"""
days = self.config.get_int('auto_delete_unverified_days')
threshold = datetime.utcnow() - timedelta(days=days)
old_unverified = SSHKey.query.filter(
SSHKey.verified == False,
SSHKey.created_at < threshold,
SSHKey.deleted_at.is_(None),
).all()
count = 0
for key in old_unverified:
key.delete()
count += 1
logger.info(f"Deleted {count} unverified SSH keys older than {days} days")
return count
+41
View File
@@ -12,6 +12,16 @@ class UserStatus(str, Enum):
COMPLIANCE_SUSPENDED = "compliance_suspended"
class Role(str, Enum):
"""Generic role definitions (hierarchy: Admin > Manager > Member > Viewer > Guest)."""
ADMIN = "admin"
MANAGER = "manager"
MEMBER = "member"
VIEWER = "viewer"
GUEST = "guest"
class OrganizationRole(str, Enum):
"""Organization member roles."""
@@ -105,6 +115,37 @@ class AuditAction(str, Enum):
EXTERNAL_AUTH_CONFIG_UPDATE = "external_auth.config.update"
EXTERNAL_AUTH_CONFIG_DELETE = "external_auth.config.delete"
# SSH Key and Certificate actions
SSH_KEY_ADDED = "ssh.key.added"
SSH_KEY_VERIFIED = "ssh.key.verified"
SSH_KEY_DELETED = "ssh.key.deleted"
SSH_KEY_VALIDATION_FAILED = "ssh.key.validation.failed"
SSH_CERT_REQUESTED = "ssh.cert.requested"
SSH_CERT_ISSUED = "ssh.cert.issued"
SSH_CERT_FAILED = "ssh.cert.failed"
SSH_CERT_REVOKED = "ssh.cert.revoked"
SSH_CERT_EXPIRED = "ssh.cert.expired"
# CA actions
CA_CREATED = "ca.created"
CA_UPDATED = "ca.updated"
CA_DELETED = "ca.deleted"
CA_KEY_ROTATED = "ca.key.rotated"
# Principal actions
PRINCIPAL_CREATED = "principal.created"
PRINCIPAL_UPDATED = "principal.updated"
PRINCIPAL_DELETED = "principal.deleted"
PRINCIPAL_MEMBER_ADDED = "principal.member.added"
PRINCIPAL_MEMBER_REMOVED = "principal.member.removed"
# Department actions
DEPARTMENT_CREATED = "department.created"
DEPARTMENT_UPDATED = "department.updated"
DEPARTMENT_DELETED = "department.deleted"
DEPARTMENT_MEMBER_ADDED = "department.member.added"
DEPARTMENT_MEMBER_REMOVED = "department.member.removed"
class OIDCGrantType(str, Enum):
"""OIDC grant types."""
+128
View File
@@ -0,0 +1,128 @@
"""Cryptographic utilities for SSH operations."""
import hashlib
import base64
from typing import Optional
def compute_ssh_fingerprint(public_key_str: str, hash_algorithm: str = "sha256") -> str:
"""Compute the fingerprint of an SSH public key.
Args:
public_key_str: SSH public key in OpenSSH format
hash_algorithm: Hash algorithm to use (sha256, sha1, md5)
Returns:
Fingerprint string in the format "algorithm:hex_digest"
Example:
>>> key = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKp2..."
>>> fp = compute_ssh_fingerprint(key)
>>> print(fp)
sha256:Kb+...
"""
if not public_key_str:
raise ValueError("Public key string is empty")
# Parse OpenSSH format: "ssh-ed25519 <base64> [comment]"
parts = public_key_str.strip().split()
if len(parts) < 2:
raise ValueError("Invalid OpenSSH public key format")
try:
# The base64-encoded key is the second part
key_bytes = base64.b64decode(parts[1])
except Exception as e:
raise ValueError(f"Failed to decode public key: {str(e)}")
# Compute hash
if hash_algorithm == "sha256":
digest = hashlib.sha256(key_bytes).digest()
# SSH format uses base64 encoding without padding
fingerprint = base64.b64encode(digest).decode().rstrip('=')
elif hash_algorithm == "sha1":
digest = hashlib.sha1(key_bytes).hexdigest()
fingerprint = digest
elif hash_algorithm == "md5":
digest = hashlib.md5(key_bytes).hexdigest()
# Format as colons
fingerprint = ':'.join(digest[i:i+2] for i in range(0, len(digest), 2))
else:
raise ValueError(f"Unsupported hash algorithm: {hash_algorithm}")
return f"{hash_algorithm}:{fingerprint}"
def verify_ssh_key_format(public_key_str: str) -> bool:
"""Verify that a string is in valid OpenSSH public key format.
Args:
public_key_str: Potential SSH public key
Returns:
True if valid OpenSSH format, False otherwise
"""
if not public_key_str or not isinstance(public_key_str, str):
return False
parts = public_key_str.strip().split()
# Must have at least key type and key material
if len(parts) < 2:
return False
key_type = parts[0]
# Valid key types
valid_types = [
'ssh-rsa',
'ssh-ed25519',
'ecdsa-sha2-nistp256',
'ecdsa-sha2-nistp384',
'ecdsa-sha2-nistp521',
'ssh-dss',
]
if key_type not in valid_types:
return False
# Try to decode base64
try:
base64.b64decode(parts[1])
return True
except Exception:
return False
def extract_ssh_key_type(public_key_str: str) -> Optional[str]:
"""Extract the key type from an OpenSSH public key.
Args:
public_key_str: SSH public key in OpenSSH format
Returns:
Key type (e.g., "ssh-ed25519") or None if invalid
"""
if not verify_ssh_key_format(public_key_str):
return None
return public_key_str.strip().split()[0]
def extract_ssh_key_comment(public_key_str: str) -> Optional[str]:
"""Extract the comment from an OpenSSH public key.
Args:
public_key_str: SSH public key in OpenSSH format
Returns:
Comment string or None if not present
"""
if not verify_ssh_key_format(public_key_str):
return None
parts = public_key_str.strip().split()
if len(parts) >= 3:
# Everything after the second part is the comment
return ' '.join(parts[2:])
return None
@@ -0,0 +1,173 @@
"""Add SSH CA models: SSHKey, SSHCertificate, CA, CertificateAuditLog.
Revision ID: 007
Revises: 006
Create Date: 2026-02-27 11:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '007'
down_revision = '006'
branch_labels = None
depends_on = None
def upgrade():
# ### CA table ###
op.create_table('cas',
sa.Column('organization_id', sa.String(length=36), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('key_type', sa.Enum('ed25519', 'rsa', 'ecdsa', name='ca_key_type_enum'), nullable=False),
sa.Column('private_key', sa.Text(), nullable=False),
sa.Column('public_key', sa.Text(), nullable=False),
sa.Column('fingerprint', sa.String(length=255), nullable=False),
sa.Column('crl_enabled', sa.Boolean(), nullable=False),
sa.Column('crl_endpoint', sa.String(length=512), nullable=True),
sa.Column('default_cert_validity_hours', sa.Integer(), nullable=False),
sa.Column('max_cert_validity_hours', sa.Integer(), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('rotated_at', sa.DateTime(), nullable=True),
sa.Column('rotation_reason', sa.String(length=255), nullable=True),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('fingerprint'),
sa.UniqueConstraint('organization_id', 'name', name='uix_org_ca_name')
)
op.create_index(op.f('ix_cas_organization_id'), 'cas', ['organization_id'], unique=False)
op.create_index('idx_ca_org_active', 'cas', ['organization_id', 'is_active'], unique=False)
# ### SSHKey table ###
op.create_table('ssh_keys',
sa.Column('user_id', sa.String(length=36), nullable=False),
sa.Column('payload', sa.Text(), nullable=False),
sa.Column('fingerprint', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(length=255), nullable=True),
sa.Column('verified', sa.Boolean(), nullable=False),
sa.Column('verified_at', sa.DateTime(), nullable=True),
sa.Column('verify_text', sa.String(length=255), nullable=True),
sa.Column('verify_text_created_at', sa.DateTime(), nullable=True),
sa.Column('key_type', sa.String(length=50), nullable=True),
sa.Column('key_bits', sa.Integer(), nullable=True),
sa.Column('key_comment', sa.String(length=255), nullable=True),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('payload'),
sa.UniqueConstraint('fingerprint')
)
op.create_index(op.f('ix_ssh_keys_user_id'), 'ssh_keys', ['user_id'], unique=False)
op.create_index(op.f('ix_ssh_keys_fingerprint'), 'ssh_keys', ['fingerprint'], unique=False)
op.create_index(op.f('ix_ssh_keys_verified'), 'ssh_keys', ['verified'], unique=False)
op.create_index('idx_ssh_key_user_verified', 'ssh_keys', ['user_id', 'verified'], unique=False)
# ### SSHCertificate table ###
op.create_table('ssh_certificates',
sa.Column('ca_id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=36), nullable=False),
sa.Column('ssh_key_id', sa.String(length=36), nullable=False),
sa.Column('certificate', sa.Text(), nullable=False),
sa.Column('serial', sa.String(length=255), nullable=False),
sa.Column('key_id', sa.String(length=255), nullable=False),
sa.Column('cert_type', sa.Enum('user', 'host', name='ssh_cert_type_enum'), nullable=False),
sa.Column('principals', sa.JSON(), nullable=False),
sa.Column('valid_after', sa.DateTime(), nullable=False),
sa.Column('valid_before', sa.DateTime(), nullable=False),
sa.Column('revoked', sa.Boolean(), nullable=False),
sa.Column('revoked_at', sa.DateTime(), nullable=True),
sa.Column('revoke_reason', sa.String(length=255), nullable=True),
sa.Column('status', sa.Enum('requested', 'issued', 'revoked', 'expired', 'superseded', name='ssh_cert_status_enum'), nullable=False),
sa.Column('request_ip', sa.String(length=45), nullable=True),
sa.Column('request_user_agent', sa.String(length=512), nullable=True),
sa.Column('critical_options', sa.JSON(), nullable=True),
sa.Column('extensions', sa.JSON(), nullable=True),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ),
sa.ForeignKeyConstraint(['ssh_key_id'], ['ssh_keys.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id'),
sa.UniqueConstraint('serial')
)
op.create_index(op.f('ix_ssh_certificates_ca_id'), 'ssh_certificates', ['ca_id'], unique=False)
op.create_index(op.f('ix_ssh_certificates_user_id'), 'ssh_certificates', ['user_id'], unique=False)
op.create_index(op.f('ix_ssh_certificates_ssh_key_id'), 'ssh_certificates', ['ssh_key_id'], unique=False)
op.create_index(op.f('ix_ssh_certificates_serial'), 'ssh_certificates', ['serial'], unique=False)
op.create_index(op.f('ix_ssh_certificates_revoked'), 'ssh_certificates', ['revoked'], unique=False)
op.create_index(op.f('ix_ssh_certificates_status'), 'ssh_certificates', ['status'], unique=False)
op.create_index('idx_cert_user_status', 'ssh_certificates', ['user_id', 'status'], unique=False)
op.create_index('idx_cert_validity', 'ssh_certificates', ['valid_after', 'valid_before'], unique=False)
op.create_index('idx_cert_revoked', 'ssh_certificates', ['revoked', 'revoked_at'], unique=False)
# ### CertificateAuditLog table ###
op.create_table('certificate_audit_logs',
sa.Column('certificate_id', sa.String(length=36), nullable=False),
sa.Column('user_id', sa.String(length=36), nullable=True),
sa.Column('action', sa.String(length=50), nullable=False),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.String(length=512), nullable=True),
sa.Column('request_id', sa.String(length=36), nullable=True),
sa.Column('message', sa.Text(), nullable=True),
sa.Column('extra_data', sa.JSON(), nullable=True),
sa.Column('success', sa.Boolean(), nullable=False),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('id', sa.String(length=36), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['certificate_id'], ['ssh_certificates.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('id')
)
op.create_index(op.f('ix_certificate_audit_logs_certificate_id'), 'certificate_audit_logs', ['certificate_id'], unique=False)
op.create_index(op.f('ix_certificate_audit_logs_user_id'), 'certificate_audit_logs', ['user_id'], unique=False)
op.create_index(op.f('ix_certificate_audit_logs_action'), 'certificate_audit_logs', ['action'], unique=False)
op.create_index('idx_cert_audit_cert_action', 'certificate_audit_logs', ['certificate_id', 'action'], unique=False)
op.create_index('idx_cert_audit_user', 'certificate_audit_logs', ['user_id', 'created_at'], unique=False)
def downgrade():
op.drop_index('idx_cert_audit_user', table_name='certificate_audit_logs')
op.drop_index('idx_cert_audit_cert_action', table_name='certificate_audit_logs')
op.drop_index(op.f('ix_certificate_audit_logs_action'), table_name='certificate_audit_logs')
op.drop_index(op.f('ix_certificate_audit_logs_user_id'), table_name='certificate_audit_logs')
op.drop_index(op.f('ix_certificate_audit_logs_certificate_id'), table_name='certificate_audit_logs')
op.drop_table('certificate_audit_logs')
op.drop_index('idx_cert_revoked', table_name='ssh_certificates')
op.drop_index('idx_cert_validity', table_name='ssh_certificates')
op.drop_index('idx_cert_user_status', table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_status'), table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_revoked'), table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_serial'), table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_ssh_key_id'), table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_user_id'), table_name='ssh_certificates')
op.drop_index(op.f('ix_ssh_certificates_ca_id'), table_name='ssh_certificates')
op.drop_table('ssh_certificates')
op.drop_index('idx_ssh_key_user_verified', table_name='ssh_keys')
op.drop_index(op.f('ix_ssh_keys_verified'), table_name='ssh_keys')
op.drop_index(op.f('ix_ssh_keys_fingerprint'), table_name='ssh_keys')
op.drop_index(op.f('ix_ssh_keys_user_id'), table_name='ssh_keys')
op.drop_table('ssh_keys')
op.drop_index('idx_ca_org_active', table_name='cas')
op.drop_index(op.f('ix_cas_organization_id'), table_name='cas')
op.drop_table('cas')
@@ -0,0 +1,53 @@
"""Add TOTP and WEBAUTHN to authmethodtype enum.
Revision ID: 008
Revises: 007
Create Date: 2026-02-27 15:00:00.000000
The original migration (001_base) created authmethodtype with only:
PASSWORD, GOOGLE, GITHUB, MICROSOFT, SAML, OIDC
This migration adds the missing TOTP and WEBAUTHN values so
has_totp_enabled() and has_webauthn_enabled() queries work correctly.
"""
from alembic import op
import sqlalchemy as sa
revision = '008'
down_revision = '007'
branch_labels = None
depends_on = None
def upgrade():
# Add TOTP to the enum (idempotent approach using DO block)
op.execute("""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_enum
WHERE enumlabel = 'TOTP'
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype')
) THEN
ALTER TYPE authmethodtype ADD VALUE 'TOTP';
END IF;
END$$;
""")
op.execute("""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_enum
WHERE enumlabel = 'WEBAUTHN'
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype')
) THEN
ALTER TYPE authmethodtype ADD VALUE 'WEBAUTHN';
END IF;
END$$;
""")
def downgrade():
# PostgreSQL does not support removing enum values; downgrade is a no-op.
pass
@@ -0,0 +1,61 @@
"""Sync auditaction enum with all AuditAction Python enum values.
Revision ID: 009
Revises: 008
Create Date: 2026-02-27 15:20:00.000000
The auditaction DB enum was only created with the initial 17 values from 001_base.py.
All TOTP, WebAuthn, OAuth, SSH, CA, Principal, and Department audit actions were added
to the Python enum but never synced to the DB type.
"""
from alembic import op
revision = '009'
down_revision = '008'
branch_labels = None
depends_on = None
MISSING_VALUES = [
'TOTP_ENROLL_INITIATED', 'TOTP_ENROLL_COMPLETED', 'TOTP_VERIFY_SUCCESS',
'TOTP_VERIFY_FAILED', 'TOTP_DISABLED', 'TOTP_BACKUP_CODE_USED',
'TOTP_BACKUP_CODES_REGENERATED', 'WEBAUTHN_REGISTER_INITIATED',
'WEBAUTHN_REGISTER_COMPLETED', 'WEBAUTHN_REGISTER_FAILED',
'WEBAUTHN_LOGIN_INITIATED', 'WEBAUTHN_LOGIN_SUCCESS', 'WEBAUTHN_LOGIN_FAILED',
'WEBAUTHN_CREDENTIAL_DELETED', 'WEBAUTHN_CREDENTIAL_RENAMED',
'ORG_SECURITY_POLICY_UPDATE', 'USER_SECURITY_POLICY_OVERRIDE_UPDATE',
'MFA_POLICY_USER_SUSPENDED', 'MFA_POLICY_USER_COMPLIANT',
'EXTERNAL_AUTH_LINK_INITIATED', 'EXTERNAL_AUTH_LINK_COMPLETED',
'EXTERNAL_AUTH_LINK_FAILED', 'EXTERNAL_AUTH_UNLINK', 'EXTERNAL_AUTH_LOGIN',
'EXTERNAL_AUTH_LOGIN_FAILED', 'EXTERNAL_AUTH_TOKEN_REFRESH',
'EXTERNAL_AUTH_CONFIG_CREATE', 'EXTERNAL_AUTH_CONFIG_UPDATE',
'EXTERNAL_AUTH_CONFIG_DELETE', 'SSH_KEY_ADDED', 'SSH_KEY_VERIFIED',
'SSH_KEY_DELETED', 'SSH_KEY_VALIDATION_FAILED', 'SSH_CERT_REQUESTED',
'SSH_CERT_ISSUED', 'SSH_CERT_FAILED', 'SSH_CERT_REVOKED', 'SSH_CERT_EXPIRED',
'CA_CREATED', 'CA_UPDATED', 'CA_DELETED', 'CA_KEY_ROTATED',
'PRINCIPAL_CREATED', 'PRINCIPAL_UPDATED', 'PRINCIPAL_DELETED',
'PRINCIPAL_MEMBER_ADDED', 'PRINCIPAL_MEMBER_REMOVED',
'DEPARTMENT_CREATED', 'DEPARTMENT_UPDATED', 'DEPARTMENT_DELETED',
'DEPARTMENT_MEMBER_ADDED', 'DEPARTMENT_MEMBER_REMOVED',
]
def upgrade():
for val in MISSING_VALUES:
op.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_enum
WHERE enumlabel = '{val}'
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
) THEN
ALTER TYPE auditaction ADD VALUE '{val}';
END IF;
END$$;
""")
def downgrade():
# PostgreSQL does not support removing enum values; downgrade is a no-op.
pass
@@ -0,0 +1,33 @@
"""Make CA.organization_id nullable (system CA) and add cert_id to sign response
Revision ID: 012_ca_nullable_org_and_cert_serial
Revises: 011_org_invite_tokens
Create Date: 2025-01-01 00:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = '012_ca_nullable_org'
down_revision = '011_org_invite_tokens'
branch_labels = None
depends_on = None
def upgrade():
# Allow CA records without an org (e.g. the global system-config CA)
with op.batch_alter_table('cas', schema=None) as batch_op:
batch_op.alter_column(
'organization_id',
existing_type=sa.String(36),
nullable=True,
)
def downgrade():
with op.batch_alter_table('cas', schema=None) as batch_op:
batch_op.alter_column(
'organization_id',
existing_type=sa.String(36),
nullable=False,
)
+3
View File
@@ -48,3 +48,6 @@ Flask-Limiter==3.5.0
# Logging
python-json-logger==2.0.7
qrcode[pil]
# SSH CA Certificate signing
sshkey-tools==0.11.0
+22
View File
@@ -20,3 +20,25 @@ watchdog==3.0.0
# Documentation
sphinx==7.2.6
# Web framework & Database
Flask==3.0.0
Flask-SQLAlchemy==3.1.1
Flask-Migrate==4.0.5
sqlalchemy-cockroachdb==2.0.3
# Utilities
colorlog==6.8.0
coloredlogs==15.0.1
prettytable==3.10.2
tabulate==0.9.0
requests==2.31.0
pytz==2023.3
python-dotenv==1.0.0
pydantic==2.5.0
PyJWT==2.8.0
cryptography==41.0.7
pycryptodome==3.20.0
psycopg2==2.9.9
sshkey-tools==0.10.3
sendgrid==6.11.0