1 | /*
|
---|
2 | Copyright (c) 2011, Intel Corporation. All rights reserved.
|
---|
3 |
|
---|
4 | Redistribution and use in source and binary forms, with or without modification,
|
---|
5 | are permitted provided that the following conditions are met:
|
---|
6 |
|
---|
7 | * Redistributions of source code must retain the above copyright notice, this
|
---|
8 | list of conditions and the following disclaimer.
|
---|
9 | * Redistributions in binary form must reproduce the above copyright notice,
|
---|
10 | this list of conditions and the following disclaimer in the documentation
|
---|
11 | and/or other materials provided with the distribution.
|
---|
12 | * Neither the name of Intel Corporation nor the names of its contributors may
|
---|
13 | be used to endorse or promote products derived from this software without
|
---|
14 | specific prior written permission.
|
---|
15 |
|
---|
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
---|
17 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
---|
18 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
---|
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
---|
20 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
---|
21 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
---|
22 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
---|
23 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
---|
24 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
---|
25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
---|
26 |
|
---|
27 | ********************************************************************************
|
---|
28 | * Content : Eigen bindings to Intel(R) MKL
|
---|
29 | * Triangular matrix-vector product functionality based on ?TRMV.
|
---|
30 | ********************************************************************************
|
---|
31 | */
|
---|
32 |
|
---|
33 | #ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
|
---|
34 | #define EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
|
---|
35 |
|
---|
36 | namespace Eigen {
|
---|
37 |
|
---|
38 | namespace internal {
|
---|
39 |
|
---|
40 | /**********************************************************************
|
---|
41 | * This file implements triangular matrix-vector multiplication using BLAS
|
---|
42 | **********************************************************************/
|
---|
43 |
|
---|
44 | // trmv/hemv specialization
|
---|
45 |
|
---|
46 | template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder>
|
---|
47 | struct triangular_matrix_vector_product_trmv :
|
---|
48 | triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,StorageOrder,BuiltIn> {};
|
---|
49 |
|
---|
50 | #define EIGEN_MKL_TRMV_SPECIALIZE(Scalar) \
|
---|
51 | template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
|
---|
52 | struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
|
---|
53 | static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
|
---|
54 | const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
|
---|
55 | triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \
|
---|
56 | _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
|
---|
57 | } \
|
---|
58 | }; \
|
---|
59 | template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
|
---|
60 | struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \
|
---|
61 | static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \
|
---|
62 | const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \
|
---|
63 | triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \
|
---|
64 | _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
|
---|
65 | } \
|
---|
66 | };
|
---|
67 |
|
---|
68 | EIGEN_MKL_TRMV_SPECIALIZE(double)
|
---|
69 | EIGEN_MKL_TRMV_SPECIALIZE(float)
|
---|
70 | EIGEN_MKL_TRMV_SPECIALIZE(dcomplex)
|
---|
71 | EIGEN_MKL_TRMV_SPECIALIZE(scomplex)
|
---|
72 |
|
---|
73 | // implements col-major: res += alpha * op(triangular) * vector
|
---|
74 | #define EIGEN_MKL_TRMV_CM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
---|
75 | template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
|
---|
76 | struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor> { \
|
---|
77 | enum { \
|
---|
78 | IsLower = (Mode&Lower) == Lower, \
|
---|
79 | SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
|
---|
80 | IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
|
---|
81 | IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
|
---|
82 | LowUp = IsLower ? Lower : Upper \
|
---|
83 | }; \
|
---|
84 | static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
|
---|
85 | const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
|
---|
86 | { \
|
---|
87 | if (ConjLhs || IsZeroDiag) { \
|
---|
88 | triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \
|
---|
89 | _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
|
---|
90 | return; \
|
---|
91 | }\
|
---|
92 | Index size = (std::min)(_rows,_cols); \
|
---|
93 | Index rows = IsLower ? _rows : size; \
|
---|
94 | Index cols = IsLower ? size : _cols; \
|
---|
95 | \
|
---|
96 | typedef VectorX##EIGPREFIX VectorRhs; \
|
---|
97 | EIGTYPE *x, *y;\
|
---|
98 | \
|
---|
99 | /* Set x*/ \
|
---|
100 | Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
|
---|
101 | VectorRhs x_tmp; \
|
---|
102 | if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
|
---|
103 | x = x_tmp.data(); \
|
---|
104 | \
|
---|
105 | /* Square part handling */\
|
---|
106 | \
|
---|
107 | char trans, uplo, diag; \
|
---|
108 | MKL_INT m, n, lda, incx, incy; \
|
---|
109 | EIGTYPE const *a; \
|
---|
110 | MKLTYPE alpha_, beta_; \
|
---|
111 | assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
|
---|
112 | assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
|
---|
113 | \
|
---|
114 | /* Set m, n */ \
|
---|
115 | n = (MKL_INT)size; \
|
---|
116 | lda = lhsStride; \
|
---|
117 | incx = 1; \
|
---|
118 | incy = resIncr; \
|
---|
119 | \
|
---|
120 | /* Set uplo, trans and diag*/ \
|
---|
121 | trans = 'N'; \
|
---|
122 | uplo = IsLower ? 'L' : 'U'; \
|
---|
123 | diag = IsUnitDiag ? 'U' : 'N'; \
|
---|
124 | \
|
---|
125 | /* call ?TRMV*/ \
|
---|
126 | MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
|
---|
127 | \
|
---|
128 | /* Add op(a_tr)rhs into res*/ \
|
---|
129 | MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
|
---|
130 | /* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
|
---|
131 | if (size<(std::max)(rows,cols)) { \
|
---|
132 | typedef Matrix<EIGTYPE, Dynamic, Dynamic> MatrixLhs; \
|
---|
133 | if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
|
---|
134 | x = x_tmp.data(); \
|
---|
135 | if (size<rows) { \
|
---|
136 | y = _res + size*resIncr; \
|
---|
137 | a = _lhs + size; \
|
---|
138 | m = rows-size; \
|
---|
139 | n = size; \
|
---|
140 | } \
|
---|
141 | else { \
|
---|
142 | x += size; \
|
---|
143 | y = _res; \
|
---|
144 | a = _lhs + size*lda; \
|
---|
145 | m = size; \
|
---|
146 | n = cols-size; \
|
---|
147 | } \
|
---|
148 | MKLPREFIX##gemv(&trans, &m, &n, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
|
---|
149 | } \
|
---|
150 | } \
|
---|
151 | };
|
---|
152 |
|
---|
153 | EIGEN_MKL_TRMV_CM(double, double, d, d)
|
---|
154 | EIGEN_MKL_TRMV_CM(dcomplex, MKL_Complex16, cd, z)
|
---|
155 | EIGEN_MKL_TRMV_CM(float, float, f, s)
|
---|
156 | EIGEN_MKL_TRMV_CM(scomplex, MKL_Complex8, cf, c)
|
---|
157 |
|
---|
158 | // implements row-major: res += alpha * op(triangular) * vector
|
---|
159 | #define EIGEN_MKL_TRMV_RM(EIGTYPE, MKLTYPE, EIGPREFIX, MKLPREFIX) \
|
---|
160 | template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
|
---|
161 | struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor> { \
|
---|
162 | enum { \
|
---|
163 | IsLower = (Mode&Lower) == Lower, \
|
---|
164 | SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
|
---|
165 | IsUnitDiag = (Mode&UnitDiag) ? 1 : 0, \
|
---|
166 | IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
|
---|
167 | LowUp = IsLower ? Lower : Upper \
|
---|
168 | }; \
|
---|
169 | static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \
|
---|
170 | const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \
|
---|
171 | { \
|
---|
172 | if (IsZeroDiag) { \
|
---|
173 | triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \
|
---|
174 | _rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \
|
---|
175 | return; \
|
---|
176 | }\
|
---|
177 | Index size = (std::min)(_rows,_cols); \
|
---|
178 | Index rows = IsLower ? _rows : size; \
|
---|
179 | Index cols = IsLower ? size : _cols; \
|
---|
180 | \
|
---|
181 | typedef VectorX##EIGPREFIX VectorRhs; \
|
---|
182 | EIGTYPE *x, *y;\
|
---|
183 | \
|
---|
184 | /* Set x*/ \
|
---|
185 | Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \
|
---|
186 | VectorRhs x_tmp; \
|
---|
187 | if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
|
---|
188 | x = x_tmp.data(); \
|
---|
189 | \
|
---|
190 | /* Square part handling */\
|
---|
191 | \
|
---|
192 | char trans, uplo, diag; \
|
---|
193 | MKL_INT m, n, lda, incx, incy; \
|
---|
194 | EIGTYPE const *a; \
|
---|
195 | MKLTYPE alpha_, beta_; \
|
---|
196 | assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(alpha_, alpha); \
|
---|
197 | assign_scalar_eig2mkl<MKLTYPE, EIGTYPE>(beta_, EIGTYPE(1)); \
|
---|
198 | \
|
---|
199 | /* Set m, n */ \
|
---|
200 | n = (MKL_INT)size; \
|
---|
201 | lda = lhsStride; \
|
---|
202 | incx = 1; \
|
---|
203 | incy = resIncr; \
|
---|
204 | \
|
---|
205 | /* Set uplo, trans and diag*/ \
|
---|
206 | trans = ConjLhs ? 'C' : 'T'; \
|
---|
207 | uplo = IsLower ? 'U' : 'L'; \
|
---|
208 | diag = IsUnitDiag ? 'U' : 'N'; \
|
---|
209 | \
|
---|
210 | /* call ?TRMV*/ \
|
---|
211 | MKLPREFIX##trmv(&uplo, &trans, &diag, &n, (const MKLTYPE*)_lhs, &lda, (MKLTYPE*)x, &incx); \
|
---|
212 | \
|
---|
213 | /* Add op(a_tr)rhs into res*/ \
|
---|
214 | MKLPREFIX##axpy(&n, &alpha_,(const MKLTYPE*)x, &incx, (MKLTYPE*)_res, &incy); \
|
---|
215 | /* Non-square case - doesn't fit to MKL ?TRMV. Fall to default triangular product*/ \
|
---|
216 | if (size<(std::max)(rows,cols)) { \
|
---|
217 | typedef Matrix<EIGTYPE, Dynamic, Dynamic> MatrixLhs; \
|
---|
218 | if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
|
---|
219 | x = x_tmp.data(); \
|
---|
220 | if (size<rows) { \
|
---|
221 | y = _res + size*resIncr; \
|
---|
222 | a = _lhs + size*lda; \
|
---|
223 | m = rows-size; \
|
---|
224 | n = size; \
|
---|
225 | } \
|
---|
226 | else { \
|
---|
227 | x += size; \
|
---|
228 | y = _res; \
|
---|
229 | a = _lhs + size; \
|
---|
230 | m = size; \
|
---|
231 | n = cols-size; \
|
---|
232 | } \
|
---|
233 | MKLPREFIX##gemv(&trans, &n, &m, &alpha_, (const MKLTYPE*)a, &lda, (const MKLTYPE*)x, &incx, &beta_, (MKLTYPE*)y, &incy); \
|
---|
234 | } \
|
---|
235 | } \
|
---|
236 | };
|
---|
237 |
|
---|
238 | EIGEN_MKL_TRMV_RM(double, double, d, d)
|
---|
239 | EIGEN_MKL_TRMV_RM(dcomplex, MKL_Complex16, cd, z)
|
---|
240 | EIGEN_MKL_TRMV_RM(float, float, f, s)
|
---|
241 | EIGEN_MKL_TRMV_RM(scomplex, MKL_Complex8, cf, c)
|
---|
242 |
|
---|
243 | } // end namespase internal
|
---|
244 |
|
---|
245 | } // end namespace Eigen
|
---|
246 |
|
---|
247 | #endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_MKL_H
|
---|