Try   HackMD

3-AES - zer0pts CTF 2021

tags: zer0pts CTF 2021 crypto

short answer

I found this paper after I created the challenge.

overview

There is a 3 stage AES encryption. AES_ECB + AES_CBC + AES_CFB. We can encrypt/decrypt any plaintext/ciphertext respectively as many as we like, and at last, we can retrieve encrypted flag.

from Crypto.Cipher import AES from Crypto.Random import get_random_bytes from binascii import hexlify, unhexlify from hashlib import md5 import os import signal from flag import flag keys = [md5(os.urandom(3)).digest() for _ in range(3)] def get_ciphers(iv1, iv2): return [ AES.new(keys[0], mode=AES.MODE_ECB), AES.new(keys[1], mode=AES.MODE_CBC, iv=iv1), AES.new(keys[2], mode=AES.MODE_CFB, iv=iv2, segment_size=8*16), ] def encrypt(m: bytes, iv1: bytes, iv2: bytes) -> bytes: assert len(m) % 16 == 0 ciphers = get_ciphers(iv1, iv2) c = m for cipher in ciphers: c = cipher.encrypt(c) return c def decrypt(c: bytes, iv1: bytes, iv2: bytes) -> bytes: assert len(c) % 16 == 0 ciphers = get_ciphers(iv1, iv2) m = c for cipher in ciphers[::-1]: m = cipher.decrypt(m) return m signal.alarm(3600) while True: print("==== MENU ====") print("1. Encrypt your plaintext") print("2. Decrypt your ciphertext") print("3. Get encrypted flag") choice = int(input("> ")) if choice == 1: plaintext = unhexlify(input("your plaintext(hex): ")) iv1, iv2 = get_random_bytes(16), get_random_bytes(16) ciphertext = encrypt(plaintext, iv1, iv2) ciphertext = b":".join([hexlify(x) for x in [iv1, iv2, ciphertext]]).decode() print("here's the ciphertext: {}".format(ciphertext)) elif choice == 2: ciphertext = input("your ciphertext: ") iv1, iv2, ciphertext = [unhexlify(x) for x in ciphertext.strip().split(":")] plaintext = decrypt(ciphertext, iv1, iv2) print("here's the plaintext(hex): {}".format(hexlify(plaintext).decode())) elif choice == 3: plaintext = flag iv1, iv2 = get_random_bytes(16), get_random_bytes(16) ciphertext = encrypt(plaintext, iv1, iv2) ciphertext = b":".join([hexlify(x) for x in [iv1, iv2, ciphertext]]).decode() print("here's the encrypted flag: {}".format(ciphertext)) exit() else: exit()

observation

One of weak points is the key length. The key of all stages are generated from only 3 bytes. In total it comes 9 bytes. We couldn't bruteforce 9 bytes though, but bruteforceing 3 bytes is feasible. This remains us the meet-in-the-middle attack for 2-stages encryption. Can we somehow apply MITM?

Let's show the encryption as the formula. We note

E1,E2,E3 as the AES encryption of each stage,
IV2,IV3
as the initial vector of CBC and CFB mode, and
m1,m2,
as plaintext blocks and
c1,c2,
as coresponding ciphertext blocks.

Then the encryption is shown as the following.

c1=E2(E1(m1)IV2)E3(IV3)
c2=E2(E1(m2)E2(E1(m1)IV2))E3(c1)

Transforming this, we can get some of fomulas. For example,

c1E3(IV3)=E2(E1(m1)IV2)
E2(E1(m1)IV2)=D2(c2E3(c1))E1(m2)

where

Di is a decryption responding to
Ei
.

Then

c1E3(IV3)=E2(E1(m1)IV2)=D2(c2E3(c1))E1(m2)

Now we suppose

c1=c2=IV3 and let them as
X
, then as this is standing:
D2(c2E3(c1))E1(m2)=D2(c1E3(IV3))E1(m2)

=D2(E2(E1(m1)IV2))E1(m2)=E1(m1)IV2E1(m2)
,

