矩阵连乘的最优结合
1) 结合策略:矩阵乘法满足交换律,故通常无需加括号,有\(a\times b\)矩阵\(X\),\(b\times c\)矩阵\(Y\),\(c\times d\)矩阵\(Z\),朴素乘法用结合策略\((XY)Z\)的代价为\(O(abc+acd)\),利用\(X(YZ)\)的代价为\(O(bcd + abd)\),不同结合策略会导致不同的代价。规定待连乘的矩阵序列(下标)为\([A_1,A_2,…,A_n]\),维度序列为\([p_0,p_1,…,p_n]\),即\(A_i\)为\(p_{i-1} \times p_i\)矩阵;
2) 枚举法:设\(P(n)\)为\(n\)个矩阵连乘的结合(加括号)方案总数,则有\(P(n)=\sum_{k=1}^{n-1} P(k)P(n-k)\),\(P(1)=1\)。解得\(P(n)=\Omega (2^n)\),故枚举结合策略的代价是指数型的,不是种好算法;
3) 最优子结构:\([A_1…A_n]_{最优} = [A_1…A_k]_{最优} \times [A_{k+1}…A_n]_{最优}\),比如\(A_1A_2A_3A_4A_5A_6\)的其中一个最优结合策略为\(((A_1A_2)A_3)(A_4(A_5A_6))\),\((A_1A_2)A_3\)必为\(A_1A_2A_3\)的其中一个最优策略,\(A_4(A_5A_6)\)必为\(A_4A_5A_6\)的其中一个最优策略,这个最优子结构的正确性容易用反证法证明;
4) 状态转移方程:定义状态\(F(i,j)\)为子序列\(A_iA_{i+1}…A_j\)的代价最小结合策略的代价值。可利用最优子结构构造方程\(F(i,j) = min\{ F(i,k) + F(k+1,j) + p_{i-1}p_kp_j |0 < i \leq k < j \}\),边界为\(F(i,i)=0\)。这里求\(min\)是因为\(k\)未知,只能枚举得出。解\(F(i,j)\)的全部状态需要\(O(n^3)\)的代价,其中\(n\)为矩阵的数量;
5) 子状态的求解次序:考虑到状态转移方程的特性,按自低向高的顺序不能推出正确状态值。所以这里的递推次序是“从左到右遍历所有长度为2的子链,从左到右遍历所有长度为3的子链,…….,依次类推”。容易发现按这种次序,无论任何长度任何首尾元素的子链,其子问题状态值都是正确的。更具体的流程见代码注释;
6) 回溯构造具体解:在按上述次序计算\(F(i,j)\)所有状态的同时记录\(S(i,j)=k\)。然后利用\(S(i,j)\)以及类似二叉树递归遍历的逻辑为矩阵链加上括号即可得出具体的最优解;
1 2 3 4 5 6 7 8 9 10 11 | def dp(p): f = [ [ float('inf') for j in p] for i in p ] for i in range(len(f)): f[i][i] = 0 for l in range(2, len(f)): # 按链长从2开始递增来递推状态 for i in range(1, len(f)-l+1): # 矩阵下标始于1,i为所有长度l的链的起始矩阵下标 j = i + l - 1 # j为所有长度为l的链的终止矩阵下标 for k in range(i,j): # 遍历当前链起始矩阵间的所有k x = f[i][k] + f[k+1][j] + p[i-1]*p[k]*p[j] f[i][j] = x if x < f[i][j] else f[i][j] return f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | def dp(p): f = [ [ float('inf') for j in p] for i in p ] s = [ [ 0 for j in p] for i in p ] for i in range(len(f)): f[i][i] = 0 for l in range(2, len(f)): for i in range(1, len(f)-l+1): j = i + l - 1 for k in range(i,j): x = f[i][k] + f[k+1][j] + p[i-1]*p[k]*p[j] if x < f[i][j]: f[i][j] = x s[i][j] = k return s def parens(s,i,j): if i == j: return "A"+str(i) k = s[i][j] x1 = parens(s,i,k) x2 = parens(s,k+1,j) return "(" + x1 + x2 +")" def solve(p): s = dp(p) r = parens(s,1,len(p)-1) print(r) |