Glacierctf 2024 - AES Overdrive

I had the pleasure of providing a challenge for the third iteration of the Glacier 2024 event in the crypto category. It is a symmetric encryption challenge that is split into an easier first stage and a more challenging second stage, which revolves around symmetric cryptanalysis to recover a round key of AES.

Challenge overview

Challenge

The challenge provides two files:

  1. challenge.py (where the main challenge is located)
  2. aes.py (a python implementation of individual AES operations which was taken from boppreh)

All files can be downloaded here: aes_overdrive.tar.gz

Guided Overview

You are given a custom AES encryption where individual operations (Sub bytes, shift rows, etc) are unchanged.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def encrypt_block(pt: bytes, key: bytes, rounds: int) -> bytes:
	if len(pt) != 16 or len(key) != 16:
    	raise ValueError("Invalid input length")

	subkeys = expand_key(key, rounds)
	assert len(subkeys) == rounds + 1

	block = aes.bytes2matrix(pt)
	aes.add_round_key(block, subkeys[0])

	for i in range(1, rounds+1):
    	aes.sub_bytes(block)
    	aes.shift_rows(block)
    	aes.mix_columns(block)
    	aes.add_round_key(block, subkeys[i])

	return aes.matrix2bytes(block)

There are two critical differences

  1. The number of rounds is either 22 for normal users or 24 for premium users (which will be explained below)
  2. The last round of AES also performs the mix columns operation, which means that all rounds of AES are the same The mode of operation is ECB mode with a maximum block size of 3 (meaning you can only encrypt three blocks of AES per encryption call). Before encrypting a block, the message gets padded to be a multiple of 16, and after decryption, the message is unpaded.

Inside the challenge.py a key is securely generated

1
key = os.urandom(16)

and you are allowed to perform three different operations (all using the same key):

  1. Encryption: You can choose to encrypt a chosen message m using AES with 22 rounds. The only restrictions on the message m are that it can’t be m != "premium" and the message |m| length can’t exceed 3*16 bytes (meaning three blocks of encryption). $$ Enc22_{k}(m) $$
1
2
3
if pt == PREMIUM_USER.decode():
    print("Not so fast my friend.")
    return
  1. Premium encryption: You can choose to encrypt a chosen message $m$ using AES with 24 rounds. However, to access this feature, you need to know the encrypted value of Enc22_k(premium) using the normal 22-round encryption. Here, you also have the same length restrictions of 3*16 bytes. $$ Enc24_{k}(m) $$
1
2
3
if decrypt_msg(ct, key, False) != PREMIUM_USER.decode():
    print("Sorry, you are not a premium user.")
    return
  1. Guess Key: to win the challenge and get the flag, you must correctly guess the key generated at the start.
1
2
3
4
5
if key_guess == key:
	print("Correct key!")
	with open('/app/flag.txt', 'r') as flag:
		print(f"Flag: {flag.read().strip()}")
	return

Importantly, you are only allowed to perform 4 actions (and one of them needs to be the key guess).
If you have not solved the challenge, now is a great time to think about how you could approach this particular problem. I discussed all the building blocks you need in order to solve the challenge.

Solution

In the following section, we will look at the intended solution. Judging from the discussion on Discord and other online forums, there were no unintended ways of solving this challenge.

To solve the challenge, we need to recover the key used for the AES encryption. Here, it is essential to know that if we can recover any round key of AES, we can calculate the master key that was used. We will exploit the fact that we can gain access to 22 and 24-round AES encryption where, importantly, the last round is the same, and by doing some cryptanalysis, we will be able to recover the secret key used.

Getting the premium encryption

The first step of unlocking the premium option is quite trivial with AES in ECB mode, as the check is quite brittle. The easiest way to circumvent this check is to provide a padded string of the premium user.

1
pad(premium, 16) # b'premium\t\t\t\t\t\t\t\t\t'

We use this string as an input for the normal encryption. This string will parse the check as it does not match just b"premium" but will successfully authenticate us as a premium user since, during the decryption, the padding is removed. $$ Enc22_{k}(\text{b’premium\t\t\t\t\t\t\t\t\t’}) = t $$

Recovering a round key

Now that we can access the premium round feature, we can access the 24-round AES encryption. Here is where we can exploit the fact that the last round of the AES implementation is not different from all other rounds.
The key insight is that the output of the 22-round encryption becomes the intermediate value before the last two rounds of the 24-round encryption.

Exploit overview

We can transform the output of the Enc22 by applying one round of AES: $$Enc22_{k}(pt)=ct$$

  1. SubBytes
  2. ShiftRows
  3. MixColumns

