Merge branch 'main' into v1.01/stable
This commit is contained in:
+144
@@ -0,0 +1,144 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# IDEs
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Project specific
|
||||
|
||||
*.db
|
||||
flask_session/
|
||||
|
||||
# Opencode files and folders
|
||||
.opencode/
|
||||
.swarm/
|
||||
SWARM_PLAN.*
|
||||
@@ -0,0 +1,40 @@
|
||||
FROM python:3.11-slim as builder
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
WORKDIR /app
|
||||
COPY requirements/base.txt requirements/base.txt
|
||||
COPY requirements/production.txt requirements/production.txt
|
||||
|
||||
RUN pip install --no-cache-dir --upgrade pip wheel && \
|
||||
pip install --no-cache-dir -r requirements/production.txt
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libpq5 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN groupadd --gid 1000 appgroup && \
|
||||
useradd --uid 1000 --gid appgroup --shell /bin/bash --create-home appuser
|
||||
|
||||
COPY --from=builder /opt/venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
WORKDIR /app
|
||||
COPY --chown=appuser:appgroup . .
|
||||
|
||||
RUN mkdir -p /app/logs && chown -R appuser:appgroup /app/logs
|
||||
|
||||
USER appuser
|
||||
|
||||
HEALTHCHECK --interval=60s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD pgrep -f "job_runner" || exit 1
|
||||
|
||||
CMD ["python", "scripts/job_runner.py"]
|
||||
@@ -154,6 +154,7 @@ Copy `.env.example` to `.env` and configure:
|
||||
- `POST /api/v1/auth/logout` - Logout
|
||||
- `GET /api/v1/auth/me` - Get current user
|
||||
- `GET /api/v1/auth/sessions` - Get user sessions
|
||||
- `POST /api/v1/auth/sessions/refresh` - Extend session idle window
|
||||
- `DELETE /api/v1/auth/sessions/:id` - Revoke session
|
||||
|
||||
### Users
|
||||
@@ -174,6 +175,9 @@ Copy `.env.example` to `.env` and configure:
|
||||
- `PATCH /api/v1/organizations/:id/members/:userId/role` - Update role
|
||||
|
||||
|
||||
### Contact (Public — No Auth Required)
|
||||
- `POST /api/v1/contact` - Submit a contact enquiry (demo request, sales enquiry, general, or support). Rate limited to 5 requests per IP per hour. Sends an email to info@secuird.tech.
|
||||
|
||||
### Health
|
||||
- `GET /api/health` - Health check
|
||||
|
||||
@@ -261,6 +265,52 @@ gunicorn -w 4 -b 0.0.0.0:8000 wsgi:app
|
||||
- Request ID tracking for audit trails
|
||||
|
||||
|
||||
## Session Management
|
||||
|
||||
Sessions are database-backed bearer tokens stored in PostgreSQL. Each session is created at login and validated on every authenticated request via the `login_required` decorator.
|
||||
|
||||
### Sliding Timeout
|
||||
|
||||
Sessions use a **sliding window** model with two independent limits:
|
||||
|
||||
| Timeout | Default | Env Var | Behaviour |
|
||||
|---------|---------|---------|-----------|
|
||||
| **Idle** | 15 min | `SESSION_IDLE_TIMEOUT` | Extends automatically on every request. If no request is made within this window the session expires. |
|
||||
| **Absolute** | 8 h | `SESSION_ABSOLUTE_TIMEOUT` | Hard cap measured from session creation. Activity cannot extend a session beyond this point. |
|
||||
|
||||
Every authenticated request resets the idle clock by calling `Session.refresh()`, which sets `expires_at = now + idle_timeout` — but never past `created_at + absolute_timeout`. This means:
|
||||
|
||||
- An active user stays logged in indefinitely **up to** the absolute cap.
|
||||
- An idle user is logged out after the idle timeout.
|
||||
- No session can survive longer than the absolute timeout regardless of activity.
|
||||
|
||||
### Configuration
|
||||
|
||||
Override defaults via environment variables:
|
||||
|
||||
```bash
|
||||
SESSION_IDLE_TIMEOUT=900 # seconds (15 min)
|
||||
SESSION_ABSOLUTE_TIMEOUT=28800 # seconds (8 h)
|
||||
```
|
||||
|
||||
### Cleanup
|
||||
|
||||
Expired sessions are soft-marked as `EXPIRED` by the `cleanup_sessions` job. Run it periodically via the job runner:
|
||||
|
||||
```bash
|
||||
python manage.py cleanup_sessions
|
||||
|
||||
# Or via the job runner (Docker):
|
||||
JOB_NAME=cleanup_sessions JOB_INTERVAL_SECONDS=300
|
||||
```
|
||||
|
||||
### Session Endpoints
|
||||
|
||||
- `GET /api/v1/auth/sessions` — List active sessions for the current user
|
||||
- `POST /api/v1/auth/sessions/refresh` — Extend the current session's idle window (returns new `expires_at`)
|
||||
- `DELETE /api/v1/auth/sessions/:id` — Revoke a specific session
|
||||
|
||||
|
||||
# Boostrap db
|
||||
python manage.py db upgrade
|
||||
|
||||
|
||||
+163
-7
@@ -49,10 +49,96 @@ class MyServer(BaseHTTPRequestHandler):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(bytes("<html><head><title>OIDC Workflow Tool</title></head>", "utf-8"))
|
||||
self.wfile.write(bytes("<body><p>The token has been received</p>", "utf-8"))
|
||||
self.wfile.write(bytes("<p>You may now close this window.</p>", "utf-8"))
|
||||
self.wfile.write(bytes("</body></html>", "utf-8"))
|
||||
html_content = """<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Successful - Gatehouse</title>
|
||||
<!-- Best-effort CSS load from primary site -->
|
||||
<link rel="stylesheet" href="{SIGN_URL}/static/css/main.css">
|
||||
<style>
|
||||
* {{
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}}
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
|
||||
background-color: #f0f4f8;
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}}
|
||||
.card {{
|
||||
background: white;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.08);
|
||||
padding: 48px 40px;
|
||||
text-align: center;
|
||||
max-width: 400px;
|
||||
width: 90%;
|
||||
}}
|
||||
.checkmark {{
|
||||
width: 64px;
|
||||
height: 64px;
|
||||
background: #10b981;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin: 0 auto 24px;
|
||||
}}
|
||||
.checkmark svg {{
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
stroke: white;
|
||||
stroke-width: 3;
|
||||
fill: none;
|
||||
}}
|
||||
h1 {{
|
||||
color: #1f2937;
|
||||
font-size: 24px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 12px;
|
||||
}}
|
||||
p {{
|
||||
color: #6b7280;
|
||||
font-size: 16px;
|
||||
line-height: 1.5;
|
||||
}}
|
||||
.fallback {{
|
||||
margin-top: 24px;
|
||||
padding-top: 24px;
|
||||
border-top: 1px solid #e5e7eb;
|
||||
color: #9ca3af;
|
||||
font-size: 14px;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<div class="checkmark">
|
||||
<svg viewBox="0 0 24 24">
|
||||
<polyline points="20 6 9 17 4 12"></polyline>
|
||||
</svg>
|
||||
</div>
|
||||
<h1>Authentication Complete</h1>
|
||||
<p>You can now return to the terminal.</p>
|
||||
<p class="fallback">If this window doesn't close automatically, you can close it manually.</p>
|
||||
</div>
|
||||
<script>
|
||||
setTimeout(function() {{
|
||||
window.close();
|
||||
if (window.innerHeight > 0) {{
|
||||
document.querySelector('.fallback').textContent = 'Window refused to close. You may close this tab manually.';
|
||||
}}
|
||||
}}, 2000);
|
||||
</script>
|
||||
</body>
|
||||
</html>""".format(SIGN_URL=SIGN_URL)
|
||||
self.wfile.write(bytes(html_content, "utf-8"))
|
||||
|
||||
parsed_url = urlparse(self.path)
|
||||
query_data = dict(parse_qsl(parsed_url.query))
|
||||
@@ -253,7 +339,7 @@ def fetch_my_principals():
|
||||
return principal_names
|
||||
|
||||
|
||||
def request_certificate():
|
||||
def request_certificate(org_id=None):
|
||||
CERT_ID = os.getenv("CERT_ID") or get_activated_ssh_key()
|
||||
|
||||
principals = fetch_my_principals()
|
||||
@@ -272,9 +358,13 @@ def request_certificate():
|
||||
'principals': principals,
|
||||
}
|
||||
|
||||
# Add organization_id if specified
|
||||
if org_id:
|
||||
payload['organization_id'] = org_id
|
||||
|
||||
try:
|
||||
response = requests.post(f"{SIGN_URL}/api/v1/ssh/sign", json=payload, headers=headers)
|
||||
|
||||
|
||||
if response.status_code == 201:
|
||||
json_result = response.json().get('data', response.json())
|
||||
with open(CERT_FILE_PATH, 'w') as f:
|
||||
@@ -287,14 +377,41 @@ def request_certificate():
|
||||
|
||||
logger.info(f"Certificate signed successfully, located at {CERT_FILE_PATH}")
|
||||
logger.info(f"Valid for principals: {', '.join(json_result.get('principals', principals))}")
|
||||
|
||||
# Show which org issued the cert
|
||||
org_name = json_result.get('organization_name', 'Unknown')
|
||||
logger.info(f"Issued by organization: {org_name}")
|
||||
|
||||
logger.info("You can login to your destination server with the following command")
|
||||
logger.info(f"\tssh user@server -o CertificateFile={CERT_FILE_PATH}")
|
||||
|
||||
elif response.status_code == 400:
|
||||
error_data = response.json()
|
||||
if error_data.get('error', {}).get('type') == 'MULTIPLE_ORGS_AMBIGUOUS':
|
||||
logger.error("You are a member of multiple organizations. Please specify one with --org-id")
|
||||
logger.error("\nYour organizations:")
|
||||
for org in error_data.get('error', {}).get('details', {}).get('organizations', []):
|
||||
logger.error(f" - {org['name']} (ID: {org['id']}, Role: {org['role']})")
|
||||
logger.error("\nRun: secuird --list-orgs to see all your organizations")
|
||||
logger.error("Then run: secuird -r --org-id <organization_id>")
|
||||
else:
|
||||
logger.error(f"Error: {error_data.get('message', 'Unknown error')}")
|
||||
exit(1)
|
||||
|
||||
elif response.status_code == 403:
|
||||
error_data = response.json()
|
||||
logger.error(f"Permission denied: {error_data.get('message', 'Unknown error')}")
|
||||
exit(1)
|
||||
|
||||
else:
|
||||
logger.error("Error in response from server")
|
||||
logger.error(f"Status code: {response.status_code}")
|
||||
logger.error(f"Response text: {response.text}")
|
||||
exit(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during certificate signing: {e}")
|
||||
exit(1)
|
||||
|
||||
def generate_and_sign_challenge(ssh_key_file, key_id):
|
||||
"""Fetch a challenge from the server, sign it with the SSH key, and submit the signature."""
|
||||
@@ -321,11 +438,13 @@ def generate_and_sign_challenge(ssh_key_file, key_id):
|
||||
|
||||
with open(CHALLENGE_FILE_PATH, 'w') as f:
|
||||
f.write(challenge_text)
|
||||
os.chmod(CHALLENGE_FILE_PATH, 0o600)
|
||||
|
||||
subprocess.run(
|
||||
["ssh-keygen", "-Y", "sign", "-f", ssh_key_file, "-n", "file", CHALLENGE_FILE_PATH],
|
||||
check=True,
|
||||
)
|
||||
os.chmod(CHALLENGE_SIG_FILE_PATH, 0o600)
|
||||
|
||||
with open(CHALLENGE_SIG_FILE_PATH, 'rb') as f:
|
||||
signature = base64.b64encode(f.read()).decode('utf-8')
|
||||
@@ -399,6 +518,38 @@ def remove_ssh_key(key_id=None):
|
||||
else:
|
||||
logger.error(f"Failed to remove key {k['id']}: {del_response.status_code} - {del_response.text}")
|
||||
|
||||
def list_organizations():
|
||||
"""List all organizations the user is a member of."""
|
||||
response = requests.get(
|
||||
f"{SIGN_URL}/api/v1/users/me/organizations/simple",
|
||||
headers=auth_headers()
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to list organizations: {response.status_code} - {response.text}")
|
||||
exit(1)
|
||||
|
||||
data = response.json().get('data', {})
|
||||
orgs = data.get('organizations', [])
|
||||
|
||||
if not orgs:
|
||||
print("You are not a member of any organizations.")
|
||||
return
|
||||
|
||||
print("\nYour Organizations:")
|
||||
print("-" * 80)
|
||||
for org in orgs:
|
||||
ca_status = []
|
||||
if org.get('has_user_ca'):
|
||||
ca_status.append("User CA ✓")
|
||||
if org.get('has_host_ca'):
|
||||
ca_status.append("Host CA ✓")
|
||||
|
||||
ca_str = f" ({', '.join(ca_status)})" if ca_status else " (No CAs configured)"
|
||||
|
||||
print(f" ID: {org['id']}")
|
||||
print(f" Name: {org['name']}{ca_str}")
|
||||
print(f" Role: {org['role']}")
|
||||
print("-" * 80)
|
||||
|
||||
def add_ssh_key(ssh_key_file):
|
||||
"""Add an SSH key to the server and auto-verify it."""
|
||||
@@ -540,6 +691,11 @@ if __name__ == "__main__":
|
||||
remove_ssh_key(args.remove_key if args.remove_key else None)
|
||||
exit(0)
|
||||
|
||||
if args.list_orgs:
|
||||
request_token()
|
||||
list_organizations()
|
||||
exit(0)
|
||||
|
||||
if args.list_keys:
|
||||
request_token()
|
||||
response = requests.get(f"{SIGN_URL}/api/v1/ssh/keys", headers=auth_headers())
|
||||
@@ -580,5 +736,5 @@ if __name__ == "__main__":
|
||||
if args.force:
|
||||
logger.info("Forcing renewal of certificate")
|
||||
if args.force or checkCert() == 1:
|
||||
request_certificate()
|
||||
request_certificate(org_id=args.org_id)
|
||||
exit(0)
|
||||
|
||||
@@ -48,6 +48,10 @@ class BaseConfig:
|
||||
seconds=int(os.getenv("MAX_SESSION_DURATION", "86400"))
|
||||
)
|
||||
|
||||
# Session timeout policy (seconds)
|
||||
SESSION_IDLE_TIMEOUT = int(os.getenv("SESSION_IDLE_TIMEOUT", "900"))
|
||||
SESSION_ABSOLUTE_TIMEOUT = int(os.getenv("SESSION_ABSOLUTE_TIMEOUT", "28800"))
|
||||
|
||||
# CORS
|
||||
CORS_ORIGINS = os.getenv(
|
||||
"CORS_ORIGINS",
|
||||
@@ -83,6 +87,7 @@ class BaseConfig:
|
||||
RATELIMIT_AUTH_TOTP_VERIFY = os.getenv("RATELIMIT_AUTH_TOTP_VERIFY", "20 per minute; 100 per hour")
|
||||
RATELIMIT_AUTH_FORGOT_PASSWORD = os.getenv("RATELIMIT_AUTH_FORGOT_PASSWORD", "5 per minute; 20 per hour")
|
||||
RATELIMIT_AUTH_RESET_PASSWORD = os.getenv("RATELIMIT_AUTH_RESET_PASSWORD", "10 per minute; 30 per hour")
|
||||
RATELIMIT_AUTH_RESEND_VERIFICATION = os.getenv("RATELIMIT_AUTH_RESEND_VERIFICATION", "5 per minute; 20 per hour")
|
||||
|
||||
# Logging
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||
|
||||
+7
-1
@@ -30,7 +30,13 @@ class TestingConfig(BaseConfig):
|
||||
|
||||
# Use different Redis DB for testing
|
||||
REDIS_URL = "redis://localhost:6379/15"
|
||||
|
||||
|
||||
# Use filesystem for sessions in testing
|
||||
SESSION_TYPE = "filesystem"
|
||||
SESSION_FILE_DIR = "/tmp/flask_session_test"
|
||||
|
||||
# Override cookie domain so test_client on localhost can send cookies
|
||||
SESSION_COOKIE_DOMAIN = None
|
||||
WEBAUTHN_RP_ID = "localhost"
|
||||
WEBAUTHN_ORIGIN = "http://localhost:8080"
|
||||
FRONTEND_URL = "http://localhost:8080"
|
||||
|
||||
@@ -78,6 +78,42 @@ services:
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
zerotier-reconciler:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.job
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- JOB_NAME=zerotier_reconciliation
|
||||
- JOB_INTERVAL_SECONDS=${ZEROTIER_RECONCILE_INTERVAL:-120}
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- authy2-network
|
||||
restart: unless-stopped
|
||||
|
||||
mfa-compliance:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.job
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- JOB_NAME=mfa_compliance
|
||||
- JOB_INTERVAL_SECONDS=${MFA_COMPLIANCE_INTERVAL:-3600}
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- authy2-network
|
||||
restart: unless-stopped
|
||||
|
||||
networks:
|
||||
authy2-network:
|
||||
driver: bridge
|
||||
|
||||
@@ -5,7 +5,9 @@ from flask import Blueprint
|
||||
api_v1_bp = Blueprint("api_v1", __name__)
|
||||
|
||||
# Import route modules to register them
|
||||
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh, zerotier, sudo, oidc
|
||||
from gatehouse_app.api.v1 import auth, users, organizations, policies, external_auth, departments, principals, ssh, zerotier, sudo, oidc, contact
|
||||
from gatehouse_app.api.v1 import superadmin
|
||||
|
||||
api_v1_bp.register_blueprint(ssh.ssh_bp)
|
||||
api_v1_bp.register_blueprint(superadmin.superadmin_bp)
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ def login():
|
||||
|
||||
remember_me = data.get("remember_me", False)
|
||||
policy_result = MfaPolicyService.after_primary_auth_success(user, remember_me)
|
||||
duration = 2592000 if remember_me else 86400
|
||||
duration = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800) if remember_me else None
|
||||
is_compliance_only = policy_result.create_compliance_only_session
|
||||
|
||||
user_session = AuthService.create_session(user, duration_seconds=duration, is_compliance_only=is_compliance_only)
|
||||
@@ -227,6 +227,32 @@ def revoke_session(session_id):
|
||||
return api_response(message="Session revoked successfully")
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/sessions/refresh", methods=["POST"])
|
||||
@login_required
|
||||
def refresh_session():
|
||||
"""Extend the current session's idle window.
|
||||
|
||||
The server already refreshes the session on every authenticated
|
||||
request, but this endpoint exists so the frontend can proactively
|
||||
keep a session alive (e.g. a heartbeat while the user is reading
|
||||
a long page with no API calls).
|
||||
|
||||
Returns the new ``expires_at`` so the frontend can display a
|
||||
countdown or warning before the absolute cap.
|
||||
"""
|
||||
session = g.current_session
|
||||
session.refresh()
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"expires_at": session.expires_at.isoformat() + "Z"
|
||||
if session.expires_at.isoformat()[-1] != "Z"
|
||||
else session.expires_at.isoformat(),
|
||||
},
|
||||
message="Session refreshed",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/token", methods=["GET"])
|
||||
@login_required
|
||||
def get_token():
|
||||
@@ -246,7 +272,8 @@ def get_token():
|
||||
parsed_redirect = urlparse(redirect_url)
|
||||
redirect_origin = f"{parsed_redirect.scheme}://{parsed_redirect.netloc}"
|
||||
|
||||
if redirect_origin not in allowed_origins:
|
||||
wildcard = "*" in allowed_origins
|
||||
if not wildcard and redirect_origin not in allowed_origins:
|
||||
return api_response(success=False, message="Redirect URL is not allowed.", status=400, error_type="INVALID_REDIRECT")
|
||||
|
||||
sep = "&" if "?" in redirect_url else "?"
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Password reset, email verification, and account activation endpoints."""
|
||||
import logging
|
||||
from flask import request, current_app
|
||||
from flask import request, current_app, g
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.extensions import limiter
|
||||
from gatehouse_app.utils.decorators import login_required
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
from gatehouse_app.services.notification_service import NotificationService
|
||||
@@ -151,6 +152,36 @@ def resend_verification():
|
||||
return api_response(data={}, message="If an account exists for this email and is not yet verified, you will receive a verification link shortly.")
|
||||
|
||||
|
||||
|
||||
@api_v1_bp.route("/auth/me/resend-verification", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config["RATELIMIT_AUTH_RESEND_VERIFICATION"])
|
||||
@login_required
|
||||
def resend_verification_authenticated():
|
||||
from gatehouse_app.models import EmailVerificationToken
|
||||
|
||||
user = g.current_user
|
||||
if not user.email_verified:
|
||||
try:
|
||||
verify_token = EmailVerificationToken.generate(user_id=user.id)
|
||||
app_url = current_app.config.get("APP_URL", "http://localhost:8080")
|
||||
verify_link = f"{app_url}/verify-email?token={verify_token.token}"
|
||||
email_body = build_email_verification_html(
|
||||
user_name=user.full_name or user.email,
|
||||
verify_link=verify_link,
|
||||
expiry_hours=24,
|
||||
)
|
||||
NotificationService._send_email_async(
|
||||
to_address=user.email,
|
||||
subject="Verify your Secuird email address",
|
||||
body=f"Verify your Secuird email: {verify_link}",
|
||||
html_body=email_body,
|
||||
)
|
||||
_logger.info(f"Verification email sent for authenticated user {user.id}")
|
||||
except Exception as exc:
|
||||
_logger.exception(f"Error sending verification email: {exc}")
|
||||
|
||||
return api_response(data={}, message="If your email is not yet verified, you will receive a verification link shortly.")
|
||||
|
||||
@api_v1_bp.route("/auth/activate", methods=["POST"])
|
||||
def activate_account():
|
||||
import secrets
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Contact form endpoint for website enquiries."""
|
||||
import logging
|
||||
|
||||
from flask import request, current_app
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
from gatehouse_app.extensions import limiter
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.schemas.contact_schema import ContactSchema
|
||||
from gatehouse_app.services.notification_service import NotificationService
|
||||
from gatehouse_app.services.email_templates import build_contact_enquiry_html
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hardcoded destination for all contact submissions
|
||||
CONTACT_DESTINATION = "info@secuird.tech"
|
||||
|
||||
|
||||
@api_v1_bp.route("/contact", methods=["POST"])
|
||||
@limiter.limit("5 per hour")
|
||||
def contact():
|
||||
"""Handle contact form submissions from the marketing website.
|
||||
|
||||
Accepts: email, name, company, enquiry_type, message, interest_area, _hp.
|
||||
Sends an email to info@secuird.tech with the enquiry details.
|
||||
Silently discards submissions where the honeypot field (_hp) is filled.
|
||||
"""
|
||||
try:
|
||||
schema = ContactSchema()
|
||||
data = schema.load(request.get_json() or {})
|
||||
except ValidationError as err:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Invalid request data",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=err.messages,
|
||||
)
|
||||
|
||||
# Honeypot check — silently succeed without sending
|
||||
if data.get("_hp"):
|
||||
logger.info(f"[Contact] Honeypot triggered, ip={request.remote_addr}")
|
||||
return api_response(message="Thank you for your message!")
|
||||
|
||||
enquiry_type = data.get("enquiry_type") or "general"
|
||||
email = data.get("email") or ""
|
||||
|
||||
# Build and send email
|
||||
html_body = build_contact_enquiry_html(
|
||||
enquiry_type=enquiry_type,
|
||||
submitter_email=email,
|
||||
name=data.get("name"),
|
||||
company=data.get("company"),
|
||||
interest_area=data.get("interest_area"),
|
||||
message=data.get("message"),
|
||||
)
|
||||
|
||||
NotificationService._send_email_async(
|
||||
to_address=CONTACT_DESTINATION,
|
||||
subject=f"Secuird Website: {enquiry_type.replace('_', ' ').title()} from {email}",
|
||||
body=f"New contact enquiry ({enquiry_type}) from {email}",
|
||||
html_body=html_body,
|
||||
)
|
||||
|
||||
logger.info(f"[Contact] enquiry_type={enquiry_type} ip={request.remote_addr}")
|
||||
|
||||
return api_response(message="Thank you for your message!")
|
||||
@@ -21,7 +21,7 @@ def admin_list_app_providers():
|
||||
return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
|
||||
|
||||
PROVIDERS = [{"id": "google", "name": "Google"}, {"id": "github", "name": "GitHub"}, {"id": "microsoft", "name": "Microsoft"}]
|
||||
db_configs = {c.provider_type: c for c in ApplicationProviderConfig.query.all()}
|
||||
db_configs = {c.provider_type: c for c in ApplicationProviderConfig.query.filter_by(deleted_at=None).all()}
|
||||
|
||||
result = []
|
||||
for p in PROVIDERS:
|
||||
@@ -64,7 +64,7 @@ def admin_configure_app_provider(provider: str):
|
||||
if not client_id:
|
||||
return api_response(success=False, message="client_id is required", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
|
||||
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider, deleted_at=None).first()
|
||||
if cfg:
|
||||
cfg.client_id = client_id
|
||||
if client_secret:
|
||||
@@ -90,7 +90,6 @@ def admin_delete_app_provider(provider: str):
|
||||
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
|
||||
from gatehouse_app.models import OrganizationMember
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
admin_memberships = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == g.current_user.id,
|
||||
@@ -100,10 +99,9 @@ def admin_delete_app_provider(provider: str):
|
||||
if not admin_memberships:
|
||||
return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
|
||||
|
||||
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider).first()
|
||||
cfg = ApplicationProviderConfig.query.filter_by(provider_type=provider, deleted_at=None).first()
|
||||
if not cfg:
|
||||
return api_response(success=False, message=f"Provider '{provider}' is not configured", status=404, error_type="NOT_FOUND")
|
||||
|
||||
db.session.delete(cfg)
|
||||
db.session.commit()
|
||||
cfg.delete()
|
||||
return api_response(message=f"{provider.capitalize()} OAuth provider configuration removed")
|
||||
|
||||
@@ -174,6 +174,7 @@ def select_organization():
|
||||
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
method_type=state_record.provider_type,
|
||||
deleted_at=None,
|
||||
).order_by(AuthenticationMethod.created_at.desc()).first()
|
||||
|
||||
if not auth_method:
|
||||
@@ -181,16 +182,16 @@ def select_organization():
|
||||
|
||||
user = auth_method.user
|
||||
|
||||
org = Organization.query.get(organization_id)
|
||||
org = Organization.query.filter_by(id=organization_id, deleted_at=None).first()
|
||||
if not org:
|
||||
return api_response(success=False, message="Organization not found", status=404, error_type="NOT_FOUND")
|
||||
|
||||
member = OrganizationMember.query.filter_by(user_id=user.id, organization_id=organization_id).first()
|
||||
member = OrganizationMember.query.filter_by(user_id=user.id, organization_id=organization_id, deleted_at=None).first()
|
||||
if not member:
|
||||
return api_response(success=False, message="You are not a member of this organization", status=403, error_type="FORBIDDEN")
|
||||
|
||||
from gatehouse_app.services.session_service import SessionService
|
||||
session = SessionService.create_session(user=user, organization_id=organization_id)
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user)
|
||||
state_record.mark_used()
|
||||
|
||||
provider_type_val = state_record.provider_type.value if isinstance(state_record.provider_type, _AuthMethodType) else state_record.provider_type
|
||||
|
||||
@@ -14,13 +14,13 @@ from gatehouse_app.api.v1.external_auth._helpers import get_provider_type, _get_
|
||||
def list_providers():
|
||||
from gatehouse_app.models.auth.authentication_method import ApplicationProviderConfig
|
||||
|
||||
app_configs = {c.provider_type.lower(): c for c in ApplicationProviderConfig.query.filter_by(is_enabled=True).all()}
|
||||
app_configs = {c.provider_type.lower(): c for c in ApplicationProviderConfig.query.filter_by(is_enabled=True, deleted_at=None).all()}
|
||||
|
||||
user_orgs = g.current_user.get_organizations()
|
||||
org_configs = {}
|
||||
if user_orgs:
|
||||
organization_id = user_orgs[0].id
|
||||
org_level = ExternalProviderConfig.query.filter_by(organization_id=organization_id).all()
|
||||
org_level = ExternalProviderConfig.query.filter_by(organization_id=organization_id, deleted_at=None).all()
|
||||
org_configs = {c.provider_type.lower(): c for c in org_level}
|
||||
|
||||
def provider_info(provider_id, name):
|
||||
@@ -50,11 +50,11 @@ def get_provider_config(provider: str):
|
||||
return api_response(success=False, message="No organizations found for user", status=400, error_type="BAD_REQUEST")
|
||||
|
||||
organization_id = user_orgs[0].id
|
||||
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id).first()
|
||||
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id, deleted_at=None).first()
|
||||
if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
|
||||
return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
|
||||
|
||||
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value).first()
|
||||
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value, deleted_at=None).first()
|
||||
if not config:
|
||||
return api_response(success=False, message=f"{provider.title()} OAuth is not configured", status=404, error_type="NOT_FOUND")
|
||||
|
||||
@@ -74,7 +74,7 @@ def create_or_update_provider_config(provider: str):
|
||||
return api_response(success=False, message="No organizations found for user", status=400, error_type="BAD_REQUEST")
|
||||
|
||||
organization_id = user_orgs[0].id
|
||||
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id).first()
|
||||
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id, deleted_at=None).first()
|
||||
if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
|
||||
return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
|
||||
|
||||
@@ -85,7 +85,7 @@ def create_or_update_provider_config(provider: str):
|
||||
if not client_id:
|
||||
return api_response(success=False, message="client_id is required", status=400, error_type="VALIDATION_ERROR")
|
||||
|
||||
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value).first()
|
||||
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value, deleted_at=None).first()
|
||||
is_new = config is None
|
||||
|
||||
if config:
|
||||
@@ -137,11 +137,11 @@ def delete_provider_config(provider: str):
|
||||
return api_response(success=False, message="No organizations found for user", status=400, error_type="BAD_REQUEST")
|
||||
|
||||
organization_id = user_orgs[0].id
|
||||
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id).first()
|
||||
member = OrganizationMember.query.filter_by(user_id=g.current_user.id, organization_id=organization_id, deleted_at=None).first()
|
||||
if not member or member.role not in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
|
||||
return api_response(success=False, message="Admin access required", status=403, error_type="FORBIDDEN")
|
||||
|
||||
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value).first()
|
||||
config = ExternalProviderConfig.query.filter_by(organization_id=organization_id, provider_type=provider_type.value, deleted_at=None).first()
|
||||
if not config:
|
||||
return api_response(success=False, message=f"{provider.title()} OAuth is not configured", status=404, error_type="NOT_FOUND")
|
||||
|
||||
|
||||
@@ -819,9 +819,9 @@ def oidc_register():
|
||||
|
||||
org_id = data.get("organization_id")
|
||||
if org_id:
|
||||
organization = Organization.query.get(org_id)
|
||||
organization = Organization.query.filter_by(id=org_id, deleted_at=None).first()
|
||||
else:
|
||||
organization = Organization.query.filter_by(is_active=True).first()
|
||||
organization = Organization.query.filter_by(is_active=True, deleted_at=None).first()
|
||||
|
||||
if not organization:
|
||||
organization = Organization(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Organization core CRUD endpoints."""
|
||||
import logging
|
||||
from flask import g, request
|
||||
from marshmallow import ValidationError
|
||||
from gatehouse_app.api.v1 import api_v1_bp
|
||||
@@ -12,6 +13,15 @@ from gatehouse_app.services.organization_service import OrganizationService
|
||||
@login_required
|
||||
@full_access_required
|
||||
def create_organization():
|
||||
try:
|
||||
org_count = OrganizationService.get_user_org_count(g.current_user.id)
|
||||
if org_count is not None and org_count >= 10:
|
||||
return api_response(success=False, message="You cannot belong to more than 10 organizations", status=400, error_type="ORG_LIMIT_REACHED")
|
||||
except Exception as e:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"[Org] Failed to check org count for user {g.current_user.id}: {e}")
|
||||
# Fail open to avoid blocking legitimate users when the count service is unavailable
|
||||
|
||||
try:
|
||||
schema = OrganizationCreateSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
@@ -8,7 +8,8 @@ from gatehouse_app.services.notification_service import NotificationService
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
from gatehouse_app.services.organization_service import OrganizationService
|
||||
from gatehouse_app.services.email_templates import build_org_invite_html
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
from gatehouse_app.utils.constants import AuditAction, OrganizationRole
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
|
||||
@api_v1_bp.route("/organizations/<org_id>/invites", methods=["POST"])
|
||||
@@ -56,6 +57,19 @@ def create_org_invite(org_id):
|
||||
logging.getLogger(__name__).info(f"[INVITE] Email queued for {email}")
|
||||
email_sent = True # async — assume queued successfully
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.ORG_INVITE_SENT,
|
||||
user_id=g.current_user.id,
|
||||
organization_id=org_id,
|
||||
resource_type="org_invite",
|
||||
resource_id=invite.id,
|
||||
metadata={
|
||||
"invited_email": email,
|
||||
"role": role,
|
||||
},
|
||||
description=f"Invitation sent to {email} with role {role}",
|
||||
)
|
||||
|
||||
response_data = {
|
||||
"invite": {
|
||||
"id": invite.id,
|
||||
@@ -173,7 +187,7 @@ def accept_invite(token):
|
||||
if session_user.email.lower() != invite.email.lower():
|
||||
return api_response(
|
||||
success=False,
|
||||
message="This invite was sent to a different email address.",
|
||||
message="You are already logged in and this invite was sent to a different email address.",
|
||||
status=403,
|
||||
error_type="EMAIL_MISMATCH",
|
||||
)
|
||||
|
||||
@@ -56,7 +56,10 @@ def add_organization_member(org_id):
|
||||
@full_access_required
|
||||
def remove_organization_member(org_id, user_id):
|
||||
org = OrganizationService.get_organization_by_id(org_id)
|
||||
OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id)
|
||||
try:
|
||||
OrganizationService.remove_member(org=org, user_id=user_id, remover_id=g.current_user.id)
|
||||
except ValueError as e:
|
||||
return api_response(success=False, message=str(e), status=403, error_type="OWNER_PROTECTION")
|
||||
return api_response(message="Member removed successfully")
|
||||
|
||||
|
||||
@@ -155,7 +158,7 @@ def send_mfa_reminder(org_id, user_id):
|
||||
if not user:
|
||||
return api_response(success=False, message="User not found", status=404)
|
||||
|
||||
compliance = MfaPolicyCompliance.query.filter_by(user_id=user_id, organization_id=org_id).first()
|
||||
compliance = MfaPolicyCompliance.query.filter_by(user_id=user_id, organization_id=org_id, deleted_at=None).first()
|
||||
policy = OrganizationSecurityPolicy.query.filter_by(organization_id=org_id).first()
|
||||
|
||||
if compliance and policy and compliance.deadline_at:
|
||||
|
||||
@@ -11,16 +11,23 @@ ssh_ca_service = SSHCASigningService()
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_org_ca_for_user(user, ca_type: str = "user"):
|
||||
def _get_org_ca_for_user(user, ca_type: str = "user", organization_id=None):
|
||||
try:
|
||||
from gatehouse_app.models.ssh_ca.ca import CA, CaType
|
||||
org_ids = [m.organization_id for m in user.organization_memberships]
|
||||
|
||||
if organization_id:
|
||||
org_ids = [organization_id]
|
||||
else:
|
||||
org_ids = [m.organization_id for m in user.get_active_memberships()]
|
||||
|
||||
if not org_ids:
|
||||
return None
|
||||
|
||||
return CA.query.filter(
|
||||
CA.organization_id.in_(org_ids),
|
||||
CA.ca_type == CaType(ca_type),
|
||||
CA.is_active == True, # noqa: E712
|
||||
CA.is_active == True,
|
||||
CA.deleted_at.is_(None),
|
||||
).first()
|
||||
except Exception:
|
||||
return None
|
||||
@@ -34,7 +41,7 @@ def _get_or_create_system_ca():
|
||||
import os
|
||||
|
||||
try:
|
||||
existing = CA.query.filter_by(name="system-config-ca").first()
|
||||
existing = CA.query.filter_by(name="system-config-ca", deleted_at=None).first()
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
@@ -60,7 +67,7 @@ def _get_or_create_system_ca():
|
||||
|
||||
fingerprint = compute_ssh_fingerprint(pub_key)
|
||||
|
||||
existing_by_fp = CA.query.filter_by(fingerprint=fingerprint).first()
|
||||
existing_by_fp = CA.query.filter_by(fingerprint=fingerprint, deleted_at=None).first()
|
||||
if existing_by_fp:
|
||||
return existing_by_fp
|
||||
|
||||
|
||||
@@ -14,6 +14,12 @@ from gatehouse_app.utils.decorators import login_required
|
||||
from gatehouse_app.utils.response import api_response
|
||||
|
||||
|
||||
def _validate_uuid(uuid_str: str) -> bool:
|
||||
"""Validate UUID format."""
|
||||
import re
|
||||
return bool(re.match(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', uuid_str, re.I))
|
||||
|
||||
|
||||
@ssh_bp.route('/dept-cert-policy', methods=['GET'])
|
||||
@login_required
|
||||
def get_my_dept_cert_policy():
|
||||
@@ -60,6 +66,7 @@ def sign_certificate():
|
||||
cert_type = data.get('cert_type', 'user')
|
||||
key_id = data.get('key_id') or data.get('cert_id')
|
||||
expiry_hours = data.get('expiry_hours')
|
||||
requested_org_id = data.get('organization_id')
|
||||
|
||||
AuditLog.log(
|
||||
action=AuditAction.SSH_CERT_REQUESTED,
|
||||
@@ -67,22 +74,63 @@ def sign_certificate():
|
||||
description=(f'{user.email} requested a certificate' + (f' for principals: {", ".join(requested_principals)}' if requested_principals else '')),
|
||||
)
|
||||
|
||||
# Validate organization_id if provided
|
||||
if requested_org_id and not _validate_uuid(requested_org_id):
|
||||
return api_response(success=False, message="Invalid organization_id format. Must be a valid UUID.", status=400, error_type="INVALID_ORG_ID")
|
||||
|
||||
# Get user's active organization memberships
|
||||
memberships = OrganizationMember.query.filter_by(user_id=user_id, deleted_at=None).all()
|
||||
active_memberships = [om for om in memberships if om.organization and om.organization.deleted_at is None]
|
||||
|
||||
if not active_memberships:
|
||||
return api_response(success=False, message="You are not a member of any active organizations.", status=400, error_type="NO_ORG_MEMBERSHIPS")
|
||||
|
||||
# Select target organization
|
||||
target_org = None
|
||||
if requested_org_id:
|
||||
# Check if user is member of the requested organization
|
||||
target_membership = next((om for om in active_memberships if str(om.organization_id).lower() == requested_org_id.lower()), None)
|
||||
if not target_membership:
|
||||
return api_response(success=False, message="You are not a member of the specified organization.", status=403, error_type="NOT_ORG_MEMBER")
|
||||
|
||||
target_org = target_membership.organization
|
||||
if not target_org or target_org.deleted_at is not None:
|
||||
return api_response(success=False, message="The specified organization was not found or has been deleted.", status=404, error_type="ORG_NOT_FOUND")
|
||||
else:
|
||||
# No organization specified - use default logic for backward compatibility
|
||||
if len(active_memberships) > 1:
|
||||
org_names = [om.organization.name for om in active_memberships]
|
||||
orgs_data = [
|
||||
{
|
||||
"id": m.organization_id,
|
||||
"name": m.organization.name,
|
||||
"role": m.role.value if hasattr(m.role, "value") else str(m.role)
|
||||
}
|
||||
for m in active_memberships
|
||||
]
|
||||
return api_response(
|
||||
success=False,
|
||||
message="You are a member of multiple organizations. Please specify organization_id.",
|
||||
status=400,
|
||||
error_type="MULTIPLE_ORGS_AMBIGUOUS",
|
||||
error_details={"organizations": orgs_data}
|
||||
)
|
||||
target_org = active_memberships[0].organization
|
||||
|
||||
# Get allowed principals for the selected organization
|
||||
allowed_principal_names = set()
|
||||
memberships = OrganizationMember.query.filter_by(user_id=user_id).all()
|
||||
for om in memberships:
|
||||
org = om.organization
|
||||
if not org or org.deleted_at is not None:
|
||||
continue
|
||||
role = om.role
|
||||
target_membership = next((om for om in active_memberships if str(om.organization_id).lower() == str(target_org.id).lower()), None)
|
||||
if target_membership:
|
||||
role = target_membership.role
|
||||
if role in (OrganizationRole.ADMIN, OrganizationRole.OWNER):
|
||||
for p in Principal.query.filter_by(organization_id=org.id, deleted_at=None).all():
|
||||
for p in Principal.query.filter_by(organization_id=target_org.id, deleted_at=None).all():
|
||||
allowed_principal_names.add(p.name)
|
||||
else:
|
||||
for pm in PrincipalMembership.query.filter_by(user_id=user_id, deleted_at=None).all():
|
||||
if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None:
|
||||
if pm.principal and pm.principal.organization_id == target_org.id and pm.principal.deleted_at is None:
|
||||
allowed_principal_names.add(pm.principal.name)
|
||||
for dm in DepartmentMembership.query.filter_by(user_id=user_id, deleted_at=None).all():
|
||||
if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None:
|
||||
if dm.department and dm.department.organization_id == target_org.id and dm.department.deleted_at is None:
|
||||
for dp in DepartmentPrincipal.query.filter_by(department_id=dm.department_id, deleted_at=None).all():
|
||||
if dp.principal and dp.principal.deleted_at is None:
|
||||
allowed_principal_names.add(dp.principal.name)
|
||||
@@ -114,7 +162,8 @@ def sign_certificate():
|
||||
if not ssh_key.verified:
|
||||
return api_response(success=False, message="SSH key is not verified. Verify it before requesting a certificate.", status=400, error_type="KEY_NOT_VERIFIED")
|
||||
|
||||
db_ca = _get_org_ca_for_user(user, ca_type=cert_type)
|
||||
# Use the selected organization's ID for CA selection
|
||||
db_ca = _get_org_ca_for_user(user, ca_type=cert_type, organization_id=target_org.id)
|
||||
if db_ca is None:
|
||||
return api_response(
|
||||
success=False,
|
||||
@@ -122,11 +171,7 @@ def sign_certificate():
|
||||
status=503, error_type="CA_NOT_CONFIGURED",
|
||||
)
|
||||
|
||||
is_org_admin = any(
|
||||
om.role in (OrganizationRole.ADMIN, OrganizationRole.OWNER)
|
||||
for om in memberships
|
||||
if om.organization and om.organization.deleted_at is None
|
||||
)
|
||||
is_org_admin = target_membership.role in (OrganizationRole.ADMIN, OrganizationRole.OWNER) if target_membership else False
|
||||
|
||||
dept_policy = _get_merged_dept_cert_policy(user_id)
|
||||
if dept_policy:
|
||||
@@ -146,11 +191,7 @@ def sign_certificate():
|
||||
else:
|
||||
policy_extensions = None
|
||||
|
||||
org_slugs = sorted({
|
||||
om.organization.slug for om in memberships
|
||||
if om.organization and om.organization.deleted_at is None and getattr(om.organization, 'slug', None)
|
||||
})
|
||||
org_slug = org_slugs[0] if org_slugs else "unknown"
|
||||
org_slug = getattr(target_org, 'slug', 'unknown')
|
||||
full_name = getattr(user, 'full_name', None) or getattr(user, 'name', None) or "unknown"
|
||||
cert_identity = f"{user.email} ({full_name}) [org:{org_slug}]"
|
||||
|
||||
@@ -185,12 +226,13 @@ def sign_certificate():
|
||||
resource_type='SSHCertificate', resource_id=cert_record.id if cert_record else key_id,
|
||||
ip_address=request.remote_addr,
|
||||
description=f'Certificate serial={response.serial} issued for {user.email}; principals: {", ".join(principals)}',
|
||||
extra_data={'serial': response.serial, 'key_id': cert_identity, 'principals': principals, 'ca_id': str(db_ca.id), 'ssh_key_id': str(key_id)},
|
||||
extra_data={'serial': response.serial, 'key_id': cert_identity, 'principals': principals, 'ca_id': str(db_ca.id), 'ssh_key_id': str(key_id), 'organization_id': str(target_org.id), 'organization_name': target_org.name},
|
||||
)
|
||||
|
||||
if cert_record:
|
||||
CertificateAuditLog.log(
|
||||
certificate_id=cert_record.id, action='issued', user_id=user_id,
|
||||
organization_id=str(target_org.id),
|
||||
ip_address=request.remote_addr, user_agent=request.headers.get('User-Agent'),
|
||||
message=f'Certificate serial={response.serial} issued for {user.email}; principals: {", ".join(principals)}',
|
||||
extra_data={
|
||||
@@ -198,6 +240,7 @@ def sign_certificate():
|
||||
'ca_id': str(db_ca.id), 'ssh_key_id': str(key_id),
|
||||
'valid_after': response.valid_after.isoformat() if response.valid_after else None,
|
||||
'valid_before': response.valid_before.isoformat() if response.valid_before else None,
|
||||
'organization_id': str(target_org.id),
|
||||
},
|
||||
success=True,
|
||||
)
|
||||
@@ -207,6 +250,8 @@ def sign_certificate():
|
||||
'principals': response.principals,
|
||||
'valid_after': response.valid_after.isoformat() if response.valid_after else None,
|
||||
'valid_before': response.valid_before.isoformat() if response.valid_before else None,
|
||||
'organization_id': str(target_org.id),
|
||||
'organization_name': target_org.name,
|
||||
}
|
||||
if cert_record:
|
||||
result['cert_id'] = str(cert_record.id)
|
||||
@@ -371,7 +416,13 @@ def revoke_certificate(cert_id):
|
||||
|
||||
cert.revoke(reason=reason)
|
||||
AuditLog.log(action=AuditAction.SSH_CERT_REVOKED, user_id=user_id, resource_type='SSHCertificate', resource_id=cert_id, ip_address=request.remote_addr, description=f'Revoked: {reason}')
|
||||
CertificateAuditLog.log(certificate_id=cert_id, action='revoked', user_id=user_id, ip_address=request.remote_addr, user_agent=request.headers.get('User-Agent'), message=f'Certificate revoked: {reason}', success=True)
|
||||
|
||||
# Get organization from certificate's CA for audit logging
|
||||
from gatehouse_app.models.ssh_ca.ca import CA
|
||||
ca = CA.query.get(cert.ca_id)
|
||||
org_id = ca.organization_id if ca else None
|
||||
|
||||
CertificateAuditLog.log(certificate_id=cert_id, action='revoked', user_id=user_id, organization_id=org_id, ip_address=request.remote_addr, user_agent=request.headers.get('User-Agent'), message=f'Certificate revoked: {reason}', success=True)
|
||||
|
||||
return api_response(success=True, message='Certificate revoked successfully', data={'status': 'revoked', 'cert_id': cert_id, 'reason': reason}, status=200)
|
||||
except Exception as e:
|
||||
|
||||
@@ -32,11 +32,12 @@ def add_ssh_key():
|
||||
return api_response(success=False, message='public_key is required', status=400, error_type='BAD_REQUEST')
|
||||
|
||||
try:
|
||||
ssh_key = ssh_key_service.add_ssh_key(user_id=user_id, public_key=public_key, description=description)
|
||||
AuditLog.log(action=AuditAction.SSH_KEY_ADDED, user_id=user_id, resource_type='SSHKey', resource_id=ssh_key.id, ip_address=request.remote_addr)
|
||||
return api_response(success=True, message='SSH key added', data=ssh_key.to_dict(), status=201)
|
||||
except SSHKeyAlreadyExistsError as e:
|
||||
return api_response(success=False, message=e.message, status=409, error_type='SSH_KEY_ALREADY_EXISTS')
|
||||
ssh_key, is_new = ssh_key_service.add_ssh_key(user_id=user_id, public_key=public_key, description=description)
|
||||
if is_new:
|
||||
AuditLog.log(action=AuditAction.SSH_KEY_ADDED, user_id=user_id, resource_type='SSHKey', resource_id=ssh_key.id, ip_address=request.remote_addr)
|
||||
return api_response(success=True, message='SSH key added', data=ssh_key.to_dict(), status=201)
|
||||
else:
|
||||
return api_response(success=True, message='SSH key already exists', data=ssh_key.to_dict(), status=200)
|
||||
except IntegrityError:
|
||||
return api_response(success=False, message='SSH key already exists', status=409, error_type='SSH_KEY_ALREADY_EXISTS')
|
||||
except SSHKeyError as e:
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
"""Superadmin API blueprint."""
|
||||
import logging
|
||||
from flask import Blueprint
|
||||
|
||||
from gatehouse_app.extensions import limiter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create superadmin blueprint
|
||||
superadmin_bp = Blueprint("superadmin", __name__, url_prefix="/superadmin")
|
||||
|
||||
# Import route modules to register them
|
||||
from gatehouse_app.api.v1.superadmin import auth, organizations, organization_members, usage_analytics, users, billing, cas # noqa: F401
|
||||
@@ -0,0 +1,286 @@
|
||||
"""Superadmin authentication endpoints."""
|
||||
import logging
|
||||
from flask import request, g, current_app
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from gatehouse_app.api.v1.superadmin import superadmin_bp
|
||||
from gatehouse_app.extensions import limiter
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService
|
||||
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginSchema:
|
||||
"""Schema for superadmin login."""
|
||||
|
||||
@staticmethod
|
||||
def load(data):
|
||||
"""Validate login data."""
|
||||
errors = {}
|
||||
|
||||
if not data.get('email'):
|
||||
errors['email'] = ['Email is required']
|
||||
elif '@' not in data['email']:
|
||||
errors['email'] = ['Invalid email format']
|
||||
|
||||
if not data.get('password'):
|
||||
errors['password'] = ['Password is required']
|
||||
|
||||
if errors:
|
||||
raise ValidationError(errors)
|
||||
|
||||
return {
|
||||
'email': data['email'].lower().strip(),
|
||||
'password': data['password'],
|
||||
}
|
||||
|
||||
|
||||
@superadmin_bp.route("/auth/login", methods=["POST"])
|
||||
@limiter.limit(lambda: current_app.config.get("RATELIMIT_AUTH_LOGIN", "100 per minute"))
|
||||
def login():
|
||||
"""Superadmin login endpoint.
|
||||
|
||||
Authenticates with email/password and returns a session token.
|
||||
"""
|
||||
try:
|
||||
schema = LoginSchema()
|
||||
data = schema.load(request.json)
|
||||
|
||||
# Authenticate
|
||||
superadmin = SuperadminAuthService.authenticate(
|
||||
email=data['email'],
|
||||
credentials=data['password']
|
||||
)
|
||||
|
||||
# Create session (default 8 hours)
|
||||
session = SuperadminAuthService.create_session(
|
||||
superadmin_id=superadmin.id,
|
||||
duration_seconds=28800 # 8 hours
|
||||
)
|
||||
|
||||
expires_str = session.expires_at.isoformat()
|
||||
if not expires_str.endswith('Z'):
|
||||
expires_str += 'Z'
|
||||
|
||||
logger.info(f"[SuperadminAuth] Login successful for: {superadmin.email}")
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"superadmin": superadmin.to_dict(),
|
||||
"token": session.token,
|
||||
"expires_at": expires_str,
|
||||
},
|
||||
message="Login successful",
|
||||
status=200
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages
|
||||
)
|
||||
except InvalidCredentialsError:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Invalid email or password",
|
||||
status=401,
|
||||
error_type="INVALID_CREDENTIALS"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminAuth] Login error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred during login",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR"
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/auth/logout", methods=["POST"])
|
||||
@superadmin_required
|
||||
def logout():
|
||||
"""Superadmin logout endpoint.
|
||||
|
||||
Invalidates the current session.
|
||||
"""
|
||||
try:
|
||||
session = g.superadmin_session
|
||||
if session:
|
||||
SuperadminAuthService.revoke_session(session.id, reason="Superadmin logout")
|
||||
|
||||
return api_response(
|
||||
message="Logout successful"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminAuth] Logout error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred during logout",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR"
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/auth/me", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_current_superadmin():
|
||||
"""Get current superadmin profile.
|
||||
|
||||
Returns the profile of the currently authenticated superadmin.
|
||||
"""
|
||||
try:
|
||||
superadmin = g.current_superadmin
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"superadmin": superadmin.to_dict(),
|
||||
},
|
||||
message="Superadmin retrieved successfully"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminAuth] Get me error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR"
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/auth/impersonate/<user_id>", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="impersonate", resource_type="user")
|
||||
def impersonate_user(user_id):
|
||||
"""Create emergency access session by impersonating a user.
|
||||
|
||||
Creates a temporary session for the target user that allows
|
||||
the superadmin to access the platform as that user.
|
||||
|
||||
This action is fully audited.
|
||||
"""
|
||||
try:
|
||||
superadmin = g.current_superadmin
|
||||
data = request.json or {}
|
||||
reason = data.get('reason', 'Not specified')
|
||||
duration_minutes = data.get('duration_minutes', 15)
|
||||
|
||||
# Limit duration to max 60 minutes
|
||||
duration_minutes = min(duration_minutes, 60)
|
||||
|
||||
# Create emergency access
|
||||
result = SuperadminAuthService.create_emergency_access(
|
||||
superadmin_id=superadmin.id,
|
||||
target_user_id=user_id,
|
||||
reason=reason,
|
||||
duration_minutes=duration_minutes
|
||||
)
|
||||
|
||||
expires_str = result['expires_at'].isoformat()
|
||||
if not expires_str.endswith('Z'):
|
||||
expires_str += 'Z'
|
||||
|
||||
logger.warning(
|
||||
f"[SuperadminAuth] IMPERSONATION: superadmin={superadmin.email} "
|
||||
f"impersonated user_id={user_id} reason={reason}"
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"session_token": result['session'].token,
|
||||
"expires_at": expires_str,
|
||||
"target_user_id": user_id,
|
||||
"reason": reason,
|
||||
"duration_minutes": duration_minutes,
|
||||
},
|
||||
message="Emergency access session created",
|
||||
status=201
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=404,
|
||||
error_type="NOT_FOUND"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminAuth] Impersonate error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR"
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/auth/emergency/<user_id>", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="emergency_access", resource_type="user")
|
||||
def grant_emergency_access(user_id):
|
||||
"""Grant temporary elevated access to a user.
|
||||
|
||||
Similar to impersonate but grants elevated permissions
|
||||
rather than creating a session as the user.
|
||||
|
||||
This action is fully audited.
|
||||
"""
|
||||
try:
|
||||
superadmin = g.current_superadmin
|
||||
data = request.json or {}
|
||||
reason = data.get('reason', 'Not specified')
|
||||
duration_minutes = data.get('duration_minutes', 15)
|
||||
|
||||
# Limit duration to max 60 minutes
|
||||
duration_minutes = min(duration_minutes, 60)
|
||||
|
||||
# Create emergency access
|
||||
result = SuperadminAuthService.create_emergency_access(
|
||||
superadmin_id=superadmin.id,
|
||||
target_user_id=user_id,
|
||||
reason=reason,
|
||||
duration_minutes=duration_minutes
|
||||
)
|
||||
|
||||
expires_str = result['expires_at'].isoformat()
|
||||
if not expires_str.endswith('Z'):
|
||||
expires_str += 'Z'
|
||||
|
||||
logger.warning(
|
||||
f"[SuperadminAuth] EMERGENCY ACCESS: superadmin={superadmin.email} "
|
||||
f"granted access to user_id={user_id} reason={reason}"
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"session_token": result['session'].token,
|
||||
"expires_at": expires_str,
|
||||
"target_user_id": user_id,
|
||||
"reason": reason,
|
||||
"duration_minutes": duration_minutes,
|
||||
},
|
||||
message="Emergency access granted",
|
||||
status=201
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=404,
|
||||
error_type="NOT_FOUND"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminAuth] Emergency access error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR"
|
||||
)
|
||||
@@ -0,0 +1,568 @@
|
||||
"""Superadmin billing endpoints for plans and subscriptions."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from flask import request
|
||||
from gatehouse_app.api.v1.superadmin import superadmin_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============ Plans Endpoints ============
|
||||
|
||||
@superadmin_bp.route("/billing/plans", methods=["GET"])
|
||||
@superadmin_required
|
||||
def list_plans():
|
||||
"""Get all available plans."""
|
||||
try:
|
||||
from gatehouse_app.models.billing.plan import Plan
|
||||
|
||||
plans = Plan.query.filter(Plan.is_active == True).order_by(Plan.price_monthly.asc()).all()
|
||||
|
||||
items = [{
|
||||
"id": p.id,
|
||||
"name": p.name,
|
||||
"slug": p.slug,
|
||||
"description": p.description,
|
||||
"price_monthly": p.price_monthly,
|
||||
"price_yearly": p.price_yearly,
|
||||
"included_users": p.included_users,
|
||||
"overage_rate_per_user": p.overage_rate_per_user,
|
||||
"features": p.features,
|
||||
"stripe_price_id_monthly": p.stripe_price_id_monthly,
|
||||
"stripe_price_id_yearly": p.stripe_price_id_yearly,
|
||||
"is_active": p.is_active,
|
||||
"created_at": p.created_at.isoformat() + "Z" if p.created_at else None,
|
||||
} for p in plans]
|
||||
|
||||
return api_response(data={"items": items}, message="Plans retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Billing] List plans error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/billing/plans", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="plan.create", resource_type="plan")
|
||||
def create_plan():
|
||||
"""Create a new plan."""
|
||||
try:
|
||||
from gatehouse_app.models.billing.plan import Plan
|
||||
from marshmallow import Schema, fields, validate
|
||||
|
||||
class CreatePlanSchema(Schema):
|
||||
name = fields.Str(required=True, validate=validate.Length(min=1, max=100))
|
||||
slug = fields.Str(required=True, validate=validate.Length(min=1, max=50))
|
||||
description = fields.Str(allow_none=True)
|
||||
price_monthly = fields.Int(required=True, validate=validate.Range(min=0))
|
||||
price_yearly = fields.Int(required=True, validate=validate.Range(min=0))
|
||||
included_users = fields.Int(required=True, validate=validate.Range(min=0))
|
||||
overage_rate_per_user = fields.Int(required=True, validate=validate.Range(min=0))
|
||||
features = fields.Dict(allow_none=True)
|
||||
stripe_price_id_monthly = fields.Str(allow_none=True)
|
||||
stripe_price_id_yearly = fields.Str(allow_none=True)
|
||||
|
||||
data = request.json or {}
|
||||
schema = CreatePlanSchema()
|
||||
errors = schema.validate(data)
|
||||
|
||||
if errors:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation error",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=errors,
|
||||
)
|
||||
|
||||
# Check if slug already exists
|
||||
existing = Plan.query.filter_by(slug=data["slug"]).first()
|
||||
if existing:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Plan with this slug already exists",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
plan = Plan(
|
||||
name=data["name"],
|
||||
slug=data["slug"],
|
||||
description=data.get("description"),
|
||||
price_monthly=data["price_monthly"],
|
||||
price_yearly=data["price_yearly"],
|
||||
included_users=data["included_users"],
|
||||
overage_rate_per_user=data["overage_rate_per_user"],
|
||||
features=data.get("features"),
|
||||
stripe_price_id_monthly=data.get("stripe_price_id_monthly"),
|
||||
stripe_price_id_yearly=data.get("stripe_price_id_yearly"),
|
||||
)
|
||||
db.session.add(plan)
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={
|
||||
"id": plan.id,
|
||||
"name": plan.name,
|
||||
"slug": plan.slug,
|
||||
}, message="Plan created successfully", status=201)
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"[Billing] Create plan error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/billing/plans/<plan_id>", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_plan(plan_id):
|
||||
"""Get a single plan by ID."""
|
||||
try:
|
||||
from gatehouse_app.models.billing.plan import Plan
|
||||
|
||||
plan = Plan.query.get(plan_id)
|
||||
if not plan:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Plan not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
return api_response(data={
|
||||
"id": plan.id,
|
||||
"name": plan.name,
|
||||
"slug": plan.slug,
|
||||
"description": plan.description,
|
||||
"price_monthly": plan.price_monthly,
|
||||
"price_yearly": plan.price_yearly,
|
||||
"included_users": plan.included_users,
|
||||
"overage_rate_per_user": plan.overage_rate_per_user,
|
||||
"features": plan.features,
|
||||
"stripe_price_id_monthly": plan.stripe_price_id_monthly,
|
||||
"stripe_price_id_yearly": plan.stripe_price_id_yearly,
|
||||
"is_active": plan.is_active,
|
||||
"created_at": plan.created_at.isoformat() + "Z" if plan.created_at else None,
|
||||
"updated_at": plan.updated_at.isoformat() + "Z" if plan.updated_at else None,
|
||||
}, message="Plan retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Billing] Get plan error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/billing/plans/<plan_id>", methods=["PATCH"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="plan.update", resource_type="plan")
|
||||
def update_plan(plan_id):
|
||||
"""Update a plan."""
|
||||
try:
|
||||
from gatehouse_app.models.billing.plan import Plan
|
||||
from marshmallow import Schema, fields, validate
|
||||
|
||||
class UpdatePlanSchema(Schema):
|
||||
name = fields.Str(validate=validate.Length(min=1, max=100))
|
||||
description = fields.Str(allow_none=True)
|
||||
price_monthly = fields.Int(validate=validate.Range(min=0))
|
||||
price_yearly = fields.Int(validate=validate.Range(min=0))
|
||||
included_users = fields.Int(validate=validate.Range(min=0))
|
||||
overage_rate_per_user = fields.Int(validate=validate.Range(min=0))
|
||||
features = fields.Dict(allow_none=True)
|
||||
stripe_price_id_monthly = fields.Str(allow_none=True)
|
||||
stripe_price_id_yearly = fields.Str(allow_none=True)
|
||||
is_active = fields.Bool()
|
||||
|
||||
plan = Plan.query.get(plan_id)
|
||||
if not plan:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Plan not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
data = request.json or {}
|
||||
schema = UpdatePlanSchema()
|
||||
errors = schema.validate(data)
|
||||
|
||||
if errors:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation error",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=errors,
|
||||
)
|
||||
|
||||
# Update fields
|
||||
for key, value in data.items():
|
||||
if hasattr(plan, key):
|
||||
setattr(plan, key, value)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={
|
||||
"id": plan.id,
|
||||
"name": plan.name,
|
||||
"slug": plan.slug,
|
||||
}, message="Plan updated successfully")
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"[Billing] Update plan error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/billing/plans/<plan_id>", methods=["DELETE"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="plan.delete", resource_type="plan")
|
||||
def delete_plan(plan_id):
|
||||
"""Soft-delete a plan by setting is_active=False."""
|
||||
try:
|
||||
from gatehouse_app.models.billing.plan import Plan
|
||||
|
||||
plan = Plan.query.get(plan_id)
|
||||
if not plan:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Plan not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
plan.is_active = False
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={"id": plan.id}, message="Plan deleted successfully")
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"[Billing] Delete plan error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
# ============ Subscriptions Endpoints ============
|
||||
|
||||
@superadmin_bp.route("/billing/subscriptions", methods=["GET"])
|
||||
@superadmin_required
|
||||
def list_subscriptions():
|
||||
"""Get all subscriptions with optional filters."""
|
||||
try:
|
||||
from gatehouse_app.models.billing.subscription import Subscription
|
||||
from gatehouse_app.models.billing.plan import Plan
|
||||
|
||||
page = max(1, int(request.args.get("page", 1)))
|
||||
per_page = min(100, max(1, int(request.args.get("per_page", 20))))
|
||||
plan_id = request.args.get("plan_id")
|
||||
status = request.args.get("status")
|
||||
|
||||
query = Subscription.query
|
||||
|
||||
if plan_id:
|
||||
query = query.filter(Subscription.plan_id == plan_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(Subscription.status == status)
|
||||
|
||||
total = query.count()
|
||||
subs = query.order_by(Subscription.created_at.desc()).offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
items = []
|
||||
for sub in subs:
|
||||
org = Organization.query.get(sub.organization_id)
|
||||
plan = Plan.query.get(sub.plan_id) if sub.plan_id else None
|
||||
|
||||
# Calculate MRR
|
||||
if plan and sub.status == "active":
|
||||
mrr = plan.price_monthly if sub.billing_cycle == "monthly" else plan.price_yearly // 12
|
||||
else:
|
||||
mrr = 0
|
||||
|
||||
items.append({
|
||||
"id": sub.id,
|
||||
"organization_id": sub.organization_id,
|
||||
"org_name": org.name if org else "Unknown",
|
||||
"plan_id": sub.plan_id,
|
||||
"plan_name": plan.name if plan else "Unknown",
|
||||
"status": sub.status,
|
||||
"billing_cycle": sub.billing_cycle,
|
||||
"mrr": mrr,
|
||||
"current_period_start": sub.current_period_start.isoformat() + "Z" if sub.current_period_start else None,
|
||||
"current_period_end": sub.current_period_end.isoformat() + "Z" if sub.current_period_end else None,
|
||||
"trial_ends_at": sub.trial_ends_at.isoformat() + "Z" if sub.trial_ends_at else None,
|
||||
"cancel_at_period_end": sub.cancel_at_period_end,
|
||||
"created_at": sub.created_at.isoformat() + "Z" if sub.created_at else None,
|
||||
})
|
||||
|
||||
return api_response(data={
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page > 0 else 0,
|
||||
}, message="Subscriptions retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Billing] List subscriptions error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/billing/subscriptions/<sub_id>", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_subscription(sub_id):
|
||||
"""Get a single subscription."""
|
||||
try:
|
||||
from gatehouse_app.models.billing.subscription import Subscription
|
||||
from gatehouse_app.models.billing.plan import Plan
|
||||
|
||||
sub = Subscription.query.get(sub_id)
|
||||
if not sub:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Subscription not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
org = Organization.query.get(sub.organization_id)
|
||||
plan = Plan.query.get(sub.plan_id) if sub.plan_id else None
|
||||
|
||||
return api_response(data={
|
||||
"id": sub.id,
|
||||
"organization_id": sub.organization_id,
|
||||
"org_name": org.name if org else "Unknown",
|
||||
"plan_id": sub.plan_id,
|
||||
"plan_name": plan.name if plan else None,
|
||||
"status": sub.status,
|
||||
"billing_cycle": sub.billing_cycle,
|
||||
"current_period_start": sub.current_period_start.isoformat() + "Z" if sub.current_period_start else None,
|
||||
"current_period_end": sub.current_period_end.isoformat() + "Z" if sub.current_period_end else None,
|
||||
"trial_ends_at": sub.trial_ends_at.isoformat() + "Z" if sub.trial_ends_at else None,
|
||||
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||
"overage_enabled": sub.overage_enabled,
|
||||
"cancelled_at": sub.cancelled_at.isoformat() + "Z" if sub.cancelled_at else None,
|
||||
"cancel_at_period_end": sub.cancel_at_period_end,
|
||||
}, message="Subscription retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Billing] Get subscription error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/billing/subscriptions/<org_id>", methods=["PATCH"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="subscription.update", resource_type="subscription")
|
||||
def update_subscription(org_id):
|
||||
"""Update subscription plan or billing cycle for an organization."""
|
||||
try:
|
||||
from gatehouse_app.models.billing.subscription import Subscription
|
||||
from gatehouse_app.models.billing.plan import Plan
|
||||
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Organization not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
sub = Subscription.query.filter_by(organization_id=org_id).first()
|
||||
if not sub:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="No subscription found for this organization",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
data = request.json or {}
|
||||
|
||||
if "plan_id" in data:
|
||||
plan = Plan.query.get(data["plan_id"])
|
||||
if not plan:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Plan not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
sub.plan_id = data["plan_id"]
|
||||
|
||||
if "billing_cycle" in data:
|
||||
if data["billing_cycle"] not in ["monthly", "yearly"]:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Invalid billing cycle. Must be 'monthly' or 'yearly'",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
sub.billing_cycle = data["billing_cycle"]
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={
|
||||
"id": sub.id,
|
||||
"organization_id": sub.organization_id,
|
||||
"plan_id": sub.plan_id,
|
||||
"billing_cycle": sub.billing_cycle,
|
||||
}, message="Subscription updated successfully")
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"[Billing] Update subscription error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/billing/subscriptions/<org_id>/cancel", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="subscription.cancel", resource_type="subscription")
|
||||
def cancel_subscription(org_id):
|
||||
"""Cancel subscription at period end."""
|
||||
try:
|
||||
from gatehouse_app.models.billing.subscription import Subscription
|
||||
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Organization not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
sub = Subscription.query.filter_by(organization_id=org_id).first()
|
||||
if not sub:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="No subscription found for this organization",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
sub.cancel_at_period_end = True
|
||||
sub.status = "cancelled"
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={
|
||||
"id": sub.id,
|
||||
"cancel_at_period_end": True,
|
||||
"status": sub.status,
|
||||
}, message="Subscription will be cancelled at period end")
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"[Billing] Cancel subscription error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/billing/subscriptions/<org_id>/trial", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="subscription.extend_trial", resource_type="subscription")
|
||||
def extend_trial(org_id):
|
||||
"""Extend trial period for an organization."""
|
||||
try:
|
||||
from gatehouse_app.models.billing.subscription import Subscription
|
||||
from datetime import timedelta
|
||||
|
||||
data = request.json or {}
|
||||
days = data.get("days", 30)
|
||||
|
||||
if not isinstance(days, int) or days < 1:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Days must be a positive integer",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Organization not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
sub = Subscription.query.filter_by(organization_id=org_id).first()
|
||||
if not sub:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="No subscription found for this organization",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Extend trial
|
||||
if sub.trial_ends_at:
|
||||
sub.trial_ends_at = sub.trial_ends_at + timedelta(days=days)
|
||||
else:
|
||||
sub.trial_ends_at = datetime.now(timezone.utc) + timedelta(days=days)
|
||||
|
||||
sub.status = "trial"
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={
|
||||
"id": sub.id,
|
||||
"trial_ends_at": sub.trial_ends_at.isoformat() + "Z",
|
||||
"days_added": days,
|
||||
}, message=f"Trial extended by {days} days")
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"[Billing] Extend trial error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Superadmin SSH CA management endpoints."""
|
||||
import logging
|
||||
from flask import request
|
||||
from gatehouse_app.api.v1.superadmin import superadmin_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations/<org_id>/cas/<ca_id>", methods=["DELETE"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="ca.delete", resource_type="CA")
|
||||
def delete_org_ca(org_id, ca_id):
|
||||
"""Soft-delete an SSH CA for an organization.
|
||||
|
||||
Sets is_active=False and deleted_at=now().
|
||||
"""
|
||||
from gatehouse_app.models.ssh_ca.ca import CA
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
|
||||
org = Organization.query.filter_by(id=org_id, deleted_at=None).first()
|
||||
if not org:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Organization not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND"
|
||||
)
|
||||
|
||||
ca = CA.query.filter_by(id=ca_id, organization_id=org_id, deleted_at=None).first()
|
||||
if not ca:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="CA not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND"
|
||||
)
|
||||
|
||||
try:
|
||||
ca.is_active = False
|
||||
ca.delete(soft=True)
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={"ca_id": ca_id}, message="CA deleted successfully")
|
||||
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
logger.exception(f"Failed to delete CA {ca_id}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Failed to delete CA",
|
||||
status=500,
|
||||
error_type="SERVER_ERROR"
|
||||
)
|
||||
@@ -0,0 +1,456 @@
|
||||
"""Superadmin organization member management endpoints."""
|
||||
import logging
|
||||
from flask import request, g
|
||||
from marshmallow import ValidationError
|
||||
from gatehouse_app.api.v1.superadmin import superadmin_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ListMembersSchema:
|
||||
"""Schema for list members query params."""
|
||||
|
||||
@staticmethod
|
||||
def load(args):
|
||||
"""Parse and validate query parameters."""
|
||||
try:
|
||||
page = max(1, int(args.get("page", 1)))
|
||||
per_page = min(100, max(1, int(args.get("per_page", 20))))
|
||||
except (ValueError, TypeError):
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
search = args.get("search")
|
||||
role = args.get("role")
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"search": search,
|
||||
"role": role,
|
||||
}
|
||||
|
||||
|
||||
class AddMemberSchema:
|
||||
"""Schema for adding a member."""
|
||||
|
||||
@staticmethod
|
||||
def load(data):
|
||||
"""Parse and validate add member data."""
|
||||
errors = {}
|
||||
|
||||
user_id = data.get("user_id")
|
||||
if not user_id:
|
||||
errors["user_id"] = ["User ID is required"]
|
||||
|
||||
role_str = data.get("role", "member")
|
||||
try:
|
||||
role = OrganizationRole(role_str)
|
||||
except ValueError:
|
||||
errors["role"] = [f"Invalid role. Must be one of: {', '.join(r.value for r in OrganizationRole)}"]
|
||||
|
||||
if errors:
|
||||
raise ValidationError(errors)
|
||||
|
||||
return {"user_id": user_id, "role": role}
|
||||
|
||||
|
||||
class UpdateMemberSchema:
|
||||
"""Schema for updating a member role."""
|
||||
|
||||
@staticmethod
|
||||
def load(data):
|
||||
"""Parse and validate update data."""
|
||||
errors = {}
|
||||
|
||||
role_str = data.get("role")
|
||||
if not role_str:
|
||||
errors["role"] = ["Role is required"]
|
||||
|
||||
try:
|
||||
role = OrganizationRole(role_str)
|
||||
except ValueError:
|
||||
errors["role"] = [f"Invalid role. Must be one of: {', '.join(r.value for r in OrganizationRole)}"]
|
||||
|
||||
if errors:
|
||||
raise ValidationError(errors)
|
||||
|
||||
return {"role": role}
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations/<org_id>/members", methods=["GET"])
|
||||
@superadmin_required
|
||||
def list_organization_members(org_id):
|
||||
"""List all members of an organization.
|
||||
|
||||
Query params:
|
||||
page: Page number (default 1)
|
||||
per_page: Items per page (default 20)
|
||||
search: Search by user email or name
|
||||
role: Filter by role
|
||||
"""
|
||||
try:
|
||||
# Verify org exists
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Organization not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
schema = ListMembersSchema()
|
||||
params = schema.load(request.args)
|
||||
|
||||
query = OrganizationMember.query.filter_by(organization_id=org_id, deleted_at=None)
|
||||
|
||||
# Search by user email or name
|
||||
if params["search"]:
|
||||
search_term = f"%{params['search']}%"
|
||||
query = query.join(User).filter(
|
||||
db.or_(
|
||||
User.email.ilike(search_term),
|
||||
User.full_name.ilike(search_term),
|
||||
)
|
||||
)
|
||||
|
||||
# Filter by role
|
||||
if params["role"]:
|
||||
try:
|
||||
role = OrganizationRole(params["role"])
|
||||
query = query.filter(OrganizationMember.role == role)
|
||||
except ValueError:
|
||||
pass # Ignore invalid role filter
|
||||
|
||||
# Order by joined_at desc
|
||||
query = query.order_by(OrganizationMember.joined_at.desc())
|
||||
|
||||
# Paginate
|
||||
pagination = query.paginate(page=params["page"], per_page=params["per_page"], error_out=False)
|
||||
|
||||
# Build response
|
||||
items = []
|
||||
for member in pagination.items:
|
||||
user = User.query.get(member.user_id)
|
||||
item = {
|
||||
"user_id": member.user_id,
|
||||
"organization_id": member.organization_id,
|
||||
"role": member.role.value,
|
||||
"joined_at": member.joined_at.isoformat() + "Z" if member.joined_at else None,
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"is_active": user.is_active,
|
||||
} if user else {"id": member.user_id, "email": "[deleted]", "full_name": None, "is_active": False},
|
||||
}
|
||||
items.append(item)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"items": items,
|
||||
"total": pagination.total,
|
||||
"page": params["page"],
|
||||
"per_page": params["per_page"],
|
||||
"pages": pagination.pages,
|
||||
},
|
||||
message="Members retrieved successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminOrg] List members error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations/<org_id>/members", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="add_member", resource_type="organization_member")
|
||||
def add_organization_member(org_id):
|
||||
"""Add a user to an organization.
|
||||
|
||||
Body:
|
||||
user_id: User UUID to add
|
||||
role: Role (owner, admin, member, guest)
|
||||
"""
|
||||
try:
|
||||
schema = AddMemberSchema()
|
||||
data = schema.load(request.json or {})
|
||||
|
||||
# Verify org exists
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Organization not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Verify user exists
|
||||
user = User.query.get(data["user_id"])
|
||||
if not user:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Check if already a member
|
||||
existing = OrganizationMember.query.filter_by(
|
||||
user_id=data["user_id"],
|
||||
organization_id=org_id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if existing:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User is already a member of this organization",
|
||||
status=400,
|
||||
error_type="ALREADY_EXISTS",
|
||||
)
|
||||
|
||||
# Create membership
|
||||
member = OrganizationMember(
|
||||
user_id=data["user_id"],
|
||||
organization_id=org_id,
|
||||
role=data["role"],
|
||||
invited_by_id=g.current_superadmin.id,
|
||||
invited_at=db.func.now(),
|
||||
joined_at=db.func.now(),
|
||||
)
|
||||
db.session.add(member)
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"[SuperadminOrg] Added user {data['user_id']} to org {org_id} as {data['role'].value}")
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"member": {
|
||||
"user_id": member.user_id,
|
||||
"organization_id": member.organization_id,
|
||||
"role": member.role.value,
|
||||
"joined_at": member.joined_at.isoformat() + "Z" if member.joined_at else None,
|
||||
}
|
||||
},
|
||||
message="Member added successfully",
|
||||
status=201,
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminOrg] Add member error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations/<org_id>/members/<user_id>", methods=["PATCH"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="update_member_role", resource_type="organization_member")
|
||||
def update_organization_member(org_id, user_id):
|
||||
"""Update a member's role.
|
||||
|
||||
Body:
|
||||
role: New role (owner, admin, member, guest)
|
||||
"""
|
||||
try:
|
||||
schema = UpdateMemberSchema()
|
||||
data = schema.load(request.json or {})
|
||||
|
||||
# Find member
|
||||
member = OrganizationMember.query.filter_by(
|
||||
user_id=user_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if not member:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Member not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Update role
|
||||
old_role = member.role
|
||||
member.role = data["role"]
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"[SuperadminOrg] Updated member {user_id} role from {old_role.value} to {data['role'].value}")
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"member": {
|
||||
"user_id": member.user_id,
|
||||
"organization_id": member.organization_id,
|
||||
"role": member.role.value,
|
||||
}
|
||||
},
|
||||
message="Member role updated successfully",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Validation failed",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
error_details=e.messages,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminOrg] Update member error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations/<org_id>/members/<user_id>", methods=["DELETE"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="remove_member", resource_type="organization_member")
|
||||
def remove_organization_member(org_id, user_id):
|
||||
"""Remove a user from an organization."""
|
||||
try:
|
||||
# Find member
|
||||
member = OrganizationMember.query.filter_by(
|
||||
user_id=user_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if not member:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Member not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Prevent removing owner without transferring ownership
|
||||
if member.role == OrganizationRole.OWNER:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Cannot remove organization owner. Transfer ownership first.",
|
||||
status=400,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
from datetime import datetime, timezone
|
||||
member.deleted_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"[SuperadminOrg] Removed user {user_id} from org {org_id}")
|
||||
|
||||
return api_response(message="Member removed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminOrg] Remove member error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations/<org_id>/transfer-ownership/<user_id>", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="transfer_ownership", resource_type="organization_member")
|
||||
def transfer_organization_ownership(org_id, user_id):
|
||||
"""Transfer organization ownership to another member.
|
||||
|
||||
The target user must already be a member of the organization.
|
||||
The current owner will be changed to admin role.
|
||||
"""
|
||||
try:
|
||||
# Verify org exists
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Organization not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Find current owner
|
||||
current_owner = OrganizationMember.query.filter_by(
|
||||
organization_id=org_id,
|
||||
role=OrganizationRole.OWNER,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if not current_owner:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Current owner not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Find target user as member
|
||||
target_member = OrganizationMember.query.filter_by(
|
||||
user_id=user_id,
|
||||
organization_id=org_id,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
if not target_member:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Target user is not a member of this organization",
|
||||
status=400,
|
||||
error_type="AUTHORIZATION_ERROR",
|
||||
)
|
||||
|
||||
# Transfer ownership
|
||||
current_owner.role = OrganizationRole.ADMIN
|
||||
target_member.role = OrganizationRole.OWNER
|
||||
db.session.commit()
|
||||
|
||||
logger.warning(
|
||||
f"[SuperadminOrg] TRANSFERRED OWNERSHIP: org={org_id} from user={current_owner.user_id} to user={user_id}"
|
||||
)
|
||||
|
||||
return api_response(
|
||||
data={
|
||||
"owner": {
|
||||
"user_id": target_member.user_id,
|
||||
"role": target_member.role.value,
|
||||
}
|
||||
},
|
||||
message="Ownership transferred successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminOrg] Transfer ownership error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
@@ -0,0 +1,254 @@
|
||||
"""Superadmin organization management endpoints."""
|
||||
import logging
|
||||
from flask import request
|
||||
from gatehouse_app.api.v1.superadmin import superadmin_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.services.superadmin_organization_service import SuperadminOrganizationService
|
||||
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ListOrganizationsSchema:
|
||||
"""Schema for list organizations query params."""
|
||||
|
||||
@staticmethod
|
||||
def load(args):
|
||||
"""Parse and validate query parameters."""
|
||||
try:
|
||||
page = max(1, int(args.get("page", 1)))
|
||||
per_page = min(100, max(1, int(args.get("per_page", 20))))
|
||||
except (ValueError, TypeError):
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
search = args.get("search")
|
||||
status = args.get("status")
|
||||
plan_slug = args.get("plan_slug")
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"search": search,
|
||||
"status": status,
|
||||
"plan_slug": plan_slug,
|
||||
}
|
||||
|
||||
|
||||
class UpdateOrganizationSchema:
|
||||
"""Schema for updating an organization."""
|
||||
|
||||
@staticmethod
|
||||
def load(data):
|
||||
"""Parse and validate update data."""
|
||||
result = {}
|
||||
|
||||
if "name" in data:
|
||||
result["name"] = data["name"]
|
||||
if "description" in data:
|
||||
result["description"] = data["description"]
|
||||
if "is_active" in data:
|
||||
result["is_active"] = bool(data["is_active"])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations", methods=["GET"])
|
||||
@superadmin_required
|
||||
def list_organizations():
|
||||
"""List all organizations with pagination and filtering.
|
||||
|
||||
Query params:
|
||||
page: Page number (default 1)
|
||||
per_page: Items per page (default 20, max 100)
|
||||
search: Search by name or slug
|
||||
status: Filter by status (active, suspended)
|
||||
plan_slug: Filter by plan slug
|
||||
"""
|
||||
try:
|
||||
schema = ListOrganizationsSchema()
|
||||
params = schema.load(request.args)
|
||||
|
||||
result = SuperadminOrganizationService.list_organizations(
|
||||
page=params["page"],
|
||||
per_page=params["per_page"],
|
||||
search=params["search"],
|
||||
status=params["status"],
|
||||
plan_slug=params["plan_slug"],
|
||||
)
|
||||
|
||||
return api_response(data=result, message="Organizations retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminOrg] List organizations error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations/<org_id>", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_organization(org_id):
|
||||
"""Get detailed organization information.
|
||||
|
||||
Returns org details including member count, owner info, and active sessions.
|
||||
"""
|
||||
try:
|
||||
result = SuperadminOrganizationService.get_organization_detail(org_id)
|
||||
return api_response(data={"organization": result}, message="Organization retrieved successfully")
|
||||
|
||||
except ValueError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminOrg] Get organization error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations/<org_id>", methods=["PATCH"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="update_organization", resource_type="organization")
|
||||
def update_organization(org_id):
|
||||
"""Update organization details.
|
||||
|
||||
Body:
|
||||
name: New name (optional)
|
||||
description: New description (optional)
|
||||
is_active: New active status (optional)
|
||||
"""
|
||||
try:
|
||||
schema = UpdateOrganizationSchema()
|
||||
data = schema.load(request.json or {})
|
||||
|
||||
if not data:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="No update data provided",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
org = SuperadminOrganizationService.update_organization(org_id, **data)
|
||||
|
||||
return api_response(
|
||||
data={"organization": org.to_dict()},
|
||||
message="Organization updated successfully",
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminOrg] Update organization error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations/<org_id>/suspend", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="suspend_organization", resource_type="organization")
|
||||
def suspend_organization(org_id):
|
||||
"""Suspend an organization.
|
||||
|
||||
Sets is_active=False and invalidates all member sessions.
|
||||
"""
|
||||
try:
|
||||
org = SuperadminOrganizationService.suspend_organization(org_id)
|
||||
|
||||
return api_response(
|
||||
data={"organization": org.to_dict()},
|
||||
message="Organization suspended successfully",
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminOrg] Suspend organization error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations/<org_id>/unsuspend", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="unsuspend_organization", resource_type="organization")
|
||||
def unsuspend_organization(org_id):
|
||||
"""Restore a suspended organization."""
|
||||
try:
|
||||
org = SuperadminOrganizationService.restore_organization(org_id)
|
||||
|
||||
return api_response(
|
||||
data={"organization": org.to_dict()},
|
||||
message="Organization restored successfully",
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminOrg] Unsuspend organization error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/organizations/<org_id>", methods=["DELETE"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="delete_organization", resource_type="organization")
|
||||
def delete_organization(org_id):
|
||||
"""Soft-delete an organization."""
|
||||
try:
|
||||
SuperadminOrganizationService.soft_delete_organization(org_id)
|
||||
|
||||
return api_response(message="Organization deleted successfully")
|
||||
|
||||
except ValueError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminOrg] Delete organization error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
@@ -0,0 +1,330 @@
|
||||
"""Superadmin usage and analytics endpoints."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from flask import request
|
||||
from gatehouse_app.api.v1.superadmin import superadmin_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.services.superadmin_usage_service import SuperadminUsageService
|
||||
from gatehouse_app.services.superadmin_analytics_service import SuperadminAnalyticsService
|
||||
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============ Analytics Endpoints ============
|
||||
|
||||
@superadmin_bp.route("/analytics/dashboard", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_dashboard_stats():
|
||||
"""Get dashboard statistics for the overview page.
|
||||
|
||||
Returns aggregated stats: org counts, user counts, sessions, recent signups.
|
||||
"""
|
||||
try:
|
||||
stats = SuperadminAnalyticsService.get_dashboard_stats()
|
||||
return api_response(data=stats, message="Dashboard stats retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminAnalytics] Dashboard stats error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/analytics/signup-trends", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_signup_trends():
|
||||
"""Get signup trends over time.
|
||||
|
||||
Query params:
|
||||
days: Number of days to analyze (default 30, max 365)
|
||||
"""
|
||||
try:
|
||||
days = min(365, max(1, int(request.args.get("days", 30))))
|
||||
trends = SuperadminAnalyticsService.get_signup_trends(days)
|
||||
return api_response(data=trends, message="Signup trends retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminAnalytics] Signup trends error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/analytics/org-distribution", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_org_distribution():
|
||||
"""Get organization distribution by size."""
|
||||
try:
|
||||
distribution = SuperadminAnalyticsService.get_org_distribution()
|
||||
return api_response(data=distribution, message="Organization distribution retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminAnalytics] Org distribution error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/analytics/recent-activity", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_recent_activity():
|
||||
"""Get recent superadmin actions.
|
||||
|
||||
Query params:
|
||||
limit: Maximum number of entries (default 20, max 100)
|
||||
"""
|
||||
try:
|
||||
limit = min(100, max(1, int(request.args.get("limit", 20))))
|
||||
activity = SuperadminAnalyticsService.get_recent_activity(limit)
|
||||
return api_response(data={"items": activity}, message="Recent activity retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminAnalytics] Recent activity error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
# ============ Usage Endpoints ============
|
||||
|
||||
@superadmin_bp.route("/usage/<org_id>", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_organization_usage(org_id):
|
||||
"""Get current usage for an organization.
|
||||
|
||||
Returns current period usage metrics: user count, active sessions.
|
||||
"""
|
||||
try:
|
||||
usage = SuperadminUsageService.get_current_usage(org_id)
|
||||
return api_response(data=usage, message="Usage retrieved successfully")
|
||||
|
||||
except ValueError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminUsage] Get usage error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/usage/<org_id>/history", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_usage_history(org_id):
|
||||
"""Get usage history for an organization.
|
||||
|
||||
Query params:
|
||||
metric: Metric type (users, sessions) - default users
|
||||
days: Number of days of history (default 30, max 365)
|
||||
"""
|
||||
try:
|
||||
metric = request.args.get("metric", "users")
|
||||
days = min(365, max(1, int(request.args.get("days", 30))))
|
||||
|
||||
history = SuperadminUsageService.get_usage_history(org_id, metric, days)
|
||||
return api_response(data=history, message="Usage history retrieved successfully")
|
||||
|
||||
except ValueError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminUsage] Get usage history error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/usage/<org_id>/seats", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_seat_count(org_id):
|
||||
"""Get maximum seat count for billing period.
|
||||
|
||||
Query params:
|
||||
year: Year (default current year)
|
||||
month: Month (default current month)
|
||||
"""
|
||||
try:
|
||||
year = int(request.args.get("year", datetime.now().year))
|
||||
month = int(request.args.get("month", datetime.now().month))
|
||||
|
||||
seats = SuperadminUsageService.get_seat_count_for_period(org_id, year, month)
|
||||
return api_response(data=seats, message="Seat count retrieved successfully")
|
||||
|
||||
except ValueError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminUsage] Get seat count error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/usage/<org_id>/adjust", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="usage_adjustment", resource_type="usage")
|
||||
def adjust_usage(org_id):
|
||||
"""Apply a manual usage adjustment.
|
||||
|
||||
Body:
|
||||
metric: Metric to adjust
|
||||
adjustment: Positive (credit) or negative (charge)
|
||||
reason: Reason for adjustment
|
||||
"""
|
||||
try:
|
||||
data = request.json or {}
|
||||
|
||||
metric = data.get("metric", "users")
|
||||
adjustment = data.get("adjustment", 0)
|
||||
reason = data.get("reason", "")
|
||||
|
||||
if not reason:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Reason is required for usage adjustment",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
result = SuperadminUsageService.adjust_usage(
|
||||
org_id=org_id,
|
||||
metric=metric,
|
||||
adjustment=adjustment,
|
||||
reason=reason,
|
||||
superadmin_id="", # Will be filled from decorator
|
||||
)
|
||||
|
||||
return api_response(data=result, message="Usage adjustment applied successfully")
|
||||
|
||||
except ValueError as e:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=str(e),
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminUsage] Adjust usage error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
# ============ Invoice Data Endpoint ============
|
||||
|
||||
@superadmin_bp.route("/invoice-data/<org_id>", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_invoice_data(org_id):
|
||||
"""Get all data needed to generate an invoice for an organization.
|
||||
|
||||
Returns organization info, plan details, usage, and subscription status.
|
||||
"""
|
||||
try:
|
||||
from datetime import datetime
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Organization not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Get seat count for current month
|
||||
now = datetime.now(timezone.utc)
|
||||
seats = SuperadminUsageService.get_seat_count_for_period(org_id, now.year, now.month)
|
||||
|
||||
# Get owner
|
||||
owner = org.get_owner()
|
||||
|
||||
invoice_data = {
|
||||
"organization": {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
"slug": org.slug,
|
||||
"owner_email": owner.email if owner else None,
|
||||
"created_at": org.created_at.isoformat() + "Z" if org.created_at else None,
|
||||
},
|
||||
"billing_period": {
|
||||
"year": now.year,
|
||||
"month": now.month,
|
||||
"start": seats["period_start"],
|
||||
"end": seats["period_end"],
|
||||
},
|
||||
"usage": {
|
||||
"max_seats": seats["max_seats"],
|
||||
"current_seats": seats["current_seats"],
|
||||
"included_seats": 0, # Would come from plan
|
||||
"overage": max(0, seats["current_seats"]), # Simplified
|
||||
},
|
||||
"subscription": {
|
||||
"status": "active" if org.is_active else "suspended",
|
||||
"is_active": org.is_active,
|
||||
},
|
||||
"line_items": [
|
||||
{
|
||||
"description": "Base subscription",
|
||||
"quantity": 1,
|
||||
"unit_price": 0, # Would come from plan
|
||||
"total": 0,
|
||||
},
|
||||
{
|
||||
"description": f"User seats ({seats['current_seats']})",
|
||||
"quantity": seats["current_seats"],
|
||||
"unit_price": 0, # Would come from plan per-seat price
|
||||
"total": 0,
|
||||
},
|
||||
],
|
||||
"total": 0, # Would be calculated
|
||||
"generated_at": now.isoformat() + "Z",
|
||||
}
|
||||
|
||||
return api_response(data=invoice_data, message="Invoice data retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminAnalytics] Invoice data error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
@@ -0,0 +1,516 @@
|
||||
"""Superadmin user management endpoints."""
|
||||
import logging
|
||||
from flask import request, g
|
||||
from gatehouse_app.api.v1.superadmin import superadmin_bp
|
||||
from gatehouse_app.utils.response import api_response
|
||||
from gatehouse_app.decorators.superadmin import superadmin_required, superadmin_audit_log
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@superadmin_bp.route("/users", methods=["GET"])
|
||||
@superadmin_required
|
||||
def list_users():
|
||||
"""Get paginated list of users with optional filters.
|
||||
|
||||
Query params:
|
||||
page: Page number (default 1)
|
||||
per_page: Items per page (default 20, max 100)
|
||||
organization_id: Filter by organization
|
||||
status: Filter by status (active/suspended)
|
||||
search: Search by email or name
|
||||
"""
|
||||
try:
|
||||
page = max(1, int(request.args.get("page", 1)))
|
||||
per_page = min(100, max(1, int(request.args.get("per_page", 20))))
|
||||
org_id = request.args.get("organization_id")
|
||||
status = request.args.get("status")
|
||||
search = request.args.get("search", "").strip()
|
||||
|
||||
# Base query
|
||||
query = User.query.filter(User.deleted_at.is_(None))
|
||||
|
||||
# Filter by organization
|
||||
if org_id:
|
||||
member_user_ids = db.session.query(OrganizationMember.user_id).filter(
|
||||
OrganizationMember.organization_id == org_id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).all()
|
||||
user_ids = [m.user_id for m in member_user_ids]
|
||||
query = query.filter(User.id.in_(user_ids))
|
||||
|
||||
# Filter by status
|
||||
if status == "suspended":
|
||||
query = query.filter(User.status == "GLOBAL_SUSPENDED")
|
||||
elif status == "active":
|
||||
query = query.filter(User.status != "GLOBAL_SUSPENDED")
|
||||
|
||||
# Search by email or name
|
||||
if search:
|
||||
search_filter = f"%{search}%"
|
||||
query = query.filter(
|
||||
db.or_(
|
||||
User.email.ilike(search_filter),
|
||||
User.full_name.ilike(search_filter),
|
||||
)
|
||||
)
|
||||
|
||||
# Order by created_at desc
|
||||
query = query.order_by(User.created_at.desc())
|
||||
|
||||
# Paginate
|
||||
total = query.count()
|
||||
users = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
# Get org memberships for each user
|
||||
items = []
|
||||
for user in users:
|
||||
# Get organization memberships
|
||||
memberships = db.session.query(OrganizationMember).filter(
|
||||
OrganizationMember.user_id == user.id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).all()
|
||||
|
||||
orgs = []
|
||||
for m in memberships:
|
||||
org = Organization.query.get(m.organization_id)
|
||||
if org:
|
||||
orgs.append({
|
||||
"org_id": org.id,
|
||||
"org_name": org.name,
|
||||
"role": m.role,
|
||||
"joined_at": m.created_at.isoformat() + "Z" if m.created_at else None,
|
||||
})
|
||||
|
||||
# Get active session count
|
||||
active_sessions = Session.query.filter(
|
||||
Session.user_id == user.id,
|
||||
Session.deleted_at.is_(None),
|
||||
Session.status == "active",
|
||||
).count()
|
||||
|
||||
items.append({
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"status": user.status,
|
||||
"mfa_enabled": user.mfa_enabled if hasattr(user, 'mfa_enabled') else False,
|
||||
"org_count": len(orgs),
|
||||
"orgs": orgs,
|
||||
"active_sessions": active_sessions,
|
||||
"last_login_at": user.last_login_at.isoformat() + "Z" if user.last_login_at else None,
|
||||
"created_at": user.created_at.isoformat() + "Z" if user.created_at else None,
|
||||
})
|
||||
|
||||
return api_response(data={
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page > 0 else 0,
|
||||
}, message="Users retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminUsers] List users error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/users/<user_id>", methods=["GET"])
|
||||
@superadmin_required
|
||||
def get_user(user_id):
|
||||
"""Get detailed user information.
|
||||
|
||||
Returns user + all org memberships + active sessions + security methods.
|
||||
"""
|
||||
try:
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Get organization memberships
|
||||
memberships = db.session.query(OrganizationMember).filter(
|
||||
OrganizationMember.user_id == user_id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).all()
|
||||
|
||||
orgs = []
|
||||
for m in memberships:
|
||||
org = Organization.query.get(m.organization_id)
|
||||
if org:
|
||||
orgs.append({
|
||||
"org_id": org.id,
|
||||
"org_name": org.name,
|
||||
"org_slug": org.slug,
|
||||
"role": m.role,
|
||||
"joined_at": m.created_at.isoformat() + "Z" if m.created_at else None,
|
||||
})
|
||||
|
||||
# Get active sessions
|
||||
sessions = Session.query.filter(
|
||||
Session.user_id == user_id,
|
||||
Session.deleted_at.is_(None),
|
||||
Session.status == "active",
|
||||
).all()
|
||||
|
||||
active_sessions = [{
|
||||
"id": s.id,
|
||||
"ip_address": s.ip_address,
|
||||
"user_agent": s.user_agent,
|
||||
"created_at": s.created_at.isoformat() + "Z" if s.created_at else None,
|
||||
"last_active_at": s.last_active_at.isoformat() + "Z" if hasattr(s, 'last_active_at') and s.last_active_at else None,
|
||||
} for s in sessions]
|
||||
|
||||
# Get security methods (simplified - would need UserSecurityMethod model)
|
||||
security_methods = []
|
||||
if hasattr(user, 'totp_enabled') and user.totp_enabled:
|
||||
security_methods.append({"type": "totp", "enabled": True})
|
||||
if hasattr(user, 'webauthn_enabled') and user.webauthn_enabled:
|
||||
security_methods.append({"type": "webauthn", "enabled": True})
|
||||
|
||||
return api_response(data={
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"status": user.status,
|
||||
"mfa_enabled": user.mfa_enabled if hasattr(user, 'mfa_enabled') else False,
|
||||
"last_login_at": user.last_login_at.isoformat() + "Z" if user.last_login_at else None,
|
||||
"created_at": user.created_at.isoformat() + "Z" if user.created_at else None,
|
||||
},
|
||||
"organizations": orgs,
|
||||
"active_sessions": active_sessions,
|
||||
"security_methods": security_methods,
|
||||
}, message="User retrieved successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminUsers] Get user error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/users/<user_id>/suspend", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="user.suspend", resource_type="user")
|
||||
def suspend_user(user_id):
|
||||
"""Globally suspend a user (sets status=GLOBAL_SUSPENDED)."""
|
||||
try:
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
if user.status == "GLOBAL_SUSPENDED":
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User is already suspended",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
user.status = "GLOBAL_SUSPENDED"
|
||||
db.session.commit()
|
||||
|
||||
# Revoke all sessions
|
||||
revoked_count = Session.query.filter(
|
||||
Session.user_id == user_id,
|
||||
Session.deleted_at.is_(None),
|
||||
).update({"status": "revoked", "deleted_at": db.func.now()})
|
||||
db.session.commit()
|
||||
|
||||
logger.warning(f"[SuperadminUsers] User {user_id} globally suspended by {getattr(g, 'current_superadmin', {}).get('id', 'unknown')}")
|
||||
|
||||
return api_response(data={
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"status": user.status,
|
||||
},
|
||||
"sessions_revoked": revoked_count,
|
||||
}, message="User suspended successfully")
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"[SuperadminUsers] Suspend user error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/users/<user_id>/unsuspend", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="user.unsuspend", resource_type="user")
|
||||
def unsuspend_user(user_id):
|
||||
"""Remove global suspension from a user."""
|
||||
try:
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
if user.status != "GLOBAL_SUSPENDED":
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User is not suspended",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
user.status = "active"
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"status": user.status,
|
||||
},
|
||||
}, message="User unsuspended successfully")
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"[SuperadminUsers] Unsuspend user error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/users/<user_id>/reset-password", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="user.reset_password", resource_type="user")
|
||||
def reset_user_password(user_id):
|
||||
"""Trigger password reset email flow for user."""
|
||||
try:
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# In production, this would call AuthService.send_password_reset_email(user.email)
|
||||
# For now, just log and return success
|
||||
logger.info(f"[SuperadminUsers] Password reset requested for {user.email} by superadmin")
|
||||
|
||||
return api_response(data={
|
||||
"email": user.email,
|
||||
}, message="Password reset email sent successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperadminUsers] Reset password error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/users/<user_id>/sessions", methods=["DELETE"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="user.revoke_sessions", resource_type="user")
|
||||
def revoke_user_sessions(user_id):
|
||||
"""Revoke all sessions for a user."""
|
||||
try:
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Revoke all sessions
|
||||
result = Session.query.filter(
|
||||
Session.user_id == user_id,
|
||||
Session.deleted_at.is_(None),
|
||||
).update({"status": "revoked", "deleted_at": db.func.now()})
|
||||
db.session.commit()
|
||||
|
||||
return api_response(data={
|
||||
"user_id": user_id,
|
||||
"count": result,
|
||||
}, message=f"All sessions revoked ({result} sessions)")
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"[SuperadminUsers] Revoke sessions error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/users/<user_id>/add-to-org/<org_id>", methods=["POST"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="user.add_to_org", resource_type="user")
|
||||
def add_user_to_org(user_id, org_id):
|
||||
"""Add a user to an organization with specified role."""
|
||||
try:
|
||||
data = request.json or {}
|
||||
role = data.get("role", "member")
|
||||
|
||||
valid_roles = ["member", "admin", "owner"]
|
||||
if role not in valid_roles:
|
||||
return api_response(
|
||||
success=False,
|
||||
message=f"Invalid role. Must be one of: {', '.join(valid_roles)}",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
org = Organization.query.get(org_id)
|
||||
if not org or org.deleted_at is not None:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Organization not found",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Check if already a member
|
||||
existing = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == user_id,
|
||||
OrganizationMember.organization_id == org_id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User is already a member of this organization",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
# Create membership
|
||||
membership = OrganizationMember(
|
||||
user_id=user_id,
|
||||
organization_id=org_id,
|
||||
role=role,
|
||||
)
|
||||
db.session.add(membership)
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"[SuperadminUsers] User {user_id} added to org {org_id} as {role} by superadmin")
|
||||
|
||||
return api_response(data={
|
||||
"user_id": user_id,
|
||||
"organization_id": org_id,
|
||||
"role": role,
|
||||
"joined_at": membership.created_at.isoformat() + "Z" if membership.created_at else None,
|
||||
}, message="User added to organization successfully")
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"[SuperadminUsers] Add to org error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
|
||||
|
||||
@superadmin_bp.route("/users/<user_id>/orgs/<org_id>", methods=["DELETE"])
|
||||
@superadmin_required
|
||||
@superadmin_audit_log(action="user.remove_from_org", resource_type="user")
|
||||
def remove_user_from_org(user_id, org_id):
|
||||
"""Remove a user from an organization."""
|
||||
try:
|
||||
membership = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == user_id,
|
||||
OrganizationMember.organization_id == org_id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="User is not a member of this organization",
|
||||
status=404,
|
||||
error_type="NOT_FOUND",
|
||||
)
|
||||
|
||||
# Check if user is the only owner
|
||||
if membership.role == "owner":
|
||||
owner_count = OrganizationMember.query.filter(
|
||||
OrganizationMember.organization_id == org_id,
|
||||
OrganizationMember.role == "owner",
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).count()
|
||||
|
||||
if owner_count <= 1:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Cannot remove the only owner from an organization. Transfer ownership first.",
|
||||
status=400,
|
||||
error_type="VALIDATION_ERROR",
|
||||
)
|
||||
|
||||
# Soft delete membership
|
||||
membership.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"[SuperadminUsers] User {user_id} removed from org {org_id} by superadmin")
|
||||
|
||||
return api_response(data={
|
||||
"user_id": user_id,
|
||||
"organization_id": org_id,
|
||||
}, message="User removed from organization successfully")
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"[SuperadminUsers] Remove from org error: {e}")
|
||||
return api_response(
|
||||
success=False,
|
||||
message="An error occurred",
|
||||
status=500,
|
||||
error_type="INTERNAL_ERROR",
|
||||
)
|
||||
@@ -283,7 +283,8 @@ def admin_verify_user_email(user_id):
|
||||
if was_inactive:
|
||||
target.status = UserStatus.ACTIVE
|
||||
|
||||
EmailVerificationToken.query.filter_by(user_id=target.id, used_at=None).delete()
|
||||
now = datetime.now(timezone.utc)
|
||||
EmailVerificationToken.query.filter_by(user_id=target.id, used_at=None).filter(EmailVerificationToken.deleted_at == None).update({"deleted_at": now}, synchronize_session=False)
|
||||
_db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
@@ -300,14 +301,12 @@ def admin_verify_user_email(user_id):
|
||||
@api_v1_bp.route("/admin/users/<user_id>/delete", methods=["POST"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def admin_hard_delete_user(user_id):
|
||||
def admin_delete_user(user_id):
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.user.user import User as _User
|
||||
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
|
||||
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate
|
||||
from gatehouse_app.models.ssh_ca.certificate_audit_log import CertificateAuditLog
|
||||
from gatehouse_app.models.auth.authentication_method import OAuthState
|
||||
from gatehouse_app.models.security.organization_security_policy import OrganizationSecurityPolicy
|
||||
from gatehouse_app.extensions import db as _db
|
||||
from gatehouse_app.utils.constants import AuditAction, OrganizationRole
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
@@ -329,7 +328,7 @@ def admin_hard_delete_user(user_id):
|
||||
if target.id == caller.id:
|
||||
return api_response(success=False, message="Cannot delete your own account via this endpoint.", status=400, error_type="BAD_REQUEST")
|
||||
|
||||
target_org_ids = {m.organization_id for m in target.organization_memberships}
|
||||
target_org_ids = {m.organization_id for m in target.get_active_memberships()}
|
||||
admin_in_shared_org = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == caller.id,
|
||||
OrganizationMember.organization_id.in_(target_org_ids),
|
||||
@@ -373,44 +372,29 @@ def admin_hard_delete_user(user_id):
|
||||
|
||||
target_email = target.email
|
||||
target_id_str = str(target.id)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
# NULL out FK references that don't cascade on delete so the
|
||||
# session.delete() below doesn't hit FK constraint violations.
|
||||
# Soft delete the user — set deleted_at timestamp.
|
||||
target.deleted_at = now
|
||||
|
||||
# org_invite_tokens.invited_by_id — SET NULL is already on the FK column,
|
||||
# but OrganizationMember.invited_by_id has no ondelete clause.
|
||||
_db.session.execute(
|
||||
_db.text("UPDATE organization_members SET invited_by_id = NULL WHERE invited_by_id = :uid"),
|
||||
{"uid": target_id_str},
|
||||
# Soft delete associated OAuthState records.
|
||||
OAuthState.query.filter_by(user_id=target_id_str).filter(OAuthState.deleted_at == None).update(
|
||||
{"deleted_at": now}, synchronize_session=False
|
||||
)
|
||||
|
||||
# certificate_audit_logs.user_id — nullable, no ondelete clause.
|
||||
CertificateAuditLog.query.filter_by(user_id=target_id_str).update(
|
||||
{"user_id": None}, synchronize_session=False
|
||||
)
|
||||
|
||||
# organization_security_policies.updated_by_user_id — nullable, no ondelete.
|
||||
OrganizationSecurityPolicy.query.filter_by(updated_by_user_id=target_id_str).update(
|
||||
{"updated_by_user_id": None}, synchronize_session=False
|
||||
)
|
||||
|
||||
# oauth_states.user_id — nullable, no ondelete.
|
||||
OAuthState.query.filter_by(user_id=target_id_str).delete(synchronize_session=False)
|
||||
|
||||
_db.session.delete(target)
|
||||
_db.session.flush()
|
||||
except Exception as exc:
|
||||
_db.session.rollback()
|
||||
_logger.error(f"Hard delete failed for {target_id_str}: {exc}")
|
||||
_logger.error(f"Soft delete failed for {target_id_str}: {exc}")
|
||||
return api_response(success=False, message="Failed to delete user account. Please try again.", status=500, error_type="SERVER_ERROR")
|
||||
|
||||
AuditService.log_action(
|
||||
action=AuditAction.USER_HARD_DELETE,
|
||||
action=AuditAction.USER_DELETE,
|
||||
user_id=caller.id,
|
||||
organization_id=admin_in_shared_org.organization_id,
|
||||
resource_type="user", resource_id=target_id_str,
|
||||
description=f"Admin permanently deleted user account: {target_email}",
|
||||
description=f"Admin deleted user account: {target_email}",
|
||||
metadata={
|
||||
"deleted_user_id": target_id_str, "deleted_user_email": target_email,
|
||||
"ssh_keys_deleted": ssh_key_count, "certs_revoked": active_cert_count,
|
||||
@@ -419,7 +403,7 @@ def admin_hard_delete_user(user_id):
|
||||
|
||||
_db.session.commit()
|
||||
return api_response(
|
||||
message=f"User account {target_email} has been permanently deleted.",
|
||||
message=f"User account {target_email} has been deleted.",
|
||||
data={"deleted_user_id": target_id_str, "deleted_user_email": target_email,
|
||||
"ssh_keys_deleted": ssh_key_count, "certs_revoked": active_cert_count},
|
||||
)
|
||||
@@ -512,7 +496,14 @@ def admin_get_user_mfa(user_id):
|
||||
user_id=user_id, method_type=AuthMethodType.WEBAUTHN, deleted_at=None,
|
||||
).first()
|
||||
if webauthn_method and webauthn_method.provider_data:
|
||||
for cred in webauthn_method.provider_data.get("credentials", []):
|
||||
# Handle both single credential (direct in provider_data) and multiple credentials (in credentials array)
|
||||
credentials = webauthn_method.provider_data.get("credentials", [])
|
||||
|
||||
# If no credentials array, check if provider_data itself is a single credential
|
||||
if not credentials and "credential_id" in webauthn_method.provider_data:
|
||||
credentials = [webauthn_method.provider_data]
|
||||
|
||||
for cred in credentials:
|
||||
if not cred.get("deleted_at"):
|
||||
mfa_methods.append({
|
||||
"id": cred.get("id") or cred.get("credential_id"),
|
||||
@@ -588,6 +579,8 @@ def admin_remove_user_mfa(user_id, method_type):
|
||||
credential_id = request.args.get("credential_id")
|
||||
if credential_id:
|
||||
credentials = (webauthn_method.provider_data or {}).get("credentials", [])
|
||||
if not credentials and "credential_id" in (webauthn_method.provider_data or {}):
|
||||
credentials = [webauthn_method.provider_data]
|
||||
found = False
|
||||
new_credentials = []
|
||||
for cred in credentials:
|
||||
|
||||
@@ -142,6 +142,55 @@ def get_my_organizations():
|
||||
return api_response(data={"organizations": orgs, "count": len(orgs)}, message="Organizations retrieved successfully")
|
||||
|
||||
|
||||
@api_v1_bp.route("/users/me/organizations/simple", methods=["GET"])
|
||||
@login_required
|
||||
def get_my_organizations_simple():
|
||||
"""Lightweight organization list for CLI tool.
|
||||
|
||||
Returns organizations with CA status indicators for CLI users.
|
||||
"""
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.ssh_ca.ca import CA, CaType
|
||||
|
||||
user = g.current_user
|
||||
memberships = OrganizationMember.query.filter_by(user_id=user.id, deleted_at=None).all()
|
||||
|
||||
orgs = []
|
||||
for membership in memberships:
|
||||
org = membership.organization
|
||||
if not org or org.deleted_at is not None:
|
||||
continue
|
||||
|
||||
# Check for active CAs
|
||||
user_ca = CA.query.filter_by(
|
||||
organization_id=org.id,
|
||||
ca_type=CaType.USER,
|
||||
is_active=True,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
host_ca = CA.query.filter_by(
|
||||
organization_id=org.id,
|
||||
ca_type=CaType.HOST,
|
||||
is_active=True,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
orgs.append({
|
||||
"id": str(org.id),
|
||||
"name": org.name,
|
||||
"slug": getattr(org, 'slug', None),
|
||||
"role": membership.role.value if hasattr(membership.role, "value") else str(membership.role),
|
||||
"has_user_ca": user_ca is not None,
|
||||
"has_host_ca": host_ca is not None,
|
||||
})
|
||||
|
||||
return api_response(
|
||||
data={"organizations": orgs, "count": len(orgs)},
|
||||
message="Organizations retrieved successfully",
|
||||
)
|
||||
|
||||
|
||||
@api_v1_bp.route("/users/me/principals", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
@@ -182,12 +231,11 @@ def get_my_principals():
|
||||
|
||||
my_principals = []
|
||||
if effective_principal_ids:
|
||||
for p in Principal.query.filter(
|
||||
Principal.id.in_(list(effective_principal_ids)),
|
||||
Principal.deleted_at == None,
|
||||
).all():
|
||||
for p in Principal.query.filter(Principal.id.in_(list(effective_principal_ids)), Principal.deleted_at == None).all():
|
||||
my_principals.append({
|
||||
"id": p.id, "name": p.name, "description": p.description,
|
||||
"id": p.id,
|
||||
"name": p.name,
|
||||
"description": p.description,
|
||||
"direct": p.id in direct_principal_ids,
|
||||
})
|
||||
|
||||
@@ -197,7 +245,8 @@ def get_my_principals():
|
||||
all_principals.append({"id": p.id, "name": p.name, "description": p.description})
|
||||
|
||||
orgs_result.append({
|
||||
"org_id": org.id, "org_name": org.name,
|
||||
"org_id": org.id,
|
||||
"org_name": org.name,
|
||||
"role": role.value if hasattr(role, "value") else role,
|
||||
"is_admin": is_admin,
|
||||
"my_principals": my_principals,
|
||||
@@ -241,6 +290,7 @@ def get_my_pending_invites():
|
||||
|
||||
@api_v1_bp.route("/users/me/memberships", methods=["GET"])
|
||||
@login_required
|
||||
@full_access_required
|
||||
def get_my_memberships():
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.organization.department import DepartmentMembership, DepartmentPrincipal, Department
|
||||
@@ -258,15 +308,15 @@ def get_my_memberships():
|
||||
|
||||
dept_memberships = DepartmentMembership.query.filter_by(user_id=user.id, deleted_at=None).all()
|
||||
user_depts = [
|
||||
dm.department for dm in dept_memberships
|
||||
if dm.department
|
||||
and dm.department.organization_id == org.id
|
||||
and dm.department.deleted_at is None
|
||||
dm.department
|
||||
for dm in dept_memberships
|
||||
if dm.department and dm.department.organization_id == org.id and dm.department.deleted_at is None
|
||||
]
|
||||
|
||||
direct_pm = PrincipalMembership.query.filter_by(user_id=user.id, deleted_at=None).all()
|
||||
direct_principal_ids = {
|
||||
pm.principal_id for pm in direct_pm
|
||||
pm.principal_id
|
||||
for pm in direct_pm
|
||||
if pm.principal and pm.principal.organization_id == org.id and pm.principal.deleted_at is None
|
||||
}
|
||||
|
||||
@@ -279,18 +329,18 @@ def get_my_memberships():
|
||||
all_principal_ids = direct_principal_ids | via_dept_principal_ids
|
||||
principals_list = []
|
||||
if all_principal_ids:
|
||||
for p in Principal.query.filter(
|
||||
Principal.id.in_(list(all_principal_ids)),
|
||||
Principal.deleted_at == None,
|
||||
).all():
|
||||
for p in Principal.query.filter(Principal.id.in_(list(all_principal_ids)), Principal.deleted_at == None).all():
|
||||
principals_list.append({
|
||||
"id": str(p.id), "name": p.name, "description": p.description,
|
||||
"id": str(p.id),
|
||||
"name": p.name,
|
||||
"description": p.description,
|
||||
"via_department": p.id not in direct_principal_ids,
|
||||
})
|
||||
|
||||
role = membership.role
|
||||
orgs_result.append({
|
||||
"org_id": str(org.id), "org_name": org.name,
|
||||
"org_id": str(org.id),
|
||||
"org_name": org.name,
|
||||
"role": role.value if hasattr(role, "value") else role,
|
||||
"departments": [{"id": str(d.id), "name": d.name, "description": d.description} for d in user_depts],
|
||||
"principals": principals_list,
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
"""Superadmin authentication and audit decorators."""
|
||||
import logging
|
||||
from functools import wraps
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from flask import request, g
|
||||
|
||||
from gatehouse_app.utils.response import api_response
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def superadmin_required(f):
|
||||
"""Decorator to require superadmin Bearer token authentication.
|
||||
|
||||
Extracts token from Authorization: Bearer {token} header,
|
||||
validates the session against SuperadminSession table,
|
||||
and sets g.current_superadmin and g.superadmin_session.
|
||||
|
||||
Returns 401 if no valid session, 403 if not a superadmin.
|
||||
"""
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
# Extract token from Authorization header
|
||||
auth_header = request.headers.get('Authorization')
|
||||
|
||||
if not auth_header:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Authorization header is required",
|
||||
status=401,
|
||||
error_type="AUTH_REQUIRED"
|
||||
)
|
||||
|
||||
# Expect format: "Bearer {token}"
|
||||
parts = auth_header.split()
|
||||
if len(parts) != 2 or parts[0].lower() != 'bearer':
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Invalid authorization format. Use: Bearer {token}",
|
||||
status=401,
|
||||
error_type="INVALID_AUTH_FORMAT"
|
||||
)
|
||||
|
||||
token = parts[1]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from gatehouse_app.models.superadmin import SuperadminSession, Superadmin
|
||||
|
||||
# Get active session by token
|
||||
session = SuperadminSession.query.filter_by(token=token).first()
|
||||
|
||||
if not session:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Invalid or expired session",
|
||||
status=401,
|
||||
error_type="INVALID_TOKEN"
|
||||
)
|
||||
|
||||
# Check if session is active
|
||||
if not session.is_active():
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Session is no longer active",
|
||||
status=401,
|
||||
error_type="SESSION_INACTIVE"
|
||||
)
|
||||
|
||||
# Get the superadmin
|
||||
superadmin = session.superadmin
|
||||
if not superadmin:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Superadmin not found",
|
||||
status=401,
|
||||
error_type="INVALID_TOKEN"
|
||||
)
|
||||
|
||||
# Check if superadmin is active
|
||||
if not superadmin.is_active:
|
||||
return api_response(
|
||||
success=False,
|
||||
message="Superadmin account is disabled",
|
||||
status=403,
|
||||
error_type="ACCOUNT_DISABLED"
|
||||
)
|
||||
|
||||
# Update last_activity_at timestamp
|
||||
session.last_activity_at = datetime.now(timezone.utc)
|
||||
from gatehouse_app import db
|
||||
db.session.commit()
|
||||
|
||||
# Set context variables
|
||||
g.current_superadmin = superadmin
|
||||
g.superadmin_session = session
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated_function
|
||||
|
||||
|
||||
def superadmin_audit_log(action, resource_type):
|
||||
"""Decorator to log superadmin actions to SuperadminAuditLog.
|
||||
|
||||
Must be used AFTER @superadmin_required to have access to g.current_superadmin.
|
||||
|
||||
Args:
|
||||
action: The action being performed (e.g., 'update', 'delete', 'create')
|
||||
resource_type: The type of resource being acted on (e.g., 'organization', 'user')
|
||||
|
||||
Captures: superadmin_id, action, resource_type, resource_id, org_id, user_id,
|
||||
ip_address, user_agent, request_id, extra_data
|
||||
"""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
# Get superadmin from context (set by @superadmin_required)
|
||||
superadmin = getattr(g, 'current_superadmin', None)
|
||||
session = getattr(g, 'superadmin_session', None)
|
||||
|
||||
if not superadmin:
|
||||
logger.warning(f"superadmin_audit_log used without @superadmin_required on {f.__name__}")
|
||||
return f(*args, **kwargs)
|
||||
|
||||
# Extract resource_id from kwargs if present
|
||||
resource_id = kwargs.get('resource_id') or kwargs.get(f'{resource_type}_id') or None
|
||||
|
||||
# Extract org_id and user_id from kwargs if present
|
||||
org_id = kwargs.get('org_id') or None
|
||||
user_id = kwargs.get('user_id') or None
|
||||
|
||||
# Get IP address and user agent
|
||||
ip_address = request.remote_addr or None
|
||||
user_agent = request.headers.get('User-Agent') or None
|
||||
request_id = request.headers.get('X-Request-ID') or None
|
||||
|
||||
# Get extra data from request body (for POST/PATCH requests)
|
||||
extra_data = None
|
||||
if request.is_json:
|
||||
try:
|
||||
# Exclude sensitive fields
|
||||
body = request.get_json(silent=True) or {}
|
||||
sensitive_fields = {'password', 'password_hash', 'token', 'secret', 'key'}
|
||||
extra_data = {k: v for k, v in body.items() if k not in sensitive_fields}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get success status (default to True unless an error is raised)
|
||||
success = True
|
||||
error_message = None
|
||||
|
||||
try:
|
||||
result = f(*args, **kwargs)
|
||||
|
||||
# Check if the response indicates failure
|
||||
if hasattr(result, 'get_json'):
|
||||
result_data = result.get_json()
|
||||
if result_data and result_data.get('success') is False:
|
||||
success = False
|
||||
error_message = result_data.get('message')
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
success = False
|
||||
error_message = str(e)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Log the action
|
||||
try:
|
||||
from gatehouse_app.models.superadmin_audit_log import SuperadminAuditLog
|
||||
from gatehouse_app import db
|
||||
|
||||
audit_entry = SuperadminAuditLog(
|
||||
superadmin_id=superadmin.id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
extra_data=extra_data,
|
||||
success=success,
|
||||
error_message=error_message
|
||||
)
|
||||
db.session.add(audit_entry)
|
||||
db.session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Superadmin audit: superadmin={superadmin.email} action={action} "
|
||||
f"resource_type={resource_type} resource_id={resource_id} success={success}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Never let audit logging failures break the main operation
|
||||
logger.error(f"Failed to write superadmin audit log: {e}")
|
||||
|
||||
return decorated_function
|
||||
return decorator
|
||||
@@ -1,6 +1,44 @@
|
||||
"""CORS middleware configuration."""
|
||||
from flask import request, make_response
|
||||
|
||||
ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
ALLOWED_HEADERS = (
|
||||
"Content-Type, Authorization, X-Requested-With, X-Request-ID, "
|
||||
"Cache-Control, Pragma, X-WebAuthn-Session-Token"
|
||||
)
|
||||
|
||||
|
||||
def _is_origin_allowed(origin, cors_origins):
|
||||
"""Return True if the origin is permitted by the CORS config.
|
||||
|
||||
Handles both wildcard ("*") and explicit origin lists.
|
||||
"""
|
||||
if not origin:
|
||||
return False
|
||||
if cors_origins == "*":
|
||||
return True
|
||||
if isinstance(cors_origins, list):
|
||||
if "*" in cors_origins:
|
||||
return True
|
||||
return origin in cors_origins
|
||||
return False
|
||||
|
||||
|
||||
def _cors_origin_header(cors_origins, request_origin):
|
||||
"""Return the value for Access-Control-Allow-Origin.
|
||||
|
||||
Per the CORS spec, browsers reject ``*`` when credentials are involved,
|
||||
so we echo the request origin when wildcard + credentials is configured.
|
||||
"""
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
if allow_all and request_origin:
|
||||
return request_origin
|
||||
if allow_all:
|
||||
return "*"
|
||||
if request_origin and request_origin in cors_origins:
|
||||
return request_origin
|
||||
return None
|
||||
|
||||
|
||||
def setup_cors(app):
|
||||
"""
|
||||
@@ -9,6 +47,7 @@ def setup_cors(app):
|
||||
Args:
|
||||
app: Flask application instance
|
||||
"""
|
||||
supports_credentials = app.config.get("CORS_SUPPORTS_CREDENTIALS", True)
|
||||
|
||||
@app.before_request
|
||||
def handle_preflight():
|
||||
@@ -16,49 +55,33 @@ def setup_cors(app):
|
||||
if request.method == "OPTIONS":
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
|
||||
if allow_all:
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
return response
|
||||
elif origin and origin in cors_origins:
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
|
||||
|
||||
if not _is_origin_allowed(origin, cors_origins):
|
||||
return None
|
||||
|
||||
response = make_response("", 204)
|
||||
response.headers["Access-Control-Allow-Origin"] = _cors_origin_header(cors_origins, origin)
|
||||
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
|
||||
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
|
||||
if supports_credentials:
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
return response
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
response.headers["Cache-Control"] = "no-cache, no-store"
|
||||
return response
|
||||
|
||||
@app.after_request
|
||||
def after_request_cors(response):
|
||||
"""Add additional CORS headers if needed."""
|
||||
"""Add CORS headers to non-preflight responses."""
|
||||
origin = request.headers.get("Origin")
|
||||
cors_origins = app.config.get("CORS_ORIGINS", [])
|
||||
|
||||
# Allow all origins if CORS_ORIGINS is "*" (string) or ["*"] (list with wildcard)
|
||||
allow_all = cors_origins == "*" or (isinstance(cors_origins, list) and "*" in cors_origins)
|
||||
|
||||
if allow_all:
|
||||
# When allowing all origins, set header to "*"
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
elif origin and origin in cors_origins:
|
||||
# When allowing specific origins, echo the request origin
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization, X-Requested-With, X-Request-ID, Cache-Control, Pragma, X-WebAuthn-Session-Token"
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
allow_origin = _cors_origin_header(cors_origins, origin)
|
||||
if allow_origin:
|
||||
response.headers["Access-Control-Allow-Origin"] = allow_origin
|
||||
response.headers["Access-Control-Allow-Methods"] = ALLOWED_METHODS
|
||||
response.headers["Access-Control-Allow-Headers"] = ALLOWED_HEADERS
|
||||
if supports_credentials:
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Max-Age"] = "3600"
|
||||
|
||||
return response
|
||||
|
||||
@@ -113,6 +113,14 @@ from gatehouse_app.models.zerotier import ( # noqa: F401
|
||||
ZeroTierMembership,
|
||||
KillSwitchEvent,
|
||||
)
|
||||
|
||||
# ── Superadmin ─────────────────────────────────────────────────────────────────
|
||||
from gatehouse_app.models.superadmin import ( # noqa: F401
|
||||
Superadmin,
|
||||
SuperadminSession,
|
||||
SuperadminSessionStatus,
|
||||
)
|
||||
from gatehouse_app.models.superadmin_audit_log import SuperadminAuditLog # noqa: F401
|
||||
from gatehouse_app.models.security.user_security_policy import ( # noqa: F401
|
||||
UserSecurityPolicy,
|
||||
)
|
||||
@@ -175,4 +183,9 @@ __all__ = [
|
||||
"ActivationSession",
|
||||
"ZeroTierMembership",
|
||||
"KillSwitchEvent",
|
||||
# Superadmin
|
||||
"Superadmin",
|
||||
"SuperadminSession",
|
||||
"SuperadminSessionStatus",
|
||||
"SuperadminAuditLog",
|
||||
]
|
||||
|
||||
@@ -374,7 +374,7 @@ class OAuthState(BaseModel):
|
||||
def cleanup_expired(cls) -> None:
|
||||
"""Remove expired OAuth states."""
|
||||
now = datetime.now(timezone.utc)
|
||||
cls.query.filter(cls.expires_at < now).delete()
|
||||
cls.query.filter(cls.expires_at < now).filter(cls.deleted_at == None).update({"deleted_at": now}, synchronize_session=False)
|
||||
db.session.commit()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
|
||||
@@ -32,7 +32,8 @@ class EmailVerificationToken(BaseModel):
|
||||
|
||||
Any existing unused tokens for this user are invalidated first.
|
||||
"""
|
||||
cls.query.filter_by(user_id=user_id, used_at=None).delete()
|
||||
now = datetime.now(timezone.utc)
|
||||
cls.query.filter_by(user_id=user_id, used_at=None).filter(cls.deleted_at == None).update({"deleted_at": now}, synchronize_session=False)
|
||||
db.session.flush()
|
||||
|
||||
token_value = secrets.token_urlsafe(48)
|
||||
|
||||
@@ -33,7 +33,8 @@ class PasswordResetToken(BaseModel):
|
||||
Any existing unused tokens for this user are invalidated first.
|
||||
"""
|
||||
# Invalidate any existing unused tokens for this user
|
||||
cls.query.filter_by(user_id=user_id, used_at=None).delete()
|
||||
now = datetime.now(timezone.utc)
|
||||
cls.query.filter_by(user_id=user_id, used_at=None).filter(cls.deleted_at == None).update({"deleted_at": now}, synchronize_session=False)
|
||||
db.session.flush()
|
||||
|
||||
token_value = secrets.token_urlsafe(48)
|
||||
|
||||
@@ -13,7 +13,6 @@ class BaseModel(db.Model):
|
||||
db.String(36),
|
||||
primary_key=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
created_at = db.Column(db.DateTime, nullable=False, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Billing models package."""
|
||||
from gatehouse_app.models.billing.plan import Plan
|
||||
from gatehouse_app.models.billing.subscription import Subscription, SubscriptionStatus, BillingCycle
|
||||
|
||||
__all__ = ["Plan", "Subscription", "SubscriptionStatus", "BillingCycle"]
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Plan model for subscription tiers."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, String, Integer, Boolean, DateTime, Text
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Plan(BaseModel):
|
||||
"""Subscription plan definition.
|
||||
|
||||
Represents different pricing tiers that organizations can subscribe to.
|
||||
"""
|
||||
__tablename__ = "plans"
|
||||
|
||||
name = Column(String(100), nullable=False)
|
||||
slug = Column(String(50), unique=True, nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Pricing in cents
|
||||
price_monthly = Column(Integer, nullable=False, default=0) # Price in cents
|
||||
price_yearly = Column(Integer, nullable=False, default=0) # Price in cents
|
||||
|
||||
# User limits
|
||||
included_users = Column(Integer, nullable=False, default=0) # 0 = unlimited
|
||||
|
||||
# Overage pricing (cents per user over limit)
|
||||
overage_rate_per_user = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# Feature flags (JSON)
|
||||
features = Column(Text, nullable=True) # JSON string
|
||||
|
||||
# Stripe integration
|
||||
stripe_price_id_monthly = Column(String(100), nullable=True)
|
||||
stripe_price_id_yearly = Column(String(100), nullable=True)
|
||||
|
||||
# Active/inactive
|
||||
is_active = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Plan {self.slug}: ${self.price_monthly / 100}/mo>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert plan to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"slug": self.slug,
|
||||
"description": self.description,
|
||||
"price_monthly": self.price_monthly,
|
||||
"price_yearly": self.price_yearly,
|
||||
"included_users": self.included_users,
|
||||
"overage_rate_per_user": self.overage_rate_per_user,
|
||||
"features": self.features,
|
||||
"stripe_price_id_monthly": self.stripe_price_id_monthly,
|
||||
"stripe_price_id_yearly": self.stripe_price_id_yearly,
|
||||
"is_active": self.is_active,
|
||||
"created_at": self.created_at.isoformat() + "Z" if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() + "Z" if self.updated_at else None,
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
"""Subscription model for organization billing."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import Column, String, Integer, Boolean, DateTime, ForeignKey, Enum
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
import enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubscriptionStatus(enum.Enum):
|
||||
"""Subscription status values."""
|
||||
TRIAL = "trial"
|
||||
ACTIVE = "active"
|
||||
PAST_DUE = "past_due"
|
||||
CANCELLED = "cancelled"
|
||||
SUSPENDED = "suspended"
|
||||
|
||||
|
||||
class BillingCycle(enum.Enum):
|
||||
"""Billing cycle values."""
|
||||
MONTHLY = "monthly"
|
||||
YEARLY = "yearly"
|
||||
|
||||
|
||||
class Subscription(BaseModel):
|
||||
"""Organization subscription record.
|
||||
|
||||
Links an organization to a plan and tracks billing state.
|
||||
"""
|
||||
__tablename__ = "subscriptions"
|
||||
|
||||
# Organization relation
|
||||
organization_id = Column(
|
||||
String(36),
|
||||
ForeignKey("organizations.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False
|
||||
)
|
||||
|
||||
# Plan relation
|
||||
plan_id = Column(
|
||||
String(36),
|
||||
ForeignKey("plans.id", ondelete="SET NULL"),
|
||||
nullable=True
|
||||
)
|
||||
|
||||
# Status
|
||||
status = Column(
|
||||
Enum(SubscriptionStatus, name="subscription_status"),
|
||||
nullable=False,
|
||||
default=SubscriptionStatus.TRIAL
|
||||
)
|
||||
|
||||
# Billing
|
||||
billing_cycle = Column(
|
||||
Enum(BillingCycle, name="billing_cycle"),
|
||||
nullable=False,
|
||||
default=BillingCycle.MONTHLY
|
||||
)
|
||||
|
||||
# Period dates
|
||||
current_period_start = Column(DateTime, nullable=True)
|
||||
current_period_end = Column(DateTime, nullable=True)
|
||||
|
||||
# Trial
|
||||
trial_ends_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Stripe
|
||||
stripe_subscription_id = Column(String(100), nullable=True)
|
||||
|
||||
# Overage
|
||||
overage_enabled = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
# Cancellation
|
||||
cancelled_at = Column(DateTime, nullable=True)
|
||||
cancel_at_period_end = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Subscription org={self.organization_id} status={self.status.value}>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert subscription to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"organization_id": self.organization_id,
|
||||
"plan_id": self.plan_id,
|
||||
"status": self.status.value if self.status else None,
|
||||
"billing_cycle": self.billing_cycle.value if self.billing_cycle else None,
|
||||
"current_period_start": self.current_period_start.isoformat() + "Z" if self.current_period_start else None,
|
||||
"current_period_end": self.current_period_end.isoformat() + "Z" if self.current_period_end else None,
|
||||
"trial_ends_at": self.trial_ends_at.isoformat() + "Z" if self.trial_ends_at else None,
|
||||
"stripe_subscription_id": self.stripe_subscription_id,
|
||||
"overage_enabled": self.overage_enabled,
|
||||
"cancelled_at": self.cancelled_at.isoformat() + "Z" if self.cancelled_at else None,
|
||||
"cancel_at_period_end": self.cancel_at_period_end,
|
||||
"created_at": self.created_at.isoformat() + "Z" if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() + "Z" if self.updated_at else None,
|
||||
}
|
||||
@@ -78,3 +78,43 @@ class Organization(BaseModel):
|
||||
).first()
|
||||
is not None
|
||||
)
|
||||
def get_active_members(self):
|
||||
"""Get active (non-deleted) organization members.
|
||||
|
||||
Returns:
|
||||
List of OrganizationMember instances where deleted_at is None.
|
||||
"""
|
||||
return [m for m in self.members if m.deleted_at is None]
|
||||
|
||||
def get_active_departments(self):
|
||||
"""Get active (non-deleted) departments.
|
||||
|
||||
Returns:
|
||||
List of Department instances where deleted_at is None.
|
||||
"""
|
||||
return [d for d in self.departments if d.deleted_at is None]
|
||||
|
||||
def get_active_principals(self):
|
||||
"""Get active (non-deleted) principals.
|
||||
|
||||
Returns:
|
||||
List of Principal instances where deleted_at is None.
|
||||
"""
|
||||
return [p for p in self.principals if p.deleted_at is None]
|
||||
|
||||
def get_active_cas(self):
|
||||
"""Get active (non-deleted) certificate authorities.
|
||||
|
||||
Returns:
|
||||
List of CA instances where deleted_at is None.
|
||||
"""
|
||||
return [ca for ca in self.cas if ca.deleted_at is None]
|
||||
|
||||
def get_active_api_keys(self):
|
||||
"""Get active (non-deleted) API keys.
|
||||
|
||||
Returns:
|
||||
List of OrganizationApiKey instances where deleted_at is None.
|
||||
"""
|
||||
return [k for k in self.api_keys if k.deleted_at is None]
|
||||
|
||||
|
||||
@@ -29,6 +29,14 @@ class CertificateAuditLog(BaseModel):
|
||||
index=True,
|
||||
)
|
||||
|
||||
# The organization that owns the CA (null for system CAs)
|
||||
organization_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("organizations.id"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Action type (e.g., "signed", "revoked", "validated", "requested")
|
||||
action = db.Column(db.String(50), nullable=False, index=True)
|
||||
|
||||
@@ -50,6 +58,7 @@ class CertificateAuditLog(BaseModel):
|
||||
# Relationships
|
||||
certificate = db.relationship("SSHCertificate", back_populates="audit_logs")
|
||||
user = db.relationship("User")
|
||||
organization = db.relationship("Organization")
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_cert_audit_cert_action", "certificate_id", "action"),
|
||||
@@ -68,6 +77,7 @@ class CertificateAuditLog(BaseModel):
|
||||
certificate_id: str,
|
||||
action: str,
|
||||
user_id: str = None,
|
||||
organization_id: str = None,
|
||||
**kwargs,
|
||||
) -> "CertificateAuditLog":
|
||||
"""Create a certificate audit log entry.
|
||||
@@ -76,6 +86,7 @@ class CertificateAuditLog(BaseModel):
|
||||
certificate_id: ID of the certificate
|
||||
action: Action type (e.g., "signed", "revoked")
|
||||
user_id: ID of the user performing the action (optional)
|
||||
organization_id: ID of the organization that owns the CA (optional)
|
||||
**kwargs: Additional fields (ip_address, user_agent, message, etc.)
|
||||
|
||||
Returns:
|
||||
@@ -85,6 +96,7 @@ class CertificateAuditLog(BaseModel):
|
||||
certificate_id=certificate_id,
|
||||
action=action,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
**kwargs,
|
||||
)
|
||||
log_entry.save()
|
||||
|
||||
@@ -21,10 +21,10 @@ class SSHKey(BaseModel):
|
||||
)
|
||||
|
||||
# SSH key payload in OpenSSH format (e.g., "ssh-ed25519 AAAAB3Nz...")
|
||||
payload = db.Column(db.Text, nullable=False, unique=True)
|
||||
payload = db.Column(db.Text, nullable=False)
|
||||
|
||||
# SHA256 fingerprint for quick comparison and deduplication
|
||||
fingerprint = db.Column(db.String(255), nullable=False, unique=True, index=True)
|
||||
fingerprint = db.Column(db.String(255), nullable=False, index=True)
|
||||
|
||||
# Optional human-readable description (e.g., "My laptop key")
|
||||
description = db.Column(db.String(255), nullable=True)
|
||||
@@ -53,6 +53,7 @@ class SSHKey(BaseModel):
|
||||
|
||||
__table_args__ = (
|
||||
db.Index("idx_ssh_key_user_verified", "user_id", "verified"),
|
||||
db.UniqueConstraint('user_id', 'fingerprint', name='uix_user_fingerprint'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Superadmin models."""
|
||||
from gatehouse_app.models.superadmin.superadmin import Superadmin
|
||||
from gatehouse_app.models.superadmin.superadmin_session import SuperadminSession, SuperadminSessionStatus
|
||||
|
||||
__all__ = ["Superadmin", "SuperadminSession", "SuperadminSessionStatus"]
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Superadmin model."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Superadmin(BaseModel):
|
||||
"""Superadmin model for SaaS platform operators.
|
||||
|
||||
Completely separate from User model - has its own email/password auth.
|
||||
"""
|
||||
|
||||
__tablename__ = "superadmins"
|
||||
|
||||
email = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
password_hash = db.Column(db.String(255), nullable=False)
|
||||
full_name = db.Column(db.String(255), nullable=True)
|
||||
is_active = db.Column(db.Boolean, default=True, nullable=False)
|
||||
last_login_at = db.Column(db.DateTime, nullable=True)
|
||||
|
||||
# Relationship to sessions
|
||||
sessions = db.relationship(
|
||||
"SuperadminSession",
|
||||
back_populates="superadmin",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Relationship to audit logs
|
||||
audit_logs = db.relationship(
|
||||
"SuperadminAuditLog",
|
||||
back_populates="superadmin",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Superadmin {self.email}>"
|
||||
|
||||
def has_password_auth(self):
|
||||
"""Check if superadmin has password authentication."""
|
||||
return bool(self.password_hash)
|
||||
|
||||
def has_totp_enabled(self):
|
||||
"""Check if superadmin has TOTP enabled."""
|
||||
# TODO: Implement TOTP for superadmin if needed
|
||||
return False
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
exclude.append("password_hash")
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Superadmin session model."""
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuperadminSessionStatus:
|
||||
"""Session status constants."""
|
||||
ACTIVE = "active"
|
||||
REVOKED = "revoked"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class SuperadminSession(BaseModel):
|
||||
"""Session model for superadmin authentication."""
|
||||
|
||||
__tablename__ = "superadmin_sessions"
|
||||
|
||||
superadmin_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("superadmins.id"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
token = db.Column(db.String(255), unique=True, nullable=False, index=True)
|
||||
expires_at = db.Column(db.DateTime, nullable=False)
|
||||
last_activity_at = db.Column(
|
||||
db.DateTime,
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
revoked_at = db.Column(db.DateTime, nullable=True)
|
||||
revoked_reason = db.Column(db.String(255), nullable=True)
|
||||
|
||||
# Relationship
|
||||
superadmin = db.relationship("Superadmin", back_populates="sessions")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SuperadminSession superadmin_id={self.superadmin_id}>"
|
||||
|
||||
def is_active(self):
|
||||
"""Check if session is currently active."""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return (
|
||||
self.deleted_at is None
|
||||
and self.revoked_at is None
|
||||
and expires_at > now
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if session has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return now > expires_at
|
||||
|
||||
def revoke(self, reason: str = None):
|
||||
"""Revoke the session."""
|
||||
self.revoked_at = datetime.now(timezone.utc)
|
||||
if reason:
|
||||
self.revoked_reason = reason
|
||||
from gatehouse_app import db
|
||||
db.session.commit()
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary, excluding sensitive fields."""
|
||||
exclude = exclude or []
|
||||
exclude.append("token")
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Superadmin audit log model."""
|
||||
import logging
|
||||
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuperadminAuditLog(BaseModel):
|
||||
"""Audit log for superadmin actions.
|
||||
|
||||
Records every action performed by superadmins for security and compliance.
|
||||
"""
|
||||
|
||||
__tablename__ = "superadmin_audit_logs"
|
||||
|
||||
superadmin_id = db.Column(
|
||||
db.String(36),
|
||||
db.ForeignKey("superadmins.id"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
action = db.Column(db.String(100), nullable=False, index=True)
|
||||
resource_type = db.Column(db.String(50), nullable=False, index=True)
|
||||
resource_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
org_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
user_id = db.Column(db.String(36), nullable=True, index=True)
|
||||
ip_address = db.Column(db.String(45), nullable=True)
|
||||
user_agent = db.Column(db.Text, nullable=True)
|
||||
request_id = db.Column(db.String(100), nullable=True)
|
||||
extra_data = db.Column(db.JSON, nullable=True)
|
||||
success = db.Column(db.Boolean, default=True, nullable=False)
|
||||
error_message = db.Column(db.String(500), nullable=True)
|
||||
|
||||
# Relationship
|
||||
superadmin = db.relationship("Superadmin", back_populates="audit_logs")
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<SuperadminAuditLog superadmin={self.superadmin_id} "
|
||||
f"action={self.action} resource={self.resource_type}/{self.resource_id}>"
|
||||
)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Convert to dictionary."""
|
||||
exclude = exclude or []
|
||||
return super().to_dict(exclude=exclude)
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Session model."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from flask import current_app
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.base import BaseModel
|
||||
from gatehouse_app.utils.constants import SessionStatus
|
||||
@@ -38,33 +39,71 @@ class Session(BaseModel):
|
||||
return f"<Session user_id={self.user_id} status={self.status}>"
|
||||
|
||||
def is_active(self):
|
||||
"""Check if session is currently active."""
|
||||
"""Check if session is currently active.
|
||||
|
||||
Sessions are evaluated against two independent timeouts:
|
||||
- Idle timeout: expires if no request has been made within
|
||||
SESSION_IDLE_TIMEOUT seconds (default 15 min).
|
||||
- Absolute timeout: expires if SESSION_ABSOLUTE_TIMEOUT seconds
|
||||
have elapsed since the session was created (default 8 h),
|
||||
regardless of activity.
|
||||
|
||||
A session must satisfy *both* constraints to remain active.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
created_at = self.created_at
|
||||
last_activity_at = self.last_activity_at
|
||||
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
if last_activity_at.tzinfo is None:
|
||||
last_activity_at = last_activity_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
idle_timeout = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
|
||||
|
||||
idle_expires_at = last_activity_at + timedelta(seconds=idle_timeout)
|
||||
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
|
||||
|
||||
return (
|
||||
self.status == SessionStatus.ACTIVE
|
||||
and expires_at > now
|
||||
and now < idle_expires_at
|
||||
and now < absolute_expires_at
|
||||
and self.deleted_at is None
|
||||
)
|
||||
|
||||
def is_expired(self):
|
||||
"""Check if session has expired."""
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
return now > expires_at
|
||||
"""Check if session has expired (either idle or absolute)."""
|
||||
return not self.is_active() and self.status != SessionStatus.REVOKED
|
||||
|
||||
def refresh(self, duration_seconds: int = 86400):
|
||||
"""Refresh session expiration.
|
||||
def refresh(self, duration_seconds: int = None):
|
||||
"""Extend the session expiration using a sliding window.
|
||||
|
||||
The new ``expires_at`` is set to *now + idle timeout*, but is
|
||||
capped so that the session never exceeds the absolute lifetime
|
||||
(``created_at + absolute timeout``).
|
||||
|
||||
Args:
|
||||
duration_seconds: New session duration in seconds
|
||||
duration_seconds: Override for the idle timeout. When *None*
|
||||
(the common case), the value is read from
|
||||
``SESSION_IDLE_TIMEOUT`` in the Flask config.
|
||||
"""
|
||||
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration_seconds)
|
||||
self.last_activity_at = datetime.now(timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if duration_seconds is None:
|
||||
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
|
||||
absolute_timeout = current_app.config.get("SESSION_ABSOLUTE_TIMEOUT", 28800)
|
||||
|
||||
idle_expires_at = now + timedelta(seconds=duration_seconds)
|
||||
|
||||
created_at = self.created_at
|
||||
if created_at.tzinfo is None:
|
||||
created_at = created_at.replace(tzinfo=timezone.utc)
|
||||
absolute_expires_at = created_at + timedelta(seconds=absolute_timeout)
|
||||
|
||||
self.expires_at = min(idle_expires_at, absolute_expires_at)
|
||||
self.last_activity_at = now
|
||||
db.session.commit()
|
||||
|
||||
def revoke(self, reason: str = None):
|
||||
|
||||
@@ -116,9 +116,64 @@ class User(BaseModel):
|
||||
is not None
|
||||
)
|
||||
|
||||
def get_active_memberships(self):
|
||||
"""Get active (non-deleted) organization memberships with active organizations.
|
||||
|
||||
Returns:
|
||||
List of OrganizationMember instances where:
|
||||
- membership.deleted_at is None
|
||||
- organization exists and organization.deleted_at is None
|
||||
"""
|
||||
return [
|
||||
m for m in self.organization_memberships
|
||||
if m.deleted_at is None
|
||||
and m.organization
|
||||
and m.organization.deleted_at is None
|
||||
]
|
||||
|
||||
def get_organizations(self):
|
||||
"""Get all organizations the user is a member of."""
|
||||
return [membership.organization for membership in self.organization_memberships]
|
||||
"""Get all active organizations the user is a member of."""
|
||||
return [membership.organization for membership in self.get_active_memberships()]
|
||||
def get_active_ssh_keys(self):
|
||||
"""Get active (non-deleted) SSH keys.
|
||||
|
||||
Returns:
|
||||
List of SSHKey instances where deleted_at is None.
|
||||
"""
|
||||
return [k for k in self.ssh_keys if k.deleted_at is None]
|
||||
|
||||
def get_active_auth_methods(self):
|
||||
"""Get active (non-deleted) authentication methods.
|
||||
|
||||
Returns:
|
||||
List of AuthenticationMethod instances where deleted_at is None.
|
||||
"""
|
||||
return [m for m in self.authentication_methods if m.deleted_at is None]
|
||||
|
||||
def get_active_department_memberships(self):
|
||||
"""Get active (non-deleted) department memberships.
|
||||
|
||||
Returns:
|
||||
List of DepartmentMembership instances where deleted_at is None.
|
||||
"""
|
||||
return [m for m in self.department_memberships if m.deleted_at is None]
|
||||
|
||||
def get_active_principal_memberships(self):
|
||||
"""Get active (non-deleted) principal memberships.
|
||||
|
||||
Returns:
|
||||
List of PrincipalMembership instances where deleted_at is None.
|
||||
"""
|
||||
return [m for m in self.principal_memberships if m.deleted_at is None]
|
||||
|
||||
def get_active_ca_permissions(self):
|
||||
"""Get active (non-deleted) CA permissions.
|
||||
|
||||
Returns:
|
||||
List of CAPermission instances where deleted_at is None.
|
||||
"""
|
||||
return [p for p in self.ca_permissions if p.deleted_at is None]
|
||||
|
||||
|
||||
def has_totp_enabled(self) -> bool:
|
||||
"""Check if user has TOTP enabled and verified.
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Contact form validation schemas."""
|
||||
import logging
|
||||
import re
|
||||
|
||||
from marshmallow import Schema, fields, validate, validates_schema, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContactSchema(Schema):
|
||||
"""Schema for contact form submissions."""
|
||||
|
||||
email = fields.Email(required=True)
|
||||
name = fields.Str(
|
||||
allow_none=True,
|
||||
load_default=None,
|
||||
validate=validate.Length(max=255),
|
||||
)
|
||||
company = fields.Str(
|
||||
allow_none=True,
|
||||
load_default=None,
|
||||
validate=validate.Length(max=255),
|
||||
)
|
||||
enquiry_type = fields.Str(
|
||||
required=True,
|
||||
validate=validate.OneOf(["demo_request", "sales_enquiry", "general", "support"]),
|
||||
)
|
||||
message = fields.Str(
|
||||
allow_none=True,
|
||||
load_default=None,
|
||||
validate=validate.Length(max=2000),
|
||||
)
|
||||
interest_area = fields.Str(
|
||||
allow_none=True,
|
||||
load_default=None,
|
||||
validate=validate.Length(max=100),
|
||||
)
|
||||
_hp = fields.Str(
|
||||
allow_none=True,
|
||||
load_default=None,
|
||||
load_from="_hp",
|
||||
)
|
||||
|
||||
@validates_schema
|
||||
def sanitize_html(self, data, **kwargs):
|
||||
"""Strip HTML tags from all text fields to prevent XSS."""
|
||||
text_fields = ["name", "company", "message", "interest_area"]
|
||||
for field in text_fields:
|
||||
value = data.get(field)
|
||||
if value and isinstance(value, str):
|
||||
data[field] = re.sub(r"<[^>]*>", "", value)
|
||||
@@ -9,6 +9,8 @@ from gatehouse_app.services.oidc_jwks_service import OIDCJWKSService
|
||||
from gatehouse_app.services.oidc_token_service import OIDCTokenService
|
||||
from gatehouse_app.services.oidc_session_service import OIDCSessionService
|
||||
from gatehouse_app.services.oidc_audit_service import OIDCAuditService
|
||||
from gatehouse_app.services.superadmin_auth_service import SuperadminAuthService
|
||||
from gatehouse_app.services.superadmin_organization_service import SuperadminOrganizationService
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
@@ -22,4 +24,6 @@ __all__ = [
|
||||
"OIDCTokenService",
|
||||
"OIDCSessionService",
|
||||
"OIDCAuditService",
|
||||
"SuperadminAuthService",
|
||||
"SuperadminOrganizationService",
|
||||
]
|
||||
|
||||
@@ -36,9 +36,9 @@ class AuthService:
|
||||
Raises:
|
||||
EmailAlreadyExistsError: If email is already registered
|
||||
"""
|
||||
# Check if email already exists
|
||||
existing_user = User.query.filter_by(email=email.lower()).first()
|
||||
if existing_user and existing_user.deleted_at is None:
|
||||
# Check if email already exists
|
||||
existing_user = User.query.filter_by(email=email.lower(), deleted_at=None).first()
|
||||
if existing_user:
|
||||
raise EmailAlreadyExistsError()
|
||||
|
||||
# Create user
|
||||
@@ -140,18 +140,26 @@ class AuthService:
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_session(user, duration_seconds=86400, is_compliance_only=False):
|
||||
def create_session(user, duration_seconds=None, is_compliance_only=False):
|
||||
"""
|
||||
Create a new session for the user.
|
||||
|
||||
Args:
|
||||
user: User instance
|
||||
duration_seconds: Session duration in seconds
|
||||
duration_seconds: Session idle timeout in seconds.
|
||||
When None, defaults to SESSION_IDLE_TIMEOUT from config.
|
||||
The absolute lifetime is always enforced by Session.is_active()
|
||||
regardless of this value.
|
||||
is_compliance_only: Whether this is a compliance-only session (limited access)
|
||||
|
||||
Returns:
|
||||
Session instance
|
||||
"""
|
||||
from flask import current_app
|
||||
|
||||
if duration_seconds is None:
|
||||
duration_seconds = current_app.config.get("SESSION_IDLE_TIMEOUT", 900)
|
||||
|
||||
# Generate session token
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
@@ -280,12 +288,11 @@ class AuthService:
|
||||
raise ConflictError("TOTP is already enabled for this account")
|
||||
|
||||
# Clean up any existing unverified TOTP enrollment attempts
|
||||
# Use hard delete for unverified methods since they're incomplete enrollment attempts
|
||||
# Soft delete for unverified methods since they're incomplete enrollment attempts
|
||||
existing_totp_method = user.get_totp_method()
|
||||
if existing_totp_method and not existing_totp_method.verified:
|
||||
logger.debug(f"Removing existing unverified TOTP method for user {user.id}")
|
||||
db.session.delete(existing_totp_method) # Hard delete - unverified methods are temporary
|
||||
db.session.commit() # Commit to ensure deletion before creating new record
|
||||
existing_totp_method.delete(soft=True) # Soft delete - unverified methods are temporary
|
||||
|
||||
# Generate TOTP secret
|
||||
secret = TOTPService.generate_secret()
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
"""Billing service for superadmin operations."""
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.billing.plan import Plan
|
||||
from gatehouse_app.models.billing.subscription import Subscription, SubscriptionStatus, BillingCycle
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BillingService:
|
||||
"""Service for billing operations."""
|
||||
|
||||
@staticmethod
|
||||
def get_plan(plan_id: str) -> Plan:
|
||||
"""Get a plan by ID."""
|
||||
plan = Plan.query.get(plan_id)
|
||||
if not plan:
|
||||
raise ValueError("Plan not found")
|
||||
return plan
|
||||
|
||||
@staticmethod
|
||||
def list_plans() -> list:
|
||||
"""List all active plans."""
|
||||
return Plan.query.filter(Plan.is_active == True).order_by(Plan.price_monthly.asc()).all()
|
||||
|
||||
@staticmethod
|
||||
def create_subscription(
|
||||
organization_id: str,
|
||||
plan_id: str,
|
||||
billing_cycle: str = "monthly"
|
||||
) -> Subscription:
|
||||
"""Create a new subscription for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: Organization UUID
|
||||
plan_id: Plan UUID
|
||||
billing_cycle: 'monthly' or 'yearly'
|
||||
|
||||
Returns:
|
||||
New subscription
|
||||
"""
|
||||
org = Organization.query.get(organization_id)
|
||||
if not org:
|
||||
raise ValueError("Organization not found")
|
||||
|
||||
plan = Plan.query.get(plan_id)
|
||||
if not plan:
|
||||
raise ValueError("Plan not found")
|
||||
|
||||
# Check if subscription already exists
|
||||
existing = Subscription.query.filter_by(organization_id=organization_id, deleted_at=None).first()
|
||||
if existing:
|
||||
raise ValueError("Organization already has a subscription")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Calculate period
|
||||
if billing_cycle == "yearly":
|
||||
period_end = now + timedelta(days=365)
|
||||
else:
|
||||
period_end = now + timedelta(days=30)
|
||||
|
||||
subscription = Subscription(
|
||||
organization_id=organization_id,
|
||||
plan_id=plan_id,
|
||||
status=SubscriptionStatus.ACTIVE,
|
||||
billing_cycle=BillingCycle.MONTHLY if billing_cycle == "monthly" else BillingCycle.YEARLY,
|
||||
current_period_start=now,
|
||||
current_period_end=period_end,
|
||||
)
|
||||
|
||||
db.session.add(subscription)
|
||||
db.session.commit()
|
||||
|
||||
return subscription
|
||||
|
||||
@staticmethod
|
||||
def change_plan(organization_id: str, new_plan_id: str) -> Subscription:
|
||||
"""Change subscription plan.
|
||||
|
||||
Args:
|
||||
organization_id: Organization UUID
|
||||
new_plan_id: New plan UUID
|
||||
|
||||
Returns:
|
||||
Updated subscription
|
||||
"""
|
||||
subscription = Subscription.query.filter_by(organization_id=organization_id, deleted_at=None).first()
|
||||
if not subscription:
|
||||
raise ValueError("No subscription found for organization")
|
||||
|
||||
new_plan = Plan.query.get(new_plan_id)
|
||||
if not new_plan:
|
||||
raise ValueError("Plan not found")
|
||||
|
||||
subscription.plan_id = new_plan_id
|
||||
db.session.commit()
|
||||
|
||||
return subscription
|
||||
|
||||
@staticmethod
|
||||
def cancel_subscription(organization_id: str) -> Subscription:
|
||||
"""Cancel subscription at period end.
|
||||
|
||||
Args:
|
||||
organization_id: Organization UUID
|
||||
|
||||
Returns:
|
||||
Updated subscription
|
||||
"""
|
||||
subscription = Subscription.query.filter_by(organization_id=organization_id, deleted_at=None).first()
|
||||
if not subscription:
|
||||
raise ValueError("No subscription found for organization")
|
||||
|
||||
subscription.cancel_at_period_end = True
|
||||
subscription.status = SubscriptionStatus.CANCELLED
|
||||
db.session.commit()
|
||||
|
||||
return subscription
|
||||
|
||||
@staticmethod
|
||||
def extend_trial(organization_id: str, days: int) -> Subscription:
|
||||
"""Extend trial period.
|
||||
|
||||
Args:
|
||||
organization_id: Organization UUID
|
||||
days: Number of days to extend
|
||||
|
||||
Returns:
|
||||
Updated subscription
|
||||
"""
|
||||
subscription = Subscription.query.filter_by(organization_id=organization_id, deleted_at=None).first()
|
||||
if not subscription:
|
||||
raise ValueError("No subscription found for organization")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if subscription.trial_ends_at:
|
||||
subscription.trial_ends_at = subscription.trial_ends_at + timedelta(days=days)
|
||||
else:
|
||||
subscription.trial_ends_at = now + timedelta(days=days)
|
||||
|
||||
subscription.status = SubscriptionStatus.TRIAL
|
||||
db.session.commit()
|
||||
|
||||
return subscription
|
||||
|
||||
@staticmethod
|
||||
def calculate_overage(organization_id: str) -> dict:
|
||||
"""Calculate overage charges for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: Organization UUID
|
||||
|
||||
Returns:
|
||||
Overage calculation with details
|
||||
"""
|
||||
subscription = Subscription.query.filter_by(organization_id=organization_id, deleted_at=None).first()
|
||||
if not subscription:
|
||||
return {"has_overage": False, "overage_cost": 0, "user_count": 0, "included_users": 0}
|
||||
|
||||
plan = Plan.query.get(subscription.plan_id) if subscription.plan_id else None
|
||||
if not plan:
|
||||
return {"has_overage": False, "overage_cost": 0, "user_count": 0, "included_users": 0}
|
||||
|
||||
# Count current users
|
||||
user_count = OrganizationMember.query.filter(
|
||||
OrganizationMember.organization_id == organization_id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).count()
|
||||
|
||||
included_users = plan.included_users
|
||||
overage_users = max(0, user_count - included_users)
|
||||
|
||||
if overage_users > 0 and plan.overage_rate_per_user > 0:
|
||||
overage_cost = overage_users * plan.overage_rate_per_user
|
||||
has_overage = True
|
||||
else:
|
||||
overage_cost = 0
|
||||
has_overage = False
|
||||
|
||||
return {
|
||||
"has_overage": has_overage,
|
||||
"user_count": user_count,
|
||||
"included_users": included_users,
|
||||
"overage_users": overage_users,
|
||||
"overage_rate_per_user": plan.overage_rate_per_user,
|
||||
"overage_cost": overage_cost,
|
||||
}
|
||||
@@ -496,3 +496,69 @@ def build_email_verification_resend_html(
|
||||
{get_alert_box("If you didn't request this, you can safely ignore this email.", "info", "🔒")}
|
||||
'''
|
||||
return get_base_html(content, "Verify your Secuird email address", "Please verify your email address")
|
||||
|
||||
|
||||
def build_contact_enquiry_html(
|
||||
enquiry_type: str,
|
||||
submitter_email: str,
|
||||
name: Optional[str],
|
||||
company: Optional[str],
|
||||
interest_area: Optional[str],
|
||||
message: Optional[str],
|
||||
) -> str:
|
||||
"""Build a contact enquiry notification email.
|
||||
|
||||
Args:
|
||||
enquiry_type: One of demo_request, sales_enquiry, general, support
|
||||
submitter_email: Email address of the person submitting the enquiry
|
||||
name: Full name of the submitter (optional)
|
||||
company: Company name (optional)
|
||||
interest_area: Area of interest (optional)
|
||||
message: Free-text message (optional)
|
||||
|
||||
Returns:
|
||||
HTML email string
|
||||
"""
|
||||
# Map enquiry types to display labels and colors
|
||||
type_labels = {
|
||||
"demo_request": ("Demo Request", "info"),
|
||||
"sales_enquiry": ("Sales Enquiry", "success"),
|
||||
"general": ("General Enquiry", "info"),
|
||||
"support": ("Support Request", "warning"),
|
||||
}
|
||||
type_label, alert_type = type_labels.get(enquiry_type, ("Enquiry", "info"))
|
||||
|
||||
name_display = name if name else "Not provided"
|
||||
company_display = company if company else "Not provided"
|
||||
interest_display = interest_area if interest_area else "Not provided"
|
||||
message_display = message if message else "No message provided"
|
||||
|
||||
# Build details table
|
||||
details_rows = f"""
|
||||
{get_detail_row("Enquiry Type", type_label)}
|
||||
{get_detail_row("Submitter Email", submitter_email)}
|
||||
{get_detail_row("Name", name_display)}
|
||||
{get_detail_row("Company", company_display)}
|
||||
{get_detail_row("Interest Area", interest_display)}
|
||||
"""
|
||||
|
||||
content = f'''
|
||||
<h2 style="margin: 0 0 20px 0; color: {TEXT_COLOR}; font-size: 20px; font-weight: 600;">New {type_label}</h2>
|
||||
<p style="margin: 0 0 20px 0; color: {TEXT_COLOR}; font-size: 15px; line-height: 1.6;">
|
||||
A new {type_label.lower()} has been submitted through the Secuird website.
|
||||
</p>
|
||||
{get_alert_box(f"Enquiry type: <strong>{type_label}</strong>", alert_type, "📬")}
|
||||
<table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="margin: 20px 0; background-color: {BACKGROUND_COLOR}; border-radius: 8px;">
|
||||
<tr>
|
||||
<td style="padding: 20px;">
|
||||
<h3 style="margin: 0 0 16px 0; color: {TEXT_COLOR}; font-size: 14px; font-weight: 600;">Enquiry Details</h3>
|
||||
<table role="presentation" width="100%" cellspacing="0" cellpadding="0">
|
||||
{details_rows}
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<h3 style="margin: 20px 0 12px 0; color: {TEXT_COLOR}; font-size: 14px; font-weight: 600;">Message</h3>
|
||||
<p style="margin: 0; color: {TEXT_COLOR}; font-size: 14px; line-height: 1.6; white-space: pre-wrap;">{message_display}</p>
|
||||
'''
|
||||
return get_base_html(content, f"Secuird Website: {type_label}", f"New {type_label} from {submitter_email}")
|
||||
|
||||
@@ -42,7 +42,7 @@ class ExternalAuthService:
|
||||
provider_type_str = provider_type.value if isinstance(provider_type, AuthMethodType) else provider_type
|
||||
|
||||
app_config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type_str
|
||||
provider_type=provider_type_str, deleted_at=None
|
||||
).first()
|
||||
|
||||
if not app_config:
|
||||
@@ -64,6 +64,7 @@ class ExternalAuthService:
|
||||
org_override_obj = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type_str,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if org_override_obj and not org_override_obj.is_enabled:
|
||||
|
||||
@@ -14,7 +14,7 @@ def create_app_provider_config(
|
||||
**kwargs,
|
||||
) -> ApplicationProviderConfig:
|
||||
existing = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
provider_type=provider_type, deleted_at=None
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
@@ -51,7 +51,7 @@ def update_app_provider_config(
|
||||
**updates,
|
||||
) -> ApplicationProviderConfig:
|
||||
config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
provider_type=provider_type, deleted_at=None
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
@@ -90,7 +90,7 @@ def update_app_provider_config(
|
||||
|
||||
def get_app_provider_config(provider_type: str) -> ApplicationProviderConfig:
|
||||
config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
provider_type=provider_type, deleted_at=None
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
@@ -104,13 +104,13 @@ def get_app_provider_config(provider_type: str) -> ApplicationProviderConfig:
|
||||
|
||||
|
||||
def list_app_provider_configs() -> list:
|
||||
configs = ApplicationProviderConfig.query.all()
|
||||
configs = ApplicationProviderConfig.query.filter_by(deleted_at=None).all()
|
||||
return [config.to_dict() for config in configs]
|
||||
|
||||
|
||||
def delete_app_provider_config(provider_type: str) -> bool:
|
||||
config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
provider_type=provider_type, deleted_at=None
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
|
||||
@@ -219,10 +219,11 @@ def authenticate_with_provider(
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
method_type=provider_type,
|
||||
provider_user_id=user_info["provider_user_id"],
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not auth_method:
|
||||
existing_user = User.query.filter_by(email=user_info["email"]).first()
|
||||
existing_user = User.query.filter_by(email=user_info["email"], deleted_at=None).first()
|
||||
|
||||
if existing_user:
|
||||
AuditService.log_external_auth_login_failed(
|
||||
@@ -262,7 +263,7 @@ def authenticate_with_provider(
|
||||
state_record.mark_used()
|
||||
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
session = AuthService.create_session(user=user, organization_id=organization_id)
|
||||
session = AuthService.create_session(user=user)
|
||||
|
||||
AuditService.log_external_auth_login(
|
||||
user_id=user.id,
|
||||
@@ -286,12 +287,13 @@ def unlink_provider(
|
||||
auth_method = AuthenticationMethod.query.filter_by(
|
||||
user_id=user_id,
|
||||
method_type=provider_type,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not auth_method:
|
||||
raise ExternalAuthError("Provider not linked", "PROVIDER_NOT_LINKED", 400)
|
||||
|
||||
other_methods = AuthenticationMethod.query.filter_by(user_id=user_id).count()
|
||||
other_methods = AuthenticationMethod.query.filter_by(user_id=user_id, deleted_at=None).count()
|
||||
if other_methods <= 1:
|
||||
raise ExternalAuthError(
|
||||
"Cannot unlink the last authentication method",
|
||||
|
||||
@@ -16,7 +16,7 @@ def create_org_provider_override(
|
||||
**kwargs,
|
||||
) -> OrganizationProviderOverride:
|
||||
app_config = ApplicationProviderConfig.query.filter_by(
|
||||
provider_type=provider_type
|
||||
provider_type=provider_type, deleted_at=None
|
||||
).first()
|
||||
|
||||
if not app_config:
|
||||
@@ -29,6 +29,7 @@ def create_org_provider_override(
|
||||
existing = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
@@ -69,6 +70,7 @@ def update_org_provider_override(
|
||||
override = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not override:
|
||||
@@ -110,6 +112,7 @@ def get_org_provider_override(
|
||||
override = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not override:
|
||||
@@ -124,7 +127,7 @@ def get_org_provider_override(
|
||||
|
||||
def list_org_provider_overrides(organization_id: str) -> list:
|
||||
overrides = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id
|
||||
organization_id=organization_id, deleted_at=None
|
||||
).all()
|
||||
return [override.to_dict() for override in overrides]
|
||||
|
||||
@@ -133,6 +136,7 @@ def delete_org_provider_override(organization_id: str, provider_type: str) -> bo
|
||||
override = OrganizationProviderOverride.query.filter_by(
|
||||
organization_id=organization_id,
|
||||
provider_type=provider_type,
|
||||
deleted_at=None,
|
||||
).first()
|
||||
|
||||
if not override:
|
||||
|
||||
@@ -1024,17 +1024,17 @@ def revoke_membership_soft(
|
||||
|
||||
|
||||
def hard_delete_membership(membership_id: str) -> None:
|
||||
"""Hard delete a membership after ZeroTier has been cleaned up.
|
||||
"""Soft delete a membership after ZeroTier has been cleaned up.
|
||||
|
||||
Called by the reconciliation job after successfully removing the member
|
||||
from the ZeroTier controller. This is the final DB cleanup step.
|
||||
from the ZeroTier controller. This marks the membership as deleted.
|
||||
"""
|
||||
membership = DeviceNetworkMembership.query.filter(
|
||||
DeviceNetworkMembership.id == membership_id,
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
logger.warning(f"[hard_delete_membership] Membership {membership_id} not found, skipping.")
|
||||
logger.warning(f"[hard_delete_membership] Membership {membership_id} not found or already deleted, skipping.")
|
||||
return
|
||||
|
||||
device = Device.query.get(membership.device_id)
|
||||
@@ -1048,7 +1048,7 @@ def hard_delete_membership(membership_id: str) -> None:
|
||||
except Exception as exc:
|
||||
logger.warning(f"[hard_delete_membership] ZT delete failed for {device.node_id}: {exc}")
|
||||
|
||||
db.session.delete(membership)
|
||||
membership.delete(soft=True)
|
||||
db.session.commit()
|
||||
|
||||
AuditService.log_action(
|
||||
@@ -1061,6 +1061,6 @@ def hard_delete_membership(membership_id: str) -> None:
|
||||
"device_node_id": device.node_id if device else None,
|
||||
"network_id": network.zerotier_network_id if network else None,
|
||||
},
|
||||
description=f"Membership hard-deleted: device {device.node_id if device else 'unknown'} from network",
|
||||
description=f"Membership deleted: device {device.node_id if device else 'unknown'} from network",
|
||||
success=True,
|
||||
)
|
||||
|
||||
@@ -111,7 +111,7 @@ def handle_register_callback(
|
||||
access_token=tokens["access_token"],
|
||||
)
|
||||
|
||||
existing_user = User.query.filter_by(email=user_info["email"]).first()
|
||||
existing_user = User.query.filter_by(email=user_info["email"], deleted_at=None).first()
|
||||
if existing_user:
|
||||
raise OAuthFlowError(
|
||||
f"An account with email {user_info['email']} already exists. "
|
||||
|
||||
@@ -55,9 +55,9 @@ def get_userinfo(access_token: str, validate_access_token_fn) -> Dict:
|
||||
|
||||
def _get_user_roles(user: User) -> list:
|
||||
roles = []
|
||||
if not user or not user.organization_memberships:
|
||||
if not user:
|
||||
return roles
|
||||
for member in user.organization_memberships:
|
||||
for member in user.get_active_memberships():
|
||||
roles.append({
|
||||
"organization_id": str(member.organization_id),
|
||||
"role": member.role.value,
|
||||
|
||||
@@ -324,8 +324,8 @@ class OIDCTokenService:
|
||||
List of role objects with organization_id and role
|
||||
"""
|
||||
roles = []
|
||||
if user and user.organization_memberships:
|
||||
for member in user.organization_memberships:
|
||||
if user:
|
||||
for member in user.get_active_memberships():
|
||||
roles.append({
|
||||
"organization_id": str(member.organization_id),
|
||||
"role": member.role.value
|
||||
|
||||
@@ -70,6 +70,22 @@ class OrganizationService:
|
||||
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_user_org_count(user_id):
|
||||
"""
|
||||
Get the count of organizations a user belongs to.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Count of active memberships (deleted_at is NULL)
|
||||
"""
|
||||
return OrganizationMember.query.filter_by(
|
||||
user_id=user_id,
|
||||
deleted_at=None,
|
||||
).count()
|
||||
|
||||
@staticmethod
|
||||
def get_organization_by_id(org_id):
|
||||
"""
|
||||
@@ -286,7 +302,7 @@ class OrganizationService:
|
||||
Raises:
|
||||
ConflictError: If user is already a member
|
||||
"""
|
||||
# Check if already a member (active or soft-deleted — both blocked by DB unique constraint)
|
||||
# Check for any membership (active or soft-deleted) to enable reactivation
|
||||
existing = OrganizationMember.query.filter_by(
|
||||
user_id=user_id,
|
||||
organization_id=org.id,
|
||||
@@ -294,7 +310,7 @@ class OrganizationService:
|
||||
|
||||
# Development-only debug logging for membership validation
|
||||
if current_app.config.get('ENV') == 'development':
|
||||
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}")
|
||||
logger.debug(f"[Org] Member check: org_id={org.id}, user_id={user_id}, already_member={existing is not None}, soft_deleted={existing.deleted_at is not None if existing else False}")
|
||||
|
||||
if existing:
|
||||
if existing.deleted_at is not None:
|
||||
@@ -363,6 +379,15 @@ class OrganizationService:
|
||||
logger.debug(f"[Org] Member removal: org_id={org.id}, user_id={user_id}, found={member is not None}")
|
||||
|
||||
if member:
|
||||
if member.role == OrganizationRole.OWNER:
|
||||
owner_count = OrganizationMember.query.filter(
|
||||
OrganizationMember.organization_id == org.id,
|
||||
OrganizationMember.role == OrganizationRole.OWNER,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
OrganizationMember.user_id != user_id,
|
||||
).count()
|
||||
if owner_count < 1:
|
||||
raise ValueError("Cannot remove the only owner from an organization. Transfer ownership first.")
|
||||
member.delete(soft=True)
|
||||
|
||||
# Log member removal
|
||||
|
||||
@@ -10,10 +10,10 @@ class SessionService:
|
||||
@staticmethod
|
||||
def get_active_session_by_token(token):
|
||||
"""Get active session by token.
|
||||
|
||||
|
||||
Args:
|
||||
token: The session token string
|
||||
|
||||
|
||||
Returns:
|
||||
Session object if found and active, None otherwise
|
||||
"""
|
||||
@@ -23,6 +23,8 @@ class SessionService:
|
||||
token=token,
|
||||
status=SessionStatus.ACTIVE,
|
||||
deleted_at=None
|
||||
).filter(
|
||||
Session.expires_at > datetime.now(timezone.utc)
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -41,46 +41,47 @@ class SSHKeyService:
|
||||
user_id: str,
|
||||
public_key: str,
|
||||
description: Optional[str] = None,
|
||||
) -> SSHKey:
|
||||
) -> tuple[SSHKey, bool]:
|
||||
"""Add an SSH public key for a user.
|
||||
|
||||
|
||||
Args:
|
||||
user_id: ID of the user
|
||||
public_key: SSH public key in OpenSSH format
|
||||
description: Optional description of the key
|
||||
|
||||
|
||||
Returns:
|
||||
Created SSHKey instance
|
||||
|
||||
Tuple of (SSHKey instance, is_new) where is_new is True for
|
||||
newly created keys, False for existing keys (idempotent).
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If user doesn't exist
|
||||
SSHKeyError: If key format is invalid
|
||||
SSHKeyAlreadyExistsError: If key already exists
|
||||
"""
|
||||
# Verify user exists
|
||||
user = User.query.get(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(f"User {user_id} not found")
|
||||
|
||||
|
||||
# Validate key format
|
||||
if not verify_ssh_key_format(public_key):
|
||||
raise SSHKeyError("Invalid SSH public key format")
|
||||
|
||||
|
||||
# Compute fingerprint
|
||||
try:
|
||||
fingerprint = compute_ssh_fingerprint(public_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compute fingerprint: {str(e)}")
|
||||
raise SSHKeyError(f"Failed to compute key fingerprint: {str(e)}")
|
||||
|
||||
# Check for duplicate (including soft-deleted records — fingerprint is unique in DB)
|
||||
existing = SSHKey.query.filter_by(fingerprint=fingerprint).first()
|
||||
|
||||
# Check for duplicate per user (including soft-deleted records)
|
||||
existing = SSHKey.query.filter_by(
|
||||
user_id=user_id, fingerprint=fingerprint
|
||||
).first()
|
||||
if existing:
|
||||
if existing.deleted_at is not None:
|
||||
# Restore the soft-deleted key: clear deleted_at and update fields
|
||||
existing.deleted_at = None
|
||||
existing.user_id = user_id
|
||||
existing.description = description or existing.description
|
||||
existing.description = description if description is not None else existing.description
|
||||
existing.verified = False
|
||||
existing.verified_at = None
|
||||
existing.verify_text = None
|
||||
@@ -90,15 +91,18 @@ class SSHKeyService:
|
||||
f"Restored soft-deleted SSH key for user {user_id}: "
|
||||
f"fingerprint={fingerprint}"
|
||||
)
|
||||
return existing
|
||||
raise SSHKeyAlreadyExistsError(
|
||||
f"SSH key with fingerprint {fingerprint} already exists"
|
||||
return existing, False
|
||||
# Idempotent: return existing key without error
|
||||
logger.info(
|
||||
f"SSH key already exists for user {user_id}: "
|
||||
f"fingerprint={fingerprint}"
|
||||
)
|
||||
|
||||
return existing, False
|
||||
|
||||
# Extract metadata
|
||||
key_type = extract_ssh_key_type(public_key)
|
||||
key_comment = extract_ssh_key_comment(public_key)
|
||||
|
||||
|
||||
# Create SSH key record
|
||||
ssh_key = SSHKey(
|
||||
user_id=user_id,
|
||||
@@ -109,15 +113,15 @@ class SSHKeyService:
|
||||
key_comment=key_comment,
|
||||
verified=False,
|
||||
)
|
||||
|
||||
|
||||
ssh_key.save()
|
||||
|
||||
|
||||
logger.info(
|
||||
f"SSH key added for user {user_id}: "
|
||||
f"fingerprint={fingerprint}, type={key_type}"
|
||||
)
|
||||
|
||||
return ssh_key
|
||||
|
||||
return ssh_key, True
|
||||
|
||||
def get_ssh_key(self, key_id: str) -> SSHKey:
|
||||
"""Get an SSH key by ID.
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
"""Analytics service for platform-wide statistics."""
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.models.superadmin_audit_log import SuperadminAuditLog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuperadminAnalyticsService:
|
||||
"""Service for platform-wide analytics and statistics."""
|
||||
|
||||
@staticmethod
|
||||
def get_dashboard_stats() -> dict:
|
||||
"""Get dashboard statistics for the overview page.
|
||||
|
||||
Returns:
|
||||
Dashboard stats including org count, user count, etc.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
thirty_days_ago = now - timedelta(days=30)
|
||||
|
||||
# Total organizations
|
||||
total_orgs = Organization.query.filter(Organization.deleted_at.is_(None)).count()
|
||||
active_orgs = Organization.query.filter(
|
||||
Organization.deleted_at.is_(None),
|
||||
Organization.is_active == True, # noqa: E712
|
||||
).count()
|
||||
|
||||
# Total users
|
||||
total_users = User.query.filter(User.deleted_at.is_(None)).count()
|
||||
|
||||
# Active sessions
|
||||
active_sessions = Session.query.filter(
|
||||
Session.deleted_at.is_(None),
|
||||
Session.status == "active",
|
||||
).count()
|
||||
|
||||
# New signups in last 30 days
|
||||
new_users_30d = User.query.filter(
|
||||
User.deleted_at.is_(None),
|
||||
User.created_at >= thirty_days_ago,
|
||||
).count()
|
||||
|
||||
# New organizations in last 30 days
|
||||
new_orgs_30d = Organization.query.filter(
|
||||
Organization.deleted_at.is_(None),
|
||||
Organization.created_at >= thirty_days_ago,
|
||||
).count()
|
||||
|
||||
# Suspended organizations
|
||||
suspended_orgs = Organization.query.filter(
|
||||
Organization.deleted_at.is_(None),
|
||||
Organization.is_active == False, # noqa: E712
|
||||
).count()
|
||||
|
||||
return {
|
||||
"total_organizations": total_orgs,
|
||||
"active_organizations": active_orgs,
|
||||
"suspended_organizations": suspended_orgs,
|
||||
"total_users": total_users,
|
||||
"active_sessions": active_sessions,
|
||||
"new_users_30d": new_users_30d,
|
||||
"new_orgs_30d": new_orgs_30d,
|
||||
"generated_at": now.isoformat() + "Z",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_signup_trends(days: int = 30) -> dict:
|
||||
"""Get signup trends over time.
|
||||
|
||||
Args:
|
||||
days: Number of days to analyze
|
||||
|
||||
Returns:
|
||||
Daily signup data
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
start_date = now - timedelta(days=days)
|
||||
|
||||
# Get all users created in period
|
||||
users = User.query.filter(
|
||||
User.deleted_at.is_(None),
|
||||
User.created_at >= start_date,
|
||||
).all()
|
||||
|
||||
# Group by day
|
||||
daily_signups = {}
|
||||
for i in range(days):
|
||||
date = (start_date + timedelta(days=i)).strftime("%Y-%m-%d")
|
||||
daily_signups[date] = 0
|
||||
|
||||
for user in users:
|
||||
date = user.created_at.strftime("%Y-%m-%d")
|
||||
if date in daily_signups:
|
||||
daily_signups[date] += 1
|
||||
|
||||
# Convert to list
|
||||
history = [
|
||||
{"date": date, "value": count}
|
||||
for date, count in sorted(daily_signups.items())
|
||||
]
|
||||
|
||||
return {
|
||||
"period_start": start_date.isoformat() + "Z",
|
||||
"period_end": now.isoformat() + "Z",
|
||||
"total": len(users),
|
||||
"history": history,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_org_distribution() -> dict:
|
||||
"""Get distribution of organizations by size.
|
||||
|
||||
Returns:
|
||||
Organization size distribution
|
||||
"""
|
||||
orgs = Organization.query.filter(Organization.deleted_at.is_(None)).all()
|
||||
|
||||
distribution = {
|
||||
"solo": 0, # 1 user
|
||||
"small": 0, # 2-10 users
|
||||
"medium": 0, # 11-50 users
|
||||
"large": 0, # 51-200 users
|
||||
"enterprise": 0, # 200+ users
|
||||
}
|
||||
|
||||
for org in orgs:
|
||||
count = org.get_member_count()
|
||||
if count == 1:
|
||||
distribution["solo"] += 1
|
||||
elif count <= 10:
|
||||
distribution["small"] += 1
|
||||
elif count <= 50:
|
||||
distribution["medium"] += 1
|
||||
elif count <= 200:
|
||||
distribution["large"] += 1
|
||||
else:
|
||||
distribution["enterprise"] += 1
|
||||
|
||||
return {
|
||||
"distribution": distribution,
|
||||
"total_orgs": len(orgs),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_recent_activity(limit: int = 20) -> list:
|
||||
"""Get recent superadmin actions.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of actions to return
|
||||
|
||||
Returns:
|
||||
List of recent audit log entries
|
||||
"""
|
||||
logs = SuperadminAuditLog.query.filter(
|
||||
SuperadminAuditLog.deleted_at.is_(None),
|
||||
).order_by(
|
||||
SuperadminAuditLog.created_at.desc()
|
||||
).limit(limit).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": log.id,
|
||||
"superadmin_id": log.superadmin_id,
|
||||
"action": log.action,
|
||||
"resource_type": log.resource_type,
|
||||
"resource_id": log.resource_id,
|
||||
"extra_data": log.extra_data,
|
||||
"ip_address": log.ip_address,
|
||||
"user_agent": log.user_agent,
|
||||
"created_at": log.created_at.isoformat() + "Z" if log.created_at else None,
|
||||
}
|
||||
for log in logs
|
||||
]
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Superadmin authentication service."""
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from flask import request, current_app
|
||||
from gatehouse_app.extensions import db, bcrypt
|
||||
from gatehouse_app.models.superadmin import Superadmin, SuperadminSession
|
||||
from gatehouse_app.exceptions.auth_exceptions import InvalidCredentialsError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuperadminAuthService:
|
||||
"""Service for superadmin authentication operations."""
|
||||
|
||||
@staticmethod
|
||||
def authenticate(email, credentials):
|
||||
"""Authenticate superadmin with email/password credentials.
|
||||
|
||||
Args:
|
||||
email: Superadmin email
|
||||
credentials: Plain text credential
|
||||
|
||||
Returns:
|
||||
Superadmin instance if authentication succeeds
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: If credentials are invalid or account is disabled
|
||||
"""
|
||||
# Find superadmin by email
|
||||
superadmin = Superadmin.query.filter_by(email=email.lower()).first()
|
||||
|
||||
if not superadmin:
|
||||
logger.warning(f"[SuperadminAuth] Login attempt for non-existent email: {email}")
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# Check if account is active
|
||||
if not superadmin.is_active:
|
||||
logger.warning(f"[SuperadminAuth] Login attempt for disabled account: {email}")
|
||||
raise InvalidCredentialsError("Account is disabled")
|
||||
|
||||
# Check credential
|
||||
if not superadmin.password_hash:
|
||||
logger.warning(f"[SuperadminAuth] Login attempt for account with no credential set: {email}")
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# Verify credential
|
||||
password_valid = bcrypt.check_password_hash(superadmin.password_hash, credentials)
|
||||
|
||||
if not password_valid:
|
||||
logger.warning(f"[SuperadminAuth] Invalid password for: {email}")
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# Update last login
|
||||
superadmin.last_login_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"[SuperadminAuth] Successful login for: {email}")
|
||||
return superadmin
|
||||
|
||||
@staticmethod
|
||||
def create_session(superadmin_id, duration_seconds=28800):
|
||||
"""Create a new session for superadmin.
|
||||
|
||||
Args:
|
||||
superadmin_id: Superadmin ID
|
||||
duration_seconds: Session duration in seconds (default 8 hours)
|
||||
|
||||
Returns:
|
||||
SuperadminSession instance
|
||||
"""
|
||||
# Generate secure token
|
||||
token = secrets.token_urlsafe(32)
|
||||
|
||||
# Create session
|
||||
session = SuperadminSession(
|
||||
superadmin_id=superadmin_id,
|
||||
token=token,
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(seconds=duration_seconds),
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get("User-Agent"),
|
||||
)
|
||||
session.save()
|
||||
|
||||
logger.info(f"[SuperadminAuth] Session created for superadmin_id={superadmin_id}")
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
def revoke_session(session_id, reason=None):
|
||||
"""Revoke a superadmin session.
|
||||
|
||||
Args:
|
||||
session_id: Session ID to revoke
|
||||
reason: Optional revocation reason
|
||||
"""
|
||||
session = SuperadminSession.query.get(session_id)
|
||||
if session:
|
||||
session.revoke(reason=reason)
|
||||
logger.info(f"[SuperadminAuth] Session {session_id} revoked: {reason or 'No reason'}")
|
||||
|
||||
@staticmethod
|
||||
def revoke_all_sessions(superadmin_id, except_token=None, reason=None):
|
||||
"""Revoke all sessions for a superadmin.
|
||||
|
||||
Args:
|
||||
superadmin_id: Superadmin ID
|
||||
except_token: Optional token to keep (current session)
|
||||
reason: Optional revocation reason
|
||||
"""
|
||||
query = SuperadminSession.query.filter_by(superadmin_id=superadmin_id)
|
||||
if except_token:
|
||||
query = query.filter(SuperadminSession.token != except_token)
|
||||
|
||||
sessions = query.all()
|
||||
for session in sessions:
|
||||
session.revoke(reason=reason)
|
||||
|
||||
logger.info(f"[SuperadminAuth] Revoked {len(sessions)} sessions for superadmin_id={superadmin_id}")
|
||||
return len(sessions)
|
||||
|
||||
@staticmethod
|
||||
def create_emergency_access(superadmin_id, target_user_id, reason, duration_minutes=15):
|
||||
"""Create emergency access to a user's account.
|
||||
|
||||
This creates a special emergency session that grants temporary elevated access.
|
||||
|
||||
Args:
|
||||
superadmin_id: Superadmin ID initiating emergency access
|
||||
target_user_id: User ID to access
|
||||
reason: Reason for emergency access
|
||||
duration_minutes: Duration of emergency access in minutes
|
||||
|
||||
Returns:
|
||||
Dictionary with emergency session info
|
||||
"""
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.services.auth_service import AuthService
|
||||
from gatehouse_app.services.audit_service import AuditService
|
||||
|
||||
# Verify target user exists
|
||||
target_user = User.query.get(target_user_id)
|
||||
if not target_user:
|
||||
raise ValueError(f"Target user not found: {target_user_id}")
|
||||
|
||||
# Create emergency session for the target user
|
||||
emergency_session = AuthService.create_session(
|
||||
user=target_user,
|
||||
duration_seconds=duration_minutes * 60,
|
||||
is_compliance_only=False
|
||||
)
|
||||
|
||||
# Log the emergency access
|
||||
logger.warning(
|
||||
f"[SuperadminAuth] EMERGENCY ACCESS: superadmin_id={superadmin_id} "
|
||||
f"accessed user_id={target_user_id} reason={reason}"
|
||||
)
|
||||
|
||||
return {
|
||||
"session": emergency_session,
|
||||
"expires_at": emergency_session.expires_at,
|
||||
"reason": reason,
|
||||
"target_user_id": target_user_id,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def hash_password(plain_credential):
|
||||
"""Hash a credential for storage.
|
||||
|
||||
Args:
|
||||
plain_credential: Plain text credential
|
||||
|
||||
Returns:
|
||||
Hashed credential string
|
||||
"""
|
||||
return bcrypt.generate_password_hash(plain_credential).decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def create_superadmin(email, credential, full_name=None):
|
||||
"""Create a new superadmin.
|
||||
|
||||
Args:
|
||||
email: Superadmin email
|
||||
credential: Plain text credential
|
||||
full_name: Optional full name
|
||||
|
||||
Returns:
|
||||
Superadmin instance
|
||||
"""
|
||||
# Check if email already exists
|
||||
existing = Superadmin.query.filter_by(email=email.lower()).first()
|
||||
if existing:
|
||||
raise ValueError(f"Superadmin with email {email} already exists")
|
||||
|
||||
# Hash credential
|
||||
password_hash = bcrypt.generate_password_hash(credential).decode("utf-8")
|
||||
|
||||
# Create superadmin
|
||||
superadmin = Superadmin(
|
||||
email=email.lower(),
|
||||
password_hash=password_hash,
|
||||
full_name=full_name,
|
||||
is_active=True,
|
||||
)
|
||||
superadmin.save()
|
||||
|
||||
logger.info(f"[SuperadminAuth] Created new superadmin: {email}")
|
||||
return superadmin
|
||||
|
||||
@staticmethod
|
||||
def update_superadmin(superadmin_id, **kwargs):
|
||||
"""Update superadmin details.
|
||||
|
||||
Args:
|
||||
superadmin_id: Superadmin ID
|
||||
**kwargs: Fields to update (email, full_name, is_active, credential)
|
||||
|
||||
Returns:
|
||||
Updated Superadmin instance
|
||||
"""
|
||||
superadmin = Superadmin.query.get(superadmin_id)
|
||||
if not superadmin:
|
||||
raise ValueError(f"Superadmin not found: {superadmin_id}")
|
||||
|
||||
# Handle credential update
|
||||
if 'password' in kwargs:
|
||||
kwargs['password_hash'] = bcrypt.generate_password_hash(kwargs.pop('password')).decode("utf-8")
|
||||
|
||||
# Update fields
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(superadmin, key):
|
||||
setattr(superadmin, key, value)
|
||||
|
||||
superadmin.save()
|
||||
logger.info(f"[SuperadminAuth] Updated superadmin_id={superadmin_id}")
|
||||
return superadmin
|
||||
@@ -0,0 +1,244 @@
|
||||
"""Superadmin organization management service."""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.user.session import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuperadminOrganizationService:
|
||||
"""Service for superadmin organization management operations."""
|
||||
|
||||
@staticmethod
|
||||
def list_organizations(
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
search: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
plan_slug: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""List organizations with pagination and filtering.
|
||||
|
||||
Args:
|
||||
page: Page number (1-indexed)
|
||||
per_page: Items per page
|
||||
search: Search by name or slug
|
||||
status: Filter by status (active, suspended)
|
||||
plan_slug: Filter by plan slug (not implemented yet - requires Subscription model)
|
||||
|
||||
Returns:
|
||||
Paginated response with organization summaries
|
||||
"""
|
||||
query = Organization.query
|
||||
|
||||
# Search filter
|
||||
if search:
|
||||
search_term = f"%{search}%"
|
||||
query = query.filter(
|
||||
db.or_(
|
||||
Organization.name.ilike(search_term),
|
||||
Organization.slug.ilike(search_term),
|
||||
)
|
||||
)
|
||||
|
||||
# Status filter
|
||||
if status == "active":
|
||||
query = query.filter(Organization.is_active.is_(True))
|
||||
elif status == "suspended":
|
||||
query = query.filter(Organization.is_active.is_(False))
|
||||
|
||||
# Note: plan_slug filtering requires Plan/Subscription models (Phase 4)
|
||||
# Currently ignored but parameter is accepted for API compatibility
|
||||
|
||||
# Order by created_at desc
|
||||
query = query.order_by(Organization.created_at.desc())
|
||||
|
||||
# Paginate
|
||||
pagination = query.paginate(page=page, per_page=per_page, error_out=False)
|
||||
|
||||
# Build response
|
||||
items = []
|
||||
for org in pagination.items:
|
||||
item = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
"slug": org.slug,
|
||||
"description": org.description,
|
||||
"is_active": org.is_active,
|
||||
"member_count": org.get_member_count(),
|
||||
"created_at": org.created_at.isoformat() + "Z" if org.created_at else None,
|
||||
"updated_at": org.updated_at.isoformat() + "Z" if org.updated_at else None,
|
||||
}
|
||||
items.append(item)
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": pagination.total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": pagination.pages,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_organization_detail(org_id: str) -> dict:
|
||||
"""Get detailed organization information.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
|
||||
Returns:
|
||||
Organization detail with member_count, owner, stats
|
||||
|
||||
Raises:
|
||||
ValueError: If organization not found
|
||||
"""
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
raise ValueError("Organization not found")
|
||||
|
||||
owner = org.get_owner()
|
||||
|
||||
# Count active sessions for org members
|
||||
member_user_ids = [m.user_id for m in org.members if m.deleted_at is None]
|
||||
active_sessions = Session.query.filter(
|
||||
Session.user_id.in_(member_user_ids),
|
||||
Session.deleted_at.is_(None),
|
||||
).count()
|
||||
|
||||
return {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
"slug": org.slug,
|
||||
"description": org.description,
|
||||
"is_active": org.is_active,
|
||||
"settings": org.settings or {},
|
||||
"member_count": org.get_member_count(),
|
||||
"owner": {
|
||||
"id": owner.id,
|
||||
"email": owner.email,
|
||||
"full_name": owner.full_name,
|
||||
} if owner else None,
|
||||
"active_sessions": active_sessions,
|
||||
"created_at": org.created_at.isoformat() + "Z" if org.created_at else None,
|
||||
"updated_at": org.updated_at.isoformat() + "Z" if org.updated_at else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update_organization(
|
||||
org_id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
) -> Organization:
|
||||
"""Update organization details.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
name: New name (optional)
|
||||
description: New description (optional)
|
||||
is_active: New active status (optional)
|
||||
|
||||
Returns:
|
||||
Updated organization
|
||||
|
||||
Raises:
|
||||
ValueError: If organization not found
|
||||
"""
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
raise ValueError("Organization not found")
|
||||
|
||||
if name is not None:
|
||||
org.name = name
|
||||
if description is not None:
|
||||
org.description = description
|
||||
if is_active is not None:
|
||||
org.is_active = is_active
|
||||
|
||||
db.session.commit()
|
||||
logger.info(f"[SuperadminOrg] Updated organization {org_id}")
|
||||
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def suspend_organization(org_id: str) -> Organization:
|
||||
"""Suspend an organization.
|
||||
|
||||
Sets is_active=False and invalidates all member sessions.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
|
||||
Returns:
|
||||
Suspended organization
|
||||
|
||||
Raises:
|
||||
ValueError: If organization not found
|
||||
"""
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
raise ValueError("Organization not found")
|
||||
|
||||
org.is_active = False
|
||||
|
||||
# Invalidate all member sessions
|
||||
member_user_ids = [m.user_id for m in org.members if m.deleted_at is None]
|
||||
Session.query.filter(
|
||||
Session.user_id.in_(member_user_ids),
|
||||
Session.deleted_at.is_(None),
|
||||
).update({"deleted_at": db.func.now()})
|
||||
|
||||
db.session.commit()
|
||||
logger.warning(f"[SuperadminOrg] Suspended organization {org_id}")
|
||||
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def restore_organization(org_id: str) -> Organization:
|
||||
"""Restore a suspended organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
|
||||
Returns:
|
||||
Restored organization
|
||||
|
||||
Raises:
|
||||
ValueError: If organization not found
|
||||
"""
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
raise ValueError("Organization not found")
|
||||
|
||||
org.is_active = True
|
||||
db.session.commit()
|
||||
logger.info(f"[SuperadminOrg] Restored organization {org_id}")
|
||||
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def soft_delete_organization(org_id: str) -> Organization:
|
||||
"""Soft-delete an organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
|
||||
Returns:
|
||||
Soft-deleted organization
|
||||
|
||||
Raises:
|
||||
ValueError: If organization not found
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
raise ValueError("Organization not found")
|
||||
|
||||
org.deleted_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
logger.warning(f"[SuperadminOrg] Soft-deleted organization {org_id}")
|
||||
|
||||
return org
|
||||
@@ -0,0 +1,199 @@
|
||||
"""Usage tracking service for superadmin operations."""
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UsageMetric:
|
||||
"""Usage metric types."""
|
||||
USERS = "users"
|
||||
SESSIONS = "sessions"
|
||||
ACTIVE_SESSIONS = "active_sessions"
|
||||
API_CALLS = "api_calls"
|
||||
|
||||
|
||||
class SuperadminUsageService:
|
||||
"""Service for tracking and retrieving usage metrics."""
|
||||
|
||||
@staticmethod
|
||||
def get_current_usage(org_id: str) -> dict:
|
||||
"""Get current period usage for an organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
|
||||
Returns:
|
||||
Current usage metrics including user count and active sessions
|
||||
"""
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
raise ValueError("Organization not found")
|
||||
|
||||
# Get active member count
|
||||
member_count = org.get_member_count()
|
||||
|
||||
# Get active sessions count
|
||||
member_user_ids = [m.user_id for m in org.members if m.deleted_at is None]
|
||||
active_sessions = Session.query.filter(
|
||||
Session.user_id.in_(member_user_ids),
|
||||
Session.deleted_at.is_(None),
|
||||
Session.status == "active",
|
||||
).count()
|
||||
|
||||
# Get max concurrent sessions this month
|
||||
now = datetime.now(timezone.utc)
|
||||
period_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# For simplicity, we'll track peak concurrent sessions
|
||||
# In production, you'd want a separate tracking table
|
||||
max_sessions_this_month = active_sessions # Placeholder
|
||||
|
||||
return {
|
||||
"organization_id": org_id,
|
||||
"period_start": period_start.isoformat() + "Z",
|
||||
"period_end": now.isoformat() + "Z",
|
||||
"metrics": {
|
||||
"users": {
|
||||
"current": member_count,
|
||||
"limit": None, # Will come from plan
|
||||
"description": "Total organization members",
|
||||
},
|
||||
"active_sessions": {
|
||||
"current": active_sessions,
|
||||
"max_this_month": max_sessions_this_month,
|
||||
"description": "Currently active user sessions",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_usage_history(
|
||||
org_id: str,
|
||||
metric: str,
|
||||
days: int = 30,
|
||||
) -> dict:
|
||||
"""Get usage history for a specific metric.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
metric: Metric type (users, sessions)
|
||||
days: Number of days of history
|
||||
|
||||
Returns:
|
||||
List of daily usage data points
|
||||
"""
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
raise ValueError("Organization not found")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
start_date = now - timedelta(days=days)
|
||||
|
||||
# Get member history (simplified - would need a history table in production)
|
||||
history = []
|
||||
current_count = org.get_member_count()
|
||||
|
||||
# Generate daily data points (placeholder - real implementation needs history table)
|
||||
for i in range(days):
|
||||
date = start_date + timedelta(days=i)
|
||||
history.append({
|
||||
"date": date.strftime("%Y-%m-%d"),
|
||||
"value": current_count, # Simplified
|
||||
})
|
||||
|
||||
return {
|
||||
"organization_id": org_id,
|
||||
"metric": metric,
|
||||
"period_start": start_date.isoformat() + "Z",
|
||||
"period_end": now.isoformat() + "Z",
|
||||
"history": history,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_seat_count_for_period(org_id: str, year: int, month: int) -> dict:
|
||||
"""Calculate maximum seat count used in a given month.
|
||||
|
||||
For billing purposes - tracks the peak number of users.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
year: Year
|
||||
month: Month
|
||||
|
||||
Returns:
|
||||
Seat count data for the period
|
||||
"""
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
raise ValueError("Organization not found")
|
||||
|
||||
# Calculate first and last day of month
|
||||
first_day = datetime(year, month, 1, tzinfo=timezone.utc)
|
||||
if month == 12:
|
||||
last_day = datetime(year + 1, 1, 1, tzinfo=timezone.utc) - timedelta(seconds=1)
|
||||
else:
|
||||
last_day = datetime(year, month + 1, 1, tzinfo=timezone.utc) - timedelta(seconds=1)
|
||||
|
||||
# Get all members that existed during this period
|
||||
members = OrganizationMember.query.filter(
|
||||
OrganizationMember.organization_id == org_id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).all()
|
||||
|
||||
# Count unique users who were members at any point during the month
|
||||
max_seats = len(members)
|
||||
|
||||
# Current count at end of month
|
||||
current_seats = len([m for m in members if m.deleted_at is None or m.deleted_at > last_day])
|
||||
|
||||
return {
|
||||
"organization_id": org_id,
|
||||
"period": f"{year}-{month:02d}",
|
||||
"max_seats": max_seats,
|
||||
"current_seats": current_seats,
|
||||
"period_start": first_day.isoformat() + "Z",
|
||||
"period_end": last_day.isoformat() + "Z",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def adjust_usage(
|
||||
org_id: str,
|
||||
metric: str,
|
||||
adjustment: int,
|
||||
reason: str,
|
||||
superadmin_id: str,
|
||||
) -> dict:
|
||||
"""Apply a manual usage adjustment (credit or charge).
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
metric: Metric to adjust
|
||||
adjustment: Positive (credit) or negative (charge)
|
||||
reason: Reason for adjustment
|
||||
superadmin_id: Superadmin making the adjustment
|
||||
|
||||
Returns:
|
||||
Adjustment confirmation
|
||||
"""
|
||||
org = Organization.query.get(org_id)
|
||||
if not org:
|
||||
raise ValueError("Organization not found")
|
||||
|
||||
# In production, you'd create a UsageAdjustment record
|
||||
# For now, just log and return
|
||||
logger.warning(
|
||||
f"[SuperadminUsage] Adjustment: org={org_id}, metric={metric}, "
|
||||
f"adjustment={adjustment}, reason={reason}, by={superadmin_id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"organization_id": org_id,
|
||||
"metric": metric,
|
||||
"adjustment": adjustment,
|
||||
"reason": reason,
|
||||
"applied_at": datetime.now(timezone.utc).isoformat() + "Z",
|
||||
}
|
||||
@@ -0,0 +1,371 @@
|
||||
"""User management service for superadmin operations."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.user.session import Session
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.extensions import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuperadminUserService:
|
||||
"""Service for managing users across the platform."""
|
||||
|
||||
@staticmethod
|
||||
def list_users(
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
org_id: str = None,
|
||||
status: str = None,
|
||||
search: str = None,
|
||||
) -> dict:
|
||||
"""List users with filters and pagination.
|
||||
|
||||
Args:
|
||||
page: Page number
|
||||
per_page: Items per page
|
||||
org_id: Filter by organization
|
||||
status: Filter by status (active/suspended)
|
||||
search: Search by email or name
|
||||
|
||||
Returns:
|
||||
Paginated user list with metadata
|
||||
"""
|
||||
query = User.query.filter(User.deleted_at.is_(None))
|
||||
|
||||
# Filter by organization
|
||||
if org_id:
|
||||
member_user_ids = db.session.query(OrganizationMember.user_id).filter(
|
||||
OrganizationMember.organization_id == org_id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).all()
|
||||
user_ids = [m.user_id for m in member_user_ids]
|
||||
query = query.filter(User.id.in_(user_ids))
|
||||
|
||||
# Filter by status
|
||||
if status == "suspended":
|
||||
query = query.filter(User.status == "GLOBAL_SUSPENDED")
|
||||
elif status == "active":
|
||||
query = query.filter(User.status != "GLOBAL_SUSPENDED")
|
||||
|
||||
# Search
|
||||
if search:
|
||||
search_filter = f"%{search}%"
|
||||
query = query.filter(
|
||||
db.or_(
|
||||
User.email.ilike(search_filter),
|
||||
User.full_name.ilike(search_filter),
|
||||
)
|
||||
)
|
||||
|
||||
query = query.order_by(User.created_at.desc())
|
||||
|
||||
total = query.count()
|
||||
users = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
items = []
|
||||
for user in users:
|
||||
# Get org memberships
|
||||
memberships = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == user.id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).all()
|
||||
|
||||
orgs = []
|
||||
for m in memberships:
|
||||
org = Organization.query.get(m.organization_id)
|
||||
if org:
|
||||
orgs.append({
|
||||
"org_id": org.id,
|
||||
"org_name": org.name,
|
||||
"role": m.role,
|
||||
})
|
||||
|
||||
# Get active sessions count
|
||||
active_sessions = Session.query.filter(
|
||||
Session.user_id == user.id,
|
||||
Session.deleted_at.is_(None),
|
||||
Session.status == "active",
|
||||
).count()
|
||||
|
||||
items.append({
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"status": user.status,
|
||||
"org_count": len(orgs),
|
||||
"orgs": orgs,
|
||||
"active_sessions": active_sessions,
|
||||
"last_login_at": user.last_login_at.isoformat() + "Z" if user.last_login_at else None,
|
||||
"created_at": user.created_at.isoformat() + "Z" if user.created_at else None,
|
||||
})
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"pages": (total + per_page - 1) // per_page if per_page > 0 else 0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_user_detail(user_id: str) -> dict:
|
||||
"""Get detailed user information.
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
|
||||
Returns:
|
||||
User detail with orgs, sessions, security methods
|
||||
"""
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
raise ValueError("User not found")
|
||||
|
||||
# Get org memberships
|
||||
memberships = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == user_id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).all()
|
||||
|
||||
orgs = []
|
||||
for m in memberships:
|
||||
org = Organization.query.get(m.organization_id)
|
||||
if org:
|
||||
orgs.append({
|
||||
"org_id": org.id,
|
||||
"org_name": org.name,
|
||||
"org_slug": org.slug,
|
||||
"role": m.role,
|
||||
"joined_at": m.created_at.isoformat() + "Z" if m.created_at else None,
|
||||
})
|
||||
|
||||
# Get active sessions
|
||||
sessions = Session.query.filter(
|
||||
Session.user_id == user_id,
|
||||
Session.deleted_at.is_(None),
|
||||
Session.status == "active",
|
||||
).all()
|
||||
|
||||
active_sessions = [{
|
||||
"id": s.id,
|
||||
"ip_address": s.ip_address,
|
||||
"user_agent": s.user_agent,
|
||||
"created_at": s.created_at.isoformat() + "Z" if s.created_at else None,
|
||||
} for s in sessions]
|
||||
|
||||
# Security methods
|
||||
security_methods = []
|
||||
if hasattr(user, 'totp_enabled') and user.totp_enabled:
|
||||
security_methods.append({"type": "totp", "enabled": True})
|
||||
if hasattr(user, 'webauthn_enabled') and user.webauthn_enabled:
|
||||
security_methods.append({"type": "webauthn", "enabled": True})
|
||||
|
||||
return {
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"full_name": user.full_name,
|
||||
"status": user.status,
|
||||
"mfa_enabled": user.mfa_enabled if hasattr(user, 'mfa_enabled') else False,
|
||||
"last_login_at": user.last_login_at.isoformat() + "Z" if user.last_login_at else None,
|
||||
"created_at": user.created_at.isoformat() + "Z" if user.created_at else None,
|
||||
},
|
||||
"organizations": orgs,
|
||||
"active_sessions": active_sessions,
|
||||
"security_methods": security_methods,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def suspend_user(user_id: str) -> dict:
|
||||
"""Globally suspend a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
|
||||
Returns:
|
||||
Updated user info and count of revoked sessions
|
||||
"""
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
raise ValueError("User not found")
|
||||
|
||||
if user.status == "GLOBAL_SUSPENDED":
|
||||
raise ValueError("User is already suspended")
|
||||
|
||||
user.status = "GLOBAL_SUSPENDED"
|
||||
db.session.commit()
|
||||
|
||||
# Revoke all sessions
|
||||
revoked_count = Session.query.filter(
|
||||
Session.user_id == user_id,
|
||||
Session.deleted_at.is_(None),
|
||||
).update({"status": "revoked", "deleted_at": db.func.now()})
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"status": user.status,
|
||||
},
|
||||
"sessions_revoked": revoked_count,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def unsuspend_user(user_id: str) -> dict:
|
||||
"""Remove global suspension from a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
|
||||
Returns:
|
||||
Updated user info
|
||||
"""
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
raise ValueError("User not found")
|
||||
|
||||
if user.status != "GLOBAL_SUSPENDED":
|
||||
raise ValueError("User is not suspended")
|
||||
|
||||
user.status = "active"
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"status": user.status,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def reset_password(user_id: str) -> dict:
|
||||
"""Trigger password reset for user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
|
||||
Returns:
|
||||
Email of user
|
||||
"""
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
raise ValueError("User not found")
|
||||
|
||||
# In production, this would call AuthService.send_password_reset_email
|
||||
logger.info(f"[SuperadminUserService] Password reset requested for {user.email}")
|
||||
|
||||
return {"email": user.email}
|
||||
|
||||
@staticmethod
|
||||
def revoke_all_sessions(user_id: str) -> dict:
|
||||
"""Revoke all sessions for a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
|
||||
Returns:
|
||||
Count of revoked sessions
|
||||
"""
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
raise ValueError("User not found")
|
||||
|
||||
result = Session.query.filter(
|
||||
Session.user_id == user_id,
|
||||
Session.deleted_at.is_(None),
|
||||
).update({"status": "revoked", "deleted_at": db.func.now()})
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"count": result,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def add_to_org(user_id: str, org_id: str, role: str = "member") -> dict:
|
||||
"""Add a user to an organization.
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
org_id: Organization UUID
|
||||
role: Membership role
|
||||
|
||||
Returns:
|
||||
Membership details
|
||||
"""
|
||||
user = User.query.get(user_id)
|
||||
if not user or user.deleted_at is not None:
|
||||
raise ValueError("User not found")
|
||||
|
||||
org = Organization.query.get(org_id)
|
||||
if not org or org.deleted_at is not None:
|
||||
raise ValueError("Organization not found")
|
||||
|
||||
# Check if already a member
|
||||
existing = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == user_id,
|
||||
OrganizationMember.organization_id == org_id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
raise ValueError("User is already a member of this organization")
|
||||
|
||||
membership = OrganizationMember(
|
||||
user_id=user_id,
|
||||
organization_id=org_id,
|
||||
role=role,
|
||||
)
|
||||
db.session.add(membership)
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"organization_id": org_id,
|
||||
"role": role,
|
||||
"joined_at": membership.created_at.isoformat() + "Z" if membership.created_at else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def remove_from_org(user_id: str, org_id: str) -> dict:
|
||||
"""Remove a user from an organization.
|
||||
|
||||
Args:
|
||||
user_id: User UUID
|
||||
org_id: Organization UUID
|
||||
|
||||
Returns:
|
||||
Confirmation
|
||||
"""
|
||||
membership = OrganizationMember.query.filter(
|
||||
OrganizationMember.user_id == user_id,
|
||||
OrganizationMember.organization_id == org_id,
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).first()
|
||||
|
||||
if not membership:
|
||||
raise ValueError("User is not a member of this organization")
|
||||
|
||||
# Check if user is the only owner
|
||||
if membership.role == "owner":
|
||||
owner_count = OrganizationMember.query.filter(
|
||||
OrganizationMember.organization_id == org_id,
|
||||
OrganizationMember.role == "owner",
|
||||
OrganizationMember.deleted_at.is_(None),
|
||||
).count()
|
||||
|
||||
if owner_count <= 1:
|
||||
raise ValueError("Cannot remove the only owner from an organization. Transfer ownership first.")
|
||||
|
||||
membership.deleted_at = db.func.now()
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"organization_id": org_id,
|
||||
}
|
||||
@@ -283,9 +283,9 @@ def reconcile_deleted_memberships() -> dict:
|
||||
if not device or not network:
|
||||
logger.warning(
|
||||
f"[Reconciliation] Membership {membership.id}: missing "
|
||||
f"{'device' if not device else 'network'} — hard-deleting record only."
|
||||
f"{'device' if not device else 'network'} — soft-deleting record only."
|
||||
)
|
||||
db.session.delete(membership)
|
||||
membership.delete(soft=True)
|
||||
db.session.commit()
|
||||
results["deleted"] += 1
|
||||
continue
|
||||
@@ -304,20 +304,20 @@ def reconcile_deleted_memberships() -> dict:
|
||||
except Exception as zt_exc:
|
||||
logger.warning(
|
||||
f"[Reconciliation] ZT delete failed for node {node_id} "
|
||||
f"on {network_label}: {zt_exc} — proceeding with DB hard-delete."
|
||||
f"on {network_label}: {zt_exc} — proceeding with DB soft-delete."
|
||||
)
|
||||
|
||||
db.session.delete(membership)
|
||||
membership.delete(soft=True)
|
||||
db.session.commit()
|
||||
results["deleted"] += 1
|
||||
logger.debug(
|
||||
f"[Reconciliation] Hard-deleted membership {membership.id} "
|
||||
f"[Reconciliation] Soft-deleted membership {membership.id} "
|
||||
f"(node={node_id}, network={network_label})."
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"[Reconciliation] Failed to hard-delete membership {membership.id}: {exc}",
|
||||
f"[Reconciliation] Failed to soft-delete membership {membership.id}: {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
results["errors"] += 1
|
||||
|
||||
@@ -75,6 +75,7 @@ class AuditAction(str, Enum):
|
||||
ORG_MEMBER_REMOVE = "org.member.remove"
|
||||
ORG_MEMBER_ROLE_CHANGE = "org.member.role_change"
|
||||
ORG_OWNERSHIP_TRANSFERRED = "org.ownership.transferred"
|
||||
ORG_INVITE_SENT = "org.invite.sent"
|
||||
|
||||
# Session actions
|
||||
SESSION_CREATE = "session.create"
|
||||
|
||||
@@ -59,11 +59,9 @@ def login_required(f):
|
||||
error_type="SESSION_INACTIVE"
|
||||
)
|
||||
|
||||
# Update last_activity_at timestamp
|
||||
from datetime import datetime, timezone
|
||||
session.last_activity_at = datetime.now(timezone.utc)
|
||||
from gatehouse_app import db
|
||||
db.session.commit()
|
||||
# Extend session via sliding window (updates last_activity_at
|
||||
# and recalculates expires_at within the idle / absolute caps).
|
||||
session.refresh()
|
||||
|
||||
# Set context variables
|
||||
g.current_user = session.user
|
||||
|
||||
@@ -153,6 +153,31 @@ def mfa_compliance_status():
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
@cli.command("cleanup_sessions")
|
||||
def cleanup_sessions():
|
||||
"""Clean up expired user sessions.
|
||||
|
||||
Marks sessions as EXPIRED when they have passed their expires_at
|
||||
timestamp. Safe to run frequently (e.g. every 5 minutes via job_runner).
|
||||
|
||||
Usage:
|
||||
python manage.py cleanup_sessions
|
||||
"""
|
||||
from gatehouse_app.services.session_service import SessionService
|
||||
|
||||
print("=" * 60)
|
||||
print("Session Cleanup Job")
|
||||
print("=" * 60)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
print(f"Start time: {datetime.now(timezone.utc).isoformat()}")
|
||||
|
||||
count = SessionService.cleanup_expired_sessions()
|
||||
|
||||
print(f"Expired sessions marked: {count}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
@cli.command("configure_oauth")
|
||||
@click.argument("provider", required=False)
|
||||
@click.option("--client-id", default=None, help="OAuth client ID")
|
||||
|
||||
@@ -29,8 +29,7 @@ def upgrade():
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_application_provider_configs_provider_type'), 'application_provider_configs', ['provider_type'], unique=True)
|
||||
op.create_table('oidc_jwks_keys',
|
||||
@@ -63,8 +62,7 @@ def upgrade():
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_organizations_slug'), 'organizations', ['slug'], unique=True)
|
||||
op.create_table('users',
|
||||
@@ -81,8 +79,7 @@ def upgrade():
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_activation_key'), 'users', ['activation_key'], unique=True)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
@@ -105,8 +102,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('idx_audit_org', 'audit_logs', ['organization_id', 'created_at'], unique=False)
|
||||
op.create_index('idx_audit_resource', 'audit_logs', ['resource_type', 'resource_id'], unique=False)
|
||||
@@ -135,7 +131,6 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'method_type', 'provider_user_id', name='uix_user_method_provider')
|
||||
)
|
||||
op.create_index('idx_user_method', 'authentication_methods', ['user_id', 'method_type'], unique=False)
|
||||
@@ -165,7 +160,6 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('fingerprint'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('organization_id', 'name', name='uix_org_ca_name')
|
||||
)
|
||||
op.create_index('idx_ca_org_active', 'cas', ['organization_id', 'is_active'], unique=False)
|
||||
@@ -182,7 +176,6 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('organization_id', 'name', name='uix_org_dept_name')
|
||||
)
|
||||
op.create_index(op.f('ix_departments_name'), 'departments', ['name'], unique=False)
|
||||
@@ -202,8 +195,7 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_devices_node_id'), 'devices', ['node_id'], unique=False)
|
||||
op.create_index(op.f('ix_devices_organization_id'), 'devices', ['organization_id'], unique=False)
|
||||
@@ -218,8 +210,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_email_verification_tokens_token'), 'email_verification_tokens', ['token'], unique=True)
|
||||
op.create_index(op.f('ix_email_verification_tokens_user_id'), 'email_verification_tokens', ['user_id'], unique=False)
|
||||
@@ -242,7 +233,6 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('organization_id', 'provider_type', name='uix_org_provider_type')
|
||||
)
|
||||
op.create_index('idx_provider_config_org', 'external_provider_configs', ['organization_id', 'provider_type'], unique=False)
|
||||
@@ -262,8 +252,7 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.ForeignKeyConstraint(['target_user_id'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['triggered_by_user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_kill_switch_events_organization_id'), 'kill_switch_events', ['organization_id'], unique=False)
|
||||
op.create_index(op.f('ix_kill_switch_events_target_user_id'), 'kill_switch_events', ['target_user_id'], unique=False)
|
||||
@@ -285,7 +274,6 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'organization_id', name='uix_user_org_compliance')
|
||||
)
|
||||
op.create_index(op.f('ix_mfa_policy_compliance_organization_id'), 'mfa_policy_compliance', ['organization_id'], unique=False)
|
||||
@@ -310,8 +298,7 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_oauth_states_expires_at'), 'oauth_states', ['expires_at'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_states_organization_id'), 'oauth_states', ['organization_id'], unique=False)
|
||||
@@ -340,8 +327,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_oidc_clients_client_id'), 'oidc_clients', ['client_id'], unique=True)
|
||||
op.create_index(op.f('ix_oidc_clients_organization_id'), 'oidc_clients', ['organization_id'], unique=False)
|
||||
@@ -359,8 +345,7 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['invited_by_id'], ['users.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_org_invite_tokens_email'), 'org_invite_tokens', ['email'], unique=False)
|
||||
op.create_index(op.f('ix_org_invite_tokens_organization_id'), 'org_invite_tokens', ['organization_id'], unique=False)
|
||||
@@ -379,8 +364,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('idx_api_key_last_used', 'organization_api_keys', ['last_used_at'], unique=False)
|
||||
op.create_index('idx_org_api_key_org_active', 'organization_api_keys', ['organization_id', 'is_revoked'], unique=False)
|
||||
@@ -402,7 +386,6 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'organization_id', name='uix_user_org')
|
||||
)
|
||||
op.create_index(op.f('ix_organization_members_organization_id'), 'organization_members', ['organization_id'], unique=False)
|
||||
@@ -421,7 +404,6 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('organization_id', 'provider_type', name='uix_org_provider_override_type')
|
||||
)
|
||||
op.create_index(op.f('ix_organization_provider_overrides_organization_id'), 'organization_provider_overrides', ['organization_id'], unique=False)
|
||||
@@ -439,8 +421,7 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.ForeignKeyConstraint(['updated_by_user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_organization_security_policies_organization_id'), 'organization_security_policies', ['organization_id'], unique=True)
|
||||
op.create_table('password_reset_tokens',
|
||||
@@ -453,8 +434,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_password_reset_tokens_token'), 'password_reset_tokens', ['token'], unique=True)
|
||||
op.create_index(op.f('ix_password_reset_tokens_user_id'), 'password_reset_tokens', ['user_id'], unique=False)
|
||||
@@ -476,7 +456,6 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.ForeignKeyConstraint(['owner_user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('organization_id', 'zerotier_network_id', name='uix_org_zt_network_id')
|
||||
)
|
||||
op.create_index(op.f('ix_portal_networks_organization_id'), 'portal_networks', ['organization_id'], unique=False)
|
||||
@@ -491,7 +470,6 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('organization_id', 'name', name='uix_org_principal_name')
|
||||
)
|
||||
op.create_index(op.f('ix_principals_name'), 'principals', ['name'], unique=False)
|
||||
@@ -513,8 +491,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_sessions_token'), 'sessions', ['token'], unique=True)
|
||||
op.create_index(op.f('ix_sessions_user_id'), 'sessions', ['user_id'], unique=False)
|
||||
@@ -536,7 +513,6 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('payload')
|
||||
)
|
||||
op.create_index('idx_ssh_key_user_verified', 'ssh_keys', ['user_id', 'verified'], unique=False)
|
||||
@@ -556,7 +532,6 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'organization_id', name='uix_user_org_policy')
|
||||
)
|
||||
op.create_index(op.f('ix_user_security_policies_organization_id'), 'user_security_policies', ['organization_id'], unique=False)
|
||||
@@ -572,8 +547,7 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['ca_id'], ['cas.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('ca_id', 'user_id', name='uix_ca_permission'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.UniqueConstraint('ca_id', 'user_id', name='uix_ca_permission')
|
||||
)
|
||||
op.create_index(op.f('ix_ca_permissions_ca_id'), 'ca_permissions', ['ca_id'], unique=False)
|
||||
op.create_index(op.f('ix_ca_permissions_user_id'), 'ca_permissions', ['user_id'], unique=False)
|
||||
@@ -589,8 +563,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_department_cert_policies_department_id'), 'department_cert_policies', ['department_id'], unique=True)
|
||||
op.create_table('department_memberships',
|
||||
@@ -603,7 +576,6 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'department_id', name='uix_user_dept')
|
||||
)
|
||||
op.create_index(op.f('ix_department_memberships_department_id'), 'department_memberships', ['department_id'], unique=False)
|
||||
@@ -618,8 +590,7 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
|
||||
sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('department_id', 'principal_id', name='uix_dept_principal'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.UniqueConstraint('department_id', 'principal_id', name='uix_dept_principal')
|
||||
)
|
||||
op.create_index(op.f('ix_department_principals_department_id'), 'department_principals', ['department_id'], unique=False)
|
||||
op.create_index(op.f('ix_department_principals_principal_id'), 'department_principals', ['principal_id'], unique=False)
|
||||
@@ -640,8 +611,7 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oidc_clients.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_oidc_audit_logs_client_id'), 'oidc_audit_logs', ['client_id'], unique=False)
|
||||
op.create_index(op.f('ix_oidc_audit_logs_event_type'), 'oidc_audit_logs', ['event_type'], unique=False)
|
||||
@@ -668,8 +638,7 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oidc_clients.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_oidc_authorization_codes_client_id'), 'oidc_authorization_codes', ['client_id'], unique=False)
|
||||
op.create_index(op.f('ix_oidc_authorization_codes_expires_at'), 'oidc_authorization_codes', ['expires_at'], unique=False)
|
||||
@@ -693,8 +662,7 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oidc_clients.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_oidc_refresh_tokens_access_token_id'), 'oidc_refresh_tokens', ['access_token_id'], unique=False)
|
||||
op.create_index(op.f('ix_oidc_refresh_tokens_client_id'), 'oidc_refresh_tokens', ['client_id'], unique=False)
|
||||
@@ -718,8 +686,7 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oidc_clients.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_oidc_sessions_client_id'), 'oidc_sessions', ['client_id'], unique=False)
|
||||
op.create_index(op.f('ix_oidc_sessions_expires_at'), 'oidc_sessions', ['expires_at'], unique=False)
|
||||
@@ -755,7 +722,6 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['principal_id'], ['principals.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'principal_id', name='uix_user_principal')
|
||||
)
|
||||
op.create_index(op.f('ix_principal_memberships_principal_id'), 'principal_memberships', ['principal_id'], unique=False)
|
||||
@@ -787,8 +753,7 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['ssh_key_id'], ['ssh_keys.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('ca_id', 'serial', name='uq_ssh_certificates_ca_serial'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.UniqueConstraint('ca_id', 'serial', name='uq_ssh_certificates_ca_serial')
|
||||
)
|
||||
op.create_index('idx_cert_revoked', 'ssh_certificates', ['revoked', 'revoked_at'], unique=False)
|
||||
op.create_index('idx_cert_user_status', 'ssh_certificates', ['user_id', 'status'], unique=False)
|
||||
@@ -816,7 +781,6 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['portal_network_id'], ['portal_networks.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'portal_network_id', 'deleted_at', name='uix_user_network_approval')
|
||||
)
|
||||
op.create_index(op.f('ix_user_network_approvals_organization_id'), 'user_network_approvals', ['organization_id'], unique=False)
|
||||
@@ -840,8 +804,7 @@ def upgrade():
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['certificate_id'], ['ssh_certificates.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('idx_cert_audit_cert_action', 'certificate_audit_logs', ['certificate_id', 'action'], unique=False)
|
||||
op.create_index('idx_cert_audit_user', 'certificate_audit_logs', ['user_id', 'created_at'], unique=False)
|
||||
@@ -868,8 +831,7 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_network_approval_id'], ['user_network_approvals.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('device_id', 'portal_network_id', 'deleted_at', name='uix_device_network'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.UniqueConstraint('device_id', 'portal_network_id', 'deleted_at', name='uix_device_network')
|
||||
)
|
||||
op.create_index(op.f('ix_device_network_memberships_device_id'), 'device_network_memberships', ['device_id'], unique=False)
|
||||
op.create_index(op.f('ix_device_network_memberships_organization_id'), 'device_network_memberships', ['organization_id'], unique=False)
|
||||
@@ -894,8 +856,7 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['device_network_membership_id'], ['device_network_memberships.id'], ),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id')
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_activation_sessions_device_network_membership_id'), 'activation_sessions', ['device_network_membership_id'], unique=False)
|
||||
op.create_index(op.f('ix_activation_sessions_organization_id'), 'activation_sessions', ['organization_id'], unique=False)
|
||||
@@ -917,7 +878,6 @@ def upgrade():
|
||||
sa.ForeignKeyConstraint(['device_network_membership_id'], ['device_network_memberships.id'], ),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('id'),
|
||||
sa.UniqueConstraint('zerotier_network_id', 'node_id', name='uix_zt_network_node')
|
||||
)
|
||||
op.create_index(op.f('ix_zerotier_memberships_device_network_membership_id'), 'zerotier_memberships', ['device_network_membership_id'], unique=False)
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Add organization_id to certificate_audit_logs.
|
||||
|
||||
Revision ID: 8f2d9e4a7c1b
|
||||
Revises: b4cd6c6b3b1c
|
||||
Create Date: 2026-04-23 07:30:00.000000
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '8f2d9e4a7c1b'
|
||||
down_revision = 'b4cd6c6b3b1c'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Add organization_id column to certificate_audit_logs
|
||||
op.add_column(
|
||||
'certificate_audit_logs',
|
||||
sa.Column('organization_id', sa.String(length=36), nullable=True)
|
||||
)
|
||||
|
||||
# Create index on organization_id
|
||||
op.create_index(
|
||||
op.f('ix_certificate_audit_logs_organization_id'),
|
||||
'certificate_audit_logs',
|
||||
['organization_id']
|
||||
)
|
||||
|
||||
# Create foreign key constraint
|
||||
op.create_foreign_key(
|
||||
'fk_cert_audit_log_organization',
|
||||
'certificate_audit_logs',
|
||||
'organizations',
|
||||
['organization_id'],
|
||||
['id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Drop foreign key constraint
|
||||
op.drop_constraint('fk_cert_audit_log_organization', 'certificate_audit_logs', type_='foreignkey')
|
||||
|
||||
# Drop index
|
||||
op.drop_index(op.f('ix_certificate_audit_logs_organization_id'), 'certificate_audit_logs')
|
||||
|
||||
# Drop organization_id column
|
||||
op.drop_column('certificate_audit_logs', 'organization_id')
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Per-user SSH key fingerprint uniqueness.
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 8f2d9e4a7c1b
|
||||
Create Date: 2026-04-24 10:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a1b2c3d4e5f6'
|
||||
down_revision = '8f2d9e4a7c1b'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Drop the global unique constraint on payload
|
||||
op.drop_constraint('ssh_keys_payload_key', 'ssh_keys', type_='unique')
|
||||
|
||||
# Drop the global unique index on fingerprint
|
||||
op.drop_index('ix_ssh_keys_fingerprint', table_name='ssh_keys')
|
||||
|
||||
# Create a non-unique index on fingerprint for query performance
|
||||
op.create_index(op.f('ix_ssh_keys_fingerprint'), 'ssh_keys', ['fingerprint'], unique=False)
|
||||
|
||||
# Add composite unique constraint for per-user fingerprint uniqueness
|
||||
op.create_unique_constraint('uix_user_fingerprint', 'ssh_keys', ['user_id', 'fingerprint'])
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Drop the composite unique constraint
|
||||
op.drop_constraint('uix_user_fingerprint', 'ssh_keys', type_='unique')
|
||||
|
||||
# Drop the non-unique index
|
||||
op.drop_index(op.f('ix_ssh_keys_fingerprint'), table_name='ssh_keys')
|
||||
|
||||
# Recreate the global unique index
|
||||
op.create_index('ix_ssh_keys_fingerprint', 'ssh_keys', ['fingerprint'], unique=True)
|
||||
|
||||
# Recreate the global unique constraint on payload
|
||||
op.create_unique_constraint('ssh_keys_payload_key', 'ssh_keys', ['payload'])
|
||||
@@ -0,0 +1,100 @@
|
||||
"""Superadmin
|
||||
|
||||
Revision ID: b4cd6c6b3b1c
|
||||
Revises: 6a4c4ed4a5c6
|
||||
Create Date: 2026-04-08 16:55:52.646980
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b4cd6c6b3b1c'
|
||||
down_revision = '6a4c4ed4a5c6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# --- Create superadmin tables (not captured by auto-generation) ---
|
||||
op.create_table(
|
||||
'superadmins',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('email', sa.String(length=255), nullable=False),
|
||||
sa.Column('password_hash', sa.String(length=255), nullable=False),
|
||||
sa.Column('full_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.text('true')),
|
||||
sa.Column('last_login_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
op.create_index(op.f('ix_superadmins_email'), 'superadmins', ['email'], unique=True)
|
||||
|
||||
op.create_table(
|
||||
'superadmin_sessions',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('superadmin_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('token', sa.String(length=255), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('last_activity_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.Text(), nullable=True),
|
||||
sa.Column('revoked_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('revoked_reason', sa.String(length=255), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['superadmin_id'], ['superadmins.id']),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
op.create_index(op.f('ix_superadmin_sessions_superadmin_id'), 'superadmin_sessions', ['superadmin_id'])
|
||||
op.create_index(op.f('ix_superadmin_sessions_token'), 'superadmin_sessions', ['token'], unique=True)
|
||||
|
||||
op.create_table(
|
||||
'superadmin_audit_logs',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('superadmin_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('action', sa.String(length=100), nullable=False),
|
||||
sa.Column('resource_type', sa.String(length=50), nullable=False),
|
||||
sa.Column('resource_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('org_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('user_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.Text(), nullable=True),
|
||||
sa.Column('request_id', sa.String(length=100), nullable=True),
|
||||
sa.Column('extra_data', sa.JSON(), nullable=True),
|
||||
sa.Column('success', sa.Boolean(), nullable=False, server_default=sa.text('true')),
|
||||
sa.Column('error_message', sa.String(length=500), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['superadmin_id'], ['superadmins.id']),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
op.create_index(op.f('ix_superadmin_audit_logs_superadmin_id'), 'superadmin_audit_logs', ['superadmin_id'])
|
||||
op.create_index(op.f('ix_superadmin_audit_logs_action'), 'superadmin_audit_logs', ['action'])
|
||||
op.create_index(op.f('ix_superadmin_audit_logs_resource_type'), 'superadmin_audit_logs', ['resource_type'])
|
||||
op.create_index(op.f('ix_superadmin_audit_logs_resource_id'), 'superadmin_audit_logs', ['resource_id'])
|
||||
op.create_index(op.f('ix_superadmin_audit_logs_org_id'), 'superadmin_audit_logs', ['org_id'])
|
||||
op.create_index(op.f('ix_superadmin_audit_logs_user_id'), 'superadmin_audit_logs', ['user_id'])
|
||||
|
||||
|
||||
def downgrade():
|
||||
# --- Drop superadmin tables (reverse order due to FK dependencies) ---
|
||||
op.drop_index(op.f('ix_superadmin_audit_logs_user_id'), table_name='superadmin_audit_logs')
|
||||
op.drop_index(op.f('ix_superadmin_audit_logs_org_id'), table_name='superadmin_audit_logs')
|
||||
op.drop_index(op.f('ix_superadmin_audit_logs_resource_id'), table_name='superadmin_audit_logs')
|
||||
op.drop_index(op.f('ix_superadmin_audit_logs_resource_type'), table_name='superadmin_audit_logs')
|
||||
op.drop_index(op.f('ix_superadmin_audit_logs_action'), table_name='superadmin_audit_logs')
|
||||
op.drop_index(op.f('ix_superadmin_audit_logs_superadmin_id'), table_name='superadmin_audit_logs')
|
||||
op.drop_table('superadmin_audit_logs')
|
||||
|
||||
op.drop_index(op.f('ix_superadmin_sessions_token'), table_name='superadmin_sessions')
|
||||
op.drop_index(op.f('ix_superadmin_sessions_superadmin_id'), table_name='superadmin_sessions')
|
||||
op.drop_table('superadmin_sessions')
|
||||
|
||||
op.drop_index(op.f('ix_superadmins_email'), table_name='superadmins')
|
||||
op.drop_table('superadmins')
|
||||
@@ -6,6 +6,9 @@ python_functions = test_*
|
||||
addopts =
|
||||
-v
|
||||
--strict-markers
|
||||
-p no:maas-django
|
||||
-p no:maas-perftest
|
||||
-p no:maas-seeds
|
||||
--cov=gatehouse_app
|
||||
--cov-report=term-missing
|
||||
--cov-report=html
|
||||
|
||||
Executable
+106
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generic job runner for scheduled tasks in Docker containers.
|
||||
|
||||
Runs a Flask CLI command on a configurable interval with graceful shutdown support.
|
||||
|
||||
Environment Variables:
|
||||
JOB_NAME: Name of the job to run (zerotier_reconciliation, mfa_compliance)
|
||||
JOB_INTERVAL_SECONDS: Seconds between job runs (default: 300)
|
||||
|
||||
Usage:
|
||||
docker run -e JOB_NAME=zerotier_reconciliation -e JOB_INTERVAL_SECONDS=120 app
|
||||
"""
|
||||
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JOB_COMMANDS = {
|
||||
"zerotier_reconciliation": "python manage.py run_zerotier_reconciliation",
|
||||
"mfa_compliance": "python manage.py run_mfa_compliance_job",
|
||||
"cleanup_sessions": "python manage.py cleanup_sessions",
|
||||
}
|
||||
|
||||
shutdown_requested = False
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
global shutdown_requested
|
||||
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
|
||||
shutdown_requested = True
|
||||
|
||||
|
||||
def run_job(job_name: str) -> bool:
|
||||
command = JOB_COMMANDS.get(job_name)
|
||||
if not command:
|
||||
logger.error(f"Unknown job: {job_name}. Valid jobs: {list(JOB_COMMANDS.keys())}")
|
||||
return False
|
||||
|
||||
logger.info(f"Running job: {job_name}")
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
cwd="/app",
|
||||
capture_output=False,
|
||||
)
|
||||
elapsed = time.monotonic() - start_time
|
||||
logger.info(f"Job {job_name} completed in {elapsed:.2f}s with exit code {result.returncode}")
|
||||
return result.returncode == 0
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
logger.error(f"Job {job_name} failed after {elapsed:.2f}s: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
job_name = os.getenv("JOB_NAME")
|
||||
interval = int(os.getenv("JOB_INTERVAL_SECONDS", "300"))
|
||||
|
||||
if not job_name:
|
||||
logger.error("JOB_NAME environment variable is required")
|
||||
sys.exit(1)
|
||||
|
||||
if job_name not in JOB_COMMANDS:
|
||||
logger.error(f"Unknown JOB_NAME: {job_name}. Valid: {list(JOB_COMMANDS.keys())}")
|
||||
sys.exit(1)
|
||||
|
||||
if interval < 10:
|
||||
logger.error(f"JOB_INTERVAL_SECONDS must be at least 10 seconds, got {interval}")
|
||||
sys.exit(1)
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
logger.info(f"Job runner started: {job_name}, interval={interval}s")
|
||||
logger.info(f"Valid jobs: {list(JOB_COMMANDS.keys())}")
|
||||
|
||||
while not shutdown_requested:
|
||||
run_job(job_name)
|
||||
|
||||
if shutdown_requested:
|
||||
break
|
||||
|
||||
logger.info(f"Sleeping for {interval}s until next run...")
|
||||
|
||||
sleep_start = time.monotonic()
|
||||
while time.monotonic() - sleep_start < interval and not shutdown_requested:
|
||||
time.sleep(1)
|
||||
|
||||
logger.info("Job runner stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Seed script for creating superadmin user.
|
||||
|
||||
This script reads SUPERADMIN_EMAIL and SUPERADMIN_SECRET environment variables.
|
||||
If both are present:
|
||||
- Creates a superadmin with the given email and hashed credential
|
||||
- Idempotent: updates existing superadmin if email already exists
|
||||
If either is absent:
|
||||
- No-op (logs that env vars are not set)
|
||||
|
||||
Usage:
|
||||
export SUPERADMIN_EMAIL="admin@example.com"
|
||||
export SUPERADMIN_SECRET="[SetYourSecureSecretHere]"
|
||||
python scripts/seed_superadmin.py
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables FIRST before any app imports
|
||||
load_dotenv()
|
||||
|
||||
from gatehouse_app import create_app
|
||||
from gatehouse_app.extensions import db, bcrypt
|
||||
from gatehouse_app.models.superadmin import Superadmin
|
||||
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
app = create_app()
|
||||
|
||||
with app.app_context():
|
||||
# Read environment variables
|
||||
email = os.environ.get("SUPERADMIN_EMAIL")
|
||||
credential = os.environ.get("SUPERADMIN_SECRET")
|
||||
|
||||
# Check if env vars are set
|
||||
if not email or not credential:
|
||||
logger.info("SUPERADMIN_EMAIL and/or SUPERADMIN_SECRET not set. No-op.")
|
||||
logger.info("To create a superadmin, set both environment variables:")
|
||||
logger.info(" export SUPERADMIN_EMAIL='admin@example.com'")
|
||||
logger.info(" export SUPERADMIN_SECRET='[SetYourSecureSecretHere]'")
|
||||
return
|
||||
|
||||
# Normalize email
|
||||
email = email.lower().strip()
|
||||
|
||||
logger.info(f"Processing superadmin: {email}")
|
||||
|
||||
# Check if superadmin already exists
|
||||
existing = Superadmin.query.filter_by(email=email).first()
|
||||
|
||||
if existing:
|
||||
# Update existing superadmin
|
||||
logger.info(" → Existing superadmin found, updating credential")
|
||||
existing.password_hash = bcrypt.generate_password_hash(credential).decode("utf-8")
|
||||
existing.is_active = True
|
||||
existing.full_name = existing.full_name or "Super Admin"
|
||||
db.session.commit()
|
||||
logger.info(f" → Updated superadmin: {email} (id={existing.id})")
|
||||
print(f"Updated existing superadmin: {email}")
|
||||
else:
|
||||
# Create new superadmin
|
||||
logger.info(" → Creating new superadmin")
|
||||
password_hash = bcrypt.generate_password_hash(credential).decode("utf-8")
|
||||
superadmin = Superadmin(
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
full_name="Super Admin",
|
||||
is_active=True,
|
||||
)
|
||||
db.session.add(superadmin)
|
||||
db.session.commit()
|
||||
logger.info(f" → Created superadmin: {email} (id={superadmin.id})")
|
||||
print(f"Created new superadmin: {email}")
|
||||
|
||||
logger.info("Seed complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+228
@@ -0,0 +1,228 @@
|
||||
# Secuird Integration Test Suite
|
||||
|
||||
This directory contains the integration test suite for the Secuird IAM platform.
|
||||
|
||||
## Quick Start
|
||||
|
||||
Run all integration tests:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
pytest tests/integration/
|
||||
```
|
||||
|
||||
Run a specific test file:
|
||||
|
||||
```bash
|
||||
pytest tests/integration/test_ssh_workflows.py -v
|
||||
```
|
||||
|
||||
Run without coverage (faster):
|
||||
|
||||
```bash
|
||||
pytest tests/integration/ --no-cov
|
||||
```
|
||||
|
||||
Fail fast (stop on first failure):
|
||||
|
||||
```bash
|
||||
pytest tests/integration/ -x
|
||||
```
|
||||
|
||||
Run previously failed tests first:
|
||||
|
||||
```bash
|
||||
pytest tests/integration/ --ff
|
||||
```
|
||||
|
||||
## Test Structure
|
||||
|
||||
```
|
||||
tests/
|
||||
├── conftest.py # Base pytest fixtures (app, client, test_user)
|
||||
├── integration/ # Integration tests
|
||||
│ ├── conftest.py # Integration-specific fixtures and factories
|
||||
│ ├── client/ # Reusable API client library
|
||||
│ │ ├── base.py # SecuirdClient with session management
|
||||
│ │ ├── auth.py # Authentication operations
|
||||
│ │ ├── users.py # User self-service operations
|
||||
│ │ ├── orgs.py # Organization operations
|
||||
│ │ ├── ssh.py # SSH key/cert operations
|
||||
│ │ ├── mfa.py # TOTP/WebAuthn operations
|
||||
│ │ ├── zerotier.py # ZeroTier network operations
|
||||
│ │ └── admin.py # Admin operations
|
||||
│ ├── fixtures/ # Test data and helpers
|
||||
│ │ ├── ssh_keys.py # Test SSH key pairs and helpers
|
||||
│ │ └── test_data.py # Common test data generators
|
||||
│ ├── test_auth_flows.py # Authentication flows (24 tests)
|
||||
│ ├── test_totp_workflows.py # TOTP MFA flows (15 tests)
|
||||
│ ├── test_ssh_workflows.py # SSH key/cert flows (34 tests)
|
||||
│ ├── test_org_workflows.py # Organization & invite flows (27 tests)
|
||||
│ ├── test_multi_org.py # Multi-organization access (4 tests)
|
||||
│ ├── test_self_service.py # User self-service features (9 tests)
|
||||
│ ├── test_admin_ops.py # Admin user management (9 tests)
|
||||
│ ├── test_authorization.py # RBAC & access control (8 tests)
|
||||
│ ├── test_security.py # Security & edge cases (5 tests)
|
||||
│ ├── test_dept_principal.py # Department & principal management (5 tests)
|
||||
│ ├── test_ca_management.py # Certificate authority management (4 tests)
|
||||
│ ├── test_policy_compliance.py # Security policy & compliance (4 tests)
|
||||
│ ├── test_webauthn_workflows.py# WebAuthn passkey flows (5 tests)
|
||||
│ └── test_zerotier.py # ZeroTier network access (8 tests)
|
||||
└── unit/ # Unit tests (existing)
|
||||
```
|
||||
|
||||
## Environment
|
||||
|
||||
- **Python**: 3.10+
|
||||
- **Database**: SQLite in-memory (`sqlite:///:memory:`)
|
||||
- **Rate Limiting**: Disabled in tests (`RATELIMIT_ENABLED = False`)
|
||||
- **CSRF**: Disabled (`WTF_CSRF_ENABLED = False`)
|
||||
- **Email**: Suppressed (`MAIL_SUPPRESS_SEND = True`)
|
||||
|
||||
## Configuration
|
||||
|
||||
The `pytest.ini` file configures:
|
||||
|
||||
- Verbose output (`-v`)
|
||||
- Coverage reporting (`--cov=gatehouse_app`)
|
||||
- Disabled maas plugins that cause import errors (see Known Issues below)
|
||||
- Custom markers for `unit`, `integration`, `slow`, etc.
|
||||
|
||||
## Coverage
|
||||
|
||||
Coverage reports are generated automatically:
|
||||
|
||||
- **Terminal**: printed after each run
|
||||
- **HTML**: `backend/htmlcov/index.html`
|
||||
|
||||
Target coverage: **85% minimum**.
|
||||
|
||||
```bash
|
||||
pytest tests/integration/ --cov=gatehouse_app --cov-fail-under=85
|
||||
```
|
||||
|
||||
## Known Issues
|
||||
|
||||
### maastesting Plugin Import Error
|
||||
|
||||
The `maas` system package installs pytest entry points that fail to load in this environment. The `pytest.ini` file disables them automatically with:
|
||||
|
||||
```ini
|
||||
-p no:maas-django
|
||||
-p no:maas-perftest
|
||||
-p no:maas-seeds
|
||||
```
|
||||
|
||||
If you see `ModuleNotFoundError: No module named 'maastesting'`, these flags are not being applied. Ensure you run pytest from the `backend/` directory.
|
||||
|
||||
### ssh-keygen Not Available
|
||||
|
||||
One test (`test_verify_key_positive` in `test_ssh_workflows.py`) requires `ssh-keygen` to generate real Ed25519 key pairs for signature verification. It is automatically skipped when `ssh-keygen` is not available:
|
||||
|
||||
```bash
|
||||
sudo apt-get install openssh-client # Debian/Ubuntu
|
||||
```
|
||||
|
||||
Other certificate signing tests use a DB helper (`_mark_key_verified`) to bypass the signature requirement in CI environments.
|
||||
|
||||
## Writing New Tests
|
||||
|
||||
### Pattern
|
||||
|
||||
Every test must include a verbose docstring with `WHAT`, `WHY`, and `EXPECTED`:
|
||||
|
||||
```python
|
||||
def test_add_key_positive(self, integration_client, create_test_user):
|
||||
"""TEST: SSH-KEY-01 — Add a new SSH public key.
|
||||
|
||||
WHAT: Authenticated user POSTs a valid public key with a description.
|
||||
WHY: Users must be able to register their SSH keys for later
|
||||
certificate signing and server access.
|
||||
EXPECTED: 201 Created, response contains key id and metadata.
|
||||
"""
|
||||
```
|
||||
|
||||
### Fixtures
|
||||
|
||||
| Fixture | Purpose |
|
||||
|---------|---------|
|
||||
| `integration_client` | Fresh `SecuirdClient` instance per test |
|
||||
| `create_test_user` | Factory returning `{"id", "email", "password", "full_name"}` |
|
||||
| `create_test_org` | Factory returning `{"id", "name", "slug"}` |
|
||||
| `create_test_membership` | Links user to org with a role |
|
||||
| `create_test_ca` | Creates a Certificate Authority for an org |
|
||||
|
||||
### Client Usage
|
||||
|
||||
```python
|
||||
# Authentication
|
||||
integration_client.auth.register(email, password, full_name)
|
||||
integration_client.auth.login(email, password)
|
||||
integration_client.auth.logout()
|
||||
|
||||
# SSH
|
||||
integration_client.ssh.add_key(public_key, description)
|
||||
integration_client.ssh.sign_certificate(key_id=key_id, principals=["deploy"])
|
||||
integration_client.ssh.revoke_certificate(cert_id)
|
||||
|
||||
# Organizations
|
||||
integration_client.orgs.create(name, slug)
|
||||
integration_client.orgs.create_principal(org_id, name)
|
||||
integration_client.orgs.create_ca(org_id, name, ca_type="user")
|
||||
```
|
||||
|
||||
### Assertions
|
||||
|
||||
Use the standard helpers:
|
||||
|
||||
```python
|
||||
def assert_success(response: dict, message_contains: str = "") -> dict:
|
||||
data = response.get("data", {})
|
||||
assert response.get("success") is not False
|
||||
if message_contains:
|
||||
assert message_contains.lower() in response.get("message", "").lower()
|
||||
return data
|
||||
|
||||
# Negative tests
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.ssh.get_key(str(uuid.uuid4()))
|
||||
assert exc_info.value.status_code == 404
|
||||
assert exc_info.value.error_type == "NOT_FOUND"
|
||||
```
|
||||
|
||||
## Test Counts
|
||||
|
||||
| Module | Tests | Focus |
|
||||
|--------|-------|-------|
|
||||
| test_auth_flows.py | 24 | Registration, login, logout, sessions, password reset, email verification |
|
||||
| test_totp_workflows.py | 15 | TOTP enrollment, verification, backup codes, disable, regenerate |
|
||||
| test_ssh_workflows.py | 34 | Key CRUD, verification, certificate signing & management |
|
||||
| test_org_workflows.py | 27 | Org CRUD, members, roles, invites, ownership transfer |
|
||||
| test_multi_org.py | 4 | Cross-org isolation, role-based access |
|
||||
| test_self_service.py | 9 | Profile, password change, account deletion |
|
||||
| test_admin_ops.py | 9 | Suspend, unsuspend, verify email, set password, remove MFA, hard delete |
|
||||
| test_authorization.py | 8 | RBAC, cross-user isolation, soft-delete behavior |
|
||||
| test_security.py | 5 | SQL injection, XSS, oversized payload, malformed JSON, empty body |
|
||||
| test_dept_principal.py | 5 | Department/principal CRUD, membership, linking |
|
||||
| test_ca_management.py | 4 | CA creation, listing, rotation |
|
||||
| test_policy_compliance.py | 4 | Security policy, MFA compliance |
|
||||
| test_webauthn_workflows.py | 5 | WebAuthn registration/login (mocked) |
|
||||
| test_zerotier.py | 8 | Network CRUD, devices, approvals, memberships (mocked) |
|
||||
| **Total** | **162** | |
|
||||
|
||||
## Pre-Commit Checklist
|
||||
|
||||
Before committing backend changes:
|
||||
|
||||
1. Run the integration suite: `pytest tests/integration/ -x`
|
||||
2. Verify coverage hasn't decreased: `pytest tests/integration/ --cov=gatehouse_app --cov-fail-under=85`
|
||||
3. If tests fail, fix before committing
|
||||
|
||||
## CI/CD
|
||||
|
||||
Integration tests run automatically on:
|
||||
- Every pull request
|
||||
- Every push to main
|
||||
- Nightly builds
|
||||
|
||||
**Failure policy**: Integration test failures block merging.
|
||||
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
@@ -0,0 +1 @@
|
||||
# API tests package
|
||||
@@ -0,0 +1 @@
|
||||
# API v1 tests package
|
||||
@@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
from gatehouse_app.utils.constants import AuthMethodType
|
||||
from gatehouse_app.services.external_auth.models import ExternalAuthError
|
||||
from gatehouse_app.api.v1.external_auth._helpers import (
|
||||
get_provider_type,
|
||||
_get_provider_endpoints,
|
||||
)
|
||||
|
||||
|
||||
class TestProviderType:
|
||||
def test_google(self):
|
||||
assert get_provider_type("google") == AuthMethodType.GOOGLE
|
||||
|
||||
def test_github(self):
|
||||
assert get_provider_type("github") == AuthMethodType.GITHUB
|
||||
|
||||
def test_microsoft(self):
|
||||
assert get_provider_type("microsoft") == AuthMethodType.MICROSOFT
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert get_provider_type("GitHub") == AuthMethodType.GITHUB
|
||||
|
||||
def test_unknown_provider_raises(self):
|
||||
with pytest.raises(ExternalAuthError) as exc_info:
|
||||
get_provider_type("facebook")
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "facebook" in exc_info.value.message.lower()
|
||||
|
||||
|
||||
class TestProviderEndpoints:
|
||||
def test_google_endpoints(self):
|
||||
auth, token, userinfo = _get_provider_endpoints(AuthMethodType.GOOGLE)
|
||||
assert "accounts.google.com" in auth
|
||||
assert "oauth2.googleapis.com" in token
|
||||
assert "googleapis.com" in userinfo
|
||||
|
||||
def test_github_endpoints(self):
|
||||
auth, token, userinfo = _get_provider_endpoints(AuthMethodType.GITHUB)
|
||||
assert "github.com/login" in auth
|
||||
assert "github.com/login/oauth/access_token" in token
|
||||
assert "api.github.com/user" in userinfo
|
||||
|
||||
def test_microsoft_endpoints(self):
|
||||
auth, token, userinfo = _get_provider_endpoints(AuthMethodType.MICROSOFT)
|
||||
assert "login.microsoftonline.com" in auth
|
||||
assert "login.microsoftonline.com" in token
|
||||
assert "graph.microsoft.com" in userinfo
|
||||
|
||||
def test_unknown_type_raises(self):
|
||||
with pytest.raises(ExternalAuthError) as exc_info:
|
||||
_get_provider_endpoints("nonexistent")
|
||||
assert exc_info.value.status_code == 400
|
||||
@@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
from gatehouse_app.api.v1.organizations._helpers import _get_system_ca_dict
|
||||
from gatehouse_app.config.ssh_ca_config import SSHCAConfig, reset_config_instance
|
||||
|
||||
# Ed25519 key fixture data
|
||||
VALID_PRIVATE_KEY = (
|
||||
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
|
||||
"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\n"
|
||||
"QyNTUxOQAAACCi+2CgIPgoFL5P6DZlNXztuHy3+TuS2shh/xIDkW89OgAAAJhDQd+ZQ0Hf\n"
|
||||
"mQAAAAtzc2gtZWQyNTUxOQAAACCi+2CgIPgoFL5P6DZlNXztuHy3+TuS2shh/xIDkW89Og\n"
|
||||
"AAAECMbnF+1E22w9Z1AOTUbUGspL8Pb0UyP+p8lSLpAwZSpaL7YKAg+CgUvk/oNmU1fO24\n"
|
||||
"fLf5O5LayGH/EgORbz06AAAAD2NvcnlAbGFwdG9wLXZtMQECAwQFBg==\n"
|
||||
"-----END OPENSSH PRIVATE KEY-----"
|
||||
)
|
||||
|
||||
|
||||
class FakeEmptyConfig(SSHCAConfig):
|
||||
def get_str(self, key, default=""):
|
||||
if key == "ca_key_path":
|
||||
return ""
|
||||
return default
|
||||
|
||||
|
||||
class BadConfig(SSHCAConfig):
|
||||
def get_str(self, key, default=""):
|
||||
raise RuntimeError("config error")
|
||||
|
||||
|
||||
class TestSystemCADict:
|
||||
|
||||
def test_no_key_available_returns_none(self, monkeypatch):
|
||||
monkeypatch.delenv("SSH_CA_PRIVATE_KEY", raising=False)
|
||||
reset_config_instance()
|
||||
monkeypatch.setattr(
|
||||
"gatehouse_app.config.ssh_ca_config.get_ssh_ca_config",
|
||||
lambda: FakeEmptyConfig(),
|
||||
)
|
||||
result = _get_system_ca_dict()
|
||||
assert result is None
|
||||
|
||||
def test_env_var_returns_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("SSH_CA_PRIVATE_KEY", VALID_PRIVATE_KEY)
|
||||
result = _get_system_ca_dict()
|
||||
assert result is not None
|
||||
assert result["ca_type"] == "user"
|
||||
assert result["is_system"] is True
|
||||
assert "fingerprint" in result
|
||||
assert result["public_key"]
|
||||
assert result["public_key"].startswith("ssh-")
|
||||
|
||||
def test_exception_gracefully_returns_none(self, monkeypatch):
|
||||
monkeypatch.delenv("SSH_CA_PRIVATE_KEY", raising=False)
|
||||
reset_config_instance()
|
||||
monkeypatch.setattr(
|
||||
"gatehouse_app.config.ssh_ca_config.get_ssh_ca_config",
|
||||
lambda: BadConfig(),
|
||||
)
|
||||
result = _get_system_ca_dict()
|
||||
assert result is None
|
||||
@@ -0,0 +1 @@
|
||||
# SSH tests package
|
||||
@@ -0,0 +1,79 @@
|
||||
"""Pytest fixtures for API tests."""
|
||||
import pytest
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from gatehouse_app import create_app, db
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.ssh_ca.ca import CA, CaType, KeyType
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create test Flask app with in-memory SQLite."""
|
||||
app = create_app(config_name="testing")
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
app.config["TESTING"] = True
|
||||
app.config["WTF_CSRF_ENABLED"] = False
|
||||
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
yield app
|
||||
db.session.remove()
|
||||
db.drop_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user(app):
|
||||
"""Create a test user."""
|
||||
with app.app_context():
|
||||
user = User(email="test_user@test.com", full_name="Test User")
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
return user.id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_org(app):
|
||||
"""Create a test organization."""
|
||||
with app.app_context():
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
db.session.add(org)
|
||||
db.session.commit()
|
||||
return org.id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_membership(app, test_user, test_org):
|
||||
"""Create a test membership."""
|
||||
with app.app_context():
|
||||
membership = OrganizationMember(
|
||||
user_id=test_user,
|
||||
organization_id=test_org,
|
||||
role=OrganizationRole.MEMBER,
|
||||
)
|
||||
db.session.add(membership)
|
||||
db.session.commit()
|
||||
return membership.id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_ca(app, test_org, test_membership):
|
||||
"""Create a test CA."""
|
||||
with app.app_context():
|
||||
ca = CA(
|
||||
organization_id=test_org,
|
||||
name="Test CA",
|
||||
ca_type=CaType.USER,
|
||||
key_type=KeyType.ED25519,
|
||||
private_key="encrypted_private_key_placeholder",
|
||||
public_key="ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI...",
|
||||
fingerprint="sha256:TEST123...",
|
||||
is_active=True,
|
||||
)
|
||||
db.session.add(ca)
|
||||
db.session.commit()
|
||||
return ca.id
|
||||
@@ -0,0 +1,143 @@
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.ssh_ca.ca import CA, CaType, KeyType
|
||||
from gatehouse_app.api.v1.ssh._helpers import _get_org_ca_for_user
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
|
||||
class TestCASoftDelete:
|
||||
"""Test CA soft delete handling."""
|
||||
|
||||
def test_active_ca_is_returned(self, app, test_user, test_org, test_ca, test_membership):
|
||||
"""Active CA should be returned."""
|
||||
with app.app_context():
|
||||
user = db.session.get(User, test_user)
|
||||
ca = _get_org_ca_for_user(user, ca_type='user')
|
||||
assert ca is not None
|
||||
assert ca.id == test_ca
|
||||
|
||||
def test_deleted_ca_is_not_returned(self, app, test_user, test_org, test_membership):
|
||||
"""Deleted CA should not be returned."""
|
||||
with app.app_context():
|
||||
ca = CA(
|
||||
organization_id=test_org,
|
||||
name='Deleted CA',
|
||||
ca_type=CaType.USER,
|
||||
key_type=KeyType.ED25519,
|
||||
private_key='key',
|
||||
public_key='pubkey',
|
||||
fingerprint='sha256:deleted123',
|
||||
is_active=True,
|
||||
deleted_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.session.add(ca)
|
||||
db.session.commit()
|
||||
|
||||
user = db.session.get(User, test_user)
|
||||
result = _get_org_ca_for_user(user, ca_type='user')
|
||||
assert result is None
|
||||
|
||||
def test_deleted_membership_no_access(self, app, test_org, test_ca):
|
||||
"""User with deleted membership should not access CA."""
|
||||
with app.app_context():
|
||||
user = User(email='deleted_member@test.com', full_name='Deleted Member')
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
membership = OrganizationMember(
|
||||
user_id=user.id,
|
||||
organization_id=test_org,
|
||||
role=OrganizationRole.MEMBER,
|
||||
deleted_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.session.add(membership)
|
||||
db.session.commit()
|
||||
|
||||
result = _get_org_ca_for_user(user, ca_type='user')
|
||||
assert result is None
|
||||
|
||||
def test_deleted_org_no_access(self, app):
|
||||
"""User in deleted org should not access CA."""
|
||||
with app.app_context():
|
||||
org = Organization(
|
||||
name='Deleted Org',
|
||||
slug='deleted-org',
|
||||
deleted_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.session.add(org)
|
||||
db.session.commit()
|
||||
|
||||
user = User(email='user@deleted.org', full_name='User')
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
membership = OrganizationMember(
|
||||
user_id=user.id,
|
||||
organization_id=org.id,
|
||||
role=OrganizationRole.MEMBER
|
||||
)
|
||||
db.session.add(membership)
|
||||
|
||||
ca = CA(
|
||||
organization_id=org.id,
|
||||
name='CA in Deleted Org',
|
||||
ca_type=CaType.USER,
|
||||
key_type=KeyType.ED25519,
|
||||
private_key='key',
|
||||
public_key='pubkey',
|
||||
fingerprint='sha256:deletedorg123',
|
||||
is_active=True
|
||||
)
|
||||
db.session.add(ca)
|
||||
db.session.commit()
|
||||
|
||||
result = _get_org_ca_for_user(user, ca_type='user')
|
||||
assert result is None
|
||||
|
||||
def test_get_active_memberships_excludes_deleted(self, app, test_user, test_org, test_membership):
|
||||
"""User.get_active_memberships() should exclude deleted memberships."""
|
||||
with app.app_context():
|
||||
user = db.session.get(User, test_user)
|
||||
|
||||
org2 = Organization(name='Org 2', slug='org-2')
|
||||
db.session.add(org2)
|
||||
db.session.commit()
|
||||
|
||||
membership2 = OrganizationMember(
|
||||
user_id=test_user,
|
||||
organization_id=org2.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
deleted_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.session.add(membership2)
|
||||
db.session.commit()
|
||||
|
||||
active = user.get_active_memberships()
|
||||
assert len(active) == 1
|
||||
assert active[0].organization_id == test_org
|
||||
|
||||
def test_get_organizations_excludes_deleted(self, app, test_user, test_org, test_membership):
|
||||
"""User.get_organizations() should exclude deleted memberships/orgs."""
|
||||
with app.app_context():
|
||||
user = db.session.get(User, test_user)
|
||||
|
||||
org2 = Organization(name='Deleted Org', slug='deleted-org-2')
|
||||
db.session.add(org2)
|
||||
db.session.commit()
|
||||
|
||||
membership2 = OrganizationMember(
|
||||
user_id=test_user,
|
||||
organization_id=org2.id,
|
||||
role=OrganizationRole.MEMBER,
|
||||
deleted_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.session.add(membership2)
|
||||
db.session.commit()
|
||||
|
||||
orgs = user.get_organizations()
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].id == test_org
|
||||
@@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
from gatehouse_app.api.v1.ssh._helpers import _classify_ssh_key_material
|
||||
|
||||
|
||||
class TestClassifySSHKeyMaterial:
|
||||
def test_classifies_certificate(self):
|
||||
result = _classify_ssh_key_material("ssh-ed25519-cert-v01@openssh.com AAAA comment")
|
||||
assert result == "certificate"
|
||||
|
||||
def test_classifies_ed25519_public_key(self):
|
||||
result = _classify_ssh_key_material("ssh-ed25519 AAAAB3NzaC1lZDI1NTE5AAAAI... comment")
|
||||
assert result == "public_key"
|
||||
|
||||
def test_classifies_rsa_public_key(self):
|
||||
result = _classify_ssh_key_material("ssh-rsa AAAAB3NzaC1yc2E... comment")
|
||||
assert result == "public_key"
|
||||
|
||||
def test_classifies_dss_public_key(self):
|
||||
result = _classify_ssh_key_material("ssh-dss AAAAB3NzaC1kc3M... comment")
|
||||
assert result == "public_key"
|
||||
|
||||
def test_classifies_ecdsa_nistp256_public_key(self):
|
||||
result = _classify_ssh_key_material("ecdsa-sha2-nistp256 AAAAE2Vj... comment")
|
||||
assert result == "public_key"
|
||||
|
||||
def test_classifies_ecdsa_nistp384_public_key(self):
|
||||
result = _classify_ssh_key_material("ecdsa-sha2-nistp384 AAAAE2Vj... comment")
|
||||
assert result == "public_key"
|
||||
|
||||
def test_classifies_ecdsa_nistp521_public_key(self):
|
||||
result = _classify_ssh_key_material("ecdsa-sha2-nistp521 AAAAE2Vj... comment")
|
||||
assert result == "public_key"
|
||||
|
||||
def test_classifies_sk_ed25519_public_key(self):
|
||||
result = _classify_ssh_key_material(
|
||||
"sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5... comment"
|
||||
)
|
||||
assert result == "public_key"
|
||||
|
||||
def test_classifies_openssh_private_key(self):
|
||||
result = _classify_ssh_key_material(
|
||||
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
|
||||
"base64data==\n"
|
||||
"-----END OPENSSH PRIVATE KEY-----"
|
||||
)
|
||||
assert result == "private_key"
|
||||
|
||||
def test_classifies_rsa_private_key(self):
|
||||
result = _classify_ssh_key_material(
|
||||
"-----BEGIN RSA PRIVATE KEY-----\n"
|
||||
"base64data==\n"
|
||||
"-----END RSA PRIVATE KEY-----"
|
||||
)
|
||||
assert result == "private_key"
|
||||
|
||||
def test_unknown_for_empty_string(self):
|
||||
result = _classify_ssh_key_material("")
|
||||
assert result == "unknown"
|
||||
|
||||
def test_unknown_for_whitespace_string(self):
|
||||
result = _classify_ssh_key_material(" \n ")
|
||||
assert result == "unknown"
|
||||
|
||||
def test_unknown_for_gibberish(self):
|
||||
result = _classify_ssh_key_material("not a valid ssh key")
|
||||
assert result == "unknown"
|
||||
|
||||
def test_unknown_for_unsupported_key_type(self):
|
||||
result = _classify_ssh_key_material("ssh-nonsense AAAABogus...")
|
||||
assert result == "unknown"
|
||||
|
||||
@pytest.mark.parametrize("raw,expected", [
|
||||
("ssh-rsa AAAAB3Nza... user@host", "public_key"),
|
||||
("ssh-ed25519 AAAAC3... john@laptop", "public_key"),
|
||||
("ecdsa-sha2-nistp256 AAAAE2Vj... me@box", "public_key"),
|
||||
("sk-ssh-ed25519@openssh.com AAAAGn...", "public_key"),
|
||||
("ssh-ed25519-cert-v01@openssh.com AAAAB3Nza cert for user", "certificate"),
|
||||
(
|
||||
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
|
||||
"abcdefghijklmnopqrstuvwxyz\n"
|
||||
"-----END OPENSSH PRIVATE KEY-----",
|
||||
"private_key",
|
||||
),
|
||||
("", "unknown"),
|
||||
("totally random garbage here", "unknown"),
|
||||
])
|
||||
def test_parametrized_variants(self, raw, expected):
|
||||
assert _classify_ssh_key_material(raw) == expected
|
||||
|
||||
def test_certificate_with_leading_whitespace(self):
|
||||
raw = " ssh-ed25519-cert-v01@openssh.com AAAAB3Nza extra words"
|
||||
assert _classify_ssh_key_material(raw) == "certificate"
|
||||
@@ -0,0 +1,275 @@
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.organization.department import (
|
||||
Department,
|
||||
DepartmentMembership,
|
||||
)
|
||||
from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy
|
||||
from gatehouse_app.api.v1.ssh._helpers import _get_merged_dept_cert_policy
|
||||
|
||||
|
||||
class TestDeptCertPolicy:
|
||||
def test_no_departments_returns_none(self, app, test_user):
|
||||
with app.app_context():
|
||||
result = _get_merged_dept_cert_policy(test_user)
|
||||
assert result is None
|
||||
|
||||
def test_department_without_policy_returns_none(self, app, test_user, test_org):
|
||||
with app.app_context():
|
||||
dept = Department(
|
||||
organization_id=test_org,
|
||||
name="No Policy Dept",
|
||||
)
|
||||
db.session.add(dept)
|
||||
db.session.commit()
|
||||
|
||||
membership = DepartmentMembership(
|
||||
user_id=test_user,
|
||||
department_id=dept.id,
|
||||
)
|
||||
db.session.add(membership)
|
||||
db.session.commit()
|
||||
|
||||
result = _get_merged_dept_cert_policy(test_user)
|
||||
assert result is None
|
||||
|
||||
def test_single_department_policy(self, app, test_user, test_org):
|
||||
with app.app_context():
|
||||
dept = Department(
|
||||
organization_id=test_org,
|
||||
name="Engineering",
|
||||
)
|
||||
db.session.add(dept)
|
||||
db.session.commit()
|
||||
|
||||
membership = DepartmentMembership(
|
||||
user_id=test_user,
|
||||
department_id=dept.id,
|
||||
)
|
||||
db.session.add(membership)
|
||||
db.session.commit()
|
||||
|
||||
policy = DepartmentCertPolicy(
|
||||
department_id=dept.id,
|
||||
allow_user_expiry=True,
|
||||
default_expiry_hours=4,
|
||||
max_expiry_hours=48,
|
||||
allowed_extensions=["permit-pty", "permit-agent-forwarding"],
|
||||
)
|
||||
db.session.add(policy)
|
||||
db.session.commit()
|
||||
|
||||
result = _get_merged_dept_cert_policy(test_user)
|
||||
assert result is not None
|
||||
assert result["allow_user_expiry"] is True
|
||||
assert result["default_expiry_hours"] == 4
|
||||
assert result["max_expiry_hours"] == 48
|
||||
assert set(result["extensions"]) == {"permit-pty", "permit-agent-forwarding"}
|
||||
|
||||
def test_both_departments_same_policies(self, app, test_user, test_org):
|
||||
with app.app_context():
|
||||
dept1 = Department(
|
||||
organization_id=test_org,
|
||||
name="Engineering",
|
||||
)
|
||||
dept2 = Department(
|
||||
organization_id=test_org,
|
||||
name="SRE",
|
||||
)
|
||||
db.session.add_all([dept1, dept2])
|
||||
db.session.commit()
|
||||
|
||||
member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id)
|
||||
member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id)
|
||||
db.session.add_all([member1, member2])
|
||||
db.session.commit()
|
||||
|
||||
policy1 = DepartmentCertPolicy(
|
||||
department_id=dept1.id,
|
||||
allow_user_expiry=True,
|
||||
default_expiry_hours=4,
|
||||
max_expiry_hours=48,
|
||||
allowed_extensions=["permit-pty", "permit-agent-forwarding"],
|
||||
)
|
||||
policy2 = DepartmentCertPolicy(
|
||||
department_id=dept2.id,
|
||||
allow_user_expiry=True,
|
||||
default_expiry_hours=4,
|
||||
max_expiry_hours=48,
|
||||
allowed_extensions=["permit-pty", "permit-agent-forwarding"],
|
||||
)
|
||||
db.session.add_all([policy1, policy2])
|
||||
db.session.commit()
|
||||
|
||||
result = _get_merged_dept_cert_policy(test_user)
|
||||
assert result["allow_user_expiry"] is True
|
||||
assert result["default_expiry_hours"] == 4
|
||||
assert result["max_expiry_hours"] == 48
|
||||
|
||||
def test_merges_min_expiry_across_departments(self, app, test_user, test_org):
|
||||
with app.app_context():
|
||||
dept1 = Department(
|
||||
organization_id=test_org,
|
||||
name="Engineering",
|
||||
)
|
||||
dept2 = Department(
|
||||
organization_id=test_org,
|
||||
name="SRE",
|
||||
)
|
||||
db.session.add_all([dept1, dept2])
|
||||
db.session.commit()
|
||||
|
||||
member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id)
|
||||
member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id)
|
||||
db.session.add_all([member1, member2])
|
||||
db.session.commit()
|
||||
|
||||
policy1 = DepartmentCertPolicy(
|
||||
department_id=dept1.id,
|
||||
allow_user_expiry=True,
|
||||
default_expiry_hours=24,
|
||||
max_expiry_hours=720,
|
||||
)
|
||||
policy2 = DepartmentCertPolicy(
|
||||
department_id=dept2.id,
|
||||
allow_user_expiry=True,
|
||||
default_expiry_hours=1,
|
||||
max_expiry_hours=72,
|
||||
)
|
||||
db.session.add_all([policy1, policy2])
|
||||
db.session.commit()
|
||||
|
||||
result = _get_merged_dept_cert_policy(test_user)
|
||||
assert result["default_expiry_hours"] == 1
|
||||
assert result["max_expiry_hours"] == 72
|
||||
|
||||
def test_extends_intersection_across_departments(self, app, test_user, test_org):
|
||||
with app.app_context():
|
||||
dept1 = Department(
|
||||
organization_id=test_org,
|
||||
name="Engineering",
|
||||
)
|
||||
dept2 = Department(
|
||||
organization_id=test_org,
|
||||
name="SRE",
|
||||
)
|
||||
db.session.add_all([dept1, dept2])
|
||||
db.session.commit()
|
||||
|
||||
member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id)
|
||||
member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id)
|
||||
db.session.add_all([member1, member2])
|
||||
db.session.commit()
|
||||
|
||||
policy1 = DepartmentCertPolicy(
|
||||
department_id=dept1.id,
|
||||
allowed_extensions=["permit-pty", "permit-agent-forwarding"],
|
||||
)
|
||||
policy2 = DepartmentCertPolicy(
|
||||
department_id=dept2.id,
|
||||
allowed_extensions=["permit-pty", "permit-port-forwarding"],
|
||||
)
|
||||
db.session.add_all([policy1, policy2])
|
||||
db.session.commit()
|
||||
|
||||
result = _get_merged_dept_cert_policy(test_user)
|
||||
assert set(result["extensions"]) == {"permit-pty"}
|
||||
|
||||
def test_any_false_user_expiry_means_overall_false(
|
||||
self, app, test_user, test_org
|
||||
):
|
||||
with app.app_context():
|
||||
dept1 = Department(
|
||||
organization_id=test_org,
|
||||
name="Engineering",
|
||||
)
|
||||
dept2 = Department(
|
||||
organization_id=test_org,
|
||||
name="SRE",
|
||||
)
|
||||
db.session.add_all([dept1, dept2])
|
||||
db.session.commit()
|
||||
|
||||
member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id)
|
||||
member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id)
|
||||
db.session.add_all([member1, member2])
|
||||
db.session.commit()
|
||||
|
||||
policy1 = DepartmentCertPolicy(
|
||||
department_id=dept1.id,
|
||||
allow_user_expiry=True,
|
||||
)
|
||||
policy2 = DepartmentCertPolicy(
|
||||
department_id=dept2.id,
|
||||
allow_user_expiry=False,
|
||||
)
|
||||
db.session.add_all([policy1, policy2])
|
||||
db.session.commit()
|
||||
|
||||
result = _get_merged_dept_cert_policy(test_user)
|
||||
assert result["allow_user_expiry"] is False
|
||||
|
||||
def test_deleted_department_filtered(self, app, test_user, test_org):
|
||||
with app.app_context():
|
||||
active_dept = Department(
|
||||
organization_id=test_org,
|
||||
name="Active Dept",
|
||||
)
|
||||
deleted_dept = Department(
|
||||
organization_id=test_org,
|
||||
name="Deleted Dept",
|
||||
deleted_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.session.add_all([active_dept, deleted_dept])
|
||||
db.session.commit()
|
||||
|
||||
active_member = DepartmentMembership(
|
||||
user_id=test_user, department_id=active_dept.id
|
||||
)
|
||||
deleted_member = DepartmentMembership(
|
||||
user_id=test_user,
|
||||
department_id=deleted_dept.id,
|
||||
deleted_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.session.add_all([active_member, deleted_member])
|
||||
db.session.commit()
|
||||
|
||||
policy = DepartmentCertPolicy(
|
||||
department_id=active_dept.id,
|
||||
allow_user_expiry=True,
|
||||
default_expiry_hours=12,
|
||||
max_expiry_hours=96,
|
||||
)
|
||||
db.session.add(policy)
|
||||
db.session.commit()
|
||||
|
||||
result = _get_merged_dept_cert_policy(test_user)
|
||||
assert result is not None
|
||||
assert result["default_expiry_hours"] == 12
|
||||
|
||||
def test_single_department_no_extensions(self, app, test_user, test_org):
|
||||
with app.app_context():
|
||||
dept = Department(
|
||||
organization_id=test_org,
|
||||
name="Minimal Dept",
|
||||
)
|
||||
db.session.add(dept)
|
||||
db.session.commit()
|
||||
|
||||
membership = DepartmentMembership(
|
||||
user_id=test_user, department_id=dept.id
|
||||
)
|
||||
db.session.add(membership)
|
||||
db.session.commit()
|
||||
|
||||
policy = DepartmentCertPolicy(
|
||||
department_id=dept.id,
|
||||
allowed_extensions=[],
|
||||
)
|
||||
db.session.add(policy)
|
||||
db.session.commit()
|
||||
|
||||
result = _get_merged_dept_cert_policy(test_user)
|
||||
assert result is not None
|
||||
assert result["extensions"] == []
|
||||
@@ -0,0 +1,118 @@
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.organization.organization import Organization
|
||||
from gatehouse_app.models.organization.organization_member import OrganizationMember
|
||||
from gatehouse_app.models.ssh_ca.ca import CA, CaType, KeyType
|
||||
from gatehouse_app.api.v1.ssh._helpers import _get_org_ca_for_user
|
||||
from gatehouse_app.utils.constants import OrganizationRole
|
||||
|
||||
|
||||
class TestOrgCAForUser:
|
||||
def test_organization_id_param_overrides_membership(self, app, test_user, test_org, test_ca, test_membership):
|
||||
with app.app_context():
|
||||
org2 = Organization(name="Org 2", slug="org-2")
|
||||
db.session.add(org2)
|
||||
db.session.commit()
|
||||
|
||||
ca2 = CA(
|
||||
organization_id=org2.id,
|
||||
name="Org 2 CA",
|
||||
ca_type=CaType.USER,
|
||||
key_type=KeyType.ED25519,
|
||||
private_key="key2",
|
||||
public_key="pubkey2",
|
||||
fingerprint="sha256:org2...",
|
||||
is_active=True,
|
||||
)
|
||||
db.session.add(ca2)
|
||||
db.session.commit()
|
||||
|
||||
user = db.session.get(User, test_user)
|
||||
result = _get_org_ca_for_user(user, ca_type="user", organization_id=test_org)
|
||||
assert result is not None
|
||||
assert result.organization_id == test_org
|
||||
|
||||
def test_multiple_orgs_returns_ca(self, app, test_user, test_org, test_ca, test_membership):
|
||||
with app.app_context():
|
||||
org2 = Organization(name="Org 2", slug="org-2")
|
||||
db.session.add(org2)
|
||||
db.session.commit()
|
||||
|
||||
user = db.session.get(User, test_user)
|
||||
member2 = OrganizationMember(
|
||||
user_id=test_user, organization_id=org2.id, role=OrganizationRole.MEMBER
|
||||
)
|
||||
db.session.add(member2)
|
||||
db.session.commit()
|
||||
|
||||
result = _get_org_ca_for_user(user, ca_type="user")
|
||||
assert result is not None
|
||||
|
||||
def test_user_with_no_memberships_returns_none(self, app):
|
||||
with app.app_context():
|
||||
user = User(email="lonely@test.com", full_name="Lonely User")
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
result = _get_org_ca_for_user(user, ca_type="user")
|
||||
assert result is None
|
||||
|
||||
def test_inactive_ca_not_returned(self, app, test_user, test_org, test_membership):
|
||||
with app.app_context():
|
||||
ca = CA(
|
||||
organization_id=test_org,
|
||||
name="Inactive CA",
|
||||
ca_type=CaType.USER,
|
||||
key_type=KeyType.ED25519,
|
||||
private_key="key",
|
||||
public_key="pubkey",
|
||||
fingerprint="sha256:inactive123...",
|
||||
is_active=False,
|
||||
)
|
||||
db.session.add(ca)
|
||||
db.session.commit()
|
||||
|
||||
user = db.session.get(User, test_user)
|
||||
result = _get_org_ca_for_user(user, ca_type="user")
|
||||
assert result is None
|
||||
|
||||
def test_host_ca_not_returned_when_user_requested(self, app, test_user, test_org, test_membership):
|
||||
with app.app_context():
|
||||
ca = CA(
|
||||
organization_id=test_org,
|
||||
name="Host CA",
|
||||
ca_type=CaType.HOST,
|
||||
key_type=KeyType.ED25519,
|
||||
private_key="key",
|
||||
public_key="pubkey",
|
||||
fingerprint="sha256:host123...",
|
||||
is_active=True,
|
||||
)
|
||||
db.session.add(ca)
|
||||
db.session.commit()
|
||||
|
||||
user = db.session.get(User, test_user)
|
||||
result = _get_org_ca_for_user(user, ca_type="user")
|
||||
assert result is None
|
||||
|
||||
def test_user_ca_not_returned_when_host_requested(self, app, test_user, test_org, test_membership):
|
||||
with app.app_context():
|
||||
ca = CA(
|
||||
organization_id=test_org,
|
||||
name="User CA",
|
||||
ca_type=CaType.USER,
|
||||
key_type=KeyType.ED25519,
|
||||
private_key="key",
|
||||
public_key="pubkey",
|
||||
fingerprint="sha256:useronly...",
|
||||
is_active=True,
|
||||
)
|
||||
db.session.add(ca)
|
||||
db.session.commit()
|
||||
|
||||
user = db.session.get(User, test_user)
|
||||
result = _get_org_ca_for_user(user, ca_type="host")
|
||||
assert result is None
|
||||
@@ -0,0 +1,180 @@
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.ssh_ca.ca import CA, CaType, KeyType, CertType
|
||||
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
|
||||
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
|
||||
from gatehouse_app.services.ssh_ca_signing_service import SSHCertificateSigningResponse
|
||||
from gatehouse_app.api.v1.ssh._helpers import _persist_certificate
|
||||
|
||||
|
||||
class TestPersistCertificate:
|
||||
def test_persists_valid_certificate(self, app, test_user, test_org):
|
||||
with app.app_context():
|
||||
ca = CA(
|
||||
organization_id=test_org,
|
||||
name="Signing CA",
|
||||
ca_type=CaType.USER,
|
||||
key_type=KeyType.ED25519,
|
||||
private_key="enc_priv",
|
||||
public_key="ssh-ed25519 AAAAB3Nza...",
|
||||
fingerprint="sha256:abc123...",
|
||||
is_active=True,
|
||||
)
|
||||
db.session.add(ca)
|
||||
db.session.commit()
|
||||
|
||||
ssh_key = SSHKey(
|
||||
user_id=test_user,
|
||||
payload="ssh-ed25519 AAAAB3NzaC1lZDI1NTE5AAAAIKeyData comment",
|
||||
fingerprint="sha256:keyfp123...",
|
||||
)
|
||||
db.session.add(ssh_key)
|
||||
db.session.commit()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
later = now + timedelta(hours=24)
|
||||
response = SSHCertificateSigningResponse(
|
||||
certificate="ssh-ed25519-cert-v01@openssh.com AAAACertData...",
|
||||
serial="123456",
|
||||
valid_after=now,
|
||||
valid_before=later,
|
||||
principals=["eng-prod"],
|
||||
)
|
||||
|
||||
result = _persist_certificate(
|
||||
user_id=test_user,
|
||||
ssh_key_id=ssh_key.id,
|
||||
ca=ca,
|
||||
signing_response=response,
|
||||
request_ip="10.0.0.1",
|
||||
cert_type_str="user",
|
||||
cert_identity="user@example.com",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.ca_id == ca.id
|
||||
assert result.user_id == test_user
|
||||
assert result.ssh_key_id == ssh_key.id
|
||||
assert result.cert_type == CertType.USER
|
||||
assert result.certificate == response.certificate
|
||||
assert result.serial == response.serial
|
||||
assert result.valid_after.replace(tzinfo=None) == now.replace(tzinfo=None)
|
||||
assert result.valid_before.replace(tzinfo=None) == later.replace(tzinfo=None)
|
||||
assert result.request_ip == "10.0.0.1"
|
||||
assert result.key_id == "user@example.com"
|
||||
assert sorted(result.principals) == ["eng-prod"]
|
||||
assert result.revoked is False
|
||||
assert result.status == CertificateStatus.ISSUED
|
||||
|
||||
def test_none_ca_returns_none(self, app, test_user):
|
||||
with app.app_context():
|
||||
now = datetime.now(timezone.utc)
|
||||
response = SSHCertificateSigningResponse(
|
||||
certificate="cert-data",
|
||||
serial="1",
|
||||
valid_after=now,
|
||||
valid_before=now + timedelta(hours=1),
|
||||
)
|
||||
result = _persist_certificate(test_user, "keyid", None, response)
|
||||
assert result is None
|
||||
|
||||
def test_invalid_cert_type_str_falls_back_to_user(self, app, test_user, test_org):
|
||||
with app.app_context():
|
||||
ca = CA(
|
||||
organization_id=test_org,
|
||||
name="Signing CA",
|
||||
ca_type=CaType.USER,
|
||||
key_type=KeyType.ED25519,
|
||||
private_key="enc_priv",
|
||||
public_key="ssh-ed25519 AAAAB3Nza...",
|
||||
fingerprint="sha256:fallback123...",
|
||||
is_active=True,
|
||||
)
|
||||
db.session.add(ca)
|
||||
db.session.commit()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
response = SSHCertificateSigningResponse(
|
||||
certificate="cert-data",
|
||||
serial="1",
|
||||
valid_after=now,
|
||||
valid_before=now + timedelta(hours=1),
|
||||
)
|
||||
result = _persist_certificate(
|
||||
user_id=test_user,
|
||||
ssh_key_id=None,
|
||||
ca=ca,
|
||||
signing_response=response,
|
||||
cert_type_str="invalid_type",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.cert_type == CertType.USER
|
||||
|
||||
def test_none_ssh_key_id_defaults_to_host_cert_key_id(self, app, test_user, test_org):
|
||||
with app.app_context():
|
||||
ca = CA(
|
||||
organization_id=test_org,
|
||||
name="Host CA",
|
||||
ca_type=CaType.HOST,
|
||||
key_type=KeyType.ED25519,
|
||||
private_key="enc_priv",
|
||||
public_key="ssh-ed25519 AAAAB3Nza...",
|
||||
fingerprint="sha256:hostca123...",
|
||||
is_active=True,
|
||||
)
|
||||
db.session.add(ca)
|
||||
db.session.commit()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
response = SSHCertificateSigningResponse(
|
||||
certificate="cert-data",
|
||||
serial="1",
|
||||
valid_after=now,
|
||||
valid_before=now + timedelta(hours=1),
|
||||
)
|
||||
result = _persist_certificate(
|
||||
user_id=test_user,
|
||||
ssh_key_id=None,
|
||||
ca=ca,
|
||||
signing_response=response,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.key_id == "host-cert"
|
||||
|
||||
def test_request_ip_stored(self, app, test_user, test_org):
|
||||
with app.app_context():
|
||||
ca = CA(
|
||||
organization_id=test_org,
|
||||
name="Signing CA",
|
||||
ca_type=CaType.USER,
|
||||
key_type=KeyType.ED25519,
|
||||
private_key="enc_priv",
|
||||
public_key="ssh-ed25519 AAAAB3Nza...",
|
||||
fingerprint="sha256:ip...",
|
||||
is_active=True,
|
||||
)
|
||||
db.session.add(ca)
|
||||
db.session.commit()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
response = SSHCertificateSigningResponse(
|
||||
certificate="cert-data",
|
||||
serial="1",
|
||||
valid_after=now,
|
||||
valid_before=now + timedelta(hours=1),
|
||||
)
|
||||
result = _persist_certificate(
|
||||
user_id=test_user,
|
||||
ssh_key_id=None,
|
||||
ca=ca,
|
||||
signing_response=response,
|
||||
request_ip="192.168.1.100",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.request_ip == "192.168.1.100"
|
||||
@@ -0,0 +1,78 @@
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from gatehouse_app.extensions import db
|
||||
from gatehouse_app.models.user.user import User
|
||||
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
|
||||
from gatehouse_app.services.ssh_key_service import SSHKeyService
|
||||
from gatehouse_app.exceptions import UserNotFoundError, SSHKeyError
|
||||
|
||||
|
||||
VALID_PUBLIC_KEY = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKL7YKAg+CgUvk/oNmU1fO24fLf5O5LayGH/EgORbz06"
|
||||
|
||||
|
||||
class TestSSHKeyServiceAdd:
|
||||
|
||||
def test_add_new_key_returns_true(self, app, test_user):
|
||||
with app.app_context():
|
||||
service = SSHKeyService()
|
||||
key, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "My laptop")
|
||||
assert is_new is True
|
||||
assert key.user_id == test_user
|
||||
assert key.payload == VALID_PUBLIC_KEY
|
||||
assert key.description == "My laptop"
|
||||
assert key.verified is False
|
||||
assert key.fingerprint is not None
|
||||
assert key.key_type is not None
|
||||
|
||||
def test_add_duplicate_returns_existing(self, app, test_user):
|
||||
with app.app_context():
|
||||
service = SSHKeyService()
|
||||
key1, _ = service.add_ssh_key(test_user, VALID_PUBLIC_KEY)
|
||||
key2, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY)
|
||||
assert is_new is False
|
||||
assert key2.id == key1.id
|
||||
|
||||
def test_add_restores_soft_deleted_key(self, app, test_user):
|
||||
with app.app_context():
|
||||
service = SSHKeyService()
|
||||
key1, _ = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "Original")
|
||||
|
||||
# Soft-delete the key
|
||||
key1.deleted_at = datetime.now(timezone.utc)
|
||||
db.session.commit()
|
||||
|
||||
# Re-add same key
|
||||
key2, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "Restored")
|
||||
assert is_new is False
|
||||
assert key2.id == key1.id
|
||||
assert key2.deleted_at is None
|
||||
assert key2.description == "Restored"
|
||||
assert key2.verified is False
|
||||
assert key2.verified_at is None
|
||||
|
||||
def test_add_with_description(self, app, test_user):
|
||||
with app.app_context():
|
||||
service = SSHKeyService()
|
||||
key, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "Work laptop")
|
||||
assert is_new is True
|
||||
assert key.description == "Work laptop"
|
||||
|
||||
def test_user_not_found_raises(self, app):
|
||||
with app.app_context():
|
||||
service = SSHKeyService()
|
||||
with pytest.raises(UserNotFoundError):
|
||||
service.add_ssh_key("nonexistent-user-id", VALID_PUBLIC_KEY)
|
||||
|
||||
def test_invalid_key_format_raises(self, app, test_user):
|
||||
with app.app_context():
|
||||
service = SSHKeyService()
|
||||
with pytest.raises(SSHKeyError):
|
||||
service.add_ssh_key(test_user, "not-a-valid-key")
|
||||
|
||||
def test_idempotent_second_call_no_error(self, app, test_user):
|
||||
with app.app_context():
|
||||
service = SSHKeyService()
|
||||
service.add_ssh_key(test_user, VALID_PUBLIC_KEY)
|
||||
key2, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY)
|
||||
assert is_new is False
|
||||
assert key2 is not None
|
||||
@@ -0,0 +1,148 @@
|
||||
import pytest
|
||||
from gatehouse_app.services.ssh_ca_signing_service import SSHCertificateSigningRequest
|
||||
|
||||
|
||||
VALID_PUBLIC_KEY = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKL7YKAg+CgUvk/oNmU1fO24fLf5O5LayGH/EgORbz06"
|
||||
|
||||
|
||||
class TestCertSigningRequestValidate:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_config(self, monkeypatch):
|
||||
from gatehouse_app.config.ssh_ca_config import SSHCAConfig
|
||||
|
||||
class TestConfig(SSHCAConfig):
|
||||
def get_int(self, key, default=0):
|
||||
values = {
|
||||
"max_cert_validity_hours": 720,
|
||||
"max_principals_per_cert": 256,
|
||||
"max_key_id_length": 255,
|
||||
}
|
||||
return values.get(key, default)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gatehouse_app.config.ssh_ca_config.get_ssh_ca_config",
|
||||
lambda: TestConfig(),
|
||||
)
|
||||
|
||||
def test_valid_request_no_errors(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key=VALID_PUBLIC_KEY,
|
||||
principals=["eng-prod"],
|
||||
key_id="user@example.com",
|
||||
)
|
||||
errors = req.validate()
|
||||
assert errors == []
|
||||
|
||||
def test_valid_host_cert_no_errors(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key=VALID_PUBLIC_KEY,
|
||||
principals=["host1.example.com"],
|
||||
key_id="host-identity",
|
||||
cert_type="host",
|
||||
)
|
||||
errors = req.validate()
|
||||
assert errors == []
|
||||
|
||||
def test_invalid_cert_type(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key=VALID_PUBLIC_KEY,
|
||||
principals=["eng-prod"],
|
||||
key_id="user@example.com",
|
||||
cert_type="invalid",
|
||||
)
|
||||
errors = req.validate()
|
||||
assert any("cert_type" in e.lower() for e in errors)
|
||||
|
||||
def test_missing_public_key(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key="",
|
||||
principals=["eng-prod"],
|
||||
key_id="user@example.com",
|
||||
)
|
||||
errors = req.validate()
|
||||
assert any("public key" in e.lower() for e in errors)
|
||||
|
||||
def test_malformed_public_key(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key="not-a-key",
|
||||
principals=["eng-prod"],
|
||||
key_id="user@example.com",
|
||||
)
|
||||
errors = req.validate()
|
||||
assert any("public key" in e.lower() for e in errors)
|
||||
|
||||
def test_no_principals(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key=VALID_PUBLIC_KEY,
|
||||
principals=[],
|
||||
key_id="user@example.com",
|
||||
)
|
||||
errors = req.validate()
|
||||
assert any("principal" in e.lower() for e in errors)
|
||||
|
||||
def test_too_many_principals(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key=VALID_PUBLIC_KEY,
|
||||
principals=[f"p{i}" for i in range(300)],
|
||||
key_id="user@example.com",
|
||||
)
|
||||
errors = req.validate()
|
||||
assert any("too many" in e.lower() for e in errors)
|
||||
|
||||
def test_missing_key_id(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key=VALID_PUBLIC_KEY,
|
||||
principals=["eng-prod"],
|
||||
key_id="",
|
||||
)
|
||||
errors = req.validate()
|
||||
assert any("key_id" in e.lower() for e in errors)
|
||||
|
||||
def test_key_id_too_short(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key=VALID_PUBLIC_KEY,
|
||||
principals=["eng-prod"],
|
||||
key_id="ab",
|
||||
)
|
||||
errors = req.validate()
|
||||
assert any("key_id" in e.lower() for e in errors)
|
||||
|
||||
def test_key_id_exceeds_max_length(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key=VALID_PUBLIC_KEY,
|
||||
principals=["eng-prod"],
|
||||
key_id="x" * 300,
|
||||
)
|
||||
errors = req.validate()
|
||||
assert any("key_id" in e.lower() for e in errors)
|
||||
|
||||
def test_non_positive_expiry(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key=VALID_PUBLIC_KEY,
|
||||
principals=["eng-prod"],
|
||||
key_id="user@example.com",
|
||||
expiry_hours=0,
|
||||
)
|
||||
errors = req.validate()
|
||||
assert any("expiry" in e.lower() for e in errors)
|
||||
|
||||
def test_expiry_exceeds_max(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key=VALID_PUBLIC_KEY,
|
||||
principals=["eng-prod"],
|
||||
key_id="user@example.com",
|
||||
expiry_hours=99999,
|
||||
)
|
||||
errors = req.validate()
|
||||
assert any("expiry" in e.lower() for e in errors)
|
||||
|
||||
def test_none_expiry_is_ok(self):
|
||||
req = SSHCertificateSigningRequest(
|
||||
ssh_public_key=VALID_PUBLIC_KEY,
|
||||
principals=["eng-prod"],
|
||||
key_id="user@example.com",
|
||||
expiry_hours=None,
|
||||
)
|
||||
errors = req.validate()
|
||||
assert errors == []
|
||||
@@ -0,0 +1,77 @@
|
||||
from gatehouse_app.api.v1.superadmin.organizations import (
|
||||
ListOrganizationsSchema,
|
||||
UpdateOrganizationSchema,
|
||||
)
|
||||
|
||||
|
||||
class TestListOrganizationsSchema:
|
||||
def test_defaults_when_empty(self):
|
||||
result = ListOrganizationsSchema.load({})
|
||||
assert result["page"] == 1
|
||||
assert result["per_page"] == 20
|
||||
assert result["search"] is None
|
||||
assert result["status"] is None
|
||||
assert result["plan_slug"] is None
|
||||
|
||||
def test_normal_pagination(self):
|
||||
result = ListOrganizationsSchema.load({"page": 3, "per_page": 10})
|
||||
assert result["page"] == 3
|
||||
assert result["per_page"] == 10
|
||||
|
||||
def test_page_zero_clamped_to_one(self):
|
||||
result = ListOrganizationsSchema.load({"page": 0})
|
||||
assert result["page"] == 1
|
||||
|
||||
def test_negative_per_page_clamped_to_one(self):
|
||||
result = ListOrganizationsSchema.load({"per_page": -5})
|
||||
assert result["per_page"] == 1
|
||||
|
||||
def test_per_page_exceeds_max_clamped_to_100(self):
|
||||
result = ListOrganizationsSchema.load({"per_page": 200})
|
||||
assert result["per_page"] == 100
|
||||
|
||||
def test_non_integer_values_fallback(self):
|
||||
result = ListOrganizationsSchema.load({"page": "abc", "per_page": "xyz"})
|
||||
assert result["page"] == 1
|
||||
assert result["per_page"] == 20
|
||||
|
||||
def test_search_passthrough(self):
|
||||
result = ListOrganizationsSchema.load({"search": "acme"})
|
||||
assert result["search"] == "acme"
|
||||
|
||||
def test_status_passthrough(self):
|
||||
result = ListOrganizationsSchema.load({"status": "active"})
|
||||
assert result["status"] == "active"
|
||||
|
||||
def test_plan_slug_passthrough(self):
|
||||
result = ListOrganizationsSchema.load({"plan_slug": "pro"})
|
||||
assert result["plan_slug"] == "pro"
|
||||
|
||||
|
||||
class TestUpdateOrganizationSchema:
|
||||
def test_all_fields(self):
|
||||
result = UpdateOrganizationSchema.load({
|
||||
"name": "New Name",
|
||||
"description": "New Description",
|
||||
"is_active": True,
|
||||
})
|
||||
assert result == {
|
||||
"name": "New Name",
|
||||
"description": "New Description",
|
||||
"is_active": True,
|
||||
}
|
||||
|
||||
def test_empty_dict(self):
|
||||
result = UpdateOrganizationSchema.load({})
|
||||
assert result == {}
|
||||
|
||||
def test_partial_data(self):
|
||||
result = UpdateOrganizationSchema.load({"name": "Renamed Only"})
|
||||
assert result == {"name": "Renamed Only"}
|
||||
|
||||
def test_is_active_coerced_to_bool(self):
|
||||
result = UpdateOrganizationSchema.load({"is_active": "truthy"})
|
||||
assert result["is_active"] is True
|
||||
|
||||
result = UpdateOrganizationSchema.load({"is_active": ""})
|
||||
assert result["is_active"] is False
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user