fix(cors): handle wildcard origin with credentials and add unit tests
- Refactor CORS middleware to echo request origin when wildcard + credentials is configured (browsers reject Access-Control-Allow-Origin: * with Access-Control-Allow-Credentials: true) - Add _is_origin_allowed() and _cors_origin_header() helpers - Use CORS_SUPPORTS_CREDENTIALS config consistently - Ensure consistent Access-Control-Allow-Headers in all CORS paths - Fix redirect validation in get_token() to allow wildcard CORS origins - Add 46 unit tests covering encryption round-trips, idempotency, key derivation, thread safety, CORS origin matching, and preflight responses
This commit is contained in:
@@ -0,0 +1,205 @@
|
||||
"""Unit tests for ca_key_encryption module.
|
||||
|
||||
WHAT: Tests for the Fernet-based CA private key encryption/decryption
|
||||
utility functions.
|
||||
WHY: CA private keys are the most sensitive data in the system; we need
|
||||
to verify round-trip correctness, idempotency, and error handling.
|
||||
EXPECTED: All encrypt/decrypt operations produce correct plaintext.
|
||||
"""
|
||||
import os
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gatehouse_app.utils.ca_key_encryption import (
|
||||
CAKeyEncryptionError,
|
||||
_FERNET_PREFIX,
|
||||
_get_fernet,
|
||||
decrypt_ca_key,
|
||||
encrypt_ca_key,
|
||||
is_encrypted,
|
||||
reencrypt_ca_key,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_PEM = (
|
||||
"-----BEGIN OPENSSH PRIVATE KEY-----\n"
|
||||
"b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtz\n"
|
||||
"c2gtZWQyNTUxOQAAACBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBAAA\n"
|
||||
"-----END OPENSSH PRIVATE KEY-----"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _set_ca_encryption_key():
|
||||
"""Ensure CA_ENCRYPTION_KEY is set for every test."""
|
||||
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "test-secret-key-for-unit-tests"}):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# encrypt / decrypt round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEncryptDecryptRoundTrip:
|
||||
"""Verify that encrypt -> decrypt returns the original plaintext."""
|
||||
|
||||
def test_basic_round_trip(self):
|
||||
"""TEST: ENC-RT-01 -- Encrypt then decrypt returns original PEM."""
|
||||
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||||
decrypted = decrypt_ca_key(encrypted)
|
||||
assert decrypted == SAMPLE_PEM
|
||||
|
||||
def test_encrypted_value_has_prefix(self):
|
||||
"""TEST: ENC-RT-02 -- Encrypted output carries the $fernet$ envelope."""
|
||||
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||||
assert encrypted.startswith(_FERNET_PREFIX)
|
||||
|
||||
def test_different_ciphertext_each_time(self):
|
||||
"""TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ."""
|
||||
enc1 = encrypt_ca_key(SAMPLE_PEM)
|
||||
enc2 = encrypt_ca_key(SAMPLE_PEM)
|
||||
assert enc1 != enc2
|
||||
assert decrypt_ca_key(enc1) == SAMPLE_PEM
|
||||
assert decrypt_ca_key(enc2) == SAMPLE_PEM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Idempotency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIdempotency:
|
||||
"""The module must not double-encrypt or double-decrypt."""
|
||||
|
||||
def test_encrypt_idempotent(self):
|
||||
"""TEST: ENC-IDEM-01 -- Encrypting an already-encrypted value is a no-op."""
|
||||
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||||
double = encrypt_ca_key(encrypted)
|
||||
assert double == encrypted
|
||||
|
||||
def test_decrypt_plaintext_passthrough(self):
|
||||
"""TEST: ENC-IDEM-02 -- Decrypting a plaintext (legacy) value returns it as-is."""
|
||||
result = decrypt_ca_key(SAMPLE_PEM)
|
||||
assert result == SAMPLE_PEM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_encrypted helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsEncrypted:
|
||||
def test_encrypted_value(self):
|
||||
"""TEST: ENC-IE-01 -- is_encrypted returns True for $fernet$ values."""
|
||||
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||||
assert is_encrypted(encrypted) is True
|
||||
|
||||
def test_plaintext_value(self):
|
||||
"""TEST: ENC-IE-02 -- is_encrypted returns False for plain PEM."""
|
||||
assert is_encrypted(SAMPLE_PEM) is False
|
||||
|
||||
def test_empty_string(self):
|
||||
"""TEST: ENC-IE-03 -- is_encrypted returns False for empty string."""
|
||||
assert is_encrypted("") is False
|
||||
|
||||
def test_none_value(self):
|
||||
"""TEST: ENC-IE-04 -- is_encrypted returns False for None."""
|
||||
assert is_encrypted(None) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestErrorHandling:
|
||||
def test_encrypt_empty_raises(self):
|
||||
"""TEST: ENC-ERR-01 -- Encrypting empty string raises CAKeyEncryptionError."""
|
||||
with pytest.raises(CAKeyEncryptionError, match="empty"):
|
||||
encrypt_ca_key("")
|
||||
|
||||
def test_decrypt_empty_raises(self):
|
||||
"""TEST: ENC-ERR-02 -- Decrypting empty string raises CAKeyEncryptionError."""
|
||||
with pytest.raises(CAKeyEncryptionError, match="empty"):
|
||||
decrypt_ca_key("")
|
||||
|
||||
def test_missing_key_raises(self):
|
||||
"""TEST: ENC-ERR-03 -- Operations fail when CA_ENCRYPTION_KEY is unset."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("CA_ENCRYPTION_KEY", None)
|
||||
with pytest.raises(CAKeyEncryptionError, match="not set"):
|
||||
encrypt_ca_key(SAMPLE_PEM)
|
||||
|
||||
def test_wrong_key_raises_on_decrypt(self):
|
||||
"""TEST: ENC-ERR-04 -- Decrypting with the wrong key raises."""
|
||||
encrypted = encrypt_ca_key(SAMPLE_PEM)
|
||||
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": "wrong-key"}):
|
||||
with pytest.raises(CAKeyEncryptionError, match="decryption failed"):
|
||||
decrypt_ca_key(encrypted)
|
||||
|
||||
def test_corrupted_data_raises(self):
|
||||
"""TEST: ENC-ERR-05 -- Decrypting corrupted ciphertext raises."""
|
||||
with pytest.raises(CAKeyEncryptionError):
|
||||
decrypt_ca_key("$fernet$not-a-real-token")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reencrypt_ca_key -- key rotation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReencrypt:
|
||||
def test_reencrypt_round_trip(self):
|
||||
"""TEST: ENC-RE-01 -- Re-encrypted value decrypts with the new key."""
|
||||
old_key = "old-secret-key"
|
||||
new_key = "new-secret-key"
|
||||
encrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", old_key)
|
||||
reencrypted = reencrypt_ca_key(encrypted, old_key, new_key)
|
||||
|
||||
# Verify it decrypts with the new key
|
||||
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}):
|
||||
decrypted = decrypt_ca_key(reencrypted)
|
||||
assert decrypted == SAMPLE_PEM
|
||||
|
||||
def test_reencrypt_plaintext_key(self):
|
||||
"""TEST: ENC-RE-02 -- Re-encrypting a legacy plaintext key works."""
|
||||
new_key = "brand-new-key"
|
||||
reencrypted = reencrypt_ca_key(SAMPLE_PEM, "any-old-key", new_key)
|
||||
with patch.dict(os.environ, {"CA_ENCRYPTION_KEY": new_key}):
|
||||
decrypted = decrypt_ca_key(reencrypted)
|
||||
assert decrypted == SAMPLE_PEM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread safety
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Concurrent encrypt/decrypt calls must not corrupt state."""
|
||||
|
||||
def test_concurrent_encrypt_decrypt(self):
|
||||
"""TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently."""
|
||||
errors = []
|
||||
results = []
|
||||
|
||||
def worker(i):
|
||||
try:
|
||||
data = f"key-data-{i}"
|
||||
enc = encrypt_ca_key(data)
|
||||
dec = decrypt_ca_key(enc)
|
||||
results.append((i, dec))
|
||||
except Exception as exc:
|
||||
errors.append((i, exc))
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=10)
|
||||
|
||||
assert not errors, f"Thread errors: {errors}"
|
||||
assert len(results) == 50
|
||||
for i, dec in results:
|
||||
assert dec == f"key-data-{i}", f"Thread {i}: expected 'key-data-{i}', got {dec!r}"
|
||||
Reference in New Issue
Block a user