首页
社区
课程
招聘
KCTF2019Q4 第六题 三道八佛
发表于: 2019-12-13 17:41 4017

KCTF2019Q4 第六题 三道八佛

2019-12-13 17:41
4017

这题和Q3的第10题一样的处理方式, 之前写的脚本可以直接用

把最后一层解密得到的asm丢到ida, 可以直接得到化简后的函数
函数比较简单, 直接用z3求解

1. 嵌套式smc
把每层smc的解密指令, 数据起始, 数据大小都提取出来
最后得到完全解密的exe
# import StringIO
# import cProfile
# import pstats
import struct
import time
from capstone import *
from capstone.x86 import *


DATA_DIR = 'D:\\work\\2019\\pediy_Q4\\6\\'


def hex2bin(s):
    return s.decode('hex')


def bin2hex(s):
    return s.encode('hex')


def load_file(filename):
    f = open(filename, 'rb')
    s = f.read()
    f.close()
    return s


def save_file(filename, s):
    f = open(filename, 'wb')
    f.write(s)
    f.close()
    return


def rol8(v, n):
    n &= 7
    if n == 0:
        return v
    return ((v << n) | (v >> (8 - n))) & 0xFF


def ror8(v, n):
    n &= 7
    if n == 0:
        return v
    return ((v >> n) | (v << (8 - n))) & 0xFF


def unwrap_u32(s, offset):
    return struct.unpack('<I', s[offset:offset+4])[0]


def unwrap_u8(s, offset):
    return struct.unpack('<B', s[offset:offset+1])[0]


def va_to_offset(va):
    return va - (0x401000 - 0x400)


def offset_to_va(offset):
    return offset + (0x401000 - 0x400)


class Pattern(object):
    def __init__(self, offset, size, pattern=''):
        self.offset = offset
        self.size = size
        self.pattern = pattern

    def place_holder(self):
        return self.pattern == ''

    def __str__(self):
        if not self.place_holder():
            return 'PatternInfo(offset:%d, hex:%s)' % (self.offset, bin2hex(self.pattern))
        return 'PatternInfo(offset:%d, size:%d)' % (self.offset, self.size)

    def __repr__(self):
        return str(self)


class PatternObject(object):
    def __init__(self, pattern=''):
        self.patterns = []
        offset = 0
        while True:
            i = pattern.find('*', offset)
            if i == -1:
                # right
                right_pattern = hex2bin(pattern[offset:])
                self.patterns.append(Pattern(offset / 2, len(right_pattern), right_pattern))
                break
            # left
            left_pattern = hex2bin(pattern[offset:i])
            self.patterns.append(Pattern(offset / 2, len(left_pattern), left_pattern))
            # middle
            k = i + 1
            while pattern[k] == '*':
                k += 1
            self.patterns.append(Pattern(i / 2, (k - i) / 2))
            offset = k

    def first_pattern(self):
        return self.patterns[0].pattern

    def match(self, buf, offset):
        i = offset
        for pat in self.patterns:
            if pat.place_holder():
                i += pat.size
            elif pat.pattern == buf[i:i+pat.size]:
                i += pat.size
            else:
                return False
        return True


def find_pattern(buf, pattern, offset=0):
    pat = PatternObject(pattern)
    size = len(buf)
    while offset < size:
        # find first pattern
        i = buf.find(pat.first_pattern(), offset, offset + size)
        if i == -1:
            return -1
        if pat.match(buf, i):
            return i
        offset += 1
    return -1


class Instruction(object):
    def __init__(self, address=0, imm=0):
        self.address = address
        self.imm = imm
        return

    def get_address(self):
        return self.address

    def get_imm(self):
        return self.imm

    def do(self, v, counter):  # type:(int, int) -> int
        return 0

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return ''


class Instruction8(Instruction):
    def get_imm(self):
        return self.imm & 0xFF


class InstructionMov8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: mov al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return self.get_imm()

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = 0x%02X' % (var_name, self.get_imm())


class InstructionNeg8(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: neg al' % self.get_address()

    def do(self, v, counter):
        return 0 - v

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (0 - %s) & 0xFF' % (var_name, var_name)


class InstructionNot8(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: not al' % self.get_address()

    def do(self, v, counter):
        return ~v

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (~%s) & 0xFF' % (var_name, var_name)


class InstructionSub8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: sub al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return v - self.get_imm()

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s - 0x%02X) & 0xFF' % (var_name, var_name, self.get_imm())


class InstructionSub8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: sub al,cl' % self.get_address()

    def do(self, v, counter):
        return v - counter

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s - %s) & 0xFF' % (var_name, var_name, counter_name)


class InstructionAdd8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: add al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return v + self.get_imm()

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s + 0x%02X) & 0xFF' % (var_name, var_name, self.get_imm())


class InstructionAdd8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: add al,cl' % self.get_address()

    def do(self, v, counter):
        return v + counter

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s + %s) & 0xFF' % (var_name, var_name, counter_name)


class InstructionXor8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: xor al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return v ^ self.get_imm()

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s ^ 0x%02X) & 0xFF' % (var_name, var_name, self.get_imm())


class InstructionXor8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: xor al,cl' % self.get_address()

    def do(self, v, counter):
        return v ^ counter

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s ^ %s) & 0xFF' % (var_name, var_name, counter_name)


class InstructionMul8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: mul cl' % self.get_address()

    def do(self, v, counter):
        return v * counter

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s * %s) & 0xFF' % (var_name, var_name, counter_name)


class InstructionXor8Expression(Instruction8):
    def __init__(self, address, expr):
        Instruction.__init__(self, address=address)
        self.expr = expr

    def __repr__(self):
        s = ''
        s += '%08X: xor al,bl' % self.get_address()
        for ins in self.expr:
            s += '\n\t%s' % ins
        return s

    def do(self, v, counter):
        t = 0
        for ins in self.expr:  # type:Instruction
            t = ins.do(t, counter) & 0xFF
        return v ^ t

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        s = 't = 0' + '; '
        for ins in self.expr:  # type:Instruction
            s += ins.gen_code('t', counter_name) + '; '
        s += 'v = v ^ t'
        return s


