hdu 2243 AC+矩阵快速幂

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=2243 

做这道题前,得先去做这道题:http://poj.org/problem?id=2778

目大意给定m个词根,现在要用26个字母组成长度小等于n的字符串并且至少含一个词根的组合种数,n <  2^31,结果对2^64次方取余

 

这道题跟POJ那道差不多,不过这道更恶心,结果要对2^64取模(巨坑啊,害我卡了那么久),其实就相当于不用取模,因为unsigned __int64越界了就变成0,相当于已经取模了,把各种取模去了之后就AC了。

 

做法就是先求出所有的情况数,即26^1+26^2+…+26^k(k为最长字符串数,用快速幂就搞定了);再求出不包含词根的字符串数,做完POJ2778那道就没问题了,那道求的是k长度的,通过AC自动机构造出目标矩阵A,求A^k(用矩阵快速幂)就可以得到结果,但hdu这道长度可以是1、2…k,其实就是求A^1+A^2+…+A^k,但k可以很大,固然不可以一个一个求,具体做法参考这个http://blog.csdn.net/llc_9012/article/details/8911120。最后用总的减去不包含词根的,就得出结果啦。

 

学习AC自动机的娃,强烈推荐这位大神的博客http://www.notonlysuccess.com/index.php/aho-corasick-automaton/

 

在此贴出俺的AC代码:

#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;

#define UI unsigned __int64
#define SIZE 31

struct Matrix
{
    UI a[SIZE][SIZE];
    int len;
};

int num;
struct node
{
    int fail, end;
    int next[26];
}p[SIZE];

int get_node()
{
    num++;
    p[num].fail = p[num].end = 0;
    memset(p[num].next, 0, sizeof(p[num].next));
    return num;
}

void insert(char *st)
{
    int i=0, k, id=1;
    while(st[i])
    {
        k = st[i]-'a';
        if( !p[id].next[k] )
           p[id].next[k] = get_node();
        id = p[id].next[k];
        i++;
    }
    p[id].end = 1;
}

int q[SIZE];
void make_fail()
{
    int head, tail;
    head  = tail = 0;
    int id, temp;
    q[tail++] = 1;
    while(head < tail)
    {
        id = q[head++];
        for(int i=0; i<26; i++)
        {
            if( p[id].next[i] )
            {
                q[tail++] = p[id].next[i];
                if(id == 1)
                {
                    p[ p[id].next[i] ].fail = 1;
                }
                else
                {
                    temp = p[id].fail;
                    while( temp && !p[temp].next[i])
                    {
                        temp = p[temp].fail;
                    }
                    if( !temp )
                    {
                        p[ p[id].next[i] ].fail = 1;
                    }
                    else
                    {
                        p[ p[id].next[i] ].fail = p[temp].next[i];
                        if( p[ p[temp].next[i] ].end )
                           p[ p[id].next[i] ].end = 1;
                    }
                }
            }
        }
    }
}

Matrix A, E;
void get_matrix_E()
{
    memset(E.a, 0, sizeof(E.a));
    E.len = num;
    for(int i=1; i<=E.len; i++)
    {
        E.a[i][i] = 1;
    }
}
void get_matrix_A()
{
    A.len = num;
    memset(A.a, 0, sizeof(A.a));
    int id, i, temp;
    for(id=1; id<=A.len; id++)
    {
        if( p[id].end )
            continue;

        for(i=0; i<26; i++)
        {
            if( p[id].next[i] && !p[ p[id].next[i] ].end )
            {
                A.a[id][ p[id].next[i] ]++;
            }
            else if( !p[id].next[i])
                 {
                     temp = p[id].fail;
                     while( temp && !p[temp].next[i])
                     {
                         temp = p[temp].fail;
                     }
                     if( !temp )
                     {
                         A.a[id][1]++;
                     }
                     else
                     {
                         if( !p[ p[temp].next[i] ].end )
                         {
                             A.a[id][ p[temp].next[i] ]++;
                         }
                     }
                 }
        }
    }
}

Matrix mul(Matrix m, Matrix n)
{
    Matrix c;
    c.len = m.len;
    int i, j, k;
    UI sum;
    for(i=1; i<=c.len; i++)
       for(j=1; j<=c.len; j++)
       {
           sum = 0;
           for(k=1; k<=c.len; k++)
           {
               sum += m.a[i][k]*n.a[k][j];
           }
           c.a[i][j] = sum;
       }
    return c;
}

Matrix add(Matrix m, Matrix n)
{
    Matrix c;
    c.len = m.len;
    for(int i=1; i<=c.len; i++)
        for(int j=1; j<=c.len; j++)
           c.a[i][j] = (m.a[i][j] + n.a[i][j]) ;
    return c;
}

Matrix pow(Matrix x, int n)
{
    Matrix ans;
    ans = E;
    ans.len = x.len;
    while(n)
    {
        if(n%2 == 1)
            ans = mul(ans, x);
        x = mul(x, x);
        n /= 2;
    }
    return ans;
}

Matrix fun(Matrix x, int n)
{
    //cout << n <<"*"<<endl;
    if(n == 1)
        return x;
    if(n%2 == 1)
        return add(mul(fun(x, (n-1)/2), add(E, pow(x, (n-1)/2))), pow(x, n));
        //return add(fun(x, n-1), pow(x, n));
    else
        return mul(fun(x, n/2), add(E, pow(x, n/2)));
}

UI pow1(UI x, int n)
{
    UI y=1;
    while(n)
    {
        if(n%2 == 1)
           y = (y*x) ;
        x = (x*x) ;
        n /= 2;
    }
    return y;
}
UI f(UI x, int n)
{
    if(n == 1)
       return x;
    if(n%2 == 1)
        return (f(x, n-1)+pow1(x,n)) ;
    else
        return (f(x, n/2) * (1+pow1(x,n/2)));
}
/*
void print()
{
    for(int i=1; i<=A.len; i++)
        {
            for(int j=1; j<=A.len; j++)
               cout << A.a[i][j] << " ";
            cout << endl;
        }
}
*/
int main()
{
    /*
    MOD = 1;
    for(int i=0; i<64; i++)
       MOD *= 2;
    MOD -= 1;
    */
    int n, l;
    char s[10];
    UI sum, ans;
    while(scanf("%d%d", &n, &l) != EOF)
    {
        num = 0;
        get_node();
        while(n--)
        {
            scanf("%s", s);
            insert(s);
        }

        make_fail();
        get_matrix_E();
        get_matrix_A();

        A = fun(A, l);

        sum = 0;
        for(int i=1; i<=A.len; i++)
        {
            sum += A.a[1][i];
        }
        //cout<<sum<<endl;
        //cout<<f(26,l)<<endl;
        ans = f(26,l)-sum;
        //printf("%I64u\n", (ans+MOD)%MOD);
        printf("%I64u\n", ans);
    }
	return 0;
}

 

点赞