"""Unit tests for encryption module (general-purpose Fernet encryption). WHAT: Tests for the PBKDF2-based Fernet encryption/decryption used for OAuth tokens and client secrets. WHY: These utilities protect access tokens and client secrets; we need to verify round-trip correctness and error handling. EXPECTED: All encrypt/decrypt operations produce correct plaintext. """ import threading import pytest from gatehouse_app.utils.encryption import ( SALT_LENGTH, _get_fernet_key, decrypt, encrypt, ) # --------------------------------------------------------------------------- # Shared fixture # --------------------------------------------------------------------------- SECRET_KEY = "test-encryption-secret-key" SAMPLE_DATA = "access_token=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload" # --------------------------------------------------------------------------- # encrypt / decrypt round-trip # --------------------------------------------------------------------------- class TestEncryptDecryptRoundTrip: """Verify that encrypt -> decrypt returns the original plaintext.""" def test_basic_round_trip(self): """TEST: ENC-RT-01 -- Encrypt then decrypt returns original data.""" encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) decrypted = decrypt(encrypted, secret_key=SECRET_KEY) assert decrypted == SAMPLE_DATA def test_encrypted_is_base64(self): """TEST: ENC-RT-02 -- Encrypted output is valid base64.""" import base64 encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) # Should not raise base64.urlsafe_b64decode(encrypted.encode()) def test_different_ciphertext_each_time(self): """TEST: ENC-RT-03 -- Two encryptions of the same plaintext differ.""" enc1 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) enc2 = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) assert enc1 != enc2 assert decrypt(enc1, secret_key=SECRET_KEY) == SAMPLE_DATA assert decrypt(enc2, secret_key=SECRET_KEY) == SAMPLE_DATA def test_round_trip_unicode(self): """TEST: ENC-RT-04 -- Unicode data round-trips correctly.""" data = "token=cafe\u00e9\u00f1\u00fc" encrypted = encrypt(data, secret_key=SECRET_KEY) assert decrypt(encrypted, secret_key=SECRET_KEY) == data def test_round_trip_long_data(self): """TEST: ENC-RT-05 -- Large data round-trips correctly.""" data = "x" * 10000 encrypted = encrypt(data, secret_key=SECRET_KEY) assert decrypt(encrypted, secret_key=SECRET_KEY) == data # --------------------------------------------------------------------------- # Empty / edge inputs # --------------------------------------------------------------------------- class TestEdgeCases: def test_encrypt_empty_returns_empty(self): """TEST: ENC-EDGE-01 -- Encrypting empty string returns empty.""" assert encrypt("", secret_key=SECRET_KEY) == "" def test_decrypt_empty_returns_empty(self): """TEST: ENC-EDGE-02 -- Decrypting empty string returns empty.""" assert decrypt("", secret_key=SECRET_KEY) == "" def test_missing_key_raises_on_encrypt(self): """TEST: ENC-EDGE-03 -- Missing key raises ValueError on encrypt.""" with pytest.raises(ValueError, match="Encryption key not configured"): encrypt("data", secret_key="") def test_missing_key_raises_on_decrypt(self): """TEST: ENC-EDGE-04 -- Missing key raises ValueError on decrypt.""" with pytest.raises(ValueError, match="Encryption key not configured"): decrypt("something", secret_key="") def test_wrong_key_raises_on_decrypt(self): """TEST: ENC-EDGE-05 -- Wrong key raises ValueError on decrypt.""" encrypted = encrypt(SAMPLE_DATA, secret_key=SECRET_KEY) with pytest.raises(ValueError, match="Failed to decrypt"): decrypt(encrypted, secret_key="wrong-key") def test_corrupted_data_raises(self): """TEST: ENC-EDGE-06 -- Corrupted ciphertext raises ValueError.""" import base64 bad = base64.urlsafe_b64encode(b"not-valid-fernet-data").decode() with pytest.raises(ValueError): decrypt(bad, secret_key=SECRET_KEY) # --------------------------------------------------------------------------- # _get_fernet_key — PBKDF2 derivation # --------------------------------------------------------------------------- class TestKeyDerivation: def test_same_salt_same_key(self): """TEST: ENC-KD-01 -- Same salt produces the same derived key.""" salt = b"\x00" * SALT_LENGTH key1 = _get_fernet_key(SECRET_KEY, salt=salt) key2 = _get_fernet_key(SECRET_KEY, salt=salt) assert key1 == key2 def test_different_salt_different_key(self): """TEST: ENC-KD-02 -- Different salts produce different keys.""" salt1 = b"\x00" * SALT_LENGTH salt2 = b"\xff" * SALT_LENGTH key1 = _get_fernet_key(SECRET_KEY, salt=salt1) key2 = _get_fernet_key(SECRET_KEY, salt=salt2) assert key1 != key2 def test_auto_salt_length(self): """TEST: ENC-KD-03 -- Auto-generated salt is 16 bytes.""" key = _get_fernet_key(SECRET_KEY) # If it didn't raise, the salt was valid assert len(key) > 0 # --------------------------------------------------------------------------- # Thread safety # --------------------------------------------------------------------------- class TestThreadSafety: """Concurrent encrypt/decrypt calls must not corrupt state.""" def test_concurrent_encrypt_decrypt(self): """TEST: ENC-TS-01 -- 50 threads encrypting/decrypting concurrently.""" errors = [] results = [] def worker(i): try: data = f"token-{i}-secret" enc = encrypt(data, secret_key=SECRET_KEY) dec = decrypt(enc, secret_key=SECRET_KEY) results.append((i, dec)) except Exception as exc: errors.append((i, exc)) threads = [threading.Thread(target=worker, args=(i,)) for i in range(50)] for t in threads: t.start() for t in threads: t.join(timeout=10) assert not errors, f"Thread errors: {errors}" assert len(results) == 50 for i, dec in results: assert dec == f"token-{i}-secret", f"Thread {i}: mismatch"