class InstructionRor8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: ror al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return ror8(v, self.get_imm())

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        t = self.get_imm() & 7
        if t == 0:
            return ''
        s = 't = 0x%02X' % t + '; '
        s += '%s = ((%s >> t) | (%s << (8 - t))) & 0xFF' % (var_name, var_name, var_name)
        return s


class InstructionRor8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: ror al,cl' % self.get_address()

    def do(self, v, counter):
        return ror8(v, counter)

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        s = 't = %s & 7' % counter_name + '; '
        s += '%s = ((%s >> t) | (%s << (8 - t))) & 0xFF' % (var_name, var_name, var_name)
        return s


class InstructionRol8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: rol al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return rol8(v, self.get_imm())

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        t = self.get_imm() & 7
        if t == 0:
            return ''
        s = 't = 0x%02X' % t + '; '
        s += '%s = ((%s << t) | (%s >> (8 - t))) & 0xFF' % (var_name, var_name, var_name)
        return s


class InstructionRol8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: rol al,cl' % self.get_address()

    def do(self, v, counter):
        return rol8(v, counter)

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        s = 't = %s & 7' % counter_name + '; '
        s += '%s = ((%s << t) | (%s >> (8 - t))) & 0xFF' % (var_name, var_name, var_name)
        return s


def pe_get_code_partial(ary, va, size=0x800):
    offset = va_to_offset(va)
    buf = bytes(ary[offset:offset+size])
    return buf


def simplify_inst(ins_ary):
    # return ins_ary
    if len(ins_ary) == 0:
        return ins_ary
    simplified_ary = []
    for old_ins in ins_ary:
        if len(simplified_ary) == 0:
            simplified_ary.append(old_ins)
            continue
        ins = simplified_ary[-1]
        if isinstance(ins, InstructionAdd8) and isinstance(old_ins, InstructionAdd8):
            simplified_ary[-1] = InstructionAdd8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionAdd8) and isinstance(old_ins, InstructionSub8):
            simplified_ary[-1] = InstructionAdd8(ins.get_address(), ins.get_imm() - old_ins.get_imm())
            continue
        if isinstance(ins, InstructionSub8) and isinstance(old_ins, InstructionAdd8):
            simplified_ary[-1] = InstructionSub8(ins.get_address(), ins.get_imm() - old_ins.get_imm())
            continue
        if isinstance(ins, InstructionSub8) and isinstance(old_ins, InstructionSub8):
            simplified_ary[-1] = InstructionSub8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionXor8) and isinstance(old_ins, InstructionXor8):
            simplified_ary[-1] = InstructionXor8(ins.get_address(), ins.get_imm() ^ old_ins.get_imm())
            continue
        if isinstance(ins, InstructionRol8) and isinstance(old_ins, InstructionRol8):
            simplified_ary[-1] = InstructionRol8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionRor8) and isinstance(old_ins, InstructionRor8):
            simplified_ary[-1] = InstructionRor8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionRol8) and isinstance(old_ins, InstructionRor8):
            v1 = ins.get_imm()
            v2 = old_ins.get_imm()
            if v1 > v2:
                simplified_ary[-1] = InstructionRol8(ins.get_address(), v1 - v2)
            else:
                simplified_ary[-1] = InstructionRor8(ins.get_address(), v2 - v1)
            continue
        if isinstance(ins, InstructionRor8) and isinstance(old_ins, InstructionRol8):
            v1 = ins.get_imm()
            v2 = old_ins.get_imm()
            if v1 > v2:
                simplified_ary[-1] = InstructionRor8(ins.get_address(), v1 - v2)
            else:
                simplified_ary[-1] = InstructionRol8(ins.get_address(), v2 - v1)
            continue
        simplified_ary.append(old_ins)
    return simplified_ary


def pe_decrypt_code_partial(smc_index, ary, va, size, ins_ary, counter_type=0):
    # print('counter_type: %d' % counter_type)
    # for ins in ins_ary:
    #     print('%s' % ins)
    simplified_ary = simplify_inst(ins_ary)
    # print('simplified')
    # for ins in simplified_ary:
    #     print('%s' % ins)

    offset = va_to_offset(va)
    smc_routine_name = 'smc%d' % smc_index

    def source_line(tab_count, source):
        return '    ' * tab_count + source + '\n'

    smc_routine = ''
    smc_routine += source_line(0, 'def %s(ary, offset, size):' % smc_routine_name)
    smc_routine += source_line(1, 'for i in range(size):')
    smc_routine += source_line(2, 'v = ary[offset + i]')
    if counter_type == 0:
        smc_routine += source_line(2, 'counter = (size - i) & 0xFF')
    elif counter_type == 1:
        smc_routine += source_line(2, 'counter = (i + 1) & 0xFF')
    elif counter_type == 2:
        smc_routine += source_line(2, 'counter = (size - 1 - i) & 0xFF')
    else:
        smc_routine += source_line(2, 'counter = 0')
    # eliminate calls in loop
    for ins in simplified_ary:  # type: Instruction
        smc_routine += source_line(2, ins.gen_code('v', 'counter'))
    smc_routine += source_line(2, 'ary[offset + i] = v')
    # print('smc_routine:\n%s' % smc_routine)
    exec(smc_routine, globals())
    globals()[smc_routine_name](ary, offset, size)
    return


patch_enc_va_info = dict()


