[136] | 1 | // This file is part of Eigen, a lightweight C++ template library
|
---|
| 2 | // for linear algebra.
|
---|
| 3 | //
|
---|
| 4 | // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
|
---|
| 5 | // Copyright (C) 2007-2009 Benoit Jacob <jacob.benoit.1@gmail.com>
|
---|
| 6 | //
|
---|
| 7 | // This Source Code Form is subject to the terms of the Mozilla
|
---|
| 8 | // Public License v. 2.0. If a copy of the MPL was not distributed
|
---|
| 9 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
---|
| 10 |
|
---|
| 11 | #ifndef EIGEN_DIAGONALPRODUCT_H
|
---|
| 12 | #define EIGEN_DIAGONALPRODUCT_H
|
---|
| 13 |
|
---|
| 14 | namespace Eigen {
|
---|
| 15 |
|
---|
| 16 | namespace internal {
|
---|
| 17 | template<typename MatrixType, typename DiagonalType, int ProductOrder>
|
---|
| 18 | struct traits<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
|
---|
| 19 | : traits<MatrixType>
|
---|
| 20 | {
|
---|
| 21 | typedef typename scalar_product_traits<typename MatrixType::Scalar, typename DiagonalType::Scalar>::ReturnType Scalar;
|
---|
| 22 | enum {
|
---|
| 23 | RowsAtCompileTime = MatrixType::RowsAtCompileTime,
|
---|
| 24 | ColsAtCompileTime = MatrixType::ColsAtCompileTime,
|
---|
| 25 | MaxRowsAtCompileTime = MatrixType::MaxRowsAtCompileTime,
|
---|
| 26 | MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime,
|
---|
| 27 |
|
---|
| 28 | _StorageOrder = MatrixType::Flags & RowMajorBit ? RowMajor : ColMajor,
|
---|
| 29 | _ScalarAccessOnDiag = !((int(_StorageOrder) == ColMajor && int(ProductOrder) == OnTheLeft)
|
---|
| 30 | ||(int(_StorageOrder) == RowMajor && int(ProductOrder) == OnTheRight)),
|
---|
| 31 | _SameTypes = is_same<typename MatrixType::Scalar, typename DiagonalType::Scalar>::value,
|
---|
| 32 | // FIXME currently we need same types, but in the future the next rule should be the one
|
---|
| 33 | //_Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && ((!_PacketOnDiag) || (_SameTypes && bool(int(DiagonalType::DiagonalVectorType::Flags)&PacketAccessBit))),
|
---|
| 34 | _Vectorizable = bool(int(MatrixType::Flags)&PacketAccessBit) && _SameTypes && (_ScalarAccessOnDiag || (bool(int(DiagonalType::DiagonalVectorType::Flags)&PacketAccessBit))),
|
---|
| 35 | _LinearAccessMask = (RowsAtCompileTime==1 || ColsAtCompileTime==1) ? LinearAccessBit : 0,
|
---|
| 36 |
|
---|
| 37 | Flags = ((HereditaryBits|_LinearAccessMask|AlignedBit) & (unsigned int)(MatrixType::Flags)) | (_Vectorizable ? PacketAccessBit : 0),//(int(MatrixType::Flags)&int(DiagonalType::DiagonalVectorType::Flags)&AlignedBit),
|
---|
| 38 | Cost0 = EIGEN_ADD_COST(NumTraits<Scalar>::MulCost, MatrixType::CoeffReadCost),
|
---|
| 39 | CoeffReadCost = EIGEN_ADD_COST(Cost0,DiagonalType::DiagonalVectorType::CoeffReadCost)
|
---|
| 40 | };
|
---|
| 41 | };
|
---|
| 42 | }
|
---|
| 43 |
|
---|
| 44 | template<typename MatrixType, typename DiagonalType, int ProductOrder>
|
---|
| 45 | class DiagonalProduct : internal::no_assignment_operator,
|
---|
| 46 | public MatrixBase<DiagonalProduct<MatrixType, DiagonalType, ProductOrder> >
|
---|
| 47 | {
|
---|
| 48 | public:
|
---|
| 49 |
|
---|
| 50 | typedef MatrixBase<DiagonalProduct> Base;
|
---|
| 51 | EIGEN_DENSE_PUBLIC_INTERFACE(DiagonalProduct)
|
---|
| 52 |
|
---|
| 53 | inline DiagonalProduct(const MatrixType& matrix, const DiagonalType& diagonal)
|
---|
| 54 | : m_matrix(matrix), m_diagonal(diagonal)
|
---|
| 55 | {
|
---|
| 56 | eigen_assert(diagonal.diagonal().size() == (ProductOrder == OnTheLeft ? matrix.rows() : matrix.cols()));
|
---|
| 57 | }
|
---|
| 58 |
|
---|
| 59 | EIGEN_STRONG_INLINE Index rows() const { return m_matrix.rows(); }
|
---|
| 60 | EIGEN_STRONG_INLINE Index cols() const { return m_matrix.cols(); }
|
---|
| 61 |
|
---|
| 62 | EIGEN_STRONG_INLINE const Scalar coeff(Index row, Index col) const
|
---|
| 63 | {
|
---|
| 64 | return m_diagonal.diagonal().coeff(ProductOrder == OnTheLeft ? row : col) * m_matrix.coeff(row, col);
|
---|
| 65 | }
|
---|
| 66 |
|
---|
| 67 | EIGEN_STRONG_INLINE const Scalar coeff(Index idx) const
|
---|
| 68 | {
|
---|
| 69 | enum {
|
---|
| 70 | StorageOrder = int(MatrixType::Flags) & RowMajorBit ? RowMajor : ColMajor
|
---|
| 71 | };
|
---|
| 72 | return coeff(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
|
---|
| 73 | }
|
---|
| 74 |
|
---|
| 75 | template<int LoadMode>
|
---|
| 76 | EIGEN_STRONG_INLINE PacketScalar packet(Index row, Index col) const
|
---|
| 77 | {
|
---|
| 78 | enum {
|
---|
| 79 | StorageOrder = Flags & RowMajorBit ? RowMajor : ColMajor
|
---|
| 80 | };
|
---|
| 81 | const Index indexInDiagonalVector = ProductOrder == OnTheLeft ? row : col;
|
---|
| 82 | return packet_impl<LoadMode>(row,col,indexInDiagonalVector,typename internal::conditional<
|
---|
| 83 | ((int(StorageOrder) == RowMajor && int(ProductOrder) == OnTheLeft)
|
---|
| 84 | ||(int(StorageOrder) == ColMajor && int(ProductOrder) == OnTheRight)), internal::true_type, internal::false_type>::type());
|
---|
| 85 | }
|
---|
| 86 |
|
---|
| 87 | template<int LoadMode>
|
---|
| 88 | EIGEN_STRONG_INLINE PacketScalar packet(Index idx) const
|
---|
| 89 | {
|
---|
| 90 | enum {
|
---|
| 91 | StorageOrder = int(MatrixType::Flags) & RowMajorBit ? RowMajor : ColMajor
|
---|
| 92 | };
|
---|
| 93 | return packet<LoadMode>(int(StorageOrder)==ColMajor?idx:0,int(StorageOrder)==ColMajor?0:idx);
|
---|
| 94 | }
|
---|
| 95 |
|
---|
| 96 | protected:
|
---|
| 97 | template<int LoadMode>
|
---|
| 98 | EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::true_type) const
|
---|
| 99 | {
|
---|
| 100 | return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
|
---|
| 101 | internal::pset1<PacketScalar>(m_diagonal.diagonal().coeff(id)));
|
---|
| 102 | }
|
---|
| 103 |
|
---|
| 104 | template<int LoadMode>
|
---|
| 105 | EIGEN_STRONG_INLINE PacketScalar packet_impl(Index row, Index col, Index id, internal::false_type) const
|
---|
| 106 | {
|
---|
| 107 | enum {
|
---|
| 108 | InnerSize = (MatrixType::Flags & RowMajorBit) ? MatrixType::ColsAtCompileTime : MatrixType::RowsAtCompileTime,
|
---|
| 109 | DiagonalVectorPacketLoadMode = (LoadMode == Aligned && (((InnerSize%16) == 0) || (int(DiagonalType::DiagonalVectorType::Flags)&AlignedBit)==AlignedBit) ? Aligned : Unaligned)
|
---|
| 110 | };
|
---|
| 111 | return internal::pmul(m_matrix.template packet<LoadMode>(row, col),
|
---|
| 112 | m_diagonal.diagonal().template packet<DiagonalVectorPacketLoadMode>(id));
|
---|
| 113 | }
|
---|
| 114 |
|
---|
| 115 | typename MatrixType::Nested m_matrix;
|
---|
| 116 | typename DiagonalType::Nested m_diagonal;
|
---|
| 117 | };
|
---|
| 118 |
|
---|
| 119 | /** \returns the diagonal matrix product of \c *this by the diagonal matrix \a diagonal.
|
---|
| 120 | */
|
---|
| 121 | template<typename Derived>
|
---|
| 122 | template<typename DiagonalDerived>
|
---|
| 123 | inline const DiagonalProduct<Derived, DiagonalDerived, OnTheRight>
|
---|
| 124 | MatrixBase<Derived>::operator*(const DiagonalBase<DiagonalDerived> &a_diagonal) const
|
---|
| 125 | {
|
---|
| 126 | return DiagonalProduct<Derived, DiagonalDerived, OnTheRight>(derived(), a_diagonal.derived());
|
---|
| 127 | }
|
---|
| 128 |
|
---|
| 129 | } // end namespace Eigen
|
---|
| 130 |
|
---|
| 131 | #endif // EIGEN_DIAGONALPRODUCT_H
|
---|