feat: allow admins to bypass approval flow when joining networks
This commit is contained in:
@@ -48,6 +48,21 @@ class AdminClient:
|
||||
data={"confirm": confirm},
|
||||
)
|
||||
|
||||
def get_user_ssh_certificates(self, user_id: str, **params) -> dict:
|
||||
"""List all SSH certificates for a user (admin view).
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
**params: Optional query parameters — status, active, cert_type, page, per_page
|
||||
"""
|
||||
path = f"/admin/users/{user_id}/ssh-certificates"
|
||||
if params:
|
||||
from urllib.parse import urlencode
|
||||
query = urlencode({k: v for k, v in params.items() if v is not None})
|
||||
if query:
|
||||
path = f"{path}?{query}"
|
||||
return self._client.get(path)
|
||||
|
||||
def list_audit_logs(self) -> dict:
|
||||
"""List system-wide audit logs."""
|
||||
return self._client.get("/audit-logs")
|
||||
|
||||
@@ -211,3 +211,309 @@ class TestAdminUserManagement:
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.auth.login(email=victim["email"], password="VictimPass123!")
|
||||
assert exc_info.value.status_code in (400, 401)
|
||||
|
||||
|
||||
class TestAdminSSHCertificates:
|
||||
"""Test admin SSH certificate listing endpoints."""
|
||||
|
||||
def _create_test_cert(
|
||||
self, integration_app, user_id: str, ca_id: str, *, ssh_key_id=None,
|
||||
status="issued", revoked=False, valid_after=None, valid_before=None,
|
||||
cert_type="user", principals=None,
|
||||
):
|
||||
"""Create a test SSH certificate record."""
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus
|
||||
from gatehouse_app.models.ssh_ca.ca import CertType
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_after = valid_after or (now - timedelta(hours=1))
|
||||
valid_before = valid_before or (now + timedelta(hours=23))
|
||||
principals = principals or ["prod-servers"]
|
||||
|
||||
with integration_app.app_context():
|
||||
cert = SSHCertificate(
|
||||
ca_id=ca_id,
|
||||
user_id=user_id,
|
||||
ssh_key_id=ssh_key_id,
|
||||
certificate=f"ssh-ed25519-cert-v01@openssh.com AAAA...test_serial_{uuid.uuid4().hex[:8]}",
|
||||
serial=str(uuid.uuid4().int)[:20],
|
||||
key_id=f"test@example.com-{uuid.uuid4().hex[:8]}",
|
||||
cert_type=CertType(cert_type),
|
||||
principals=principals,
|
||||
valid_after=valid_after,
|
||||
valid_before=valid_before,
|
||||
revoked=revoked,
|
||||
status=CertificateStatus(status),
|
||||
request_ip="192.168.1.100",
|
||||
request_user_agent="OpenSSH_9.0",
|
||||
)
|
||||
if revoked:
|
||||
cert.revoked_at = now
|
||||
cert.revoke_reason = "test revocation"
|
||||
db.session.add(cert)
|
||||
db.session.commit()
|
||||
return str(cert.id)
|
||||
|
||||
def _create_test_ssh_key(self, integration_app, user_id: str, fingerprint: str = None):
|
||||
"""Create a test SSH key record."""
|
||||
from gatehouse_app.models.ssh_ca.ssh_key import SSHKey
|
||||
|
||||
fingerprint = fingerprint or f"SHA256:{uuid.uuid4().hex[:43]}"
|
||||
with integration_app.app_context():
|
||||
key = SSHKey(
|
||||
user_id=user_id,
|
||||
payload=f"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAI...test",
|
||||
fingerprint=fingerprint,
|
||||
description="Test laptop key",
|
||||
verified=True,
|
||||
key_type="ssh-ed25519",
|
||||
key_bits=256,
|
||||
key_comment="test@laptop",
|
||||
)
|
||||
db.session.add(key)
|
||||
db.session.commit()
|
||||
return str(key.id)
|
||||
|
||||
def test_list_user_ssh_certs_positive(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca):
|
||||
"""TEST: ADMIN-SSH-01 — List all SSH certificates for a user as admin.
|
||||
|
||||
WHAT: Create a user with two certs (one active, one expired),
|
||||
admin lists all certs via the new endpoint.
|
||||
WHY: Admin needs full visibility of user SSH certificate history.
|
||||
EXPECTED: 200 OK with certificates array containing both certs.
|
||||
"""
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
victim = create_test_user(password="VictimPass123!")
|
||||
org = create_test_org()
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER)
|
||||
create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER)
|
||||
ca = create_test_ca(org_id=org["id"])
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Create an active cert
|
||||
self._create_test_cert(
|
||||
integration_app, victim["id"], ca["id"],
|
||||
status="issued", valid_after=now - timedelta(hours=1),
|
||||
valid_before=now + timedelta(hours=23),
|
||||
)
|
||||
# Create an expired cert
|
||||
self._create_test_cert(
|
||||
integration_app, victim["id"], ca["id"],
|
||||
status="expired", valid_after=now - timedelta(days=7),
|
||||
valid_before=now - timedelta(days=1),
|
||||
)
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
result = integration_client.admin.get_user_ssh_certificates(victim["id"])
|
||||
data = assert_success(result)
|
||||
assert "certificates" in data
|
||||
assert data["count"] == 2
|
||||
assert len(data["certificates"]) == 2
|
||||
|
||||
def test_list_user_ssh_certs_with_key_metadata(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca):
|
||||
"""TEST: ADMIN-SSH-02 — Certificate includes SSH key metadata.
|
||||
|
||||
WHAT: Create a cert linked to an SSH key, verify key details
|
||||
appear in the response.
|
||||
WHY: Admin needs to see which key was used to request the cert.
|
||||
EXPECTED: ssh_key object with fingerprint, key_type, key_bits.
|
||||
"""
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
victim = create_test_user(password="VictimPass123!")
|
||||
org = create_test_org()
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER)
|
||||
create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER)
|
||||
ca = create_test_ca(org_id=org["id"])
|
||||
|
||||
key_id = self._create_test_ssh_key(integration_app, victim["id"])
|
||||
self._create_test_cert(integration_app, victim["id"], ca["id"], ssh_key_id=key_id)
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
result = integration_client.admin.get_user_ssh_certificates(victim["id"])
|
||||
data = assert_success(result)
|
||||
|
||||
cert = data["certificates"][0]
|
||||
assert cert["ssh_key"] is not None
|
||||
assert cert["ssh_key"]["key_type"] == "ssh-ed25519"
|
||||
assert cert["ssh_key"]["fingerprint"] is not None
|
||||
assert cert["ssh_key"]["description"] == "Test laptop key"
|
||||
|
||||
def test_list_user_ssh_certs_non_admin_negative(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca):
|
||||
"""TEST: ADMIN-SSH-03 — Non-admin cannot list another user's certs.
|
||||
|
||||
WHAT: Regular member tries to list admin's certs.
|
||||
WHY: Certificate data is sensitive and admin-only.
|
||||
EXPECTED: 403 Forbidden.
|
||||
"""
|
||||
member = create_test_user(password="MemberPass123!")
|
||||
admin_user = create_test_user(password="AdminPass123!")
|
||||
org = create_test_org()
|
||||
create_test_membership(member["id"], org["id"], OrganizationRole.MEMBER)
|
||||
create_test_membership(admin_user["id"], org["id"], OrganizationRole.OWNER)
|
||||
|
||||
integration_client.auth.login(email=member["email"], password="MemberPass123!")
|
||||
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.admin.get_user_ssh_certificates(admin_user["id"])
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
def test_list_user_ssh_certs_filter_by_status(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca):
|
||||
"""TEST: ADMIN-SSH-04 — Filter certificates by status.
|
||||
|
||||
WHAT: Create certs with different statuses, filter by status=revoked.
|
||||
WHY: Admin may want to see only revoked certs to audit access.
|
||||
EXPECTED: Only revoked certs returned.
|
||||
"""
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
victim = create_test_user(password="VictimPass123!")
|
||||
org = create_test_org()
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER)
|
||||
create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER)
|
||||
ca = create_test_ca(org_id=org["id"])
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
self._create_test_cert(integration_app, victim["id"], ca["id"], status="issued")
|
||||
self._create_test_cert(integration_app, victim["id"], ca["id"], status="revoked", revoked=True)
|
||||
self._create_test_cert(integration_app, victim["id"], ca["id"], status="expired")
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
result = integration_client.admin.get_user_ssh_certificates(victim["id"], status="revoked")
|
||||
data = assert_success(result)
|
||||
|
||||
assert data["count"] == 1
|
||||
assert data["certificates"][0]["status"] == "revoked"
|
||||
assert data["certificates"][0]["revoked"] is True
|
||||
|
||||
def test_list_user_ssh_certs_filter_active_only(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca):
|
||||
"""TEST: ADMIN-SSH-05 — Filter for only currently valid certificates.
|
||||
|
||||
WHAT: Create active and expired certs, filter by active=true.
|
||||
WHY: Admin needs quick view of currently active certs.
|
||||
EXPECTED: Only valid (non-revoked, non-expired) certs returned.
|
||||
"""
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
victim = create_test_user(password="VictimPass123!")
|
||||
org = create_test_org()
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER)
|
||||
create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER)
|
||||
ca = create_test_ca(org_id=org["id"])
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
self._create_test_cert(
|
||||
integration_app, victim["id"], ca["id"], status="issued",
|
||||
valid_after=now - timedelta(hours=1), valid_before=now + timedelta(hours=23),
|
||||
)
|
||||
self._create_test_cert(
|
||||
integration_app, victim["id"], ca["id"], status="expired",
|
||||
valid_after=now - timedelta(days=7), valid_before=now - timedelta(days=1),
|
||||
)
|
||||
self._create_test_cert(
|
||||
integration_app, victim["id"], ca["id"], status="revoked", revoked=True,
|
||||
valid_after=now - timedelta(hours=1), valid_before=now + timedelta(hours=23),
|
||||
)
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
result = integration_client.admin.get_user_ssh_certificates(victim["id"], active="true")
|
||||
data = assert_success(result)
|
||||
|
||||
assert data["count"] == 1
|
||||
cert = data["certificates"][0]
|
||||
assert cert["is_valid"] is True
|
||||
assert cert["revoked"] is False
|
||||
|
||||
def test_list_user_ssh_certs_user_not_found(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership):
|
||||
"""TEST: ADMIN-SSH-06 — Return 404 for non-existent user.
|
||||
|
||||
WHAT: Admin requests certs for a user ID that doesn't exist.
|
||||
WHY: Clear error for missing resources.
|
||||
EXPECTED: 404 NOT_FOUND.
|
||||
"""
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
org = create_test_org()
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER)
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.admin.get_user_ssh_certificates("non-existent-user-id")
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert exc_info.value.error_type == "NOT_FOUND"
|
||||
|
||||
def test_list_user_ssh_certs_empty_result(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership):
|
||||
"""TEST: ADMIN-SSH-07 — Empty result when user has no certs.
|
||||
|
||||
WHAT: Admin lists certs for a user who has never requested one.
|
||||
WHY: Endpoint should handle gracefully, not error.
|
||||
EXPECTED: 200 OK with empty certificates array and count=0.
|
||||
"""
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
victim = create_test_user(password="VictimPass123!")
|
||||
org = create_test_org()
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER)
|
||||
create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER)
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
result = integration_client.admin.get_user_ssh_certificates(victim["id"])
|
||||
data = assert_success(result)
|
||||
|
||||
assert data["certificates"] == []
|
||||
assert data["count"] == 0
|
||||
|
||||
def test_list_user_ssh_certs_revoked_cert_details(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership, create_test_ca):
|
||||
"""TEST: ADMIN-SSH-08 — Revoked certificate shows revocation details.
|
||||
|
||||
WHAT: Create a revoked cert, verify revoke metadata is present.
|
||||
WHY: Admin needs to know when and why a cert was revoked.
|
||||
EXPECTED: revoked=True, revoked_at populated, revoke_reason present.
|
||||
"""
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
victim = create_test_user(password="VictimPass123!")
|
||||
org = create_test_org()
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER)
|
||||
create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER)
|
||||
ca = create_test_ca(org_id=org["id"])
|
||||
|
||||
self._create_test_cert(
|
||||
integration_app, victim["id"], ca["id"],
|
||||
status="revoked", revoked=True,
|
||||
)
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
result = integration_client.admin.get_user_ssh_certificates(victim["id"])
|
||||
data = assert_success(result)
|
||||
|
||||
cert = data["certificates"][0]
|
||||
assert cert["revoked"] is True
|
||||
assert cert["revoked_at"] is not None
|
||||
assert cert["revoke_reason"] == "test revocation"
|
||||
assert cert["status"] == "revoked"
|
||||
|
||||
def test_list_user_ssh_certs_invalid_status_filter(self, integration_app, integration_client, create_test_user, create_test_org, create_test_membership):
|
||||
"""TEST: ADMIN-SSH-09 — Invalid status filter returns 400.
|
||||
|
||||
WHAT: Admin passes an invalid status value.
|
||||
WHY: Input validation prevents confusing queries.
|
||||
EXPECTED: 400 VALIDATION_ERROR.
|
||||
"""
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
victim = create_test_user(password="VictimPass123!")
|
||||
org = create_test_org()
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.OWNER)
|
||||
create_test_membership(victim["id"], org["id"], OrganizationRole.MEMBER)
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.admin.get_user_ssh_certificates(victim["id"], status="bogus")
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.error_type == "VALIDATION_ERROR"
|
||||
|
||||
@@ -201,3 +201,145 @@ class TestZeroTierMembership:
|
||||
except ApiError as exc:
|
||||
# Accept errors when no active memberships to kill
|
||||
assert exc.status_code in (400, 500)
|
||||
|
||||
|
||||
class TestAdminUserDevices:
|
||||
"""Test admin endpoint to list devices for a specific user."""
|
||||
|
||||
def test_list_user_devices_positive(
|
||||
self, integration_client, create_test_user, create_test_org, create_test_membership, integration_app
|
||||
):
|
||||
"""TEST: ZT-10 — Admin lists devices for a user with devices.
|
||||
|
||||
WHAT: Admin GET /organizations/<id>/users/<user_id>/devices.
|
||||
WHY: Admins need to see what devices a user has registered.
|
||||
EXPECTED: 200 OK with devices array.
|
||||
"""
|
||||
from gatehouse_app.models.zerotier.device import Device
|
||||
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
member = create_test_user(password="MemberPass123!")
|
||||
org = create_test_org()
|
||||
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.ADMIN)
|
||||
create_test_membership(member["id"], org["id"], OrganizationRole.MEMBER)
|
||||
|
||||
# Create test devices for the member
|
||||
from gatehouse_app.extensions import db as _db
|
||||
with integration_app.app_context():
|
||||
device1 = Device(
|
||||
user_id=member["id"],
|
||||
organization_id=org["id"],
|
||||
node_id="1234567890",
|
||||
device_nickname="Member Laptop",
|
||||
hostname="member-laptop",
|
||||
)
|
||||
device2 = Device(
|
||||
user_id=member["id"],
|
||||
organization_id=org["id"],
|
||||
node_id="0987654321",
|
||||
device_nickname="Member Phone",
|
||||
hostname="member-phone",
|
||||
)
|
||||
_db.session.add_all([device1, device2])
|
||||
_db.session.commit()
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
result = integration_client.get(f"/organizations/{org['id']}/users/{member['id']}/devices")
|
||||
data = assert_success(result, "devices retrieved")
|
||||
|
||||
assert "devices" in data
|
||||
assert data["count"] == 2
|
||||
assert data["user_id"] == member["id"]
|
||||
assert data["organization_id"] == org["id"]
|
||||
device_node_ids = [d["node_id"] for d in data["devices"]]
|
||||
assert "1234567890" in device_node_ids
|
||||
assert "0987654321" in device_node_ids
|
||||
|
||||
def test_list_user_devices_no_devices(
|
||||
self, integration_client, create_test_user, create_test_org, create_test_membership
|
||||
):
|
||||
"""TEST: ZT-11 — Admin lists devices for a user with no devices.
|
||||
|
||||
WHAT: Admin GET /organizations/<id>/users/<user_id>/devices for user with no devices.
|
||||
WHY: Endpoint should return empty list, not error.
|
||||
EXPECTED: 200 OK with empty devices array.
|
||||
"""
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
member = create_test_user(password="MemberPass123!")
|
||||
org = create_test_org()
|
||||
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.ADMIN)
|
||||
create_test_membership(member["id"], org["id"], OrganizationRole.MEMBER)
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
result = integration_client.get(f"/organizations/{org['id']}/users/{member['id']}/devices")
|
||||
data = assert_success(result)
|
||||
|
||||
assert data["count"] == 0
|
||||
assert data["devices"] == []
|
||||
|
||||
def test_list_user_devices_non_admin_negative(
|
||||
self, integration_client, create_test_user, create_test_org, create_test_membership
|
||||
):
|
||||
"""TEST: ZT-12 — Non-admin cannot list another user's devices.
|
||||
|
||||
WHAT: Member attempts GET /organizations/<id>/users/<user_id>/devices.
|
||||
WHY: This endpoint is admin-only.
|
||||
EXPECTED: 403 Forbidden.
|
||||
"""
|
||||
member1 = create_test_user(password="Member1Pass123!")
|
||||
member2 = create_test_user(password="Member2Pass123!")
|
||||
org = create_test_org()
|
||||
|
||||
create_test_membership(member1["id"], org["id"], OrganizationRole.MEMBER)
|
||||
create_test_membership(member2["id"], org["id"], OrganizationRole.MEMBER)
|
||||
|
||||
integration_client.auth.login(email=member1["email"], password="Member1Pass123!")
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.get(f"/organizations/{org['id']}/users/{member2['id']}/devices")
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
def test_list_user_devices_user_not_in_org_negative(
|
||||
self, integration_client, create_test_user, create_test_org, create_test_membership
|
||||
):
|
||||
"""TEST: ZT-13 — Cannot list devices for user not in organization.
|
||||
|
||||
WHAT: Admin GET /organizations/<id>/users/<user_id>/devices for user not in org.
|
||||
WHY: User must be a member of the organization.
|
||||
EXPECTED: 404 Not Found.
|
||||
"""
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
outside_user = create_test_user(password="OutsidePass123!")
|
||||
org = create_test_org()
|
||||
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.ADMIN)
|
||||
# outside_user is NOT added to the org
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.get(f"/organizations/{org['id']}/users/{outside_user['id']}/devices")
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
def test_list_user_devices_user_not_found_negative(
|
||||
self, integration_client, create_test_user, create_test_org, create_test_membership
|
||||
):
|
||||
"""TEST: ZT-14 — Cannot list devices for non-existent user.
|
||||
|
||||
WHAT: Admin GET /organizations/<id>/users/<non_existent_id>/devices.
|
||||
WHY: User must exist.
|
||||
EXPECTED: 404 Not Found.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
admin = create_test_user(password="AdminPass123!")
|
||||
org = create_test_org()
|
||||
|
||||
create_test_membership(admin["id"], org["id"], OrganizationRole.ADMIN)
|
||||
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
|
||||
integration_client.auth.login(email=admin["email"], password="AdminPass123!")
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
integration_client.get(f"/organizations/{org['id']}/users/{non_existent_id}/devices")
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Verify the structure of the Alembic migration that merges
|
||||
user_network_approvals and device_network_memberships into network_access_requests.
|
||||
|
||||
These are STRUCTURAL tests only — no database connection is required.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
# ── helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _load_migration_module():
|
||||
"""Load the migration module by file path without executing Alembic."""
|
||||
migration_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'..', '..', 'migrations', 'versions',
|
||||
'merge_approval_membership_tables.py',
|
||||
)
|
||||
migration_path = os.path.abspath(migration_path)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'merge_approval_membership_tables', migration_path,
|
||||
)
|
||||
assert spec is not None, f'Could not create module spec for {migration_path}'
|
||||
assert spec.loader is not None, f'Module spec has no loader for {migration_path}'
|
||||
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
# ── structural tests ───────────────────────────────────────────────────────
|
||||
|
||||
def test_migration_file_can_be_imported():
|
||||
"""The migration module MUST import without raising any exception."""
|
||||
mod = _load_migration_module()
|
||||
assert mod is not None
|
||||
|
||||
|
||||
def test_upgrade_function_exists():
|
||||
"""upgrade() must be a callable in the module."""
|
||||
mod = _load_migration_module()
|
||||
assert hasattr(mod, 'upgrade'), 'module is missing upgrade()'
|
||||
assert callable(mod.upgrade), 'upgrade is not callable'
|
||||
|
||||
|
||||
def test_downgrade_function_exists():
|
||||
"""downgrade() must be a callable in the module."""
|
||||
mod = _load_migration_module()
|
||||
assert hasattr(mod, 'downgrade'), 'module is missing downgrade()'
|
||||
assert callable(mod.downgrade), 'downgrade is not callable'
|
||||
|
||||
|
||||
def test_revision_is_set_correctly():
|
||||
"""revision must equal the documented value 'c0a1b2c3d4e5'."""
|
||||
mod = _load_migration_module()
|
||||
assert hasattr(mod, 'revision'), 'module is missing revision'
|
||||
assert mod.revision == 'c0a1b2c3d4e5', (
|
||||
f"Expected revision 'c0a1b2c3d4e5', got '{mod.revision}'"
|
||||
)
|
||||
|
||||
|
||||
def test_down_revision_is_set_correctly():
|
||||
"""down_revision must equal the documented value 'a1b2c3d4e5f6'."""
|
||||
mod = _load_migration_module()
|
||||
assert hasattr(mod, 'down_revision'), 'module is missing down_revision'
|
||||
assert mod.down_revision == 'a1b2c3d4e5f6', (
|
||||
f"Expected down_revision 'a1b2c3d4e5f6', got '{mod.down_revision}'"
|
||||
)
|
||||
|
||||
|
||||
def test_branch_labels_is_none():
|
||||
"""branch_labels should be None for a standard linear migration."""
|
||||
mod = _load_migration_module()
|
||||
assert mod.branch_labels is None, (
|
||||
f"Expected branch_labels None, got {mod.branch_labels!r}"
|
||||
)
|
||||
|
||||
|
||||
def test_depends_on_is_none():
|
||||
"""depends_on should be None — this migration has no cross-dependencies."""
|
||||
mod = _load_migration_module()
|
||||
assert mod.depends_on is None, (
|
||||
f"Expected depends_on None, got {mod.depends_on!r}"
|
||||
)
|
||||
@@ -0,0 +1,340 @@
|
||||
"""Unit tests for NetworkAccessRequest model structure.
|
||||
|
||||
WHAT: Verifies the model class can be imported, has the expected columns,
|
||||
constraints, and enum types.
|
||||
WHY: Structural correctness of the model is a prerequisite for Phase 2+
|
||||
work; catching missing columns or constraints early prevents
|
||||
migration/runtime failures.
|
||||
|
||||
APPROACH: gatehouse_app/__init__.py calls create_app() at module level which
|
||||
requires psycopg2 (PostgreSQL driver). We prevent this by pre-loading
|
||||
gatehouse_app as a bare namespace package, then selectively providing
|
||||
the real submodules (utils.constants) and fakes (extensions, models.base).
|
||||
|
||||
We do NOT call db.create_all() — the table metadata is fully populated
|
||||
during class definition. FK target tables don't exist in our test
|
||||
metadata, so we check FK presence without table resolution.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import importlib.util
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Step 1: Pre-load gatehouse_app as a bare namespace (prevents __init__.py)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
_gatehouse = type(sys)("gatehouse_app")
|
||||
_gatehouse.__path__ = []
|
||||
sys.modules["gatehouse_app"] = _gatehouse
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Step 2: Load the real gatehouse_app.utils.constants (self-contained, no deps)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
_constants_spec = importlib.util.spec_from_file_location(
|
||||
"gatehouse_app.utils.constants",
|
||||
"/home/ubuntu/securid/gatehouse-api/gatehouse_app/utils/constants.py",
|
||||
submodule_search_locations=[],
|
||||
)
|
||||
_constants_mod = importlib.util.module_from_spec(_constants_spec)
|
||||
sys.modules["gatehouse_app.utils"] = type(sys)("gatehouse_app.utils")
|
||||
sys.modules["gatehouse_app.utils.constants"] = _constants_mod
|
||||
_constants_spec.loader.exec_module(_constants_mod)
|
||||
|
||||
ApprovalGrantType = _constants_mod.ApprovalGrantType
|
||||
ApprovalState = _constants_mod.ApprovalState
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Step 3: Build fake extensions.db and models.base
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
_fake_db = SQLAlchemy()
|
||||
|
||||
|
||||
class FakeBaseModel(_fake_db.Model):
|
||||
"""Minimal BaseModel matching the real one's column definitions."""
|
||||
__abstract__ = True
|
||||
id = _fake_db.Column(_fake_db.String(36), primary_key=True, default=lambda: "test-uuid", nullable=False)
|
||||
created_at = _fake_db.Column(_fake_db.DateTime, nullable=False)
|
||||
updated_at = _fake_db.Column(_fake_db.DateTime, nullable=False)
|
||||
deleted_at = _fake_db.Column(_fake_db.DateTime, nullable=True)
|
||||
|
||||
def to_dict(self, exclude=None):
|
||||
"""Mimic the real BaseModel.to_dict — iterates __table__.columns."""
|
||||
from datetime import datetime, timezone
|
||||
exclude = exclude or []
|
||||
result = {}
|
||||
for column in self.__table__.columns:
|
||||
if column.name not in exclude:
|
||||
value = getattr(self, column.name)
|
||||
if isinstance(value, datetime):
|
||||
result[column.name] = value.isoformat()
|
||||
else:
|
||||
result[column.name] = value
|
||||
return result
|
||||
|
||||
|
||||
_fake_extensions = type(sys)("gatehouse_app.extensions")
|
||||
_fake_extensions.db = _fake_db
|
||||
|
||||
_fake_models_base = type(sys)("gatehouse_app.models.base")
|
||||
_fake_models_base.BaseModel = FakeBaseModel
|
||||
|
||||
sys.modules["gatehouse_app.extensions"] = _fake_extensions
|
||||
sys.modules["gatehouse_app.models"] = type(sys)("gatehouse_app.models")
|
||||
sys.modules["gatehouse_app.models.base"] = _fake_models_base
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Step 3b: Create stub models for relationship targets so ORM mapper
|
||||
# can resolve 'Organization', 'User', 'Device', 'PortalNetwork'
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class Organization(_fake_db.Model):
|
||||
__tablename__ = "organizations"
|
||||
id = _fake_db.Column(_fake_db.String(36), primary_key=True)
|
||||
|
||||
|
||||
class User(_fake_db.Model):
|
||||
__tablename__ = "users"
|
||||
id = _fake_db.Column(_fake_db.String(36), primary_key=True)
|
||||
|
||||
|
||||
class Device(_fake_db.Model):
|
||||
__tablename__ = "devices"
|
||||
id = _fake_db.Column(_fake_db.String(36), primary_key=True)
|
||||
|
||||
|
||||
class PortalNetwork(_fake_db.Model):
|
||||
__tablename__ = "portal_networks"
|
||||
id = _fake_db.Column(_fake_db.String(36), primary_key=True)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Step 4: Load the real network_access_request module from file
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
_model_spec = importlib.util.spec_from_file_location(
|
||||
"gatehouse_app.models.zerotier.network_access_request",
|
||||
"/home/ubuntu/securid/gatehouse-api/gatehouse_app/models/zerotier/network_access_request.py",
|
||||
submodule_search_locations=[],
|
||||
)
|
||||
_model_mod = importlib.util.module_from_spec(_model_spec)
|
||||
sys.modules["gatehouse_app.models.zerotier"] = type(sys)("gatehouse_app.models.zerotier")
|
||||
sys.modules["gatehouse_app.models.zerotier.network_access_request"] = _model_mod
|
||||
_model_spec.loader.exec_module(_model_mod)
|
||||
NetworkAccessRequest = _model_mod.NetworkAccessRequest
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Fixture
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_class():
|
||||
"""Return the model class — table metadata is already built at definition time."""
|
||||
return NetworkAccessRequest
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app():
|
||||
"""Minimal Flask app for to_dict (BaseModel.to_dict iterates __table__.columns)."""
|
||||
app = Flask(__name__)
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
||||
_fake_db.init_app(app)
|
||||
return app
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Test data
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
EXPECTED_LOCAL_COLUMNS = {
|
||||
"organization_id", "user_id", "device_id", "portal_network_id",
|
||||
"granted_by_user_id", "grant_type", "status", "active",
|
||||
"justification", "join_seen",
|
||||
}
|
||||
|
||||
EXPECTED_INHERITED_COLUMNS = {"id", "created_at", "updated_at", "deleted_at"}
|
||||
ALL_EXPECTED = EXPECTED_LOCAL_COLUMNS | EXPECTED_INHERITED_COLUMNS
|
||||
|
||||
# FK columns that should have foreign keys (table name, FK target)
|
||||
EXPECTED_FKS = {
|
||||
"organization_id": "organizations.id",
|
||||
"user_id": "users.id",
|
||||
"device_id": "devices.id",
|
||||
"portal_network_id": "portal_networks.id",
|
||||
"granted_by_user_id": "users.id",
|
||||
}
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Test: Module importability
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestImport:
|
||||
def test_model_importable(self, model_class):
|
||||
assert model_class is not None
|
||||
assert isinstance(model_class, type)
|
||||
|
||||
def test_model_tablename(self, model_class):
|
||||
assert model_class.__tablename__ == "network_access_requests"
|
||||
|
||||
def test_model_inherits_base(self, model_class):
|
||||
assert issubclass(model_class, FakeBaseModel)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Test: Columns
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestColumns:
|
||||
def test_all_expected_columns_present(self, model_class):
|
||||
actual = {c.name for c in model_class.__table__.columns}
|
||||
missing = ALL_EXPECTED - actual
|
||||
assert missing == set(), f"Missing columns: {missing}"
|
||||
|
||||
def test_no_extra_columns(self, model_class):
|
||||
actual = {c.name for c in model_class.__table__.columns}
|
||||
extra = actual - ALL_EXPECTED
|
||||
assert extra == set(), f"Unexpected columns: {extra}"
|
||||
|
||||
def test_exact_column_count(self, model_class):
|
||||
assert len(model_class.__table__.columns) == 14, (
|
||||
f"Expected 14 columns, got {len(model_class.__table__.columns)}: "
|
||||
f"{sorted(c.name for c in model_class.__table__.columns)}"
|
||||
)
|
||||
|
||||
def test_organization_id_is_fk_string_not_null(self, model_class):
|
||||
col = model_class.__table__.columns["organization_id"]
|
||||
assert not col.nullable
|
||||
assert _has_foreign_key(col)
|
||||
|
||||
def test_user_id_is_fk_string_not_null(self, model_class):
|
||||
col = model_class.__table__.columns["user_id"]
|
||||
assert not col.nullable
|
||||
assert _has_foreign_key(col)
|
||||
|
||||
def test_device_id_is_fk_string_not_null(self, model_class):
|
||||
col = model_class.__table__.columns["device_id"]
|
||||
assert not col.nullable
|
||||
assert _has_foreign_key(col)
|
||||
|
||||
def test_portal_network_id_is_fk_string_not_null(self, model_class):
|
||||
col = model_class.__table__.columns["portal_network_id"]
|
||||
assert not col.nullable
|
||||
assert _has_foreign_key(col)
|
||||
|
||||
def test_granted_by_user_id_nullable_fk(self, model_class):
|
||||
col = model_class.__table__.columns["granted_by_user_id"]
|
||||
assert col.nullable
|
||||
assert _has_foreign_key(col)
|
||||
|
||||
def test_justification_is_text_nullable(self, model_class):
|
||||
col = model_class.__table__.columns["justification"]
|
||||
assert col.nullable
|
||||
assert "TEXT" in str(col.type).upper()
|
||||
|
||||
def test_active_is_boolean_not_null(self, model_class):
|
||||
col = model_class.__table__.columns["active"]
|
||||
assert str(col.type) in ("BOOLEAN", "INTEGER")
|
||||
assert not col.nullable
|
||||
|
||||
def test_join_seen_is_boolean_not_null(self, model_class):
|
||||
col = model_class.__table__.columns["join_seen"]
|
||||
assert str(col.type) in ("BOOLEAN", "INTEGER")
|
||||
assert not col.nullable
|
||||
|
||||
def test_fk_count(self, model_class):
|
||||
"""Verify exactly the expected FK columns have foreign keys."""
|
||||
fk_cols = {c.name for c in model_class.__table__.columns if _has_foreign_key(c)}
|
||||
assert fk_cols == set(EXPECTED_FKS.keys()), (
|
||||
f"FK columns {sorted(fk_cols)} != expected {sorted(EXPECTED_FKS.keys())}"
|
||||
)
|
||||
|
||||
|
||||
def _has_foreign_key(column):
|
||||
"""Check if column has at least one ForeignKey, without resolving target table."""
|
||||
return bool(column.foreign_keys)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Test: UniqueConstraint
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestConstraints:
|
||||
def test_unique_constraint_exists(self, model_class):
|
||||
from sqlalchemy import UniqueConstraint
|
||||
ucs = [c for c in model_class.__table__.constraints if isinstance(c, UniqueConstraint)]
|
||||
assert len(ucs) >= 1, "No UniqueConstraint found"
|
||||
|
||||
def test_unique_constraint_columns(self, model_class):
|
||||
from sqlalchemy import UniqueConstraint
|
||||
ucs = [c for c in model_class.__table__.constraints if isinstance(c, UniqueConstraint)]
|
||||
assert len(ucs) == 1, f"Expected 1, found {len(ucs)}"
|
||||
cols = {col.name for col in ucs[0].columns}
|
||||
expected = {"user_id", "device_id", "portal_network_id", "deleted_at"}
|
||||
assert cols == expected, f"UniqueConstraint columns {cols} != {expected}"
|
||||
|
||||
def test_unique_constraint_name(self, model_class):
|
||||
from sqlalchemy import UniqueConstraint
|
||||
ucs = [c for c in model_class.__table__.constraints if isinstance(c, UniqueConstraint)]
|
||||
assert len(ucs) == 1
|
||||
assert ucs[0].name == "uix_user_device_network", (
|
||||
f"Expected 'uix_user_device_network', got '{ucs[0].name}'"
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Test: Enum types
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestEnumTypes:
|
||||
def test_status_column_uses_approval_state_enum(self, model_class):
|
||||
col = model_class.__table__.columns["status"]
|
||||
assert hasattr(col.type, "enum_class"), (
|
||||
f"status column type {type(col.type)} has no enum_class"
|
||||
)
|
||||
assert col.type.enum_class is ApprovalState, (
|
||||
f"status enum is {col.type.enum_class}, expected ApprovalState"
|
||||
)
|
||||
|
||||
def test_grant_type_column_uses_approval_grant_type_enum(self, model_class):
|
||||
col = model_class.__table__.columns["grant_type"]
|
||||
assert hasattr(col.type, "enum_class"), (
|
||||
f"grant_type column type {type(col.type)} has no enum_class"
|
||||
)
|
||||
assert col.type.enum_class is ApprovalGrantType, (
|
||||
f"grant_type enum is {col.type.enum_class}, expected ApprovalGrantType"
|
||||
)
|
||||
|
||||
def test_status_column_not_nullable(self, model_class):
|
||||
assert not model_class.__table__.columns["status"].nullable
|
||||
|
||||
def test_grant_type_column_not_nullable(self, model_class):
|
||||
assert not model_class.__table__.columns["grant_type"].nullable
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Test: Properties and methods
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestMethods:
|
||||
def test_repr_returns_string(self, model_class):
|
||||
instance = model_class()
|
||||
result = repr(instance)
|
||||
assert isinstance(result, str)
|
||||
assert "NetworkAccessRequest" in result
|
||||
|
||||
def test_active_session_property_returns_none(self, model_class):
|
||||
instance = model_class()
|
||||
assert instance.active_session is None
|
||||
|
||||
def test_to_dict_returns_dict(self, model_class, app):
|
||||
with app.app_context():
|
||||
instance = model_class()
|
||||
result = instance.to_dict()
|
||||
assert isinstance(result, dict)
|
||||
for col_name in EXPECTED_LOCAL_COLUMNS:
|
||||
assert col_name in result, f"Missing '{col_name}' in to_dict output"
|
||||
Reference in New Issue
Block a user