def pe_smc_decrypt(smc_index, ary, enc_va, enc_size, initial_va, counter_type):
    va = initial_va
    ins_ary = []
    md = Cs(CS_ARCH_X86, CS_MODE_32)
    md.detail = True

    # search for decrypt start
    t = -1
    for i in md.disasm(pe_get_code_partial(ary, va, size=0x80), va):  # type:CsInsn
        if i.id == X86_INS_LODSB:
            # print('%08X: lodsb' % i.address)
            t = i.address + i.size
            break
    if t == -1:
        return -1
    va = t

    # begin(inclusive), end
    jmp_ary = []

    # push ecx; mov cl,0xA7; xor al,cl; pop ecx
    ecx_pushed = False
    ecx_popped = False
    ecx_imm = 0  # TODO need inst_ary? like ebx?

    # push eax; op eax; ...; mov ebx,eax; pop eax
    eax_pushed = False
    eax_popped = False
    ebx_inst_ary = []

    while va < (initial_va + 0x800):
        for i in md.disasm(pe_get_code_partial(ary, va, size=0x80), va):  # type:CsInsn
            va = i.address + i.size
            # skip code in jmp area
            in_jmp_area = False
            for begin_address, end_address in jmp_ary:
                if begin_address <= i.address < end_address:
                    in_jmp_area = True
                    break
            if in_jmp_area:
                continue

            # print("%x:\t%s\t%s" % (i.address, i.mnemonic, i.op_str))

            if len(i.operands) == 1:
                op0 = i.operands[0]  # type: X86Op
                if i.id == X86_INS_JMP and op0.type == X86_OP_IMM:
                    # print('jmp: %08X-%08X' % (i.address + i.size, op0.imm))
                    jmp_ary.append((i.address + i.size, op0.imm))
                    continue

            if len(i.operands) == 0:
                if i.id == X86_INS_PUSHAL:
                    eax_pushed = True
                elif i.id == X86_INS_POPAL:
                    eax_popped = True

            if len(i.operands) == 1:
                op0 = i.operands[0]  # type: X86Op
                if i.id == X86_INS_PUSH and op0.type == X86_OP_REG:
                    if op0.reg == X86_REG_EAX:
                        eax_pushed = True
                    elif op0.reg == X86_REG_ECX:
                        ecx_pushed = True
                    continue
                if i.id == X86_INS_POP and op0.type == X86_OP_REG:
                    if op0.reg == X86_REG_EAX:
                        eax_popped = True
                    elif op0.reg == X86_REG_ECX:
                        ecx_popped = True
                    continue

            if eax_pushed and (not eax_popped):
                if len(i.operands) == 2:
                    op0 = i.operands[0]  # type: X86Op
                    op1 = i.operands[1]  # type:X86Op
                    if op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)):
                        # OP EAX/AL, IMM/REG
                        if op1.type == X86_OP_IMM:
                            if i.id == X86_INS_MOV:
                                ebx_inst_ary.append(InstructionMov8(i.address, op1.imm))
                            elif i.id == X86_INS_SUB:
                                ebx_inst_ary.append(InstructionSub8(i.address, op1.imm))
                            elif i.id == X86_INS_ADD:
                                ebx_inst_ary.append(InstructionAdd8(i.address, op1.imm))
                            elif i.id == X86_INS_XOR:
                                ebx_inst_ary.append(InstructionXor8(i.address, op1.imm))
                            elif i.id == X86_INS_ROL:
                                ebx_inst_ary.append(InstructionRol8(i.address, op1.imm))
                            elif i.id == X86_INS_ROR:
                                ebx_inst_ary.append(InstructionRor8(i.address, op1.imm))
                        # OP EAX/AL, ECX/CL
                        elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_CL, X86_REG_ECX)):
                            if i.id == X86_INS_SUB:
                                ebx_inst_ary.append(InstructionSub8Counter(i.address))
                            elif i.id == X86_INS_ADD:
                                ebx_inst_ary.append(InstructionAdd8Counter(i.address))
                            elif i.id == X86_INS_XOR:
                                ebx_inst_ary.append(InstructionXor8Counter(i.address))
                            elif i.id == X86_INS_ROL:
                                ebx_inst_ary.append(InstructionRol8Counter(i.address))
                            elif i.id == X86_INS_ROR:
                                ebx_inst_ary.append(InstructionRor8Counter(i.address))
                elif len(i.operands) == 1:
                    op0 = i.operands[0]  # type: X86Op
                    if op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)):
                        if i.id == X86_INS_NEG:
                            ebx_inst_ary.append(InstructionNeg8(i.address))
                        elif i.id == X86_INS_NOT:
                            ebx_inst_ary.append(InstructionNot8(i.address))
                    if i.id == X86_INS_MUL and op0.type == X86_OP_REG and (op0.reg in (X86_REG_CL, X86_REG_ECX)):
                        ebx_inst_ary.append(InstructionMul8Counter(i.address))
                continue  # until we meet pop eax

            if len(i.operands) == 2:
                op0 = i.operands[0]  # type: X86Op
                op1 = i.operands[1]  # type:X86Op
                # mov [esi-1],al (CLD)
                # mov [esi+1],al (STD)
                if i.id == X86_INS_MOV and op0.type == X86_OP_MEM and op0.mem.base == X86_REG_ESI and op1.type == X86_OP_REG and op1.reg == X86_REG_AL:
                    # collect decrypt instruction finished
                    if len(ins_ary) == 0:
                        return -1
                    if op0.mem.disp == 1:
                        counter_type = 1
                    tmp_ins = ins_ary[-1]
                    if isinstance(tmp_ins, InstructionRor8) and tmp_ins.get_imm() == 0xD8:
                        if enc_va in patch_enc_va_info:
                            patch_enc_va_info[enc_va] += 1
                        else:
                            patch_enc_va_info[enc_va] = 1
                    if enc_va in patch_enc_va_info and patch_enc_va_info[enc_va] == 3:
                        ary[va_to_offset(enc_va)] = 0xE8  # ugly hack
                        return va
                    pe_decrypt_code_partial(smc_index, ary, enc_va, enc_size, ins_ary, counter_type=counter_type)
                    return va
                elif i.id == X86_INS_MOV and op0.type == X86_OP_REG and op0.reg == X86_REG_CL:
                    # mov cl, A7
                    # mov cl, ah
                    if not ecx_pushed:
                        continue
                    if op1.type == X86_OP_IMM:
                        ecx_imm = op1.imm
                    elif op1.type == X86_OP_REG and op1.reg == X86_REG_AH:
                        ecx_imm = 0
                elif op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)):
                    # OP EAX/AL, IMM/REG
                    if op1.type == X86_OP_IMM:
                        if i.id == X86_INS_SUB:
                            ins_ary.append(InstructionSub8(i.address, op1.imm))
                        elif i.id == X86_INS_ADD:
                            ins_ary.append(InstructionAdd8(i.address, op1.imm))
                        elif i.id == X86_INS_XOR:
                            ins_ary.append(InstructionXor8(i.address, op1.imm))
                        elif i.id == X86_INS_ROL:
                            ins_ary.append(InstructionRol8(i.address, op1.imm))
                        elif i.id == X86_INS_ROR:
                            ins_ary.append(InstructionRor8(i.address, op1.imm))
                    # OP EAX/AL, EBX/BL
                    elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_BL,)):
                        if i.id == X86_INS_XOR:
                            ins_ary.append(InstructionXor8Expression(i.address, ebx_inst_ary))
                    # OP EAX/AL, ECX/CL
                    elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_CL, X86_REG_ECX)):
                        if ecx_pushed and (not ecx_popped):
                            if i.id == X86_INS_SUB:
                                ins_ary.append(InstructionSub8(i.address, ecx_imm))
                            elif i.id == X86_INS_ADD:
                                ins_ary.append(InstructionAdd8(i.address, ecx_imm))
                            elif i.id == X86_INS_XOR:
                                ins_ary.append(InstructionXor8(i.address, ecx_imm))
                            elif i.id == X86_INS_ROL:
                                ins_ary.append(InstructionRol8(i.address, ecx_imm))
                            elif i.id == X86_INS_ROR:
                                ins_ary.append(InstructionRor8(i.address, ecx_imm))
                        else:
                            if i.id == X86_INS_SUB:
                                ins_ary.append(InstructionSub8Counter(i.address))
                            elif i.id == X86_INS_ADD:
                                ins_ary.append(InstructionAdd8Counter(i.address))
                            elif i.id == X86_INS_XOR:
                                ins_ary.append(InstructionXor8Counter(i.address))
                            elif i.id == X86_INS_ROL:
                                ins_ary.append(InstructionRol8Counter(i.address))
                            elif i.id == X86_INS_ROR:
                                ins_ary.append(InstructionRor8Counter(i.address))
            elif len(i.operands) == 1:
                op0 = i.operands[0]  # type: X86Op
                if op0.type == X86_OP_REG and op0.reg in (X86_REG_AL, X86_REG_EAX):
                    if i.id == X86_INS_NEG:
                        ins_ary.append(InstructionNeg8(i.address))
                    elif i.id == X86_INS_NOT:
                        ins_ary.append(InstructionNot8(i.address))
    return -1


