import numpy as np
from loguru import logger
from gmssl import sm3, func
from itertools import cycle
from Crypto.Util.Padding import pad
sign_key_bytes = bytes.fromhex("ac1adaae95a7af94a5114ab3b3a97dd80050aa0a39314c40528caec95256c28c")
rand_num = bytes.fromhex("7283514a")
protobuf_mixed_bytes = sm3.sm3_hash(func.bytes_to_list(sign_key_bytes + rand_num + sign_key_bytes))
logger.debug(protobuf_mixed_bytes)
device_protobuf = bytes.fromhex("0a10366c956c35725a9be4d41d65c8b87926100418f4fdae9b0f220433303139320a313631313932313736343a0632352e302e3042147630342e30342e30352d6d6c2d616e64726f6964488094a04052080000000000000000609aefb5fa0c6a14a9918604d779bd1b908bea845e31346b136be5f07206241768832a1c7a0e080210bee15418bee15420bee154a201046e6f6e65a801e205ba010908e68de1f90c38ac71c2016a7b0a0922636d72223a0931363737373231362c0a0922636d7232223a0931363737373231362c0a0922756e5f68223a09302c0a09226b64223a093639343336372c0a0922666b64223a09313939383031383230342c0a09227064223a092d313034333039303038350a7d")
def c_bitwise_not(num):
bit_size = 32
mask = (1 << bit_size) - 1
result = ~num & mask
if result & (1 << (bit_size - 1)):
result -= (1 << bit_size)
return result
def bfxill(w21, w8):
extracted_bits = (w8 >> 3) & 0b11111
w21 = w21 & ~0b11111
w21 = w21 | extracted_bits
return w21
def medusa_protobuf_mixed(protobuf_bytes: bytes, mix_param_bytes: bytes):
result = []
for i, b in enumerate(protobuf_bytes):
idx = (4 * i) % len(mix_param_bytes)
tmp = (b >> 2) & 0xffffc03f | (b << 6)
tmp += mix_param_bytes[idx]
eon_val = (tmp ^ c_bitwise_not(mix_param_bytes[idx + 1])) & 0xffffffff
tmp = bfxill((32 * eon_val & 0xffffffff), eon_val) + mix_param_bytes[idx + 1]
tmp = (mix_param_bytes[idx] ^ c_bitwise_not(tmp)) & 0xffffffff
result.append(int.to_bytes(tmp, 4, byteorder='little')[0])
mixed_param = list(reversed(result))
mixed_param[0] = (((c_bitwise_not(mixed_param[-2]) ^ mixed_param[-1]) & 0xffffffff) + mixed_param[0]) & 0xffffffff
mixed_param[1] = (mixed_param[0] ^ mixed_param[-1] ^ 0xfe) + mixed_param[1] & 0xffffffff
for i in range(2, len(mixed_param) - 1):
mixed_param[i] += mixed_param[i-2] ^ (((mixed_param[i-1] & 0x80 != 0)) | (2 * mixed_param[i-1])) ^ (c_bitwise_not(i) & 0xffffffff)
mixed_param[i] = int.to_bytes(mixed_param[i], 8, byteorder='little')[0]
mixed_param[-1] ^= mixed_param[-2]
return bytes(mixed_param[1:])
protobuf_processed = medusa_protobuf_mixed(device_protobuf, bytes.fromhex(protobuf_mixed_bytes))
logger.debug(protobuf_processed.hex())
def get_xor_key(random_bytes: bytes):
a, b = random_bytes[-2], random_bytes[-1]
res = a ^ (a >> 0x5) ^ ((a << 0xb | b))
res = (~res) & 0xffffffff
return res
logger.debug(f"xor key calculated {hex(get_xor_key(rand_num))}")
xor_key = bytes.fromhex('fffd77e6')
pad_bytes = bytes.fromhex('00000000000000000d')
xor_result = bytearray([a ^ b for a, b in zip(reversed(pad_bytes + protobuf_processed), cycle(xor_key))])
logger.debug(f"xor result: {xor_result.hex()}" )
prefix_bytes = bytes.fromhex('a6')
prefix_bytes += bytes.fromhex("859ef750")
prefix_bytes += bytes.fromhex("01290918")
aes_lite_in_bytes = prefix_bytes + xor_result + rand_num[2:]
aes_lite_in_bytes = pad(aes_lite_in_bytes, 16)
logger.debug(f"aes lite input: {aes_lite_in_bytes.hex()}")
round_key = bytes.fromhex("ea2b045b11bf2364839e6ab27f95a9df84e705c7955826a316c64c116953e5ce62028a3df75aac9ee19ce08f88cf0541")
dfed0_table = [
0x2E, 0x5C, 0x55, 0xED, 0x1B, 0xDA, 0xA, 0x79, 0x28,
0x69, 0x57, 0xFE, 0x68, 0x3A, 0xDE, 0xAC, 0x90, 0xF9,
0xC1, 0xE1, 0xC3, 0x8B, 0x7F, 0x59, 0x26, 0xCA, 0x13,
0xBB, 0x11, 0x37, 0x39, 0x21, 0xEB, 0x9A, 0xFF, 0x5E,
0x42, 0x33, 0xBE, 0x51, 0x8D, 0x40, 0x1E, 0x91, 0xB3,
0x85, 0xB7, 0xCD, 0xDC, 0x27, 0x92, 0x83, 0x87, 0x3F,
0xE6, 0x4A, 0x64, 0x56, 0x8C, 0xA1, 0x76, 0xD2, 0xFD,
0xC0, 0x63, 0x18, 0x44, 0x1A, 0x9F, 0x61, 0xCB, 0x6E,
0x67, 0x29, 0xAF, 0xB8, 0x54, 0x60, 0xDB, 0x97, 0xE8,
0xA3, 0xC9, 0xE4, 0, 0xEC, 0x50, 0x17, 0xBD, 0x2A,
0xB6, 0x8E, 0x3B, 0x46, 0x65, 0xA6, 0x7A, 0x96, 0xD3,
0x72, 0x12, 0xBC, 0x20, 0x4D, 0x7C, 0xFA, 0x15, 0xC,
0x41, 0x9B, 0xAA, 9, 0xF8, 0xF0, 0x5D, 0x84, 0xFC,
0xE, 0xD6, 0xA0, 0xF2, 0xEF, 0x4E, 0x10, 0xBF, 0x89,
0x6D, 0x9C, 0x98, 6, 0xC2, 0xC7, 0x5A, 0xF1, 0xB1,
0xA5, 0xF4, 0xB9, 0xA2, 0xF5, 0x78, 0xAE, 0x3D, 0x24,
0xFB, 0x30, 0x9D, 0xD8, 0xA4, 0x6F, 0x1F, 0x49, 0xD0,
0x95, 0x3C, 0x99, 0xBA, 0x23, 0xEA, 0x53, 0x14, 0x2B,
0xE0, 0xD, 0x5B, 0x94, 0x38, 0x4B, 0x1C, 0xCC, 0x4C,
0x88, 0x2C, 0x81, 0xF3, 0x9E, 0x70, 0xF6, 0x58, 0x45,
0xB0, 0x35, 0x5F, 0x6A, 0x8A, 0x32, 0x19, 0x34, 0xDD,
0x4F, 0x7D, 0x36, 0xEE, 0xAB, 0x75, 0x71, 0xF, 0x25,
0xB5, 0xE9, 0x47, 0xF7, 0xCF, 0x43, 0x6C, 0xC6, 0x8F,
0x31, 0xB2, 0x2F, 0xD9, 0x1D, 0xC4, 0xA8, 0xD4, 0x93,
0x73, 0xA7, 0x82, 0x77, 0x66, 8, 0x6B, 1, 0xA9, 0xE3,
0xD5, 0xAD, 0xD7, 0xE5, 0x62, 0x86, 3, 0x22, 0xB4,
0x2D, 0xD1, 0xDF, 0x3E, 0x7B, 0x52, 0xE2, 0x7E, 0x48,
0xE7, 0xB, 4, 0xC8, 0x16, 0xC5, 2, 0xCE, 7, 0x74, 0x80,
5, 0x8D, 1, 2, 4, 8, 0x10, 0x20, 0x40, 0x80, 0x1B,
0x36, 0, 0, 0, 0, 0,
]
def shift_rows(s):
s[0][1], s[1][1], s[2][1], s[3][1] = s[2][1], s[3][1], s[0][1], s[1][1]
s[0][2], s[1][2], s[2][2], s[3][2] = s[3][2], s[0][2], s[1][2], s[2][2]
s[0][3], s[1][3], s[2][3], s[3][3] = s[1][3], s[2][3], s[3][3], s[0][3]
def gf_multiply(a, b):
p = 0
counter = 0
while b:
if b & 1:
p ^= a
a <<= 1
if a & 0x100:
a ^= 0x11B
b >>= 1
counter += 1
return p
def mix_columns(state):
new_state = [[0 for _ in range(4)] for _ in range(4)]
mix_matrix = [
[0x02, 0x03, 0x01, 0x01],
[0x01, 0x02, 0x03, 0x01],
[0x01, 0x01, 0x02, 0x03],
[0x03, 0x01, 0x01, 0x02]
]
for col in range(4):
for row in range(4):
for k in range(4):
new_state[row][col] ^= gf_multiply(mix_matrix[row][k], state[k][col])
return new_state
def add_round_key(matrix, round_key):
for i in range(4):
for j in range(4):
matrix[i][j] ^= round_key[i * 4 + j]
def round_encrypt(block_bytes: bytes):
iv = round_key[:16]
state = bytes([a ^ b for a, b in zip(block_bytes, iv)])
state = [dfed0_table[a] for a in state]
state = np.asarray(state).reshape((4, 4))
shift_rows(state)
state = mix_columns(state)
add_round_key(state, round_key[16:32])
state = [dfed0_table[a] for a in np.asarray(state).flatten()]
state = np.asarray(state).reshape((4, 4))
shift_rows(state)
add_round_key(state, round_key[32:])
add_round_key(state, round_key[16:32])
return bytes(state.flatten().astype(np.uint8))
def aes_encrypt(message: bytes):
assert len(message) % 16 == 0, 'Message must be padded for AES block size!'
encrypted_msg = b''
iv = bytes.fromhex("ea180a0336ed352fcd24e4d50018ae54")
for i in range(0, len(message), 16):
msg = message[i:i+16] if iv is None else bytes([a ^ b for a, b in zip(message[i:i+16], iv)])
iv = round_encrypt(msg)
encrypted_msg += iv
return encrypted_msg
aes_result = aes_encrypt(aes_lite_in_bytes)
logger.debug(f"aes lite result: {aes_result.hex()}")