Try   HackMD

VolgaCTF Qualifiers

Knock-Knock

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

Above is the picture of the challenge on the site. Below are the files linked.

knock.pcap link

task.py

import os
import time


class mersenne_rng(object):
    def __init__(self, seed=5489):
        self.state = [0] * 624
        self.f = 1812433253
        self.m = 397
        self.u = 11
        self.s = 7
        self.b = 0x9D2C5680
        self.t = 15
        self.c = 0xEFC60000
        self.l = 18
        self.index = 624
        self.lower_mask = (1 << 31)-1
        self.upper_mask = 1 << 31

        # update state
        self.state[0] = seed
        for i in range(1, 624):
            self.state[i] = self.int_32(self.f * (self.state[i-1] ^ (self.state[i-1] >> 30)) + i)

    def twist(self):
        for i in range(624):
            temp = self.int_32((self.state[i] & self.upper_mask) + (self.state[(i+1) % 624] & self.lower_mask))
            temp_shift = temp >> 1
            if temp % 2 != 0:
                temp_shift = temp_shift ^ 0x9908b0df
            self.state[i] = self.state[(i+self.m) % 624] ^ temp_shift
        self.index = 0

    def get_random_number(self):
        if self.index >= 624:
            self.twist()
        y = self.state[self.index]
        y = y ^ (y >> self.u)
        y = y ^ ((y << self.s) & self.b)
        y = y ^ ((y << self.t) & self.c)
        y = y ^ (y >> self.l)
        self.index += 1
        return self.int_32(y)

    def int_32(self, number):
        return int(0xFFFFFFFF & number)


def main():
    rng = mersenne_rng(1000) #placeholder
    for i in range(625):
        number = rng.get_random_number()
        port1 = (number & (2 ** 32 - 2 ** 16)) >> 16
        port2 = number & (2 ** 16 - 1)

        fd = open('/etc/knockd.conf', 'w')
        fd.write('[options]\n')
        fd.write('    UseSyslog\n')
        fd.write('    interface = enp0s3\n')
        fd.write('[openSSH]\n')
        fd.write('    sequence = {0}, {1}\n'.format(port1, port2))
        fd.write('    seq_timeout = 5\n')
        fd.write('    command = /sbin/iptables -A INPUT -s %IP% -p tcp --dport 2222 -j ACCEPT\n')
        fd.write('    tcpflags = syn\n')
        fd.write('[closeSSH]\n')
        fd.write('    sequence = {1}, {0}\n'.format(port1, port2))
        fd.write('    seq_timeout = 5\n')
        fd.write('    command = /sbin/iptables -D INPUT -s %IP% -p tcp --dport 2222 -j ACCEPT\n')
        fd.write('    tcpflags = syn\n')
        fd.close()
        os.system('systemctl restart knockd')
        assert 'Active: active (running)' in os.popen('systemctl status knockd').read()

        time.sleep(5)

if __name__ == "__main__":
    main()

We are given a network capture for a series of port knocking attempts. Port knocking is a special form of SSH where you must ping a series of ports before you're allowed to SSH in. These ports are generated through an insecure RNG function called the Mersenne RNG. Our goal is:

  1. To discover what port entries in the PCAP file are part of the RNG
  2. Predict the next sequence in the Mersenne RNG

For step one, I wanted to download knockd and try knocking on my localhost port, but that one fell through due to my own networking issues. So I read through the main function and noticed that we should get two ports per RNG entry; one port is the most significant 16 bits of the entry and the second port is the least most significant 16 bits of the entry. To enter the ssh successfully, you must knock through the order of port1, port2, and to close the ssh, the protocol will knock via port2, port1.

After manually reading through the Wireshark packet capture, I noticed the following pattern (source -> destination):

    random_port -> port1
    random_port -> port2
    fixed_random -> 2222 (alternating like a TCP connection)
    random_port -> port2
    random_port -> port1

So, in order to scrape all the port1 and port2 ports, I looped through all the destination ports in the pcap file; I iterated by pairs, and if a reversed pair was spotted fewer than 5 pairs later, then I would add it to a list.

Finally, I was able to solve this because I did Cryptopals like a year ago. I copied my code for reversing a Mersenne Twister based on 624 usages of the cipher. The premise is that the Mersenne Twister's bitwise operations can be reversed to a single Twister state element with every RNG output, and the state element is initialized to 624 members at the start.

My solution is below.

from scapy.all import *
from task import mersenne_rng

