题目链接:https://www.nowcoder.com/acm/contest/127/D
解题思路:拓展KMP,只要求出字符串b对字符串a每一个后缀的最长公共前缀即可
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>
#include <algorithm>
#include <map>
#include <cmath>
#include <set>
#include <stack>
#include <queue>
#include <vector>
#include <bitset>
#include <functional>
using namespace std;
#define LL long long
const int INF = 0x3f3f3f3f;
int nt[2000009], extend[2000009];
char s1[2000009], s2[2000009];
void GetNext(char T[], int nt[]) {
int t_len = strlen(T);
nt[0] = t_len;
int a, p;
for (int i = 1, j = -1; i < t_len; i++, j--) {
if (j < 0 || i + nt[i - a] >= p) {
if (j < 0)
p = i, j = 0;
while (p < t_len && T[p] == T[j])
p++, j++;
nt[i] = j;
a = i;
}
else nt[i] = nt[i - a];
}
}
void GetExtend(char S[], char T[], int extend[], int nt[]) {
GetNext(T, nt);
int a, p;
int s_len = strlen(S);
int t_len = strlen(T);
for (int i = 0, j = -1; i < s_len; i++, j--) {
if (j < 0 || i + nt[i - a] >= p) {
if (j < 0) p = i, j = 0;
while (p < s_len && j < t_len && S[p] == T[j])
p++, j++;
extend[i] = j;
a = i;
}
else extend[i] = nt[i - a];
}
}
int main() {
while (~scanf("%s %s", s1, s2)) {
GetExtend(s1, s2, extend, nt);
LL ans = 0;
int len1 = strlen(s1), len2 = strlen(s2);
for (int i = 0; s1[i]; i++) {
if (extend[i] == len2) {
ans = ans + 1LL * extend[i] - 1LL;
if (s1[i + extend[i]] < s2[extend[i]]) {
ans = ans + 1LL * len1 - 1LL * i - 1LL * extend[i];
}
} else {
if (s1[i + extend[i]] < s2[extend[i]]) {
ans = ans + 1LL * len1 - 1LL * i - 1LL * extend[i];
}
printf("%lld\n", ans);
}
}
printf("%lld\n", ans);
}
return 0;
}