# AmateursCTF 2024 Writeup
## crypto/unsuspicious-rsa
> I need help factoring this modulus, it looks suspicious, but I can't factor using any conventional methods.
>
> Files: `output.txt` and `unsuspicious-rsa.py`
>
> Solved by: Jackylkk2003
### Files
`unsuspicious-rsa.py`
```python=
from Crypto.Util.number import *
def nextPrime(p, n):
p += (n - p) % n
p += 1
iters = 0
while not isPrime(p):
p += n
return p
def factorial(n):
if n == 0:
return 1
return factorial(n-1) * n
flag = bytes_to_long(open('flag.txt', 'rb').read().strip())
p = getPrime(512)
q = nextPrime(p, factorial(90))
N = p * q
e = 65537
c = pow(flag, e, N)
print(N, e, c)
```
`output.txt`
```!
172391551927761576067659307357620721422739678820495774305873584621252712399496576196263035396006999836369799931266873378023097609967946749267124740589901094349829053978388042817025552765214268699484300142561454883219890142913389461801693414623922253012031301348707811702687094437054617108593289186399175149061 65537 128185847052386409377183184214572579042527531775256727031562496105460578259228314918798269412725873626743107842431605023962700973103340370786679287012472752872015208333991822872782385473020628386447897357839507808287989016150724816091476582807745318701830009449343823207792128099226593723498556813015444306241
```
### Solution
What the nextPrime function does is to find the first prime strictly greater than $p$ and mod $n$ = 1.
We know that $p < \sqrt{n} < q$, so `nextPrime(p, factorial(90)) == nextPrime(int(N ** 0.5), factorial(90))`.
### Solve Script
```python=
from Crypto.Util.number import *
def nextPrime(p, n):
p += (n - p) % n
p += 1
iters = 0
while not isPrime(p):
p += n
return p
def factorial(n):
if n == 0:
return 1
return factorial(n-1) * n
N = 172391551927761576067659307357620721422739678820495774305873584621252712399496576196263035396006999836369799931266873378023097609967946749267124740589901094349829053978388042817025552765214268699484300142561454883219890142913389461801693414623922253012031301348707811702687094437054617108593289186399175149061
e = 65537
q = nextPrime(int(N**0.5), factorial(90))
p = N // q
phi = (p - 1) * (q - 1)
d = pow(e, -1, phi)
c = 128185847052386409377183184214572579042527531775256727031562496105460578259228314918798269412725873626743107842431605023962700973103340370786679287012472752872015208333991822872782385473020628386447897357839507808287989016150724816091476582807745318701830009449343823207792128099226593723498556813015444306241
flag = pow(c, d, N)
print(long_to_bytes(flag).decode())
```
The output and the flag is `amateursCTF{here's_the_flag_you_requested.}`.
## crypto/decryption-as-a-service
> Hi, please do not decrypt the flag. Or anything else, for that matter...
>
> `nc chal.amt.rs 1417`
>
> Files: `decryption-as-a-service.py`
>
> Solved by: Jackylkk2003
### Files
`decryption-as-a-service.py`
```python=
#!/usr/local/bin/python3
from Crypto.Util.number import *
from math import isqrt
flag = bytes_to_long(open('flag.txt', 'rb').read())
p, q = getPrime(1024), getPrime(1024)
N = p * q
e = getPrime(64)
d = pow(e, -1, N - p - q + 1)
encrypted_flag = pow(flag, e, N)
print(f"{encrypted_flag = }")
try:
for i in range(10):
c = int(input("message? "))
if isqrt(N) < c < N:
if c == encrypted_flag or c == (N - encrypted_flag):
print("sorry, that looks like the flag")
continue
print(hex(pow(c, d, N))[2:])
else:
print("please pick a number which I can (easily) check does not look like the flag.")
except:
exit()
print("ok bye")
```
The server added a proof of work at the beginning to prevent DoS. So that also need to be manually handled (or automatically handled if you wish).
### Observations
For this problem, we can send at most 10 ciphertexts for the server to decrypt, given that the ciphertext is not the $encrypted\_flag$ or $N - encrypted\_flag$, and is in the range $(\sqrt{N}, N)$. Note that we do NOT have $N$.
So it is difficult for us to even make the queries.
### Solution
For the queries, first, we have to find a number that can always decryptable. The number that I have chosen is $2^{2000}$.
For the remaining queries, I put $2^{2001}, 2^{2002}, ..., 2^{2008}$ as the ciphertext, since we have more idea what is going on if we simply scale the inputs by 2. We intentionally hold up one query for future use.
Recall (or read from the code) that the decryption algorithm of RSA is $c^d\ mod\ N$. So for each adjacent powers of 2 as ciphertext, the plaintext would be $2^d$ times the previous one.
That is, if we name the plaintexts $m_0, m_1, ..., m_8$, we have the following relation:
For $i > 0$, $$m_i = 2^d \cdot m_{i-1}\ (mod\ N)$$
And this is a recurrence equation/formula, and is also the formula used in [Linear congruential generators (LCG)](https://en.wikipedia.org/wiki/Linear_congruential_generator). So we can find [something called LCG Hack](https://github.com/TomasGlgg/LCGHack) online to crack the LCG.
After cracking LCG, we know $2^d$ and $N$, but we cannot reverse $d$ from these information.
Instead, we use the last query to send $encrypted\_flag \times 2^k\ mod\ N$ for any integer $k > 0$ such that it the ciphertext is at least $\sqrt{N}$. The server will respond with $encrypted\_flag^d \times 2^{kd}\ mod\ N = flag \times 2^{kd}\ mod\ N$. Since we know $2^d$ and $N$, we can compute the inverse of $2^d\ mod\ N$ and hence the inverse of $2^{kd}\ mod\ N$. Multiply this value with our last decrypted plaintext $m_9$ and we will get the $flag$.
The reason why we need to specify such $k$ is that it is possible for $encrypted\_flag \times 2$ to be greater than $N$. So we need to keep multiplying until we can finally be greater than $\sqrt{N}$.
### Solve Script
```python=
# import argparse
from functools import reduce
from math import gcd, isqrt
from pwn import *
from Crypto.Util.number import long_to_bytes
import sys
import subprocess
sys.setrecursionlimit(10**6) # To prevent RecursionError since numbers are too large
# parser = argparse.ArgumentParser()
# requiredNamed = parser.add_argument_group('required arguments')
# requiredNamed.add_argument('-k', '--known-elements', metavar='ELEMENT', dest='known', nargs='+', type=int,
# help='Known values', required=True)
# parser.add_argument('-m', '--modulus', metavar='MODULUS', dest='modulus', type=int, help='LCG modulus')
# parser.add_argument('-a', '--multiplier', metavar='MULTIPLIER', dest='multi', type=int, help='LCG multiplier')
# parser.add_argument('-c', '--increment', metavar='INCREMENT', dest='inc', type=int, help='LCG increment')
# parser.add_argument('-n', '--next', metavar='COUNT', dest='next', type=int, help='Calculate next values')
# args = parser.parse_args()
def egcd(a, b):
if a == 0:
return b, 0, 1
else:
g, x, y = egcd(b % a, a)
return g, y - (b // a) * x, x
def modinv(b, n):
g, x, _ = egcd(b, n)
if g == 1:
return x % n
def crack_unknown_increment(states, modulus, multiplier):
increment = (states[1] - states[0] * multiplier) % modulus
return increment
def crack_unknown_multiplier(states, modulus):
multiplier = (states[2] - states[1]) * modinv(states[1] - states[0], modulus) % modulus
return multiplier
def crack_unknown_modulus(states):
diffs = [s1 - s0 for s0, s1 in zip(states, states[1:])]
zeroes = [t2 * t0 - t1 * t1 for t0, t1, t2 in zip(diffs, diffs[1:], diffs[2:])]
modulus = abs(reduce(gcd, zeroes))
return modulus
class LCG:
# Xn = (a*Xn-1 + c) % n
def __init__(self, seed, a, c, m):
self.seed = seed
self.a = a
self.c = c
self.m = m
def next(self):
self.seed = (self.a * self.seed + self.c) % self.m
return self.seed
context.log_level = 'DEBUG'
io = remote('chal.amt.rs', 1417)
io.recvuntil(b'proof of work:\n')
pw = io.recvline().decode().strip()
pro = subprocess.run(pw, shell=True, stdout=subprocess.PIPE) # Could be vulnerable! Use with caution
print(io.recvuntil(b'solution: ').decode())
io.send(pro.stdout)
io.recvuntil(b'encrypted_flag = ')
known_elements = [] # args.known
encrypted_flag = int(io.recvline().decode().strip())
for i in range(9):
io.sendlineafter(b'message? ', str(2**(2000+i)).encode())
known_elements.append(int(io.recvline().decode().strip(), 16))
modulus = None # args.modulus
multiplier = None # args.multi
increment = None # args.inc
if modulus is None:
if len(known_elements) < 6:
print('At least 6 known values are needed to calculate the modulus')
exit()
modulus = crack_unknown_modulus(known_elements)
if multiplier is None:
if len(known_elements) < 3:
print('At least 3 known values are needed to calculate the multiplier')
exit()
multiplier = crack_unknown_multiplier(known_elements, modulus)
if increment is None:
if len(known_elements) < 2:
print('At least 2 known values are needed to calculate the increment')
exit()
increment = crack_unknown_increment(known_elements, modulus, multiplier)
print('''Modulus: {}
Multiplier: {}
Increment: {}'''.format(modulus, multiplier, increment))
# multiplier = 2 ** d
# modulus = N
# increment = 0
cnt = 1
while isqrt(modulus) >= encrypted_flag * (2 ** cnt) % modulus:
cnt += 1
io.sendlineafter(b'message? ', str(encrypted_flag * (2 ** cnt) % modulus).encode())
flag = int(io.recvline().decode().strip(), 16) * pow(multiplier, -cnt, modulus) % modulus
print(long_to_bytes(flag).decode())
# if args.next is not None:
# print('\nCalculating next values:')
# lcg = LCG(known_elements[-1], multiplier, increment, modulus)
# for _ in range(args.next):
# print(lcg.next())
```
The output and the flag is `amateursCTF{wtf_why_is_this_rsa_but_you_dont_provide_public_key_this_isnt_how_rsa_works?!?!_0b8ee05d}`.
## algo/orz-larry
> I wrote code to try to solve the problem the omniscient god solved but it was too slow (it works tho!!). Can you help?
>
> `nc chal.amt.rs 1412`
>
> Don't forget to orz larry!
>
> You are supposed to optimize the provided solution given in `lib.rs` to complete the problem in the given memory and time constraints.
>
> Files: `Cargo.lock`, `Cargo.toml`, `main.rs`, `lib.rs`, `Dockerfile`
>
> Solved by: Jackylkk2003
### Important files
`lib.rs`
```rust=
pub const MOD: u32 = 1e9 as u32 + 9;
pub mod omniscient_god;
pub mod gen {
use rand::prelude::*;
const ALPHABET: &[u8] = b"orzlarryhowistheomniscientgodsoorzomgiwanttobelikehim:place_of_worship::star_struck:15969056191281289460760WOIGHOPWAPWQOPHAZQSWDFGBHJOP;L.HQWUORFHPztajawehuiprvgiugIP";
pub fn rand_string(rng: &mut impl Rng, min_n: usize, max_n: usize) -> String {
let n = rng.gen_range(min_n..=max_n);
(0..n)
.map(|_| ALPHABET[rng.gen_range(0..ALPHABET.len())] as char)
.collect()
}
}
pub mod my_code {
use super::MOD;
use std::collections::HashSet;
fn on(mask: usize, bit: usize) -> bool {
((mask >> bit) & 1) == 1
}
pub fn solve(s: &str) -> u32 {
let n = s.len();
let arr = s.as_bytes();
let mut ans = HashSet::new();
for mask in 0..(1usize << n) {
ans.insert(
(0..n)
.filter(|&i| on(mask, i))
.map(|i| arr[i])
.collect::<Vec<_>>(),
);
}
(ans.len() % (MOD as usize)) as u32
}
}
#[cfg(test)]
mod tests {
use super::*;
/// verify my brute force by comparing against the omniscient god
#[test]
fn stress_test() {
let mut rng = rand::thread_rng();
for _ in 0..1000 {
let s = gen::rand_string(&mut rng, 1, 15);
assert_eq!(my_code::solve(&s), omniscient_god::solve(&s));
}
}
}
```
The most important part is the solve function in my_code.
### Task
The solve function enumerates all possible [bitmasks](https://en.wikipedia.org/wiki/Mask_(computing)) (you can treat bitmasks as binary strings with possible leading zeros) with length equal to string length, and then select only those positions with mask on and get the character in that place. Then, the string formed is put into the HashSet and the length of answer is returned.
In other words, the task is count the number of distinct subsequences. Give the answer modulo $10^9+9$.
### Solution
We can simply copy [the solution online](https://www.geeksforgeeks.org/count-distinct-subsequences/). Of course, with some modifications on the inputs and outputs using pwntools.
```python=
from pwn import *
MAX_CHAR = 256
MOD = 1000000009
def solve(ss):
# create an array to store index of last
last = [-1 for i in range(MAX_CHAR + 1)]
# length of input string
n = len(ss)
# dp[i] is going to store count of
# discount subsequence of length of i
dp = [-2 for i in range(n + 1)]
# empty substring has only
# one subsequence
dp[0] = 1
# Traverse through all lengths
# from 1 to n
for i in range(1, n + 1):
# number of subsequence with
# substring str[0...i-1]
dp[i] = 2 * dp[i - 1] % MOD
# if current character has appeared
# before, then remove all subsequences
# ending with previous occurrence.
if last[ord(ss[i - 1])] != -1:
dp[i] = (dp[i] - dp[last[ord(ss[i - 1])]]) % MOD
last[ord(ss[i - 1])] = i - 1
return dp[n] % MOD
context.log_level = 'DEBUG'
r = remote('chal.amt.rs', 1412)
n = int(r.recvuntil('\n').decode().strip())
for _ in range(n):
s = r.recvuntil('\n').decode().strip()
r.sendline(str(solve(s)).encode())
r.interactive()
```
Final output and flag: `Yay! Good job, here's your flag and remember to orz larry: amateursCTF{orz-larry-how-is-larry-so-orz-4efe27a2edde418184d668992819a62fa4b3a7e6ba5ac3a204be9a66ed7b7105}`.