1 | // This file is part of Eigen, a lightweight C++ template library
|
---|
2 | // for linear algebra.
|
---|
3 | //
|
---|
4 | // Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
|
---|
5 | //
|
---|
6 | // This Source Code Form is subject to the terms of the Mozilla
|
---|
7 | // Public License v. 2.0. If a copy of the MPL was not distributed
|
---|
8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
---|
9 |
|
---|
10 | #ifndef EIGEN_SPARSEDENSEPRODUCT_H
|
---|
11 | #define EIGEN_SPARSEDENSEPRODUCT_H
|
---|
12 |
|
---|
13 | namespace Eigen {
|
---|
14 |
|
---|
15 | template<typename Lhs, typename Rhs, int InnerSize> struct SparseDenseProductReturnType
|
---|
16 | {
|
---|
17 | typedef SparseTimeDenseProduct<Lhs,Rhs> Type;
|
---|
18 | };
|
---|
19 |
|
---|
20 | template<typename Lhs, typename Rhs> struct SparseDenseProductReturnType<Lhs,Rhs,1>
|
---|
21 | {
|
---|
22 | typedef typename internal::conditional<
|
---|
23 | Lhs::IsRowMajor,
|
---|
24 | SparseDenseOuterProduct<Rhs,Lhs,true>,
|
---|
25 | SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type;
|
---|
26 | };
|
---|
27 |
|
---|
28 | template<typename Lhs, typename Rhs, int InnerSize> struct DenseSparseProductReturnType
|
---|
29 | {
|
---|
30 | typedef DenseTimeSparseProduct<Lhs,Rhs> Type;
|
---|
31 | };
|
---|
32 |
|
---|
33 | template<typename Lhs, typename Rhs> struct DenseSparseProductReturnType<Lhs,Rhs,1>
|
---|
34 | {
|
---|
35 | typedef typename internal::conditional<
|
---|
36 | Rhs::IsRowMajor,
|
---|
37 | SparseDenseOuterProduct<Rhs,Lhs,true>,
|
---|
38 | SparseDenseOuterProduct<Lhs,Rhs,false> >::type Type;
|
---|
39 | };
|
---|
40 |
|
---|
41 | namespace internal {
|
---|
42 |
|
---|
43 | template<typename Lhs, typename Rhs, bool Tr>
|
---|
44 | struct traits<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
|
---|
45 | {
|
---|
46 | typedef Sparse StorageKind;
|
---|
47 | typedef typename scalar_product_traits<typename traits<Lhs>::Scalar,
|
---|
48 | typename traits<Rhs>::Scalar>::ReturnType Scalar;
|
---|
49 | typedef typename Lhs::Index Index;
|
---|
50 | typedef typename Lhs::Nested LhsNested;
|
---|
51 | typedef typename Rhs::Nested RhsNested;
|
---|
52 | typedef typename remove_all<LhsNested>::type _LhsNested;
|
---|
53 | typedef typename remove_all<RhsNested>::type _RhsNested;
|
---|
54 |
|
---|
55 | enum {
|
---|
56 | LhsCoeffReadCost = traits<_LhsNested>::CoeffReadCost,
|
---|
57 | RhsCoeffReadCost = traits<_RhsNested>::CoeffReadCost,
|
---|
58 |
|
---|
59 | RowsAtCompileTime = Tr ? int(traits<Rhs>::RowsAtCompileTime) : int(traits<Lhs>::RowsAtCompileTime),
|
---|
60 | ColsAtCompileTime = Tr ? int(traits<Lhs>::ColsAtCompileTime) : int(traits<Rhs>::ColsAtCompileTime),
|
---|
61 | MaxRowsAtCompileTime = Tr ? int(traits<Rhs>::MaxRowsAtCompileTime) : int(traits<Lhs>::MaxRowsAtCompileTime),
|
---|
62 | MaxColsAtCompileTime = Tr ? int(traits<Lhs>::MaxColsAtCompileTime) : int(traits<Rhs>::MaxColsAtCompileTime),
|
---|
63 |
|
---|
64 | Flags = Tr ? RowMajorBit : 0,
|
---|
65 |
|
---|
66 | CoeffReadCost = LhsCoeffReadCost + RhsCoeffReadCost + NumTraits<Scalar>::MulCost
|
---|
67 | };
|
---|
68 | };
|
---|
69 |
|
---|
70 | } // end namespace internal
|
---|
71 |
|
---|
72 | template<typename Lhs, typename Rhs, bool Tr>
|
---|
73 | class SparseDenseOuterProduct
|
---|
74 | : public SparseMatrixBase<SparseDenseOuterProduct<Lhs,Rhs,Tr> >
|
---|
75 | {
|
---|
76 | public:
|
---|
77 |
|
---|
78 | typedef SparseMatrixBase<SparseDenseOuterProduct> Base;
|
---|
79 | EIGEN_DENSE_PUBLIC_INTERFACE(SparseDenseOuterProduct)
|
---|
80 | typedef internal::traits<SparseDenseOuterProduct> Traits;
|
---|
81 |
|
---|
82 | private:
|
---|
83 |
|
---|
84 | typedef typename Traits::LhsNested LhsNested;
|
---|
85 | typedef typename Traits::RhsNested RhsNested;
|
---|
86 | typedef typename Traits::_LhsNested _LhsNested;
|
---|
87 | typedef typename Traits::_RhsNested _RhsNested;
|
---|
88 |
|
---|
89 | public:
|
---|
90 |
|
---|
91 | class InnerIterator;
|
---|
92 |
|
---|
93 | EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Lhs& lhs, const Rhs& rhs)
|
---|
94 | : m_lhs(lhs), m_rhs(rhs)
|
---|
95 | {
|
---|
96 | EIGEN_STATIC_ASSERT(!Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
|
---|
97 | }
|
---|
98 |
|
---|
99 | EIGEN_STRONG_INLINE SparseDenseOuterProduct(const Rhs& rhs, const Lhs& lhs)
|
---|
100 | : m_lhs(lhs), m_rhs(rhs)
|
---|
101 | {
|
---|
102 | EIGEN_STATIC_ASSERT(Tr,YOU_MADE_A_PROGRAMMING_MISTAKE);
|
---|
103 | }
|
---|
104 |
|
---|
105 | EIGEN_STRONG_INLINE Index rows() const { return Tr ? m_rhs.rows() : m_lhs.rows(); }
|
---|
106 | EIGEN_STRONG_INLINE Index cols() const { return Tr ? m_lhs.cols() : m_rhs.cols(); }
|
---|
107 |
|
---|
108 | EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
|
---|
109 | EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
|
---|
110 |
|
---|
111 | protected:
|
---|
112 | LhsNested m_lhs;
|
---|
113 | RhsNested m_rhs;
|
---|
114 | };
|
---|
115 |
|
---|
116 | template<typename Lhs, typename Rhs, bool Transpose>
|
---|
117 | class SparseDenseOuterProduct<Lhs,Rhs,Transpose>::InnerIterator : public _LhsNested::InnerIterator
|
---|
118 | {
|
---|
119 | typedef typename _LhsNested::InnerIterator Base;
|
---|
120 | typedef typename SparseDenseOuterProduct::Index Index;
|
---|
121 | public:
|
---|
122 | EIGEN_STRONG_INLINE InnerIterator(const SparseDenseOuterProduct& prod, Index outer)
|
---|
123 | : Base(prod.lhs(), 0), m_outer(outer), m_factor(get(prod.rhs(), outer, typename internal::traits<Rhs>::StorageKind() ))
|
---|
124 | { }
|
---|
125 |
|
---|
126 | inline Index outer() const { return m_outer; }
|
---|
127 | inline Index row() const { return Transpose ? m_outer : Base::index(); }
|
---|
128 | inline Index col() const { return Transpose ? Base::index() : m_outer; }
|
---|
129 |
|
---|
130 | inline Scalar value() const { return Base::value() * m_factor; }
|
---|
131 |
|
---|
132 | protected:
|
---|
133 | static Scalar get(const _RhsNested &rhs, Index outer, Dense = Dense())
|
---|
134 | {
|
---|
135 | return rhs.coeff(outer);
|
---|
136 | }
|
---|
137 |
|
---|
138 | static Scalar get(const _RhsNested &rhs, Index outer, Sparse = Sparse())
|
---|
139 | {
|
---|
140 | typename Traits::_RhsNested::InnerIterator it(rhs, outer);
|
---|
141 | if (it && it.index()==0)
|
---|
142 | return it.value();
|
---|
143 |
|
---|
144 | return Scalar(0);
|
---|
145 | }
|
---|
146 |
|
---|
147 | Index m_outer;
|
---|
148 | Scalar m_factor;
|
---|
149 | };
|
---|
150 |
|
---|
151 | namespace internal {
|
---|
152 | template<typename Lhs, typename Rhs>
|
---|
153 | struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
|
---|
154 | : traits<ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs> >
|
---|
155 | {
|
---|
156 | typedef Dense StorageKind;
|
---|
157 | typedef MatrixXpr XprKind;
|
---|
158 | };
|
---|
159 |
|
---|
160 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
|
---|
161 | int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
|
---|
162 | bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
|
---|
163 | struct sparse_time_dense_product_impl;
|
---|
164 |
|
---|
165 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
---|
166 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true>
|
---|
167 | {
|
---|
168 | typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
---|
169 | typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
---|
170 | typedef typename internal::remove_all<DenseResType>::type Res;
|
---|
171 | typedef typename Lhs::Index Index;
|
---|
172 | typedef typename Lhs::InnerIterator LhsInnerIterator;
|
---|
173 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
|
---|
174 | {
|
---|
175 | for(Index c=0; c<rhs.cols(); ++c)
|
---|
176 | {
|
---|
177 | Index n = lhs.outerSize();
|
---|
178 | for(Index j=0; j<n; ++j)
|
---|
179 | {
|
---|
180 | typename Res::Scalar tmp(0);
|
---|
181 | for(LhsInnerIterator it(lhs,j); it ;++it)
|
---|
182 | tmp += it.value() * rhs.coeff(it.index(),c);
|
---|
183 | res.coeffRef(j,c) += alpha * tmp;
|
---|
184 | }
|
---|
185 | }
|
---|
186 | }
|
---|
187 | };
|
---|
188 |
|
---|
189 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
---|
190 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true>
|
---|
191 | {
|
---|
192 | typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
---|
193 | typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
---|
194 | typedef typename internal::remove_all<DenseResType>::type Res;
|
---|
195 | typedef typename Lhs::InnerIterator LhsInnerIterator;
|
---|
196 | typedef typename Lhs::Index Index;
|
---|
197 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
|
---|
198 | {
|
---|
199 | for(Index c=0; c<rhs.cols(); ++c)
|
---|
200 | {
|
---|
201 | for(Index j=0; j<lhs.outerSize(); ++j)
|
---|
202 | {
|
---|
203 | typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
|
---|
204 | for(LhsInnerIterator it(lhs,j); it ;++it)
|
---|
205 | res.coeffRef(it.index(),c) += it.value() * rhs_j;
|
---|
206 | }
|
---|
207 | }
|
---|
208 | }
|
---|
209 | };
|
---|
210 |
|
---|
211 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
---|
212 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false>
|
---|
213 | {
|
---|
214 | typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
---|
215 | typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
---|
216 | typedef typename internal::remove_all<DenseResType>::type Res;
|
---|
217 | typedef typename Lhs::InnerIterator LhsInnerIterator;
|
---|
218 | typedef typename Lhs::Index Index;
|
---|
219 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
|
---|
220 | {
|
---|
221 | for(Index j=0; j<lhs.outerSize(); ++j)
|
---|
222 | {
|
---|
223 | typename Res::RowXpr res_j(res.row(j));
|
---|
224 | for(LhsInnerIterator it(lhs,j); it ;++it)
|
---|
225 | res_j += (alpha*it.value()) * rhs.row(it.index());
|
---|
226 | }
|
---|
227 | }
|
---|
228 | };
|
---|
229 |
|
---|
230 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
---|
231 | struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false>
|
---|
232 | {
|
---|
233 | typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
---|
234 | typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
---|
235 | typedef typename internal::remove_all<DenseResType>::type Res;
|
---|
236 | typedef typename Lhs::InnerIterator LhsInnerIterator;
|
---|
237 | typedef typename Lhs::Index Index;
|
---|
238 | static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
|
---|
239 | {
|
---|
240 | for(Index j=0; j<lhs.outerSize(); ++j)
|
---|
241 | {
|
---|
242 | typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
|
---|
243 | for(LhsInnerIterator it(lhs,j); it ;++it)
|
---|
244 | res.row(it.index()) += (alpha*it.value()) * rhs_j;
|
---|
245 | }
|
---|
246 | }
|
---|
247 | };
|
---|
248 |
|
---|
249 | template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
|
---|
250 | inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
|
---|
251 | {
|
---|
252 | sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha);
|
---|
253 | }
|
---|
254 |
|
---|
255 | } // end namespace internal
|
---|
256 |
|
---|
257 | template<typename Lhs, typename Rhs>
|
---|
258 | class SparseTimeDenseProduct
|
---|
259 | : public ProductBase<SparseTimeDenseProduct<Lhs,Rhs>, Lhs, Rhs>
|
---|
260 | {
|
---|
261 | public:
|
---|
262 | EIGEN_PRODUCT_PUBLIC_INTERFACE(SparseTimeDenseProduct)
|
---|
263 |
|
---|
264 | SparseTimeDenseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
|
---|
265 | {}
|
---|
266 |
|
---|
267 | template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const
|
---|
268 | {
|
---|
269 | internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha);
|
---|
270 | }
|
---|
271 |
|
---|
272 | private:
|
---|
273 | SparseTimeDenseProduct& operator=(const SparseTimeDenseProduct&);
|
---|
274 | };
|
---|
275 |
|
---|
276 |
|
---|
277 | // dense = dense * sparse
|
---|
278 | namespace internal {
|
---|
279 | template<typename Lhs, typename Rhs>
|
---|
280 | struct traits<DenseTimeSparseProduct<Lhs,Rhs> >
|
---|
281 | : traits<ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs> >
|
---|
282 | {
|
---|
283 | typedef Dense StorageKind;
|
---|
284 | };
|
---|
285 | } // end namespace internal
|
---|
286 |
|
---|
287 | template<typename Lhs, typename Rhs>
|
---|
288 | class DenseTimeSparseProduct
|
---|
289 | : public ProductBase<DenseTimeSparseProduct<Lhs,Rhs>, Lhs, Rhs>
|
---|
290 | {
|
---|
291 | public:
|
---|
292 | EIGEN_PRODUCT_PUBLIC_INTERFACE(DenseTimeSparseProduct)
|
---|
293 |
|
---|
294 | DenseTimeSparseProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs)
|
---|
295 | {}
|
---|
296 |
|
---|
297 | template<typename Dest> void scaleAndAddTo(Dest& dest, const Scalar& alpha) const
|
---|
298 | {
|
---|
299 | Transpose<const _LhsNested> lhs_t(m_lhs);
|
---|
300 | Transpose<const _RhsNested> rhs_t(m_rhs);
|
---|
301 | Transpose<Dest> dest_t(dest);
|
---|
302 | internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha);
|
---|
303 | }
|
---|
304 |
|
---|
305 | private:
|
---|
306 | DenseTimeSparseProduct& operator=(const DenseTimeSparseProduct&);
|
---|
307 | };
|
---|
308 |
|
---|
309 | } // end namespace Eigen
|
---|
310 |
|
---|
311 | #endif // EIGEN_SPARSEDENSEPRODUCT_H
|
---|