一个做法好多的字符串题目。

题目

bzoj 3238: [Ahoi2013]差异

Time Limit:\(20\)Sec Memory Limit: \(512\)MB

Description

图片标题 ## Input

一行,一个字符串\(S\)

Output

一行,一个整数,表示所求值

解题报告

看着\(ISA\)大爷在学习\(SAM\),这样的话:我也要学~~ 所以我找了这个题,感觉后缀数组做非常简单:

  • 题目要求\(\sum_{1<=i<j<=n} \ (len(S_i)+len(S_j))+ 2* \sum_{1<=i<j<=n} \ lcp(S_i,S_j)\)
  • 前半部分可以通过直接算得到,后半部分不需要按照字符串原有顺序枚举,可以在\(sa\),数组上,枚举\(\sum_{i=1}^{n-1} (\sum_{j=i+1}^{n}lcp(sa[i],sa[j]))\) 括号内的部分是一个连续区间和\[LCP(sa[i],sa[j])= min height_k (k=i+1 \to j )\] 所以很显然,从右向左枚举 \(sa\) ,在求\(\sum_{j=i+1}^n lcp(sa[i],sa[j])\)时,当前 \(sa[i]\) 会影响后方一个区间 \([i,r]\) ,满足\(height_k (k \in [i+1,r]) > height_i\),所以用一个单调栈,找到这个区间,然后区间覆盖+区间求和就搞定了;
  • 黄学长也使用后缀数组,但是他用了两个单调栈直接求出了每个\(height\)贡献答案的区间,比我高到不知道哪里去了;

然后强行练后缀自动机的模板:

  • 因为没学过后缀树,所以想得有些烦;
  • 因为后缀自动机的\(parent\)树上存在\(right(fa) \supseteq right(x)\) ,并且\(max(fa)+1=min(x)\),思考含义,实际上是从\(right\)集合中的每个位置记录向前的子串,在到达\(max+1\)长度时出现差异就会在\(par\)树上分支,那么\(par\)\(max-min\)\(right\)集合元素的最长公共后缀的组成部分;
  • 所以反建后缀自动机,在每个节点统计答案:当前\(right\)集合任选两个$ (max-min)$