def get_current_time():
    now = int(time.time())
    t = time.localtime(now)
    return time.strftime("%Y-%m-%d %H:%M:%S", t)


def pe_smc_decrypt_repeated(ary, va_start, va_end=0x506000):  # type:(str, int, int) -> None
    smc_index = 0
    va = va_start
    while va < va_end:
        # it's repeated-smc, no point to get full code
        buf = pe_get_code_partial(ary, va)
        # call $+5; pop esi; sub esi,imm1; add esi,imm2; mov ecx,imm3
        # call $+5; pop esi; sub esi,imm1; add esi,imm2; jmp xx; mov ecx,imm3
        pat = 'E8000000005E81EE******0081C6******00'
        i = find_pattern(buf, pat)
        if i == -1:
            break
        jmp_offset = i + len(pat)/2
        if buf[jmp_offset] == hex2bin('EB'):  # jmp $+XX
            mov_offset = jmp_offset + unwrap_u8(buf, jmp_offset+1) + 2
        else:
            mov_offset = jmp_offset
        # mov ecx,imm3
        if buf[mov_offset] != hex2bin('B9') or buf[mov_offset+4] != hex2bin('00'):
            va += jmp_offset
            continue
        ip = i + 5
        enc_va = (va + ip) + unwrap_u32(buf, ip + 9) - unwrap_u32(buf, ip + 3)
        enc_size = unwrap_u32(buf, mov_offset + 1)

        # sub ecx,1
        sub_ecx_offset = mov_offset + 5
        while buf[sub_ecx_offset] == hex2bin('EB'):
            sub_ecx_offset = sub_ecx_offset + unwrap_u8(buf, sub_ecx_offset + 1) + 2
        if buf[sub_ecx_offset:sub_ecx_offset+3] == hex2bin('83E901'):
            counter_type = 2
        else:
            counter_type = 0

        # position independent code
        last_pic_va = va + i
        print('[%s] smc_index: %08d, pic_va: %08x, enc_va: %08x, enc_size: %08x' % (get_current_time(), smc_index, last_pic_va, enc_va, enc_size))

        # pr = cProfile.Profile()
        # pr.enable()
        new_va = pe_smc_decrypt(smc_index, ary, enc_va, enc_size, last_pic_va, counter_type)
        # pr.disable()
        # s = StringIO.StringIO()
        # ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
        # ps.print_stats()
        # print s.getvalue()

        print('[%s] decrypt done, new_va: %08x' % (get_current_time(), new_va))
        if new_va == -1:
            break
        va = new_va
        smc_index += 1
    return


def test_decrypt():
    va = 0x00401629
    buf = load_file(DATA_DIR + 'CrackMe.exe')
    ary = bytearray(buf)
    # pe_smc_decrypt_repeated(ary, va, va_end=0x004016ac)
    pe_smc_decrypt_repeated(ary, va)
    save_file(DATA_DIR + 'CM_fix.exe', bytes(ary))
    return


test_decrypt()

# import StringIO
# import cProfile
# import pstats
import struct
import time
from capstone import *
from capstone.x86 import *


DATA_DIR = 'D:\\work\\2019\\pediy_Q4\\6\\'


def hex2bin(s):
    return s.decode('hex')


def bin2hex(s):
    return s.encode('hex')


def load_file(filename):
    f = open(filename, 'rb')
    s = f.read()
    f.close()
    return s


def save_file(filename, s):
    f = open(filename, 'wb')
    f.write(s)
    f.close()
    return


def rol8(v, n):
    n &= 7
    if n == 0:
        return v
    return ((v << n) | (v >> (8 - n))) & 0xFF


def ror8(v, n):
    n &= 7
    if n == 0:
        return v
    return ((v >> n) | (v << (8 - n))) & 0xFF


def unwrap_u32(s, offset):
    return struct.unpack('<I', s[offset:offset+4])[0]


def unwrap_u8(s, offset):
    return struct.unpack('<B', s[offset:offset+1])[0]


def va_to_offset(va):
    return va - (0x401000 - 0x400)


def offset_to_va(offset):
    return offset + (0x401000 - 0x400)


class Pattern(object):
    def __init__(self, offset, size, pattern=''):
        self.offset = offset
        self.size = size
        self.pattern = pattern

    def place_holder(self):
        return self.pattern == ''

    def __str__(self):
        if not self.place_holder():
            return 'PatternInfo(offset:%d, hex:%s)' % (self.offset, bin2hex(self.pattern))
        return 'PatternInfo(offset:%d, size:%d)' % (self.offset, self.size)

    def __repr__(self):
        return str(self)


