1 | // This file is part of Eigen, a lightweight C++ template library
|
---|
2 | // for linear algebra.
|
---|
3 | //
|
---|
4 | // Copyright (C) 2012 Gael Guennebaud <gael.guennebaud@inria.fr>
|
---|
5 | //
|
---|
6 | // This Source Code Form is subject to the terms of the Mozilla
|
---|
7 | // Public License v. 2.0. If a copy of the MPL was not distributed
|
---|
8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
---|
9 |
|
---|
10 | #ifndef EIGEN_SPARSELU_GEMM_KERNEL_H
|
---|
11 | #define EIGEN_SPARSELU_GEMM_KERNEL_H
|
---|
12 |
|
---|
13 | namespace Eigen {
|
---|
14 |
|
---|
15 | namespace internal {
|
---|
16 |
|
---|
17 |
|
---|
18 | /** \internal
|
---|
19 | * A general matrix-matrix product kernel optimized for the SparseLU factorization.
|
---|
20 | * - A, B, and C must be column major
|
---|
21 | * - lda and ldc must be multiples of the respective packet size
|
---|
22 | * - C must have the same alignment as A
|
---|
23 | */
|
---|
24 | template<typename Scalar,typename Index>
|
---|
25 | EIGEN_DONT_INLINE
|
---|
26 | void sparselu_gemm(Index m, Index n, Index d, const Scalar* A, Index lda, const Scalar* B, Index ldb, Scalar* C, Index ldc)
|
---|
27 | {
|
---|
28 | using namespace Eigen::internal;
|
---|
29 |
|
---|
30 | typedef typename packet_traits<Scalar>::type Packet;
|
---|
31 | enum {
|
---|
32 | NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
|
---|
33 | PacketSize = packet_traits<Scalar>::size,
|
---|
34 | PM = 8, // peeling in M
|
---|
35 | RN = 2, // register blocking
|
---|
36 | RK = NumberOfRegisters>=16 ? 4 : 2, // register blocking
|
---|
37 | BM = 4096/sizeof(Scalar), // number of rows of A-C per chunk
|
---|
38 | SM = PM*PacketSize // step along M
|
---|
39 | };
|
---|
40 | Index d_end = (d/RK)*RK; // number of columns of A (rows of B) suitable for full register blocking
|
---|
41 | Index n_end = (n/RN)*RN; // number of columns of B-C suitable for processing RN columns at once
|
---|
42 | Index i0 = internal::first_aligned(A,m);
|
---|
43 |
|
---|
44 | eigen_internal_assert(((lda%PacketSize)==0) && ((ldc%PacketSize)==0) && (i0==internal::first_aligned(C,m)));
|
---|
45 |
|
---|
46 | // handle the non aligned rows of A and C without any optimization:
|
---|
47 | for(Index i=0; i<i0; ++i)
|
---|
48 | {
|
---|
49 | for(Index j=0; j<n; ++j)
|
---|
50 | {
|
---|
51 | Scalar c = C[i+j*ldc];
|
---|
52 | for(Index k=0; k<d; ++k)
|
---|
53 | c += B[k+j*ldb] * A[i+k*lda];
|
---|
54 | C[i+j*ldc] = c;
|
---|
55 | }
|
---|
56 | }
|
---|
57 | // process the remaining rows per chunk of BM rows
|
---|
58 | for(Index ib=i0; ib<m; ib+=BM)
|
---|
59 | {
|
---|
60 | Index actual_b = std::min<Index>(BM, m-ib); // actual number of rows
|
---|
61 | Index actual_b_end1 = (actual_b/SM)*SM; // actual number of rows suitable for peeling
|
---|
62 | Index actual_b_end2 = (actual_b/PacketSize)*PacketSize; // actual number of rows suitable for vectorization
|
---|
63 |
|
---|
64 | // Let's process two columns of B-C at once
|
---|
65 | for(Index j=0; j<n_end; j+=RN)
|
---|
66 | {
|
---|
67 | const Scalar* Bc0 = B+(j+0)*ldb;
|
---|
68 | const Scalar* Bc1 = B+(j+1)*ldb;
|
---|
69 |
|
---|
70 | for(Index k=0; k<d_end; k+=RK)
|
---|
71 | {
|
---|
72 |
|
---|
73 | // load and expand a RN x RK block of B
|
---|
74 | Packet b00, b10, b20, b30, b01, b11, b21, b31;
|
---|
75 | b00 = pset1<Packet>(Bc0[0]);
|
---|
76 | b10 = pset1<Packet>(Bc0[1]);
|
---|
77 | if(RK==4) b20 = pset1<Packet>(Bc0[2]);
|
---|
78 | if(RK==4) b30 = pset1<Packet>(Bc0[3]);
|
---|
79 | b01 = pset1<Packet>(Bc1[0]);
|
---|
80 | b11 = pset1<Packet>(Bc1[1]);
|
---|
81 | if(RK==4) b21 = pset1<Packet>(Bc1[2]);
|
---|
82 | if(RK==4) b31 = pset1<Packet>(Bc1[3]);
|
---|
83 |
|
---|
84 | Packet a0, a1, a2, a3, c0, c1, t0, t1;
|
---|
85 |
|
---|
86 | const Scalar* A0 = A+ib+(k+0)*lda;
|
---|
87 | const Scalar* A1 = A+ib+(k+1)*lda;
|
---|
88 | const Scalar* A2 = A+ib+(k+2)*lda;
|
---|
89 | const Scalar* A3 = A+ib+(k+3)*lda;
|
---|
90 |
|
---|
91 | Scalar* C0 = C+ib+(j+0)*ldc;
|
---|
92 | Scalar* C1 = C+ib+(j+1)*ldc;
|
---|
93 |
|
---|
94 | a0 = pload<Packet>(A0);
|
---|
95 | a1 = pload<Packet>(A1);
|
---|
96 | if(RK==4)
|
---|
97 | {
|
---|
98 | a2 = pload<Packet>(A2);
|
---|
99 | a3 = pload<Packet>(A3);
|
---|
100 | }
|
---|
101 | else
|
---|
102 | {
|
---|
103 | // workaround "may be used uninitialized in this function" warning
|
---|
104 | a2 = a3 = a0;
|
---|
105 | }
|
---|
106 |
|
---|
107 | #define KMADD(c, a, b, tmp) {tmp = b; tmp = pmul(a,tmp); c = padd(c,tmp);}
|
---|
108 | #define WORK(I) \
|
---|
109 | c0 = pload<Packet>(C0+i+(I)*PacketSize); \
|
---|
110 | c1 = pload<Packet>(C1+i+(I)*PacketSize); \
|
---|
111 | KMADD(c0, a0, b00, t0) \
|
---|
112 | KMADD(c1, a0, b01, t1) \
|
---|
113 | a0 = pload<Packet>(A0+i+(I+1)*PacketSize); \
|
---|
114 | KMADD(c0, a1, b10, t0) \
|
---|
115 | KMADD(c1, a1, b11, t1) \
|
---|
116 | a1 = pload<Packet>(A1+i+(I+1)*PacketSize); \
|
---|
117 | if(RK==4) KMADD(c0, a2, b20, t0) \
|
---|
118 | if(RK==4) KMADD(c1, a2, b21, t1) \
|
---|
119 | if(RK==4) a2 = pload<Packet>(A2+i+(I+1)*PacketSize); \
|
---|
120 | if(RK==4) KMADD(c0, a3, b30, t0) \
|
---|
121 | if(RK==4) KMADD(c1, a3, b31, t1) \
|
---|
122 | if(RK==4) a3 = pload<Packet>(A3+i+(I+1)*PacketSize); \
|
---|
123 | pstore(C0+i+(I)*PacketSize, c0); \
|
---|
124 | pstore(C1+i+(I)*PacketSize, c1)
|
---|
125 |
|
---|
126 | // process rows of A' - C' with aggressive vectorization and peeling
|
---|
127 | for(Index i=0; i<actual_b_end1; i+=PacketSize*8)
|
---|
128 | {
|
---|
129 | EIGEN_ASM_COMMENT("SPARSELU_GEMML_KERNEL1");
|
---|
130 | prefetch((A0+i+(5)*PacketSize));
|
---|
131 | prefetch((A1+i+(5)*PacketSize));
|
---|
132 | if(RK==4) prefetch((A2+i+(5)*PacketSize));
|
---|
133 | if(RK==4) prefetch((A3+i+(5)*PacketSize));
|
---|
134 | WORK(0);
|
---|
135 | WORK(1);
|
---|
136 | WORK(2);
|
---|
137 | WORK(3);
|
---|
138 | WORK(4);
|
---|
139 | WORK(5);
|
---|
140 | WORK(6);
|
---|
141 | WORK(7);
|
---|
142 | }
|
---|
143 | // process the remaining rows with vectorization only
|
---|
144 | for(Index i=actual_b_end1; i<actual_b_end2; i+=PacketSize)
|
---|
145 | {
|
---|
146 | WORK(0);
|
---|
147 | }
|
---|
148 | #undef WORK
|
---|
149 | // process the remaining rows without vectorization
|
---|
150 | for(Index i=actual_b_end2; i<actual_b; ++i)
|
---|
151 | {
|
---|
152 | if(RK==4)
|
---|
153 | {
|
---|
154 | C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1]+A2[i]*Bc0[2]+A3[i]*Bc0[3];
|
---|
155 | C1[i] += A0[i]*Bc1[0]+A1[i]*Bc1[1]+A2[i]*Bc1[2]+A3[i]*Bc1[3];
|
---|
156 | }
|
---|
157 | else
|
---|
158 | {
|
---|
159 | C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1];
|
---|
160 | C1[i] += A0[i]*Bc1[0]+A1[i]*Bc1[1];
|
---|
161 | }
|
---|
162 | }
|
---|
163 |
|
---|
164 | Bc0 += RK;
|
---|
165 | Bc1 += RK;
|
---|
166 | } // peeled loop on k
|
---|
167 | } // peeled loop on the columns j
|
---|
168 | // process the last column (we now perform a matrux-vector product)
|
---|
169 | if((n-n_end)>0)
|
---|
170 | {
|
---|
171 | const Scalar* Bc0 = B+(n-1)*ldb;
|
---|
172 |
|
---|
173 | for(Index k=0; k<d_end; k+=RK)
|
---|
174 | {
|
---|
175 |
|
---|
176 | // load and expand a 1 x RK block of B
|
---|
177 | Packet b00, b10, b20, b30;
|
---|
178 | b00 = pset1<Packet>(Bc0[0]);
|
---|
179 | b10 = pset1<Packet>(Bc0[1]);
|
---|
180 | if(RK==4) b20 = pset1<Packet>(Bc0[2]);
|
---|
181 | if(RK==4) b30 = pset1<Packet>(Bc0[3]);
|
---|
182 |
|
---|
183 | Packet a0, a1, a2, a3, c0, t0/*, t1*/;
|
---|
184 |
|
---|
185 | const Scalar* A0 = A+ib+(k+0)*lda;
|
---|
186 | const Scalar* A1 = A+ib+(k+1)*lda;
|
---|
187 | const Scalar* A2 = A+ib+(k+2)*lda;
|
---|
188 | const Scalar* A3 = A+ib+(k+3)*lda;
|
---|
189 |
|
---|
190 | Scalar* C0 = C+ib+(n_end)*ldc;
|
---|
191 |
|
---|
192 | a0 = pload<Packet>(A0);
|
---|
193 | a1 = pload<Packet>(A1);
|
---|
194 | if(RK==4)
|
---|
195 | {
|
---|
196 | a2 = pload<Packet>(A2);
|
---|
197 | a3 = pload<Packet>(A3);
|
---|
198 | }
|
---|
199 | else
|
---|
200 | {
|
---|
201 | // workaround "may be used uninitialized in this function" warning
|
---|
202 | a2 = a3 = a0;
|
---|
203 | }
|
---|
204 |
|
---|
205 | #define WORK(I) \
|
---|
206 | c0 = pload<Packet>(C0+i+(I)*PacketSize); \
|
---|
207 | KMADD(c0, a0, b00, t0) \
|
---|
208 | a0 = pload<Packet>(A0+i+(I+1)*PacketSize); \
|
---|
209 | KMADD(c0, a1, b10, t0) \
|
---|
210 | a1 = pload<Packet>(A1+i+(I+1)*PacketSize); \
|
---|
211 | if(RK==4) KMADD(c0, a2, b20, t0) \
|
---|
212 | if(RK==4) a2 = pload<Packet>(A2+i+(I+1)*PacketSize); \
|
---|
213 | if(RK==4) KMADD(c0, a3, b30, t0) \
|
---|
214 | if(RK==4) a3 = pload<Packet>(A3+i+(I+1)*PacketSize); \
|
---|
215 | pstore(C0+i+(I)*PacketSize, c0);
|
---|
216 |
|
---|
217 | // agressive vectorization and peeling
|
---|
218 | for(Index i=0; i<actual_b_end1; i+=PacketSize*8)
|
---|
219 | {
|
---|
220 | EIGEN_ASM_COMMENT("SPARSELU_GEMML_KERNEL2");
|
---|
221 | WORK(0);
|
---|
222 | WORK(1);
|
---|
223 | WORK(2);
|
---|
224 | WORK(3);
|
---|
225 | WORK(4);
|
---|
226 | WORK(5);
|
---|
227 | WORK(6);
|
---|
228 | WORK(7);
|
---|
229 | }
|
---|
230 | // vectorization only
|
---|
231 | for(Index i=actual_b_end1; i<actual_b_end2; i+=PacketSize)
|
---|
232 | {
|
---|
233 | WORK(0);
|
---|
234 | }
|
---|
235 | // remaining scalars
|
---|
236 | for(Index i=actual_b_end2; i<actual_b; ++i)
|
---|
237 | {
|
---|
238 | if(RK==4)
|
---|
239 | C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1]+A2[i]*Bc0[2]+A3[i]*Bc0[3];
|
---|
240 | else
|
---|
241 | C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1];
|
---|
242 | }
|
---|
243 |
|
---|
244 | Bc0 += RK;
|
---|
245 | #undef WORK
|
---|
246 | }
|
---|
247 | }
|
---|
248 |
|
---|
249 | // process the last columns of A, corresponding to the last rows of B
|
---|
250 | Index rd = d-d_end;
|
---|
251 | if(rd>0)
|
---|
252 | {
|
---|
253 | for(Index j=0; j<n; ++j)
|
---|
254 | {
|
---|
255 | enum {
|
---|
256 | Alignment = PacketSize>1 ? Aligned : 0
|
---|
257 | };
|
---|
258 | typedef Map<Matrix<Scalar,Dynamic,1>, Alignment > MapVector;
|
---|
259 | typedef Map<const Matrix<Scalar,Dynamic,1>, Alignment > ConstMapVector;
|
---|
260 | if(rd==1) MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b);
|
---|
261 |
|
---|
262 | else if(rd==2) MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b)
|
---|
263 | + B[1+d_end+j*ldb] * ConstMapVector(A+(d_end+1)*lda+ib, actual_b);
|
---|
264 |
|
---|
265 | else MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b)
|
---|
266 | + B[1+d_end+j*ldb] * ConstMapVector(A+(d_end+1)*lda+ib, actual_b)
|
---|
267 | + B[2+d_end+j*ldb] * ConstMapVector(A+(d_end+2)*lda+ib, actual_b);
|
---|
268 | }
|
---|
269 | }
|
---|
270 |
|
---|
271 | } // blocking on the rows of A and C
|
---|
272 | }
|
---|
273 | #undef KMADD
|
---|
274 |
|
---|
275 | } // namespace internal
|
---|
276 |
|
---|
277 | } // namespace Eigen
|
---|
278 |
|
---|
279 | #endif // EIGEN_SPARSELU_GEMM_KERNEL_H
|
---|