最短母串 https://loj.ac/problem/10061 题目描述 思路 代码

题目描述

  给出若干字符串,求一个最短的字符串(T)满足这些字符串都是(T)的子串。

思路

  这道题想到和(AC)自动机联系起来还是有一些难度,不过要维护这些字符串肯定要建一棵字典树,再考虑如何求最短母串。我们考虑最短母串至少会经过字典树上的一些节点,并且如果当前节点为(u),下一节点为(v),那么设(1sim v)的字符串为(T)(1sim u)的字符串(S),那么(v)能在(u)后面当且仅当(T[1...j]=S[i-j+1...i]),而这就是(AC)自动机中(next)数组的定义,因此我们可以先把(AC)自动机板子打上去。

  接下来我们考虑如何求答案。显然,在(Trie)图(就是(Trie)树和失配指针),我们要求出一种方案,使得它经过每一个字符串结尾,相当于就是在(Trie)图上找一条经过每个字符串结尾节点的最短路。而这个问题我们可以用状态压缩(dp)维护,由于(n)比较小,我们用(state)表示已经经过哪几个字符串尾部,转移方程就很显然了:(dp[v][state|s[u]]=dp[u][state]+1)。而我们在这里维护一下每个节点入队时其前缀节点的编号,这样就可以在搜到答案时往回寻找即可。

  不过还是吐槽一下这道题的空间限制,只有(32MB),本来不想用状压想用分层图,但发现空间开不下,而且没用(STL)中的队列和循环队列,直接开数组居然有一个点(MLE)了,有一些无语。

代码

#include<bits/stdc++.h>
using namespace std;

struct aa
{
    int state,m;
    aa(int m=0,int state=0):m(m),state(state) {}
};

int tot=1,ch[610][26],ed[610],nxt[610];
void insert(char *s,int k)        //字典树建立 
{
    int u=1;
    int len=strlen(s);
    for(int i=0;i<len;i++)
    {
        int c=s[i]-'A';
        if(!ch[u][c])ch[u][c]=++tot;
        u=ch[u][c];
    }
    ed[u]|=1<<k;        //字符串结束标记 
}
void getfail()
{
    int q[610];
    for(int i=0;i<26;i++)
        ch[0][i]=1;            //把零点和根节点连边,不要忘记 
    q[1]=1;nxt[1]=0;
    int head=1,tail=1;
    while(head<=tail)
    {
        int u=q[head];head++;
        for(int i=0;i<26;i++)
        {
            if(!ch[u][i])ch[u][i]=ch[nxt[u]][i];
            else
            {
                int v=nxt[u];
                q[++tail]=ch[u][i];
                if(v&&!ch[v][i])v=nxt[v];
                nxt[ch[u][i]]=ch[v][i];
            }
        }
    }
    for(int i=0;i<=tot;i++)
    {
        int v=nxt[i];
        while(v)ed[i]|=ed[v],v=nxt[v];
    }
}
int n,cnt;
int dp[610][4400],c[2400010],ans[2400010],fa[2400010];
queue<aa>q;
void bfs()
{
    memset(dp,0x3f,sizeof(dp));
    int head=1,tail=1,idx=0,k=0;
    q.push(aa(1,0));dp[1][0]=0;
    while(!q.empty())
    {
        aa u=q.front();q.pop();
        if(u.state==((1<<n)-1))
        {
            for(int i=k;i;i=fa[i])
                ans[++cnt]=c[i];
            return ;
        }
        for(int i=0;i<26;i++)
        {
            if(!ch[u.m][i])continue ;
            int v=ch[u.m][i];
            int s=u.state|ed[v];
            if(dp[u.m][u.state]+1<dp[v][s])
            {
                dp[v][s]=dp[u.m][u.state]+1;
                q.push(aa(v,s));
                fa[++idx]=k;            //记录前一个节点的编号 
                c[idx]=i;
            }
        }
        k++;
    }
}

char s[70];
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        scanf(" %s",s);
        insert(s,i-1);    //插入节点,为节省空间从1<<0开始 
    }
//    cout<<tot<<endl;
    getfail();
//    for(int i=1;i<=tot;i++)
//        cout<<i<<' '<<nxt[i]<<endl;
    bfs();
//    cout<<cnt<<endl;
    for(int i=cnt;i>=1;i--)
        putchar(ans[i]+'A');
}