POJ 3233 Matrix Power Series

题意:给一个n阶矩阵A,求A+A^2+A^3+…+A^k的结果

这道题看似挺简单,但K的值很大10^9,肯定无法每个都进行处理,所以肯定存在某些重复的地方,从而减少计算量,这就是题目的第一个考察点。由于A+A^2+A^3+…+A^k = (A+A^2+A^3+…+A^(k/2)) * (E+A^(k/2)),其中E为单位矩阵(可以理解为我们整数运算时的1),k为偶数,这样就可以减去一半的运算量啦,前面那部分同样可以这样处理,所以就达到log(k)的效率啦;另外,对于k为奇数,只要把最后一项独立出来,前面就是偶数项啦。由于K很大,求A^k就可以用快速幂算出。那这道题基本就可以解决啦。但要注意,由于模运算比较费时间,但矩阵数据不大,所以矩阵运算时可以先算出一项的值再取模,具体看代码。接着问题又来了,由于我们算出一项的值再取模可能数据大过int,所以我们会用long long。接着又一个问题来了,栈溢出,因为数据运算量大,而long long占的字节又大,但我们知道数据都是非负的,所以可以用unsigned,占的字节少,数据范围又大。现在这题就真的AC啦!

#include <iostream>
#include <cstdio>
#define LL unsigned int
using namespace std;

#define MAXN 35

struct Matrix
{
    LL a[MAXN][MAXN];
};

int n , k;
LL mod;

Matrix E, mat;

Matrix mult(Matrix m1, Matrix m2)
{
    Matrix m;
    int i, j, t;
    for(i=0; i<n; i++)
       for(j=0; j<n; j++)
       {
           m.a[i][j] = 0;
           for(t=0; t<n; t++)
           {
               m.a[i][j] += (m1.a[i][t]*m2.a[t][j])%mod;
           }
           m.a[i][j] %= mod;
       }
    return m;
}

Matrix add(Matrix m1, Matrix m2)
{
    Matrix m;
    for(int i=0; i<n; i++)
      for(int j=0; j<n; j++)
        m.a[i][j] = (m1.a[i][j] + m2.a[i][j]) % mod;
    return m;
}

Matrix pow_mod(Matrix m1, int t)
{
    Matrix m;
    m = E;
    while(t)
    {
        if(t %2 == 1)
          m = mult(m, m1);
        m1 = mult(m1, m1);
        t /= 2;
    }
    return m;
}

void print(Matrix m)
{
    int i, j;
    for(i=0; i<n; i++)
    {
        printf("%u", m.a[i][0]);
        for(j=1; j<n; j++)
           printf(" %u", m.a[i][j]);
        printf("\n");
    }
}

Matrix solve(Matrix m, int t)
{
    if(t == 1)
    {
        for(int i=0; i<n; i++)
           for(int j=0; j<n; j++)
              m.a[i][j] %= mod;
        return m;
    }
    if(t % 2 == 1)
        return add(mult(solve(m, (t-1)/2), add(pow_mod(m, (t-1)/2), E)), pow_mod(m, t));
    else
        return mult(solve(m, t/2), add(pow_mod(m, t/2), E));
}

int main()
{
    //freopen("in.txt", "r", stdin);
    int i, j;
    //scanf("%d%d%d", &n, &k, &m);
    while( scanf("%d%d%u", &n, &k, &mod) != EOF )
    {
        for(i=0; i<n; i++)
           for(j=0; j<n; j++)
           {
               scanf("%u", &mat.a[i][j]);
               E.a[i][j] = ( i == j );
           }
        print( solve(mat, k) );
    }
	return 0;
}

 

点赞