Feat: Added CA-merged with Securid-Principals, Depart, Client-CLI
This commit is contained in:
+115
-39
@@ -2,6 +2,7 @@
|
||||
import base64
|
||||
from datetime import datetime
|
||||
import os
|
||||
import sys
|
||||
import webbrowser
|
||||
import requests
|
||||
import argparse
|
||||
@@ -22,13 +23,12 @@ import base64
|
||||
load_dotenv()
|
||||
|
||||
# Get the API_URL from the environment variables
|
||||
SIGN_URL = os.getenv("SIGN_URL", "http://localhost:1234")
|
||||
SIGN_URL = os.getenv("SIGN_URL", "http://localhost:5000")
|
||||
LISTENER_HOST_NAME = "127.0.0.1"
|
||||
LISTENER_SERVER_PORT = 8250
|
||||
CA_API_HOST = "127.0.0.1"
|
||||
CA_SERVER_PORT = 1234
|
||||
CACHE_FILE = 'token_cache.json' ###need to change it to secure location and permissions if used in production
|
||||
CERT_FILE_PATH = "/tmp/ssl-cert"
|
||||
CACHE_FILE = os.path.expanduser('~/.gatehouse/token_cache.json')
|
||||
os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True)
|
||||
CERT_FILE_PATH = "/tmp/ssh-cert"
|
||||
CHALLENGE_FILE_PATH = "/tmp/challenge.txt"
|
||||
CHALLENGE_SIG_FILE_PATH = "/tmp/challenge.txt.sig"
|
||||
|
||||
@@ -116,7 +116,7 @@ def decode_and_validate_token(token):
|
||||
return True # Token is valid
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token validation failed: {e}")
|
||||
logger.debug(f"Token validation failed: {e}")
|
||||
return False
|
||||
|
||||
def request_token():
|
||||
@@ -129,19 +129,34 @@ def request_token():
|
||||
logger.debug("Token loaded from cache: %s", token)
|
||||
|
||||
# Validate the cached token, if it exists
|
||||
if token and decode_and_validate_token(token):
|
||||
logger.info("Cached token is valid. Using cached token.")
|
||||
return token
|
||||
|
||||
logger.info("No valid cached token found, proceeding to request a new token.")
|
||||
token = ""
|
||||
if token:
|
||||
try:
|
||||
if decode_and_validate_token(token):
|
||||
logger.info("Cached token is valid. Using cached token.")
|
||||
return token
|
||||
except Exception:
|
||||
pass
|
||||
# Try opaque token via /auth/me
|
||||
try:
|
||||
r = requests.get(
|
||||
f"{SIGN_URL}/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=5,
|
||||
)
|
||||
if r.status_code == 200:
|
||||
logger.info("Cached session token is valid. Using cached token.")
|
||||
return token
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("Cached token is expired or invalid, requesting a new token.")
|
||||
token = ""
|
||||
|
||||
# Prepare the redirect URL for the token request
|
||||
redirect_url = f"http://{LISTENER_HOST_NAME}:{LISTENER_SERVER_PORT}/?token="
|
||||
logger.info("Redirect URL: %s", redirect_url)
|
||||
|
||||
# Construct the token request URL
|
||||
token_url = f"{SIGN_URL}/token_please?redirect_url={redirect_url}"
|
||||
token_url = f"{SIGN_URL}/api/v1/token_please?redirect_url={redirect_url}"
|
||||
logger.info("Token request URL: %s", token_url)
|
||||
|
||||
# Start the web server to handle the token response
|
||||
@@ -168,10 +183,10 @@ def get_activated_ssh_key():
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.get(f"{SIGN_URL}/api/ssh-keys", headers=headers)
|
||||
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
keys = response.json().get('ssh_keys', [])
|
||||
keys = response.json().get('keys', [])
|
||||
verified_keys = [key for key in keys if key['verified']]
|
||||
|
||||
if not verified_keys:
|
||||
@@ -179,8 +194,19 @@ def get_activated_ssh_key():
|
||||
exit(1)
|
||||
|
||||
if len(verified_keys) > 1:
|
||||
logger.error("Multiple verified SSH keys found. Please specify CERT_ID.")
|
||||
exit(1)
|
||||
# If running interactively, let the user pick; otherwise use the most recently added key
|
||||
if sys.stdout.isatty():
|
||||
print("\nMultiple verified SSH keys found. Please choose one:")
|
||||
for i, k in enumerate(verified_keys):
|
||||
print(f" [{i+1}] {k['id'][:8]}... fingerprint={k.get('fingerprint','?')} name={k.get('key_comment','?')}")
|
||||
try:
|
||||
choice = int(input("Enter number: ").strip()) - 1
|
||||
if 0 <= choice < len(verified_keys):
|
||||
return verified_keys[choice]['id']
|
||||
except (ValueError, EOFError):
|
||||
pass
|
||||
logger.info("Multiple verified SSH keys found; using the most recently added one.")
|
||||
verified_keys.sort(key=lambda k: k.get('created_at', ''), reverse=True)
|
||||
|
||||
return verified_keys[0]['id']
|
||||
|
||||
@@ -193,26 +219,35 @@ def get_activated_ssh_key():
|
||||
exit(1)
|
||||
|
||||
|
||||
def request_certificate():
|
||||
def request_certificate(principals=None):
|
||||
CERT_ID = os.getenv("CERT_ID") or get_activated_ssh_key()
|
||||
|
||||
if not principals:
|
||||
env_principals = os.getenv("PRINCIPALS")
|
||||
if env_principals:
|
||||
principals = [p.strip() for p in env_principals.split(',')]
|
||||
else:
|
||||
principals = [os.getlogin()]
|
||||
|
||||
headers = {
|
||||
'content-type': 'application/json',
|
||||
"Authorization": "bearer " + token
|
||||
}
|
||||
|
||||
payload = {
|
||||
'cert_id': CERT_ID
|
||||
'cert_id': CERT_ID,
|
||||
'principals': principals,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{SIGN_URL}/sign_cert", json=payload, headers=headers)
|
||||
response = requests.post(f"{SIGN_URL}/api/v1/ssh/sign", json=payload, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
if response.status_code == 201:
|
||||
json_result = response.json()
|
||||
with open(CERT_FILE_PATH, 'w') as f:
|
||||
f.write(json_result['certificate'])
|
||||
logger.info(f"Certificate signed successfully, located at {CERT_FILE_PATH}")
|
||||
logger.info(f"Valid for principals: {', '.join(json_result.get('principals', principals))}")
|
||||
logger.info("You can login to your destination server with the following command")
|
||||
logger.info(f"\tssh user@server -o CertificateFile={CERT_FILE_PATH}")
|
||||
else:
|
||||
@@ -238,14 +273,14 @@ def generate_and_sign_challenge(ssh_key_file,key_id):
|
||||
|
||||
# Send the POST request
|
||||
response = requests.get(
|
||||
f"http://{CA_API_HOST}:{CA_SERVER_PORT}/api/ssh-key/{key_id}/validationData",
|
||||
f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify",
|
||||
headers=headers
|
||||
)
|
||||
if response.status_code!=200:
|
||||
logger.error(f"Server returned unexpected code {response.status_code}")
|
||||
return False
|
||||
|
||||
challenge_text=response.json()['validationText']+"\n"
|
||||
challenge_text=response.json().get('challenge_text', response.json().get('validationText', ''))+"\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to fetch SSH Key validation data {e}")
|
||||
@@ -291,7 +326,7 @@ def submit_signature_validation(signature, key_id):
|
||||
|
||||
# Send the POST request
|
||||
response = requests.post(
|
||||
f"http://{CA_API_HOST}:{CA_SERVER_PORT}/api/ssh-key/{key_id}/validate",
|
||||
f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
@@ -317,12 +352,12 @@ def remove_ssh_key(key_id=None):
|
||||
}
|
||||
|
||||
# List keys first
|
||||
response = requests.get(f"{SIGN_URL}/api/ssh-keys", headers=headers)
|
||||
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=headers)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}")
|
||||
exit(1)
|
||||
|
||||
keys = response.json().get('ssh_keys', [])
|
||||
keys = response.json().get('keys', [])
|
||||
if not keys:
|
||||
logger.info("No SSH keys found for your user.")
|
||||
return
|
||||
@@ -359,7 +394,7 @@ def remove_ssh_key(key_id=None):
|
||||
exit(1)
|
||||
|
||||
for k in keys_to_delete:
|
||||
del_response = requests.delete(f"{SIGN_URL}/api/ssh-key/{k['id']}", headers=headers)
|
||||
del_response = requests.delete(f"{SIGN_URL}/api/v1/ssh/keys/{k['id']}", headers=headers)
|
||||
if del_response.status_code == 200:
|
||||
logger.info(f"Key {k['id']} removed successfully.")
|
||||
else:
|
||||
@@ -381,20 +416,40 @@ def add_ssh_key(ssh_key_file):
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
ssh_key = ssh_key_file.read().decode('utf-8')
|
||||
}
|
||||
|
||||
if hasattr(ssh_key_file, 'read'):
|
||||
# File object (e.g. argparse.FileType('rb'))
|
||||
key_bytes = ssh_key_file.read()
|
||||
key_path = ssh_key_file.name
|
||||
elif isinstance(ssh_key_file, bytes):
|
||||
key_bytes = ssh_key_file
|
||||
key_path = None
|
||||
else:
|
||||
# String path
|
||||
key_path = str(ssh_key_file)
|
||||
with open(key_path, 'rb') as f:
|
||||
key_bytes = f.read()
|
||||
|
||||
ssh_key = key_bytes.decode('utf-8').strip()
|
||||
|
||||
payload = {
|
||||
'description': 'Added via gatehouse CLI tool',
|
||||
'key': ssh_key
|
||||
}
|
||||
|
||||
response = requests.post(f"{SIGN_URL}/api/ssh-key/add", json=payload, headers=headers)
|
||||
response = requests.post(f"{SIGN_URL}/api/v1/ssh/keys", json=payload, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
ssh_key_id=response.json()['key_id']
|
||||
if response.status_code == 201:
|
||||
ssh_key_id=response.json()['id']
|
||||
logger.info(f"SSH key {ssh_key_id} added successfully")
|
||||
generate_and_sign_challenge(ssh_key_file.name,ssh_key_id)
|
||||
if key_path:
|
||||
# Strip .pub suffix to get the private key path for signing
|
||||
private_key_path = key_path[:-4] if key_path.endswith('.pub') else key_path
|
||||
generate_and_sign_challenge(private_key_path, ssh_key_id)
|
||||
else:
|
||||
logger.warning("No key file path available — skipping auto-verification. "
|
||||
"Run with -k <path> to enable automatic key verification.")
|
||||
else:
|
||||
logger.error(f"Failed to add SSH key: {response.status_code} - {response.text}")
|
||||
|
||||
@@ -431,13 +486,15 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-a", "--add-key", action='store_true', default=False, help="Add SSH key to the server")
|
||||
parser.add_argument("-c", "--check-cert", action='store_true', default=False, help="Check the certificate, if it's valid exit 0, if it's invalid exit 1")
|
||||
parser.add_argument("-r", "--request-cert", action='store_true', default=False, help="Request that gatehouse sign a new certificate for you based on an SSH public key on file in your profile")
|
||||
parser.add_argument("--principals", nargs='+', metavar='PRINCIPAL', help="Unix usernames for the certificate (default: current OS user)")
|
||||
parser.add_argument("--clear-cache", action='store_true', default=False, help="Remove the cached authentication token")
|
||||
parser.add_argument("--remove-key", nargs='?', const='', metavar='KEY_ID', help="Remove an SSH key from your profile. Omit KEY_ID to pick interactively.")
|
||||
parser.add_argument("--list-keys", action='store_true', default=False, help="List SSH keys in your profile")
|
||||
|
||||
args = parser.parse_args()
|
||||
# Ensure that one of --check-cert, --request-cert, or --add-key is provided
|
||||
if not (args.check_cert or args.request_cert or args.add_key or args.clear_cache or args.remove_key is not None):
|
||||
parser.error("At least one of --check-cert, --request-cert, --add-key, --validate-key, or --clear-cache must be provided.")
|
||||
if not (args.check_cert or args.request_cert or args.add_key or args.clear_cache
|
||||
or args.remove_key is not None or args.list_keys):
|
||||
parser.error("At least one of --check-cert, --request-cert, --add-key, --list-keys, --remove-key, or --clear-cache must be provided.")
|
||||
|
||||
|
||||
# Retrieve SSH key from environment variables if not provided via CLI
|
||||
@@ -456,6 +513,25 @@ if __name__ == "__main__":
|
||||
remove_ssh_key(args.remove_key if args.remove_key else None)
|
||||
exit(0)
|
||||
|
||||
if args.list_keys:
|
||||
request_token()
|
||||
response = requests.get(
|
||||
f"{SIGN_URL}/api/v1/ssh/keys",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
keys = data.get('keys', [])
|
||||
if not keys:
|
||||
print("No SSH keys found in your profile.")
|
||||
else:
|
||||
for k in keys:
|
||||
verified = "✓ verified" if k.get('verified') else "✗ unverified"
|
||||
print(f" {k['id']} {verified} {k.get('description', '')} (added {k['created_at'][:10]})")
|
||||
else:
|
||||
logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}")
|
||||
exit(0)
|
||||
|
||||
if args.add_key:
|
||||
request_token()
|
||||
|
||||
@@ -476,10 +552,10 @@ if __name__ == "__main__":
|
||||
if args.force:
|
||||
request_token()
|
||||
logger.info("Forcing renewal of certificate")
|
||||
request_certificate()
|
||||
request_certificate(principals=args.principals)
|
||||
|
||||
if checkCert() == 1:
|
||||
request_token()
|
||||
request_certificate()
|
||||
request_certificate(principals=args.principals)
|
||||
|
||||
exit(0)
|
||||
|
||||
+114
@@ -0,0 +1,114 @@
|
||||
|
||||
[default]
|
||||
# Certificate validity period (in hours)
|
||||
# Default: 1 hour
|
||||
cert_validity_hours=1
|
||||
|
||||
# Maximum certificate validity allowed (in hours)
|
||||
# Default: 24 hours
|
||||
# Prevents users from requesting certificates valid longer than this
|
||||
max_cert_validity_hours=24
|
||||
|
||||
# Certificate Request Limits
|
||||
# Maximum number of certificates per user
|
||||
max_certs_per_user=100
|
||||
|
||||
# Certificate revocation list (CRL) configuration
|
||||
crl_enabled=true
|
||||
# CRL endpoint URL - set to your domain where CRL is served
|
||||
crl_endpoint=https://ca.example.com/crl
|
||||
# CRL refresh interval (in hours)
|
||||
crl_refresh_hours=24
|
||||
|
||||
# CA Key Configuration
|
||||
# Default key type for new CAs (ed25519, rsa, ecdsa)
|
||||
default_key_type=ed25519
|
||||
|
||||
# RSA key size (if using RSA)
|
||||
rsa_key_bits=4096
|
||||
|
||||
# Private key encryption
|
||||
# Method: kms (AWS Key Management Service) or local (for development only)
|
||||
private_key_encryption=kms
|
||||
# AWS KMS Key ID (only used if private_key_encryption=kms)
|
||||
aws_kms_key_id=${SSH_CA_KMS_KEY_ID}
|
||||
|
||||
# SSH Certificate Extensions
|
||||
# Default extensions to add to certificates
|
||||
extensions_enabled=true
|
||||
extensions=permit-X11-forwarding,permit-agent-forwarding,permit-pty,permit-port-forwarding,permit-user-rc
|
||||
|
||||
# Critical Options
|
||||
# Critical options to add to certificates (rarely needed)
|
||||
critical_options_enabled=false
|
||||
|
||||
# Certificate Field Limits
|
||||
# Maximum number of principals per certificate (SSH limitation is 256)
|
||||
max_principals_per_cert=256
|
||||
|
||||
# Maximum length for key_id field
|
||||
max_key_id_length=255
|
||||
|
||||
# Logging Configuration
|
||||
# Log level for SSH CA operations (DEBUG, INFO, WARNING, ERROR)
|
||||
log_level=INFO
|
||||
|
||||
# Audit Configuration
|
||||
# Log all certificate signing operations
|
||||
audit_enabled=true
|
||||
|
||||
# Security Configuration
|
||||
# Require SSH key verification before issuing certificates
|
||||
require_key_verification=true
|
||||
|
||||
# Verification challenge max age (in hours)
|
||||
verification_challenge_max_age=24
|
||||
|
||||
# Rate limiting for certificate signing
|
||||
# Max certificates per minute per user
|
||||
rate_limit_certs_per_minute=5
|
||||
|
||||
# Request timeout (in seconds)
|
||||
request_timeout=30
|
||||
|
||||
# Cleanup Configuration
|
||||
# Automatically delete unverified SSH keys after this many days
|
||||
auto_delete_unverified_days=30
|
||||
|
||||
# Archive expired certificates after this many days
|
||||
archive_expired_days=365
|
||||
|
||||
# CLI OAuth Configuration (for secuird-cli.py compatibility)
|
||||
# OAuth token endpoint for CLI clients
|
||||
oauth_token_endpoint=/api/v1/oauth2/token
|
||||
# OAuth userinfo endpoint for CLI clients
|
||||
oauth_userinfo_endpoint=/api/v1/oauth2/userinfo
|
||||
|
||||
[development]
|
||||
# Override settings for development environment
|
||||
private_key_encryption=local
|
||||
ca_key_path=/home/james/cory/secuird/certs/ca-users
|
||||
log_level=DEBUG
|
||||
cert_validity_hours=24
|
||||
max_cert_validity_hours=720
|
||||
rate_limit_certs_per_minute=100
|
||||
require_key_verification=false
|
||||
|
||||
[production]
|
||||
# Override settings for production environment
|
||||
private_key_encryption=kms
|
||||
log_level=WARNING
|
||||
cert_validity_hours=1
|
||||
max_cert_validity_hours=24
|
||||
rate_limit_certs_per_minute=5
|
||||
require_key_verification=true
|
||||
|
||||
[testing]
|
||||
# Override settings for testing environment
|
||||
private_key_encryption=local
|
||||
log_level=DEBUG
|
||||
cert_validity_hours=1
|
||||
max_cert_validity_hours=24
|
||||
rate_limit_certs_per_minute=100
|
||||
require_key_verification=true
|
||||
audit_enabled=false
|
||||
@@ -2,6 +2,9 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env'))
|
||||
|
||||
# Test debug logging - this should appear when running `flask run --debug`
|
||||
_root_logger = logging.getLogger(__name__)
|
||||
_root_logger.debug("[TEST] Debug logging is working!")
|
||||
@@ -239,3 +242,6 @@ def initialize_oidc_jwks(app):
|
||||
app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}")
|
||||
except Exception as e:
|
||||
app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}")
|
||||
|
||||
# Create default app instance for gunicorn/wsgi
|
||||
app = create_app()
|
||||
|
||||
@@ -5,4 +5,6 @@ from flask import Blueprint
|
||||
api_v1_bp = Blueprint("api_v1", __name__)
|
||||
|
||||
# Import route modules to register them
|
||||
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth
|
||||
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh
|
||||
|
||||
api_v1_bp.register_blueprint(ssh.ssh_bp)
|
||||
|
||||
@@ -46,6 +46,33 @@ def _pop_oidc_bridge(oauth_state: str) -> str | None:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _store_cli_redirect(oauth_state: str, redirect_url: str) -> None:
|
||||
"""Store CLI redirect_url keyed by OAuth state (for /token_please flow)."""
|
||||
try:
|
||||
import gatehouse_app.extensions as _ext
|
||||
rc = _ext.redis_client
|
||||
if rc is not None:
|
||||
rc.setex(f"oauth_cli_redirect:{oauth_state}", _OAUTH_BRIDGE_TTL, redirect_url)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _pop_cli_redirect(oauth_state: str) -> str | None:
|
||||
"""Retrieve and delete CLI redirect_url for the given OAuth state."""
|
||||
try:
|
||||
import gatehouse_app.extensions as _ext
|
||||
rc = _ext.redis_client
|
||||
if rc is not None:
|
||||
key = f"oauth_cli_redirect:{oauth_state}"
|
||||
val = rc.get(key)
|
||||
if val:
|
||||
rc.delete(key)
|
||||
return val.decode() if isinstance(val, bytes) else val
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -69,6 +96,71 @@ def get_provider_type(provider: str) -> AuthMethodType:
|
||||
return PROVIDER_TYPE_MAP[provider_lower]
|
||||
|
||||
|
||||
@api_v1_bp.route("/token_please", methods=["GET"])
|
||||
def token_please():
|
||||
"""
|
||||
CLI token acquisition endpoint.
|
||||
|
||||
Initiates an OAuth login flow and, on success, redirects the user's browser
|
||||
to the CLI's local callback server (redirect_url) with the session token
|
||||
appended, e.g.: http://127.0.0.1:8250/?token=<SESSION_TOKEN>
|
||||
|
||||
This endpoint is designed for CLI clients that:
|
||||
1. Start a local HTTP server on LISTENER_SERVER_PORT (e.g. 8250)
|
||||
2. Open a browser to /api/v1/token_please?redirect_url=http://127.0.0.1:8250/?token=
|
||||
3. Wait for the browser to POST the token back to their local server
|
||||
|
||||
Query parameters:
|
||||
redirect_url: Local callback URL where the token will be appended
|
||||
provider: OAuth provider to use (default: 'google')
|
||||
"""
|
||||
from urllib.parse import urlencode
|
||||
from flask import current_app, redirect as flask_redirect
|
||||
|
||||
redirect_url = request.args.get("redirect_url", "").strip()
|
||||
provider = request.args.get("provider", "google").lower()
|
||||
|
||||
if not redirect_url:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="redirect_url query parameter is required",
|
||||
status=400,
|
||||
error_type="MISSING_REDIRECT_URL",
|
||||
)
|
||||
|
||||
# Validate redirect_url is localhost/127.0.0.1 (security: prevent open redirect)
|
||||
from urllib.parse import urlparse as _urlparse
|
||||
parsed = _urlparse(redirect_url)
|
||||
if parsed.hostname not in ("localhost", "127.0.0.1"):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="redirect_url must point to localhost",
|
||||
status=400,
|
||||
error_type="INVALID_REDIRECT_URL",
|
||||
)
|
||||
|
||||
try:
|
||||
provider_type = get_provider_type(provider)
|
||||
auth_url, state = OAuthFlowService.initiate_login_flow(
|
||||
provider_type=provider_type,
|
||||
organization_id=None,
|
||||
redirect_uri=None,
|
||||
)
|
||||
except (OAuthFlowError, ExternalAuthError) as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=getattr(e, "message", str(e)),
|
||||
status=getattr(e, "status_code", 400),
|
||||
error_type=getattr(e, "error_type", "OAUTH_ERROR"),
|
||||
)
|
||||
|
||||
# Store the CLI redirect URL so the callback can use it
|
||||
_store_cli_redirect(state, redirect_url)
|
||||
|
||||
logger.info(f"CLI token_please: provider={provider}, redirect_url={redirect_url}, redirecting to OAuth")
|
||||
return flask_redirect(auth_url, code=302)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider Configuration Endpoints (Admin)
|
||||
# =============================================================================
|
||||
@@ -575,8 +667,6 @@ def initiate_oauth_authorize(provider: str):
|
||||
"state": "state_token"
|
||||
}
|
||||
"""
|
||||
provider_type = get_provider_type(provider)
|
||||
|
||||
# Get query parameters - organization_id is now optional
|
||||
flow = request.args.get("flow", "login")
|
||||
redirect_uri = request.args.get("redirect_uri")
|
||||
@@ -592,7 +682,7 @@ def initiate_oauth_authorize(provider: str):
|
||||
)
|
||||
|
||||
try:
|
||||
# Initiate flow - organization_id is now optional
|
||||
provider_type = get_provider_type(provider)
|
||||
if flow == "login":
|
||||
auth_url, state = OAuthFlowService.initiate_login_flow(
|
||||
provider_type=provider_type,
|
||||
@@ -626,6 +716,13 @@ def initiate_oauth_authorize(provider: str):
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
except ExternalAuthError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=e.message,
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/external/<provider>/callback", methods=["GET"])
|
||||
@@ -666,8 +763,19 @@ def handle_oauth_callback(provider: str):
|
||||
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080")
|
||||
frontend_callback = f"{frontend_url}/oauth/callback"
|
||||
|
||||
# Check if this is a CLI /token_please flow — retrieve stored redirect_url
|
||||
cli_redirect_url = _pop_cli_redirect(state) if state else None
|
||||
|
||||
def redirect_error(message: str, error_type: str = "OAUTH_ERROR"):
|
||||
"""Redirect to frontend with error params."""
|
||||
"""Redirect to frontend (or CLI) with error params."""
|
||||
if cli_redirect_url:
|
||||
# CLI flow: return a plain error page instead of redirecting back
|
||||
from flask import make_response
|
||||
return make_response(
|
||||
f"<html><body><h2>Authentication Error</h2><p>{message}</p>"
|
||||
f"<p>You may close this window.</p></body></html>",
|
||||
400,
|
||||
)
|
||||
params = {"error": message, "error_type": error_type}
|
||||
if state:
|
||||
params["state"] = state
|
||||
@@ -706,8 +814,11 @@ def handle_oauth_callback(provider: str):
|
||||
# Recover oidc_session_id if this was triggered from an OIDC bridge flow
|
||||
oidc_session_id = _pop_oidc_bridge(state)
|
||||
|
||||
# Organization selection / creation flows are not supported in CLI mode
|
||||
# (fall through to token redirect with whatever session we have)
|
||||
|
||||
# Organization selection needed (user belongs to multiple orgs)
|
||||
if result.get("requires_org_selection"):
|
||||
if result.get("requires_org_selection") and not cli_redirect_url:
|
||||
import json
|
||||
orgs = json.dumps(result.get("available_organizations", []))
|
||||
params = {
|
||||
@@ -722,7 +833,7 @@ def handle_oauth_callback(provider: str):
|
||||
return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302)
|
||||
|
||||
# Organization creation needed (new user via OAuth with no org)
|
||||
if result.get("requires_org_creation"):
|
||||
if result.get("requires_org_creation") and not cli_redirect_url:
|
||||
params = {
|
||||
"requires_org_creation": "1",
|
||||
"state": result["state"],
|
||||
@@ -751,6 +862,19 @@ def handle_oauth_callback(provider: str):
|
||||
user_info = result.get("user", {})
|
||||
if user_info.get("email"):
|
||||
params["email"] = user_info["email"]
|
||||
|
||||
# ── CLI /token_please flow: redirect to the CLI's local callback ─────
|
||||
if cli_redirect_url:
|
||||
# The CLI expects: http://127.0.0.1:8250/?token=<TOKEN>
|
||||
# cli_redirect_url already ends with "token=" so just append the value
|
||||
cli_final_url = cli_redirect_url + token
|
||||
logger.info(
|
||||
f"CLI token_please success: provider={provider}, user={user_info.get('email')}, "
|
||||
f"redirecting to CLI callback"
|
||||
)
|
||||
return flask_redirect(cli_final_url, code=302)
|
||||
|
||||
# ── Frontend flow ─────────────────────────────────────────────────────
|
||||
# Pass oidc_session_id through so the frontend can complete the OIDC flow
|
||||
if oidc_session_id:
|
||||
params["oidc_session_id"] = oidc_session_id
|
||||
|
||||
@@ -13,8 +13,6 @@ from gatehouse_app.schemas.organization_schema import (
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
########jb- need to implement departs, principals
|
||||
@api_v1_bp.route("/organizations", methods=["POST"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
@@ -378,3 +376,557 @@ def update_member_role(org_id, user_id):
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/audit-logs", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def get_organization_audit_logs(org_id):
|
||||
"""
|
||||
Get audit logs for an organization.
|
||||
|
||||
Query params:
|
||||
page: Page number (default 1)
|
||||
per_page: Results per page (default 50, max 200)
|
||||
action: Filter by action type
|
||||
|
||||
Returns:
|
||||
200: List of audit log entries
|
||||
401: Not authenticated
|
||||
403: Not a member / insufficient permissions
|
||||
404: Organization not found
|
||||
"""
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
|
||||
# Ensure org exists and user is a member (full_access_required handles this)
|
||||
OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
page = int(request.args.get("page", 1))
|
||||
per_page = min(int(request.args.get("per_page", 50)), 200)
|
||||
action_filter = request.args.get("action")
|
||||
|
||||
query = AuditLog.query.filter_by(organization_id=org_id)
|
||||
if action_filter:
|
||||
query = query.filter_by(action=action_filter)
|
||||
|
||||
query = query.order_by(AuditLog.created_at.desc())
|
||||
total = query.count()
|
||||
logs = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
def log_to_dict(log):
|
||||
return {
|
||||
"id": log.id,
|
||||
"action": log.action.value if log.action else None,
|
||||
"user_id": log.user_id,
|
||||
"user_email": log.user.email if log.user else None,
|
||||
"user": {"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name} if log.user else None,
|
||||
"organization_id": log.organization_id,
|
||||
"resource_type": log.resource_type,
|
||||
"resource_id": log.resource_id,
|
||||
"ip_address": log.ip_address,
|
||||
"user_agent": log.user_agent,
|
||||
"request_id": log.request_id,
|
||||
"description": log.description,
|
||||
"success": log.success,
|
||||
"error_message": log.error_message,
|
||||
"metadata": log.extra_data,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
"updated_at": log.updated_at.isoformat() if log.updated_at else None,
|
||||
}
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"audit_logs": [log_to_dict(log) for log in logs],
|
||||
"count": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page,
|
||||
},
|
||||
message="Audit logs retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Organization Invite Tokens
|
||||
# ============================================================================
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/invites", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def create_org_invite(org_id):
|
||||
"""Create an invite token for an organization.
|
||||
|
||||
Request body:
|
||||
email: Email address to invite
|
||||
role: Role to assign (default: member)
|
||||
|
||||
Returns:
|
||||
201: Invite created
|
||||
400: Validation error
|
||||
403: Not an admin
|
||||
404: Organization not found
|
||||
"""
|
||||
from gatehouse_app.models import OrgInviteToken, Organization
|
||||
from gatehouse_app.services.notification_service import NotificationService
|
||||
from flask import current_app
|
||||
|
||||
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
|
||||
if not org:
|
||||
return api_response(success=False, message="Organization not found", status=404)
|
||||
|
||||
data = request.get_json() or {}
|
||||
email = (data.get("email") or "").strip().lower()
|
||||
role = (data.get("role") or "member").strip()
|
||||
|
||||
if not email:
|
||||
return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
invite = OrgInviteToken.generate(
|
||||
organization_id=org_id,
|
||||
email=email,
|
||||
role=role,
|
||||
invited_by_id=g.current_user.id,
|
||||
)
|
||||
|
||||
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
|
||||
invite_link = f"{app_url}/invite?token={invite.token}"
|
||||
|
||||
NotificationService._send_email(
|
||||
to_address=email,
|
||||
subject=f"You're invited to join {org.name} on Gatehouse",
|
||||
body=(
|
||||
f"You've been invited to join {org.name} on Gatehouse.\n\n"
|
||||
f"Click the link below to accept the invitation (valid for 7 days):\n"
|
||||
f"{invite_link}\n\n"
|
||||
f"Gatehouse Security Team"
|
||||
),
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"invite": {"id": invite.id, "email": invite.email, "role": invite.role, "expires_at": invite.expires_at.isoformat() + "Z"}},
|
||||
message="Invite sent successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/invites/<token>", methods=["GET"])
|
||||
def get_invite(token):
|
||||
"""Get invite details by token.
|
||||
|
||||
Returns:
|
||||
200: Invite details (org name, email)
|
||||
400: Invalid or expired token
|
||||
"""
|
||||
from gatehouse_app.models import OrgInviteToken
|
||||
|
||||
invite = OrgInviteToken.query.filter_by(token=token).first()
|
||||
if not invite or not invite.is_valid:
|
||||
return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN")
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"email": invite.email,
|
||||
"organization": {"id": invite.organization_id, "name": invite.organization.name},
|
||||
"role": invite.role,
|
||||
},
|
||||
message="Invite found",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/invites/<token>/accept", methods=["POST"])
|
||||
def accept_invite(token):
|
||||
"""Accept an organization invite.
|
||||
|
||||
Creates the user account (if not already registered) and adds them
|
||||
to the organization.
|
||||
|
||||
Request body:
|
||||
full_name: User's display name
|
||||
password: Password for new account (if not already registered)
|
||||
password_confirm: Password confirmation
|
||||
|
||||
Returns:
|
||||
200: Invite accepted, returns user token
|
||||
400: Invalid/expired token or validation error
|
||||
409: Already a member
|
||||
"""
|
||||
from gatehouse_app.models import OrgInviteToken, User
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
invite = OrgInviteToken.query.filter_by(token=token).first()
|
||||
if not invite or not invite.is_valid:
|
||||
return api_response(success=False, message="This invitation link is invalid or has expired.", status=400, error_type="INVALID_TOKEN")
|
||||
|
||||
data = request.get_json() or {}
|
||||
full_name = data.get("full_name") or ""
|
||||
password = data.get("password") or ""
|
||||
password_confirm = data.get("password_confirm") or ""
|
||||
|
||||
user = User.query.filter_by(email=invite.email, deleted_at=None).first()
|
||||
|
||||
if not user:
|
||||
# Register a new user
|
||||
if not password:
|
||||
return api_response(success=False, message="Password is required for new accounts.", status=400, error_type="VALIDATION_ERROR")
|
||||
if password != password_confirm:
|
||||
return api_response(success=False, message="Passwords do not match.", status=400, error_type="VALIDATION_ERROR")
|
||||
if len(password) < 8:
|
||||
return api_response(success=False, message="Password must be at least 8 characters.", status=400, error_type="VALIDATION_ERROR")
|
||||
try:
|
||||
user = AuthService.register_user(email=invite.email, password=password, full_name=full_name or None)
|
||||
except Exception as exc:
|
||||
return api_response(success=False, message=str(exc), status=400, error_type="REGISTRATION_ERROR")
|
||||
|
||||
# Add to org
|
||||
role_value = invite.role
|
||||
try:
|
||||
org_role = OrganizationRole(role_value)
|
||||
except ValueError:
|
||||
org_role = OrganizationRole.MEMBER
|
||||
|
||||
try:
|
||||
OrganizationService.add_member(
|
||||
org=invite.organization,
|
||||
user_id=user.id,
|
||||
role=org_role,
|
||||
inviter_id=invite.invited_by_id,
|
||||
)
|
||||
except Exception:
|
||||
pass # Already a member is fine
|
||||
|
||||
invite.accept()
|
||||
|
||||
user_session = AuthService.create_session(user)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"user": user.to_dict(),
|
||||
"token": user_session.token,
|
||||
"expires_at": user_session.expires_at.isoformat() + "Z",
|
||||
},
|
||||
message="Invitation accepted. Welcome!",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Organization OIDC Clients
|
||||
# ============================================================================
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/clients", methods=["GET"])
|
||||
@login_required
|
||||
def list_org_clients(org_id):
|
||||
"""List OIDC clients for an organization.
|
||||
|
||||
Returns:
|
||||
200: List of OIDC clients
|
||||
403: Not a member
|
||||
404: Organization not found
|
||||
"""
|
||||
from gatehouse_app.models import OIDCClient, Organization
|
||||
|
||||
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
|
||||
if not org:
|
||||
return api_response(success=False, message="Organization not found", status=404)
|
||||
|
||||
clients = OIDCClient.query.filter_by(organization_id=org_id, is_active=True).all()
|
||||
|
||||
def client_to_dict(c):
|
||||
return {
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"client_id": c.client_id,
|
||||
"redirect_uris": c.redirect_uris,
|
||||
"scopes": c.scopes,
|
||||
"grant_types": c.grant_types,
|
||||
"is_active": c.is_active,
|
||||
"created_at": c.created_at.isoformat() + "Z",
|
||||
}
|
||||
|
||||
return api_response(
|
||||
data={"clients": [client_to_dict(c) for c in clients], "count": len(clients)},
|
||||
message="Clients retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/clients", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def create_org_client(org_id):
|
||||
"""Create a new OIDC client for an organization.
|
||||
|
||||
Request body:
|
||||
name: Client name
|
||||
redirect_uris: List of allowed redirect URIs (newline or comma separated string)
|
||||
|
||||
Returns:
|
||||
201: Client created with client_id and client_secret
|
||||
403: Not an admin
|
||||
404: Organization not found
|
||||
"""
|
||||
import secrets as _secrets
|
||||
from gatehouse_app.extensions import bcrypt
|
||||
from gatehouse_app.models import OIDCClient, Organization
|
||||
|
||||
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
|
||||
if not org:
|
||||
return api_response(success=False, message="Organization not found", status=404)
|
||||
|
||||
data = request.get_json() or {}
|
||||
name = (data.get("name") or "").strip()
|
||||
redirect_uris_raw = data.get("redirect_uris") or []
|
||||
|
||||
if not name:
|
||||
return api_response(success=False, message="Client name is required", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
if isinstance(redirect_uris_raw, str):
|
||||
redirect_uris = [u.strip() for u in redirect_uris_raw.replace(",", "\n").splitlines() if u.strip()]
|
||||
else:
|
||||
redirect_uris = [u.strip() for u in redirect_uris_raw if isinstance(u, str) and u.strip()]
|
||||
|
||||
if not redirect_uris:
|
||||
return api_response(success=False, message="At least one redirect URI is required", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
client_id = _secrets.token_hex(16)
|
||||
client_secret = _secrets.token_urlsafe(32)
|
||||
|
||||
client = OIDCClient(
|
||||
organization_id=org_id,
|
||||
name=name,
|
||||
client_id=client_id,
|
||||
client_secret_hash=bcrypt.generate_password_hash(client_secret).decode("utf-8"),
|
||||
redirect_uris=redirect_uris,
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
scopes=["openid", "profile", "email"],
|
||||
is_active=True,
|
||||
is_confidential=True,
|
||||
)
|
||||
from gatehouse_app.extensions import db
|
||||
db.session.add(client)
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"client": {
|
||||
"id": client.id,
|
||||
"name": client.name,
|
||||
"client_id": client.client_id,
|
||||
"client_secret": client_secret, # Only returned once
|
||||
"redirect_uris": client.redirect_uris,
|
||||
"scopes": client.scopes,
|
||||
"created_at": client.created_at.isoformat() + "Z",
|
||||
}
|
||||
},
|
||||
message="OIDC client created successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/clients/<client_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def delete_org_client(org_id, client_id):
|
||||
"""Deactivate an OIDC client.
|
||||
|
||||
Returns:
|
||||
200: Client deactivated
|
||||
403: Not an admin
|
||||
404: Client not found
|
||||
"""
|
||||
from gatehouse_app.models import OIDCClient
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
client = OIDCClient.query.filter_by(id=client_id, organization_id=org_id).first()
|
||||
if not client:
|
||||
return api_response(success=False, message="Client not found", status=404)
|
||||
|
||||
client.is_active = False
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={}, message="Client deactivated successfully")
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/members/<user_id>/send-mfa-reminder", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
def send_mfa_reminder(org_id, user_id):
|
||||
"""Send an MFA reminder email to a specific member.
|
||||
|
||||
Returns:
|
||||
200: Reminder sent (or silently skipped if no deadline record)
|
||||
403: Not an admin
|
||||
404: Member not found
|
||||
"""
|
||||
from gatehouse_app.models import User, MfaPolicyCompliance, OrganizationSecurityPolicy
|
||||
from gatehouse_app.services.notification_service import NotificationService
|
||||
|
||||
user = User.query.filter_by(id=user_id, deleted_at=None).first()
|
||||
if not user:
|
||||
return api_response(success=False, message="User not found", status=404)
|
||||
|
||||
compliance = MfaPolicyCompliance.query.filter_by(
|
||||
user_id=user_id, organization_id=org_id
|
||||
).first()
|
||||
policy = OrganizationSecurityPolicy.query.filter_by(organization_id=org_id).first()
|
||||
|
||||
if compliance and policy and compliance.deadline_at:
|
||||
NotificationService.send_mfa_deadline_reminder(user, compliance, policy)
|
||||
else:
|
||||
# No compliance deadline — send a generic nudge
|
||||
NotificationService._send_email(
|
||||
to_address=user.email,
|
||||
subject="Reminder: Set up multi-factor authentication",
|
||||
body=(
|
||||
f"Hi {user.full_name or user.email},\n\n"
|
||||
"Your organization administrator has asked you to set up "
|
||||
"multi-factor authentication (MFA) on your Gatehouse account.\n\n"
|
||||
"Please log in and configure MFA as soon as possible.\n\n"
|
||||
"Gatehouse Security Team"
|
||||
),
|
||||
)
|
||||
|
||||
return api_response(data={}, message="Reminder sent successfully")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# System-wide Audit Log (admin view) + User self audit
|
||||
# =============================================================================
|
||||
|
||||
def _audit_log_to_dict(log):
|
||||
"""Serialize an AuditLog record to a dict."""
|
||||
return {
|
||||
"id": log.id,
|
||||
"action": log.action.value if log.action else None,
|
||||
"user_id": log.user_id,
|
||||
"user": (
|
||||
{"id": log.user.id, "email": log.user.email, "full_name": log.user.full_name}
|
||||
if log.user else None
|
||||
),
|
||||
"organization_id": log.organization_id,
|
||||
"resource_type": log.resource_type,
|
||||
"resource_id": log.resource_id,
|
||||
"ip_address": log.ip_address,
|
||||
"user_agent": log.user_agent,
|
||||
"request_id": log.request_id,
|
||||
"description": log.description,
|
||||
"success": log.success,
|
||||
"error_message": log.error_message,
|
||||
"metadata": log.extra_data,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
"updated_at": log.updated_at.isoformat() if log.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@api_v1_bp.route("/audit-logs", methods=["GET"])
|
||||
@login_required
|
||||
def get_system_audit_logs():
|
||||
"""
|
||||
Get all audit logs (system-wide). Any authenticated user can query
|
||||
their own logs; org owners/admins also see org-scoped logs; this
|
||||
endpoint returns ALL logs for users who own at least one org
|
||||
(acting as an admin view).
|
||||
|
||||
Query params:
|
||||
page – page number (default 1)
|
||||
per_page – results per page (default 50, max 200)
|
||||
action – filter by AuditAction value
|
||||
user_id – filter by user id
|
||||
resource_type – filter by resource type
|
||||
success – "true"/"false"
|
||||
q – free-text search on description
|
||||
"""
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
|
||||
current_user = g.current_user
|
||||
page = max(1, int(request.args.get("page", 1)))
|
||||
per_page = min(int(request.args.get("per_page", 50)), 200)
|
||||
|
||||
# Check if the user is an owner of any org to grant admin-level access
|
||||
is_admin = OrganizationMember.query.filter_by(
|
||||
user_id=current_user.id, role="OWNER"
|
||||
).first() is not None
|
||||
|
||||
query = AuditLog.query
|
||||
|
||||
if not is_admin:
|
||||
# Non-admins can only see their own logs
|
||||
query = query.filter(AuditLog.user_id == current_user.id)
|
||||
|
||||
# Optional filters
|
||||
action_filter = request.args.get("action")
|
||||
if action_filter:
|
||||
query = query.filter(AuditLog.action == action_filter)
|
||||
|
||||
user_id_filter = request.args.get("user_id")
|
||||
if user_id_filter:
|
||||
query = query.filter(AuditLog.user_id == user_id_filter)
|
||||
|
||||
resource_type_filter = request.args.get("resource_type")
|
||||
if resource_type_filter:
|
||||
query = query.filter(AuditLog.resource_type == resource_type_filter)
|
||||
|
||||
success_filter = request.args.get("success")
|
||||
if success_filter is not None:
|
||||
query = query.filter(AuditLog.success == (success_filter.lower() == "true"))
|
||||
|
||||
q = request.args.get("q", "").strip()
|
||||
if q:
|
||||
query = query.filter(AuditLog.description.ilike(f"%{q}%"))
|
||||
|
||||
query = query.order_by(AuditLog.created_at.desc())
|
||||
total = query.count()
|
||||
logs = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"audit_logs": [_audit_log_to_dict(log) for log in logs],
|
||||
"count": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page,
|
||||
"is_admin_view": is_admin,
|
||||
},
|
||||
message="Audit logs retrieved",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/audit-logs", methods=["GET"])
|
||||
@login_required
|
||||
def get_my_audit_logs():
|
||||
"""
|
||||
Get audit logs for the currently authenticated user only.
|
||||
|
||||
Query params:
|
||||
page – page number (default 1)
|
||||
per_page – results per page (default 50, max 200)
|
||||
action – filter by AuditAction value
|
||||
"""
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
|
||||
current_user = g.current_user
|
||||
page = max(1, int(request.args.get("page", 1)))
|
||||
per_page = min(int(request.args.get("per_page", 50)), 200)
|
||||
|
||||
query = AuditLog.query.filter(AuditLog.user_id == current_user.id)
|
||||
|
||||
action_filter = request.args.get("action")
|
||||
if action_filter:
|
||||
query = query.filter(AuditLog.action == action_filter)
|
||||
|
||||
query = query.order_by(AuditLog.created_at.desc())
|
||||
total = query.count()
|
||||
logs = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"audit_logs": [_audit_log_to_dict(log) for log in logs],
|
||||
"count": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page,
|
||||
},
|
||||
message="Activity retrieved",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,615 @@
|
||||
"""SSH Key and Certificate API routes."""
|
||||
from flask import Blueprint, request, jsonify, g
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from gatehouse_app.services.ssh_key_service import SSHKeyService
|
||||
from gatehouse_app.services.ssh_ca_signing_service import (
|
||||
SSHCASigningService,
|
||||
SSHCertificateSigningRequest,
|
||||
)
|
||||
from gatehouse_app.exceptions import (
|
||||
SSHKeyError,
|
||||
SSHKeyNotFoundError,
|
||||
SSHCertificateError,
|
||||
ValidationError,
|
||||
SSHKeyAlreadyExistsError,
|
||||
)
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.models import AuditLog
|
||||
from gatehouse_app.utils.decorators import login_required
|
||||
|
||||
ssh_bp = Blueprint('ssh', __name__, url_prefix='/ssh')
|
||||
ssh_key_service = SSHKeyService()
|
||||
ssh_ca_service = SSHCASigningService()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_org_ca_for_user(user):
|
||||
"""Return the active DB CA for the user's first org, or None."""
|
||||
try:
|
||||
from gatehouse_app.models.ca import CA
|
||||
org_ids = [m.organization_id for m in user.organization_memberships]
|
||||
if not org_ids:
|
||||
return None
|
||||
return CA.query.filter(
|
||||
CA.organization_id.in_(org_ids),
|
||||
CA.is_active == True, # noqa: E712
|
||||
).first()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_or_create_system_ca():
|
||||
"""
|
||||
Return a CA DB record representing the config-file CA.
|
||||
|
||||
This is used as the ``ca_id`` FK when persisting certificates that were
|
||||
signed by the globally-configured CA key (not an org-specific DB CA).
|
||||
The record is created on first use and has no ``organization_id``.
|
||||
"""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.ca import CA, KeyType
|
||||
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
|
||||
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
||||
import os
|
||||
|
||||
try:
|
||||
existing = CA.query.filter_by(name="system-config-ca").first()
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
cfg = get_ssh_ca_config()
|
||||
key_path = cfg.get_str("ca_key_path", "").strip()
|
||||
pub_key_path = key_path + ".pub"
|
||||
|
||||
if not os.path.exists(pub_key_path):
|
||||
return None
|
||||
|
||||
with open(pub_key_path) as f:
|
||||
pub_key = f.read().strip()
|
||||
|
||||
# Load private key for the record (stored but not actually used for signing here)
|
||||
priv_key = ""
|
||||
if os.path.exists(key_path):
|
||||
with open(key_path) as f:
|
||||
priv_key = f.read()
|
||||
|
||||
fingerprint = compute_ssh_fingerprint(pub_key)
|
||||
|
||||
# Check by fingerprint in case it was created under a different name
|
||||
existing_by_fp = CA.query.filter_by(fingerprint=fingerprint).first()
|
||||
if existing_by_fp:
|
||||
return existing_by_fp
|
||||
|
||||
system_ca = CA(
|
||||
name="system-config-ca",
|
||||
description="Global CA loaded from etc/ssh_ca.conf (ca_key_path)",
|
||||
key_type=KeyType.ED25519,
|
||||
private_key=priv_key,
|
||||
public_key=pub_key,
|
||||
fingerprint=fingerprint,
|
||||
is_active=True,
|
||||
default_cert_validity_hours=24,
|
||||
max_cert_validity_hours=720,
|
||||
)
|
||||
# organization_id is nullable=False in schema — we need a dummy org or
|
||||
# need to allow NULL. Use None; the DB constraint will tell us quickly.
|
||||
# If the migration enforces NOT NULL we'll catch the error gracefully.
|
||||
db.session.add(system_ca)
|
||||
db.session.commit()
|
||||
return system_ca
|
||||
except Exception as exc:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
f"Could not upsert system-config-ca: {exc}"
|
||||
)
|
||||
try:
|
||||
db.session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _persist_certificate(user_id, ssh_key_id, ca, signing_response, request_ip=None):
|
||||
"""Save a signed certificate to the ssh_certificates table.
|
||||
|
||||
Args:
|
||||
user_id: UUID of the user
|
||||
ssh_key_id: UUID of the SSH key that was signed
|
||||
ca: CA model instance (may be None — cert still returned but not persisted)
|
||||
signing_response: SSHCertificateSigningResponse
|
||||
request_ip: Client IP address
|
||||
|
||||
Returns:
|
||||
SSHCertificate instance or None if persistence failed
|
||||
"""
|
||||
if ca is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus
|
||||
from gatehouse_app.models.ca import CertType
|
||||
|
||||
cert_record = SSHCertificate(
|
||||
ca_id=ca.id,
|
||||
user_id=user_id,
|
||||
ssh_key_id=ssh_key_id,
|
||||
certificate=signing_response.certificate,
|
||||
serial=signing_response.serial,
|
||||
key_id=str(ssh_key_id),
|
||||
cert_type=CertType.USER,
|
||||
principals=signing_response.principals,
|
||||
valid_after=signing_response.valid_after,
|
||||
valid_before=signing_response.valid_before,
|
||||
revoked=False,
|
||||
status=CertificateStatus.ISSUED,
|
||||
request_ip=request_ip,
|
||||
)
|
||||
db.session.add(cert_record)
|
||||
db.session.commit()
|
||||
return cert_record
|
||||
except Exception as exc:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
f"Failed to persist certificate to DB: {exc}"
|
||||
)
|
||||
try:
|
||||
from gatehouse_app.extensions import db as _db
|
||||
_db.session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ssh_bp.route('/keys', methods=['GET'])
|
||||
@login_required
|
||||
def list_ssh_keys():
|
||||
"""Get all SSH keys for current user."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
keys = ssh_key_service.get_user_ssh_keys(user_id)
|
||||
return jsonify({
|
||||
'keys': [k.to_dict() for k in keys],
|
||||
'count': len(keys),
|
||||
}), 200
|
||||
|
||||
|
||||
@ssh_bp.route('/keys', methods=['POST'])
|
||||
@login_required
|
||||
def add_ssh_key():
|
||||
"""Add a new SSH public key for current user."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return jsonify({'error': 'No JSON data provided'}), 400
|
||||
|
||||
public_key = data.get('public_key') or data.get('key')
|
||||
description = data.get('description')
|
||||
|
||||
if not public_key:
|
||||
return jsonify({'error': 'public_key is required'}), 400
|
||||
|
||||
try:
|
||||
ssh_key = ssh_key_service.add_ssh_key(
|
||||
user_id=user_id,
|
||||
public_key=public_key,
|
||||
description=description,
|
||||
)
|
||||
|
||||
# Audit log
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_KEY_ADDED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHKey',
|
||||
resource_id=ssh_key.id,
|
||||
ip_address=request.remote_addr,
|
||||
)
|
||||
|
||||
return jsonify(ssh_key.to_dict()), 201
|
||||
|
||||
except SSHKeyAlreadyExistsError as e:
|
||||
return jsonify({'error': e.message, 'code': 'SSH_KEY_ALREADY_EXISTS'}), 409
|
||||
except IntegrityError:
|
||||
return jsonify({'error': 'SSH key already exists', 'code': 'SSH_KEY_ALREADY_EXISTS'}), 409
|
||||
except SSHKeyError as e:
|
||||
return jsonify({'error': str(e)}), 400
|
||||
except ValidationError as e:
|
||||
return jsonify({'error': str(e)}), 400
|
||||
|
||||
|
||||
@ssh_bp.route('/keys/<key_id>', methods=['GET'])
|
||||
@login_required
|
||||
def get_ssh_key(key_id):
|
||||
"""Get a specific SSH key."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
try:
|
||||
ssh_key = ssh_key_service.get_ssh_key(key_id)
|
||||
|
||||
# Check ownership
|
||||
if ssh_key.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
|
||||
return jsonify(ssh_key.to_dict()), 200
|
||||
|
||||
except SSHKeyNotFoundError:
|
||||
return jsonify({'error': 'SSH key not found'}), 404
|
||||
|
||||
|
||||
@ssh_bp.route('/keys/<key_id>', methods=['DELETE'])
|
||||
@login_required
|
||||
def delete_ssh_key(key_id):
|
||||
"""Delete an SSH key."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
try:
|
||||
ssh_key = ssh_key_service.get_ssh_key(key_id)
|
||||
|
||||
# Check ownership
|
||||
if ssh_key.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
|
||||
ssh_key_service.delete_ssh_key(key_id)
|
||||
|
||||
# Audit log
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_KEY_DELETED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHKey',
|
||||
resource_id=key_id,
|
||||
ip_address=request.remote_addr,
|
||||
)
|
||||
|
||||
return jsonify({'status': 'deleted'}), 200
|
||||
|
||||
except SSHKeyNotFoundError:
|
||||
return jsonify({'error': 'SSH key not found'}), 404
|
||||
|
||||
|
||||
@ssh_bp.route('/keys/<key_id>/verify', methods=['GET', 'POST'])
|
||||
@login_required
|
||||
def verify_ssh_key(key_id):
|
||||
"""Generate or verify SSH key ownership challenge."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
try:
|
||||
ssh_key = ssh_key_service.get_ssh_key(key_id)
|
||||
|
||||
# Check ownership
|
||||
if ssh_key.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
|
||||
# Handle GET request - return challenge
|
||||
if request.method == 'GET':
|
||||
challenge = ssh_key_service.generate_verification_challenge(key_id)
|
||||
return jsonify({
|
||||
'challenge_text': challenge,
|
||||
'validationText': challenge, # Backwards compatibility
|
||||
'key_id': key_id,
|
||||
}), 200
|
||||
|
||||
# Handle POST request - verify signature
|
||||
data = request.get_json() or {}
|
||||
action = data.get('action', 'verify_signature')
|
||||
|
||||
if action == 'verify_signature':
|
||||
# Verify signature
|
||||
signature = data.get('signature')
|
||||
if not signature:
|
||||
return jsonify({'error': 'signature is required'}), 400
|
||||
|
||||
try:
|
||||
verified = ssh_key_service.verify_ssh_key_ownership(key_id, signature)
|
||||
|
||||
# Audit log
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_KEY_VERIFIED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHKey',
|
||||
resource_id=key_id,
|
||||
ip_address=request.remote_addr,
|
||||
success=verified,
|
||||
)
|
||||
|
||||
return jsonify({'verified': verified}), 200
|
||||
|
||||
except Exception as e:
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_KEY_VALIDATION_FAILED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHKey',
|
||||
resource_id=key_id,
|
||||
ip_address=request.remote_addr,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
return jsonify({'error': str(e)}), 400
|
||||
|
||||
else: # generate_challenge
|
||||
# Generate verification challenge
|
||||
challenge = ssh_key_service.generate_verification_challenge(key_id)
|
||||
return jsonify({
|
||||
'challenge_text': challenge,
|
||||
'challenge': challenge, # Both for compatibility
|
||||
}), 200
|
||||
|
||||
except SSHKeyNotFoundError:
|
||||
return jsonify({'error': 'SSH key not found'}), 404
|
||||
|
||||
|
||||
@ssh_bp.route('/keys/<key_id>/update-description', methods=['PATCH'])
|
||||
@login_required
|
||||
def update_ssh_key_description(key_id):
|
||||
"""Update SSH key description."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
data = request.get_json()
|
||||
if not data or 'description' not in data:
|
||||
return jsonify({'error': 'description is required'}), 400
|
||||
|
||||
try:
|
||||
ssh_key = ssh_key_service.get_ssh_key(key_id)
|
||||
|
||||
# Check ownership
|
||||
if ssh_key.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
|
||||
updated_key = ssh_key_service.update_ssh_key_description(
|
||||
key_id,
|
||||
data['description']
|
||||
)
|
||||
|
||||
return jsonify(updated_key.to_dict()), 200
|
||||
|
||||
except SSHKeyNotFoundError:
|
||||
return jsonify({'error': 'SSH key not found'}), 404
|
||||
|
||||
|
||||
@ssh_bp.route('/sign', methods=['POST'])
|
||||
@login_required
|
||||
def sign_certificate():
|
||||
"""Sign an SSH certificate for the current user."""
|
||||
user = g.current_user
|
||||
user_id = user.id
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return jsonify({'error': 'No JSON data provided'}), 400
|
||||
|
||||
try:
|
||||
principals = data.get('principals', [])
|
||||
cert_type = data.get('cert_type', 'user')
|
||||
# Accept both 'key_id' and 'cert_id' (from CLI)
|
||||
key_id = data.get('key_id') or data.get('cert_id')
|
||||
expiry_hours = data.get('expiry_hours')
|
||||
|
||||
if not principals:
|
||||
return jsonify({'error': 'principals is required'}), 400
|
||||
|
||||
# If key_id not specified, use first verified key
|
||||
if not key_id:
|
||||
verified_keys = ssh_key_service.get_user_verified_ssh_keys(user_id)
|
||||
if not verified_keys:
|
||||
return jsonify({'error': 'No verified SSH keys found'}), 400
|
||||
key_id = verified_keys[0].id
|
||||
|
||||
# Get the SSH key
|
||||
ssh_key = ssh_key_service.get_ssh_key(key_id)
|
||||
if ssh_key.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
|
||||
if not ssh_key.verified:
|
||||
return jsonify({'error': 'SSH key is not verified'}), 400
|
||||
|
||||
# Resolve which CA to use: org DB CA > config-file CA
|
||||
db_ca = _get_org_ca_for_user(user)
|
||||
ca_private_key = db_ca.private_key if db_ca else None # None → signing service uses config
|
||||
|
||||
# Create signing request
|
||||
signing_request = SSHCertificateSigningRequest(
|
||||
ssh_public_key=ssh_key.payload,
|
||||
principals=principals,
|
||||
cert_type=cert_type,
|
||||
key_id=key_id,
|
||||
expiry_hours=int(expiry_hours) if expiry_hours else None,
|
||||
)
|
||||
|
||||
# Validate request
|
||||
validation_errors = signing_request.validate()
|
||||
if validation_errors:
|
||||
return jsonify({'errors': validation_errors}), 400
|
||||
|
||||
# Sign the certificate (pass ca_private_key=None → service loads from config)
|
||||
response = ssh_ca_service.sign_certificate(signing_request, ca_private_key=ca_private_key)
|
||||
|
||||
# Persist certificate to DB
|
||||
# If user's org has no DB CA, use the system-config-ca record
|
||||
ca_for_db = db_ca or _get_or_create_system_ca()
|
||||
cert_record = _persist_certificate(
|
||||
user_id=user_id,
|
||||
ssh_key_id=key_id,
|
||||
ca=ca_for_db,
|
||||
signing_response=response,
|
||||
request_ip=request.remote_addr,
|
||||
)
|
||||
|
||||
# Audit log
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_CERT_ISSUED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHCertificate',
|
||||
resource_id=cert_record.id if cert_record else key_id,
|
||||
ip_address=request.remote_addr,
|
||||
description=f'Certificate issued for principals: {", ".join(principals)}',
|
||||
)
|
||||
|
||||
result = {
|
||||
'certificate': response.certificate,
|
||||
'serial': response.serial,
|
||||
'principals': response.principals,
|
||||
'valid_after': response.valid_after.isoformat() if response.valid_after else None,
|
||||
'valid_before': response.valid_before.isoformat() if response.valid_before else None,
|
||||
}
|
||||
if cert_record:
|
||||
result['cert_id'] = str(cert_record.id)
|
||||
|
||||
return jsonify(result), 201
|
||||
|
||||
except SSHKeyNotFoundError:
|
||||
return jsonify({'error': 'SSH key not found'}), 404
|
||||
except SSHCertificateError as e:
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_CERT_FAILED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHCertificate',
|
||||
ip_address=request.remote_addr,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
return jsonify({'error': str(e)}), 400
|
||||
except Exception as e:
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_CERT_FAILED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHCertificate',
|
||||
ip_address=request.remote_addr,
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
return jsonify({'error': 'Certificate signing failed: ' + str(e)}), 500
|
||||
|
||||
|
||||
@ssh_bp.route('/certificates', methods=['GET'])
|
||||
@login_required
|
||||
def list_certificates():
|
||||
"""List all SSH certificates issued for the current user."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
try:
|
||||
from gatehouse_app.models.ssh_certificate import SSHCertificate
|
||||
certs = (
|
||||
SSHCertificate.query
|
||||
.filter_by(user_id=user_id, deleted_at=None)
|
||||
.order_by(SSHCertificate.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return jsonify({
|
||||
'certificates': [c.to_dict() for c in certs],
|
||||
'count': len(certs),
|
||||
}), 200
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@ssh_bp.route('/certificates/<cert_id>', methods=['GET'])
|
||||
@login_required
|
||||
def get_certificate(cert_id):
|
||||
"""Get a specific issued certificate (metadata only)."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
try:
|
||||
from gatehouse_app.models.ssh_certificate import SSHCertificate
|
||||
cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first()
|
||||
if not cert:
|
||||
return jsonify({'error': 'Certificate not found'}), 404
|
||||
if cert.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
# Include full certificate text in single-fetch endpoint
|
||||
data = cert.to_dict()
|
||||
data['certificate'] = cert.certificate
|
||||
return jsonify(data), 200
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@ssh_bp.route('/certificates/<cert_id>/revoke', methods=['POST'])
|
||||
@login_required
|
||||
def revoke_certificate(cert_id):
|
||||
"""Revoke an issued certificate."""
|
||||
user_id = g.current_user.id
|
||||
|
||||
data = request.get_json() or {}
|
||||
reason = data.get('reason', 'User requested revocation')
|
||||
|
||||
try:
|
||||
from gatehouse_app.models.ssh_certificate import SSHCertificate
|
||||
cert = SSHCertificate.query.filter_by(id=cert_id, deleted_at=None).first()
|
||||
if not cert:
|
||||
return jsonify({'error': 'Certificate not found'}), 404
|
||||
if cert.user_id != user_id:
|
||||
return jsonify({'error': 'Forbidden'}), 403
|
||||
if cert.revoked:
|
||||
return jsonify({'error': 'Certificate is already revoked'}), 409
|
||||
|
||||
cert.revoke(reason=reason)
|
||||
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_CERT_REVOKED,
|
||||
user_id=user_id,
|
||||
resource_type='SSHCertificate',
|
||||
resource_id=cert_id,
|
||||
ip_address=request.remote_addr,
|
||||
description=f'Revoked: {reason}',
|
||||
)
|
||||
|
||||
return jsonify({'status': 'revoked', 'cert_id': cert_id, 'reason': reason}), 200
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@ssh_bp.route('/ca/public-key', methods=['GET'])
|
||||
@login_required
|
||||
def get_ca_public_key():
|
||||
"""
|
||||
Return the CA public key for this user's organization.
|
||||
|
||||
Server admins should add this key to their host's ``TrustedUserCAKeys``
|
||||
directive so that certificates issued by gatehouse are trusted.
|
||||
|
||||
Query parameters:
|
||||
format: 'openssh' (default) or 'text' — affects Content-Type only
|
||||
|
||||
Returns:
|
||||
{ "public_key": "ssh-ed25519 AAAA...",
|
||||
"fingerprint": "SHA256:...",
|
||||
"ca_name": "..." }
|
||||
"""
|
||||
user = g.current_user
|
||||
|
||||
# Try org CA first
|
||||
db_ca = _get_org_ca_for_user(user)
|
||||
if db_ca:
|
||||
return jsonify({
|
||||
'public_key': db_ca.public_key,
|
||||
'fingerprint': db_ca.fingerprint,
|
||||
'ca_name': db_ca.name,
|
||||
'source': 'db',
|
||||
}), 200
|
||||
|
||||
# Fall back to config-file CA
|
||||
try:
|
||||
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
|
||||
import os
|
||||
cfg = get_ssh_ca_config()
|
||||
key_path = cfg.get_str('ca_key_path', '').strip() + '.pub'
|
||||
if os.path.exists(key_path):
|
||||
with open(key_path) as f:
|
||||
pub_key = f.read().strip()
|
||||
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
||||
return jsonify({
|
||||
'public_key': pub_key,
|
||||
'fingerprint': compute_ssh_fingerprint(pub_key),
|
||||
'ca_name': 'system-config-ca',
|
||||
'source': 'config',
|
||||
}), 200
|
||||
except Exception as e:
|
||||
return jsonify({'error': f'Could not load CA public key: {e}'}), 500
|
||||
|
||||
return jsonify({'error': 'No CA configured for this organization'}), 404
|
||||
|
||||
|
||||
@@ -0,0 +1,271 @@
|
||||
"""SSH CA Configuration Manager.
|
||||
|
||||
Handles loading and managing SSH CA configuration from etc/ssh_ca.conf
|
||||
and environment variables.
|
||||
"""
|
||||
import os
|
||||
import configparser
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class SSHCAConfig:
|
||||
"""Configuration manager for SSH CA settings.
|
||||
|
||||
Loads configuration from:
|
||||
1. etc/ssh_ca.conf file
|
||||
2. Environment variables (override config file)
|
||||
3. Application environment-specific defaults
|
||||
|
||||
Example:
|
||||
config = SSHCAConfig()
|
||||
cert_hours = config.get_int('cert_validity_hours')
|
||||
kms_key = config.get_str('aws_kms_key_id')
|
||||
"""
|
||||
|
||||
# Configuration file location (relative to project root)
|
||||
DEFAULT_CONFIG_FILE = "etc/ssh_ca.conf"
|
||||
|
||||
# Default values if config file is missing
|
||||
DEFAULTS = {
|
||||
'cert_validity_hours': '1',
|
||||
'max_cert_validity_hours': '24',
|
||||
'max_certs_per_user': '100',
|
||||
'crl_enabled': 'true',
|
||||
'crl_endpoint': 'https://ca.example.com/crl',
|
||||
'crl_refresh_hours': '24',
|
||||
'default_key_type': 'ed25519',
|
||||
'rsa_key_bits': '4096',
|
||||
'private_key_encryption': 'kms',
|
||||
'aws_kms_key_id': '',
|
||||
'extensions_enabled': 'true',
|
||||
'extensions': 'permit-X11-forwarding,permit-agent-forwarding,permit-pty,permit-port-forwarding,permit-user-rc',
|
||||
'critical_options_enabled': 'false',
|
||||
'max_principals_per_cert': '256',
|
||||
'max_key_id_length': '255',
|
||||
'log_level': 'INFO',
|
||||
'audit_enabled': 'true',
|
||||
'require_key_verification': 'true',
|
||||
'verification_challenge_max_age': '24',
|
||||
'rate_limit_certs_per_minute': '5',
|
||||
'request_timeout': '30',
|
||||
'auto_delete_unverified_days': '30',
|
||||
'archive_expired_days': '365',
|
||||
'oauth_token_endpoint': '/api/v1/oauth2/token',
|
||||
'oauth_userinfo_endpoint': '/api/v1/oauth2/userinfo',
|
||||
'ca_key_path': '',
|
||||
}
|
||||
|
||||
def __init__(self, config_file: Optional[str] = None, environment: Optional[str] = None):
|
||||
"""Initialize SSH CA configuration.
|
||||
|
||||
Args:
|
||||
config_file: Path to config file (default: etc/ssh_ca.conf)
|
||||
environment: Environment name (development, production, testing)
|
||||
Default: value of FLASK_ENV or 'development'
|
||||
"""
|
||||
self.config = configparser.ConfigParser()
|
||||
|
||||
# Determine environment
|
||||
if environment is None:
|
||||
environment = os.environ.get('FLASK_ENV', 'development')
|
||||
self.environment = environment
|
||||
|
||||
# Load config file
|
||||
if config_file is None:
|
||||
# Try to find config file relative to this module
|
||||
module_dir = Path(__file__).parent.parent.parent
|
||||
config_file = module_dir / self.DEFAULT_CONFIG_FILE
|
||||
|
||||
self.config_file = config_file
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from file and apply environment-specific overrides."""
|
||||
# Set defaults
|
||||
self.config['default'] = self.DEFAULTS.copy()
|
||||
|
||||
# Load config file if it exists
|
||||
if Path(self.config_file).exists():
|
||||
self.config.read(self.config_file)
|
||||
|
||||
# Apply environment-specific configuration
|
||||
if self.environment in self.config:
|
||||
for key, value in self.config[self.environment].items():
|
||||
self.config['default'][key] = value
|
||||
|
||||
def get_str(self, key: str, default: Optional[str] = None) -> str:
|
||||
"""Get a string configuration value.
|
||||
|
||||
First checks environment variables (SSH_CA_<KEY>), then config file.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as string
|
||||
"""
|
||||
env_key = f"SSH_CA_{key.upper()}"
|
||||
|
||||
# Check environment variable first
|
||||
if env_key in os.environ:
|
||||
return os.environ[env_key]
|
||||
|
||||
# Check config file
|
||||
if key in self.config['default']:
|
||||
value = self.config['default'][key]
|
||||
# Handle environment variable substitution
|
||||
return os.path.expandvars(value)
|
||||
|
||||
# Return default
|
||||
if default is not None:
|
||||
return default
|
||||
|
||||
return self.DEFAULTS.get(key, '')
|
||||
|
||||
def get_int(self, key: str, default: Optional[int] = None) -> int:
|
||||
"""Get an integer configuration value.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as integer
|
||||
|
||||
Raises:
|
||||
ValueError: If value cannot be converted to integer
|
||||
"""
|
||||
str_value = self.get_str(key)
|
||||
if not str_value:
|
||||
if default is not None:
|
||||
return default
|
||||
raise ValueError(f"No value found for {key}")
|
||||
|
||||
try:
|
||||
return int(str_value)
|
||||
except ValueError:
|
||||
if default is not None:
|
||||
return default
|
||||
raise ValueError(f"Configuration {key}={str_value} is not a valid integer")
|
||||
|
||||
def get_bool(self, key: str, default: Optional[bool] = None) -> bool:
|
||||
"""Get a boolean configuration value.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as boolean
|
||||
"""
|
||||
str_value = self.get_str(key)
|
||||
if not str_value:
|
||||
if default is not None:
|
||||
return default
|
||||
return False
|
||||
|
||||
return str_value.lower() in ('true', '1', 'yes', 'on')
|
||||
|
||||
def get_list(self, key: str, delimiter: str = ',', default: Optional[list] = None) -> list:
|
||||
"""Get a comma-separated list configuration value.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
delimiter: Delimiter between items (default: comma)
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as list of strings
|
||||
"""
|
||||
str_value = self.get_str(key)
|
||||
if not str_value:
|
||||
if default is not None:
|
||||
return default
|
||||
return []
|
||||
|
||||
return [item.strip() for item in str_value.split(delimiter) if item.strip()]
|
||||
|
||||
def validate_config(self) -> list:
|
||||
"""Validate SSH CA configuration.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Check cert validity hours
|
||||
try:
|
||||
validity = self.get_int('cert_validity_hours')
|
||||
max_validity = self.get_int('max_cert_validity_hours')
|
||||
if validity > max_validity:
|
||||
errors.append(
|
||||
f"cert_validity_hours ({validity}) > max_cert_validity_hours ({max_validity})"
|
||||
)
|
||||
except ValueError as e:
|
||||
errors.append(f"Invalid cert validity hours: {e}")
|
||||
|
||||
# Check key type
|
||||
valid_key_types = ['ed25519', 'rsa', 'ecdsa']
|
||||
key_type = self.get_str('default_key_type', 'ed25519')
|
||||
if key_type not in valid_key_types:
|
||||
errors.append(f"Invalid key type: {key_type}. Must be one of {valid_key_types}")
|
||||
|
||||
# Check encryption method
|
||||
valid_methods = ['kms', 'local']
|
||||
encryption = self.get_str('private_key_encryption', 'kms')
|
||||
if encryption not in valid_methods:
|
||||
errors.append(f"Invalid private_key_encryption: {encryption}. Must be one of {valid_methods}")
|
||||
|
||||
# Warn if using local encryption in production
|
||||
if encryption == 'local' and self.environment == 'production':
|
||||
errors.append("WARNING: Using local key encryption in production! Use KMS instead.")
|
||||
|
||||
# Check KMS key ID if using KMS
|
||||
if encryption == 'kms':
|
||||
kms_key = self.get_str('aws_kms_key_id', '').strip()
|
||||
if not kms_key:
|
||||
errors.append("aws_kms_key_id not set but private_key_encryption=kms")
|
||||
|
||||
# Check principals limit
|
||||
max_principals = self.get_int('max_principals_per_cert')
|
||||
if max_principals > 256:
|
||||
errors.append(f"max_principals_per_cert ({max_principals}) exceeds SSH limit of 256")
|
||||
|
||||
return errors
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Export current configuration as dictionary.
|
||||
"""
|
||||
return dict(self.config['default'])
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of configuration."""
|
||||
return f"<SSHCAConfig environment={self.environment} file={self.config_file}>"
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_config_instance = None
|
||||
|
||||
|
||||
def get_ssh_ca_config() -> SSHCAConfig:
|
||||
"""Get the global SSH CA configuration instance.
|
||||
|
||||
This function uses a singleton pattern to ensure only one
|
||||
configuration instance is created and reused.
|
||||
|
||||
Returns:
|
||||
SSHCAConfig instance
|
||||
"""
|
||||
global _config_instance
|
||||
if _config_instance is None:
|
||||
_config_instance = SSHCAConfig()
|
||||
return _config_instance
|
||||
|
||||
|
||||
def reset_config_instance():
|
||||
"""Reset the global configuration instance.
|
||||
"""
|
||||
global _config_instance
|
||||
_config_instance = None
|
||||
@@ -19,6 +19,21 @@ from gatehouse_app.exceptions.validation_exceptions import (
|
||||
OrganizationNotFoundError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from gatehouse_app.exceptions.ssh_exceptions import (
|
||||
SSHCAError,
|
||||
SSHKeyError,
|
||||
SSHKeyNotFoundError,
|
||||
SSHKeyAlreadyExistsError,
|
||||
SSHKeyNotVerifiedError,
|
||||
SSHCertificateError,
|
||||
SSHCertificateNotFoundError,
|
||||
CAError,
|
||||
CANotFoundError,
|
||||
PrincipalError,
|
||||
PrincipalNotFoundError,
|
||||
DepartmentError,
|
||||
DepartmentNotFoundError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseAPIException",
|
||||
@@ -37,4 +52,18 @@ __all__ = [
|
||||
"EmailAlreadyExistsError",
|
||||
"OrganizationNotFoundError",
|
||||
"UserNotFoundError",
|
||||
"SSHCAError",
|
||||
"SSHKeyError",
|
||||
"SSHKeyNotFoundError",
|
||||
"SSHKeyAlreadyExistsError",
|
||||
"SSHKeyNotVerifiedError",
|
||||
"SSHCertificateError",
|
||||
"SSHCertificateNotFoundError",
|
||||
"CAError",
|
||||
"CANotFoundError",
|
||||
"PrincipalError",
|
||||
"PrincipalNotFoundError",
|
||||
"DepartmentError",
|
||||
"DepartmentNotFoundError",
|
||||
]
|
||||
|
||||
|
||||
@@ -16,9 +16,10 @@ class BaseAPIException(Exception):
|
||||
message: Custom error message
|
||||
error_details: Additional error details dictionary
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(self.message)
|
||||
if message:
|
||||
self.message = message
|
||||
super().__init__(message) # update args so str(e) works
|
||||
self.error_details = error_details or {}
|
||||
|
||||
def to_dict(self):
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
"""SSH-specific exceptions."""
|
||||
from gatehouse_app.exceptions.base import BaseAPIException
|
||||
|
||||
|
||||
class SSHCAError(BaseAPIException):
|
||||
"""Base exception for SSH CA operations."""
|
||||
|
||||
status_code = 500
|
||||
error_type = "SSH_CA_ERROR"
|
||||
|
||||
|
||||
class SSHKeyError(BaseAPIException):
|
||||
"""Exception for SSH key operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "SSH_KEY_ERROR"
|
||||
|
||||
|
||||
class SSHKeyNotFoundError(BaseAPIException):
|
||||
"""SSH key not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "SSH_KEY_NOT_FOUND"
|
||||
|
||||
|
||||
class SSHKeyAlreadyExistsError(BaseAPIException):
|
||||
"""SSH key already exists (duplicate fingerprint)."""
|
||||
|
||||
status_code = 409
|
||||
error_type = "SSH_KEY_ALREADY_EXISTS"
|
||||
|
||||
|
||||
class SSHKeyNotVerifiedError(BaseAPIException):
|
||||
"""SSH key has not been verified."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "SSH_KEY_NOT_VERIFIED"
|
||||
|
||||
|
||||
class SSHCertificateError(BaseAPIException):
|
||||
"""Exception for SSH certificate operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "SSH_CERT_ERROR"
|
||||
|
||||
|
||||
class SSHCertificateNotFoundError(BaseAPIException):
|
||||
"""SSH certificate not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "SSH_CERT_NOT_FOUND"
|
||||
|
||||
|
||||
class CAError(BaseAPIException):
|
||||
"""Exception for Certificate Authority operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "CA_ERROR"
|
||||
|
||||
|
||||
class CANotFoundError(BaseAPIException):
|
||||
"""Certificate Authority not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "CA_NOT_FOUND"
|
||||
|
||||
|
||||
class PrincipalError(BaseAPIException):
|
||||
"""Exception for principal operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "PRINCIPAL_ERROR"
|
||||
|
||||
|
||||
class PrincipalNotFoundError(BaseAPIException):
|
||||
"""Principal not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "PRINCIPAL_NOT_FOUND"
|
||||
|
||||
|
||||
class DepartmentError(BaseAPIException):
|
||||
"""Exception for department operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "DEPARTMENT_ERROR"
|
||||
|
||||
|
||||
class DepartmentNotFoundError(BaseAPIException):
|
||||
"""Department not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "DEPARTMENT_NOT_FOUND"
|
||||
@@ -29,6 +29,13 @@ from gatehouse_app.models.principal import (
|
||||
Principal,
|
||||
PrincipalMembership,
|
||||
)
|
||||
from gatehouse_app.models.ssh_key import SSHKey
|
||||
from gatehouse_app.models.ca import CA, KeyType, CertType
|
||||
from gatehouse_app.models.ssh_certificate import SSHCertificate, CertificateStatus
|
||||
from gatehouse_app.models.certificate_audit_log import CertificateAuditLog
|
||||
from gatehouse_app.models.password_reset_token import PasswordResetToken
|
||||
from gatehouse_app.models.email_verification_token import EmailVerificationToken
|
||||
from gatehouse_app.models.org_invite_token import OrgInviteToken
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
@@ -55,4 +62,14 @@ __all__ = [
|
||||
"DepartmentPrincipal",
|
||||
"Principal",
|
||||
"PrincipalMembership",
|
||||
"SSHKey",
|
||||
"CA",
|
||||
"KeyType",
|
||||
"CertType",
|
||||
"SSHCertificate",
|
||||
"CertificateStatus",
|
||||
"CertificateAuditLog",
|
||||
"PasswordResetToken",
|
||||
"EmailVerificationToken",
|
||||
"OrgInviteToken",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
"""Certificate Authority (CA) model."""
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class KeyType(str, Enum):
|
||||
"""SSH CA key types."""
|
||||
|
||||
ED25519 = "ed25519"
|
||||
RSA = "rsa"
|
||||
ECDSA = "ecdsa"
|
||||
|
||||
|
||||
class CertType(str, Enum):
|
||||
"""SSH certificate types."""
|
||||
|
||||
USER = "user"
|
||||
HOST = "host"
|
||||
|
||||
|
||||
class CA(BaseModel):
|
||||
"""Certificate Authority (CA) model for SSH certificate signing.
|
||||
|
||||
Each organization can have multiple CAs for different purposes
|
||||
(e.g., production vs. staging). Private keys are encrypted at rest
|
||||
and should be protected with KMS.
|
||||
"""
|
||||
|
||||
__tablename__ = "cas"
|
||||
|
||||
organization_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("organizations.id"),
|
||||
nullable=True, # NULL for the global system-config CA
|
||||
index=True,
|
||||
)
|
||||
|
||||
# CA name and description
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Key type (ED25519, RSA, ECDSA)
|
||||
key_type = db.Column(
|
||||
db.Enum(KeyType, values_callable=lambda x: [e.value for e in x]),
|
||||
default=KeyType.ED25519,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Private key (encrypted at rest by database/KMS)
|
||||
# Format: PEM-encoded private key
|
||||
private_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Public key (PEM format)
|
||||
public_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# SHA256 fingerprint of the public key
|
||||
fingerprint = db.Column(db.String(255), nullable=False, unique=True)
|
||||
|
||||
# CRL (Certificate Revocation List) configuration
|
||||
crl_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||
crl_endpoint = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Default certificate validity in hours
|
||||
# Can be overridden per certificate request
|
||||
default_cert_validity_hours = db.Column(
|
||||
db.Integer,
|
||||
default=1,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Maximum validity duration allowed
|
||||
max_cert_validity_hours = db.Column(
|
||||
db.Integer,
|
||||
default=24,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# CA status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Key rotation tracking
|
||||
rotated_at = db.Column(db.DateTime, nullable=True)
|
||||
rotation_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="cas")
|
||||
certificates = db.relationship(
|
||||
"SSHCertificate",
|
||||
back_populates="ca",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"organization_id", "name", name="uix_org_ca_name"
|
||||
),
|
||||
db.Index("idx_ca_org_active", "organization_id", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of CA."""
|
||||
return f"<CA {self.name} (org_id={self.organization_id}, type={self.key_type})>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert CA to dictionary."""
|
||||
exclude = exclude or []
|
||||
# Never expose private key in API responses
|
||||
exclude.extend(["private_key"])
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
# Add computed fields
|
||||
data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None])
|
||||
data["active_certs"] = len([
|
||||
c for c in self.certificates
|
||||
if c.deleted_at is None and not c.revoked
|
||||
])
|
||||
data["revoked_certs"] = len([
|
||||
c for c in self.certificates
|
||||
if c.deleted_at is None and c.revoked
|
||||
])
|
||||
|
||||
return data
|
||||
|
||||
def get_active_certificates(self):
|
||||
"""Get all active (non-revoked) certificates issued by this CA.
|
||||
|
||||
Returns:
|
||||
List of non-revoked SSHCertificate objects
|
||||
"""
|
||||
return [
|
||||
c for c in self.certificates
|
||||
if c.deleted_at is None and not c.revoked
|
||||
]
|
||||
|
||||
def rotate_key(self, new_private_key, new_public_key, new_fingerprint, reason=None):
|
||||
"""Rotate the CA's key pair.
|
||||
|
||||
This should only be done in carefully controlled circumstances.
|
||||
All existing certificates remain valid but no new certs can be
|
||||
signed with the old key.
|
||||
|
||||
Args:
|
||||
new_private_key: New PEM-encoded private key
|
||||
new_public_key: New PEM-encoded public key
|
||||
new_fingerprint: SHA256 fingerprint of new public key
|
||||
reason: Optional reason for rotation
|
||||
"""
|
||||
self.private_key = new_private_key
|
||||
self.public_key = new_public_key
|
||||
self.fingerprint = new_fingerprint
|
||||
self.rotated_at = datetime.utcnow()
|
||||
self.rotation_reason = reason
|
||||
self.save()
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Certificate audit log model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class CertificateAuditLog(BaseModel):
|
||||
"""Audit log for SSH certificate lifecycle events.
|
||||
|
||||
Tracks all operations on SSH certificates: signing, revocation,
|
||||
validation, etc. This is separate from the general AuditLog to
|
||||
provide detailed certificate operation tracking.
|
||||
"""
|
||||
|
||||
__tablename__ = "certificate_audit_logs"
|
||||
|
||||
# Reference to the certificate
|
||||
certificate_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("ssh_certificates.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# The user who performed the action (can be null for system actions)
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Action type (e.g., "signed", "revoked", "validated", "requested")
|
||||
action = db.Column(db.String(50), nullable=False, index=True)
|
||||
|
||||
# Request details
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.String(512), nullable=True)
|
||||
request_id = db.Column(db.String(36), nullable=True)
|
||||
|
||||
# Detailed message
|
||||
message = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Additional context
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Success/failure
|
||||
success = db.Column(db.Boolean, default=True, nullable=False)
|
||||
error_message = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
certificate = db.relationship("SSHCertificate", back_populates="audit_logs")
|
||||
user = db.relationship("User")
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_cert_audit_cert_action", "certificate_id", "action"),
|
||||
db.Index("idx_cert_audit_user", "user_id", "created_at"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of CertificateAuditLog."""
|
||||
return f"<CertificateAuditLog cert_id={self.certificate_id} action={self.action}>"
|
||||
|
||||
@classmethod
|
||||
def log(cls, certificate_id, action, user_id=None, **kwargs):
|
||||
"""Create a certificate audit log entry.
|
||||
|
||||
Args:
|
||||
certificate_id: ID of the certificate
|
||||
action: Action type (e.g., "signed", "revoked")
|
||||
user_id: ID of the user performing the action (optional)
|
||||
**kwargs: Additional fields (ip_address, user_agent, message, etc.)
|
||||
|
||||
Returns:
|
||||
CertificateAuditLog instance
|
||||
"""
|
||||
log_entry = cls(
|
||||
certificate_id=certificate_id,
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
**kwargs
|
||||
)
|
||||
log_entry.save()
|
||||
return log_entry
|
||||
@@ -40,6 +40,9 @@ class Organization(BaseModel):
|
||||
principals = db.relationship(
|
||||
"Principal", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
cas = db.relationship(
|
||||
"CA", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of Organization."""
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
"""SSH Certificate model."""
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.ca import CertType
|
||||
|
||||
|
||||
class CertificateStatus(str, Enum):
|
||||
"""SSH certificate lifecycle status."""
|
||||
|
||||
REQUESTED = "requested" # Waiting for signing
|
||||
ISSUED = "issued" # Signed and valid
|
||||
REVOKED = "revoked" # Manually revoked
|
||||
EXPIRED = "expired" # Validity period ended
|
||||
SUPERSEDED = "superseded" # Replaced by newer cert
|
||||
|
||||
|
||||
class SSHCertificate(BaseModel):
|
||||
"""SSH Certificate model representing a signed SSH user/host certificate.
|
||||
|
||||
Certificates are issued by a CA and associated with an SSH public key.
|
||||
They include principals (access levels), validity periods, and other
|
||||
OpenSSH certificate metadata.
|
||||
"""
|
||||
|
||||
__tablename__ = "ssh_certificates"
|
||||
|
||||
# Certificate relationships
|
||||
ca_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("cas.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
ssh_key_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("ssh_keys.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Certificate content (full signed certificate in OpenSSH format)
|
||||
certificate = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Certificate metadata
|
||||
serial = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
key_id = db.Column(db.String(255), nullable=False) # Usually user email
|
||||
cert_type = db.Column(
|
||||
db.Enum(CertType, values_callable=lambda x: [e.value for e in x]),
|
||||
default=CertType.USER,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Principals (JSON list) - e.g., ["prod-servers", "dev-servers"]
|
||||
principals = db.Column(db.JSON, nullable=False, default=list)
|
||||
|
||||
# Validity period
|
||||
valid_after = db.Column(db.DateTime, nullable=False)
|
||||
valid_before = db.Column(db.DateTime, nullable=False)
|
||||
|
||||
# Revocation status
|
||||
revoked = db.Column(db.Boolean, default=False, nullable=False, index=True)
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoke_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Status tracking
|
||||
status = db.Column(
|
||||
db.Enum(CertificateStatus, values_callable=lambda x: [e.value for e in x]),
|
||||
default=CertificateStatus.ISSUED,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Request metadata
|
||||
request_ip = db.Column(db.String(45), nullable=True)
|
||||
request_user_agent = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Critical options (JSON) - OpenSSH critical options
|
||||
# See: https://man.openbsd.org/ssh-cert
|
||||
critical_options = db.Column(db.JSON, nullable=True, default=dict)
|
||||
|
||||
# Extensions (JSON) - OpenSSH extensions
|
||||
# Common ones: permit-X11-forwarding, permit-agent-forwarding, permit-pty, etc.
|
||||
extensions = db.Column(db.JSON, nullable=True, default=dict)
|
||||
|
||||
# Relationships
|
||||
ca = db.relationship("CA", back_populates="certificates")
|
||||
user = db.relationship("User", back_populates="ssh_certificates")
|
||||
ssh_key = db.relationship(
|
||||
"SSHKey",
|
||||
back_populates="certificates",
|
||||
foreign_keys="SSHCertificate.ssh_key_id",
|
||||
)
|
||||
audit_logs = db.relationship(
|
||||
"CertificateAuditLog",
|
||||
back_populates="certificate",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_cert_user_status", "user_id", "status"),
|
||||
db.Index("idx_cert_validity", "valid_after", "valid_before"),
|
||||
db.Index("idx_cert_revoked", "revoked", "revoked_at"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of SSHCertificate."""
|
||||
return f"<SSHCertificate serial={self.serial[:16]}... user_id={self.user_id}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert certificate to dictionary."""
|
||||
exclude = exclude or []
|
||||
# Optionally exclude the certificate content (it's large)
|
||||
if "certificate" not in exclude:
|
||||
exclude.append("certificate")
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
# Add computed fields
|
||||
data["is_valid"] = self.is_valid()
|
||||
data["days_until_expiry"] = self.days_until_expiry()
|
||||
|
||||
return data
|
||||
|
||||
def is_valid(self):
|
||||
"""Check if certificate is currently valid.
|
||||
|
||||
Returns:
|
||||
True if certificate is issued, not revoked, and within validity period
|
||||
"""
|
||||
if self.revoked or self.status == CertificateStatus.REVOKED:
|
||||
return False
|
||||
|
||||
now = datetime.utcnow()
|
||||
return self.valid_after <= now <= self.valid_before
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if certificate has expired.
|
||||
|
||||
Returns:
|
||||
True if current time is past valid_before
|
||||
"""
|
||||
return datetime.utcnow() > self.valid_before
|
||||
|
||||
def days_until_expiry(self):
|
||||
"""Get number of days until certificate expires.
|
||||
|
||||
Returns:
|
||||
Number of days remaining (negative if already expired)
|
||||
"""
|
||||
delta = self.valid_before - datetime.utcnow()
|
||||
return delta.days + (1 if delta.seconds > 0 else 0)
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""Revoke this certificate.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked = True
|
||||
self.revoked_at = datetime.utcnow()
|
||||
self.revoke_reason = reason
|
||||
self.status = CertificateStatus.REVOKED
|
||||
self.save()
|
||||
|
||||
def mark_expired(self):
|
||||
"""Mark certificate as expired when validity period ends."""
|
||||
self.status = CertificateStatus.EXPIRED
|
||||
self.save()
|
||||
@@ -0,0 +1,96 @@
|
||||
"""SSH Key model."""
|
||||
from datetime import datetime
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class SSHKey(BaseModel):
|
||||
"""SSH Key model representing a user's SSH public key.
|
||||
|
||||
This model stores SSH public keys that users register for certificate signing.
|
||||
Users must verify ownership of the key before it can be used for signing certificates.
|
||||
"""
|
||||
|
||||
__tablename__ = "ssh_keys"
|
||||
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# SSH key payload in OpenSSH format (e.g., "ssh-rsa AAAAB3Nz...")
|
||||
payload = db.Column(db.Text, nullable=False, unique=True)
|
||||
|
||||
# SHA256 fingerprint for quick comparison
|
||||
fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
|
||||
# Optional description for the key (e.g., "My laptop key")
|
||||
description = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Verification status
|
||||
verified = db.Column(db.Boolean, default=False, nullable=False, index=True)
|
||||
verified_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Verification challenge
|
||||
verify_text = db.Column(db.String(255), nullable=True)
|
||||
verify_text_created_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Key type extracted from the key (ssh-rsa, ssh-ed25519, etc.)
|
||||
key_type = db.Column(db.String(50), nullable=True)
|
||||
|
||||
# Key bits/length
|
||||
key_bits = db.Column(db.Integer, nullable=True)
|
||||
|
||||
# Comment from the key (usually email or key name)
|
||||
key_comment = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="ssh_keys")
|
||||
certificates = db.relationship(
|
||||
"SSHCertificate",
|
||||
back_populates="ssh_key",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHCertificate.ssh_key_id",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_ssh_key_user_verified", "user_id", "verified"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of SSHKey."""
|
||||
return f"<SSHKey {self.fingerprint[:16]}... user_id={self.user_id}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert SSH key to dictionary."""
|
||||
exclude = exclude or []
|
||||
exclude.extend(["payload", "verify_text"]) # Never expose these in API
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
# Add computed fields
|
||||
data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None])
|
||||
|
||||
return data
|
||||
|
||||
def mark_verified(self):
|
||||
"""Mark this SSH key as verified."""
|
||||
self.verified = True
|
||||
self.verified_at = datetime.utcnow()
|
||||
self.save()
|
||||
|
||||
def needs_verification_refresh(self, max_age_hours=24):
|
||||
"""Check if verification challenge needs to be refreshed.
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age of verification challenge in hours
|
||||
|
||||
Returns:
|
||||
True if verification challenge is stale
|
||||
"""
|
||||
if not self.verify_text_created_at:
|
||||
return True
|
||||
|
||||
age = datetime.utcnow() - self.verify_text_created_at
|
||||
return age.total_seconds() > (max_age_hours * 3600)
|
||||
@@ -55,6 +55,18 @@ class User(BaseModel):
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="PrincipalMembership.user_id",
|
||||
)
|
||||
ssh_keys = db.relationship(
|
||||
"SSHKey",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHKey.user_id",
|
||||
)
|
||||
ssh_certificates = db.relationship(
|
||||
"SSHCertificate",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHCertificate.user_id",
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of User."""
|
||||
|
||||
@@ -12,7 +12,7 @@ from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.authentication_method import OAuthState
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth_service import (
|
||||
ExternalAuthService,
|
||||
@@ -139,7 +139,7 @@ class OAuthFlowService:
|
||||
except ExternalAuthError as e:
|
||||
# Log failed initiation
|
||||
AuditService.log_action(
|
||||
action="external_auth.login.initiated",
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
@@ -236,7 +236,7 @@ class OAuthFlowService:
|
||||
|
||||
except ExternalAuthError as e:
|
||||
AuditService.log_action(
|
||||
action="external_auth.register.initiated",
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
@@ -399,6 +399,27 @@ class OAuthFlowService:
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
if not user_info.get("provider_user_id"):
|
||||
raise OAuthFlowError(
|
||||
"Provider did not return a user identifier (sub claim). "
|
||||
"Cannot complete authentication.",
|
||||
"MISSING_PROVIDER_USER_ID",
|
||||
400,
|
||||
)
|
||||
|
||||
if not user_info.get("email"):
|
||||
raise OAuthFlowError(
|
||||
"Provider did not return an email address. "
|
||||
"Cannot complete authentication.",
|
||||
"MISSING_EMAIL",
|
||||
400,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Got user_info from provider: sub={user_info['provider_user_id']}, "
|
||||
f"email={user_info['email']}, email_verified={user_info.get('email_verified')}"
|
||||
)
|
||||
|
||||
# Look up user by provider_user_id
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
method_type=provider_type,
|
||||
|
||||
@@ -0,0 +1,333 @@
|
||||
"""SSH Certificate Authority signing service.
|
||||
|
||||
Handles SSH certificate signing operations, leveraging sshkey-tools library.
|
||||
This service is a Gatehouse-integrated version of the secuird/ssh_ca.py logic.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from sshkey_tools.cert import SSHCertificate, CertificateFields
|
||||
from sshkey_tools.keys import PublicKey, PrivateKey
|
||||
|
||||
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
|
||||
from gatehouse_app.exceptions import SSHCAError, ValidationError
|
||||
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHCASigningError(Exception):
|
||||
"""SSH CA signing operation error."""
|
||||
pass
|
||||
|
||||
|
||||
class SSHCertificateSigningRequest:
|
||||
"""Represents an SSH certificate signing request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssh_public_key: str,
|
||||
principals: List[str],
|
||||
key_id: str,
|
||||
cert_type: str = "user",
|
||||
expiry_hours: Optional[int] = None,
|
||||
critical_options: Optional[Dict[str, str]] = None,
|
||||
extensions: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize signing request.
|
||||
|
||||
Args:
|
||||
ssh_public_key: Public key in OpenSSH format (e.g., "ssh-ed25519 AAAA...")
|
||||
principals: List of principals (e.g., ["prod-servers", "staging"])
|
||||
key_id: Key identifier (usually user email)
|
||||
cert_type: Certificate type - "user" or "host" (default: user)
|
||||
expiry_hours: Certificate validity in hours
|
||||
critical_options: Critical options dict
|
||||
extensions: List of extensions (e.g., ["permit-pty", "permit-agent-forwarding"])
|
||||
"""
|
||||
self.ssh_public_key = ssh_public_key
|
||||
self.principals = principals or []
|
||||
self.key_id = key_id
|
||||
self.cert_type = cert_type
|
||||
self.expiry_hours = expiry_hours
|
||||
self.critical_options = critical_options or {}
|
||||
self.extensions = extensions or []
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate the signing request.
|
||||
|
||||
Returns:
|
||||
List of validation errors (empty if valid)
|
||||
"""
|
||||
errors = []
|
||||
config = get_ssh_ca_config()
|
||||
|
||||
# Validate cert type
|
||||
if self.cert_type not in ("user", "host"):
|
||||
errors.append(f"Invalid cert_type: {self.cert_type}. Must be 'user' or 'host'")
|
||||
|
||||
# Validate SSH public key
|
||||
if not self.ssh_public_key or len(self.ssh_public_key) < 16:
|
||||
errors.append("SSH public key is missing or invalid")
|
||||
else:
|
||||
try:
|
||||
PublicKey.from_string(self.ssh_public_key)
|
||||
except Exception as e:
|
||||
errors.append(f"SSH public key is not valid: {str(e)}")
|
||||
|
||||
# Validate principals
|
||||
if not self.principals or len(self.principals) == 0:
|
||||
errors.append("At least one principal is required")
|
||||
else:
|
||||
max_principals = config.get_int('max_principals_per_cert')
|
||||
if len(self.principals) > max_principals:
|
||||
errors.append(
|
||||
f"Too many principals ({len(self.principals)}). "
|
||||
f"Maximum is {max_principals}"
|
||||
)
|
||||
|
||||
# Validate key_id
|
||||
if not self.key_id or len(self.key_id) < 5:
|
||||
errors.append("key_id is missing or too short (minimum 5 characters)")
|
||||
else:
|
||||
max_id_len = config.get_int('max_key_id_length')
|
||||
if len(self.key_id) > max_id_len:
|
||||
errors.append(f"key_id exceeds maximum length of {max_id_len}")
|
||||
|
||||
# Validate expiry_hours
|
||||
if self.expiry_hours is not None:
|
||||
if not isinstance(self.expiry_hours, int) or self.expiry_hours <= 0:
|
||||
errors.append("expiry_hours must be a positive integer")
|
||||
else:
|
||||
max_validity = config.get_int('max_cert_validity_hours')
|
||||
if self.expiry_hours > max_validity:
|
||||
errors.append(
|
||||
f"Requested expiry ({self.expiry_hours}h) exceeds "
|
||||
f"maximum allowed ({max_validity}h)"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
class SSHCertificateSigningResponse:
|
||||
"""Represents a signed SSH certificate response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
certificate: str,
|
||||
serial: str,
|
||||
valid_after: datetime,
|
||||
valid_before: datetime,
|
||||
principals: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize signing response.
|
||||
|
||||
Args:
|
||||
certificate: Full certificate in OpenSSH format
|
||||
serial: Certificate serial number
|
||||
valid_after: Validity start datetime
|
||||
valid_before: Validity end datetime
|
||||
principals: List of principals the cert was issued for
|
||||
"""
|
||||
self.certificate = certificate
|
||||
self.serial = serial
|
||||
self.valid_after = valid_after
|
||||
self.valid_before = valid_before
|
||||
self.principals = principals or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert response to dictionary."""
|
||||
return {
|
||||
'certificate': self.certificate,
|
||||
'serial': self.serial,
|
||||
'valid_after': self.valid_after.isoformat(),
|
||||
'valid_before': self.valid_before.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class SSHCASigningService:
|
||||
"""Service for signing SSH certificates.
|
||||
|
||||
This service handles all SSH certificate signing operations.
|
||||
It uses configuration from ssh_ca_config to apply rules and limits.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the SSH CA signing service."""
|
||||
self.config = get_ssh_ca_config()
|
||||
self.logger = logger
|
||||
|
||||
def _load_ca_key_from_config(self) -> str:
|
||||
"""Load CA private key from config (local file or env var).
|
||||
|
||||
Returns:
|
||||
CA private key in PEM/OpenSSH format as string
|
||||
|
||||
Raises:
|
||||
SSHCASigningError: If key cannot be loaded
|
||||
"""
|
||||
# Check env var first
|
||||
key_content = os.environ.get('SSH_CA_PRIVATE_KEY')
|
||||
if key_content:
|
||||
return key_content
|
||||
|
||||
# Load from file path
|
||||
key_path = self.config.get_str('ca_key_path', '').strip()
|
||||
if not key_path:
|
||||
raise SSHCASigningError(
|
||||
"CA private key not configured. Set SSH_CA_PRIVATE_KEY env var "
|
||||
"or ca_key_path in etc/ssh_ca.conf"
|
||||
)
|
||||
|
||||
key_path = os.path.expandvars(os.path.expanduser(key_path))
|
||||
if not os.path.exists(key_path):
|
||||
raise SSHCASigningError(f"CA private key file not found: {key_path}")
|
||||
|
||||
with open(key_path, 'r') as f:
|
||||
return f.read()
|
||||
|
||||
def sign_certificate(
|
||||
self,
|
||||
signing_request: SSHCertificateSigningRequest,
|
||||
ca_private_key: Optional[str] = None,
|
||||
) -> SSHCertificateSigningResponse:
|
||||
"""Sign an SSH certificate.
|
||||
|
||||
Args:
|
||||
signing_request: SSHCertificateSigningRequest instance
|
||||
ca_private_key: CA private key in PEM format. If not provided,
|
||||
loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var)
|
||||
|
||||
Returns:
|
||||
SSHCertificateSigningResponse with signed certificate
|
||||
|
||||
Raises:
|
||||
SSHCASigningError: If signing fails
|
||||
ValidationError: If request is invalid
|
||||
"""
|
||||
# Validate request
|
||||
errors = signing_request.validate()
|
||||
if errors:
|
||||
error_msg = "; ".join(errors)
|
||||
self.logger.error(f"Certificate signing validation failed: {error_msg}")
|
||||
raise ValidationError(f"Certificate signing validation failed: {error_msg}")
|
||||
|
||||
# Load CA key if not provided
|
||||
if ca_private_key is None:
|
||||
ca_private_key = self._load_ca_key_from_config()
|
||||
|
||||
try:
|
||||
# Parse CA private key
|
||||
try:
|
||||
ca_key = PrivateKey.from_string(ca_private_key)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load CA private key: {str(e)}")
|
||||
raise SSHCASigningError(f"Invalid CA private key: {str(e)}")
|
||||
|
||||
# Parse user's public key
|
||||
try:
|
||||
user_pub_key = PublicKey.from_string(signing_request.ssh_public_key)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to parse user public key: {str(e)}")
|
||||
raise SSHCASigningError(f"Invalid user public key: {str(e)}")
|
||||
|
||||
# Create certificate
|
||||
certificate = SSHCertificate.create(
|
||||
subject_pubkey=user_pub_key,
|
||||
ca_privkey=ca_key,
|
||||
)
|
||||
|
||||
# Set validity period
|
||||
now = datetime.utcnow()
|
||||
expiry_hours = signing_request.expiry_hours or self.config.get_int('cert_validity_hours')
|
||||
valid_before = now + timedelta(hours=expiry_hours)
|
||||
|
||||
# Set certificate fields
|
||||
cert_type = 1 if signing_request.cert_type == "user" else 0
|
||||
|
||||
certificate.fields.cert_type = cert_type
|
||||
certificate.fields.key_id = signing_request.key_id
|
||||
certificate.fields.principals = signing_request.principals
|
||||
certificate.fields.valid_after = now
|
||||
certificate.fields.valid_before = valid_before
|
||||
|
||||
# Set extensions
|
||||
extensions = signing_request.extensions
|
||||
if not extensions and self.config.get_bool('extensions_enabled'):
|
||||
extensions = self.config.get_list('extensions')
|
||||
|
||||
certificate.fields.extensions = extensions or []
|
||||
certificate.fields.critical_options = signing_request.critical_options or {}
|
||||
|
||||
# Validate certificate before signing
|
||||
if not certificate.can_sign():
|
||||
raise SSHCASigningError("Certificate cannot be signed")
|
||||
|
||||
# Sign the certificate
|
||||
certificate.sign()
|
||||
|
||||
# Verify the certificate
|
||||
try:
|
||||
certificate.verify(ca_key.public_key, raise_on_error=True)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Certificate verification failed: {str(e)}")
|
||||
raise SSHCASigningError(f"Certificate verification failed: {str(e)}")
|
||||
|
||||
# Extract serial from certificate
|
||||
serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
|
||||
|
||||
# Build response
|
||||
cert_string = certificate.to_string()
|
||||
|
||||
self.logger.info(
|
||||
f"Successfully signed certificate: serial={serial}, "
|
||||
f"key_id={signing_request.key_id}, principals={signing_request.principals}"
|
||||
)
|
||||
|
||||
return SSHCertificateSigningResponse(
|
||||
certificate=cert_string,
|
||||
serial=serial,
|
||||
valid_after=now,
|
||||
valid_before=valid_before,
|
||||
principals=signing_request.principals,
|
||||
)
|
||||
|
||||
except (SSHCASigningError, ValidationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error during certificate signing: {str(e)}", exc_info=True)
|
||||
raise SSHCASigningError(f"Error signing certificate: {str(e)}")
|
||||
|
||||
def verify_ca_key(self, ca_private_key: str) -> Dict[str, Any]:
|
||||
"""Verify a CA private key is valid and extract metadata.
|
||||
|
||||
Args:
|
||||
ca_private_key: CA private key in PEM format
|
||||
|
||||
Returns:
|
||||
Dictionary with key metadata (fingerprint, key_type, etc.)
|
||||
|
||||
Raises:
|
||||
SSHCASigningError: If key is invalid
|
||||
"""
|
||||
try:
|
||||
ca_key = PrivateKey.from_string(ca_private_key)
|
||||
pub_key = ca_key.public_key
|
||||
|
||||
# Compute fingerprint
|
||||
fingerprint = compute_ssh_fingerprint(pub_key.to_string())
|
||||
|
||||
# Get key type
|
||||
key_type = pub_key.keytype if hasattr(pub_key, 'keytype') else 'unknown'
|
||||
|
||||
return {
|
||||
'fingerprint': fingerprint,
|
||||
'key_type': key_type,
|
||||
'public_key': pub_key.to_string(),
|
||||
'valid': True,
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"CA key verification failed: {str(e)}")
|
||||
raise SSHCASigningError(f"Invalid CA key: {str(e)}")
|
||||
@@ -0,0 +1,373 @@
|
||||
"""SSH Key management service."""
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import subprocess
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models import SSHKey, User
|
||||
from gatehouse_app.exceptions import (
|
||||
SSHKeyError,
|
||||
SSHKeyNotFoundError,
|
||||
SSHKeyAlreadyExistsError,
|
||||
SSHKeyNotVerifiedError,
|
||||
ValidationError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from gatehouse_app.utils.crypto import (
|
||||
compute_ssh_fingerprint,
|
||||
verify_ssh_key_format,
|
||||
extract_ssh_key_type,
|
||||
extract_ssh_key_comment,
|
||||
)
|
||||
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHKeyService:
|
||||
"""Service for managing SSH keys."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize SSH key service."""
|
||||
self.config = get_ssh_ca_config()
|
||||
|
||||
def add_ssh_key(
|
||||
self,
|
||||
user_id: str,
|
||||
public_key: str,
|
||||
description: Optional[str] = None,
|
||||
) -> SSHKey:
|
||||
"""Add an SSH public key for a user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user
|
||||
public_key: SSH public key in OpenSSH format
|
||||
description: Optional description of the key
|
||||
|
||||
Returns:
|
||||
Created SSHKey instance
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If user doesn't exist
|
||||
SSHKeyError: If key format is invalid
|
||||
SSHKeyAlreadyExistsError: If key already exists
|
||||
"""
|
||||
# Verify user exists
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"User {user_id} not found")
|
||||
|
||||
# Validate key format
|
||||
if not verify_ssh_key_format(public_key):
|
||||
raise SSHKeyError("Invalid SSH public key format")
|
||||
|
||||
# Compute fingerprint
|
||||
try:
|
||||
fingerprint = compute_ssh_fingerprint(public_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compute fingerprint: {str(e)}")
|
||||
raise SSHKeyError(f"Failed to compute key fingerprint: {str(e)}")
|
||||
|
||||
# Check for duplicate (including soft-deleted records — fingerprint is unique in DB)
|
||||
existing = SSHKey.query.filter_by(fingerprint=fingerprint).first()
|
||||
if existing:
|
||||
if existing.deleted_at is not None:
|
||||
# Restore the soft-deleted key: clear deleted_at and update fields
|
||||
existing.deleted_at = None
|
||||
existing.user_id = user_id
|
||||
existing.description = description or existing.description
|
||||
existing.verified = False
|
||||
existing.verified_at = None
|
||||
existing.verify_text = None
|
||||
existing.verify_text_created_at = None
|
||||
db.session.commit()
|
||||
logger.info(
|
||||
f"Restored soft-deleted SSH key for user {user_id}: "
|
||||
f"fingerprint={fingerprint}"
|
||||
)
|
||||
return existing
|
||||
raise SSHKeyAlreadyExistsError(
|
||||
f"SSH key with fingerprint {fingerprint} already exists"
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
key_type = extract_ssh_key_type(public_key)
|
||||
key_comment = extract_ssh_key_comment(public_key)
|
||||
|
||||
# Create SSH key record
|
||||
ssh_key = SSHKey(
|
||||
user_id=user_id,
|
||||
payload=public_key,
|
||||
fingerprint=fingerprint,
|
||||
description=description,
|
||||
key_type=key_type,
|
||||
key_comment=key_comment,
|
||||
verified=False,
|
||||
)
|
||||
|
||||
ssh_key.save()
|
||||
|
||||
logger.info(
|
||||
f"SSH key added for user {user_id}: "
|
||||
f"fingerprint={fingerprint}, type={key_type}"
|
||||
)
|
||||
|
||||
return ssh_key
|
||||
|
||||
def get_ssh_key(self, key_id: str) -> SSHKey:
|
||||
"""Get an SSH key by ID.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Returns:
|
||||
SSHKey instance
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = SSHKey.query.filter_by(id=key_id, deleted_at=None).first()
|
||||
if not key:
|
||||
raise SSHKeyNotFoundError(f"SSH key {key_id} not found")
|
||||
return key
|
||||
|
||||
def get_user_ssh_keys(self, user_id: str) -> List[SSHKey]:
|
||||
"""Get all SSH keys for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of SSHKey instances
|
||||
"""
|
||||
return SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all()
|
||||
|
||||
def get_user_verified_ssh_keys(self, user_id: str) -> List[SSHKey]:
|
||||
"""Get all verified SSH keys for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of verified SSHKey instances
|
||||
"""
|
||||
return SSHKey.query.filter_by(
|
||||
user_id=user_id,
|
||||
verified=True,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
|
||||
def delete_ssh_key(self, key_id: str) -> None:
|
||||
"""Soft-delete an SSH key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
key.delete()
|
||||
|
||||
logger.info(f"SSH key deleted: {key_id}")
|
||||
|
||||
def generate_verification_challenge(self, key_id: str) -> str:
|
||||
"""Generate a verification challenge for an SSH key.
|
||||
|
||||
The user must sign this challenge text with their private key
|
||||
to prove key ownership.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Returns:
|
||||
Verification challenge text
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
|
||||
# Generate random challenge
|
||||
challenge = secrets.token_hex(32)
|
||||
challenge_text = f"Please sign this to verify SSH key ownership: {challenge}"
|
||||
|
||||
# Store challenge
|
||||
key.verify_text = challenge_text
|
||||
key.verify_text_created_at = datetime.utcnow()
|
||||
key.save()
|
||||
|
||||
logger.info(f"Generated verification challenge for SSH key {key_id}")
|
||||
|
||||
return challenge_text
|
||||
|
||||
def verify_ssh_key_ownership(
|
||||
self,
|
||||
key_id: str,
|
||||
signature: str,
|
||||
) -> bool:
|
||||
"""Verify SSH key ownership via signature.
|
||||
|
||||
The user must sign the verification challenge with their private key.
|
||||
We verify the signature using the public key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
signature: Base64-encoded signature of the challenge
|
||||
|
||||
Returns:
|
||||
True if signature is valid
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
SSHKeyNotVerifiedError: If challenge is stale or missing
|
||||
SSHKeyError: If verification fails
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
|
||||
# Check if challenge exists and is not stale
|
||||
if not key.verify_text or not key.verify_text_created_at:
|
||||
raise SSHKeyNotVerifiedError("No verification challenge generated")
|
||||
|
||||
max_age = self.config.get_int('verification_challenge_max_age')
|
||||
age = datetime.utcnow() - key.verify_text_created_at
|
||||
if age.total_seconds() > (max_age * 3600):
|
||||
raise SSHKeyNotVerifiedError("Verification challenge has expired")
|
||||
|
||||
try:
|
||||
# Verify the SSH signature using ssh-keygen -Y verify.
|
||||
# The CLI signs the challenge with: ssh-keygen -Y sign -f <key> -n file <challenge>
|
||||
# We verify with: ssh-keygen -Y verify -f <allowed_signers> -I <identity> -n file -s <sig> < <message>
|
||||
#
|
||||
# allowed_signers format: "<identity> <keytype> <pubkey>"
|
||||
# We use the key fingerprint as the identity.
|
||||
|
||||
sig_bytes = base64.b64decode(signature)
|
||||
challenge_text = key.verify_text + "\n"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
allowed_signers_path = os.path.join(tmpdir, "allowed_signers")
|
||||
sig_path = os.path.join(tmpdir, "message.sig")
|
||||
message_path = os.path.join(tmpdir, "message.txt")
|
||||
|
||||
identity = key.fingerprint
|
||||
|
||||
# Write the allowed_signers file
|
||||
with open(allowed_signers_path, "w") as f:
|
||||
f.write(f"{identity} {key.payload}\n")
|
||||
|
||||
# Write the signature file
|
||||
with open(sig_path, "wb") as f:
|
||||
f.write(sig_bytes)
|
||||
|
||||
# Write the challenge message
|
||||
with open(message_path, "w") as f:
|
||||
f.write(challenge_text)
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ssh-keygen", "-Y", "verify",
|
||||
"-f", allowed_signers_path,
|
||||
"-I", identity,
|
||||
"-n", "file",
|
||||
"-s", sig_path,
|
||||
],
|
||||
stdin=open(message_path, "rb"),
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = result.stderr.decode(errors="replace").strip()
|
||||
logger.warning(f"SSH signature verification failed for key {key_id}: {stderr}")
|
||||
raise SSHKeyError(f"Signature verification failed: {stderr}")
|
||||
|
||||
key.mark_verified()
|
||||
logger.info(f"SSH key verified: {key_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SSH key verification failed: {str(e)}")
|
||||
raise SSHKeyError(f"Signature verification failed: {str(e)}")
|
||||
|
||||
def get_key_fingerprint(self, key_id: str) -> str:
|
||||
"""Get the fingerprint of an SSH key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Returns:
|
||||
Fingerprint string
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
return key.fingerprint
|
||||
|
||||
def update_ssh_key_description(self, key_id: str, description: str) -> SSHKey:
|
||||
"""Update the description of an SSH key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
description: New description
|
||||
|
||||
Returns:
|
||||
Updated SSHKey instance
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
key.description = description
|
||||
key.save()
|
||||
|
||||
return key
|
||||
|
||||
def cleanup_expired_challenges(self) -> int:
|
||||
"""Clean up expired verification challenges.
|
||||
|
||||
Returns:
|
||||
Number of challenges cleaned
|
||||
"""
|
||||
max_age = self.config.get_int('verification_challenge_max_age')
|
||||
threshold = datetime.utcnow() - timedelta(hours=max_age)
|
||||
|
||||
expired = SSHKey.query.filter(
|
||||
SSHKey.verify_text_created_at < threshold,
|
||||
SSHKey.verify_text_created_at.isnot(None),
|
||||
SSHKey.deleted_at.is_(None),
|
||||
).update({"verify_text": None, "verify_text_created_at": None})
|
||||
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"Cleaned up {expired} expired verification challenges")
|
||||
return expired
|
||||
|
||||
def cleanup_unverified_keys(self) -> int:
|
||||
"""Delete unverified SSH keys older than configured days.
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
days = self.config.get_int('auto_delete_unverified_days')
|
||||
threshold = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
old_unverified = SSHKey.query.filter(
|
||||
SSHKey.verified == False,
|
||||
SSHKey.created_at < threshold,
|
||||
SSHKey.deleted_at.is_(None),
|
||||
).all()
|
||||
|
||||
count = 0
|
||||
for key in old_unverified:
|
||||
key.delete()
|
||||
count += 1
|
||||
|
||||
logger.info(f"Deleted {count} unverified SSH keys older than {days} days")
|
||||
return count
|
||||
@@ -12,6 +12,16 @@ class UserStatus(str, Enum):
|
||||
COMPLIANCE_SUSPENDED = "compliance_suspended"
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
"""Generic role definitions (hierarchy: Admin > Manager > Member > Viewer > Guest)."""
|
||||
|
||||
ADMIN = "admin"
|
||||
MANAGER = "manager"
|
||||
MEMBER = "member"
|
||||
VIEWER = "viewer"
|
||||
GUEST = "guest"
|
||||
|
||||
|
||||
class OrganizationRole(str, Enum):
|
||||
"""Organization member roles."""
|
||||
|
||||
@@ -105,6 +115,37 @@ class AuditAction(str, Enum):
|
||||
EXTERNAL_AUTH_CONFIG_UPDATE = "external_auth.config.update"
|
||||
EXTERNAL_AUTH_CONFIG_DELETE = "external_auth.config.delete"
|
||||
|
||||
# SSH Key and Certificate actions
|
||||
SSH_KEY_ADDED = "ssh.key.added"
|
||||
SSH_KEY_VERIFIED = "ssh.key.verified"
|
||||
SSH_KEY_DELETED = "ssh.key.deleted"
|
||||
SSH_KEY_VALIDATION_FAILED = "ssh.key.validation.failed"
|
||||
SSH_CERT_REQUESTED = "ssh.cert.requested"
|
||||
SSH_CERT_ISSUED = "ssh.cert.issued"
|
||||
SSH_CERT_FAILED = "ssh.cert.failed"
|
||||
SSH_CERT_REVOKED = "ssh.cert.revoked"
|
||||
SSH_CERT_EXPIRED = "ssh.cert.expired"
|
||||
|
||||
# CA actions
|
||||
CA_CREATED = "ca.created"
|
||||
CA_UPDATED = "ca.updated"
|
||||
CA_DELETED = "ca.deleted"
|
||||
CA_KEY_ROTATED = "ca.key.rotated"
|
||||
|
||||
# Principal actions
|
||||
PRINCIPAL_CREATED = "principal.created"
|
||||
PRINCIPAL_UPDATED = "principal.updated"
|
||||
PRINCIPAL_DELETED = "principal.deleted"
|
||||
PRINCIPAL_MEMBER_ADDED = "principal.member.added"
|
||||
PRINCIPAL_MEMBER_REMOVED = "principal.member.removed"
|
||||
|
||||
# Department actions
|
||||
DEPARTMENT_CREATED = "department.created"
|
||||
DEPARTMENT_UPDATED = "department.updated"
|
||||
DEPARTMENT_DELETED = "department.deleted"
|
||||
DEPARTMENT_MEMBER_ADDED = "department.member.added"
|
||||
DEPARTMENT_MEMBER_REMOVED = "department.member.removed"
|
||||
|
||||
|
||||
class OIDCGrantType(str, Enum):
|
||||
"""OIDC grant types."""
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Cryptographic utilities for SSH operations."""
|
||||
import hashlib
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def compute_ssh_fingerprint(public_key_str: str, hash_algorithm: str = "sha256") -> str:
|
||||
"""Compute the fingerprint of an SSH public key.
|
||||
|
||||
Args:
|
||||
public_key_str: SSH public key in OpenSSH format
|
||||
hash_algorithm: Hash algorithm to use (sha256, sha1, md5)
|
||||
|
||||
Returns:
|
||||
Fingerprint string in the format "algorithm:hex_digest"
|
||||
|
||||
Example:
|
||||
>>> key = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKp2..."
|
||||
>>> fp = compute_ssh_fingerprint(key)
|
||||
>>> print(fp)
|
||||
sha256:Kb+...
|
||||
"""
|
||||
if not public_key_str:
|
||||
raise ValueError("Public key string is empty")
|
||||
|
||||
# Parse OpenSSH format: "ssh-ed25519 <base64> [comment]"
|
||||
parts = public_key_str.strip().split()
|
||||
if len(parts) < 2:
|
||||
raise ValueError("Invalid OpenSSH public key format")
|
||||
|
||||
try:
|
||||
# The base64-encoded key is the second part
|
||||
key_bytes = base64.b64decode(parts[1])
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to decode public key: {str(e)}")
|
||||
|
||||
# Compute hash
|
||||
if hash_algorithm == "sha256":
|
||||
digest = hashlib.sha256(key_bytes).digest()
|
||||
# SSH format uses base64 encoding without padding
|
||||
fingerprint = base64.b64encode(digest).decode().rstrip('=')
|
||||
elif hash_algorithm == "sha1":
|
||||
digest = hashlib.sha1(key_bytes).hexdigest()
|
||||
fingerprint = digest
|
||||
elif hash_algorithm == "md5":
|
||||
digest = hashlib.md5(key_bytes).hexdigest()
|
||||
# Format as colons
|
||||
fingerprint = ':'.join(digest[i:i+2] for i in range(0, len(digest), 2))
|
||||
else:
|
||||
raise ValueError(f"Unsupported hash algorithm: {hash_algorithm}")
|
||||
|
||||
return f"{hash_algorithm}:{fingerprint}"
|
||||
|
||||
|
||||
def verify_ssh_key_format(public_key_str: str) -> bool:
|
||||
"""Verify that a string is in valid OpenSSH public key format.
|
||||
|
||||
Args:
|
||||
public_key_str: Potential SSH public key
|
||||
|
||||
Returns:
|
||||
True if valid OpenSSH format, False otherwise
|
||||
"""
|
||||
if not public_key_str or not isinstance(public_key_str, str):
|
||||
return False
|
||||
|
||||
parts = public_key_str.strip().split()
|
||||
|
||||
# Must have at least key type and key material
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
|
||||
key_type = parts[0]
|
||||
|
||||
# Valid key types
|
||||
valid_types = [
|
||||
'ssh-rsa',
|
||||
'ssh-ed25519',
|
||||
'ecdsa-sha2-nistp256',
|
||||
'ecdsa-sha2-nistp384',
|
||||
'ecdsa-sha2-nistp521',
|
||||
'ssh-dss',
|
||||
]
|
||||
|
||||
if key_type not in valid_types:
|
||||
return False
|
||||
|
||||
# Try to decode base64
|
||||
try:
|
||||
base64.b64decode(parts[1])
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def extract_ssh_key_type(public_key_str: str) -> Optional[str]:
|
||||
"""Extract the key type from an OpenSSH public key.
|
||||
|
||||
Args:
|
||||
public_key_str: SSH public key in OpenSSH format
|
||||
|
||||
Returns:
|
||||
Key type (e.g., "ssh-ed25519") or None if invalid
|
||||
"""
|
||||
if not verify_ssh_key_format(public_key_str):
|
||||
return None
|
||||
|
||||
return public_key_str.strip().split()[0]
|
||||
|
||||
|
||||
def extract_ssh_key_comment(public_key_str: str) -> Optional[str]:
|
||||
"""Extract the comment from an OpenSSH public key.
|
||||
|
||||
Args:
|
||||
public_key_str: SSH public key in OpenSSH format
|
||||
|
||||
Returns:
|
||||
Comment string or None if not present
|
||||
"""
|
||||
if not verify_ssh_key_format(public_key_str):
|
||||
return None
|
||||
|
||||
parts = public_key_str.strip().split()
|
||||
if len(parts) >= 3:
|
||||
# Everything after the second part is the comment
|
||||
return ' '.join(parts[2:])
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,173 @@
|
||||
"""Add SSH CA models: SSHKey, SSHCertificate, CA, CertificateAuditLog.
|
||||
|
||||
Revision ID: 007
|
||||
Revises: 006
|
||||
Create Date: 2026-02-27 11:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '007'
|
||||
down_revision = '006'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### CA table ###
|
||||
op.create_table('cas',
|
||||
sa.Column('organization_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('key_type', sa.Enum('ed25519', 'rsa', 'ecdsa', name='ca_key_type_enum'), nullable=False),
|
||||
sa.Column('private_key', sa.Text(), nullable=False),
|
||||
sa.Column('public_key', sa.Text(), nullable=False),
|
||||
sa.Column('fingerprint', sa.String(length=255), nullable=False),
|
||||
sa.Column('crl_enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('crl_endpoint', sa.String(length=512), nullable=True),
|
||||
sa.Column('default_cert_validity_hours', sa.Integer(), nullable=False),
|
||||
sa.Column('max_cert_validity_hours', sa.Integer(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('rotated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('rotation_reason', sa.String(length=255), nullable=True),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('fingerprint'),
|
||||
sa.UniqueConstraint('organization_id', 'name', name='uix_org_ca_name')
|
||||
)
|
||||
op.create_index(op.f('ix_cas_organization_id'), 'cas', ['organization_id'], unique=False)
|
||||
op.create_index('idx_ca_org_active', 'cas', ['organization_id', 'is_active'], unique=False)
|
||||
|
||||
# ### SSHKey table ###
|
||||
op.create_table('ssh_keys',
|
||||
sa.Column('user_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('payload', sa.Text(), nullable=False),
|
||||
sa.Column('fingerprint', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=True),
|
||||
sa.Column('verified', sa.Boolean(), nullable=False),
|
||||
sa.Column('verified_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('verify_text', sa.String(length=255), nullable=True),
|
||||
sa.Column('verify_text_created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('key_type', sa.String(length=50), nullable=True),
|
||||
sa.Column('key_bits', sa.Integer(), nullable=True),
|
||||
sa.Column('key_comment', sa.String(length=255), nullable=True),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('payload'),
|
||||
sa.UniqueConstraint('fingerprint')
|
||||
)
|
||||
op.create_index(op.f('ix_ssh_keys_user_id'), 'ssh_keys', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_keys_fingerprint'), 'ssh_keys', ['fingerprint'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_keys_verified'), 'ssh_keys', ['verified'], unique=False)
|
||||
op.create_index('idx_ssh_key_user_verified', 'ssh_keys', ['user_id', 'verified'], unique=False)
|
||||
|
||||
# ### SSHCertificate table ###
|
||||
op.create_table('ssh_certificates',
|
||||
sa.Column('ca_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('ssh_key_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('certificate', sa.Text(), nullable=False),
|
||||
sa.Column('serial', sa.String(length=255), nullable=False),
|
||||
sa.Column('key_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('cert_type', sa.Enum('user', 'host', name='ssh_cert_type_enum'), nullable=False),
|
||||
sa.Column('principals', sa.JSON(), nullable=False),
|
||||
sa.Column('valid_after', sa.DateTime(), nullable=False),
|
||||
sa.Column('valid_before', sa.DateTime(), nullable=False),
|
||||
sa.Column('revoked', sa.Boolean(), nullable=False),
|
||||
sa.Column('revoked_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('revoke_reason', sa.String(length=255), nullable=True),
|
||||
sa.Column('status', sa.Enum('requested', 'issued', 'revoked', 'expired', 'superseded', name='ssh_cert_status_enum'), nullable=False),
|
||||
sa.Column('request_ip', sa.String(length=45), nullable=True),
|
||||
sa.Column('request_user_agent', sa.String(length=512), nullable=True),
|
||||
sa.Column('critical_options', sa.JSON(), nullable=True),
|
||||
sa.Column('extensions', sa.JSON(), nullable=True),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ),
|
||||
sa.ForeignKeyConstraint(['ssh_key_id'], ['ssh_keys.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('serial')
|
||||
)
|
||||
op.create_index(op.f('ix_ssh_certificates_ca_id'), 'ssh_certificates', ['ca_id'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_certificates_user_id'), 'ssh_certificates', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_certificates_ssh_key_id'), 'ssh_certificates', ['ssh_key_id'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_certificates_serial'), 'ssh_certificates', ['serial'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_certificates_revoked'), 'ssh_certificates', ['revoked'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_certificates_status'), 'ssh_certificates', ['status'], unique=False)
|
||||
op.create_index('idx_cert_user_status', 'ssh_certificates', ['user_id', 'status'], unique=False)
|
||||
op.create_index('idx_cert_validity', 'ssh_certificates', ['valid_after', 'valid_before'], unique=False)
|
||||
op.create_index('idx_cert_revoked', 'ssh_certificates', ['revoked', 'revoked_at'], unique=False)
|
||||
|
||||
# ### CertificateAuditLog table ###
|
||||
op.create_table('certificate_audit_logs',
|
||||
sa.Column('certificate_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('action', sa.String(length=50), nullable=False),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=512), nullable=True),
|
||||
sa.Column('request_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('message', sa.Text(), nullable=True),
|
||||
sa.Column('extra_data', sa.JSON(), nullable=True),
|
||||
sa.Column('success', sa.Boolean(), nullable=False),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['certificate_id'], ['ssh_certificates.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_certificate_audit_logs_certificate_id'), 'certificate_audit_logs', ['certificate_id'], unique=False)
|
||||
op.create_index(op.f('ix_certificate_audit_logs_user_id'), 'certificate_audit_logs', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_certificate_audit_logs_action'), 'certificate_audit_logs', ['action'], unique=False)
|
||||
op.create_index('idx_cert_audit_cert_action', 'certificate_audit_logs', ['certificate_id', 'action'], unique=False)
|
||||
op.create_index('idx_cert_audit_user', 'certificate_audit_logs', ['user_id', 'created_at'], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index('idx_cert_audit_user', table_name='certificate_audit_logs')
|
||||
op.drop_index('idx_cert_audit_cert_action', table_name='certificate_audit_logs')
|
||||
op.drop_index(op.f('ix_certificate_audit_logs_action'), table_name='certificate_audit_logs')
|
||||
op.drop_index(op.f('ix_certificate_audit_logs_user_id'), table_name='certificate_audit_logs')
|
||||
op.drop_index(op.f('ix_certificate_audit_logs_certificate_id'), table_name='certificate_audit_logs')
|
||||
op.drop_table('certificate_audit_logs')
|
||||
|
||||
op.drop_index('idx_cert_revoked', table_name='ssh_certificates')
|
||||
op.drop_index('idx_cert_validity', table_name='ssh_certificates')
|
||||
op.drop_index('idx_cert_user_status', table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_status'), table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_revoked'), table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_serial'), table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_ssh_key_id'), table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_user_id'), table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_ca_id'), table_name='ssh_certificates')
|
||||
op.drop_table('ssh_certificates')
|
||||
|
||||
op.drop_index('idx_ssh_key_user_verified', table_name='ssh_keys')
|
||||
op.drop_index(op.f('ix_ssh_keys_verified'), table_name='ssh_keys')
|
||||
op.drop_index(op.f('ix_ssh_keys_fingerprint'), table_name='ssh_keys')
|
||||
op.drop_index(op.f('ix_ssh_keys_user_id'), table_name='ssh_keys')
|
||||
op.drop_table('ssh_keys')
|
||||
|
||||
op.drop_index('idx_ca_org_active', table_name='cas')
|
||||
op.drop_index(op.f('ix_cas_organization_id'), table_name='cas')
|
||||
op.drop_table('cas')
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Add TOTP and WEBAUTHN to authmethodtype enum.
|
||||
|
||||
Revision ID: 008
|
||||
Revises: 007
|
||||
Create Date: 2026-02-27 15:00:00.000000
|
||||
|
||||
The original migration (001_base) created authmethodtype with only:
|
||||
PASSWORD, GOOGLE, GITHUB, MICROSOFT, SAML, OIDC
|
||||
|
||||
This migration adds the missing TOTP and WEBAUTHN values so
|
||||
has_totp_enabled() and has_webauthn_enabled() queries work correctly.
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = '008'
|
||||
down_revision = '007'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Add TOTP to the enum (idempotent approach using DO block)
|
||||
op.execute("""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = 'TOTP'
|
||||
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype')
|
||||
) THEN
|
||||
ALTER TYPE authmethodtype ADD VALUE 'TOTP';
|
||||
END IF;
|
||||
END$$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = 'WEBAUTHN'
|
||||
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype')
|
||||
) THEN
|
||||
ALTER TYPE authmethodtype ADD VALUE 'WEBAUTHN';
|
||||
END IF;
|
||||
END$$;
|
||||
""")
|
||||
|
||||
|
||||
def downgrade():
|
||||
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Sync auditaction enum with all AuditAction Python enum values.
|
||||
|
||||
Revision ID: 009
|
||||
Revises: 008
|
||||
Create Date: 2026-02-27 15:20:00.000000
|
||||
|
||||
The auditaction DB enum was only created with the initial 17 values from 001_base.py.
|
||||
All TOTP, WebAuthn, OAuth, SSH, CA, Principal, and Department audit actions were added
|
||||
to the Python enum but never synced to the DB type.
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision = '009'
|
||||
down_revision = '008'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
MISSING_VALUES = [
|
||||
'TOTP_ENROLL_INITIATED', 'TOTP_ENROLL_COMPLETED', 'TOTP_VERIFY_SUCCESS',
|
||||
'TOTP_VERIFY_FAILED', 'TOTP_DISABLED', 'TOTP_BACKUP_CODE_USED',
|
||||
'TOTP_BACKUP_CODES_REGENERATED', 'WEBAUTHN_REGISTER_INITIATED',
|
||||
'WEBAUTHN_REGISTER_COMPLETED', 'WEBAUTHN_REGISTER_FAILED',
|
||||
'WEBAUTHN_LOGIN_INITIATED', 'WEBAUTHN_LOGIN_SUCCESS', 'WEBAUTHN_LOGIN_FAILED',
|
||||
'WEBAUTHN_CREDENTIAL_DELETED', 'WEBAUTHN_CREDENTIAL_RENAMED',
|
||||
'ORG_SECURITY_POLICY_UPDATE', 'USER_SECURITY_POLICY_OVERRIDE_UPDATE',
|
||||
'MFA_POLICY_USER_SUSPENDED', 'MFA_POLICY_USER_COMPLIANT',
|
||||
'EXTERNAL_AUTH_LINK_INITIATED', 'EXTERNAL_AUTH_LINK_COMPLETED',
|
||||
'EXTERNAL_AUTH_LINK_FAILED', 'EXTERNAL_AUTH_UNLINK', 'EXTERNAL_AUTH_LOGIN',
|
||||
'EXTERNAL_AUTH_LOGIN_FAILED', 'EXTERNAL_AUTH_TOKEN_REFRESH',
|
||||
'EXTERNAL_AUTH_CONFIG_CREATE', 'EXTERNAL_AUTH_CONFIG_UPDATE',
|
||||
'EXTERNAL_AUTH_CONFIG_DELETE', 'SSH_KEY_ADDED', 'SSH_KEY_VERIFIED',
|
||||
'SSH_KEY_DELETED', 'SSH_KEY_VALIDATION_FAILED', 'SSH_CERT_REQUESTED',
|
||||
'SSH_CERT_ISSUED', 'SSH_CERT_FAILED', 'SSH_CERT_REVOKED', 'SSH_CERT_EXPIRED',
|
||||
'CA_CREATED', 'CA_UPDATED', 'CA_DELETED', 'CA_KEY_ROTATED',
|
||||
'PRINCIPAL_CREATED', 'PRINCIPAL_UPDATED', 'PRINCIPAL_DELETED',
|
||||
'PRINCIPAL_MEMBER_ADDED', 'PRINCIPAL_MEMBER_REMOVED',
|
||||
'DEPARTMENT_CREATED', 'DEPARTMENT_UPDATED', 'DEPARTMENT_DELETED',
|
||||
'DEPARTMENT_MEMBER_ADDED', 'DEPARTMENT_MEMBER_REMOVED',
|
||||
]
|
||||
|
||||
|
||||
def upgrade():
|
||||
for val in MISSING_VALUES:
|
||||
op.execute(f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = '{val}'
|
||||
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
|
||||
) THEN
|
||||
ALTER TYPE auditaction ADD VALUE '{val}';
|
||||
END IF;
|
||||
END$$;
|
||||
""")
|
||||
|
||||
|
||||
def downgrade():
|
||||
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Make CA.organization_id nullable (system CA) and add cert_id to sign response
|
||||
|
||||
Revision ID: 012_ca_nullable_org_and_cert_serial
|
||||
Revises: 011_org_invite_tokens
|
||||
Create Date: 2025-01-01 00:00:00.000000
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = '012_ca_nullable_org'
|
||||
down_revision = '011_org_invite_tokens'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Allow CA records without an org (e.g. the global system-config CA)
|
||||
with op.batch_alter_table('cas', schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
'organization_id',
|
||||
existing_type=sa.String(36),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table('cas', schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
'organization_id',
|
||||
existing_type=sa.String(36),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -47,4 +47,7 @@ Flask-Limiter==3.5.0
|
||||
|
||||
# Logging
|
||||
python-json-logger==2.0.7
|
||||
qrcode[pil]
|
||||
qrcode[pil]
|
||||
|
||||
# SSH CA Certificate signing
|
||||
sshkey-tools==0.11.0
|
||||
|
||||
@@ -20,3 +20,25 @@ watchdog==3.0.0
|
||||
|
||||
# Documentation
|
||||
sphinx==7.2.6
|
||||
|
||||
# Web framework & Database
|
||||
Flask==3.0.0
|
||||
Flask-SQLAlchemy==3.1.1
|
||||
Flask-Migrate==4.0.5
|
||||
sqlalchemy-cockroachdb==2.0.3
|
||||
|
||||
# Utilities
|
||||
colorlog==6.8.0
|
||||
coloredlogs==15.0.1
|
||||
prettytable==3.10.2
|
||||
tabulate==0.9.0
|
||||
requests==2.31.0
|
||||
pytz==2023.3
|
||||
python-dotenv==1.0.0
|
||||
pydantic==2.5.0
|
||||
PyJWT==2.8.0
|
||||
cryptography==41.0.7
|
||||
pycryptodome==3.20.0
|
||||
psycopg2==2.9.9
|
||||
sshkey-tools==0.10.3
|
||||
sendgrid==6.11.0
|
||||
|
||||
Reference in New Issue
Block a user