class PatternObject(object):
    def __init__(self, pattern=''):
        self.patterns = []
        offset = 0
        while True:
            i = pattern.find('*', offset)
            if i == -1:
                # right
                right_pattern = hex2bin(pattern[offset:])
                self.patterns.append(Pattern(offset / 2, len(right_pattern), right_pattern))
                break
            # left
            left_pattern = hex2bin(pattern[offset:i])
            self.patterns.append(Pattern(offset / 2, len(left_pattern), left_pattern))
            # middle
            k = i + 1
            while pattern[k] == '*':
                k += 1
            self.patterns.append(Pattern(i / 2, (k - i) / 2))
            offset = k

    def first_pattern(self):
        return self.patterns[0].pattern

    def match(self, buf, offset):
        i = offset
        for pat in self.patterns:
            if pat.place_holder():
                i += pat.size
            elif pat.pattern == buf[i:i+pat.size]:
                i += pat.size
            else:
                return False
        return True


def find_pattern(buf, pattern, offset=0):
    pat = PatternObject(pattern)
    size = len(buf)
    while offset < size:
        # find first pattern
        i = buf.find(pat.first_pattern(), offset, offset + size)
        if i == -1:
            return -1
        if pat.match(buf, i):
            return i
        offset += 1
    return -1


class Instruction(object):
    def __init__(self, address=0, imm=0):
        self.address = address
        self.imm = imm
        return

    def get_address(self):
        return self.address

    def get_imm(self):
        return self.imm

    def do(self, v, counter):  # type:(int, int) -> int
        return 0

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return ''


class Instruction8(Instruction):
    def get_imm(self):
        return self.imm & 0xFF


class InstructionMov8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: mov al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return self.get_imm()

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = 0x%02X' % (var_name, self.get_imm())


class InstructionNeg8(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: neg al' % self.get_address()

    def do(self, v, counter):
        return 0 - v

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (0 - %s) & 0xFF' % (var_name, var_name)


class InstructionNot8(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: not al' % self.get_address()

    def do(self, v, counter):
        return ~v

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (~%s) & 0xFF' % (var_name, var_name)


class InstructionSub8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: sub al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return v - self.get_imm()

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s - 0x%02X) & 0xFF' % (var_name, var_name, self.get_imm())


class InstructionSub8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: sub al,cl' % self.get_address()

    def do(self, v, counter):
        return v - counter

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s - %s) & 0xFF' % (var_name, var_name, counter_name)


class InstructionAdd8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: add al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return v + self.get_imm()

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s + 0x%02X) & 0xFF' % (var_name, var_name, self.get_imm())


class InstructionAdd8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: add al,cl' % self.get_address()

    def do(self, v, counter):
        return v + counter

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s + %s) & 0xFF' % (var_name, var_name, counter_name)


class InstructionXor8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: xor al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return v ^ self.get_imm()

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s ^ 0x%02X) & 0xFF' % (var_name, var_name, self.get_imm())


class InstructionXor8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: xor al,cl' % self.get_address()

    def do(self, v, counter):
        return v ^ counter

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s ^ %s) & 0xFF' % (var_name, var_name, counter_name)


class InstructionMul8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: mul cl' % self.get_address()

    def do(self, v, counter):
        return v * counter

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        return '%s = (%s * %s) & 0xFF' % (var_name, var_name, counter_name)


class InstructionXor8Expression(Instruction8):
    def __init__(self, address, expr):
        Instruction.__init__(self, address=address)
        self.expr = expr

    def __repr__(self):
        s = ''
        s += '%08X: xor al,bl' % self.get_address()
        for ins in self.expr:
            s += '\n\t%s' % ins
        return s

    def do(self, v, counter):
        t = 0
        for ins in self.expr:  # type:Instruction
            t = ins.do(t, counter) & 0xFF
        return v ^ t

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        s = 't = 0' + '; '
        for ins in self.expr:  # type:Instruction
            s += ins.gen_code('t', counter_name) + '; '
        s += 'v = v ^ t'
        return s


class InstructionRor8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: ror al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return ror8(v, self.get_imm())

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        t = self.get_imm() & 7
        if t == 0:
            return ''
        s = 't = 0x%02X' % t + '; '
        s += '%s = ((%s >> t) | (%s << (8 - t))) & 0xFF' % (var_name, var_name, var_name)
        return s


class InstructionRor8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: ror al,cl' % self.get_address()

    def do(self, v, counter):
        return ror8(v, counter)

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        s = 't = %s & 7' % counter_name + '; '
        s += '%s = ((%s >> t) | (%s << (8 - t))) & 0xFF' % (var_name, var_name, var_name)
        return s


class InstructionRol8(Instruction8):
    def __init__(self, address, imm):
        Instruction.__init__(self, address=address, imm=imm)

    def __repr__(self):
        return '%08X: rol al,%02X' % (self.get_address(), self.get_imm())

    def do(self, v, counter):
        return rol8(v, self.get_imm())

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        t = self.get_imm() & 7
        if t == 0:
            return ''
        s = 't = 0x%02X' % t + '; '
        s += '%s = ((%s << t) | (%s >> (8 - t))) & 0xFF' % (var_name, var_name, var_name)
        return s


class InstructionRol8Counter(Instruction8):
    def __init__(self, address):
        Instruction.__init__(self, address=address)

    def __repr__(self):
        return '%08X: rol al,cl' % self.get_address()

    def do(self, v, counter):
        return rol8(v, counter)

    def gen_code(self, var_name, counter_name):  # type:(str, str) -> str
        s = 't = %s & 7' % counter_name + '; '
        s += '%s = ((%s << t) | (%s >> (8 - t))) & 0xFF' % (var_name, var_name, var_name)
        return s


def pe_get_code_partial(ary, va, size=0x800):
    offset = va_to_offset(va)
    buf = bytes(ary[offset:offset+size])
    return buf