which leaves us with only one round of AES.

Now, we are tasked with the easier challenge of recovering the AES key for one full round of AES. With this, we can perform a known plaintext attack on 1 round AES (with two known plaintexts, we can have up to 3 blocks of AES per call). In practice, this means that we will first use our remaining two operations to:

  1. Encrypt two blocks of AES using the ENC22 function with a fixed input $$Enc22_{k}(pt)=[ct_1, ct_2]$$

  2. Encrypt two blocks of AES using the ENC24 function with the same input $$Enc24_{k}(pt)=[ct’_1, ct’_2]$$

The attack requires two known plaintext/ciphertext pairs and has a time complexity of 2^16 operations. It is taken from Low Data Complexity Attacks on AES by Bouillaguet et.al. We can recover the key by using the following steps:

  1. We work backward from the 24-round ciphertexts by applying the inverse MixColumns MC^(-1)(ct'_i), ShiftRows SR^(-1)(ct'_i) until we reach the SubBytes operation.
  2. For each byte position we guess the value of the key and check if this value would even be possible given the Sbox opeartion:
1
2
3
4
5
6
for byte in range(key_size):
    for sub_k_24_guess in range(256):
        sub_k_25_guess_1 = ct1_prime[byte] ^ aes.s_box[new_pt1[byte] ^ sub_k_24_guess]
        sub_k_25_guess_2 = ct2_prime[byte] ^ aes.s_box[new_pt2[byte] ^ sub_k_24_guess]
        if sub_k_25_guess_1 == sub_k_25_guess_2:
            possible_key24[byte].append(sub_k_24_guess)

Once we have all potential candidates for the round key, we can reverse the key schedule to obtain the master key and verify our guess by trial encryption. Bing checling

$$ \text{if } Enc22_{kguess_i}(pt) \neq [ct_1, ct_2] \text{ incorrect key guess} \ $$

The correct key will produce the same ciphertexts as the server. This approach would give us 2^16 suggestions. There is also a more efficient version of this attack that can be done with a complexity of 2^12.

One of the keys will give the correct result, which will be the proper key, and we will have to solve the challenge.

Here is my exploit script:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# This exploit template was generated via:
# $ pwn template --host localhost --port 1337
from pwn import *
import re
import aes
from aeskeyschedule import reverse_key_schedule, key_schedule
import itertools

# Set up pwntools for the correct architecture
# Just set TERM_PROGRAM in your ~/.profile!
# context.update(terminal='CHANGEME')
exe = (args.EXE or 'challenge')
host = args.HOST or 'localhost'
port = int(args.PORT or 1337)

# Find flag by exact match or format
# log.success(find_flag(io.recvall()))
def find_flag(output):
	if not isinstance(output, str):
    	output = output.decode(errors="ignore")
	# Match real flag
	if real_flag in output:
    	return real_flag
	# Match fake flag
	if fake_flag in output:
    	return fake_flag
	# Match possible local flag
	with open("/flag.txt", "r") as local:
    	locl_flag = local.readline().strip()
    	if locl_flag in output:
        	return locl_flag
	# Match regexp flag
	r = find_flag_fmt(output)
	if r is not None:
    	return r
	# Definitely no flag found
	return None


# Find flag by format
# log.success(find_flag_fmt(io.recvall()))
ffmt = re.compile(r"gctf{.*}")
def find_flag_fmt(output):
	if not isinstance(output, str):
    	output = output.decode(errors="ignore")
	m = ffmt.search(output)
	if m is None:
    	return None
	return m.group(0)

def start_local(argv=[], *a, **kw):
	'''Execute the target binary locally'''
	return process([exe] + argv, *a, **kw)

def start_remote(argv=[], *a, **kw):
	'''Connect to the process on the remote host'''
	io = connect(host, port)
	return io

def start(argv=[], *a, **kw):
	'''Start the exploit against the target.'''
	if args.LOCAL:
    	return start_local(argv, *a, **kw)
	else:
    	return start_remote(argv, *a, **kw)


def expand_key(key: bytes, rounds: int) -> list[list[int]]:
	round_keys = [[key[i : i + 4] for i in range(0, len(key), 4)]]
	while len(round_keys) < rounds + 1:
    	base_key = b"".join(round_keys[-1])
    	round_keys += aes.expand_key(base_key, 10)[1:]
	round_keys = [b"".join(k) for k in round_keys]
	round_keys = [aes.bytes2matrix(k) for k in round_keys]
	return round_keys[:rounds + 1]


