最近为对某软件的指定模块打特征码补丁,研究学习了各种搜索算法(Sunday、Shift_And、BNDM等),上一篇《对SSE2模式匹配算法SSE2PatternFind的一点改造优化》中的算法,主要是利用SSE2指令找到特征码中不是通配符的第一个字节,再基于找到的第一个字节用常规指令搜索后面的字节序列,不足之处是除了第一个字节利用了SSE2大位宽的单指令多数据处理优势,而后面的字节搜索用不上SSE2指令集的优势,基于这个问题,这两天思考后又想到了一个更好的特征码匹配方式,尝试了一下,发现效果很好,写下来和大家分享!
上一篇文章《对SSE2模式匹配算法SSE2PatternFind的一点改造优化》链接: https://bbs.kanxue.com/thread-283252-1.htm
根据传入的std::string
类型的特征码字符串,检查myPattern
是否为空,并使用std::remove
和erase
去除特征码字符串中的所有空格字符;检查特征码中的每个字符是否是?
或十六进制字符,如果有非法的字符(既不是?
、也不是十六进制字符)则返回FALSE
;检查特征码字符串长度是否为偶数,如果不是则返回FALSE
;检查如果vecIdx
大小为0
,说明没有有效的特征码字节(如:所有的特征码字符都是通配符??
),返回FALSE
。
将传入的特征码字符串按每两个字符(一个字节)进行分割并依次处理:
① 判特征码字符是否含有?
或??
来处理传入的特征码字符串,即半字节中的?
通过 0xFF << 4
或0xFF >> 4
将左、右半字节的二进制位替换为1111
,双问号??
替换为0xFF
(即二进制位全为1),其它保持原特征码字节数据,得到vecPtn特征码字节序列;
② 根据特征码字符是否含有?
或??
,生成相应的二进制掩码vecMsk
,即半字节中的?
通过 0xFF << 4
或0xFF >> 4
将左、右半字节的二进制位替换为1111
,双问号??
替换为0xFF
(即二进制位全为1),其它全为0;
③ 记录不是??
(双问号)的特征码字节在原始特征码字节序列(传入的有??
的特征码)中的索引下标(vecIdx
)。
特征码字符串初始化代码实现如下:
① 首先,从要搜索的内存中取得第1组32字节(__m256i
)的数据curMemByte
,将取到的32字节数据curMemByte
中的每一个字节与上面初始化得到的不是??
的特征字节对应的二进制掩码vecMsk[0]
进行位或运算,以修正特征码字节序列中第1个特征字节vecPtn[0]
存在半字节时的情况(如果存在半字节,则对应的半字节?
处的二进制位替换为1),修正后得到curByteCorr
;将curByteCorr
中的每一个字节分别和vecPtn[0]
(即m256VecPtn.at(0)
中的一个字节)进行比较,即在curByteCorr
中查找所有可能的vecPtn[0]
,得到查找结果curCmp
;从curCmp
中提取出找到的vecPtn[0]
,得到得到vecPtn.at(0)
对应的curBit
;
② 再次,从内存取出第2组32字节(__m256i)的数据curMemByte
,步骤和前面第1组方式相同,这里需要说明:第2组32字节数据从内存地址0x03
处开始取,因为我们要查找的第2个特征字节vecPtn.at(1)
在原始特征码字符串中的索引下标是3,与vecPtn.at(0)
之间有一个??通配符位置需要留出来;通过和前面第1组数据的完全相同方式得到vecPtn.at(1)
对应的curBit
;
③ 将两次得到的curBit
进行位与(&)运算,找出vecPtn.at(0)
的curBit
(即prevCmpBit
)与vecPtn.at(1)
的curBit
相同索引位都为1的位置,这些位置就是内存中找的vecPtn前两个特征码字节。
这里的位与(&)运算结果curBit
不为0说明找到相应的特征码字节vecPtn.at(i)
,需要继续查找下一个vecPtn.at(i+1)
,直到vecPtn
的每个元素都遍历完时curBit
还不为0,说明找到了整个特征码序列,此时curBit
中一个为1的二进制位就是找到的一个特征码序列的位置标记(一个curBit
中可能找到多个特征码序列)。
当这里的任何一次位与(&)运算结果curBit
为0时,说明没有找到vecPtn
的对应元素,本次查找结束,当前的内存查找指针+32Byte,进入内存的下一个32字节搜索。
重复上述步骤,直到整个要搜索的内存区域遍历完成。
④ 上述①~③步遍历完成后,给定内存区域末尾还会剩下少于32字节的内存区域没有被搜索,需要判断要搜索的特征码字节序列长度是否小于内存区域末尾剩余的字节数,如果小于,则需要单独对末尾剩余的这部分内存进行特征码字节序列的搜索。
特征码匹配算法的具体流程如下:
基于AVX2指令集的特征码字节序列遍历匹配的算法代码实现如下:
对 VS Code 的主程序Code.exe的代码段".text"(大小131M)做如下搜索测试:
算法原理简单高效,代码易于实现、易于扩展;只搜索特征码中不是通配符的特征字节,优化搜索字节数,搜索速度快;算法主要利用位操作对特征码进行比对,充分利用了AVX2、SSE2指令集的大位宽、单指令多数据的优势;采用掩码的方式实现通配符(含半字节)特征码的搜索支持。
另:如果采用AVX512指令集(位宽512bit)一次可比对的字节数达到64Byte,且AVX2中的__mm256_cmpeq_epi8
和_mm256_movemask_epi8
指令在AVX512中可简化一条指令_mm512_cmpeq_epi8_mask
,理论上速度还会显著提升,机器支持AVX512指令集的可以修改代码后自行测试。
20240908修改: ---①修改文章说法和代码逻辑有差别的问题; ---②上传源码文件,源码中提供AVX2PatternFind256、SSE2PatternFind128、和判断当前机器支持指令集情况自动选择对应搜索函数; ---③代码增加进行内存搜索时不要搜到自己特征码字节序列的处理。 20240924修改: ---修复搜索的内存尺寸小于32Byte时奔溃的问题。
inline
BOOL
InitPattern(
const
std::string& myPattern, std::vector<
UCHAR
>& vecPtn, std::vector<
UCHAR
>& vecMsk, std::vector<
ULONG
>& vecIdx)
{
std::string patternText = myPattern;
if
(patternText.empty()) {
return
FALSE; }
patternText.erase(std::
remove
(patternText.begin(), patternText.end(),
' '
), patternText.end());
for
(
char
ch : patternText) {
if
(ch !=
'?'
&& !((ch >=
'0'
&& ch <=
'9'
) || (ch >=
'A'
&& ch <=
'F'
) || (ch >=
'a'
&& ch <=
'f'
))) {
return
FALSE; }
}
if
(patternText.length() % 2 != 0) {
return
FALSE; }
ULONG
len = patternText.length() / 2;
for
(
ULONG
i = 0; i < len; i++)
{
std::string tmpS = patternText.substr(i * 2, 2);
if
(
"??"
!= tmpS)
{
if
(
'?'
== tmpS.at(0))
{
tmpS.at(0) =
'F'
;
vecMsk.push_back(
UCHAR
(0xFF) << 4);
}
else
if
(
'?'
== tmpS.at(1))
{
tmpS.at(1) =
'F'
;
vecMsk.push_back(
UCHAR
(0xFF) >> 4);
}
else
{
vecMsk.push_back(
UCHAR
(0x00));
}
vecIdx.push_back(i);
}
if
(
"??"
== tmpS)
{
tmpS.at(0) =
'F'
;
tmpS.at(1) =
'F'
;
vecMsk.push_back(
UCHAR
(0xFF));
}
vecPtn.push_back(
strtoul
(tmpS.c_str(), nullptr, 16));
}
if
(0 == vecIdx.size()) {
return
FALSE; }
return
TRUE;
}
inline
BOOL
InitPattern(
const
std::string& myPattern, std::vector<
UCHAR
>& vecPtn, std::vector<
UCHAR
>& vecMsk, std::vector<
ULONG
>& vecIdx)
{
std::string patternText = myPattern;
if
(patternText.empty()) {
return
FALSE; }
patternText.erase(std::
remove
(patternText.begin(), patternText.end(),
' '
), patternText.end());
for
(
char
ch : patternText) {
if
(ch !=
'?'
&& !((ch >=
'0'
&& ch <=
'9'
) || (ch >=
'A'
&& ch <=
'F'
) || (ch >=
'a'
&& ch <=
'f'
))) {
return
FALSE; }
}
if
(patternText.length() % 2 != 0) {
return
FALSE; }
ULONG
len = patternText.length() / 2;
for
(
ULONG
i = 0; i < len; i++)
{
std::string tmpS = patternText.substr(i * 2, 2);
if
(
"??"
!= tmpS)
{
if
(
'?'
== tmpS.at(0))
{
tmpS.at(0) =
'F'
;
vecMsk.push_back(
UCHAR
(0xFF) << 4);
}
else
if
(
'?'
== tmpS.at(1))
{
tmpS.at(1) =
'F'
;
vecMsk.push_back(
UCHAR
(0xFF) >> 4);
}
else
{
vecMsk.push_back(
UCHAR
(0x00));
}
vecIdx.push_back(i);
}
if
(
"??"
== tmpS)
{
tmpS.at(0) =
'F'
;
tmpS.at(1) =
'F'
;
vecMsk.push_back(
UCHAR
(0xFF));
}
vecPtn.push_back(
strtoul
(tmpS.c_str(), nullptr, 16));
}
if
(0 == vecIdx.size()) {
return
FALSE; }
return
TRUE;
}
DLL_API
inline
ULONGLONG
BFPatternFind(
const
ULONGLONG
startAddr,
const
ULONGLONG
searchSize,
const
std::vector<
UCHAR
>& vecPtn,
const
std::vector<
UCHAR
>& vecMsk,
const
std::vector<
ULONG
>& vecIdx)
{
if
(searchSize < vecPtn.size()) {
return
0; }
PUCHAR
maxAddress = (
PUCHAR
)(startAddr + searchSize);
PUCHAR
currPattern = (
PUCHAR
)&vecPtn[0];
UCHAR
currEqual;
register
UCHAR
currPtnCh;
PUCHAR
currAddress = (
PUCHAR
)startAddr;
for
(
size_t
iCh = 0; iCh < vecIdx.size() && (
size_t
)currAddress <= (
size_t
)maxAddress; iCh++)
{
currPtnCh = currPattern[vecIdx[iCh]];
currPattern[vecIdx.at(iCh)] = currPtnCh + 0x1;
currEqual = ((currAddress[vecIdx[iCh]] | vecMsk.at(vecIdx[iCh])) ^ currPtnCh);
currPattern[vecIdx.at(iCh)] = currPtnCh;
if
(currEqual) {
return
0; }
if
(iCh + 1 == vecIdx.size())
{
return
(
ULONGLONG
)currAddress;
}
}
return
0;
}
DLL_API
BOOL
AVX2PatternFind256(std::vector<
ULONGLONG
>& retList,
const
ULONGLONG
searchStartAddr,
const
LONGLONG
searchSize,
const
std::string& myPattern,
const
LONGLONG
offsetSize,
const
ULONGLONG
searchNum)
{
if
(0 == searchStartAddr || 0 == searchSize) {
return
FALSE; }
ULONGLONG
realStartAddr = searchStartAddr;
if
((searchSize < 0) && (searchStartAddr > std::
abs
(searchSize)))
{
realStartAddr = searchStartAddr - std::
abs
(searchSize);
}
std::vector<
UCHAR
> vecPtn;
vecPtn.reserve(16);
std::vector<
UCHAR
> vecMsk;
vecMsk.reserve(16);
std::vector<
ULONG
> vecIdx;
vecIdx.reserve(8);
if
(!InitPattern(myPattern, vecPtn, vecMsk, vecIdx)) {
return
FALSE; }
std::vector<__m256i> m256VecPtn;
m256VecPtn.reserve(16);
std::vector<__m256i> m256VecMsk;
m256VecMsk.reserve(16);
for
(
size_t
k = 0; k < vecIdx.size(); k++)
{
m256VecPtn.push_back(_mm256_set1_epi8(vecPtn.at(vecIdx[k])));
m256VecMsk.push_back(_mm256_set1_epi8(vecMsk.at(vecIdx[k])));
}
UCHAR
bakVecPtnCh = vecPtn.at(vecIdx[0]);
vecPtn.at(vecIdx[0]) += 1;
retList.clear();
retList.reserve(16);
__m256i curMemByte, curCmp, curByteCorr;
register
size_t
curBit = 0;
PUCHAR
currMemAddr;
size_t
maxEndSize = min(std::
abs
(searchSize) - vecPtn.size(), std::
abs
(searchSize) - 32);
if
(std::
abs
(searchSize) < 32) {
goto
lessThan32Byte; }
for
(
size_t
i = vecIdx[0]; i <= maxEndSize; i += 32)
{
PUCHAR
baseMemAddr = (
PUCHAR
)(realStartAddr + i - vecIdx[0]);
size_t
prevCmpBit = 0xFFFFFFFF;
for
(
size_t
j = 0; j < vecIdx.size(); j++)
{
curMemByte = _mm256_loadu_si256((__m256i*)(baseMemAddr + vecIdx[j]));
curByteCorr = _mm256_or_si256(curMemByte, m256VecMsk.at(j));
curCmp = _mm256_cmpeq_epi8(m256VecPtn.at(j), curByteCorr);
curBit = _mm256_movemask_epi8(curCmp);
curBit = curBit & prevCmpBit;
if
(0 == curBit) {
break
; }
prevCmpBit = curBit;
if
(j + 1 == vecIdx.size())
{
ULONG
bitIdx = 0, n = 0;
while
(_BitScanForward(&bitIdx, curBit))
{
currMemAddr = baseMemAddr + n + bitIdx;
retList.push_back((
size_t
)(currMemAddr + offsetSize));
if
(searchNum != 0 && retList.size() >= searchNum) {
return
TRUE; }
++bitIdx;
curBit = curBit >> bitIdx;
n += bitIdx;
}
}
}
}
vecPtn.at(vecIdx[0]) = bakVecPtnCh;
lessThan32Byte:
if
(vecPtn.size() < 32)
{
ULONGLONG
tmpStarAddr = realStartAddr + maxEndSize + 1;
ULONGLONG
tmpSearchSize = std::
abs
(searchSize) - maxEndSize - 1;
for
(
int
i = 0; i <= tmpSearchSize - vecPtn.size(); i += vecPtn.size())
{
ULONGLONG
tailPtnAddr = BFPatternFind(tmpStarAddr + i, tmpSearchSize - i, vecPtn, vecMsk, vecIdx);
if
(tailPtnAddr)
{
retList.push_back(tailPtnAddr);
}
}
}
return
TRUE;
}
DLL_API
inline
ULONGLONG
BFPatternFind(
const
ULONGLONG
startAddr,
const
ULONGLONG
searchSize,
const
std::vector<
UCHAR
>& vecPtn,
const
std::vector<
UCHAR
>& vecMsk,
const
std::vector<
ULONG
>& vecIdx)
{
if
(searchSize < vecPtn.size()) {
return
0; }
PUCHAR
maxAddress = (
PUCHAR
)(startAddr + searchSize);
PUCHAR
currPattern = (
PUCHAR
)&vecPtn[0];
UCHAR
currEqual;
register
UCHAR
currPtnCh;
PUCHAR
currAddress = (
PUCHAR
)startAddr;
for
(
size_t
iCh = 0; iCh < vecIdx.size() && (
size_t
)currAddress <= (
size_t
)maxAddress; iCh++)
{
currPtnCh = currPattern[vecIdx[iCh]];
currPattern[vecIdx.at(iCh)] = currPtnCh + 0x1;
currEqual = ((currAddress[vecIdx[iCh]] | vecMsk.at(vecIdx[iCh])) ^ currPtnCh);
currPattern[vecIdx.at(iCh)] = currPtnCh;
if
(currEqual) {
return
0; }
if
(iCh + 1 == vecIdx.size())
{
return
(
ULONGLONG
)currAddress;
}
}
return
0;
}
[注意]传递专业知识、拓宽行业人脉——看雪讲师团队等你加入!
最后于 2024-9-24 16:53
被haogl编辑
,原因:
上传的附件: