{%hackmd @themes/orangeheart %}
# KMACTF 2024
## EzRSA
Souce chall của bài đây.
```python=
import base64
import hashlib
from Crypto.PublicKey import RSA
from Crypto.Util.number import *
import random
FLAG = "KMACTF{s0m3_r3ad4ble_5tr1ng_like_7his}"
def verifySignature(public_key, msg: bytes, sig_bytes: bytes, key_size: int, prefix: bytes, namePrefix):
HASH_FUNC = {
'MD5': hashlib.md5(msg).digest(),
'SHA-1': hashlib.sha1(msg).digest(),
'SHA-256': hashlib.sha256(msg).digest(),
'SHA-384': hashlib.sha384(msg).digest(),
'SHA-512': hashlib.sha512(msg).digest()
}
message_sum = HASH_FUNC[namePrefix]
c = bytes_to_long(sig_bytes)
m = long_to_bytes(pow(c, public_key.e, public_key.n))
em = bytearray(key_size//8)
em[key_size//8-len(m):] = m
em = bytes(em)
i = 0
if em[i] != 0 or em[i+1] != 1:
return 0
i = i + 2
while i < key_size//8 and em[i] == 0xff:
i += 1
if em[i] != 0:
return 0
i += 1
if em[i:i+len(prefix)] != prefix:
return 0
i = i + len(prefix)
if em[i:i+len(message_sum)] != message_sum:
return 0
return 1
def main():
HASH_ASN1 = {
'MD5': b'\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10',
'SHA-1': b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14',
'SHA-256': b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20',
'SHA-384': b'\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30',
'SHA-512': b'\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'
}
namePrefix = random.choice(list(HASH_ASN1.keys()))
p, q = getPrime(1024), getPrime(1024)
modulus = p*q
key_size = modulus.bit_length()
public_key = RSA.construct((modulus, 3))
print("Hash:", namePrefix)
print("Modulus =", modulus)
try:
message = input("Enter the message you want to verify: ")
signature = base64.b64decode(input("Enter its base64 signature: "))
except:
print("Invalid input!")
exit()
err = verifySignature(public_key, message.encode(), signature, key_size, HASH_ASN1[namePrefix], namePrefix)
if err:
print("Well done! This is your flag:", FLAG)
exit()
else:
print("Not good enough! Try harder! -.-")
exit()
if __name__ == "__main__":
main()
```
Ta cùng phân tích hàm verify của bài.
```python=
HASH_FUNC = {
'MD5': hashlib.md5(msg).digest(),
'SHA-1': hashlib.sha1(msg).digest(),
'SHA-256': hashlib.sha256(msg).digest(),
'SHA-384': hashlib.sha384(msg).digest(),
'SHA-512': hashlib.sha512(msg).digest()
}
message_sum = HASH_FUNC[namePrefix]
c = bytes_to_long(sig_bytes)
m = long_to_bytes(pow(c, public_key.e, public_key.n))
em = bytearray(key_size//8)
em[key_size//8-len(m):] = m
em = bytes(em)
```
Ta thấy ``message_sum = hash(msg)``, ``c = signature mà mình nhập``, ``m = c^e % n``, thường là key_size = 2048, thế nên ``em = [\x00]*256``, và sau đó ``em[key_size//8-len(m):] = m``.
Ta chuẩn bị 1 giá trị ``send = b''``
```python=
i = 0
if em[i] != 0 or em[i+1] != 1:
return 0
```
Từ có ta có thể thấy được byte đầu và byte thứ 2 phải là 0 và 1, thế nên ``send = b'\x00\x01'``
```python=
i = i + 2
while i < key_size//8 and em[i] == 0xff:
i += 1
if em[i] != 0:
return 0
```
Từ đó, mình sẽ thêm vào giá trị ``send = b'\x00\x01\xff\xff\x00'`` thực ra vài cái ff cũng được nha.
Tiếp theo ta thấy rằng
```python=
if em[i:i+len(prefix)] != prefix:
return 0
i = i + len(prefix)
if em[i:i+len(message_sum)] != message_sum:
return 0
```
Hai điều kiện sau là HASH_ASN1 và hash của message mình gửi vào, thế nên giá trị ``send = b'\x00\x01\xff\xff\x00' + ASN1 + message_hash``.
Giờ mình sẽ padd thêm cho giá trị send này, sau đó căn cho 3, từ đó, khi ^3 lên ở hàm verify, ta sẽ thu được giá trị send này.
```python3=
from gmpy2 import iroot
from Crypto.Util.number import*
import hashlib
import base64
HASH_ASN1 = {
'MD5': b'\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10',
'SHA-1': b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14',
'SHA-256': b'\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20',
'SHA-384': b'\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30',
'SHA-512': b'\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'
}
def new_signature(message, n, e, h):
HASH_FUNC = {
'MD5': hashlib.md5(message).digest(),
'SHA-1': hashlib.sha1(message).digest(),
'SHA-256': hashlib.sha256(message).digest(),
'SHA-384': hashlib.sha384(message).digest(),
'SHA-512': hashlib.sha512(message).digest()
}
message_hash = HASH_FUNC[h]
key_length = n.bit_length()
ASN1 = HASH_ASN1[h]
block = b'\x00\x01\xff\xff\x00' + ASN1 + message_hash
pad = (((key_length + 7) // 8) - len(block)) * b'\xff'
block += pad
print(len(block))
block = bytes_to_long(block)
new_sign = iroot(block, e)[0]
return long_to_bytes(new_sign)
n = 18039602488649683626669737663806860962971652662290679291844248738107765586091088472139974176635665475859569301149035536461272371178315725766197531007359686884746487459310286890453402574981484886813503743557447149217612997372814452433061858599480584164081813477944519468342334704562569381023062471665238036972802001875590298208847646904455700395656757723848953392700473073304863464215448194064126393413594828574470176251015882263516118825722370206505755684875041229214564482263412342153753797258399997994855405011639212319296245796906277059585505616097014632091029166976556328158124015021263754429466589922984292303859
e = 3
h = "SHA-256"
inp = b'hello'
out = new_signature(inp,n,e,h)
print(base64.b64encode(out))
```

**Flag: KMACTF{W0w!!_Y0u're_s0_g00d_4t_Bl3ich3nb4ch3r}**
## Encrypt Message System
Source chall của bài này đây.
```python=
import secrets
import hashlib
from Crypto.Cipher import ChaCha20_Poly1305
from Crypto.Util.number import getPrime
import json
flag = b'KMACTF{******************************}'
l = 16
key = secrets.token_bytes(32)
def enc(cmt, nonce):
nonce = hashlib.sha256(nonce).digest()[:12]
cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
ct, tag = cipher.encrypt_and_digest(cmt)
return nonce + ct + tag
def dec(cmt):
nonce = cmt[:12]
ct = cmt[12:-16]
tag = cmt[-16:]
cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
pt = cipher.decrypt_and_verify(ct, tag)
return pt
def polynomial_evaluation(coefficients, x):
at_x = 0
for i in range(len(coefficients)):
at_x += coefficients[i] * (x ** i)
at_x = at_x % p
return at_x
def verify(enc_message):
message = dec(enc_message)
if message == b'give me the flag':
return True
return False
print("Welcome to the encrypt message system\n")
print("1. Encrypt message")
print("2. Get flag")
p = getPrime(512)
print(f"p = {p}")
while True:
try:
option = input("Enter option: ")
if option == "1":
coefficients = [secrets.randbelow(p) for _ in range(l)]
print(f"coefficients = {coefficients}")
message = input("Enter message: ")
if message == "give me the flag":
print("Invalid message")
break
x = int(input("Enter x: "))
y = polynomial_evaluation(coefficients, x)
encrypted = enc(message.encode(), str(y).encode())
print(encrypted.hex())
if option == "2":
encrypted = bytes.fromhex(input("Enter encrypted message: "))
try:
if verify(encrypted):
print(flag)
except:
print("you failed")
pass
except:
print("Invalid input")
break
```
Ta thấy rằng, hàm dec và hàm enc hoàn toàn không có bug, và chắc chắn hàm bug là hàm ``polynomial_evaluation`` là hàm tạo nonce. Vì khi reuse key và nonce, ta sẽ có cách tấn công. Trước hết, ta cùng nhìn hàm tạo nonce của bài.
```python=
def polynomial_evaluation(coefficients, x):
at_x = 0
for i in range(len(coefficients)):
at_x += coefficients[i] * (x ** i)
at_x = at_x % p
return at_x
```
Ta thấy đây chính là tạo 1 poly f(x), và ta sẽ phải nhập x vào để thu được giá trị at_x.
Giờ để reuse nonce, ta cần phải để cho 2 đa thức $f(x) = g(y) = target$ để có thể ra cùng một giá trị.
Code tìm nghiệm của đa thức để sao cho kết quả bằng target
```python=
def make_poly(a):
out = ""
for i in range(len(a)):
if i < 15:
out += str(a[i]) + "*" + "x**" + str(i) + "+"
else:
out += str(a[i]) + "*" + "x**" + str(i)
return out
p = 10387061657389288245541908712832169493687848959161584257433977380101118700300321090649252887314758838439259285567095846220802337511455717550849895465079881
F = GF(p)
coefficients1 = [6654682310793149262150506873431961663742346170198295897986821650404858255245360588169636643111828505747449041145239535263182023857428570006867763747507337, 2367918927449961221869777558802540535845173404535413736391500754667350425069038259739127575040337360683592251449969113341346672572560616080045451587512553, 3926273836839220006748089884659366658306789030060334980375789751397177817011483513575246664432138637159467399052534785696988577783420649359291621990843104, 5124368819128639164151666376289616029288320009896800179083884351437914443654085616679182878204303407090501671673244421482517983346321095319219537910761933, 410846099728407022450763888901999910503856229385552066113830287255598112841706264812653171773326470778865943038515537706546914605445630853359520457700937, 4520600607856029716312428931359389267882833735547194327102514594389767785237592821524053392092738722769418284919654791902479760134586217864475571359006071, 6988355734455286553336263497212981441782783540566561504495513017917034445309008802046336031362469535491821449454969884574515763451299927662375744407284105, 2832885782488915275338242542270452650067997609346545153389324880570337033197683307512028460473920965854102972688278534322804366654719952211835063230730407, 3864978560967653765316809603277587938329227588551705970818990354174819194317234899083389941514922471362562680288288753153278490720538244995009139820910130, 5265220789417516732554676808534274989368344530152844732067877852421694910531108059192217355126487520096464003899608813499424710865092902281557476527964026, 4255800743938409760060221282948238081820602514163281983397717003497060844470197160489012520525609975130463238836945675547141734278435256244694171065496360, 4342907582207290202177221097475766722307464054060730241911671868267700102425384744538977201535170007588982445976762045029118127889980844923938342747726000, 197892564387870165570077685288731925326257941925460775279202298537108646723258217838517029568434908596965574609067145836372392950995472601013954507964270, 5686566032646754296714624556152641931785451362410647221448052806989054530130722370136620104747960195264748214321613997832679181785661535519529637342231194, 3747562287019834553463257214255383490681372613053700541821627623162096241622439803358428874374163345080430476967067362682374872353017522799651150563048053, 1919882435590715730979819065433034473858870135091446303496765073270207122401651173519308063211963356642819318494806640271128269177752261876145918206255915]
coefficients2 = [1806567027678622968840876715453065762276695867679327938322763656145112845885382539183790342961486746779659624125183350330095495239608348044378845212469993, 5900209277488951865293907321374356504714070524021780337342743106601942120295500298839347541096742440340465469345891832242544373197272303402785427872048010, 6229224578519764163341794687482713387718033321425564607227187870873786431013692773252528548807842620241533093927972617738219967361903986757027275777602307, 6637773630714954229409245351059610067729629844488451476296779456465336104326308752417374457369023597668325091786034022322013888607277695789117712037793896, 5180835468598786794506152596284851072887372019189086295450772490068965326864062675846865678525467182010333672577286298090069654347534197259600119358016008, 3816375300810249725631465752306167324125936996150358910901543147640067714032190859658213202036579180559503173029134362757965219599038226245168525453370868, 1957564144302841525112706777068006292322181703207961600497961245078138850769814585599191251594055519842262956628810545080280874870315617788353831174820981, 3326959815186505398718410132028826000205994934218422556961433245343852546768131873984151019104609984125582582706505161680466367093015269284075249481643646, 5777786584236261264114128896807554972554453712954420266767215735714858188837521430757127734343884789449231947289673670519106426031828844090044985357248902, 2449787134489387945928642811466692825632698483693550363202748708399589135627179284752781414442170449117130240387571203904924846334239304878692825973558874, 4858976628397792692875816446641977405441547264395899580345647059327045543183309323275500488128816859195881983921089754855451792514987719690818858480436941, 2600185321104962837265194294453537426753457453622464059948912076712038909101289311310052462179706344652227956413509730248979599950106094827786559909878587, 3423430297220618108589727870591445481042996738552267763150521698192144495861091171968285127931872038790778782890885839064003369787914450962358838841510679, 6905654673667466087516194334533207496269408282028075151599449638390621820110744850356015911905126009523112820409513791671960454908955270701189369570224066, 763743289514097983771735885593522172935616254165873100327860754594398535184358088581542600643329389893704920491366613360728960648664331632394845641857597, 2549323574830079743290529314712505088633034132772602268839925529781887929618672481146273237871950473421697144513876585797883752083988626587543850829810110]
target = 2975672310188785687385844603660016766206691140065536851146747876253244864625944443947863382270702107635053088907914576631657492317975064684414261425546899
R.<x> = PolynomialRing(F)
f = eval(make_poly(coefficients1))
g = eval(make_poly(coefficients2))
h = f - target
solutions = h.roots()
print("Solutions 1")
for sol, _ in solutions:
print(sol)
h = g - target
solutions = h.roots()
print("Solutions 2")
for sol, _ in solutions:
print(sol)
```
Giờ chạy file source, ta thay giá trị p, thay coeff1, nhập message trùng với "give me the flag", mình sẽ lấy "give me the fake" và "give me the girl".
Sau khi chạy, ta thu được 2 kết quả
```python=
out1 = "540fb957ff77acd995a0620656ca9bc9de2db329b0eeef78a0dc43f5c2d638eedc51bd49f06d7cde585694d3"
out2 = "540fb957ff77acd995a0620656ca9bc9de2db329b0eeef78a1d45afca385eccf8f778a9a5312124d404cfc5e"
```
Ta thấy kết quả gần giống nhau, nhưng mà chỉ là 12 bytes phần nonce thôi. Giờ ta sẽ forgery tag của "give me the flag" để có thể dec ra được yêu cầu của bài.
Có 2 cách, cách thứ nhất là dùng tool. Ta có tool như sau https://github.com/tl2cents/AEAD-Nonce-Reuse-Attacks/blob/main/chacha-poly1305/chacha_poly1305_forgery.py, và đây là code mình sửa để có thể dùng tool.
```python=
from chacha_poly1305_forgery import chachapoly1305_forgery_attack, chachapoly1305_forgery_attack_general
from chacha_poly1305_forgery import poly1305, sage_poly1305, recover_poly1305_key_from_nonce_reuse, chachapoly1305_nonce_reuse_attack, derive_poly1305_key
from sage.all import GF, ZZ, PolynomialRing
from chacha_poly1305_forgery import construct_poly1305_coeffs, forge_poly1305_tag, chachapoly1305_merger, poly1305
import secrets
from Crypto.Cipher import ChaCha20_Poly1305
def test_chachapoly1305_forgery_attack(general=True):
m1 = b'give me the fake'
m2 = b'give me the girl'
target_msg = b"give me the flag"
# 540fb957ff77acd995a06206 56ca9bc9de2db329b0eeef78a0dc43f5 c2d638eedc51bd49f06d7cde585694d3
# 540fb957ff77acd995a06206 56ca9bc9de2db329b0eeef78a1d45afc a385eccf8f778a9a5312124d404cfc5e
a1 = b""
a2 = b""
c1 = bytes.fromhex("56ca9bc9de2db329b0eeef78a0dc43f5")
c2 = bytes.fromhex("56ca9bc9de2db329b0eeef78a1d45afc")
t1 = bytes.fromhex("c2d638eedc51bd49f06d7cde585694d3")
t2 = bytes.fromhex("a385eccf8f778a9a5312124d404cfc5e")
target_a = b""
keys = chachapoly1305_nonce_reuse_attack(a1, c1, t1, a2, c2, t2)
forges = list(chachapoly1305_forgery_attack(a1, c1, t1,
a2, c2, t2,
m1,
target_msg, target_a))
print("Ciphertext =",forges[0][0].hex())
print("Tag =",forges[0][1].hex())
test_chachapoly1305_forgery_attack(False)
```
Và giờ thì chỉ cần kết hợp nonce, ciphertext và tag là dec sẽ ra được "give me the flag".
Đó là cách dùng tool, giờ ta sẽ tới cách toán học he.
Mình tìm thấy một wu này: https://zenn.dev/kurenaif/articles/2a005936de308a#how-to-solve
Ta nhìn vào sơ đồ của loại mã hóa này:

Ta thấy được Ciphertext sau khi xor với Keystream, thì sẽ vào hàm Poly1305, và công thức để tính mac như sau:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mi>t</mi>
<mi>a</mi>
<msub>
<mi>g</mi>
<mn>1</mn>
</msub>
<mo>=</mo>
<mo stretchy="false">(</mo>
<mi>P</mi>
<mi>o</mi>
<mi>l</mi>
<mi>y</mi>
<msub>
<mn>1305</mn>
<mi>r</mi>
</msub>
<mo stretchy="false">(</mo>
<mi>m</mi>
<mi>s</mi>
<msub>
<mi>g</mi>
<mn>1</mn>
</msub>
<mo stretchy="false">)</mo>
<mo>+</mo>
<mi>s</mi>
<mo stretchy="false">)</mo>
<mstyle>
<mspace width="1em"></mspace>
</mstyle>
<mi>m</mi>
<mi>o</mi>
<mi>d</mi>
<mstyle>
<mspace width="1em"></mspace>
</mstyle>
<msup>
<mn>2</mn>
<mrow data-mjx-texclass="ORD">
<mn>128</mn>
</mrow>
</msup>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mo stretchy="false">(</mo>
<mn>1</mn>
<mo stretchy="false">)</mo>
</math>
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mi>t</mi>
<mi>a</mi>
<msub>
<mi>g</mi>
<mn>2</mn>
</msub>
<mo>=</mo>
<mo stretchy="false">(</mo>
<mi>P</mi>
<mi>o</mi>
<mi>l</mi>
<mi>y</mi>
<msub>
<mn>1305</mn>
<mi>r</mi>
</msub>
<mo stretchy="false">(</mo>
<mi>m</mi>
<mi>s</mi>
<msub>
<mi>g</mi>
<mn>2</mn>
</msub>
<mo stretchy="false">)</mo>
<mo>+</mo>
<mi>s</mi>
<mo stretchy="false">)</mo>
<mstyle>
<mspace width="1em"></mspace>
</mstyle>
<mi>m</mi>
<mi>o</mi>
<mi>d</mi>
<mstyle>
<mspace width="1em"></mspace>
</mstyle>
<msup>
<mn>2</mn>
<mrow data-mjx-texclass="ORD">
<mn>128</mn>
</mrow>
</msup>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mo stretchy="false">(</mo>
<mn>2</mn>
<mo stretchy="false">)</mo>
</math>
Trừ hai vế cho nhau
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mi>t</mi>
<mi>a</mi>
<msub>
<mi>g</mi>
<mn>1</mn>
</msub>
<mo>−</mo>
<mi>t</mi>
<mi>a</mi>
<msub>
<mi>g</mi>
<mn>2</mn>
</msub>
<mo>≡</mo>
<mo stretchy="false">(</mo>
<mo stretchy="false">(</mo>
<msubsup>
<mi>c</mi>
<mn>1</mn>
<mn>1</mn>
</msubsup>
<msup>
<mi>r</mi>
<mi>q</mi>
</msup>
<mo>+</mo>
<msubsup>
<mi>c</mi>
<mn>1</mn>
<mn>2</mn>
</msubsup>
<msup>
<mi>r</mi>
<mrow data-mjx-texclass="ORD">
<mi>q</mi>
<mo>−</mo>
<mn>1</mn>
</mrow>
</msup>
<mo>+</mo>
<mo>⋯</mo>
<mo>+</mo>
<msubsup>
<mi>c</mi>
<mn>1</mn>
<mi>q</mi>
</msubsup>
<msup>
<mi>r</mi>
<mn>1</mn>
</msup>
<mo stretchy="false">)</mo>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mo stretchy="false">(</mo>
<mi>m</mi>
<mi>o</mi>
<mi>d</mi>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<msup>
<mn>2</mn>
<mrow data-mjx-texclass="ORD">
<mn>130</mn>
</mrow>
</msup>
<mo stretchy="false">)</mo>
<mo>−</mo>
<mn>5</mn>
<mo stretchy="false">)</mo>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mi>m</mi>
<mi>o</mi>
<mi>d</mi>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<msup>
<mn>2</mn>
<mrow data-mjx-texclass="ORD">
<mn>128</mn>
</mrow>
</msup>
</math>
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mo>−</mo>
<mo stretchy="false">(</mo>
<mo stretchy="false">(</mo>
<msubsup>
<mi>c</mi>
<mn>2</mn>
<mn>1</mn>
</msubsup>
<msup>
<mi>r</mi>
<mi>q</mi>
</msup>
<mo>+</mo>
<msubsup>
<mi>c</mi>
<mn>2</mn>
<mn>2</mn>
</msubsup>
<msup>
<mi>r</mi>
<mrow data-mjx-texclass="ORD">
<mi>q</mi>
<mo>−</mo>
<mn>1</mn>
</mrow>
</msup>
<mo>+</mo>
<mo>⋯</mo>
<mo>+</mo>
<msubsup>
<mi>c</mi>
<mn>2</mn>
<mi>q</mi>
</msubsup>
<msup>
<mi>r</mi>
<mn>1</mn>
</msup>
<mo stretchy="false">)</mo>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mo stretchy="false">(</mo>
<mi>m</mi>
<mi>o</mi>
<mi>d</mi>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<msup>
<mn>2</mn>
<mrow data-mjx-texclass="ORD">
<mn>130</mn>
</mrow>
</msup>
<mo stretchy="false">)</mo>
<mo>−</mo>
<mn>5</mn>
<mo stretchy="false">)</mo>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mi>m</mi>
<mi>o</mi>
<mi>d</mi>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<msup>
<mn>2</mn>
<mrow data-mjx-texclass="ORD">
<mn>128</mn>
</mrow>
</msup>
</math>
Từ đó ta có thể viết được thành
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mi>t</mi>
<mi>a</mi>
<msub>
<mi>g</mi>
<mn>1</mn>
</msub>
<mo>−</mo>
<mi>t</mi>
<mi>a</mi>
<msub>
<mi>g</mi>
<mn>2</mn>
</msub>
<mo>+</mo>
<msup>
<mn>2</mn>
<mrow data-mjx-texclass="ORD">
<mn>128</mn>
</mrow>
</msup>
<mi>k</mi>
<mo>=</mo>
<mo stretchy="false">(</mo>
<mo stretchy="false">(</mo>
<mo stretchy="false">(</mo>
<msubsup>
<mi>c</mi>
<mn>1</mn>
<mn>1</mn>
</msubsup>
<mo>−</mo>
<msubsup>
<mi>c</mi>
<mn>2</mn>
<mn>1</mn>
</msubsup>
<mo stretchy="false">)</mo>
<msup>
<mi>r</mi>
<mi>q</mi>
</msup>
<mo>+</mo>
<mo stretchy="false">(</mo>
<msubsup>
<mi>c</mi>
<mn>1</mn>
<mn>2</mn>
</msubsup>
<mo>−</mo>
<msubsup>
<mi>c</mi>
<mn>2</mn>
<mn>2</mn>
</msubsup>
<mo stretchy="false">)</mo>
<msup>
<mi>r</mi>
<mrow data-mjx-texclass="ORD">
<mi>q</mi>
<mo>−</mo>
<mn>1</mn>
</mrow>
</msup>
<mo>+</mo>
<mo>⋯</mo>
<mo>+</mo>
<mo stretchy="false">(</mo>
<msubsup>
<mi>c</mi>
<mn>1</mn>
<mi>q</mi>
</msubsup>
<mo>−</mo>
<msubsup>
<mi>c</mi>
<mn>2</mn>
<mi>q</mi>
</msubsup>
<mo stretchy="false">)</mo>
<msup>
<mi>r</mi>
<mn>1</mn>
</msup>
<mo stretchy="false">)</mo>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mo stretchy="false">(</mo>
<mi>m</mi>
<mi>o</mi>
<mi>d</mi>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<mstyle>
<mspace width="0.278em"></mspace>
</mstyle>
<msup>
<mn>2</mn>
<mrow data-mjx-texclass="ORD">
<mn>130</mn>
</mrow>
</msup>
<mo stretchy="false">)</mo>
<mo>−</mo>
<mn>5</mn>
<mo stretchy="false">)</mo>
</math>
Nhưng trước hết, ta phải sửa ciphertext thành dạng đúng của nó vì theo https://datatracker.ietf.org/doc/html/rfc7539#section-2.8 thì hàm mã hóa poly1305 sẽ như sau:
```python=
chacha20_aead_encrypt(aad, key, iv, constant, plaintext):
nonce = constant | iv
otk = poly1305_key_gen(key, nonce)
ciphertext = chacha20_encrypt(key, 1, nonce, plaintext)
mac_data = aad | pad16(aad)
mac_data |= ciphertext | pad16(ciphertext)
mac_data |= num_to_4_le_bytes(aad.length)
mac_data |= num_to_4_le_bytes(ciphertext.length)
tag = poly1305_mac(mac_data, otk)
return (ciphertext, tag)
```
Sau đó, ta phải đổi các giá trị ciphertext vừa xong thành dạng số nguyên little endian 17 bytes bằng cách thêm 1 byte vào mỗi đoạn 16 bytes.
Giờ ta lập các poly rồi tính toán c3 và tag3 nhóe. Với các bước lấy ciphertext reuse nonce như trên, ta thu được output như trong note.
```python=
from Crypto.Util.number import *
from sage.all import *
from pwn import*
def convert_to_blocks(ciphertext):
return [ciphertext[i:i + 16] for i in range(0 , len(ciphertext), 16)]
def le_bytes_to_num(inp) -> int:
return int.from_bytes(inp,'little')
# 540fb957ff77acd995a06206 d1242f99e44a71e29644cb006a65515f c0fdc35e9489f9b7687ec12c7318441a
# 540fb957ff77acd995a06206 d1242f99e44a71e29644cb006b6d4856 2c6638c715135334c1df257868ecc538
p = 2**130 - 5
c1 = bytes.fromhex("d1242f99e44a71e29644cb006a65515f")
c2 = bytes.fromhex("d1242f99e44a71e29644cb006b6d4856")
c3 = xor(xor(c1,b"give me the fake"),b"give me the flag")
t1 = bytes.fromhex("c0fdc35e9489f9b7687ec12c7318441a")
t2 = bytes.fromhex("2c6638c715135334c1df257868ecc538")
msg1 = c1 + b'\x00'*8 + long_to_bytes(len(c1)) + b'\x00'*7
msg2 = c2 + b'\x00'*8 + long_to_bytes(len(c2)) + b'\x00'*7
msg3 = c3 + b'\x00'*8 + long_to_bytes(len(c3)) + b'\x00'*7
q = len(msg1)//16
m1_chunks = [msg1[i*16:i*16+16] + b'\x01' for i in range(len(msg1)//16)]
m2_chunks = [msg2[i*16:i*16+16] + b'\x01' for i in range(len(msg2)//16)]
m3_chunks = [msg3[i*16:i*16+16] + b'\x01' for i in range(q)]
coeffs_1 = []
coeffs_2 = []
coeffs_3 = []
for i in range(len(msg1)//16):
k = 0
c_i_1 = 0
c_i_2 = 0
c_i_3 = 0
for j in range(0, 128+1, 8):
c_i_1 += m1_chunks[i][k] * 2**j
c_i_2 += m2_chunks[i][k] * 2**j
c_i_3 += m3_chunks[i][k] * 2**j
k += 1
coeffs_1.append(c_i_1)
coeffs_2.append(c_i_2)
coeffs_3.append(c_i_3)
a1 = int.from_bytes(t1, 'little')
a2 = int.from_bytes(t2, 'little')
R.<r> = GF(p)[]
poly1305_1 = sum([coeffs_1[i] * r**(q-i) for i in range(q)])
poly1305_2 = sum([coeffs_2[i] * r**(q-i) for i in range(q)])
valid_roots = []
for k in (-4, -3, -2, -1, 0, 1, 2, 3, 4):
f = poly1305_1 - poly1305_2 - (a1 - a2 + k*2**128)
roots = f.roots()
for root in roots:
if root[0] <= 2**128:
valid_roots.append(root[0])
print('valid_roots', valid_roots)
r_values = []
s_values = []
for r in valid_roots:
r = Integer(r)
poly1305_1 = sum([coeffs_1[i] * r**(q-i) for i in range(q)]) % p
poly1305_2 = sum([coeffs_2[i] * r**(q-i) for i in range(q)]) % p
s1 = (a1 - poly1305_1) % int(2**128)
s2 = (a2 - poly1305_2) % int(2**128)
if s1 == s2:
r_values.append(r)
s_values.append(s1)
print('r_values', r_values)
print('s_values', s_values)
for i in range(len(r_values)):
r = r_values[i]
s = s_values[i]
print(f'i = {i}')
print(f'--> r = {r}')
print(f'--> s = {s}')
poly1305_3 = sum([coeffs_3[i] * r**(q-i) for i in range(q)]) % p
a3 = (poly1305_3 + s) % 2**128
tag3 = int(a3).to_bytes(16, byteorder='little')
print(c3.hex()+tag3.hex())
print()
```
Giờ kết hợp với phần trên để kết nối thẳng với server he.
solve.py
```python3=
from Crypto.Util.number import *
from sage.all import *
from pwn import*
def split_data(data):
nonce = data[:12]
ct = data[12:-16]
tag = data[-16:]
return ct, tag, nonce
def solve(coeff,p):
def make_poly(a):
out = ""
for i in range(len(a)):
if i < 15:
out += str(a[i]) + "*" + "x**" + str(i) + "+"
else:
out += str(a[i]) + "*" + "x**" + str(i)
return out
target = 2975672310188785687385844603660016766206691140065536851146747876253244864625944443947863382270702107635053088907914576631657492317975064684414261425546899
R.<x> = PolynomialRing(F)
f = eval(make_poly(coeff))
h = f - target
solutions = h.roots()
if len(solutions) == 0:
return 0
else:
return(solutions[0][0])
io = remote("157.15.86.73", 1305)
# io = process(["python3","server.py"])
io.recvuntil(b'p = ')
p = int(io.recvuntil(b'\n',drop=True).decode())
F = GF(p)
# get ciphertext 1
while True:
io.recvuntil(b'Enter option: ')
io.sendline(b'1')
io.recvuntil(b'[')
data = io.recvuntil(b']',drop=True).decode()
coeff1 = [int(i) for i in data.split(',')]
res = solve(coeff1,p)
if res != 0:
io.recvuntil(b'Enter message: ')
io.sendline(b'give me the fake')
io.recvuntil(b'Enter x: ')
io.sendline(str(res).encode())
out = bytes.fromhex(io.recvuntil(b'\n',drop=True).decode())
c1,t1,nonce = split_data(out)
break
else:
io.recvuntil(b'Enter message: ')
io.sendline(b'1')
io.recvuntil(b'Enter x: ')
io.sendline(b'1')
# get ciphertext 2
while True:
io.recvuntil(b'Enter option: ')
io.sendline(b'1')
io.recvuntil(b'[')
data = io.recvuntil(b']',drop=True).decode()
coeff1 = [int(i) for i in data.split(',')]
res = solve(coeff1,p)
if res != 0:
io.recvuntil(b'Enter message: ')
io.sendline(b'give me the girl')
io.recvuntil(b'Enter x: ')
io.sendline(str(res).encode())
out = bytes.fromhex(io.recvuntil(b'\n',drop=True).decode())
c2,t2,nonce = split_data(out)
break
else:
io.recvuntil(b'Enter message: ')
io.sendline(b'1')
io.recvuntil(b'Enter x: ')
io.sendline(b'1')
p_poly = 2**130 - 5
c3 = xor(xor(c1,b"give me the fake"),b"give me the flag")
msg1 = c1 + b'\x00'*8 + long_to_bytes(len(c1)) + b'\x00'*7
msg2 = c2 + b'\x00'*8 + long_to_bytes(len(c2)) + b'\x00'*7
msg3 = c3 + b'\x00'*8 + long_to_bytes(len(c3)) + b'\x00'*7
q = len(msg1)//16
m1_chunks = [msg1[i*16:i*16+16] + b'\x01' for i in range(len(msg1)//16)]
m2_chunks = [msg2[i*16:i*16+16] + b'\x01' for i in range(len(msg2)//16)]
m3_chunks = [msg3[i*16:i*16+16] + b'\x01' for i in range(q)]
coeffs_1 = []
coeffs_2 = []
coeffs_3 = []
for i in range(len(msg1)//16):
k = 0
c_i_1 = 0
c_i_2 = 0
c_i_3 = 0
for j in range(0, 128+1, 8):
c_i_1 += m1_chunks[i][k] * 2**j
c_i_2 += m2_chunks[i][k] * 2**j
c_i_3 += m3_chunks[i][k] * 2**j
k += 1
coeffs_1.append(c_i_1)
coeffs_2.append(c_i_2)
coeffs_3.append(c_i_3)
a1 = int.from_bytes(t1, 'little')
a2 = int.from_bytes(t2, 'little')
R.<r> = GF(p_poly)[]
poly1305_1 = sum([coeffs_1[i] * r**(q-i) for i in range(q)])
poly1305_2 = sum([coeffs_2[i] * r**(q-i) for i in range(q)])
valid_roots = []
for k in (-4, -3, -2, -1, 0, 1, 2, 3, 4):
f = poly1305_1 - poly1305_2 - (a1 - a2 + k*2**128)
roots = f.roots()
for root in roots:
if root[0] <= 2**128:
valid_roots.append(root[0])
r_values = []
s_values = []
for r in valid_roots:
r = Integer(r)
poly1305_1 = sum([coeffs_1[i] * r**(q-i) for i in range(q)]) % p_poly
poly1305_2 = sum([coeffs_2[i] * r**(q-i) for i in range(q)]) % p_poly
s1 = (a1 - poly1305_1) % int(2**128)
s2 = (a2 - poly1305_2) % int(2**128)
if s1 == s2:
r_values.append(r)
s_values.append(s1)
for i in range(len(r_values)):
r = r_values[i]
s = s_values[i]
poly1305_3 = sum([coeffs_3[i] * r**(q-i) for i in range(q)]) % p_poly
a3 = (poly1305_3 + s) % 2**128
tag3 = int(a3).to_bytes(16, byteorder='little')
io.recvuntil(b'Enter option: ')
io.sendline(b'2')
io.recvuntil(b'Enter encrypted message: ')
io.sendline((nonce.hex() + c3.hex()+tag3.hex()).encode())
data = io.recvuntil(b'\n',drop=True)
print(data)
```
**Flag: KMACTF{us3_poly0omia1_t0_gen3rate_n0nce_1s_the_bad_idea}**
Tham khảo: https://www.l3ak.team/2024/04/21/plaid24/#Solution
## Let them share
```python3=
import os
import random
from Crypto.Util.number import bytes_to_long, getPrime
DEGREE = 10
BITSIZE = 64
FLAG = "KMACTF{s0m3_r3ad4ble_5tr1ng_like_7his}"
def get_coeff(p):
# bigger coeff, safer sss :D
while True:
coeff = bytes_to_long(os.urandom(BITSIZE // 16).hex().upper().encode())
if BITSIZE - 1 <= coeff.bit_length() and coeff < p:
return coeff
def _eval_at(poly, x, prime):
"""Evaluates polynomial (coefficient tuple) at x, used to generate a
shamir pool in make_random_shares below.
"""
accum = 0
for coeff in reversed(poly):
accum *= x
accum += coeff
accum %= prime
return x, accum
def make_random_shares(secret, modulus):
coefficients = [secret] + [get_coeff(modulus) for _ in range(DEGREE)]
return [_eval_at(coefficients, random.randint(0, modulus - 1), modulus) for _ in range(DEGREE)]
def main():
p = getPrime(BITSIZE)
SECRET = get_coeff(p)
points = make_random_shares(SECRET, p)
print("Wait, there's something wrong, is our secret lost ?")
print("p =", p)
print("points =", points)
number = int(input("What's our secret ? "))
if number == SECRET:
print("Cool!! You can deal with hard challenge, get your treasure here:", FLAG)
exit()
else:
print("We lost it! :((")
exit()
if __name__ == "__main__":
main()
```
Ta thấy dạng bài này sẽ kiểu là:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mi>p</mi>
<mi>o</mi>
<mi>l</mi>
<mi>y</mi>
<mo>=</mo>
<mo stretchy="false">[</mo>
<mi>s</mi>
<mi>e</mi>
<mi>c</mi>
<mi>r</mi>
<mi>e</mi>
<mi>t</mi>
<mo>,</mo>
<msub>
<mi>n</mi>
<mn>0</mn>
</msub>
<mo>,</mo>
<msub>
<mi>n</mi>
<mn>1</mn>
</msub>
<mo>,</mo>
<mo>.</mo>
<mo>.</mo>
<mo>.</mo>
<mo>,</mo>
<msub>
<mi>n</mi>
<mn>9</mn>
</msub>
<mo stretchy="false">]</mo>
</math>
Và các point sẽ dạng
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mi>p</mi>
<mi>o</mi>
<mi>i</mi>
<mi>n</mi>
<msub>
<mi>t</mi>
<mn>0</mn>
</msub>
<mo>=</mo>
<mi>s</mi>
<mi>e</mi>
<mi>c</mi>
<mi>r</mi>
<mi>e</mi>
<mi>t</mi>
<mo>∗</mo>
<msubsup>
<mi>x</mi>
<mn>0</mn>
<mrow data-mjx-texclass="ORD">
<mn>10</mn>
</mrow>
</msubsup>
<mo>+</mo>
<msub>
<mi>n</mi>
<mn>0</mn>
</msub>
<mo>∗</mo>
<msup>
<mi>x</mi>
<mn>9</mn>
</msup>
<mo>+</mo>
<mo>.</mo>
<mo>.</mo>
<mo>.</mo>
<mo>+</mo>
<msub>
<mi>n</mi>
<mn>8</mn>
</msub>
<mo>∗</mo>
<msub>
<mi>x</mi>
<mn>0</mn>
</msub>
<mo>+</mo>
<msub>
<mi>n</mi>
<mn>9</mn>
</msub>
<mspace width="1em"></mspace>
<mi>mod</mi>
<mstyle scriptlevel="0">
<mspace width="0.167em"></mspace>
</mstyle>
<mstyle scriptlevel="0">
<mspace width="0.167em"></mspace>
</mstyle>
<mo stretchy="false">(</mo>
<mi>p</mi>
<mo stretchy="false">)</mo>
</math>
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mi>p</mi>
<mi>o</mi>
<mi>i</mi>
<mi>n</mi>
<msub>
<mi>t</mi>
<mn>1</mn>
</msub>
<mo>=</mo>
<mi>s</mi>
<mi>e</mi>
<mi>c</mi>
<mi>r</mi>
<mi>e</mi>
<mi>t</mi>
<mo>∗</mo>
<msubsup>
<mi>x</mi>
<mn>1</mn>
<mrow data-mjx-texclass="ORD">
<mn>10</mn>
</mrow>
</msubsup>
<mo>+</mo>
<msub>
<mi>n</mi>
<mn>0</mn>
</msub>
<mo>∗</mo>
<msubsup>
<mi>x</mi>
<mn>1</mn>
<mn>9</mn>
</msubsup>
<mo>+</mo>
<mo>.</mo>
<mo>.</mo>
<mo>.</mo>
<mo>+</mo>
<msub>
<mi>n</mi>
<mn>8</mn>
</msub>
<mo>∗</mo>
<msub>
<mi>x</mi>
<mn>1</mn>
</msub>
<mo>+</mo>
<msub>
<mi>n</mi>
<mn>9</mn>
</msub>
<mspace width="1em"></mspace>
<mi>mod</mi>
<mstyle scriptlevel="0">
<mspace width="0.167em"></mspace>
</mstyle>
<mstyle scriptlevel="0">
<mspace width="0.167em"></mspace>
</mstyle>
<mo stretchy="false">(</mo>
<mi>p</mi>
<mo stretchy="false">)</mo>
</math>
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mi>p</mi>
<mi>o</mi>
<mi>i</mi>
<mi>n</mi>
<msub>
<mi>t</mi>
<mn>9</mn>
</msub>
<mo>=</mo>
<mi>s</mi>
<mi>e</mi>
<mi>c</mi>
<mi>r</mi>
<mi>e</mi>
<mi>t</mi>
<mo>∗</mo>
<msubsup>
<mi>x</mi>
<mn>9</mn>
<mrow data-mjx-texclass="ORD">
<mn>10</mn>
</mrow>
</msubsup>
<mo>+</mo>
<msub>
<mi>n</mi>
<mn>0</mn>
</msub>
<mo>∗</mo>
<msubsup>
<mi>x</mi>
<mn>9</mn>
<mn>9</mn>
</msubsup>
<mo>+</mo>
<mo>.</mo>
<mo>.</mo>
<mo>.</mo>
<mo>+</mo>
<msub>
<mi>n</mi>
<mn>8</mn>
</msub>
<mo>∗</mo>
<msub>
<mi>x</mi>
<mn>9</mn>
</msub>
<mo>+</mo>
<msub>
<mi>n</mi>
<mn>9</mn>
</msub>
<mspace width="1em"></mspace>
<mi>mod</mi>
<mstyle scriptlevel="0">
<mspace width="0.167em"></mspace>
</mstyle>
<mstyle scriptlevel="0">
<mspace width="0.167em"></mspace>
</mstyle>
<mo stretchy="false">(</mo>
<mi>p</mi>
<mo stretchy="false">)</mo>
</math>
Sau khi suy lại bài này, mình có thể dùng LLL, nhưng mà đề bài lại bảo số quá lớn để LLL, thế nên mình đã tìm kiếm cách khác.
Sau khi tìm tòi thì mình thấy có challenge này lấy số và cách mã hóa tương tự, ta cùng tham khảo challenge https://github.com/SSTF-Office/SamsungCTF/blob/main/2018_SCTF/Final/crypto/LCG/deploy/src/LCG.py của SamsungCTF
```python=
import signal
import random
class LCG:
def __init__(self, s, k):
self.init_state = s
k1, k2, k3 = k
self.x = (k1 - 0xdeadbeef) % k3
self.y = (k1 * 0xdeadbeef) % k3
self.z = k2
self.m = k3
def __iter__(self):
self.index = 0
self.size = 32
self.s0, self.s1 = self.init_state
return self
def __next__(self):
if self.index >= self.size:
raise StopIteration
self.index += 1
s0, s1 = self.s0, self.s1
self.s0, self.s1 = s1, (self.x * s1 + self.y * s0 + self.z) % self.m
return self.s1
if __name__ == '__main__':
signal.alarm(60)
s0, s1, k1, k2, k3 = [random.getrandbits(64) for i in range(5)]
s = (s0, s1)
k = (k1, k2, k3)
cnt = 16
for i, v in enumerate(LCG(s, k)):
guess = int(input())
if guess == v:
cnt += 1
else:
cnt -= 1
if i <= 16:
print(v)
if cnt >= 16:
with open('flag.txt', 'r') as f:
print(f.read())
```
Bài này ẩn hết tất cả các giá trị, và mình thấy 1 writeup giải bài này như sau:
```python=
# secret parameters: s0, s1, x, y, z, m
s0 = 3005423129600575593
s1 = 7396509365641243733
x = 8169846461236506548
y = 5748392989531061213
z = 15303690528977248313
m = 15674144604358019630
expsoln = {"x": x, "y": y, "z": z, "m": m}
ks = []
for i in range(9):
s0, s1 = s1, (x * s1 + y * s0 + z) % m
ks.append(s1)
# cheating m
x, y, z = var("x,y,z")
eqns = []
bounds = {x: m, y: m, z: m}
for i in range(2, 7):
eqns.append((ks[i] == x * ks[i - 1] + y * ks[i - 2] + z, m))
solution = solve_linear_mod(eqns, bounds)
print(solution)
assert solution[x] % m == expsoln["x"] % m
assert solution[y] % m == expsoln["y"] % m
assert solution[z] % m == expsoln["z"] % m
print("True")
```
Thì họ cũng dùng một hàm là ``solve_linear_mod``, mình search thì ra một tool là https://github.com/protozeit/toolbase/blob/master/crypto/solvelinmod.py
```python=
from Crypto.Util.number import *
from sage.all import *
from collections.abc import Sequence
import math
import operator
from typing import List, Tuple
from pwn import *
def _process_linear_equations(equations, vars, guesses) -> List[Tuple[List[int], int, int]]:
result = []
for rel, m in equations:
op = rel.operator()
if op is not operator.eq:
raise TypeError(f"relation {rel}: not an equality relation")
expr = (rel - rel.rhs()).lhs().expand()
for var in expr.variables():
if var not in vars:
raise ValueError(f"relation {rel}: variable {var} is not bounded")
# Fill in eqns block of B
coeffs = []
for var in vars:
if expr.degree(var) >= 2:
raise ValueError(f"relation {rel}: equation is not linear in {var}")
coeff = expr.coefficient(var)
if not coeff.is_constant():
raise ValueError(f"relation {rel}: coefficient of {var} is not constant (equation is not linear)")
if not coeff.is_integer():
raise ValueError(f"relation {rel}: coefficient of {var} is not an integer")
coeff = int(coeff)
if m:
coeff %= m
coeffs.append(coeff)
# Shift variables towards their guesses to reduce the (expected) length of the solution vector
const = expr.subs({var: guesses[var] for var in vars})
if not const.is_constant():
raise ValueError(f"relation {rel}: failed to extract constant")
if not const.is_integer():
raise ValueError(f"relation {rel}: constant is not integer")
const = int(const)
if m:
const %= m
result.append((coeffs, const, m))
return result
def solve_linear_mod(equations, bounds, verbose=False, use_flatter=False, **lll_args):
"""Solve an arbitrary system of modular linear equations over different moduli.
equations: A sequence of (lhs == rhs, M) pairs, where lhs and rhs are expressions and M is the modulus.
M may be None to indicate that the equation is not modular.
bounds: A dictionary of {var: B} entries, where var is a variable and B is the bounds on that variable.
Bounds may be specified in one of three ways:
- A single integer X: Variable is assumed to be uniformly distributed in [0, X] with an expected value of X/2.
- A tuple of integers (X, Y): Variable is assumed to be uniformly distributed in [X, Y] with an expected value of (X + Y)/2.
- A tuple of integers (X, E, Y): Variable is assumed to be bounded within [X, Y] with an expected value of E.
All variables used in the equations must be bounded.
verbose: set to True to enable additional output
use_flatter: set to True to use [flatter](https://github.com/keeganryan/flatter), which is much faster
lll_args: Additional arguments passed to LLL, for advanced usage.
NOTE: Bounds are *soft*. This function may return solutions above the bounds. If this happens, and the result
is incorrect, make some bounds tighter and try again.
Tip: if you get an unwanted solution, try setting the expected values to that solution to force this function
to produce a different solution.
Tip: if your bounds are loose and you just want small solutions, set the expected values to zero for all
loosely-bounded variables.
>>> k = var('k')
>>> # solve CRT
>>> solve_linear_mod([(k == 2, 3), (k == 4, 5), (k == 3, 7)], {k: 3*5*7})
{k: 59}
>>> x,y = var('x,y')
>>> solve_linear_mod([(2*x + 3*y == 7, 11), (3*x + 5*y == 3, 13), (2*x + 5*y == 6, 143)], {x: 143, y: 143})
{x: 62, y: 5}
>>> x,y = var('x,y')
>>> # we can also solve homogenous equations, provided the guesses are zeroed
>>> solve_linear_mod([(2*x + 5*y == 0, 1337)], {x: 5, y: 5}, guesses={x: 0, y: 0})
{x: 5, y: -2}
"""
# The general idea is to set up an integer matrix equation Ax=y by introducing extra variables for the quotients,
# then use LLL to solve the equation. We introduce extra axes in the lattice to observe the actual solution x,
# which works so long as the solutions are known to be bounded (which is of course the case for modular equations).
# Scaling factors are configured to generally push the smallest vectors to have zeros for the relations, and to
# scale disparate variables to approximately the same base.
vars = list(bounds)
guesses = {}
var_scale = {}
for var in vars:
bound = bounds[var]
if isinstance(bound, Sequence):
if len(bound) == 2:
xmin, xmax = map(int, bound)
guess = (xmax - xmin) // 2 + xmin
elif len(bound) == 3:
xmin, guess, xmax = map(int, bound)
else:
raise TypeError("Bounds must be integers, 2-tuples or 3-tuples")
else:
xmin = 0
xmax = int(bound)
guess = xmax // 2
if not xmin <= guess <= xmax:
raise ValueError(f"Bound for variable {var} is invalid ({xmin=} {guess=} {xmax=})")
var_scale[var] = max(xmax - guess, guess - xmin, 1)
guesses[var] = guess
var_bits = math.log2(int(prod(var_scale.values()))) + len(vars)
mod_bits = math.log2(int(prod(m for rel, m in equations if m)))
if verbose:
print(f"verbose: variable entropy: {var_bits:.2f} bits")
print(f"verbose: modulus entropy: {mod_bits:.2f} bits")
# Extract coefficients from equations
equation_coeffs = _process_linear_equations(equations, vars, guesses)
is_inhom = any(const != 0 for coeffs, const, m in equation_coeffs)
mod_count = sum(1 for coeffs, const, m in equation_coeffs if m)
NR = len(equation_coeffs)
NV = len(vars)
if is_inhom:
# Add one dummy variable for the constant term.
NV += 1
B = matrix(ZZ, mod_count + NV, NR + NV)
# B format (rows are the basis for the lattice):
# [ mods:NRxNR 0
# eqns:NVxNR vars:NVxNV ]
# eqns correspond to equation axes, fi(...) = yi mod mi
# vars correspond to variable axes, which effectively "observe" elements of the solution vector (x in Ax=y)
# mods and vars are diagonal, so this matrix is lower triangular.
# Compute maximum scale factor over all variables
S = max(var_scale.values())
# Compute equation scale such that the bounded solution vector (equation columns all zero)
# will be shorter than any vector that has a nonzero equation column
eqS = S << (NR + NV + 1)
# If the equation is underconstrained, add additional scaling to find a solution anyway
if var_bits > mod_bits:
eqS <<= int((var_bits - mod_bits) / NR) + 1
col_scales = []
mi = 0
for ri, (coeffs, const, m) in enumerate(equation_coeffs):
for vi, c in enumerate(coeffs):
B[mod_count + vi, ri] = c
if is_inhom:
B[mod_count + NV - 1, ri] = const
if m:
B[mi, ri] = m
mi += 1
col_scales.append(eqS)
# Compute per-variable scale such that the variable axes are scaled roughly equally
for vi, var in enumerate(vars):
col_scales.append(S // var_scale[var])
# Fill in vars block of B
B[mod_count + vi, NR + vi] = 1
if is_inhom:
# Const block: effectively, this is a bound of 1 on the constant term
col_scales.append(S)
B[mod_count + NV - 1, -1] = 1
if verbose:
print("verbose: scaling shifts:", [math.log2(int(s)) for s in col_scales])
print("verbose: matrix dimensions:", B.dimensions())
print("verbose: unscaled matrix before:")
print(B.n())
for i, s in enumerate(col_scales):
B[:, i] *= s
if use_flatter:
from re import findall
from subprocess import check_output
# compile https://github.com/keeganryan/flatter and put it in $PATH
z = "[[" + "]\n[".join(" ".join(map(str, row)) for row in B) + "]]"
ret = check_output(["flatter"], input=z.encode())
B = matrix(B.nrows(), B.ncols(), map(int, findall(b"-?\\d+", ret)))
else:
B = B.LLL(**lll_args)
for i, s in enumerate(col_scales):
B[:, i] /= s
# Negate rows for more readable output
for i in range(B.nrows()):
if sum(x < 0 for x in B[i, :]) > sum(x > 0 for x in B[i, :]):
B[i, :] *= -1
if is_inhom and B[i, -1] < 0:
B[i, :] *= -1
if verbose:
print("verbose: unscaled matrix after:")
print(B.n())
for row in B:
if any(x != 0 for x in row[:NR]):
# invalid solution: some relations are nonzero
continue
if is_inhom:
# Each row is a potential solution, but some rows may not carry a constant.
if row[-1] != 1:
if verbose:
print(
"verbose: zero solution",
{var: row[NR + vi] for vi, var in enumerate(vars) if row[NR + vi] != 0},
)
continue
res = {}
for vi, var in enumerate(vars):
res[var] = row[NR + vi] + guesses[var]
return res
```
Bài của chúng ta cũng giải hệ phương trình tương tự như bài LCG SAMSUNG kia, thế nên mình sẽ sử dụng tool này để làm.
Ban đầu mình có lập hệ như này
```python=
io = process(['python3', 'server.py'])
# io = remote('157.15.86.73', 2004)
secret = int(io.recvline().strip().decode())
print(secret)
secret = (io.recvline().strip().decode())
print(secret)
io.recvuntil(b'p = ')
p = int(io.recvuntil(b'\n',drop=True).decode())
print(p)
io.recvuntil(b'points = ')
points = eval(io.recvuntil(b'\n',drop=True).decode())
x = [i[0] for i in points]
ys = [i[1] for i in points]
bs = [var(f"b_{i}") for i in range(10)]
se = var("se")
print(bs)
bounds = {b_0: p, b_1: p, b_2: p, b_3: p, b_4: p, b_5: p, b_6: p, b_7: p, b_8: p, b_9: p, se: p}
eqns = []
for i in range(10):
eqns.append((b_9*x[i]^10 + b_8*x[i]^9 + b_7*x[i]^8 + b_6*x[i]^7 + b_5*x[i]^6 + b_4*x[i]^5 + b_3*x[i]^4 + b_2*x[i]^3 + b_1*x[i]^2 + b_0*x[i]^1 + se == ys[i], p))
solution = solve_linear_mod(eqns, bounds)
print(solution)
```
Thế nhưng kết quả bị sai, tại vì nó có nhiều nghiệm quá. Dựa vào hint là sử dụng hex number, thế nên mình sẽ tách 1 số 8 byte thành 8 nghiệm và dựa vào công thức như sau:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mi>s</mi>
<mi>e</mi>
<mi>c</mi>
<mi>r</mi>
<mi>e</mi>
<mi>t</mi>
<mo>=</mo>
<msup>
<mn>256</mn>
<mn>7</mn>
</msup>
<mo>.</mo>
<msub>
<mi>b</mi>
<mn>7</mn>
</msub>
<mo>+</mo>
<msup>
<mn>256</mn>
<mn>6</mn>
</msup>
<mo>.</mo>
<msub>
<mi>b</mi>
<mn>6</mn>
</msub>
<mo>+</mo>
<mo>.</mo>
<mo>.</mo>
<mo>.</mo>
<mo>+</mo>
<mn>256.</mn>
<msub>
<mi>b</mi>
<mn>1</mn>
</msub>
<mo>+</mo>
<msub>
<mi>b</mi>
<mn>0</mn>
</msub>
</math>
Và các số tiếp theo cũng như thế, sau đó mình tìm các nghiệm rồi cộng hết vào là tìm được.
Giờ ta sẽ có 10 phương trình tất cả, ta phải nhân với giá trị x để có thể thu được phương trình chính xác nhất.
```python=
p = 11621801616120342377
points = [(15644190507625758959, 12161542648022838840), (4576473806494987889, 12556861567871161465), (10469720920797756484, 7271167879572642711), (15892777276575622987, 720739272163799431), (16683115943588000400, 687963153815711759), (6031950515621716130, 14954786879787403442), (798472918567478938, 15294132676987041310), (11527842205764419603, 14326321063176162855), (15926005302025399922, 6650974629402948954), (651456423855923190, 17065437574275207368)]
xs = [i[0] for i in points]
ys = [i[1] for i in points]
bs = [var(f"b_{i}") for i in range(11*8)]
convert_num = []
cfs = []
for i in range(11):
num = 0
for j in range(8):
num += bs[8*i + j]*256**j
convert_num.append(num)
# print(convert_num)
eqs = []
for i in range(10):
eq = 0
for j in range(11):
eq += pow(xs[i], j, p)*convert_num[j]
print(eq)
eqs.append(((ys[i] == eq), p))
```
Tiếp theo, ta phải tìm được bounds. Ta thấy đoạn code ``coeff = bytes_to_long(os.urandom(BITSIZE // 16).hex().upper().encode())``, từ đó giới hạn của các giá trị b này chỉ có trong khoảng từ int(b'0') tới int(b'F'), hay còn từ 48-70.
```python=
bounds = {b: (bytes_to_long(b'0'),bytes_to_long(b'F')) for b in bs}
```
Đoạn code đầy đủ của bài đây
```python=
from Crypto.Util.number import *
from sage.all import *
from collections.abc import Sequence
import math
import operator
from typing import List, Tuple
from sage.all import ZZ, gcd, matrix, prod, var
from pwn import *
def _process_linear_equations(equations, vars, guesses) -> List[Tuple[List[int], int, int]]:
result = []
for rel, m in equations:
op = rel.operator()
if op is not operator.eq:
raise TypeError(f"relation {rel}: not an equality relation")
expr = (rel - rel.rhs()).lhs().expand()
for var in expr.variables():
if var not in vars:
raise ValueError(f"relation {rel}: variable {var} is not bounded")
# Fill in eqns block of B
coeffs = []
for var in vars:
if expr.degree(var) >= 2:
raise ValueError(f"relation {rel}: equation is not linear in {var}")
coeff = expr.coefficient(var)
if not coeff.is_constant():
raise ValueError(f"relation {rel}: coefficient of {var} is not constant (equation is not linear)")
if not coeff.is_integer():
raise ValueError(f"relation {rel}: coefficient of {var} is not an integer")
coeff = int(coeff)
if m:
coeff %= m
coeffs.append(coeff)
# Shift variables towards their guesses to reduce the (expected) length of the solution vector
const = expr.subs({var: guesses[var] for var in vars})
if not const.is_constant():
raise ValueError(f"relation {rel}: failed to extract constant")
if not const.is_integer():
raise ValueError(f"relation {rel}: constant is not integer")
const = int(const)
if m:
const %= m
result.append((coeffs, const, m))
return result
def solve_linear_mod(equations, bounds, verbose=False, use_flatter=False, **lll_args):
"""Solve an arbitrary system of modular linear equations over different moduli.
equations: A sequence of (lhs == rhs, M) pairs, where lhs and rhs are expressions and M is the modulus.
M may be None to indicate that the equation is not modular.
bounds: A dictionary of {var: B} entries, where var is a variable and B is the bounds on that variable.
Bounds may be specified in one of three ways:
- A single integer X: Variable is assumed to be uniformly distributed in [0, X] with an expected value of X/2.
- A tuple of integers (X, Y): Variable is assumed to be uniformly distributed in [X, Y] with an expected value of (X + Y)/2.
- A tuple of integers (X, E, Y): Variable is assumed to be bounded within [X, Y] with an expected value of E.
All variables used in the equations must be bounded.
verbose: set to True to enable additional output
use_flatter: set to True to use [flatter](https://github.com/keeganryan/flatter), which is much faster
lll_args: Additional arguments passed to LLL, for advanced usage.
NOTE: Bounds are *soft*. This function may return solutions above the bounds. If this happens, and the result
is incorrect, make some bounds tighter and try again.
Tip: if you get an unwanted solution, try setting the expected values to that solution to force this function
to produce a different solution.
Tip: if your bounds are loose and you just want small solutions, set the expected values to zero for all
loosely-bounded variables.
>>> k = var('k')
>>> # solve CRT
>>> solve_linear_mod([(k == 2, 3), (k == 4, 5), (k == 3, 7)], {k: 3*5*7})
{k: 59}
>>> x,y = var('x,y')
>>> solve_linear_mod([(2*x + 3*y == 7, 11), (3*x + 5*y == 3, 13), (2*x + 5*y == 6, 143)], {x: 143, y: 143})
{x: 62, y: 5}
>>> x,y = var('x,y')
>>> # we can also solve homogenous equations, provided the guesses are zeroed
>>> solve_linear_mod([(2*x + 5*y == 0, 1337)], {x: 5, y: 5}, guesses={x: 0, y: 0})
{x: 5, y: -2}
"""
# The general idea is to set up an integer matrix equation Ax=y by introducing extra variables for the quotients,
# then use LLL to solve the equation. We introduce extra axes in the lattice to observe the actual solution x,
# which works so long as the solutions are known to be bounded (which is of course the case for modular equations).
# Scaling factors are configured to generally push the smallest vectors to have zeros for the relations, and to
# scale disparate variables to approximately the same base.
vars = list(bounds)
guesses = {}
var_scale = {}
for var in vars:
bound = bounds[var]
if isinstance(bound, Sequence):
if len(bound) == 2:
xmin, xmax = map(int, bound)
guess = (xmax - xmin) // 2 + xmin
elif len(bound) == 3:
xmin, guess, xmax = map(int, bound)
else:
raise TypeError("Bounds must be integers, 2-tuples or 3-tuples")
else:
xmin = 0
xmax = int(bound)
guess = xmax // 2
if not xmin <= guess <= xmax:
raise ValueError(f"Bound for variable {var} is invalid ({xmin=} {guess=} {xmax=})")
var_scale[var] = max(xmax - guess, guess - xmin, 1)
guesses[var] = guess
var_bits = math.log2(int(prod(var_scale.values()))) + len(vars)
mod_bits = math.log2(int(prod(m for rel, m in equations if m)))
if verbose:
print(f"verbose: variable entropy: {var_bits:.2f} bits")
print(f"verbose: modulus entropy: {mod_bits:.2f} bits")
# Extract coefficients from equations
equation_coeffs = _process_linear_equations(equations, vars, guesses)
is_inhom = any(const != 0 for coeffs, const, m in equation_coeffs)
mod_count = sum(1 for coeffs, const, m in equation_coeffs if m)
NR = len(equation_coeffs)
NV = len(vars)
if is_inhom:
# Add one dummy variable for the constant term.
NV += 1
B = matrix(ZZ, mod_count + NV, NR + NV)
# B format (rows are the basis for the lattice):
# [ mods:NRxNR 0
# eqns:NVxNR vars:NVxNV ]
# eqns correspond to equation axes, fi(...) = yi mod mi
# vars correspond to variable axes, which effectively "observe" elements of the solution vector (x in Ax=y)
# mods and vars are diagonal, so this matrix is lower triangular.
# Compute maximum scale factor over all variables
S = max(var_scale.values())
# Compute equation scale such that the bounded solution vector (equation columns all zero)
# will be shorter than any vector that has a nonzero equation column
eqS = S << (NR + NV + 1)
# If the equation is underconstrained, add additional scaling to find a solution anyway
if var_bits > mod_bits:
eqS <<= int((var_bits - mod_bits) / NR) + 1
col_scales = []
mi = 0
for ri, (coeffs, const, m) in enumerate(equation_coeffs):
for vi, c in enumerate(coeffs):
B[mod_count + vi, ri] = c
if is_inhom:
B[mod_count + NV - 1, ri] = const
if m:
B[mi, ri] = m
mi += 1
col_scales.append(eqS)
# Compute per-variable scale such that the variable axes are scaled roughly equally
for vi, var in enumerate(vars):
col_scales.append(S // var_scale[var])
# Fill in vars block of B
B[mod_count + vi, NR + vi] = 1
if is_inhom:
# Const block: effectively, this is a bound of 1 on the constant term
col_scales.append(S)
B[mod_count + NV - 1, -1] = 1
if verbose:
print("verbose: scaling shifts:", [math.log2(int(s)) for s in col_scales])
print("verbose: matrix dimensions:", B.dimensions())
print("verbose: unscaled matrix before:")
print(B.n())
for i, s in enumerate(col_scales):
B[:, i] *= s
if use_flatter:
from re import findall
from subprocess import check_output
# compile https://github.com/keeganryan/flatter and put it in $PATH
z = "[[" + "]\n[".join(" ".join(map(str, row)) for row in B) + "]]"
ret = check_output(["flatter"], input=z.encode())
B = matrix(B.nrows(), B.ncols(), map(int, findall(b"-?\\d+", ret)))
else:
B = B.LLL(**lll_args)
for i, s in enumerate(col_scales):
B[:, i] /= s
# Negate rows for more readable output
for i in range(B.nrows()):
if sum(x < 0 for x in B[i, :]) > sum(x > 0 for x in B[i, :]):
B[i, :] *= -1
if is_inhom and B[i, -1] < 0:
B[i, :] *= -1
if verbose:
print("verbose: unscaled matrix after:")
print(B.n())
for row in B:
if any(x != 0 for x in row[:NR]):
# invalid solution: some relations are nonzero
continue
if is_inhom:
# Each row is a potential solution, but some rows may not carry a constant.
if row[-1] != 1:
if verbose:
print(
"verbose: zero solution",
{var: row[NR + vi] for vi, var in enumerate(vars) if row[NR + vi] != 0},
)
continue
res = {}
for vi, var in enumerate(vars):
res[var] = row[NR + vi] + guesses[var]
return res
# io = process(['python3', 'server.py'])
io = remote('157.15.86.73', 2004)
# secret = int(io.recvline().strip().decode())
# print(secret)
# secret = (io.recvline().strip().decode())
# print(secret)
io.recvuntil(b'p = ')
p = int(io.recvuntil(b'\n',drop=True).decode())
print(p)
io.recvuntil(b'points = ')
points = eval(io.recvuntil(b'\n',drop=True).decode())
print(points)
xs = [i[0] for i in points]
ys = [i[1] for i in points]
bs = [var(f"b_{i}") for i in range(11*8)]
convert_num = []
cfs = []
for i in range(11):
num = 0
for j in range(8):
num += bs[8*i + j]*256**j
convert_num.append(num)
# print(convert_num)
eqs = []
for i in range(10):
eq = 0
for j in range(11):
eq += pow(xs[i], j, p)*convert_num[j]
eqs.append(((ys[i] == eq), p))
print(len(eqs))
bounds = {b: (bytes_to_long(b'0'),bytes_to_long(b'F')) for b in bs}
sol = solve_linear_mod(eqs, bounds)
print(sol)
ans = ''
for i in range(8):
ans += hex(sol[bs[7-i]])[2:]
print(sol[bs[7-i]])
ans = int(ans, 16)
print(ans)
io.sendline(str(ans).encode())
io.interactive()
```
**FLAG: KMACTF{Us1nG_LLL_1s_4n_4rt_f0rm__:)))}**