def simplify_inst(ins_ary):
    # return ins_ary
    if len(ins_ary) == 0:
        return ins_ary
    simplified_ary = []
    for old_ins in ins_ary:
        if len(simplified_ary) == 0:
            simplified_ary.append(old_ins)
            continue
        ins = simplified_ary[-1]
        if isinstance(ins, InstructionAdd8) and isinstance(old_ins, InstructionAdd8):
            simplified_ary[-1] = InstructionAdd8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionAdd8) and isinstance(old_ins, InstructionSub8):
            simplified_ary[-1] = InstructionAdd8(ins.get_address(), ins.get_imm() - old_ins.get_imm())
            continue
        if isinstance(ins, InstructionSub8) and isinstance(old_ins, InstructionAdd8):
            simplified_ary[-1] = InstructionSub8(ins.get_address(), ins.get_imm() - old_ins.get_imm())
            continue
        if isinstance(ins, InstructionSub8) and isinstance(old_ins, InstructionSub8):
            simplified_ary[-1] = InstructionSub8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionXor8) and isinstance(old_ins, InstructionXor8):
            simplified_ary[-1] = InstructionXor8(ins.get_address(), ins.get_imm() ^ old_ins.get_imm())
            continue
        if isinstance(ins, InstructionRol8) and isinstance(old_ins, InstructionRol8):
            simplified_ary[-1] = InstructionRol8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionRor8) and isinstance(old_ins, InstructionRor8):
            simplified_ary[-1] = InstructionRor8(ins.get_address(), ins.get_imm() + old_ins.get_imm())
            continue
        if isinstance(ins, InstructionRol8) and isinstance(old_ins, InstructionRor8):
            v1 = ins.get_imm()
            v2 = old_ins.get_imm()
            if v1 > v2:
                simplified_ary[-1] = InstructionRol8(ins.get_address(), v1 - v2)
            else:
                simplified_ary[-1] = InstructionRor8(ins.get_address(), v2 - v1)
            continue
        if isinstance(ins, InstructionRor8) and isinstance(old_ins, InstructionRol8):
            v1 = ins.get_imm()
            v2 = old_ins.get_imm()
            if v1 > v2:
                simplified_ary[-1] = InstructionRor8(ins.get_address(), v1 - v2)
            else:
                simplified_ary[-1] = InstructionRol8(ins.get_address(), v2 - v1)
            continue
        simplified_ary.append(old_ins)
    return simplified_ary


def pe_decrypt_code_partial(smc_index, ary, va, size, ins_ary, counter_type=0):
    # print('counter_type: %d' % counter_type)
    # for ins in ins_ary:
    #     print('%s' % ins)
    simplified_ary = simplify_inst(ins_ary)
    # print('simplified')
    # for ins in simplified_ary:
    #     print('%s' % ins)

    offset = va_to_offset(va)
    smc_routine_name = 'smc%d' % smc_index

    def source_line(tab_count, source):
        return '    ' * tab_count + source + '\n'

    smc_routine = ''
    smc_routine += source_line(0, 'def %s(ary, offset, size):' % smc_routine_name)
    smc_routine += source_line(1, 'for i in range(size):')
    smc_routine += source_line(2, 'v = ary[offset + i]')
    if counter_type == 0:
        smc_routine += source_line(2, 'counter = (size - i) & 0xFF')
    elif counter_type == 1:
        smc_routine += source_line(2, 'counter = (i + 1) & 0xFF')
    elif counter_type == 2:
        smc_routine += source_line(2, 'counter = (size - 1 - i) & 0xFF')
    else:
        smc_routine += source_line(2, 'counter = 0')
    # eliminate calls in loop
    for ins in simplified_ary:  # type: Instruction
        smc_routine += source_line(2, ins.gen_code('v', 'counter'))
    smc_routine += source_line(2, 'ary[offset + i] = v')
    # print('smc_routine:\n%s' % smc_routine)
    exec(smc_routine, globals())
    globals()[smc_routine_name](ary, offset, size)
    return


patch_enc_va_info = dict()