c1E3(IV3)=E1(m1)IV2E1(m2)

The formula

() is using only
E1
and
E3
, but not using
E2
. Then it looks that MITM attack feasible now.

solution

As described below. We can do MITM-attack to the

E1 and
E3
. Then we will get
k1
and
k3
. At last, we simply do bruteforce
k2
to get them all.

So what should we do is simple.

  1. decrypt the ciphertext, which made arbitrarily to statisfy
    c1=c2=IV3
    , and get plaintext-ciphertext pair
  2. find
    k1,k2,k3
    by MITM and bruteforce
  3. get encrypted flag and decrypt with
    k1,k2,k3

exploit

the main script

from ptrlib import Socket, Process from subprocess import run, PIPE from binascii import hexlify, unhexlify from Crypto.Cipher import AES sock = Socket("localhost", 9999) sock.sendlineafter("> ", "2") iv2 = "A" * 32 iv3 = "A" * 32 a = "A" * 32 b = "A" * 32 c = "A" * 32 sock.sendlineafter("ciphertext: ", "{}:{}:{}".format(iv2, iv3, a+b+c)) plaintext = unhexlify(sock.recvlineafter(": ")) A = hexlify(plaintext[:16]).decode() B = hexlify(plaintext[16:32]).decode() sock.sendlineafter("> ", "3") flag = sock.recvlineafter("flag: ") sock.close() # --- r = run(["./k1k3", A, B, iv2, iv3], stdout=PIPE) k1, k3 = r.stdout.decode().strip().split("\n") print("k1={}".format(k1)) print("k3={}".format(k3)) r = run(["./k2", A, B, iv2, iv3, k1, k3], stdout=PIPE) k2 = r.stdout.decode().strip() print("k2={}".format(k2)) # --- keys = [ unhexlify(k1), unhexlify(k2), unhexlify(k3), ] def get_ciphers(iv1, iv2): return [ AES.new(keys[0], mode=AES.MODE_ECB), AES.new(keys[1], mode=AES.MODE_CBC, iv=iv1), AES.new(keys[2], mode=AES.MODE_CFB, iv=iv2, segment_size=8*16), ] def decrypt(c: bytes, iv1: bytes, iv2: bytes) -> bytes: assert len(c) % 16 == 0 ciphers = get_ciphers(iv1, iv2) m = c for cipher in ciphers[::-1]: m = cipher.decrypt(m) return m iv1, iv2, ciphertext = [unhexlify(x) for x in flag.decode().strip().split(":")] plaintext = decrypt(ciphertext, iv1, iv2) print(plaintext)

To find keys as fast as we can, I implemnted them by C++

