206 lines
7.7 KiB
Python
206 lines
7.7 KiB
Python
|
|
"""Unit tests for ca_key_encryption module.
|
||
|
|
|
||
|
|
WHAT: Tests for the Fernet-based CA private key encryption/decryption
|
||
|
|
utility functions.
|
||
|
|
WHY: CA private keys are the most sensitive data in the system; we need
|
||
|
|
to verify round-trip correctness, idempotency, and error handling.
|
||
|
|
EXPECTED: All encrypt/decrypt operations produce correct plaintext.
|
||
|
|
"""
|
||
|
|
import os
|
||
|
|
import threading
|
||
|
|
from unittest.mock import patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from gatehouse_app.utils.ca_key_encryption import (
|
||
|
|
CAKeyEncryptionError,
|
||
|
|
_FERNET_PREFIX,
|
||
|
|
_get_fernet,
|
||
|
|
decrypt_ca_key,
|
||
|
|
encrypt_ca_key,
|
||
|
|
is_encrypted,
|
||
|
|
reencrypt_ca_key,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Shared fixture
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
SAMPLE_PEM = (
|
||
|
|
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
|
||
|
|
"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtz\n"
|
||
|
|
"c2gtZWQyNTUxOQAAACBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBAAA\n"
|
||
|
|
"-----END OPENSSH PRIVATE KEY-----"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture(autouse=True)
|
||
|
|
def _set_ca_encryption_key():
|
||
|
|
"""Ensure CA_ENCRYPTION_KEY is set for every test."""
|
||
|
|
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "test-secret-key-for-unit-tests"}):
|
||
|
|
yield
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# encrypt / decrypt round-trip
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestEncryptDecryptRoundTrip:
|
||
|
|
"""Verify that encrypt -> decrypt returns the original plaintext."""
|
||
|
|
|
||
|
|
def test_basic_round_trip(self):
|
||
|
|
"""TEST: ENC-RT-01 -- Encrypt then decrypt returns original PEM."""
|
||
|
|
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||
|
|
decrypted = decrypt_ca_key(encrypted)
|
||
|
|
assert decrypted == SAMPLE_PEM
|
||
|
|
|
||
|
|
def test_encrypted_value_has_prefix(self):
|
||
|
|
"""TEST: ENC-RT-02 -- Encrypted output carries the $fernet$ envelope."""
|
||
|
|
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||
|
|
assert encrypted.startswith(_FERNET_PREFIX)
|
||
|
|
|
||
|
|
def test_different_ciphertext_each_time(self):
|
||
|
|
"""TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ."""
|
||
|
|
enc1 = encrypt_ca_key(SAMPLE_PEM)
|
||
|
|
enc2 = encrypt_ca_key(SAMPLE_PEM)
|
||
|
|
assert enc1 != enc2
|
||
|
|
assert decrypt_ca_key(enc1) == SAMPLE_PEM
|
||
|
|
assert decrypt_ca_key(enc2) == SAMPLE_PEM
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Idempotency
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestIdempotency:
|
||
|
|
"""The module must not double-encrypt or double-decrypt."""
|
||
|
|
|
||
|
|
def test_encrypt_idempotent(self):
|
||
|
|
"""TEST: ENC-IDEM-01 -- Encrypting an already-encrypted value is a no-op."""
|
||
|
|
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||
|
|
double = encrypt_ca_key(encrypted)
|
||
|
|
assert double == encrypted
|
||
|
|
|
||
|
|
def test_decrypt_plaintext_passthrough(self):
|
||
|
|
"""TEST: ENC-IDEM-02 -- Decrypting a plaintext (legacy) value returns it as-is."""
|
||
|
|
result = decrypt_ca_key(SAMPLE_PEM)
|
||
|
|
assert result == SAMPLE_PEM
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# is_encrypted helper
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestIsEncrypted:
|
||
|
|
def test_encrypted_value(self):
|
||
|
|
"""TEST: ENC-IE-01 -- is_encrypted returns True for $fernet$ values."""
|
||
|
|
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||
|
|
assert is_encrypted(encrypted) is True
|
||
|
|
|
||
|
|
def test_plaintext_value(self):
|
||
|
|
"""TEST: ENC-IE-02 -- is_encrypted returns False for plain PEM."""
|
||
|
|
assert is_encrypted(SAMPLE_PEM) is False
|
||
|
|
|
||
|
|
def test_empty_string(self):
|
||
|
|
"""TEST: ENC-IE-03 -- is_encrypted returns False for empty string."""
|
||
|
|
assert is_encrypted("") is False
|
||
|
|
|
||
|
|
def test_none_value(self):
|
||
|
|
"""TEST: ENC-IE-04 -- is_encrypted returns False for None."""
|
||
|
|
assert is_encrypted(None) is False
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Error handling
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestErrorHandling:
|
||
|
|
def test_encrypt_empty_raises(self):
|
||
|
|
"""TEST: ENC-ERR-01 -- Encrypting empty string raises CAKeyEncryptionError."""
|
||
|
|
with pytest.raises(CAKeyEncryptionError, match="empty"):
|
||
|
|
encrypt_ca_key("")
|
||
|
|
|
||
|
|
def test_decrypt_empty_raises(self):
|
||
|
|
"""TEST: ENC-ERR-02 -- Decrypting empty string raises CAKeyEncryptionError."""
|
||
|
|
with pytest.raises(CAKeyEncryptionError, match="empty"):
|
||
|
|
decrypt_ca_key("")
|
||
|
|
|
||
|
|
def test_missing_key_raises(self):
|
||
|
|
"""TEST: ENC-ERR-03 -- Operations fail when CA_ENCRYPTION_KEY is unset."""
|
||
|
|
with patch.dict(os.environ, {}, clear=True):
|
||
|
|
os.environ.pop("CA_ENCRYPTION_KEY", None)
|
||
|
|
with pytest.raises(CAKeyEncryptionError, match="not set"):
|
||
|
|
encrypt_ca_key(SAMPLE_PEM)
|
||
|
|
|
||
|
|
def test_wrong_key_raises_on_decrypt(self):
|
||
|
|
"""TEST: ENC-ERR-04 -- Decrypting with the wrong key raises."""
|
||
|
|
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||
|
|
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "wrong-key"}):
|
||
|
|
with pytest.raises(CAKeyEncryptionError, match="decryption failed"):
|
||
|
|
decrypt_ca_key(encrypted)
|
||
|
|
|
||
|
|
def test_corrupted_data_raises(self):
|
||
|
|
"""TEST: ENC-ERR-05 -- Decrypting corrupted ciphertext raises."""
|
||
|
|
with pytest.raises(CAKeyEncryptionError):
|
||
|
|
decrypt_ca_key("$fernet$not-a-real-token")
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# reencrypt_ca_key -- key rotation
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestReencrypt:
|
||
|
|
def test_reencrypt_round_trip(self):
|
||
|
|
"""TEST: ENC-RE-01 -- Re-encrypted value decrypts with the new key."""
|
||
|
|
old_key = "old-secret-key"
|
||
|
|
new_key = "new-secret-key"
|
||
|
|
encrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", old_key)
|
||
|
|
reencrypted = reencrypt_ca_key(encrypted, old_key, new_key)
|
||
|
|
|
||
|
|
# Verify it decrypts with the new key
|
||
|
|
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}):
|
||
|
|
decrypted = decrypt_ca_key(reencrypted)
|
||
|
|
assert decrypted == SAMPLE_PEM
|
||
|
|
|
||
|
|
def test_reencrypt_plaintext_key(self):
|
||
|
|
"""TEST: ENC-RE-02 -- Re-encrypting a legacy plaintext key works."""
|
||
|
|
new_key = "brand-new-key"
|
||
|
|
reencrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", new_key)
|
||
|
|
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}):
|
||
|
|
decrypted = decrypt_ca_key(reencrypted)
|
||
|
|
assert decrypted == SAMPLE_PEM
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Thread safety
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestThreadSafety:
|
||
|
|
"""Concurrent encrypt/decrypt calls must not corrupt state."""
|
||
|
|
|
||
|
|
def test_concurrent_encrypt_decrypt(self):
|
||
|
|
"""TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently."""
|
||
|
|
errors = []
|
||
|
|
results = []
|
||
|
|
|
||
|
|
def worker(i):
|
||
|
|
try:
|
||
|
|
data = f"key-data-{i}"
|
||
|
|
enc = encrypt_ca_key(data)
|
||
|
|
dec = decrypt_ca_key(enc)
|
||
|
|
results.append((i, dec))
|
||
|
|
except Exception as exc:
|
||
|
|
errors.append((i, exc))
|
||
|
|
|
||
|
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)]
|
||
|
|
for t in threads:
|
||
|
|
t.start()
|
||
|
|
for t in threads:
|
||
|
|
t.join(timeout=10)
|
||
|
|
|
||
|
|
assert not errors, f"Thread errors: {errors}"
|
||
|
|
assert len(results) == 50
|
||
|
|
for i, dec in results:
|
||
|
|
assert dec == f"key-data-{i}", f"Thread {i}: expected 'key-data-{i}', got {dec!r}"
|