def pe_smc_decrypt(smc_index, ary, enc_va, enc_size, initial_va, counter_type):
    va = initial_va
    ins_ary = []
    md = Cs(CS_ARCH_X86, CS_MODE_32)
    md.detail = True

    # search for decrypt start
    t = -1
    for i in md.disasm(pe_get_code_partial(ary, va, size=0x80), va):  # type:CsInsn
        if i.id == X86_INS_LODSB:
            # print('%08X: lodsb' % i.address)
            t = i.address + i.size
            break
    if t == -1:
        return -1
    va = t

    # begin(inclusive), end
    jmp_ary = []

    # push ecx; mov cl,0xA7; xor al,cl; pop ecx
    ecx_pushed = False
    ecx_popped = False
    ecx_imm = 0  # TODO need inst_ary? like ebx?

    # push eax; op eax; ...; mov ebx,eax; pop eax
    eax_pushed = False
    eax_popped = False
    ebx_inst_ary = []

    while va < (initial_va + 0x800):
        for i in md.disasm(pe_get_code_partial(ary, va, size=0x80), va):  # type:CsInsn
            va = i.address + i.size
            # skip code in jmp area
            in_jmp_area = False
            for begin_address, end_address in jmp_ary:
                if begin_address <= i.address < end_address:
                    in_jmp_area = True
                    break
            if in_jmp_area:
                continue

            # print("%x:\t%s\t%s" % (i.address, i.mnemonic, i.op_str))

            if len(i.operands) == 1:
                op0 = i.operands[0]  # type: X86Op
                if i.id == X86_INS_JMP and op0.type == X86_OP_IMM:
                    # print('jmp: %08X-%08X' % (i.address + i.size, op0.imm))
                    jmp_ary.append((i.address + i.size, op0.imm))
                    continue

            if len(i.operands) == 0:
                if i.id == X86_INS_PUSHAL:
                    eax_pushed = True
                elif i.id == X86_INS_POPAL:
                    eax_popped = True

            if len(i.operands) == 1:
                op0 = i.operands[0]  # type: X86Op
                if i.id == X86_INS_PUSH and op0.type == X86_OP_REG:
                    if op0.reg == X86_REG_EAX:
                        eax_pushed = True
                    elif op0.reg == X86_REG_ECX:
                        ecx_pushed = True
                    continue
                if i.id == X86_INS_POP and op0.type == X86_OP_REG:
                    if op0.reg == X86_REG_EAX:
                        eax_popped = True
                    elif op0.reg == X86_REG_ECX:
                        ecx_popped = True
                    continue

            if eax_pushed and (not eax_popped):
                if len(i.operands) == 2:
                    op0 = i.operands[0]  # type: X86Op
                    op1 = i.operands[1]  # type:X86Op
                    if op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)):
                        # OP EAX/AL, IMM/REG
                        if op1.type == X86_OP_IMM:
                            if i.id == X86_INS_MOV:
                                ebx_inst_ary.append(InstructionMov8(i.address, op1.imm))
                            elif i.id == X86_INS_SUB:
                                ebx_inst_ary.append(InstructionSub8(i.address, op1.imm))
                            elif i.id == X86_INS_ADD:
                                ebx_inst_ary.append(InstructionAdd8(i.address, op1.imm))
                            elif i.id == X86_INS_XOR:
                                ebx_inst_ary.append(InstructionXor8(i.address, op1.imm))
                            elif i.id == X86_INS_ROL:
                                ebx_inst_ary.append(InstructionRol8(i.address, op1.imm))
                            elif i.id == X86_INS_ROR:
                                ebx_inst_ary.append(InstructionRor8(i.address, op1.imm))
                        # OP EAX/AL, ECX/CL
                        elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_CL, X86_REG_ECX)):
                            if i.id == X86_INS_SUB:
                                ebx_inst_ary.append(InstructionSub8Counter(i.address))
                            elif i.id == X86_INS_ADD:
                                ebx_inst_ary.append(InstructionAdd8Counter(i.address))
                            elif i.id == X86_INS_XOR:
                                ebx_inst_ary.append(InstructionXor8Counter(i.address))
                            elif i.id == X86_INS_ROL:
                                ebx_inst_ary.append(InstructionRol8Counter(i.address))
                            elif i.id == X86_INS_ROR:
                                ebx_inst_ary.append(InstructionRor8Counter(i.address))
                elif len(i.operands) == 1:
                    op0 = i.operands[0]  # type: X86Op
                    if op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)):
                        if i.id == X86_INS_NEG:
                            ebx_inst_ary.append(InstructionNeg8(i.address))
                        elif i.id == X86_INS_NOT:
                            ebx_inst_ary.append(InstructionNot8(i.address))
                    if i.id == X86_INS_MUL and op0.type == X86_OP_REG and (op0.reg in (X86_REG_CL, X86_REG_ECX)):
                        ebx_inst_ary.append(InstructionMul8Counter(i.address))
                continue  # until we meet pop eax

            if len(i.operands) == 2:
                op0 = i.operands[0]  # type: X86Op
                op1 = i.operands[1]  # type:X86Op
                # mov [esi-1],al (CLD)
                # mov [esi+1],al (STD)
                if i.id == X86_INS_MOV and op0.type == X86_OP_MEM and op0.mem.base == X86_REG_ESI and op1.type == X86_OP_REG and op1.reg == X86_REG_AL:
                    # collect decrypt instruction finished
                    if len(ins_ary) == 0:
                        return -1
                    if op0.mem.disp == 1:
                        counter_type = 1
                    tmp_ins = ins_ary[-1]
                    if isinstance(tmp_ins, InstructionRor8) and tmp_ins.get_imm() == 0xD8:
                        if enc_va in patch_enc_va_info:
                            patch_enc_va_info[enc_va] += 1
                        else:
                            patch_enc_va_info[enc_va] = 1
                    if enc_va in patch_enc_va_info and patch_enc_va_info[enc_va] == 3:
                        ary[va_to_offset(enc_va)] = 0xE8  # ugly hack
                        return va
                    pe_decrypt_code_partial(smc_index, ary, enc_va, enc_size, ins_ary, counter_type=counter_type)
                    return va
                elif i.id == X86_INS_MOV and op0.type == X86_OP_REG and op0.reg == X86_REG_CL:
                    # mov cl, A7
                    # mov cl, ah
                    if not ecx_pushed:
                        continue
                    if op1.type == X86_OP_IMM:
                        ecx_imm = op1.imm
                    elif op1.type == X86_OP_REG and op1.reg == X86_REG_AH:
                        ecx_imm = 0
                elif op0.type == X86_OP_REG and (op0.reg in (X86_REG_AL, X86_REG_EAX)):
                    # OP EAX/AL, IMM/REG
                    if op1.type == X86_OP_IMM:
                        if i.id == X86_INS_SUB:
                            ins_ary.append(InstructionSub8(i.address, op1.imm))
                        elif i.id == X86_INS_ADD:
                            ins_ary.append(InstructionAdd8(i.address, op1.imm))
                        elif i.id == X86_INS_XOR:
                            ins_ary.append(InstructionXor8(i.address, op1.imm))
                        elif i.id == X86_INS_ROL:
                            ins_ary.append(InstructionRol8(i.address, op1.imm))
                        elif i.id == X86_INS_ROR:
                            ins_ary.append(InstructionRor8(i.address, op1.imm))
                    # OP EAX/AL, EBX/BL
                    elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_BL,)):
                        if i.id == X86_INS_XOR:
                            ins_ary.append(InstructionXor8Expression(i.address, ebx_inst_ary))
                    # OP EAX/AL, ECX/CL
                    elif op1.type == X86_OP_REG and (op1.reg in (X86_REG_CL, X86_REG_ECX)):
                        if ecx_pushed and (not ecx_popped):
                            if i.id == X86_INS_SUB:
                                ins_ary.append(InstructionSub8(i.address, ecx_imm))
                            elif i.id == X86_INS_ADD:
                                ins_ary.append(InstructionAdd8(i.address, ecx_imm))
                            elif i.id == X86_INS_XOR:
                                ins_ary.append(InstructionXor8(i.address, ecx_imm))
                            elif i.id == X86_INS_ROL:
                                ins_ary.append(InstructionRol8(i.address, ecx_imm))
                            elif i.id == X86_INS_ROR:
                                ins_ary.append(InstructionRor8(i.address, ecx_imm))
                        else:
                            if i.id == X86_INS_SUB:
                                ins_ary.append(InstructionSub8Counter(i.address))
                            elif i.id == X86_INS_ADD:
                                ins_ary.append(InstructionAdd8Counter(i.address))
                            elif i.id == X86_INS_XOR:
                                ins_ary.append(InstructionXor8Counter(i.address))
                            elif i.id == X86_INS_ROL:
                                ins_ary.append(InstructionRol8Counter(i.address))
                            elif i.id == X86_INS_ROR:
                                ins_ary.append(InstructionRor8Counter(i.address))
            elif len(i.operands) == 1:
                op0 = i.operands[0]  # type: X86Op
                if op0.type == X86_OP_REG and op0.reg in (X86_REG_AL, X86_REG_EAX):
                    if i.id == X86_INS_NEG:
                        ins_ary.append(InstructionNeg8(i.address))
                    elif i.id == X86_INS_NOT:
                        ins_ary.append(InstructionNot8(i.address))
    return -1


