AC自动机 多模式字符串匹配(简单版)
题目
给定n个模式串和1个文本串,求有多少个模式串在文本串里出现过。
注意:是出现过,就是出现多次只算一次。
输入格式
第一行是一个整数,表示模式串的个数 n。
第 2 到第 (n+1) 行,每行一个字符串,第 (i+1) 行的字符串表示编号为 i 的模式串 si。
最后一行是一个字符串,表示文本串 t。
输出格式
输出一行一个整数表示答案。
题目灰常简单,AC主要作用就是优化的求这个问题的
组成部分
Trie 树节点结构:
struct K{int son[26],tag,fail;
}tr[maxn];
son[26]
:存储子节点指针(每个位置对应一个字母)。tag
:标记该节点是否为单词结尾,记录单词出现次数。fail
:失败指针,指向当前节点匹配失败后应跳转的节点。
1. 插入模式串 in(char* s)
- 功能:将模式串插入 Trie 树。
- 流程:
- 从根节点(初始为 1)开始遍历字符串。
- 若当前字符对应的子节点不存在,则创建新节点。
- 遍历结束后,标记单词结尾节点的
tag++
。
void in(char* s){int u=1,len=strlen(s); // 从根节点(1)开始for(int i=0;i<len;i++){int v=s[i]-'a'; // 字符转索引(0-25)if(!tr[u].son[v]) // 若子节点不存在tr[u].son[v]=++cnt; // 创建新节点u=tr[u].son[v]; // 移动到子节点}tr[u].tag++; // 标记单词结尾,记录出现次数
}
cnt初始为1,用来记录当前节点编号;
这个树顺便把同样的模式串合并了
2. 构建失败指针 getF()
- 功能:利用 BFS 构建每个节点的失败指针。
- 流程:
- 初始化根节点(1)的失败指针为 0(虚拟根),并将其所有子节点指向根。简化边界处理。
- BFS 遍历每个节点:
- 若子节点存在:其失败指针指向父节点失败指针的对应子节点。
- 若子节点不存在:直接将其指向父节点失败指针的对应子节点(路径压缩)。
void getF(){for(int i=0;i<26;i++) tr[0].son[i]=1; // 虚拟根节点的子节点全指向根(1)q.push(1); tr[1].fail=0; // 根节点的失败指针为虚拟根(0)while(!q.empty()){int u=q.front(); q.pop();for(int i=0;i<26;i++){ // 遍历所有可能的子节点int v=tr[u].son[i]; // 当前子节点int F=tr[u].fail; // 当前节点的失败指针if(!v){ // 若子节点不存在tr[u].son[i]=tr[F].son[i]; // 路径压缩,直接指向失败指针的对应子节点continue;}tr[v].fail=tr[F].son[i]; // 子节点的失败指针指向其父失败指针的对应子节点q.push(v); // 将子节点入队}}
}
失败指针的意义:当节点u
的子节点v
匹配失败时,跳转到tr[u].fail
的对应子节点tr[F].son[i]
,相当于在其他模式串中继续查找前缀匹配。
3. 查询匹配 qu(char* s)
- 功能:在文本串中查找所有模式串的出现次数。
- 流程:
- 从根节点开始遍历文本串。
- 沿当前节点的子节点或失败指针跳转,累加匹配到的单词次数(
tag
),并将tag
置为 - 1 避免重复计数。
int qu(char* s){int u=1,ans=0,len=strlen(s); // 从根节点(1)开始for(int i=0;i<len;i++){int v=s[i]-'a';int k=tr[u].son[v]; // 当前字符的子节点while(k>1 && tr[k].tag!=-1){ // 沿失败指针回溯,直到根或已访问过的节点ans+=tr[k].tag; // 累加匹配次数tr[k].tag=-1; // 标记为已访问,避免重复计数k=tr[k].fail; // 跳转至失败指针}u=tr[u].son[v]; // 移动到当前字符的子节点}return ans;
}
tag=-1
标记已访问的单词节点,确保每个模式串只计数一次(即使在文本中多次出现)
完整代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+1;
struct K{
int son[26],tag,fail;
}tr[maxn];
int n,cnt;char s[maxn];
queue<int>q;
void in(char* s){
int u=1,len=strlen(s);
for(int i=0;i<len;i++){
int v=s[i]-'a';
if(!tr[u].son[v])tr[u].son[v]=++cnt;
u=tr[u].son[v];
}
tr[u].tag++;
}
void getF(){
for(int i=0;i<26;i++)tr[0].son[i]=1;
q.push(1);tr[1].fail=0;
while(!q.empty()){
int u=q.front();q.pop();
for(int i=0;i<26;i++){
int v=tr[u].son[i];
int F=tr[u].fail;
if(!v){tr[u].son[i]=tr[F].son[i];continue;}
tr[v].fail=tr[F].son[i];
q.push(v);
}
}
}
int qu(char* s){
int u=1,ans=0,len=strlen(s);
for(int i=0;i<len;i++){
int v=s[i]-'a';
int k=tr[u].son[v];
while(k>1&&tr[k].tag!=-1){
ans+=tr[k].tag,tr[k].tag=-1;
k=tr[k].fail;
}
u=tr[u].son[v];
}
return ans;
}
int main(){
cnt=1;scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%s",s);
in(s);
}
getF();
scanf("%s",s);
printf("%d\n",qu(s));
return 0;
}
扩展难度加大(额外特性)题目
因为额外要求,我们需要满足:
- 记录每个模式串的原始内容,方便最后根据出现次数收集结果。
- 统计每个模式串出现次数时,不再将
tag
置为-1
避免重复统计(因为题目要统计所有出现次数,不是仅第一次匹配到就算),而是正常累加,后续通过标记模式串的索引来区分不同模式串。 - 最后遍历所有模式串,找出出现次数最多的那些,并按输入顺序输出。
以下是在上面原代码上修改出来的
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 1;
struct K {
int son[26], tag, fail; // tag 表示单词结尾标记(非0值)
int idx; // 新增:模式串的索引(从0开始)
} tr[maxn];
int n, cnt;
char s[maxn];
queue<int> q;
vector<string> patterns; // 存储所有模式串
int pattern_cnt[155]; // 每个模式串的出现次数
// 插入模式串,记录索引
void in(char* s, int idx) {
int u = 1, len = strlen(s);
for (int i = 0; i < len; i++) {
int v = s[i] - 'a';
if (!tr[u].son[v]) tr[u].son[v] = ++cnt;
u = tr[u].son[v];
}
tr[u].tag = 1; // 标记为单词结尾
tr[u].idx = idx; // 记录模式串索引
}
// 构建失败指针
void getF() {
for (int i = 0; i < 26; i++) tr[0].son[i] = 1;
q.push(1); tr[1].fail = 0;
while (!q.empty()) {
int u = q.front(); q.pop();
for (int i = 0; i < 26; i++) {
int v = tr[u].son[i];
int F = tr[u].fail;
if (!v) {
tr[u].son[i] = tr[F].son[i];
continue;
}
tr[v].fail = tr[F].son[i];
q.push(v);
}
}
}
// 查询并统计每个模式串的出现次数
void qu(char* s) {
int u = 1, len = strlen(s);
for (int i = 0; i < len; i++) {
int v = s[i] - 'a';
u = tr[u].son[v]; // 沿Trie树移动
int k = u;
// 回溯失败指针链,统计所有匹配的模式串
while (k > 1) {
if (tr[k].tag) { // 如果是单词结尾
pattern_cnt[tr[k].idx]++; // 对应模式串计数+1
}
k = tr[k].fail; // 继续回溯
}
}
}
int main() {
while (scanf("%d", &n), n != 0) { // 多组数据
// 初始化
cnt = 1;
memset(tr, 0, sizeof(tr));
patterns.clear();
memset(pattern_cnt, 0, sizeof(pattern_cnt));
// 读入模式串
for (int i = 0; i < n; i++) {
scanf("%s", s);
patterns.push_back(s);
in(s, i);
}
// 构建AC自动机
getF();
// 读入文本串
scanf("%s", s);
qu(s);
// 找出最大出现次数
int max_cnt = 0;
for (int i = 0; i < n; i++) {
max_cnt = max(max_cnt, pattern_cnt[i]);
}
// 输出结果
printf("%d\n", max_cnt);
for (int i = 0; i < n; i++) {
if (pattern_cnt[i] == max_cnt) {
printf("%s\n", patterns[i].c_str());
}
}
}
return 0;
}