POJ 3693 反复次数最多的连续重复子串 后缀数组

POJ 3693 重复次数最多的连续重复子串 后缀数组

题目大意就是求重复次数最多的连续重复子串。例如abababc 答案就是ababab  因为ab连续出现的次数最多

并且题目还要求输出字典序最小的

比如abababcdcdcd 

ababab和cdcdcd都符合要求

但是ababab字典序小


具体做法参见罗穗骞的论文

穷举子串的长度L,然后求长度为L的子串最多出现几次

首先连续出现一次是肯定的,所以只考虑出现两次及以上的情况

假设在字符串中出现了两次,记这个重复了两次L长度子串的子串为S。

那么S肯定包含了字符r[0], r[L], r[L*2], r[3 * L]....中的某相邻的两个。

所以就看r[L*i]和r[L*(i + 1)]往前往后分别匹配到多远,记这个长度为K(具体匹配方式看代码),那么就连续出现了K/L+1次,最后看最大值多少

注意每次求这个k要分为两种情况,一是公共前缀恰好模L等于0,另一种是模L不等于0

模L不等于0时还要计算一个值,假如公共前缀%L等于t,就求lcp(i - (L - t), i - (L - t) + L);

为什么呢,我们画一画就知道了,这样的做法,实际上两个公共前缀往前延伸了几个位置,使得前缀的长度加上延伸的长度是L的倍数

然后求lcp,会发现,他是有可能比原来的k大的,那么连续出现的次数也有可能改变。所以这种情况不能遗漏

这里用到了lcp,既求任意两个后缀的最长公共前缀,使用RMQ实现。

最后输出的时候,由于要按字典序输出,就枚举sa数组


#include <iostream>
#include <algorithm>
#include <cstring>
#include <string>
#include <cstdio>
#include <cmath>
#include <queue>
#include <map>
#include <set>
#define MAXN 111111
#define MAXM 200005
#define INF 1000000011
#define lch(x) x<<1
#define rch(x) x<<1|1
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define eps 1e-7
using namespace std;
int r[MAXN];
int wa[MAXN], wb[MAXN], wv[MAXN], tmp[MAXN];
int sa[MAXN]; //index range 1~n value range 0~n-1
int cmp(int *r, int a, int b, int l)
{
    return r[a] == r[b] && r[a + l] == r[b + l];
}
void da(int *r, int *sa, int n, int m)
{
    int i, j, p, *x = wa, *y = wb, *ws = tmp;
    for (i = 0; i < m; i++) ws[i] = 0;
    for (i = 0; i < n; i++) ws[x[i] = r[i]]++;
    for (i = 1; i < m; i++) ws[i] += ws[i - 1];
    for (i = n - 1; i >= 0; i--) sa[--ws[x[i]]] = i;
    for (j = 1, p = 1; p < n; j *= 2, m = p)
    {
        for (p = 0, i = n - j; i < n; i++) y[p++] = i;
        for (i = 0; i < n; i++)
            if (sa[i] >= j) y[p++] = sa[i] - j;
        for (i = 0; i < n; i++) wv[i] = x[y[i]];
        for (i = 0; i < m; i++) ws[i] = 0;
        for (i = 0; i < n; i++) ws[wv[i]]++;
        for (i = 1; i < m; i++) ws[i] += ws[i - 1];
        for (i = n - 1; i >= 0; i--) sa[--ws[wv[i]]] = y[i];
        for (swap(x, y), p = 1, x[sa[0]] = 0, i = 1; i < n; i++)
            x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;
    }
}
int rank[MAXN]; //index range 0~n-1 value range 1~n
int height[MAXN]; //index from 1   (height[1] = 0)
void calheight(int *r, int *sa, int n)
{
    int i, j, k = 0;
    for (i = 1; i <= n; ++i) rank[sa[i]] = i;
    for (i = 0; i < n; height[rank[i++]] = k)
        for (k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; ++k);
    return;
}
int Log[MAXN];
int mi[MAXN][20];
void rmqinit(int n)
{
    for(int i = 1; i <= n; i++) mi[i][0] = height[i];
    int m = Log[n];
    for(int i = 1; i <= m; i++)
        for(int j = 1; j <= n; j++)
        {
            mi[j][i] = mi[j][i - 1];
            if(j + (1 << (i - 1)) <= n) mi[j][i] = min(mi[j][i], mi[j + (1 << (i - 1))][i - 1]);
        }
}
int lcp(int a, int b)
{//询问a,b后缀的最长公共前缀
    a = rank[a];    b = rank[b];
    if(a > b) swap(a,b);
    a ++;
    int t = Log[b - a + 1];
    return min(mi[a][t] , mi[b - (1<<t) + 1][t]);
}
char s[MAXN];
int ans[MAXN];
char out[MAXN];
int main()
{
    Log[1] = 0;
    for(int i = 2; i < MAXN; i++) Log[i] = Log[i >> 1] + 1;
    int cas = 0;
    while(scanf("%s", s) != EOF && strcmp(s, "#") != 0)
    {
        int n = strlen(s), m = 0;
        for(int i = 0; i < n; i++)
        {
            m = max(m, (int)s[i]);
            r[i] = s[i];
        }
        r[n] = 0;
        da(r, sa, n + 1, m + 1);
        calheight(r, sa, n);
        rmqinit(n);
        int mx = -1;
        int cnt, l;
        for(l = 1; l < n; l++) //枚举子串的长度
        {
            for(int i = 0; i + l < n; i += l)
            {
                int k = lcp(i, i + l);//计算r[L*i]和r[L*(i+1)]的最长公共前缀
                int p = k / l + 1;
                int t = l - k % l;
                t = i - t;
                //printf("ss i :%d i + l :%d t :%d t + l:%d p: %d\n", i, i + l, t, t + l, p);
                if (t >= 0 && k % l != 0) 
                {
                    int tk = lcp(t, t + l);
                    if(tk / l + 1 > p) p = tk / l + 1;
                }
                if(p > mx)
                {
                    cnt = 0;
                    mx = p;
                    ans[cnt++] = l;
                }
                if(p == mx) ans[cnt++] = l;
            }
        }
        int pos = 0;
        int flag = 0;
        for(int i = 1; i <= n && !flag; i++)
        {
            for(int j = 0; j < cnt; j++)
            {
                int k = ans[j];
                if(lcp(sa[i], sa[i] + k) >= (mx - 1) * k)
                {
                    pos = sa[i];
                    l = mx * k;
                    flag = 1;
                    break;
                }
            }
        }
        printf("Case %d: ", ++cas);
        for(int i = 0; i < l; i++) printf("%c", s[pos + i]);
        printf("\n");
    }
    return 0;
}