[136] | 1 | // This file is part of Eigen, a lightweight C++ template library
|
---|
| 2 | // for linear algebra.
|
---|
| 3 | //
|
---|
| 4 | // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr>
|
---|
| 5 | //
|
---|
| 6 | // This Source Code Form is subject to the terms of the Mozilla
|
---|
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed
|
---|
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
---|
| 9 |
|
---|
| 10 | #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
|
---|
| 11 | #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
|
---|
| 12 |
|
---|
| 13 | namespace Eigen {
|
---|
| 14 |
|
---|
| 15 | namespace internal {
|
---|
| 16 |
|
---|
| 17 |
|
---|
| 18 | // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
|
---|
| 19 | template<typename Lhs, typename Rhs, typename ResultType>
|
---|
| 20 | static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
|
---|
| 21 | {
|
---|
| 22 | // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
|
---|
| 23 |
|
---|
| 24 | typedef typename remove_all<Lhs>::type::Scalar Scalar;
|
---|
| 25 | typedef typename remove_all<Lhs>::type::Index Index;
|
---|
| 26 |
|
---|
| 27 | // make sure to call innerSize/outerSize since we fake the storage order.
|
---|
| 28 | Index rows = lhs.innerSize();
|
---|
| 29 | Index cols = rhs.outerSize();
|
---|
| 30 | //Index size = lhs.outerSize();
|
---|
| 31 | eigen_assert(lhs.outerSize() == rhs.innerSize());
|
---|
| 32 |
|
---|
| 33 | // allocate a temporary buffer
|
---|
| 34 | AmbiVector<Scalar,Index> tempVector(rows);
|
---|
| 35 |
|
---|
| 36 | // estimate the number of non zero entries
|
---|
| 37 | // given a rhs column containing Y non zeros, we assume that the respective Y columns
|
---|
| 38 | // of the lhs differs in average of one non zeros, thus the number of non zeros for
|
---|
| 39 | // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
|
---|
| 40 | // per column of the lhs.
|
---|
| 41 | // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
|
---|
| 42 | Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
|
---|
| 43 |
|
---|
| 44 | // mimics a resizeByInnerOuter:
|
---|
| 45 | if(ResultType::IsRowMajor)
|
---|
| 46 | res.resize(cols, rows);
|
---|
| 47 | else
|
---|
| 48 | res.resize(rows, cols);
|
---|
| 49 |
|
---|
| 50 | res.reserve(estimated_nnz_prod);
|
---|
| 51 | double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols()));
|
---|
| 52 | for (Index j=0; j<cols; ++j)
|
---|
| 53 | {
|
---|
| 54 | // FIXME:
|
---|
| 55 | //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
|
---|
| 56 | // let's do a more accurate determination of the nnz ratio for the current column j of res
|
---|
| 57 | tempVector.init(ratioColRes);
|
---|
| 58 | tempVector.setZero();
|
---|
| 59 | for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
|
---|
| 60 | {
|
---|
| 61 | // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
|
---|
| 62 | tempVector.restart();
|
---|
| 63 | Scalar x = rhsIt.value();
|
---|
| 64 | for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
|
---|
| 65 | {
|
---|
| 66 | tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
|
---|
| 67 | }
|
---|
| 68 | }
|
---|
| 69 | res.startVec(j);
|
---|
| 70 | for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it)
|
---|
| 71 | res.insertBackByOuterInner(j,it.index()) = it.value();
|
---|
| 72 | }
|
---|
| 73 | res.finalize();
|
---|
| 74 | }
|
---|
| 75 |
|
---|
| 76 | template<typename Lhs, typename Rhs, typename ResultType,
|
---|
| 77 | int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
|
---|
| 78 | int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
|
---|
| 79 | int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
|
---|
| 80 | struct sparse_sparse_product_with_pruning_selector;
|
---|
| 81 |
|
---|
| 82 | template<typename Lhs, typename Rhs, typename ResultType>
|
---|
| 83 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
|
---|
| 84 | {
|
---|
| 85 | typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
|
---|
| 86 | typedef typename ResultType::RealScalar RealScalar;
|
---|
| 87 |
|
---|
| 88 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
---|
| 89 | {
|
---|
| 90 | typename remove_all<ResultType>::type _res(res.rows(), res.cols());
|
---|
| 91 | internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
|
---|
| 92 | res.swap(_res);
|
---|
| 93 | }
|
---|
| 94 | };
|
---|
| 95 |
|
---|
| 96 | template<typename Lhs, typename Rhs, typename ResultType>
|
---|
| 97 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
|
---|
| 98 | {
|
---|
| 99 | typedef typename ResultType::RealScalar RealScalar;
|
---|
| 100 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
---|
| 101 | {
|
---|
| 102 | // we need a col-major matrix to hold the result
|
---|
| 103 | typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> SparseTemporaryType;
|
---|
| 104 | SparseTemporaryType _res(res.rows(), res.cols());
|
---|
| 105 | internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
|
---|
| 106 | res = _res;
|
---|
| 107 | }
|
---|
| 108 | };
|
---|
| 109 |
|
---|
| 110 | template<typename Lhs, typename Rhs, typename ResultType>
|
---|
| 111 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
|
---|
| 112 | {
|
---|
| 113 | typedef typename ResultType::RealScalar RealScalar;
|
---|
| 114 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
---|
| 115 | {
|
---|
| 116 | // let's transpose the product to get a column x column product
|
---|
| 117 | typename remove_all<ResultType>::type _res(res.rows(), res.cols());
|
---|
| 118 | internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
|
---|
| 119 | res.swap(_res);
|
---|
| 120 | }
|
---|
| 121 | };
|
---|
| 122 |
|
---|
| 123 | template<typename Lhs, typename Rhs, typename ResultType>
|
---|
| 124 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
|
---|
| 125 | {
|
---|
| 126 | typedef typename ResultType::RealScalar RealScalar;
|
---|
| 127 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
---|
| 128 | {
|
---|
| 129 | typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs;
|
---|
| 130 | typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs;
|
---|
| 131 | ColMajorMatrixLhs colLhs(lhs);
|
---|
| 132 | ColMajorMatrixRhs colRhs(rhs);
|
---|
| 133 | internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
|
---|
| 134 |
|
---|
| 135 | // let's transpose the product to get a column x column product
|
---|
| 136 | // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
|
---|
| 137 | // SparseTemporaryType _res(res.cols(), res.rows());
|
---|
| 138 | // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
|
---|
| 139 | // res = _res.transpose();
|
---|
| 140 | }
|
---|
| 141 | };
|
---|
| 142 |
|
---|
| 143 | // NOTE the 2 others cases (col row *) must never occur since they are caught
|
---|
| 144 | // by ProductReturnType which transforms it to (col col *) by evaluating rhs.
|
---|
| 145 |
|
---|
| 146 | } // end namespace internal
|
---|
| 147 |
|
---|
| 148 | } // end namespace Eigen
|
---|
| 149 |
|
---|
| 150 | #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
|
---|