60799bbc52
- Refactor CORS middleware to echo request origin when wildcard + credentials is configured (browsers reject Access-Control-Allow-Origin: * with Access-Control-Allow-Credentials: true) - Add _is_origin_allowed() and _cors_origin_header() helpers - Use CORS_SUPPORTS_CREDENTIALS config consistently - Ensure consistent Access-Control-Allow-Headers in all CORS paths - Fix redirect validation in get_token() to allow wildcard CORS origins - Add 46 unit tests covering encryption round-trips, idempotency, key derivation, thread safety, CORS origin matching, and preflight responses
165 lines
6.4 KiB
Python
165 lines
6.4 KiB
Python
"""Unit tests for encryption module (general-purpose Fernet encryption).
|
|
|
|
WHAT: Tests for the PBKDF2-based Fernet encryption/decryption used for
|
|
OAuth tokens and client secrets.
|
|
WHY: These utilities protect access tokens and client secrets; we need
|
|
to verify round-trip correctness and error handling.
|
|
EXPECTED: All encrypt/decrypt operations produce correct plaintext.
|
|
"""
|
|
import threading
|
|
|
|
import pytest
|
|
|
|
from gatehouse_app.utils.encryption import (
|
|
SALT_LENGTH,
|
|
_get_fernet_key,
|
|
decrypt,
|
|
encrypt,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared fixture
|
|
# ---------------------------------------------------------------------------
|
|
|
|
SECRET_KEY = "test-encryption-secret-key"
|
|
SAMPLE_DATA = "access_token=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# encrypt / decrypt round-trip
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestEncryptDecryptRoundTrip:
|
|
"""Verify that encrypt -> decrypt returns the original plaintext."""
|
|
|
|
def test_basic_round_trip(self):
|
|
"""TEST: ENC-RT-01 -- Encrypt then decrypt returns original data."""
|
|
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
|
decrypted = decrypt(encrypted, secret_key=SECRET_KEY)
|
|
assert decrypted == SAMPLE_DATA
|
|
|
|
def test_encrypted_is_base64(self):
|
|
"""TEST: ENC-RT-02 -- Encrypted output is valid base64."""
|
|
import base64
|
|
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
|
# Should not raise
|
|
base64.urlsafe_b64decode(encrypted.encode())
|
|
|
|
def test_different_ciphertext_each_time(self):
|
|
"""TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ."""
|
|
enc1 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
|
enc2 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
|
assert enc1 != enc2
|
|
assert decrypt(enc1, secret_key=SECRET_KEY) == SAMPLE_DATA
|
|
assert decrypt(enc2, secret_key=SECRET_KEY) == SAMPLE_DATA
|
|
|
|
def test_round_trip_unicode(self):
|
|
"""TEST: ENC-RT-04 -- Unicode data round-trips correctly."""
|
|
data = "token=cafe\u00e9\u00f1\u00fc"
|
|
encrypted = encrypt(data, secret_key=SECRET_KEY)
|
|
assert decrypt(encrypted, secret_key=SECRET_KEY) == data
|
|
|
|
def test_round_trip_long_data(self):
|
|
"""TEST: ENC-RT-05 -- Large data round-trips correctly."""
|
|
data = "x" * 10000
|
|
encrypted = encrypt(data, secret_key=SECRET_KEY)
|
|
assert decrypt(encrypted, secret_key=SECRET_KEY) == data
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Empty / edge inputs
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestEdgeCases:
|
|
def test_encrypt_empty_returns_empty(self):
|
|
"""TEST: ENC-EDGE-01 -- Encrypting empty string returns empty."""
|
|
assert encrypt("", secret_key=SECRET_KEY) == ""
|
|
|
|
def test_decrypt_empty_returns_empty(self):
|
|
"""TEST: ENC-EDGE-02 -- Decrypting empty string returns empty."""
|
|
assert decrypt("", secret_key=SECRET_KEY) == ""
|
|
|
|
def test_missing_key_raises_on_encrypt(self):
|
|
"""TEST: ENC-EDGE-03 -- Missing key raises ValueError on encrypt."""
|
|
with pytest.raises(ValueError, match="Encryption key not configured"):
|
|
encrypt("data", secret_key="")
|
|
|
|
def test_missing_key_raises_on_decrypt(self):
|
|
"""TEST: ENC-EDGE-04 -- Missing key raises ValueError on decrypt."""
|
|
with pytest.raises(ValueError, match="Encryption key not configured"):
|
|
decrypt("something", secret_key="")
|
|
|
|
def test_wrong_key_raises_on_decrypt(self):
|
|
"""TEST: ENC-EDGE-05 -- Wrong key raises ValueError on decrypt."""
|
|
encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY)
|
|
with pytest.raises(ValueError, match="Failed to decrypt"):
|
|
decrypt(encrypted, secret_key="wrong-key")
|
|
|
|
def test_corrupted_data_raises(self):
|
|
"""TEST: ENC-EDGE-06 -- Corrupted ciphertext raises ValueError."""
|
|
import base64
|
|
bad = base64.urlsafe_b64encode(b"not-valid-fernet-data").decode()
|
|
with pytest.raises(ValueError):
|
|
decrypt(bad, secret_key=SECRET_KEY)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _get_fernet_key — PBKDF2 derivation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestKeyDerivation:
|
|
def test_same_salt_same_key(self):
|
|
"""TEST: ENC-KD-01 -- Same salt produces the same derived key."""
|
|
salt = b"\x00" * SALT_LENGTH
|
|
key1 = _get_fernet_key(SECRET_KEY, salt=salt)
|
|
key2 = _get_fernet_key(SECRET_KEY, salt=salt)
|
|
assert key1 == key2
|
|
|
|
def test_different_salt_different_key(self):
|
|
"""TEST: ENC-KD-02 -- Different salts produce different keys."""
|
|
salt1 = b"\x00" * SALT_LENGTH
|
|
salt2 = b"\xff" * SALT_LENGTH
|
|
key1 = _get_fernet_key(SECRET_KEY, salt=salt1)
|
|
key2 = _get_fernet_key(SECRET_KEY, salt=salt2)
|
|
assert key1 != key2
|
|
|
|
def test_auto_salt_length(self):
|
|
"""TEST: ENC-KD-03 -- Auto-generated salt is 16 bytes."""
|
|
key = _get_fernet_key(SECRET_KEY)
|
|
# If it didn't raise, the salt was valid
|
|
assert len(key) > 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Thread safety
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestThreadSafety:
|
|
"""Concurrent encrypt/decrypt calls must not corrupt state."""
|
|
|
|
def test_concurrent_encrypt_decrypt(self):
|
|
"""TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently."""
|
|
errors = []
|
|
results = []
|
|
|
|
def worker(i):
|
|
try:
|
|
data = f"token-{i}-secret"
|
|
enc = encrypt(data, secret_key=SECRET_KEY)
|
|
dec = decrypt(enc, secret_key=SECRET_KEY)
|
|
results.append((i, dec))
|
|
except Exception as exc:
|
|
errors.append((i, exc))
|
|
|
|
threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join(timeout=10)
|
|
|
|
assert not errors, f"Thread errors: {errors}"
|
|
assert len(results) == 50
|
|
for i, dec in results:
|
|
assert dec == f"token-{i}-secret", f"Thread {i}: mismatch"
|