# AIS3 Final Project
## reverse
### Prufer's backpack on the tree
***預期難度:4***
這題主要想考驗的是看懂算法邏輯、用 script 做靜態分析,然後最後藏了一個小梗。在 symbol 上幾乎沒有做混淆。
prufer 序列是一種長度為 n - 2 且可以和 n 個節點的無根樹一一對應的序列。
https://oi-wiki.org/graph/prufer/
本題會輸入一個序列,然後程式會將這個序列轉成一顆樹,通過驗證後會用這個序列當作 key 來解密 flag。每個節點會有一個權重,這些權重極大,我們將正確的那顆樹以某個點作為根節點去計算每個節點的子樹權重總和。驗證的方式是我們會以相同的方式去處理以使用者輸入所生出來的樹,並且比對每個子樹的權重總和是否相同,如果都相同,那就通過。
這題主要是要還原出這棵樹的結構,然後跟著題目提供的文章去計算這棵樹對應的 prufer 序列。
首先,點權重的部份可以直接從 ida 中獲得。

然後,我們會發現在比對總和的時候,它每此都會呼叫一個函式去獲得值,然後這些函數的指標儲存在一個陣列中,可以發現函數的名稱就是 func{i} , i 是他的編號。所以我們接下來可以寫個 script 把函數全部輸出,然後再把總和的陣列還原出來。我用的是 pyhidra ,他是可以拿來寫 ghidra script 的 library。
```python
import os
import sys
import pyhidra
pyhidra.start()
from ghidra.app.decompiler.flatapi import FlatDecompilerAPI # type: ignore
code = ""
with pyhidra.open_program("./problem") as f:
#program = f.getCurrentProgram()
decomp = FlatDecompilerAPI(f)
listing = f.currentProgram.getListing()
program = f.getCurrentProgram()
function_manager = program.getFunctionManager()
functions = function_manager.getFunctions(True)
for function in functions:
code += decomp.decompile(function) + "\n"
print(f"Function: {function.getName()} at {function.getEntryPoint()}")
with open("output.cpp", "w") as f:
f.write(code)
```
```python
import re
with open("output.cpp", "r") as f:
code = f.read()
s = []
funcs = re.findall("func[0-9]*\(void\)[^}]*}", code)
for func in funcs:
idx = int(re.findall("func[0-9]*", func)[0][4:])
num = int(re.findall("return .*;", func)[0][7:-1], 16)
s.append((idx, num))
s.sort(key=lambda x: x[0])
sum = [j for i, j in s]
print(sum)
```
接下來要來還原這棵樹,我們會發現對於每個節點來說,找出他的子節點們是一個背包問題。對於葉節點來說,他的子樹權重和等於他自己的權重。對於其他節點來說,他的子樹權重和減掉他自己的權重會等於他的兒子們的子樹權重和的總和,以下用圖舉個小例子:

