import
sys
import
logging
import
unicorn
import
capstone
import
keystone
logging.basicConfig(
stream
=
sys.stdout,
level
=
logging.DEBUG,
format
=
"%(asctime)s %(levelname)7s %(name)s | %(message)s"
)
logger
=
logging.getLogger(
"ollvm"
)
class
BlockItem:
start_addr:
int
=
None
end_addr:
int
=
None
jump_addr:
int
=
None
ins_list:
list
=
None
has_bl_ins:
bool
=
False
has_ret_ins:
bool
=
False
has_csel_ins:
bool
=
False
_sign:
str
=
None
@
property
def
sign(
self
):
if
self
._sign
is
not
None
:
return
self
._sign
self
._sign
=
''
for
ins
in
self
.ins_list:
self
._sign
+
=
ins.mnemonic
for
op
in
ins.operands:
self
._sign
+
=
str
(op.
type
)
return
self
._sign
@
property
def
csel_ins(
self
)
-
> capstone.CsInsn:
if
not
self
.has_csel_ins:
return
None
for
ins
in
self
.ins_list:
if
ins.
id
=
=
capstone.arm64.ARM64_INS_CSEL:
return
ins
return
None
@
property
def
last_ins(
self
)
-
> capstone.CsInsn:
return
self
.ins_list[
-
1
]
class
FuncContext():
real_blocks:
list
=
None
func_start_addr:
int
=
0
func_end_addr:
int
=
0
trace_list:
list
=
None
block_map:
dict
=
None
is_success:
bool
=
False
start_addr:
int
=
0
dist_addr:
int
=
0
branch_control:
int
=
1
def
UR(regname):
return
getattr
(unicorn.arm64_const,
'UC_ARM64_REG_{0}'
.
format
(regname.upper()),
None
)
class
FuckArm64Ollvm:
_END_GROUPS
=
[capstone.arm64.ARM64_GRP_RET, capstone.arm64.ARM64_GRP_BRANCH_RELATIVE, capstone.arm64.ARM64_GRP_JUMP]
_FAKE_BLOCK_SINGS
=
[
'movz12movk12b2'
,
'cmp11mov11b.ne2'
,
'movz12movk12cmp11b.eq2'
,
'movz12movk12cmp11mov11b.ne2'
,
'movz12movk12cmp11movz12movk12b.eq2'
,
]
_MU_SKIP_INS
=
[
'bl'
]
def
__init__(
self
, so_path:
str
):
self
.so_path
=
so_path
self
.so_bin
=
self
._open_so(so_path)
self
.capstone
=
self
._get_capstone()
self
.mu: unicorn.Uc
=
None
self
.ks
=
self
._get_ks()
def
_open_so(
self
, so_path:
str
)
-
> bytes:
with
open
(so_path,
'rb'
) as fp:
return
fp.read()
def
_mu_hook_code(
self
, mu: unicorn.Uc, address, size, user_data: FuncContext):
if
user_data.is_success:
mu.emu_stop()
return
if
address > user_data.func_end_addr:
mu.emu_stop()
return
if
address
in
user_data.real_blocks:
if
address
in
user_data.trace_list:
print
(
"This maybe a fake block. codesign:%s "
%
user_data.block_map[address].sign)
mu.emu_stop()
return
else
:
user_data.trace_list.append(address)
if
address
in
user_data.real_blocks
and
address !
=
user_data.start_addr:
user_data.is_success
=
True
user_data.dist_addr
=
address
logger.info(
"find dist addr: 0x%x"
%
address)
mu.emu_stop()
return
code_bin
=
self
.so_bin[address:address
+
size]
try
:
ins
=
next
(
self
.capstone.disasm(code_bin, address))
except
StopIteration:
mu.emu_stop()
return
if
ins.mnemonic
=
=
'ret'
:
mu.reg_write(unicorn.arm64_const.UC_ARM64_REG_PC,
0
)
mu.emu_stop()
logger.info(
"ret ins.."
)
return
is_skip
=
False
for
bl
in
self
._MU_SKIP_INS:
if
bl
in
ins.mnemonic:
is_skip
=
True
break
if
'['
in
ins.op_str
and
'sp'
in
ins.op_str:
addr
=
self
._get_ins_mem_addr(mu, ins)
if
addr <
0x80000000
and
addr >
=
0x80000000
+
0x10000
*
8
:
is_skip
=
True
if
is_skip:
logger.info(
"will pass 0x%x:\t%s\t%s"
%
(ins.address, ins.mnemonic, ins.op_str))
mu.reg_write(unicorn.arm64_const.UC_ARM64_REG_PC, address
+
size)
return
if
ins.mnemonic
=
=
'csel'
:
reg_list
=
self
._get_uc_reg_list(ins)
reg1_val
=
mu.reg_read(reg_list[
1
])
reg2_val
=
mu.reg_read(reg_list[
2
])
if
user_data.branch_control
=
=
1
:
mu.reg_write(reg_list[
0
], reg1_val)
else
:
mu.reg_write(reg_list[
0
], reg2_val)
mu.reg_write(unicorn.arm64_const.UC_ARM64_REG_PC, address
+
size)
def
_get_uc_reg_list(
self
, ins: capstone.CsInsn):
regid_list
=
[]
for
op
in
ins.operands:
if
op.
type
=
=
capstone.arm64.ARM64_OP_REG:
regid_list.append(UR(ins.reg_name(op.reg)))
return
regid_list
def
_get_ins_mem_addr(
self
, mu, ins: capstone.CsInsn):
addr
=
0
for
op
in
ins.operands:
if
op.
type
=
=
capstone.arm64.ARM64_OP_MEM:
if
op.mem.base !
=
0
:
regname
=
ins.reg_name(op.mem.base)
uc_reg_id
=
UR(regname)
addr
+
=
mu.reg_read(uc_reg_id)
elif
op.mem.index !
=
0
:
regname
=
ins.reg_name(op.mem.index)
uc_reg_id
=
UR(regname)
addr
+
=
mu.reg_read(uc_reg_id)
elif
op.mem.disp !
=
0
:
addr
+
=
op.mem.disp
return
addr
def
_mu_hook_mem_unmapped(
self
, mu,
type
, address, size, value, user_data: FuncContext):
pc
=
mu.reg_read(unicorn.arm64_const.UC_ARM64_REG_PC)
logger.error(
'mem_unmapped pc:%x type:%d addr:%x size:%x'
%
(pc,
type
, address, size))
return
False
def
_get_mu(
self
, user_data):
if
self
.mu
is
not
None
:
try
:
self
.mu.emu_stop()
except
unicorn.UcError as e:
logger.error(e)
mu
=
unicorn.Uc(unicorn.UC_ARCH_ARM64, unicorn.UC_MODE_ARM)
mu.mem_map(
0x80000000
,
0x10000
*
8
)
mu.mem_map(
0
,
4
*
1024
*
1024
)
mu.mem_write(
0
,
self
.so_bin)
mu.reg_write(unicorn.arm64_const.UC_ARM64_REG_SP,
0x80000000
+
0x10000
*
6
)
mu.hook_add(unicorn.UC_HOOK_CODE,
self
._mu_hook_code, user_data
=
user_data)
mu.hook_add(unicorn.UC_HOOK_MEM_UNMAPPED,
self
._mu_hook_mem_unmapped, user_data
=
user_data)
return
mu
def
_get_mu_context(
self
):
reg_vals
=
[]
for
i
in
range
(
31
):
reg_vals.append(
self
.mu.reg_read(UR(
'x{0}'
.
format
(i))))
reg_vals.append(
self
.mu.reg_read(UR(
'sp'
)))
return
reg_vals
def
_set_mu_context(
self
, reg_vals):
if
reg_vals
is
None
:
return
if
len
(reg_vals) !
=
32
:
return
for
i
in
range
(
31
):
self
.mu.reg_write(UR(
'x{0}'
.
format
(i)), reg_vals[i])
self
.mu.reg_write(UR(
'sp'
), reg_vals[
31
])
def
_get_capstone(
self
, detail
=
True
):
md
=
capstone.Cs(capstone.CS_ARCH_ARM64, capstone.CS_MODE_ARM)
md.detail
=
detail
return
md
def
_get_ks(
self
):
return
keystone.Ks(keystone.KS_ARCH_ARM64, keystone.KS_MODE_LITTLE_ENDIAN)
def
_is_block_end(
self
, ins:capstone.CsInsn):
ins_groups
=
ins.groups
if
ins_groups
is
None
:
return
False
if
len
(ins_groups)
=
=
0
:
return
False
for
group
in
self
._END_GROUPS:
if
group
in
ins_groups:
return
True
return
False
def
_is_bl(
self
, ins: capstone.CsInsn):
return
ins.
id
=
=
capstone.arm64.ARM64_INS_BL
def
_is_ret(
self
, ins: capstone.CsInsn):
return
ins.
id
=
=
capstone.arm64.ARM64_INS_RET
def
_is_csel(
self
, ins:capstone.CsInsn):
return
ins.
id
=
=
capstone.arm64.ARM64_INS_CSEL
def
_get_ins_imm(
self
, ins: capstone.CsInsn):
opr: capstone.arm64.Arm64Op
=
None
for
opr
in
ins.operands:
if
opr.
type
=
=
capstone.arm64.ARM64_OP_IMM:
return
opr.imm
return
None
def
_is_fake_block(
self
, block_item: BlockItem):
if
block_item.sign
in
self
._FAKE_BLOCK_SINGS:
return
True
return
False
def
get_func_blocks(
self
, func_start, func_end):
func_bin
=
self
.so_bin[func_start:func_end]
ins_list
=
self
.capstone.disasm(func_bin, func_start)
block_map
=
{}
tmp_jump_map
=
{}
processors
=
[]
dead_loop
=
[]
ins: capstone.CsInsn
=
None
block_item: BlockItem
=
None
new_block
=
True
for
ins
in
ins_list:
ins_str
=
'0x%x: %s %s'
%
(ins.address, ins.mnemonic, ins.op_str)
if
new_block:
block_item
=
BlockItem()
block_item.start_addr
=
ins.address
block_item.ins_list
=
[]
new_block
=
False
block_item.ins_list.append(ins)
if
self
._is_bl(ins):
block_item.has_bl_ins
=
True
if
self
._is_ret(ins):
block_item.has_ret_ins
=
True
if
self
._is_csel(ins):
block_item.has_csel_ins
=
True
if
self
._is_block_end(ins)
and
not
self
._is_bl(ins):
jump_addr
=
self
._get_ins_imm(ins)
block_item.end_addr
=
ins.address
block_item.jump_addr
=
jump_addr
tmp_jump_map[jump_addr]
=
tmp_jump_map.get(jump_addr,
0
)
+
1
if
jump_addr
=
=
ins.address:
dead_loop.append(jump_addr)
block_map[block_item.start_addr]
=
block_item
new_block
=
True
for
addr
in
dead_loop:
del
tmp_jump_map[addr]
for
addr, num
in
tmp_jump_map.items():
if
num >
1
:
processors.append(addr)
return
block_map, processors
def
get_real_blocks(
self
, block_map:
dict
, processor:
list
):
real_blocks
=
set
()
for
start_addr, block_item
in
block_map.items():
if
block_item.has_bl_ins:
real_blocks.add(start_addr)
elif
block_item.jump_addr
in
processor:
real_blocks.add(start_addr)
elif
block_item.has_ret_ins:
real_blocks.add(start_addr)
real_blocks
=
filter
(
lambda
addr:
not
self
._is_fake_block(block_map[addr]), real_blocks)
return
[addr
for
addr
in
real_blocks]
def
find_next_block(
self
, func_context: FuncContext, start_addr, branch_control
=
1
):
func_context.branch_control
=
branch_control
func_context.trace_list
=
[]
func_context.is_success
=
False
func_context.dist_addr
=
None
func_context.start_addr
=
start_addr
try
:
self
.mu.emu_start(func_context.start_addr,
0x10000
)
except
unicorn.UcError as e:
pc
=
self
.mu.reg_read(unicorn.arm64_const.UC_ARM64_REG_PC)
if
pc !
=
0
:
return
self
.find_next_block(func_context, pc
+
4
, branch_control)
else
:
logger.error(
"find_next_block, pc: 0x%x, err: %s"
%
(pc, e))
return
func_context.dist_addr
def
get_real_block_flows(
self
, func_context: FuncContext):
self
.mu
=
self
._get_mu(func_context)
if
func_context.func_start_addr
in
func_context.real_blocks:
func_context.real_blocks.remove(func_context.func_start_addr)
queue
=
[(func_context.func_start_addr,
None
)]
flows
=
{}
while
len
(queue) !
=
0
:
pc, reg_vals
=
queue.pop()
self
._set_mu_context(reg_vals)
if
pc
in
flows:
continue
flows[pc]
=
[]
ctx
=
self
._get_mu_context()
pc1
=
self
.find_next_block(func_context, pc,
0
)
if
pc1
is
not
None
:
queue.append((pc1,
self
._get_mu_context()))
flows[pc].append(pc1)
block_item: BlockItem
=
func_context.block_map.get(pc)
if
block_item.has_csel_ins:
self
._set_mu_context(ctx)
pc2
=
self
.find_next_block(func_context, pc,
1
)
if
(pc2
is
not
None
)
and
(pc1 !
=
pc2):
queue.append((pc2,
self
._get_mu_context()))
flows[pc].append(pc2)
return
flows
def
_generate_patch(
self
, origin_addr, branchs, condition
=
None
)
-
> bytes:
if
condition
is
not
None
:
codes, _
=
self
.ks.asm(
"b%s #0x%x;b #0x%x"
%
(condition, branchs[
0
], branchs[
1
]), origin_addr, as_bytes
=
True
)
else
:
codes, _
=
self
.ks.asm((
"b #0x%x"
%
branchs[
0
]), origin_addr, as_bytes
=
True
)
return
codes
def
_generate_nop_patch(
self
, num)
-
> bytes:
codes, _
=
self
.ks.asm(
'nop'
, as_bytes
=
True
)
return
codes
*
num
def
patch_func(
self
, func_context: FuncContext, flows):
patch_list
=
[]
for
start_addr, branchs
in
flows.items():
if
len
(branchs)
=
=
0
:
continue
block_item: BlockItem
=
func_context.block_map.get(start_addr)
patch_addr
=
None
shellcode
=
None
if
len
(branchs)
=
=
2
:
csel_ins
=
block_item.csel_ins
patch_addr
=
csel_ins.address
condition
=
csel_ins.op_str[
-
2
:]
shellcode
=
self
._generate_patch(patch_addr, branchs, condition
=
condition)
if
len
(branchs)
=
=
1
:
b_ins
=
block_item.last_ins
patch_addr
=
b_ins.address
shellcode
=
self
._generate_patch(patch_addr, branchs)
if
patch_addr
is
not
None
and
shellcode
is
not
None
:
patch_list.append((patch_addr, shellcode))
for
start_addr, block_item
in
func_context.block_map.items():
if
start_addr
not
in
flows:
block_size
=
block_item.end_addr
-
block_item.start_addr
+
4
shellcode
=
self
._generate_nop_patch(
int
(block_size
/
4
))
patch_list.append((start_addr, shellcode))
tmp_so_bin
=
bytearray(
self
.so_bin)
for
patch_addr, shellcode
in
patch_list:
for
ch
in
shellcode:
tmp_so_bin[patch_addr]
=
ch
patch_addr
+
=
1
self
.so_bin
=
bytes(tmp_so_bin)
def
proccess_func(
self
, start_addr, end_addr):
block_map, processors
=
self
.get_func_blocks(start_addr, end_addr)
real_blocks
=
self
.get_real_blocks(block_map, processors)
func_context
=
FuncContext()
func_context.func_start_addr
=
start_addr
func_context.func_end_addr
=
end_addr
func_context.block_map
=
block_map
func_context.real_blocks
=
real_blocks
func_context.trace_list
=
[]
flows
=
self
.get_real_block_flows(func_context)
self
.patch_func(func_context, flows)
def
dump(
self
, save_path):
with
open
(save_path,
'wb'
) as fp:
fp.write(
self
.so_bin)
if
__name__
=
=
'__main__'
:
ollvm
=
FuckArm64Ollvm(
'libvdog.so'
)
ollvm.proccess_func(
0x70438
,
0x7170C
)
ollvm.dump(
'libvdog_new.so'
)