from
queue
import
*
import
ida_bytes
from
idc
import
*
import
idc
from
keystone
import
*
from
capstone
import
*
asmer
=
Ks(KS_ARCH_X86, KS_MODE_32)
disasmer
=
Cs(CS_ARCH_X86, CS_MODE_32)
def
disasm(machine_code, addr
=
0
):
l
=
""
for
i
in
disasmer.disasm(machine_code, addr):
l
+
=
"{:8s} {};\n"
.
format
(i.mnemonic, i.op_str)
return
l.strip(
'\n'
)
def
asm(asm_code, addr
=
0
):
l
=
b''
for
i
in
asmer.asm(asm_code, addr)[
0
]:
l
+
=
bytes([i])
return
l
def
print_asm(ea):
print
(disasm(idc.get_bytes(ea, idc.get_item_size(ea)), ea))
class
RelocDSU:
def
__init__(
self
):
self
.reloc
=
{}
def
get(
self
, ea):
if
ea
not
in
self
.reloc:
if
idc.print_insn_mnem(ea)
=
=
'jmp'
and
idc.get_operand_type(ea,
0
) !
=
idc.o_reg:
jmp_ea
=
idc.get_operand_value(ea,
0
)
if
idc.get_segm_name(jmp_ea)
=
=
'.got.plt'
:
self
.reloc[ea]
=
ea
return
self
.reloc[ea],
False
self
.reloc[ea], need_handle
=
self
.get(idc.get_operand_value(ea,
0
))
return
self
.reloc[ea], need_handle
else
:
self
.reloc[ea]
=
ea
if
self
.reloc[ea] !
=
ea:
self
.reloc[ea]
=
self
.get(
self
.reloc[ea])[
0
]
return
self
.reloc[ea], idc.get_segm_name(
self
.reloc[ea])
=
=
'.text'
def
merge(
self
, ea, reloc_ea):
self
.reloc[
self
.get(ea)[
0
]]
=
self
.get(reloc_ea)[
0
]
reloc
=
RelocDSU()
class
Block:
def
__init__(
self
, start_ea, end_ea, imm, reg, call_target):
self
.start_ea
=
start_ea
self
.end_ea
=
end_ea
self
.imm
=
imm
self
.reg
=
reg
self
.call_target
=
call_target
def
mov_code(ea, new_code_ea):
return
asm(disasm(idc.get_bytes(ea, idc.get_item_size(ea)), ea), new_code_ea)
def
get_real_code(block, new_code_ea):
ea
=
block.call_target
while
True
:
if
idc.print_insn_mnem(ea)
=
=
'cmp'
:
reg
=
idc.print_operand(ea,
0
)
imm
=
idc.get_operand_value(ea,
1
)
if
reg
=
=
block.reg
and
imm
=
=
block.imm:
ea
+
=
idc.get_item_size(ea)
break
ea
+
=
idc.get_item_size(ea)
assert
idc.print_insn_mnem(ea)
=
=
'jnz'
ea
+
=
idc.get_item_size(ea)
assert
idc.print_insn_mnem(ea)
=
=
'popa'
ea
+
=
idc.get_item_size(ea)
assert
idc.print_insn_mnem(ea)
=
=
'popf'
ea
+
=
idc.get_item_size(ea)
if
idc.print_insn_mnem(ea)
=
=
'pushf'
:
return
True
, asm(
'ret'
)
new_code
=
b''
while
True
:
if
idc.print_insn_mnem(ea)
=
=
'jmp'
:
jmp_ea
=
idc.get_operand_value(ea,
0
)
if
idc.print_insn_mnem(jmp_ea)
=
=
'pushf'
:
break
ea
=
jmp_ea
else
:
code
=
mov_code(ea, new_code_ea)
new_code
+
=
code
new_code_ea
+
=
len
(code)
ea
+
=
get_item_size(ea)
return
False
, new_code
def
get_block(start_ea):
global
imm, reg, call_target
mnem_list
=
[
'pushf'
,
'pusha'
,
'mov'
,
'call'
,
'pop'
]
ea
=
start_ea
for
i
in
range
(
5
):
mnem
=
idc.print_insn_mnem(ea)
assert
mnem
=
=
mnem_list[i]
if
mnem
=
=
'mov'
:
imm
=
idc.get_operand_value(ea,
1
)
reg
=
idc.print_operand(ea,
0
)
elif
mnem
=
=
'call'
:
call_target
=
idc.get_operand_value(ea,
0
)
ea
+
=
idc.get_item_size(ea)
return
Block(start_ea, ea, imm, reg, call_target)
def
handle_one_branch(branch_address, new_code_ea):
new_code
=
b''
ea
=
branch_address
while
True
:
try
:
block
=
get_block(ea)
is_ret, real_code
=
get_real_code(block, new_code_ea)
reloc.merge(ea, new_code_ea)
ea
=
block.end_ea
new_code_ea
+
=
len
(real_code)
new_code
+
=
real_code
if
is_ret:
break
except
:
get_eip_func
=
{
0x900
:
'ebx'
,
0x435c
:
'eax'
}
if
idc.print_insn_mnem(ea)
=
=
'call'
and
get_operand_value(ea,
0
)
in
get_eip_func:
reloc.merge(ea, new_code_ea)
real_code
=
asm(
'mov %s, 0x%x'
%
(get_eip_func[get_operand_value(ea,
0
)], ea
+
5
), new_code_ea)
else
:
if
idc.print_insn_mnem(ea)
=
=
'jmp'
and
idc.get_operand_type(ea,
0
) !
=
idc.o_reg:
reloc.merge(new_code_ea, ea)
else
:
reloc.merge(ea, new_code_ea)
real_code
=
mov_code(ea, new_code_ea)
new_code
+
=
real_code
if
real_code
=
=
asm(
'ret'
):
break
new_code_ea
+
=
len
(real_code)
if
idc.print_insn_mnem(ea)
=
=
'jmp'
and
idc.get_operand_type(ea,
0
) !
=
idc.o_reg:
jmp_ea
=
idc.get_operand_value(ea,
0
)
if
reloc.get(jmp_ea)[
1
]
=
=
False
:
break
ea
=
reloc.get(jmp_ea)[
0
]
else
:
ea
+
=
get_item_size(ea)
return
new_code
def
solve():
entry_point
=
0x48F4
new_code_start
=
0x96150
new_code_ea
=
new_code_start
jmp_table
=
(
0x892ac
,
0x8c000
)
for
_
in
range
(
0x10000
): idc.del_items(new_code_ea
+
_)
ida_bytes.patch_bytes(new_code_ea,
0x10000
*
b
'\x90'
)
func_queue
=
Queue()
func_queue.put(entry_point)
while
not
func_queue.empty():
func_address
=
func_queue.get()
if
reloc.get(func_address)[
1
]
=
=
False
:
continue
reloc.merge(func_address, new_code_ea)
branch_queue
=
Queue()
branch_queue.put(func_address)
if
func_address
=
=
0x4148
:
assert
new_code_ea
=
=
0x963d0
for
eax
in
range
(
0x20
):
jmp_target
=
(ida_bytes.get_dword(jmp_table[
0
]
+
eax
*
4
)
+
jmp_table[
1
]) &
0xFFFFFFFF
new_jmp_target, need_handle
=
reloc.get(jmp_target)
if
need_handle: branch_queue.put(jmp_target)
while
not
branch_queue.empty():
branch_address
=
branch_queue.get()
new_code
=
handle_one_branch(branch_address, new_code_ea)
ida_bytes.patch_bytes(new_code_ea, new_code)
ea
=
new_code_ea
while
ea < new_code_ea
+
len
(new_code):
idc.create_insn(ea)
if
idc.print_insn_mnem(ea)
=
=
'call'
:
call_target, need_handle
=
reloc.get(get_operand_value(ea,
0
))
if
need_handle: func_queue.put(call_target)
elif
idc.print_insn_mnem(ea)[
0
]
=
=
'j'
and
idc.get_operand_type(ea,
0
) !
=
idc.o_reg:
jcc_target, need_handle
=
reloc.get(get_operand_value(ea,
0
))
if
need_handle
=
=
True
:
branch_queue.put(jcc_target)
ea
+
=
get_item_size(ea)
new_code_ea
+
=
len
(new_code)
ea
=
new_code_start
while
ea < new_code_ea:
idc.create_insn(ea)
mnem
=
idc.print_insn_mnem(ea)
if
mnem
=
=
'call'
:
call_target, need_handle
=
reloc.get(get_operand_value(ea,
0
))
assert
need_handle
=
=
False
ida_bytes.patch_bytes(ea, asm(
'call 0x%x'
%
(call_target), ea))
elif
mnem[
0
]
=
=
'j'
and
idc.get_operand_type(ea,
0
) !
=
idc.o_reg:
jcc_target, need_handle
=
reloc.get(get_operand_value(ea,
0
))
assert
need_handle
=
=
False
ida_bytes.patch_bytes(ea, asm(
'%s 0x%x'
%
(mnem, jcc_target), ea).ljust(idc.get_item_size(ea), b
'\x90'
))
elif
mnem
=
=
'pushf'
:
ida_bytes.patch_bytes(ea, b
'\x90'
*
9
)
ea
+
=
9
continue
ea
+
=
get_item_size(ea)
new_jmp_table
=
(
0xA6000
-
0x2D54
,
0xA6000
)
for
eax
in
range
(
0x20
):
jmp_target
=
(ida_bytes.get_dword(jmp_table[
0
]
+
eax
*
4
)
+
jmp_table[
1
]) &
0xFFFFFFFF
new_jmp_target, need_handle
=
reloc.get(jmp_target)
assert
need_handle
=
=
False
ida_bytes.patch_dword(new_jmp_table[
0
]
+
eax
*
4
, (new_jmp_target
-
new_jmp_table[
1
]) &
0xFFFFFFFF
)
need_patch_addr
=
0x963D7
ida_bytes.patch_bytes(need_patch_addr, asm(
'call 0x900;add ebx, 0x%x'
%
(new_jmp_table[
1
]
-
(need_patch_addr
+
5
)), need_patch_addr))
ida_bytes.patch_bytes(new_jmp_table[
1
]
-
0x2d7a
, ida_bytes.get_bytes(jmp_table[
1
]
-
0x2d7a
,
0x26
))
for
_
in
range
(
0x10000
): idc.del_items(new_code_ea
+
_)
idc.jumpto(new_code_start)
ida_funcs.add_func(new_code_start)
print
(
"finish"
)
solve()