sessions = rdpcap("knockd.pcap").sessions()
dports = []

for session in sessions:
    for packet in sessions[session]:
        if not dports or packet[TCP].dport != dports[-1]:
            dports.append(packet[TCP].dport)

candidate_pairs = [dports[i:i+2] for i in range(len(dports)-1)]
window = 5
pairs = []

for i in range(len(candidate_pairs) - window + 1):
    for j in range(1, window):
        if candidate_pairs[i] == candidate_pairs[i+j][::-1]:
            pairs.append(candidate_pairs[i])

states = [(p[0]<<16) + p[1] for p in pairs]

u, d = 11, 0xffffffff
s, b = 7, 0x9d2c5680
t, c = 15, 0xefc60000
l = 18
n = 624
mask = 0xffffffff

def temper(y):
    y ^= ((y >> u) & d)
    y ^= ((y << s) & b)
    y ^= ((y << t) & c)
    y ^= (y >> l)
    return y

def numberBits(y):
    ctr = 0
    while y:
        y >>= 1
        ctr += 1
    return ctr

def t1(y, rs, ad):
    # General case of t4, solved separately
    nb = numberBits(y)
    ans = 0
    for x in range(nb):
        if x < rs:
            ans = ans | (((y >> (nb - 1 - x)) & 1) ^ 0)
        else:
            ysb = (ans >> rs) & 1 # y shifted bit
            cp = ((ad >> (nb - 1 - x)) & 1) & ysb # c prime pit
            ans = ans | (cp ^ ((y >> (nb - 1 - x)) & 1))
        ans <<= 1
    return ans >> 1

def t3(y, ls, ad):
    # ls = left shifter
    # ad = ander
    ans = 0xffffffff
    for x in range(32):
        if x < ls:
            test = (y & 1) ^ 0
        else:
            ysb = (ans >> (x - ls)) & 1 # y shifted bit, with offset of ls due to right shift
            cp = ((ad >> x) & 1) & ysb # c prime; ad & ysb = cp
            test = (y & 1) ^ cp # cp ^ ans = y
        if ((ans >> x) & 1) != test:
            ans ^= (1 << x)
        y >>= 1
    return ans

def t4(y):
    # Special case of t1, solved separately
    # Beginning l bits of ans must be xored by 0, a result of y >> l
    # Next l bits can be xored by a reconstruction of those beginning ans bits
    nb = numberBits(y)
    ans = 0
    for x in range(nb):
        if x < l:
            ans = ans | (((y >> (nb - 1 - x)) & 1) ^ 0)
        else:
            ans = ans | (((y >> (nb - 1 - x)) & 1) ^ ((ans >> l) & 1))
        ans <<= 1
    return ans >> 1

def untemper(y):
    y = t4(y)
    y = t3(y, t, c)
    y = t3(y, s, b)
    y = t1(y, u, d)
    return y

reversed_seeds = [untemper(s) for s in states]
rng = mersenne_rng(69)
rng.state = reversed_seeds
rng.index = 0
for i in range(624):
    assert rng.get_random_number() == states[i]

final = rng.get_random_number()
port1 = (final & (2 ** 32 - 2 ** 16)) >> 16
port2 = final & (2 ** 16 - 1)
print("VolgaCTF{" + str(port1) + "," + str(port2) + "}")

QR Codebook

Below is qr.png

Below is qr.encrypted.png

Below is flag.encrypted.png

I tried observing the plaintext, ciphertext pair QR codes. I noticed that every 16 rows of pixels, the encrypted picture would repeat some sort of pattern. I noticed that if I squinted really hard, I could see the original QR image, but some data would be missing.

I loaded the pictures in matlab and noticed that while the plaintext picture array was all 0s and 1s, the ciphertext was filled with decimals. I formed a theory that I might be able to get a better version of the ciphertext if I:

  1. Mapped every plaintext bit to a ciphertext bit
  2. If a ciphertext float was always mapped to a white bit, make the float white.
  3. If a ciphertext float was always mapped to a black bit, make the float black.
  4. If a ciphertext float sometimes mapped to either, make the float white.

The resulting picture after I applied my theory looked like

After this, I thought of a way to "fill" the gaps, and found that if I took a point, and it had a black pixel within 10 pixels down, then I should make it black. Of course, the bottom 10 rows wouldn't have 10 pixels down; for those pixels, I would search 10 pixels up.

I then used the same always black ciphertext floats, always white plaintext floats to decrypt the flag.

