class
instruction2vec(nn.Module):
def
__init__(
self
, word2vec_model_path:
str
):
super
(instruction2vec,
self
).__init__()
word2vec
=
Word2Vec.load(word2vec_model_path)
self
.embedding
=
nn.Embedding.from_pretrained(torch.from_numpy(word2vec.wv.vectors))
self
.token_size
=
word2vec.wv.vector_size
self
.key_to_index
=
word2vec.wv.key_to_index.copy()
self
.index_to_key
=
word2vec.wv.index_to_key.copy()
del
word2vec
def
keylist_to_tensor(
self
, keyList:
list
):
indexList
=
[
self
.key_to_index[token]
for
token
in
keyList]
return
self
.embedding(torch.LongTensor(indexList))
def
InsnStr2Tensor(
self
, insnStr:
str
)
-
> torch.tensor:
insnStr
=
RefineAsmCode(insnStr)
tokenList
=
re.findall(
'\w+|[\+\-\*\:\[\]\,]'
, insnStr)
opcode_tensor
=
self
.keylist_to_tensor(tokenList[
0
:
1
])[
0
]
op_zero_tensor
=
torch.zeros(
self
.token_size)
insn_tensor
=
None
if
(
1
=
=
len
(tokenList)):
insn_tensor
=
torch.cat((opcode_tensor, op_zero_tensor, op_zero_tensor), dim
=
0
)
else
:
op_token_list
=
tokenList[
1
:]
if
(op_token_list.count(
','
)
=
=
0
):
op1_tensor
=
self
.keylist_to_tensor(op_token_list)
insn_tensor
=
torch.cat((opcode_tensor, op1_tensor.mean(dim
=
0
), op_zero_tensor), dim
=
0
)
elif
(op_token_list.count(
','
)
=
=
1
):
dot_index
=
op_token_list.index(
','
)
op1_tensor
=
self
.keylist_to_tensor(op_token_list[
0
:dot_index])
op2_tensor
=
self
.keylist_to_tensor(op_token_list[dot_index
+
1
:])
insn_tensor
=
torch.cat((opcode_tensor, op1_tensor.mean(dim
=
0
), op2_tensor.mean(dim
=
0
)), dim
=
0
)
elif
(op_token_list.count(
','
)
=
=
2
):
dot1_index
=
op_token_list.index(
','
)
dot2_index
=
op_token_list.index(
','
, dot1_index
+
1
)
op1_tensor
=
self
.keylist_to_tensor(op_token_list[
0
:dot1_index])
op2_tensor
=
self
.keylist_to_tensor(op_token_list[dot1_index
+
1
:dot2_index])
op3_tensor
=
self
.keylist_to_tensor(op_token_list[dot2_index
+
1
:])
op2_tensor
=
(op2_tensor.mean(dim
=
0
)
+
op3_tensor.mean(dim
=
0
))
/
2
insn_tensor
=
torch.cat((opcode_tensor, op1_tensor.mean(dim
=
0
), op2_tensor), dim
=
0
)
if
(
None
=
=
insn_tensor):
print
(
"error: None == insn_tensor"
)
raise
insn_size
=
insn_tensor.shape[
0
]
if
(
self
.token_size
*
3
!
=
insn_size):
print
(
"error: (token_size)%d != %d(insn_size)"
%
(
self
.token_size, insn_size))
raise
return
insn_tensor
def
forward(
self
, insnStrList:
list
)
-
> torch.tensor:
insnTensorList
=
[
self
.InsnStr2Tensor(insnStr)
for
insnStr
in
insnStrList]
return
torch.stack(insnTensorList)