1 | /*
|
---|
2 | Copyright (c) 2011, Intel Corporation. All rights reserved.
|
---|
3 |
|
---|
4 | Redistribution and use in source and binary forms, with or without modification,
|
---|
5 | are permitted provided that the following conditions are met:
|
---|
6 |
|
---|
7 | * Redistributions of source code must retain the above copyright notice, this
|
---|
8 | list of conditions and the following disclaimer.
|
---|
9 | * Redistributions in binary form must reproduce the above copyright notice,
|
---|
10 | this list of conditions and the following disclaimer in the documentation
|
---|
11 | and/or other materials provided with the distribution.
|
---|
12 | * Neither the name of Intel Corporation nor the names of its contributors may
|
---|
13 | be used to endorse or promote products derived from this software without
|
---|
14 | specific prior written permission.
|
---|
15 |
|
---|
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
---|
17 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
---|
18 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
---|
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
---|
20 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
---|
21 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
---|
22 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
---|
23 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
---|
24 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
---|
25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
---|
26 |
|
---|
27 | ********************************************************************************
|
---|
28 | * Content : Eigen bindings to Intel(R) MKL
|
---|
29 | * Triangular matrix * matrix product functionality based on ?TRMM.
|
---|
30 | ********************************************************************************
|
---|
31 | */
|
---|
32 |
|
---|
33 | #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
|
---|
34 | #define EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
|
---|
35 |
|
---|
36 | namespace Eigen {
|
---|
37 |
|
---|
38 | namespace internal {
|
---|
39 |
|
---|
40 |
|
---|
41 | template <typename Scalar, typename Index,
|
---|
42 | int Mode, bool LhsIsTriangular,
|
---|
43 | int LhsStorageOrder, bool ConjugateLhs,
|
---|
44 | int RhsStorageOrder, bool ConjugateRhs,
|
---|
45 | int ResStorageOrder>
|
---|
46 | struct product_triangular_matrix_matrix_trmm :
|
---|
47 | product_triangular_matrix_matrix<Scalar,Index,Mode,
|
---|
48 | LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
|
---|
49 | RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {};
|
---|
50 |
|
---|
51 |
|
---|
52 | // try to go to BLAS specialization
|
---|
53 | #define EIGEN_MKL_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
|
---|
54 | template <typename Index, int Mode, \
|
---|
55 | int LhsStorageOrder, bool ConjugateLhs, \
|
---|
56 | int RhsStorageOrder, bool ConjugateRhs> \
|
---|
57 | struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
|
---|
58 | LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \
|
---|
59 | static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
|
---|
60 | const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
|
---|
61 | product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
|
---|
62 | LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
|
---|
63 | RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
|
---|
64 | _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
|
---|
65 | } \
|
---|
66 | };
|
---|
67 |
|
---|
68 | EIGEN_MKL_TRMM_SPECIALIZE(double, true)
|
---|
69 | EIGEN_MKL_TRMM_SPECIALIZE(double, false)
|
---|
70 | EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, true)
|
---|
71 | EIGEN_MKL_TRMM_SPECIALIZE(dcomplex, false)
|
---|
72 | EIGEN_MKL_TRMM_SPECIALIZE(float, true)
|
---|
73 | EIGEN_MKL_TRMM_SPECIALIZE(float, false)
|
---|
74 | EIGEN_MKL_TRMM_SPECIALIZE(scomplex, true)
|
---|
75 | EIGEN_MKL_TRMM_SPECIALIZE(scomplex, false)
|
---|
76 |
|
---|
77 | // implements col-major += alpha * op(triangular) * op(general)
|
---|
78 | #define EIGEN_MKL_TRMM_L(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
---|
79 | template <typename Index, int Mode, \
|
---|
80 | int LhsStorageOrder, bool ConjugateLhs, \
|
---|
81 | int RhsStorageOrder, bool ConjugateRhs> \
|
---|
82 | struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
---|
83 | LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
|
---|
84 | { \
|
---|
85 | enum { \
|
---|
86 | IsLower = (Mode&Lower) == Lower, \
|
---|
87 | SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
|
---|
88 | IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
|
---|
89 | IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
|
---|
90 | LowUp = IsLower ? Lower : Upper, \
|
---|
91 | conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
|
---|
92 | }; \
|
---|
93 | \
|
---|
94 | static void run( \
|
---|
95 | Index _rows, Index _cols, Index _depth, \
|
---|
96 | const EIGTYPE* _lhs, Index lhsStride, \
|
---|
97 | const EIGTYPE* _rhs, Index rhsStride, \
|
---|
98 | EIGTYPE* res, Index resStride, \
|
---|
99 | EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
|
---|
100 | { \
|
---|
101 | Index diagSize = (std::min)(_rows,_depth); \
|
---|
102 | Index rows = IsLower ? _rows : diagSize; \
|
---|
103 | Index depth = IsLower ? diagSize : _depth; \
|
---|
104 | Index cols = _cols; \
|
---|
105 | \
|
---|
106 | typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
|
---|
107 | typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
|
---|
108 | \
|
---|
109 | /* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
|
---|
110 | if (rows != depth) { \
|
---|
111 | \
|
---|
112 | int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \
|
---|
113 | \
|
---|
114 | if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
|
---|
115 | /* Most likely no benefit to call TRMM or GEMM from MKL*/ \
|
---|
116 | product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
|
---|
117 | LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
|
---|
118 | _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
|
---|
119 | /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \
|
---|
120 | } else { \
|
---|
121 | /* Make sense to call GEMM */ \
|
---|
122 | Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
|
---|
123 | MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
|
---|
124 | MKL_INT aStride = aa_tmp.outerStride(); \
|
---|
125 | gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \
|
---|
126 | general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
|
---|
127 | rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
|
---|
128 | \
|
---|
129 | /*std::cout << "TRMM_L: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \
|
---|
130 | } \
|
---|
131 | return; \
|
---|
132 | } \
|
---|
133 | char side = 'L', transa, uplo, diag = 'N'; \
|
---|
134 | EIGTYPE *b; \
|
---|
135 | const EIGTYPE *a; \
|
---|
136 | MKL_INT m, n, lda, ldb; \
|
---|
137 | MKLTYPE alpha_; \
|
---|
138 | \
|
---|
139 | /* Set alpha_*/ \
|
---|
140 | assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
|
---|
141 | \
|
---|
142 | /* Set m, n */ \
|
---|
143 | m = (MKL_INT)diagSize; \
|
---|
144 | n = (MKL_INT)cols; \
|
---|
145 | \
|
---|
146 | /* Set trans */ \
|
---|
147 | transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
|
---|
148 | \
|
---|
149 | /* Set b, ldb */ \
|
---|
150 | Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
|
---|
151 | MatrixX##EIGPREFIX b_tmp; \
|
---|
152 | \
|
---|
153 | if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
|
---|
154 | b = b_tmp.data(); \
|
---|
155 | ldb = b_tmp.outerStride(); \
|
---|
156 | \
|
---|
157 | /* Set uplo */ \
|
---|
158 | uplo = IsLower ? 'L' : 'U'; \
|
---|
159 | if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
|
---|
160 | /* Set a, lda */ \
|
---|
161 | Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
|
---|
162 | MatrixLhs a_tmp; \
|
---|
163 | \
|
---|
164 | if ((conjA!=0) || (SetDiag==0)) { \
|
---|
165 | if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
|
---|
166 | if (IsZeroDiag) \
|
---|
167 | a_tmp.diagonal().setZero(); \
|
---|
168 | else if (IsUnitDiag) \
|
---|
169 | a_tmp.diagonal().setOnes();\
|
---|
170 | a = a_tmp.data(); \
|
---|
171 | lda = a_tmp.outerStride(); \
|
---|
172 | } else { \
|
---|
173 | a = _lhs; \
|
---|
174 | lda = lhsStride; \
|
---|
175 | } \
|
---|
176 | /*std::cout << "TRMM_L: A is square! Go to MKL TRMM implementation! \n";*/ \
|
---|
177 | /* call ?trmm*/ \
|
---|
178 | MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
|
---|
179 | \
|
---|
180 | /* Add op(a_triangular)*b into res*/ \
|
---|
181 | Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
|
---|
182 | res_tmp=res_tmp+b_tmp; \
|
---|
183 | } \
|
---|
184 | };
|
---|
185 |
|
---|
186 | EIGEN_MKL_TRMM_L(double, double, d, d)
|
---|
187 | EIGEN_MKL_TRMM_L(dcomplex, MKL_Complex16, cd, z)
|
---|
188 | EIGEN_MKL_TRMM_L(float, float, f, s)
|
---|
189 | EIGEN_MKL_TRMM_L(scomplex, MKL_Complex8, cf, c)
|
---|
190 |
|
---|
191 | // implements col-major += alpha * op(general) * op(triangular)
|
---|
192 | #define EIGEN_MKL_TRMM_R(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
---|
193 | template <typename Index, int Mode, \
|
---|
194 | int LhsStorageOrder, bool ConjugateLhs, \
|
---|
195 | int RhsStorageOrder, bool ConjugateRhs> \
|
---|
196 | struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
---|
197 | LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
|
---|
198 | { \
|
---|
199 | enum { \
|
---|
200 | IsLower = (Mode&Lower) == Lower, \
|
---|
201 | SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
|
---|
202 | IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
|
---|
203 | IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
|
---|
204 | LowUp = IsLower ? Lower : Upper, \
|
---|
205 | conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
|
---|
206 | }; \
|
---|
207 | \
|
---|
208 | static void run( \
|
---|
209 | Index _rows, Index _cols, Index _depth, \
|
---|
210 | const EIGTYPE* _lhs, Index lhsStride, \
|
---|
211 | const EIGTYPE* _rhs, Index rhsStride, \
|
---|
212 | EIGTYPE* res, Index resStride, \
|
---|
213 | EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
|
---|
214 | { \
|
---|
215 | Index diagSize = (std::min)(_cols,_depth); \
|
---|
216 | Index rows = _rows; \
|
---|
217 | Index depth = IsLower ? _depth : diagSize; \
|
---|
218 | Index cols = IsLower ? diagSize : _cols; \
|
---|
219 | \
|
---|
220 | typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
|
---|
221 | typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
|
---|
222 | \
|
---|
223 | /* Non-square case - doesn't fit to MKL ?TRMM. Fall to default triangular product or call MKL ?GEMM*/ \
|
---|
224 | if (cols != depth) { \
|
---|
225 | \
|
---|
226 | int nthr = mkl_domain_get_max_threads(EIGEN_MKL_DOMAIN_BLAS); \
|
---|
227 | \
|
---|
228 | if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
|
---|
229 | /* Most likely no benefit to call TRMM or GEMM from MKL*/ \
|
---|
230 | product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
|
---|
231 | LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
|
---|
232 | _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
|
---|
233 | /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \
|
---|
234 | } else { \
|
---|
235 | /* Make sense to call GEMM */ \
|
---|
236 | Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
|
---|
237 | MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
|
---|
238 | MKL_INT aStride = aa_tmp.outerStride(); \
|
---|
239 | gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth); \
|
---|
240 | general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
|
---|
241 | rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
|
---|
242 | \
|
---|
243 | /*std::cout << "TRMM_R: A is not square! Go to MKL GEMM implementation! " << nthr<<" \n";*/ \
|
---|
244 | } \
|
---|
245 | return; \
|
---|
246 | } \
|
---|
247 | char side = 'R', transa, uplo, diag = 'N'; \
|
---|
248 | EIGTYPE *b; \
|
---|
249 | const EIGTYPE *a; \
|
---|
250 | MKL_INT m, n, lda, ldb; \
|
---|
251 | MKLTYPE alpha_; \
|
---|
252 | \
|
---|
253 | /* Set alpha_*/ \
|
---|
254 | assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
|
---|
255 | \
|
---|
256 | /* Set m, n */ \
|
---|
257 | m = (MKL_INT)rows; \
|
---|
258 | n = (MKL_INT)diagSize; \
|
---|
259 | \
|
---|
260 | /* Set trans */ \
|
---|
261 | transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
|
---|
262 | \
|
---|
263 | /* Set b, ldb */ \
|
---|
264 | Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
|
---|
265 | MatrixX##EIGPREFIX b_tmp; \
|
---|
266 | \
|
---|
267 | if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
|
---|
268 | b = b_tmp.data(); \
|
---|
269 | ldb = b_tmp.outerStride(); \
|
---|
270 | \
|
---|
271 | /* Set uplo */ \
|
---|
272 | uplo = IsLower ? 'L' : 'U'; \
|
---|
273 | if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
|
---|
274 | /* Set a, lda */ \
|
---|
275 | Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
|
---|
276 | MatrixRhs a_tmp; \
|
---|
277 | \
|
---|
278 | if ((conjA!=0) || (SetDiag==0)) { \
|
---|
279 | if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
|
---|
280 | if (IsZeroDiag) \
|
---|
281 | a_tmp.diagonal().setZero(); \
|
---|
282 | else if (IsUnitDiag) \
|
---|
283 | a_tmp.diagonal().setOnes();\
|
---|
284 | a = a_tmp.data(); \
|
---|
285 | lda = a_tmp.outerStride(); \
|
---|
286 | } else { \
|
---|
287 | a = _rhs; \
|
---|
288 | lda = rhsStride; \
|
---|
289 | } \
|
---|
290 | /*std::cout << "TRMM_R: A is square! Go to MKL TRMM implementation! \n";*/ \
|
---|
291 | /* call ?trmm*/ \
|
---|
292 | MKLPREFIX##trmm(&side, &uplo, &transa, &diag, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (MKLTYPE*)b, &ldb); \
|
---|
293 | \
|
---|
294 | /* Add op(a_triangular)*b into res*/ \
|
---|
295 | Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
|
---|
296 | res_tmp=res_tmp+b_tmp; \
|
---|
297 | } \
|
---|
298 | };
|
---|
299 |
|
---|
300 | EIGEN_MKL_TRMM_R(double, double, d, d)
|
---|
301 | EIGEN_MKL_TRMM_R(dcomplex, MKL_Complex16, cd, z)
|
---|
302 | EIGEN_MKL_TRMM_R(float, float, f, s)
|
---|
303 | EIGEN_MKL_TRMM_R(scomplex, MKL_Complex8, cf, c)
|
---|
304 |
|
---|
305 | } // end namespace internal
|
---|
306 |
|
---|
307 | } // end namespace Eigen
|
---|
308 |
|
---|
309 | #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_MKL_H
|
---|