这是一道玄学题。。据说应该随机化,但是用优化过的矩阵相乘算法,就可以过。。
方法一:矩阵相乘算法优化
附上论文网址:https://wenku.baidu.com/view/abe932c6bb4cf7ec4afed0d8.html
如果用经典算法,很容易就TLE了。。但是,把矩阵转置一下,再相乘,就过了,,过了???
还是O(n^3)算法,这这这…太玄学了吧。。。
附上AC代码:
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<stack>
#include<queue>
using namespace std;
#define ll long long
typedef pair<ll,ll>pp;
const double pi=acos(-1.0);
const double eps=1e-9;
const ll INF=0x3f3f3f3f;
const ll MOD=1e9+7ll;
const int MAX=1005;
int n;
int a[MAX][MAX];
int b[MAX][MAX];
int c[MAX][MAX];
int main()
{
while(scanf("%d",&n)==1)
{
if(n==0)
break;
memset(a,0,sizeof(a));
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
scanf("%d",&a[i][j]);
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
b[i][j]=a[j][i];//转置
memset(c,0,sizeof(c));
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
for(int k=0;k<n;k++)
{
c[i][j]+=a[i][k]*b[j][k];//矩阵相乘,注意ijk顺序变了
}
bool sign=true;
int x;
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
{
scanf("%d",&x);
if(sign&&x!=c[i][j])
sign=false;
}
if(sign)
printf("YES\n");
else
printf("NO\n");
}
return 0;
}
方法二:随机化
这才是这道题的正确打开方式emmmmm…
矩阵乘法A*B的复杂度为O(n^3),降维把矩阵相乘的复杂度将为O(n^2)。
原理:A*B=C => (α*A)*B=α*C
等式两边同时左乘一个1×n的向量α,把矩阵信息进行“压缩”,做两次O(n^2)的矩阵乘法。
向量α的选取很重要,不能是简单的(1,1)这种,很容易出问题。最好的是利用哈希,其次可以随机生成。
附上AC代码:
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<time.h>
using namespace std;
#define ll long long
const double pi=acos(-1.0);
const double eps=1e-9;
const ll INF=0x3f3f3f3f;
const ll MOD=1e9+7ll;
const int MAX=1005;
int n;
int a[MAX][MAX];
int b[MAX][MAX];
int ma[MAX];
int ansa[MAX];
int ansb[MAX];
int v[MAX];
int main()
{
while(scanf("%d",&n)==1)
{
if(n==0)
break;
srand((int)time(NULL));
for(int i=0;i<n;i++)
v[i]=rand()%n;//随机生成降维向量
memset(ma,0,sizeof(ma));
memset(ansa,0,sizeof(ansa));
memset(ansb,0,sizeof(ansb));
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
scanf("%d",&a[i][j]);
for(int j=0;j<n;j++)
for(int k=0;k<n;k++)
ma[j]+=v[k]*a[k][j];
for(int j=0;j<n;j++)
for(int k=0;k<n;k++)
ansa[j]+=ma[k]*a[k][j];
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
scanf("%d",&b[i][j]);
for(int j=0;j<n;j++)
for(int k=0;k<n;k++)
ansb[j]+=v[k]*b[k][j];
bool sign=true;
for(int i=0;i<n;i++)
{
if(ansa[i]!=ansb[i])
{
sign=false;
break;
}
}
if(sign)
printf("YES\n");
else
printf("NO\n");
}
return 0;
}