[136] | 1 | // This file is part of Eigen, a lightweight C++ template library
|
---|
| 2 | // for linear algebra.
|
---|
| 3 | //
|
---|
| 4 | // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
|
---|
| 5 | //
|
---|
| 6 | // This Source Code Form is subject to the terms of the Mozilla
|
---|
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed
|
---|
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
---|
| 9 |
|
---|
| 10 | #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_H
|
---|
| 11 | #define EIGEN_TRIANGULAR_MATRIX_MATRIX_H
|
---|
| 12 |
|
---|
| 13 | namespace Eigen {
|
---|
| 14 |
|
---|
| 15 | namespace internal {
|
---|
| 16 |
|
---|
| 17 | // template<typename Scalar, int mr, int StorageOrder, bool Conjugate, int Mode>
|
---|
| 18 | // struct gemm_pack_lhs_triangular
|
---|
| 19 | // {
|
---|
| 20 | // Matrix<Scalar,mr,mr,
|
---|
| 21 | // void operator()(Scalar* blockA, const EIGEN_RESTRICT Scalar* _lhs, int lhsStride, int depth, int rows)
|
---|
| 22 | // {
|
---|
| 23 | // conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
|
---|
| 24 | // const_blas_data_mapper<Scalar, StorageOrder> lhs(_lhs,lhsStride);
|
---|
| 25 | // int count = 0;
|
---|
| 26 | // const int peeled_mc = (rows/mr)*mr;
|
---|
| 27 | // for(int i=0; i<peeled_mc; i+=mr)
|
---|
| 28 | // {
|
---|
| 29 | // for(int k=0; k<depth; k++)
|
---|
| 30 | // for(int w=0; w<mr; w++)
|
---|
| 31 | // blockA[count++] = cj(lhs(i+w, k));
|
---|
| 32 | // }
|
---|
| 33 | // for(int i=peeled_mc; i<rows; i++)
|
---|
| 34 | // {
|
---|
| 35 | // for(int k=0; k<depth; k++)
|
---|
| 36 | // blockA[count++] = cj(lhs(i, k));
|
---|
| 37 | // }
|
---|
| 38 | // }
|
---|
| 39 | // };
|
---|
| 40 |
|
---|
| 41 | /* Optimized triangular matrix * matrix (_TRMM++) product built on top of
|
---|
| 42 | * the general matrix matrix product.
|
---|
| 43 | */
|
---|
| 44 | template <typename Scalar, typename Index,
|
---|
| 45 | int Mode, bool LhsIsTriangular,
|
---|
| 46 | int LhsStorageOrder, bool ConjugateLhs,
|
---|
| 47 | int RhsStorageOrder, bool ConjugateRhs,
|
---|
| 48 | int ResStorageOrder, int Version = Specialized>
|
---|
| 49 | struct product_triangular_matrix_matrix;
|
---|
| 50 |
|
---|
| 51 | template <typename Scalar, typename Index,
|
---|
| 52 | int Mode, bool LhsIsTriangular,
|
---|
| 53 | int LhsStorageOrder, bool ConjugateLhs,
|
---|
| 54 | int RhsStorageOrder, bool ConjugateRhs, int Version>
|
---|
| 55 | struct product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular,
|
---|
| 56 | LhsStorageOrder,ConjugateLhs,
|
---|
| 57 | RhsStorageOrder,ConjugateRhs,RowMajor,Version>
|
---|
| 58 | {
|
---|
| 59 | static EIGEN_STRONG_INLINE void run(
|
---|
| 60 | Index rows, Index cols, Index depth,
|
---|
| 61 | const Scalar* lhs, Index lhsStride,
|
---|
| 62 | const Scalar* rhs, Index rhsStride,
|
---|
| 63 | Scalar* res, Index resStride,
|
---|
| 64 | const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
|
---|
| 65 | {
|
---|
| 66 | product_triangular_matrix_matrix<Scalar, Index,
|
---|
| 67 | (Mode&(UnitDiag|ZeroDiag)) | ((Mode&Upper) ? Lower : Upper),
|
---|
| 68 | (!LhsIsTriangular),
|
---|
| 69 | RhsStorageOrder==RowMajor ? ColMajor : RowMajor,
|
---|
| 70 | ConjugateRhs,
|
---|
| 71 | LhsStorageOrder==RowMajor ? ColMajor : RowMajor,
|
---|
| 72 | ConjugateLhs,
|
---|
| 73 | ColMajor>
|
---|
| 74 | ::run(cols, rows, depth, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha, blocking);
|
---|
| 75 | }
|
---|
| 76 | };
|
---|
| 77 |
|
---|
| 78 | // implements col-major += alpha * op(triangular) * op(general)
|
---|
| 79 | template <typename Scalar, typename Index, int Mode,
|
---|
| 80 | int LhsStorageOrder, bool ConjugateLhs,
|
---|
| 81 | int RhsStorageOrder, bool ConjugateRhs, int Version>
|
---|
| 82 | struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
---|
| 83 | LhsStorageOrder,ConjugateLhs,
|
---|
| 84 | RhsStorageOrder,ConjugateRhs,ColMajor,Version>
|
---|
| 85 | {
|
---|
| 86 |
|
---|
| 87 | typedef gebp_traits<Scalar,Scalar> Traits;
|
---|
| 88 | enum {
|
---|
| 89 | SmallPanelWidth = 2 * EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
|
---|
| 90 | IsLower = (Mode&Lower) == Lower,
|
---|
| 91 | SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1
|
---|
| 92 | };
|
---|
| 93 |
|
---|
| 94 | static EIGEN_DONT_INLINE void run(
|
---|
| 95 | Index _rows, Index _cols, Index _depth,
|
---|
| 96 | const Scalar* _lhs, Index lhsStride,
|
---|
| 97 | const Scalar* _rhs, Index rhsStride,
|
---|
| 98 | Scalar* res, Index resStride,
|
---|
| 99 | const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
|
---|
| 100 | };
|
---|
| 101 |
|
---|
| 102 | template <typename Scalar, typename Index, int Mode,
|
---|
| 103 | int LhsStorageOrder, bool ConjugateLhs,
|
---|
| 104 | int RhsStorageOrder, bool ConjugateRhs, int Version>
|
---|
| 105 | EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
---|
| 106 | LhsStorageOrder,ConjugateLhs,
|
---|
| 107 | RhsStorageOrder,ConjugateRhs,ColMajor,Version>::run(
|
---|
| 108 | Index _rows, Index _cols, Index _depth,
|
---|
| 109 | const Scalar* _lhs, Index lhsStride,
|
---|
| 110 | const Scalar* _rhs, Index rhsStride,
|
---|
| 111 | Scalar* res, Index resStride,
|
---|
| 112 | const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
|
---|
| 113 | {
|
---|
| 114 | // strip zeros
|
---|
| 115 | Index diagSize = (std::min)(_rows,_depth);
|
---|
| 116 | Index rows = IsLower ? _rows : diagSize;
|
---|
| 117 | Index depth = IsLower ? diagSize : _depth;
|
---|
| 118 | Index cols = _cols;
|
---|
| 119 |
|
---|
| 120 | const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
---|
| 121 | const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
---|
| 122 |
|
---|
| 123 | Index kc = blocking.kc(); // cache block size along the K direction
|
---|
| 124 | Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
---|
| 125 |
|
---|
| 126 | std::size_t sizeA = kc*mc;
|
---|
| 127 | std::size_t sizeB = kc*cols;
|
---|
| 128 | std::size_t sizeW = kc*Traits::WorkSpaceFactor;
|
---|
| 129 |
|
---|
| 130 | ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
|
---|
| 131 | ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
|
---|
| 132 | ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
|
---|
| 133 |
|
---|
| 134 | Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,LhsStorageOrder> triangularBuffer;
|
---|
| 135 | triangularBuffer.setZero();
|
---|
| 136 | if((Mode&ZeroDiag)==ZeroDiag)
|
---|
| 137 | triangularBuffer.diagonal().setZero();
|
---|
| 138 | else
|
---|
| 139 | triangularBuffer.diagonal().setOnes();
|
---|
| 140 |
|
---|
| 141 | gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
|
---|
| 142 | gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
|
---|
| 143 | gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs;
|
---|
| 144 |
|
---|
| 145 | for(Index k2=IsLower ? depth : 0;
|
---|
| 146 | IsLower ? k2>0 : k2<depth;
|
---|
| 147 | IsLower ? k2-=kc : k2+=kc)
|
---|
| 148 | {
|
---|
| 149 | Index actual_kc = (std::min)(IsLower ? k2 : depth-k2, kc);
|
---|
| 150 | Index actual_k2 = IsLower ? k2-actual_kc : k2;
|
---|
| 151 |
|
---|
| 152 | // align blocks with the end of the triangular part for trapezoidal lhs
|
---|
| 153 | if((!IsLower)&&(k2<rows)&&(k2+actual_kc>rows))
|
---|
| 154 | {
|
---|
| 155 | actual_kc = rows-k2;
|
---|
| 156 | k2 = k2+actual_kc-kc;
|
---|
| 157 | }
|
---|
| 158 |
|
---|
| 159 | pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, actual_kc, cols);
|
---|
| 160 |
|
---|
| 161 | // the selected lhs's panel has to be split in three different parts:
|
---|
| 162 | // 1 - the part which is zero => skip it
|
---|
| 163 | // 2 - the diagonal block => special kernel
|
---|
| 164 | // 3 - the dense panel below (lower case) or above (upper case) the diagonal block => GEPP
|
---|
| 165 |
|
---|
| 166 | // the block diagonal, if any:
|
---|
| 167 | if(IsLower || actual_k2<rows)
|
---|
| 168 | {
|
---|
| 169 | // for each small vertical panels of lhs
|
---|
| 170 | for (Index k1=0; k1<actual_kc; k1+=SmallPanelWidth)
|
---|
| 171 | {
|
---|
| 172 | Index actualPanelWidth = std::min<Index>(actual_kc-k1, SmallPanelWidth);
|
---|
| 173 | Index lengthTarget = IsLower ? actual_kc-k1-actualPanelWidth : k1;
|
---|
| 174 | Index startBlock = actual_k2+k1;
|
---|
| 175 | Index blockBOffset = k1;
|
---|
| 176 |
|
---|
| 177 | // => GEBP with the micro triangular block
|
---|
| 178 | // The trick is to pack this micro block while filling the opposite triangular part with zeros.
|
---|
| 179 | // To this end we do an extra triangular copy to a small temporary buffer
|
---|
| 180 | for (Index k=0;k<actualPanelWidth;++k)
|
---|
| 181 | {
|
---|
| 182 | if (SetDiag)
|
---|
| 183 | triangularBuffer.coeffRef(k,k) = lhs(startBlock+k,startBlock+k);
|
---|
| 184 | for (Index i=IsLower ? k+1 : 0; IsLower ? i<actualPanelWidth : i<k; ++i)
|
---|
| 185 | triangularBuffer.coeffRef(i,k) = lhs(startBlock+i,startBlock+k);
|
---|
| 186 | }
|
---|
| 187 | pack_lhs(blockA, triangularBuffer.data(), triangularBuffer.outerStride(), actualPanelWidth, actualPanelWidth);
|
---|
| 188 |
|
---|
| 189 | gebp_kernel(res+startBlock, resStride, blockA, blockB, actualPanelWidth, actualPanelWidth, cols, alpha,
|
---|
| 190 | actualPanelWidth, actual_kc, 0, blockBOffset, blockW);
|
---|
| 191 |
|
---|
| 192 | // GEBP with remaining micro panel
|
---|
| 193 | if (lengthTarget>0)
|
---|
| 194 | {
|
---|
| 195 | Index startTarget = IsLower ? actual_k2+k1+actualPanelWidth : actual_k2;
|
---|
| 196 |
|
---|
| 197 | pack_lhs(blockA, &lhs(startTarget,startBlock), lhsStride, actualPanelWidth, lengthTarget);
|
---|
| 198 |
|
---|
| 199 | gebp_kernel(res+startTarget, resStride, blockA, blockB, lengthTarget, actualPanelWidth, cols, alpha,
|
---|
| 200 | actualPanelWidth, actual_kc, 0, blockBOffset, blockW);
|
---|
| 201 | }
|
---|
| 202 | }
|
---|
| 203 | }
|
---|
| 204 | // the part below (lower case) or above (upper case) the diagonal => GEPP
|
---|
| 205 | {
|
---|
| 206 | Index start = IsLower ? k2 : 0;
|
---|
| 207 | Index end = IsLower ? rows : (std::min)(actual_k2,rows);
|
---|
| 208 | for(Index i2=start; i2<end; i2+=mc)
|
---|
| 209 | {
|
---|
| 210 | const Index actual_mc = (std::min)(i2+mc,end)-i2;
|
---|
| 211 | gemm_pack_lhs<Scalar, Index, Traits::mr,Traits::LhsProgress, LhsStorageOrder,false>()
|
---|
| 212 | (blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc);
|
---|
| 213 |
|
---|
| 214 | gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha, -1, -1, 0, 0, blockW);
|
---|
| 215 | }
|
---|
| 216 | }
|
---|
| 217 | }
|
---|
| 218 | }
|
---|
| 219 |
|
---|
| 220 | // implements col-major += alpha * op(general) * op(triangular)
|
---|
| 221 | template <typename Scalar, typename Index, int Mode,
|
---|
| 222 | int LhsStorageOrder, bool ConjugateLhs,
|
---|
| 223 | int RhsStorageOrder, bool ConjugateRhs, int Version>
|
---|
| 224 | struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
---|
| 225 | LhsStorageOrder,ConjugateLhs,
|
---|
| 226 | RhsStorageOrder,ConjugateRhs,ColMajor,Version>
|
---|
| 227 | {
|
---|
| 228 | typedef gebp_traits<Scalar,Scalar> Traits;
|
---|
| 229 | enum {
|
---|
| 230 | SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
|
---|
| 231 | IsLower = (Mode&Lower) == Lower,
|
---|
| 232 | SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1
|
---|
| 233 | };
|
---|
| 234 |
|
---|
| 235 | static EIGEN_DONT_INLINE void run(
|
---|
| 236 | Index _rows, Index _cols, Index _depth,
|
---|
| 237 | const Scalar* _lhs, Index lhsStride,
|
---|
| 238 | const Scalar* _rhs, Index rhsStride,
|
---|
| 239 | Scalar* res, Index resStride,
|
---|
| 240 | const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
|
---|
| 241 | };
|
---|
| 242 |
|
---|
| 243 | template <typename Scalar, typename Index, int Mode,
|
---|
| 244 | int LhsStorageOrder, bool ConjugateLhs,
|
---|
| 245 | int RhsStorageOrder, bool ConjugateRhs, int Version>
|
---|
| 246 | EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
---|
| 247 | LhsStorageOrder,ConjugateLhs,
|
---|
| 248 | RhsStorageOrder,ConjugateRhs,ColMajor,Version>::run(
|
---|
| 249 | Index _rows, Index _cols, Index _depth,
|
---|
| 250 | const Scalar* _lhs, Index lhsStride,
|
---|
| 251 | const Scalar* _rhs, Index rhsStride,
|
---|
| 252 | Scalar* res, Index resStride,
|
---|
| 253 | const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
|
---|
| 254 | {
|
---|
| 255 | // strip zeros
|
---|
| 256 | Index diagSize = (std::min)(_cols,_depth);
|
---|
| 257 | Index rows = _rows;
|
---|
| 258 | Index depth = IsLower ? _depth : diagSize;
|
---|
| 259 | Index cols = IsLower ? diagSize : _cols;
|
---|
| 260 |
|
---|
| 261 | const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
---|
| 262 | const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
---|
| 263 |
|
---|
| 264 | Index kc = blocking.kc(); // cache block size along the K direction
|
---|
| 265 | Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
---|
| 266 |
|
---|
| 267 | std::size_t sizeA = kc*mc;
|
---|
| 268 | std::size_t sizeB = kc*cols;
|
---|
| 269 | std::size_t sizeW = kc*Traits::WorkSpaceFactor;
|
---|
| 270 |
|
---|
| 271 | ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
|
---|
| 272 | ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
|
---|
| 273 | ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
|
---|
| 274 |
|
---|
| 275 | Matrix<Scalar,SmallPanelWidth,SmallPanelWidth,RhsStorageOrder> triangularBuffer;
|
---|
| 276 | triangularBuffer.setZero();
|
---|
| 277 | if((Mode&ZeroDiag)==ZeroDiag)
|
---|
| 278 | triangularBuffer.diagonal().setZero();
|
---|
| 279 | else
|
---|
| 280 | triangularBuffer.diagonal().setOnes();
|
---|
| 281 |
|
---|
| 282 | gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
|
---|
| 283 | gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
|
---|
| 284 | gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs;
|
---|
| 285 | gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder,false,true> pack_rhs_panel;
|
---|
| 286 |
|
---|
| 287 | for(Index k2=IsLower ? 0 : depth;
|
---|
| 288 | IsLower ? k2<depth : k2>0;
|
---|
| 289 | IsLower ? k2+=kc : k2-=kc)
|
---|
| 290 | {
|
---|
| 291 | Index actual_kc = (std::min)(IsLower ? depth-k2 : k2, kc);
|
---|
| 292 | Index actual_k2 = IsLower ? k2 : k2-actual_kc;
|
---|
| 293 |
|
---|
| 294 | // align blocks with the end of the triangular part for trapezoidal rhs
|
---|
| 295 | if(IsLower && (k2<cols) && (actual_k2+actual_kc>cols))
|
---|
| 296 | {
|
---|
| 297 | actual_kc = cols-k2;
|
---|
| 298 | k2 = actual_k2 + actual_kc - kc;
|
---|
| 299 | }
|
---|
| 300 |
|
---|
| 301 | // remaining size
|
---|
| 302 | Index rs = IsLower ? (std::min)(cols,actual_k2) : cols - k2;
|
---|
| 303 | // size of the triangular part
|
---|
| 304 | Index ts = (IsLower && actual_k2>=cols) ? 0 : actual_kc;
|
---|
| 305 |
|
---|
| 306 | Scalar* geb = blockB+ts*ts;
|
---|
| 307 |
|
---|
| 308 | pack_rhs(geb, &rhs(actual_k2,IsLower ? 0 : k2), rhsStride, actual_kc, rs);
|
---|
| 309 |
|
---|
| 310 | // pack the triangular part of the rhs padding the unrolled blocks with zeros
|
---|
| 311 | if(ts>0)
|
---|
| 312 | {
|
---|
| 313 | for (Index j2=0; j2<actual_kc; j2+=SmallPanelWidth)
|
---|
| 314 | {
|
---|
| 315 | Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth);
|
---|
| 316 | Index actual_j2 = actual_k2 + j2;
|
---|
| 317 | Index panelOffset = IsLower ? j2+actualPanelWidth : 0;
|
---|
| 318 | Index panelLength = IsLower ? actual_kc-j2-actualPanelWidth : j2;
|
---|
| 319 | // general part
|
---|
| 320 | pack_rhs_panel(blockB+j2*actual_kc,
|
---|
| 321 | &rhs(actual_k2+panelOffset, actual_j2), rhsStride,
|
---|
| 322 | panelLength, actualPanelWidth,
|
---|
| 323 | actual_kc, panelOffset);
|
---|
| 324 |
|
---|
| 325 | // append the triangular part via a temporary buffer
|
---|
| 326 | for (Index j=0;j<actualPanelWidth;++j)
|
---|
| 327 | {
|
---|
| 328 | if (SetDiag)
|
---|
| 329 | triangularBuffer.coeffRef(j,j) = rhs(actual_j2+j,actual_j2+j);
|
---|
| 330 | for (Index k=IsLower ? j+1 : 0; IsLower ? k<actualPanelWidth : k<j; ++k)
|
---|
| 331 | triangularBuffer.coeffRef(k,j) = rhs(actual_j2+k,actual_j2+j);
|
---|
| 332 | }
|
---|
| 333 |
|
---|
| 334 | pack_rhs_panel(blockB+j2*actual_kc,
|
---|
| 335 | triangularBuffer.data(), triangularBuffer.outerStride(),
|
---|
| 336 | actualPanelWidth, actualPanelWidth,
|
---|
| 337 | actual_kc, j2);
|
---|
| 338 | }
|
---|
| 339 | }
|
---|
| 340 |
|
---|
| 341 | for (Index i2=0; i2<rows; i2+=mc)
|
---|
| 342 | {
|
---|
| 343 | const Index actual_mc = (std::min)(mc,rows-i2);
|
---|
| 344 | pack_lhs(blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc);
|
---|
| 345 |
|
---|
| 346 | // triangular kernel
|
---|
| 347 | if(ts>0)
|
---|
| 348 | {
|
---|
| 349 | for (Index j2=0; j2<actual_kc; j2+=SmallPanelWidth)
|
---|
| 350 | {
|
---|
| 351 | Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth);
|
---|
| 352 | Index panelLength = IsLower ? actual_kc-j2 : j2+actualPanelWidth;
|
---|
| 353 | Index blockOffset = IsLower ? j2 : 0;
|
---|
| 354 |
|
---|
| 355 | gebp_kernel(res+i2+(actual_k2+j2)*resStride, resStride,
|
---|
| 356 | blockA, blockB+j2*actual_kc,
|
---|
| 357 | actual_mc, panelLength, actualPanelWidth,
|
---|
| 358 | alpha,
|
---|
| 359 | actual_kc, actual_kc, // strides
|
---|
| 360 | blockOffset, blockOffset,// offsets
|
---|
| 361 | blockW); // workspace
|
---|
| 362 | }
|
---|
| 363 | }
|
---|
| 364 | gebp_kernel(res+i2+(IsLower ? 0 : k2)*resStride, resStride,
|
---|
| 365 | blockA, geb, actual_mc, actual_kc, rs,
|
---|
| 366 | alpha,
|
---|
| 367 | -1, -1, 0, 0, blockW);
|
---|
| 368 | }
|
---|
| 369 | }
|
---|
| 370 | }
|
---|
| 371 |
|
---|
| 372 | /***************************************************************************
|
---|
| 373 | * Wrapper to product_triangular_matrix_matrix
|
---|
| 374 | ***************************************************************************/
|
---|
| 375 |
|
---|
| 376 | template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
|
---|
| 377 | struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false> >
|
---|
| 378 | : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs> >
|
---|
| 379 | {};
|
---|
| 380 |
|
---|
| 381 | } // end namespace internal
|
---|
| 382 |
|
---|
| 383 | template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
|
---|
| 384 | struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
|
---|
| 385 | : public ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>, Lhs, Rhs >
|
---|
| 386 | {
|
---|
| 387 | EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
|
---|
| 388 |
|
---|
| 389 | TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
|
---|
| 390 |
|
---|
| 391 | template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const
|
---|
| 392 | {
|
---|
| 393 | typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(m_lhs);
|
---|
| 394 | typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(m_rhs);
|
---|
| 395 |
|
---|
| 396 | Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
---|
| 397 | * RhsBlasTraits::extractScalarFactor(m_rhs);
|
---|
| 398 |
|
---|
| 399 | typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar,
|
---|
| 400 | Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,4> BlockingType;
|
---|
| 401 |
|
---|
| 402 | enum { IsLower = (Mode&Lower) == Lower };
|
---|
| 403 | Index stripedRows = ((!LhsIsTriangular) || (IsLower)) ? lhs.rows() : (std::min)(lhs.rows(),lhs.cols());
|
---|
| 404 | Index stripedCols = ((LhsIsTriangular) || (!IsLower)) ? rhs.cols() : (std::min)(rhs.cols(),rhs.rows());
|
---|
| 405 | Index stripedDepth = LhsIsTriangular ? ((!IsLower) ? lhs.cols() : (std::min)(lhs.cols(),lhs.rows()))
|
---|
| 406 | : ((IsLower) ? rhs.rows() : (std::min)(rhs.rows(),rhs.cols()));
|
---|
| 407 |
|
---|
| 408 | BlockingType blocking(stripedRows, stripedCols, stripedDepth);
|
---|
| 409 |
|
---|
| 410 | internal::product_triangular_matrix_matrix<Scalar, Index,
|
---|
| 411 | Mode, LhsIsTriangular,
|
---|
| 412 | (internal::traits<_ActualLhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
|
---|
| 413 | (internal::traits<_ActualRhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
|
---|
| 414 | (internal::traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
---|
| 415 | ::run(
|
---|
| 416 | stripedRows, stripedCols, stripedDepth, // sizes
|
---|
| 417 | &lhs.coeffRef(0,0), lhs.outerStride(), // lhs info
|
---|
| 418 | &rhs.coeffRef(0,0), rhs.outerStride(), // rhs info
|
---|
| 419 | &dst.coeffRef(0,0), dst.outerStride(), // result info
|
---|
| 420 | actualAlpha, blocking
|
---|
| 421 | );
|
---|
| 422 | }
|
---|
| 423 | };
|
---|
| 424 |
|
---|
| 425 | } // end namespace Eigen
|
---|
| 426 |
|
---|
| 427 | #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_H
|
---|