Merge pull request #2 from jamesii-b/gatehouse/secuird-CA-merge-v2.01
Gatehouse with secuird CA Merge (Gatehouse Isolated)
This commit is contained in:
+6
-2
@@ -40,5 +40,9 @@ LOG_TO_STDOUT=True
|
||||
RATELIMIT_ENABLED=True
|
||||
RATELIMIT_STORAGE_URL=redis://localhost:6379/1
|
||||
|
||||
# Testing
|
||||
TESTING=False
|
||||
# SSH CA
|
||||
# Path to CA private key file (alternative to SSH_CA_PRIVATE_KEY env var)
|
||||
SSH_CA_KEY_PATH=/path/to/ca-users
|
||||
# Or set the key content directly (takes priority over SSH_CA_KEY_PATH):
|
||||
# SSH_CA_PRIVATE_KEY=
|
||||
|
||||
|
||||
Executable
+528
@@ -0,0 +1,528 @@
|
||||
#!/usr/bin/python3
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
import webbrowser
|
||||
import requests
|
||||
import argparse
|
||||
import jwt
|
||||
import json
|
||||
import datetime
|
||||
import pytz
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from urllib.parse import urlparse, parse_qsl
|
||||
from dotenv import load_dotenv
|
||||
from sshkey_tools.cert import SSHCertificate
|
||||
import logging
|
||||
import coloredlogs
|
||||
import subprocess
|
||||
|
||||
# Load environment variables from the .env file
|
||||
load_dotenv()
|
||||
|
||||
# Get the API_URL from the environment variables
|
||||
SIGN_URL = os.getenv("SIGN_URL", "http://localhost:5000")
|
||||
LISTENER_HOST_NAME = "127.0.0.1"
|
||||
LISTENER_SERVER_PORT = 8250
|
||||
CACHE_FILE = os.path.expanduser('~/.gatehouse/token_cache.json')
|
||||
os.makedirs(os.path.dirname(CACHE_FILE), exist_ok=True)
|
||||
CERT_FILE_PATH = "/tmp/ssh-cert"
|
||||
CHALLENGE_FILE_PATH = "/tmp/challenge.txt"
|
||||
CHALLENGE_SIG_FILE_PATH = "/tmp/challenge.txt.sig"
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
coloredlogs.install(level='DEBUG', logger=logger, fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
token = ""
|
||||
|
||||
def auth_headers(content_type="application/json"):
|
||||
"""Return auth headers using the current cached token."""
|
||||
return {"Authorization": f"Bearer {token}", "Content-Type": content_type}
|
||||
|
||||
|
||||
class MyServer(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
"""Handle GET requests and process token reception."""
|
||||
global server_done, token
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(bytes("<html><head><title>OIDC Workflow Tool</title></head>", "utf-8"))
|
||||
self.wfile.write(bytes("<body><p>The token has been received</p>", "utf-8"))
|
||||
self.wfile.write(bytes("<p>You may now close this window.</p>", "utf-8"))
|
||||
self.wfile.write(bytes("</body></html>", "utf-8"))
|
||||
|
||||
parsed_url = urlparse(self.path)
|
||||
query_data = dict(parse_qsl(parsed_url.query))
|
||||
received_token = query_data.get('token')
|
||||
|
||||
if received_token:
|
||||
token = received_token
|
||||
server_done = True
|
||||
logger.info("Token received")
|
||||
save_token_to_cache(token)
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""Log messages using the logger instead of stdout."""
|
||||
logger.info("%s - %s" % (self.client_address[0], format % args))
|
||||
|
||||
|
||||
def load_token_from_cache():
|
||||
"""Load the token from the cache file."""
|
||||
if os.path.exists(CACHE_FILE):
|
||||
with open(CACHE_FILE, 'r') as f:
|
||||
data = json.load(f)
|
||||
if 'token' in data:
|
||||
return data['token']
|
||||
return None
|
||||
|
||||
def save_token_to_cache(token):
|
||||
"""Save the token to the cache file."""
|
||||
with open(CACHE_FILE, 'w') as f:
|
||||
json.dump({'token': token}, f)
|
||||
|
||||
def clear_token_cache():
|
||||
"""Remove the cached token file."""
|
||||
if os.path.exists(CACHE_FILE):
|
||||
os.remove(CACHE_FILE)
|
||||
logger.info("Cached token removed.")
|
||||
else:
|
||||
logger.info("No cached token found.")
|
||||
|
||||
def decode_and_validate_token(token):
|
||||
"""Decode the JWT and validate its claims.
|
||||
|
||||
Returns True if the token is a valid, non-expired JWT.
|
||||
Returns False if the token is not a JWT (e.g. opaque session token)
|
||||
or if it has expired — callers should then fall back to /auth/me.
|
||||
"""
|
||||
try:
|
||||
decoded_token = jwt.decode(token, options={"verify_signature": False})
|
||||
except jwt.exceptions.DecodeError:
|
||||
# Not a JWT — likely an opaque session token; let /auth/me handle it.
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug(f"Unexpected JWT decode error: {e}")
|
||||
return False
|
||||
|
||||
iat = decoded_token.get('iat')
|
||||
exp = decoded_token.get('exp')
|
||||
|
||||
if iat is None or exp is None:
|
||||
logger.debug("JWT is missing 'iat' or 'exp' claims — treating as invalid.")
|
||||
return False
|
||||
|
||||
now = datetime.datetime.now(pytz.UTC)
|
||||
exp_dt = datetime.datetime.fromtimestamp(exp, pytz.UTC)
|
||||
iat_dt = datetime.datetime.fromtimestamp(iat, pytz.UTC)
|
||||
|
||||
logger.debug(f"JWT iat={iat_dt.isoformat()} exp={exp_dt.isoformat()}")
|
||||
|
||||
if exp_dt < now:
|
||||
logger.debug("JWT has expired.")
|
||||
return False
|
||||
|
||||
if iat_dt > now:
|
||||
logger.debug("JWT 'iat' is in the future — clock skew?")
|
||||
|
||||
return True
|
||||
|
||||
def request_token():
|
||||
global server_done, token
|
||||
server_done = False
|
||||
logger.info("Starting request_token process.")
|
||||
|
||||
# Attempt to load the token from the cache
|
||||
token = load_token_from_cache()
|
||||
logger.debug("Token loaded from cache: %s", token)
|
||||
|
||||
# Validate the cached token, if it exists
|
||||
if token:
|
||||
try:
|
||||
if decode_and_validate_token(token):
|
||||
logger.info("Cached token is valid. Using cached token.")
|
||||
return token
|
||||
except Exception:
|
||||
pass
|
||||
# Try opaque token via /auth/me
|
||||
try:
|
||||
r = requests.get(
|
||||
f"{SIGN_URL}/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=5,
|
||||
)
|
||||
if r.status_code == 200:
|
||||
logger.info("Cached session token is valid. Using cached token.")
|
||||
return token
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("Cached token is expired or invalid, requesting a new token.")
|
||||
token = ""
|
||||
|
||||
# Prepare the redirect URL for the token request
|
||||
redirect_url = f"http://{LISTENER_HOST_NAME}:{LISTENER_SERVER_PORT}/?token="
|
||||
logger.info("Redirect URL: %s", redirect_url)
|
||||
|
||||
# Construct the token request URL
|
||||
token_url = f"{SIGN_URL}/api/v1/token_please?redirect_url={redirect_url}"
|
||||
logger.info("Token request URL: %s", token_url)
|
||||
|
||||
# Start the web server to handle the token response
|
||||
logger.debug("Starting the HTTP server on %s:%d", LISTENER_HOST_NAME, LISTENER_SERVER_PORT)
|
||||
webServer = HTTPServer((LISTENER_HOST_NAME, LISTENER_SERVER_PORT), MyServer)
|
||||
|
||||
# Open the web browser to initiate the token request
|
||||
logger.info("Opening web browser to request token.")
|
||||
webbrowser.open(token_url, new=2)
|
||||
|
||||
# Wait for the server to handle the request and receive the token
|
||||
logger.debug("Waiting for the token response...")
|
||||
while not server_done:
|
||||
webServer.handle_request()
|
||||
logger.debug("Server handled a request, server_done status: %s", server_done)
|
||||
|
||||
logger.info("Token received: %s", token)
|
||||
return token
|
||||
|
||||
def get_activated_ssh_key():
|
||||
"""Retrieve the list of SSH keys and return the ID of a verified key."""
|
||||
try:
|
||||
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers())
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to retrieve SSH keys: {response.status_code} - {response.text}")
|
||||
exit(1)
|
||||
|
||||
keys = response.json().get('data', {}).get('keys', [])
|
||||
verified_keys = [k for k in keys if k['verified']]
|
||||
|
||||
if not verified_keys:
|
||||
logger.error("No verified SSH keys found for the user.")
|
||||
exit(1)
|
||||
|
||||
if len(verified_keys) > 1 and sys.stdout.isatty():
|
||||
print("\nMultiple verified SSH keys found. Please choose one:")
|
||||
for i, k in enumerate(verified_keys):
|
||||
print(f" [{i+1}] {k['id'][:8]}... fingerprint={k.get('fingerprint','?')} name={k.get('key_comment','?')}")
|
||||
try:
|
||||
choice = int(input("Enter number: ").strip()) - 1
|
||||
if 0 <= choice < len(verified_keys):
|
||||
return verified_keys[choice]['id']
|
||||
except (ValueError, EOFError):
|
||||
pass
|
||||
logger.info("Invalid choice; using the most recently added key.")
|
||||
|
||||
verified_keys.sort(key=lambda k: k.get('created_at', ''), reverse=True)
|
||||
return verified_keys[0]['id']
|
||||
|
||||
except SystemExit:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error while retrieving SSH keys: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
def fetch_my_principals():
|
||||
"""Fetch all principal names the current user is entitled to from the API.
|
||||
For regular members: returns their assigned principals.
|
||||
For org admins/owners: returns all principals in the org (they can sign for any).
|
||||
"""
|
||||
global token
|
||||
response = requests.get(
|
||||
f"{SIGN_URL}/api/v1/users/me/principals",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to fetch principals from server: {response.status_code} - {response.text}")
|
||||
exit(1)
|
||||
|
||||
orgs = response.json().get("data", {}).get("orgs", [])
|
||||
principal_names = []
|
||||
for org in orgs:
|
||||
# Admins/owners get all principals; regular members get only their assigned ones
|
||||
if org.get("is_admin"):
|
||||
source = org.get("all_principals", [])
|
||||
else:
|
||||
source = org.get("my_principals", [])
|
||||
for p in source:
|
||||
if p["name"] not in principal_names:
|
||||
principal_names.append(p["name"])
|
||||
|
||||
return principal_names
|
||||
|
||||
|
||||
def request_certificate():
|
||||
CERT_ID = os.getenv("CERT_ID") or get_activated_ssh_key()
|
||||
|
||||
principals = fetch_my_principals()
|
||||
if not principals:
|
||||
logger.error("You have no principals assigned. Contact your org admin.")
|
||||
exit(1)
|
||||
logger.info(f"Requesting certificate for principals: {', '.join(principals)}")
|
||||
|
||||
headers = {
|
||||
'content-type': 'application/json',
|
||||
"Authorization": "bearer " + token
|
||||
}
|
||||
|
||||
payload = {
|
||||
'cert_id': CERT_ID,
|
||||
'principals': principals,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{SIGN_URL}/api/v1/ssh/sign", json=payload, headers=headers)
|
||||
|
||||
if response.status_code == 201:
|
||||
json_result = response.json().get('data', response.json())
|
||||
with open(CERT_FILE_PATH, 'w') as f:
|
||||
f.write(json_result['certificate'])
|
||||
logger.info(f"Certificate signed successfully, located at {CERT_FILE_PATH}")
|
||||
logger.info(f"Valid for principals: {', '.join(json_result.get('principals', principals))}")
|
||||
logger.info("You can login to your destination server with the following command")
|
||||
logger.info(f"\tssh user@server -o CertificateFile={CERT_FILE_PATH}")
|
||||
else:
|
||||
logger.error("Error in response from server")
|
||||
logger.error(f"Status code: {response.status_code}")
|
||||
logger.error(f"Response text: {response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during certificate signing: {e}")
|
||||
|
||||
def generate_and_sign_challenge(ssh_key_file, key_id):
|
||||
"""Fetch a challenge from the server, sign it with the SSH key, and submit the signature."""
|
||||
logger.debug(f"generate_and_sign_challenge - {ssh_key_file} {key_id}")
|
||||
|
||||
# Fetch challenge text
|
||||
try:
|
||||
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify", headers=auth_headers())
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Server returned unexpected code {response.status_code}")
|
||||
return False
|
||||
resp_json = response.json()
|
||||
data = resp_json.get('data', resp_json)
|
||||
challenge_text = data.get('challenge_text', data.get('validationText', '')) + "\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to fetch SSH Key validation data: {e}")
|
||||
return False
|
||||
|
||||
# Sign the challenge
|
||||
try:
|
||||
for path in (CHALLENGE_FILE_PATH, CHALLENGE_SIG_FILE_PATH):
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
|
||||
with open(CHALLENGE_FILE_PATH, 'w') as f:
|
||||
f.write(challenge_text)
|
||||
|
||||
subprocess.run(
|
||||
["ssh-keygen", "-Y", "sign", "-f", ssh_key_file, "-n", "file", CHALLENGE_FILE_PATH],
|
||||
check=True,
|
||||
)
|
||||
|
||||
with open(CHALLENGE_SIG_FILE_PATH, 'rb') as f:
|
||||
signature = base64.b64encode(f.read()).decode('utf-8')
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to sign the challenge response: {e}")
|
||||
return False
|
||||
|
||||
# Submit signature
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{SIGN_URL}/api/v1/ssh/keys/{key_id}/verify",
|
||||
headers=auth_headers(),
|
||||
json={"signature": signature},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
logger.info("SSH key verified successfully.")
|
||||
else:
|
||||
logger.error(f"Verification failed: {response.status_code} - {response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to submit the challenge response: {e}")
|
||||
|
||||
return signature
|
||||
|
||||
def remove_ssh_key(key_id=None):
|
||||
"""
|
||||
Remove an SSH key from the server. If key_id is None, list keys and prompt user to pick one.
|
||||
"""
|
||||
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers())
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}")
|
||||
exit(1)
|
||||
|
||||
keys = response.json().get('data', {}).get('keys', [])
|
||||
if not keys:
|
||||
logger.info("No SSH keys found for your user.")
|
||||
return
|
||||
|
||||
if key_id:
|
||||
target = next((k for k in keys if k['id'] == key_id), None)
|
||||
if not target:
|
||||
logger.error(f"Key ID {key_id} not found in your profile.")
|
||||
exit(1)
|
||||
keys_to_delete = [target]
|
||||
else:
|
||||
print("\nYour SSH keys:")
|
||||
for i, k in enumerate(keys):
|
||||
verified = "✓ verified" if k['verified'] else "✗ unverified"
|
||||
print(f" [{i+1}] {k['id']} {verified} {k.get('description', '')} (added {k['created_at'][:10]})")
|
||||
print(" [a] Delete ALL keys")
|
||||
print(" [q] Quit")
|
||||
choice = input("\nEnter number to delete (or 'a' for all, 'q' to quit): ").strip().lower()
|
||||
|
||||
if choice == 'q':
|
||||
return
|
||||
elif choice == 'a':
|
||||
keys_to_delete = keys
|
||||
else:
|
||||
try:
|
||||
idx = int(choice) - 1
|
||||
if idx < 0 or idx >= len(keys):
|
||||
raise ValueError()
|
||||
keys_to_delete = [keys[idx]]
|
||||
except ValueError:
|
||||
logger.error("Invalid selection.")
|
||||
exit(1)
|
||||
|
||||
for k in keys_to_delete:
|
||||
del_response = requests.delete(f"{SIGN_URL}/api/v1/ssh/keys/{k['id']}", headers=auth_headers())
|
||||
if del_response.status_code == 200:
|
||||
logger.info(f"Key {k['id']} removed successfully.")
|
||||
else:
|
||||
logger.error(f"Failed to remove key {k['id']}: {del_response.status_code} - {del_response.text}")
|
||||
|
||||
|
||||
def add_ssh_key(ssh_key_file):
|
||||
"""Add an SSH key to the server and auto-verify it."""
|
||||
if hasattr(ssh_key_file, 'read'):
|
||||
key_bytes = ssh_key_file.read()
|
||||
key_path = ssh_key_file.name
|
||||
elif isinstance(ssh_key_file, bytes):
|
||||
key_bytes = ssh_key_file
|
||||
key_path = None
|
||||
else:
|
||||
key_path = str(ssh_key_file)
|
||||
with open(key_path, 'rb') as f:
|
||||
key_bytes = f.read()
|
||||
|
||||
ssh_key = key_bytes.decode('utf-8').strip()
|
||||
payload = {
|
||||
'description': 'Added via gatehouse CLI tool',
|
||||
'key': ssh_key,
|
||||
}
|
||||
|
||||
response = requests.post(f"{SIGN_URL}/api/v1/ssh/keys", json=payload, headers=auth_headers())
|
||||
if response.status_code == 201:
|
||||
ssh_key_id = response.json().get('data', {}).get('id')
|
||||
logger.info(f"SSH key {ssh_key_id} added successfully")
|
||||
if key_path:
|
||||
private_key_path = key_path[:-4] if key_path.endswith('.pub') else key_path
|
||||
generate_and_sign_challenge(private_key_path, ssh_key_id)
|
||||
else:
|
||||
logger.warning("No key file path available — skipping auto-verification. "
|
||||
"Run with -k <path> to enable automatic key verification.")
|
||||
else:
|
||||
logger.error(f"Failed to add SSH key: {response.status_code} - {response.text}")
|
||||
|
||||
def checkCert():
|
||||
logger.info("Running cert check")
|
||||
if not os.path.isfile(CERT_FILE_PATH):
|
||||
logger.warning("Certificate does not exist, new certificate required")
|
||||
return 1
|
||||
|
||||
try:
|
||||
certificate = SSHCertificate.from_file(CERT_FILE_PATH)
|
||||
except Exception:
|
||||
logger.warning("Certificate file is invalid or corrupt, renewal required")
|
||||
return 1
|
||||
|
||||
# Get the current datetime
|
||||
now = datetime.datetime.now()
|
||||
logger.debug(certificate
|
||||
)
|
||||
|
||||
# Check if the date is in the past or future
|
||||
if certificate.get("valid_before") > now:
|
||||
# Expiry is in the future
|
||||
if args.force:
|
||||
return 0
|
||||
else:
|
||||
logger.info("You have a valid SSH Certificate with the principals {} expiring at {}, not renewing. Use -f to force renewal".format(certificate.get("principals"), certificate.get("valid_before")))
|
||||
return 0
|
||||
else:
|
||||
logger.warning("Certificate is not valid, renewal required")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Sign an SSH key via a web service')
|
||||
parser.add_argument("-k", "--ssh-key", type=argparse.FileType('rb'), dest="sshkeyfile", help="Add an SSH Public Key to your user profile in gatehouse")
|
||||
parser.add_argument("-f", "--force", action='store_true', default=False, help="Force the certificate renewal")
|
||||
parser.add_argument("-a", "--add-key", action='store_true', default=False, help="Add SSH key to the server")
|
||||
parser.add_argument("-c", "--check-cert", action='store_true', default=False, help="Check the certificate, if it's valid exit 0, if it's invalid exit 1")
|
||||
parser.add_argument("-r", "--request-cert", action='store_true', default=False, help="Request that gatehouse sign a new certificate for you based on an SSH public key on file in your profile")
|
||||
parser.add_argument("--clear-cache", action='store_true', default=False, help="Remove the cached authentication token")
|
||||
parser.add_argument("--remove-key", nargs='?', const='', metavar='KEY_ID', help="Remove an SSH key from your profile. Omit KEY_ID to pick interactively.")
|
||||
parser.add_argument("--list-keys", action='store_true', default=False, help="List SSH keys in your profile")
|
||||
|
||||
args = parser.parse_args()
|
||||
if not (args.check_cert or args.request_cert or args.add_key or args.clear_cache
|
||||
or args.remove_key is not None or args.list_keys):
|
||||
parser.error("At least one of --check-cert, --request-cert, --add-key, --list-keys, --remove-key, or --clear-cache must be provided.")
|
||||
|
||||
|
||||
# Retrieve SSH key from environment variables if not provided via CLI
|
||||
ssh_key_file = args.sshkeyfile if args.sshkeyfile else os.getenv('SSH_KEY_FILE')
|
||||
|
||||
if args.check_cert:
|
||||
logger.info("Only checking certificate")
|
||||
exit(checkCert())
|
||||
|
||||
if args.clear_cache:
|
||||
clear_token_cache()
|
||||
exit(0)
|
||||
|
||||
if args.remove_key is not None:
|
||||
request_token()
|
||||
remove_ssh_key(args.remove_key if args.remove_key else None)
|
||||
exit(0)
|
||||
|
||||
if args.list_keys:
|
||||
request_token()
|
||||
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers())
|
||||
if response.status_code == 200:
|
||||
keys = response.json().get('data', {}).get('keys', [])
|
||||
if not keys:
|
||||
print("No SSH keys found in your profile.")
|
||||
else:
|
||||
for k in keys:
|
||||
verified = "✓ verified" if k.get('verified') else "✗ unverified"
|
||||
print(f" {k['id']} {verified} {k.get('description', '')} (added {k['created_at'][:10]})")
|
||||
else:
|
||||
logger.error(f"Failed to list SSH keys: {response.status_code} - {response.text}")
|
||||
exit(0)
|
||||
|
||||
if args.add_key:
|
||||
request_token()
|
||||
|
||||
if not ssh_key_file:
|
||||
logger.error("SSH key file is required to add SSH key")
|
||||
exit(1)
|
||||
|
||||
# If ssh_key_file is retrieved from the environment, it will be a string (file path), so open it
|
||||
if isinstance(ssh_key_file, str):
|
||||
with open(ssh_key_file, 'rb') as f:
|
||||
ssh_key_file = f.read()
|
||||
|
||||
add_ssh_key(ssh_key_file)
|
||||
exit(0)
|
||||
|
||||
|
||||
if args.request_cert:
|
||||
request_token()
|
||||
if args.force:
|
||||
logger.info("Forcing renewal of certificate")
|
||||
if args.force or checkCert() == 1:
|
||||
request_certificate()
|
||||
exit(0)
|
||||
@@ -28,6 +28,11 @@ class BaseConfig:
|
||||
|
||||
# Encryption key for sensitive data (client secrets, tokens, etc.)
|
||||
ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", "dev-encryption-key-change-in-production")
|
||||
|
||||
# Encryption key for CA private keys stored in the database.
|
||||
# Must be set to a strong random secret in production.
|
||||
# Any string is accepted — it is SHA-256 derived to a 32-byte Fernet key internally.
|
||||
CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "dev-ca-encryption-key-change-in-production")
|
||||
|
||||
# Session configuration for WebAuthn cross-origin support
|
||||
SESSION_COOKIE_SECURE = os.getenv("SESSION_COOKIE_SECURE", "True").lower() == "true"
|
||||
@@ -72,6 +77,13 @@ class BaseConfig:
|
||||
RATELIMIT_STORAGE_URL = os.getenv("RATELIMIT_STORAGE_URL", "redis://localhost:6379/1")
|
||||
RATELIMIT_DEFAULT = "100/hour"
|
||||
|
||||
# Per-endpoint auth rate limits (override via env vars for each environment)
|
||||
RATELIMIT_AUTH_REGISTER = os.getenv("RATELIMIT_AUTH_REGISTER", "10 per minute; 50 per hour")
|
||||
RATELIMIT_AUTH_LOGIN = os.getenv("RATELIMIT_AUTH_LOGIN", "20 per minute; 100 per hour")
|
||||
RATELIMIT_AUTH_TOTP_VERIFY = os.getenv("RATELIMIT_AUTH_TOTP_VERIFY", "20 per minute; 100 per hour")
|
||||
RATELIMIT_AUTH_FORGOT_PASSWORD = os.getenv("RATELIMIT_AUTH_FORGOT_PASSWORD", "5 per minute; 20 per hour")
|
||||
RATELIMIT_AUTH_RESET_PASSWORD = os.getenv("RATELIMIT_AUTH_RESET_PASSWORD", "10 per minute; 30 per hour")
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_TO_STDOUT = os.getenv("LOG_TO_STDOUT", "False").lower() == "true"
|
||||
@@ -116,3 +128,11 @@ class BaseConfig:
|
||||
|
||||
# Frontend URL (for OAuth callback redirects)
|
||||
FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:8080")
|
||||
|
||||
# Email / SMTP
|
||||
EMAIL_ENABLED = os.getenv("EMAIL_ENABLED", "False").lower() == "true"
|
||||
SMTP_HOST = os.getenv("SMTP_HOST", "smtp.gmail.com")
|
||||
SMTP_PORT = int(os.getenv("SMTP_PORT", "587"))
|
||||
SMTP_USERNAME = os.getenv("SMTP_USERNAME", "")
|
||||
SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "")
|
||||
FROM_ADDRESS = os.getenv("FROM_ADDRESS", "noreply@gatehouse.local")
|
||||
|
||||
@@ -25,3 +25,15 @@ class DevelopmentConfig(BaseConfig):
|
||||
"CORS_ORIGINS",
|
||||
"http://localhost:8080,http://localhost:3000,http://localhost:5173,https://ui.webauthn.local"
|
||||
).split(",")
|
||||
|
||||
# ── Email / SMTP ──────────────────────────────────────────────────────────
|
||||
# Read from .env so real SMTP credentials work in dev.
|
||||
# Set EMAIL_ENABLED=false in .env to disable; defaults to True if SMTP_HOST is set.
|
||||
EMAIL_ENABLED = os.getenv("EMAIL_ENABLED", "True").lower() == "true"
|
||||
SMTP_HOST = os.getenv("SMTP_HOST", "localhost")
|
||||
SMTP_PORT = int(os.getenv("SMTP_PORT", "1025"))
|
||||
SMTP_USERNAME = os.getenv("SMTP_USERNAME") or None
|
||||
SMTP_PASSWORD = os.getenv("SMTP_PASSWORD") or None
|
||||
SMTP_USE_TLS = os.getenv("SMTP_USE_TLS", "").lower() == "true" if os.getenv("SMTP_USE_TLS") else int(os.getenv("SMTP_PORT", "1025")) not in (25, 1025)
|
||||
FROM_ADDRESS = os.getenv("FROM_ADDRESS", "noreply@gatehouse.local")
|
||||
EMAIL_FROM = FROM_ADDRESS # alias
|
||||
|
||||
@@ -12,6 +12,9 @@ class TestingConfig(BaseConfig):
|
||||
# Explicitly set SECRET_KEY for testing
|
||||
SECRET_KEY = os.getenv("SECRET_KEY", "test-secret-key-for-testing")
|
||||
|
||||
# CA key encryption — use a fixed test key so tests are deterministic
|
||||
CA_ENCRYPTION_KEY = os.getenv("CA_ENCRYPTION_KEY", "test-ca-encryption-key-fixed-for-tests")
|
||||
|
||||
# Use in-memory SQLite for testing
|
||||
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
|
||||
SQLALCHEMY_ECHO = False
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
|
||||
[default]
|
||||
# Certificate validity period (in hours)
|
||||
cert_validity_hours=8
|
||||
|
||||
# Maximum certificate validity allowed (in hours)
|
||||
max_cert_validity_hours=720
|
||||
|
||||
# CA private key path (required for local encryption mode)
|
||||
ca_key_path=
|
||||
|
||||
# Certificate Field Limits
|
||||
max_principals_per_cert=256
|
||||
max_key_id_length=255
|
||||
|
||||
# Verification challenge max age (in hours)
|
||||
verification_challenge_max_age=24
|
||||
|
||||
# Cleanup: delete unverified SSH keys after this many days
|
||||
auto_delete_unverified_days=30
|
||||
|
||||
[development]
|
||||
ca_key_path=${SSH_CA_KEY_PATH}
|
||||
cert_validity_hours=24
|
||||
|
||||
[production]
|
||||
cert_validity_hours=8
|
||||
|
||||
[testing]
|
||||
cert_validity_hours=8
|
||||
@@ -2,6 +2,9 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), '.env'))
|
||||
|
||||
# Test debug logging - this should appear when running `flask run --debug`
|
||||
_root_logger = logging.getLogger(__name__)
|
||||
_root_logger.debug("[TEST] Debug logging is working!")
|
||||
@@ -239,3 +242,6 @@ def initialize_oidc_jwks(app):
|
||||
app.logger.info(f"[OIDC] Signing key initialized: kid={signing_key.kid}")
|
||||
except Exception as e:
|
||||
app.logger.error(f"[OIDC] Failed to initialize JWKS: {e}")
|
||||
|
||||
# Create default app instance for gunicorn/wsgi
|
||||
app = create_app()
|
||||
|
||||
@@ -21,8 +21,12 @@ from gatehouse_app.extensions import db
|
||||
from gatehouse_app.extensions import bcrypt as flask_bcrypt
|
||||
from gatehouse_app.extensions import redis_client as _redis_client_ref # may be None until app init
|
||||
from gatehouse_app.models import User, OIDCClient
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.exceptions.auth_exceptions import (
|
||||
InvalidCredentialsError,
|
||||
AccountSuspendedError,
|
||||
AccountInactiveError,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for Redis-backed OIDC pending state
|
||||
@@ -326,7 +330,7 @@ def oidc_complete():
|
||||
400: invalid request
|
||||
401: invalid token
|
||||
"""
|
||||
from gatehouse_app.models.session import Session as GHSession
|
||||
from gatehouse_app.models.user.session import Session as GHSession
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
@@ -343,6 +347,20 @@ def oidc_complete():
|
||||
|
||||
user_id = str(gh_session.user_id)
|
||||
|
||||
# Check the user is still active (not suspended after session was issued)
|
||||
from gatehouse_app.models.user.user import User as _User
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
_complete_user = _User.query.filter_by(id=user_id, deleted_at=None).first()
|
||||
if not _complete_user or _complete_user.status in (
|
||||
UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED, UserStatus.INACTIVE
|
||||
):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Your account is not active or has been suspended.",
|
||||
status=403,
|
||||
error_type="ACCOUNT_SUSPENDED",
|
||||
)
|
||||
|
||||
# Retrieve stashed OIDC params (consume = True removes from Redis atomically)
|
||||
params = _fetch_oidc_params(oidc_session_id, consume=True)
|
||||
if not params:
|
||||
@@ -565,6 +583,28 @@ def oidc_authorize():
|
||||
session["oidc_user_id"] = user_id
|
||||
|
||||
logger.debug("[OIDC] User authentication successful: user_id=%s, email=%s", user_id, email)
|
||||
except AccountSuspendedError:
|
||||
logger.debug("[OIDC] User authentication failed: account suspended for email=%s", email)
|
||||
return _show_login_page(
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
response_type=response_type,
|
||||
error="Your account has been suspended. Please contact an administrator.",
|
||||
)
|
||||
except AccountInactiveError:
|
||||
logger.debug("[OIDC] User authentication failed: account inactive for email=%s", email)
|
||||
return _show_login_page(
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
response_type=response_type,
|
||||
error="Your account is not active. Please verify your email.",
|
||||
)
|
||||
except InvalidCredentialsError:
|
||||
logger.debug("[OIDC] User authentication failed: invalid credentials for email=%s", email)
|
||||
return _show_login_page(
|
||||
@@ -600,7 +640,34 @@ def oidc_authorize():
|
||||
if not user:
|
||||
logger.debug("[OIDC] Redirecting with error: server_error (user not found)")
|
||||
return _redirect_with_error(redirect_uri, "server_error", "User not found", state)
|
||||
|
||||
|
||||
# Check account is still active (user could have been suspended after session start)
|
||||
from gatehouse_app.utils.constants import UserStatus as _UserStatus
|
||||
if user.status in (_UserStatus.SUSPENDED, _UserStatus.COMPLIANCE_SUSPENDED):
|
||||
session.pop("oidc_user_id", None) # clear stale session
|
||||
logger.debug("[OIDC] User is suspended, clearing session and showing login error: user_id=%s", user_id)
|
||||
return _show_login_page(
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
response_type=response_type,
|
||||
error="Your account has been suspended. Please contact an administrator.",
|
||||
)
|
||||
if user.status == _UserStatus.INACTIVE:
|
||||
session.pop("oidc_user_id", None)
|
||||
logger.debug("[OIDC] User is inactive, clearing session and showing login error: user_id=%s", user_id)
|
||||
return _show_login_page(
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
response_type=response_type,
|
||||
error="Your account is not active. Please verify your email.",
|
||||
)
|
||||
|
||||
logger.debug("[OIDC] Generating authorization code...")
|
||||
logger.debug("[OIDC] Authorization code params: client_id=%s, user_id=%s, redirect_uri=%s", client_id, user_id, redirect_uri)
|
||||
logger.debug("[OIDC] Authorization code params: scopes=%s, state=%s, nonce=%s", valid_scopes, state, nonce)
|
||||
|
||||
@@ -5,4 +5,6 @@ from flask import Blueprint
|
||||
api_v1_bp = Blueprint("api_v1", __name__)
|
||||
|
||||
# Import route modules to register them
|
||||
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth
|
||||
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh
|
||||
|
||||
api_v1_bp.register_blueprint(ssh.ssh_bp)
|
||||
|
||||
+557
-10
@@ -1,8 +1,10 @@
|
||||
"""Authentication endpoints."""
|
||||
import json
|
||||
from flask import request, session, g, jsonify
|
||||
import logging
|
||||
from flask import request, session, g, jsonify, current_app
|
||||
from marshmallow import ValidationError
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.extensions import limiter
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.schemas.auth_schema import (
|
||||
RegisterSchema,
|
||||
@@ -23,6 +25,7 @@ from gatehouse_app.services.auth_service import AuthService
|
||||
from gatehouse_app.services.webauthn_service import WebAuthnService
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.services.mfa_policy_service import MfaPolicyService
|
||||
from gatehouse_app.services.notification_service import NotificationService
|
||||
from gatehouse_app.utils.decorators import login_required
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||
@@ -30,6 +33,7 @@ from gatehouse_app.exceptions.validation_exceptions import ConflictError, NotFou
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/register", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_REGISTER"])
|
||||
def register():
|
||||
"""
|
||||
Register a new user.
|
||||
@@ -57,14 +61,66 @@ def register():
|
||||
full_name=data.get("full_name"),
|
||||
)
|
||||
|
||||
# Send verification email
|
||||
try:
|
||||
from gatehouse_app.models import EmailVerificationToken
|
||||
verify_token = EmailVerificationToken.generate(user_id=user.id)
|
||||
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
|
||||
verify_link = f"{app_url}/verify-email?token={verify_token.token}"
|
||||
subject = "Verify your Gatehouse email address"
|
||||
body = (
|
||||
f"Hi {user.full_name or user.email},\n\n"
|
||||
f"Welcome to Gatehouse! Please verify your email address by clicking the link below (valid for 24 hours):\n"
|
||||
f"{verify_link}\n\n"
|
||||
f"Gatehouse Security Team"
|
||||
)
|
||||
NotificationService._send_email(to_address=user.email, subject=subject, body=body)
|
||||
except Exception as exc:
|
||||
logging.getLogger(__name__).warning(f"Failed to send verification email on register: {exc}")
|
||||
|
||||
# Create session
|
||||
user_session = AuthService.create_session(user)
|
||||
|
||||
# ── Post-registration hints ─────────────────────────────────────────
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
from gatehouse_app.models.user.user import User as _User
|
||||
from datetime import datetime, timezone as _tz
|
||||
|
||||
now = datetime.now(_tz.utc)
|
||||
pending_invites = OrgInviteToken.query.filter(
|
||||
OrgInviteToken.email == user.email,
|
||||
OrgInviteToken.accepted_at.is_(None),
|
||||
OrgInviteToken.expires_at > now,
|
||||
OrgInviteToken.deleted_at.is_(None),
|
||||
).all()
|
||||
|
||||
# Determine if this is the very first user ever registered on this
|
||||
# instance (exactly 1 active user means it must be this one).
|
||||
total_users = _User.query.filter(_User.deleted_at.is_(None)).count()
|
||||
is_first_user = total_users == 1
|
||||
|
||||
expires_str = user_session.expires_at.isoformat()
|
||||
if expires_str[-1] != "Z":
|
||||
expires_str += "Z"
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"user": user.to_dict(),
|
||||
"token": user_session.token,
|
||||
"expires_at": user_session.expires_at.isoformat() + "Z" if user_session.expires_at.isoformat()[-1] != "Z" else user_session.expires_at.isoformat(),
|
||||
"expires_at": expires_str,
|
||||
"is_first_user": is_first_user,
|
||||
"pending_invites": [
|
||||
{
|
||||
"token": inv.token,
|
||||
"organization": {
|
||||
"id": str(inv.organization_id),
|
||||
"name": inv.organization.name,
|
||||
},
|
||||
"role": inv.role,
|
||||
"expires_at": inv.expires_at.isoformat(),
|
||||
}
|
||||
for inv in pending_invites
|
||||
],
|
||||
},
|
||||
message="Registration successful",
|
||||
status=201,
|
||||
@@ -81,6 +137,7 @@ def register():
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/login", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_LOGIN"])
|
||||
def login():
|
||||
"""
|
||||
Login user.
|
||||
@@ -179,7 +236,9 @@ def login():
|
||||
"organization_id": org.organization_id,
|
||||
"organization_name": org.organization_name,
|
||||
"status": org.status,
|
||||
"effective_mode": org.effective_mode,
|
||||
"deadline_at": org.deadline_at,
|
||||
"applied_at": org.applied_at,
|
||||
}
|
||||
for org in policy_result.compliance_summary.orgs
|
||||
],
|
||||
@@ -189,6 +248,36 @@ def login():
|
||||
if is_compliance_only:
|
||||
response_data["requires_mfa_enrollment"] = True
|
||||
|
||||
# ── Org-setup hint for org-less users ────────────────────────────────
|
||||
# If the user has no organisation memberships, surface any pending
|
||||
# invitations so the UI can redirect straight to /org-setup instead of
|
||||
# showing an empty dashboard.
|
||||
user_orgs = user.get_organizations()
|
||||
if not user_orgs:
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
from datetime import datetime, timezone as _tz
|
||||
_now = datetime.now(_tz.utc)
|
||||
pending_invites = OrgInviteToken.query.filter(
|
||||
OrgInviteToken.email == user.email,
|
||||
OrgInviteToken.accepted_at.is_(None),
|
||||
OrgInviteToken.expires_at > _now,
|
||||
OrgInviteToken.deleted_at.is_(None),
|
||||
).all()
|
||||
response_data["pending_invites"] = [
|
||||
{
|
||||
"token": inv.token,
|
||||
"organization": {
|
||||
"id": str(inv.organization_id),
|
||||
"name": inv.organization.name,
|
||||
},
|
||||
"role": inv.role,
|
||||
"expires_at": inv.expires_at.isoformat(),
|
||||
}
|
||||
for inv in pending_invites
|
||||
]
|
||||
# Flag so the UI knows to send this user through org-setup
|
||||
response_data["requires_org_setup"] = True
|
||||
|
||||
return api_response(
|
||||
data=response_data,
|
||||
message="Login successful",
|
||||
@@ -239,8 +328,13 @@ def get_current_user():
|
||||
data={
|
||||
"user": user.to_dict(),
|
||||
"organizations": [
|
||||
{"id": org.id, "name": org.name, "slug": org.slug}
|
||||
for org in user.get_organizations()
|
||||
{
|
||||
"id": membership.organization.id,
|
||||
"name": membership.organization.name,
|
||||
"slug": membership.organization.slug,
|
||||
"role": membership.role,
|
||||
}
|
||||
for membership in user.organization_memberships
|
||||
],
|
||||
},
|
||||
message="User retrieved successfully",
|
||||
@@ -284,7 +378,7 @@ def revoke_session(session_id):
|
||||
401: Not authenticated
|
||||
404: Session not found
|
||||
"""
|
||||
from gatehouse_app.models.session import Session
|
||||
from gatehouse_app.models.user.session import Session
|
||||
|
||||
# Ensure session belongs to current user
|
||||
user_session = Session.query.filter_by(
|
||||
@@ -392,6 +486,7 @@ def verify_totp_enrollment():
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/totp/verify", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_TOTP_VERIFY"])
|
||||
def verify_totp():
|
||||
"""
|
||||
Verify TOTP code during login.
|
||||
@@ -424,7 +519,7 @@ def verify_totp():
|
||||
)
|
||||
|
||||
# Get user from database
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.user.user import User
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
return api_response(
|
||||
@@ -434,6 +529,18 @@ def verify_totp():
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
)
|
||||
|
||||
# Check account suspension before completing TOTP verification
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
session.pop("totp_pending_user_id", None)
|
||||
session.pop("webauthn_pending_user_id", None)
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Account is suspended. Contact an administrator.",
|
||||
status=403,
|
||||
error_type="ACCOUNT_SUSPENDED",
|
||||
)
|
||||
|
||||
# Verify TOTP code
|
||||
AuthService.authenticate_with_totp(
|
||||
user,
|
||||
@@ -475,7 +582,9 @@ def verify_totp():
|
||||
"organization_id": org.organization_id,
|
||||
"organization_name": org.organization_name,
|
||||
"status": org.status,
|
||||
"effective_mode": org.effective_mode,
|
||||
"deadline_at": org.deadline_at,
|
||||
"applied_at": org.applied_at,
|
||||
}
|
||||
for org in policy_result.compliance_summary.orgs
|
||||
],
|
||||
@@ -806,7 +915,7 @@ def begin_webauthn_login():
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Find user by email
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.user.user import User
|
||||
user = User.query.filter_by(
|
||||
email=data["email"].lower(),
|
||||
deleted_at=None
|
||||
@@ -820,7 +929,18 @@ def begin_webauthn_login():
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
|
||||
# Check account suspension before proceeding
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
logger.warning(f"WebAuthn login begin - suspended account attempt: {user.email}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Account is suspended. Contact an administrator.",
|
||||
status=403,
|
||||
error_type="ACCOUNT_SUSPENDED",
|
||||
)
|
||||
|
||||
# Check if user has any WebAuthn credentials
|
||||
if not user.has_webauthn_enabled():
|
||||
logger.warning(f"WebAuthn login begin - no credentials for user: {user.email}")
|
||||
@@ -893,7 +1013,7 @@ def complete_webauthn_login():
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Get user from database
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.user.user import User
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
logger.error(f"WebAuthn login complete - user not found: {user_id}")
|
||||
@@ -903,7 +1023,19 @@ def complete_webauthn_login():
|
||||
status=401,
|
||||
error_type="AUTHENTICATION_ERROR",
|
||||
)
|
||||
|
||||
|
||||
# Check account suspension before completing login
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
session.pop("webauthn_pending_user_id", None)
|
||||
logger.warning(f"WebAuthn login complete - suspended account attempt: {user.email}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Account is suspended. Contact an administrator.",
|
||||
status=403,
|
||||
error_type="ACCOUNT_SUSPENDED",
|
||||
)
|
||||
|
||||
# Extract challenge from client data
|
||||
client_data = data.get("response", {}).get("clientDataJSON", "")
|
||||
|
||||
@@ -962,7 +1094,9 @@ def complete_webauthn_login():
|
||||
"organization_id": org.organization_id,
|
||||
"organization_name": org.organization_name,
|
||||
"status": org.status,
|
||||
"effective_mode": org.effective_mode,
|
||||
"deadline_at": org.deadline_at,
|
||||
"applied_at": org.applied_at,
|
||||
}
|
||||
for org in policy_result.compliance_summary.orgs
|
||||
],
|
||||
@@ -1039,6 +1173,19 @@ def delete_webauthn_credential(credential_id):
|
||||
"""
|
||||
user = g.current_user
|
||||
|
||||
# First check that the specific credential actually belongs to this user.
|
||||
# Only then check whether it is the last one — otherwise a user with zero
|
||||
# credentials gets a misleading "Cannot delete the last passkey" error
|
||||
# instead of a 404.
|
||||
credential_exists = WebAuthnService.credential_belongs_to_user(credential_id, user)
|
||||
if not credential_exists:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Credential not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Check if this is the last credential
|
||||
credential_count = user.get_webauthn_credential_count()
|
||||
if credential_count <= 1:
|
||||
@@ -1142,3 +1289,403 @@ def get_webauthn_status():
|
||||
},
|
||||
message="WebAuthn status retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
_pw_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/forgot-password", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_FORGOT_PASSWORD"])
|
||||
def forgot_password():
|
||||
"""Request a password reset email.
|
||||
|
||||
Always returns 200 to avoid leaking account existence.
|
||||
|
||||
Request body:
|
||||
email: User email address
|
||||
|
||||
Returns:
|
||||
200: Password reset email sent (or silently no-op if email not found)
|
||||
"""
|
||||
from gatehouse_app.models import User, PasswordResetToken
|
||||
|
||||
data = request.get_json() or {}
|
||||
email = (data.get("email") or "").strip().lower()
|
||||
|
||||
if not email:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Email is required",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
# Always return 200 — don't leak whether the email exists
|
||||
user = User.query.filter_by(email=email, deleted_at=None).first()
|
||||
if user:
|
||||
try:
|
||||
reset_token = PasswordResetToken.generate(user_id=user.id)
|
||||
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
|
||||
reset_link = f"{app_url}/reset-password?token={reset_token.token}"
|
||||
subject = "Reset your Gatehouse password"
|
||||
body = (
|
||||
f"Hi {user.full_name or user.email},\n\n"
|
||||
f"You requested a password reset for your Gatehouse account.\n\n"
|
||||
f"Click the link below to reset your password (valid for 2 hours):\n"
|
||||
f"{reset_link}\n\n"
|
||||
f"If you did not request this, you can safely ignore this email.\n\n"
|
||||
f"Gatehouse Security Team"
|
||||
)
|
||||
NotificationService._send_email(
|
||||
to_address=user.email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
)
|
||||
_pw_logger.info(f"Password reset token generated for user {user.id}")
|
||||
except Exception as exc:
|
||||
_pw_logger.exception(f"Error generating password reset token: {exc}")
|
||||
|
||||
return api_response(
|
||||
data={},
|
||||
message="If an account exists for this email, you will receive a password reset link shortly.",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/reset-password", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESET_PASSWORD"])
|
||||
def reset_password():
|
||||
"""Reset a user's password using a reset token.
|
||||
|
||||
Request body:
|
||||
token: Password reset token from email
|
||||
password: New password
|
||||
password_confirm: Password confirmation
|
||||
|
||||
Returns:
|
||||
200: Password reset successfully
|
||||
400: Invalid or expired token / validation error
|
||||
"""
|
||||
import bcrypt as _bcrypt
|
||||
from gatehouse_app.extensions import bcrypt
|
||||
from gatehouse_app.models import PasswordResetToken, AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
data = request.get_json() or {}
|
||||
token_value = (data.get("token") or "").strip()
|
||||
new_password = data.get("password") or ""
|
||||
password_confirm = data.get("password_confirm") or ""
|
||||
|
||||
if not token_value or not new_password:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Token and new password are required",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
if new_password != password_confirm:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Passwords do not match",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
if len(new_password) < 8:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Password must be at least 8 characters",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
reset_token = PasswordResetToken.query.filter_by(token=token_value).first()
|
||||
if not reset_token or not reset_token.is_valid:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="This password reset link is invalid or has expired.",
|
||||
status=400,
|
||||
error_type="INVALID_TOKEN",
|
||||
)
|
||||
|
||||
try:
|
||||
user = reset_token.user
|
||||
# Update the password hash on the authentication method
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.PASSWORD,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if auth_method:
|
||||
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
|
||||
from gatehouse_app.extensions import db
|
||||
db.session.add(auth_method)
|
||||
|
||||
reset_token.consume()
|
||||
_pw_logger.info(f"Password reset for user {user.id}")
|
||||
|
||||
return api_response(
|
||||
data={},
|
||||
message="Your password has been reset. You can now sign in with your new password.",
|
||||
)
|
||||
except Exception as exc:
|
||||
_pw_logger.exception(f"Error resetting password: {exc}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred while resetting your password.",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/verify-email", methods=["POST"])
|
||||
def verify_email():
|
||||
"""Verify a user's email address using a verification token.
|
||||
|
||||
Request body:
|
||||
token: Email verification token
|
||||
|
||||
Returns:
|
||||
200: Email verified successfully
|
||||
400: Invalid or expired token
|
||||
"""
|
||||
from gatehouse_app.models import EmailVerificationToken
|
||||
|
||||
data = request.get_json() or {}
|
||||
token_value = (data.get("token") or "").strip()
|
||||
|
||||
if not token_value:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Verification token is required",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
verify_token = EmailVerificationToken.query.filter_by(token=token_value).first()
|
||||
if not verify_token or not verify_token.is_valid:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="This verification link is invalid or has expired.",
|
||||
status=400,
|
||||
error_type="INVALID_TOKEN",
|
||||
)
|
||||
|
||||
try:
|
||||
user = verify_token.user
|
||||
user.email_verified = True
|
||||
from gatehouse_app.extensions import db
|
||||
db.session.add(user)
|
||||
verify_token.consume()
|
||||
_pw_logger.info(f"Email verified for user {user.id}")
|
||||
|
||||
return api_response(
|
||||
data={},
|
||||
message="Your email has been verified. You can now sign in.",
|
||||
)
|
||||
except Exception as exc:
|
||||
_pw_logger.exception(f"Error verifying email: {exc}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred while verifying your email.",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/resend-verification", methods=["POST"])
|
||||
def resend_verification():
|
||||
"""Resend email verification link.
|
||||
|
||||
Always returns 200 to avoid leaking account existence.
|
||||
|
||||
Request body:
|
||||
email: User email address
|
||||
|
||||
Returns:
|
||||
200: Verification email sent (or silently no-op)
|
||||
"""
|
||||
from gatehouse_app.models import User, EmailVerificationToken
|
||||
|
||||
data = request.get_json() or {}
|
||||
email = (data.get("email") or "").strip().lower()
|
||||
|
||||
if not email:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Email is required",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
user = User.query.filter_by(email=email, deleted_at=None).first()
|
||||
if user and not user.email_verified:
|
||||
try:
|
||||
verify_token = EmailVerificationToken.generate(user_id=user.id)
|
||||
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
|
||||
verify_link = f"{app_url}/verify-email?token={verify_token.token}"
|
||||
subject = "Verify your Gatehouse email address"
|
||||
body = (
|
||||
f"Hi {user.full_name or user.email},\n\n"
|
||||
f"Please verify your email address by clicking the link below (valid for 24 hours):\n"
|
||||
f"{verify_link}\n\n"
|
||||
f"Gatehouse Security Team"
|
||||
)
|
||||
NotificationService._send_email(
|
||||
to_address=user.email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
)
|
||||
_pw_logger.info(f"Verification email sent for user {user.id}")
|
||||
except Exception as exc:
|
||||
_pw_logger.exception(f"Error sending verification email: {exc}")
|
||||
|
||||
return api_response(
|
||||
data={},
|
||||
message="If an account exists for this email and is not yet verified, you will receive a verification link shortly.",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Account Activation (separate from email-verification)
|
||||
# =============================================================================
|
||||
|
||||
@api_v1_bp.route("/auth/activate", methods=["POST"])
|
||||
def activate_account():
|
||||
"""Activate a user account via a one-time activation code.
|
||||
|
||||
Request body:
|
||||
code – the activation_key from the welcome email
|
||||
|
||||
Returns:
|
||||
200: Account activated, session token returned
|
||||
400: Missing code
|
||||
404: Invalid or already-used code
|
||||
"""
|
||||
import secrets
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
data = request.get_json() or {}
|
||||
code = (data.get("code") or "").strip()
|
||||
if not code:
|
||||
return api_response(success=False, message="Activation code is required", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
user = User.query.filter_by(activation_key=code, deleted_at=None).first()
|
||||
if not user:
|
||||
return api_response(success=False, message="Invalid or expired activation code", status=404, error_type="NOT_FOUND")
|
||||
|
||||
user.activated = True
|
||||
user.activation_key = None # one-time use
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
user_session = AuthService.create_session(user)
|
||||
_pw_logger.info(f"Account activated for user {user.id}")
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"user": user.to_dict(),
|
||||
"token": user_session.token,
|
||||
"expires_at": user_session.expires_at.isoformat() + "Z"
|
||||
if user_session.expires_at.isoformat()[-1] != "Z"
|
||||
else user_session.expires_at.isoformat(),
|
||||
},
|
||||
message="Account activated successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/resend-activation", methods=["POST"])
|
||||
def resend_activation():
|
||||
"""Re-send an account activation email.
|
||||
|
||||
Always returns 200 to avoid leaking whether an account exists.
|
||||
|
||||
Request body:
|
||||
email – user email address
|
||||
"""
|
||||
import secrets
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
data = request.get_json() or {}
|
||||
email = (data.get("email") or "").strip().lower()
|
||||
if not email:
|
||||
return api_response(success=False, message="Email is required", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
user = User.query.filter_by(email=email, deleted_at=None).first()
|
||||
if user and not user.activated:
|
||||
try:
|
||||
code = secrets.token_urlsafe(32)
|
||||
user.activation_key = code
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
app_url = current_app.config.get("APP_URL", current_app.config.get("FRONTEND_URL", "http://localhost:8080"))
|
||||
activate_link = f"{app_url}/activate?code={code}"
|
||||
subject = "Activate your Gatehouse account"
|
||||
body = (
|
||||
f"Hi {user.full_name or user.email},\n\n"
|
||||
f"Please activate your Gatehouse account by clicking the link below:\n"
|
||||
f"{activate_link}\n\n"
|
||||
f"If you did not create an account, you can safely ignore this email.\n\n"
|
||||
f"Gatehouse Security Team"
|
||||
)
|
||||
NotificationService._send_email(to_address=user.email, subject=subject, body=body)
|
||||
_pw_logger.info(f"Activation email re-sent to {user.id}")
|
||||
except Exception as exc:
|
||||
_pw_logger.exception(f"Error re-sending activation email: {exc}")
|
||||
|
||||
return api_response(
|
||||
data={},
|
||||
message="If an unactivated account exists for this email, you will receive a new activation link shortly.",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Token retrieval / redirect (for CLI / external tools)
|
||||
# =============================================================================
|
||||
|
||||
@api_v1_bp.route("/auth/token", methods=["GET"])
|
||||
@login_required
|
||||
def get_token():
|
||||
"""Return the current session token, optionally redirecting to a URL.
|
||||
|
||||
Query parameters:
|
||||
redirect – optional URL to redirect to with the token appended as
|
||||
a query param: ``<redirect>?token=<token>``
|
||||
|
||||
Returns:
|
||||
200: JSON ``{"token": "<token>"}`` (no redirect given)
|
||||
302: Redirect to ``<redirect>?token=<token>``
|
||||
"""
|
||||
from flask import redirect as flask_redirect
|
||||
from urllib.parse import urlparse
|
||||
|
||||
token = g.current_session.token
|
||||
redirect_url = request.args.get("redirect", "").strip()
|
||||
|
||||
if redirect_url:
|
||||
# Validate redirect URL against allowed origins to prevent open-redirect
|
||||
# token exfiltration attacks (CWE-601).
|
||||
allowed_origins = set(current_app.config.get("CORS_ORIGINS", []))
|
||||
frontend_url = current_app.config.get("FRONTEND_URL", "")
|
||||
if frontend_url:
|
||||
parsed = urlparse(frontend_url)
|
||||
allowed_origins.add(f"{parsed.scheme}://{parsed.netloc}")
|
||||
|
||||
parsed_redirect = urlparse(redirect_url)
|
||||
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
|
||||
|
||||
if redirect_origin not in allowed_origins:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Redirect URL is not allowed.",
|
||||
status=400,
|
||||
error_type="INVALID_REDIRECT",
|
||||
)
|
||||
|
||||
sep = "&" if "?" in redirect_url else "?"
|
||||
return flask_redirect(f"{redirect_url}{sep}token={token}", code=302)
|
||||
|
||||
return api_response(data={"token": token}, message="Token retrieved")
|
||||
|
||||
@@ -0,0 +1,699 @@
|
||||
"""Department endpoints."""
|
||||
from flask import g, request
|
||||
from marshmallow import Schema, fields, validate, ValidationError
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
|
||||
from gatehouse_app.models import Department, DepartmentMembership
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
|
||||
class DepartmentCreateSchema(Schema):
|
||||
"""Schema for creating a department."""
|
||||
name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
|
||||
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
|
||||
|
||||
|
||||
class DepartmentUpdateSchema(Schema):
|
||||
"""Schema for updating a department."""
|
||||
name = fields.Str(validate=validate.Length(min=1, max=255))
|
||||
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
|
||||
|
||||
|
||||
class AddDepartmentMemberSchema(Schema):
|
||||
"""Schema for adding a member to a department."""
|
||||
email = fields.Email(required=True)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/departments", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def list_departments(org_id):
|
||||
"""
|
||||
List all departments in an organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
|
||||
Returns:
|
||||
200: List of departments
|
||||
401: Not authenticated
|
||||
403: Not a member
|
||||
404: Organization not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
# Check if user is a member
|
||||
if not org.is_member(g.current_user.id):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are not a member of this organization",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
departments = Department.query.filter_by(
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).all()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"departments": [d.to_dict() for d in departments],
|
||||
"count": len(departments),
|
||||
},
|
||||
message="Departments retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/departments", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def create_department(org_id):
|
||||
"""
|
||||
Create a new department.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
|
||||
Request body:
|
||||
name: Department name (required)
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
201: Department created successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization not found
|
||||
409: Department name already exists in org
|
||||
"""
|
||||
try:
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
schema = DepartmentCreateSchema()
|
||||
data = schema.load(request.json or {})
|
||||
|
||||
# Check if department name already exists
|
||||
existing = Department.query.filter_by(
|
||||
organization_id=org_id,
|
||||
name=data["name"],
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=f"Department '{data['name']}' already exists in this organization",
|
||||
status=409,
|
||||
error_type="CONFLICT",
|
||||
)
|
||||
|
||||
# Create department
|
||||
dept = Department(
|
||||
organization_id=org_id,
|
||||
name=data["name"],
|
||||
description=data.get("description"),
|
||||
)
|
||||
db.session.add(dept)
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
data={"department": dept.to_dict()},
|
||||
message="Department created successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def get_department(org_id, dept_id):
|
||||
"""
|
||||
Get a specific department.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
dept_id: Department ID
|
||||
|
||||
Returns:
|
||||
200: Department data
|
||||
401: Not authenticated
|
||||
403: Not a member
|
||||
404: Organization or department not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
if not org.is_member(g.current_user.id):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are not a member of this organization",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
dept = Department.query.filter_by(
|
||||
id=dept_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not dept:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Department not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"department": dept.to_dict()},
|
||||
message="Department retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>", methods=["PATCH"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def update_department(org_id, dept_id):
|
||||
"""
|
||||
Update a department.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
dept_id: Department ID
|
||||
|
||||
Request body:
|
||||
name: Optional new name
|
||||
description: Optional new description
|
||||
|
||||
Returns:
|
||||
200: Department updated successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization or department not found
|
||||
409: Name already exists
|
||||
"""
|
||||
try:
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
dept = Department.query.filter_by(
|
||||
id=dept_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not dept:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Department not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
schema = DepartmentUpdateSchema()
|
||||
data = schema.load(request.json or {})
|
||||
|
||||
# Check if new name already exists
|
||||
if "name" in data and data["name"] != dept.name:
|
||||
existing = Department.query.filter_by(
|
||||
organization_id=org_id,
|
||||
name=data["name"],
|
||||
deleted_at=None
|
||||
).first()
|
||||
if existing:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=f"Department '{data['name']}' already exists",
|
||||
status=409,
|
||||
error_type="CONFLICT",
|
||||
)
|
||||
|
||||
# Update fields
|
||||
for key, value in data.items():
|
||||
setattr(dept, key, value)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
data={"department": dept.to_dict()},
|
||||
message="Department updated successfully",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def delete_department(org_id, dept_id):
|
||||
"""
|
||||
Delete a department (soft delete).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
dept_id: Department ID
|
||||
|
||||
Returns:
|
||||
200: Department deleted successfully
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization or department not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
dept = Department.query.filter_by(
|
||||
id=dept_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not dept:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Department not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
dept.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
message="Department deleted successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/members", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def get_department_members(org_id, dept_id):
|
||||
"""
|
||||
Get all members of a department.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
dept_id: Department ID
|
||||
|
||||
Returns:
|
||||
200: List of members
|
||||
401: Not authenticated
|
||||
403: Not a member
|
||||
404: Organization or department not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
if not org.is_member(g.current_user.id):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are not a member of this organization",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
dept = Department.query.filter_by(
|
||||
id=dept_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not dept:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Department not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
members = DepartmentMembership.query.filter_by(
|
||||
department_id=dept_id,
|
||||
deleted_at=None
|
||||
).all()
|
||||
|
||||
members_data = []
|
||||
for member in members:
|
||||
member_dict = member.to_dict()
|
||||
member_dict["user"] = member.user.to_dict()
|
||||
members_data.append(member_dict)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"members": members_data,
|
||||
"count": len(members_data),
|
||||
},
|
||||
message="Members retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/members", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def add_department_member(org_id, dept_id):
|
||||
"""
|
||||
Add a member to a department.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
dept_id: Department ID
|
||||
|
||||
Request body:
|
||||
email: User email to add
|
||||
|
||||
Returns:
|
||||
201: Member added successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization, department, or user not found
|
||||
409: User already a member
|
||||
"""
|
||||
try:
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
dept = Department.query.filter_by(
|
||||
id=dept_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not dept:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Department not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
schema = AddDepartmentMemberSchema()
|
||||
data = schema.load(request.json or {})
|
||||
|
||||
# Find user by email
|
||||
user = UserService.get_user_by_email(data["email"])
|
||||
if not user:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Check if already an active member
|
||||
existing = DepartmentMembership.query.filter_by(
|
||||
user_id=user.id,
|
||||
department_id=dept_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User is already a member of this department",
|
||||
status=409,
|
||||
error_type="CONFLICT",
|
||||
)
|
||||
|
||||
# Check for a previously soft-deleted row and resurrect it instead of inserting
|
||||
soft_deleted = DepartmentMembership.query.filter(
|
||||
DepartmentMembership.user_id == user.id,
|
||||
DepartmentMembership.department_id == dept_id,
|
||||
DepartmentMembership.deleted_at.isnot(None)
|
||||
).first()
|
||||
|
||||
if soft_deleted:
|
||||
soft_deleted.deleted_at = None
|
||||
membership = soft_deleted
|
||||
else:
|
||||
membership = DepartmentMembership(
|
||||
user_id=user.id,
|
||||
department_id=dept_id,
|
||||
)
|
||||
db.session.add(membership)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
member_dict = membership.to_dict()
|
||||
member_dict["user"] = user.to_dict()
|
||||
|
||||
return api_response(
|
||||
data={"member": member_dict},
|
||||
message="Member added successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/members/<user_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def remove_department_member(org_id, dept_id, user_id):
|
||||
"""
|
||||
Remove a member from a department.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
dept_id: Department ID
|
||||
user_id: User ID to remove
|
||||
|
||||
Returns:
|
||||
200: Member removed successfully
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization, department, or member not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
dept = Department.query.filter_by(
|
||||
id=dept_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not dept:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Department not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
membership = DepartmentMembership.query.filter_by(
|
||||
user_id=user_id,
|
||||
department_id=dept_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User is not a member of this department",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
membership.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
message="Member removed successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/principals", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def get_department_principals(org_id, dept_id):
|
||||
"""Get all principals linked to a department."""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
if not org.is_member(g.current_user.id):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are not a member of this organization",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
dept = Department.query.filter_by(
|
||||
id=dept_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not dept:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Department not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
principals = dept.get_principals(active_only=True)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"principals": [p.to_dict() for p in principals],
|
||||
"count": len(principals),
|
||||
},
|
||||
message="Principals retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Department Certificate Policy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/cert-policy", methods=["GET"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def get_dept_cert_policy(org_id, dept_id):
|
||||
"""Get the certificate issuance policy for a department (admin only)."""
|
||||
from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy, STANDARD_EXTENSIONS
|
||||
|
||||
dept = Department.query.filter_by(
|
||||
id=dept_id, organization_id=org_id, deleted_at=None
|
||||
).first()
|
||||
if not dept:
|
||||
return api_response(success=False, message="Department not found", status=404, error_type="NOT_FOUND")
|
||||
|
||||
policy = DepartmentCertPolicy.query.filter(
|
||||
DepartmentCertPolicy.department_id == dept_id,
|
||||
DepartmentCertPolicy.deleted_at.is_(None),
|
||||
).first()
|
||||
|
||||
if policy:
|
||||
data = policy.to_dict()
|
||||
else:
|
||||
# Return default (all standard extensions, no user expiry choice)
|
||||
data = {
|
||||
"department_id": str(dept_id),
|
||||
"allow_user_expiry": False,
|
||||
"default_expiry_hours": 1,
|
||||
"max_expiry_hours": 24,
|
||||
"allowed_extensions": list(STANDARD_EXTENSIONS),
|
||||
"custom_extensions": [],
|
||||
"all_extensions": list(STANDARD_EXTENSIONS),
|
||||
"standard_extensions": list(STANDARD_EXTENSIONS),
|
||||
}
|
||||
|
||||
return api_response(data={"cert_policy": data}, message="Certificate policy retrieved")
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/departments/<dept_id>/cert-policy", methods=["PUT"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def set_dept_cert_policy(org_id, dept_id):
|
||||
"""Create or update the certificate issuance policy for a department (admin only)."""
|
||||
from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy, STANDARD_EXTENSIONS
|
||||
|
||||
dept = Department.query.filter_by(
|
||||
id=dept_id, organization_id=org_id, deleted_at=None
|
||||
).first()
|
||||
if not dept:
|
||||
return api_response(success=False, message="Department not found", status=404, error_type="NOT_FOUND")
|
||||
|
||||
body = request.get_json() or {}
|
||||
|
||||
# Validate expiry values
|
||||
default_expiry = body.get("default_expiry_hours")
|
||||
max_expiry = body.get("max_expiry_hours")
|
||||
if default_expiry is not None:
|
||||
try:
|
||||
default_expiry = int(default_expiry)
|
||||
if default_expiry < 1:
|
||||
raise ValueError
|
||||
except (ValueError, TypeError):
|
||||
return api_response(success=False, message="default_expiry_hours must be a positive integer", status=400, error_type="VALIDATION_ERROR")
|
||||
if max_expiry is not None:
|
||||
try:
|
||||
max_expiry = int(max_expiry)
|
||||
if max_expiry < 1:
|
||||
raise ValueError
|
||||
except (ValueError, TypeError):
|
||||
return api_response(success=False, message="max_expiry_hours must be a positive integer", status=400, error_type="VALIDATION_ERROR")
|
||||
if default_expiry and max_expiry and default_expiry > max_expiry:
|
||||
return api_response(success=False, message="default_expiry_hours cannot exceed max_expiry_hours", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
# Validate allowed_extensions — must be subset of STANDARD_EXTENSIONS
|
||||
allowed_extensions = body.get("allowed_extensions")
|
||||
if allowed_extensions is not None:
|
||||
if not isinstance(allowed_extensions, list):
|
||||
return api_response(success=False, message="allowed_extensions must be a list", status=400, error_type="VALIDATION_ERROR")
|
||||
invalid_ext = [e for e in allowed_extensions if e not in STANDARD_EXTENSIONS]
|
||||
if invalid_ext:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=f"Invalid standard extensions: {', '.join(invalid_ext)}. Valid: {', '.join(STANDARD_EXTENSIONS)}",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
# Validate custom_extensions — plain strings
|
||||
custom_extensions = body.get("custom_extensions")
|
||||
if custom_extensions is not None:
|
||||
if not isinstance(custom_extensions, list) or not all(isinstance(e, str) for e in custom_extensions):
|
||||
return api_response(success=False, message="custom_extensions must be a list of strings", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
policy = DepartmentCertPolicy.query.filter(
|
||||
DepartmentCertPolicy.department_id == dept_id,
|
||||
DepartmentCertPolicy.deleted_at.is_(None),
|
||||
).first()
|
||||
|
||||
if policy is None:
|
||||
policy = DepartmentCertPolicy(department_id=dept_id)
|
||||
db.session.add(policy)
|
||||
|
||||
if "allow_user_expiry" in body:
|
||||
policy.allow_user_expiry = bool(body["allow_user_expiry"])
|
||||
if default_expiry is not None:
|
||||
policy.default_expiry_hours = default_expiry
|
||||
if max_expiry is not None:
|
||||
policy.max_expiry_hours = max_expiry
|
||||
if allowed_extensions is not None:
|
||||
policy.allowed_extensions = list(allowed_extensions)
|
||||
flag_modified(policy, "allowed_extensions")
|
||||
if custom_extensions is not None:
|
||||
policy.custom_extensions = list(custom_extensions)
|
||||
flag_modified(policy, "custom_extensions")
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={"cert_policy": policy.to_dict()}, message="Certificate policy saved")
|
||||
|
||||
@@ -46,6 +46,33 @@ def _pop_oidc_bridge(oauth_state: str) -> str | None:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _store_cli_redirect(oauth_state: str, redirect_url: str) -> None:
|
||||
"""Store CLI redirect_url keyed by OAuth state (for /token_please flow)."""
|
||||
try:
|
||||
import gatehouse_app.extensions as _ext
|
||||
rc = _ext.redis_client
|
||||
if rc is not None:
|
||||
rc.setex(f"oauth_cli_redirect:{oauth_state}", _OAUTH_BRIDGE_TTL, redirect_url)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _pop_cli_redirect(oauth_state: str) -> str | None:
|
||||
"""Retrieve and delete CLI redirect_url for the given OAuth state."""
|
||||
try:
|
||||
import gatehouse_app.extensions as _ext
|
||||
rc = _ext.redis_client
|
||||
if rc is not None:
|
||||
key = f"oauth_cli_redirect:{oauth_state}"
|
||||
val = rc.get(key)
|
||||
if val:
|
||||
rc.delete(key)
|
||||
return val.decode() if isinstance(val, bytes) else val
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -69,6 +96,137 @@ def get_provider_type(provider: str) -> AuthMethodType:
|
||||
return PROVIDER_TYPE_MAP[provider_lower]
|
||||
|
||||
|
||||
@api_v1_bp.route("/token_please", methods=["GET"])
|
||||
def token_please():
|
||||
"""
|
||||
CLI token acquisition endpoint.
|
||||
|
||||
Redirects the user's browser to the Gatehouse login page so they can
|
||||
authenticate using any method (password, OAuth, passkey, TOTP, etc.).
|
||||
On successful login the frontend delivers the session token directly to
|
||||
the CLI's local callback server.
|
||||
|
||||
This endpoint is designed for CLI clients that:
|
||||
1. Start a local HTTP server on LISTENER_SERVER_PORT (e.g. 8250)
|
||||
2. Open a browser to /api/v1/token_please?redirect_url=http://127.0.0.1:8250/?token=
|
||||
3. Wait for the browser to deliver the token to their local server
|
||||
|
||||
Query parameters:
|
||||
redirect_url: Local callback URL where the token will be appended
|
||||
"""
|
||||
import secrets
|
||||
from urllib.parse import urlencode, quote
|
||||
from flask import current_app, redirect as flask_redirect
|
||||
|
||||
redirect_url = request.args.get("redirect_url", "").strip()
|
||||
|
||||
if not redirect_url:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="redirect_url query parameter is required",
|
||||
status=400,
|
||||
error_type="MISSING_REDIRECT_URL",
|
||||
)
|
||||
|
||||
# Validate redirect_url is localhost/127.0.0.1 (security: prevent open redirect)
|
||||
from urllib.parse import urlparse as _urlparse
|
||||
parsed = _urlparse(redirect_url)
|
||||
if parsed.hostname not in ("localhost", "127.0.0.1"):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="redirect_url must point to localhost",
|
||||
status=400,
|
||||
error_type="INVALID_REDIRECT_URL",
|
||||
)
|
||||
|
||||
# Store the CLI redirect URL in Redis keyed by a short-lived token so the
|
||||
# frontend can retrieve it after login without it being visible in the URL.
|
||||
cli_token = secrets.token_urlsafe(32)
|
||||
try:
|
||||
import gatehouse_app.extensions as _ext
|
||||
rc = _ext.redis_client
|
||||
if rc is not None:
|
||||
rc.setex(f"cli_redirect:{cli_token}", _OAUTH_BRIDGE_TTL, redirect_url)
|
||||
else:
|
||||
logger.warning("Redis not available; passing cli_redirect directly in URL")
|
||||
cli_token = None
|
||||
except Exception:
|
||||
cli_token = None
|
||||
|
||||
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080")
|
||||
|
||||
if cli_token:
|
||||
# Pass an opaque token; the frontend exchanges it for the real URL via
|
||||
# GET /api/v1/cli/redirect-url?token=<cli_token>
|
||||
login_url = f"{frontend_url}/login?cli_token={cli_token}"
|
||||
else:
|
||||
# Fallback: put the redirect URL directly (still localhost-only, validated above)
|
||||
login_url = f"{frontend_url}/login?cli_redirect={quote(redirect_url, safe='')}"
|
||||
|
||||
logger.info(f"CLI token_please: redirecting browser to Gatehouse login page")
|
||||
return flask_redirect(login_url, code=302)
|
||||
|
||||
|
||||
@api_v1_bp.route("/cli/redirect-url", methods=["GET"])
|
||||
def cli_redirect_url_lookup():
|
||||
"""
|
||||
Exchange a short-lived cli_token for the CLI's local redirect URL.
|
||||
|
||||
Called by the frontend LoginPage after it detects the cli_token query
|
||||
param so it can obtain the actual CLI callback URL from Redis without
|
||||
exposing it in the browser URL bar.
|
||||
|
||||
Query parameters:
|
||||
token: The cli_token issued by /token_please
|
||||
|
||||
Returns:
|
||||
200: { "redirect_url": "http://127.0.0.1:8250/?token=" }
|
||||
400: Missing token
|
||||
404: Token not found or expired
|
||||
"""
|
||||
cli_token = request.args.get("token", "").strip()
|
||||
if not cli_token:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="token query parameter is required",
|
||||
status=400,
|
||||
error_type="MISSING_TOKEN",
|
||||
)
|
||||
|
||||
try:
|
||||
import gatehouse_app.extensions as _ext
|
||||
rc = _ext.redis_client
|
||||
if rc is not None:
|
||||
key = f"cli_redirect:{cli_token}"
|
||||
val = rc.get(key)
|
||||
if val is None:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="CLI token not found or expired",
|
||||
status=404,
|
||||
error_type="TOKEN_NOT_FOUND",
|
||||
)
|
||||
# Keep the key alive until the login actually completes (consume on use
|
||||
# would break multi-step auth like TOTP), so we leave it as-is.
|
||||
redirect_url = val.decode() if isinstance(val, bytes) else val
|
||||
return api_response(data={"redirect_url": redirect_url})
|
||||
except Exception as e:
|
||||
logger.error(f"cli_redirect_url_lookup error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Internal error looking up CLI token",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Redis not available",
|
||||
status=503,
|
||||
error_type="SERVICE_UNAVAILABLE",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider Configuration Endpoints (Admin)
|
||||
# =============================================================================
|
||||
@@ -83,7 +241,7 @@ def list_providers():
|
||||
200: List of providers with their configuration status
|
||||
401: Not authenticated
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import ApplicationProviderConfig
|
||||
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
|
||||
from gatehouse_app.services.external_auth_service import ExternalProviderConfig
|
||||
|
||||
# Check app-level provider configs (ApplicationProviderConfig)
|
||||
@@ -575,8 +733,6 @@ def initiate_oauth_authorize(provider: str):
|
||||
"state": "state_token"
|
||||
}
|
||||
"""
|
||||
provider_type = get_provider_type(provider)
|
||||
|
||||
# Get query parameters - organization_id is now optional
|
||||
flow = request.args.get("flow", "login")
|
||||
redirect_uri = request.args.get("redirect_uri")
|
||||
@@ -592,7 +748,7 @@ def initiate_oauth_authorize(provider: str):
|
||||
)
|
||||
|
||||
try:
|
||||
# Initiate flow - organization_id is now optional
|
||||
provider_type = get_provider_type(provider)
|
||||
if flow == "login":
|
||||
auth_url, state = OAuthFlowService.initiate_login_flow(
|
||||
provider_type=provider_type,
|
||||
@@ -626,6 +782,13 @@ def initiate_oauth_authorize(provider: str):
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
except ExternalAuthError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=e.message,
|
||||
status=e.status_code,
|
||||
error_type=e.error_type,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/external/<provider>/callback", methods=["GET"])
|
||||
@@ -666,8 +829,19 @@ def handle_oauth_callback(provider: str):
|
||||
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:8080")
|
||||
frontend_callback = f"{frontend_url}/oauth/callback"
|
||||
|
||||
# Check if this is a CLI /token_please flow — retrieve stored redirect_url
|
||||
cli_redirect_url = _pop_cli_redirect(state) if state else None
|
||||
|
||||
def redirect_error(message: str, error_type: str = "OAUTH_ERROR"):
|
||||
"""Redirect to frontend with error params."""
|
||||
"""Redirect to frontend (or CLI) with error params."""
|
||||
if cli_redirect_url:
|
||||
# CLI flow: return a plain error page instead of redirecting back
|
||||
from flask import make_response
|
||||
return make_response(
|
||||
f"<html><body><h2>Authentication Error</h2><p>{message}</p>"
|
||||
f"<p>You may close this window.</p></body></html>",
|
||||
400,
|
||||
)
|
||||
params = {"error": message, "error_type": error_type}
|
||||
if state:
|
||||
params["state"] = state
|
||||
@@ -706,8 +880,11 @@ def handle_oauth_callback(provider: str):
|
||||
# Recover oidc_session_id if this was triggered from an OIDC bridge flow
|
||||
oidc_session_id = _pop_oidc_bridge(state)
|
||||
|
||||
# Organization selection / creation flows are not supported in CLI mode
|
||||
# (fall through to token redirect with whatever session we have)
|
||||
|
||||
# Organization selection needed (user belongs to multiple orgs)
|
||||
if result.get("requires_org_selection"):
|
||||
if result.get("requires_org_selection") and not cli_redirect_url:
|
||||
import json
|
||||
orgs = json.dumps(result.get("available_organizations", []))
|
||||
params = {
|
||||
@@ -722,12 +899,20 @@ def handle_oauth_callback(provider: str):
|
||||
return flask_redirect(f"{frontend_callback}?{urlencode(params)}", code=302)
|
||||
|
||||
# Organization creation needed (new user via OAuth with no org)
|
||||
if result.get("requires_org_creation"):
|
||||
if result.get("requires_org_creation") and not cli_redirect_url:
|
||||
import json as _json
|
||||
session_data = result.get("session", {})
|
||||
token = session_data.get("token", "")
|
||||
expires_in = session_data.get("expires_in", 86400)
|
||||
pending_invites = result.get("pending_invites", [])
|
||||
params = {
|
||||
"requires_org_creation": "1",
|
||||
"state": result["state"],
|
||||
"provider": provider,
|
||||
"flow": flow_type,
|
||||
"token": token,
|
||||
"expires_in": str(expires_in),
|
||||
"pending_invites": _json.dumps(pending_invites),
|
||||
}
|
||||
if oidc_session_id:
|
||||
params["oidc_session_id"] = oidc_session_id
|
||||
@@ -751,6 +936,19 @@ def handle_oauth_callback(provider: str):
|
||||
user_info = result.get("user", {})
|
||||
if user_info.get("email"):
|
||||
params["email"] = user_info["email"]
|
||||
|
||||
# ── CLI /token_please flow: redirect to the CLI's local callback ─────
|
||||
if cli_redirect_url:
|
||||
# The CLI expects: http://127.0.0.1:8250/?token=<TOKEN>
|
||||
# cli_redirect_url already ends with "token=" so just append the value
|
||||
cli_final_url = cli_redirect_url + token
|
||||
logger.info(
|
||||
f"CLI token_please success: provider={provider}, user={user_info.get('email')}, "
|
||||
f"redirecting to CLI callback"
|
||||
)
|
||||
return flask_redirect(cli_final_url, code=302)
|
||||
|
||||
# ── Frontend flow ─────────────────────────────────────────────────────
|
||||
# Pass oidc_session_id through so the frontend can complete the OIDC flow
|
||||
if oidc_session_id:
|
||||
params["oidc_session_id"] = oidc_session_id
|
||||
@@ -1049,3 +1247,203 @@ def _get_provider_endpoints(provider_type: AuthMethodType):
|
||||
"UNSUPPORTED_PROVIDER",
|
||||
400,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Admin: Application-level OAuth Provider Management
|
||||
# =============================================================================
|
||||
|
||||
@api_v1_bp.route("/admin/oauth/providers", methods=["GET"])
|
||||
@login_required
|
||||
def admin_list_app_providers():
|
||||
"""List all application-level OAuth provider configurations (admin only).
|
||||
|
||||
Returns:
|
||||
200: List of providers with client_id and enabled status
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
|
||||
from gatehouse_app.models import OrganizationMember
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
# Verify caller is admin in any org
|
||||
admin_memberships = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == g.current_user.id,
|
||||
OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]),
|
||||
).all()
|
||||
|
||||
if not admin_memberships:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Admin access required",
|
||||
status=403,
|
||||
error_type="FORBIDDEN",
|
||||
)
|
||||
|
||||
PROVIDERS = [
|
||||
{"id": "google", "name": "Google"},
|
||||
{"id": "github", "name": "GitHub"},
|
||||
{"id": "microsoft", "name": "Microsoft"},
|
||||
]
|
||||
|
||||
db_configs = {
|
||||
c.provider_type: c
|
||||
for c in ApplicationProviderConfig.query.all()
|
||||
}
|
||||
|
||||
result = []
|
||||
for p in PROVIDERS:
|
||||
cfg = db_configs.get(p["id"])
|
||||
result.append({
|
||||
"id": p["id"],
|
||||
"name": p["name"],
|
||||
"is_configured": cfg is not None,
|
||||
"is_enabled": cfg.is_enabled if cfg else False,
|
||||
"client_id": cfg.client_id if cfg else None,
|
||||
})
|
||||
|
||||
return api_response(
|
||||
data={"providers": result},
|
||||
message="OAuth providers retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/admin/oauth/providers/<provider>", methods=["PUT"])
|
||||
@login_required
|
||||
def admin_configure_app_provider(provider: str):
|
||||
"""Create or update an application-level OAuth provider config (admin only).
|
||||
|
||||
Args:
|
||||
provider: Provider type (google, github, microsoft)
|
||||
|
||||
Request body:
|
||||
client_id: OAuth client ID
|
||||
client_secret: OAuth client secret (optional — omit to keep existing)
|
||||
is_enabled: Whether the provider is enabled (default: true)
|
||||
|
||||
Returns:
|
||||
200: Provider configuration updated
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
|
||||
from gatehouse_app.models import OrganizationMember
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
SUPPORTED = ["google", "github", "microsoft"]
|
||||
if provider not in SUPPORTED:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=f"Unsupported provider. Must be one of: {', '.join(SUPPORTED)}",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
# Verify caller is admin in any org
|
||||
admin_memberships = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == g.current_user.id,
|
||||
OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]),
|
||||
).all()
|
||||
|
||||
if not admin_memberships:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Admin access required",
|
||||
status=403,
|
||||
error_type="FORBIDDEN",
|
||||
)
|
||||
|
||||
data = request.json or {}
|
||||
client_id = (data.get("client_id") or "").strip()
|
||||
client_secret = (data.get("client_secret") or "").strip()
|
||||
is_enabled = data.get("is_enabled", True)
|
||||
|
||||
if not client_id:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="client_id is required",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
|
||||
if cfg:
|
||||
cfg.client_id = client_id
|
||||
if client_secret:
|
||||
cfg.set_client_secret(client_secret)
|
||||
cfg.is_enabled = bool(is_enabled)
|
||||
db.session.commit()
|
||||
else:
|
||||
cfg = ApplicationProviderConfig(
|
||||
provider_type=provider,
|
||||
client_id=client_id,
|
||||
is_enabled=bool(is_enabled),
|
||||
)
|
||||
if client_secret:
|
||||
cfg.set_client_secret(client_secret)
|
||||
db.session.add(cfg)
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"provider": {
|
||||
"id": provider,
|
||||
"client_id": cfg.client_id,
|
||||
"is_enabled": cfg.is_enabled,
|
||||
}
|
||||
},
|
||||
message=f"{provider.capitalize()} OAuth provider configured successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/admin/oauth/providers/<provider>", methods=["DELETE"])
|
||||
@login_required
|
||||
def admin_delete_app_provider(provider: str):
|
||||
"""Delete an application-level OAuth provider config (admin only).
|
||||
|
||||
Args:
|
||||
provider: Provider type (google, github, microsoft)
|
||||
|
||||
Returns:
|
||||
200: Provider configuration deleted
|
||||
404: Provider not found
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
|
||||
from gatehouse_app.models import OrganizationMember
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
# Verify caller is admin in any org
|
||||
admin_memberships = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == g.current_user.id,
|
||||
OrganizationMember.role.in_([OrganizationRole.OWNER, OrganizationRole.ADMIN]),
|
||||
).all()
|
||||
|
||||
if not admin_memberships:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Admin access required",
|
||||
status=403,
|
||||
error_type="FORBIDDEN",
|
||||
)
|
||||
|
||||
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
|
||||
if not cfg:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=f"Provider '{provider}' is not configured",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
db.session.delete(cfg)
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
message=f"{provider.capitalize()} OAuth provider configuration removed",
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -195,20 +195,46 @@ def get_org_mfa_compliance(org_id):
|
||||
|
||||
limit = min(int(request.args.get("limit", 100)), 100)
|
||||
offset = int(request.args.get("offset", 0))
|
||||
page = int(request.args.get("page", 1))
|
||||
page_size = min(int(request.args.get("page_size", limit)), 100)
|
||||
|
||||
effective_offset = offset if request.args.get("offset") else (page - 1) * page_size
|
||||
|
||||
compliance_list = MfaPolicyService.get_org_compliance_list(
|
||||
organization_id=org_id,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
limit=page_size,
|
||||
offset=effective_offset,
|
||||
)
|
||||
|
||||
def format_member(c):
|
||||
"""Normalize compliance record to UI-expected shape."""
|
||||
if isinstance(c, dict):
|
||||
return {
|
||||
"user_id": c.get("user_id"),
|
||||
"user_email": c.get("email"),
|
||||
"user_name": c.get("full_name"),
|
||||
"status": c.get("status"),
|
||||
"deadline_at": c.get("deadline_at"),
|
||||
"compliant_at": c.get("compliant_at"),
|
||||
"last_notified_at": c.get("notified_at"),
|
||||
}
|
||||
return {
|
||||
"user_id": getattr(c, "user_id", None),
|
||||
"user_email": getattr(c, "email", None),
|
||||
"user_name": getattr(c, "full_name", None),
|
||||
"status": getattr(c, "status", None),
|
||||
"deadline_at": getattr(c, "deadline_at", None),
|
||||
"compliant_at": getattr(c, "compliant_at", None),
|
||||
"last_notified_at": getattr(c, "notified_at", None),
|
||||
}
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"compliance": compliance_list,
|
||||
"members": [format_member(c) for c in compliance_list],
|
||||
"count": len(compliance_list),
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
message="Compliance records retrieved successfully",
|
||||
)
|
||||
@@ -325,12 +351,10 @@ def get_my_mfa_compliance():
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"mfa_compliance": {
|
||||
"overall_status": compliance_summary.overall_status,
|
||||
"missing_methods": compliance_summary.missing_methods,
|
||||
"deadline_at": compliance_summary.deadline_at,
|
||||
"orgs": orgs,
|
||||
}
|
||||
"overall_status": compliance_summary.overall_status,
|
||||
"missing_methods": compliance_summary.missing_methods,
|
||||
"deadline_at": compliance_summary.deadline_at,
|
||||
"orgs": orgs,
|
||||
},
|
||||
message="MFA compliance retrieved successfully",
|
||||
)
|
||||
@@ -0,0 +1,779 @@
|
||||
"""Principal endpoints."""
|
||||
from flask import g, request
|
||||
from marshmallow import Schema, fields, validate, ValidationError
|
||||
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.utils.decorators import login_required, require_admin, full_access_required
|
||||
from gatehouse_app.models import Principal, PrincipalMembership, Department, DepartmentPrincipal
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.services.user_service import UserService
|
||||
from gatehouse_app.exceptions import OrganizationNotFoundError
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
|
||||
class PrincipalCreateSchema(Schema):
|
||||
"""Schema for creating a principal."""
|
||||
name = fields.Str(required=True, validate=validate.Length(min=1, max=255))
|
||||
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
|
||||
|
||||
|
||||
class PrincipalUpdateSchema(Schema):
|
||||
"""Schema for updating a principal."""
|
||||
name = fields.Str(validate=validate.Length(min=1, max=255))
|
||||
description = fields.Str(allow_none=True, validate=validate.Length(max=2000))
|
||||
|
||||
|
||||
class AddPrincipalMemberSchema(Schema):
|
||||
"""Schema for adding a member to a principal."""
|
||||
email = fields.Email(required=True)
|
||||
|
||||
|
||||
class LinkPrincipalSchema(Schema):
|
||||
"""Schema for linking principal to department."""
|
||||
department_id = fields.Str(required=True)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/principals", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def list_principals(org_id):
|
||||
"""
|
||||
List all principals in an organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
|
||||
Returns:
|
||||
200: List of principals
|
||||
401: Not authenticated
|
||||
403: Not a member
|
||||
404: Organization not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
if not org.is_member(g.current_user.id):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are not a member of this organization",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
principals = Principal.query.filter_by(
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).all()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"principals": [p.to_dict() for p in principals],
|
||||
"count": len(principals),
|
||||
},
|
||||
message="Principals retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/principals", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def create_principal(org_id):
|
||||
"""
|
||||
Create a new principal.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
|
||||
Request body:
|
||||
name: Principal name (required)
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
201: Principal created successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization not found
|
||||
409: Principal name already exists
|
||||
"""
|
||||
try:
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
schema = PrincipalCreateSchema()
|
||||
data = schema.load(request.json or {})
|
||||
|
||||
# Check if principal name already exists
|
||||
existing = Principal.query.filter_by(
|
||||
organization_id=org_id,
|
||||
name=data["name"],
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=f"Principal '{data['name']}' already exists",
|
||||
status=409,
|
||||
error_type="CONFLICT",
|
||||
)
|
||||
|
||||
# Create principal
|
||||
principal = Principal(
|
||||
organization_id=org_id,
|
||||
name=data["name"],
|
||||
description=data.get("description"),
|
||||
)
|
||||
db.session.add(principal)
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
data={"principal": principal.to_dict()},
|
||||
message="Principal created successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def get_principal(org_id, principal_id):
|
||||
"""
|
||||
Get a specific principal.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
principal_id: Principal ID
|
||||
|
||||
Returns:
|
||||
200: Principal data
|
||||
401: Not authenticated
|
||||
403: Not a member
|
||||
404: Organization or principal not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
if not org.is_member(g.current_user.id):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are not a member of this organization",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
principal = Principal.query.filter_by(
|
||||
id=principal_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not principal:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Principal not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={"principal": principal.to_dict()},
|
||||
message="Principal retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>", methods=["PATCH"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def update_principal(org_id, principal_id):
|
||||
"""
|
||||
Update a principal.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
principal_id: Principal ID
|
||||
|
||||
Request body:
|
||||
name: Optional new name
|
||||
description: Optional new description
|
||||
|
||||
Returns:
|
||||
200: Principal updated successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization or principal not found
|
||||
409: Name already exists
|
||||
"""
|
||||
try:
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
principal = Principal.query.filter_by(
|
||||
id=principal_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not principal:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Principal not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
schema = PrincipalUpdateSchema()
|
||||
data = schema.load(request.json or {})
|
||||
|
||||
# Check if new name already exists
|
||||
if "name" in data and data["name"] != principal.name:
|
||||
existing = Principal.query.filter_by(
|
||||
organization_id=org_id,
|
||||
name=data["name"],
|
||||
deleted_at=None
|
||||
).first()
|
||||
if existing:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=f"Principal '{data['name']}' already exists",
|
||||
status=409,
|
||||
error_type="CONFLICT",
|
||||
)
|
||||
|
||||
# Update fields
|
||||
for key, value in data.items():
|
||||
setattr(principal, key, value)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
data={"principal": principal.to_dict()},
|
||||
message="Principal updated successfully",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def delete_principal(org_id, principal_id):
|
||||
"""
|
||||
Delete a principal (soft delete).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
principal_id: Principal ID
|
||||
|
||||
Returns:
|
||||
200: Principal deleted successfully
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization or principal not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
principal = Principal.query.filter_by(
|
||||
id=principal_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not principal:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Principal not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
principal.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
message="Principal deleted successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/members", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def get_principal_members(org_id, principal_id):
|
||||
"""
|
||||
Get all members (direct + via department) with access to a principal.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
principal_id: Principal ID
|
||||
|
||||
Returns:
|
||||
200: List of members
|
||||
401: Not authenticated
|
||||
403: Not a member
|
||||
404: Organization or principal not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
if not org.is_member(g.current_user.id):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are not a member of this organization",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
principal = Principal.query.filter_by(
|
||||
id=principal_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not principal:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Principal not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Get direct members
|
||||
direct_members = PrincipalMembership.query.filter_by(
|
||||
principal_id=principal_id,
|
||||
deleted_at=None
|
||||
).all()
|
||||
|
||||
all_users = set()
|
||||
for membership in direct_members:
|
||||
if membership.user.deleted_at is None:
|
||||
all_users.add(membership.user)
|
||||
|
||||
# Get members via departments
|
||||
dept_links = DepartmentPrincipal.query.filter_by(
|
||||
principal_id=principal_id,
|
||||
deleted_at=None
|
||||
).all()
|
||||
|
||||
for link in dept_links:
|
||||
dept = link.department
|
||||
if dept.deleted_at is None:
|
||||
dept_members = dept.get_members(active_only=True)
|
||||
for dept_member in dept_members:
|
||||
if dept_member.user.deleted_at is None:
|
||||
all_users.add(dept_member.user)
|
||||
|
||||
users_data = [u.to_dict() for u in all_users]
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"members": users_data,
|
||||
"count": len(users_data),
|
||||
},
|
||||
message="Members retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/members", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def add_principal_member(org_id, principal_id):
|
||||
"""
|
||||
Add a direct member to a principal.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
principal_id: Principal ID
|
||||
|
||||
Request body:
|
||||
email: User email to add
|
||||
|
||||
Returns:
|
||||
201: Member added successfully
|
||||
400: Validation error
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization, principal, or user not found
|
||||
409: User already a member
|
||||
"""
|
||||
try:
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
principal = Principal.query.filter_by(
|
||||
id=principal_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not principal:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Principal not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
schema = AddPrincipalMemberSchema()
|
||||
data = schema.load(request.json or {})
|
||||
|
||||
# Find user by email
|
||||
user = UserService.get_user_by_email(data["email"])
|
||||
if not user:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Check if already a member
|
||||
existing = PrincipalMembership.query.filter_by(
|
||||
user_id=user.id,
|
||||
principal_id=principal_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User is already a member of this principal",
|
||||
status=409,
|
||||
error_type="CONFLICT",
|
||||
)
|
||||
|
||||
soft_deleted = PrincipalMembership.query.filter(
|
||||
PrincipalMembership.user_id == user.id,
|
||||
PrincipalMembership.principal_id == principal_id,
|
||||
PrincipalMembership.deleted_at.isnot(None)
|
||||
).first()
|
||||
|
||||
if soft_deleted:
|
||||
soft_deleted.deleted_at = None
|
||||
membership = soft_deleted
|
||||
else:
|
||||
membership = PrincipalMembership(
|
||||
user_id=user.id,
|
||||
principal_id=principal_id,
|
||||
)
|
||||
db.session.add(membership)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
member_dict = membership.to_dict()
|
||||
member_dict["user"] = user.to_dict()
|
||||
|
||||
return api_response(
|
||||
data={"member": member_dict},
|
||||
message="Member added successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/members/<user_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def remove_principal_member(org_id, principal_id, user_id):
|
||||
"""
|
||||
Remove a direct member from a principal.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
principal_id: Principal ID
|
||||
user_id: User ID to remove
|
||||
|
||||
Returns:
|
||||
200: Member removed successfully
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization, principal, or member not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
principal = Principal.query.filter_by(
|
||||
id=principal_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not principal:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Principal not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
membership = PrincipalMembership.query.filter_by(
|
||||
user_id=user_id,
|
||||
principal_id=principal_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User is not a member of this principal",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
membership.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
message="Member removed successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/departments", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def get_principal_departments(org_id, principal_id):
|
||||
"""
|
||||
Get all departments this principal is assigned to.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
principal_id: Principal ID
|
||||
|
||||
Returns:
|
||||
200: List of departments
|
||||
401: Not authenticated
|
||||
403: Not a member
|
||||
404: Organization or principal not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
if not org.is_member(g.current_user.id):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are not a member of this organization",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
principal = Principal.query.filter_by(
|
||||
id=principal_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not principal:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Principal not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
depts = principal.get_departments(active_only=True)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"departments": [d.to_dict() for d in depts],
|
||||
"count": len(depts),
|
||||
},
|
||||
message="Departments retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/departments/<dept_id>", methods=["POST"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def link_principal_to_department(org_id, principal_id, dept_id):
|
||||
"""
|
||||
Link a principal to a department.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
principal_id: Principal ID
|
||||
dept_id: Department ID
|
||||
|
||||
Returns:
|
||||
201: Principal linked successfully
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization, principal, or department not found
|
||||
409: Already linked
|
||||
"""
|
||||
try:
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
except OrganizationNotFoundError:
|
||||
return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND")
|
||||
|
||||
principal = Principal.query.filter_by(
|
||||
id=principal_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not principal:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Principal not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
dept = Department.query.filter_by(
|
||||
id=dept_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not dept:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Department not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
existing = DepartmentPrincipal.query.filter_by(
|
||||
department_id=dept_id,
|
||||
principal_id=principal_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Principal is already linked to this department",
|
||||
status=409,
|
||||
error_type="CONFLICT",
|
||||
)
|
||||
|
||||
soft_deleted = DepartmentPrincipal.query.filter(
|
||||
DepartmentPrincipal.department_id == dept_id,
|
||||
DepartmentPrincipal.principal_id == principal_id,
|
||||
DepartmentPrincipal.deleted_at.isnot(None),
|
||||
).first()
|
||||
|
||||
try:
|
||||
if soft_deleted:
|
||||
soft_deleted.deleted_at = None
|
||||
else:
|
||||
link = DepartmentPrincipal(
|
||||
department_id=dept_id,
|
||||
principal_id=principal_id,
|
||||
)
|
||||
db.session.add(link)
|
||||
db.session.commit()
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Failed to link principal to department",
|
||||
status=500,
|
||||
error_type="SERVER_ERROR",
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"principal": principal.to_dict(),
|
||||
"department": dept.to_dict(),
|
||||
},
|
||||
message="Principal linked to department successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/principals/<principal_id>/departments/<dept_id>", methods=["DELETE"])
|
||||
@login_required
|
||||
@require_admin
|
||||
@full_access_required
|
||||
def unlink_principal_from_department(org_id, principal_id, dept_id):
|
||||
"""
|
||||
Unlink a principal from a department.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
principal_id: Principal ID
|
||||
dept_id: Department ID
|
||||
|
||||
Returns:
|
||||
200: Principal unlinked successfully
|
||||
401: Not authenticated
|
||||
403: Not an admin
|
||||
404: Organization, principal, department, or link not found
|
||||
"""
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
|
||||
principal = Principal.query.filter_by(
|
||||
id=principal_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not principal:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Principal not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
dept = Department.query.filter_by(
|
||||
id=dept_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not dept:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Department not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
link = DepartmentPrincipal.query.filter_by(
|
||||
department_id=dept_id,
|
||||
principal_id=principal_id,
|
||||
deleted_at=None
|
||||
).first()
|
||||
|
||||
if not link:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Principal is not linked to this department",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
link.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
return api_response(
|
||||
message="Principal unlinked from department successfully",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -73,11 +73,51 @@ def delete_me():
|
||||
"""
|
||||
Delete current user account (soft delete).
|
||||
|
||||
Blocked if the user is the sole owner of any organization that has other
|
||||
active members — they must transfer ownership or dissolve those organizations
|
||||
first.
|
||||
|
||||
Returns:
|
||||
200: Account deleted successfully
|
||||
401: Not authenticated
|
||||
409: User is sole owner of one or more organizations with other members
|
||||
"""
|
||||
UserService.delete_user(g.current_user, soft=True)
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
user = g.current_user
|
||||
|
||||
# Find orgs where this user is the sole owner AND other members exist.
|
||||
owned_memberships = OrganizationMember.query.filter_by(
|
||||
user_id=user.id,
|
||||
role=OrganizationRole.OWNER,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
|
||||
blocked_orgs = []
|
||||
for membership in owned_memberships:
|
||||
org = membership.organization
|
||||
if org.deleted_at is not None:
|
||||
continue
|
||||
member_count = org.get_member_count()
|
||||
if member_count > 1:
|
||||
blocked_orgs.append(org.name)
|
||||
|
||||
if blocked_orgs:
|
||||
names = ", ".join(f'"{n}"' for n in blocked_orgs)
|
||||
return api_response(
|
||||
success=False,
|
||||
message=(
|
||||
f"You are the sole owner of {len(blocked_orgs)} organization"
|
||||
f"{'s' if len(blocked_orgs) > 1 else ''}: {names}. "
|
||||
"Transfer ownership or delete those organizations before deleting your account."
|
||||
),
|
||||
status=409,
|
||||
error_type="USER_IS_SOLE_OWNER",
|
||||
error_details={"organizations": blocked_orgs},
|
||||
)
|
||||
|
||||
UserService.delete_user(user, soft=True)
|
||||
|
||||
return api_response(
|
||||
message="Account deleted successfully",
|
||||
@@ -142,18 +182,686 @@ def change_password():
|
||||
@full_access_required
|
||||
def get_my_organizations():
|
||||
"""
|
||||
Get all organizations current user is a member of.
|
||||
Get all organizations current user is a member of, including the user's role.
|
||||
|
||||
Returns:
|
||||
200: List of organizations
|
||||
200: List of organizations with role
|
||||
401: Not authenticated
|
||||
"""
|
||||
organizations = UserService.get_user_organizations(g.current_user)
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
|
||||
user = g.current_user
|
||||
memberships = OrganizationMember.query.filter_by(
|
||||
user_id=user.id,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
|
||||
orgs = []
|
||||
for membership in memberships:
|
||||
org = membership.organization
|
||||
if not org or org.deleted_at is not None:
|
||||
continue
|
||||
org_dict = org.to_dict()
|
||||
org_dict["role"] = membership.role.value if hasattr(membership.role, "value") else str(membership.role)
|
||||
orgs.append(org_dict)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"organizations": [org.to_dict() for org in organizations],
|
||||
"count": len(organizations),
|
||||
"organizations": orgs,
|
||||
"count": len(orgs),
|
||||
},
|
||||
message="Organizations retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/users/me/principals", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def get_my_principals():
|
||||
"""Return all principals the current user can sign certificates for.
|
||||
|
||||
For each organization the user belongs to, returns:
|
||||
- Their effective principals (direct membership + via department)
|
||||
- Their role in that org (so the frontend can offer admin-mode selection)
|
||||
- All principals in the org (admin/owner only — so they can pick any)
|
||||
|
||||
Returns:
|
||||
200: {
|
||||
orgs: [{
|
||||
org_id, org_name, role,
|
||||
my_principals: [{id, name, description}],
|
||||
all_principals: [{id, name, description}] # populated for admin/owner only
|
||||
}]
|
||||
}
|
||||
"""
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
|
||||
from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
user = g.current_user
|
||||
user_id = user.id
|
||||
|
||||
# Get all org memberships
|
||||
memberships = OrganizationMember.query.filter_by(
|
||||
user_id=user_id,
|
||||
).all()
|
||||
|
||||
orgs_result = []
|
||||
for membership in memberships:
|
||||
org = membership.organization
|
||||
if not org or org.deleted_at is not None:
|
||||
continue
|
||||
|
||||
role = membership.role
|
||||
is_admin = role in (OrganizationRole.ADMIN, OrganizationRole.OWNER)
|
||||
|
||||
# Collect the user's effective principals for this org
|
||||
# Track direct vs via-department separately
|
||||
direct_principal_ids = set()
|
||||
via_dept_principal_ids = set()
|
||||
|
||||
# Direct memberships
|
||||
direct = PrincipalMembership.query.filter_by(
|
||||
user_id=user_id,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
for pm in direct:
|
||||
if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None:
|
||||
direct_principal_ids.add(pm.principal_id)
|
||||
|
||||
# Via department
|
||||
dept_memberships = DepartmentMembership.query.filter_by(
|
||||
user_id=user_id,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
for dm in dept_memberships:
|
||||
if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None:
|
||||
dept_principals = DepartmentPrincipal.query.filter_by(
|
||||
department_id=dm.department_id,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
for dp in dept_principals:
|
||||
if dp.principal and dp.principal.deleted_at is None:
|
||||
via_dept_principal_ids.add(dp.principal_id)
|
||||
|
||||
effective_principal_ids = direct_principal_ids | via_dept_principal_ids
|
||||
|
||||
# Fetch principal objects
|
||||
my_principals = []
|
||||
if effective_principal_ids:
|
||||
my_p = Principal.query.filter(
|
||||
Principal.id.in_(list(effective_principal_ids)),
|
||||
Principal.deleted_at == None,
|
||||
).all()
|
||||
my_principals = [
|
||||
{
|
||||
"id": p.id,
|
||||
"name": p.name,
|
||||
"description": p.description,
|
||||
# direct=True means removable via API; False=inherited via department
|
||||
"direct": p.id in direct_principal_ids,
|
||||
}
|
||||
for p in my_p
|
||||
]
|
||||
|
||||
# For admins/owners: also return all principals in the org
|
||||
all_principals = []
|
||||
if is_admin:
|
||||
all_p = Principal.query.filter_by(
|
||||
organization_id=org.id,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
all_principals = [{"id": p.id, "name": p.name, "description": p.description} for p in all_p]
|
||||
|
||||
orgs_result.append({
|
||||
"org_id": org.id,
|
||||
"org_name": org.name,
|
||||
"role": role.value if hasattr(role, "value") else role,
|
||||
"is_admin": is_admin,
|
||||
"my_principals": my_principals,
|
||||
"all_principals": all_principals,
|
||||
})
|
||||
|
||||
return api_response(
|
||||
data={"orgs": orgs_result},
|
||||
message="Principals retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/admin/users", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def admin_list_users():
|
||||
"""List all users the caller has admin rights to see.
|
||||
|
||||
The caller must be an OWNER or ADMIN of at least one organization.
|
||||
Returns users that share an organization with the caller and where the
|
||||
caller holds admin/owner role in that organization.
|
||||
|
||||
Query params:
|
||||
q – optional search string (matched against name/email)
|
||||
page – page number (default 1)
|
||||
per_page – page size (default 50, max 200)
|
||||
"""
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.user.user import User as _User
|
||||
from gatehouse_app.extensions import db as _db
|
||||
from sqlalchemy import or_
|
||||
|
||||
caller = g.current_user
|
||||
|
||||
# Find orgs where caller is admin/owner
|
||||
admin_memberships = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == caller.id,
|
||||
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
|
||||
OrganizationMember.deleted_at == None,
|
||||
).all()
|
||||
|
||||
if not admin_memberships:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Admin or owner role required",
|
||||
status=403,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
admin_org_ids = [m.organization_id for m in admin_memberships]
|
||||
|
||||
# Collect user IDs in those orgs
|
||||
member_rows = OrganizationMember.query.filter(
|
||||
OrganizationMember.organization_id.in_(admin_org_ids),
|
||||
OrganizationMember.deleted_at == None,
|
||||
).all()
|
||||
visible_user_ids = list({row.user_id for row in member_rows})
|
||||
|
||||
# Optional search
|
||||
q = request.args.get("q", "").strip()
|
||||
try:
|
||||
page = max(1, int(request.args.get("page", 1)))
|
||||
per_page = min(200, max(1, int(request.args.get("per_page", 50))))
|
||||
except ValueError:
|
||||
page, per_page = 1, 50
|
||||
|
||||
query = _User.query.filter(
|
||||
_User.id.in_(visible_user_ids),
|
||||
_User.deleted_at == None,
|
||||
)
|
||||
if q:
|
||||
like = f"%{q}%"
|
||||
query = query.filter(or_(_User.email.ilike(like), _User.full_name.ilike(like)))
|
||||
|
||||
total = query.count()
|
||||
users = query.order_by(_User.email).offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
member_lookup: dict = {}
|
||||
for row in member_rows:
|
||||
if row.user_id not in member_lookup:
|
||||
member_lookup[row.user_id] = {
|
||||
"organization_id": row.organization_id,
|
||||
"role": row.role.value if hasattr(row.role, "value") else row.role,
|
||||
}
|
||||
|
||||
users_data = []
|
||||
for u in users:
|
||||
d = u.to_dict()
|
||||
m = member_lookup.get(u.id, {})
|
||||
d["org_role"] = m.get("role", "member")
|
||||
d["org_id"] = m.get("organization_id")
|
||||
users_data.append(d)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"users": users_data,
|
||||
"count": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page,
|
||||
},
|
||||
message="Users retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/admin/users/<user_id>", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def admin_get_user(user_id):
|
||||
"""Get a single user's profile (admin view with SSH keys)."""
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.user.user import User as _User
|
||||
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
|
||||
|
||||
caller = g.current_user
|
||||
|
||||
target = _User.query.filter_by(id=user_id, deleted_at=None).first()
|
||||
if not target:
|
||||
return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
|
||||
|
||||
# Verify caller has admin access to a shared org
|
||||
target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None}
|
||||
has_access = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == caller.id,
|
||||
OrganizationMember.organization_id.in_(target_org_ids),
|
||||
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
|
||||
OrganizationMember.deleted_at == None,
|
||||
).first() is not None
|
||||
|
||||
if not has_access:
|
||||
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
|
||||
|
||||
ssh_keys = SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"user": target.to_dict(),
|
||||
"ssh_keys": [k.to_dict() for k in ssh_keys],
|
||||
},
|
||||
message="User retrieved",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/admin/users/<user_id>/suspend", methods=["POST"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def admin_suspend_user(user_id):
|
||||
"""Suspend a user account (blocks CA issuance and login).
|
||||
|
||||
The caller must be an OWNER or ADMIN of an organization the target user belongs to.
|
||||
"""
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.user.user import User as _User
|
||||
from gatehouse_app.extensions import db as _db
|
||||
from gatehouse_app.utils.constants import UserStatus, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
caller = g.current_user
|
||||
target = _User.query.filter_by(id=user_id, deleted_at=None).first()
|
||||
if not target:
|
||||
return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
|
||||
|
||||
if target.id == caller.id:
|
||||
return api_response(success=False, message="Cannot suspend yourself", status=400, error_type="BAD_REQUEST")
|
||||
|
||||
# Verify caller has admin access to a shared org
|
||||
target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None}
|
||||
admin_in_shared_org = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == caller.id,
|
||||
OrganizationMember.organization_id.in_(target_org_ids),
|
||||
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
|
||||
OrganizationMember.deleted_at == None,
|
||||
).first()
|
||||
|
||||
if not admin_in_shared_org:
|
||||
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
|
||||
|
||||
# ── Owner protection ──────────────────────────────────────────────────────
|
||||
# An org owner cannot be suspended until they transfer ownership.
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
owner_memberships = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == target.id,
|
||||
OrganizationMember.role == OrganizationRole.OWNER,
|
||||
OrganizationMember.deleted_at == None,
|
||||
).all()
|
||||
if owner_memberships:
|
||||
org_names = [
|
||||
m.organization.name
|
||||
for m in owner_memberships
|
||||
if m.organization and not m.organization.deleted_at
|
||||
]
|
||||
return api_response(
|
||||
success=False,
|
||||
message=(
|
||||
f"Cannot suspend an organization owner. "
|
||||
f"{target.email} is the owner of: {', '.join(org_names)}. "
|
||||
"Transfer ownership to another member first."
|
||||
),
|
||||
status=403,
|
||||
error_type="OWNER_PROTECTION",
|
||||
)
|
||||
|
||||
if target.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
return api_response(success=False, message="User is already suspended", status=409, error_type="CONFLICT")
|
||||
|
||||
target.status = UserStatus.SUSPENDED
|
||||
_db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.USER_SUSPEND,
|
||||
user_id=caller.id,
|
||||
organization_id=admin_in_shared_org.organization_id,
|
||||
resource_type="user",
|
||||
resource_id=str(target.id),
|
||||
description=f"Admin suspended user {target.email}",
|
||||
metadata={"target_user_id": str(target.id), "target_email": target.email},
|
||||
)
|
||||
|
||||
return api_response(data={"user": target.to_dict()}, message="User suspended successfully")
|
||||
|
||||
|
||||
@api_v1_bp.route("/admin/users/<user_id>/unsuspend", methods=["POST"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def admin_unsuspend_user(user_id):
|
||||
"""Restore a suspended user account to active status."""
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.user.user import User as _User
|
||||
from gatehouse_app.extensions import db as _db
|
||||
from gatehouse_app.utils.constants import UserStatus, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
caller = g.current_user
|
||||
target = _User.query.filter_by(id=user_id, deleted_at=None).first()
|
||||
if not target:
|
||||
return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
|
||||
|
||||
# Verify caller has admin access to a shared org
|
||||
target_org_ids = {m.organization_id for m in target.organization_memberships if m.deleted_at is None}
|
||||
admin_in_shared_org = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == caller.id,
|
||||
OrganizationMember.organization_id.in_(target_org_ids),
|
||||
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
|
||||
OrganizationMember.deleted_at == None,
|
||||
).first()
|
||||
|
||||
if not admin_in_shared_org:
|
||||
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
|
||||
|
||||
if target.status not in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
return api_response(success=False, message="User is not suspended", status=409, error_type="CONFLICT")
|
||||
|
||||
target.status = UserStatus.ACTIVE
|
||||
_db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.USER_UNSUSPEND,
|
||||
user_id=caller.id,
|
||||
organization_id=admin_in_shared_org.organization_id,
|
||||
resource_type="user",
|
||||
resource_id=str(target.id),
|
||||
description=f"Admin unsuspended user {target.email}",
|
||||
metadata={"target_user_id": str(target.id), "target_email": target.email},
|
||||
)
|
||||
|
||||
return api_response(data={"user": target.to_dict()}, message="User unsuspended successfully")
|
||||
|
||||
|
||||
@api_v1_bp.route("/users/me/invites", methods=["GET"])
|
||||
@login_required
|
||||
def get_my_pending_invites():
|
||||
"""Return pending (unaccepted, non-expired) invitations for the current user's email."""
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
from datetime import datetime, timezone
|
||||
|
||||
user = g.current_user
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
invites = OrgInviteToken.query.filter(
|
||||
OrgInviteToken.email == user.email,
|
||||
OrgInviteToken.accepted_at.is_(None),
|
||||
OrgInviteToken.expires_at > now,
|
||||
OrgInviteToken.deleted_at.is_(None),
|
||||
).all()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"invites": [
|
||||
{
|
||||
"token": i.token,
|
||||
"organization": {"id": str(i.organization_id), "name": i.organization.name},
|
||||
"role": i.role,
|
||||
"expires_at": i.expires_at.isoformat(),
|
||||
}
|
||||
for i in invites
|
||||
]
|
||||
},
|
||||
message="Pending invitations retrieved",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/users/me/memberships", methods=["GET"])
|
||||
@login_required
|
||||
def get_my_memberships():
|
||||
"""Return the current user's department and principal memberships across all orgs.
|
||||
|
||||
Returns:
|
||||
200: {
|
||||
orgs: [{
|
||||
org_id, org_name, role,
|
||||
departments: [{id, name, description}],
|
||||
principals: [{id, name, description, via_department: bool}]
|
||||
}]
|
||||
}
|
||||
"""
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal, Department
|
||||
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
|
||||
|
||||
user = g.current_user
|
||||
|
||||
memberships = OrganizationMember.query.filter_by(
|
||||
user_id=user.id,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
|
||||
orgs_result = []
|
||||
for membership in memberships:
|
||||
org = membership.organization
|
||||
if not org or org.deleted_at is not None:
|
||||
continue
|
||||
|
||||
# Departments the user belongs to
|
||||
dept_memberships = DepartmentMembership.query.filter_by(
|
||||
user_id=user.id,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
user_depts = [
|
||||
dm.department for dm in dept_memberships
|
||||
if dm.department
|
||||
and dm.department.organization_id == org.id
|
||||
and dm.department.deleted_at is None
|
||||
]
|
||||
|
||||
# Principals: direct
|
||||
direct_pm = PrincipalMembership.query.filter_by(
|
||||
user_id=user.id,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
direct_principal_ids = {
|
||||
pm.principal_id for pm in direct_pm
|
||||
if pm.principal
|
||||
and pm.principal.organization_id == org.id
|
||||
and pm.principal.deleted_at is None
|
||||
}
|
||||
|
||||
# Principals: via department
|
||||
via_dept_principal_ids = set()
|
||||
for dept in user_depts:
|
||||
for dp in DepartmentPrincipal.query.filter_by(department_id=dept.id, deleted_at=None).all():
|
||||
if dp.principal and dp.principal.deleted_at is None:
|
||||
via_dept_principal_ids.add(dp.principal_id)
|
||||
|
||||
all_principal_ids = direct_principal_ids | via_dept_principal_ids
|
||||
principals_list = []
|
||||
if all_principal_ids:
|
||||
for p in Principal.query.filter(
|
||||
Principal.id.in_(list(all_principal_ids)),
|
||||
Principal.deleted_at == None,
|
||||
).all():
|
||||
principals_list.append({
|
||||
"id": str(p.id),
|
||||
"name": p.name,
|
||||
"description": p.description,
|
||||
"via_department": p.id not in direct_principal_ids,
|
||||
})
|
||||
|
||||
role = membership.role
|
||||
orgs_result.append({
|
||||
"org_id": str(org.id),
|
||||
"org_name": org.name,
|
||||
"role": role.value if hasattr(role, "value") else role,
|
||||
"departments": [
|
||||
{"id": str(d.id), "name": d.name, "description": d.description}
|
||||
for d in user_depts
|
||||
],
|
||||
"principals": principals_list,
|
||||
})
|
||||
|
||||
return api_response(
|
||||
data={"orgs": orgs_result},
|
||||
message="Memberships retrieved",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/admin/users/<user_id>/delete", methods=["POST"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def admin_hard_delete_user(user_id):
|
||||
"""Permanently delete a user and ALL associated data (hard delete, irreversible).
|
||||
|
||||
Required body: {"confirm": true}
|
||||
|
||||
Pre-conditions:
|
||||
- Caller is OWNER or ADMIN of a shared org with the target.
|
||||
- Cannot delete yourself.
|
||||
- Target must not be the OWNER of any active organization (transfer first).
|
||||
|
||||
Side-effects:
|
||||
- All active SSH certificates are revoked before deletion.
|
||||
- The user row and all cascaded rows are hard-deleted from the database.
|
||||
- An audit log entry is written by the *caller* (so it is not lost with the user).
|
||||
"""
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.user.user import User as _User
|
||||
from gatehouse_app.extensions import db as _db
|
||||
from gatehouse_app.utils.constants import UserStatus, AuditAction, OrganizationRole
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
caller = g.current_user
|
||||
data = request.get_json() or {}
|
||||
|
||||
if not data.get("confirm"):
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Deletion requires explicit confirmation. Send {\"confirm\": true} to proceed.",
|
||||
status=400,
|
||||
error_type="CONFIRMATION_REQUIRED",
|
||||
)
|
||||
|
||||
target = _User.query.filter_by(id=user_id).first()
|
||||
if not target:
|
||||
return api_response(success=False, message="User not found", status=404, error_type="NOT_FOUND")
|
||||
|
||||
if target.id == caller.id:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Cannot delete your own account via this endpoint.",
|
||||
status=400,
|
||||
error_type="BAD_REQUEST",
|
||||
)
|
||||
|
||||
# Caller must be OWNER/ADMIN of a shared org.
|
||||
# Include soft-deleted memberships so that already-soft-deleted users can
|
||||
# still be hard-deleted by an admin who shared an org with them.
|
||||
target_org_ids = {m.organization_id for m in target.organization_memberships}
|
||||
admin_in_shared_org = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == caller.id,
|
||||
OrganizationMember.organization_id.in_(target_org_ids),
|
||||
OrganizationMember.role.in_(["OWNER", "ADMIN"]),
|
||||
OrganizationMember.deleted_at == None,
|
||||
).first()
|
||||
if not admin_in_shared_org:
|
||||
return api_response(success=False, message="Access denied", status=403, error_type="AUTHORIZATION_ERROR")
|
||||
|
||||
# Block deletion if target is an org owner — they must transfer first
|
||||
owner_memberships = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == target.id,
|
||||
OrganizationMember.role == OrganizationRole.OWNER,
|
||||
OrganizationMember.deleted_at == None,
|
||||
).all()
|
||||
if owner_memberships:
|
||||
org_names = [
|
||||
m.organization.name
|
||||
for m in owner_memberships
|
||||
if m.organization and not m.organization.deleted_at
|
||||
]
|
||||
return api_response(
|
||||
success=False,
|
||||
message=(
|
||||
f"Cannot delete an organization owner. "
|
||||
f"{target.email} is the owner of: {', '.join(org_names)}. "
|
||||
"Transfer ownership to another member first."
|
||||
),
|
||||
status=403,
|
||||
error_type="OWNER_PROTECTION",
|
||||
)
|
||||
|
||||
# ── Collect counts for audit metadata ────────────────────────────────────
|
||||
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
|
||||
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
|
||||
|
||||
ssh_key_count = SSHKey.query.filter_by(user_id=target.id, deleted_at=None).count()
|
||||
active_cert_count = SSHCertificate.query.filter_by(
|
||||
user_id=target.id, revoked=False
|
||||
).filter(SSHCertificate.deleted_at == None).count()
|
||||
|
||||
# ── Revoke all active SSH certificates before deletion ───────────────────
|
||||
active_certs = SSHCertificate.query.filter_by(
|
||||
user_id=target.id, revoked=False
|
||||
).filter(SSHCertificate.deleted_at == None).all()
|
||||
for cert in active_certs:
|
||||
try:
|
||||
cert.revoke("account_deleted")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if active_certs:
|
||||
try:
|
||||
_db.session.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Hard delete ───────────────────────────────────────────────────────────
|
||||
target_email = target.email # capture before deletion
|
||||
target_id_str = str(target.id)
|
||||
|
||||
try:
|
||||
_db.session.delete(target) # cascades to all child tables
|
||||
_db.session.flush()
|
||||
except Exception as exc:
|
||||
_db.session.rollback()
|
||||
import logging
|
||||
logging.getLogger(__name__).error(f"Hard delete failed for {target_id_str}: {exc}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Failed to delete user account. Please try again.",
|
||||
status=500,
|
||||
error_type="SERVER_ERROR",
|
||||
)
|
||||
|
||||
# ── Audit log (written as the caller so it survives the deletion) ─────────
|
||||
AuditService.log_action(
|
||||
action=AuditAction.USER_HARD_DELETE,
|
||||
user_id=caller.id,
|
||||
organization_id=admin_in_shared_org.organization_id,
|
||||
resource_type="user",
|
||||
resource_id=target_id_str,
|
||||
description=f"Admin permanently deleted user account: {target_email}",
|
||||
metadata={
|
||||
"deleted_user_id": target_id_str,
|
||||
"deleted_user_email": target_email,
|
||||
"ssh_keys_deleted": ssh_key_count,
|
||||
"certs_revoked": active_cert_count,
|
||||
},
|
||||
)
|
||||
|
||||
_db.session.commit()
|
||||
|
||||
return api_response(
|
||||
message=f"User account {target_email} has been permanently deleted.",
|
||||
data={
|
||||
"deleted_user_id": target_id_str,
|
||||
"deleted_user_email": target_email,
|
||||
"ssh_keys_deleted": ssh_key_count,
|
||||
"certs_revoked": active_cert_count,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
"""SSH CA Configuration Manager.
|
||||
|
||||
Handles loading and managing SSH CA configuration from etc/ssh_ca.conf
|
||||
and environment variables.
|
||||
"""
|
||||
import os
|
||||
import configparser
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class SSHCAConfig:
|
||||
"""Configuration manager for SSH CA settings.
|
||||
|
||||
Loads configuration from:
|
||||
1. etc/ssh_ca.conf file
|
||||
2. Environment variables (override config file)
|
||||
3. Application environment-specific defaults
|
||||
|
||||
Example:
|
||||
config = SSHCAConfig()
|
||||
cert_hours = config.get_int('cert_validity_hours')
|
||||
key_path = config.get_str('ca_key_path')
|
||||
"""
|
||||
|
||||
# Configuration file location (relative to project root)
|
||||
DEFAULT_CONFIG_FILE = "etc/ssh_ca.conf"
|
||||
|
||||
# Default values if config file is missing
|
||||
DEFAULTS = {
|
||||
'cert_validity_hours': '8',
|
||||
'max_cert_validity_hours': '720',
|
||||
'ca_key_path': '',
|
||||
'max_principals_per_cert': '256',
|
||||
'max_key_id_length': '255',
|
||||
'verification_challenge_max_age': '24',
|
||||
'auto_delete_unverified_days': '30',
|
||||
}
|
||||
|
||||
def __init__(self, config_file: Optional[str] = None, environment: Optional[str] = None):
|
||||
"""Initialize SSH CA configuration.
|
||||
|
||||
Args:
|
||||
config_file: Path to config file (default: etc/ssh_ca.conf)
|
||||
environment: Environment name (development, production, testing)
|
||||
Default: value of FLASK_ENV or 'development'
|
||||
"""
|
||||
self.config = configparser.ConfigParser()
|
||||
|
||||
# Determine environment
|
||||
if environment is None:
|
||||
environment = os.environ.get('FLASK_ENV', 'development')
|
||||
self.environment = environment
|
||||
|
||||
# Load config file
|
||||
if config_file is None:
|
||||
# Try to find config file relative to this module
|
||||
module_dir = Path(__file__).parent.parent.parent
|
||||
config_file = module_dir / self.DEFAULT_CONFIG_FILE
|
||||
|
||||
self.config_file = config_file
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from file and apply environment-specific overrides."""
|
||||
# Set defaults
|
||||
self.config['default'] = self.DEFAULTS.copy()
|
||||
|
||||
# Load config file if it exists
|
||||
if Path(self.config_file).exists():
|
||||
self.config.read(self.config_file)
|
||||
|
||||
# Apply environment-specific configuration
|
||||
if self.environment in self.config:
|
||||
for key, value in self.config[self.environment].items():
|
||||
self.config['default'][key] = value
|
||||
|
||||
def get_str(self, key: str, default: Optional[str] = None) -> str:
|
||||
"""Get a string configuration value.
|
||||
|
||||
First checks environment variables (SSH_CA_<KEY>), then config file.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as string
|
||||
"""
|
||||
env_key = f"SSH_CA_{key.upper()}"
|
||||
|
||||
# Check environment variable first
|
||||
if env_key in os.environ:
|
||||
return os.environ[env_key]
|
||||
|
||||
# Check config file
|
||||
if key in self.config['default']:
|
||||
value = self.config['default'][key]
|
||||
# Handle environment variable substitution
|
||||
return os.path.expandvars(value)
|
||||
|
||||
# Return default
|
||||
if default is not None:
|
||||
return default
|
||||
|
||||
return self.DEFAULTS.get(key, '')
|
||||
|
||||
def get_int(self, key: str, default: Optional[int] = None) -> int:
|
||||
"""Get an integer configuration value.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as integer
|
||||
|
||||
Raises:
|
||||
ValueError: If value cannot be converted to integer
|
||||
"""
|
||||
str_value = self.get_str(key)
|
||||
if not str_value:
|
||||
if default is not None:
|
||||
return default
|
||||
raise ValueError(f"No value found for {key}")
|
||||
|
||||
try:
|
||||
return int(str_value)
|
||||
except ValueError:
|
||||
if default is not None:
|
||||
return default
|
||||
raise ValueError(f"Configuration {key}={str_value} is not a valid integer")
|
||||
|
||||
def get_bool(self, key: str, default: Optional[bool] = None) -> bool:
|
||||
"""Get a boolean configuration value.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as boolean
|
||||
"""
|
||||
str_value = self.get_str(key)
|
||||
if not str_value:
|
||||
if default is not None:
|
||||
return default
|
||||
return False
|
||||
|
||||
return str_value.lower() in ('true', '1', 'yes', 'on')
|
||||
|
||||
def get_list(self, key: str, delimiter: str = ',', default: Optional[list] = None) -> list:
|
||||
"""Get a comma-separated list configuration value.
|
||||
|
||||
Args:
|
||||
key: Configuration key
|
||||
delimiter: Delimiter between items (default: comma)
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Configuration value as list of strings
|
||||
"""
|
||||
str_value = self.get_str(key)
|
||||
if not str_value:
|
||||
if default is not None:
|
||||
return default
|
||||
return []
|
||||
|
||||
return [item.strip() for item in str_value.split(delimiter) if item.strip()]
|
||||
|
||||
def validate_config(self) -> list:
|
||||
"""Validate SSH CA configuration.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Check cert validity hours
|
||||
try:
|
||||
validity = self.get_int('cert_validity_hours')
|
||||
max_validity = self.get_int('max_cert_validity_hours')
|
||||
if validity > max_validity:
|
||||
errors.append(
|
||||
f"cert_validity_hours ({validity}) > max_cert_validity_hours ({max_validity})"
|
||||
)
|
||||
except ValueError as e:
|
||||
errors.append(f"Invalid cert validity hours: {e}")
|
||||
|
||||
# Check principals limit
|
||||
max_principals = self.get_int('max_principals_per_cert')
|
||||
if max_principals > 256:
|
||||
errors.append(f"max_principals_per_cert ({max_principals}) exceeds SSH limit of 256")
|
||||
|
||||
# Check ca_key_path is set
|
||||
if not self.get_str('ca_key_path', '').strip():
|
||||
errors.append("ca_key_path is not set")
|
||||
|
||||
return errors
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Export current configuration as dictionary.
|
||||
"""
|
||||
return dict(self.config['default'])
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of configuration."""
|
||||
return f"<SSHCAConfig environment={self.environment} file={self.config_file}>"
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_config_instance = None
|
||||
|
||||
|
||||
def get_ssh_ca_config() -> SSHCAConfig:
|
||||
"""Get the global SSH CA configuration instance.
|
||||
|
||||
This function uses a singleton pattern to ensure only one
|
||||
configuration instance is created and reused.
|
||||
|
||||
Returns:
|
||||
SSHCAConfig instance
|
||||
"""
|
||||
global _config_instance
|
||||
if _config_instance is None:
|
||||
_config_instance = SSHCAConfig()
|
||||
return _config_instance
|
||||
|
||||
|
||||
def reset_config_instance():
|
||||
"""Reset the global configuration instance.
|
||||
"""
|
||||
global _config_instance
|
||||
_config_instance = None
|
||||
@@ -19,6 +19,21 @@ from gatehouse_app.exceptions.validation_exceptions import (
|
||||
OrganizationNotFoundError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from gatehouse_app.exceptions.ssh_exceptions import (
|
||||
SSHCAError,
|
||||
SSHKeyError,
|
||||
SSHKeyNotFoundError,
|
||||
SSHKeyAlreadyExistsError,
|
||||
SSHKeyNotVerifiedError,
|
||||
SSHCertificateError,
|
||||
SSHCertificateNotFoundError,
|
||||
CAError,
|
||||
CANotFoundError,
|
||||
PrincipalError,
|
||||
PrincipalNotFoundError,
|
||||
DepartmentError,
|
||||
DepartmentNotFoundError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseAPIException",
|
||||
@@ -37,4 +52,18 @@ __all__ = [
|
||||
"EmailAlreadyExistsError",
|
||||
"OrganizationNotFoundError",
|
||||
"UserNotFoundError",
|
||||
"SSHCAError",
|
||||
"SSHKeyError",
|
||||
"SSHKeyNotFoundError",
|
||||
"SSHKeyAlreadyExistsError",
|
||||
"SSHKeyNotVerifiedError",
|
||||
"SSHCertificateError",
|
||||
"SSHCertificateNotFoundError",
|
||||
"CAError",
|
||||
"CANotFoundError",
|
||||
"PrincipalError",
|
||||
"PrincipalNotFoundError",
|
||||
"DepartmentError",
|
||||
"DepartmentNotFoundError",
|
||||
]
|
||||
|
||||
|
||||
@@ -16,9 +16,10 @@ class BaseAPIException(Exception):
|
||||
message: Custom error message
|
||||
error_details: Additional error details dictionary
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(self.message)
|
||||
if message:
|
||||
self.message = message
|
||||
super().__init__(message) # update args so str(e) works
|
||||
self.error_details = error_details or {}
|
||||
|
||||
def to_dict(self):
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
"""SSH-specific exceptions."""
|
||||
from gatehouse_app.exceptions.base import BaseAPIException
|
||||
|
||||
|
||||
class SSHCAError(BaseAPIException):
|
||||
"""Base exception for SSH CA operations."""
|
||||
|
||||
status_code = 500
|
||||
error_type = "SSH_CA_ERROR"
|
||||
|
||||
|
||||
class SSHKeyError(BaseAPIException):
|
||||
"""Exception for SSH key operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "SSH_KEY_ERROR"
|
||||
|
||||
|
||||
class SSHKeyNotFoundError(BaseAPIException):
|
||||
"""SSH key not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "SSH_KEY_NOT_FOUND"
|
||||
|
||||
|
||||
class SSHKeyAlreadyExistsError(BaseAPIException):
|
||||
"""SSH key already exists (duplicate fingerprint)."""
|
||||
|
||||
status_code = 409
|
||||
error_type = "SSH_KEY_ALREADY_EXISTS"
|
||||
|
||||
|
||||
class SSHKeyNotVerifiedError(BaseAPIException):
|
||||
"""SSH key has not been verified."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "SSH_KEY_NOT_VERIFIED"
|
||||
|
||||
|
||||
class SSHCertificateError(BaseAPIException):
|
||||
"""Exception for SSH certificate operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "SSH_CERT_ERROR"
|
||||
|
||||
|
||||
class SSHCertificateNotFoundError(BaseAPIException):
|
||||
"""SSH certificate not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "SSH_CERT_NOT_FOUND"
|
||||
|
||||
|
||||
class CAError(BaseAPIException):
|
||||
"""Exception for Certificate Authority operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "CA_ERROR"
|
||||
|
||||
|
||||
class CANotFoundError(BaseAPIException):
|
||||
"""Certificate Authority not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "CA_NOT_FOUND"
|
||||
|
||||
|
||||
class PrincipalError(BaseAPIException):
|
||||
"""Exception for principal operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "PRINCIPAL_ERROR"
|
||||
|
||||
|
||||
class PrincipalNotFoundError(BaseAPIException):
|
||||
"""Principal not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "PRINCIPAL_NOT_FOUND"
|
||||
|
||||
|
||||
class DepartmentError(BaseAPIException):
|
||||
"""Exception for department operations."""
|
||||
|
||||
status_code = 400
|
||||
error_type = "DEPARTMENT_ERROR"
|
||||
|
||||
|
||||
class DepartmentNotFoundError(BaseAPIException):
|
||||
"""Department not found."""
|
||||
|
||||
status_code = 404
|
||||
error_type = "DEPARTMENT_NOT_FOUND"
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Security headers middleware."""
|
||||
from flask import request
|
||||
import os
|
||||
from flask import current_app, request
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware:
|
||||
@@ -34,13 +35,22 @@ class SecurityHeadersMiddleware:
|
||||
)
|
||||
|
||||
# Content Security Policy
|
||||
try:
|
||||
flask_env = current_app.config.get("ENV") or os.environ.get("FLASK_ENV", "production")
|
||||
if flask_env == "development":
|
||||
connect_src = "connect-src 'self' http://localhost:5000 http://127.0.0.1:5000"
|
||||
else:
|
||||
connect_src = "connect-src 'self'"
|
||||
except RuntimeError:
|
||||
connect_src = "connect-src 'self'"
|
||||
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self' 'unsafe-inline'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data: https:; "
|
||||
"font-src 'self' data:; "
|
||||
"connect-src 'self'"
|
||||
+ connect_src
|
||||
)
|
||||
|
||||
# Referrer Policy
|
||||
|
||||
@@ -1,43 +1,149 @@
|
||||
"""Models package."""
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.authentication_method import (
|
||||
"""Models package.
|
||||
|
||||
Sub-packages
|
||||
------------
|
||||
models.user — User, Session
|
||||
models.organization — Organization, OrganizationMember, Department,
|
||||
DepartmentMembership, DepartmentPrincipal,
|
||||
DepartmentCertPolicy, Principal, PrincipalMembership,
|
||||
OrgInviteToken
|
||||
models.auth — AuthenticationMethod, ApplicationProviderConfig,
|
||||
OrganizationProviderOverride, OAuthState,
|
||||
AuditLog, PasswordResetToken, EmailVerificationToken
|
||||
models.oidc — OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCSession,
|
||||
OIDCTokenMetadata, OIDCAuditLog, OidcJwksKey
|
||||
models.ssh_ca — CA, KeyType, CertType, CaType, CAPermission,
|
||||
SSHKey, SSHCertificate, CertificateStatus,
|
||||
CertificateAuditLog
|
||||
models.security — OrganizationSecurityPolicy, UserSecurityPolicy,
|
||||
MfaPolicyCompliance
|
||||
|
||||
All names are re-exported here so that existing code using the flat import
|
||||
style (``from gatehouse_app.models import X``) or the old per-file style
|
||||
(``from gatehouse_app.models.user import User``) continue to work unchanged.
|
||||
"""
|
||||
|
||||
# ── Base ──────────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.base import BaseModel # noqa: F401
|
||||
|
||||
# ── User ──────────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.user.user import User # noqa: F401
|
||||
from gatehouse_app.models.user.session import Session # noqa: F401
|
||||
|
||||
# ── Organization ──────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.organization.organization import Organization # noqa: F401
|
||||
from gatehouse_app.models.organization.organization_member import ( # noqa: F401
|
||||
OrganizationMember,
|
||||
)
|
||||
from gatehouse_app.models.organization.department import ( # noqa: F401
|
||||
Department,
|
||||
DepartmentMembership,
|
||||
DepartmentPrincipal,
|
||||
)
|
||||
from gatehouse_app.models.organization.department_cert_policy import ( # noqa: F401
|
||||
DepartmentCertPolicy,
|
||||
STANDARD_EXTENSIONS,
|
||||
)
|
||||
from gatehouse_app.models.organization.principal import ( # noqa: F401
|
||||
Principal,
|
||||
PrincipalMembership,
|
||||
)
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken # noqa: F401
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.auth.authentication_method import ( # noqa: F401
|
||||
AuthenticationMethod,
|
||||
ApplicationProviderConfig,
|
||||
OrganizationProviderOverride,
|
||||
OAuthState,
|
||||
)
|
||||
from gatehouse_app.models.session import Session
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.models.oidc_refresh_token import OIDCRefreshToken
|
||||
from gatehouse_app.models.oidc_session import OIDCSession
|
||||
from gatehouse_app.models.oidc_token_metadata import OIDCTokenMetadata
|
||||
from gatehouse_app.models.oidc_audit_log import OIDCAuditLog
|
||||
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
|
||||
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
|
||||
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
||||
from gatehouse_app.models.auth.audit_log import AuditLog # noqa: F401
|
||||
from gatehouse_app.models.auth.password_reset_token import PasswordResetToken # noqa: F401
|
||||
from gatehouse_app.models.auth.email_verification_token import ( # noqa: F401
|
||||
EmailVerificationToken,
|
||||
)
|
||||
|
||||
# ── OIDC ──────────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.oidc.oidc_client import OIDCClient # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_session import OIDCSession # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog # noqa: F401
|
||||
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey # noqa: F401
|
||||
|
||||
# ── SSH / CA ──────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.ssh_ca.ca import ( # noqa: F401
|
||||
CA,
|
||||
KeyType,
|
||||
CertType,
|
||||
CaType,
|
||||
CAPermission,
|
||||
)
|
||||
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey # noqa: F401
|
||||
from gatehouse_app.models.ssh_ca.ssh_certificate import ( # noqa: F401
|
||||
SSHCertificate,
|
||||
CertificateStatus,
|
||||
)
|
||||
from gatehouse_app.models.ssh_ca.certificate_audit_log import ( # noqa: F401
|
||||
CertificateAuditLog,
|
||||
)
|
||||
|
||||
# ── Security ──────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.security.organization_security_policy import ( # noqa: F401
|
||||
OrganizationSecurityPolicy,
|
||||
)
|
||||
from gatehouse_app.models.security.user_security_policy import ( # noqa: F401
|
||||
UserSecurityPolicy,
|
||||
)
|
||||
from gatehouse_app.models.security.mfa_policy_compliance import ( # noqa: F401
|
||||
MfaPolicyCompliance,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base
|
||||
"BaseModel",
|
||||
# User
|
||||
"User",
|
||||
"Session",
|
||||
# Organization
|
||||
"Organization",
|
||||
"OrganizationMember",
|
||||
"Department",
|
||||
"DepartmentMembership",
|
||||
"DepartmentPrincipal",
|
||||
"DepartmentCertPolicy",
|
||||
"STANDARD_EXTENSIONS",
|
||||
"Principal",
|
||||
"PrincipalMembership",
|
||||
"OrgInviteToken",
|
||||
# Auth
|
||||
"AuthenticationMethod",
|
||||
"ApplicationProviderConfig",
|
||||
"OrganizationProviderOverride",
|
||||
"OAuthState",
|
||||
"Session",
|
||||
"AuditLog",
|
||||
"PasswordResetToken",
|
||||
"EmailVerificationToken",
|
||||
# OIDC
|
||||
"OIDCClient",
|
||||
"OIDCAuthCode",
|
||||
"OIDCRefreshToken",
|
||||
"OIDCSession",
|
||||
"OIDCTokenMetadata",
|
||||
"OIDCAuditLog",
|
||||
"OidcJwksKey",
|
||||
# SSH / CA
|
||||
"CA",
|
||||
"KeyType",
|
||||
"CertType",
|
||||
"CaType",
|
||||
"CAPermission",
|
||||
"SSHKey",
|
||||
"SSHCertificate",
|
||||
"CertificateStatus",
|
||||
"CertificateAuditLog",
|
||||
# Security
|
||||
"OrganizationSecurityPolicy",
|
||||
"UserSecurityPolicy",
|
||||
"MfaPolicyCompliance",
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Auth subpackage — authentication methods, tokens, and audit logs."""
|
||||
from gatehouse_app.models.auth.authentication_method import (
|
||||
AuthenticationMethod,
|
||||
ApplicationProviderConfig,
|
||||
OrganizationProviderOverride,
|
||||
OAuthState,
|
||||
)
|
||||
from gatehouse_app.models.auth.audit_log import AuditLog
|
||||
from gatehouse_app.models.auth.password_reset_token import PasswordResetToken
|
||||
from gatehouse_app.models.auth.email_verification_token import EmailVerificationToken
|
||||
|
||||
__all__ = [
|
||||
"AuthenticationMethod",
|
||||
"ApplicationProviderConfig",
|
||||
"OrganizationProviderOverride",
|
||||
"OAuthState",
|
||||
"AuditLog",
|
||||
"PasswordResetToken",
|
||||
"EmailVerificationToken",
|
||||
]
|
||||
@@ -26,14 +26,13 @@ class AuditLog(BaseModel):
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Success/failure
|
||||
# Outcome
|
||||
success = db.Column(db.Boolean, default=True, nullable=False)
|
||||
error_message = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="audit_logs")
|
||||
|
||||
# Indexes for common queries
|
||||
__table_args__ = (
|
||||
db.Index("idx_audit_user_action", "user_id", "action"),
|
||||
db.Index("idx_audit_resource", "resource_type", "resource_id"),
|
||||
@@ -45,9 +44,8 @@ class AuditLog(BaseModel):
|
||||
return f"<AuditLog action={self.action} user_id={self.user_id}>"
|
||||
|
||||
@classmethod
|
||||
def log(cls, action, user_id=None, **kwargs):
|
||||
"""
|
||||
Create an audit log entry.
|
||||
def log(cls, action, user_id=None, **kwargs) -> "AuditLog":
|
||||
"""Create an audit log entry.
|
||||
|
||||
Args:
|
||||
action: AuditAction enum value
|
||||
+101
-95
@@ -1,4 +1,4 @@
|
||||
"""Authentication method model."""
|
||||
"""Authentication method model — user credentials and OAuth provider config."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import secrets
|
||||
from gatehouse_app.extensions import db
|
||||
@@ -35,7 +35,6 @@ class AuthenticationMethod(BaseModel):
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="authentication_methods")
|
||||
|
||||
# Ensure unique provider combinations
|
||||
__table_args__ = (
|
||||
db.Index("idx_user_method", "user_id", "method_type"),
|
||||
db.UniqueConstraint(
|
||||
@@ -45,13 +44,15 @@ class AuthenticationMethod(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of AuthenticationMethod."""
|
||||
return f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
|
||||
return (
|
||||
f"<AuthenticationMethod user_id={self.user_id} type={self.method_type}>"
|
||||
)
|
||||
|
||||
def is_password(self):
|
||||
def is_password(self) -> bool:
|
||||
"""Check if this is a password authentication method."""
|
||||
return self.method_type == AuthMethodType.PASSWORD
|
||||
|
||||
def is_oauth(self):
|
||||
def is_oauth(self) -> bool:
|
||||
"""Check if this is an OAuth authentication method."""
|
||||
return self.method_type in [
|
||||
AuthMethodType.GOOGLE,
|
||||
@@ -59,32 +60,32 @@ class AuthenticationMethod(BaseModel):
|
||||
AuthMethodType.MICROSOFT,
|
||||
]
|
||||
|
||||
def is_totp(self):
|
||||
def is_totp(self) -> bool:
|
||||
"""Check if this is a TOTP authentication method."""
|
||||
return self.method_type == AuthMethodType.TOTP
|
||||
|
||||
def is_webauthn(self):
|
||||
def is_webauthn(self) -> bool:
|
||||
"""Check if this is a WebAuthn authentication method."""
|
||||
return self.method_type == AuthMethodType.WEBAUTHN
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude password hash and TOTP secrets
|
||||
exclude.append("password_hash")
|
||||
exclude.append("totp_secret")
|
||||
exclude.append("totp_backup_codes")
|
||||
# Always exclude credential material
|
||||
for field in ("password_hash", "totp_secret", "totp_backup_codes"):
|
||||
if field not in exclude:
|
||||
exclude.append(field)
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
def to_webauthn_dict(self):
|
||||
"""Convert WebAuthn credential to public dictionary.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with safe-to-expose credential information.
|
||||
Dictionary with safe-to-expose credential information, or None.
|
||||
"""
|
||||
if not self.is_webauthn() or not self.provider_data:
|
||||
return None
|
||||
|
||||
|
||||
data = self.provider_data
|
||||
return {
|
||||
"id": data.get("credential_id"),
|
||||
@@ -98,26 +99,26 @@ class AuthenticationMethod(BaseModel):
|
||||
|
||||
class ApplicationProviderConfig(BaseModel):
|
||||
"""Application-wide OAuth provider configuration.
|
||||
|
||||
This model stores OAuth provider credentials at the application level,
|
||||
allowing users to authenticate without needing to specify an organization first.
|
||||
|
||||
Stores OAuth provider credentials at the application level, allowing users
|
||||
to authenticate without needing to specify an organization first.
|
||||
"""
|
||||
|
||||
__tablename__ = "application_provider_configs"
|
||||
|
||||
# Provider identification
|
||||
provider_type = db.Column(db.String(50), nullable=False, unique=True, index=True)
|
||||
|
||||
# OAuth credentials (encrypted)
|
||||
|
||||
# OAuth credentials (client_secret encrypted at rest)
|
||||
client_id = db.Column(db.String(255), nullable=False)
|
||||
client_secret_encrypted = db.Column(db.String(512), nullable=True)
|
||||
|
||||
|
||||
# Provider status
|
||||
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
|
||||
# Default redirect URL
|
||||
default_redirect_url = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
|
||||
# Provider-specific settings (JSON)
|
||||
additional_config = db.Column(db.JSON, nullable=True)
|
||||
|
||||
@@ -126,28 +127,34 @@ class ApplicationProviderConfig(BaseModel):
|
||||
"OrganizationProviderOverride",
|
||||
back_populates="application_config",
|
||||
foreign_keys="OrganizationProviderOverride.provider_type",
|
||||
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
|
||||
cascade="all, delete-orphan"
|
||||
primaryjoin=(
|
||||
"ApplicationProviderConfig.provider_type"
|
||||
"==OrganizationProviderOverride.provider_type"
|
||||
),
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of ApplicationProviderConfig."""
|
||||
return f"<ApplicationProviderConfig provider={self.provider_type} enabled={self.is_enabled}>"
|
||||
return (
|
||||
f"<ApplicationProviderConfig provider={self.provider_type} "
|
||||
f"enabled={self.is_enabled}>"
|
||||
)
|
||||
|
||||
def set_client_secret(self, plaintext_secret: str):
|
||||
def set_client_secret(self, plaintext_secret: str) -> None:
|
||||
"""Encrypt and store client secret.
|
||||
|
||||
|
||||
Args:
|
||||
plaintext_secret: The plaintext OAuth client secret
|
||||
"""
|
||||
if plaintext_secret:
|
||||
self.client_secret_encrypted = encrypt(plaintext_secret)
|
||||
|
||||
def get_client_secret(self) -> str:
|
||||
def get_client_secret(self) -> str | None:
|
||||
"""Decrypt and return client secret.
|
||||
|
||||
|
||||
Returns:
|
||||
The plaintext OAuth client secret
|
||||
The plaintext OAuth client secret, or None if not set.
|
||||
"""
|
||||
if self.client_secret_encrypted:
|
||||
return decrypt(self.client_secret_encrypted)
|
||||
@@ -156,37 +163,38 @@ class ApplicationProviderConfig(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude encrypted client secret
|
||||
exclude.append("client_secret_encrypted")
|
||||
if "client_secret_encrypted" not in exclude:
|
||||
exclude.append("client_secret_encrypted")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
class OrganizationProviderOverride(BaseModel):
|
||||
"""Organization-specific OAuth configuration overrides.
|
||||
|
||||
This model allows organizations to override application-level OAuth settings
|
||||
for enterprise SSO scenarios or custom provider configurations.
|
||||
|
||||
Allows organizations to override application-level OAuth settings for
|
||||
enterprise SSO scenarios or custom provider configurations.
|
||||
"""
|
||||
|
||||
__tablename__ = "organization_provider_overrides"
|
||||
|
||||
# References
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"),
|
||||
nullable=False, index=True
|
||||
db.String(36),
|
||||
db.ForeignKey("organizations.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
provider_type = db.Column(db.String(50), nullable=False, index=True)
|
||||
|
||||
# Override OAuth credentials (encrypted, nullable - only if overriding)
|
||||
|
||||
# Override OAuth credentials (encrypted, nullable — only set when overriding)
|
||||
client_id = db.Column(db.String(255), nullable=True)
|
||||
client_secret_encrypted = db.Column(db.String(512), nullable=True)
|
||||
|
||||
|
||||
# Provider status
|
||||
is_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||
|
||||
|
||||
# Redirect URL override
|
||||
redirect_url_override = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
|
||||
# Provider-specific settings override (JSON)
|
||||
additional_config = db.Column(db.JSON, nullable=True)
|
||||
|
||||
@@ -196,37 +204,33 @@ class OrganizationProviderOverride(BaseModel):
|
||||
"ApplicationProviderConfig",
|
||||
back_populates="organization_overrides",
|
||||
foreign_keys=[provider_type],
|
||||
primaryjoin="ApplicationProviderConfig.provider_type==OrganizationProviderOverride.provider_type",
|
||||
viewonly=True
|
||||
primaryjoin=(
|
||||
"ApplicationProviderConfig.provider_type"
|
||||
"==OrganizationProviderOverride.provider_type"
|
||||
),
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
# Unique constraint on (organization_id, provider_type)
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"organization_id", "provider_type",
|
||||
name="uix_org_provider_type"
|
||||
"organization_id", "provider_type", name="uix_org_provider_type"
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OrganizationProviderOverride."""
|
||||
return f"<OrganizationProviderOverride org={self.organization_id} provider={self.provider_type}>"
|
||||
return (
|
||||
f"<OrganizationProviderOverride org={self.organization_id} "
|
||||
f"provider={self.provider_type}>"
|
||||
)
|
||||
|
||||
def set_client_secret(self, plaintext_secret: str):
|
||||
"""Encrypt and store client secret override.
|
||||
|
||||
Args:
|
||||
plaintext_secret: The plaintext OAuth client secret
|
||||
"""
|
||||
def set_client_secret(self, plaintext_secret: str) -> None:
|
||||
"""Encrypt and store client secret override."""
|
||||
if plaintext_secret:
|
||||
self.client_secret_encrypted = encrypt(plaintext_secret)
|
||||
|
||||
def get_client_secret(self) -> str:
|
||||
"""Decrypt and return client secret override.
|
||||
|
||||
Returns:
|
||||
The plaintext OAuth client secret
|
||||
"""
|
||||
def get_client_secret(self) -> str | None:
|
||||
"""Decrypt and return client secret override."""
|
||||
if self.client_secret_encrypted:
|
||||
return decrypt(self.client_secret_encrypted)
|
||||
return None
|
||||
@@ -234,53 +238,52 @@ class OrganizationProviderOverride(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude encrypted client secret
|
||||
exclude.append("client_secret_encrypted")
|
||||
if "client_secret_encrypted" not in exclude:
|
||||
exclude.append("client_secret_encrypted")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
"""OAuth flow state tracking.
|
||||
|
||||
This model tracks OAuth authentication flow state, including PKCE parameters
|
||||
and organization context (which is now optional to support login flows where
|
||||
the organization isn't known until after authentication).
|
||||
|
||||
Tracks OAuth authentication flow state, including PKCE parameters and
|
||||
organization context (which is optional to support login flows where the
|
||||
organization isn't known until after authentication).
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
# OAuth state parameter (unique, used for CSRF protection)
|
||||
state = db.Column(db.String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
|
||||
# Flow type: "login", "register", "link"
|
||||
flow_type = db.Column(db.String(50), nullable=False)
|
||||
|
||||
|
||||
# Provider type
|
||||
provider_type = db.Column(db.String(50), nullable=False)
|
||||
|
||||
# User context (optional - not set for login/register flows)
|
||||
|
||||
# User context (optional — not set for login/register flows)
|
||||
user_id = db.Column(db.String(36), db.ForeignKey("users.id"), nullable=True)
|
||||
|
||||
# Organization context (NOW OPTIONAL - for SSO discovery or post-auth)
|
||||
|
||||
# Organization context (optional — for SSO discovery or post-auth)
|
||||
organization_id = db.Column(
|
||||
db.String(36), db.ForeignKey("organizations.id"),
|
||||
nullable=True, index=True
|
||||
db.String(36), db.ForeignKey("organizations.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
|
||||
# PKCE parameters
|
||||
nonce = db.Column(db.String(128), nullable=True)
|
||||
code_verifier = db.Column(db.String(128), nullable=True)
|
||||
code_challenge = db.Column(db.String(128), nullable=True)
|
||||
|
||||
|
||||
# OAuth parameters
|
||||
redirect_uri = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
|
||||
# Post-auth redirect (for frontend routing)
|
||||
return_url = db.Column(db.String(2048), nullable=True)
|
||||
|
||||
|
||||
# Additional state data
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
|
||||
|
||||
# Expiration and usage tracking
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
used = db.Column(db.Boolean, default=False, nullable=False)
|
||||
@@ -291,7 +294,10 @@ class OAuthState(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OAuthState."""
|
||||
return f"<OAuthState state={self.state[:8]}... flow={self.flow_type} provider={self.provider_type}>"
|
||||
return (
|
||||
f"<OAuthState state={self.state[:8]}... "
|
||||
f"flow={self.flow_type} provider={self.provider_type}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_state(
|
||||
@@ -306,10 +312,10 @@ class OAuthState(BaseModel):
|
||||
code_challenge: str = None,
|
||||
nonce: str = None,
|
||||
extra_data: dict = None,
|
||||
lifetime_seconds: int = 600
|
||||
):
|
||||
"""Create a new OAuth state with auto-generated state parameter.
|
||||
|
||||
lifetime_seconds: int = 600,
|
||||
) -> "OAuthState":
|
||||
"""Create a new OAuth state with an auto-generated state parameter.
|
||||
|
||||
Args:
|
||||
flow_type: Type of flow ("login", "register", "link")
|
||||
provider_type: OAuth provider type
|
||||
@@ -322,13 +328,13 @@ class OAuthState(BaseModel):
|
||||
nonce: OpenID Connect nonce
|
||||
extra_data: Additional state data
|
||||
lifetime_seconds: How long the state is valid (default 10 minutes)
|
||||
|
||||
|
||||
Returns:
|
||||
New OAuthState instance
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=lifetime_seconds)
|
||||
|
||||
|
||||
oauth_state = cls(
|
||||
state=state,
|
||||
flow_type=flow_type,
|
||||
@@ -342,31 +348,30 @@ class OAuthState(BaseModel):
|
||||
nonce=nonce,
|
||||
extra_data=extra_data,
|
||||
expires_at=expires_at,
|
||||
used=False
|
||||
used=False,
|
||||
)
|
||||
oauth_state.save()
|
||||
return oauth_state
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the OAuth state is still valid.
|
||||
|
||||
|
||||
Returns:
|
||||
True if state hasn't expired and hasn't been used
|
||||
True if state hasn't expired and hasn't been used.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make expires_at timezone-aware if it's naive (database returns naive datetimes)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return not self.used and expires_at > now
|
||||
|
||||
def mark_used(self):
|
||||
def mark_used(self) -> None:
|
||||
"""Mark the state as used to prevent replay attacks."""
|
||||
self.used = True
|
||||
self.save()
|
||||
|
||||
@classmethod
|
||||
def cleanup_expired(cls):
|
||||
def cleanup_expired(cls) -> None:
|
||||
"""Remove expired OAuth states."""
|
||||
now = datetime.now(timezone.utc)
|
||||
cls.query.filter(cls.expires_at < now).delete()
|
||||
@@ -375,6 +380,7 @@ class OAuthState(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Exclude code_verifier as it's sensitive
|
||||
exclude.append("code_verifier")
|
||||
# code_verifier must never be exposed
|
||||
if "code_verifier" not in exclude:
|
||||
exclude.append("code_verifier")
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Email verification token model."""
|
||||
import secrets
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class EmailVerificationToken(BaseModel):
|
||||
"""Single-use token for verifying a user's email address."""
|
||||
|
||||
__tablename__ = "email_verification_tokens"
|
||||
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
user = db.relationship(
|
||||
"User",
|
||||
backref=db.backref("email_verification_tokens", cascade="all, delete-orphan"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate(cls, user_id: str, ttl_hours: int = 24) -> "EmailVerificationToken":
|
||||
"""Create a new verification token for a user.
|
||||
|
||||
Any existing unused tokens for this user are invalidated first.
|
||||
"""
|
||||
cls.query.filter_by(user_id=user_id, used_at=None).delete()
|
||||
db.session.flush()
|
||||
|
||||
token_value = secrets.token_urlsafe(48)
|
||||
instance = cls(
|
||||
user_id=user_id,
|
||||
token=token_value,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
|
||||
)
|
||||
db.session.add(instance)
|
||||
db.session.commit()
|
||||
return instance
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Return True if the token has not been used and has not expired."""
|
||||
if self.used_at is not None:
|
||||
return False
|
||||
now = datetime.now(timezone.utc)
|
||||
expires = self.expires_at
|
||||
if expires.tzinfo is None:
|
||||
expires = expires.replace(tzinfo=timezone.utc)
|
||||
return now < expires
|
||||
|
||||
def consume(self) -> None:
|
||||
"""Mark the token as used."""
|
||||
self.used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<EmailVerificationToken user_id={self.user_id} "
|
||||
f"used={self.used_at is not None}>"
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Password reset token model."""
|
||||
import secrets
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class PasswordResetToken(BaseModel):
|
||||
"""Single-use token for resetting a user's password."""
|
||||
|
||||
__tablename__ = "password_reset_tokens"
|
||||
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
used_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
user = db.relationship(
|
||||
"User",
|
||||
backref=db.backref("password_reset_tokens", cascade="all, delete-orphan"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate(cls, user_id: str, ttl_hours: int = 2) -> "PasswordResetToken":
|
||||
"""Create a new password reset token for a user.
|
||||
|
||||
Any existing unused tokens for this user are invalidated first.
|
||||
"""
|
||||
# Invalidate any existing unused tokens for this user
|
||||
cls.query.filter_by(user_id=user_id, used_at=None).delete()
|
||||
db.session.flush()
|
||||
|
||||
token_value = secrets.token_urlsafe(48)
|
||||
instance = cls(
|
||||
user_id=user_id,
|
||||
token=token_value,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=ttl_hours),
|
||||
)
|
||||
db.session.add(instance)
|
||||
db.session.commit()
|
||||
return instance
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Return True if the token has not been used and has not expired."""
|
||||
if self.used_at is not None:
|
||||
return False
|
||||
now = datetime.now(timezone.utc)
|
||||
expires = self.expires_at
|
||||
if expires.tzinfo is None:
|
||||
expires = expires.replace(tzinfo=timezone.utc)
|
||||
return now < expires
|
||||
|
||||
def consume(self) -> None:
|
||||
"""Mark the token as used."""
|
||||
self.used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<PasswordResetToken user_id={self.user_id} "
|
||||
f"used={self.used_at is not None}>"
|
||||
)
|
||||
@@ -82,7 +82,10 @@ class BaseModel(db.Model):
|
||||
if column.name not in exclude:
|
||||
value = getattr(self, column.name)
|
||||
if isinstance(value, datetime):
|
||||
result[column.name] = value.isoformat()
|
||||
if value.tzinfo is None:
|
||||
result[column.name] = value.isoformat() + "Z"
|
||||
else:
|
||||
result[column.name] = value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
else:
|
||||
result[column.name] = value
|
||||
return result
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
"""OIDC subpackage — clients, tokens, sessions, and audit logs."""
|
||||
from gatehouse_app.models.oidc.oidc_client import OIDCClient
|
||||
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.models.oidc.oidc_refresh_token import OIDCRefreshToken
|
||||
from gatehouse_app.models.oidc.oidc_session import OIDCSession
|
||||
from gatehouse_app.models.oidc.oidc_token_metadata import OIDCTokenMetadata
|
||||
from gatehouse_app.models.oidc.oidc_audit_log import OIDCAuditLog
|
||||
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey
|
||||
|
||||
__all__ = [
|
||||
"OIDCClient",
|
||||
"OIDCAuthCode",
|
||||
"OIDCRefreshToken",
|
||||
"OIDCSession",
|
||||
"OIDCTokenMetadata",
|
||||
"OIDCAuditLog",
|
||||
"OidcJwksKey",
|
||||
]
|
||||
+83
-50
@@ -1,5 +1,4 @@
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking."""
|
||||
from datetime import datetime
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
@@ -7,8 +6,7 @@ from gatehouse_app.models.base import BaseModel
|
||||
class OIDCAuditLog(BaseModel):
|
||||
"""OIDC Audit Log model for comprehensive OIDC event tracking.
|
||||
|
||||
This model logs all OIDC-related events for security, compliance,
|
||||
and debugging purposes.
|
||||
Logs all OIDC-related events for security, compliance, and debugging.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_audit_logs"
|
||||
@@ -46,16 +44,29 @@ class OIDCAuditLog(BaseModel):
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuditLog."""
|
||||
status = "success" if self.success else "failed"
|
||||
return f"<OIDCAuditLog event={self.event_type} status={status} client={self.client_id}>"
|
||||
return (
|
||||
f"<OIDCAuditLog event={self.event_type} "
|
||||
f"status={status} client={self.client_id}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_event(cls, event_type, client_id=None, user_id=None, success=True,
|
||||
error_code=None, error_description=None, ip_address=None,
|
||||
user_agent=None, request_id=None, event_metadata=None):
|
||||
def log_event(
|
||||
cls,
|
||||
event_type: str,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
event_metadata: dict = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log an OIDC event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event (e.g., "authorization_request", "token_issue")
|
||||
event_type: Type of event (e.g., "authorization_request")
|
||||
client_id: The OIDC client ID
|
||||
user_id: The user ID
|
||||
success: Whether the event was successful
|
||||
@@ -86,9 +97,19 @@ class OIDCAuditLog(BaseModel):
|
||||
return log
|
||||
|
||||
@classmethod
|
||||
def log_authorization_request(cls, client_id, user_id, redirect_uri, scope,
|
||||
ip_address=None, user_agent=None, request_id=None,
|
||||
success=True, error_code=None, error_description=None):
|
||||
def log_authorization_request(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
redirect_uri: str,
|
||||
scope,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
success: bool = True,
|
||||
error_code: str = None,
|
||||
error_description: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log an authorization request event."""
|
||||
return cls.log_event(
|
||||
event_type="authorization_request",
|
||||
@@ -100,15 +121,19 @@ class OIDCAuditLog(BaseModel):
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": scope,
|
||||
}
|
||||
event_metadata={"redirect_uri": redirect_uri, "scope": scope},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_issue(cls, client_id, user_id, token_type,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
def log_token_issue(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_type: str,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log a token issuance event."""
|
||||
return cls.log_event(
|
||||
event_type="token_issue",
|
||||
@@ -118,12 +143,20 @@ class OIDCAuditLog(BaseModel):
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={"token_type": token_type}
|
||||
event_metadata={"token_type": token_type},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_token_revocation(cls, client_id, user_id, token_type, reason=None,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
def log_token_revocation(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_type: str,
|
||||
reason: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log a token revocation event."""
|
||||
return cls.log_event(
|
||||
event_type="token_revocation",
|
||||
@@ -133,15 +166,19 @@ class OIDCAuditLog(BaseModel):
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
event_metadata={
|
||||
"token_type": token_type,
|
||||
"reason": reason,
|
||||
}
|
||||
event_metadata={"token_type": token_type, "reason": reason},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_authentication_failure(cls, client_id, error_code, error_description,
|
||||
ip_address=None, user_agent=None, request_id=None):
|
||||
def log_authentication_failure(
|
||||
cls,
|
||||
client_id: str,
|
||||
error_code: str,
|
||||
error_description: str,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
request_id: str = None,
|
||||
) -> "OIDCAuditLog":
|
||||
"""Log an authentication failure event."""
|
||||
return cls.log_event(
|
||||
event_type="authentication_failure",
|
||||
@@ -155,7 +192,7 @@ class OIDCAuditLog(BaseModel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_user(cls, user_id, limit=100):
|
||||
def get_events_for_user(cls, user_id: str, limit: int = 100) -> list:
|
||||
"""Get audit events for a user.
|
||||
|
||||
Args:
|
||||
@@ -165,13 +202,15 @@ class OIDCAuditLog(BaseModel):
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return cls.query.filter_by(user_id=user_id, deleted_at=None)\
|
||||
.order_by(cls.created_at.desc())\
|
||||
.limit(limit)\
|
||||
return (
|
||||
cls.query.filter_by(user_id=user_id, deleted_at=None)
|
||||
.order_by(cls.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_events_for_client(cls, client_id, limit=100):
|
||||
def get_events_for_client(cls, client_id: str, limit: int = 100) -> list:
|
||||
"""Get audit events for a client.
|
||||
|
||||
Args:
|
||||
@@ -181,14 +220,22 @@ class OIDCAuditLog(BaseModel):
|
||||
Returns:
|
||||
List of OIDCAuditLog instances
|
||||
"""
|
||||
return cls.query.filter_by(client_id=client_id, deleted_at=None)\
|
||||
.order_by(cls.created_at.desc())\
|
||||
.limit(limit)\
|
||||
return (
|
||||
cls.query.filter_by(client_id=client_id, deleted_at=None)
|
||||
.order_by(cls.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_failed_events(cls, client_id=None, user_id=None, start_date=None,
|
||||
end_date=None, limit=100):
|
||||
def get_failed_events(
|
||||
cls,
|
||||
client_id: str = None,
|
||||
user_id: str = None,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
limit: int = 100,
|
||||
) -> list:
|
||||
"""Get failed audit events.
|
||||
|
||||
Args:
|
||||
@@ -210,22 +257,8 @@ class OIDCAuditLog(BaseModel):
|
||||
query = query.filter(cls.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(cls.created_at <= end_date)
|
||||
|
||||
return query.order_by(cls.created_at.desc()).limit(limit).all()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
+32
-34
@@ -1,14 +1,14 @@
|
||||
"""OIDC Authorization Code model for auth code flow."""
|
||||
"""OIDC Authorization Code model for the authorization code grant flow."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OIDCAuthCode(BaseModel):
|
||||
"""OIDC Authorization Code model for authorization code flow.
|
||||
"""OIDC Authorization Code model for the authorization code grant flow.
|
||||
|
||||
Authorization codes are single-use, short-lived codes used in the
|
||||
authorization code grant flow. The code is hashed for security.
|
||||
Authorization codes are single-use, short-lived codes. The code itself is
|
||||
hashed before storage so that a database breach cannot replay codes.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_authorization_codes"
|
||||
@@ -26,9 +26,9 @@ class OIDCAuthCode(BaseModel):
|
||||
|
||||
# Request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True) # Requested scopes
|
||||
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
|
||||
code_verifier = db.Column(db.String(255), nullable=True) # For PKCE
|
||||
scope = db.Column(db.JSON, nullable=True)
|
||||
nonce = db.Column(db.String(255), nullable=True)
|
||||
code_verifier = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Status tracking
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
@@ -39,37 +39,48 @@ class OIDCAuthCode(BaseModel):
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
# Relationships — back_populates declared on User and OIDCClient
|
||||
client = db.relationship("OIDCClient", back_populates="authorization_codes")
|
||||
user = db.relationship("User", back_populates="oidc_auth_codes")
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCAuthCode."""
|
||||
return f"<OIDCAuthCode client_id={self.client_id} user_id={self.user_id} used={self.is_used}>"
|
||||
return (
|
||||
f"<OIDCAuthCode client_id={self.client_id} "
|
||||
f"user_id={self.user_id} used={self.is_used}>"
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the authorization code has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
# Make naive datetime timezone-aware (UTC)
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_valid(self):
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the authorization code is valid for use."""
|
||||
return not self.is_used and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def mark_as_used(self):
|
||||
def mark_as_used(self) -> None:
|
||||
"""Mark the authorization code as used."""
|
||||
self.is_used = True
|
||||
self.used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_code(cls, client_id, user_id, code_hash, redirect_uri, scope=None,
|
||||
nonce=None, code_verifier=None, ip_address=None, user_agent=None,
|
||||
lifetime_seconds=600):
|
||||
def create_code(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
code_hash: str,
|
||||
redirect_uri: str,
|
||||
scope=None,
|
||||
nonce: str = None,
|
||||
code_verifier: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> "OIDCAuthCode":
|
||||
"""Create a new authorization code.
|
||||
|
||||
Args:
|
||||
@@ -79,7 +90,7 @@ class OIDCAuthCode(BaseModel):
|
||||
redirect_uri: The redirect URI
|
||||
scope: Requested scopes
|
||||
nonce: OIDC nonce
|
||||
code_verifier: PKCE code verifier
|
||||
code_verifier: PKCE code verifier (stored hashed server-side)
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Code lifetime in seconds (default 10 minutes)
|
||||
@@ -106,20 +117,7 @@ class OIDCAuthCode(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude code hash
|
||||
exclude.append("code_hash")
|
||||
exclude.append("code_verifier")
|
||||
for field in ("code_hash", "code_verifier"):
|
||||
if field not in exclude:
|
||||
exclude.append(field)
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_auth_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.authorization_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -17,10 +17,10 @@ class OIDCClient(BaseModel):
|
||||
client_secret_hash = db.Column(db.String(255), nullable=False)
|
||||
|
||||
# OAuth/OIDC configuration
|
||||
redirect_uris = db.Column(db.JSON, nullable=False) # List of allowed redirect URIs
|
||||
grant_types = db.Column(db.JSON, nullable=False) # List of allowed grant types
|
||||
response_types = db.Column(db.JSON, nullable=False) # List of allowed response types
|
||||
scopes = db.Column(db.JSON, nullable=False) # List of allowed scopes
|
||||
redirect_uris = db.Column(db.JSON, nullable=False) # Allowed redirect URIs
|
||||
grant_types = db.Column(db.JSON, nullable=False) # Allowed grant types
|
||||
response_types = db.Column(db.JSON, nullable=False) # Allowed response types
|
||||
scopes = db.Column(db.JSON, nullable=False) # Allowed scopes
|
||||
|
||||
# Client metadata
|
||||
logo_uri = db.Column(db.String(512), nullable=True)
|
||||
@@ -41,6 +41,23 @@ class OIDCClient(BaseModel):
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="oidc_clients")
|
||||
|
||||
# OIDC sub-resource relationships (declared here, not monkey-patched elsewhere)
|
||||
authorization_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCClient."""
|
||||
return f"<OIDCClient {self.name} client_id={self.client_id}>"
|
||||
@@ -48,22 +65,22 @@ class OIDCClient(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude client secret
|
||||
exclude.append("client_secret_hash")
|
||||
if "client_secret_hash" not in exclude:
|
||||
exclude.append("client_secret_hash")
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
def has_grant_type(self, grant_type):
|
||||
def has_grant_type(self, grant_type) -> bool:
|
||||
"""Check if client supports a specific grant type."""
|
||||
return grant_type in self.grant_types
|
||||
|
||||
def has_response_type(self, response_type):
|
||||
def has_response_type(self, response_type) -> bool:
|
||||
"""Check if client supports a specific response type."""
|
||||
return response_type in self.response_types
|
||||
|
||||
def is_redirect_uri_allowed(self, redirect_uri):
|
||||
def is_redirect_uri_allowed(self, redirect_uri: str) -> bool:
|
||||
"""Check if a redirect URI is allowed for this client."""
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def has_scope(self, scope):
|
||||
def has_scope(self, scope: str) -> bool:
|
||||
"""Check if client is allowed to request a specific scope."""
|
||||
return scope in self.scopes
|
||||
@@ -0,0 +1,76 @@
|
||||
"""OIDC JWKS Key model for persisting signing keys."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OidcJwksKey(BaseModel):
|
||||
"""OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
|
||||
|
||||
Stores RSA/ECDSA key pairs used for signing OIDC tokens. Multiple keys can
|
||||
be stored to support key rotation scenarios.
|
||||
|
||||
Attributes:
|
||||
kid: Unique key ID used in JWT ``kid`` header
|
||||
key_type: Type of key (e.g., "RSA", "EC")
|
||||
private_key: PEM-encoded private key (never exposed in API responses)
|
||||
public_key: PEM-encoded public key
|
||||
algorithm: Signing algorithm (e.g., "RS256", "ES256")
|
||||
is_active: Whether this key is currently used for signing/verification
|
||||
is_primary: Whether this is the primary signing key
|
||||
expires_at: Optional expiry for key rotation enforcement
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_jwks_keys"
|
||||
|
||||
# Override the default UUID id with integer primary key for JWKS key sets
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
|
||||
expires_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Key identification and type
|
||||
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
|
||||
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
|
||||
|
||||
# Key material (PEM-encoded) — private_key must never be returned by API
|
||||
private_key = db.Column(db.Text, nullable=False)
|
||||
public_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Key status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
is_primary = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OidcJwksKey."""
|
||||
return (
|
||||
f"<OidcJwksKey kid={self.kid} "
|
||||
f"key_type={self.key_type} algorithm={self.algorithm}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude_private_key: bool = True):
|
||||
"""Convert model to dictionary.
|
||||
|
||||
Args:
|
||||
exclude_private_key: If True (default), excludes the private key.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the model
|
||||
"""
|
||||
exclude = ["private_key"] if exclude_private_key else []
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
@classmethod
|
||||
def get_active_keys(cls) -> list:
|
||||
"""Get all active keys for signing operations."""
|
||||
return cls.query.filter_by(is_active=True).all()
|
||||
|
||||
@classmethod
|
||||
def get_primary_key(cls) -> "OidcJwksKey | None":
|
||||
"""Get the primary signing key."""
|
||||
return cls.query.filter_by(is_primary=True).first()
|
||||
|
||||
@classmethod
|
||||
def get_key_by_kid(cls, kid: str) -> "OidcJwksKey | None":
|
||||
"""Get an active key by its key ID."""
|
||||
return cls.query.filter_by(kid=kid, is_active=True).first()
|
||||
+33
-41
@@ -1,5 +1,5 @@
|
||||
"""OIDC Refresh Token model for token rotation."""
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
@@ -8,7 +8,8 @@ class OIDCRefreshToken(BaseModel):
|
||||
"""OIDC Refresh Token model for token refresh and rotation.
|
||||
|
||||
Refresh tokens are long-lived credentials used to obtain new access tokens.
|
||||
They support token rotation for enhanced security.
|
||||
They support token rotation for enhanced security — each use invalidates
|
||||
the old token and issues a new one.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_refresh_tokens"
|
||||
@@ -21,16 +22,14 @@ class OIDCRefreshToken(BaseModel):
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token (hashed for security)
|
||||
# Token (hashed for security — never store plaintext refresh tokens)
|
||||
token_hash = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
|
||||
# Associated access token ID (stores JWT JTI string — no FK to sessions)
|
||||
access_token_id = db.Column(
|
||||
db.String(255), nullable=True, index=True
|
||||
)
|
||||
# Associated access token JTI (no FK — stored as string for lightweight lookup)
|
||||
access_token_id = db.Column(db.String(255), nullable=True, index=True)
|
||||
|
||||
# Token scope
|
||||
scope = db.Column(db.JSON, nullable=True) # Granted scopes
|
||||
scope = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
@@ -40,7 +39,7 @@ class OIDCRefreshToken(BaseModel):
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Token rotation metadata
|
||||
previous_token_hash = db.Column(db.String(255), nullable=True) # For rotation
|
||||
previous_token_hash = db.Column(db.String(255), nullable=True)
|
||||
rotation_count = db.Column(db.Integer, default=0, nullable=False)
|
||||
|
||||
# Request metadata
|
||||
@@ -53,25 +52,27 @@ class OIDCRefreshToken(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCRefreshToken."""
|
||||
return f"<OIDCRefreshToken client_id={self.client_id} user_id={self.user_id} revoked={self.is_revoked()}>"
|
||||
return (
|
||||
f"<OIDCRefreshToken client_id={self.client_id} "
|
||||
f"user_id={self.user_id} revoked={self.is_revoked()}>"
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the refresh token has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_revoked(self):
|
||||
def is_revoked(self) -> bool:
|
||||
"""Check if the refresh token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self):
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the refresh token is valid for use."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason=None):
|
||||
def revoke(self, reason: str = None) -> None:
|
||||
"""Revoke the refresh token.
|
||||
|
||||
Args:
|
||||
@@ -81,8 +82,8 @@ class OIDCRefreshToken(BaseModel):
|
||||
self.revoked_reason = reason
|
||||
db.session.commit()
|
||||
|
||||
def rotate(self, new_token_hash):
|
||||
"""Rotate the refresh token (invalidate old, create new).
|
||||
def rotate(self, new_token_hash: str) -> "OIDCRefreshToken":
|
||||
"""Rotate the refresh token — invalidate the old hash, store the new one.
|
||||
|
||||
Args:
|
||||
new_token_hash: Hash of the new refresh token
|
||||
@@ -90,20 +91,25 @@ class OIDCRefreshToken(BaseModel):
|
||||
Returns:
|
||||
self for chaining
|
||||
"""
|
||||
# Store reference to old token
|
||||
self.previous_token_hash = self.token_hash
|
||||
self.token_hash = new_token_hash
|
||||
self.rotation_count += 1
|
||||
# Extend expiration on rotation
|
||||
from datetime import timedelta
|
||||
self.expires_at = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
db.session.commit()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def create_token(cls, client_id, user_id, token_hash, scope=None,
|
||||
access_token_id=None, ip_address=None, user_agent=None,
|
||||
lifetime_seconds=2592000):
|
||||
def create_token(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_hash: str,
|
||||
scope=None,
|
||||
access_token_id: str = None,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
lifetime_seconds: int = 2592000,
|
||||
) -> "OIDCRefreshToken":
|
||||
"""Create a new refresh token.
|
||||
|
||||
Args:
|
||||
@@ -111,7 +117,7 @@ class OIDCRefreshToken(BaseModel):
|
||||
user_id: The user ID
|
||||
token_hash: Hashed refresh token
|
||||
scope: Granted scopes
|
||||
access_token_id: Associated access token ID
|
||||
access_token_id: Associated access token JTI
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
lifetime_seconds: Token lifetime in seconds (default 30 days)
|
||||
@@ -119,7 +125,6 @@ class OIDCRefreshToken(BaseModel):
|
||||
Returns:
|
||||
OIDCRefreshToken instance
|
||||
"""
|
||||
from datetime import timedelta
|
||||
token = cls(
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
@@ -137,20 +142,7 @@ class OIDCRefreshToken(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Always exclude token hashes
|
||||
exclude.append("token_hash")
|
||||
exclude.append("previous_token_hash")
|
||||
for field in ("token_hash", "previous_token_hash"):
|
||||
if field not in exclude:
|
||||
exclude.append(field)
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -1,5 +1,7 @@
|
||||
"""OIDC Session model for OIDC session tracking."""
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import base64
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
@@ -7,8 +9,8 @@ from gatehouse_app.models.base import BaseModel
|
||||
class OIDCSession(BaseModel):
|
||||
"""OIDC Session model for tracking OIDC authentication sessions.
|
||||
|
||||
This model tracks the state during the OIDC authentication flow,
|
||||
including PKCE parameters and nonce validation.
|
||||
Tracks the state during the OIDC authorization flow, including PKCE
|
||||
parameters and nonce validation.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_sessions"
|
||||
@@ -25,11 +27,11 @@ class OIDCSession(BaseModel):
|
||||
|
||||
# State management
|
||||
state = db.Column(db.String(255), nullable=False, index=True)
|
||||
nonce = db.Column(db.String(255), nullable=True) # For OIDC ID Token validation
|
||||
nonce = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Authorization request parameters
|
||||
redirect_uri = db.Column(db.String(512), nullable=False)
|
||||
scope = db.Column(db.JSON, nullable=True) # Requested scopes
|
||||
scope = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# PKCE parameters
|
||||
code_challenge = db.Column(db.String(255), nullable=True)
|
||||
@@ -45,50 +47,52 @@ class OIDCSession(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCSession."""
|
||||
return f"<OIDCSession user_id={self.user_id} client_id={self.client_id} state={self.state[:8]}...>"
|
||||
return (
|
||||
f"<OIDCSession user_id={self.user_id} "
|
||||
f"client_id={self.client_id} state={self.state[:8]}...>"
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the OIDC session has expired."""
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_authenticated(self):
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if the user has been authenticated in this session."""
|
||||
return self.authenticated_at is not None
|
||||
|
||||
def mark_authenticated(self):
|
||||
def mark_authenticated(self) -> None:
|
||||
"""Mark the session as authenticated."""
|
||||
self.authenticated_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def validate_nonce(self, expected_nonce):
|
||||
def validate_nonce(self, expected_nonce: str) -> bool:
|
||||
"""Validate the nonce matches the expected value.
|
||||
|
||||
Args:
|
||||
expected_nonce: The expected nonce value
|
||||
|
||||
Returns:
|
||||
bool: True if nonce matches
|
||||
True if nonce matches
|
||||
"""
|
||||
return self.nonce == expected_nonce
|
||||
|
||||
def validate_code_challenge(self, code_verifier):
|
||||
def validate_code_challenge(self, code_verifier: str) -> bool:
|
||||
"""Validate the code verifier against the stored code challenge.
|
||||
|
||||
Args:
|
||||
code_verifier: The PKCE code verifier
|
||||
|
||||
Returns:
|
||||
bool: True if code challenge is valid
|
||||
True if the challenge is satisfied
|
||||
"""
|
||||
if not self.code_challenge:
|
||||
return False
|
||||
|
||||
if self.code_challenge_method == "S256":
|
||||
import hashlib
|
||||
import base64
|
||||
# SHA256 hash of code_verifier
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
# Base64 URL encode without padding
|
||||
expected = base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return self.code_challenge == expected
|
||||
elif self.code_challenge_method == "plain":
|
||||
@@ -97,9 +101,18 @@ class OIDCSession(BaseModel):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def create_session(cls, user_id, client_id, state, redirect_uri, scope=None,
|
||||
nonce=None, code_challenge=None, code_challenge_method=None,
|
||||
lifetime_seconds=600):
|
||||
def create_session(
|
||||
cls,
|
||||
user_id: str,
|
||||
client_id: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
scope=None,
|
||||
nonce: str = None,
|
||||
code_challenge: str = None,
|
||||
code_challenge_method: str = None,
|
||||
lifetime_seconds: int = 600,
|
||||
) -> "OIDCSession":
|
||||
"""Create a new OIDC session.
|
||||
|
||||
Args:
|
||||
@@ -116,7 +129,6 @@ class OIDCSession(BaseModel):
|
||||
Returns:
|
||||
OIDCSession instance
|
||||
"""
|
||||
from datetime import timedelta
|
||||
session = cls(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
@@ -133,7 +145,7 @@ class OIDCSession(BaseModel):
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def get_by_state(cls, state):
|
||||
def get_by_state(cls, state: str) -> "OIDCSession | None":
|
||||
"""Get a session by state parameter.
|
||||
|
||||
Args:
|
||||
@@ -147,16 +159,3 @@ class OIDCSession(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
+49
-45
@@ -8,13 +8,14 @@ from gatehouse_app.models.base import BaseModel
|
||||
class OIDCTokenMetadata(BaseModel):
|
||||
"""OIDC Token Metadata model for tracking issued tokens.
|
||||
|
||||
This model stores metadata about issued tokens (access tokens, refresh tokens, ID tokens)
|
||||
for the purpose of token revocation. The id field matches the JTI (JWT ID) claim.
|
||||
Stores metadata about issued tokens (access, refresh, ID) for revocation.
|
||||
The ``id`` field on this model intentionally overrides the BaseModel UUID
|
||||
to store the JWT JTI directly as the primary key for O(1) revocation checks.
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_token_metadata"
|
||||
|
||||
# Token identifier (matches JTI in JWT)
|
||||
# Primary key = JTI so revocation lookups are always a PK scan
|
||||
id = db.Column(
|
||||
db.String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
@@ -27,11 +28,11 @@ class OIDCTokenMetadata(BaseModel):
|
||||
db.String(36), db.ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Token type
|
||||
token_type = db.Column(db.String(50), nullable=False) # "access_token", "refresh_token", "id_token"
|
||||
# Token type: "access_token", "refresh_token", or "id_token"
|
||||
token_type = db.Column(db.String(50), nullable=False)
|
||||
|
||||
# Token identifier for revocation lookup
|
||||
token_jti = db.Column(db.String(255), nullable=False, index=True) # JWT ID claim
|
||||
# JWT ID claim (indexed for fast lookup when id != jti)
|
||||
token_jti = db.Column(db.String(255), nullable=False, index=True)
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False, index=True)
|
||||
@@ -46,25 +47,27 @@ class OIDCTokenMetadata(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OIDCTokenMetadata."""
|
||||
return f"<OIDCTokenMetadata jti={self.token_jti[:8]}... type={self.token_type} revoked={self.is_revoked()}>"
|
||||
return (
|
||||
f"<OIDCTokenMetadata jti={self.token_jti[:8]}... "
|
||||
f"type={self.token_type} revoked={self.is_revoked()}>"
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the token has expired."""
|
||||
# Handle both timezone-aware and timezone-naive expires_at values
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
def is_revoked(self):
|
||||
def is_revoked(self) -> bool:
|
||||
"""Check if the token has been revoked."""
|
||||
return self.revoked_at is not None
|
||||
|
||||
def is_valid(self):
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the token is valid (not expired and not revoked)."""
|
||||
return not self.is_revoked() and not self.is_expired() and self.deleted_at is None
|
||||
|
||||
def revoke(self, reason=None):
|
||||
def revoke(self, reason: str = None) -> None:
|
||||
"""Revoke the token.
|
||||
|
||||
Args:
|
||||
@@ -75,8 +78,16 @@ class OIDCTokenMetadata(BaseModel):
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def create_metadata(cls, client_id, user_id, token_type, token_jti,
|
||||
expires_at, ip_address=None, user_agent=None):
|
||||
def create_metadata(
|
||||
cls,
|
||||
client_id: str,
|
||||
user_id: str,
|
||||
token_type: str,
|
||||
token_jti: str,
|
||||
expires_at,
|
||||
ip_address: str = None,
|
||||
user_agent: str = None,
|
||||
) -> "OIDCTokenMetadata":
|
||||
"""Create token metadata for tracking.
|
||||
|
||||
Args:
|
||||
@@ -85,8 +96,8 @@ class OIDCTokenMetadata(BaseModel):
|
||||
token_type: Type of token ("access_token", "refresh_token", "id_token")
|
||||
token_jti: JWT ID claim
|
||||
expires_at: Token expiration datetime
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
ip_address: Client IP address (unused column, kept for API compat)
|
||||
user_agent: Client user agent (unused column, kept for API compat)
|
||||
|
||||
Returns:
|
||||
OIDCTokenMetadata instance
|
||||
@@ -104,7 +115,7 @@ class OIDCTokenMetadata(BaseModel):
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def get_by_jti(cls, token_jti):
|
||||
def get_by_jti(cls, token_jti: str) -> "OIDCTokenMetadata | None":
|
||||
"""Get token metadata by JWT ID.
|
||||
|
||||
Args:
|
||||
@@ -116,7 +127,7 @@ class OIDCTokenMetadata(BaseModel):
|
||||
return cls.query.filter_by(token_jti=token_jti, deleted_at=None).first()
|
||||
|
||||
@classmethod
|
||||
def revoke_by_jti(cls, token_jti, reason=None):
|
||||
def revoke_by_jti(cls, token_jti: str, reason: str = None) -> bool:
|
||||
"""Revoke a token by its JWT ID.
|
||||
|
||||
Args:
|
||||
@@ -124,7 +135,7 @@ class OIDCTokenMetadata(BaseModel):
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
bool: True if token was found and revoked
|
||||
True if token was found and revoked, False otherwise
|
||||
"""
|
||||
metadata = cls.get_by_jti(token_jti)
|
||||
if metadata:
|
||||
@@ -133,47 +144,53 @@ class OIDCTokenMetadata(BaseModel):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_user(cls, user_id, client_id=None, reason=None):
|
||||
def revoke_all_for_user(
|
||||
cls, user_id: str, client_id: str = None, reason: str = None
|
||||
) -> int:
|
||||
"""Revoke all tokens for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
client_id: Optional client ID to filter by
|
||||
client_id: Optional client ID filter
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
int: Number of tokens revoked
|
||||
Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(user_id=user_id, deleted_at=None)
|
||||
query = cls.query.filter_by(user_id=user_id, deleted_at=None).filter(
|
||||
cls.revoked_at.is_(None)
|
||||
)
|
||||
if client_id:
|
||||
query = query.filter_by(client_id=client_id)
|
||||
|
||||
tokens = query.filter(cls.revoked_at == None).all()
|
||||
count = 0
|
||||
for token in tokens:
|
||||
for token in query.all():
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@classmethod
|
||||
def revoke_all_for_client(cls, client_id, user_id=None, reason=None):
|
||||
def revoke_all_for_client(
|
||||
cls, client_id: str, user_id: str = None, reason: str = None
|
||||
) -> int:
|
||||
"""Revoke all tokens for a client.
|
||||
|
||||
Args:
|
||||
client_id: The client ID
|
||||
user_id: Optional user ID to filter by
|
||||
user_id: Optional user ID filter
|
||||
reason: Optional revocation reason
|
||||
|
||||
Returns:
|
||||
int: Number of tokens revoked
|
||||
Number of tokens revoked
|
||||
"""
|
||||
query = cls.query.filter_by(client_id=client_id, deleted_at=None)
|
||||
query = cls.query.filter_by(client_id=client_id, deleted_at=None).filter(
|
||||
cls.revoked_at.is_(None)
|
||||
)
|
||||
if user_id:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
|
||||
tokens = query.filter(cls.revoked_at == None).all()
|
||||
count = 0
|
||||
for token in tokens:
|
||||
for token in query.all():
|
||||
token.revoke(reason)
|
||||
count += 1
|
||||
return count
|
||||
@@ -181,16 +198,3 @@ class OIDCTokenMetadata(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
|
||||
# Add relationship back to User model
|
||||
from gatehouse_app.models.user import User
|
||||
User.oidc_token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Add relationship back to OIDCClient model
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
OIDCClient.token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="client", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -1,77 +0,0 @@
|
||||
"""OIDC JWKS Key model for persisting signing keys."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OidcJwksKey(BaseModel):
|
||||
"""
|
||||
OIDC JWKS Key model for persisting JSON Web Key Set signing keys.
|
||||
|
||||
This model stores RSA/ECDSA key pairs used for signing OIDC tokens.
|
||||
Multiple keys can be stored to support key rotation scenarios.
|
||||
|
||||
Attributes:
|
||||
id: Integer primary key
|
||||
kid: Unique key ID used in JWT "kid" header
|
||||
key_type: Type of key (e.g., "RSA", "EC")
|
||||
private_key: PEM-encoded private key
|
||||
public_key: PEM-encoded public key
|
||||
algorithm: Signing algorithm (e.g., "RS256", "ES256")
|
||||
created_at: When the key was created
|
||||
is_active: Whether this key is currently active for signing
|
||||
is_primary: Whether this is the primary signing key
|
||||
expires_at: ...
|
||||
"""
|
||||
|
||||
__tablename__ = "oidc_jwks_keys"
|
||||
|
||||
# Override the default UUID id with integer primary key
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
|
||||
expires_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Key identification and type
|
||||
kid = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
key_type = db.Column(db.String(50), nullable=False) # e.g., "RSA", "EC"
|
||||
algorithm = db.Column(db.String(50), nullable=False) # e.g., "RS256", "ES256"
|
||||
|
||||
# Key material (PEM-encoded)
|
||||
private_key = db.Column(db.Text, nullable=False)
|
||||
public_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Key status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
is_primary = db.Column(db.Boolean, default=False, nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OidcJwksKey."""
|
||||
return f"<OidcJwksKey kid={self.kid} key_type={self.key_type} algorithm={self.algorithm}>"
|
||||
|
||||
def to_dict(self, exclude_private_key=True):
|
||||
"""
|
||||
Convert model to dictionary.
|
||||
|
||||
Args:
|
||||
exclude_private_key: If True, excludes the private key from output
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the model
|
||||
"""
|
||||
exclude = ["private_key"] if exclude_private_key else []
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
@classmethod
|
||||
def get_active_keys(cls):
|
||||
"""Get all active keys for signing operations."""
|
||||
return cls.query.filter(cls.is_active == True).all()
|
||||
|
||||
@classmethod
|
||||
def get_primary_key(cls):
|
||||
"""Get the primary signing key."""
|
||||
return cls.query.filter(cls.is_primary == True).first()
|
||||
|
||||
@classmethod
|
||||
def get_key_by_kid(cls, kid):
|
||||
"""Get a key by its key ID."""
|
||||
return cls.query.filter(cls.kid == kid, cls.is_active == True).first()
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Organization subpackage."""
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.department import (
|
||||
Department,
|
||||
DepartmentMembership,
|
||||
DepartmentPrincipal,
|
||||
)
|
||||
from gatehouse_app.models.organization.department_cert_policy import (
|
||||
DepartmentCertPolicy,
|
||||
STANDARD_EXTENSIONS,
|
||||
)
|
||||
from gatehouse_app.models.organization.principal import Principal, PrincipalMembership
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
|
||||
__all__ = [
|
||||
"Organization",
|
||||
"OrganizationMember",
|
||||
"Department",
|
||||
"DepartmentMembership",
|
||||
"DepartmentPrincipal",
|
||||
"DepartmentCertPolicy",
|
||||
"STANDARD_EXTENSIONS",
|
||||
"Principal",
|
||||
"PrincipalMembership",
|
||||
"OrgInviteToken",
|
||||
]
|
||||
@@ -0,0 +1,196 @@
|
||||
"""Department, DepartmentMembership, and DepartmentPrincipal models."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class Department(BaseModel):
|
||||
"""Department model representing an organizational unit for SSH access control.
|
||||
|
||||
Departments are used to group users and assign SSH principals (access levels)
|
||||
to them. A user can be a member of multiple departments, and each department
|
||||
can have multiple principals assigned.
|
||||
|
||||
Example:
|
||||
- Department: "Engineering"
|
||||
- Members: user1@example.com, user2@example.com
|
||||
- Principals: "eng-prod", "eng-staging"
|
||||
- Users get access based on their principal assignments
|
||||
"""
|
||||
|
||||
__tablename__ = "departments"
|
||||
|
||||
organization_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("organizations.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name = db.Column(db.String(255), nullable=False, index=True)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="departments")
|
||||
memberships = db.relationship(
|
||||
"DepartmentMembership",
|
||||
back_populates="department",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
principal_links = db.relationship(
|
||||
"DepartmentPrincipal",
|
||||
back_populates="department",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
cert_policy = db.relationship(
|
||||
"DepartmentCertPolicy",
|
||||
back_populates="department",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("organization_id", "name", name="uix_org_dept_name"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of Department."""
|
||||
return f"<Department {self.name} (org_id={self.organization_id})>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert department to dictionary."""
|
||||
exclude = exclude or []
|
||||
data = super().to_dict(exclude=exclude)
|
||||
data["member_count"] = len([m for m in self.memberships if m.deleted_at is None])
|
||||
data["principal_count"] = len([p for p in self.principal_links if p.deleted_at is None])
|
||||
return data
|
||||
|
||||
def get_members(self, active_only: bool = True):
|
||||
"""Get all members of this department.
|
||||
|
||||
Args:
|
||||
active_only: If True, exclude soft-deleted members
|
||||
|
||||
Returns:
|
||||
List of DepartmentMembership objects
|
||||
"""
|
||||
if active_only:
|
||||
return [m for m in self.memberships if m.deleted_at is None]
|
||||
return list(self.memberships)
|
||||
|
||||
def get_principals(self, active_only: bool = True):
|
||||
"""Get all principals assigned to this department.
|
||||
|
||||
Args:
|
||||
active_only: If True, exclude soft-deleted principals
|
||||
|
||||
Returns:
|
||||
List of Principal objects via DepartmentPrincipal
|
||||
"""
|
||||
if active_only:
|
||||
return [
|
||||
p.principal
|
||||
for p in self.principal_links
|
||||
if p.deleted_at is None and p.principal.deleted_at is None
|
||||
]
|
||||
return [p.principal for p in self.principal_links]
|
||||
|
||||
def is_member(self, user_id: str) -> bool:
|
||||
"""Check if a user is a member of this department.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user to check
|
||||
|
||||
Returns:
|
||||
True if user is an active member, False otherwise
|
||||
"""
|
||||
return (
|
||||
DepartmentMembership.query.filter_by(
|
||||
user_id=user_id,
|
||||
department_id=self.id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_member_count(self) -> int:
|
||||
"""Get the count of active members in this department."""
|
||||
return len(self.get_members(active_only=True))
|
||||
|
||||
|
||||
class DepartmentMembership(BaseModel):
|
||||
"""Department membership model representing user membership in a department.
|
||||
|
||||
When a user is added to a department, they become eligible for SSH principals
|
||||
assigned to that department.
|
||||
"""
|
||||
|
||||
__tablename__ = "department_memberships"
|
||||
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
department_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("departments.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="department_memberships")
|
||||
department = db.relationship("Department", back_populates="memberships")
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("user_id", "department_id", name="uix_user_dept"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of DepartmentMembership."""
|
||||
return (
|
||||
f"<DepartmentMembership user_id={self.user_id} dept_id={self.department_id}>"
|
||||
)
|
||||
|
||||
|
||||
class DepartmentPrincipal(BaseModel):
|
||||
"""Department principal assignment model.
|
||||
|
||||
Represents the assignment of principals to departments. All members of a
|
||||
department get access to its assigned principals (transitively).
|
||||
|
||||
Example:
|
||||
- Department: "Engineering"
|
||||
- Principal: "eng-prod-servers"
|
||||
- All engineering department members can SSH as "eng-prod-servers"
|
||||
"""
|
||||
|
||||
__tablename__ = "department_principals"
|
||||
|
||||
department_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("departments.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
principal_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("principals.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
department = db.relationship("Department", back_populates="principal_links")
|
||||
principal = db.relationship("Principal", back_populates="department_links")
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("department_id", "principal_id", name="uix_dept_principal"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of DepartmentPrincipal."""
|
||||
return (
|
||||
f"<DepartmentPrincipal dept_id={self.department_id} "
|
||||
f"principal_id={self.principal_id}>"
|
||||
)
|
||||
@@ -0,0 +1,76 @@
|
||||
"""DepartmentCertPolicy — per-department SSH certificate issuance rules."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
# Standard SSH certificate extensions
|
||||
STANDARD_EXTENSIONS = [
|
||||
"permit-X11-forwarding",
|
||||
"permit-agent-forwarding",
|
||||
"permit-pty",
|
||||
"permit-port-forwarding",
|
||||
"permit-user-rc",
|
||||
]
|
||||
|
||||
|
||||
class DepartmentCertPolicy(BaseModel):
|
||||
"""SSH certificate policy for a department.
|
||||
|
||||
Controls:
|
||||
- Whether members may choose their own expiry date (up to ``max_expiry_hours``)
|
||||
- Default expiry hours when the user doesn't (or can't) pick
|
||||
- Maximum expiry hours (hard ceiling, even for admins signing on behalf)
|
||||
- Which SSH certificate extensions are granted to members of this department
|
||||
- Any custom extensions the admin wants to add beyond the standard five
|
||||
|
||||
Inherits ``id``, ``created_at``, ``updated_at``, and ``deleted_at`` from
|
||||
:class:`BaseModel` so soft-delete and the standard timestamp behaviour are
|
||||
consistent with every other model in the application.
|
||||
"""
|
||||
|
||||
__tablename__ = "department_cert_policies"
|
||||
|
||||
department_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("departments.id"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Expiry control
|
||||
allow_user_expiry = db.Column(db.Boolean, nullable=False, default=False)
|
||||
default_expiry_hours = db.Column(db.Integer, nullable=False, default=1)
|
||||
max_expiry_hours = db.Column(db.Integer, nullable=False, default=24)
|
||||
|
||||
# Extensions — list of extension name strings
|
||||
allowed_extensions = db.Column(
|
||||
db.JSON,
|
||||
nullable=False,
|
||||
default=lambda: list(STANDARD_EXTENSIONS),
|
||||
)
|
||||
# Admin-defined extras beyond the standard five
|
||||
custom_extensions = db.Column(db.JSON, nullable=False, default=list)
|
||||
|
||||
# Relationship back to department
|
||||
department = db.relationship("Department", back_populates="cert_policy", uselist=False)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<DepartmentCertPolicy dept={self.department_id} "
|
||||
f"allow_user_expiry={self.allow_user_expiry}>"
|
||||
)
|
||||
|
||||
def all_extensions(self) -> list:
|
||||
"""Return the full list of enabled extensions (allowed + custom)."""
|
||||
return list((self.allowed_extensions or []) + (self.custom_extensions or []))
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
exclude = exclude or []
|
||||
data = super().to_dict(exclude=exclude)
|
||||
# Augment with computed / convenience fields not in the base columns
|
||||
data["all_extensions"] = self.all_extensions()
|
||||
data["standard_extensions"] = STANDARD_EXTENSIONS
|
||||
return data
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Organization invite token model."""
|
||||
import secrets
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class OrgInviteToken(BaseModel):
|
||||
"""Token-based invitation to join an organization."""
|
||||
|
||||
__tablename__ = "org_invite_tokens"
|
||||
|
||||
organization_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("organizations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
invited_by_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
email = db.Column(db.String(255), nullable=False, index=True)
|
||||
role = db.Column(db.String(64), nullable=False, default="member")
|
||||
token = db.Column(db.String(128), unique=True, nullable=False, index=True)
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
accepted_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
organization = db.relationship(
|
||||
"Organization",
|
||||
backref=db.backref("invite_tokens", cascade="all, delete-orphan"),
|
||||
)
|
||||
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
|
||||
|
||||
@classmethod
|
||||
def generate(
|
||||
cls,
|
||||
organization_id: str,
|
||||
email: str,
|
||||
role: str = "member",
|
||||
invited_by_id: str = None,
|
||||
ttl_days: int = 7,
|
||||
) -> "OrgInviteToken":
|
||||
"""Create a new invite token for an organization."""
|
||||
token_value = secrets.token_urlsafe(48)
|
||||
instance = cls(
|
||||
organization_id=organization_id,
|
||||
email=email.lower(),
|
||||
role=role,
|
||||
invited_by_id=invited_by_id,
|
||||
token=token_value,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=ttl_days),
|
||||
)
|
||||
db.session.add(instance)
|
||||
db.session.commit()
|
||||
return instance
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Return True if the token is unused and not expired."""
|
||||
if self.accepted_at is not None:
|
||||
return False
|
||||
now = datetime.now(timezone.utc)
|
||||
expires = self.expires_at
|
||||
if expires.tzinfo is None:
|
||||
expires = expires.replace(tzinfo=timezone.utc)
|
||||
return now < expires
|
||||
|
||||
def accept(self) -> None:
|
||||
"""Mark the invite as accepted."""
|
||||
self.accepted_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<OrgInviteToken org={self.organization_id} email={self.email}>"
|
||||
+11
-2
@@ -34,6 +34,15 @@ class Organization(BaseModel):
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="OrganizationSecurityPolicy.organization_id",
|
||||
)
|
||||
departments = db.relationship(
|
||||
"Department", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
principals = db.relationship(
|
||||
"Principal", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
cas = db.relationship(
|
||||
"CA", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of Organization."""
|
||||
@@ -52,9 +61,9 @@ class Organization(BaseModel):
|
||||
return member.user
|
||||
return None
|
||||
|
||||
def is_member(self, user_id):
|
||||
def is_member(self, user_id: str) -> bool:
|
||||
"""Check if a user is a member of the organization."""
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
|
||||
return (
|
||||
OrganizationMember.query.filter_by(
|
||||
+12
-8
@@ -1,4 +1,4 @@
|
||||
"""Organization member model."""
|
||||
"""Organization member model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
@@ -21,31 +21,35 @@ class OrganizationMember(BaseModel):
|
||||
joined_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", foreign_keys=[user_id], back_populates="organization_memberships")
|
||||
user = db.relationship(
|
||||
"User", foreign_keys=[user_id], back_populates="organization_memberships"
|
||||
)
|
||||
organization = db.relationship("Organization", back_populates="members")
|
||||
invited_by = db.relationship("User", foreign_keys=[invited_by_id])
|
||||
|
||||
# Unique constraint to prevent duplicate memberships
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OrganizationMember."""
|
||||
return f"<OrganizationMember user_id={self.user_id} org_id={self.organization_id} role={self.role}>"
|
||||
return (
|
||||
f"<OrganizationMember user_id={self.user_id} "
|
||||
f"org_id={self.organization_id} role={self.role}>"
|
||||
)
|
||||
|
||||
def is_owner(self):
|
||||
def is_owner(self) -> bool:
|
||||
"""Check if member is an owner."""
|
||||
return self.role == OrganizationRole.OWNER
|
||||
|
||||
def is_admin(self):
|
||||
def is_admin(self) -> bool:
|
||||
"""Check if member is an admin or owner."""
|
||||
return self.role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
def can_manage_members(self):
|
||||
def can_manage_members(self) -> bool:
|
||||
"""Check if member can manage other members."""
|
||||
return self.is_admin()
|
||||
|
||||
def can_delete_organization(self):
|
||||
def can_delete_organization(self) -> bool:
|
||||
"""Check if member can delete the organization."""
|
||||
return self.is_owner()
|
||||
@@ -0,0 +1,215 @@
|
||||
"""Principal and PrincipalMembership models."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class Principal(BaseModel):
|
||||
"""Principal model representing an SSH principal (access level/role).
|
||||
|
||||
In SSH CA terminology, a principal is a string like "eng-prod-servers" or
|
||||
"devops-admins" that represents a set of machines or access level. Users
|
||||
can be granted access to principals, either directly or via department
|
||||
membership.
|
||||
|
||||
Example:
|
||||
- Principal: "eng-prod-servers"
|
||||
- Users with this principal can SSH to prod servers
|
||||
- Can be assigned to departments or directly to users
|
||||
"""
|
||||
|
||||
__tablename__ = "principals"
|
||||
|
||||
organization_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("organizations.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name = db.Column(db.String(255), nullable=False, index=True)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="principals")
|
||||
memberships = db.relationship(
|
||||
"PrincipalMembership",
|
||||
back_populates="principal",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
department_links = db.relationship(
|
||||
"DepartmentPrincipal",
|
||||
back_populates="principal",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("organization_id", "name", name="uix_org_principal_name"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of Principal."""
|
||||
return f"<Principal {self.name} (org_id={self.organization_id})>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert principal to dictionary."""
|
||||
exclude = exclude or []
|
||||
data = super().to_dict(exclude=exclude)
|
||||
data["direct_member_count"] = len(
|
||||
[m for m in self.memberships if m.deleted_at is None]
|
||||
)
|
||||
data["department_count"] = len(
|
||||
[d for d in self.department_links if d.deleted_at is None]
|
||||
)
|
||||
return data
|
||||
|
||||
def get_members(self, active_only: bool = True):
|
||||
"""Get all users who are directly assigned to this principal.
|
||||
|
||||
Does NOT include users who get access via department membership.
|
||||
|
||||
Args:
|
||||
active_only: If True, exclude soft-deleted members
|
||||
|
||||
Returns:
|
||||
List of PrincipalMembership objects
|
||||
"""
|
||||
if active_only:
|
||||
return [m for m in self.memberships if m.deleted_at is None]
|
||||
return list(self.memberships)
|
||||
|
||||
def get_all_members(self, active_only: bool = True):
|
||||
"""Get all users who have access to this principal.
|
||||
|
||||
Includes both direct members and users via department membership.
|
||||
|
||||
Args:
|
||||
active_only: If True, exclude soft-deleted members
|
||||
|
||||
Returns:
|
||||
Set of User objects with access to this principal
|
||||
"""
|
||||
all_users: set = set()
|
||||
|
||||
# Direct members
|
||||
for membership in self.get_members(active_only=active_only):
|
||||
if not active_only or membership.user.deleted_at is None:
|
||||
all_users.add(membership.user)
|
||||
|
||||
# Members via department assignment
|
||||
for dept_link in self.department_links:
|
||||
if dept_link.deleted_at is None or not active_only:
|
||||
for dept_member in dept_link.department.get_members(active_only=active_only):
|
||||
if not active_only or dept_member.user.deleted_at is None:
|
||||
all_users.add(dept_member.user)
|
||||
|
||||
return all_users
|
||||
|
||||
def get_departments(self, active_only: bool = True):
|
||||
"""Get all departments this principal is assigned to.
|
||||
|
||||
Args:
|
||||
active_only: If True, exclude soft-deleted departments
|
||||
|
||||
Returns:
|
||||
List of Department objects
|
||||
"""
|
||||
if active_only:
|
||||
return [
|
||||
d.department
|
||||
for d in self.department_links
|
||||
if d.deleted_at is None and d.department.deleted_at is None
|
||||
]
|
||||
return [d.department for d in self.department_links]
|
||||
|
||||
def is_member(self, user_id: str, include_via_department: bool = True) -> bool:
|
||||
"""Check if a user has access to this principal.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user to check
|
||||
include_via_department: If True, check department memberships too
|
||||
|
||||
Returns:
|
||||
True if user has access to this principal
|
||||
"""
|
||||
# Check direct membership
|
||||
has_direct = (
|
||||
PrincipalMembership.query.filter_by(
|
||||
user_id=user_id,
|
||||
principal_id=self.id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
if has_direct:
|
||||
return True
|
||||
|
||||
if not include_via_department:
|
||||
return False
|
||||
|
||||
# Check department membership
|
||||
dept_ids = [d.id for d in self.get_departments(active_only=True)]
|
||||
if not dept_ids:
|
||||
return False
|
||||
|
||||
from gatehouse_app.models.organization.department import DepartmentMembership
|
||||
|
||||
return (
|
||||
DepartmentMembership.query.filter(
|
||||
DepartmentMembership.user_id == user_id,
|
||||
DepartmentMembership.department_id.in_(dept_ids),
|
||||
DepartmentMembership.deleted_at.is_(None),
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_member_count(self, include_via_department: bool = True) -> int:
|
||||
"""Get the count of active members with access to this principal.
|
||||
|
||||
Args:
|
||||
include_via_department: If True, include members via department
|
||||
|
||||
Returns:
|
||||
Count of members
|
||||
"""
|
||||
if not include_via_department:
|
||||
return len(self.get_members(active_only=True))
|
||||
return len(self.get_all_members(active_only=True))
|
||||
|
||||
|
||||
class PrincipalMembership(BaseModel):
|
||||
"""Principal membership model representing direct user assignment to a principal.
|
||||
|
||||
When a user is assigned directly to a principal, they get access to that
|
||||
principal for SSH authentication. This is in addition to any principals
|
||||
they get via department membership.
|
||||
"""
|
||||
|
||||
__tablename__ = "principal_memberships"
|
||||
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
principal_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("principals.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="principal_memberships")
|
||||
principal = db.relationship("Principal", back_populates="memberships")
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("user_id", "principal_id", name="uix_user_principal"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of PrincipalMembership."""
|
||||
return (
|
||||
f"<PrincipalMembership user_id={self.user_id} "
|
||||
f"principal_id={self.principal_id}>"
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
"""Security subpackage — organization and user security policies, MFA compliance."""
|
||||
from gatehouse_app.models.security.organization_security_policy import (
|
||||
OrganizationSecurityPolicy,
|
||||
)
|
||||
from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy
|
||||
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
|
||||
|
||||
__all__ = [
|
||||
"OrganizationSecurityPolicy",
|
||||
"UserSecurityPolicy",
|
||||
"MfaPolicyCompliance",
|
||||
]
|
||||
+11
-10
@@ -1,4 +1,4 @@
|
||||
"""MfaPolicyCompliance model."""
|
||||
"""MfaPolicyCompliance model — per-user per-organization MFA compliance tracking."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import MfaComplianceStatus
|
||||
@@ -7,7 +7,8 @@ from gatehouse_app.utils.constants import MfaComplianceStatus
|
||||
class MfaPolicyCompliance(BaseModel):
|
||||
"""MFA policy compliance tracking per user per organization.
|
||||
|
||||
Tracks each user's MFA compliance state separately for each organization membership.
|
||||
Tracks each user's MFA compliance state separately for each organization
|
||||
membership. One row per (user, org) pair.
|
||||
"""
|
||||
|
||||
__tablename__ = "mfa_policy_compliance"
|
||||
@@ -25,13 +26,13 @@ class MfaPolicyCompliance(BaseModel):
|
||||
default=MfaComplianceStatus.NOT_APPLICABLE,
|
||||
)
|
||||
|
||||
# Snapshot of org policy at the time this record became active
|
||||
# Snapshot of org policy version when this record became active
|
||||
policy_version = db.Column(db.Integer, nullable=False)
|
||||
|
||||
# When policy started applying to this user
|
||||
applied_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Final deadline for this user to comply (per user, not global)
|
||||
# Final deadline for this user to comply
|
||||
deadline_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# When they became compliant under this policy_version
|
||||
@@ -45,9 +46,7 @@ class MfaPolicyCompliance(BaseModel):
|
||||
notification_count = db.Column(db.Integer, nullable=False, default=0)
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"user_id", "organization_id", name="uix_user_org_compliance"
|
||||
),
|
||||
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_compliance"),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
@@ -58,9 +57,11 @@ class MfaPolicyCompliance(BaseModel):
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of MfaPolicyCompliance."""
|
||||
return f"<MfaPolicyCompliance user={self.user_id} org={self.organization_id} status={self.status}>"
|
||||
return (
|
||||
f"<MfaPolicyCompliance user={self.user_id} "
|
||||
f"org={self.organization_id} status={self.status}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
exclude = exclude or []
|
||||
return super().to_dict(exclude=exclude)
|
||||
return super().to_dict(exclude=exclude or [])
|
||||
+8
-4
@@ -39,15 +39,19 @@ class OrganizationSecurityPolicy(BaseModel):
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship(
|
||||
"Organization", back_populates="security_policy", foreign_keys=[organization_id]
|
||||
"Organization",
|
||||
back_populates="security_policy",
|
||||
foreign_keys=[organization_id],
|
||||
)
|
||||
updated_by_user = db.relationship("User", foreign_keys=[updated_by_user_id])
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of OrganizationSecurityPolicy."""
|
||||
return f"<OrganizationSecurityPolicy org={self.organization_id} mode={self.mfa_policy_mode}>"
|
||||
return (
|
||||
f"<OrganizationSecurityPolicy "
|
||||
f"org={self.organization_id} mode={self.mfa_policy_mode}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
exclude = exclude or []
|
||||
return super().to_dict(exclude=exclude)
|
||||
return super().to_dict(exclude=exclude or [])
|
||||
+10
-12
@@ -1,4 +1,4 @@
|
||||
"""UserSecurityPolicy model."""
|
||||
"""UserSecurityPolicy model — per-user MFA overrides."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import MfaRequirementOverride
|
||||
@@ -7,7 +7,7 @@ from gatehouse_app.utils.constants import MfaRequirementOverride
|
||||
class UserSecurityPolicy(BaseModel):
|
||||
"""User security policy model for per-user MFA overrides.
|
||||
|
||||
Stores per user overrides of organization level MFA requirements.
|
||||
Stores per-user overrides of organization-level MFA requirements.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_security_policies"
|
||||
@@ -25,29 +25,27 @@ class UserSecurityPolicy(BaseModel):
|
||||
default=MfaRequirementOverride.INHERIT,
|
||||
)
|
||||
|
||||
# If override is REQUIRED and you want to force a specific factor set
|
||||
# If override is REQUIRED, optionally force a specific factor set
|
||||
force_totp = db.Column(db.Boolean, nullable=False, default=False)
|
||||
force_webauthn = db.Column(db.Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint(
|
||||
"user_id", "organization_id", name="uix_user_org_policy"
|
||||
),
|
||||
db.UniqueConstraint("user_id", "organization_id", name="uix_user_org_policy"),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship(
|
||||
"User", back_populates="security_policies", foreign_keys=[user_id]
|
||||
)
|
||||
organization = db.relationship(
|
||||
"Organization", foreign_keys=[organization_id]
|
||||
)
|
||||
organization = db.relationship("Organization", foreign_keys=[organization_id])
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of UserSecurityPolicy."""
|
||||
return f"<UserSecurityPolicy user={self.user_id} org={self.organization_id} mode={self.mfa_override_mode}>"
|
||||
return (
|
||||
f"<UserSecurityPolicy user={self.user_id} "
|
||||
f"org={self.organization_id} mode={self.mfa_override_mode}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
exclude = exclude or []
|
||||
return super().to_dict(exclude=exclude)
|
||||
return super().to_dict(exclude=exclude or [])
|
||||
@@ -0,0 +1,17 @@
|
||||
"""SSH/CA subpackage — certificate authorities, SSH keys, and certificates."""
|
||||
from gatehouse_app.models.ssh_ca.ca import CA, KeyType, CertType, CaType, CAPermission
|
||||
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
|
||||
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
|
||||
from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
|
||||
|
||||
__all__ = [
|
||||
"CA",
|
||||
"KeyType",
|
||||
"CertType",
|
||||
"CaType",
|
||||
"CAPermission",
|
||||
"SSHKey",
|
||||
"SSHCertificate",
|
||||
"CertificateStatus",
|
||||
"CertificateAuditLog",
|
||||
]
|
||||
@@ -0,0 +1,238 @@
|
||||
"""Certificate Authority (CA) model."""
|
||||
from enum import Enum
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class KeyType(str, Enum):
|
||||
"""SSH CA key types."""
|
||||
|
||||
ED25519 = "ed25519"
|
||||
RSA = "rsa"
|
||||
ECDSA = "ecdsa"
|
||||
|
||||
|
||||
class CertType(str, Enum):
|
||||
"""SSH certificate types."""
|
||||
|
||||
USER = "user"
|
||||
HOST = "host"
|
||||
|
||||
|
||||
class CaType(str, Enum):
|
||||
"""CA signing type — whether this CA signs user or host certificates."""
|
||||
|
||||
USER = "user"
|
||||
HOST = "host"
|
||||
|
||||
|
||||
class CA(BaseModel):
|
||||
"""Certificate Authority (CA) model for SSH certificate signing.
|
||||
|
||||
Each organization can have multiple CAs for different purposes
|
||||
(e.g., production vs. staging). Private keys are encrypted at rest
|
||||
and should be protected with KMS.
|
||||
"""
|
||||
|
||||
__tablename__ = "cas"
|
||||
|
||||
organization_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("organizations.id"),
|
||||
nullable=True, # NULL for the global system-config CA
|
||||
index=True,
|
||||
)
|
||||
|
||||
# CA identity
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=True)
|
||||
|
||||
# CA signing type: 'user' signs user certificates, 'host' signs host certs
|
||||
ca_type = db.Column(
|
||||
db.Enum(CaType, values_callable=lambda x: [e.value for e in x]),
|
||||
default=CaType.USER,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Key type (ED25519, RSA, ECDSA)
|
||||
key_type = db.Column(
|
||||
db.Enum(KeyType, values_callable=lambda x: [e.value for e in x]),
|
||||
default=KeyType.ED25519,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Private key — PEM-encoded, encrypted at rest by database/KMS
|
||||
private_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Public key — PEM format
|
||||
public_key = db.Column(db.Text, nullable=False)
|
||||
|
||||
# SHA256 fingerprint of the public key
|
||||
fingerprint = db.Column(db.String(255), nullable=False, unique=True)
|
||||
|
||||
# CRL (Certificate Revocation List) configuration
|
||||
crl_enabled = db.Column(db.Boolean, default=True, nullable=False)
|
||||
crl_endpoint = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Default certificate validity in hours (overridable per request)
|
||||
default_cert_validity_hours = db.Column(db.Integer, default=1, nullable=False)
|
||||
|
||||
# Maximum validity duration allowed
|
||||
max_cert_validity_hours = db.Column(db.Integer, default=24, nullable=False)
|
||||
|
||||
# CA status
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Key rotation tracking
|
||||
rotated_at = db.Column(db.DateTime, nullable=True)
|
||||
rotation_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Monotonically-increasing serial counter. Every cert this CA issues
|
||||
# gets the next value so serials are unique, ordered, and auditable.
|
||||
# Protected by a row-level SELECT … FOR UPDATE in get_next_serial().
|
||||
next_serial_number = db.Column(db.BigInteger, default=1, nullable=False)
|
||||
|
||||
# Relationships
|
||||
organization = db.relationship("Organization", back_populates="cas")
|
||||
certificates = db.relationship(
|
||||
"SSHCertificate",
|
||||
back_populates="ca",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
permissions = db.relationship(
|
||||
"CAPermission",
|
||||
back_populates="ca",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_ca_org_active", "organization_id", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of CA."""
|
||||
return (
|
||||
f"<CA {self.name} "
|
||||
f"(org_id={self.organization_id}, type={self.key_type})>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert CA to dictionary, never exposing the private key."""
|
||||
exclude = exclude or []
|
||||
if "private_key" not in exclude:
|
||||
exclude.append("private_key")
|
||||
data = super().to_dict(exclude=exclude)
|
||||
|
||||
# Add computed fields
|
||||
data["total_certs"] = len([c for c in self.certificates if c.deleted_at is None])
|
||||
data["active_certs"] = len(
|
||||
[c for c in self.certificates if c.deleted_at is None and not c.revoked]
|
||||
)
|
||||
data["revoked_certs"] = len(
|
||||
[c for c in self.certificates if c.deleted_at is None and c.revoked]
|
||||
)
|
||||
return data
|
||||
|
||||
def get_active_certificates(self) -> list:
|
||||
"""Get all active (non-revoked) certificates issued by this CA."""
|
||||
return [
|
||||
c for c in self.certificates if c.deleted_at is None and not c.revoked
|
||||
]
|
||||
|
||||
def rotate_key(
|
||||
self,
|
||||
new_private_key: str,
|
||||
new_public_key: str,
|
||||
new_fingerprint: str,
|
||||
reason: str = None,
|
||||
) -> None:
|
||||
"""Rotate the CA's key pair.
|
||||
|
||||
This should only be done in carefully controlled circumstances.
|
||||
All existing certificates remain valid but no new certificates can be
|
||||
signed with the old key after rotation.
|
||||
|
||||
Args:
|
||||
new_private_key: New PEM-encoded private key
|
||||
new_public_key: New PEM-encoded public key
|
||||
new_fingerprint: SHA256 fingerprint of new public key
|
||||
reason: Optional reason for rotation
|
||||
"""
|
||||
self.private_key = new_private_key
|
||||
self.public_key = new_public_key
|
||||
self.fingerprint = new_fingerprint
|
||||
self.rotated_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
|
||||
self.rotation_reason = reason
|
||||
self.save()
|
||||
|
||||
def get_next_serial(self) -> int:
|
||||
"""Atomically increment and return the next certificate serial number.
|
||||
|
||||
Uses a SELECT … FOR UPDATE row lock so concurrent requests never
|
||||
receive the same serial. Must be called inside an active DB
|
||||
transaction (i.e. before the final session.commit()).
|
||||
|
||||
Returns:
|
||||
int: The serial number to embed in the next certificate.
|
||||
"""
|
||||
# Re-fetch this CA row with an exclusive row lock
|
||||
locked = (
|
||||
db.session.query(CA)
|
||||
.with_for_update()
|
||||
.filter_by(id=self.id)
|
||||
.one()
|
||||
)
|
||||
serial = locked.next_serial_number
|
||||
locked.next_serial_number = serial + 1
|
||||
db.session.flush() # write increment; commit happens in the caller
|
||||
return serial
|
||||
|
||||
|
||||
class CAPermission(BaseModel):
|
||||
"""Per-user CA permission model.
|
||||
|
||||
Controls which users are allowed to sign certificates against a specific CA.
|
||||
When a CA has any permission rows, the signing endpoint enforces the list;
|
||||
CAs with no rows are open to all org members (backwards-compatible default).
|
||||
|
||||
Permission values:
|
||||
sign – user may request certificate signing
|
||||
admin – user may sign AND manage the CA (rotate keys, delete, etc.)
|
||||
"""
|
||||
|
||||
__tablename__ = "ca_permissions"
|
||||
|
||||
ca_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("cas.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
permission = db.Column(db.String(50), nullable=False, default="sign")
|
||||
|
||||
# Relationships
|
||||
ca = db.relationship("CA", back_populates="permissions")
|
||||
user = db.relationship("User", back_populates="ca_permissions")
|
||||
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("ca_id", "user_id", name="uix_ca_permission"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<CAPermission ca_id={self.ca_id} "
|
||||
f"user_id={self.user_id} permission={self.permission}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
data = super().to_dict(exclude=exclude or [])
|
||||
data["permission"] = self.permission
|
||||
return data
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Certificate audit log model — tracks SSH certificate lifecycle events."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class CertificateAuditLog(BaseModel):
|
||||
"""Audit log for SSH certificate lifecycle events.
|
||||
|
||||
Tracks all operations on SSH certificates: signing, revocation, validation,
|
||||
etc. Kept separate from the general AuditLog to provide detailed certificate
|
||||
operation tracking without polluting the main audit stream.
|
||||
"""
|
||||
|
||||
__tablename__ = "certificate_audit_logs"
|
||||
|
||||
# Reference to the certificate
|
||||
certificate_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("ssh_certificates.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# The user who performed the action (null for system actions)
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Action type (e.g., "signed", "revoked", "validated", "requested")
|
||||
action = db.Column(db.String(50), nullable=False, index=True)
|
||||
|
||||
# Request details
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.String(512), nullable=True)
|
||||
request_id = db.Column(db.String(36), nullable=True)
|
||||
|
||||
# Detailed message
|
||||
message = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Additional context
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
|
||||
# Outcome
|
||||
success = db.Column(db.Boolean, default=True, nullable=False)
|
||||
error_message = db.Column(db.Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
certificate = db.relationship("SSHCertificate", back_populates="audit_logs")
|
||||
user = db.relationship("User")
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_cert_audit_cert_action", "certificate_id", "action"),
|
||||
db.Index("idx_cert_audit_user", "user_id", "created_at"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of CertificateAuditLog."""
|
||||
return (
|
||||
f"<CertificateAuditLog cert_id={self.certificate_id} action={self.action}>"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log(
|
||||
cls,
|
||||
certificate_id: str,
|
||||
action: str,
|
||||
user_id: str = None,
|
||||
**kwargs,
|
||||
) -> "CertificateAuditLog":
|
||||
"""Create a certificate audit log entry.
|
||||
|
||||
Args:
|
||||
certificate_id: ID of the certificate
|
||||
action: Action type (e.g., "signed", "revoked")
|
||||
user_id: ID of the user performing the action (optional)
|
||||
**kwargs: Additional fields (ip_address, user_agent, message, etc.)
|
||||
|
||||
Returns:
|
||||
CertificateAuditLog instance
|
||||
"""
|
||||
log_entry = cls(
|
||||
certificate_id=certificate_id,
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
**kwargs,
|
||||
)
|
||||
log_entry.save()
|
||||
return log_entry
|
||||
@@ -0,0 +1,176 @@
|
||||
"""SSH Certificate model — signed SSH user/host certificates."""
|
||||
from enum import Enum
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.ssh_ca.ca import CertType
|
||||
|
||||
|
||||
class CertificateStatus(str, Enum):
|
||||
"""SSH certificate lifecycle status."""
|
||||
|
||||
REQUESTED = "requested" # Waiting for signing
|
||||
ISSUED = "issued" # Signed and valid
|
||||
REVOKED = "revoked" # Manually revoked
|
||||
EXPIRED = "expired" # Validity period ended
|
||||
SUPERSEDED = "superseded" # Replaced by newer certificate
|
||||
|
||||
|
||||
class SSHCertificate(BaseModel):
|
||||
"""SSH Certificate model representing a signed SSH user/host certificate.
|
||||
|
||||
Certificates are issued by a CA and associated with an SSH public key.
|
||||
They include principals (access levels), validity periods, and standard
|
||||
OpenSSH certificate metadata.
|
||||
"""
|
||||
|
||||
__tablename__ = "ssh_certificates"
|
||||
|
||||
# Certificate relationships
|
||||
ca_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("cas.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
ssh_key_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("ssh_keys.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Certificate content — full signed certificate in OpenSSH format
|
||||
certificate = db.Column(db.Text, nullable=False)
|
||||
|
||||
# Certificate metadata
|
||||
serial = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
key_id = db.Column(db.String(255), nullable=False) # Usually user email
|
||||
cert_type = db.Column(
|
||||
db.Enum(CertType, values_callable=lambda x: [e.value for e in x]),
|
||||
default=CertType.USER,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Principals — JSON list, e.g., ["prod-servers", "dev-servers"]
|
||||
principals = db.Column(db.JSON, nullable=False, default=list)
|
||||
|
||||
# Validity period
|
||||
valid_after = db.Column(db.DateTime, nullable=False)
|
||||
valid_before = db.Column(db.DateTime, nullable=False)
|
||||
|
||||
# Revocation status
|
||||
revoked = db.Column(db.Boolean, default=False, nullable=False, index=True)
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoke_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Status tracking
|
||||
status = db.Column(
|
||||
db.Enum(CertificateStatus, values_callable=lambda x: [e.value for e in x]),
|
||||
default=CertificateStatus.ISSUED,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Request metadata
|
||||
request_ip = db.Column(db.String(45), nullable=True)
|
||||
request_user_agent = db.Column(db.String(512), nullable=True)
|
||||
|
||||
# Critical options — OpenSSH critical options (JSON)
|
||||
critical_options = db.Column(db.JSON, nullable=True, default=dict)
|
||||
|
||||
# Extensions — OpenSSH extensions (JSON)
|
||||
extensions = db.Column(db.JSON, nullable=True, default=dict)
|
||||
|
||||
# Relationships
|
||||
ca = db.relationship("CA", back_populates="certificates")
|
||||
user = db.relationship("User", back_populates="ssh_certificates")
|
||||
ssh_key = db.relationship(
|
||||
"SSHKey",
|
||||
back_populates="certificates",
|
||||
foreign_keys="SSHCertificate.ssh_key_id",
|
||||
)
|
||||
audit_logs = db.relationship(
|
||||
"CertificateAuditLog",
|
||||
back_populates="certificate",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_cert_user_status", "user_id", "status"),
|
||||
db.Index("idx_cert_validity", "valid_after", "valid_before"),
|
||||
db.Index("idx_cert_revoked", "revoked", "revoked_at"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of SSHCertificate."""
|
||||
return f"<SSHCertificate serial={self.serial[:16]}... user_id={self.user_id}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert certificate to dictionary.
|
||||
|
||||
The raw ``certificate`` blob is excluded by default (it is large and
|
||||
callers can request it explicitly by removing it from the exclude list).
|
||||
"""
|
||||
exclude = exclude or []
|
||||
if "certificate" not in exclude:
|
||||
exclude.append("certificate")
|
||||
data = super().to_dict(exclude=exclude)
|
||||
data["is_valid"] = self.is_valid()
|
||||
data["days_until_expiry"] = self.days_until_expiry()
|
||||
return data
|
||||
|
||||
def _aware(self, dt: datetime) -> datetime:
|
||||
"""Return a timezone-aware UTC datetime."""
|
||||
return dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if certificate is currently valid.
|
||||
|
||||
Returns:
|
||||
True if certificate is issued, not revoked, and within validity period
|
||||
"""
|
||||
if self.revoked or self.status == CertificateStatus.REVOKED:
|
||||
return False
|
||||
now = datetime.now(timezone.utc)
|
||||
return self._aware(self.valid_after) <= now <= self._aware(self.valid_before)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if certificate has expired.
|
||||
|
||||
Returns:
|
||||
True if current time is past valid_before
|
||||
"""
|
||||
return datetime.now(timezone.utc) > self._aware(self.valid_before)
|
||||
|
||||
def days_until_expiry(self) -> int:
|
||||
"""Get number of days until certificate expires.
|
||||
|
||||
Returns:
|
||||
Number of days remaining (negative if already expired)
|
||||
"""
|
||||
delta = self._aware(self.valid_before) - datetime.now(timezone.utc)
|
||||
return delta.days + (1 if delta.seconds > 0 else 0)
|
||||
|
||||
def revoke(self, reason: str = None) -> None:
|
||||
"""Revoke this certificate.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
"""
|
||||
self.revoked = True
|
||||
self.revoked_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
|
||||
self.revoke_reason = reason
|
||||
self.status = CertificateStatus.REVOKED
|
||||
self.save()
|
||||
|
||||
def mark_expired(self) -> None:
|
||||
"""Mark certificate as expired when validity period ends."""
|
||||
self.status = CertificateStatus.EXPIRED
|
||||
self.save()
|
||||
@@ -0,0 +1,98 @@
|
||||
"""SSH Key model — user SSH public keys registered for certificate signing."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
class SSHKey(BaseModel):
|
||||
"""SSH Key model representing a user's SSH public key.
|
||||
|
||||
Users register SSH public keys for certificate signing. Keys must be
|
||||
verified (owner proved possession) before they can be used.
|
||||
"""
|
||||
|
||||
__tablename__ = "ssh_keys"
|
||||
|
||||
user_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("users.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# SSH key payload in OpenSSH format (e.g., "ssh-ed25519 AAAAB3Nz...")
|
||||
payload = db.Column(db.Text, nullable=False, unique=True)
|
||||
|
||||
# SHA256 fingerprint for quick comparison and deduplication
|
||||
fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
|
||||
# Optional human-readable description (e.g., "My laptop key")
|
||||
description = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Verification status
|
||||
verified = db.Column(db.Boolean, default=False, nullable=False, index=True)
|
||||
verified_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Verification challenge — shown to user once, cleared after verification
|
||||
verify_text = db.Column(db.String(255), nullable=True)
|
||||
verify_text_created_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Key metadata extracted from the key
|
||||
key_type = db.Column(db.String(50), nullable=True) # ssh-rsa, ssh-ed25519, etc.
|
||||
key_bits = db.Column(db.Integer, nullable=True) # key length
|
||||
key_comment = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationships
|
||||
user = db.relationship("User", back_populates="ssh_keys")
|
||||
certificates = db.relationship(
|
||||
"SSHCertificate",
|
||||
back_populates="ssh_key",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHCertificate.ssh_key_id",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_ssh_key_user_verified", "user_id", "verified"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of SSHKey."""
|
||||
return f"<SSHKey {self.fingerprint[:16]}... user_id={self.user_id}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert SSH key to dictionary.
|
||||
|
||||
``payload`` and ``verify_text`` are never exposed through the API.
|
||||
"""
|
||||
exclude = exclude or []
|
||||
for field in ("payload", "verify_text"):
|
||||
if field not in exclude:
|
||||
exclude.append(field)
|
||||
data = super().to_dict(exclude=exclude)
|
||||
data["cert_count"] = len([c for c in self.certificates if c.deleted_at is None])
|
||||
return data
|
||||
|
||||
def mark_verified(self) -> None:
|
||||
"""Mark this SSH key as verified and clear the challenge."""
|
||||
self.verified = True
|
||||
self.verified_at = datetime.now(timezone.utc) # Bug fix: was datetime.utcnow()
|
||||
self.verify_text = None
|
||||
self.save()
|
||||
|
||||
def needs_verification_refresh(self, max_age_hours: int = 24) -> bool:
|
||||
"""Check if verification challenge needs to be refreshed.
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age of verification challenge in hours
|
||||
|
||||
Returns:
|
||||
True if verification challenge is stale or missing
|
||||
"""
|
||||
if not self.verify_text_created_at:
|
||||
return True
|
||||
age = datetime.now(timezone.utc) - self.verify_text_created_at.replace(
|
||||
tzinfo=timezone.utc
|
||||
) if self.verify_text_created_at.tzinfo is None else (
|
||||
datetime.now(timezone.utc) - self.verify_text_created_at
|
||||
)
|
||||
return age.total_seconds() > (max_age_hours * 3600)
|
||||
@@ -1,153 +1,4 @@
|
||||
"""User model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
"""Backward-compatibility shim — import from gatehouse_app.models.user.user instead."""
|
||||
from gatehouse_app.models.user.user import User # noqa: F401
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""User model representing a user account."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
email_verified = db.Column(db.Boolean, default=False, nullable=False)
|
||||
full_name = db.Column(db.String(255), nullable=True)
|
||||
avatar_url = db.Column(db.String(512), nullable=True)
|
||||
status = db.Column(
|
||||
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
|
||||
)
|
||||
last_login_at = db.Column(db.DateTime, nullable=True)
|
||||
last_login_ip = db.Column(db.String(45), nullable=True)
|
||||
|
||||
# Relationships
|
||||
authentication_methods = db.relationship(
|
||||
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
|
||||
organization_memberships = db.relationship(
|
||||
"OrganizationMember",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="OrganizationMember.user_id",
|
||||
)
|
||||
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
|
||||
security_policies = db.relationship(
|
||||
"UserSecurityPolicy",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="UserSecurityPolicy.user_id",
|
||||
)
|
||||
mfa_compliance = db.relationship(
|
||||
"MfaPolicyCompliance",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="MfaPolicyCompliance.user_id",
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of User."""
|
||||
return f"<User {self.email}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert user to dictionary, excluding sensitive fields by default."""
|
||||
exclude = exclude or []
|
||||
# Always exclude password-related fields
|
||||
default_exclude = []
|
||||
all_exclude = list(set(default_exclude + exclude))
|
||||
return super().to_dict(exclude=all_exclude)
|
||||
|
||||
def has_password_auth(self):
|
||||
"""Check if user has password authentication enabled."""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_organizations(self):
|
||||
"""Get all organizations the user is a member of."""
|
||||
return [membership.organization for membership in self.organization_memberships]
|
||||
|
||||
def has_totp_enabled(self) -> bool:
|
||||
"""Check if user has TOTP enabled and verified.
|
||||
|
||||
Returns:
|
||||
True if user has a verified TOTP authentication method, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.TOTP,
|
||||
verified=True,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_totp_method(self):
|
||||
"""Get user's TOTP authentication method.
|
||||
|
||||
Returns:
|
||||
The AuthenticationMethod instance for TOTP or None if not found.
|
||||
|
||||
Note:
|
||||
Returns the most recently created TOTP method to handle cases where
|
||||
multiple enrollment attempts may exist.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
|
||||
).order_by(AuthenticationMethod.created_at.desc()).first()
|
||||
|
||||
def has_webauthn_enabled(self) -> bool:
|
||||
"""Check if user has any WebAuthn passkey credentials.
|
||||
|
||||
Returns:
|
||||
True if user has at least one WebAuthn credential, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_webauthn_credentials(self):
|
||||
"""Get all WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
).order_by(AuthenticationMethod.created_at.desc()).all()
|
||||
|
||||
def get_webauthn_credential_count(self) -> int:
|
||||
"""Get the count of WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
Number of WebAuthn credentials.
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
).count()
|
||||
__all__ = ["User"]
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""User subpackage."""
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.user.session import Session
|
||||
|
||||
__all__ = ["User", "Session"]
|
||||
@@ -21,7 +21,9 @@ class Session(BaseModel):
|
||||
|
||||
# Timing
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
last_activity_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
last_activity_at = db.Column(
|
||||
db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
@@ -38,7 +40,6 @@ class Session(BaseModel):
|
||||
def is_active(self):
|
||||
"""Check if session is currently active."""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make expires_at timezone-aware if it's naive
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
@@ -51,15 +52,13 @@ class Session(BaseModel):
|
||||
def is_expired(self):
|
||||
"""Check if session has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make expires_at timezone-aware if it's naive
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return now > expires_at
|
||||
|
||||
def refresh(self, duration_seconds=86400):
|
||||
"""
|
||||
Refresh session expiration.
|
||||
def refresh(self, duration_seconds: int = 86400):
|
||||
"""Refresh session expiration.
|
||||
|
||||
Args:
|
||||
duration_seconds: New session duration in seconds
|
||||
@@ -68,9 +67,8 @@ class Session(BaseModel):
|
||||
self.last_activity_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
def revoke(self, reason=None):
|
||||
"""
|
||||
Revoke the session.
|
||||
def revoke(self, reason: str = None):
|
||||
"""Revoke the session.
|
||||
|
||||
Args:
|
||||
reason: Optional reason for revocation
|
||||
@@ -84,6 +82,5 @@ class Session(BaseModel):
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
# Exclude token from dict
|
||||
exclude.append("token")
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,209 @@
|
||||
"""User model."""
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import UserStatus
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""User model representing a user account."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
email_verified = db.Column(db.Boolean, default=False, nullable=False)
|
||||
full_name = db.Column(db.String(255), nullable=True)
|
||||
avatar_url = db.Column(db.String(512), nullable=True)
|
||||
status = db.Column(
|
||||
db.Enum(UserStatus), default=UserStatus.ACTIVE, nullable=False, index=True
|
||||
)
|
||||
last_login_at = db.Column(db.DateTime, nullable=True)
|
||||
last_login_ip = db.Column(db.String(45), nullable=True)
|
||||
|
||||
# Account activation (email-link flow)
|
||||
activated = db.Column(db.Boolean, default=True, nullable=False)
|
||||
activation_key = db.Column(db.String(128), unique=True, nullable=True, index=True)
|
||||
|
||||
# Relationships – defined here only for models that don't circular-import
|
||||
authentication_methods = db.relationship(
|
||||
"AuthenticationMethod", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
sessions = db.relationship("Session", back_populates="user", cascade="all, delete-orphan")
|
||||
organization_memberships = db.relationship(
|
||||
"OrganizationMember",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="OrganizationMember.user_id",
|
||||
)
|
||||
audit_logs = db.relationship("AuditLog", back_populates="user", cascade="all, delete-orphan")
|
||||
security_policies = db.relationship(
|
||||
"UserSecurityPolicy",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="UserSecurityPolicy.user_id",
|
||||
)
|
||||
mfa_compliance = db.relationship(
|
||||
"MfaPolicyCompliance",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="MfaPolicyCompliance.user_id",
|
||||
)
|
||||
department_memberships = db.relationship(
|
||||
"DepartmentMembership",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="DepartmentMembership.user_id",
|
||||
)
|
||||
principal_memberships = db.relationship(
|
||||
"PrincipalMembership",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="PrincipalMembership.user_id",
|
||||
)
|
||||
ssh_keys = db.relationship(
|
||||
"SSHKey",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHKey.user_id",
|
||||
)
|
||||
ssh_certificates = db.relationship(
|
||||
"SSHCertificate",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="SSHCertificate.user_id",
|
||||
)
|
||||
ca_permissions = db.relationship(
|
||||
"CAPermission",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="CAPermission.user_id",
|
||||
)
|
||||
|
||||
# OIDC relationships – registered here (no monkey-patching needed)
|
||||
oidc_auth_codes = db.relationship(
|
||||
"OIDCAuthCode", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_refresh_tokens = db.relationship(
|
||||
"OIDCRefreshToken", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_sessions = db.relationship(
|
||||
"OIDCSession", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_token_metadata = db.relationship(
|
||||
"OIDCTokenMetadata", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oidc_audit_logs = db.relationship(
|
||||
"OIDCAuditLog", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""String representation of User."""
|
||||
return f"<User {self.email}>"
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert user to dictionary, excluding sensitive fields by default."""
|
||||
exclude = exclude or []
|
||||
return super().to_dict(exclude=exclude)
|
||||
|
||||
def has_password_auth(self):
|
||||
"""Check if user has password authentication enabled."""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.PASSWORD, deleted_at=None
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_organizations(self):
|
||||
"""Get all organizations the user is a member of."""
|
||||
return [membership.organization for membership in self.organization_memberships]
|
||||
|
||||
def has_totp_enabled(self) -> bool:
|
||||
"""Check if user has TOTP enabled and verified.
|
||||
|
||||
Returns:
|
||||
True if user has a verified TOTP authentication method, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.TOTP,
|
||||
verified=True,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_totp_method(self):
|
||||
"""Get user's TOTP authentication method.
|
||||
|
||||
Returns:
|
||||
The AuthenticationMethod instance for TOTP or None if not found.
|
||||
|
||||
Note:
|
||||
Returns the most recently created TOTP method to handle cases where
|
||||
multiple enrollment attempts may exist.
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.TOTP, deleted_at=None
|
||||
)
|
||||
.order_by(AuthenticationMethod.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
def has_webauthn_enabled(self) -> bool:
|
||||
"""Check if user has any WebAuthn passkey credentials.
|
||||
|
||||
Returns:
|
||||
True if user has at least one WebAuthn credential, False otherwise.
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_webauthn_credentials(self):
|
||||
"""Get all WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
List of AuthenticationMethod instances for WebAuthn, ordered by creation date.
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return (
|
||||
AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
)
|
||||
.order_by(AuthenticationMethod.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_webauthn_credential_count(self) -> int:
|
||||
"""Get the count of WebAuthn credentials for the user.
|
||||
|
||||
Returns:
|
||||
Number of WebAuthn credentials.
|
||||
"""
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
|
||||
return AuthenticationMethod.query.filter_by(
|
||||
user_id=self.id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None
|
||||
).count()
|
||||
@@ -25,7 +25,7 @@ class LoginSchema(Schema):
|
||||
|
||||
email = fields.Email(required=True)
|
||||
password = fields.Str(required=True, validate=validate.Length(min=1))
|
||||
remember_me = fields.Bool(missing=False)
|
||||
remember_me = fields.Bool(load_default=False)
|
||||
|
||||
|
||||
class RefreshTokenSchema(Schema):
|
||||
@@ -77,14 +77,38 @@ class TOTPVerifyEnrollmentSchema(Schema):
|
||||
class TOTPVerifySchema(Schema):
|
||||
"""Schema for TOTP code verification during login."""
|
||||
|
||||
code = fields.Str(required=True)
|
||||
is_backup_code = fields.Bool(missing=False)
|
||||
code = fields.Str(
|
||||
required=True,
|
||||
validate=validate.Length(min=1),
|
||||
)
|
||||
is_backup_code = fields.Bool(load_default=False)
|
||||
client_timestamp = fields.Int(
|
||||
required=False,
|
||||
allow_none=True,
|
||||
metadata={"description": "Client UTC timestamp in seconds since epoch for TOTP verification"},
|
||||
)
|
||||
|
||||
@validates_schema
|
||||
def validate_code_format(self, data, **kwargs):
|
||||
"""Validate code format depending on whether it's a backup code."""
|
||||
code = data.get("code", "")
|
||||
is_backup_code = data.get("is_backup_code", False)
|
||||
if is_backup_code:
|
||||
# Backup codes are 16 uppercase hex characters
|
||||
if not code or len(code) != 16 or not all(c in "0123456789ABCDEFabcdef" for c in code):
|
||||
raise ValidationError(
|
||||
"Backup code must be a 16-character hexadecimal string.",
|
||||
field_name="code",
|
||||
)
|
||||
else:
|
||||
# Regular TOTP codes are exactly 6 digits
|
||||
import re
|
||||
if not re.match(r"^\d{6}$", code):
|
||||
raise ValidationError(
|
||||
"Code must be a 6-digit number.",
|
||||
field_name="code",
|
||||
)
|
||||
|
||||
|
||||
class TOTPDisableSchema(Schema):
|
||||
"""Schema for disabling TOTP."""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Audit service."""
|
||||
from flask import request, g
|
||||
from gatehouse_app.models.audit_log import AuditLog
|
||||
from gatehouse_app.models.auth.audit_log import AuditLog
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ class AuditService:
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
metadata=metadata,
|
||||
extra_data=metadata,
|
||||
description=description,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
|
||||
@@ -5,9 +5,9 @@ from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from flask import request, g, current_app
|
||||
from gatehouse_app.extensions import db, bcrypt
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.models.session import Session
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.utils.constants import AuthMethodType, SessionStatus, UserStatus, AuditAction
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError, AccountSuspendedError, AccountInactiveError
|
||||
from gatehouse_app.exceptions.validation_exceptions import EmailAlreadyExistsError
|
||||
@@ -102,7 +102,7 @@ class AuthService:
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[Auth] Account status: user_id={user.id}, status={user.status}")
|
||||
|
||||
if user.status == UserStatus.SUSPENDED:
|
||||
if user.status in (UserStatus.SUSPENDED, UserStatus.COMPLIANCE_SUSPENDED):
|
||||
raise AccountSuspendedError()
|
||||
if user.status == UserStatus.INACTIVE:
|
||||
raise AccountInactiveError()
|
||||
@@ -210,6 +210,22 @@ class AuthService:
|
||||
auth_method.password_hash = bcrypt.generate_password_hash(new_password).decode("utf-8")
|
||||
db.session.commit()
|
||||
|
||||
# Invalidate all other sessions so that if an attacker had a valid
|
||||
# session token, changing the password actually locks them out.
|
||||
# The current request's session (if any) is preserved so the user
|
||||
# doesn't have to log in again immediately.
|
||||
from flask import g as flask_g
|
||||
current_session_id = getattr(flask_g, "current_session", None)
|
||||
current_session_id = current_session_id.id if current_session_id else None
|
||||
sessions_to_revoke = Session.query.filter(
|
||||
Session.user_id == user.id,
|
||||
Session.revoked_at == None, # noqa: E711
|
||||
).all()
|
||||
for sess in sessions_to_revoke:
|
||||
if sess.id != current_session_id:
|
||||
sess.revoke(reason="Password changed")
|
||||
db.session.commit()
|
||||
|
||||
# Log password change
|
||||
AuditService.log_action(
|
||||
action=AuditAction.PASSWORD_CHANGE,
|
||||
@@ -482,9 +498,24 @@ class AuthService:
|
||||
if not secret:
|
||||
raise InvalidCredentialsError("TOTP secret not found")
|
||||
|
||||
# Replay-attack prevention: reject codes that have already been
|
||||
# accepted within the current validity window.
|
||||
if TOTPService.is_code_already_used(str(user.id), code):
|
||||
AuditService.log_action(
|
||||
action=AuditAction.TOTP_VERIFY_FAILED,
|
||||
user_id=user.id,
|
||||
resource_type="authentication_method",
|
||||
resource_id=auth_method.id,
|
||||
description="TOTP code replay attempt detected",
|
||||
)
|
||||
raise InvalidCredentialsError("Invalid TOTP code")
|
||||
|
||||
is_valid = TOTPService.verify_code(secret, code, client_utc_timestamp=client_utc_timestamp)
|
||||
|
||||
if is_valid:
|
||||
# Mark this code as used to prevent replay within the validity window
|
||||
TOTPService.mark_code_used(str(user.id), code)
|
||||
|
||||
auth_method.last_used_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from flask import current_app
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.authentication_method import (
|
||||
from gatehouse_app.models.auth.authentication_method import (
|
||||
OAuthState,
|
||||
ApplicationProviderConfig,
|
||||
OrganizationProviderOverride
|
||||
@@ -736,9 +736,14 @@ class ExternalAuthService:
|
||||
400,
|
||||
)
|
||||
|
||||
# Generate PKCE
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = cls._compute_s256_challenge(code_verifier)
|
||||
# Generate PKCE — skip for confidential clients (Google, Microsoft) that use a
|
||||
# client_secret. Sending code_challenge to Microsoft causes it to enforce PKCE on
|
||||
# the token exchange, which then fails. Matches the behaviour of initiate_login_flow.
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if provider_type_str not in ('google', 'microsoft'):
|
||||
code_verifier = secrets.token_urlsafe(32)
|
||||
code_challenge = cls._compute_s256_challenge(code_verifier)
|
||||
|
||||
# Create OAuth state
|
||||
state = OAuthState.create_state(
|
||||
@@ -1210,12 +1215,35 @@ class ExternalAuthService:
|
||||
else:
|
||||
email_verified = data.get("email_verified", False)
|
||||
|
||||
sub = data.get("sub")
|
||||
|
||||
# Derive email from sub when the provider omits the email claim.
|
||||
# This happens with some OIDC servers (including the nav-security mock)
|
||||
# that only return the minimal {sub, iss, iat, exp} set.
|
||||
# Rule: if sub looks like an email address, use it directly.
|
||||
# Otherwise, construct a deterministic fallback so we never get NULL.
|
||||
raw_email = data.get("email")
|
||||
if not raw_email and sub:
|
||||
import re as _re
|
||||
if _re.match(r"^[^@\s]+@[^@\s]+\.[^@\s]+$", sub):
|
||||
raw_email = sub
|
||||
email_verified = True # if sub IS the email it's already verified
|
||||
else:
|
||||
# e.g. "12345" → "12345@google.local" so we can store it
|
||||
raw_email = f"{sub}@{provider or 'oauth'}.local"
|
||||
email_verified = False
|
||||
|
||||
# Derive display name when omitted
|
||||
raw_name = data.get("name") or data.get("display_name")
|
||||
if not raw_name and raw_email:
|
||||
raw_name = raw_email.split("@")[0]
|
||||
|
||||
# Standardize user info
|
||||
return {
|
||||
"provider_user_id": data.get("sub"),
|
||||
"email": data.get("email"),
|
||||
"provider_user_id": sub,
|
||||
"email": raw_email,
|
||||
"email_verified": email_verified,
|
||||
"name": data.get("name"),
|
||||
"name": raw_name,
|
||||
"first_name": data.get("given_name"),
|
||||
"last_name": data.get("family_name"),
|
||||
"picture": data.get("picture"),
|
||||
|
||||
@@ -4,11 +4,11 @@ from datetime import datetime, timezone
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
|
||||
from gatehouse_app.models.user_security_policy import UserSecurityPolicy
|
||||
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy
|
||||
from gatehouse_app.models.security.user_security_policy import UserSecurityPolicy
|
||||
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.utils.constants import (
|
||||
MfaPolicyMode,
|
||||
@@ -702,7 +702,7 @@ class MfaPolicyService:
|
||||
if now is None:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
|
||||
updated_count = 0
|
||||
|
||||
|
||||
@@ -19,9 +19,9 @@ import logging
|
||||
import json
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
||||
from gatehouse_app.models.organization_security_policy import OrganizationSecurityPolicy
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
|
||||
from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
|
||||
@@ -37,6 +37,7 @@ class NotificationService:
|
||||
SMTP_PORT_KEY = "SMTP_PORT"
|
||||
SMTP_USERNAME_KEY = "SMTP_USERNAME"
|
||||
SMTP_PASSWORD_KEY = "SMTP_PASSWORD"
|
||||
SMTP_USE_TLS_KEY = "SMTP_USE_TLS"
|
||||
FROM_ADDRESS_KEY = "FROM_ADDRESS"
|
||||
|
||||
@staticmethod
|
||||
@@ -86,10 +87,9 @@ class NotificationService:
|
||||
if success:
|
||||
logger.info(
|
||||
f"Sent MFA deadline reminder to {user.email} "
|
||||
f"({days_until_deadline} days remaining # Audit log
|
||||
)"
|
||||
f"({days_until_deadline} days remaining)"
|
||||
)
|
||||
AuditService.log_action(
|
||||
AuditService.log_action(
|
||||
action=AuditAction.MFA_POLICY_USER_COMPLIANT,
|
||||
user_id=user.id,
|
||||
organization_id=compliance.organization_id,
|
||||
@@ -291,101 +291,62 @@ Gatehouse Security Team
|
||||
body: str,
|
||||
html_body: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Send an email notification.
|
||||
"""Send an email via SMTP.
|
||||
|
||||
This method attempts to send an email using configured SMTP settings.
|
||||
If email is not configured, it logs the notification instead.
|
||||
|
||||
Args:
|
||||
to_address: Recipient email address
|
||||
subject: Email subject
|
||||
body: Plain text email body
|
||||
html_body: Optional HTML email body
|
||||
|
||||
Returns:
|
||||
True if email was sent (or logged), False on error
|
||||
Returns True if the email was sent successfully, False otherwise.
|
||||
If EMAIL_ENABLED is False, logs the email body instead (simulation mode).
|
||||
"""
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from flask import current_app
|
||||
|
||||
email_enabled = current_app.config.get(NotificationService.EMAIL_ENABLED_KEY, False)
|
||||
|
||||
if not email_enabled:
|
||||
logger.info(
|
||||
f"[EMAIL DISABLED] Would have sent to: {to_address} | Subject: {subject}\n"
|
||||
f"Body: {body[:500]}"
|
||||
)
|
||||
return False
|
||||
|
||||
smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY, "localhost")
|
||||
smtp_port = int(current_app.config.get(NotificationService.SMTP_PORT_KEY, 587))
|
||||
smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY)
|
||||
smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY)
|
||||
smtp_use_tls = current_app.config.get(
|
||||
NotificationService.SMTP_USE_TLS_KEY,
|
||||
smtp_port not in (25, 1025),
|
||||
)
|
||||
from_address = current_app.config.get(
|
||||
NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local"
|
||||
)
|
||||
|
||||
try:
|
||||
from flask import current_app
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["From"] = from_address
|
||||
msg["To"] = to_address
|
||||
msg.attach(MIMEText(body, "plain"))
|
||||
if html_body:
|
||||
msg.attach(MIMEText(html_body, "html"))
|
||||
|
||||
# Check if email is configured
|
||||
email_enabled = current_app.config.get(
|
||||
NotificationService.EMAIL_ENABLED_KEY, False
|
||||
)
|
||||
|
||||
if not email_enabled:
|
||||
# Log the notification instead of sending
|
||||
logger.info(
|
||||
f"[EMAIL SIMULATION] To: {to_address}\n"
|
||||
f"Subject: {subject}\n"
|
||||
f"Body: {body[:200]}..." if len(body) > 200 else f"Body: {body}"
|
||||
)
|
||||
return True
|
||||
|
||||
# Get email configuration
|
||||
smtp_host = current_app.config.get(NotificationService.SMTP_HOST_KEY)
|
||||
smtp_port = current_app.config.get(NotificationService.SMTP_PORT_KEY, 587)
|
||||
smtp_username = current_app.config.get(NotificationService.SMTP_USERNAME_KEY)
|
||||
smtp_password = current_app.config.get(NotificationService.SMTP_PASSWORD_KEY)
|
||||
from_address = current_app.config.get(
|
||||
NotificationService.FROM_ADDRESS_KEY, "noreply@gatehouse.local"
|
||||
)
|
||||
|
||||
# Import send_email based on available mail library
|
||||
try:
|
||||
from flask_mail import Message
|
||||
|
||||
from gatehouse_app import mail
|
||||
|
||||
msg = Message(
|
||||
subject=subject,
|
||||
recipients=[to_address],
|
||||
body=body,
|
||||
html=html_body,
|
||||
sender=from_address,
|
||||
)
|
||||
mail.send(msg)
|
||||
logger.info(f"Email sent successfully to {to_address}")
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
# Flask-Mail not available, use SMTP directly
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["From"] = from_address
|
||||
msg["To"] = to_address
|
||||
|
||||
# Attach plain text and HTML versions
|
||||
part1 = MIMEText(body, "plain")
|
||||
msg.attach(part1)
|
||||
|
||||
if html_body:
|
||||
part2 = MIMEText(html_body, "html")
|
||||
msg.attach(part2)
|
||||
|
||||
# Send via SMTP
|
||||
with smtplib.SMTP(smtp_host, smtp_port) as server:
|
||||
with smtplib.SMTP(smtp_host, smtp_port) as server:
|
||||
server.ehlo()
|
||||
if smtp_use_tls:
|
||||
server.starttls()
|
||||
if smtp_username and smtp_password:
|
||||
server.login(smtp_username, smtp_password)
|
||||
server.send_message(msg)
|
||||
server.ehlo()
|
||||
if smtp_username and smtp_password:
|
||||
server.login(smtp_username, smtp_password)
|
||||
server.send_message(msg)
|
||||
|
||||
logger.info(f"Email sent successfully to {to_address}")
|
||||
return True
|
||||
logger.info(f"[EMAIL] Sent to {to_address} | Subject: {subject}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to send email to {to_address}: {e}")
|
||||
# Log the notification as fallback
|
||||
logger.info(
|
||||
f"[EMAIL FALLBACK] To: {to_address}\n"
|
||||
f"Subject: {subject}\n"
|
||||
f"Body: {body[:500]}..." if len(body) > 500 else f"Body: {body}"
|
||||
)
|
||||
return True # Return True to continue processing
|
||||
logger.error(f"[EMAIL] Failed to send to {to_address}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_notification_stats(user_id: str) -> Dict[str, Any]:
|
||||
@@ -397,7 +358,7 @@ Gatehouse Security Team
|
||||
Returns:
|
||||
Dictionary with notification statistics
|
||||
"""
|
||||
from gatehouse_app.models.mfa_policy_compliance import MfaPolicyCompliance
|
||||
from gatehouse_app.models.security.mfa_policy_compliance import MfaPolicyCompliance
|
||||
|
||||
stats = {
|
||||
"total_notifications": 0,
|
||||
|
||||
@@ -9,10 +9,10 @@ from flask import current_app, request, g, redirect
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models import User, AuthenticationMethod
|
||||
from gatehouse_app.models.authentication_method import OAuthState
|
||||
from gatehouse_app.models.auth.authentication_method import OAuthState
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.models.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.models.oidc.oidc_authorization_code import OIDCAuthCode
|
||||
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
from gatehouse_app.services.external_auth_service import (
|
||||
ExternalAuthService,
|
||||
@@ -139,7 +139,7 @@ class OAuthFlowService:
|
||||
except ExternalAuthError as e:
|
||||
# Log failed initiation
|
||||
AuditService.log_action(
|
||||
action="external_auth.login.initiated",
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
@@ -236,7 +236,7 @@ class OAuthFlowService:
|
||||
|
||||
except ExternalAuthError as e:
|
||||
AuditService.log_action(
|
||||
action="external_auth.register.initiated",
|
||||
action=AuditAction.EXTERNAL_AUTH_LOGIN_FAILED,
|
||||
organization_id=organization_id,
|
||||
metadata={
|
||||
"provider_type": provider_type_str,
|
||||
@@ -399,6 +399,27 @@ class OAuthFlowService:
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
if not user_info.get("provider_user_id"):
|
||||
raise OAuthFlowError(
|
||||
"Provider did not return a user identifier (sub claim). "
|
||||
"Cannot complete authentication.",
|
||||
"MISSING_PROVIDER_USER_ID",
|
||||
400,
|
||||
)
|
||||
|
||||
if not user_info.get("email"):
|
||||
raise OAuthFlowError(
|
||||
"Provider did not return an email address. "
|
||||
"Cannot complete authentication.",
|
||||
"MISSING_EMAIL",
|
||||
400,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Got user_info from provider: sub={user_info['provider_user_id']}, "
|
||||
f"email={user_info['email']}, email_verified={user_info.get('email_verified')}"
|
||||
)
|
||||
|
||||
# Look up user by provider_user_id
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
method_type=provider_type,
|
||||
@@ -494,25 +515,52 @@ class OAuthFlowService:
|
||||
if not target_org and len(user_orgs) == 1:
|
||||
target_org = user_orgs[0]
|
||||
|
||||
# Priority 3: No orgs at all — auto-create a personal org and log in
|
||||
# Priority 3: No orgs at all — send to org-setup instead of auto-creating
|
||||
if not target_org and len(user_orgs) == 0:
|
||||
import re
|
||||
import uuid
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
org_name = f"{user_info.get('name') or user.email.split('@')[0]}'s Workspace"
|
||||
# Build a URL-safe slug and ensure uniqueness with a short suffix
|
||||
base_slug = re.sub(r"[^a-z0-9]+", "-", org_name.lower()).strip("-")[:40]
|
||||
slug = f"{base_slug}-{uuid.uuid4().hex[:6]}"
|
||||
org = OrganizationService.create_organization(
|
||||
name=org_name,
|
||||
slug=slug,
|
||||
owner_user_id=user.id,
|
||||
)
|
||||
target_org = org
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
from gatehouse_app.services.auth_service import AuthService as _AS
|
||||
_now = datetime.now(timezone.utc)
|
||||
_session = _AS.create_session(user=user, is_compliance_only=False)
|
||||
_session_dict = _session.to_dict()
|
||||
_session_dict["token"] = _session.token
|
||||
_expires_at = _session.expires_at
|
||||
if _expires_at.tzinfo is None:
|
||||
_expires_at = _expires_at.replace(tzinfo=timezone.utc)
|
||||
_session_dict["expires_in"] = int((_expires_at - _now).total_seconds())
|
||||
|
||||
_pending = OrgInviteToken.query.filter(
|
||||
OrgInviteToken.email == user.email,
|
||||
OrgInviteToken.accepted_at.is_(None),
|
||||
OrgInviteToken.expires_at > _now,
|
||||
OrgInviteToken.deleted_at.is_(None),
|
||||
).all()
|
||||
_pending_list = [
|
||||
{
|
||||
"token": inv.token,
|
||||
"organization": {
|
||||
"id": str(inv.organization_id),
|
||||
"name": inv.organization.name,
|
||||
},
|
||||
"role": inv.role,
|
||||
"expires_at": inv.expires_at.isoformat(),
|
||||
}
|
||||
for inv in _pending
|
||||
]
|
||||
|
||||
state_record.mark_used()
|
||||
logger.info(
|
||||
f"OAuth login: auto-created org '{org.name}' (id={org.id}) "
|
||||
f"for new user {user.id}"
|
||||
f"OAuth login: user {user.id} has no org, redirecting to org-setup "
|
||||
f"(pending_invites={len(_pending_list)})"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "login",
|
||||
"requires_org_creation": True,
|
||||
"user": {"id": user.id, "email": user.email, "full_name": user.full_name},
|
||||
"session": _session_dict,
|
||||
"pending_invites": _pending_list,
|
||||
"state": state_record.state,
|
||||
}
|
||||
|
||||
# Priority 4: Multiple orgs — need user to pick one
|
||||
if not target_org:
|
||||
@@ -755,7 +803,7 @@ class OAuthFlowService:
|
||||
|
||||
# If organization_id hint was provided and valid, create session for that org
|
||||
if state_record.organization_id:
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
org = Organization.query.get(state_record.organization_id)
|
||||
if org:
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
@@ -784,7 +832,40 @@ class OAuthFlowService:
|
||||
"session": session_dict,
|
||||
}
|
||||
|
||||
# No organization hint or invalid - need to create/select org
|
||||
# No organization hint or invalid - need to create/select org.
|
||||
# Still create a session so the frontend can call /organizations
|
||||
# and /invites after redirecting to /org-setup.
|
||||
from gatehouse_app.services.auth_service import AuthService as _AS
|
||||
from gatehouse_app.models.organization.org_invite_token import OrgInviteToken
|
||||
_session = _AS.create_session(user=user, is_compliance_only=False)
|
||||
_session_dict = _session.to_dict()
|
||||
_session_dict["token"] = _session.token
|
||||
_expires_at = _session.expires_at
|
||||
if _expires_at.tzinfo is None:
|
||||
_expires_at = _expires_at.replace(tzinfo=timezone.utc)
|
||||
_now = datetime.now(timezone.utc)
|
||||
_session_dict["expires_in"] = int((_expires_at - _now).total_seconds())
|
||||
|
||||
# Surface pending invitations so the UI can offer "join vs create"
|
||||
_pending = OrgInviteToken.query.filter(
|
||||
OrgInviteToken.email == user.email,
|
||||
OrgInviteToken.accepted_at.is_(None),
|
||||
OrgInviteToken.expires_at > _now,
|
||||
OrgInviteToken.deleted_at.is_(None),
|
||||
).all()
|
||||
_pending_list = [
|
||||
{
|
||||
"token": inv.token,
|
||||
"organization": {
|
||||
"id": str(inv.organization_id),
|
||||
"name": inv.organization.name,
|
||||
},
|
||||
"role": inv.role,
|
||||
"expires_at": inv.expires_at.isoformat(),
|
||||
}
|
||||
for inv in _pending
|
||||
]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"flow_type": "register",
|
||||
@@ -794,6 +875,8 @@ class OAuthFlowService:
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
},
|
||||
"session": _session_dict,
|
||||
"pending_invites": _pending_list,
|
||||
"state": state_record.state,
|
||||
}
|
||||
|
||||
@@ -966,8 +1049,8 @@ class OAuthFlowService:
|
||||
)
|
||||
|
||||
# Determine organization
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
|
||||
# Get user's organizations
|
||||
user_orgs = user.get_organizations()
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
from flask import current_app
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.oidc_jwks_key import OidcJwksKey
|
||||
from gatehouse_app.models.oidc.oidc_jwks_key import OidcJwksKey
|
||||
|
||||
|
||||
class JWKSKey:
|
||||
|
||||
@@ -14,7 +14,7 @@ from gatehouse_app.models import (
|
||||
User, OIDCClient, OIDCAuthCode, OIDCRefreshToken,
|
||||
OIDCSession, OIDCTokenMetadata
|
||||
)
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.exceptions.validation_exceptions import (
|
||||
ValidationError, NotFoundError, BadRequestError
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ import jwt
|
||||
from flask import current_app, g
|
||||
|
||||
from gatehouse_app.models import User, OIDCClient
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -3,8 +3,8 @@ import logging
|
||||
from datetime import datetime, timezone
|
||||
from flask import current_app
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.exceptions.validation_exceptions import OrganizationNotFoundError, ConflictError
|
||||
from gatehouse_app.utils.constants import OrganizationRole, AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
@@ -188,11 +188,10 @@ class OrganizationService:
|
||||
Raises:
|
||||
ConflictError: If user is already a member
|
||||
"""
|
||||
# Check if already a member
|
||||
# Check if already a member (active or soft-deleted — both blocked by DB unique constraint)
|
||||
existing = OrganizationMember.query.filter_by(
|
||||
user_id=user_id,
|
||||
organization_id=org.id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
# Development-only debug logging for membership validation
|
||||
@@ -200,6 +199,25 @@ class OrganizationService:
|
||||
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
|
||||
|
||||
if existing:
|
||||
if existing.deleted_at is not None:
|
||||
# Reactivate the soft-deleted membership with the new role
|
||||
existing.deleted_at = None
|
||||
existing.role = role
|
||||
existing.invited_by_id = inviter_id
|
||||
existing.invited_at = datetime.now(timezone.utc)
|
||||
existing.joined_at = datetime.now(timezone.utc)
|
||||
existing.save()
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_MEMBER_ADD,
|
||||
user_id=inviter_id,
|
||||
organization_id=org.id,
|
||||
resource_type="organization_member",
|
||||
resource_id=existing.id,
|
||||
metadata={"added_user_id": user_id, "role": role.value},
|
||||
description=f"Member re-added to organization with role: {role.value}",
|
||||
)
|
||||
return existing
|
||||
raise ConflictError("User is already a member of this organization")
|
||||
|
||||
# Create membership
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Session service."""
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.models.session import Session
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class SessionService:
|
||||
Returns:
|
||||
Session object if found and active, None otherwise
|
||||
"""
|
||||
from gatehouse_app.models.session import Session
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
return Session.query.filter_by(
|
||||
token=token,
|
||||
|
||||
@@ -0,0 +1,359 @@
|
||||
"""SSH Certificate Authority signing service.
|
||||
|
||||
Handles SSH certificate signing operations, leveraging sshkey-tools library.
|
||||
This service is a Gatehouse-integrated version of the secuird/ssh_ca.py logic.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from sshkey_tools.cert import SSHCertificate, CertificateFields
|
||||
from sshkey_tools.keys import PublicKey, PrivateKey
|
||||
|
||||
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
|
||||
from gatehouse_app.exceptions import SSHCAError, ValidationError
|
||||
from gatehouse_app.utils.crypto import compute_ssh_fingerprint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHCASigningError(Exception):
|
||||
"""SSH CA signing operation error."""
|
||||
pass
|
||||
|
||||
|
||||
class SSHCertificateSigningRequest:
|
||||
"""Represents an SSH certificate signing request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssh_public_key: str,
|
||||
principals: List[str],
|
||||
key_id: str,
|
||||
cert_type: str = "user",
|
||||
expiry_hours: Optional[int] = None,
|
||||
critical_options: Optional[Dict[str, str]] = None,
|
||||
extensions: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize signing request.
|
||||
|
||||
Args:
|
||||
ssh_public_key: Public key in OpenSSH format (e.g., "ssh-ed25519 AAAA...")
|
||||
principals: List of principals (e.g., ["prod-servers", "staging"])
|
||||
key_id: Key identifier (usually user email)
|
||||
cert_type: Certificate type - "user" or "host" (default: user)
|
||||
expiry_hours: Certificate validity in hours
|
||||
critical_options: Critical options dict
|
||||
extensions: List of extensions (e.g., ["permit-pty", "permit-agent-forwarding"])
|
||||
"""
|
||||
self.ssh_public_key = ssh_public_key
|
||||
self.principals = principals or []
|
||||
self.key_id = key_id
|
||||
self.cert_type = cert_type
|
||||
self.expiry_hours = expiry_hours
|
||||
self.critical_options = critical_options or {}
|
||||
self.extensions = extensions or []
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate the signing request.
|
||||
|
||||
Returns:
|
||||
List of validation errors (empty if valid)
|
||||
"""
|
||||
errors = []
|
||||
config = get_ssh_ca_config()
|
||||
|
||||
# Validate cert type
|
||||
if self.cert_type not in ("user", "host"):
|
||||
errors.append(f"Invalid cert_type: {self.cert_type}. Must be 'user' or 'host'")
|
||||
|
||||
# Validate SSH public key
|
||||
if not self.ssh_public_key or len(self.ssh_public_key) < 16:
|
||||
errors.append("SSH public key is missing or invalid")
|
||||
else:
|
||||
try:
|
||||
PublicKey.from_string(self.ssh_public_key)
|
||||
except Exception as e:
|
||||
errors.append(f"SSH public key is not valid: {str(e)}")
|
||||
|
||||
# Validate principals
|
||||
if not self.principals or len(self.principals) == 0:
|
||||
errors.append("At least one principal is required")
|
||||
else:
|
||||
max_principals = config.get_int('max_principals_per_cert')
|
||||
if len(self.principals) > max_principals:
|
||||
errors.append(
|
||||
f"Too many principals ({len(self.principals)}). "
|
||||
f"Maximum is {max_principals}"
|
||||
)
|
||||
|
||||
# Validate key_id
|
||||
if not self.key_id or len(self.key_id) < 5:
|
||||
errors.append("key_id is missing or too short (minimum 5 characters)")
|
||||
else:
|
||||
max_id_len = config.get_int('max_key_id_length')
|
||||
if len(self.key_id) > max_id_len:
|
||||
errors.append(f"key_id exceeds maximum length of {max_id_len}")
|
||||
|
||||
# Validate expiry_hours
|
||||
if self.expiry_hours is not None:
|
||||
if not isinstance(self.expiry_hours, int) or self.expiry_hours <= 0:
|
||||
errors.append("expiry_hours must be a positive integer")
|
||||
else:
|
||||
max_validity = config.get_int('max_cert_validity_hours')
|
||||
if self.expiry_hours > max_validity:
|
||||
errors.append(
|
||||
f"Requested expiry ({self.expiry_hours}h) exceeds "
|
||||
f"maximum allowed ({max_validity}h)"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
class SSHCertificateSigningResponse:
|
||||
"""Represents a signed SSH certificate response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
certificate: str,
|
||||
serial: str,
|
||||
valid_after: datetime,
|
||||
valid_before: datetime,
|
||||
principals: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize signing response.
|
||||
|
||||
Args:
|
||||
certificate: Full certificate in OpenSSH format
|
||||
serial: Certificate serial number
|
||||
valid_after: Validity start datetime
|
||||
valid_before: Validity end datetime
|
||||
principals: List of principals the cert was issued for
|
||||
"""
|
||||
self.certificate = certificate
|
||||
self.serial = serial
|
||||
self.valid_after = valid_after
|
||||
self.valid_before = valid_before
|
||||
self.principals = principals or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert response to dictionary."""
|
||||
return {
|
||||
'certificate': self.certificate,
|
||||
'serial': self.serial,
|
||||
'valid_after': self.valid_after.isoformat(),
|
||||
'valid_before': self.valid_before.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class SSHCASigningService:
|
||||
"""Service for signing SSH certificates.
|
||||
|
||||
This service handles all SSH certificate signing operations.
|
||||
It uses configuration from ssh_ca_config to apply rules and limits.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the SSH CA signing service."""
|
||||
self.config = get_ssh_ca_config()
|
||||
self.logger = logger
|
||||
|
||||
def _load_ca_key_from_config(self) -> str:
|
||||
"""Load CA private key from config (local file or env var).
|
||||
|
||||
Returns:
|
||||
CA private key in PEM/OpenSSH format as string
|
||||
|
||||
Raises:
|
||||
SSHCASigningError: If key cannot be loaded
|
||||
"""
|
||||
# Check env var first
|
||||
key_content = os.environ.get('SSH_CA_PRIVATE_KEY')
|
||||
if key_content:
|
||||
return key_content
|
||||
|
||||
# Load from file path
|
||||
key_path = self.config.get_str('ca_key_path', '').strip()
|
||||
if not key_path:
|
||||
raise SSHCASigningError(
|
||||
"CA private key not configured. Set SSH_CA_PRIVATE_KEY env var "
|
||||
"or ca_key_path in etc/ssh_ca.conf"
|
||||
)
|
||||
|
||||
key_path = os.path.expandvars(os.path.expanduser(key_path))
|
||||
if not os.path.exists(key_path):
|
||||
raise SSHCASigningError(f"CA private key file not found: {key_path}")
|
||||
|
||||
with open(key_path, 'r') as f:
|
||||
return f.read()
|
||||
|
||||
def sign_certificate(
|
||||
self,
|
||||
signing_request: SSHCertificateSigningRequest,
|
||||
ca_private_key: Optional[str] = None,
|
||||
ca_obj=None,
|
||||
) -> "SSHCertificateSigningResponse":
|
||||
"""Sign an SSH certificate.
|
||||
|
||||
Args:
|
||||
signing_request: SSHCertificateSigningRequest instance
|
||||
ca_private_key: CA private key in PEM format. If not provided,
|
||||
loaded from config (ca_key_path or SSH_CA_PRIVATE_KEY env var)
|
||||
ca_obj: Optional CA model instance. When supplied its monotonic
|
||||
serial counter is incremented atomically (SELECT FOR UPDATE)
|
||||
and the resulting integer is embedded in the certificate's
|
||||
serial field. This ensures every issued cert has a unique,
|
||||
ordered, auditable serial number.
|
||||
|
||||
Returns:
|
||||
SSHCertificateSigningResponse with signed certificate
|
||||
|
||||
Raises:
|
||||
SSHCASigningError: If signing fails
|
||||
ValidationError: If request is invalid
|
||||
"""
|
||||
# Validate request
|
||||
errors = signing_request.validate()
|
||||
if errors:
|
||||
error_msg = "; ".join(errors)
|
||||
self.logger.error(f"Certificate signing validation failed: {error_msg}")
|
||||
raise ValidationError(f"Certificate signing validation failed: {error_msg}")
|
||||
|
||||
# Load CA key if not provided
|
||||
if ca_private_key is None:
|
||||
ca_private_key = self._load_ca_key_from_config()
|
||||
|
||||
try:
|
||||
# Parse CA private key
|
||||
try:
|
||||
ca_key = PrivateKey.from_string(ca_private_key)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load CA private key: {str(e)}")
|
||||
raise SSHCASigningError(f"Invalid CA private key: {str(e)}")
|
||||
|
||||
# Parse user's public key
|
||||
try:
|
||||
user_pub_key = PublicKey.from_string(signing_request.ssh_public_key)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to parse user public key: {str(e)}")
|
||||
raise SSHCASigningError(f"Invalid user public key: {str(e)}")
|
||||
|
||||
# Create certificate
|
||||
certificate = SSHCertificate.create(
|
||||
subject_pubkey=user_pub_key,
|
||||
ca_privkey=ca_key,
|
||||
)
|
||||
|
||||
# Set validity period
|
||||
now = datetime.now(timezone.utc)
|
||||
expiry_hours = signing_request.expiry_hours or self.config.get_int('cert_validity_hours')
|
||||
valid_before = now + timedelta(hours=expiry_hours)
|
||||
|
||||
# Set certificate fields
|
||||
# sshkey-tools: user=1, host=2 (not 0)
|
||||
cert_type = 1 if signing_request.cert_type == "user" else 2
|
||||
|
||||
certificate.fields.cert_type = cert_type
|
||||
certificate.fields.key_id = signing_request.key_id
|
||||
certificate.fields.principals = signing_request.principals
|
||||
certificate.fields.valid_after = now
|
||||
certificate.fields.valid_before = valid_before
|
||||
|
||||
# ── Serial number ────────────────────────────────────────────────
|
||||
# If a CA object is provided, use its monotonic counter so every
|
||||
# certificate gets a unique, ordered, auditable serial. The
|
||||
# counter increment is flushed inside get_next_serial(); the
|
||||
# caller's commit() persists it atomically with the cert record.
|
||||
if ca_obj is not None:
|
||||
assigned_serial = ca_obj.get_next_serial()
|
||||
certificate.fields.serial = assigned_serial
|
||||
self.logger.debug(
|
||||
f"Assigned serial {assigned_serial} from CA {ca_obj.id}"
|
||||
)
|
||||
# ─────────────────────────────────────────────────────────────────
|
||||
|
||||
# Set extensions — prefer policy-provided list, fall back to standard set
|
||||
extensions = signing_request.extensions
|
||||
if not extensions:
|
||||
from gatehouse_app.models.organization.department_cert_policy import STANDARD_EXTENSIONS
|
||||
extensions = list(STANDARD_EXTENSIONS)
|
||||
|
||||
certificate.fields.extensions = extensions
|
||||
certificate.fields.critical_options = signing_request.critical_options or {}
|
||||
|
||||
# Validate certificate before signing
|
||||
if not certificate.can_sign():
|
||||
raise SSHCASigningError("Certificate cannot be signed")
|
||||
|
||||
# Sign the certificate
|
||||
certificate.sign()
|
||||
|
||||
# Verify the certificate
|
||||
try:
|
||||
certificate.verify(ca_key.public_key, raise_on_error=True)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Certificate verification failed: {str(e)}")
|
||||
raise SSHCASigningError(f"Certificate verification failed: {str(e)}")
|
||||
|
||||
# Extract serial from certificate — use the integer we assigned
|
||||
# when ca_obj was provided, otherwise fall back to whatever the
|
||||
# library generated.
|
||||
if ca_obj is not None:
|
||||
serial = str(assigned_serial)
|
||||
else:
|
||||
serial = str(certificate.fields.serial).split(":")[-1].strip() if hasattr(certificate.fields.serial, '__str__') else str(certificate.fields.serial)
|
||||
|
||||
# Build response
|
||||
cert_string = certificate.to_string()
|
||||
|
||||
self.logger.info(
|
||||
f"Successfully signed certificate: serial={serial}, "
|
||||
f"key_id={signing_request.key_id}, principals={signing_request.principals}"
|
||||
)
|
||||
|
||||
return SSHCertificateSigningResponse(
|
||||
certificate=cert_string,
|
||||
serial=serial,
|
||||
valid_after=now,
|
||||
valid_before=valid_before,
|
||||
principals=signing_request.principals,
|
||||
)
|
||||
|
||||
except (SSHCASigningError, ValidationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Unexpected error during certificate signing: {str(e)}", exc_info=True)
|
||||
raise SSHCASigningError(f"Error signing certificate: {str(e)}")
|
||||
|
||||
def verify_ca_key(self, ca_private_key: str) -> Dict[str, Any]:
|
||||
"""Verify a CA private key is valid and extract metadata.
|
||||
|
||||
Args:
|
||||
ca_private_key: CA private key in PEM format
|
||||
|
||||
Returns:
|
||||
Dictionary with key metadata (fingerprint, key_type, etc.)
|
||||
|
||||
Raises:
|
||||
SSHCASigningError: If key is invalid
|
||||
"""
|
||||
try:
|
||||
ca_key = PrivateKey.from_string(ca_private_key)
|
||||
pub_key = ca_key.public_key
|
||||
|
||||
# Compute fingerprint
|
||||
fingerprint = compute_ssh_fingerprint(pub_key.to_string())
|
||||
|
||||
# Get key type
|
||||
key_type = pub_key.keytype if hasattr(pub_key, 'keytype') else 'unknown'
|
||||
|
||||
return {
|
||||
'fingerprint': fingerprint,
|
||||
'key_type': key_type,
|
||||
'public_key': pub_key.to_string(),
|
||||
'valid': True,
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"CA key verification failed: {str(e)}")
|
||||
raise SSHCASigningError(f"Invalid CA key: {str(e)}")
|
||||
@@ -0,0 +1,375 @@
|
||||
"""SSH Key management service."""
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import subprocess
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models import SSHKey, User
|
||||
from gatehouse_app.exceptions import (
|
||||
SSHKeyError,
|
||||
SSHKeyNotFoundError,
|
||||
SSHKeyAlreadyExistsError,
|
||||
SSHKeyNotVerifiedError,
|
||||
ValidationError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from gatehouse_app.utils.crypto import (
|
||||
compute_ssh_fingerprint,
|
||||
verify_ssh_key_format,
|
||||
extract_ssh_key_type,
|
||||
extract_ssh_key_comment,
|
||||
)
|
||||
from gatehouse_app.config.ssh_ca_config import get_ssh_ca_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHKeyService:
|
||||
"""Service for managing SSH keys."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize SSH key service."""
|
||||
self.config = get_ssh_ca_config()
|
||||
|
||||
def add_ssh_key(
|
||||
self,
|
||||
user_id: str,
|
||||
public_key: str,
|
||||
description: Optional[str] = None,
|
||||
) -> SSHKey:
|
||||
"""Add an SSH public key for a user.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user
|
||||
public_key: SSH public key in OpenSSH format
|
||||
description: Optional description of the key
|
||||
|
||||
Returns:
|
||||
Created SSHKey instance
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If user doesn't exist
|
||||
SSHKeyError: If key format is invalid
|
||||
SSHKeyAlreadyExistsError: If key already exists
|
||||
"""
|
||||
# Verify user exists
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"User {user_id} not found")
|
||||
|
||||
# Validate key format
|
||||
if not verify_ssh_key_format(public_key):
|
||||
raise SSHKeyError("Invalid SSH public key format")
|
||||
|
||||
# Compute fingerprint
|
||||
try:
|
||||
fingerprint = compute_ssh_fingerprint(public_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compute fingerprint: {str(e)}")
|
||||
raise SSHKeyError(f"Failed to compute key fingerprint: {str(e)}")
|
||||
|
||||
# Check for duplicate (including soft-deleted records — fingerprint is unique in DB)
|
||||
existing = SSHKey.query.filter_by(fingerprint=fingerprint).first()
|
||||
if existing:
|
||||
if existing.deleted_at is not None:
|
||||
# Restore the soft-deleted key: clear deleted_at and update fields
|
||||
existing.deleted_at = None
|
||||
existing.user_id = user_id
|
||||
existing.description = description or existing.description
|
||||
existing.verified = False
|
||||
existing.verified_at = None
|
||||
existing.verify_text = None
|
||||
existing.verify_text_created_at = None
|
||||
db.session.commit()
|
||||
logger.info(
|
||||
f"Restored soft-deleted SSH key for user {user_id}: "
|
||||
f"fingerprint={fingerprint}"
|
||||
)
|
||||
return existing
|
||||
raise SSHKeyAlreadyExistsError(
|
||||
f"SSH key with fingerprint {fingerprint} already exists"
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
key_type = extract_ssh_key_type(public_key)
|
||||
key_comment = extract_ssh_key_comment(public_key)
|
||||
|
||||
# Create SSH key record
|
||||
ssh_key = SSHKey(
|
||||
user_id=user_id,
|
||||
payload=public_key,
|
||||
fingerprint=fingerprint,
|
||||
description=description,
|
||||
key_type=key_type,
|
||||
key_comment=key_comment,
|
||||
verified=False,
|
||||
)
|
||||
|
||||
ssh_key.save()
|
||||
|
||||
logger.info(
|
||||
f"SSH key added for user {user_id}: "
|
||||
f"fingerprint={fingerprint}, type={key_type}"
|
||||
)
|
||||
|
||||
return ssh_key
|
||||
|
||||
def get_ssh_key(self, key_id: str) -> SSHKey:
|
||||
"""Get an SSH key by ID.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Returns:
|
||||
SSHKey instance
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = SSHKey.query.filter_by(id=key_id, deleted_at=None).first()
|
||||
if not key:
|
||||
raise SSHKeyNotFoundError(f"SSH key {key_id} not found")
|
||||
return key
|
||||
|
||||
def get_user_ssh_keys(self, user_id: str) -> List[SSHKey]:
|
||||
"""Get all SSH keys for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of SSHKey instances
|
||||
"""
|
||||
return SSHKey.query.filter_by(user_id=user_id, deleted_at=None).all()
|
||||
|
||||
def get_user_verified_ssh_keys(self, user_id: str) -> List[SSHKey]:
|
||||
"""Get all verified SSH keys for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of verified SSHKey instances
|
||||
"""
|
||||
return SSHKey.query.filter_by(
|
||||
user_id=user_id,
|
||||
verified=True,
|
||||
deleted_at=None,
|
||||
).all()
|
||||
|
||||
def delete_ssh_key(self, key_id: str) -> None:
|
||||
"""Soft-delete an SSH key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
key.delete()
|
||||
|
||||
logger.info(f"SSH key deleted: {key_id}")
|
||||
|
||||
def generate_verification_challenge(self, key_id: str) -> str:
|
||||
"""Generate a verification challenge for an SSH key.
|
||||
|
||||
The user must sign this challenge text with their private key
|
||||
to prove key ownership.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Returns:
|
||||
Verification challenge text
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
|
||||
# Generate random challenge
|
||||
challenge = secrets.token_hex(32)
|
||||
challenge_text = f"Please sign this to verify SSH key ownership: {challenge}"
|
||||
|
||||
# Store challenge
|
||||
key.verify_text = challenge_text
|
||||
key.verify_text_created_at = datetime.utcnow()
|
||||
key.save()
|
||||
|
||||
logger.info(f"Generated verification challenge for SSH key {key_id}")
|
||||
|
||||
return challenge_text
|
||||
|
||||
def verify_ssh_key_ownership(
|
||||
self,
|
||||
key_id: str,
|
||||
signature: str,
|
||||
) -> bool:
|
||||
"""Verify SSH key ownership via signature.
|
||||
|
||||
The user must sign the verification challenge with their private key.
|
||||
We verify the signature using the public key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
signature: Base64-encoded signature of the challenge
|
||||
|
||||
Returns:
|
||||
True if signature is valid
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
SSHKeyNotVerifiedError: If challenge is stale or missing
|
||||
SSHKeyError: If verification fails
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
|
||||
# Check if challenge exists and is not stale
|
||||
if not key.verify_text or not key.verify_text_created_at:
|
||||
raise SSHKeyNotVerifiedError("No verification challenge generated")
|
||||
|
||||
max_age = self.config.get_int('verification_challenge_max_age')
|
||||
age = datetime.utcnow() - key.verify_text_created_at
|
||||
if age.total_seconds() > (max_age * 3600):
|
||||
raise SSHKeyNotVerifiedError("Verification challenge has expired")
|
||||
|
||||
try:
|
||||
# Verify the SSH signature using ssh-keygen -Y verify.
|
||||
# The CLI signs the challenge with: ssh-keygen -Y sign -f <key> -n file <challenge>
|
||||
# We verify with: ssh-keygen -Y verify -f <allowed_signers> -I <identity> -n file -s <sig> < <message>
|
||||
#
|
||||
# allowed_signers format: "<identity> <keytype> <pubkey>"
|
||||
# We use the key fingerprint as the identity.
|
||||
|
||||
sig_bytes = base64.b64decode(signature)
|
||||
challenge_text = key.verify_text + "\n"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
allowed_signers_path = os.path.join(tmpdir, "allowed_signers")
|
||||
sig_path = os.path.join(tmpdir, "message.sig")
|
||||
message_path = os.path.join(tmpdir, "message.txt")
|
||||
|
||||
identity = key.fingerprint
|
||||
|
||||
# Write the allowed_signers file
|
||||
with open(allowed_signers_path, "w") as f:
|
||||
f.write(f"{identity} {key.payload}\n")
|
||||
|
||||
# Write the signature file
|
||||
with open(sig_path, "wb") as f:
|
||||
f.write(sig_bytes)
|
||||
|
||||
# Write the challenge message
|
||||
with open(message_path, "w") as f:
|
||||
f.write(challenge_text)
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ssh-keygen", "-Y", "verify",
|
||||
"-f", allowed_signers_path,
|
||||
"-I", identity,
|
||||
"-n", "file",
|
||||
"-s", sig_path,
|
||||
],
|
||||
stdin=open(message_path, "rb"),
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
stderr = result.stderr.decode(errors="replace").strip()
|
||||
logger.warning(f"SSH signature verification failed for key {key_id}: {stderr}")
|
||||
raise SSHKeyError(f"Signature verification failed: {stderr}")
|
||||
|
||||
key.mark_verified()
|
||||
logger.info(f"SSH key verified: {key_id}")
|
||||
return True
|
||||
|
||||
except SSHKeyError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"SSH key verification failed: {str(e)}")
|
||||
raise SSHKeyError(f"Signature verification failed: {str(e)}")
|
||||
|
||||
def get_key_fingerprint(self, key_id: str) -> str:
|
||||
"""Get the fingerprint of an SSH key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
|
||||
Returns:
|
||||
Fingerprint string
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
return key.fingerprint
|
||||
|
||||
def update_ssh_key_description(self, key_id: str, description: str) -> SSHKey:
|
||||
"""Update the description of an SSH key.
|
||||
|
||||
Args:
|
||||
key_id: SSH key ID
|
||||
description: New description
|
||||
|
||||
Returns:
|
||||
Updated SSHKey instance
|
||||
|
||||
Raises:
|
||||
SSHKeyNotFoundError: If key not found
|
||||
"""
|
||||
key = self.get_ssh_key(key_id)
|
||||
key.description = description
|
||||
key.save()
|
||||
|
||||
return key
|
||||
|
||||
def cleanup_expired_challenges(self) -> int:
|
||||
"""Clean up expired verification challenges.
|
||||
|
||||
Returns:
|
||||
Number of challenges cleaned
|
||||
"""
|
||||
max_age = self.config.get_int('verification_challenge_max_age')
|
||||
threshold = datetime.utcnow() - timedelta(hours=max_age)
|
||||
|
||||
expired = SSHKey.query.filter(
|
||||
SSHKey.verify_text_created_at < threshold,
|
||||
SSHKey.verify_text_created_at.isnot(None),
|
||||
SSHKey.deleted_at.is_(None),
|
||||
).update({"verify_text": None, "verify_text_created_at": None})
|
||||
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"Cleaned up {expired} expired verification challenges")
|
||||
return expired
|
||||
|
||||
def cleanup_unverified_keys(self) -> int:
|
||||
"""Delete unverified SSH keys older than configured days.
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
days = self.config.get_int('auto_delete_unverified_days')
|
||||
threshold = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
old_unverified = SSHKey.query.filter(
|
||||
SSHKey.verified == False,
|
||||
SSHKey.created_at < threshold,
|
||||
SSHKey.deleted_at.is_(None),
|
||||
).all()
|
||||
|
||||
count = 0
|
||||
for key in old_unverified:
|
||||
key.delete()
|
||||
count += 1
|
||||
|
||||
logger.info(f"Deleted {count} unverified SSH keys older than {days} days")
|
||||
return count
|
||||
@@ -11,10 +11,51 @@ from gatehouse_app.extensions import bcrypt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TOTP codes are valid for at most (2*window + 1) * 30s steps.
|
||||
# With window=1 that's 3 steps = 90 seconds. We use a slightly
|
||||
# generous TTL of 95 seconds to account for clock skew at boundaries.
|
||||
_TOTP_USED_CODE_TTL = 95
|
||||
|
||||
|
||||
class TOTPService:
|
||||
"""Service for TOTP operations."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Replay-attack prevention helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _used_key(user_id: str, code: str) -> str:
|
||||
return f"totp:used:{user_id}:{code}"
|
||||
|
||||
@staticmethod
|
||||
def is_code_already_used(user_id: str, code: str) -> bool:
|
||||
"""Return True if *code* has already been accepted for *user_id*
|
||||
within the current validity window (prevents replay attacks)."""
|
||||
try:
|
||||
from gatehouse_app.extensions import redis_client
|
||||
if redis_client is None:
|
||||
return False
|
||||
return redis_client.exists(TOTPService._used_key(user_id, code)) == 1
|
||||
except Exception:
|
||||
logger.warning("Redis unavailable for TOTP replay check; allowing code")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def mark_code_used(user_id: str, code: str) -> None:
|
||||
"""Record *code* as consumed for *user_id* so it cannot be reused."""
|
||||
try:
|
||||
from gatehouse_app.extensions import redis_client
|
||||
if redis_client is None:
|
||||
return
|
||||
redis_client.setex(
|
||||
TOTPService._used_key(user_id, code),
|
||||
_TOTP_USED_CODE_TTL,
|
||||
"1",
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Redis unavailable; TOTP used-code not recorded")
|
||||
|
||||
@staticmethod
|
||||
def generate_secret() -> str:
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import logging
|
||||
from flask import current_app
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.exceptions.validation_exceptions import UserNotFoundError
|
||||
from gatehouse_app.utils.constants import AuditAction
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
@@ -10,8 +10,8 @@ from flask import current_app
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from gatehouse_app.extensions import db, redis_client
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.utils.constants import AuthMethodType, AuditAction
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
@@ -641,6 +641,26 @@ class WebAuthnService:
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def credential_belongs_to_user(cls, credential_id: str, user: User) -> bool:
|
||||
"""Check whether *credential_id* exists and belongs to *user*.
|
||||
|
||||
Args:
|
||||
credential_id: The credential ID to look up
|
||||
user: User instance
|
||||
|
||||
Returns:
|
||||
True if the credential exists and belongs to this user, False otherwise.
|
||||
"""
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user.id,
|
||||
method_type=AuthMethodType.WEBAUTHN,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if not auth_method or not auth_method.provider_data:
|
||||
return False
|
||||
return auth_method.provider_data.get("credential_id") == credential_id
|
||||
|
||||
@classmethod
|
||||
def rename_credential(cls, credential_id: str, user: User, name: str) -> bool:
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Encryption helpers for CA private keys stored in the database.
|
||||
|
||||
CA private keys are encrypted at rest using Fernet (AES-128-CBC + HMAC-SHA256)
|
||||
from the ``cryptography`` package. The encryption key is derived from the
|
||||
``CA_ENCRYPTION_KEY`` environment variable (or ``Flask.config["CA_ENCRYPTION_KEY"]``).
|
||||
|
||||
Key derivation
|
||||
--------------
|
||||
Fernet requires a URL-safe base64-encoded 32-byte key. We accept any string
|
||||
from the env and derive the actual Fernet key using SHA-256 so that operators
|
||||
can supply human-readable secrets without having to pre-encode them.
|
||||
|
||||
Envelope format
|
||||
---------------
|
||||
Encrypted values are stored as the string::
|
||||
|
||||
$fernet$<fernet_token>
|
||||
|
||||
The ``$fernet$`` prefix lets the code distinguish already-encrypted values from
|
||||
legacy plaintext PEM keys so that the migration path is safe and idempotent.
|
||||
|
||||
Usage
|
||||
-----
|
||||
Encrypt before storing::
|
||||
|
||||
from gatehouse_app.utils.ca_key_encryption import encrypt_ca_key
|
||||
ca.private_key = encrypt_ca_key(private_key_pem)
|
||||
|
||||
Decrypt before use::
|
||||
|
||||
from gatehouse_app.utils.ca_key_encryption import decrypt_ca_key
|
||||
plaintext_pem = decrypt_ca_key(ca.private_key)
|
||||
"""
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Prefix that marks a stored value as Fernet-encrypted
|
||||
_FERNET_PREFIX = "$fernet$"
|
||||
|
||||
|
||||
class CAKeyEncryptionError(Exception):
|
||||
"""Raised when CA key encryption or decryption fails."""
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
"""Build a Fernet instance from the configured encryption key.
|
||||
|
||||
Looks up ``CA_ENCRYPTION_KEY`` in the environment first, then falls back to
|
||||
the Flask app config (if a request context is active).
|
||||
|
||||
Raises:
|
||||
CAKeyEncryptionError: if no key is configured or it is the insecure
|
||||
placeholder value in a production-like environment.
|
||||
"""
|
||||
raw_key = os.environ.get("CA_ENCRYPTION_KEY")
|
||||
|
||||
if not raw_key:
|
||||
# Try Flask config if we're inside an app context
|
||||
try:
|
||||
from flask import current_app
|
||||
raw_key = current_app.config.get("CA_ENCRYPTION_KEY")
|
||||
except RuntimeError:
|
||||
pass # No app context
|
||||
|
||||
if not raw_key:
|
||||
raise CAKeyEncryptionError(
|
||||
"CA_ENCRYPTION_KEY is not set. "
|
||||
"Set this environment variable before starting the application."
|
||||
)
|
||||
|
||||
# Warn loudly when running with the placeholder in a non-test environment
|
||||
env_name = os.environ.get("FLASK_ENV", "").lower()
|
||||
if raw_key.startswith("dev-") and env_name not in ("development", "testing", "test"):
|
||||
logger.warning(
|
||||
"CA_ENCRYPTION_KEY appears to be a development placeholder. "
|
||||
"Set a strong random key for production environments."
|
||||
)
|
||||
|
||||
# Derive a 32-byte key from the raw secret via SHA-256, then URL-safe base64
|
||||
key_bytes = hashlib.sha256(raw_key.encode()).digest()
|
||||
fernet_key = base64.urlsafe_b64encode(key_bytes)
|
||||
return Fernet(fernet_key)
|
||||
|
||||
|
||||
def encrypt_ca_key(plaintext_pem: str) -> str:
|
||||
"""Encrypt a CA private key PEM string.
|
||||
|
||||
Idempotent: already-encrypted values are returned unchanged.
|
||||
|
||||
Args:
|
||||
plaintext_pem: CA private key in OpenSSH/PEM format.
|
||||
|
||||
Returns:
|
||||
Encrypted string with ``$fernet$`` prefix, safe for database storage.
|
||||
|
||||
Raises:
|
||||
CAKeyEncryptionError: if the key cannot be encrypted.
|
||||
"""
|
||||
if not plaintext_pem:
|
||||
raise CAKeyEncryptionError("Cannot encrypt an empty key")
|
||||
|
||||
# Already encrypted — do not double-encrypt
|
||||
if plaintext_pem.startswith(_FERNET_PREFIX):
|
||||
return plaintext_pem
|
||||
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
token = fernet.encrypt(plaintext_pem.encode()).decode()
|
||||
return f"{_FERNET_PREFIX}{token}"
|
||||
except CAKeyEncryptionError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise CAKeyEncryptionError(f"Failed to encrypt CA key: {exc}") from exc
|
||||
|
||||
|
||||
def decrypt_ca_key(stored_value: str) -> str:
|
||||
"""Decrypt a CA private key retrieved from the database.
|
||||
|
||||
Idempotent: plaintext (legacy) values are returned unchanged so that the
|
||||
system continues to work while a migration encrypts existing rows.
|
||||
|
||||
Args:
|
||||
stored_value: Value from ``CA.private_key`` column.
|
||||
|
||||
Returns:
|
||||
Plaintext PEM string ready for use with ``sshkey_tools``.
|
||||
|
||||
Raises:
|
||||
CAKeyEncryptionError: if decryption fails (wrong key, corrupted data).
|
||||
"""
|
||||
if not stored_value:
|
||||
raise CAKeyEncryptionError("Cannot decrypt an empty value")
|
||||
|
||||
# Legacy plaintext key — return as-is
|
||||
if not stored_value.startswith(_FERNET_PREFIX):
|
||||
logger.warning(
|
||||
"CA private key appears to be stored as plaintext. "
|
||||
"Run the migration to encrypt existing keys."
|
||||
)
|
||||
return stored_value
|
||||
|
||||
token = stored_value[len(_FERNET_PREFIX):]
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
return fernet.decrypt(token.encode()).decode()
|
||||
except InvalidToken as exc:
|
||||
raise CAKeyEncryptionError(
|
||||
"CA key decryption failed — the CA_ENCRYPTION_KEY may be incorrect "
|
||||
"or the stored key is corrupted."
|
||||
) from exc
|
||||
except CAKeyEncryptionError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise CAKeyEncryptionError(f"Unexpected decryption error: {exc}") from exc
|
||||
|
||||
|
||||
def is_encrypted(stored_value: str) -> bool:
|
||||
"""Return True if the stored value has the ``$fernet$`` envelope.
|
||||
|
||||
Args:
|
||||
stored_value: Value from ``CA.private_key`` column.
|
||||
"""
|
||||
return bool(stored_value and stored_value.startswith(_FERNET_PREFIX))
|
||||
|
||||
|
||||
def reencrypt_ca_key(stored_value: str, old_raw_key: str, new_raw_key: str) -> str:
|
||||
"""Re-encrypt a CA key with a new encryption key (for key rotation).
|
||||
|
||||
Args:
|
||||
stored_value: Current value from ``CA.private_key`` (may or may not be encrypted).
|
||||
old_raw_key: The current ``CA_ENCRYPTION_KEY`` value (raw secret string).
|
||||
new_raw_key: The new ``CA_ENCRYPTION_KEY`` value to encrypt with.
|
||||
|
||||
Returns:
|
||||
New encrypted envelope string.
|
||||
|
||||
Raises:
|
||||
CAKeyEncryptionError: if decryption or re-encryption fails.
|
||||
"""
|
||||
# Decrypt with old key
|
||||
if stored_value.startswith(_FERNET_PREFIX):
|
||||
token = stored_value[len(_FERNET_PREFIX):]
|
||||
old_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(old_raw_key.encode()).digest())
|
||||
try:
|
||||
plaintext = Fernet(old_key_bytes).decrypt(token.encode()).decode()
|
||||
except InvalidToken as exc:
|
||||
raise CAKeyEncryptionError(
|
||||
"Re-encryption failed: could not decrypt with the old key."
|
||||
) from exc
|
||||
else:
|
||||
# Plaintext
|
||||
plaintext = stored_value
|
||||
|
||||
# Re-encrypt with new key
|
||||
new_key_bytes = base64.urlsafe_b64encode(hashlib.sha256(new_raw_key.encode()).digest())
|
||||
try:
|
||||
token = Fernet(new_key_bytes).encrypt(plaintext.encode()).decode()
|
||||
return f"{_FERNET_PREFIX}{token}"
|
||||
except Exception as exc:
|
||||
raise CAKeyEncryptionError(f"Re-encryption with new key failed: {exc}") from exc
|
||||
@@ -12,6 +12,16 @@ class UserStatus(str, Enum):
|
||||
COMPLIANCE_SUSPENDED = "compliance_suspended"
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
"""Generic role definitions (hierarchy: Admin > Manager > Member > Viewer > Guest)."""
|
||||
|
||||
ADMIN = "admin"
|
||||
MANAGER = "manager"
|
||||
MEMBER = "member"
|
||||
VIEWER = "viewer"
|
||||
GUEST = "guest"
|
||||
|
||||
|
||||
class OrganizationRole(str, Enum):
|
||||
"""Organization member roles."""
|
||||
|
||||
@@ -51,6 +61,9 @@ class AuditAction(str, Enum):
|
||||
USER_REGISTER = "user.register"
|
||||
USER_UPDATE = "user.update"
|
||||
USER_DELETE = "user.delete"
|
||||
USER_HARD_DELETE = "user.hard_delete"
|
||||
USER_SUSPEND = "user.suspend"
|
||||
USER_UNSUSPEND = "user.unsuspend"
|
||||
PASSWORD_CHANGE = "user.password_change"
|
||||
PASSWORD_RESET = "user.password_reset"
|
||||
|
||||
@@ -61,6 +74,7 @@ class AuditAction(str, Enum):
|
||||
ORG_MEMBER_ADD = "org.member.add"
|
||||
ORG_MEMBER_REMOVE = "org.member.remove"
|
||||
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
|
||||
ORG_OWNERSHIP_TRANSFERRED = "org.ownership.transferred"
|
||||
|
||||
# Session actions
|
||||
SESSION_CREATE = "session.create"
|
||||
@@ -105,6 +119,37 @@ class AuditAction(str, Enum):
|
||||
EXTERNAL_AUTH_CONFIG_UPDATE = "external_auth.config.update"
|
||||
EXTERNAL_AUTH_CONFIG_DELETE = "external_auth.config.delete"
|
||||
|
||||
# SSH Key and Certificate actions
|
||||
SSH_KEY_ADDED = "ssh.key.added"
|
||||
SSH_KEY_VERIFIED = "ssh.key.verified"
|
||||
SSH_KEY_DELETED = "ssh.key.deleted"
|
||||
SSH_KEY_VALIDATION_FAILED = "ssh.key.validation.failed"
|
||||
SSH_CERT_REQUESTED = "ssh.cert.requested"
|
||||
SSH_CERT_ISSUED = "ssh.cert.issued"
|
||||
SSH_CERT_FAILED = "ssh.cert.failed"
|
||||
SSH_CERT_REVOKED = "ssh.cert.revoked"
|
||||
SSH_CERT_EXPIRED = "ssh.cert.expired"
|
||||
|
||||
# CA actions
|
||||
CA_CREATED = "ca.created"
|
||||
CA_UPDATED = "ca.updated"
|
||||
CA_DELETED = "ca.deleted"
|
||||
CA_KEY_ROTATED = "ca.key.rotated"
|
||||
|
||||
# Principal actions
|
||||
PRINCIPAL_CREATED = "principal.created"
|
||||
PRINCIPAL_UPDATED = "principal.updated"
|
||||
PRINCIPAL_DELETED = "principal.deleted"
|
||||
PRINCIPAL_MEMBER_ADDED = "principal.member.added"
|
||||
PRINCIPAL_MEMBER_REMOVED = "principal.member.removed"
|
||||
|
||||
# Department actions
|
||||
DEPARTMENT_CREATED = "department.created"
|
||||
DEPARTMENT_UPDATED = "department.updated"
|
||||
DEPARTMENT_DELETED = "department.deleted"
|
||||
DEPARTMENT_MEMBER_ADDED = "department.member.added"
|
||||
DEPARTMENT_MEMBER_REMOVED = "department.member.removed"
|
||||
|
||||
|
||||
class OIDCGrantType(str, Enum):
|
||||
"""OIDC grant types."""
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Cryptographic utilities for SSH operations."""
|
||||
import hashlib
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def compute_ssh_fingerprint(public_key_str: str, hash_algorithm: str = "sha256") -> str:
|
||||
"""Compute the fingerprint of an SSH public key.
|
||||
|
||||
Args:
|
||||
public_key_str: SSH public key in OpenSSH format
|
||||
hash_algorithm: Hash algorithm to use (sha256, sha1, md5)
|
||||
|
||||
Returns:
|
||||
Fingerprint string in the format "algorithm:hex_digest"
|
||||
|
||||
Example:
|
||||
>>> key = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKp2..."
|
||||
>>> fp = compute_ssh_fingerprint(key)
|
||||
>>> print(fp)
|
||||
sha256:Kb+...
|
||||
"""
|
||||
if not public_key_str:
|
||||
raise ValueError("Public key string is empty")
|
||||
|
||||
# Parse OpenSSH format: "ssh-ed25519 <base64> [comment]"
|
||||
parts = public_key_str.strip().split()
|
||||
if len(parts) < 2:
|
||||
raise ValueError("Invalid OpenSSH public key format")
|
||||
|
||||
try:
|
||||
# The base64-encoded key is the second part
|
||||
key_bytes = base64.b64decode(parts[1])
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to decode public key: {str(e)}")
|
||||
|
||||
# Compute hash
|
||||
if hash_algorithm == "sha256":
|
||||
digest = hashlib.sha256(key_bytes).digest()
|
||||
# SSH format uses base64 encoding without padding
|
||||
fingerprint = base64.b64encode(digest).decode().rstrip('=')
|
||||
elif hash_algorithm == "sha1":
|
||||
digest = hashlib.sha1(key_bytes).hexdigest()
|
||||
fingerprint = digest
|
||||
elif hash_algorithm == "md5":
|
||||
digest = hashlib.md5(key_bytes).hexdigest()
|
||||
# Format as colons
|
||||
fingerprint = ':'.join(digest[i:i+2] for i in range(0, len(digest), 2))
|
||||
else:
|
||||
raise ValueError(f"Unsupported hash algorithm: {hash_algorithm}")
|
||||
|
||||
return f"{hash_algorithm}:{fingerprint}"
|
||||
|
||||
|
||||
def verify_ssh_key_format(public_key_str: str) -> bool:
|
||||
"""Verify that a string is in valid OpenSSH public key format.
|
||||
|
||||
Args:
|
||||
public_key_str: Potential SSH public key
|
||||
|
||||
Returns:
|
||||
True if valid OpenSSH format, False otherwise
|
||||
"""
|
||||
if not public_key_str or not isinstance(public_key_str, str):
|
||||
return False
|
||||
|
||||
parts = public_key_str.strip().split()
|
||||
|
||||
# Must have at least key type and key material
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
|
||||
key_type = parts[0]
|
||||
|
||||
# Valid key types
|
||||
valid_types = [
|
||||
'ssh-rsa',
|
||||
'ssh-ed25519',
|
||||
'ecdsa-sha2-nistp256',
|
||||
'ecdsa-sha2-nistp384',
|
||||
'ecdsa-sha2-nistp521',
|
||||
'ssh-dss',
|
||||
]
|
||||
|
||||
if key_type not in valid_types:
|
||||
return False
|
||||
|
||||
# Try to decode base64
|
||||
try:
|
||||
base64.b64decode(parts[1])
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def extract_ssh_key_type(public_key_str: str) -> Optional[str]:
|
||||
"""Extract the key type from an OpenSSH public key.
|
||||
|
||||
Args:
|
||||
public_key_str: SSH public key in OpenSSH format
|
||||
|
||||
Returns:
|
||||
Key type (e.g., "ssh-ed25519") or None if invalid
|
||||
"""
|
||||
if not verify_ssh_key_format(public_key_str):
|
||||
return None
|
||||
|
||||
return public_key_str.strip().split()[0]
|
||||
|
||||
|
||||
def extract_ssh_key_comment(public_key_str: str) -> Optional[str]:
|
||||
"""Extract the comment from an OpenSSH public key.
|
||||
|
||||
Args:
|
||||
public_key_str: SSH public key in OpenSSH format
|
||||
|
||||
Returns:
|
||||
Comment string or None if not present
|
||||
"""
|
||||
if not verify_ssh_key_format(public_key_str):
|
||||
return None
|
||||
|
||||
parts = public_key_str.strip().split()
|
||||
if len(parts) >= 3:
|
||||
# Everything after the second part is the comment
|
||||
return ' '.join(parts[2:])
|
||||
|
||||
return None
|
||||
@@ -64,11 +64,47 @@ def login_required(f):
|
||||
session.last_activity_at = datetime.now(timezone.utc)
|
||||
from gatehouse_app import db
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# Set context variables
|
||||
g.current_user = session.user
|
||||
g.current_session = session
|
||||
|
||||
|
||||
user = session.user
|
||||
token_groups: list = []
|
||||
try:
|
||||
if session.device_info:
|
||||
# device_info may carry OIDC claims stored at login time
|
||||
claims = session.device_info
|
||||
# Normalise: Gatehouse stores roles as [{"organization_id":…,"role":…}]
|
||||
roles_claim = claims.get("roles", [])
|
||||
if isinstance(roles_claim, list):
|
||||
for entry in roles_claim:
|
||||
if isinstance(entry, dict):
|
||||
role_val = entry.get("role")
|
||||
if role_val:
|
||||
token_groups.append(str(role_val))
|
||||
elif isinstance(entry, str):
|
||||
token_groups.append(entry)
|
||||
# Standard OIDC groups claim
|
||||
groups_claim = claims.get("groups", [])
|
||||
if isinstance(groups_claim, list):
|
||||
token_groups.extend(str(g_) for g_ in groups_claim if g_)
|
||||
except Exception:
|
||||
pass # Never block auth over token_groups enrichment failure
|
||||
user.token_groups = token_groups
|
||||
|
||||
# Activation check: if the user has an `activated` attribute and it is
|
||||
# explicitly False, block access. New accounts without the attribute are
|
||||
# treated as active to avoid breaking existing sessions.
|
||||
activated = getattr(user, "activated", None)
|
||||
if activated is False:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Account not yet activated. Please check your email for an activation link.",
|
||||
status=403,
|
||||
error_type="ACCOUNT_NOT_ACTIVATED",
|
||||
)
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
@@ -97,11 +133,12 @@ def require_role(*allowed_roles):
|
||||
raise ForbiddenError("Organization context required")
|
||||
|
||||
# Check user's role in the organization
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
|
||||
membership = OrganizationMember.query.filter_by(
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
|
||||
@@ -111,5 +111,79 @@ def mfa_compliance_status():
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
@cli.command("configure_oauth")
|
||||
def configure_oauth():
|
||||
"""Interactively configure an OAuth provider at the application level.
|
||||
|
||||
Usage:
|
||||
python manage.py configure_oauth
|
||||
|
||||
Supported providers: google, github, microsoft
|
||||
"""
|
||||
import getpass
|
||||
from gatehouse_app.models.authentication_method import ApplicationProviderConfig
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
SUPPORTED = ["google", "github", "microsoft"]
|
||||
|
||||
print("=" * 60)
|
||||
print("OAuth Provider Configuration")
|
||||
print("=" * 60)
|
||||
print(f"Supported providers: {', '.join(SUPPORTED)}")
|
||||
|
||||
provider = input("Provider [google/github/microsoft]: ").strip().lower()
|
||||
if provider not in SUPPORTED:
|
||||
print(f"❌ Unknown provider: {provider}")
|
||||
return
|
||||
|
||||
client_id = input("Client ID: ").strip()
|
||||
if not client_id:
|
||||
print("❌ client_id is required")
|
||||
return
|
||||
|
||||
client_secret = getpass.getpass("Client Secret (leave blank to keep existing): ").strip()
|
||||
|
||||
with app.app_context():
|
||||
config = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
|
||||
if config:
|
||||
config.client_id = client_id
|
||||
if client_secret:
|
||||
config.set_client_secret(client_secret)
|
||||
config.is_enabled = True
|
||||
db.session.commit()
|
||||
print(f"✅ Updated {provider} provider config.")
|
||||
else:
|
||||
config = ApplicationProviderConfig(
|
||||
provider_type=provider,
|
||||
client_id=client_id,
|
||||
is_enabled=True,
|
||||
)
|
||||
if client_secret:
|
||||
config.set_client_secret(client_secret)
|
||||
db.session.add(config)
|
||||
db.session.commit()
|
||||
print(f"✅ Created {provider} provider config.")
|
||||
|
||||
|
||||
@cli.command("list_oauth")
|
||||
def list_oauth():
|
||||
"""List all configured OAuth providers.
|
||||
|
||||
Usage:
|
||||
python manage.py list_oauth
|
||||
"""
|
||||
from gatehouse_app.models.authentication_method import ApplicationProviderConfig
|
||||
|
||||
with app.app_context():
|
||||
configs = ApplicationProviderConfig.query.all()
|
||||
if not configs:
|
||||
print("No OAuth providers configured.")
|
||||
return
|
||||
print(f"{'Provider':<15} {'Client ID':<40} {'Enabled'}")
|
||||
print("-" * 65)
|
||||
for c in configs:
|
||||
print(f"{c.provider_type:<15} {c.client_id:<40} {c.is_enabled}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
"""Add Department and Principal models for SSH CA management.
|
||||
|
||||
Revision ID: 006
|
||||
Revises: 005
|
||||
Create Date: 2026-02-27 10:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '006'
|
||||
down_revision = '005'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### Department table ###
|
||||
op.create_table('departments',
|
||||
sa.Column('organization_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('organization_id', 'name', name='uix_org_dept_name')
|
||||
)
|
||||
op.create_index(op.f('ix_departments_organization_id'), 'departments', ['organization_id'], unique=False)
|
||||
op.create_index(op.f('ix_departments_name'), 'departments', ['name'], unique=False)
|
||||
|
||||
# ### DepartmentMembership table ###
|
||||
op.create_table('department_memberships',
|
||||
sa.Column('user_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('department_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'department_id', name='uix_user_dept')
|
||||
)
|
||||
op.create_index(op.f('ix_department_memberships_user_id'), 'department_memberships', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_department_memberships_department_id'), 'department_memberships', ['department_id'], unique=False)
|
||||
|
||||
# ### Principal table ###
|
||||
op.create_table('principals',
|
||||
sa.Column('organization_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('organization_id', 'name', name='uix_org_principal_name')
|
||||
)
|
||||
op.create_index(op.f('ix_principals_organization_id'), 'principals', ['organization_id'], unique=False)
|
||||
op.create_index(op.f('ix_principals_name'), 'principals', ['name'], unique=False)
|
||||
|
||||
# ### PrincipalMembership table ###
|
||||
op.create_table('principal_memberships',
|
||||
sa.Column('user_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('principal_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'principal_id', name='uix_user_principal')
|
||||
)
|
||||
op.create_index(op.f('ix_principal_memberships_user_id'), 'principal_memberships', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_principal_memberships_principal_id'), 'principal_memberships', ['principal_id'], unique=False)
|
||||
|
||||
# ### DepartmentPrincipal table ###
|
||||
op.create_table('department_principals',
|
||||
sa.Column('department_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('principal_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
|
||||
sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('department_id', 'principal_id', name='uix_dept_principal')
|
||||
)
|
||||
op.create_index(op.f('ix_department_principals_department_id'), 'department_principals', ['department_id'], unique=False)
|
||||
op.create_index(op.f('ix_department_principals_principal_id'), 'department_principals', ['principal_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_department_principals_principal_id'), table_name='department_principals')
|
||||
op.drop_index(op.f('ix_department_principals_department_id'), table_name='department_principals')
|
||||
op.drop_table('department_principals')
|
||||
|
||||
op.drop_index(op.f('ix_principal_memberships_principal_id'), table_name='principal_memberships')
|
||||
op.drop_index(op.f('ix_principal_memberships_user_id'), table_name='principal_memberships')
|
||||
op.drop_table('principal_memberships')
|
||||
|
||||
op.drop_index(op.f('ix_principals_name'), table_name='principals')
|
||||
op.drop_index(op.f('ix_principals_organization_id'), table_name='principals')
|
||||
op.drop_table('principals')
|
||||
|
||||
op.drop_index(op.f('ix_department_memberships_department_id'), table_name='department_memberships')
|
||||
op.drop_index(op.f('ix_department_memberships_user_id'), table_name='department_memberships')
|
||||
op.drop_table('department_memberships')
|
||||
|
||||
op.drop_index(op.f('ix_departments_name'), table_name='departments')
|
||||
op.drop_index(op.f('ix_departments_organization_id'), table_name='departments')
|
||||
op.drop_table('departments')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,173 @@
|
||||
"""Add SSH CA models: SSHKey, SSHCertificate, CA, CertificateAuditLog.
|
||||
|
||||
Revision ID: 007
|
||||
Revises: 006
|
||||
Create Date: 2026-02-27 11:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '007'
|
||||
down_revision = '006'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### CA table ###
|
||||
op.create_table('cas',
|
||||
sa.Column('organization_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('key_type', sa.Enum('ed25519', 'rsa', 'ecdsa', name='ca_key_type_enum'), nullable=False),
|
||||
sa.Column('private_key', sa.Text(), nullable=False),
|
||||
sa.Column('public_key', sa.Text(), nullable=False),
|
||||
sa.Column('fingerprint', sa.String(length=255), nullable=False),
|
||||
sa.Column('crl_enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('crl_endpoint', sa.String(length=512), nullable=True),
|
||||
sa.Column('default_cert_validity_hours', sa.Integer(), nullable=False),
|
||||
sa.Column('max_cert_validity_hours', sa.Integer(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('rotated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('rotation_reason', sa.String(length=255), nullable=True),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('fingerprint'),
|
||||
sa.UniqueConstraint('organization_id', 'name', name='uix_org_ca_name')
|
||||
)
|
||||
op.create_index(op.f('ix_cas_organization_id'), 'cas', ['organization_id'], unique=False)
|
||||
op.create_index('idx_ca_org_active', 'cas', ['organization_id', 'is_active'], unique=False)
|
||||
|
||||
# ### SSHKey table ###
|
||||
op.create_table('ssh_keys',
|
||||
sa.Column('user_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('payload', sa.Text(), nullable=False),
|
||||
sa.Column('fingerprint', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=True),
|
||||
sa.Column('verified', sa.Boolean(), nullable=False),
|
||||
sa.Column('verified_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('verify_text', sa.String(length=255), nullable=True),
|
||||
sa.Column('verify_text_created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('key_type', sa.String(length=50), nullable=True),
|
||||
sa.Column('key_bits', sa.Integer(), nullable=True),
|
||||
sa.Column('key_comment', sa.String(length=255), nullable=True),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('payload'),
|
||||
sa.UniqueConstraint('fingerprint')
|
||||
)
|
||||
op.create_index(op.f('ix_ssh_keys_user_id'), 'ssh_keys', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_keys_fingerprint'), 'ssh_keys', ['fingerprint'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_keys_verified'), 'ssh_keys', ['verified'], unique=False)
|
||||
op.create_index('idx_ssh_key_user_verified', 'ssh_keys', ['user_id', 'verified'], unique=False)
|
||||
|
||||
# ### SSHCertificate table ###
|
||||
op.create_table('ssh_certificates',
|
||||
sa.Column('ca_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('ssh_key_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('certificate', sa.Text(), nullable=False),
|
||||
sa.Column('serial', sa.String(length=255), nullable=False),
|
||||
sa.Column('key_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('cert_type', sa.Enum('user', 'host', name='ssh_cert_type_enum'), nullable=False),
|
||||
sa.Column('principals', sa.JSON(), nullable=False),
|
||||
sa.Column('valid_after', sa.DateTime(), nullable=False),
|
||||
sa.Column('valid_before', sa.DateTime(), nullable=False),
|
||||
sa.Column('revoked', sa.Boolean(), nullable=False),
|
||||
sa.Column('revoked_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('revoke_reason', sa.String(length=255), nullable=True),
|
||||
sa.Column('status', sa.Enum('requested', 'issued', 'revoked', 'expired', 'superseded', name='ssh_cert_status_enum'), nullable=False),
|
||||
sa.Column('request_ip', sa.String(length=45), nullable=True),
|
||||
sa.Column('request_user_agent', sa.String(length=512), nullable=True),
|
||||
sa.Column('critical_options', sa.JSON(), nullable=True),
|
||||
sa.Column('extensions', sa.JSON(), nullable=True),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ),
|
||||
sa.ForeignKeyConstraint(['ssh_key_id'], ['ssh_keys.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('serial')
|
||||
)
|
||||
op.create_index(op.f('ix_ssh_certificates_ca_id'), 'ssh_certificates', ['ca_id'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_certificates_user_id'), 'ssh_certificates', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_certificates_ssh_key_id'), 'ssh_certificates', ['ssh_key_id'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_certificates_serial'), 'ssh_certificates', ['serial'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_certificates_revoked'), 'ssh_certificates', ['revoked'], unique=False)
|
||||
op.create_index(op.f('ix_ssh_certificates_status'), 'ssh_certificates', ['status'], unique=False)
|
||||
op.create_index('idx_cert_user_status', 'ssh_certificates', ['user_id', 'status'], unique=False)
|
||||
op.create_index('idx_cert_validity', 'ssh_certificates', ['valid_after', 'valid_before'], unique=False)
|
||||
op.create_index('idx_cert_revoked', 'ssh_certificates', ['revoked', 'revoked_at'], unique=False)
|
||||
|
||||
# ### CertificateAuditLog table ###
|
||||
op.create_table('certificate_audit_logs',
|
||||
sa.Column('certificate_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('action', sa.String(length=50), nullable=False),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=512), nullable=True),
|
||||
sa.Column('request_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('message', sa.Text(), nullable=True),
|
||||
sa.Column('extra_data', sa.JSON(), nullable=True),
|
||||
sa.Column('success', sa.Boolean(), nullable=False),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['certificate_id'], ['ssh_certificates.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_certificate_audit_logs_certificate_id'), 'certificate_audit_logs', ['certificate_id'], unique=False)
|
||||
op.create_index(op.f('ix_certificate_audit_logs_user_id'), 'certificate_audit_logs', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_certificate_audit_logs_action'), 'certificate_audit_logs', ['action'], unique=False)
|
||||
op.create_index('idx_cert_audit_cert_action', 'certificate_audit_logs', ['certificate_id', 'action'], unique=False)
|
||||
op.create_index('idx_cert_audit_user', 'certificate_audit_logs', ['user_id', 'created_at'], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index('idx_cert_audit_user', table_name='certificate_audit_logs')
|
||||
op.drop_index('idx_cert_audit_cert_action', table_name='certificate_audit_logs')
|
||||
op.drop_index(op.f('ix_certificate_audit_logs_action'), table_name='certificate_audit_logs')
|
||||
op.drop_index(op.f('ix_certificate_audit_logs_user_id'), table_name='certificate_audit_logs')
|
||||
op.drop_index(op.f('ix_certificate_audit_logs_certificate_id'), table_name='certificate_audit_logs')
|
||||
op.drop_table('certificate_audit_logs')
|
||||
|
||||
op.drop_index('idx_cert_revoked', table_name='ssh_certificates')
|
||||
op.drop_index('idx_cert_validity', table_name='ssh_certificates')
|
||||
op.drop_index('idx_cert_user_status', table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_status'), table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_revoked'), table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_serial'), table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_ssh_key_id'), table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_user_id'), table_name='ssh_certificates')
|
||||
op.drop_index(op.f('ix_ssh_certificates_ca_id'), table_name='ssh_certificates')
|
||||
op.drop_table('ssh_certificates')
|
||||
|
||||
op.drop_index('idx_ssh_key_user_verified', table_name='ssh_keys')
|
||||
op.drop_index(op.f('ix_ssh_keys_verified'), table_name='ssh_keys')
|
||||
op.drop_index(op.f('ix_ssh_keys_fingerprint'), table_name='ssh_keys')
|
||||
op.drop_index(op.f('ix_ssh_keys_user_id'), table_name='ssh_keys')
|
||||
op.drop_table('ssh_keys')
|
||||
|
||||
op.drop_index('idx_ca_org_active', table_name='cas')
|
||||
op.drop_index(op.f('ix_cas_organization_id'), table_name='cas')
|
||||
op.drop_table('cas')
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Add TOTP and WEBAUTHN to authmethodtype enum.
|
||||
|
||||
Revision ID: 008
|
||||
Revises: 007
|
||||
Create Date: 2026-02-27 15:00:00.000000
|
||||
|
||||
The original migration (001_base) created authmethodtype with only:
|
||||
PASSWORD, GOOGLE, GITHUB, MICROSOFT, SAML, OIDC
|
||||
|
||||
This migration adds the missing TOTP and WEBAUTHN values so
|
||||
has_totp_enabled() and has_webauthn_enabled() queries work correctly.
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = '008'
|
||||
down_revision = '007'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Add TOTP to the enum (idempotent approach using DO block)
|
||||
op.execute("""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = 'TOTP'
|
||||
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype')
|
||||
) THEN
|
||||
ALTER TYPE authmethodtype ADD VALUE 'TOTP';
|
||||
END IF;
|
||||
END$$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = 'WEBAUTHN'
|
||||
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'authmethodtype')
|
||||
) THEN
|
||||
ALTER TYPE authmethodtype ADD VALUE 'WEBAUTHN';
|
||||
END IF;
|
||||
END$$;
|
||||
""")
|
||||
|
||||
|
||||
def downgrade():
|
||||
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Sync auditaction enum with all AuditAction Python enum values.
|
||||
|
||||
Revision ID: 009
|
||||
Revises: 008
|
||||
Create Date: 2026-02-27 15:20:00.000000
|
||||
|
||||
The auditaction DB enum was only created with the initial 17 values from 001_base.py.
|
||||
All TOTP, WebAuthn, OAuth, SSH, CA, Principal, and Department audit actions were added
|
||||
to the Python enum but never synced to the DB type.
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision = '009'
|
||||
down_revision = '008'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
MISSING_VALUES = [
|
||||
'TOTP_ENROLL_INITIATED', 'TOTP_ENROLL_COMPLETED', 'TOTP_VERIFY_SUCCESS',
|
||||
'TOTP_VERIFY_FAILED', 'TOTP_DISABLED', 'TOTP_BACKUP_CODE_USED',
|
||||
'TOTP_BACKUP_CODES_REGENERATED', 'WEBAUTHN_REGISTER_INITIATED',
|
||||
'WEBAUTHN_REGISTER_COMPLETED', 'WEBAUTHN_REGISTER_FAILED',
|
||||
'WEBAUTHN_LOGIN_INITIATED', 'WEBAUTHN_LOGIN_SUCCESS', 'WEBAUTHN_LOGIN_FAILED',
|
||||
'WEBAUTHN_CREDENTIAL_DELETED', 'WEBAUTHN_CREDENTIAL_RENAMED',
|
||||
'ORG_SECURITY_POLICY_UPDATE', 'USER_SECURITY_POLICY_OVERRIDE_UPDATE',
|
||||
'MFA_POLICY_USER_SUSPENDED', 'MFA_POLICY_USER_COMPLIANT',
|
||||
'EXTERNAL_AUTH_LINK_INITIATED', 'EXTERNAL_AUTH_LINK_COMPLETED',
|
||||
'EXTERNAL_AUTH_LINK_FAILED', 'EXTERNAL_AUTH_UNLINK', 'EXTERNAL_AUTH_LOGIN',
|
||||
'EXTERNAL_AUTH_LOGIN_FAILED', 'EXTERNAL_AUTH_TOKEN_REFRESH',
|
||||
'EXTERNAL_AUTH_CONFIG_CREATE', 'EXTERNAL_AUTH_CONFIG_UPDATE',
|
||||
'EXTERNAL_AUTH_CONFIG_DELETE', 'SSH_KEY_ADDED', 'SSH_KEY_VERIFIED',
|
||||
'SSH_KEY_DELETED', 'SSH_KEY_VALIDATION_FAILED', 'SSH_CERT_REQUESTED',
|
||||
'SSH_CERT_ISSUED', 'SSH_CERT_FAILED', 'SSH_CERT_REVOKED', 'SSH_CERT_EXPIRED',
|
||||
'CA_CREATED', 'CA_UPDATED', 'CA_DELETED', 'CA_KEY_ROTATED',
|
||||
'PRINCIPAL_CREATED', 'PRINCIPAL_UPDATED', 'PRINCIPAL_DELETED',
|
||||
'PRINCIPAL_MEMBER_ADDED', 'PRINCIPAL_MEMBER_REMOVED',
|
||||
'DEPARTMENT_CREATED', 'DEPARTMENT_UPDATED', 'DEPARTMENT_DELETED',
|
||||
'DEPARTMENT_MEMBER_ADDED', 'DEPARTMENT_MEMBER_REMOVED',
|
||||
]
|
||||
|
||||
|
||||
def upgrade():
|
||||
for val in MISSING_VALUES:
|
||||
op.execute(f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = '{val}'
|
||||
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
|
||||
) THEN
|
||||
ALTER TYPE auditaction ADD VALUE '{val}';
|
||||
END IF;
|
||||
END$$;
|
||||
""")
|
||||
|
||||
|
||||
def downgrade():
|
||||
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,50 @@
|
||||
"""add password reset and email verification token tables
|
||||
|
||||
Revision ID: 010_password_reset_email_verify
|
||||
Revises: 009_sync_auditaction_enum
|
||||
Create Date: 2025-01-01 00:00:00.000000
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '010_password_reset_email_verify'
|
||||
down_revision = '009'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
'password_reset_tokens',
|
||||
sa.Column('id', sa.String(36), primary_key=True, nullable=False),
|
||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('token', sa.String(128), nullable=False, unique=True),
|
||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
||||
sa.Column('used_at', sa.DateTime, nullable=True),
|
||||
sa.Column('created_at', sa.DateTime, nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime, nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime, nullable=True),
|
||||
)
|
||||
op.create_index('ix_password_reset_tokens_user_id', 'password_reset_tokens', ['user_id'])
|
||||
op.create_index('ix_password_reset_tokens_token', 'password_reset_tokens', ['token'])
|
||||
|
||||
op.create_table(
|
||||
'email_verification_tokens',
|
||||
sa.Column('id', sa.String(36), primary_key=True, nullable=False),
|
||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('token', sa.String(128), nullable=False, unique=True),
|
||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
||||
sa.Column('used_at', sa.DateTime, nullable=True),
|
||||
sa.Column('created_at', sa.DateTime, nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime, nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime, nullable=True),
|
||||
)
|
||||
op.create_index('ix_email_verification_tokens_user_id', 'email_verification_tokens', ['user_id'])
|
||||
op.create_index('ix_email_verification_tokens_token', 'email_verification_tokens', ['token'])
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table('email_verification_tokens')
|
||||
op.drop_table('password_reset_tokens')
|
||||
@@ -0,0 +1,38 @@
|
||||
"""add org_invite_tokens table
|
||||
|
||||
Revision ID: 011_org_invite_tokens
|
||||
Revises: 010_password_reset_email_verify
|
||||
Create Date: 2025-01-01 00:00:00.000000
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = '011_org_invite_tokens'
|
||||
down_revision = '010_password_reset_email_verify'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
'org_invite_tokens',
|
||||
sa.Column('id', sa.String(36), primary_key=True, nullable=False),
|
||||
sa.Column('organization_id', sa.String(36), sa.ForeignKey('organizations.id', ondelete='CASCADE'), nullable=False),
|
||||
sa.Column('invited_by_id', sa.String(36), sa.ForeignKey('users.id', ondelete='SET NULL'), nullable=True),
|
||||
sa.Column('email', sa.String(255), nullable=False),
|
||||
sa.Column('role', sa.String(64), nullable=False, server_default='member'),
|
||||
sa.Column('token', sa.String(128), nullable=False, unique=True),
|
||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
||||
sa.Column('accepted_at', sa.DateTime, nullable=True),
|
||||
sa.Column('created_at', sa.DateTime, nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime, nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime, nullable=True),
|
||||
)
|
||||
op.create_index('ix_org_invite_tokens_organization_id', 'org_invite_tokens', ['organization_id'])
|
||||
op.create_index('ix_org_invite_tokens_email', 'org_invite_tokens', ['email'])
|
||||
op.create_index('ix_org_invite_tokens_token', 'org_invite_tokens', ['token'])
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table('org_invite_tokens')
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Make CA.organization_id nullable (system CA) and add cert_id to sign response
|
||||
|
||||
Revision ID: 012_ca_nullable_org_and_cert_serial
|
||||
Revises: 011_org_invite_tokens
|
||||
Create Date: 2025-01-01 00:00:00.000000
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = '012_ca_nullable_org'
|
||||
down_revision = '011_org_invite_tokens'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Allow CA records without an org (e.g. the global system-config CA)
|
||||
with op.batch_alter_table('cas', schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
'organization_id',
|
||||
existing_type=sa.String(36),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table('cas', schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
'organization_id',
|
||||
existing_type=sa.String(36),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Add ca_type column to cas table (user/host).
|
||||
|
||||
Revision ID: 013
|
||||
Revises: d34bfb72844e
|
||||
Create Date: 2026-02-28 23:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '013'
|
||||
down_revision = 'd34bfb72844e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Create the enum type first (PostgreSQL requires this)
|
||||
ca_type_enum = sa.Enum('user', 'host', name='ca_type_enum')
|
||||
ca_type_enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# Add ca_type column with a default of 'user' so existing CAs stay valid
|
||||
op.add_column(
|
||||
'cas',
|
||||
sa.Column(
|
||||
'ca_type',
|
||||
ca_type_enum,
|
||||
nullable=False,
|
||||
server_default='user',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column('cas', 'ca_type')
|
||||
# Drop the enum type (PostgreSQL only; SQLite ignores)
|
||||
try:
|
||||
op.execute("DROP TYPE IF EXISTS ca_type_enum")
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,44 @@
|
||||
"""add_department_cert_policies
|
||||
|
||||
Adds the department_cert_policies table which stores per-department
|
||||
SSH certificate issuance rules:
|
||||
- whether users may choose their own expiry
|
||||
- default and maximum expiry durations
|
||||
- allowed SSH certificate extensions
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "014_add_dept_cert_policy"
|
||||
down_revision = "013"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"department_cert_policies",
|
||||
sa.Column("id", sa.String(36), primary_key=True),
|
||||
sa.Column("department_id", sa.String(36), sa.ForeignKey("departments.id"), nullable=False, unique=True),
|
||||
# Whether users are allowed to specify their own expiry (up to max)
|
||||
sa.Column("allow_user_expiry", sa.Boolean(), nullable=False, server_default="0"),
|
||||
# Default validity in hours (used when user doesn't specify, or not allowed to)
|
||||
sa.Column("default_expiry_hours", sa.Integer(), nullable=False, server_default="1"),
|
||||
# Hard cap on validity; admin cannot be exceeded
|
||||
sa.Column("max_expiry_hours", sa.Integer(), nullable=False, server_default="24"),
|
||||
# JSON list of extension names that are enabled for this department
|
||||
# e.g. ["permit-pty", "permit-agent-forwarding"]
|
||||
sa.Column("allowed_extensions", sa.JSON(), nullable=False, server_default='["permit-pty","permit-agent-forwarding","permit-X11-forwarding","permit-port-forwarding","permit-user-rc"]'),
|
||||
# Admin-defined custom extension names beyond the standard five
|
||||
sa.Column("custom_extensions", sa.JSON(), nullable=False, server_default="[]"),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("deleted_at", sa.DateTime(), nullable=True),
|
||||
)
|
||||
op.create_index("idx_dept_cert_policy_dept", "department_cert_policies", ["department_id"])
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index("idx_dept_cert_policy_dept", "department_cert_policies")
|
||||
op.drop_table("department_cert_policies")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add USER_SUSPEND and USER_UNSUSPEND to auditaction enum.
|
||||
|
||||
Revision ID: 015_add_user_suspend_audit_actions
|
||||
Revises: 014_add_dept_cert_policy
|
||||
Create Date: 2026-03-02
|
||||
|
||||
USER_SUSPEND and USER_UNSUSPEND were added to the Python AuditAction enum
|
||||
but were never synced to the PostgreSQL auditaction type, causing a
|
||||
DataError (invalid enum value) whenever an admin suspends or unsuspends a user.
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
revision = "015_user_suspend_audit"
|
||||
down_revision = "014_add_dept_cert_policy"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
for val in ("USER_SUSPEND", "USER_UNSUSPEND"):
|
||||
op.execute(f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_enum
|
||||
WHERE enumlabel = '{val}'
|
||||
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')
|
||||
) THEN
|
||||
ALTER TYPE auditaction ADD VALUE '{val}';
|
||||
END IF;
|
||||
END$$;
|
||||
""")
|
||||
|
||||
|
||||
def downgrade():
|
||||
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,168 @@
|
||||
"""Encrypt existing plaintext CA private keys at rest.
|
||||
|
||||
Revision ID: 016_encrypt_existing_ca_keys
|
||||
Revises: 015_add_user_suspend_audit_actions
|
||||
Create Date: 2026-03-02
|
||||
|
||||
All CA private keys created before this migration were stored as plaintext PEM
|
||||
strings in the ``cas.private_key`` column. This migration detects those rows
|
||||
(by checking for the absence of the ``$fernet$`` prefix that encrypted values
|
||||
carry) and re-encrypts them with the key derived from ``CA_ENCRYPTION_KEY``.
|
||||
|
||||
The migration is safe to re-run: already-encrypted rows are left untouched.
|
||||
|
||||
Prerequisites
|
||||
-------------
|
||||
``CA_ENCRYPTION_KEY`` must be set in the environment before running this
|
||||
migration. The same value must be configured for the running application.
|
||||
|
||||
To roll back to plaintext (downgrade):
|
||||
The ``downgrade()`` function decrypts all rows back to plaintext PEM. This is
|
||||
provided only for emergency rollback and should not be used in production once
|
||||
the system has been running with encrypted keys.
|
||||
"""
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Alembic revision identifiers
|
||||
revision = "016_encrypt_ca_keys"
|
||||
down_revision = "015_user_suspend_audit"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_FERNET_PREFIX = "$fernet$"
|
||||
|
||||
|
||||
def _get_fernet():
|
||||
"""Build a Fernet instance from CA_ENCRYPTION_KEY env var."""
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
raw_key = os.environ.get("CA_ENCRYPTION_KEY")
|
||||
if not raw_key:
|
||||
raise RuntimeError(
|
||||
"CA_ENCRYPTION_KEY environment variable is not set. "
|
||||
"Set it before running this migration."
|
||||
)
|
||||
key_bytes = base64.urlsafe_b64encode(hashlib.sha256(raw_key.encode()).digest())
|
||||
return Fernet(key_bytes)
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Encrypt plaintext CA private keys."""
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
except RuntimeError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
# Fetch all non-deleted CA rows
|
||||
rows = session.execute(
|
||||
sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
|
||||
).fetchall()
|
||||
|
||||
encrypted_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for row in rows:
|
||||
ca_id, private_key = row[0], row[1]
|
||||
|
||||
if not private_key:
|
||||
logger.warning(f"CA {ca_id} has empty private_key — skipping")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
if private_key.startswith(_FERNET_PREFIX):
|
||||
# Already encrypted
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# Encrypt
|
||||
try:
|
||||
token = fernet.encrypt(private_key.encode()).decode()
|
||||
encrypted_value = f"{_FERNET_PREFIX}{token}"
|
||||
session.execute(
|
||||
sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
|
||||
{"pk": encrypted_value, "id": ca_id},
|
||||
)
|
||||
encrypted_count += 1
|
||||
logger.info(f"Encrypted private key for CA {ca_id}")
|
||||
except Exception as exc:
|
||||
session.rollback()
|
||||
raise RuntimeError(
|
||||
f"Failed to encrypt private key for CA {ca_id}: {exc}"
|
||||
) from exc
|
||||
|
||||
session.commit()
|
||||
logger.info(
|
||||
f"CA key encryption migration complete: "
|
||||
f"{encrypted_count} encrypted, {skipped_count} skipped"
|
||||
)
|
||||
print(
|
||||
f" [016_encrypt_ca_keys] {encrypted_count} CA private key(s) encrypted, "
|
||||
f"{skipped_count} already encrypted or empty."
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Decrypt CA private keys back to plaintext (emergency rollback only)."""
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
try:
|
||||
fernet = _get_fernet()
|
||||
except RuntimeError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
rows = session.execute(
|
||||
sa.text("SELECT id, private_key FROM cas WHERE deleted_at IS NULL")
|
||||
).fetchall()
|
||||
|
||||
decrypted_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for row in rows:
|
||||
ca_id, private_key = row[0], row[1]
|
||||
|
||||
if not private_key or not private_key.startswith(_FERNET_PREFIX):
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
token = private_key[len(_FERNET_PREFIX):]
|
||||
try:
|
||||
from cryptography.fernet import InvalidToken
|
||||
try:
|
||||
plaintext = fernet.decrypt(token.encode()).decode()
|
||||
except InvalidToken as exc:
|
||||
raise RuntimeError(
|
||||
f"Downgrade failed: cannot decrypt CA {ca_id} — wrong key or corrupted data."
|
||||
) from exc
|
||||
|
||||
session.execute(
|
||||
sa.text("UPDATE cas SET private_key = :pk WHERE id = :id"),
|
||||
{"pk": plaintext, "id": ca_id},
|
||||
)
|
||||
decrypted_count += 1
|
||||
logger.warning(f"Decrypted (plaintext restore) private key for CA {ca_id}")
|
||||
except RuntimeError:
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
session.commit()
|
||||
logger.warning(
|
||||
f"CA key decryption (downgrade) complete: "
|
||||
f"{decrypted_count} decrypted, {skipped_count} skipped"
|
||||
)
|
||||
print(
|
||||
f" [016_encrypt_ca_keys] DOWNGRADE: {decrypted_count} CA private key(s) "
|
||||
f"decrypted to plaintext. WARNING: keys are now unencrypted at rest."
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add monotonic serial counter to CAs table.
|
||||
|
||||
Each CA now owns a `next_serial_number` (BigInteger) that is atomically
|
||||
incremented every time a certificate is signed. This guarantees:
|
||||
- Serials are unique per CA
|
||||
- Serials are monotonically increasing (auditable, no gaps by accident)
|
||||
- The value embedded in the OpenSSH certificate matches what is stored
|
||||
in the `ssh_certificates.serial` column
|
||||
|
||||
Revision ID: 017_add_ca_serial_counter
|
||||
Revises: 016_encrypt_ca_keys
|
||||
Create Date: 2026-03-02
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "017_add_ca_serial_counter"
|
||||
down_revision = "016_encrypt_ca_keys"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("cas", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"next_serial_number",
|
||||
sa.BigInteger(),
|
||||
nullable=False,
|
||||
server_default="1",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("cas", schema=None) as batch_op:
|
||||
batch_op.drop_column("next_serial_number")
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Add ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE to auditaction enum.
|
||||
|
||||
Revision ID: 018_audit_enum_values
|
||||
Revises: 017_add_ca_serial_counter
|
||||
Create Date: 2026-03-02
|
||||
|
||||
ORG_OWNERSHIP_TRANSFERRED and USER_HARD_DELETE were added to the Python
|
||||
AuditAction enum but were never synced to the PostgreSQL auditaction type,
|
||||
causing a DataError (invalid enum value) when transferring org ownership
|
||||
or hard-deleting a user.
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
revision = "018_audit_enum_values"
|
||||
down_revision = "017_add_ca_serial_counter"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ALTER TYPE ... ADD VALUE cannot run inside a transaction block in PostgreSQL.
|
||||
# Alembic has already opened a transaction on the connection by the time our
|
||||
# upgrade() runs, so we must:
|
||||
# 1. Roll back that open transaction on the raw psycopg2 connection.
|
||||
# 2. Switch to autocommit so the ALTER TYPE runs outside any transaction.
|
||||
# 3. Restore the previous state afterwards.
|
||||
conn = op.get_bind()
|
||||
# SQLAlchemy 2.x: conn.connection is a _ConnectionFairy; .driver_connection is psycopg2
|
||||
fairy = conn.connection
|
||||
raw = getattr(fairy, "driver_connection", None) or getattr(fairy, "dbapi_connection", fairy)
|
||||
# Roll back the open transaction so psycopg2 allows us to change autocommit.
|
||||
raw.rollback()
|
||||
old_autocommit = raw.autocommit
|
||||
raw.autocommit = True
|
||||
try:
|
||||
with raw.cursor() as cur:
|
||||
for val in ("ORG_OWNERSHIP_TRANSFERRED", "USER_HARD_DELETE"):
|
||||
cur.execute(
|
||||
"SELECT 1 FROM pg_enum "
|
||||
"WHERE enumlabel = %s "
|
||||
"AND enumtypid = (SELECT oid FROM pg_type WHERE typname = 'auditaction')",
|
||||
(val,),
|
||||
)
|
||||
if not cur.fetchone():
|
||||
cur.execute(f"ALTER TYPE auditaction ADD VALUE '{val}'")
|
||||
finally:
|
||||
raw.autocommit = old_autocommit
|
||||
|
||||
|
||||
def downgrade():
|
||||
# PostgreSQL does not support removing enum values; downgrade is a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,50 @@
|
||||
"""add_activation_fields_and_ca_permissions
|
||||
|
||||
Revision ID: d34bfb72844e
|
||||
Revises: 012_ca_nullable_org
|
||||
Create Date: 2026-02-28 18:06:47.328552
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd34bfb72844e'
|
||||
down_revision = '012_ca_nullable_org'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Create ca_permissions table
|
||||
op.create_table(
|
||||
'ca_permissions',
|
||||
sa.Column('ca_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('permission', sa.String(length=50), nullable=False),
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('ca_id', 'user_id', name='uix_ca_permission'),
|
||||
)
|
||||
op.create_index('ix_ca_permissions_ca_id', 'ca_permissions', ['ca_id'], unique=False)
|
||||
op.create_index('ix_ca_permissions_user_id', 'ca_permissions', ['user_id'], unique=False)
|
||||
|
||||
# Add activation columns to users
|
||||
op.add_column('users', sa.Column('activated', sa.Boolean(), nullable=False,
|
||||
server_default=sa.text('true')))
|
||||
op.add_column('users', sa.Column('activation_key', sa.String(length=128), nullable=True))
|
||||
op.create_index('ix_users_activation_key', 'users', ['activation_key'], unique=True)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index('ix_users_activation_key', table_name='users')
|
||||
op.drop_column('users', 'activation_key')
|
||||
op.drop_column('users', 'activated')
|
||||
op.drop_index('ix_ca_permissions_user_id', table_name='ca_permissions')
|
||||
op.drop_index('ix_ca_permissions_ca_id', table_name='ca_permissions')
|
||||
op.drop_table('ca_permissions')
|
||||
@@ -14,7 +14,7 @@ Flask-Marshmallow==0.15.0
|
||||
marshmallow-sqlalchemy==0.29.0
|
||||
|
||||
# Security
|
||||
bcrypt==4.1.2
|
||||
bcrypt==4.2.0
|
||||
Flask-Bcrypt==1.0.1
|
||||
pyotp==2.9.0
|
||||
|
||||
@@ -24,7 +24,7 @@ cbor2==5.6.0
|
||||
|
||||
# JWT / OIDC
|
||||
PyJWT==2.8.0
|
||||
cryptography==41.0.7
|
||||
cryptography==42.0.7
|
||||
|
||||
# CORS
|
||||
Flask-CORS==4.0.0
|
||||
@@ -47,4 +47,7 @@ Flask-Limiter==3.5.0
|
||||
|
||||
# Logging
|
||||
python-json-logger==2.0.7
|
||||
qrcode[pil]
|
||||
qrcode[pil]
|
||||
|
||||
# SSH CA Certificate signing
|
||||
sshkey-tools==0.11.3
|
||||
|
||||
@@ -20,3 +20,25 @@ watchdog==3.0.0
|
||||
|
||||
# Documentation
|
||||
sphinx==7.2.6
|
||||
|
||||
# Web framework & Database
|
||||
Flask==3.0.0
|
||||
Flask-SQLAlchemy==3.1.1
|
||||
Flask-Migrate==4.0.5
|
||||
sqlalchemy-cockroachdb==2.0.3
|
||||
|
||||
# Utilities
|
||||
colorlog==6.8.0
|
||||
coloredlogs==15.0.1
|
||||
prettytable==3.10.2
|
||||
tabulate==0.9.0
|
||||
requests==2.31.0
|
||||
pytz==2023.3
|
||||
python-dotenv==1.0.0
|
||||
pydantic==2.5.0
|
||||
PyJWT==2.8.0
|
||||
cryptography==42.0.7
|
||||
pycryptodome==3.20.0
|
||||
psycopg2-binary==2.9.9
|
||||
sshkey-tools==0.11.3
|
||||
sendgrid==6.11.0
|
||||
|
||||
@@ -15,11 +15,11 @@ load_dotenv()
|
||||
|
||||
from gatehouse_app import create_app
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.user import User
|
||||
from gatehouse_app.models.organization import Organization
|
||||
from gatehouse_app.models.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.models.oidc_client import OIDCClient
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.auth.authentication_method import AuthenticationMethod
|
||||
from gatehouse_app.models.oidc.oidc_client import OIDCClient
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.utils.constants import OrganizationRole, UserStatus, AuthMethodType
|
||||
|
||||
Reference in New Issue
Block a user