# 3-AES - zer0pts CTF 2021 ###### tags: `zer0pts CTF 2021` `crypto` ## short answer I found [this paper](https://iacr.org/archive/asiacrypt2001/22480210.pdf) after I created the challenge. ## overview There is a 3 stage AES encryption. AES_ECB + AES_CBC + AES_CFB. We can encrypt/decrypt any plaintext/ciphertext respectively as many as we like, and at last, we can retrieve encrypted flag. ```python= from Crypto.Cipher import AES from Crypto.Random import get_random_bytes from binascii import hexlify, unhexlify from hashlib import md5 import os import signal from flag import flag keys = [md5(os.urandom(3)).digest() for _ in range(3)] def get_ciphers(iv1, iv2): return [ AES.new(keys[0], mode=AES.MODE_ECB), AES.new(keys[1], mode=AES.MODE_CBC, iv=iv1), AES.new(keys[2], mode=AES.MODE_CFB, iv=iv2, segment_size=8*16), ] def encrypt(m: bytes, iv1: bytes, iv2: bytes) -> bytes: assert len(m) % 16 == 0 ciphers = get_ciphers(iv1, iv2) c = m for cipher in ciphers: c = cipher.encrypt(c) return c def decrypt(c: bytes, iv1: bytes, iv2: bytes) -> bytes: assert len(c) % 16 == 0 ciphers = get_ciphers(iv1, iv2) m = c for cipher in ciphers[::-1]: m = cipher.decrypt(m) return m signal.alarm(3600) while True: print("==== MENU ====") print("1. Encrypt your plaintext") print("2. Decrypt your ciphertext") print("3. Get encrypted flag") choice = int(input("> ")) if choice == 1: plaintext = unhexlify(input("your plaintext(hex): ")) iv1, iv2 = get_random_bytes(16), get_random_bytes(16) ciphertext = encrypt(plaintext, iv1, iv2) ciphertext = b":".join([hexlify(x) for x in [iv1, iv2, ciphertext]]).decode() print("here's the ciphertext: {}".format(ciphertext)) elif choice == 2: ciphertext = input("your ciphertext: ") iv1, iv2, ciphertext = [unhexlify(x) for x in ciphertext.strip().split(":")] plaintext = decrypt(ciphertext, iv1, iv2) print("here's the plaintext(hex): {}".format(hexlify(plaintext).decode())) elif choice == 3: plaintext = flag iv1, iv2 = get_random_bytes(16), get_random_bytes(16) ciphertext = encrypt(plaintext, iv1, iv2) ciphertext = b":".join([hexlify(x) for x in [iv1, iv2, ciphertext]]).decode() print("here's the encrypted flag: {}".format(ciphertext)) exit() else: exit() ``` ## observation One of weak points is the key length. The key of all stages are generated from only 3 bytes. In total it comes 9 bytes. We couldn't bruteforce 9 bytes though, but bruteforceing 3 bytes is feasible. This remains us the meet-in-the-middle attack for 2-stages encryption. Can we somehow apply MITM? Let's show the encryption as the formula. We note $E_1, E_2, E_3$ as the AES encryption of each stage, $IV_2, IV_3$ as the initial vector of CBC and CFB mode, and $m_1, m_2, \dots$ as plaintext blocks and $c_1, c_2, \dots$ as coresponding ciphertext blocks. Then the encryption is shown as the following. $c_1 = E_2(E_1(m_1) \oplus IV_2) \oplus E_3(IV_3)$ $c_2 = E_2(E_1(m_2) \oplus E_2(E_1(m_1) \oplus IV_2)) \oplus E_3(c_1)$ Transforming this, we can get some of fomulas. For example, $c_1 \oplus E_3(IV_3) = E_2(E_1(m_1) \oplus IV_2)$ $E_2(E_1(m_1) \oplus IV_2) = D_2(c_2 \oplus E_3(c_1)) \oplus E_1(m_2)$ where $D_i$ is a decryption responding to $E_i$. Then $c_1 \oplus E_3(IV_3) = E_2(E_1(m_1) \oplus IV_2) = D_2(c_2 \oplus E_3(c_1)) \oplus E_1(m_2)$ Now we suppose $c_1 = c_2 = IV_3$ and let them as $X$, then as this is standing: $D_2(c_2 \oplus E_3(c_1)) \oplus E_1(m_2) = D_2(c_1 \oplus E_3(IV_3)) \oplus E_1(m_2)$ $= D_2(E_2(E_1(m_1) \oplus IV_2)) \oplus E_1(m_2) = E_1(m_1) \oplus IV_2 \oplus E_1(m_2)$, $c_1 \oplus E_3(IV_3) = E_1(m_1) \oplus IV_2 \oplus E_1(m_2)$ $\cdots \triangle$ The formula $(\triangle)$ is using only $E_1$ and $E_3$, but not using $E_2$. Then it looks that MITM attack feasible now. ## solution As described below. We can do MITM-attack to the $E_1$ and $E_3$. Then we will get $k_1$ and $k_3$. At last, we simply do bruteforce $k_2$ to get them all. So what should we do is simple. 1. decrypt the ciphertext, which made arbitrarily to statisfy $c_1 = c_2 = IV_3$, and get plaintext-ciphertext pair 2. find $k_1, k_2, k_3$ by MITM and bruteforce 3. get encrypted flag and decrypt with $k_1, k_2, k_3$ ## exploit the main script ```python= from ptrlib import Socket, Process from subprocess import run, PIPE from binascii import hexlify, unhexlify from Crypto.Cipher import AES sock = Socket("localhost", 9999) sock.sendlineafter("> ", "2") iv2 = "A" * 32 iv3 = "A" * 32 a = "A" * 32 b = "A" * 32 c = "A" * 32 sock.sendlineafter("ciphertext: ", "{}:{}:{}".format(iv2, iv3, a+b+c)) plaintext = unhexlify(sock.recvlineafter(": ")) A = hexlify(plaintext[:16]).decode() B = hexlify(plaintext[16:32]).decode() sock.sendlineafter("> ", "3") flag = sock.recvlineafter("flag: ") sock.close() # --- r = run(["./k1k3", A, B, iv2, iv3], stdout=PIPE) k1, k3 = r.stdout.decode().strip().split("\n") print("k1={}".format(k1)) print("k3={}".format(k3)) r = run(["./k2", A, B, iv2, iv3, k1, k3], stdout=PIPE) k2 = r.stdout.decode().strip() print("k2={}".format(k2)) # --- keys = [ unhexlify(k1), unhexlify(k2), unhexlify(k3), ] def get_ciphers(iv1, iv2): return [ AES.new(keys[0], mode=AES.MODE_ECB), AES.new(keys[1], mode=AES.MODE_CBC, iv=iv1), AES.new(keys[2], mode=AES.MODE_CFB, iv=iv2, segment_size=8*16), ] def decrypt(c: bytes, iv1: bytes, iv2: bytes) -> bytes: assert len(c) % 16 == 0 ciphers = get_ciphers(iv1, iv2) m = c for cipher in ciphers[::-1]: m = cipher.decrypt(m) return m iv1, iv2, ciphertext = [unhexlify(x) for x in flag.decode().strip().split(":")] plaintext = decrypt(ciphertext, iv1, iv2) print(plaintext) ``` To find keys as fast as we can, I implemnted them by C++ ```clike= #include <openssl/aes.h> #include <openssl/evp.h> #include <stdlib.h> #include <string.h> #include <string> #include <unordered_map> void encrypt(const char* key, const char *data, int len, unsigned char* dest) { EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); memset(dest, 0, len); int x; EVP_CIPHER_CTX_init(ctx); EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, (unsigned char*)key, NULL); EVP_EncryptUpdate(ctx, dest, &x, (const unsigned char*)data, len); EVP_CIPHER_CTX_free(ctx); } char fromHexChar(char c) { if ('0' <= c && c <= '9') { return c - '0'; } if ('a' <= c && c <= 'f') { return c - 'a' + 10; } if ('A' <= c && c <= 'F') { return c - 'A' + 10; } exit(EXIT_FAILURE); } char toHexChar(unsigned char c) { const static char table[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; return table[c]; } std::string fromhex(const char* source) { std::string s; while(*source) { char v = fromHexChar(*source) << 4; source++; v = v | fromHexChar(*source); source++; s.push_back(v); } return s; } std::string tohex(const char *src, int len) { const unsigned char *source = (const unsigned char*) src; std::string s; for (int i = 0; i < len; i++) { s.push_back(toHexChar((*source) >> 4)); s.push_back(toHexChar((*source) & 0xf)); source++; } return s; } std::string x(std::string a, std::string b) { std::string c; for (int i = 0; i < a.size(); i++) { c.push_back( a[i] ^ b[i] ); } return c; } int main(int argc, char **argv) { std::string a = fromhex(argv[1]); std::string b = fromhex(argv[2]); std::string iv2 = fromhex(argv[3]); std::string iv3 = fromhex(argv[4]); std::unordered_map<std::string, std::string> map; auto left = x(iv2, iv3); unsigned char* m1 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ; unsigned char* m2 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ; char key[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; for (int i = 0; i < 256; i++) { key[13] = i; for (int j = 0; j < 256; j++) { key[14] = j; for (int k = 0; k < 256; k++) { key[15] = k; encrypt(key, iv3.c_str(), iv3.size(), m1); auto t = x(left, std::string((char*) m1)); map[t] = tohex(key, 16); } } } std::string k1, k3; for (int i = 0; i < 256; i++) { key[13] = i; for (int j = 0; j < 256; j++) { key[14] = j; for (int k = 0; k < 256; k++) { key[15] = k; encrypt(key, a.c_str(), a.size(), m1); encrypt(key, b.c_str(), b.size(), m2); auto t = x(std::string((char*) m2), std::string((char*) m1)); if (map.find(t) != map.end()) { printf("%s\n", tohex(key, 16).c_str()); printf("%s\n", map.at(t).c_str()); k1 = std::string((char*)key); k3 = fromhex(map.at(t).c_str()); return 0; } } } } return 1; } ``` finding $k_2$ ```clike= #include <openssl/aes.h> #include <openssl/evp.h> #include <stdlib.h> #include <string.h> #include <string> #include <unordered_map> void encrypt(const char* key, const char *data, int len, unsigned char* dest) { EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); memset(dest, 0, len); int x; EVP_CIPHER_CTX_init(ctx); EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, (unsigned char*)key, NULL); EVP_EncryptUpdate(ctx, dest, &x, (const unsigned char*)data, len); EVP_CIPHER_CTX_free(ctx); } char fromHexChar(char c) { if ('0' <= c && c <= '9') { return c - '0'; } if ('a' <= c && c <= 'f') { return c - 'a' + 10; } if ('A' <= c && c <= 'F') { return c - 'A' + 10; } exit(EXIT_FAILURE); } char toHexChar(unsigned char c) { const static char table[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; return table[c]; } std::string fromhex(const char* source) { std::string s; while(*source) { char v = fromHexChar(*source) << 4; source++; v = v | fromHexChar(*source); source++; s.push_back(v); } return s; } std::string tohex(const char *src, int len) { const unsigned char *source = (const unsigned char*) src; std::string s; for (int i = 0; i < len; i++) { s.push_back(toHexChar((*source) >> 4)); s.push_back(toHexChar((*source) & 0xf)); source++; } return s; } std::string x(std::string a, std::string b) { std::string c; for (int i = 0; i < a.size(); i++) { c.push_back( a[i] ^ b[i] ); } return c; } int main(int argc, char **argv) { std::string a = fromhex(argv[1]); std::string b = fromhex(argv[2]); std::string iv2 = fromhex(argv[3]); std::string iv3 = fromhex(argv[4]); std::string k1 = fromhex(argv[5]); std::string k3 = fromhex(argv[6]); unsigned char* m1 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ; unsigned char* m2 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ; encrypt(k1.c_str(), a.c_str(), a.size(), m1); std::string target = x(std::string((char*)m1), iv2); encrypt(k3.c_str(), iv3.c_str(), iv3.size(), m1); std::string cmp_to = x(std::string((char*)m1), iv3); char key[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; for (int i = 0; i < 256; i++) { key[13] = i; for (int j = 0; j < 256; j++) { key[14] = j; for (int k = 0; k < 256; k++) { key[15] = k; encrypt(key, target.c_str(), target.size(), m1); std::string result((char*) m1); if (result == cmp_to) { printf("%s\n", tohex(key, 16).c_str()); return 0; } } } } } ```