1 | // This file is part of Eigen, a lightweight C++ template library
|
---|
2 | // for linear algebra.
|
---|
3 | //
|
---|
4 | // Copyright (C) 2011 Kolja Brix <brix@igpm.rwth-aachen.de>
|
---|
5 | // Copyright (C) 2011 Andreas Platen <andiplaten@gmx.de>
|
---|
6 | // Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net>
|
---|
7 | //
|
---|
8 | // This Source Code Form is subject to the terms of the Mozilla
|
---|
9 | // Public License v. 2.0. If a copy of the MPL was not distributed
|
---|
10 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
---|
11 |
|
---|
12 | #ifndef KRONECKER_TENSOR_PRODUCT_H
|
---|
13 | #define KRONECKER_TENSOR_PRODUCT_H
|
---|
14 |
|
---|
15 | namespace Eigen {
|
---|
16 |
|
---|
17 | template<typename Scalar, int Options, typename Index> class SparseMatrix;
|
---|
18 |
|
---|
19 | /*!
|
---|
20 | * \brief Kronecker tensor product helper class for dense matrices
|
---|
21 | *
|
---|
22 | * This class is the return value of kroneckerProduct(MatrixBase,
|
---|
23 | * MatrixBase). Use the function rather than construct this class
|
---|
24 | * directly to avoid specifying template prarameters.
|
---|
25 | *
|
---|
26 | * \tparam Lhs Type of the left-hand side, a matrix expression.
|
---|
27 | * \tparam Rhs Type of the rignt-hand side, a matrix expression.
|
---|
28 | */
|
---|
29 | template<typename Lhs, typename Rhs>
|
---|
30 | class KroneckerProduct : public ReturnByValue<KroneckerProduct<Lhs,Rhs> >
|
---|
31 | {
|
---|
32 | private:
|
---|
33 | typedef ReturnByValue<KroneckerProduct> Base;
|
---|
34 | typedef typename Base::Scalar Scalar;
|
---|
35 | typedef typename Base::Index Index;
|
---|
36 |
|
---|
37 | public:
|
---|
38 | /*! \brief Constructor. */
|
---|
39 | KroneckerProduct(const Lhs& A, const Rhs& B)
|
---|
40 | : m_A(A), m_B(B)
|
---|
41 | {}
|
---|
42 |
|
---|
43 | /*! \brief Evaluate the Kronecker tensor product. */
|
---|
44 | template<typename Dest> void evalTo(Dest& dst) const;
|
---|
45 |
|
---|
46 | inline Index rows() const { return m_A.rows() * m_B.rows(); }
|
---|
47 | inline Index cols() const { return m_A.cols() * m_B.cols(); }
|
---|
48 |
|
---|
49 | Scalar coeff(Index row, Index col) const
|
---|
50 | {
|
---|
51 | return m_A.coeff(row / m_B.rows(), col / m_B.cols()) *
|
---|
52 | m_B.coeff(row % m_B.rows(), col % m_B.cols());
|
---|
53 | }
|
---|
54 |
|
---|
55 | Scalar coeff(Index i) const
|
---|
56 | {
|
---|
57 | EIGEN_STATIC_ASSERT_VECTOR_ONLY(KroneckerProduct);
|
---|
58 | return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
|
---|
59 | }
|
---|
60 |
|
---|
61 | private:
|
---|
62 | typename Lhs::Nested m_A;
|
---|
63 | typename Rhs::Nested m_B;
|
---|
64 | };
|
---|
65 |
|
---|
66 | /*!
|
---|
67 | * \brief Kronecker tensor product helper class for sparse matrices
|
---|
68 | *
|
---|
69 | * If at least one of the operands is a sparse matrix expression,
|
---|
70 | * then this class is returned and evaluates into a sparse matrix.
|
---|
71 | *
|
---|
72 | * This class is the return value of kroneckerProduct(EigenBase,
|
---|
73 | * EigenBase). Use the function rather than construct this class
|
---|
74 | * directly to avoid specifying template prarameters.
|
---|
75 | *
|
---|
76 | * \tparam Lhs Type of the left-hand side, a matrix expression.
|
---|
77 | * \tparam Rhs Type of the rignt-hand side, a matrix expression.
|
---|
78 | */
|
---|
79 | template<typename Lhs, typename Rhs>
|
---|
80 | class KroneckerProductSparse : public EigenBase<KroneckerProductSparse<Lhs,Rhs> >
|
---|
81 | {
|
---|
82 | private:
|
---|
83 | typedef typename internal::traits<KroneckerProductSparse>::Index Index;
|
---|
84 |
|
---|
85 | public:
|
---|
86 | /*! \brief Constructor. */
|
---|
87 | KroneckerProductSparse(const Lhs& A, const Rhs& B)
|
---|
88 | : m_A(A), m_B(B)
|
---|
89 | {}
|
---|
90 |
|
---|
91 | /*! \brief Evaluate the Kronecker tensor product. */
|
---|
92 | template<typename Dest> void evalTo(Dest& dst) const;
|
---|
93 |
|
---|
94 | inline Index rows() const { return m_A.rows() * m_B.rows(); }
|
---|
95 | inline Index cols() const { return m_A.cols() * m_B.cols(); }
|
---|
96 |
|
---|
97 | template<typename Scalar, int Options, typename Index>
|
---|
98 | operator SparseMatrix<Scalar, Options, Index>()
|
---|
99 | {
|
---|
100 | SparseMatrix<Scalar, Options, Index> result;
|
---|
101 | evalTo(result.derived());
|
---|
102 | return result;
|
---|
103 | }
|
---|
104 |
|
---|
105 | private:
|
---|
106 | typename Lhs::Nested m_A;
|
---|
107 | typename Rhs::Nested m_B;
|
---|
108 | };
|
---|
109 |
|
---|
110 | template<typename Lhs, typename Rhs>
|
---|
111 | template<typename Dest>
|
---|
112 | void KroneckerProduct<Lhs,Rhs>::evalTo(Dest& dst) const
|
---|
113 | {
|
---|
114 | const int BlockRows = Rhs::RowsAtCompileTime,
|
---|
115 | BlockCols = Rhs::ColsAtCompileTime;
|
---|
116 | const Index Br = m_B.rows(),
|
---|
117 | Bc = m_B.cols();
|
---|
118 | for (Index i=0; i < m_A.rows(); ++i)
|
---|
119 | for (Index j=0; j < m_A.cols(); ++j)
|
---|
120 | Block<Dest,BlockRows,BlockCols>(dst,i*Br,j*Bc,Br,Bc) = m_A.coeff(i,j) * m_B;
|
---|
121 | }
|
---|
122 |
|
---|
123 | template<typename Lhs, typename Rhs>
|
---|
124 | template<typename Dest>
|
---|
125 | void KroneckerProductSparse<Lhs,Rhs>::evalTo(Dest& dst) const
|
---|
126 | {
|
---|
127 | const Index Br = m_B.rows(),
|
---|
128 | Bc = m_B.cols();
|
---|
129 | dst.resize(rows(),cols());
|
---|
130 | dst.resizeNonZeros(0);
|
---|
131 | dst.reserve(m_A.nonZeros() * m_B.nonZeros());
|
---|
132 |
|
---|
133 | for (Index kA=0; kA < m_A.outerSize(); ++kA)
|
---|
134 | {
|
---|
135 | for (Index kB=0; kB < m_B.outerSize(); ++kB)
|
---|
136 | {
|
---|
137 | for (typename Lhs::InnerIterator itA(m_A,kA); itA; ++itA)
|
---|
138 | {
|
---|
139 | for (typename Rhs::InnerIterator itB(m_B,kB); itB; ++itB)
|
---|
140 | {
|
---|
141 | const Index i = itA.row() * Br + itB.row(),
|
---|
142 | j = itA.col() * Bc + itB.col();
|
---|
143 | dst.insert(i,j) = itA.value() * itB.value();
|
---|
144 | }
|
---|
145 | }
|
---|
146 | }
|
---|
147 | }
|
---|
148 | }
|
---|
149 |
|
---|
150 | namespace internal {
|
---|
151 |
|
---|
152 | template<typename _Lhs, typename _Rhs>
|
---|
153 | struct traits<KroneckerProduct<_Lhs,_Rhs> >
|
---|
154 | {
|
---|
155 | typedef typename remove_all<_Lhs>::type Lhs;
|
---|
156 | typedef typename remove_all<_Rhs>::type Rhs;
|
---|
157 | typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
---|
158 |
|
---|
159 | enum {
|
---|
160 | Rows = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
|
---|
161 | Cols = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
|
---|
162 | MaxRows = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
|
---|
163 | MaxCols = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
|
---|
164 | CoeffReadCost = Lhs::CoeffReadCost + Rhs::CoeffReadCost + NumTraits<Scalar>::MulCost
|
---|
165 | };
|
---|
166 |
|
---|
167 | typedef Matrix<Scalar,Rows,Cols> ReturnType;
|
---|
168 | };
|
---|
169 |
|
---|
170 | template<typename _Lhs, typename _Rhs>
|
---|
171 | struct traits<KroneckerProductSparse<_Lhs,_Rhs> >
|
---|
172 | {
|
---|
173 | typedef MatrixXpr XprKind;
|
---|
174 | typedef typename remove_all<_Lhs>::type Lhs;
|
---|
175 | typedef typename remove_all<_Rhs>::type Rhs;
|
---|
176 | typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
---|
177 | typedef typename promote_storage_type<typename traits<Lhs>::StorageKind, typename traits<Rhs>::StorageKind>::ret StorageKind;
|
---|
178 | typedef typename promote_index_type<typename Lhs::Index, typename Rhs::Index>::type Index;
|
---|
179 |
|
---|
180 | enum {
|
---|
181 | LhsFlags = Lhs::Flags,
|
---|
182 | RhsFlags = Rhs::Flags,
|
---|
183 |
|
---|
184 | RowsAtCompileTime = size_at_compile_time<traits<Lhs>::RowsAtCompileTime, traits<Rhs>::RowsAtCompileTime>::ret,
|
---|
185 | ColsAtCompileTime = size_at_compile_time<traits<Lhs>::ColsAtCompileTime, traits<Rhs>::ColsAtCompileTime>::ret,
|
---|
186 | MaxRowsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxRowsAtCompileTime, traits<Rhs>::MaxRowsAtCompileTime>::ret,
|
---|
187 | MaxColsAtCompileTime = size_at_compile_time<traits<Lhs>::MaxColsAtCompileTime, traits<Rhs>::MaxColsAtCompileTime>::ret,
|
---|
188 |
|
---|
189 | EvalToRowMajor = (LhsFlags & RhsFlags & RowMajorBit),
|
---|
190 | RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
|
---|
191 |
|
---|
192 | Flags = ((LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
|
---|
193 | | EvalBeforeNestingBit | EvalBeforeAssigningBit,
|
---|
194 | CoeffReadCost = Dynamic
|
---|
195 | };
|
---|
196 | };
|
---|
197 |
|
---|
198 | } // end namespace internal
|
---|
199 |
|
---|
200 | /*!
|
---|
201 | * \ingroup KroneckerProduct_Module
|
---|
202 | *
|
---|
203 | * Computes Kronecker tensor product of two dense matrices
|
---|
204 | *
|
---|
205 | * \warning If you want to replace a matrix by its Kronecker product
|
---|
206 | * with some matrix, do \b NOT do this:
|
---|
207 | * \code
|
---|
208 | * A = kroneckerProduct(A,B); // bug!!! caused by aliasing effect
|
---|
209 | * \endcode
|
---|
210 | * instead, use eval() to work around this:
|
---|
211 | * \code
|
---|
212 | * A = kroneckerProduct(A,B).eval();
|
---|
213 | * \endcode
|
---|
214 | *
|
---|
215 | * \param a Dense matrix a
|
---|
216 | * \param b Dense matrix b
|
---|
217 | * \return Kronecker tensor product of a and b
|
---|
218 | */
|
---|
219 | template<typename A, typename B>
|
---|
220 | KroneckerProduct<A,B> kroneckerProduct(const MatrixBase<A>& a, const MatrixBase<B>& b)
|
---|
221 | {
|
---|
222 | return KroneckerProduct<A, B>(a.derived(), b.derived());
|
---|
223 | }
|
---|
224 |
|
---|
225 | /*!
|
---|
226 | * \ingroup KroneckerProduct_Module
|
---|
227 | *
|
---|
228 | * Computes Kronecker tensor product of two matrices, at least one of
|
---|
229 | * which is sparse
|
---|
230 | *
|
---|
231 | * \param a Dense/sparse matrix a
|
---|
232 | * \param b Dense/sparse matrix b
|
---|
233 | * \return Kronecker tensor product of a and b, stored in a sparse
|
---|
234 | * matrix
|
---|
235 | */
|
---|
236 | template<typename A, typename B>
|
---|
237 | KroneckerProductSparse<A,B> kroneckerProduct(const EigenBase<A>& a, const EigenBase<B>& b)
|
---|
238 | {
|
---|
239 | return KroneckerProductSparse<A,B>(a.derived(), b.derived());
|
---|
240 | }
|
---|
241 |
|
---|
242 | } // end namespace Eigen
|
---|
243 |
|
---|
244 | #endif // KRONECKER_TENSOR_PRODUCT_H
|
---|