题解 洛谷 P3735 【[HAOI2017]字符串】

根据本题的定义,对于两个串,若 (lcp + lcs + k geqslant | s |),则这两个串形成了匹配。统计匹配次数考虑固定 (lcp) 长度,来找合法的 (lcs) 的个数。

对所有模式串 (p_i) 的正串和反串一起建出 (AC) 自动机,将询问挂在 (p_i) 的前缀在 (AC) 自动机所对应的节点上,对于一个串 (pos) 位置的前缀对应的节点,将 (pos+k+1) 位置的后缀对应的节点也挂上去,计算每个前缀的贡献即可。

对于文本串 (s),同样处理出其前缀在 (AC) 自动机上匹配到 (pos) 位置后,在当前所在的节点挂上其后缀在 (AC) 自动机上匹配到 (pos+k+1) 位置后到达的节点。

统计答案时,遍历一遍 (fail) 树,当到达点 (x) 后,将点 (x) 子树内的节点上挂着文本串对应的后缀节点都加上 (1),对于询问,查询模式串对应的后缀节点的子树和即可,这样就统计出了当前 (lcp) 所对应的合法 (lcs)。可以通过处理出 (dfs) 序后用树状数组维护,即到达一个节点后先统计一次信息,然后遍历该节点子树,将子树内的后缀节点的贡献都加上,然后再统计一遍信息,两次值的差即为当前子树的贡献。

但是直接这样做会算重,当 (lcp + lcs + k > | s |) 时,一个匹配的位置会计算多次,所以要减去算重的部分。每次在节点上再挂上 (pos+k) 位置的后缀对应的节点,减去其贡献即可,这样就能保证每个匹配的位置都只被前缀匹配最长的位置统计到了。

#include<bits/stdc++.h>
#define maxn 400010
#define lowbit(x) (x&(-x))
#define mk make_pair
using namespace std;
template<typename T> inline void read(T &x)
{
    x=0;char c=getchar();bool flag=false;
    while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    if(flag)x=-x;
}
int k,n,m,root,tot,cnt;
int trie[maxn][94],fail[maxn],ans[maxn],bel[maxn],in[maxn],out[maxn];
char str[maxn],s[maxn];
struct node
{
    int id,p1,p2,v1,v2;
};
vector<node> v[maxn];
vector<pair<int,int> > ve[maxn];
struct edge
{
    int to,nxt;
}e[maxn];
int head[maxn],edge_cnt;
void add(int from,int to)
{
    e[++edge_cnt]={to,head[from]},head[from]=edge_cnt;
}
struct BIT
{
    int t[maxn];
    void update(int x)
    {
        while(x<=cnt) t[x]++,x+=lowbit(x);
    }
    int query(int x)
    {
        int sum=0;
        while(x) sum+=t[x],x-=lowbit(x);
        return sum;
    }
    int ask(int l,int r)
    {
        return query(r)-query(l-1);
    }
}T1,T2;
void insert(int id)
{
    int len=strlen(s+1),p=root;
    for(int i=1;i<=len;++i)
    {
        int ch=s[i]-33;
        if(!trie[p][ch]) trie[p][ch]=++tot;
        p=trie[p][ch];
    }
    p=root,bel[len+1]=0;
    for(int i=len;i;--i)
    {
        int ch=s[i]-33;
        if(!trie[p][ch]) trie[p][ch]=++tot;
        p=trie[p][ch],bel[i]=p;
    }
    p=root;
    for(int i=0;i<=len-k;++i)
    {
        node t={id,bel[i+k+1],bel[i+k],0,0};
        if(!i) t.p2=-1;
        v[p].push_back(t),p=trie[p][s[i+1]-33];
    }
}
void build()
{
    queue<int> q;
    for(int i=0;i<94;++i)
        if(trie[root][i])
            q.push(trie[root][i]);
    while(!q.empty())
    {
        int x=q.front();
        q.pop();
        for(int i=0;i<94;++i)
        {
            int y=trie[x][i];
            if(y) fail[y]=trie[fail[x]][i],q.push(y);
            else trie[x][i]=trie[fail[x]][i];
        }
    }
    for(int i=1;i<=tot;++i) add(fail[i],i);
    int p=root;
    for(int i=n;i;--i) bel[i]=p=trie[p][str[i]-33];
    p=root;
    for(int i=0;i<=n-k;++i)
        ve[p].push_back(mk(bel[i+k+1],bel[i+k])),p=trie[p][str[i+1]-33];
}
void dfs_dfn(int x)
{
    in[x]=++cnt;
    for(int i=head[x];i;i=e[i].nxt) dfs_dfn(e[i].to);
    out[x]=cnt;
}
void dfs_ans(int x)
{
    for(int i=0;i<v[x].size();++i)
    {
        node t=v[x][i];
        v[x][i].v1=T1.ask(in[t.p1],out[t.p1]);
        if(t.p2!=-1) v[x][i].v2=T2.ask(in[t.p2],out[t.p2]);
    }
    for(int i=0;i<ve[x].size();++i)
        T1.update(in[ve[x][i].first]),T2.update(in[ve[x][i].second]);
    for(int i=head[x];i;i=e[i].nxt) dfs_ans(e[i].to);
    for(int i=0;i<v[x].size();++i)
    {
        node t=v[x][i];
        ans[t.id]+=T1.ask(in[t.p1],out[t.p1])-t.v1;
        if(t.p2!=-1) ans[t.id]-=T2.ask(in[t.p2],out[t.p2])-t.v2;
    }
}
int main()
{
    read(k),scanf("%s",str+1),read(m),n=strlen(str+1);
    for(int i=1;i<=m;++i)
    {
        scanf("%s",s+1);
        int l=strlen(s+1);
        if(l>k) insert(i);
        else ans[i]=n-l+1;
    }
    build(),dfs_dfn(root),dfs_ans(root);
    for(int i=1;i<=m;++i) printf("%d
",ans[i]);
    return 0;
}