给小白看的KMP算法

文章目录

浅谈KMP算法:

(大部分人的KMP写法都是不一样的)

先给大家推荐一个讲kmp比较好理解的一个博客:阮一峰

下面介绍一点相关概念:

栗子:

  • P串: ABCBD
  • 前缀:A,AB,ABC,ABCB,ABCBD
  • 真前缀:A,AB,ABC,ABCB
  • 后缀:D,BD,CBD,BCBD,ABCBD
  • 真后缀:D,BD,CBD,BCBD

KMP算法里的next数组的含义:

栗子:

  • P串: ABCDABD
  • next[] = {-1, 0, 0, 0, 0, 1, 2, 0, };
  • next[i] 的含义:P串前 i 个字符(包括第 i 个)的最长真前缀后缀公共长度;
  • 如 i = 5时:
    •    真前缀:A,AB,ABC,ABCD
      
    •   真后缀:BCDA,CDA,DA,A
      

-显而易见,前缀和后缀相同的只有 A,而 A 的长度为 1,所以next[5] = 1;

next数组求法:

//用通俗的语句说就是k是用来表示子串中前k个和后k个是相同的,i是用来遍历数组
void get_next(char *t,int lent){
    nex[0] = -1;
    for(int i = 0,k = -1;i < lent;){
        if(k==-1||t[i] == t[k]){
        	++k;++i;
        	nex[i]=k;
		}else k = nex[k];
/*如果c[i]和c[k]中字符不同说明匹配是失败,要把k的值重新退到next[ k ] 直到两者相同才停止。这样做的好处是没必要再重新从头再来,节约时间*/
    }
}

简单KMP算法的实现:

//返回主串中匹配的位置(第一个),如果不匹配返回-1;
int kmp(){
    int i = 0, j = 0;
    while(i < lens&&j<lent) {
        if(j==-1||s[i] == t[j]){
        	i++;j++;
        	if(j==lent){
        		return i-j+1;
			}
    	}else j=nex[j];
    }
    return -1;
}

kmp模板:

//一:
//有时候用string会方便很多,虽然string可能会慢一点点
void get_next(string t){
    nex[0] = -1;
    for(int i = 0,k = -1;i < lent;){
        if(k==-1||t[i] == t[k]){
            ++k;++i;
            nex[i]=k;
        }else k = nex[k];
    }
}
bool kmp(string s,string t){
	lens=s.length();
	lent=t.length();
	if(lens<lent)return 0;
	get_next(t);
    int i = 0, j = 0;
    while(i < lens&&j<lent) {
        if(j==-1||s[i] == t[j]){
            i++;j++;
            if(j==lent){
                return 1;
            }
        }else j=nex[j];
    }
    return 0;
}
//二:
void get_next(char *t,int lent){
    nex[0] = -1;
    for(int i = 0,k = -1;i < lent;){
        if(k==-1||t[i] == t[k]){
            ++k;++i;
            nex[i]=k;
        }else k = nex[k];
    }
}
bool kmp(char *s,int lens,char *t,int lent){
    if(lens<lent)return 0;
    int i = 0, j = 0;
    while(i < lens&&j<lent) {
        if(j==-1||s[i] == t[j]){
            i++;j++;
            if(j==lent){
                return 1;
            }
        }else j=nex[j];
    }
    return 0;
}

几道例题:

洛谷P3375:

#include<cstdio>
#include<cstring>
using namespace std;
const int maxn = 1e6 + 10;
int nex[maxn];
char s[maxn],t[maxn];
int lens,lent;
void get_next(){
    nex[0] = -1;
    for(int i = 0,k = -1;i < lent;){
        if(k==-1||t[i] == t[k]){
        	++k;++i;
        	nex[i]=k;
		}else k = nex[k];
    }
}
	
void kmp(){
    int i = 0, j = 0;
    while(i < lens&&j<lent) {
        if(j==-1||s[i] == t[j]){
        	i++;j++;
        	if(j==lent){
        		printf("%d\n",i-j+1);
        		j=nex[j];
			}
    	}else j=nex[j];
    }
}
int main(){
	while(~scanf("%s %s",s,t)){
		lens=strlen(s);
		lent=strlen(t);
		get_next();
		kmp();
		for(int i=1;i<=lent;++i){
			printf("%d%c",nex[i],i==lent?'\n':' ');
		}
	}
	return 0;
}