#include <openssl/aes.h> #include <openssl/evp.h> #include <stdlib.h> #include <string.h> #include <string> #include <unordered_map> void encrypt(const char* key, const char *data, int len, unsigned char* dest) { EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); memset(dest, 0, len); int x; EVP_CIPHER_CTX_init(ctx); EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, (unsigned char*)key, NULL); EVP_EncryptUpdate(ctx, dest, &x, (const unsigned char*)data, len); EVP_CIPHER_CTX_free(ctx); } char fromHexChar(char c) { if ('0' <= c && c <= '9') { return c - '0'; } if ('a' <= c && c <= 'f') { return c - 'a' + 10; } if ('A' <= c && c <= 'F') { return c - 'A' + 10; } exit(EXIT_FAILURE); } char toHexChar(unsigned char c) { const static char table[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; return table[c]; } std::string fromhex(const char* source) { std::string s; while(*source) { char v = fromHexChar(*source) << 4; source++; v = v | fromHexChar(*source); source++; s.push_back(v); } return s; } std::string tohex(const char *src, int len) { const unsigned char *source = (const unsigned char*) src; std::string s; for (int i = 0; i < len; i++) { s.push_back(toHexChar((*source) >> 4)); s.push_back(toHexChar((*source) & 0xf)); source++; } return s; } std::string x(std::string a, std::string b) { std::string c; for (int i = 0; i < a.size(); i++) { c.push_back( a[i] ^ b[i] ); } return c; } int main(int argc, char **argv) { std::string a = fromhex(argv[1]); std::string b = fromhex(argv[2]); std::string iv2 = fromhex(argv[3]); std::string iv3 = fromhex(argv[4]); std::unordered_map<std::string, std::string> map; auto left = x(iv2, iv3); unsigned char* m1 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ; unsigned char* m2 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ; char key[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; for (int i = 0; i < 256; i++) { key[13] = i; for (int j = 0; j < 256; j++) { key[14] = j; for (int k = 0; k < 256; k++) { key[15] = k; encrypt(key, iv3.c_str(), iv3.size(), m1); auto t = x(left, std::string((char*) m1)); map[t] = tohex(key, 16); } } } std::string k1, k3; for (int i = 0; i < 256; i++) { key[13] = i; for (int j = 0; j < 256; j++) { key[14] = j; for (int k = 0; k < 256; k++) { key[15] = k; encrypt(key, a.c_str(), a.size(), m1); encrypt(key, b.c_str(), b.size(), m2); auto t = x(std::string((char*) m2), std::string((char*) m1)); if (map.find(t) != map.end()) { printf("%s\n", tohex(key, 16).c_str()); printf("%s\n", map.at(t).c_str()); k1 = std::string((char*)key); k3 = fromhex(map.at(t).c_str()); return 0; } } } } return 1; }

finding

k2

#include <openssl/aes.h> #include <openssl/evp.h> #include <stdlib.h> #include <string.h> #include <string> #include <unordered_map> void encrypt(const char* key, const char *data, int len, unsigned char* dest) { EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); memset(dest, 0, len); int x; EVP_CIPHER_CTX_init(ctx); EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, (unsigned char*)key, NULL); EVP_EncryptUpdate(ctx, dest, &x, (const unsigned char*)data, len); EVP_CIPHER_CTX_free(ctx); } char fromHexChar(char c) { if ('0' <= c && c <= '9') { return c - '0'; } if ('a' <= c && c <= 'f') { return c - 'a' + 10; } if ('A' <= c && c <= 'F') { return c - 'A' + 10; } exit(EXIT_FAILURE); } char toHexChar(unsigned char c) { const static char table[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'}; return table[c]; } std::string fromhex(const char* source) { std::string s; while(*source) { char v = fromHexChar(*source) << 4; source++; v = v | fromHexChar(*source); source++; s.push_back(v); } return s; } std::string tohex(const char *src, int len) { const unsigned char *source = (const unsigned char*) src; std::string s; for (int i = 0; i < len; i++) { s.push_back(toHexChar((*source) >> 4)); s.push_back(toHexChar((*source) & 0xf)); source++; } return s; } std::string x(std::string a, std::string b) { std::string c; for (int i = 0; i < a.size(); i++) { c.push_back( a[i] ^ b[i] ); } return c; } int main(int argc, char **argv) { std::string a = fromhex(argv[1]); std::string b = fromhex(argv[2]); std::string iv2 = fromhex(argv[3]); std::string iv3 = fromhex(argv[4]); std::string k1 = fromhex(argv[5]); std::string k3 = fromhex(argv[6]); unsigned char* m1 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ; unsigned char* m2 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ; encrypt(k1.c_str(), a.c_str(), a.size(), m1); std::string target = x(std::string((char*)m1), iv2); encrypt(k3.c_str(), iv3.c_str(), iv3.size(), m1); std::string cmp_to = x(std::string((char*)m1), iv3); char key[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; for (int i = 0; i < 256; i++) { key[13] = i; for (int j = 0; j < 256; j++) { key[14] = j; for (int k = 0; k < 256; k++) { key[15] = k; encrypt(key, target.c_str(), target.size(), m1); std::string result((char*) m1); if (result == cmp_to) { printf("%s\n", tohex(key, 16).c_str()); return 0; } } } } }