# Short write-up m0lecon
## Guess me
<details>
<summary>chall.py</summary>
```python
#!/usr/bin/env python3
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from hashlib import sha256
from hmac import compare_digest
from random import shuffle
import os
flag = os.environ.get("FLAG", "ptm{REDACTED}")
BLOCK_SIZE = 16
NUM_BITS = BLOCK_SIZE * 8
SBOX = (0xC, 0x5, 0x6, 0xB, 0x9, 0x0, 0xA, 0xD, 0x3, 0xE, 0xF, 0x8, 0x4, 0x7, 0x1, 0x2)
BIT_PERM = tuple((idx * 7) % NUM_BITS for idx in range(NUM_BITS))
def _pad_pkcs7(data, block_size = BLOCK_SIZE):
return data + bytes([(block_size - len(data)) % block_size]) * ((block_size - len(data)) % block_size)
def _xor_bytes(left, right):
return bytes(a ^ b for a, b in zip(left, right))
def _unpad_pkcs7(data, block_size = BLOCK_SIZE):
d = data[-1]
assert d <= len(data)
assert all([x==d for x in data[-d:]])
return data[:-d]
def _perm(data):
state = data
for _ in range(10):
sbox_out = bytearray(len(state))
for idx, value in enumerate(state):
sbox_out[idx] = (SBOX[value >> 4] << 4) | SBOX[value & 0x0F]
bits = []
for value in sbox_out:
for shift in range(8):
bits.append((value >> (7 - shift)) & 0x01)
permuted_bits = [0] * NUM_BITS
for idx, bit in enumerate(bits):
permuted_bits[BIT_PERM[idx]] = bit
state_out = bytearray(len(state))
for idx in range(len(state)):
byte = 0
for shift in range(8):
byte = (byte << 1) | permuted_bits[idx * 8 + shift]
state_out[idx] = byte
state = bytes(state_out)
return state
def _prf(block_index, key, data):
cipher = AES.new(key, AES.MODE_ECB)
mask = cipher.encrypt(sha256(block_index.to_bytes(4, 'big', signed=True)).digest())
result = _xor_bytes(data, mask[:BLOCK_SIZE])
result = _perm(result)
result = _xor_bytes(result, mask[-BLOCK_SIZE:])
return result
def enc_msg(key, nonce, message):
padded = _pad_pkcs7(message)
blocks = [padded[i : i + BLOCK_SIZE] for i in range(0, len(padded), BLOCK_SIZE)]
ciphertext_blocks = []
for idx, block in enumerate(blocks):
keystream = _prf(idx, key, nonce)
ciphertext_blocks.append(_xor_bytes(block, keystream))
return b"".join(ciphertext_blocks)
def enc_tag(key, nonce, additional_data, ciphertext):
ad_padded = _pad_pkcs7(nonce + additional_data)
ct_padded = _pad_pkcs7(ciphertext)
ad_blocks = [ad_padded[i : i + BLOCK_SIZE] for i in range(0, len(ad_padded), BLOCK_SIZE)]
ct_blocks = [ct_padded[i : i + BLOCK_SIZE] for i in range(0, len(ct_padded), BLOCK_SIZE)]
tag = ad_blocks[0]
for idx, block in enumerate(ad_blocks[1:], start=1):
keystream = _prf(idx + 1337, key, block)
tag = _xor_bytes(tag, keystream)
for idx, block in enumerate(ct_blocks):
keystream = _prf(idx + 31337, key, block)
tag = _xor_bytes(tag, keystream)
return _prf(-1, key, tag)
def encrypt(key, nonce, message, additional_data):
ciphertext = enc_msg(key, nonce, message)
tag = enc_tag(key, nonce, additional_data, ciphertext)
return ciphertext, tag
def decrypt(key, nonce, ciphertext, additional_data, tag):
assert len(key) == BLOCK_SIZE
assert len(nonce) == BLOCK_SIZE
assert len(ciphertext) % BLOCK_SIZE == 0
assert len(tag) == BLOCK_SIZE
expected_tag = enc_tag(key, nonce, additional_data, ciphertext)
if not compare_digest(expected_tag, tag):
return False
blocks = [ciphertext[i : i + BLOCK_SIZE] for i in range(0, len(ciphertext), BLOCK_SIZE)]
plaintext_blocks = []
for idx, block in enumerate(blocks):
keystream = _prf(idx, key, nonce)
plaintext_blocks.append(_xor_bytes(block, keystream))
plaintext_padded = b"".join(plaintext_blocks)
try:
plaintext = _unpad_pkcs7(plaintext_padded)
except:
return b"Invalid padding"
return plaintext
if __name__ == "__main__":
for r in range(5):
base = list("m0leCon")
shuffle(base)
key = bytes(sha256("".join(base).encode()).digest())[:BLOCK_SIZE]
for _ in range(16):
nonces = bytes.fromhex(input("Enter nonce (hex): ").strip())
nonces = [nonces[i:i+BLOCK_SIZE] for i in range(0, len(nonces), BLOCK_SIZE)]
additional_data = bytes.fromhex(input("Enter additional_data (hex): ").strip())
ciphertext = bytes.fromhex(input("Enter ciphertext (hex): ").strip())
tag = bytes.fromhex(input("Enter tag (hex): ").strip())
decs = [decrypt(key, nonce, ciphertext, additional_data, tag) for nonce in nonces]
auth = any(decs)
if auth:
if additional_data != b"pretty please":
print("Can you at least say 'please' next time?")
exit()
else:
if all([dec == b"next round please" for dec in decs]):
print("There you go!")
break
else:
print("This message does not seem ok :(")
else:
print("Tag is invalid")
else:
print("Better luck next time!")
exit()
print(flag)
```
</details>
### Tóm tắt đề
- Server triển khai một custom scheme AEAD gồm hai phần: mã hóa thông điệp `enc_msg` và tạo thẻ xác thực `enc_tag`.
- Mỗi round (5 rounds) server chọn khóa 128-bit từ `sha256` của một hoán vị chuỗi "m0leCon", lấy 16 byte đầu (`7! = 5040` candidate key).
- Mỗi round bạn có tối đa 16 lần nhập: `nonce (hex)`, `additional_data (hex)`, `ciphertext (hex)`, `tag (hex)`.
### Bug
- Padding PKCS#7 sai: `_pad_pkcs7` trả về padding độ dài 0 khi độ dài dữ liệu chia hết cho block (kể cả rỗng). Điều này khiến `ciphertext` rỗng có số block bằng 0.
- Xử lý xác thực dùng `any(decs)`: Khi `tag` đúng nhưng padding sai, `decrypt` trả về `b"Invalid padding"` (bytes khác rỗng, truthy), làm `auth` trở thành true.
- Chấp nhận nhiều nonce trong một lần nhập: Ghép nhiều nonce 16 byte vào cùng một chuỗi hex để hỏi membership (thuộc/không thuộc) của khóa.
### Phân tích PRF/MAC
- **PRF**: với `X_i || Y_i = AES_k(sha256(i))` (32 byte), và `P` là hoán vị 10 vòng (S-box + hoán vị bit):
- `PRF(i, k, m) = P(m XOR X_i) XOR Y_i`.
- Đảo PRF khi biết khóa: `PRF^{-1}(i, k, v) = P^{-1}(v XOR Y_i) XOR X_i`.
- Với `ad = b"pretty please"` (13 byte) → sau pad: `A1 = ad || 0x03 0x03 0x03`. Khi `ciphertext = b""`, ta có:
- `ad_blocks = [nonce, A1]`, `ct_blocks = []`.
- `tag_base = nonce XOR PRF(1338, k, A1)` và `expected_tag = PRF(-1, k, tag_base)`.
Hệ quả: với một `tag0` bất kỳ (ví dụ `00^16`), cho MỖI khóa ứng viên `k` tồn tại đúng một `nonce_k` sao cho `enc_tag(k, nonce_k, ad, b"") = tag0`:
### Ý tưởng
1) Tạo 5040 khóa từ mọi hoán vị của "m0leCon" rồi băm `sha256` lấy 16 byte đầu.
2) Tính “nonce membership” cho một `tag0` cố định (ví dụ `tag0 = 0^{16}`) khi `ciphertext = b""` và `ad = b"pretty please"`:
- Với mỗi khóa ứng viên `k`, tính:
- `B_k = prf(1338, k, A1)` với `A1 = ad || 0x03*3`.
- `inner = prf^{-1}(-1, k, tag0)` bằng công thức đảo PRF ở trên.
- Chọn `nonce_k = inner XOR B_k` để thỏa `enc_tag(k, nonce_k, ad, b"") == tag0`.
3) Binary search:
- Ở mỗi truy vấn, ghép nối các `nonce_k` của một nửa tập ứng viên thành một chuỗi hex (nhiều block) và gửi kèm `ad = pretty please`, `ciphertext = ""`, `tag = tag0`.
- Nếu server in “This message does not seem ok :(” → có ít nhất một `nonce_k` hợp lệ trong tập con (tag đúng → `decrypt` rơi vào đường `b"Invalid padding"`, làm `auth` = true).
- Nếu in “Tag is invalid” → không có khóa trong tập con.
- Lặp lại tối đa ~log2(5040) ≈ 12–13 lần
<details>
<summary>solve.py</summary>
```python
import sys
from itertools import permutations
from hashlib import sha256
from Crypto.Cipher import AES
from pwn import remote, process
BLOCK_SIZE = 16
NUM_BITS = 128
SBOX = (
0xC, 0x5, 0x6, 0xB,
0x9, 0x0, 0xA, 0xD,
0x3, 0xE, 0xF, 0x8,
0x4, 0x7, 0x1, 0x2
)
INV_SBOX = [0]*16
for i, v in enumerate(SBOX):
INV_SBOX[v] = i
BIT_PERM = tuple((i * 7) % NUM_BITS for i in range(NUM_BITS))
def bytes_to_bits(b):
out = []
for x in b:
for s in range(8):
out.append((x >> (7 - s)) & 1)
return out
def bits_to_bytes(bits):
out = bytearray(len(bits) // 8)
for i in range(len(out)):
x = 0
for s in range(8):
x = (x << 1) | bits[i * 8 + s]
out[i] = x
return bytes(out)
def xor_bytes(a, b):
return bytes(x ^ y for x, y in zip(a, b))
def perm_fwd(s):
for _ in range(10):
t = bytearray(len(s))
for i, v in enumerate(s):
t[i] = (SBOX[v >> 4] << 4) | SBOX[v & 0x0F]
bits = bytes_to_bits(bytes(t))
p = [0]*NUM_BITS
for i, bit in enumerate(bits):
p[BIT_PERM[i]] = bit
s = bits_to_bytes(p)
return s
def perm_inv(s):
for _ in range(10):
bits = bytes_to_bits(s)
pre = [0]*NUM_BITS
for i in range(NUM_BITS):
pre[i] = bits[BIT_PERM[i]]
t = bits_to_bytes(pre)
o = bytearray(len(t))
for i, v in enumerate(t):
o[i] = (INV_SBOX[v >> 4] << 4) | INV_SBOX[v & 0x0F]
s = bytes(o)
return s
def prf(idx, key, data_block):
aes = AES.new(key, AES.MODE_ECB)
seed = sha256(idx.to_bytes(4, "big", signed=True)).digest()
mask = aes.encrypt(seed)
left = mask[:BLOCK_SIZE]
right = mask[-BLOCK_SIZE:]
r = xor_bytes(data_block, left)
r = perm_fwd(r)
return xor_bytes(r, right)
def pad_bad(data, bs=BLOCK_SIZE):
n = (bs - len(data)) % bs
return data + bytes([n]) * n
def enc_msg(key, nonce, msg):
padded = pad_bad(msg)
blocks = [padded[i:i+BLOCK_SIZE] for i in range(0, len(padded), BLOCK_SIZE)]
out = []
for i, b in enumerate(blocks):
ks = prf(i, key, nonce)
out.append(xor_bytes(b, ks))
return b"".join(out)
def enc_tag(key, nonce, ad, ct):
ad_p = pad_bad(nonce + ad)
ct_p = pad_bad(ct)
ad_blocks = [ad_p[i:i+BLOCK_SIZE] for i in range(0, len(ad_p), BLOCK_SIZE)]
ct_blocks = [ct_p[i:i+BLOCK_SIZE] for i in range(0, len(ct_p), BLOCK_SIZE)]
tag = ad_blocks[0]
for i, b in enumerate(ad_blocks[1:], start=1):
tag = xor_bytes(tag, prf(i + 1337, key, b))
for j, b in enumerate(ct_blocks):
tag = xor_bytes(tag, prf(j + 31337, key, b))
return prf(-1, key, tag)
def build_keys():
base = "m0leCon"
out = []
for p in permutations(base):
out.append(sha256("".join(p).encode()).digest()[:BLOCK_SIZE])
return out
def precompute_membership(keys, ad, tag0):
A1 = ad + bytes([3]) * 3
idx_ad = 1338
nonces = []
for k in keys:
aes = AES.new(k, AES.MODE_ECB)
seed_m1 = sha256((-1).to_bytes(4, "big", signed=True)).digest()
mask_m1 = aes.encrypt(seed_m1)
x_m1 = mask_m1[:BLOCK_SIZE]
y_m1 = mask_m1[-BLOCK_SIZE:]
b_k = prf(idx_ad, k, A1)
inv = perm_inv(xor_bytes(tag0, y_m1))
n = xor_bytes(xor_bytes(b_k, x_m1), inv)
nonces.append(n)
return nonces
def precompute_success_triplet(keys, nonce_final, ad, msg):
out = []
for k in keys:
ct = enc_msg(k, nonce_final, msg)
tg = enc_tag(k, nonce_final, ad, ct)
out.append((nonce_final, ct, tg))
return out
def interact_round(p, nonce_candidates, triples, ad, tag0):
candidates = list(range(len(nonce_candidates)))
def ask(nonce_blob, ad_hex, ct_hex, tag_hex):
p.recvuntil(b"Enter nonce (hex): ")
p.sendline(nonce_blob.hex().encode())
p.recvuntil(b"Enter additional_data (hex): ")
p.sendline(ad_hex.encode())
p.recvuntil(b"Enter ciphertext (hex): ")
p.sendline(ct_hex.encode())
p.recvuntil(b"Enter tag (hex): ")
p.sendline(tag_hex.encode())
return p.recvline(timeout=10).decode(errors="ignore")
ad_hex = ad.hex()
tag0_hex = tag0.hex()
tests = 0
while len(candidates) > 1 and tests < 15:
tests += 1
mid = len(candidates) // 2
subset = candidates[:mid]
nonce_blob = b"".join(nonce_candidates[i] for i in subset)
resp = ask(nonce_blob, ad_hex, "", tag0_hex)
if "This message does not seem ok" in resp:
candidates = subset
else:
candidates = candidates[mid:]
idx = candidates[0]
n_final, ct_final, tg_final = triples[idx]
resp = ask(n_final, ad_hex, ct_final.hex(), tg_final.hex())
if "There you go!" not in resp:
raise RuntimeError("Final attempt failed: " + resp)
ad = b"pretty please"
msg = b"next round please"
tag0 = b"\x00" * 16
nonce_final = b"N" * 16
keys = build_keys()
nonce_candidates = precompute_membership(keys, ad, tag0)
triples = precompute_success_triplet(keys, nonce_final, ad, msg)
# p = remote(host, port, timeout=10)
p = process(["python3", "guess_me.py"])
for _ in range(5):
interact_round(p, nonce_candidates, triples, ad, tag0)
p.interactive()
```
</details>
## One more bit
<details>
<summary>fhe_server.py</summary>
```python!
from __future__ import annotations
import json
import sys
from typing import Any, Callable, Dict, List
import os
from ind_cpa_d_bitwise_game import (
AddCircuit,
BitwiseCKKSIndCpaDGame,
HomomorphicCKKSFunction,
MultiplyCircuit,
SquareCircuit,
)
flag = os.environ.get("FLAG", "ptm{REDACTED}")
CIRCUIT_FACTORIES: Dict[str, Callable[[], HomomorphicCKKSFunction]] = {
"square": SquareCircuit,
"add": AddCircuit,
"sum": AddCircuit,
"mul": MultiplyCircuit,
"multiply": MultiplyCircuit,
}
class INDCPADServer:
"""Simple line-oriented server exposing the IND-CPA-D oracles to a client."""
def __init__(
self,
*,
mult_depth: int = 2,
scale_mod_size: int = 50,
bit_length: int = 64,
) -> None:
self.game = BitwiseCKKSIndCpaDGame(
mult_depth=mult_depth,
scale_mod_size=scale_mod_size,
batch_size=1,
bit_length=bit_length,
)
self._resolved = False
# ------------------------------------------------------------------ #
# Public API
def encrypt(self, m0: float, m1: float) -> Dict[str, Any]:
ciphertext = self.game.enc_oracle(m0, m1)
index = len(self.game.state) - 1
return {"status": "ok", "state_index": index}
def eval(self, function: str, indices: List[int]) -> Dict[str, Any]:
circuit = self._resolve_circuit(function)
self.game.eval_oracle(tuple(indices), circuit)
index = len(self.game.state) - 1
return {"status": "ok", "state_index": index}
def decrypt(self, index: int, position: int) -> Dict[str, Any]:
bit = self.game.dec_oracle(index, position)
if bit is None:
return {"status": "denied", "reason": "bit-mismatch"}
return {"status": "ok", "bit": int(bit)}
def guess(self, bit: int) -> Dict[str, Any]:
if self._resolved:
return {"status": "error", "error": "game already resolved"}
self._resolved = True
if (bit & 1) == self.game.challenge_bit:
return {"status": "ok", "result": "WIN"}
return {"status": "ok", "result": "LOSE"}
def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
command = request.get("command")
if command is None:
raise ValueError("missing command field")
command = str(command).lower()
if command == "encrypt":
return self.encrypt(float(request["m0"]), float(request["m1"]))
if command == "eval":
function = str(request["function"])
indices = [int(idx) for idx in request["indices"]]
return self.eval(function, indices)
if command == "decrypt":
return self.decrypt(int(request["index"]), int(request["position"]))
if command == "guess":
return self.guess(int(request["bit"]))
raise ValueError(f"unknown command '{command}'")
# ------------------------------------------------------------------ #
# Helpers
@staticmethod
def _resolve_circuit(descriptor: str) -> HomomorphicCKKSFunction:
keyword = descriptor.lower()
builder = CIRCUIT_FACTORIES.get(keyword)
if builder is None:
raise ValueError(f"unsupported circuit '{descriptor}'")
return builder()
def main() -> None:
rounds = 100
for current_round in range(1, rounds + 1):
server = INDCPADServer()
print(json.dumps({"status": "new_round", "round": current_round}))
sys.stdout.flush()
while True:
line = sys.stdin.readline()
if not line:
return
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
except json.JSONDecodeError:
print(json.dumps({"status": "error", "error": "invalid json"}))
sys.stdout.flush()
continue
try:
response = server.handle_request(request)
except Exception as exc:
response = {"status": "error", "error": str(exc)}
print(json.dumps(response))
sys.stdout.flush()
result = response.get("result")
if result == "WIN":
break
if result == "LOSE":
return
print(json.dumps({"status": "ok", "flag": flag}))
sys.stdout.flush()
if __name__ == "__main__":
main()
```
</details>
<details>
<summary>ind_cpa_d_bitwise_game.py</summary>
```python
from __future__ import annotations
import secrets
from dataclasses import dataclass
from typing import List, Protocol, Tuple
from openfhe import (
CCParamsCKKSRNS,
Ciphertext,
CryptoContext,
GenCryptoContext,
PKESchemeFeature,
Plaintext,
)
DEBUG_GAME = False
def float_to_bits(value: float, bit_length: int, scale_bits: int) -> Tuple[int, ...]:
"""
Convert a CKKS plaintext slot to a fixed-length bit tuple.
Values are first scaled by 2^scale_bits and rounded to the nearest integer.
The integer is then represented in two's complement with bit_length bits.
"""
scaled = int(round(value * (1 << scale_bits)))
scaled = abs(scaled)
mask = (1 << bit_length) - 1
twos_complement = scaled & mask
return tuple((twos_complement >> i) & 1 for i in range(bit_length))
class HomomorphicCKKSFunction(Protocol):
"""Interface for describing circuits usable by the Eval oracle."""
num_inputs: int
def plaintext(self, values: Tuple[float, ...]) -> float:
"""Evaluate the circuit on cleartext data."""
def ciphertext(
self, cc: CryptoContext, ciphertexts: Tuple[Ciphertext, ...]
) -> Ciphertext:
"""Evaluate the same circuit homomorphically on ciphertext inputs."""
@dataclass
class OracleRow:
"""Single row stored by the challenger."""
m0: float
m1: float
ciphertext: Ciphertext
class BitwiseCKKSIndCpaDGame:
"""
CKKS IND-CPA-D challenger with bit-guarded decryption oracle.
The challenger maintains a state of tuples (m0, m1, bits0, bits1, Enc(m_b)).
The decryption oracle reveals Dec(Enc(m_b)) only if the requested bit index
matches in both bitstrings.
"""
def __init__(
self,
mult_depth: int = 2,
scale_mod_size: int = 50,
batch_size: int = 1,
*,
challenge_bit: int | None = None,
bit_length: int = 64,
) -> None:
params = CCParamsCKKSRNS()
params.SetMultiplicativeDepth(mult_depth)
params.SetScalingModSize(scale_mod_size)
params.SetBatchSize(batch_size)
self.cc: CryptoContext = GenCryptoContext(params)
self.cc.Enable(PKESchemeFeature.PKE)
self.cc.Enable(PKESchemeFeature.KEYSWITCH)
self.cc.Enable(PKESchemeFeature.LEVELEDSHE)
self.keys = self.cc.KeyGen()
self.cc.EvalMultKeyGen(self.keys.secretKey)
self.challenge_bit = (
secrets.randbits(1) if challenge_bit is None else (challenge_bit & 1)
)
if DEBUG_GAME:
print(f"[DEBUG] challenge_bit = {self.challenge_bit}")
self.state: List[OracleRow] = []
self.scale_bits = scale_mod_size
self.bit_length = bit_length
# ------------------------------------------------------------------ #
# Helper utilities
def _encode(self, value: float) -> Plaintext:
return self.cc.MakeCKKSPackedPlaintext([value])
def _to_bits(self, value: float) -> Tuple[int, ...]:
return float_to_bits(value, self.bit_length, self.scale_bits)
# ------------------------------------------------------------------ #
# Oracles
def enc_oracle(self, m0: float, m1: float) -> Ciphertext:
pt0 = self._encode(m0)
pt1 = self._encode(m1)
chosen_pt = pt0 if self.challenge_bit == 0 else pt1
ciphertext = self.cc.Encrypt(self.keys.publicKey, chosen_pt)
row = OracleRow(m0=m0, m1=m1, ciphertext=ciphertext)
self.state.append(row)
if DEBUG_GAME:
self._debug_log_bits(len(self.state) - 1, m0, m1)
return ciphertext
def eval_oracle(
self, indices: Tuple[int, ...], circuit: HomomorphicCKKSFunction
) -> Ciphertext:
if len(indices) != circuit.num_inputs:
raise ValueError("indices count does not match circuit arity")
rows: List[OracleRow] = []
for index in indices:
if index < 0 or index >= len(self.state):
raise IndexError("state index out of range")
rows.append(self.state[index])
m0_inputs = tuple(row.m0 for row in rows)
m1_inputs = tuple(row.m1 for row in rows)
ct_inputs = tuple(row.ciphertext for row in rows)
new_m0 = circuit.plaintext(m0_inputs)
new_m1 = circuit.plaintext(m1_inputs)
new_ct = circuit.ciphertext(self.cc, ct_inputs)
new_row = OracleRow(m0=new_m0, m1=new_m1, ciphertext=new_ct)
self.state.append(new_row)
if DEBUG_GAME:
self._debug_log_bits(len(self.state) - 1, new_m0, new_m1)
return new_ct
def dec_oracle(self, index: int, bit_index: int) -> int | None:
"""
Decryption oracle: reveals only the selected bit of Dec(Enc(m_b)).
"""
if index < 0 or index >= len(self.state):
raise IndexError("state index out of range")
if bit_index < 0 or bit_index >= self.bit_length:
raise IndexError("bit index out of range")
row = self.state[index]
bits0 = self._to_bits(row.m0)
bits1 = self._to_bits(row.m1)
if bits0[bit_index] != bits1[bit_index]:
return None
plaintext = self.cc.Decrypt(row.ciphertext, self.keys.secretKey)
plaintext.SetLength(1)
value = plaintext.GetRealPackedValue()[0]
bits = float_to_bits(value, self.bit_length, self.scale_bits)
bits_str = "".join(str(bit) for bit in bits)
# print(f"[DEBUG] dec_oracle[{index}] value={value} bits={bits_str}")
return bits[bit_index]
# ------------------------------------------------------------------ #
# Debug helpers
def _debug_log_bits(self, index: int, m0: float, m1: float) -> None:
bits0 = "".join(str(bit) for bit in self._to_bits(m0))
bits1 = "".join(str(bit) for bit in self._to_bits(m1))
print(f"[DEBUG] state[{index}] m0={m0} bits={bits0}")
print(f"[DEBUG] state[{index}] m1={m1} bits={bits1}")
class SquareCircuit(HomomorphicCKKSFunction):
num_inputs = 1
def plaintext(self, values: Tuple[float, ...]) -> float:
value = values[0]
return value * value
def ciphertext(
self, cc: CryptoContext, ciphertexts: Tuple[Ciphertext, ...]
) -> Ciphertext:
ciphertext = ciphertexts[0]
return cc.EvalMult(ciphertext, ciphertext)
class AddCircuit(HomomorphicCKKSFunction):
num_inputs = 2
def plaintext(self, values: Tuple[float, ...]) -> float:
return values[0] + values[1]
def ciphertext(
self, cc: CryptoContext, ciphertexts: Tuple[Ciphertext, ...]
) -> Ciphertext:
return cc.EvalAdd(ciphertexts[0], ciphertexts[1])
class MultiplyCircuit(HomomorphicCKKSFunction):
num_inputs = 2
def plaintext(self, values: Tuple[float, ...]) -> float:
return values[0] * values[1]
def ciphertext(
self, cc: CryptoContext, ciphertexts: Tuple[Ciphertext, ...]
) -> Ciphertext:
return cc.EvalMult(ciphertexts[0], ciphertexts[1])
def _format_plaintext(plaintext: Plaintext, precision: int = 6) -> str:
return plaintext.GetFormattedValues(precision)
if __name__ == "__main__":
game = BitwiseCKKSIndCpaDGame(mult_depth=2, scale_mod_size=50, batch_size=1)
ct0 = game.enc_oracle(0.125, 1.75)
print(f"state[0] ciphertext: {ct0}")
ct1 = game.enc_oracle(-0.5, -0.5)
print(f"state[1] ciphertext: {ct1}")
square = SquareCircuit()
ct_square = game.eval_oracle((0,), square)
add = AddCircuit()
ct_add = game.eval_oracle((0, 1), add)
bit_idx = 10
bit_value = game.dec_oracle(1, bit_idx)
if bit_value is None:
print(f"bit {bit_idx} differs, decryption denied.")
else:
print(f"bit {bit_idx} matches, decrypted bit: {bit_value}")
```
</details>
## **Tóm tắt:**
Server tương tác bằng IND-CPA-D trên CKKS (FHE xấp xỉ):
- Mỗi round sinh ra một bit bí mật `b ∈ {0,1}`.
- Oracle `encrypt(m0, m1)` lưu vào state một cặp `(m0, m1)` và trả về chỉ số của `Enc(m_b)`.
- Oracle `eval(function, indices)` áp dụng mạch đồng hình (square/add/mul) lên các ciphertext trong `indices`, đồng thời cập nhật cặp plaintext tương ứng bằng cách áp dụng cùng hàm trên `(m0, m1)` theo hai nhánh. Kết quả cũng được lưu vào state và trả về chỉ số mới.
- Oracle `decrypt(index, position)` chỉ trả về bit tại vị trí `position` của bản giải mã nếu bit ở vị trí đó của `m0` và `m1` (cặp plaintext tại `state[index]`) là giống nhau; nếu khác sẽ trả về `denied`.
- Sau đó ta phải đoán `b`. Qua 100 round đoán đúng sẽ nhận flag.
### Bit-guard hoạt động thế nào?
Trong `ind_cpa_d_bitwise_game.py`, giá trị thực được đổi sang bit bởi `float_to_bits` (scale + round + cắt 64-bit two's complement, nhưng ở đây lấy trị tuyệt đối trước khi mask):
```py
def float_to_bits(value: float, bit_length: int, scale_bits: int) -> Tuple[int, ...]:
scaled = int(round(value * (1 << scale_bits)))
scaled = abs(scaled)
mask = (1 << bit_length) - 1
twos_complement = scaled & mask
return tuple((twos_complement >> i) & 1 for i in range(bit_length))
```
Oracle `decrypt` chỉ trả về bit nếu hai nhánh `(m0, m1)` có cùng bit tại vị trí yêu cầu:
```py
bits0 = self._to_bits(row.m0)
bits1 = self._to_bits(row.m1)
if bits0[bit_index] != bits1[bit_index]:
return None # denied
plaintext = self.cc.Decrypt(row.ciphertext, self.keys.secretKey)
value = plaintext.GetRealPackedValue()[0]
bits = float_to_bits(value, self.bit_length, self.scale_bits)
return bits[bit_index]
```
### Ý tưởng
1) Xây “amplifier” zero-noise lớn: cộng nhiều lần `Enc(0)` rồi `square` để tăng mạnh noise, trong khi plaintext vẫn là 0 nên không chạm bit-guard.
```py
def build_amplifier(sock, buf):
send({"command":"encrypt","m0":0.0,"m1":0.0})
idx0 = recv()["state_index"]
powers = [idx0]
for _ in range(NZEROS.bit_length() - 1):
send({"command":"eval","function":"add","indices":[powers[-1], powers[-1]]})
powers.append(recv()["state_index"])
cur = pick_subset_sum(powers, NZEROS)
while len(cur) > 1:
batch = pairwise_add(cur) # gửi nhiều lệnh add theo cặp
cur = recv_indices(batch) # thu các chỉ số state mới
send({"command":"eval","function":"square","indices":[cur[0]]})
return recv()["state_index"]
```
2) So sánh hai nhánh bằng cặp đối xứng: tạo hai encrypt đối xứng theo `b` và nhân với amplifier zero.
```py
e1 = encrypt(hval, SMALL) # nếu b=0 → biên độ lớn ở nhánh trái
e2 = encrypt(SMALL, hval) # nếu b=1 → biên độ lớn ở nhánh phải
m1 = eval_mul(e1, amp_idx) # plaintext = 0, nhưng noise phụ thuộc e1
m2 = eval_mul(e2, amp_idx) # plaintext = 0, nhưng noise phụ thuộc e2
# Đọc bit ‘position’ vì bit-guard pass (plaintext 0 ở cả hai nhánh)
b1 = decrypt(m1, position)
b2 = decrypt(m2, position)
vote = sign(b1 - b2) # b1 có xu hướng lớn hơn nếu b=0, ngược lại nếu b=1
```
3) Bỏ phiếu với nhiều `hval` và `position`: chọn các `hval ∈ {1e5,1e6,1e7,1e8}` và `position ∈ {20,24}` để ổn định thống kê; quyết định `b` theo tổng phiếu.
```py
PROBES = [1e5, 1e6, 1e7, 1e8]
SMALL = 1e-3
TEST_POSITIONS = [20, 24]
best_guess = 0
for pos in TEST_POSITIONS:
votes = 0
for hval in PROBES:
e1 = encrypt(hval, SMALL)
e2 = encrypt(SMALL, hval)
m1 = eval_mul(e1, amp_idx)
m2 = eval_mul(e2, amp_idx)
b1 = decrypt(m1, pos)
b2 = decrypt(m2, pos)
votes += 1 if b1 > b2 else (-1 if b1 < b2 else 0)
if abs(votes) improves best:
best_guess = 0 if votes >= 0 else 1
send({"command":"guess","bit":best_guess})
```
4) Lặp 100 round để đạt `WIN` và lấy flag
<details>
<summary>solve.py</summary>
```python
import socket, json
HOST = "one-more-bit.challs.m0lecon.it"
PORT = 24180
NZEROS = 120
PROBES = [1e5, 1e6, 1e7, 1e8]
SMALL = 0.001
TEST_POSITIONS = [20, 24]
READ_TIMEOUT = 8
RECV_CHUNK = 4096
def _recv_line(sock, buf):
while True:
nl = buf.find(b"\n")
if nl != -1:
line = buf[:nl+1]
del buf[:nl+1]
return line.decode().strip(), buf
try:
chunk = sock.recv(RECV_CHUNK)
except socket.timeout:
raise TimeoutError("recv timeout")
if not chunk:
raise EOFError("connection closed")
buf += chunk
def recv_json(sock, buf):
s, buf = _recv_line(sock, buf)
try:
return json.loads(s), buf
except Exception:
return {"status": "bad_json", "raw": s}, buf
def recv_n_json(sock, buf, n):
out = []
for _ in range(n):
r, buf = recv_json(sock, buf)
out.append(r)
return out, buf
def build_amplifier(sock, buf):
sock.sendall((json.dumps({"command": "encrypt", "m0": 0.0, "m1": 0.0}, separators=(",", ":")) + "\n").encode())
r0, buf = recv_json(sock, buf)
idx0 = r0["state_index"]
powers = [idx0]
for _ in range(NZEROS.bit_length() - 1):
sock.sendall((json.dumps({"command": "eval", "function": "add", "indices": [powers[-1], powers[-1]]}, separators=(",", ":")) + "\n").encode())
r, buf = recv_json(sock, buf)
powers.append(r["state_index"])
add_list, need = [], NZEROS
for i in range(len(powers) - 1, -1, -1):
v = 1 << i
if v <= need:
add_list.append(powers[i])
need -= v
if need == 0:
break
cur = add_list[:]
while len(cur) > 1:
cmds, nxt = [], []
it = iter(cur)
for a in it:
try:
b = next(it)
cmds.append({"command": "eval", "function": "add", "indices": [a, b]})
except StopIteration:
nxt.append(a)
for c in cmds:
sock.sendall((json.dumps(c, separators=(",", ":")) + "\n").encode())
if cmds:
rs, buf = recv_n_json(sock, buf, len(cmds))
nxt.extend(r["state_index"] for r in rs)
cur = nxt
sum_idx = cur[0]
sock.sendall((json.dumps({"command": "eval", "function": "square", "indices": [sum_idx]}, separators=(",", ":")) + "\n").encode())
r, buf = recv_json(sock, buf)
return r["state_index"], buf
def main():
sock = socket.create_connection((HOST, PORT))
sock.settimeout(READ_TIMEOUT)
buf = bytearray()
probe_order = sorted(PROBES, key=lambda x: abs(x))
mid = len(probe_order) // 2
PROBES_ORDERED = [probe_order[mid]] + probe_order[:mid] + probe_order[mid+1:]
POS_ORDERED = TEST_POSITIONS[:]
while True:
msg, buf = recv_json(sock, buf)
if msg.get("status") == "new_round":
amp_idx, buf = build_amplifier(sock, buf)
best_pos = None
best_score = 0
best_guess = 0
for pos in POS_ORDERED:
votes = 0
for i, hval in enumerate(PROBES_ORDERED):
sock.sendall((json.dumps({"command": "encrypt", "m0": hval, "m1": SMALL}, separators=(",", ":")) + "\n").encode())
sock.sendall((json.dumps({"command": "encrypt", "m0": SMALL, "m1": hval}, separators=(",", ":")) + "\n").encode())
rs, buf = recv_n_json(sock, buf, 2)
e1, e2 = rs[0]["state_index"], rs[1]["state_index"]
sock.sendall((json.dumps({"command": "eval", "function": "mul", "indices": [e1, amp_idx]}, separators=(",", ":")) + "\n").encode())
sock.sendall((json.dumps({"command": "eval", "function": "mul", "indices": [e2, amp_idx]}, separators=(",", ":")) + "\n").encode())
rs, buf = recv_n_json(sock, buf, 2)
m1, m2 = rs[0]["state_index"], rs[1]["state_index"]
sock.sendall((json.dumps({"command": "decrypt", "index": m1, "position": pos}, separators=(",", ":")) + "\n").encode())
sock.sendall((json.dumps({"command": "decrypt", "index": m2, "position": pos}, separators=(",", ":")) + "\n").encode())
rs, buf = recv_n_json(sock, buf, 2)
b1 = rs[0]["bit"] if rs[0].get("status") == "ok" else None
b2 = rs[1]["bit"] if rs[1].get("status") == "ok" else None
if b1 is None and b2 is None:
vote = 0
elif b1 is not None and b2 is None:
vote = 1
elif b1 is None and b2 is not None:
vote = -1
else:
vote = 1 if b1 > b2 else (-1 if b1 < b2 else 0)
votes += vote
if abs(votes) > (len(PROBES_ORDERED) - (i + 1)):
break
if abs(votes) > abs(best_score):
best_score = votes
best_guess = 0 if votes >= 0 else 1
best_pos = pos
if abs(best_score) > len(PROBES_ORDERED):
break
sock.sendall((json.dumps({"command": "guess", "bit": best_guess}, separators=(",", ":")) + "\n").encode())
resp, buf = recv_json(sock, buf)
if resp.get("result") == "LOSE":
break
if "flag" in resp:
print(resp["flag"])
break
elif "flag" in msg:
print(msg["flag"])
break
if __name__ == "__main__":
main()
```
</details>
Vì bài này có timeout mà trên máy tính lại không đủ nhanh để chạy nên chúng ta có thể đưa code lên kaggle để chạy.