毒瘤题hdu5510

here

最小循环节

int len=m-nxt[m];
if(nxt[m]*2<m) len=m;
//短串的最小循环节len

在告诉你一个小秘密:

int ans = 0, t = 0;
for(int i = len; i > 0; ) {
    ans += i - nex[i];
    i = nex[i];
    p[t++] = nex[i];
}
printf("ans = %d\n", ans);///ans 的值和这个字符串的长度一样哦!!!
///只有p数组里存的前缀才是这个字符串的后缀!(如果这句话是错的,望大佬指出!

exKmp

hdu2594模板题
解释:
https://segmentfault.com/a/1190000008663857
https://blog.csdn.net/dyx404514/article/details/41831947
https://blog.csdn.net/discreeter/article/details/78266221

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1000000 + 10, mod = 1e9 + 7;
//nex[i]表示pat与pat[i,len-1]的最长公共前缀
//extend[i]表示pat与ori[i,len-1]的最长公共前缀
char ori[N], pat[N];
int nex[N], extend[N];
int num[N];
int cas = 0;
void get_nex(char *pat) {
    int len = strlen(pat);
    nex[0] = len;
    int k = 0;
    while(k + 1 < len && pat[k] == pat[k + 1]) ++k;
    nex[1] = k;
    k = 1;
    for(int i = 2; pat[i] && i < len; i++) {
        if(i + nex[i - k] < k + nex[k]) nex[i] = nex[i - k];
        else {
            int j = max(k + nex[k] - i, 0);
            while(i + j < len && pat[j] == pat[i + j]) ++j;
            nex[i] = j;
            k = i;
        }
    }
}
void extkmp(char *ori, char *pat) {
    get_nex(pat);
    int leno = strlen(ori), lenp = strlen(pat);
    int k = 0;
    while(k < leno && k < lenp && ori[k] == pat[k]) ++k;
    extend[0] = k;
    k = 0;
    for(int i = 1; ori[i] && i < leno; i++) {
        if(i + nex[i - k] < k + extend[k]) extend[i] = nex[i - k];
        else {
            int j = max(k + extend[k] - i, 0);
            while(i + j < leno && j < lenp && ori[i + j] == pat[j]) ++j;
            extend[i] = j;
            k = i;
        }
    }
}
int main(int argc, char const *argv[]) {
    int n, m, k;
    while (~scanf("%s%s", pat, ori)) {
        n = strlen(ori);
        m = strlen(pat);
        extkmp(ori, pat);
        int ans = 0;
        for(int i = 0; i < n; ++i) {
            if(i + extend[i] == n) ans = max(ans, extend[i]);
        }
        for(int i = n-ans; i < n; ++i) printf("%c", ori[i]);
        if(ans) printf(" ");
        printf("%d\n", ans);
    }
    return 0;
}

AC自动机模板

//基神板子
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>

using namespace std;
typedef long long LL;

