diff --git a/tests/api/v1/extauth/__init__.py b/tests/api/v1/extauth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/api/v1/extauth/test_provider_helpers.py b/tests/api/v1/extauth/test_provider_helpers.py new file mode 100644 index 0000000..6596110 --- /dev/null +++ b/tests/api/v1/extauth/test_provider_helpers.py @@ -0,0 +1,52 @@ +import pytest +from gatehouse_app.utils.constants import AuthMethodType +from gatehouse_app.services.external_auth.models import ExternalAuthError +from gatehouse_app.api.v1.external_auth._helpers import ( + get_provider_type, + _get_provider_endpoints, +) + + +class TestProviderType: + def test_google(self): + assert get_provider_type("google") == AuthMethodType.GOOGLE + + def test_github(self): + assert get_provider_type("github") == AuthMethodType.GITHUB + + def test_microsoft(self): + assert get_provider_type("microsoft") == AuthMethodType.MICROSOFT + + def test_case_insensitive(self): + assert get_provider_type("GitHub") == AuthMethodType.GITHUB + + def test_unknown_provider_raises(self): + with pytest.raises(ExternalAuthError) as exc_info: + get_provider_type("facebook") + assert exc_info.value.status_code == 400 + assert "facebook" in exc_info.value.message.lower() + + +class TestProviderEndpoints: + def test_google_endpoints(self): + auth, token, userinfo = _get_provider_endpoints(AuthMethodType.GOOGLE) + assert "accounts.google.com" in auth + assert "oauth2.googleapis.com" in token + assert "googleapis.com" in userinfo + + def test_github_endpoints(self): + auth, token, userinfo = _get_provider_endpoints(AuthMethodType.GITHUB) + assert "github.com/login" in auth + assert "github.com/login/oauth/access_token" in token + assert "api.github.com/user" in userinfo + + def test_microsoft_endpoints(self): + auth, token, userinfo = _get_provider_endpoints(AuthMethodType.MICROSOFT) + assert "login.microsoftonline.com" in auth + assert "login.microsoftonline.com" in token + assert "graph.microsoft.com" in userinfo + + def test_unknown_type_raises(self): + with pytest.raises(ExternalAuthError) as exc_info: + _get_provider_endpoints("nonexistent") + assert exc_info.value.status_code == 400 diff --git a/tests/api/v1/organizations/__init__.py b/tests/api/v1/organizations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/api/v1/organizations/test_system_ca_dict.py b/tests/api/v1/organizations/test_system_ca_dict.py new file mode 100644 index 0000000..2780067 --- /dev/null +++ b/tests/api/v1/organizations/test_system_ca_dict.py @@ -0,0 +1,59 @@ +import pytest +from gatehouse_app.api.v1.organizations._helpers import _get_system_ca_dict +from gatehouse_app.config.ssh_ca_config import SSHCAConfig, reset_config_instance + +# Ed25519 key fixture data +VALID_PRIVATE_KEY = ( + "-----BEGIN OPENSSH PRIVATE KEY-----\n" + "b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\n" + "QyNTUxOQAAACCi+2CgIPgoFL5P6DZlNXztuHy3+TuS2shh/xIDkW89OgAAAJhDQd+ZQ0Hf\n" + "mQAAAAtzc2gtZWQyNTUxOQAAACCi+2CgIPgoFL5P6DZlNXztuHy3+TuS2shh/xIDkW89Og\n" + "AAAECMbnF+1E22w9Z1AOTUbUGspL8Pb0UyP+p8lSLpAwZSpaL7YKAg+CgUvk/oNmU1fO24\n" + "fLf5O5LayGH/EgORbz06AAAAD2NvcnlAbGFwdG9wLXZtMQECAwQFBg==\n" + "-----END OPENSSH PRIVATE KEY-----" +) + + +class FakeEmptyConfig(SSHCAConfig): + def get_str(self, key, default=""): + if key == "ca_key_path": + return "" + return default + + +class BadConfig(SSHCAConfig): + def get_str(self, key, default=""): + raise RuntimeError("config error") + + +class TestSystemCADict: + + def test_no_key_available_returns_none(self, monkeypatch): + monkeypatch.delenv("SSH_CA_PRIVATE_KEY", raising=False) + reset_config_instance() + monkeypatch.setattr( + "gatehouse_app.config.ssh_ca_config.get_ssh_ca_config", + lambda: FakeEmptyConfig(), + ) + result = _get_system_ca_dict() + assert result is None + + def test_env_var_returns_dict(self, monkeypatch): + monkeypatch.setenv("SSH_CA_PRIVATE_KEY", VALID_PRIVATE_KEY) + result = _get_system_ca_dict() + assert result is not None + assert result["ca_type"] == "user" + assert result["is_system"] is True + assert "fingerprint" in result + assert result["public_key"] + assert result["public_key"].startswith("ssh-") + + def test_exception_gracefully_returns_none(self, monkeypatch): + monkeypatch.delenv("SSH_CA_PRIVATE_KEY", raising=False) + reset_config_instance() + monkeypatch.setattr( + "gatehouse_app.config.ssh_ca_config.get_ssh_ca_config", + lambda: BadConfig(), + ) + result = _get_system_ca_dict() + assert result is None diff --git a/tests/api/v1/ssh/test_classify_key_material.py b/tests/api/v1/ssh/test_classify_key_material.py new file mode 100644 index 0000000..ea34252 --- /dev/null +++ b/tests/api/v1/ssh/test_classify_key_material.py @@ -0,0 +1,92 @@ +import pytest +from gatehouse_app.api.v1.ssh._helpers import _classify_ssh_key_material + + +class TestClassifySSHKeyMaterial: + def test_classifies_certificate(self): + result = _classify_ssh_key_material("ssh-ed25519-cert-v01@openssh.com AAAA comment") + assert result == "certificate" + + def test_classifies_ed25519_public_key(self): + result = _classify_ssh_key_material("ssh-ed25519 AAAAB3NzaC1lZDI1NTE5AAAAI... comment") + assert result == "public_key" + + def test_classifies_rsa_public_key(self): + result = _classify_ssh_key_material("ssh-rsa AAAAB3NzaC1yc2E... comment") + assert result == "public_key" + + def test_classifies_dss_public_key(self): + result = _classify_ssh_key_material("ssh-dss AAAAB3NzaC1kc3M... comment") + assert result == "public_key" + + def test_classifies_ecdsa_nistp256_public_key(self): + result = _classify_ssh_key_material("ecdsa-sha2-nistp256 AAAAE2Vj... comment") + assert result == "public_key" + + def test_classifies_ecdsa_nistp384_public_key(self): + result = _classify_ssh_key_material("ecdsa-sha2-nistp384 AAAAE2Vj... comment") + assert result == "public_key" + + def test_classifies_ecdsa_nistp521_public_key(self): + result = _classify_ssh_key_material("ecdsa-sha2-nistp521 AAAAE2Vj... comment") + assert result == "public_key" + + def test_classifies_sk_ed25519_public_key(self): + result = _classify_ssh_key_material( + "sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5... comment" + ) + assert result == "public_key" + + def test_classifies_openssh_private_key(self): + result = _classify_ssh_key_material( + "-----BEGIN OPENSSH PRIVATE KEY-----\n" + "base64data==\n" + "-----END OPENSSH PRIVATE KEY-----" + ) + assert result == "private_key" + + def test_classifies_rsa_private_key(self): + result = _classify_ssh_key_material( + "-----BEGIN RSA PRIVATE KEY-----\n" + "base64data==\n" + "-----END RSA PRIVATE KEY-----" + ) + assert result == "private_key" + + def test_unknown_for_empty_string(self): + result = _classify_ssh_key_material("") + assert result == "unknown" + + def test_unknown_for_whitespace_string(self): + result = _classify_ssh_key_material(" \n ") + assert result == "unknown" + + def test_unknown_for_gibberish(self): + result = _classify_ssh_key_material("not a valid ssh key") + assert result == "unknown" + + def test_unknown_for_unsupported_key_type(self): + result = _classify_ssh_key_material("ssh-nonsense AAAABogus...") + assert result == "unknown" + + @pytest.mark.parametrize("raw,expected", [ + ("ssh-rsa AAAAB3Nza... user@host", "public_key"), + ("ssh-ed25519 AAAAC3... john@laptop", "public_key"), + ("ecdsa-sha2-nistp256 AAAAE2Vj... me@box", "public_key"), + ("sk-ssh-ed25519@openssh.com AAAAGn...", "public_key"), + ("ssh-ed25519-cert-v01@openssh.com AAAAB3Nza cert for user", "certificate"), + ( + "-----BEGIN OPENSSH PRIVATE KEY-----\n" + "abcdefghijklmnopqrstuvwxyz\n" + "-----END OPENSSH PRIVATE KEY-----", + "private_key", + ), + ("", "unknown"), + ("totally random garbage here", "unknown"), + ]) + def test_parametrized_variants(self, raw, expected): + assert _classify_ssh_key_material(raw) == expected + + def test_certificate_with_leading_whitespace(self): + raw = " ssh-ed25519-cert-v01@openssh.com AAAAB3Nza extra words" + assert _classify_ssh_key_material(raw) == "certificate" \ No newline at end of file diff --git a/tests/api/v1/ssh/test_dept_cert_policy.py b/tests/api/v1/ssh/test_dept_cert_policy.py new file mode 100644 index 0000000..0b90805 --- /dev/null +++ b/tests/api/v1/ssh/test_dept_cert_policy.py @@ -0,0 +1,275 @@ +import pytest +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.organization.department import ( + Department, + DepartmentMembership, +) +from gatehouse_app.models.organization.department_cert_policy import DepartmentCertPolicy +from gatehouse_app.api.v1.ssh._helpers import _get_merged_dept_cert_policy + + +class TestDeptCertPolicy: + def test_no_departments_returns_none(self, app, test_user): + with app.app_context(): + result = _get_merged_dept_cert_policy(test_user) + assert result is None + + def test_department_without_policy_returns_none(self, app, test_user, test_org): + with app.app_context(): + dept = Department( + organization_id=test_org, + name="No Policy Dept", + ) + db.session.add(dept) + db.session.commit() + + membership = DepartmentMembership( + user_id=test_user, + department_id=dept.id, + ) + db.session.add(membership) + db.session.commit() + + result = _get_merged_dept_cert_policy(test_user) + assert result is None + + def test_single_department_policy(self, app, test_user, test_org): + with app.app_context(): + dept = Department( + organization_id=test_org, + name="Engineering", + ) + db.session.add(dept) + db.session.commit() + + membership = DepartmentMembership( + user_id=test_user, + department_id=dept.id, + ) + db.session.add(membership) + db.session.commit() + + policy = DepartmentCertPolicy( + department_id=dept.id, + allow_user_expiry=True, + default_expiry_hours=4, + max_expiry_hours=48, + allowed_extensions=["permit-pty", "permit-agent-forwarding"], + ) + db.session.add(policy) + db.session.commit() + + result = _get_merged_dept_cert_policy(test_user) + assert result is not None + assert result["allow_user_expiry"] is True + assert result["default_expiry_hours"] == 4 + assert result["max_expiry_hours"] == 48 + assert set(result["extensions"]) == {"permit-pty", "permit-agent-forwarding"} + + def test_both_departments_same_policies(self, app, test_user, test_org): + with app.app_context(): + dept1 = Department( + organization_id=test_org, + name="Engineering", + ) + dept2 = Department( + organization_id=test_org, + name="SRE", + ) + db.session.add_all([dept1, dept2]) + db.session.commit() + + member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id) + member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id) + db.session.add_all([member1, member2]) + db.session.commit() + + policy1 = DepartmentCertPolicy( + department_id=dept1.id, + allow_user_expiry=True, + default_expiry_hours=4, + max_expiry_hours=48, + allowed_extensions=["permit-pty", "permit-agent-forwarding"], + ) + policy2 = DepartmentCertPolicy( + department_id=dept2.id, + allow_user_expiry=True, + default_expiry_hours=4, + max_expiry_hours=48, + allowed_extensions=["permit-pty", "permit-agent-forwarding"], + ) + db.session.add_all([policy1, policy2]) + db.session.commit() + + result = _get_merged_dept_cert_policy(test_user) + assert result["allow_user_expiry"] is True + assert result["default_expiry_hours"] == 4 + assert result["max_expiry_hours"] == 48 + + def test_merges_min_expiry_across_departments(self, app, test_user, test_org): + with app.app_context(): + dept1 = Department( + organization_id=test_org, + name="Engineering", + ) + dept2 = Department( + organization_id=test_org, + name="SRE", + ) + db.session.add_all([dept1, dept2]) + db.session.commit() + + member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id) + member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id) + db.session.add_all([member1, member2]) + db.session.commit() + + policy1 = DepartmentCertPolicy( + department_id=dept1.id, + allow_user_expiry=True, + default_expiry_hours=24, + max_expiry_hours=720, + ) + policy2 = DepartmentCertPolicy( + department_id=dept2.id, + allow_user_expiry=True, + default_expiry_hours=1, + max_expiry_hours=72, + ) + db.session.add_all([policy1, policy2]) + db.session.commit() + + result = _get_merged_dept_cert_policy(test_user) + assert result["default_expiry_hours"] == 1 + assert result["max_expiry_hours"] == 72 + + def test_extends_intersection_across_departments(self, app, test_user, test_org): + with app.app_context(): + dept1 = Department( + organization_id=test_org, + name="Engineering", + ) + dept2 = Department( + organization_id=test_org, + name="SRE", + ) + db.session.add_all([dept1, dept2]) + db.session.commit() + + member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id) + member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id) + db.session.add_all([member1, member2]) + db.session.commit() + + policy1 = DepartmentCertPolicy( + department_id=dept1.id, + allowed_extensions=["permit-pty", "permit-agent-forwarding"], + ) + policy2 = DepartmentCertPolicy( + department_id=dept2.id, + allowed_extensions=["permit-pty", "permit-port-forwarding"], + ) + db.session.add_all([policy1, policy2]) + db.session.commit() + + result = _get_merged_dept_cert_policy(test_user) + assert set(result["extensions"]) == {"permit-pty"} + + def test_any_false_user_expiry_means_overall_false( + self, app, test_user, test_org + ): + with app.app_context(): + dept1 = Department( + organization_id=test_org, + name="Engineering", + ) + dept2 = Department( + organization_id=test_org, + name="SRE", + ) + db.session.add_all([dept1, dept2]) + db.session.commit() + + member1 = DepartmentMembership(user_id=test_user, department_id=dept1.id) + member2 = DepartmentMembership(user_id=test_user, department_id=dept2.id) + db.session.add_all([member1, member2]) + db.session.commit() + + policy1 = DepartmentCertPolicy( + department_id=dept1.id, + allow_user_expiry=True, + ) + policy2 = DepartmentCertPolicy( + department_id=dept2.id, + allow_user_expiry=False, + ) + db.session.add_all([policy1, policy2]) + db.session.commit() + + result = _get_merged_dept_cert_policy(test_user) + assert result["allow_user_expiry"] is False + + def test_deleted_department_filtered(self, app, test_user, test_org): + with app.app_context(): + active_dept = Department( + organization_id=test_org, + name="Active Dept", + ) + deleted_dept = Department( + organization_id=test_org, + name="Deleted Dept", + deleted_at=datetime.now(timezone.utc), + ) + db.session.add_all([active_dept, deleted_dept]) + db.session.commit() + + active_member = DepartmentMembership( + user_id=test_user, department_id=active_dept.id + ) + deleted_member = DepartmentMembership( + user_id=test_user, + department_id=deleted_dept.id, + deleted_at=datetime.now(timezone.utc), + ) + db.session.add_all([active_member, deleted_member]) + db.session.commit() + + policy = DepartmentCertPolicy( + department_id=active_dept.id, + allow_user_expiry=True, + default_expiry_hours=12, + max_expiry_hours=96, + ) + db.session.add(policy) + db.session.commit() + + result = _get_merged_dept_cert_policy(test_user) + assert result is not None + assert result["default_expiry_hours"] == 12 + + def test_single_department_no_extensions(self, app, test_user, test_org): + with app.app_context(): + dept = Department( + organization_id=test_org, + name="Minimal Dept", + ) + db.session.add(dept) + db.session.commit() + + membership = DepartmentMembership( + user_id=test_user, department_id=dept.id + ) + db.session.add(membership) + db.session.commit() + + policy = DepartmentCertPolicy( + department_id=dept.id, + allowed_extensions=[], + ) + db.session.add(policy) + db.session.commit() + + result = _get_merged_dept_cert_policy(test_user) + assert result is not None + assert result["extensions"] == [] \ No newline at end of file diff --git a/tests/api/v1/ssh/test_org_ca_for_user.py b/tests/api/v1/ssh/test_org_ca_for_user.py new file mode 100644 index 0000000..2e9832e --- /dev/null +++ b/tests/api/v1/ssh/test_org_ca_for_user.py @@ -0,0 +1,118 @@ +import pytest +from uuid import uuid4 +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.user.user import User +from gatehouse_app.models.organization.organization import Organization +from gatehouse_app.models.organization.organization_member import OrganizationMember +from gatehouse_app.models.ssh_ca.ca import CA, CaType, KeyType +from gatehouse_app.api.v1.ssh._helpers import _get_org_ca_for_user +from gatehouse_app.utils.constants import OrganizationRole + + +class TestOrgCAForUser: + def test_organization_id_param_overrides_membership(self, app, test_user, test_org, test_ca, test_membership): + with app.app_context(): + org2 = Organization(name="Org 2", slug="org-2") + db.session.add(org2) + db.session.commit() + + ca2 = CA( + organization_id=org2.id, + name="Org 2 CA", + ca_type=CaType.USER, + key_type=KeyType.ED25519, + private_key="key2", + public_key="pubkey2", + fingerprint="sha256:org2...", + is_active=True, + ) + db.session.add(ca2) + db.session.commit() + + user = db.session.get(User, test_user) + result = _get_org_ca_for_user(user, ca_type="user", organization_id=test_org) + assert result is not None + assert result.organization_id == test_org + + def test_multiple_orgs_returns_ca(self, app, test_user, test_org, test_ca, test_membership): + with app.app_context(): + org2 = Organization(name="Org 2", slug="org-2") + db.session.add(org2) + db.session.commit() + + user = db.session.get(User, test_user) + member2 = OrganizationMember( + user_id=test_user, organization_id=org2.id, role=OrganizationRole.MEMBER + ) + db.session.add(member2) + db.session.commit() + + result = _get_org_ca_for_user(user, ca_type="user") + assert result is not None + + def test_user_with_no_memberships_returns_none(self, app): + with app.app_context(): + user = User(email="lonely@test.com", full_name="Lonely User") + db.session.add(user) + db.session.commit() + + result = _get_org_ca_for_user(user, ca_type="user") + assert result is None + + def test_inactive_ca_not_returned(self, app, test_user, test_org, test_membership): + with app.app_context(): + ca = CA( + organization_id=test_org, + name="Inactive CA", + ca_type=CaType.USER, + key_type=KeyType.ED25519, + private_key="key", + public_key="pubkey", + fingerprint="sha256:inactive123...", + is_active=False, + ) + db.session.add(ca) + db.session.commit() + + user = db.session.get(User, test_user) + result = _get_org_ca_for_user(user, ca_type="user") + assert result is None + + def test_host_ca_not_returned_when_user_requested(self, app, test_user, test_org, test_membership): + with app.app_context(): + ca = CA( + organization_id=test_org, + name="Host CA", + ca_type=CaType.HOST, + key_type=KeyType.ED25519, + private_key="key", + public_key="pubkey", + fingerprint="sha256:host123...", + is_active=True, + ) + db.session.add(ca) + db.session.commit() + + user = db.session.get(User, test_user) + result = _get_org_ca_for_user(user, ca_type="user") + assert result is None + + def test_user_ca_not_returned_when_host_requested(self, app, test_user, test_org, test_membership): + with app.app_context(): + ca = CA( + organization_id=test_org, + name="User CA", + ca_type=CaType.USER, + key_type=KeyType.ED25519, + private_key="key", + public_key="pubkey", + fingerprint="sha256:useronly...", + is_active=True, + ) + db.session.add(ca) + db.session.commit() + + user = db.session.get(User, test_user) + result = _get_org_ca_for_user(user, ca_type="host") + assert result is None \ No newline at end of file diff --git a/tests/api/v1/ssh/test_persist_certificate.py b/tests/api/v1/ssh/test_persist_certificate.py new file mode 100644 index 0000000..7c9122e --- /dev/null +++ b/tests/api/v1/ssh/test_persist_certificate.py @@ -0,0 +1,180 @@ +import pytest +from uuid import uuid4 +from datetime import datetime, timezone, timedelta +from gatehouse_app.extensions import db +from gatehouse_app.models.user.user import User +from gatehouse_app.models.ssh_ca.ca import CA, CaType, KeyType, CertType +from gatehouse_app.models.ssh_ca.ssh_key import SSHKey +from gatehouse_app.models.ssh_ca.ssh_certificate import SSHCertificate, CertificateStatus +from gatehouse_app.services.ssh_ca_signing_service import SSHCertificateSigningResponse +from gatehouse_app.api.v1.ssh._helpers import _persist_certificate + + +class TestPersistCertificate: + def test_persists_valid_certificate(self, app, test_user, test_org): + with app.app_context(): + ca = CA( + organization_id=test_org, + name="Signing CA", + ca_type=CaType.USER, + key_type=KeyType.ED25519, + private_key="enc_priv", + public_key="ssh-ed25519 AAAAB3Nza...", + fingerprint="sha256:abc123...", + is_active=True, + ) + db.session.add(ca) + db.session.commit() + + ssh_key = SSHKey( + user_id=test_user, + payload="ssh-ed25519 AAAAB3NzaC1lZDI1NTE5AAAAIKeyData comment", + fingerprint="sha256:keyfp123...", + ) + db.session.add(ssh_key) + db.session.commit() + + now = datetime.now(timezone.utc) + later = now + timedelta(hours=24) + response = SSHCertificateSigningResponse( + certificate="ssh-ed25519-cert-v01@openssh.com AAAACertData...", + serial="123456", + valid_after=now, + valid_before=later, + principals=["eng-prod"], + ) + + result = _persist_certificate( + user_id=test_user, + ssh_key_id=ssh_key.id, + ca=ca, + signing_response=response, + request_ip="10.0.0.1", + cert_type_str="user", + cert_identity="user@example.com", + ) + + assert result is not None + assert result.ca_id == ca.id + assert result.user_id == test_user + assert result.ssh_key_id == ssh_key.id + assert result.cert_type == CertType.USER + assert result.certificate == response.certificate + assert result.serial == response.serial + assert result.valid_after.replace(tzinfo=None) == now.replace(tzinfo=None) + assert result.valid_before.replace(tzinfo=None) == later.replace(tzinfo=None) + assert result.request_ip == "10.0.0.1" + assert result.key_id == "user@example.com" + assert sorted(result.principals) == ["eng-prod"] + assert result.revoked is False + assert result.status == CertificateStatus.ISSUED + + def test_none_ca_returns_none(self, app, test_user): + with app.app_context(): + now = datetime.now(timezone.utc) + response = SSHCertificateSigningResponse( + certificate="cert-data", + serial="1", + valid_after=now, + valid_before=now + timedelta(hours=1), + ) + result = _persist_certificate(test_user, "keyid", None, response) + assert result is None + + def test_invalid_cert_type_str_falls_back_to_user(self, app, test_user, test_org): + with app.app_context(): + ca = CA( + organization_id=test_org, + name="Signing CA", + ca_type=CaType.USER, + key_type=KeyType.ED25519, + private_key="enc_priv", + public_key="ssh-ed25519 AAAAB3Nza...", + fingerprint="sha256:fallback123...", + is_active=True, + ) + db.session.add(ca) + db.session.commit() + + now = datetime.now(timezone.utc) + response = SSHCertificateSigningResponse( + certificate="cert-data", + serial="1", + valid_after=now, + valid_before=now + timedelta(hours=1), + ) + result = _persist_certificate( + user_id=test_user, + ssh_key_id=None, + ca=ca, + signing_response=response, + cert_type_str="invalid_type", + ) + + assert result is not None + assert result.cert_type == CertType.USER + + def test_none_ssh_key_id_defaults_to_host_cert_key_id(self, app, test_user, test_org): + with app.app_context(): + ca = CA( + organization_id=test_org, + name="Host CA", + ca_type=CaType.HOST, + key_type=KeyType.ED25519, + private_key="enc_priv", + public_key="ssh-ed25519 AAAAB3Nza...", + fingerprint="sha256:hostca123...", + is_active=True, + ) + db.session.add(ca) + db.session.commit() + + now = datetime.now(timezone.utc) + response = SSHCertificateSigningResponse( + certificate="cert-data", + serial="1", + valid_after=now, + valid_before=now + timedelta(hours=1), + ) + result = _persist_certificate( + user_id=test_user, + ssh_key_id=None, + ca=ca, + signing_response=response, + ) + + assert result is not None + assert result.key_id == "host-cert" + + def test_request_ip_stored(self, app, test_user, test_org): + with app.app_context(): + ca = CA( + organization_id=test_org, + name="Signing CA", + ca_type=CaType.USER, + key_type=KeyType.ED25519, + private_key="enc_priv", + public_key="ssh-ed25519 AAAAB3Nza...", + fingerprint="sha256:ip...", + is_active=True, + ) + db.session.add(ca) + db.session.commit() + + now = datetime.now(timezone.utc) + response = SSHCertificateSigningResponse( + certificate="cert-data", + serial="1", + valid_after=now, + valid_before=now + timedelta(hours=1), + ) + result = _persist_certificate( + user_id=test_user, + ssh_key_id=None, + ca=ca, + signing_response=response, + request_ip="192.168.1.100", + ) + + assert result is not None + assert result.request_ip == "192.168.1.100" \ No newline at end of file diff --git a/tests/api/v1/ssh/test_ssh_key_service.py b/tests/api/v1/ssh/test_ssh_key_service.py new file mode 100644 index 0000000..a9e41f2 --- /dev/null +++ b/tests/api/v1/ssh/test_ssh_key_service.py @@ -0,0 +1,78 @@ +import pytest +from datetime import datetime, timezone +from gatehouse_app.extensions import db +from gatehouse_app.models.user.user import User +from gatehouse_app.models.ssh_ca.ssh_key import SSHKey +from gatehouse_app.services.ssh_key_service import SSHKeyService +from gatehouse_app.exceptions import UserNotFoundError, SSHKeyError + + +VALID_PUBLIC_KEY = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKL7YKAg+CgUvk/oNmU1fO24fLf5O5LayGH/EgORbz06" + + +class TestSSHKeyServiceAdd: + + def test_add_new_key_returns_true(self, app, test_user): + with app.app_context(): + service = SSHKeyService() + key, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "My laptop") + assert is_new is True + assert key.user_id == test_user + assert key.payload == VALID_PUBLIC_KEY + assert key.description == "My laptop" + assert key.verified is False + assert key.fingerprint is not None + assert key.key_type is not None + + def test_add_duplicate_returns_existing(self, app, test_user): + with app.app_context(): + service = SSHKeyService() + key1, _ = service.add_ssh_key(test_user, VALID_PUBLIC_KEY) + key2, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY) + assert is_new is False + assert key2.id == key1.id + + def test_add_restores_soft_deleted_key(self, app, test_user): + with app.app_context(): + service = SSHKeyService() + key1, _ = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "Original") + + # Soft-delete the key + key1.deleted_at = datetime.now(timezone.utc) + db.session.commit() + + # Re-add same key + key2, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "Restored") + assert is_new is False + assert key2.id == key1.id + assert key2.deleted_at is None + assert key2.description == "Restored" + assert key2.verified is False + assert key2.verified_at is None + + def test_add_with_description(self, app, test_user): + with app.app_context(): + service = SSHKeyService() + key, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY, "Work laptop") + assert is_new is True + assert key.description == "Work laptop" + + def test_user_not_found_raises(self, app): + with app.app_context(): + service = SSHKeyService() + with pytest.raises(UserNotFoundError): + service.add_ssh_key("nonexistent-user-id", VALID_PUBLIC_KEY) + + def test_invalid_key_format_raises(self, app, test_user): + with app.app_context(): + service = SSHKeyService() + with pytest.raises(SSHKeyError): + service.add_ssh_key(test_user, "not-a-valid-key") + + def test_idempotent_second_call_no_error(self, app, test_user): + with app.app_context(): + service = SSHKeyService() + service.add_ssh_key(test_user, VALID_PUBLIC_KEY) + key2, is_new = service.add_ssh_key(test_user, VALID_PUBLIC_KEY) + assert is_new is False + assert key2 is not None diff --git a/tests/api/v1/test_cert_signing_request.py b/tests/api/v1/test_cert_signing_request.py new file mode 100644 index 0000000..ffee864 --- /dev/null +++ b/tests/api/v1/test_cert_signing_request.py @@ -0,0 +1,148 @@ +import pytest +from gatehouse_app.services.ssh_ca_signing_service import SSHCertificateSigningRequest + + +VALID_PUBLIC_KEY = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKL7YKAg+CgUvk/oNmU1fO24fLf5O5LayGH/EgORbz06" + + +class TestCertSigningRequestValidate: + + @pytest.fixture(autouse=True) + def patch_config(self, monkeypatch): + from gatehouse_app.config.ssh_ca_config import SSHCAConfig + + class TestConfig(SSHCAConfig): + def get_int(self, key, default=0): + values = { + "max_cert_validity_hours": 720, + "max_principals_per_cert": 256, + "max_key_id_length": 255, + } + return values.get(key, default) + + monkeypatch.setattr( + "gatehouse_app.config.ssh_ca_config.get_ssh_ca_config", + lambda: TestConfig(), + ) + + def test_valid_request_no_errors(self): + req = SSHCertificateSigningRequest( + ssh_public_key=VALID_PUBLIC_KEY, + principals=["eng-prod"], + key_id="user@example.com", + ) + errors = req.validate() + assert errors == [] + + def test_valid_host_cert_no_errors(self): + req = SSHCertificateSigningRequest( + ssh_public_key=VALID_PUBLIC_KEY, + principals=["host1.example.com"], + key_id="host-identity", + cert_type="host", + ) + errors = req.validate() + assert errors == [] + + def test_invalid_cert_type(self): + req = SSHCertificateSigningRequest( + ssh_public_key=VALID_PUBLIC_KEY, + principals=["eng-prod"], + key_id="user@example.com", + cert_type="invalid", + ) + errors = req.validate() + assert any("cert_type" in e.lower() for e in errors) + + def test_missing_public_key(self): + req = SSHCertificateSigningRequest( + ssh_public_key="", + principals=["eng-prod"], + key_id="user@example.com", + ) + errors = req.validate() + assert any("public key" in e.lower() for e in errors) + + def test_malformed_public_key(self): + req = SSHCertificateSigningRequest( + ssh_public_key="not-a-key", + principals=["eng-prod"], + key_id="user@example.com", + ) + errors = req.validate() + assert any("public key" in e.lower() for e in errors) + + def test_no_principals(self): + req = SSHCertificateSigningRequest( + ssh_public_key=VALID_PUBLIC_KEY, + principals=[], + key_id="user@example.com", + ) + errors = req.validate() + assert any("principal" in e.lower() for e in errors) + + def test_too_many_principals(self): + req = SSHCertificateSigningRequest( + ssh_public_key=VALID_PUBLIC_KEY, + principals=[f"p{i}" for i in range(300)], + key_id="user@example.com", + ) + errors = req.validate() + assert any("too many" in e.lower() for e in errors) + + def test_missing_key_id(self): + req = SSHCertificateSigningRequest( + ssh_public_key=VALID_PUBLIC_KEY, + principals=["eng-prod"], + key_id="", + ) + errors = req.validate() + assert any("key_id" in e.lower() for e in errors) + + def test_key_id_too_short(self): + req = SSHCertificateSigningRequest( + ssh_public_key=VALID_PUBLIC_KEY, + principals=["eng-prod"], + key_id="ab", + ) + errors = req.validate() + assert any("key_id" in e.lower() for e in errors) + + def test_key_id_exceeds_max_length(self): + req = SSHCertificateSigningRequest( + ssh_public_key=VALID_PUBLIC_KEY, + principals=["eng-prod"], + key_id="x" * 300, + ) + errors = req.validate() + assert any("key_id" in e.lower() for e in errors) + + def test_non_positive_expiry(self): + req = SSHCertificateSigningRequest( + ssh_public_key=VALID_PUBLIC_KEY, + principals=["eng-prod"], + key_id="user@example.com", + expiry_hours=0, + ) + errors = req.validate() + assert any("expiry" in e.lower() for e in errors) + + def test_expiry_exceeds_max(self): + req = SSHCertificateSigningRequest( + ssh_public_key=VALID_PUBLIC_KEY, + principals=["eng-prod"], + key_id="user@example.com", + expiry_hours=99999, + ) + errors = req.validate() + assert any("expiry" in e.lower() for e in errors) + + def test_none_expiry_is_ok(self): + req = SSHCertificateSigningRequest( + ssh_public_key=VALID_PUBLIC_KEY, + principals=["eng-prod"], + key_id="user@example.com", + expiry_hours=None, + ) + errors = req.validate() + assert errors == [] diff --git a/tests/api/v1/test_superadmin_schemas.py b/tests/api/v1/test_superadmin_schemas.py new file mode 100644 index 0000000..36b7f88 --- /dev/null +++ b/tests/api/v1/test_superadmin_schemas.py @@ -0,0 +1,77 @@ +from gatehouse_app.api.v1.superadmin.organizations import ( + ListOrganizationsSchema, + UpdateOrganizationSchema, +) + + +class TestListOrganizationsSchema: + def test_defaults_when_empty(self): + result = ListOrganizationsSchema.load({}) + assert result["page"] == 1 + assert result["per_page"] == 20 + assert result["search"] is None + assert result["status"] is None + assert result["plan_slug"] is None + + def test_normal_pagination(self): + result = ListOrganizationsSchema.load({"page": 3, "per_page": 10}) + assert result["page"] == 3 + assert result["per_page"] == 10 + + def test_page_zero_clamped_to_one(self): + result = ListOrganizationsSchema.load({"page": 0}) + assert result["page"] == 1 + + def test_negative_per_page_clamped_to_one(self): + result = ListOrganizationsSchema.load({"per_page": -5}) + assert result["per_page"] == 1 + + def test_per_page_exceeds_max_clamped_to_100(self): + result = ListOrganizationsSchema.load({"per_page": 200}) + assert result["per_page"] == 100 + + def test_non_integer_values_fallback(self): + result = ListOrganizationsSchema.load({"page": "abc", "per_page": "xyz"}) + assert result["page"] == 1 + assert result["per_page"] == 20 + + def test_search_passthrough(self): + result = ListOrganizationsSchema.load({"search": "acme"}) + assert result["search"] == "acme" + + def test_status_passthrough(self): + result = ListOrganizationsSchema.load({"status": "active"}) + assert result["status"] == "active" + + def test_plan_slug_passthrough(self): + result = ListOrganizationsSchema.load({"plan_slug": "pro"}) + assert result["plan_slug"] == "pro" + + +class TestUpdateOrganizationSchema: + def test_all_fields(self): + result = UpdateOrganizationSchema.load({ + "name": "New Name", + "description": "New Description", + "is_active": True, + }) + assert result == { + "name": "New Name", + "description": "New Description", + "is_active": True, + } + + def test_empty_dict(self): + result = UpdateOrganizationSchema.load({}) + assert result == {} + + def test_partial_data(self): + result = UpdateOrganizationSchema.load({"name": "Renamed Only"}) + assert result == {"name": "Renamed Only"} + + def test_is_active_coerced_to_bool(self): + result = UpdateOrganizationSchema.load({"is_active": "truthy"}) + assert result["is_active"] is True + + result = UpdateOrganizationSchema.load({"is_active": ""}) + assert result["is_active"] is False