"""Unit tests for ca_key_encryption module. WHAT: Tests for the Fernet-based CA private key encryption/decryption utility functions. WHY: CA private keys are the most sensitive data in the system; we need to verify round-trip correctness, idempotency, and error handling. EXPECTED: All encrypt/decrypt operations produce correct plaintext. """ import os import threading from unittest.mock import patch import pytest from gatehouse_app.utils.ca_key_encryption import ( CAKeyEncryptionError, _FERNET_PREFIX, _get_fernet, decrypt_ca_key, encrypt_ca_key, is_encrypted, reencrypt_ca_key, ) # --------------------------------------------------------------------------- # Shared fixture # --------------------------------------------------------------------------- SAMPLE_PEM = ( "-----BEGIN OPENSSH PRIVATE KEY-----\n" "b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtz\n" "c2gtZWQyNTUxOQAAACBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBAAA\n" "-----END OPENSSH PRIVATE KEY-----" ) @pytest.fixture(autouse=True) def _set_ca_encryption_key(): """Ensure CA_ENCRYPTION_KEY is set for every test.""" with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "test-secret-key-for-unit-tests"}): yield # --------------------------------------------------------------------------- # encrypt / decrypt round-trip # --------------------------------------------------------------------------- class TestEncryptDecryptRoundTrip: """Verify that encrypt -> decrypt returns the original plaintext.""" def test_basic_round_trip(self): """TEST: ENC-RT-01 -- Encrypt then decrypt returns original PEM.""" encrypted = encrypt_ca_key(SAMPLE_PEM) decrypted = decrypt_ca_key(encrypted) assert decrypted == SAMPLE_PEM def test_encrypted_value_has_prefix(self): """TEST: ENC-RT-02 -- Encrypted output carries the $fernet$ envelope.""" encrypted = encrypt_ca_key(SAMPLE_PEM) assert encrypted.startswith(_FERNET_PREFIX) def test_different_ciphertext_each_time(self): """TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ.""" enc1 = encrypt_ca_key(SAMPLE_PEM) enc2 = encrypt_ca_key(SAMPLE_PEM) assert enc1 != enc2 assert decrypt_ca_key(enc1) == SAMPLE_PEM assert decrypt_ca_key(enc2) == SAMPLE_PEM # --------------------------------------------------------------------------- # Idempotency # --------------------------------------------------------------------------- class TestIdempotency: """The module must not double-encrypt or double-decrypt.""" def test_encrypt_idempotent(self): """TEST: ENC-IDEM-01 -- Encrypting an already-encrypted value is a no-op.""" encrypted = encrypt_ca_key(SAMPLE_PEM) double = encrypt_ca_key(encrypted) assert double == encrypted def test_decrypt_plaintext_passthrough(self): """TEST: ENC-IDEM-02 -- Decrypting a plaintext (legacy) value returns it as-is.""" result = decrypt_ca_key(SAMPLE_PEM) assert result == SAMPLE_PEM # --------------------------------------------------------------------------- # is_encrypted helper # --------------------------------------------------------------------------- class TestIsEncrypted: def test_encrypted_value(self): """TEST: ENC-IE-01 -- is_encrypted returns True for $fernet$ values.""" encrypted = encrypt_ca_key(SAMPLE_PEM) assert is_encrypted(encrypted) is True def test_plaintext_value(self): """TEST: ENC-IE-02 -- is_encrypted returns False for plain PEM.""" assert is_encrypted(SAMPLE_PEM) is False def test_empty_string(self): """TEST: ENC-IE-03 -- is_encrypted returns False for empty string.""" assert is_encrypted("") is False def test_none_value(self): """TEST: ENC-IE-04 -- is_encrypted returns False for None.""" assert is_encrypted(None) is False # --------------------------------------------------------------------------- # Error handling # --------------------------------------------------------------------------- class TestErrorHandling: def test_encrypt_empty_raises(self): """TEST: ENC-ERR-01 -- Encrypting empty string raises CAKeyEncryptionError.""" with pytest.raises(CAKeyEncryptionError, match="empty"): encrypt_ca_key("") def test_decrypt_empty_raises(self): """TEST: ENC-ERR-02 -- Decrypting empty string raises CAKeyEncryptionError.""" with pytest.raises(CAKeyEncryptionError, match="empty"): decrypt_ca_key("") def test_missing_key_raises(self): """TEST: ENC-ERR-03 -- Operations fail when CA_ENCRYPTION_KEY is unset.""" with patch.dict(os.environ, {}, clear=True): os.environ.pop("CA_ENCRYPTION_KEY", None) with pytest.raises(CAKeyEncryptionError, match="not set"): encrypt_ca_key(SAMPLE_PEM) def test_wrong_key_raises_on_decrypt(self): """TEST: ENC-ERR-04 -- Decrypting with the wrong key raises.""" encrypted = encrypt_ca_key(SAMPLE_PEM) with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "wrong-key"}): with pytest.raises(CAKeyEncryptionError, match="decryption failed"): decrypt_ca_key(encrypted) def test_corrupted_data_raises(self): """TEST: ENC-ERR-05 -- Decrypting corrupted ciphertext raises.""" with pytest.raises(CAKeyEncryptionError): decrypt_ca_key("$fernet$not-a-real-token") # --------------------------------------------------------------------------- # reencrypt_ca_key -- key rotation # --------------------------------------------------------------------------- class TestReencrypt: def test_reencrypt_round_trip(self): """TEST: ENC-RE-01 -- Re-encrypted value decrypts with the new key.""" old_key = "old-secret-key" new_key = "new-secret-key" encrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", old_key) reencrypted = reencrypt_ca_key(encrypted, old_key, new_key) # Verify it decrypts with the new key with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}): decrypted = decrypt_ca_key(reencrypted) assert decrypted == SAMPLE_PEM def test_reencrypt_plaintext_key(self): """TEST: ENC-RE-02 -- Re-encrypting a legacy plaintext key works.""" new_key = "brand-new-key" reencrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", new_key) with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}): decrypted = decrypt_ca_key(reencrypted) assert decrypted == SAMPLE_PEM # --------------------------------------------------------------------------- # Thread safety # --------------------------------------------------------------------------- class TestThreadSafety: """Concurrent encrypt/decrypt calls must not corrupt state.""" def test_concurrent_encrypt_decrypt(self): """TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently.""" errors = [] results = [] def worker(i): try: data = f"key-data-{i}" enc = encrypt_ca_key(data) dec = decrypt_ca_key(enc) results.append((i, dec)) except Exception as exc: errors.append((i, exc)) threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)] for t in threads: t.start() for t in threads: t.join(timeout=10) assert not errors, f"Thread errors: {errors}" assert len(results) == 50 for i, dec in results: assert dec == f"key-data-{i}", f"Thread {i}: expected 'key-data-{i}', got {dec!r}"