所以我們從葉節點開始往上解背包問題,然後就可以把整棵樹的結構解出來,最後轉成 prufer 序列。
```python
from sage.all import *
s = [827903036302539, 472506890491747, 748878441336022, 218289997734629, 158305400769101, 319811412484931, 916559854564748, 810184087203501, 1378970505927300, 5081228659264937, 49320173279561551, 611804439437144, 2254887631289413, 225368244774651, 496125585468781, 275952636545815, 908132089593552, 617046982684921, 147361159155314, 852695215362317, 856166548379750, 450776975258791, 824520084293885, 232266479339872, 264566941549413, 219446900243417, 654463779922936, 983710647021443, 628250169887674, 311111075214197, 976228475815369, 980212380102852, 983388806206647, 410454347210760, 3238997673430209, 710579998810410, 1341363969211692, 4579366976959670, 2046419656267150, 38892728232323038, 2696966885346286, 2248961027542928, 1255077556777535, 1566108176370958, 461638662476362, 1580421546484908, 1588084349040908, 1713631369566437, 8611585654392639, 639262370960124, 148418722850372, 21615812440617695, 9017644598167856, 1414502452176552, 20391042833543260, 51737354874956465, 5622146568732780, 6086519610209602, 1592255557457413, 23631050752643070, 956393029684018, 1116007352675956, 878330466285353, 7408652903370597, 3876098743739887, 2002507787223168, 473909748617744, 4858915182013873, 1873896584892917, 1605932082606902, 7128112219353120, 7641608934376784, 2332166837953268, 1575336715139383, 2398425776080207, 193951491521424, 2971272746596356, 3688476757225245, 2685220492643573, 1585175750964442, 873294262599488, 38701295049730901, 35124003417118009, 39684564358251645, 908408128837443, 879933949988774, 215060981748532, 609415787346519, 2902832330017448, 4523493340065506, 371967205023147, 1467988002728568, 949422451883961, 315729027727047, 569802212803973, 308844336846535, 57666678886609864, 6261227943751375, 943597270197070, 1062076113313947]
w = [827903036302539, 472506890491747, 748878441336022, 218289997734629, 158305400769101, 319811412484931, 916559854564748, 810184087203501, 123892949149765, 864062640230589, 617964323142050, 611804439437144, 380991046396496, 225368244774651, 496125585468781, 275952636545815, 908132089593552, 617046982684921, 147361159155314, 852695215362317, 856166548379750, 450776975258791, 824520084293885, 232266479339872, 264566941549413, 219446900243417, 654463779922936, 983710647021443, 628250169887674, 311111075214197, 976228475815369, 980212380102852, 983388806206647, 178187867870888, 990036645887281, 710579998810410, 361151589108840, 832697686829850, 440487573660248, 191433182592137, 137919523185460, 284865835514268, 506199115441513, 393601548523710, 461638662476362, 968617107047764, 472076996364952, 529270604183179, 833390109781416, 445310879438700, 148418722850372, 814315259863675, 406058943775217, 931645512892510, 122058049343973, 196675868049827, 182093374102526, 464373041476822, 217799505703279, 758928958380330, 956393029684018, 159614322991938, 878330466285353, 280540684017477, 656218821278059, 911549357870784, 473909748617744, 987620825145520, 924474133008956, 642623965837431, 866884275601745, 160357925270261, 352885894531503, 196366209212083, 972457014896484, 193951491521424, 924853090329206, 413677222184650, 877689243359248, 975759963617923, 873294262599488, 892071139969319, 612346056667946, 791836125928607, 693347147088911, 879933949988774, 215060981748532, 609415787346519, 875085491170128, 835016582840261, 371967205023147, 994078254110824, 949422451883961, 315729027727047, 569802212803973, 308844336846535, 664690610141337, 174708333541773, 943597270197070, 351496114503537]
n = 100
v = [[] for i in range(n)]
def backpack(arr, total):
l = len(arr)
mat = (identity_matrix(l) * 2).augment(matrix(arr).T)
mat = mat.stack(ones_matrix(1, l).augment(matrix([total])))
mat = mat.LLL()
k = 0
for i in mat:
if i[-1] == 0 and len(set(i[:-1])) <= 2:
k = list(i[:-1])
break
if k == 0:
return 0
k = list(map(lambda x: 0 if x == 1 else 1, k))
if sum(arr[idx] * k[idx] for idx in range(len(k))) == total:
return k
k = list(map(lambda x: (x + 1) % 2, k))
if sum(arr[idx] * k[idx] for idx in range(len(k))) == total:
return k
return 0
notused = set(range(n))
leaf = []
newleaf = []
for i in range(n):
if w[i] == s[i]:
leaf.append(i)
notused.remove(i)
p = [0 for i in range(n + 1)]
d = [0 for i in range(n + 1)]
newleaf = set(leaf)
while len(notused) > 0:
print(f"{notused=}")
print(f"{leaf=}")
for i in notused:
k = backpack([s[j] for j in leaf], s[i] - w[i])
if k != 0:
for idx in range(len(k)):
if k[idx] == 1:
p[leaf[idx] + 1] = i + 1
d[leaf[idx] + 1] += 1
d[i + 1] += 1
newleaf.remove(leaf[idx])
newleaf.add(i)
leaf = list(newleaf)
for i in newleaf:
if i in notused:
notused.remove(i)
seq = []
now = 1
while d[now] != 1:
now += 1
for i in range(n - 2):
seq.append(p[now])
d[p[now]] -= 1
d[now] -= 1
if d[p[now]] == 1 and p[now] < now:
now = p[now]
else:
now += 1
while d[now] != 1:
now += 1
print(" ".join(list(map(str, seq))))
```
然後把它們以空格隔開輸入到執行檔中就可以得到 flag : ||FLAG{Y0u_H3lP_Pruf3R_a_l0t!!!!!}||
## crypto
### zkp-revenge-revenge-revenge (Adapted from AIS3 pre-exam "zkp-revenge")
***預期難度:3***
這題修改自今年的 AIS3 pre-exam 中的 zkp-revenge。
```python
import random
from secret import flag
from Crypto.Util.number import bytes_to_long, getPrime, isPrime, getStrongPrime
def getStrongStrongPrime(n: int):
prime = getStrongPrime(n)
for i in range(50):
prime = max(prime, getStrongPrime(n))
return prime
def zkp_protocol(p, g, sk):
# y = pow(g, sk, p)
r = random.randrange(1, 1 << 1024)
a = pow(g, r, p)
print(f'a = {a}')
print('Give me the challenge')
try:
c = int(input('c = '))
if (p - (1<<600)) >= c >= (1 << 600):
w = (c * sk + r) % (p-1)
print(f'w = {w}')
# you can verify I know the flag with
# g^w (mod p) = (g^flag)^c * g^r (mod p) = y^c * a (mod p)
else:
print("Sorry, but your 'c' is too small.")
print("My 'x' is already too small, so I cannot let you choose such a small number!")
raise ValueError()
except:
print('Invalid input.')
if __name__ == "__main__":
n = 179637947143411747317382401131782052611269076387204917234735876474493709963329273116047020556411954632471433561628457015268318063906205354262895842115659827639177298461010941729174078197874742916562629141952049438894591570918472233933980559006374998517736641088753467961471766948061104063699338592342702310377
# generated by n = getStrongStrongPrime(1024)
e = 5
y = pow(e, bytes_to_long(flag), n)
print("""
******************************************************
Have you heard of Zero Knowledge Proof? I cannot give
you the flag, but I want to show you I know the flag.
So, let me show you with ZKP.
------------------------------------------------------
1) Printe public key.
2) Run ZKP protocol.
3) Bye~
******************************************************
""")
for _ in range(23):
try:
option = input("Option: ")
if int(option) == 1:
print('My public key:')
print(f'p = {n}')
print(f'g = {e}')
print(f'y = {y}')
elif int(option) == 2:
zkp_protocol(n, e, bytes_to_long(flag))
else:
print("Bye~~~~~")
break
except:
print("Something wrong?")
exit()
```
我改了兩個地方,第一個是我生成的質數變成生成 50 個質數後取最大的,第二個是隨機產生的 r 的隨機範圍從 [0, 1<<1000) 變成 [1, 1<<1024)。
從表面上來看它變困難了,原本能去拿到一些乾淨的 bits 或者是當作 HNP 去解,可是現在都不能做了。其實關鍵就在 [1, 1<<1024] 這個範圍。首先,如果 sk 是奇數(實際也是如此),那麼當 $c=\frac{p-1}{2}$ 的時候,$w=\frac{p-1}{2}+r\ (mod\ (p-1))$,所以 $r=w-\frac{p-1}{2}\ (mod(p-1))$。然後,然後因為 $p$ 和 $r$ 都是 1024 bits,然後 $p$ 又盡量取大,所以有非常大的機率是 $r<p-1$,所以我們能或的得 $r$ 確切的值。**
接下來考慮 [1, 1<<1024) 這個範圍,觀察一下 `random.randrange()` 的原始碼。
```python
def randrange(self, start, stop=None, step=_ONE):
...
width = istop - istart
...
# Fast path.
if istep == 1:
if width > 0:
return istart + self._randbelow(width)
...
```
它去呼叫了 `randbelow(width)` ,在這裡是 (1 << 1024) - 1。
接下來去看一下 `randbelow()`
```python
def _randbelow_with_getrandbits(self, n):
"Return a random int in the range [0,n). Returns 0 if n==0."
if not n:
return 0
getrandbits = self.getrandbits
k = n.bit_length() # don't use (n-1) here because n can be 1
r = getrandbits(k) # 0 <= r < 2**k
while r >= n:
r = getrandbits(k)
return r
```
以發現他是先去看他有多少個 bits ,接下來看他有沒有大於等於 (1 << 2024) - 1,如果大於等於的話就會把它丟棄掉再拿下一個,然而因為 (1 << 2024) - 1 全部都是 1 ,所以它絕對不會丟棄任何東西。
最後,他的 getrandbits(1024) 是把 32 個 getrandbits(32) 結合起來,而我們只需要 624 個 getrandbits(32) 就可以預測接下來所有內容了,以下直接寫 exploit。
```python
import random
from randcrack import RandCrack
from pwn import *
from Crypto.Util.number import *
r = remote("localhost", 12345)
r.sendlineafter(b"Option: ", b"1")
r.recvuntil(b"p = ")
p = int(r.recvline().decode().strip())
rc = RandCrack()
cnt = 0
for i in range(20):
r.sendlineafter(b"Option: ", b"2")
r.sendlineafter(b"c = ", str((p - 1) // 2).encode())
r.recvuntil(b"w = ")
w = int(r.recvline().decode().strip())
k = (w - (p - 1) // 2) % (p - 1) - 1
if cnt == 624:
print(k == rc.predict_getrandbits(1024))
break
for j in range(32):
if cnt == 624:
rc.predict_getrandbits(32)
continue
rc.submit(k % (2 ** 32))
k >>= 32
cnt += 1
r.sendlineafter(b"Option: ", b"2")
r.sendlineafter(b"c = ", str(1 + 2** 600).encode())
r.recvuntil(b"w = ")
w = int(r.recvline().decode().strip())
sc = (w - rc.predict_getrandbits(1024) - 1) * pow(1 + 2 ** 600, -1, p - 1) % (p - 1)
print(long_to_bytes(sc))
```
然後就可以得到 flag : ||FLAG{yeah_python_random_is_the_best}||