# SHA256-CTR SDCTF 2023 - Crypto Exploiting SHA length extension attacks in a counter scheme Writeup by [Arnav Vora](https://github.com/AVDestroyer) for [PBR | UCLA](https://pbr.acmcyber.com/). > SHA-256 seems to be a nice construction for a stream cipher > Attachments: [sha256ctr.py](https://github.com/AVDestroyer/CTF-Writeups/blob/main/sdctf2023/sha256-ctr.py) ## The challenge ### Initial steps The encryption scheme gives us 3 options: encrypting the flag, encrypt our own message, and simulate encrypting N blocks. Encrypting the flag, it looks like we have a bunch of hex as the output. If we encrypt our own message, we can quickly see that the length of the hex output corresponds with how many bytes we encrypt. This makes sense, since CTR mode turns a block cipher into a stream cipher (more on what that means soon!). Remember that every byte is 2 hex characters, so based on the length of the encrypted flag, I knew the flag was around 63 characters long. Now, we need to examine source code. ### Examining source code I first examined each of the 3 options. The first and second (encrypting the flag and encrypting our own input) appear pretty straightforward: the flag/text is turned into hex and fed into the encrypt() function. The third option is the interesting one: it allows us to add an arbitrary amount to a `counter` variable introduced earlier in the code. There is also a comment that this behavior can be emulated by repeatedly calling the first and second option, so there shouldn't be a security issue. Immediately, this sticks out to me, due to the comment. Abusing this feature is probably going to be part of our solution. Looking into the encryption function itself, it appears to calculate a function, `xor()`, of the input with the output of a function, `next_key_block()`. `xor()` looks very straightforward, just calculating a bitwise XOR between two byte strings, stopping when either of them ends. `next_key_block()` is called once for every `SHA256_SIZE` (32) bytes of our input, so if our input is 64 bytes long, the first 32 bytes are XORed with the output of `next_key_block()`, and the next 32 bytes are XORed with the next output of `next_key_block()`. `next_key_block()` uses a global `counter` variable, turns it into a byte string in little endian (since the counter variable is numerical), and computes a SHA256 hash of that. Since a SHA256 hash digest is 32 bytes, that's why we can only XOR 32 bytes with the output of this function at a time. Then, the `counter` is incremented by 1. Immediately, I saw the issue with the third option. All 3 options have a ratelimit, but we can add an arbitrary number n to the counter with the third option with only one call to `rate_limit()` instead of n calls by using the first 2 options. This could allow us to add very large numbers that would be infeasible to add to the counter with the other 2 options. Our solution likely needs to use this somehow. ### What is CTR? The source code mentions the following ```py print('Welcome to the demo playground for my unbreakable SHA256-CTR encryption scheme') print('It is inspired by AES-CTR + SHA256, none of which has been shown to be breakable') ``` Looking up AES-CTR, it appears to be a mode of operation that turns a block cipher (like AES) into a stream cipher. The keystream is generated by encrypting a "counter", which seems to line up with what the source codes. Instead of using AES or a similar cipher encryption, though, the cipher uses a SHA-256 hash. At first glance, there shouldn't be a major issue with this, since SHA-256 is supposed to produce a pseudorandom output, which is what we want to see in the keystream. I did notice that in a typical CTR scheme, we need to encrypt both a counter and a IV/nonce value, while in the source code we are directly hashing the counter. Perhaps that is important? Remember, my goal is to recover the flag. Since the flag's bytes are XORed with a "keystream", if we can predict a keystream, we can recover the flag. The predicted keystream can be any arbitrary amount in the future due to option 3, that lets us to skip forward to any future counter value. But SHA-256 is supposed to be pseudorandom as it is a good hash function, so how can we predict it? ### How can we attack SHA-256? I found this [post](https://crypto.stackexchange.com/questions/37566/can-we-use-a-sort-of-hash-function-in-ctr-mode-instead-of-a-block-cipher) pretty early on. My suspicions were confirmed -- there is nothing inherently wrong with using a hash function instead of a cipher to encrypt the counter, but the *construction* may be insecure. The post links to a Wikipedia article about hash length extension attacks that can occur if the construction is insecure (such as not including an IV/nonce and directly hashing the counter). #### Length extension For certain hash algorithms, including SHA-256, the [hash length extension attack](https://en.wikipedia.org/wiki/Length_extension_attack) can occur. This is due to the algorithm using a [Merkle-Damgard construction](https://en.wikipedia.org/wiki/Merkle%E2%80%93Damg%C3%A5rd_construction). If we know Hash(message<sub>1</sub>) and the length of message<sub>1</sub>, but not necessarily message<sub>1</sub>, and we can control a message<sub>2</sub>, we can compute Hash(message<sub>1</sub> \|\| message<sub>2</sub>), where the \|\| symbol indicates concatenation. I immediately thought that this could work in our challenge - our message<sub>2</sub> could be the number we add with option 3 to the `counter`, and then we can predict Hash(message<sub>1</sub> \|\| message<sub>2</sub>) = Hash(`counter` + n). There's just a small problem with this. Let's say `counter` started as 70, and we wanted to concatenate it to 7070. What should we add to it? 7000? Keep in mind that we don't have access to the value of `counter`, so how would we know what to add to it? I put aside this length extension idea for a while due to this limitation, and started to look for other attacks. However, aside from length extension, SHA256 acts as a *pretty good* pseudo-random function, and I didn't find anything. Every thing I did find linked back to the idea that we don't use an IV/nonce when calculating the hash, which directly links back to the idea of length extension. But, then I realized: When hashing `counter`, it is stored in *little endian*. The least significant byte is stored first! That means, a counter value of `0x12345678` (hex) would be stored as `0x78 0x56 0x34 0x12`, and then passed into the hash. This makes concatenation implementable with addition: If we wanted to concatenate the byte `0x9a` to the original sequence to get `0x78 0x56 0x34 0x12 0x9a`, we would add `0x9a*(1 << (4*8))`, where `a << b == a*2^b`. We use 4*8 for the exponent since we want to move forward by 4 bytes, and each byte is 8 bits. Thus, we can append arbitrary bytes by just adding large numbers to counter (and exploiting the third option that only has 1 call to `rate_limit()`). I found the following [Python code](https://github.com/stephenbradshaw/hlextend) to perform a length extension attack. Small side note: most times, length extension attacks are used to inject payloads into requests, where the actual hash we want to compute is Hash(secret \|\| data \|\| malicious data) given Hash(secret \|\| data), the length of secret, and we can control both data and malicious data. However, we can simply set our data to the empty string for our purposes. Anyways, I tested out this library by running the source code while printing out the counter value each time I encrypted something. I first encrypted the flag, then added a certain value using option 3 to predict the next counter/hash, then encrypted the flag again. Also, I created a dummy flag.txt file with text `flag` ```py >>> num = (1 << (32 * 8)) - 1 >>> num 115792089237316195423570985008687907853269984665640564039457584007913129639935 ``` This number is chosen to add. `counter` is chosen as a random 256-bit (or 32-byte) number, so I left-shift 1 by 256 bits/32 bytes. I subtract 1 from `num` because we want to compute `counter` \|\| `0x01`, but after we encrypt the flag once, `counter` increases by 1, which we need to account for when calculating the concatenation. Now, I run the program: `python3 sha256ctr.py` Inputs: ``` 1, 3, 115792089237316195423570985008687907853269984665640564039457584007913129639935, 1 ``` Output: ```py counter_initial = b'\xf3\x1b]\x12@\xf6\x9b\xce\xbc\xbe\xb3\xeb\x9dQ\x96@>$\x0b\xb6\x9b5\x97a\x94\xed\x94\x18\n\xa2\xa4\xc4' counter_final = b'\xf3\x1b]\x12@\xf6\x9b\xce\xbc\xbe\xb3\xeb\x9dQ\x96@>$\x0b\xb6\x9b5\x97a\x94\xed\x94\x18\n\xa2\xa4\xc4\x01' ``` As we can see, I successfully concatenated the byte `0x01` to `counter` without necessarily needing to know its value! Now let's try a hash extension. I have everything: Hash(secret \|\| data) = counter_hash, data = `b''`, len(secret) = 32 (highly likely that secret is 32 bytes long since it is a random 32-byte integer, and the probability that the top byte is 0 is 1/16), and malicious data = `b'\x01'`: ```py >>> import hlextend >>> import hashlib >>> sha = hlextend.new('sha256') >>> counter_initial = b'\xf3\x1b]\x12@\xf6\x9b\xce\xbc\xbe\xb3\xeb\x9dQ\x96@>$\x0b\xb6\x9b5\x97a\x94\xed\x94\x18\n\xa2\xa4\xc4' >>> counter_final = b'\xf3\x1b]\x12@\xf6\x9b\xce\xbc\xbe\xb3\xeb\x9dQ\x96@>$\x0b\xb6\x9b5\x97a\x94\xed\x94\x18\n\xa2\xa4\xc4\x01' >>> counter_hash = hashlib.sha256(counter_initial).hexdigest() >>> counter_hash 'af669ce186bba411111b6eec52148ddad25e69ce6fe3ff53d10b6af077b758cf' >>> sha.extend(b'\x01', b'', 32, counter_hash) b'\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01' >>> sha.hexdigest() '86a7dae7c2278a48e96b382a9f1ecbaf05190f7a685d9f99c5e4179360742173' >>> counter_final_hash = hashlib.sha256(counter_final).hexdigest() >>> counter_final_hash '9230b49d88dd31bcb260c4058aefaf5c9c2066cf3ed3fee14a234d6f7ff54738' ``` Wait, why is the hash from `hlextend` and the actual hash of the final counter different? Is something wrong with this library? I tried some other implementations of this attack and they all returned the same hash. What's going on here? ### Padding After some frustrated googling, I figured out the issue. Hash extension attacks only work if the payload string is padded a very certain way. When computing a hash H(secret), the secret string is padded to a multiple of 56 bytes, then 8 more bytes are added. [Source](https://seedsecuritylabs.org/Labs_16.04/PDF/Crypto_Hash_Length_Ext.pdf). Here are the padding rules: - The byte `\x80` is added after secret - Then, `\x00` bytes are added until the length of the string is 56 - Then, we store the length, in bits of secret in the remaining 8 bytes (not including the `\x80` or `\x00` bytes). This is stored in **big endian**, so a length of 32 bytes/256 bits would be stored as `\x00 \x00 \x00 \x00 \x00 \x00 \x01 \x00` in 8 bytes. In order to perform a hash length extension attack, we must preserve this padding and add our own payload *after* it. So, I need to re-calculate `num` to account for this padding. Specifically, the 33rd byte needs to be `\x80` while the 127th byte needs to be `\x01`, while the 129th byte can be the actual payload (I'm going to use `\x41` or `A` because I feel like it). Here is what the actual number to add must be: ```python >>> num = (65 << (8*64)) + (1 << (8*62)) + (128 << (8*32)) - 1 >>> num 871507720033181804981178500707736050011105791878633447243580836547625167963157650171703886058243710254943004764744569836035642451052111735108265899022352383 ``` Note that 65 is `\x41` in hex, and 128 is `\x80` in hex. I subtract 1 for the same reason as before. Let's see if hash length extension works now: Inputs: ``` 1, 3, 871507720033181804981178500707736050011105791878633447243580836547625167963157650171703886058243710254943004764744569836035642451052111735108265899022352383, 1 ``` Output: ```py counter_initial = b'@\xf9\x8e_\x97\xefoRw\x02\xf9\xb0\xb4\xf8\x171\x07&v<\xa1\xe7p\xdb\xddKKVx\x06>\xd9' counter_final = b'@\xf9\x8e_\x97\xefoRw\x02\xf9\xb0\xb4\xf8\x171\x07&v<\xa1\xe7p\xdb\xddKKVx\x06>\xd9\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00A' ``` As we can see, I successfully concatenated and padded `counter` the way I need to. Now, I do the hash extension: ```py >>> import hlextend >>> import hashlib >>> sha = hlextend.new('sha256') >>> counter_initial = b'@\xf9\x8e_\x97\xefoRw\x02\xf9\xb0\xb4\xf8\x171\x07&v<\xa1\xe7p\xdb\xddKKVx\x06>\xd9' >>> counter_final = b'@\xf9\x8e_\x97\xefoRw\x02\xf9\xb0\xb4\xf8\x171\x07&v<\xa1\xe7p\xdb\xddKKVx\x06>\xd9\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00A' >>> counter_hash = hashlib.sha256(counter_initial).hexdigest() >>> counter_hash '4ad090066717a2c822cd11de14caceb0209d14a5861229a022e995553761a6a4' >>> sha.extend(b'', b'\x01', 32, counter_hash) b'\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00A' >>> sha.hexdigest() '653f875bb4b5e868395f679c49adf65855cadf3ffc8dc493d6973fd333207405' >>> counter_final_hash = hashlib.sha256(counter_final).hexdigest() >>> counter_final_hash '653f875bb4b5e868395f679c49adf65855cadf3ffc8dc493d6973fd333207405' ``` Amazing, the hashes match up! Let's step back a second. What have we achieved? Using this script, we were able to predict the hash of counter after adding `num` to `counter`, without necessarily even knowing the value of `counter` (I used it here to check my work). That means that we have found a *case* where SHA-256 is **not a cryptographically secure source of randomness** (in other words, it is not a pseudo-random function)! I was able to predict a future output based only on the current output of SHA-256 and no knowledge of its state. This will prove very helpful in the final solve. If we are able to predict the hash H at any time, then we can recover the original plaintext (such as the file) by computing ciphertext XOR H. Note that in *most* cases, SHA-256 is a good source of randomness and the reason it wasn't here was because it was used inappropriately, not because it is fundamentally flawed as a hashing algorithm. ### Putting it all together The strategy is to 1. Ask the program to encrypt our own message 2. Compute message XOR ciphertext to recover the hash of `counter` that was used as the keystream for that encryption 3. Ask the program to encrypt N blocks, where N is the offset we decided on above 4. Ask the program to encrypt the flag 5. Run the hash through the hash extension attack to determine the hash at the time the flag was encrypted 6. Compute ciphertext XOR hash<sub>new</sub> to recover the flag I want to recover the entire 32-byte SHA256 hash used to encrypt our message, so the first message I send should be 32 bytes long. Here is the script to compute these values: ```py from pwn import * import binascii import hlextend # https://github.com/stephenbradshaw/hlextend SHA256_SIZE = 32 sha = hlextend.new('sha256') def xor(sa: bytes, sb: bytes): return bytes([a ^ b for a, b in zip(sa, sb)]) def decrypt(c: bytes, p: bytes) -> bytes: return b''.join(xor(c[i:i+SHA256_SIZE], p[i:i+SHA256_SIZE]) for i in range(0,len(c),SHA256_SIZE)) offset = (65 << (8*64)) + (1 << (8*62)) + (128 << (8*32)) - 1 #conn = process(['python3', 'sha256ctr.py']) conn = remote('shactr.sdc.tf',1337) print(conn.recvline().decode('utf-8')) conn.recvuntil(b'> ') conn.send(b'2\n') conn.recvuntil(b': ') p1 = b'\x00'*32 conn.send(binascii.hexlify(p1) + b'\n') conn.recvline() c1 = conn.recvline().decode('utf-8').strip().split()[-1] h1 = decrypt(binascii.unhexlify(c1.encode()),p1) print(c1) print(p1) print(h1) conn.recvuntil(b'> ') conn.send(b'3\n') conn.send(str(offset).encode() + b'\n') conn.recvuntil(b'> ') conn.send(b'1\n') conn.recvline() c2 = conn.recvline().decode('utf-8').strip().split()[-1] print(c2) sha.extend(b'A',b'',32,h1.hex()) h2_1 = sha.hexdigest() p2 = decrypt(binascii.unhexlify(c2.encode()),binascii.unhexlify(h2_1.encode())) print(p2) print('\n\n') print('c1: ' + c1) print('p1: ' + p1.hex()) print('h1: ' + h1.hex()) print('c2: ' + c2) print('h2_1: ' + h2_1) print('p2: ' + p2.hex()) ``` For p2 (plaintext or the flag), I get: `sdctf{l3ngth-ext3nsion-@tt4ck-br`. Huh. It looks like only half of the flag... If we recall the original `encrypt()` function in the source code, we see that it will call `next_key_block()` multiple times if the bytes it needs to encrypt are longer than 32 (SHA256_SIZE). We can also recall that the flag length is around 63 bytes from earlier. Each time `next_key_block()` is called, counter is incremented by 1 and hashed. That means that for an approximately 63 byte flag, we'd need to know the value of 2 consecutive hashes right? How is that possible? While I can't add 1 to the counter by concatenating anything (because adding 1 affects the LSB in little endian), what I can do is perform *another* hash extension to predict the value of `counter+1` in the future. I have to be careful about padding though #### Padding again After I padded the counter once, it has a length of 65 bytes, where the most significant (last) byte is `\x41`. To pad this again, I need to add `\x80`, a bunch of `\x00` bytes, the length bytes which would be `65 bytes = 520 or 0x208 bits`, and the next payload `\x41`. Remember how big endian works: the length bytes would be `\x00 \x00 \x00 \x00 \x00 \x00 \x02 \x08`. The offset to do this padding would be: ```python >>> num = (65 << (8*128)) + (128 << (8*65)) + (2 << (8*126)) + (8 << (8*127)) - 3 11690628653775566931140821074786127658018676615576823363144820077290791529673866268396895331048514613236261019035111887162359917111396393726135887697877459243662266292806243141031934463199638093998917113786520539425939052918774668156355246079619004749753897122338654057840534986808814722296515064770808741101565 ``` I'll use the hash that I calculated from hash extending before as the initial hash for this extension. I need to subtract 3 because when the flag is encrypted, `next_key_block()` is called twice, giving `counter+2`. If I added `num` without the -3, I would then encrypt the first part of the flag using `hash(newcounter+2)` instead of `h(newcounter)` that would be calculated by hash extension. Since the flag has 2 parts, I want the second part to be XORed with `h(newcounter)`, so the first needs to be XORed with `h(newcounter-1`), which is why I subtract 3. Small note: when I call `decrypt()` to decrypt the second part of the flag, the hash bytes need to be 64 bytes long in order to decrypt the latter 32 bytes of the flag, so I can fill in the first 32 bytes of the hash with anything, and then append the 32 byte hash I determined from length extension. After decrypting, I need to take the latter 32 bytes of the result for the latter half of the flag. We can then put everything together for the final [solve script](##solve-script) and get the flag. ## Flag: `sdctf{l3ngth-ext3nsion-@tt4ck-br3aks-ps3ud0R4nd0mn3ss-of-SHA2}` ## Solve script: ```python from pwn import * import binascii import hlextend # https://github.com/stephenbradshaw/hlextend SHA256_SIZE = 32 sha = hlextend.new('sha256') def xor(sa: bytes, sb: bytes): return bytes([a ^ b for a, b in zip(sa, sb)]) def decrypt(c: bytes, p: bytes) -> bytes: return b''.join(xor(c[i:i+SHA256_SIZE], p[i:i+SHA256_SIZE]) for i in range(0,len(c),SHA256_SIZE)) offset = (65 << (8*64)) + (1 << (8*62)) + (128 << (8*32)) - 1 offset2 = (65 << (8*128)) + (128 << (8*65)) + (2 << (8*126)) + (8 << (8*127)) - 3 #conn = process(['python3', 'sha256ctr.py']) conn = remote('shactr.sdc.tf',1337) print(conn.recvline().decode('utf-8')) conn.recvuntil(b'> ') conn.send(b'2\n') conn.recvuntil(b': ') p1 = b'\x00'*32 conn.send(binascii.hexlify(p1) + b'\n') conn.recvline() c1 = conn.recvline().decode('utf-8').strip().split()[-1] h1 = decrypt(binascii.unhexlify(c1.encode()),p1) print(c1) print(p1) print(h1) conn.recvuntil(b'> ') conn.send(b'3\n') conn.send(str(offset).encode() + b'\n') conn.recvuntil(b'> ') conn.send(b'1\n') conn.recvline() c2 = conn.recvline().decode('utf-8').strip().split()[-1] print(c2) sha.extend(b'A',b'',32,h1.hex()) h2_1 = sha.hexdigest() p2 = decrypt(binascii.unhexlify(c2.encode()),binascii.unhexlify(h2_1.encode())) conn.recvuntil(b'> ') conn.send(b'3\n') conn.send(str(offset2).encode() + b'\n') conn.recvuntil(b'> ') conn.send(b'1\n') conn.recvline() c3 = conn.recvline().decode('utf-8').strip().split()[-1] print(c3) sha.extend(b'A',b'',65,h2_1) h2_2 = sha.hexdigest() p3 = decrypt(binascii.unhexlify(c3.encode()),b'\x00'*32 + binascii.unhexlify(h2_2.encode()))[32:] print(p2) print('\n\n') print('c1: ' + c1) print('p1: ' + p1.hex()) print('h1: ' + h1.hex()) print('c2: ' + c2) print('h2_1: ' + h2_1) print('p2: ' + p2.hex()) print('c3: ' + c3) print('h2_2: ' + h2_2) print('p3: ' + p3.hex()) print('FLAG: ' + (p2 + p3).decode('utf-8')) conn.interactive() ```