const int MX = 500000 + 5;
struct AC_machine {
    int rear, root; int Next[MX][26], Fail[MX], End[MX];
    void Init() { rear = 0; root = New(); }
    int New() {
        rear++; End[rear] = 0;
        for(int i = 0; i < 26; i++) { Next[rear][i] = -1; }
        return rear;
    }
    void Add(char*A) {
        int n = strlen(A), now = root;
        for(int i = 0; i < n; i++) {
            int id = A[i] - 'a';
            if(Next[now][id] == -1) {
                Next[now][id] = New();
            }
            now = Next[now][id];
        }
        End[now]++;
    }
    //字符串匹配build
    void Build() {
        queue<int>Q;
        Fail[root] = root;
        for(int i = 0; i < 26; i++) {
            if(Next[root][i] == -1) { Next[root][i] = root; }
            else { Fail[Next[root][i]] = root; Q.push(Next[root][i]); }
        }
        while(!Q.empty()) {
            int u = Q.front(); Q.pop();
            for(int i = 0; i < 26; i++) {
                if(Next[u][i] == -1) {
                    Next[u][i] = Next[Fail[u]][i];
                } else { Fail[Next[u][i]] = Next[Fail[u]][i]; Q.push(Next[u][i]); }
            }
        }
    }
    //匹配串出现了几个
    int Query(char *S) {
        int n = strlen(S), now = root, ret = 0;
        for(int i = 0; i < n; i++) {
            now = Next[now][S[i] - 'a'];
            int temp = now;
            while(temp != root) {
                ret += End[temp];
                End[temp] = 0; 
                temp = Fail[temp];
            }
        } return ret;
    }
    /* //状态自动机build void Build2() { queue<int>Q; Fail[root] = root; for(int i = 0; i < 4; i++) { if(Next[root][i] == -1) { Next[root][i] = root; }else { Fail[Next[root][i]] = root; Q.push(Next[root][i]); } } while(!Q.empty()) { int u = Q.front(); Q.pop(); //注意这一句话根据不同的情况修改 if(End[Fail[u]]) End[u] = 1; for(int i = 0; i < 4; i++) { if(Next[u][i] == -1) { Next[u][i] = Next[Fail[u]][i]; } else { Fail[Next[u][i]] = Next[Fail[u]][i]; Q.push(Next[u][i]); } } } } //每个字符串出现次数 void Query2(char *S) { int n = strlen(S), now = root; for(int i = 0; i < n; i++) { now = Next[now][S[i] - 'a']; int temp = now; while(temp != root) { if(End[temp]) ans[End[temp]]++; temp = Fail[temp]; } } } //防重叠匹配出现次数 void Query3(char *S) { int n = strlen(S), now = root, ret = 0; for(int i = 0; i < n; i++) { now = Next[now][S[i] - 'a']; int temp = now; while(temp != root) { if(End[temp] && (last[End[temp]] == -1 || last[End[temp]] + len[temp] <= i)) { ans[End[temp]]++; last[End[temp]] = i; } temp = Fail[temp]; } } } */
} aho;
char str[1000006];
int main(int argc, char const *argv[]){
    int tim;
    scanf("%d", &tim);
    while(tim--) {
        int n;scanf("%d", &n);
        aho.Init();
        while(n--){
            scanf("%s", str);
            aho.Add(str);
        }
        aho.Build();
        scanf("%s", str);
        printf("%d\n", aho.Query(str));
    }
    return 0;
}
//我的板子
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>

using namespace std;

const int MXN = 1e6 + 6;
const int MXT = 5e5 + 5;

