[136] | 1 | // This file is part of Eigen, a lightweight C++ template library
|
---|
| 2 | // for linear algebra.
|
---|
| 3 | //
|
---|
| 4 | // Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
|
---|
| 5 | //
|
---|
| 6 | // This Source Code Form is subject to the terms of the Mozilla
|
---|
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed
|
---|
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
---|
| 9 |
|
---|
| 10 | #ifndef EIGEN_SPARSEDENSEPRODUCT_H
|
---|
| 11 | #define EIGEN_SPARSEDENSEPRODUCT_H
|
---|
| 12 |
|
---|
| 13 | namespace Eigen {
|
---|
| 14 |
|
---|
| 15 | template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType
|
---|
| 16 | {
|
---|
| 17 | typedef SparseTimeDenseProduct<Lhs,Rhs> Type;
|
---|
| 18 | };
|
---|
| 19 |
|
---|
| 20 | template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1>
|
---|
| 21 | {
|
---|
| 22 | typedef typename internal::conditional<
|
---|
| 23 | Lhs::IsRowMajor,
|
---|
| 24 | SparseDenseOuterProduct<Rhs,Lhs,true>,
|
---|
| 25 | SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type;
|
---|
| 26 | };
|
---|
| 27 |
|
---|
| 28 | template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType
|
---|
| 29 | {
|
---|
| 30 | typedef DenseTimeSparseProduct<Lhs,Rhs> Type;
|
---|
| 31 | };
|
---|
| 32 |
|
---|
| 33 | template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1>
|
---|
| 34 | {
|
---|
| 35 | typedef typename internal::conditional<
|
---|
| 36 | Rhs::IsRowMajor,
|
---|
| 37 | SparseDenseOuterProduct<Rhs,Lhs,true>,
|
---|
| 38 | SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type;
|
---|
| 39 | };
|
---|
| 40 |
|
---|
| 41 | namespace internal {
|
---|
| 42 |
|
---|
| 43 | template<typename Lhs, typename Rhs, bool Tr>
|
---|
| 44 | struct traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
|
---|
| 45 | {
|
---|
| 46 | typedef Sparse StorageKind;
|
---|
| 47 | typedef typename scalar_product_traits<typename traits<Lhs>::Scalar,
|
---|
| 48 | typename traits<Rhs>::Scalar>::ReturnType Scalar;
|
---|
| 49 | typedef typename Lhs::Index Index;
|
---|
| 50 | typedef typename Lhs::Nested LhsNested;
|
---|
| 51 | typedef typename Rhs::Nested RhsNested;
|
---|
| 52 | typedef typename remove_all<LhsNested>::type _LhsNested;
|
---|
| 53 | typedef typename remove_all<RhsNested>::type _RhsNested;
|
---|
| 54 |
|
---|
| 55 | enum {
|
---|
| 56 | LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost,
|
---|
| 57 | RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost,
|
---|
| 58 |
|
---|
| 59 | RowsAtCompileTime = Tr ? int(traits<Rhs>::RowsAtCompileTime) : int(traits<Lhs>::RowsAtCompileTime),
|
---|
| 60 | ColsAtCompileTime = Tr ? int(traits<Lhs>::ColsAtCompileTime) : int(traits<Rhs>::ColsAtCompileTime),
|
---|
| 61 | MaxRowsAtCompileTime = Tr ? int(traits<Rhs>::MaxRowsAtCompileTime) : int(traits<Lhs>::MaxRowsAtCompileTime),
|
---|
| 62 | MaxColsAtCompileTime = Tr ? int(traits<Lhs>::MaxColsAtCompileTime) : int(traits<Rhs>::MaxColsAtCompileTime),
|
---|
| 63 |
|
---|
| 64 | Flags = Tr ? RowMajorBit : 0,
|
---|
| 65 |
|
---|
| 66 | CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost
|
---|
| 67 | };
|
---|
| 68 | };
|
---|
| 69 |
|
---|
| 70 | } // end namespace internal
|
---|
| 71 |
|
---|
| 72 | template<typename Lhs, typename Rhs, bool Tr>
|
---|
| 73 | class SparseDenseOuterProduct
|
---|
| 74 | : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
|
---|
| 75 | {
|
---|
| 76 | public:
|
---|
| 77 |
|
---|
| 78 | typedef SparseMatrixBase<SparseDenseOuterProduct> Base;
|
---|
| 79 | EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct)
|
---|
| 80 | typedef internal::traits<SparseDenseOuterProduct> Traits;
|
---|
| 81 |
|
---|
| 82 | private:
|
---|
| 83 |
|
---|
| 84 | typedef typename Traits::LhsNested LhsNested;
|
---|
| 85 | typedef typename Traits::RhsNested RhsNested;
|
---|
| 86 | typedef typename Traits::_LhsNested _LhsNested;
|
---|
| 87 | typedef typename Traits::_RhsNested _RhsNested;
|
---|
| 88 |
|
---|
| 89 | public:
|
---|
| 90 |
|
---|
| 91 | class InnerIterator;
|
---|
| 92 |
|
---|
| 93 | EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs)
|
---|
| 94 | : m_lhs(lhs), m_rhs(rhs)
|
---|
| 95 | {
|
---|
| 96 | EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
|
---|
| 97 | }
|
---|
| 98 |
|
---|
| 99 | EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs)
|
---|
| 100 | : m_lhs(lhs), m_rhs(rhs)
|
---|
| 101 | {
|
---|
| 102 | EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
|
---|
| 103 | }
|
---|
| 104 |
|
---|
| 105 | EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); }
|
---|
| 106 | EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); }
|
---|
| 107 |
|
---|
| 108 | EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
|
---|
| 109 | EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
|
---|
| 110 |
|
---|
| 111 | protected:
|
---|
| 112 | LhsNested m_lhs;
|
---|
| 113 | RhsNested m_rhs;
|
---|
| 114 | };
|
---|
| 115 |
|
---|
| 116 | template<typename Lhs, typename Rhs, bool Transpose>
|
---|
| 117 | class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator
|
---|
| 118 | {
|
---|
| 119 | typedef typename _LhsNested::InnerIterator Base;
|
---|
| 120 | typedef typename SparseDenseOuterProduct::Index Index;
|
---|
| 121 | public:
|
---|
| 122 | EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer)
|
---|
| 123 | : Base(prod.lhs(), 0), m_outer(outer), m_factor(get(prod.rhs(), outer, typename internal::traits<Rhs>::StorageKind() ))
|
---|
| 124 | { }
|
---|
| 125 |
|
---|
| 126 | inline Index outer() const { return m_outer; }
|
---|
| 127 | inline Index row() const { return Transpose ? m_outer : Base::index(); }
|
---|
| 128 | inline Index col() const { return Transpose ? Base::index() : m_outer; }
|
---|
| 129 |
|
---|
| 130 | inline Scalar value() const { return Base::value() * m_factor; }
|
---|
| 131 |
|
---|
| 132 | protected:
|
---|
| 133 | static Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense())
|
---|
| 134 | {
|
---|
| 135 | return rhs.coeff(outer);
|
---|
| 136 | }
|
---|
| 137 |
|
---|
| 138 | static Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse())
|
---|
| 139 | {
|
---|
| 140 | typename Traits::_RhsNested::InnerIterator it(rhs, outer);
|
---|
| 141 | if (it && it.index()==0)
|
---|
| 142 | return it.value();
|
---|
| 143 |
|
---|
| 144 | return Scalar(0);
|
---|
| 145 | }
|
---|
| 146 |
|
---|
| 147 | Index m_outer;
|
---|
| 148 | Scalar m_factor;
|
---|
| 149 | };
|
---|
| 150 |
|
---|
| 151 | namespace internal {
|
---|
| 152 | template<typename Lhs, typename Rhs>
|
---|
| 153 | struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
|
---|
| 154 | : traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> >
|
---|
| 155 | {
|
---|
| 156 | typedef Dense StorageKind;
|
---|
| 157 | typedef MatrixXpr XprKind;
|
---|
| 158 | };
|
---|
| 159 |
|
---|
| 160 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
|
---|
| 161 | int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
|
---|
| 162 | bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
|
---|
| 163 | struct sparse_time_dense_product_impl;
|
---|
| 164 |
|
---|
| 165 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
---|
| 166 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true>
|
---|
| 167 | {
|
---|
| 168 | typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
---|
| 169 | typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
---|
| 170 | typedef typename internal::remove_all<DenseResType>::type Res;
|
---|
| 171 | typedef typename Lhs::Index Index;
|
---|
| 172 | typedef typename Lhs::InnerIterator LhsInnerIterator;
|
---|
| 173 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
|
---|
| 174 | {
|
---|
| 175 | for(Index c=0; c<rhs.cols(); ++c)
|
---|
| 176 | {
|
---|
| 177 | Index n = lhs.outerSize();
|
---|
| 178 | for(Index j=0; j<n; ++j)
|
---|
| 179 | {
|
---|
| 180 | typename Res::Scalar tmp(0);
|
---|
| 181 | for(LhsInnerIterator it(lhs,j); it ;++it)
|
---|
| 182 | tmp += it.value() * rhs.coeff(it.index(),c);
|
---|
| 183 | res.coeffRef(j,c) += alpha * tmp;
|
---|
| 184 | }
|
---|
| 185 | }
|
---|
| 186 | }
|
---|
| 187 | };
|
---|
| 188 |
|
---|
| 189 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
---|
| 190 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true>
|
---|
| 191 | {
|
---|
| 192 | typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
---|
| 193 | typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
---|
| 194 | typedef typename internal::remove_all<DenseResType>::type Res;
|
---|
| 195 | typedef typename Lhs::InnerIterator LhsInnerIterator;
|
---|
| 196 | typedef typename Lhs::Index Index;
|
---|
| 197 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
|
---|
| 198 | {
|
---|
| 199 | for(Index c=0; c<rhs.cols(); ++c)
|
---|
| 200 | {
|
---|
| 201 | for(Index j=0; j<lhs.outerSize(); ++j)
|
---|
| 202 | {
|
---|
| 203 | typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
|
---|
| 204 | for(LhsInnerIterator it(lhs,j); it ;++it)
|
---|
| 205 | res.coeffRef(it.index(),c) += it.value() * rhs_j;
|
---|
| 206 | }
|
---|
| 207 | }
|
---|
| 208 | }
|
---|
| 209 | };
|
---|
| 210 |
|
---|
| 211 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
---|
| 212 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false>
|
---|
| 213 | {
|
---|
| 214 | typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
---|
| 215 | typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
---|
| 216 | typedef typename internal::remove_all<DenseResType>::type Res;
|
---|
| 217 | typedef typename Lhs::InnerIterator LhsInnerIterator;
|
---|
| 218 | typedef typename Lhs::Index Index;
|
---|
| 219 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
|
---|
| 220 | {
|
---|
| 221 | for(Index j=0; j<lhs.outerSize(); ++j)
|
---|
| 222 | {
|
---|
| 223 | typename Res::RowXpr res_j(res.row(j));
|
---|
| 224 | for(LhsInnerIterator it(lhs,j); it ;++it)
|
---|
| 225 | res_j += (alpha*it.value()) * rhs.row(it.index());
|
---|
| 226 | }
|
---|
| 227 | }
|
---|
| 228 | };
|
---|
| 229 |
|
---|
| 230 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
---|
| 231 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false>
|
---|
| 232 | {
|
---|
| 233 | typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
---|
| 234 | typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
---|
| 235 | typedef typename internal::remove_all<DenseResType>::type Res;
|
---|
| 236 | typedef typename Lhs::InnerIterator LhsInnerIterator;
|
---|
| 237 | typedef typename Lhs::Index Index;
|
---|
| 238 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
|
---|
| 239 | {
|
---|
| 240 | for(Index j=0; j<lhs.outerSize(); ++j)
|
---|
| 241 | {
|
---|
| 242 | typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
|
---|
| 243 | for(LhsInnerIterator it(lhs,j); it ;++it)
|
---|
| 244 | res.row(it.index()) += (alpha*it.value()) * rhs_j;
|
---|
| 245 | }
|
---|
| 246 | }
|
---|
| 247 | };
|
---|
| 248 |
|
---|
| 249 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
|
---|
| 250 | inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
|
---|
| 251 | {
|
---|
| 252 | sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha);
|
---|
| 253 | }
|
---|
| 254 |
|
---|
| 255 | } // end namespace internal
|
---|
| 256 |
|
---|
| 257 | template<typename Lhs, typename Rhs>
|
---|
| 258 | class SparseTimeDenseProduct
|
---|
| 259 | : public ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs>
|
---|
| 260 | {
|
---|
| 261 | public:
|
---|
| 262 | EIGEN_PRODUCT_PUBLIC_INTERFACE(SparseTimeDenseProduct)
|
---|
| 263 |
|
---|
| 264 | SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
|
---|
| 265 | {}
|
---|
| 266 |
|
---|
| 267 | template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const
|
---|
| 268 | {
|
---|
| 269 | internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha);
|
---|
| 270 | }
|
---|
| 271 |
|
---|
| 272 | private:
|
---|
| 273 | SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&);
|
---|
| 274 | };
|
---|
| 275 |
|
---|
| 276 |
|
---|
| 277 | // dense = dense * sparse
|
---|
| 278 | namespace internal {
|
---|
| 279 | template<typename Lhs, typename Rhs>
|
---|
| 280 | struct traits<DenseTimeSparseProduct<Lhs,Rhs> >
|
---|
| 281 | : traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> >
|
---|
| 282 | {
|
---|
| 283 | typedef Dense StorageKind;
|
---|
| 284 | };
|
---|
| 285 | } // end namespace internal
|
---|
| 286 |
|
---|
| 287 | template<typename Lhs, typename Rhs>
|
---|
| 288 | class DenseTimeSparseProduct
|
---|
| 289 | : public ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs>
|
---|
| 290 | {
|
---|
| 291 | public:
|
---|
| 292 | EIGEN_PRODUCT_PUBLIC_INTERFACE(DenseTimeSparseProduct)
|
---|
| 293 |
|
---|
| 294 | DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
|
---|
| 295 | {}
|
---|
| 296 |
|
---|
| 297 | template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const
|
---|
| 298 | {
|
---|
| 299 | Transpose<const _LhsNested> lhs_t(m_lhs);
|
---|
| 300 | Transpose<const _RhsNested> rhs_t(m_rhs);
|
---|
| 301 | Transpose<Dest> dest_t(dest);
|
---|
| 302 | internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha);
|
---|
| 303 | }
|
---|
| 304 |
|
---|
| 305 | private:
|
---|
| 306 | DenseTimeSparseProduct& operator=(const DenseTimeSparseProduct&);
|
---|
| 307 | };
|
---|
| 308 |
|
---|
| 309 | } // end namespace Eigen
|
---|
| 310 |
|
---|
| 311 | #endif // EIGEN_SPARSEDENSEPRODUCT_H
|
---|