Files

206 lines
7.7 KiB
Python
Raw Permalink Normal View History

"""Unit tests for ca_key_encryption module.
WHAT: Tests for the Fernet-based CA private key encryption/decryption
utility functions.
WHY: CA private keys are the most sensitive data in the system; we need
to verify round-trip correctness, idempotency, and error handling.
EXPECTED: All encrypt/decrypt operations produce correct plaintext.
"""
import os
import threading
from unittest.mock import patch
import pytest
from gatehouse_app.utils.ca_key_encryption import (
CAKeyEncryptionError,
_FERNET_PREFIX,
_get_fernet,
decrypt_ca_key,
encrypt_ca_key,
is_encrypted,
reencrypt_ca_key,
)
# ---------------------------------------------------------------------------
# Shared fixture
# ---------------------------------------------------------------------------
SAMPLE_PEM = (
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtz\n"
"c2gtZWQyNTUxOQAAACBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBAAA\n"
"-----END OPENSSH PRIVATE KEY-----"
)
@pytest.fixture(autouse=True)
def _set_ca_encryption_key():
"""Ensure CA_ENCRYPTION_KEY is set for every test."""
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "test-secret-key-for-unit-tests"}):
yield
# ---------------------------------------------------------------------------
# encrypt / decrypt round-trip
# ---------------------------------------------------------------------------
class TestEncryptDecryptRoundTrip:
"""Verify that encrypt -> decrypt returns the original plaintext."""
def test_basic_round_trip(self):
"""TEST: ENC-RT-01 -- Encrypt then decrypt returns original PEM."""
encrypted = encrypt_ca_key(SAMPLE_PEM)
decrypted = decrypt_ca_key(encrypted)
assert decrypted == SAMPLE_PEM
def test_encrypted_value_has_prefix(self):
"""TEST: ENC-RT-02 -- Encrypted output carries the $fernet$ envelope."""
encrypted = encrypt_ca_key(SAMPLE_PEM)
assert encrypted.startswith(_FERNET_PREFIX)
def test_different_ciphertext_each_time(self):
"""TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ."""
enc1 = encrypt_ca_key(SAMPLE_PEM)
enc2 = encrypt_ca_key(SAMPLE_PEM)
assert enc1 != enc2
assert decrypt_ca_key(enc1) == SAMPLE_PEM
assert decrypt_ca_key(enc2) == SAMPLE_PEM
# ---------------------------------------------------------------------------
# Idempotency
# ---------------------------------------------------------------------------
class TestIdempotency:
"""The module must not double-encrypt or double-decrypt."""
def test_encrypt_idempotent(self):
"""TEST: ENC-IDEM-01 -- Encrypting an already-encrypted value is a no-op."""
encrypted = encrypt_ca_key(SAMPLE_PEM)
double = encrypt_ca_key(encrypted)
assert double == encrypted
def test_decrypt_plaintext_passthrough(self):
"""TEST: ENC-IDEM-02 -- Decrypting a plaintext (legacy) value returns it as-is."""
result = decrypt_ca_key(SAMPLE_PEM)
assert result == SAMPLE_PEM
# ---------------------------------------------------------------------------
# is_encrypted helper
# ---------------------------------------------------------------------------
class TestIsEncrypted:
def test_encrypted_value(self):
"""TEST: ENC-IE-01 -- is_encrypted returns True for $fernet$ values."""
encrypted = encrypt_ca_key(SAMPLE_PEM)
assert is_encrypted(encrypted) is True
def test_plaintext_value(self):
"""TEST: ENC-IE-02 -- is_encrypted returns False for plain PEM."""
assert is_encrypted(SAMPLE_PEM) is False
def test_empty_string(self):
"""TEST: ENC-IE-03 -- is_encrypted returns False for empty string."""
assert is_encrypted("") is False
def test_none_value(self):
"""TEST: ENC-IE-04 -- is_encrypted returns False for None."""
assert is_encrypted(None) is False
# ---------------------------------------------------------------------------
# Error handling
# ---------------------------------------------------------------------------
class TestErrorHandling:
def test_encrypt_empty_raises(self):
"""TEST: ENC-ERR-01 -- Encrypting empty string raises CAKeyEncryptionError."""
with pytest.raises(CAKeyEncryptionError, match="empty"):
encrypt_ca_key("")
def test_decrypt_empty_raises(self):
"""TEST: ENC-ERR-02 -- Decrypting empty string raises CAKeyEncryptionError."""
with pytest.raises(CAKeyEncryptionError, match="empty"):
decrypt_ca_key("")
def test_missing_key_raises(self):
"""TEST: ENC-ERR-03 -- Operations fail when CA_ENCRYPTION_KEY is unset."""
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("CA_ENCRYPTION_KEY", None)
with pytest.raises(CAKeyEncryptionError, match="not set"):
encrypt_ca_key(SAMPLE_PEM)
def test_wrong_key_raises_on_decrypt(self):
"""TEST: ENC-ERR-04 -- Decrypting with the wrong key raises."""
encrypted = encrypt_ca_key(SAMPLE_PEM)
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "wrong-key"}):
with pytest.raises(CAKeyEncryptionError, match="decryption failed"):
decrypt_ca_key(encrypted)
def test_corrupted_data_raises(self):
"""TEST: ENC-ERR-05 -- Decrypting corrupted ciphertext raises."""
with pytest.raises(CAKeyEncryptionError):
decrypt_ca_key("$fernet$not-a-real-token")
# ---------------------------------------------------------------------------
# reencrypt_ca_key -- key rotation
# ---------------------------------------------------------------------------
class TestReencrypt:
def test_reencrypt_round_trip(self):
"""TEST: ENC-RE-01 -- Re-encrypted value decrypts with the new key."""
old_key = "old-secret-key"
new_key = "new-secret-key"
encrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", old_key)
reencrypted = reencrypt_ca_key(encrypted, old_key, new_key)
# Verify it decrypts with the new key
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}):
decrypted = decrypt_ca_key(reencrypted)
assert decrypted == SAMPLE_PEM
def test_reencrypt_plaintext_key(self):
"""TEST: ENC-RE-02 -- Re-encrypting a legacy plaintext key works."""
new_key = "brand-new-key"
reencrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", new_key)
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}):
decrypted = decrypt_ca_key(reencrypted)
assert decrypted == SAMPLE_PEM
# ---------------------------------------------------------------------------
# Thread safety
# ---------------------------------------------------------------------------
class TestThreadSafety:
"""Concurrent encrypt/decrypt calls must not corrupt state."""
def test_concurrent_encrypt_decrypt(self):
"""TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently."""
errors = []
results = []
def worker(i):
try:
data = f"key-data-{i}"
enc = encrypt_ca_key(data)
dec = decrypt_ca_key(enc)
results.append((i, dec))
except Exception as exc:
errors.append((i, exc))
threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=10)
assert not errors, f"Thread errors: {errors}"
assert len(results) == 50
for i, dec in results:
assert dec == f"key-data-{i}", f"Thread {i}: expected 'key-data-{i}', got {dec!r}"