struct AHO {
    struct trie {
        int nex[26];
        int fail, cnt;
        void New() {
            memset(nex, -1, sizeof(nex));
            fail = cnt = 0;
        }
    }cw[MXT];
    int rt, tot;
    void init() {
        rt = tot = 0;
        cw[0].New();
    }
    void add_str(char *S) {
        int len = strlen(S);
        rt = 0;
        for(int i = 0, now; i < len; ++i) {
            now = S[i] - 'a';
            if(cw[rt].nex[now] == -1) {
                cw[rt].nex[now] = ++tot;
                cw[tot].New();
            }
            rt = cw[rt].nex[now];
        }
        cw[rt].cnt++;
    }
    void build_ac(){
        queue<int> Q;
        Q.push(0);
        cw[0].fail = -1;
        while(!Q.empty()){
            int u = Q.front(); Q.pop();
            for(int i = 0, pos; i < 26; ++i) {
                pos = cw[u].nex[i];
                if(~pos) {
                    if(u == 0) cw[pos].fail = 0;
                    else {
                        int v = cw[u].fail;
                        while(~v){
                            if(~cw[v].nex[i]) {
                                cw[pos].fail = cw[v].nex[i];
                                break;
                            }
                            v = cw[v].fail;
                        }
                        if(v == -1) cw[pos].fail = 0;
                    }
                    Q.push(pos);
                }
            }
        }
    }
    int Get(int u) {
        int ans = 0;
        while(u) {
            ans += cw[u].cnt;
            cw[u].cnt = 0;
            u = cw[u].fail;
        }
        return ans;
    }
    int Query(char *S) {
        int len = strlen(S);
        int ans = 0;
        rt = 0;
        for(int i = 0, now, p; i < len; ++i) {
            now = S[i] - 'a';
            if(~cw[rt].nex[now]) {
                rt = cw[rt].nex[now];
            }else {
                p = cw[rt].fail;
                while(p != -1 && cw[p].nex[now] == -1) p = cw[p].fail;
                if(p == -1) rt = 0;
                else rt = cw[p].nex[now];
            }
            if(cw[rt].cnt) ans += Get(rt);
        }
        return ans;
    }
    int Query2(char *S) {
        int len = strlen(S);
        int ans = 0;
        rt = 0;
        for(int i = 0, now; i < len; ++i) {
            now = S[i] - 'a';
            while(cw[rt].nex[now]==-1&&rt) rt = cw[rt].fail;
            //if(cw[rt].nex[now]!=-1)rt = cw[rt].nex[now];
            rt = cw[rt].nex[now];
            if(rt == -1) rt = 0;
            
            int tmp = rt;
            while(tmp) {
                if(cw[tmp].cnt == 0) break;
                ans += cw[tmp].cnt;
                cw[tmp].cnt = 0;
                tmp = cw[tmp].fail;
            }
        }
        return ans;
    }
}aho;
char s[MXN];
int main(int argc, char const *argv[]) {
    int tim;
    scanf("%d", &tim);
    while(tim--) {
        int n;scanf("%d", &n);
        aho.init();
        for(int i = 0; i < n; ++i) {
            scanf("%s", s);
            aho.add_str(s);
        }
        aho.build_ac();
        scanf("%s", s);
        printf("%d\n", aho.Query(s));
    }    
    return 0;
}
const int MXN = 1e6 + 6;
const int MXT = 5e5 + 5;

struct AHO {
    struct trie {
        int nex[26];
        int fail, cnt;
        void New() {
            memset(nex, -1, sizeof(nex));
            fail = cnt = 0;
        }
    }cw[MXT];
    int rt, tot;
    void init() {
        rt = tot = 0;
        cw[0].New();
    }
    void add_str(char *S) {
        int len = strlen(S);
        rt = 0;
        for(int i = 0, now; i < len; ++i) {
            now = S[i] - 'a';
            if(cw[rt].nex[now] == -1) {
                cw[rt].nex[now] = ++tot;
                cw[tot].New();
            }
            rt = cw[rt].nex[now];
        }
        cw[rt].cnt++;
    }
    void build_ac(){
        queue<int> Q;
        cw[0].fail = 0;
        for(int i = 0; i < 26; ++i){
            if(cw[0].nex[i] == -1){
                cw[0].nex[i] = 0;
            }else {
                cw[cw[0].nex[i]].fail = 0;
                Q.push(cw[0].nex[i]);
            }
        }
        while(!Q.empty()){
            int u = Q.front(); Q.pop();
            for(int i = 0, pos; i < 26; ++i) {
                pos = cw[u].nex[i];
                if(pos == -1) {
                    cw[u].nex[i] = cw[cw[u].fail].nex[i];
                }else {
                    cw[pos].fail = cw[cw[u].fail].nex[i];
                    Q.push(pos);
                }
            }
        }
    }
    int Query2(char* S) {
        int n = strlen(S), now = 0, ans = 0;
        for(int i = 0; i < n; ++i) {
            now = cw[now].nex[S[i]-'a'];
            int tmp = now;
            while(tmp!=0) {
                if(cw[tmp].cnt == 0) break;
                ans += cw[tmp].cnt;
                cw[tmp].cnt = 0;
                tmp = cw[tmp].fail;
            }
        }
        return ans;
    }
}aho;

后缀数组模板

