大意是给你两个序列A和B,要求出A中有多少段和B长度相等的连续序列G满足,G和B中位置一一对应的每个数分别在G和B中的排名都相同。
我们可以在普通的KMP上动点手脚(雾)。对于当前位置已经匹配了k个数字,如果k+1个数字的排名也相等就可以加进来,而且对之后的匹配的影响也是相同的,即无后效性,这样就保证了已经匹配的数排名都一直相等。在实际求排名的时候,我们可以求出之前的序列中有多少个数大于当前数,多少个等于当前数(对别的数有影响),注意考虑到数值相同的情况,所以这两个条件要同时相等,这个可以用树状数组维护。但是要注意的是更新树状数组的地方略有点多。
KMP的复杂度O(N),稍微带点常数,树状数组O(logN),总复杂度就是O(NlogN)。
另外这题样例有问题(大雾)
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cstdlib>
using namespace std;
#define lowbit(x) ((x)&(-x))
#define mp make_pair
#define read(x) x=readl()
const int N=5*1e5+5;
int a[N],b[N],fail[N],n,m,S;
int ans[N],tot;
int sum1[N],sum2[N];
typedef pair<int,int> pr;
int sum(int t[],int x){
int ret=0;
for (;x;x-=lowbit(x)) ret+=t[x];
return ret;
}
void add(int t[],int x,int k){
for (;x<=S;x+=lowbit(x)) t[x]+=k;
}
pr get(int t[],int p){
int x=sum(t,p-1);
int y=sum(t,p);
return mp(x,y);
}
void init(){
memset(sum1,0,sizeof sum1);
memset(sum2,0,sizeof sum2);
fail[1]=0;
for (int i=2,j=0;i<=m;i++){
while (j){
pr x=get(sum1,b[j+1]);
pr y=get(sum2,b[i]);
if (x==y) break;
for (int k=fail[j]+1;k<=j;k++) add(sum1,b[k],-1);
for (int k=i-j;k<i-fail[j];k++) add(sum2,b[k],-1);
j=fail[j];
}
pr x=get(sum1,b[j+1]);
pr y=get(sum2,b[i]);
if (x==y){
add(sum1,b[j+1],1);
add(sum2,b[i],1);
j++;
}
fail[i]=j;
}
}
void work(){
memset(sum1,0,sizeof sum1);
memset(sum2,0,sizeof sum2);
for (int i=1,j=0;i<=n;i++){
while (j){
pr x=get(sum1,b[j+1]);
pr y=get(sum2,a[i]);
if (x==y) break;
for (int k=fail[j]+1;k<=j;k++) add(sum1,b[k],-1);
for (int k=i-j;k<i-fail[j];k++) add(sum2,a[k],-1);
j=fail[j];
}
pr x=get(sum1,b[j+1]);
pr y=get(sum2,a[i]);
if (x==y){
add(sum1,b[j+1],1);
add(sum2,a[i],1);
j++;
}
if (j==m){
ans[++tot]=i-m+1;
for (int k=fail[j]+1;k<=j;k++)
add(sum1,b[k],-1);
for (int k=i-j+1;k<=i-fail[j];k++)
add(sum2,a[k],-1);
j=fail[j];
}
}
}
int readl(){
int p=0;char x=getchar();
while (x<'0' || x>'9') x=getchar();
while (x>='0' && x<='9') p=p*10+x-'0',x=getchar();
return p;
}
int main(){
read(n);read(m);read(S);
for (int i=1;i<=n;i++) read(a[i]);
for (int j=1;j<=m;j++) read(b[j]);
init();
work();
printf("%d\n",tot);
for (int i=1;i<=tot;i++) printf("%d\n",ans[i]);
return 0;
}