# Codegate CTF 2020 Preliminary
Lattice-based attack on an NTRUEncrpytion system. Can't solved during on-time cause I didn't find why 2 is in coefficients.
## Chanllange detail
We get a following sage script
```sage
#!/usr/bin/env sage
from sage.misc.banner import version_dict
from Crypto.Util.number import long_to_bytes as l2b
from Crypto.Util.number import bytes_to_long as b2l
from Crypto.Util.Padding import pad
from Crypto.Cipher import AES
from os import urandom
assert version_dict()["major"] >= 9
class Chall:
def __init__(self, N, p, q):
self.N, self.p, self.q = N, p, q
self.R = PolynomialRing(Integers(q), "x")
self.x = self.R.gen()
self.S = self.R.quotient(self.x ^ N - 1, "x")
self.h, self.f = None, None
def random(self):
return self.S([randint(-1, 1) for _ in range(self.N)])
def keygen(self):
while True:
self.F = self.random()
self.f = self.p * self.F + 1
try:
self.z = self.f ^ -1
except:
continue
break
while True:
self.g = self.random()
try:
self.g ^ -1
except:
continue
break
self.h = self.p * self.z * self.g
def getPublicKey(self):
return list(self.h)
def getPrivateKey(self):
return list(self.f)
def encrypt(self, m):
m_encoded = self.encode(b2l(m))
e = self.random() * self.h + self.S(m_encoded)
return list(e)
def decrypt(self, e, privkey):
e, privkey = self.S(e), self.S(privkey)
temp = map(Integer, list(privkey * e))
temp = [t - self.q if t > self.q // 2 else t for t in temp]
temp = [t % self.p for t in temp]
pt_encoded = [t - self.p if t > self.p // 2 else t for t in temp]
pt = l2b(self.decode(pt_encoded))
return pt
def encode(self, value):
assert 0 <= value < 3 ^ self.N
out = []
for _ in range(self.N):
out.append(value % 3 - 1)
value -= value % 3
value /= 3
return out
def decode(self, value):
out = sum([(value[i] + 1) * 3 ^ i for i in range(len(value))])
return out
def count(self, row):
p = sum([e == 1 for e in row])
n = sum([e == self.q - 1 for e in row])
return p, len(row) - p - n, n
def wrapper(N, p, q, pt):
chall = Chall(N, p, q)
chall.keygen()
print(chall.getPublicKey())
print(chall.encrypt(pt))
print(chall.count(list((chall.F))))
if __name__ == "__main__":
key = urandom(16)
cipher = AES.new(key, AES.MODE_ECB)
flag = pad(open("flag.txt", "rb").read(), 16)
enc_flag = b2l(cipher.encrypt(flag))
print(enc_flag)
key1, key2 = key[:8], key[8:]
wrapper(55, 3, 4027, key1)
wrapper(60, 3, 1499, key2)
```
and it's ```output```
```
2960408776014513590203667205130185225161547470030516261741102417822093600856513664346223496713014612247754765985505434047965417819771431223015026059243409921418043319365743779292681722097463141
[3627, 1889, 3460, 2627, 3545, 1478, 2307, 3378, 3350, 1272, 2445, 3881, 3110, 1628, 1798, 1826, 259, 1983, 453, 52, 2650, 834, 3307, 907, 2762, 3452, 1085, 3059, 3544, 1136, 3767, 2346, 1952, 699, 3023, 531, 1208, 1449, 3636, 1742, 2692, 1128, 1683, 1152, 2584, 637, 3053, 2072, 2687, 1811, 2981, 3288, 2324, 3632, 1813]
[426, 3379, 3985, 160, 2502, 3592, 55, 1753, 3599, 2656, 2380, 582, 1038, 1028, 791, 1695, 1783, 3814, 3687, 3742, 1892, 1053, 2728, 3946, 801, 238, 3766, 1355, 1219, 528, 3560, 9, 3737, 1975, 1469, 85, 1373, 3717, 195, 3252, 2020, 1087, 201, 2536, 1655, 3380, 2322, 2438, 803, 2838, 1034, 457, 3050, 4010, 231]
(22, 18, 15)
[314, 1325, 1386, 176, 369, 1029, 877, 1255, 111, 1226, 117, 0, 210, 761, 938, 273, 525, 751, 1085, 372, 1333, 898, 780, 44, 649, 1463, 326, 354, 116, 1080, 1065, 1109, 358, 275, 1209, 964, 101, 950, 415, 1492, 1197, 921, 1000, 1028, 1400, 43, 1003, 914, 447, 360, 1171, 1109, 223, 1134, 1157, 1383, 784, 189, 870, 565]
[378, 753, 466, 825, 320, 658, 630, 288, 16, 576, 134, 914, 549, 489, 197, 1392, 328, 361, 1241, 50, 710, 315, 526, 1250, 977, 453, 225, 433, 1342, 1005, 1432, 143, 1326, 1426, 1251, 1397, 237, 1202, 555, 83, 994, 446, 1406, 356, 1127, 1469, 485, 1034, 1224, 230, 1445, 825, 630, 1158, 815, 807, 837, 747, 423, 184]
(20, 20, 20)
```
## Background overview
NTRUEncrypt is a lattice-based alternative to RSA and ECC and based on the SVP in a lattice. Compare to RSA's factorization, SVP is not known to be breakable using quantum computers. Its security relies on the difficulty of factoring a certain polynomial in a truncated polynomial ring into a quotient of two polynomials having very small coefficients. These coefficients are critical on security of system, otherwise it can be easily breakable with SVP attack.
## Analysis
The code follows general NTRUEncrpytion process
1. Create the Polynomial ring of ```x``` in ${\mathbb{R}}$ with given ```q```.
2. ```S``` is the quotient of R with ```self.x ^ N -1```
3. ```F``` is a public key in NTRUEncrpytion, and there is a one weird thing the general encrpytion process doesn't do, which I missed during CTF. It add 1 to the constant term of ```f```. (This is the part that I didnt' get during CTF...)
4. The constant term will be one of -2, 1, 4 due to this process, while other coefficient is -3, 0, 3.
## Vulnerability
As I mentioned above, choice of ```N, p, q``` is critical in secure encrpytion process. There is complex math how to choose ```N, p, q```, but we will skip for now.
We now the following equation create private key of encrpytion.
\begin{align}
{p}\boldsymbol{g'}=\boldsymbol{f'}*\boldsymbol{h}\bmod
q
\end{align}
This time, we will mimic $f, g$ with SVP using LLL algorithm.
Create the $2N \times 2N$ matrix.
\begin{align}\boldsymbol{L'} = \left[\begin{array}{cccccccc}
q I_{1} & q I_{2} & \cdots & q I_{n} & 0 & 0 & \cdots & 0\\
q I_{2} & q I_{3} & & q I_{1} & 0& 0 & & \vdots\\
\vdots & & \ddots & \vdots & \vdots & & \ddots & \vdots\\
q I_{n}& \cdots & \cdots & q I_{n-1} & 0 & \cdots & \cdots & 0\\
H_{1} & H_{2} & \cdots & H_{n}& \lambda I_{1} & \lambda I_{2} & \cdots & \lambda I_{n}\\
H_{2} & H_{3} & & H_{1}& \lambda I_{2} & \lambda I_{3} & & \lambda I_{n}\\
\vdots & & \ddots& \vdots & \vdots & & \ddots& \vdots \\
H_{n}& \cdots & \cdots & H_{n-1} & \lambda I_{n} & & \cdots & \lambda I_{n-1}\\
\end{array}\right]
\end{align}
A vector in the column span of $\boldsymbol{L'}$ will be of the form
\begin{align}
v'_{\boldsymbol{f'}, \boldsymbol{x}} = \left[\begin{array}{c}\lambda \boldsymbol{f
'} &
\boldsymbol{g'}\end{array}
\right]
\end{align}
Setting $\lambda$ with 1, the first half of the output of ```M.LLL()[0]``` will be our private key.
I successfully recover $v'_{\boldsymbol{f'}, \boldsymbol{x}}$ during CTF, but one thing that I missed is the step 4 of Analysis. Since, we are finding private key over circularly shifted version of $\boldsymbol{L'}$, so do output. At this point, we can find one of -2, 1, 4 in the output array while others -3, 0, 3. Since original private key has -2, 1, 4 in the constant term. Shifting $\boldsymbol{f'}$, until its constant term is that numbers will fully recover $\boldsymbol{f}$
## Exploit code
```sage
from Crypto.Util.number import long_to_bytes as l2b
#!/usr/bin/env sage
from sage.misc.banner import version_dict
from Crypto.Util.number import long_to_bytes as l2b
from Crypto.Util.number import bytes_to_long as b2l
from Crypto.Util.number import inverse
from Crypto.Cipher import AES
from os import urandom
#assert version_dict()["major"] >= 9
class Chall:
def __init__(self, N, p, q):
self.N, self.p, self.q = N, p, q
self.R = PolynomialRing(Integers(q), "x")
self.x = self.R.gen()
self.S = self.R.quotient(self.x ^ N - 1, "x")
self.h, self.f = None, None
def random(self):
return self.S([randint(-1, 1) for _ in range(self.N)])
def keygen(self):
while True:
self.F = self.random()
self.f = self.p * self.F + 1
try:
self.z = self.f ^ -1
except:
continue
break
while True:
self.g = self.random()
try:
self.g ^ -1
except:
continue
break
self.h = self.p * self.z * self.g
#print(list(self.h * self.f * self.random()))
def getPublicKey(self):
return list(self.h)
def getPrivateKey(self):
return list(self.f)
def encrypt(self, m):
m_encoded = self.encode(b2l(m))
e = self.random() * self.h + self.S(m_encoded)
return list(e)
def decrypt(self, e, privkey):
e, privkey = self.S(e), self.S(privkey)
temp = map(Integer, list(privkey * e))
temp = [t - self.q if t > self.q // 2 else t for t in temp]
temp = [t % self.p for t in temp]
pt_encoded = [t - self.p if t > self.p // 2 else t for t in temp]
pt = l2b(self.decode(pt_encoded))
return pt
def encode(self, value):
assert 0 <= value < 3 ^ self.N
out = []
for _ in range(self.N):
out.append(value % 3 - 1)
value -= value % 3
value /= 3
return out
def decode(self, value):
out = sum([(value[i] + 1) * 3 ^ i for i in range(len(value))])
return out
def count(self, row):
p = sum([e == 1 for e in row])
n = sum([e == self.q - 1 for e in row])
return p, len(row) - p - n, n
def balancedmod(f,q):
g = list(((f[i] + q//2) % q) - q//2 for i in range(n))
return Zx(g)
def convolution(f,g):
return (f * g) % (x^n-1)
def invertmodprime(f,p):
T = Zx.change_ring(Integers(p)).quotient(x^n-1)
return Zx(lift(1 / T(f)))
'''
key1
'''
h_l = [3627, 1889, 3460, 2627, 3545, 1478, 2307, 3378, 3350, 1272, 2445, 3881, 3110, 1628, 1798, 1826, 259, 1983, 453, 52, 2650, 834, 3307, 907, 2762, 3452, 1085, 3059, 3544, 1136, 3767, 2346, 1952, 699, 3023, 531, 1208, 1449, 3636, 1742, 2692, 1128, 1683, 1152, 2584, 637, 3053, 2072, 2687, 1811, 2981, 3288, 2324, 3632, 1813]
(n, p, q) = (55, 3, 4027)
e_l = [426, 3379, 3985, 160, 2502, 3592, 55, 1753, 3599, 2656, 2380, 582, 1038, 1028, 791, 1695, 1783, 3814, 3687, 3742, 1892, 1053, 2728, 3946, 801, 238, 3766, 1355, 1219, 528, 3560, 9, 3737, 1975, 1469, 85, 1373, 3717, 195, 3252, 2020, 1087, 201, 2536, 1655, 3380, 2322, 2438, 803, 2838, 1034, 457, 3050, 4010, 231]
Zx.<x> = ZZ[]
h = Zx(h_l)
hh = Integers(q)(1/p)
h3 = balancedmod(hh*h,q)
M = matrix(2*n)
for i in range(n): M[i, i] = q
for i in range(n, 2*n): M[i,i] = 1
for i in range(n):
for j in range(n):
M[i+n,j] = convolution(h3,x^i)[j]
f = list(M.LLL()[0][n:])
F = (Zx(f)-1)/3
chall1 = Chall(55, 3, 4027)
R = PolynomialRing(Integers(q), "x")
x = R.gen()
S = R.quotient(x ^ n - 1, "x")
if 2 in f:
for i in range(len(f)):
f[i] = -1*f[i]
f = f[f.index(-2):]+f[:f.index(-2)]
print("Recovered poly f is ")
print(f)
print("Decrypting...")
key1 = chall1.decrypt(e_l, f)
print([ord(b) for b in key1])
print("key length is ", len(key1))
print("==================================\n\n")
h_2= [314, 1325, 1386, 176, 369, 1029, 877, 1255, 111, 1226, 117, 0, 210, 761, 938, 273, 525, 751, 1085, 372, 1333, 898, 780, 44, 649, 1463, 326, 354, 116, 1080, 1065, 1109, 358, 275, 1209, 964, 101, 950, 415, 1492, 1197, 921, 1000, 1028, 1400, 43, 1003, 914, 447, 360, 1171, 1109, 223, 1134, 1157, 1383, 784, 189, 870, 565]
e_2 = [378, 753, 466, 825, 320, 658, 630, 288, 16, 576, 134, 914, 549, 489, 197, 1392, 328, 361, 1241, 50, 710, 315, 526, 1250, 977, 453, 225, 433, 1342, 1005, 1432, 143, 1326, 1426, 1251, 1397, 237, 1202, 555, 83, 994, 446, 1406, 356, 1127, 1469, 485, 1034, 1224, 230, 1445, 825, 630, 1158, 815, 807, 837, 747, 423, 184]
(n, p, q) = (60, 3, 1499)
Zx.<x> = ZZ[]
h = Zx(h_2)
hh = Integers(q)(1/p)
h3 = balancedmod(hh*h,q)
M = matrix(2*n)
for i in range(n): M[i, i] = q
for i in range(n, 2*n): M[i,i] = 1
for i in range(n):
for j in range(n):
M[i+n,j] = convolution(h3,x^i)[j]
f = list(M.LLL()[0][n:])
F = (Zx(f)-1)/3
chall2 = Chall(60, 3, 1499)
R = PolynomialRing(Integers(q), "x")
x = R.gen()
S = R.quotient(x ^ n - 1, "x")
if 2 in f:
for i in range(len(f)):
f[i] = -1*f[i]
f = f[f.index(-2):]+f[:f.index(-2)]
print("Recovered poly f is ")
print(f)
print("Decrypting...")
key2 = chall2.decrypt(e_2, f)
print([ord(b) for b in key2])
print("key length is ", len(key2))
print("==================================\n")
encflag = l2b(2960408776014513590203667205130185225161547470030516261741102417822093600856513664346223496713014612247754765985505434047965417819771431223015026059243409921418043319365743779292681722097463141)
key = key1 + key2
cipher = AES.new(key, AES.MODE_ECB)
print(cipher.decrypt(encflag))
#CODEGATE2020{86f94100f760b45e9c0f6925f5b474b24387ff6be5732ab88d74b4bfbff35951}
```