//#include<bits/stdc++.h>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long LL;
/*SA,R,H的下标都是 0~n 其中多包括了一个空字符串*/
struct Suffix_Array {
    static const int N = 3e5 + 7;
    int n, len, s[N], M;
    int sa[N], rnk[N], height[N];
    int tmp_one[N], tmp_two[N], c[N];
    int dp[N][33];
    void init_str(char *str);
    void build_sa(int m = 128);
    void calc_height(int n);
    void Out(char *str);
    void RMQ_init(int n);
    int RMQ_query(int l, int r);
    int cmp_suffix(char* pattern, int p);
}SA;
int Suffix_Array::cmp_suffix(char* pattern, int p){
    return strncmp(pattern, s + sa[p], M);
}
void Suffix_Array::Out(char *str) {
    puts ("/*Suffix*/");
    for (int i=0; i<n; ++i) {
        printf ("%s\n", str+sa[i]);
    }
}
char a[MXN], b[MXN], s[MXN];
int sa[MXN], M;
int cmp_suffix(char* pattern, int p){
    return strncmp(pattern, s + sa[p], M);
}
//LCP(suffix(i), suffix(j))=RMQ_query(rnk[i], rnk[j]);
int Suffix_Array::RMQ_query(int l, int r) {
    l = rnk[l]; r = rnk[r];
    if (l > r) swap(l, r);
    l++;
    int k = 0; while (1<<(k+1) <= r - l + 1) k++;
    return min(dp[l][k], dp[r-(1<<k)+1][k]);
}
void Suffix_Array::RMQ_init(int n) {
    for (int i=0; i<n; ++i) dp[i][0] = height[i];
    for (int j=1; (1<<j)<=n; ++j) {
        for (int i=0; i+(1<<j)-1<n; ++i) {
            dp[i][j] = std::min (dp[i][j-1], dp[i+(1<<(j-1))][j-1]);
        }
    }
}
void Suffix_Array::init_str(char *str) {
    len = strlen(str);
    n = len + 1;
    for (int i=0; i<len; ++i) {
        s[i] = str[i] - 'a' + 1;
    }
    s[len] = '\0';
}
void Suffix_Array::calc_height(int n) {
    for (int i=0; i<=n; ++i) rnk[sa[i]] = i;
    int k = height[0] = 0;
    for (int i=0; i<n; ++i) {
        if (k) k--;
        int j = sa[rnk[i]-1];
        while (s[i+k] == s[j+k]) k++;
        height[rnk[i]] = k;
    }
}
//m = max(r[i]) + 1,一般字符128足够了
void Suffix_Array::build_sa(int m) {
    int i, j, p, *x = tmp_one, *y = tmp_two;
    for (i=0; i<m; ++i) c[i] = 0;
    for (i=0; i<n; ++i) c[x[i]=s[i]]++;
    for (i=1; i<m; ++i) c[i] += c[i-1];
    for (i=n-1; i>=0; --i) sa[--c[x[i]]] = i;
    for (j=1; j<=n; j<<=1) {
        for (p=0, i=n-j; i<n; ++i) y[p++] = i;
        for (i=0; i<n; ++i) if (sa[i] >= j) y[p++] = sa[i] - j;
        for (i=0; i<m; ++i) c[i] = 0;
        for (i=0; i<n; ++i) c[x[y[i]]]++;
        for (i=1; i<m; ++i) c[i] += c[i-1];
        for (i=n-1; i>=0; --i) sa[--c[x[y[i]]]] = y[i];
        std::swap (x, y);
        for (p=1, x[sa[0]]=0, i=1; i<n; ++i) {
            x[sa[i]] = (y[sa[i-1]] == y[sa[i]] && y[sa[i-1]+j] == y[sa[i]+j] ? p - 1 : p++);
        }
        if(p >= n) break;
        m=p;
    }
    calc_height(n-1);
    RMQ_init(n);
}
const int MXN = 3e5 + 7;
char str[MXN];
int main(int argc, char const *argv[]){
    int cas = 1;
    while(~scanf("%s", str) && str[0] != '#'){
        printf("Case %d: ", cas++);
        int len = strlen(str);
        SA.init_str(str);
        SA.build_sa();
        int ans = 0, ansL = 0, ansR = 0;
        for(int d = 1; d * 2 <= len; ++d){
            for(int j = 0; (j+1) * d < len; ++j){
                int x = j*d, y = (j+1)*d;
                if(str[x] != str[y]) continue;
                int z = SA.RMQ_query(x, y);
                //printf("%d %d %d\n", x, y, z);
                int st, ed = z + y - 1, tmp;
                for(int k = 0; k < d; ++k){
                    if(str[x-k] != str[y-k] || x-k < 0) break;
                    st = x - k;
                    tmp = (ed-st+1)/d;
                    if(tmp>ans || (tmp==ans&&SA.rnk[st]<SA.rnk[ansL])){
                        ans = tmp;
                        ansL = st;
                        ansR = st+tmp*d-1;
                    }
                }
            }
        }
        if(ans == 0){
            printf("%c\n", str[SA.sa[1]]);
        }else{
            for(int i = ansL; i <= ansR; ++i) putchar(str[i]);
            putchar(10);
        }
    }    
    return 0;
}
//翔神板子
//#include<bits/stdc++.h>
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>

