# 3-AES - zer0pts CTF 2021
###### tags: `zer0pts CTF 2021` `crypto`
## short answer
I found [this paper](https://iacr.org/archive/asiacrypt2001/22480210.pdf) after I created the challenge.
## overview
There is a 3 stage AES encryption. AES_ECB + AES_CBC + AES_CFB. We can encrypt/decrypt any plaintext/ciphertext respectively as many as we like, and at last, we can retrieve encrypted flag.
```python=
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from binascii import hexlify, unhexlify
from hashlib import md5
import os
import signal
from flag import flag
keys = [md5(os.urandom(3)).digest() for _ in range(3)]
def get_ciphers(iv1, iv2):
return [
AES.new(keys[0], mode=AES.MODE_ECB),
AES.new(keys[1], mode=AES.MODE_CBC, iv=iv1),
AES.new(keys[2], mode=AES.MODE_CFB, iv=iv2, segment_size=8*16),
]
def encrypt(m: bytes, iv1: bytes, iv2: bytes) -> bytes:
assert len(m) % 16 == 0
ciphers = get_ciphers(iv1, iv2)
c = m
for cipher in ciphers:
c = cipher.encrypt(c)
return c
def decrypt(c: bytes, iv1: bytes, iv2: bytes) -> bytes:
assert len(c) % 16 == 0
ciphers = get_ciphers(iv1, iv2)
m = c
for cipher in ciphers[::-1]:
m = cipher.decrypt(m)
return m
signal.alarm(3600)
while True:
print("==== MENU ====")
print("1. Encrypt your plaintext")
print("2. Decrypt your ciphertext")
print("3. Get encrypted flag")
choice = int(input("> "))
if choice == 1:
plaintext = unhexlify(input("your plaintext(hex): "))
iv1, iv2 = get_random_bytes(16), get_random_bytes(16)
ciphertext = encrypt(plaintext, iv1, iv2)
ciphertext = b":".join([hexlify(x) for x in [iv1, iv2, ciphertext]]).decode()
print("here's the ciphertext: {}".format(ciphertext))
elif choice == 2:
ciphertext = input("your ciphertext: ")
iv1, iv2, ciphertext = [unhexlify(x) for x in ciphertext.strip().split(":")]
plaintext = decrypt(ciphertext, iv1, iv2)
print("here's the plaintext(hex): {}".format(hexlify(plaintext).decode()))
elif choice == 3:
plaintext = flag
iv1, iv2 = get_random_bytes(16), get_random_bytes(16)
ciphertext = encrypt(plaintext, iv1, iv2)
ciphertext = b":".join([hexlify(x) for x in [iv1, iv2, ciphertext]]).decode()
print("here's the encrypted flag: {}".format(ciphertext))
exit()
else:
exit()
```
## observation
One of weak points is the key length. The key of all stages are generated from only 3 bytes. In total it comes 9 bytes. We couldn't bruteforce 9 bytes though, but bruteforceing 3 bytes is feasible. This remains us the meet-in-the-middle attack for 2-stages encryption. Can we somehow apply MITM?
Let's show the encryption as the formula. We note $E_1, E_2, E_3$ as the AES encryption of each stage, $IV_2, IV_3$ as the initial vector of CBC and CFB mode, and $m_1, m_2, \dots$ as plaintext blocks and $c_1, c_2, \dots$ as coresponding ciphertext blocks.
Then the encryption is shown as the following.
$c_1 = E_2(E_1(m_1) \oplus IV_2) \oplus E_3(IV_3)$
$c_2 = E_2(E_1(m_2) \oplus E_2(E_1(m_1) \oplus IV_2)) \oplus E_3(c_1)$
Transforming this, we can get some of fomulas. For example,
$c_1 \oplus E_3(IV_3) = E_2(E_1(m_1) \oplus IV_2)$
$E_2(E_1(m_1) \oplus IV_2) = D_2(c_2 \oplus E_3(c_1)) \oplus E_1(m_2)$
where $D_i$ is a decryption responding to $E_i$.
Then
$c_1 \oplus E_3(IV_3) = E_2(E_1(m_1) \oplus IV_2) = D_2(c_2 \oplus E_3(c_1)) \oplus E_1(m_2)$
Now we suppose $c_1 = c_2 = IV_3$ and let them as $X$, then as this is standing: $D_2(c_2 \oplus E_3(c_1)) \oplus E_1(m_2) = D_2(c_1 \oplus E_3(IV_3)) \oplus E_1(m_2)$
$= D_2(E_2(E_1(m_1) \oplus IV_2)) \oplus E_1(m_2) = E_1(m_1) \oplus IV_2 \oplus E_1(m_2)$,
$c_1 \oplus E_3(IV_3) = E_1(m_1) \oplus IV_2 \oplus E_1(m_2)$ $\cdots \triangle$
The formula $(\triangle)$ is using only $E_1$ and $E_3$, but not using $E_2$. Then it looks that MITM attack feasible now.
## solution
As described below. We can do MITM-attack to the $E_1$ and $E_3$. Then we will get $k_1$ and $k_3$. At last, we simply do bruteforce $k_2$ to get them all.
So what should we do is simple.
1. decrypt the ciphertext, which made arbitrarily to statisfy $c_1 = c_2 = IV_3$, and get plaintext-ciphertext pair
2. find $k_1, k_2, k_3$ by MITM and bruteforce
3. get encrypted flag and decrypt with $k_1, k_2, k_3$
## exploit
the main script
```python=
from ptrlib import Socket, Process
from subprocess import run, PIPE
from binascii import hexlify, unhexlify
from Crypto.Cipher import AES
sock = Socket("localhost", 9999)
sock.sendlineafter("> ", "2")
iv2 = "A" * 32
iv3 = "A" * 32
a = "A" * 32
b = "A" * 32
c = "A" * 32
sock.sendlineafter("ciphertext: ", "{}:{}:{}".format(iv2, iv3, a+b+c))
plaintext = unhexlify(sock.recvlineafter(": "))
A = hexlify(plaintext[:16]).decode()
B = hexlify(plaintext[16:32]).decode()
sock.sendlineafter("> ", "3")
flag = sock.recvlineafter("flag: ")
sock.close()
# ---
r = run(["./k1k3", A, B, iv2, iv3], stdout=PIPE)
k1, k3 = r.stdout.decode().strip().split("\n")
print("k1={}".format(k1))
print("k3={}".format(k3))
r = run(["./k2", A, B, iv2, iv3, k1, k3], stdout=PIPE)
k2 = r.stdout.decode().strip()
print("k2={}".format(k2))
# ---
keys = [
unhexlify(k1),
unhexlify(k2),
unhexlify(k3),
]
def get_ciphers(iv1, iv2):
return [
AES.new(keys[0], mode=AES.MODE_ECB),
AES.new(keys[1], mode=AES.MODE_CBC, iv=iv1),
AES.new(keys[2], mode=AES.MODE_CFB, iv=iv2, segment_size=8*16),
]
def decrypt(c: bytes, iv1: bytes, iv2: bytes) -> bytes:
assert len(c) % 16 == 0
ciphers = get_ciphers(iv1, iv2)
m = c
for cipher in ciphers[::-1]:
m = cipher.decrypt(m)
return m
iv1, iv2, ciphertext = [unhexlify(x) for x in flag.decode().strip().split(":")]
plaintext = decrypt(ciphertext, iv1, iv2)
print(plaintext)
```
To find keys as fast as we can, I implemnted them by C++
```clike=
#include <openssl/aes.h>
#include <openssl/evp.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include <unordered_map>
void encrypt(const char* key, const char *data, int len, unsigned char* dest) {
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
memset(dest, 0, len);
int x;
EVP_CIPHER_CTX_init(ctx);
EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, (unsigned char*)key, NULL);
EVP_EncryptUpdate(ctx, dest, &x, (const unsigned char*)data, len);
EVP_CIPHER_CTX_free(ctx);
}
char fromHexChar(char c) {
if ('0' <= c && c <= '9') {
return c - '0';
}
if ('a' <= c && c <= 'f') {
return c - 'a' + 10;
}
if ('A' <= c && c <= 'F') {
return c - 'A' + 10;
}
exit(EXIT_FAILURE);
}
char toHexChar(unsigned char c) {
const static char table[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
return table[c];
}
std::string fromhex(const char* source) {
std::string s;
while(*source) {
char v = fromHexChar(*source) << 4;
source++;
v = v | fromHexChar(*source);
source++;
s.push_back(v);
}
return s;
}
std::string tohex(const char *src, int len) {
const unsigned char *source = (const unsigned char*) src;
std::string s;
for (int i = 0; i < len; i++) {
s.push_back(toHexChar((*source) >> 4));
s.push_back(toHexChar((*source) & 0xf));
source++;
}
return s;
}
std::string x(std::string a, std::string b) {
std::string c;
for (int i = 0; i < a.size(); i++) {
c.push_back( a[i] ^ b[i] );
}
return c;
}
int main(int argc, char **argv) {
std::string a = fromhex(argv[1]);
std::string b = fromhex(argv[2]);
std::string iv2 = fromhex(argv[3]);
std::string iv3 = fromhex(argv[4]);
std::unordered_map<std::string, std::string> map;
auto left = x(iv2, iv3);
unsigned char* m1 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ;
unsigned char* m2 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ;
char key[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
for (int i = 0; i < 256; i++) {
key[13] = i;
for (int j = 0; j < 256; j++) {
key[14] = j;
for (int k = 0; k < 256; k++) {
key[15] = k;
encrypt(key, iv3.c_str(), iv3.size(), m1);
auto t = x(left, std::string((char*) m1));
map[t] = tohex(key, 16);
}
}
}
std::string k1, k3;
for (int i = 0; i < 256; i++) {
key[13] = i;
for (int j = 0; j < 256; j++) {
key[14] = j;
for (int k = 0; k < 256; k++) {
key[15] = k;
encrypt(key, a.c_str(), a.size(), m1);
encrypt(key, b.c_str(), b.size(), m2);
auto t = x(std::string((char*) m2), std::string((char*) m1));
if (map.find(t) != map.end()) {
printf("%s\n", tohex(key, 16).c_str());
printf("%s\n", map.at(t).c_str());
k1 = std::string((char*)key);
k3 = fromhex(map.at(t).c_str());
return 0;
}
}
}
}
return 1;
}
```
finding $k_2$
```clike=
#include <openssl/aes.h>
#include <openssl/evp.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include <unordered_map>
void encrypt(const char* key, const char *data, int len, unsigned char* dest) {
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
memset(dest, 0, len);
int x;
EVP_CIPHER_CTX_init(ctx);
EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, (unsigned char*)key, NULL);
EVP_EncryptUpdate(ctx, dest, &x, (const unsigned char*)data, len);
EVP_CIPHER_CTX_free(ctx);
}
char fromHexChar(char c) {
if ('0' <= c && c <= '9') {
return c - '0';
}
if ('a' <= c && c <= 'f') {
return c - 'a' + 10;
}
if ('A' <= c && c <= 'F') {
return c - 'A' + 10;
}
exit(EXIT_FAILURE);
}
char toHexChar(unsigned char c) {
const static char table[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
return table[c];
}
std::string fromhex(const char* source) {
std::string s;
while(*source) {
char v = fromHexChar(*source) << 4;
source++;
v = v | fromHexChar(*source);
source++;
s.push_back(v);
}
return s;
}
std::string tohex(const char *src, int len) {
const unsigned char *source = (const unsigned char*) src;
std::string s;
for (int i = 0; i < len; i++) {
s.push_back(toHexChar((*source) >> 4));
s.push_back(toHexChar((*source) & 0xf));
source++;
}
return s;
}
std::string x(std::string a, std::string b) {
std::string c;
for (int i = 0; i < a.size(); i++) {
c.push_back( a[i] ^ b[i] );
}
return c;
}
int main(int argc, char **argv) {
std::string a = fromhex(argv[1]);
std::string b = fromhex(argv[2]);
std::string iv2 = fromhex(argv[3]);
std::string iv3 = fromhex(argv[4]);
std::string k1 = fromhex(argv[5]);
std::string k3 = fromhex(argv[6]);
unsigned char* m1 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ;
unsigned char* m2 = (unsigned char*)malloc(sizeof(unsigned char) * a.size()) ;
encrypt(k1.c_str(), a.c_str(), a.size(), m1);
std::string target = x(std::string((char*)m1), iv2);
encrypt(k3.c_str(), iv3.c_str(), iv3.size(), m1);
std::string cmp_to = x(std::string((char*)m1), iv3);
char key[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
for (int i = 0; i < 256; i++) {
key[13] = i;
for (int j = 0; j < 256; j++) {
key[14] = j;
for (int k = 0; k < 256; k++) {
key[15] = k;
encrypt(key, target.c_str(), target.size(), m1);
std::string result((char*) m1);
if (result == cmp_to) {
printf("%s\n", tohex(key, 16).c_str());
return 0;
}
}
}
}
}
```