子串

xiaoxiao2021-02-27  449

题目描述

SA题

朴素大概要个很高的复杂度。 想一个高端一点的暴力,可以只枚举两个后缀,对于这两个后缀任意前缀之间lcp可以列出数学式子,这个式子与这两个后缀的长度以及它们的lcp长度有关。 接下来我们知道lcp等于一段区间height的最小值。 因此写个sa,然后根据height建立笛卡尔树。 接着递归维护需要维护的信息,每次以一个点为lcp值统计答案。 式子因为忘了怎么推就不推啦!

#include<cstdio> #include<algorithm> #include<cstring> //#include<ctime> #define fo(i,a,b) for(i=a;i<=b;i++) #define fd(i,a,b) for(i=a;i>=b;i--) using namespace std; typedef long long ll; const int maxn=500000+10,mo=998244353; char s[maxn]; int rank[maxn*2],sa[maxn],height[maxn],a[maxn],b[maxn],c[maxn],d[maxn],len[maxn],le[maxn]; int fa[maxn],left[maxn],right[maxn],sta[maxn],size[maxn],sum[maxn],num[maxn]; int i,j,k,l,r,mid,t,n,m,ans,top,root; void getsa(){ fo(i,1,n) b[s[i]-'a']++; fo(i,1,26) b[i]+=b[i-1]; fd(i,n,1) c[b[s[i]-'a']--]=i; t=0; fo(i,1,n){ if (s[c[i]]!=s[c[i-1]]) t++; rank[c[i]]=t; } j=1; while (j<=n){ fo(i,0,n) b[i]=0; fo(i,1,n) b[rank[i+j]]++; fo(i,1,n) b[i]+=b[i-1]; fd(i,n,1) c[b[rank[i+j]]--]=i; fo(i,0,n) b[i]=0; fo(i,1,n) b[rank[c[i]]]++; fo(i,1,n) b[i]+=b[i-1]; fd(i,n,1) d[b[rank[c[i]]]--]=c[i]; t=0; fo(i,1,n){ if (rank[d[i]]!=rank[d[i-1]]||rank[d[i]+j]!=rank[d[i-1]+j]) t++; c[d[i]]=t; } fo(i,1,n) rank[i]=c[i]; if (t==n) break; j*=2; } fo(i,1,n) sa[rank[i]]=i; } void getheight(){ k=0; fo(i,1,n){ if (k) k--; j=sa[rank[i]-1]; while (i+k<=n&&j+k<=n&&s[i+k]==s[j+k]) k++; height[rank[i]]=k; } } void dfs(int x){ if (!x) return; dfs(left[x]); dfs(right[x]); size[x]=size[left[x]]+size[right[x]]+1; sum[x]=(sum[left[x]]+sum[right[x]])%mo; sum[x]=(sum[x]+len[x])%mo; num[x]=(num[left[x]]+num[right[x]])%mo; num[x]=(num[x]+le[x])%mo; int t=a[x]; j=((ll)t*(t+1)/2)%mo; k=(ll)(num[left[x]]+le[x])*j%mo*(size[right[x]]+1)%mo; (ans+=k)%=mo; k=(ll)(sum[right[x]]+len[x])*j%mo*(size[left[x]]+1)%mo; (ans+=k)%=mo; k=(ll)(size[left[x]]+1)*(size[right[x]]+1)%mo*j%mo; (ans+=k)%=mo; j=((ll)t*(t+1)*(2*t+1)/3)%mo; k=(ll)(size[left[x]]+1)*(size[right[x]]+1)%mo*j%mo; (ans-=k)%=mo; j=(ll)t*t%mo*t%mo; k=(ll)(size[left[x]]+1)*(size[right[x]]+1)%mo*j%mo; (ans+=k)%=mo; k=(ll)(num[left[x]]+le[x])*(sum[right[x]]+len[x])%mo*t%mo; (ans+=k)%=mo; k=(ll)(num[left[x]]+le[x])*(size[right[x]]+1)%mo*t%mo*t%mo; (ans-=k)%=mo; k=(ll)(sum[right[x]]+len[x])*(size[left[x]]+1)%mo*t%mo*t%mo; (ans-=k)%=mo; } int main(){ freopen("substring.in","r",stdin);freopen("substring.out","w",stdout); scanf("%s",s+1); n=strlen(s+1); getsa(); getheight(); fo(i,2,n) a[i-1]=height[i],len[i-1]=n-sa[i]+1,le[i-1]=n-sa[i-1]+1; fo(i,1,n-1){ while (top&&a[i]<a[sta[top]]){ right[fa[sta[top]]]=0; fa[sta[top]]=i; right[sta[top]]=left[i]; fa[left[i]]=sta[top]; left[i]=sta[top]; top--; } if (top){ fa[i]=sta[top]; right[sta[top]]=i; } sta[++top]=i; } fo(i,1,n) if (!fa[i]){ root=i; break; } dfs(root); ans=(ll)ans*2%mo; fo(i,1,n){ l=t=n-i+1; k=((ll)t*(t+1)/2)%mo; (ans+=(ll)(2*l+1)%mo*k%mo)%=mo; k=((ll)t*(t+1)*(2*t+1)/3)%mo; k=-k; (ans+=k)%=mo; } (ans+=mo)%=mo; printf("%d\n",ans); //printf("%d\n",clock()); fclose(stdin);fclose(stdout); }
转载请注明原文地址: https://www.6miu.com/read-4392.html

最新回复(0)