C标准库中的短模式匹配算法

我要在大数组中匹配特定的连续数字序列,数组元素是16位整型数据,绝大部分序列的长度为2或3。很显然,这和字符串匹配算法很类似,只是元素类型不同。

字符串匹配算法有很多种,它们各有优势,适用于不同的场景,我需要匹配短模式(pattern)效率高的算法。我很好奇,通用的字符串匹配函数是如何实现的,于是阅读了C标准库字符串匹配函数strstr源代码。函数原型为:

1
char *STRSTR(const char *haystack, const char *needle)

待搜寻的字符串用haystack表示,意思是「干草堆」,模式字符串用needle表示,意思是「针」。所以字符串匹配是「在干草堆里找针」,和成语「大海捞针」的意思差不多。

代码注释详细解释了实现方案。匹配长度小于3的字符串,用专门编写的算法;长度大于3不超过256,用修改过的Horspool算法;长度大于256,则用双向(Two-Way)字符串匹配算法。由此可知,strstr函数为保证效率,对不同长度的模式使用不同的算法,这种做法在标准库中很常见。strstr函数的时间复杂度是线性的。

匹配长度为2和3的模式,分别调用strstr2strstr3函数,这正是我需要的。以strstr2为例,为了看清它的原理,我添加了相关代码,打印重要变量的内容。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <stdio.h>
#include <string.h>
#include <stdint.h>

typedef unsigned char *byte_pointer;

void show_bytes(byte_pointer start, int len)
{
for (int i = 0; i < len; i++) {
printf(" %.2x", start[i]);
}
printf("\n");
}

char *strstr2(const unsigned char *hs, const unsigned char *ne)
{
printf("hs: ");
show_bytes((byte_pointer)hs, strlen(hs));
printf("ne: ");
show_bytes((byte_pointer)ne, strlen(ne));

uint32_t h1 = (ne[0] << 16) | ne[1];
printf("h1: ");
show_bytes((byte_pointer)&h1, sizeof(h1));
printf("\n");

uint32_t h2 = 0;
for (int c = hs[0]; h1 != h2 && c != 0; c = *++hs) {
printf("h2: ");
show_bytes((byte_pointer)&h2, sizeof(h2));
h2 = (h2 << 16) | c;
}
printf("h2: ");
show_bytes((byte_pointer)&h2, sizeof(h2));

return h1 == h2 ? (char *)hs - 2 : NULL;
}

int main()
{
char haystack[] = "0123456";
char needle[] = "23";
char *match = strstr2(haystack, needle);
//printf("%s\n", match);
}

运行结果如下:

1
2
3
4
5
6
7
8
9
10
hs:  30 31 32 33 34 35 36 37 38 39
ne: 33 34
h1: 34 00 33 00

h2: 00 00 00 00
h2: 30 00 00 00
h2: 31 00 30 00
h2: 32 00 31 00
h2: 33 00 32 00
h2: 34 00 33 00

通过移位和逻辑与操作,将ne中的2个字符分别放到h1的高16位和低16位。在for循环中,从前往后每次从hs中取出2个字符,分别放到h2中的高16位和低16位,然后与h1比较,若h1h2相等,则匹配成功。strstr3函数的实现方法也是类似的。由于ASCII字符只有8位,所以h1h2其实只需要16位就够了,两个字符分别放到高8位与低8位,所以要将代码中的左移16位改成8位。之所以用32位的uint32_t应该是考虑到4字节对齐的地址访问速度更快。

现在匹配的不是ASCII字符,而是16位整型数据,h1h2就必须是32位的,高16位和低16位各放一个数据。另外,待匹配的数据不是以\0结尾的字符串,而是存放在数组中,末尾没有特定的结束符,这时第28行for循环的终止条件c != 0就不适用了。对于strstr3函数,h1h2要存放3个16位的字符,uint32_t类型放不下,所以要修改成uint64_t类型。考虑上述情况,得到如下修改后的函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <stdio.h>
#include <stdint.h>

int NaiveSearch2(const uint16_t *hs, int hs_len, const uint16_t *ne, int ne_len)
{
uint32_t h1 = (ne[0] << 16) | ne[1];
uint32_t h2 = 0;
int idx;
for (idx = 0; h1 != h2 && idx < hs_len; idx++) {
h2 = (h2 << 16) | hs[idx];
}
return h1 == h2 ? idx - 2 : -1;
}

int NaiveSearch3(const uint16_t *hs, int hs_len, const uint16_t *ne, int ne_len)
{
uint64_t h1 = ((uint64_t)ne[0] << 48) | ((uint64_t)ne[1] << 32) | (ne[2] << 16);
uint64_t h2 = 0;
int idx;
for (idx = 0; h1 != h2 && idx < hs_len; idx++) {
h2 = (h2 | hs[idx]) << 16;
}
return h1 == h2 ? idx - 3 : -1;
}

int main()
{
uint16_t hs[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
uint16_t ne[] = {7, 8}; //长度为2的序列
int pos = NaiveSearch2(hs, 10, ne, 2);
printf("%d\n", pos);
uint16_t ne[] = {3, 4, 5}; //长度为3的序列
pos = NaiveSearch3(hs, 10, ne, 3);
printf("%d\n", pos);
}

大功告成!