# Codegate CTF 2020 Preliminary Lattice-based attack on an NTRUEncrpytion system. Can't solved during on-time cause I didn't find why 2 is in coefficients. ## Chanllange detail We get a following sage script ```sage #!/usr/bin/env sage from sage.misc.banner import version_dict from Crypto.Util.number import long_to_bytes as l2b from Crypto.Util.number import bytes_to_long as b2l from Crypto.Util.Padding import pad from Crypto.Cipher import AES from os import urandom assert version_dict()["major"] >= 9 class Chall: def __init__(self, N, p, q): self.N, self.p, self.q = N, p, q self.R = PolynomialRing(Integers(q), "x") self.x = self.R.gen() self.S = self.R.quotient(self.x ^ N - 1, "x") self.h, self.f = None, None def random(self): return self.S([randint(-1, 1) for _ in range(self.N)]) def keygen(self): while True: self.F = self.random() self.f = self.p * self.F + 1 try: self.z = self.f ^ -1 except: continue break while True: self.g = self.random() try: self.g ^ -1 except: continue break self.h = self.p * self.z * self.g def getPublicKey(self): return list(self.h) def getPrivateKey(self): return list(self.f) def encrypt(self, m): m_encoded = self.encode(b2l(m)) e = self.random() * self.h + self.S(m_encoded) return list(e) def decrypt(self, e, privkey): e, privkey = self.S(e), self.S(privkey) temp = map(Integer, list(privkey * e)) temp = [t - self.q if t > self.q // 2 else t for t in temp] temp = [t % self.p for t in temp] pt_encoded = [t - self.p if t > self.p // 2 else t for t in temp] pt = l2b(self.decode(pt_encoded)) return pt def encode(self, value): assert 0 <= value < 3 ^ self.N out = [] for _ in range(self.N): out.append(value % 3 - 1) value -= value % 3 value /= 3 return out def decode(self, value): out = sum([(value[i] + 1) * 3 ^ i for i in range(len(value))]) return out def count(self, row): p = sum([e == 1 for e in row]) n = sum([e == self.q - 1 for e in row]) return p, len(row) - p - n, n def wrapper(N, p, q, pt): chall = Chall(N, p, q) chall.keygen() print(chall.getPublicKey()) print(chall.encrypt(pt)) print(chall.count(list((chall.F)))) if __name__ == "__main__": key = urandom(16) cipher = AES.new(key, AES.MODE_ECB) flag = pad(open("flag.txt", "rb").read(), 16) enc_flag = b2l(cipher.encrypt(flag)) print(enc_flag) key1, key2 = key[:8], key[8:] wrapper(55, 3, 4027, key1) wrapper(60, 3, 1499, key2) ``` and it's ```output``` ``` 2960408776014513590203667205130185225161547470030516261741102417822093600856513664346223496713014612247754765985505434047965417819771431223015026059243409921418043319365743779292681722097463141 [3627, 1889, 3460, 2627, 3545, 1478, 2307, 3378, 3350, 1272, 2445, 3881, 3110, 1628, 1798, 1826, 259, 1983, 453, 52, 2650, 834, 3307, 907, 2762, 3452, 1085, 3059, 3544, 1136, 3767, 2346, 1952, 699, 3023, 531, 1208, 1449, 3636, 1742, 2692, 1128, 1683, 1152, 2584, 637, 3053, 2072, 2687, 1811, 2981, 3288, 2324, 3632, 1813] [426, 3379, 3985, 160, 2502, 3592, 55, 1753, 3599, 2656, 2380, 582, 1038, 1028, 791, 1695, 1783, 3814, 3687, 3742, 1892, 1053, 2728, 3946, 801, 238, 3766, 1355, 1219, 528, 3560, 9, 3737, 1975, 1469, 85, 1373, 3717, 195, 3252, 2020, 1087, 201, 2536, 1655, 3380, 2322, 2438, 803, 2838, 1034, 457, 3050, 4010, 231] (22, 18, 15) [314, 1325, 1386, 176, 369, 1029, 877, 1255, 111, 1226, 117, 0, 210, 761, 938, 273, 525, 751, 1085, 372, 1333, 898, 780, 44, 649, 1463, 326, 354, 116, 1080, 1065, 1109, 358, 275, 1209, 964, 101, 950, 415, 1492, 1197, 921, 1000, 1028, 1400, 43, 1003, 914, 447, 360, 1171, 1109, 223, 1134, 1157, 1383, 784, 189, 870, 565] [378, 753, 466, 825, 320, 658, 630, 288, 16, 576, 134, 914, 549, 489, 197, 1392, 328, 361, 1241, 50, 710, 315, 526, 1250, 977, 453, 225, 433, 1342, 1005, 1432, 143, 1326, 1426, 1251, 1397, 237, 1202, 555, 83, 994, 446, 1406, 356, 1127, 1469, 485, 1034, 1224, 230, 1445, 825, 630, 1158, 815, 807, 837, 747, 423, 184] (20, 20, 20) ``` ## Background overview NTRUEncrypt is a lattice-based alternative to RSA and ECC and based on the SVP in a lattice. Compare to RSA's factorization, SVP is not known to be breakable using quantum computers. Its security relies on the difficulty of factoring a certain polynomial in a truncated polynomial ring into a quotient of two polynomials having very small coefficients. These coefficients are critical on security of system, otherwise it can be easily breakable with SVP attack. ## Analysis The code follows general NTRUEncrpytion process 1. Create the Polynomial ring of ```x``` in ${\mathbb{R}}$ with given ```q```. 2. ```S``` is the quotient of R with ```self.x ^ N -1``` 3. ```F``` is a public key in NTRUEncrpytion, and there is a one weird thing the general encrpytion process doesn't do, which I missed during CTF. It add 1 to the constant term of ```f```. (This is the part that I didnt' get during CTF...) 4. The constant term will be one of -2, 1, 4 due to this process, while other coefficient is -3, 0, 3. ## Vulnerability As I mentioned above, choice of ```N, p, q``` is critical in secure encrpytion process. There is complex math how to choose ```N, p, q```, but we will skip for now. We now the following equation create private key of encrpytion. \begin{align} {p}\boldsymbol{g'}=\boldsymbol{f'}*\boldsymbol{h}\bmod q \end{align} This time, we will mimic $f, g$ with SVP using LLL algorithm. Create the $2N \times 2N$ matrix. \begin{align}\boldsymbol{L'} = \left[\begin{array}{cccccccc} q I_{1} & q I_{2} & \cdots & q I_{n} & 0 & 0 & \cdots & 0\\ q I_{2} & q I_{3} & & q I_{1} & 0& 0 & & \vdots\\ \vdots & & \ddots & \vdots & \vdots & & \ddots & \vdots\\ q I_{n}& \cdots & \cdots & q I_{n-1} & 0 & \cdots & \cdots & 0\\ H_{1} & H_{2} & \cdots & H_{n}& \lambda I_{1} & \lambda I_{2} & \cdots & \lambda I_{n}\\ H_{2} & H_{3} & & H_{1}& \lambda I_{2} & \lambda I_{3} & & \lambda I_{n}\\ \vdots & & \ddots& \vdots & \vdots & & \ddots& \vdots \\ H_{n}& \cdots & \cdots & H_{n-1} & \lambda I_{n} & & \cdots & \lambda I_{n-1}\\ \end{array}\right] \end{align} A vector in the column span of $\boldsymbol{L'}$ will be of the form \begin{align} v'_{\boldsymbol{f'}, \boldsymbol{x}} = \left[\begin{array}{c}\lambda \boldsymbol{f '} & \boldsymbol{g'}\end{array} \right] \end{align} Setting $\lambda$ with 1, the first half of the output of ```M.LLL()[0]``` will be our private key. I successfully recover $v'_{\boldsymbol{f'}, \boldsymbol{x}}$ during CTF, but one thing that I missed is the step 4 of Analysis. Since, we are finding private key over circularly shifted version of $\boldsymbol{L'}$, so do output. At this point, we can find one of -2, 1, 4 in the output array while others -3, 0, 3. Since original private key has -2, 1, 4 in the constant term. Shifting $\boldsymbol{f'}$, until its constant term is that numbers will fully recover $\boldsymbol{f}$ ## Exploit code ```sage from Crypto.Util.number import long_to_bytes as l2b #!/usr/bin/env sage from sage.misc.banner import version_dict from Crypto.Util.number import long_to_bytes as l2b from Crypto.Util.number import bytes_to_long as b2l from Crypto.Util.number import inverse from Crypto.Cipher import AES from os import urandom #assert version_dict()["major"] >= 9 class Chall: def __init__(self, N, p, q): self.N, self.p, self.q = N, p, q self.R = PolynomialRing(Integers(q), "x") self.x = self.R.gen() self.S = self.R.quotient(self.x ^ N - 1, "x") self.h, self.f = None, None def random(self): return self.S([randint(-1, 1) for _ in range(self.N)]) def keygen(self): while True: self.F = self.random() self.f = self.p * self.F + 1 try: self.z = self.f ^ -1 except: continue break while True: self.g = self.random() try: self.g ^ -1 except: continue break self.h = self.p * self.z * self.g #print(list(self.h * self.f * self.random())) def getPublicKey(self): return list(self.h) def getPrivateKey(self): return list(self.f) def encrypt(self, m): m_encoded = self.encode(b2l(m)) e = self.random() * self.h + self.S(m_encoded) return list(e) def decrypt(self, e, privkey): e, privkey = self.S(e), self.S(privkey) temp = map(Integer, list(privkey * e)) temp = [t - self.q if t > self.q // 2 else t for t in temp] temp = [t % self.p for t in temp] pt_encoded = [t - self.p if t > self.p // 2 else t for t in temp] pt = l2b(self.decode(pt_encoded)) return pt def encode(self, value): assert 0 <= value < 3 ^ self.N out = [] for _ in range(self.N): out.append(value % 3 - 1) value -= value % 3 value /= 3 return out def decode(self, value): out = sum([(value[i] + 1) * 3 ^ i for i in range(len(value))]) return out def count(self, row): p = sum([e == 1 for e in row]) n = sum([e == self.q - 1 for e in row]) return p, len(row) - p - n, n def balancedmod(f,q): g = list(((f[i] + q//2) % q) - q//2 for i in range(n)) return Zx(g) def convolution(f,g): return (f * g) % (x^n-1) def invertmodprime(f,p): T = Zx.change_ring(Integers(p)).quotient(x^n-1) return Zx(lift(1 / T(f))) ''' key1 ''' h_l = [3627, 1889, 3460, 2627, 3545, 1478, 2307, 3378, 3350, 1272, 2445, 3881, 3110, 1628, 1798, 1826, 259, 1983, 453, 52, 2650, 834, 3307, 907, 2762, 3452, 1085, 3059, 3544, 1136, 3767, 2346, 1952, 699, 3023, 531, 1208, 1449, 3636, 1742, 2692, 1128, 1683, 1152, 2584, 637, 3053, 2072, 2687, 1811, 2981, 3288, 2324, 3632, 1813] (n, p, q) = (55, 3, 4027) e_l = [426, 3379, 3985, 160, 2502, 3592, 55, 1753, 3599, 2656, 2380, 582, 1038, 1028, 791, 1695, 1783, 3814, 3687, 3742, 1892, 1053, 2728, 3946, 801, 238, 3766, 1355, 1219, 528, 3560, 9, 3737, 1975, 1469, 85, 1373, 3717, 195, 3252, 2020, 1087, 201, 2536, 1655, 3380, 2322, 2438, 803, 2838, 1034, 457, 3050, 4010, 231] Zx.<x> = ZZ[] h = Zx(h_l) hh = Integers(q)(1/p) h3 = balancedmod(hh*h,q) M = matrix(2*n) for i in range(n): M[i, i] = q for i in range(n, 2*n): M[i,i] = 1 for i in range(n): for j in range(n): M[i+n,j] = convolution(h3,x^i)[j] f = list(M.LLL()[0][n:]) F = (Zx(f)-1)/3 chall1 = Chall(55, 3, 4027) R = PolynomialRing(Integers(q), "x") x = R.gen() S = R.quotient(x ^ n - 1, "x") if 2 in f: for i in range(len(f)): f[i] = -1*f[i] f = f[f.index(-2):]+f[:f.index(-2)] print("Recovered poly f is ") print(f) print("Decrypting...") key1 = chall1.decrypt(e_l, f) print([ord(b) for b in key1]) print("key length is ", len(key1)) print("==================================\n\n") h_2= [314, 1325, 1386, 176, 369, 1029, 877, 1255, 111, 1226, 117, 0, 210, 761, 938, 273, 525, 751, 1085, 372, 1333, 898, 780, 44, 649, 1463, 326, 354, 116, 1080, 1065, 1109, 358, 275, 1209, 964, 101, 950, 415, 1492, 1197, 921, 1000, 1028, 1400, 43, 1003, 914, 447, 360, 1171, 1109, 223, 1134, 1157, 1383, 784, 189, 870, 565] e_2 = [378, 753, 466, 825, 320, 658, 630, 288, 16, 576, 134, 914, 549, 489, 197, 1392, 328, 361, 1241, 50, 710, 315, 526, 1250, 977, 453, 225, 433, 1342, 1005, 1432, 143, 1326, 1426, 1251, 1397, 237, 1202, 555, 83, 994, 446, 1406, 356, 1127, 1469, 485, 1034, 1224, 230, 1445, 825, 630, 1158, 815, 807, 837, 747, 423, 184] (n, p, q) = (60, 3, 1499) Zx.<x> = ZZ[] h = Zx(h_2) hh = Integers(q)(1/p) h3 = balancedmod(hh*h,q) M = matrix(2*n) for i in range(n): M[i, i] = q for i in range(n, 2*n): M[i,i] = 1 for i in range(n): for j in range(n): M[i+n,j] = convolution(h3,x^i)[j] f = list(M.LLL()[0][n:]) F = (Zx(f)-1)/3 chall2 = Chall(60, 3, 1499) R = PolynomialRing(Integers(q), "x") x = R.gen() S = R.quotient(x ^ n - 1, "x") if 2 in f: for i in range(len(f)): f[i] = -1*f[i] f = f[f.index(-2):]+f[:f.index(-2)] print("Recovered poly f is ") print(f) print("Decrypting...") key2 = chall2.decrypt(e_2, f) print([ord(b) for b in key2]) print("key length is ", len(key2)) print("==================================\n") encflag = l2b(2960408776014513590203667205130185225161547470030516261741102417822093600856513664346223496713014612247754765985505434047965417819771431223015026059243409921418043319365743779292681722097463141) key = key1 + key2 cipher = AES.new(key, AES.MODE_ECB) print(cipher.decrypt(encflag)) #CODEGATE2020{86f94100f760b45e9c0f6925f5b474b24387ff6be5732ab88d74b4bfbff35951} ```