# Square attack on AES 4 rounds
## 1. A persistent structure over 3 rounds
- Giả sử ta có một bộ 256-plaintexts với byte đầu tiên từ 0-255, các bytes còn lại = 0x00

- Ta gọi bộ này là delta-set, các giá trị từ 0-255 được gọi là **active value**, ta sẽ làm như vậy với các vị trí còn lại, tổng cộng ta sẽ có 16*256 cặp plain-ciphertext sau 3 rounds

- Dựa vào hình thì có thể thấy rằng với chỉ 3 rounds thì aes sẽ khuếch tán không đủ mạnh, giả sử với vị trí active value đầu tiên, ta sẽ xor 256 cặp plai-ciphertext lại với nhau và kết quả cho ra sẽ luôn bằng 0x00 với mọi roundkey(tương tự cho các vị trí còn lại)


## 2. Square attack on AES 4 rounds

- Ý tưởng của kĩ thuật này là ta sẽ lợi dụng điểm yếu ở trên để khôi phụ lại roundkey ở vòng thứ 4
- Giả sử ta muốn khôi phục byte đầu tiên, ta sẽ chọn bộ delta-set ở vị trí đầu tiên
- Tiếp theo ta brute-force 256 khả năng của byte thứ nhất(các bytes còn lại giá trị tùy ý) rồi xor lại ciphertext và decrypt round 4,
- Sau đó ta xor 256 ciphertext sau khi decrypt lại với nhau và kiểm tra xem giá trị nào của byte thứ nhất cho ra kết quả = 0x00
- Làm tương tự với các vị trí còn lại
```attack.py
from aes_3_rounds import AES, inv_s_box
from aeskeyschedule import *
from concurrent.futures import ProcessPoolExecutor
from functools import partial
key = b'\xaa' + b'\x00'*15
cipher = AES(key)
def matrix2bytes(matrix):
""" Converts a 4x4 matrix into a 16-byte array. """
return bytes(sum(matrix, []))
def get_delta_set(inactive_value: int) -> list:
delta_set = []
for k in range(256):
base_state = [inactive_value for i in range(16)]
base_state[0] = k
delta_set.append(base_state)
return bytes(delta_set)
# print(get_delta_set(7))
def gather_encrypted_delta_sets(inactive_value:int):
encrypted_ds = []
ds = get_delta_set(inactive_value)
for i in ds:
encrypted_ds.append(cipher.encrypt_block(i))
return encrypted_ds
def is_guess_correct(reversed_bytes) -> bool:
r = 0
for i in reversed_bytes:
r ^= i
return r == 0
def reverse_state(guess, position, encrypted_ds):
r = []
i, j = position
for s in encrypted_ds:
before_add_round_key = s[i][j] ^ guess
before_sub_byte = inv_s_box[before_add_round_key]
r.append(before_sub_byte)
return r
def guess_position(encrypted_data_sets, position)-> int:
position_in_state = (position % 4, position//4)
for encrypted_data_set in encrypted_data_sets:
correct_guess = []
for guess in range(256):
reversed_bytes = reverse_state(guess, position_in_state, encrypted_data_set)
if is_guess_correct(reversed_bytes):
correct_guess.append(guess)
if len(correct_guess)==1:
break
return correct_guess[0]
def crack_last_key():
last_bytes = [0]*16
positions = list(range(16))
encrypted_ds = gather_encrypted_delta_sets()
position_guesser = partial(guess_position, encrypted_ds)
with ProcessPoolExecutor() as executor:
for position, found_byte in zip(positions, executor.map(position_guesser, positions)):
last_bytes[position] = found_byte
return bytes(last_bytes)
key = reverse_key_schedule(bytes(round_key),4)
```
## 3. References
- https://www.davidwong.fr/blockbreakers/square_2_attack4rounds.html