using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;


const int sigmaSize = 26;
const int MXT = 100000 * 50;

struct ACautomata {
    int ch[MXT][sigmaSize];
    int f[MXT];    // fail函数
    int val[MXT];  // 每个字符串的结尾结点都有一个非0的val
    int last[MXT]; // 输出链表的下一个结点
    int sz;
    int d[MXT];
    void init() {
        sz = 1;
        memset (ch[0], 0, sizeof (ch[0]) );
        memset(d, 0, sizeof(d));
    }

    // 字符c的编号
    inline int idx (char c) {
        return c - 'a';
    }

    // 插入字符串。v必须非0
    void insert (char *s) {
        int u = 0, n = strlen (s);
        for (int i = 0; i < n; i++) {
            int c = idx (s[i]);
             //printf("%c", s[i]);
            if (!ch[u][c]) {
                memset (ch[sz], 0, sizeof (ch[sz]) );
                val[sz] = 0;
                ch[u][c] = sz++;
            }
            u = ch[u][c];
            //printf("%d", u);
            //puts("");
        }
        val[u] += 1;
    }

    // 递归打印匹配文本串str[i]结尾的后缀,以结点j结尾的所有字符串
    void print (int i, int j) {
        if (j) {
            print (i, last[j]);
        }
    }

    // 在T中找模板
    void find (char* T) {
        int n = strlen (T);
        int j = 0; // 当前结点编号,初始为根结点
        for (int i = 0; i < n; i++) { // 文本串当前指针
            int c = idx (T[i]);
            j = ch[j][c];
            if (val[j]) print (i, j);
            else if (last[j]) print (i, last[j]); // 找到了!
        }
    }

    // 计算fail函数
    void getFail() {
        queue<int> q;
        f[0] = 0;
        // 初始化队列
        for (int c = 0; c < sigmaSize; c++) {
            int u = ch[0][c];
            if (u) {
                f[u] = 0;
                q.push (u);
                last[u] = 0;
            }
        }
        // 按BFS顺序计算fail
        while (!q.empty() ) {
            int r = q.front();
            q.pop();
            for (int c = 0; c < sigmaSize; c++) {
                int u = ch[r][c];
                if (!u) {
                    ch[r][c] = ch[f[r]][c];
                    continue;
                }
                q.push (u);
                int v = f[r];
                while (v && !ch[v][c]) v = f[v];
                f[u] = ch[v][c];
                last[u] = val[f[u]] ? f[u] : last[f[u]];
            }
        }
    }
    int query(char *s) {
        int len =strlen(s);
        int now = 0;
        int ans =0;
        for(int i=0; i<len; i++) {
            now = ch[now][s[i]-'a'];
            int tmp = now;
            while(now) {
                if(val[now]) {
                    ans+=val[now];
                    val[now]=0;
                }
                now = last[now];
            }
            now =tmp;
        }
        return ans;
    }
} ac;

char s[1000006];

int main() {
    int n, T;
    scanf("%d", &T);
    while(T--) {
        scanf("%d", &n);
        ac.init();
        for (int i = 1; i <= n; i++) {
            scanf ("%s", s);
            ac.insert (s);
        }
        ac.getFail();
        scanf ("%s", s );
        //ac.find(s);
        printf("%d\n",ac.query(s));
    }
    return 0;
}
    原文作者:KMP算法
    原文地址: https://blog.csdn.net/qq_39599067/article/details/80054337
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