Here is my work below:

sol.m

plaintext = imread('qr.png');
ciphertext = im2gray(im2double(imread('qr.encrypted.png')));
ciphertext_flag = im2gray(im2double(imread('flag.encrypted.png')));

black = unique(ciphertext(plaintext == 0));
white = unique(ciphertext(plaintext == 1));
C = intersect(black, white);
pure_black = setdiff(black, C);
pure_white = setdiff(white, C);
decrypt(pure_black, pure_white, C, ciphertext, 'qr.decrypted.png', 10);

all_points = unique(ciphertext_flag(:));
C = setdiff(all_points, C);
C = setdiff(all_points, C);
flag = decrypt(pure_black, pure_white, C, ciphertext_flag, 'flag.png', 10);

decrypt.m

function ciphertext = decrypt(pure_black, pure_white, C, ciphertext, name, max_diff)
    [n, ~] = size(ciphertext);
    for i = 1:16
        for j = i:16:n
            chunk = ciphertext(j, :);
            ciphertext(j, ismember(chunk, pure_black)) = 0;
            ciphertext(j, ismember(chunk, pure_white)) = 1;
            ciphertext(j, ismember(chunk, C)) = 1;
        end
    end
    figure(1); imshow(ciphertext);
    ciphertext = smooth(ciphertext, max_diff);
    figure(2); imshow(ciphertext);
    imwrite(ciphertext, name);
end

smooth.m

function ciphertext = smooth(ciphertext, max_diff)
    [n, ~] = size(ciphertext);
    for i = 1:n-max_diff
        for j = 1:n
            should_fill = false;
            if ciphertext(i, j) ~= 0
                for k = 1:max_diff-1
                    if ciphertext(i+k, j) == 0
                        should_fill = true;
                    end
                end
                if should_fill
                    ciphertext(i, j) = 0;
                end
            end
        end
    end

    for i = n-(max_diff-1):n
        for j = 1:n
            should_fill = false;
            if ciphertext(i, j) ~= 0
                for k = 1:max_diff-1
                    if ciphertext(i-k, j) == 0
                        should_fill = true;
                    end
                end
                if should_fill
                    ciphertext(i, j) = 0;
                end
            end
        end
    end
    
    ciphertext(ciphertext ~= 0) = 1;
end

carry

I solved this one after the competition with the help of this writeup. I'm only going to include my solution code with comments.

from Crypto.Util.number import *
from fcsr import FCSR
import random
from pwn import *

cycle = [0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1]

data = open("encrypted_png", "rb").read()
binary = bin(bytes_to_long(data))[2:]
png_header = open("mario.png", "rb").read()[:16]
known_clocks = xor(png_header, data)[:16]

def check_cycle(q, a):
    f = FCSR(q, 0, a)
    t = "".join(map(str, [f.clock() for _ in range(100)]))
    print(q.bit_length(), t)

# inspect the keystream for low values of q and a
# check_cycle(16, 9)
# check_cycle(32, 24)
# check_cycle(64, 57)
# check_cycle(128, 100)
# check_cycle(256, 242)
# from cycles, we find that there will be a repeating str of q.bit_length() - 1

main = {}

for i in range(4, 65):
    q = 2**i
    a = random.randint(0, q-1)
    f = FCSR(q, 0, a)
    t = "".join(map(str, [f.clock() for _ in range(200)]))
    
    for j in range(0, len(t) - 2*i):
        s = t[j:j+i]
        if s == t[j+i:j+2*i]:
            main[i] = s
            break
    
print("All 64 bit qs have cyclic problem?", len(main) == 61)

k = bin(int(known_clocks.hex(), 16))[2:].zfill(128)
stop = False

# discover repeating cycle based on key info from png header
for i in range(65, 3, -1):
    for j in range(len(t) - 2*i, -1, -1):
        s = k[j:j+i]
        if s == k[j+i:j+2*i]:
            stop = True
            break
    if stop:
        break

# extend cycle to fit length of data
repetition = k[j:j+i]
keystream = k[:j] + repetition * (len(data) - j // len(repetition) + 1)
flag = b""
# copy over FCSR's encrypt function
for i, byte in enumerate(data):
    key_byte = 0
    for _ in range(8):
        bit = int(keystream[8*i + _])
        key_byte = (key_byte << 1) | bit
    flag += int.to_bytes(key_byte ^ byte, 1, 'big')

open("flag.png", "wb").write(flag)