问题描述
给定一系列矩阵
A=(Ai)N−1i=0 求他们的连续相乘结果
A=ΠN−1i=0Ai
寻找最优的相乘组合,使得计算 A 所需的时间复杂度最小。
问题分析
由于矩阵乘法要求两个相乘矩阵的维度满足:第一个矩阵的列数要与第二个矩阵的行数相同。所以我们只要用 N+1 个数字就能表示所有矩阵的维度了,这里我们用 d 来表示这 N+1 个数字, 其中 di 和 di+1 分别表示第 i 个矩阵的行数和列数。
动态规划求解
给定一个矩阵序列 A , 我们并不需要真正计算矩阵乘法,而是给出最优时间复杂度和矩阵相乘顺序。因此,我们真正的输入是 d 。这里我们暂且不考虑输出矩阵相乘的顺序,先以求最优时间复杂度为目标解决这个优化问题。
如果你不想看分析过程可以直接看后面的算法实现部分。
如果我们用 C(⋅) 表示某个矩阵连乘序列的最优时间复杂度,那么它一定满足下面的公式:
C(ΠN−1i=0Ai)=min{C(Πkm=0Am)+C(ΠN−1n=k+1An)+d0⋅dk+1⋅dN}N−2k=0.(1)
类似地,对于任意的正整数
u 和
v , 其中
0≤u≤v≤N−1 , 我们有:
C(Πvi=uAi)=min{C(Πkm=uAm)+C(Πvn=k+1An)+du⋅dk+1⋅dv+1}v−1k=u.(2)
而且我们还知道:
C(Ai)=0,i=0,…,N−1.(3)
那么,如果我们用矩阵
B 中元素
b(u,v) 表示
C(Πvi=uAi) , 我们从公式 (2) 可以看出,
b(u,v) 只依赖于
b(u,u:1:v−1) 和
b(u+1:1:v,v) ,其中
u:1:v−1 表示从
u 以步长 1 增长到
v−1 。另外,对角线上元素都为0。故此,可以借助一个二维数组来寻找
u=0 ,
v=N−1 时的最优时间复杂度,每一次找到的时间复杂度记录下最优的
k 值就可以知道如何划分矩阵来相乘了。
算法的时间复杂度
利用动态规划的思想解决矩阵序列连乘问题的算法本身的时间复杂度跟 B 矩阵的计算有关, B 矩阵需要计算其整个上三角部分,我们逐步推导:
第 1 列: 计算 B(0,1) : 需要 1−0=1 次计算。
第 2 列: 计算 B(1,2) : 需要 2−1=1 次计算。
第 2 列: 计算 B(0,2) : 需要 2−0=2 次计算。
第 3 列: 计算 B(2,3) : 需要 3−2=1 次计算。
第 3 列: 计算 B(1,3) : 需要 3−1=2 次计算。
第 3 列: 计算 B(0,3) : 需要 3−0=3 次计算。
⋮
第 N−1 列: 计算 B(N−2,N−1) : 需要 (N−1)−(N−2)=1 次计算。
第 N−1 列: 计算 B(N−3,N−1) : 需要 (N−1)−(N−3)=2 次计算。
第 N−1 列: 计算 B(N−4,N−1) : 需要 (N−1)−(N−4)=3 次计算。
⋮
第 N−1 列: 计算 B(0,N−1) : 需要 (N−1)−(0)=N−1 次计算。
所以计算时间复杂度为:
O({1}+{1+2}+…+{1+2+…+(N−1)})=O(∑n=1N−1∑r=1nr)=O(∑n=1N−1(n+1)n2)=O(∑n=1N−1n2+∑n=1N−1n)=O(N3)
时间复杂度的推到请参考这个链接
https://en.wikipedia.org/wiki/Faulhaber%27s_formula
虽然算法时间复杂度为 O(N3) , 我们只需要存储一个矩阵就可以了,所以空间复杂度是 O(N2) 。
算法实现
完整的C++实现如下:
#include <iostream>
#include <vector>
using namespace std;
// 寻找最优时间复杂度 B,以及最优划分 K
void find_best_complexity(vector<int> &B, vector<int> &K, const int *d, int N){
B.resize(N*N);
K.resize(N*N);
for (int i = 0; i < N; i++){
B[i*N + i] = 0;
}
for (int v = 1; v < N; v++){
for (int u = v - 1; u > -1; u--){
int best_cmp = INT_MAX;
int best_k;
for (int k = u; k < v; k++){
int current_cmp = d[u] * d[k + 1] * d[v + 1] + B[u*N + k] + B[(k+1)*N + (v)];
if (current_cmp < best_cmp){
best_cmp = current_cmp;
best_k = k;
}
}
K[u*N + v] = best_k;
B[u*N + v] = best_cmp;
}
}
}
// 输出最优时间复杂度下矩阵的相乘顺序
void print_uv(int u, int v, vector<int> &K, int &N){
if (u==v){
return;
}
int k = K[u*N+v];
print_uv(u, k, K, N);
print_uv(k + 1, v, K, N);
printf("%4d ", u);
printf("%4d ", k);
printf("%4d \n", v);
}
// 举个例子
int main(int argc, char** argv){
int d[] = {1, 2, 3, 1, 5};
int N = (sizeof(d) / sizeof(d[0])) - 1;
vector<int> B;
vector<int> K;
find_best_complexity(B, K, d, N);
printf("###############################################################\n");
printf("# B\n");
for (int u = 0; u < N; u++){
for (int v = 0; v < N; v++){
if (v < u){
printf("%4d ", -1);
}
else{
printf("%4d ", B[u*N + v]);
}
}
printf("\n");
}
printf("###############################################################\n");
printf("# K\n");
for (int u = 0; u < N; u++){
for (int v = 0; v < N; v++){
if (v < u){
printf("%4d ", -1);
}
else{
printf("%4d ", K[u*N + v]);
}
}
printf("\n");
}
printf("###############################################################\n");
printf("# order\n");
print_uv(0, N - 1, K, N);
return EXIT_SUCCESS;
}
上述代码用printf
是为了更好的格式化输出。