Strassen矩阵乘法
\(
(1) ~~
\begin{equation}\begin{bmatrix} a_{11} & a_{12} \\ a_{21} & a_{22} \end{bmatrix}
\begin{bmatrix} b_{11} & b_{12} \\ b_{21} & b_{22} \end{bmatrix}
=
\begin{bmatrix} a_{11}b_{11}+a_{12}b_{21} & a_{11}b_{12}+a_{12}b_{22} \\ a_{21}b_{11}+a_{22}b_{21} & a_{21}b_{12}+a_{22}b_{22} \end{bmatrix}
\end{equation}
\)
\(
(2) ~~
\begin{equation}
\begin{bmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{bmatrix}
\begin{bmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{bmatrix}
=
\begin{bmatrix} A_{11}B_{11}+A_{12}B_{21} & A_{11}B_{12}+A_{12}B_{22} \\ A_{21}B_{11}+A_{22}B_{21} & A_{21}B_{12}+A_{22}B_{22} \end{bmatrix}
\end{equation}
\)
\(
(3) ~~
\begin{equation}
\begin{bmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{bmatrix}
\begin{bmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{bmatrix}
=
\begin{bmatrix} C_{11} & C_{12} \\ C_{21} & C_{22} \end{bmatrix}
\end{equation}
\)
\(
\begin{equation*}
P_1 := A_{11}(B_{12}-B_{22}) \\
P_2 := (A_{11}+A_{12})B_{22} \\
P_3 := (A_{21}+A_{22})B_{11} \\
P_4 := A_{22}(B_{21}-B_{11}) \\
P_5 := (A_{11}+A_{22})(B_{11}+B_{22}) \\
P_6 := (A_{12}-A_{22})(B_{21}+B_{22}) \\
P_7 := (A_{11}-A_{21})(B_{11}+B_{12})
\end{equation*}
\)
\(
\begin{equation*}
C_{11} := P_5+P_4-P_2+P_6 \\
C_{12} := P_1+P_2 \\
C_{21} := P_3+P_4 \\
C_{22} := P_5+P_1-P_3-P_7
\end{equation*}
\)
为更好的讨论算法本身,本文矩阵特指行列数相等(方阵)且都为2的幂的矩阵。矩阵乘法的定义是\((AB)_{ij} = \sum_{k=0}^n A_{ik}B_{kj}\)。使用定义直接求\(AB\)需要嵌套3层循环,如果\(AB\)是\(m\times m\)的,时间复杂度是\(O(m^3)\)。如何高效的做矩阵乘法是个重要的问题,计算机各领域都需要用到。目前计算机学者们虽未找到该算法足够“紧”的“理论下界”,但一直都在发明更优的具体算法,目前最快“斯坦福方法”代价为\(O(m^{2.373})\),第一个快于\(O(m^3)\)的算法由德国数学家Strassen于1969年提出。
结合上述(1)和(2)找一种分治算法,(2)把方阵均分为四象限,把每个象限作为“子矩阵”。根据矩阵乘法定义,矩阵乘法可以正确转化成(2)等号右边的递归形式。分析这种算法的时间复杂度,其包含8个“子矩阵乘法”和4次加法,所有加法和其他操作的总代价等于遍历矩阵的\(O(m^2)\)。所以时间复杂度满足\(T(m)=8T(\frac{m}{2}) + O(m^2)\),根据主定理\(T(m)=O(m^3)\)。
上述简单的分治算法并没有更优的时间复杂度,然后来看同样基于划分“子矩阵”思路的Strassen乘法,其划分方式如(3)所示,包含7次乘法和18次加减法,而(2)包含8次乘法和仅仅4次的加减法,所以其优化思想是“为降低乘法次数,可以不计代价的增多加减法次数”,因为根据主定理,只要加减法是常数次都不会影响整体的时间复杂度,但乘法次数影响时间复杂度。而之所以能用加减法次数换乘法次数,是因为矩阵乘法满足分配律,巧妙的运用分配律能降低乘法次数。由于只有7次乘法,Strassen乘法的时间复杂度满足\(T(m)=7T(\frac{m}{2}) + O(m^2)\),根据主定理\(T(m)=O(m^{log_2{7}})=O(m^{2.807})\)。
下面是具体实现,这里偷懒用一维数组表示这种行列数相等的方阵,并且用到了Python面向对象的运算符重载。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | #!python3 class Matrix: sqrt = __import__('math').sqrt log2 = __import__('math').log2 def __init__(self,array): length = len(array) assert length > 0, " empty input-array " level = self.sqrt(length) power = self.log2(level) assert int(level) == level, " cannot be parsed to square matrix " assert int(power) == power, " cannot be parsed to (2^N) X (2^N) matrix " self.level = int(level) self.power = int(power) self.array = array def __str__(self): x, s = self.level, '' for i in range(x): ss = str([ self.array[i*x + j] for j in range(x) ]) ss = ss.replace('[','').replace(']','') ss += '\r\n' s += ss return s def __add__(self,obj): assert isinstance(obj,self.__class__),"wrong instance type!" assert self.level == obj.level,"wrong level!" arr = [] for i in range(len(self.array)): arr.append( self.array[i] + obj.array[i] ) return self.__class__(array = arr) def __sub__(self,obj): assert isinstance(obj,self.__class__),"wrong instance type!" assert self.level == obj.level,"wrong level!" arr = [] for i in range(len(self.array)): arr.append( self.array[i] - obj.array[i] ) return self.__class__(array = arr) def divide(self): if self.level == 1: return ( self.__class__([self.array[0]]) , ) m1,m2,m3,m4 = [],[],[],[] for i in range(self.level//2): for j in range(self.level//2): m1.append(self.array[ i*self.level + j ]) m2.append(self.array[ i*self.level + j + self.level//2 ]) m3.append(self.array[ (self.level//2 + i) * self.level + j ]) m4.append(self.array[ (self.level//2 + i) * self.level + j + self.level//2 ]) return self.__class__(m1),self.__class__(m2),self.__class__(m3),self.__class__(m4) @classmethod def merge(cls,i,ii,iii,iv): assert isinstance(i,cls) and isinstance(ii,cls) and isinstance(iii,cls) and isinstance(iv,cls) , "wrong instance type!" assert i.level == ii.level == iii.level == iv.level , "wrong level!" arr = [] lev = i.level for m in range(lev): for n in range(lev): arr.append(i.array[ m*lev + n]) for n in range(i.level): arr.append(ii.array[ m*lev + n]) for m in range(lev): for n in range(lev): arr.append(iii.array[ m*lev + n]) for n in range(i.level): arr.append(iv.array[ m*lev + n]) return cls(array = arr) def __mul__(self,obj): assert isinstance(obj,self.__class__),"wrong instance type!" assert self.level == obj.level,"wrong level!" a = self.divide() b = obj.divide() if len(a)==1: return self.__class__([ self.array[0] * obj.array[0] ]) p1 = a[0] * ( b[1] - b[3] ) p2 = ( a[0] + a[1] ) * b[3] p3 = ( a[2] + a[3] ) * b[0] p4 = a[3] * ( b[2] - b[0] ) p5 = ( a[0] + a[3] ) * ( b[0] + b[3] ) p6 = ( a[1] - a[3] ) * ( b[2] + b[3] ) p7 = ( a[0] - a[2] ) * ( b[0] + b[1] ) c1 = p5 + p4 - p2 + p6 c2 = p1 + p2 c3 = p3 + p4 c4 = p5 + p1 - p3 - p7 return self.merge(c1,c2,c3,c4) a = Matrix([ 0,1, 2,3 ]) b = Matrix([ -1,0, 0,-1 ]) c = a * b print(c) |
哈哈有点负责
负责?