后缀数组的代码交时忘关文件\(RE\)一波,\(SAM\)的反而先\(A\)了; 并且显然\(SAM\)并没有速度上的优势; ### 代码 \(SA\):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstdlib>
#include<cstring>
using namespace std;
const int N=500002;
char s[N];
int sa[N],wa[N],wb[N],ss[N],wv[N],n;
int height[N],rank[N],sta[N],top;
int mark[N<<2];long long sm[N<<2];
long long ans=0,sum;
inline bool cmp(int *x,int l,int r,int len) {
return (x[l]==x[r])&&(x[l+len]==x[r+len]);
}
void da(char *s,int *sa,int n,int m) {
int i,j,p,*x=wa,*y=wb,*t;
for (i=0;i<m;++i) ss[i]=0;
for (i=0;i<n;++i) ss[x[i]=s[i]]++;
for (i=1;i<m;++i) ss[i]+=ss[i-1];
for (i=n-1;i>=0;--i) sa[--ss[x[i]]]=i;
for (j=1,p=1;j<n&&p<n;m=p,j<<=1) {
for (p=0,i=n-j;i<n;++i) y[p++]=i;
for (i=0;i<n;++i) if (sa[i]>=j) y[p++]=sa[i]-j;
for (i=0;i<n;++i) wv[i]=x[y[i]];
for (i=0;i<m;++i) ss[i]=0;
for (i=0;i<n;++i) ss[wv[i]]++;
for (i=1;i<m;++i) ss[i]+=ss[i-1];
for (i=n-1;i>=0;--i) sa[--ss[wv[i]]]=y[i];
for (t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<n;++i)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
}
void calheight(char *s,int *sa,int n) {
int k=0,i,j;
for (i=1;i<=n;++i) rank[sa[i]]=i;
for (i=0;i<n;height[rank[i]]=k,++i) {
k?k--:0; j=sa[rank[i]-1];// cout<<i<<' '<<j<<' ';
for (;s[j+k]==s[i+k]&&i+k<n&&j+k<n;)
k++;
}
}
inline void up(int x) {
sm[x]=sm[x<<1]+sm[x<<1|1];
}
inline void down(int x,int l,int r) {
int mid=(l+r)>>1;
if (mark[x]!=-1) {
sm[x<<1]=1LL*(mid-l+1)*mark[x];
sm[x<<1|1]=1LL*(r-mid)*mark[x];
mark[x<<1]=mark[x],mark[x<<1|1]=mark[x];
mark[x]=-1;
}
}
void build(int x,int l,int r) {
if (l==r) {sm[x]=height[l];return;}
int mid=(l+r)>>1;
build(x<<1,l,mid),build(x<<1|1,mid+1,r);
up(x);
}
void query(int x,int l,int r,int L,int R) {
if (L<=l&&r<=R) { sum+=sm[x]; return;}
int mid=(l+r)>>1; down(x,l,r);
if (L<=mid) query(x<<1,l,mid,L,R);
if (R>mid) query(x<<1|1,mid+1,r,L,R);
}
void change(int x,int l,int r,int L,int R,int v) {
if (L<=l&&r<=R) {
sm[x]=1LL*(r-l+1)*v,mark[x]=v;
return;
} int mid=(l+r)>>1; down(x,l,r);
if (L<=mid) change(x<<1,l,mid,L,R,v);
if (R>mid) change(x<<1|1,mid+1,r,L,R,v);
up(x);
}
int main() {
scanf("%s",s),n=strlen(s); s[n]=0;int i,j;
da(s,sa,n+1,128);
calheight(s,sa,n);
build(1,2,n),memset(mark,-1,sizeof(mark));
for (i=0;i<n;++i) ans+=1LL*(n-i)*(n-1);
for (i=n;i>=2;--i) {
while (top&&height[sta[top]]>height[i]) --top;
if (!top) j=n; else j=sta[top]-1; sta[++top]=i;
if (i<j) change(1,2,n,i+1,j,height[i]);
sum=0,query(1,2,n,i,n);
ans-=sum<<1;
}
printf("%lld\n",ans);
return 0;
}

\(SAM\):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
using namespace std;
const int N=500010;
char s[N];
int n,root=1,last=root,cnt=1,head[N<<1],tot,r[N<<1],have[N<<1];
long long ans=0;
struct edge {
int next,to;
edge (int next=0,int to=0)
:next(next),to(to) {}
} e[N<<1];
struct node {
int par,go[26],val;
node() :par(0),val(0) {
memset(go,0,sizeof(go));
}
} mem[N<<1];
int newnode(int val) {
++cnt,mem[cnt].val=val; return cnt;
}
void extend(int w) {
int p=last;
int np=newnode(mem[p].val+1);
while (p && mem[p].go[w]==0)
mem[p].go[w]=np,p=mem[p].par;
if (p==0) mem[np].par=root;
else {
int q=mem[p].go[w];
if (mem[p].val+1==mem[q].val)
mem[np].par=q;
else {
int nq=newnode(mem[p].val+1);
memcpy(mem[nq].go,mem[q].go,sizeof(mem[q].go));
mem[nq].par=mem[q].par;
mem[q].par=nq,mem[np].par=nq;
while (p&&mem[p].go[w]==q)
mem[p].go[w]=nq,p=mem[p].par;
}
}
last=np;
have[np]=1;
}
inline void add(int x,int y) {
e[++tot]=edge(head[x],y),head[x]=tot;
}
void dfs(int x) {
int i,y; r[x]=have[x];
for (i=head[x];i;i=e[i].next) {
y=e[i].to; dfs(y);
r[x]+=r[y];
}
ans-=1LL*r[x]*(r[x]-1)*(mem[x].val-mem[mem[x].par].val);
}
int main() {
scanf("%s",s); int i; n=strlen(s);
for (i=n-1;i>=0;--i) extend(s[i]-'a');
for (i=2;i<=cnt;++i) add(mem[i].par,i);
dfs(1);
for (i=0;i<n;++i) ans+=1LL*(n-i)*(n-1);
printf("%lld\n",ans);
return 0;
}

\(SAM\)短很多!

留言