def encrypt_block(pt: bytes, key: bytes, rounds: int) -> bytes:
	if len(pt) != 16 or len(key) != 16:
    	exit(1)
	subkeys = expand_key(key, rounds)
	block = aes.bytes2matrix(pt)
	aes.add_round_key(block, subkeys[0])
	for i in range(1, rounds+1):
    	aes.sub_bytes(block)
    	aes.shift_rows(block)
    	aes.mix_columns(block)
    	aes.add_round_key(block, subkeys[i])

	return aes.matrix2bytes(block)

PREMIUM_USER = b"premium"
#===========================================================
#                	EXPLOIT GOES HERE
#===========================================================
from Crypto.Util.Padding import pad

def send_encrypted(io, option, pt, code:bytes=None):
	io.sendlineafter(b"Enter option (1: Encrypt, 2: Premium Encrypt, 3: Guess Key):", option)
	if option == "2":
    	print(code)
    	io.sendlineafter(b"Enter ciphertext:", code)
	io.sendlineafter(b"Enter plaintext: ", pt)
	ct = bytes.fromhex(io.recvline().strip().decode()[11:])
	ct1, ct2 = ct[:16], ct[16:32]
	return ct1, ct2

print("Exploiting...")
io = start()

NORMAL_ROUND = 22
PREMIUM_ROUNDS = 24
key_size = 16

# The the admin code
code, _ = send_encrypted(io, "1", pad(PREMIUM_USER, 16)+b"a")
print(code)
assert len(code) == 16
code = code.hex()

pt1 = b"a"*16
pt2 = b"~"*16
new_pt1, new_pt2 = send_encrypted(io, "1", pt1+pt2+b"a")
new_ct1, new_ct2 = send_encrypted(io, "2", pt1+pt2+b"a", code)


# Solve:
new_pt1 = aes.bytes2matrix(new_pt1)
aes.sub_bytes(new_pt1)
aes.shift_rows(new_pt1)
aes.mix_columns(new_pt1)
new_pt1 = aes.matrix2bytes(new_pt1)

new_pt2 = aes.bytes2matrix(new_pt2)
aes.sub_bytes(new_pt2)
aes.shift_rows(new_pt2)
aes.mix_columns(new_pt2)
new_pt2 = aes.matrix2bytes(new_pt2)

ct1_prime = aes.bytes2matrix(new_ct1)
aes.inv_mix_columns(ct1_prime)
aes.inv_shift_rows(ct1_prime)
ct1_prime = aes.matrix2bytes(ct1_prime)

ct2_prime = aes.bytes2matrix(new_ct2)
aes.inv_mix_columns(ct2_prime)
aes.inv_shift_rows(ct2_prime)
ct2_prime = aes.matrix2bytes(ct2_prime)

possible_key24: list[list[int]] = [[] for _ in range(key_size)]

for bit in range(key_size):
	for sub_k_24_guess in range(256):
    	sub_k_25_guess_1 = ct1_prime[bit] ^ aes.s_box[new_pt1[bit] ^ sub_k_24_guess]
    	sub_k_25_guess_2 = ct2_prime[bit] ^ aes.s_box[new_pt2[bit] ^ sub_k_24_guess]
    	if sub_k_25_guess_1 == sub_k_25_guess_2:
        	# possible_key1[bit].append(k_1_guess_1)
        	possible_key24[bit].append(sub_k_24_guess)

print(possible_key24)



def test_key(pt1, pt2, ct1, ct2, key) -> bool:
	ct_test = encrypt_block(pt1, key, PREMIUM_ROUNDS)
	if ct_test == ct1:
    	return True
	ct_test = encrypt_block(pt2, key, PREMIUM_ROUNDS)
	if ct_test == ct2:
    	return True
	return False


all_possible_keys = list(itertools.product(*possible_key24))
all_possible_keys = [bytes(k) for k in all_possible_keys]
print(len(all_possible_keys))


def get_master_key(key, rounds):
	offset = rounds % 10
	k = reverse_key_schedule(key, offset)
	key_rev = rounds // 10
	for _ in range(key_rev):
    	k = reverse_key_schedule(k, 10)
	return k

for k in all_possible_keys:
	# Reverse key schedule
	master_key = get_master_key(k, NORMAL_ROUND+1)
	# k = reverse_key_schedule(k, 10)
	if test_key(pt1, pt2, new_ct1, new_ct2, master_key):
    	print(master_key)
    	io.sendlineafter(b"Enter option", "3")
    	io.sendlineafter(b"Enter key (hex): ", master_key.hex())
    	data1 = io.recvall()
    	print(data1)
    	log.success(find_flag_fmt(data1))
    	exit(0)
assert False  # Key not found