def get_current_time():
    now = int(time.time())
    t = time.localtime(now)
    return time.strftime("%Y-%m-%d %H:%M:%S", t)


def pe_smc_decrypt_repeated(ary, va_start, va_end=0x506000):  # type:(str, int, int) -> None
    smc_index = 0
    va = va_start
    while va < va_end:
        # it's repeated-smc, no point to get full code
        buf = pe_get_code_partial(ary, va)
        # call $+5; pop esi; sub esi,imm1; add esi,imm2; mov ecx,imm3
        # call $+5; pop esi; sub esi,imm1; add esi,imm2; jmp xx; mov ecx,imm3
        pat = 'E8000000005E81EE******0081C6******00'
        i = find_pattern(buf, pat)
        if i == -1:
            break
        jmp_offset = i + len(pat)/2
        if buf[jmp_offset] == hex2bin('EB'):  # jmp $+XX
            mov_offset = jmp_offset + unwrap_u8(buf, jmp_offset+1) + 2
        else:
            mov_offset = jmp_offset
        # mov ecx,imm3
        if buf[mov_offset] != hex2bin('B9') or buf[mov_offset+4] != hex2bin('00'):
            va += jmp_offset
            continue
        ip = i + 5
        enc_va = (va + ip) + unwrap_u32(buf, ip + 9) - unwrap_u32(buf, ip + 3)
        enc_size = unwrap_u32(buf, mov_offset + 1)

        # sub ecx,1
        sub_ecx_offset = mov_offset + 5
        while buf[sub_ecx_offset] == hex2bin('EB'):
            sub_ecx_offset = sub_ecx_offset + unwrap_u8(buf, sub_ecx_offset + 1) + 2
        if buf[sub_ecx_offset:sub_ecx_offset+3] == hex2bin('83E901'):
            counter_type = 2
        else:
            counter_type = 0

        # position independent code
        last_pic_va = va + i
        print('[%s] smc_index: %08d, pic_va: %08x, enc_va: %08x, enc_size: %08x' % (get_current_time(), smc_index, last_pic_va, enc_va, enc_size))

        # pr = cProfile.Profile()
        # pr.enable()
        new_va = pe_smc_decrypt(smc_index, ary, enc_va, enc_size, last_pic_va, counter_type)
        # pr.disable()
        # s = StringIO.StringIO()
        # ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
        # ps.print_stats()
        # print s.getvalue()

        print('[%s] decrypt done, new_va: %08x' % (get_current_time(), new_va))
        if new_va == -1:
            break
        va = new_va
        smc_index += 1
    return


def test_decrypt():
    va = 0x00401629
    buf = load_file(DATA_DIR + 'CrackMe.exe')
    ary = bytearray(buf)
    # pe_smc_decrypt_repeated(ary, va, va_end=0x004016ac)
    pe_smc_decrypt_repeated(ary, va)
    save_file(DATA_DIR + 'CM_fix.exe', bytes(ary))
    return


test_decrypt()

2. 提取函数
DATA_DIR = 'D:\\work\\2019\\pediy_Q4\\6\\'


def hex2bin(s):
    return s.decode('hex')


def bin2hex(s):
    return s.encode('hex')


def load_file(filename):
    f = open(filename, 'rb')
    s = f.read()
    f.close()
    return s


def save_file(filename, s):
    f = open(filename, 'wb')
    f.write(s)
    f.close()
    return


def va_to_offset(va):
    return va - (0x401000 - 0x400)


def offset_to_va(offset):
    return offset + (0x401000 - 0x400)


def test_extract():
    buf = load_file(DATA_DIR + 'CM_fix.exe')

    start_offset = va_to_offset(0x004FBB14)
    end_offset = va_to_offset(0x004FEF67)
    code = ''
    code += buf[start_offset:end_offset]
    code += hex2bin('C3')
    save_file(DATA_DIR + 'code.dat', code)

    start_offset = va_to_offset(0x004FF4AC)
    end_offset = start_offset + 0x1B0
    data = buf[start_offset:end_offset]
    save_file(DATA_DIR + 'data.dat', data)
    return


test_extract()

DATA_DIR = 'D:\\work\\2019\\pediy_Q4\\6\\'


def hex2bin(s):
    return s.decode('hex')


def bin2hex(s):
    return s.encode('hex')


def load_file(filename):
    f = open(filename, 'rb')
    s = f.read()
    f.close()
    return s


def save_file(filename, s):
    f = open(filename, 'wb')
    f.write(s)
    f.close()
    return


def va_to_offset(va):
    return va - (0x401000 - 0x400)


def offset_to_va(offset):
    return offset + (0x401000 - 0x400)


def test_extract():
    buf = load_file(DATA_DIR + 'CM_fix.exe')

    start_offset = va_to_offset(0x004FBB14)
    end_offset = va_to_offset(0x004FEF67)
    code = ''
    code += buf[start_offset:end_offset]
    code += hex2bin('C3')
    save_file(DATA_DIR + 'code.dat', code)

    start_offset = va_to_offset(0x004FF4AC)
    end_offset = start_offset + 0x1B0
    data = buf[start_offset:end_offset]
    save_file(DATA_DIR + 'data.dat', data)
    return


test_extract()

[招生]系统0day安全班,企业级设备固件漏洞挖掘,Linux平台漏洞挖掘!

最后于 2019-12-17 16:47 被风间仁编辑 ,原因: 修改了下解密exe的脚本, 消除大循环中的call, 显著缩短总运行时间
收藏
免费 1
支持
分享
最新回复 (0)
游客
登录 | 注册 方可回帖
返回
// // 统计代码