# Short write-up m0lecon ## Guess me <details> <summary>chall.py</summary> ```python #!/usr/bin/env python3 from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from hashlib import sha256 from hmac import compare_digest from random import shuffle import os flag = os.environ.get("FLAG", "ptm{REDACTED}") BLOCK_SIZE = 16 NUM_BITS = BLOCK_SIZE * 8 SBOX = (0xC, 0x5, 0x6, 0xB, 0x9, 0x0, 0xA, 0xD, 0x3, 0xE, 0xF, 0x8, 0x4, 0x7, 0x1, 0x2) BIT_PERM = tuple((idx * 7) % NUM_BITS for idx in range(NUM_BITS)) def _pad_pkcs7(data, block_size = BLOCK_SIZE): return data + bytes([(block_size - len(data)) % block_size]) * ((block_size - len(data)) % block_size) def _xor_bytes(left, right): return bytes(a ^ b for a, b in zip(left, right)) def _unpad_pkcs7(data, block_size = BLOCK_SIZE): d = data[-1] assert d <= len(data) assert all([x==d for x in data[-d:]]) return data[:-d] def _perm(data): state = data for _ in range(10): sbox_out = bytearray(len(state)) for idx, value in enumerate(state): sbox_out[idx] = (SBOX[value >> 4] << 4) | SBOX[value & 0x0F] bits = [] for value in sbox_out: for shift in range(8): bits.append((value >> (7 - shift)) & 0x01) permuted_bits = [0] * NUM_BITS for idx, bit in enumerate(bits): permuted_bits[BIT_PERM[idx]] = bit state_out = bytearray(len(state)) for idx in range(len(state)): byte = 0 for shift in range(8): byte = (byte << 1) | permuted_bits[idx * 8 + shift] state_out[idx] = byte state = bytes(state_out) return state def _prf(block_index, key, data): cipher = AES.new(key, AES.MODE_ECB) mask = cipher.encrypt(sha256(block_index.to_bytes(4, 'big', signed=True)).digest()) result = _xor_bytes(data, mask[:BLOCK_SIZE]) result = _perm(result) result = _xor_bytes(result, mask[-BLOCK_SIZE:]) return result def enc_msg(key, nonce, message): padded = _pad_pkcs7(message) blocks = [padded[i : i + BLOCK_SIZE] for i in range(0, len(padded), BLOCK_SIZE)] ciphertext_blocks = [] for idx, block in enumerate(blocks): keystream = _prf(idx, key, nonce) ciphertext_blocks.append(_xor_bytes(block, keystream)) return b"".join(ciphertext_blocks) def enc_tag(key, nonce, additional_data, ciphertext): ad_padded = _pad_pkcs7(nonce + additional_data) ct_padded = _pad_pkcs7(ciphertext) ad_blocks = [ad_padded[i : i + BLOCK_SIZE] for i in range(0, len(ad_padded), BLOCK_SIZE)] ct_blocks = [ct_padded[i : i + BLOCK_SIZE] for i in range(0, len(ct_padded), BLOCK_SIZE)] tag = ad_blocks[0] for idx, block in enumerate(ad_blocks[1:], start=1): keystream = _prf(idx + 1337, key, block) tag = _xor_bytes(tag, keystream) for idx, block in enumerate(ct_blocks): keystream = _prf(idx + 31337, key, block) tag = _xor_bytes(tag, keystream) return _prf(-1, key, tag) def encrypt(key, nonce, message, additional_data): ciphertext = enc_msg(key, nonce, message) tag = enc_tag(key, nonce, additional_data, ciphertext) return ciphertext, tag def decrypt(key, nonce, ciphertext, additional_data, tag): assert len(key) == BLOCK_SIZE assert len(nonce) == BLOCK_SIZE assert len(ciphertext) % BLOCK_SIZE == 0 assert len(tag) == BLOCK_SIZE expected_tag = enc_tag(key, nonce, additional_data, ciphertext) if not compare_digest(expected_tag, tag): return False blocks = [ciphertext[i : i + BLOCK_SIZE] for i in range(0, len(ciphertext), BLOCK_SIZE)] plaintext_blocks = [] for idx, block in enumerate(blocks): keystream = _prf(idx, key, nonce) plaintext_blocks.append(_xor_bytes(block, keystream)) plaintext_padded = b"".join(plaintext_blocks) try: plaintext = _unpad_pkcs7(plaintext_padded) except: return b"Invalid padding" return plaintext if __name__ == "__main__": for r in range(5): base = list("m0leCon") shuffle(base) key = bytes(sha256("".join(base).encode()).digest())[:BLOCK_SIZE] for _ in range(16): nonces = bytes.fromhex(input("Enter nonce (hex): ").strip()) nonces = [nonces[i:i+BLOCK_SIZE] for i in range(0, len(nonces), BLOCK_SIZE)] additional_data = bytes.fromhex(input("Enter additional_data (hex): ").strip()) ciphertext = bytes.fromhex(input("Enter ciphertext (hex): ").strip()) tag = bytes.fromhex(input("Enter tag (hex): ").strip()) decs = [decrypt(key, nonce, ciphertext, additional_data, tag) for nonce in nonces] auth = any(decs) if auth: if additional_data != b"pretty please": print("Can you at least say 'please' next time?") exit() else: if all([dec == b"next round please" for dec in decs]): print("There you go!") break else: print("This message does not seem ok :(") else: print("Tag is invalid") else: print("Better luck next time!") exit() print(flag) ``` </details> ### Tóm tắt đề - Server triển khai một custom scheme AEAD gồm hai phần: mã hóa thông điệp `enc_msg` và tạo thẻ xác thực `enc_tag`. - Mỗi round (5 rounds) server chọn khóa 128-bit từ `sha256` của một hoán vị chuỗi "m0leCon", lấy 16 byte đầu (`7! = 5040` candidate key). - Mỗi round bạn có tối đa 16 lần nhập: `nonce (hex)`, `additional_data (hex)`, `ciphertext (hex)`, `tag (hex)`. ### Bug - Padding PKCS#7 sai: `_pad_pkcs7` trả về padding độ dài 0 khi độ dài dữ liệu chia hết cho block (kể cả rỗng). Điều này khiến `ciphertext` rỗng có số block bằng 0. - Xử lý xác thực dùng `any(decs)`: Khi `tag` đúng nhưng padding sai, `decrypt` trả về `b"Invalid padding"` (bytes khác rỗng, truthy), làm `auth` trở thành true. - Chấp nhận nhiều nonce trong một lần nhập: Ghép nhiều nonce 16 byte vào cùng một chuỗi hex để hỏi membership (thuộc/không thuộc) của khóa. ### Phân tích PRF/MAC - **PRF**: với `X_i || Y_i = AES_k(sha256(i))` (32 byte), và `P` là hoán vị 10 vòng (S-box + hoán vị bit): - `PRF(i, k, m) = P(m XOR X_i) XOR Y_i`. - Đảo PRF khi biết khóa: `PRF^{-1}(i, k, v) = P^{-1}(v XOR Y_i) XOR X_i`. - Với `ad = b"pretty please"` (13 byte) → sau pad: `A1 = ad || 0x03 0x03 0x03`. Khi `ciphertext = b""`, ta có: - `ad_blocks = [nonce, A1]`, `ct_blocks = []`. - `tag_base = nonce XOR PRF(1338, k, A1)` và `expected_tag = PRF(-1, k, tag_base)`. Hệ quả: với một `tag0` bất kỳ (ví dụ `00^16`), cho MỖI khóa ứng viên `k` tồn tại đúng một `nonce_k` sao cho `enc_tag(k, nonce_k, ad, b"") = tag0`: ### Ý tưởng 1) Tạo 5040 khóa từ mọi hoán vị của "m0leCon" rồi băm `sha256` lấy 16 byte đầu. 2) Tính “nonce membership” cho một `tag0` cố định (ví dụ `tag0 = 0^{16}`) khi `ciphertext = b""` và `ad = b"pretty please"`: - Với mỗi khóa ứng viên `k`, tính: - `B_k = prf(1338, k, A1)` với `A1 = ad || 0x03*3`. - `inner = prf^{-1}(-1, k, tag0)` bằng công thức đảo PRF ở trên. - Chọn `nonce_k = inner XOR B_k` để thỏa `enc_tag(k, nonce_k, ad, b"") == tag0`. 3) Binary search: - Ở mỗi truy vấn, ghép nối các `nonce_k` của một nửa tập ứng viên thành một chuỗi hex (nhiều block) và gửi kèm `ad = pretty please`, `ciphertext = ""`, `tag = tag0`. - Nếu server in “This message does not seem ok :(” → có ít nhất một `nonce_k` hợp lệ trong tập con (tag đúng → `decrypt` rơi vào đường `b"Invalid padding"`, làm `auth` = true). - Nếu in “Tag is invalid” → không có khóa trong tập con. - Lặp lại tối đa ~log2(5040) ≈ 12–13 lần <details> <summary>solve.py</summary> ```python import sys from itertools import permutations from hashlib import sha256 from Crypto.Cipher import AES from pwn import remote, process BLOCK_SIZE = 16 NUM_BITS = 128 SBOX = ( 0xC, 0x5, 0x6, 0xB, 0x9, 0x0, 0xA, 0xD, 0x3, 0xE, 0xF, 0x8, 0x4, 0x7, 0x1, 0x2 ) INV_SBOX = [0]*16 for i, v in enumerate(SBOX): INV_SBOX[v] = i BIT_PERM = tuple((i * 7) % NUM_BITS for i in range(NUM_BITS)) def bytes_to_bits(b): out = [] for x in b: for s in range(8): out.append((x >> (7 - s)) & 1) return out def bits_to_bytes(bits): out = bytearray(len(bits) // 8) for i in range(len(out)): x = 0 for s in range(8): x = (x << 1) | bits[i * 8 + s] out[i] = x return bytes(out) def xor_bytes(a, b): return bytes(x ^ y for x, y in zip(a, b)) def perm_fwd(s): for _ in range(10): t = bytearray(len(s)) for i, v in enumerate(s): t[i] = (SBOX[v >> 4] << 4) | SBOX[v & 0x0F] bits = bytes_to_bits(bytes(t)) p = [0]*NUM_BITS for i, bit in enumerate(bits): p[BIT_PERM[i]] = bit s = bits_to_bytes(p) return s def perm_inv(s): for _ in range(10): bits = bytes_to_bits(s) pre = [0]*NUM_BITS for i in range(NUM_BITS): pre[i] = bits[BIT_PERM[i]] t = bits_to_bytes(pre) o = bytearray(len(t)) for i, v in enumerate(t): o[i] = (INV_SBOX[v >> 4] << 4) | INV_SBOX[v & 0x0F] s = bytes(o) return s def prf(idx, key, data_block): aes = AES.new(key, AES.MODE_ECB) seed = sha256(idx.to_bytes(4, "big", signed=True)).digest() mask = aes.encrypt(seed) left = mask[:BLOCK_SIZE] right = mask[-BLOCK_SIZE:] r = xor_bytes(data_block, left) r = perm_fwd(r) return xor_bytes(r, right) def pad_bad(data, bs=BLOCK_SIZE): n = (bs - len(data)) % bs return data + bytes([n]) * n def enc_msg(key, nonce, msg): padded = pad_bad(msg) blocks = [padded[i:i+BLOCK_SIZE] for i in range(0, len(padded), BLOCK_SIZE)] out = [] for i, b in enumerate(blocks): ks = prf(i, key, nonce) out.append(xor_bytes(b, ks)) return b"".join(out) def enc_tag(key, nonce, ad, ct): ad_p = pad_bad(nonce + ad) ct_p = pad_bad(ct) ad_blocks = [ad_p[i:i+BLOCK_SIZE] for i in range(0, len(ad_p), BLOCK_SIZE)] ct_blocks = [ct_p[i:i+BLOCK_SIZE] for i in range(0, len(ct_p), BLOCK_SIZE)] tag = ad_blocks[0] for i, b in enumerate(ad_blocks[1:], start=1): tag = xor_bytes(tag, prf(i + 1337, key, b)) for j, b in enumerate(ct_blocks): tag = xor_bytes(tag, prf(j + 31337, key, b)) return prf(-1, key, tag) def build_keys(): base = "m0leCon" out = [] for p in permutations(base): out.append(sha256("".join(p).encode()).digest()[:BLOCK_SIZE]) return out def precompute_membership(keys, ad, tag0): A1 = ad + bytes([3]) * 3 idx_ad = 1338 nonces = [] for k in keys: aes = AES.new(k, AES.MODE_ECB) seed_m1 = sha256((-1).to_bytes(4, "big", signed=True)).digest() mask_m1 = aes.encrypt(seed_m1) x_m1 = mask_m1[:BLOCK_SIZE] y_m1 = mask_m1[-BLOCK_SIZE:] b_k = prf(idx_ad, k, A1) inv = perm_inv(xor_bytes(tag0, y_m1)) n = xor_bytes(xor_bytes(b_k, x_m1), inv) nonces.append(n) return nonces def precompute_success_triplet(keys, nonce_final, ad, msg): out = [] for k in keys: ct = enc_msg(k, nonce_final, msg) tg = enc_tag(k, nonce_final, ad, ct) out.append((nonce_final, ct, tg)) return out def interact_round(p, nonce_candidates, triples, ad, tag0): candidates = list(range(len(nonce_candidates))) def ask(nonce_blob, ad_hex, ct_hex, tag_hex): p.recvuntil(b"Enter nonce (hex): ") p.sendline(nonce_blob.hex().encode()) p.recvuntil(b"Enter additional_data (hex): ") p.sendline(ad_hex.encode()) p.recvuntil(b"Enter ciphertext (hex): ") p.sendline(ct_hex.encode()) p.recvuntil(b"Enter tag (hex): ") p.sendline(tag_hex.encode()) return p.recvline(timeout=10).decode(errors="ignore") ad_hex = ad.hex() tag0_hex = tag0.hex() tests = 0 while len(candidates) > 1 and tests < 15: tests += 1 mid = len(candidates) // 2 subset = candidates[:mid] nonce_blob = b"".join(nonce_candidates[i] for i in subset) resp = ask(nonce_blob, ad_hex, "", tag0_hex) if "This message does not seem ok" in resp: candidates = subset else: candidates = candidates[mid:] idx = candidates[0] n_final, ct_final, tg_final = triples[idx] resp = ask(n_final, ad_hex, ct_final.hex(), tg_final.hex()) if "There you go!" not in resp: raise RuntimeError("Final attempt failed: " + resp) ad = b"pretty please" msg = b"next round please" tag0 = b"\x00" * 16 nonce_final = b"N" * 16 keys = build_keys() nonce_candidates = precompute_membership(keys, ad, tag0) triples = precompute_success_triplet(keys, nonce_final, ad, msg) # p = remote(host, port, timeout=10) p = process(["python3", "guess_me.py"]) for _ in range(5): interact_round(p, nonce_candidates, triples, ad, tag0) p.interactive() ``` </details> ## One more bit <details> <summary>fhe_server.py</summary> ```python! from __future__ import annotations import json import sys from typing import Any, Callable, Dict, List import os from ind_cpa_d_bitwise_game import ( AddCircuit, BitwiseCKKSIndCpaDGame, HomomorphicCKKSFunction, MultiplyCircuit, SquareCircuit, ) flag = os.environ.get("FLAG", "ptm{REDACTED}") CIRCUIT_FACTORIES: Dict[str, Callable[[], HomomorphicCKKSFunction]] = { "square": SquareCircuit, "add": AddCircuit, "sum": AddCircuit, "mul": MultiplyCircuit, "multiply": MultiplyCircuit, } class INDCPADServer: """Simple line-oriented server exposing the IND-CPA-D oracles to a client.""" def __init__( self, *, mult_depth: int = 2, scale_mod_size: int = 50, bit_length: int = 64, ) -> None: self.game = BitwiseCKKSIndCpaDGame( mult_depth=mult_depth, scale_mod_size=scale_mod_size, batch_size=1, bit_length=bit_length, ) self._resolved = False # ------------------------------------------------------------------ # # Public API def encrypt(self, m0: float, m1: float) -> Dict[str, Any]: ciphertext = self.game.enc_oracle(m0, m1) index = len(self.game.state) - 1 return {"status": "ok", "state_index": index} def eval(self, function: str, indices: List[int]) -> Dict[str, Any]: circuit = self._resolve_circuit(function) self.game.eval_oracle(tuple(indices), circuit) index = len(self.game.state) - 1 return {"status": "ok", "state_index": index} def decrypt(self, index: int, position: int) -> Dict[str, Any]: bit = self.game.dec_oracle(index, position) if bit is None: return {"status": "denied", "reason": "bit-mismatch"} return {"status": "ok", "bit": int(bit)} def guess(self, bit: int) -> Dict[str, Any]: if self._resolved: return {"status": "error", "error": "game already resolved"} self._resolved = True if (bit & 1) == self.game.challenge_bit: return {"status": "ok", "result": "WIN"} return {"status": "ok", "result": "LOSE"} def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]: command = request.get("command") if command is None: raise ValueError("missing command field") command = str(command).lower() if command == "encrypt": return self.encrypt(float(request["m0"]), float(request["m1"])) if command == "eval": function = str(request["function"]) indices = [int(idx) for idx in request["indices"]] return self.eval(function, indices) if command == "decrypt": return self.decrypt(int(request["index"]), int(request["position"])) if command == "guess": return self.guess(int(request["bit"])) raise ValueError(f"unknown command '{command}'") # ------------------------------------------------------------------ # # Helpers @staticmethod def _resolve_circuit(descriptor: str) -> HomomorphicCKKSFunction: keyword = descriptor.lower() builder = CIRCUIT_FACTORIES.get(keyword) if builder is None: raise ValueError(f"unsupported circuit '{descriptor}'") return builder() def main() -> None: rounds = 100 for current_round in range(1, rounds + 1): server = INDCPADServer() print(json.dumps({"status": "new_round", "round": current_round})) sys.stdout.flush() while True: line = sys.stdin.readline() if not line: return line = line.strip() if not line: continue try: request = json.loads(line) except json.JSONDecodeError: print(json.dumps({"status": "error", "error": "invalid json"})) sys.stdout.flush() continue try: response = server.handle_request(request) except Exception as exc: response = {"status": "error", "error": str(exc)} print(json.dumps(response)) sys.stdout.flush() result = response.get("result") if result == "WIN": break if result == "LOSE": return print(json.dumps({"status": "ok", "flag": flag})) sys.stdout.flush() if __name__ == "__main__": main() ``` </details> <details> <summary>ind_cpa_d_bitwise_game.py</summary> ```python from __future__ import annotations import secrets from dataclasses import dataclass from typing import List, Protocol, Tuple from openfhe import ( CCParamsCKKSRNS, Ciphertext, CryptoContext, GenCryptoContext, PKESchemeFeature, Plaintext, ) DEBUG_GAME = False def float_to_bits(value: float, bit_length: int, scale_bits: int) -> Tuple[int, ...]: """ Convert a CKKS plaintext slot to a fixed-length bit tuple. Values are first scaled by 2^scale_bits and rounded to the nearest integer. The integer is then represented in two's complement with bit_length bits. """ scaled = int(round(value * (1 << scale_bits))) scaled = abs(scaled) mask = (1 << bit_length) - 1 twos_complement = scaled & mask return tuple((twos_complement >> i) & 1 for i in range(bit_length)) class HomomorphicCKKSFunction(Protocol): """Interface for describing circuits usable by the Eval oracle.""" num_inputs: int def plaintext(self, values: Tuple[float, ...]) -> float: """Evaluate the circuit on cleartext data.""" def ciphertext( self, cc: CryptoContext, ciphertexts: Tuple[Ciphertext, ...] ) -> Ciphertext: """Evaluate the same circuit homomorphically on ciphertext inputs.""" @dataclass class OracleRow: """Single row stored by the challenger.""" m0: float m1: float ciphertext: Ciphertext class BitwiseCKKSIndCpaDGame: """ CKKS IND-CPA-D challenger with bit-guarded decryption oracle. The challenger maintains a state of tuples (m0, m1, bits0, bits1, Enc(m_b)). The decryption oracle reveals Dec(Enc(m_b)) only if the requested bit index matches in both bitstrings. """ def __init__( self, mult_depth: int = 2, scale_mod_size: int = 50, batch_size: int = 1, *, challenge_bit: int | None = None, bit_length: int = 64, ) -> None: params = CCParamsCKKSRNS() params.SetMultiplicativeDepth(mult_depth) params.SetScalingModSize(scale_mod_size) params.SetBatchSize(batch_size) self.cc: CryptoContext = GenCryptoContext(params) self.cc.Enable(PKESchemeFeature.PKE) self.cc.Enable(PKESchemeFeature.KEYSWITCH) self.cc.Enable(PKESchemeFeature.LEVELEDSHE) self.keys = self.cc.KeyGen() self.cc.EvalMultKeyGen(self.keys.secretKey) self.challenge_bit = ( secrets.randbits(1) if challenge_bit is None else (challenge_bit & 1) ) if DEBUG_GAME: print(f"[DEBUG] challenge_bit = {self.challenge_bit}") self.state: List[OracleRow] = [] self.scale_bits = scale_mod_size self.bit_length = bit_length # ------------------------------------------------------------------ # # Helper utilities def _encode(self, value: float) -> Plaintext: return self.cc.MakeCKKSPackedPlaintext([value]) def _to_bits(self, value: float) -> Tuple[int, ...]: return float_to_bits(value, self.bit_length, self.scale_bits) # ------------------------------------------------------------------ # # Oracles def enc_oracle(self, m0: float, m1: float) -> Ciphertext: pt0 = self._encode(m0) pt1 = self._encode(m1) chosen_pt = pt0 if self.challenge_bit == 0 else pt1 ciphertext = self.cc.Encrypt(self.keys.publicKey, chosen_pt) row = OracleRow(m0=m0, m1=m1, ciphertext=ciphertext) self.state.append(row) if DEBUG_GAME: self._debug_log_bits(len(self.state) - 1, m0, m1) return ciphertext def eval_oracle( self, indices: Tuple[int, ...], circuit: HomomorphicCKKSFunction ) -> Ciphertext: if len(indices) != circuit.num_inputs: raise ValueError("indices count does not match circuit arity") rows: List[OracleRow] = [] for index in indices: if index < 0 or index >= len(self.state): raise IndexError("state index out of range") rows.append(self.state[index]) m0_inputs = tuple(row.m0 for row in rows) m1_inputs = tuple(row.m1 for row in rows) ct_inputs = tuple(row.ciphertext for row in rows) new_m0 = circuit.plaintext(m0_inputs) new_m1 = circuit.plaintext(m1_inputs) new_ct = circuit.ciphertext(self.cc, ct_inputs) new_row = OracleRow(m0=new_m0, m1=new_m1, ciphertext=new_ct) self.state.append(new_row) if DEBUG_GAME: self._debug_log_bits(len(self.state) - 1, new_m0, new_m1) return new_ct def dec_oracle(self, index: int, bit_index: int) -> int | None: """ Decryption oracle: reveals only the selected bit of Dec(Enc(m_b)). """ if index < 0 or index >= len(self.state): raise IndexError("state index out of range") if bit_index < 0 or bit_index >= self.bit_length: raise IndexError("bit index out of range") row = self.state[index] bits0 = self._to_bits(row.m0) bits1 = self._to_bits(row.m1) if bits0[bit_index] != bits1[bit_index]: return None plaintext = self.cc.Decrypt(row.ciphertext, self.keys.secretKey) plaintext.SetLength(1) value = plaintext.GetRealPackedValue()[0] bits = float_to_bits(value, self.bit_length, self.scale_bits) bits_str = "".join(str(bit) for bit in bits) # print(f"[DEBUG] dec_oracle[{index}] value={value} bits={bits_str}") return bits[bit_index] # ------------------------------------------------------------------ # # Debug helpers def _debug_log_bits(self, index: int, m0: float, m1: float) -> None: bits0 = "".join(str(bit) for bit in self._to_bits(m0)) bits1 = "".join(str(bit) for bit in self._to_bits(m1)) print(f"[DEBUG] state[{index}] m0={m0} bits={bits0}") print(f"[DEBUG] state[{index}] m1={m1} bits={bits1}") class SquareCircuit(HomomorphicCKKSFunction): num_inputs = 1 def plaintext(self, values: Tuple[float, ...]) -> float: value = values[0] return value * value def ciphertext( self, cc: CryptoContext, ciphertexts: Tuple[Ciphertext, ...] ) -> Ciphertext: ciphertext = ciphertexts[0] return cc.EvalMult(ciphertext, ciphertext) class AddCircuit(HomomorphicCKKSFunction): num_inputs = 2 def plaintext(self, values: Tuple[float, ...]) -> float: return values[0] + values[1] def ciphertext( self, cc: CryptoContext, ciphertexts: Tuple[Ciphertext, ...] ) -> Ciphertext: return cc.EvalAdd(ciphertexts[0], ciphertexts[1]) class MultiplyCircuit(HomomorphicCKKSFunction): num_inputs = 2 def plaintext(self, values: Tuple[float, ...]) -> float: return values[0] * values[1] def ciphertext( self, cc: CryptoContext, ciphertexts: Tuple[Ciphertext, ...] ) -> Ciphertext: return cc.EvalMult(ciphertexts[0], ciphertexts[1]) def _format_plaintext(plaintext: Plaintext, precision: int = 6) -> str: return plaintext.GetFormattedValues(precision) if __name__ == "__main__": game = BitwiseCKKSIndCpaDGame(mult_depth=2, scale_mod_size=50, batch_size=1) ct0 = game.enc_oracle(0.125, 1.75) print(f"state[0] ciphertext: {ct0}") ct1 = game.enc_oracle(-0.5, -0.5) print(f"state[1] ciphertext: {ct1}") square = SquareCircuit() ct_square = game.eval_oracle((0,), square) add = AddCircuit() ct_add = game.eval_oracle((0, 1), add) bit_idx = 10 bit_value = game.dec_oracle(1, bit_idx) if bit_value is None: print(f"bit {bit_idx} differs, decryption denied.") else: print(f"bit {bit_idx} matches, decrypted bit: {bit_value}") ``` </details> ## **Tóm tắt:** Server tương tác bằng IND-CPA-D trên CKKS (FHE xấp xỉ): - Mỗi round sinh ra một bit bí mật `b ∈ {0,1}`. - Oracle `encrypt(m0, m1)` lưu vào state một cặp `(m0, m1)` và trả về chỉ số của `Enc(m_b)`. - Oracle `eval(function, indices)` áp dụng mạch đồng hình (square/add/mul) lên các ciphertext trong `indices`, đồng thời cập nhật cặp plaintext tương ứng bằng cách áp dụng cùng hàm trên `(m0, m1)` theo hai nhánh. Kết quả cũng được lưu vào state và trả về chỉ số mới. - Oracle `decrypt(index, position)` chỉ trả về bit tại vị trí `position` của bản giải mã nếu bit ở vị trí đó của `m0` và `m1` (cặp plaintext tại `state[index]`) là giống nhau; nếu khác sẽ trả về `denied`. - Sau đó ta phải đoán `b`. Qua 100 round đoán đúng sẽ nhận flag. ### Bit-guard hoạt động thế nào? Trong `ind_cpa_d_bitwise_game.py`, giá trị thực được đổi sang bit bởi `float_to_bits` (scale + round + cắt 64-bit two's complement, nhưng ở đây lấy trị tuyệt đối trước khi mask): ```py def float_to_bits(value: float, bit_length: int, scale_bits: int) -> Tuple[int, ...]: scaled = int(round(value * (1 << scale_bits))) scaled = abs(scaled) mask = (1 << bit_length) - 1 twos_complement = scaled & mask return tuple((twos_complement >> i) & 1 for i in range(bit_length)) ``` Oracle `decrypt` chỉ trả về bit nếu hai nhánh `(m0, m1)` có cùng bit tại vị trí yêu cầu: ```py bits0 = self._to_bits(row.m0) bits1 = self._to_bits(row.m1) if bits0[bit_index] != bits1[bit_index]: return None # denied plaintext = self.cc.Decrypt(row.ciphertext, self.keys.secretKey) value = plaintext.GetRealPackedValue()[0] bits = float_to_bits(value, self.bit_length, self.scale_bits) return bits[bit_index] ``` ### Ý tưởng 1) Xây “amplifier” zero-noise lớn: cộng nhiều lần `Enc(0)` rồi `square` để tăng mạnh noise, trong khi plaintext vẫn là 0 nên không chạm bit-guard. ```py def build_amplifier(sock, buf): send({"command":"encrypt","m0":0.0,"m1":0.0}) idx0 = recv()["state_index"] powers = [idx0] for _ in range(NZEROS.bit_length() - 1): send({"command":"eval","function":"add","indices":[powers[-1], powers[-1]]}) powers.append(recv()["state_index"]) cur = pick_subset_sum(powers, NZEROS) while len(cur) > 1: batch = pairwise_add(cur) # gửi nhiều lệnh add theo cặp cur = recv_indices(batch) # thu các chỉ số state mới send({"command":"eval","function":"square","indices":[cur[0]]}) return recv()["state_index"] ``` 2) So sánh hai nhánh bằng cặp đối xứng: tạo hai encrypt đối xứng theo `b` và nhân với amplifier zero. ```py e1 = encrypt(hval, SMALL) # nếu b=0 → biên độ lớn ở nhánh trái e2 = encrypt(SMALL, hval) # nếu b=1 → biên độ lớn ở nhánh phải m1 = eval_mul(e1, amp_idx) # plaintext = 0, nhưng noise phụ thuộc e1 m2 = eval_mul(e2, amp_idx) # plaintext = 0, nhưng noise phụ thuộc e2 # Đọc bit ‘position’ vì bit-guard pass (plaintext 0 ở cả hai nhánh) b1 = decrypt(m1, position) b2 = decrypt(m2, position) vote = sign(b1 - b2) # b1 có xu hướng lớn hơn nếu b=0, ngược lại nếu b=1 ``` 3) Bỏ phiếu với nhiều `hval` và `position`: chọn các `hval ∈ {1e5,1e6,1e7,1e8}` và `position ∈ {20,24}` để ổn định thống kê; quyết định `b` theo tổng phiếu. ```py PROBES = [1e5, 1e6, 1e7, 1e8] SMALL = 1e-3 TEST_POSITIONS = [20, 24] best_guess = 0 for pos in TEST_POSITIONS: votes = 0 for hval in PROBES: e1 = encrypt(hval, SMALL) e2 = encrypt(SMALL, hval) m1 = eval_mul(e1, amp_idx) m2 = eval_mul(e2, amp_idx) b1 = decrypt(m1, pos) b2 = decrypt(m2, pos) votes += 1 if b1 > b2 else (-1 if b1 < b2 else 0) if abs(votes) improves best: best_guess = 0 if votes >= 0 else 1 send({"command":"guess","bit":best_guess}) ``` 4) Lặp 100 round để đạt `WIN` và lấy flag <details> <summary>solve.py</summary> ```python import socket, json HOST = "one-more-bit.challs.m0lecon.it" PORT = 24180 NZEROS = 120 PROBES = [1e5, 1e6, 1e7, 1e8] SMALL = 0.001 TEST_POSITIONS = [20, 24] READ_TIMEOUT = 8 RECV_CHUNK = 4096 def _recv_line(sock, buf): while True: nl = buf.find(b"\n") if nl != -1: line = buf[:nl+1] del buf[:nl+1] return line.decode().strip(), buf try: chunk = sock.recv(RECV_CHUNK) except socket.timeout: raise TimeoutError("recv timeout") if not chunk: raise EOFError("connection closed") buf += chunk def recv_json(sock, buf): s, buf = _recv_line(sock, buf) try: return json.loads(s), buf except Exception: return {"status": "bad_json", "raw": s}, buf def recv_n_json(sock, buf, n): out = [] for _ in range(n): r, buf = recv_json(sock, buf) out.append(r) return out, buf def build_amplifier(sock, buf): sock.sendall((json.dumps({"command": "encrypt", "m0": 0.0, "m1": 0.0}, separators=(",", ":")) + "\n").encode()) r0, buf = recv_json(sock, buf) idx0 = r0["state_index"] powers = [idx0] for _ in range(NZEROS.bit_length() - 1): sock.sendall((json.dumps({"command": "eval", "function": "add", "indices": [powers[-1], powers[-1]]}, separators=(",", ":")) + "\n").encode()) r, buf = recv_json(sock, buf) powers.append(r["state_index"]) add_list, need = [], NZEROS for i in range(len(powers) - 1, -1, -1): v = 1 << i if v <= need: add_list.append(powers[i]) need -= v if need == 0: break cur = add_list[:] while len(cur) > 1: cmds, nxt = [], [] it = iter(cur) for a in it: try: b = next(it) cmds.append({"command": "eval", "function": "add", "indices": [a, b]}) except StopIteration: nxt.append(a) for c in cmds: sock.sendall((json.dumps(c, separators=(",", ":")) + "\n").encode()) if cmds: rs, buf = recv_n_json(sock, buf, len(cmds)) nxt.extend(r["state_index"] for r in rs) cur = nxt sum_idx = cur[0] sock.sendall((json.dumps({"command": "eval", "function": "square", "indices": [sum_idx]}, separators=(",", ":")) + "\n").encode()) r, buf = recv_json(sock, buf) return r["state_index"], buf def main(): sock = socket.create_connection((HOST, PORT)) sock.settimeout(READ_TIMEOUT) buf = bytearray() probe_order = sorted(PROBES, key=lambda x: abs(x)) mid = len(probe_order) // 2 PROBES_ORDERED = [probe_order[mid]] + probe_order[:mid] + probe_order[mid+1:] POS_ORDERED = TEST_POSITIONS[:] while True: msg, buf = recv_json(sock, buf) if msg.get("status") == "new_round": amp_idx, buf = build_amplifier(sock, buf) best_pos = None best_score = 0 best_guess = 0 for pos in POS_ORDERED: votes = 0 for i, hval in enumerate(PROBES_ORDERED): sock.sendall((json.dumps({"command": "encrypt", "m0": hval, "m1": SMALL}, separators=(",", ":")) + "\n").encode()) sock.sendall((json.dumps({"command": "encrypt", "m0": SMALL, "m1": hval}, separators=(",", ":")) + "\n").encode()) rs, buf = recv_n_json(sock, buf, 2) e1, e2 = rs[0]["state_index"], rs[1]["state_index"] sock.sendall((json.dumps({"command": "eval", "function": "mul", "indices": [e1, amp_idx]}, separators=(",", ":")) + "\n").encode()) sock.sendall((json.dumps({"command": "eval", "function": "mul", "indices": [e2, amp_idx]}, separators=(",", ":")) + "\n").encode()) rs, buf = recv_n_json(sock, buf, 2) m1, m2 = rs[0]["state_index"], rs[1]["state_index"] sock.sendall((json.dumps({"command": "decrypt", "index": m1, "position": pos}, separators=(",", ":")) + "\n").encode()) sock.sendall((json.dumps({"command": "decrypt", "index": m2, "position": pos}, separators=(",", ":")) + "\n").encode()) rs, buf = recv_n_json(sock, buf, 2) b1 = rs[0]["bit"] if rs[0].get("status") == "ok" else None b2 = rs[1]["bit"] if rs[1].get("status") == "ok" else None if b1 is None and b2 is None: vote = 0 elif b1 is not None and b2 is None: vote = 1 elif b1 is None and b2 is not None: vote = -1 else: vote = 1 if b1 > b2 else (-1 if b1 < b2 else 0) votes += vote if abs(votes) > (len(PROBES_ORDERED) - (i + 1)): break if abs(votes) > abs(best_score): best_score = votes best_guess = 0 if votes >= 0 else 1 best_pos = pos if abs(best_score) > len(PROBES_ORDERED): break sock.sendall((json.dumps({"command": "guess", "bit": best_guess}, separators=(",", ":")) + "\n").encode()) resp, buf = recv_json(sock, buf) if resp.get("result") == "LOSE": break if "flag" in resp: print(resp["flag"]) break elif "flag" in msg: print(msg["flag"]) break if __name__ == "__main__": main() ``` </details> Vì bài này có timeout mà trên máy tính lại không đủ nhanh để chạy nên chúng ta có thể đưa code lên kaggle để chạy. ![image](https://hackmd.io/_uploads/rkkCdB1lbl.png)