1 | // This file is part of Eigen, a lightweight C++ template library
|
---|
2 | // for linear algebra.
|
---|
3 | //
|
---|
4 | // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr>
|
---|
5 | //
|
---|
6 | // This Source Code Form is subject to the terms of the Mozilla
|
---|
7 | // Public License v. 2.0. If a copy of the MPL was not distributed
|
---|
8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
---|
9 |
|
---|
10 | #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
|
---|
11 | #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
|
---|
12 |
|
---|
13 | namespace Eigen {
|
---|
14 |
|
---|
15 | namespace internal {
|
---|
16 |
|
---|
17 |
|
---|
18 | // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
|
---|
19 | template<typename Lhs, typename Rhs, typename ResultType>
|
---|
20 | static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
|
---|
21 | {
|
---|
22 | // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
|
---|
23 |
|
---|
24 | typedef typename remove_all<Lhs>::type::Scalar Scalar;
|
---|
25 | typedef typename remove_all<Lhs>::type::Index Index;
|
---|
26 |
|
---|
27 | // make sure to call innerSize/outerSize since we fake the storage order.
|
---|
28 | Index rows = lhs.innerSize();
|
---|
29 | Index cols = rhs.outerSize();
|
---|
30 | //Index size = lhs.outerSize();
|
---|
31 | eigen_assert(lhs.outerSize() == rhs.innerSize());
|
---|
32 |
|
---|
33 | // allocate a temporary buffer
|
---|
34 | AmbiVector<Scalar,Index> tempVector(rows);
|
---|
35 |
|
---|
36 | // estimate the number of non zero entries
|
---|
37 | // given a rhs column containing Y non zeros, we assume that the respective Y columns
|
---|
38 | // of the lhs differs in average of one non zeros, thus the number of non zeros for
|
---|
39 | // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
|
---|
40 | // per column of the lhs.
|
---|
41 | // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
|
---|
42 | Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
|
---|
43 |
|
---|
44 | // mimics a resizeByInnerOuter:
|
---|
45 | if(ResultType::IsRowMajor)
|
---|
46 | res.resize(cols, rows);
|
---|
47 | else
|
---|
48 | res.resize(rows, cols);
|
---|
49 |
|
---|
50 | res.reserve(estimated_nnz_prod);
|
---|
51 | double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols()));
|
---|
52 | for (Index j=0; j<cols; ++j)
|
---|
53 | {
|
---|
54 | // FIXME:
|
---|
55 | //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
|
---|
56 | // let's do a more accurate determination of the nnz ratio for the current column j of res
|
---|
57 | tempVector.init(ratioColRes);
|
---|
58 | tempVector.setZero();
|
---|
59 | for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
|
---|
60 | {
|
---|
61 | // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
|
---|
62 | tempVector.restart();
|
---|
63 | Scalar x = rhsIt.value();
|
---|
64 | for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
|
---|
65 | {
|
---|
66 | tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
|
---|
67 | }
|
---|
68 | }
|
---|
69 | res.startVec(j);
|
---|
70 | for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it)
|
---|
71 | res.insertBackByOuterInner(j,it.index()) = it.value();
|
---|
72 | }
|
---|
73 | res.finalize();
|
---|
74 | }
|
---|
75 |
|
---|
76 | template<typename Lhs, typename Rhs, typename ResultType,
|
---|
77 | int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
|
---|
78 | int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
|
---|
79 | int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
|
---|
80 | struct sparse_sparse_product_with_pruning_selector;
|
---|
81 |
|
---|
82 | template<typename Lhs, typename Rhs, typename ResultType>
|
---|
83 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
|
---|
84 | {
|
---|
85 | typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
|
---|
86 | typedef typename ResultType::RealScalar RealScalar;
|
---|
87 |
|
---|
88 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
---|
89 | {
|
---|
90 | typename remove_all<ResultType>::type _res(res.rows(), res.cols());
|
---|
91 | internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
|
---|
92 | res.swap(_res);
|
---|
93 | }
|
---|
94 | };
|
---|
95 |
|
---|
96 | template<typename Lhs, typename Rhs, typename ResultType>
|
---|
97 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
|
---|
98 | {
|
---|
99 | typedef typename ResultType::RealScalar RealScalar;
|
---|
100 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
---|
101 | {
|
---|
102 | // we need a col-major matrix to hold the result
|
---|
103 | typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> SparseTemporaryType;
|
---|
104 | SparseTemporaryType _res(res.rows(), res.cols());
|
---|
105 | internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
|
---|
106 | res = _res;
|
---|
107 | }
|
---|
108 | };
|
---|
109 |
|
---|
110 | template<typename Lhs, typename Rhs, typename ResultType>
|
---|
111 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
|
---|
112 | {
|
---|
113 | typedef typename ResultType::RealScalar RealScalar;
|
---|
114 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
---|
115 | {
|
---|
116 | // let's transpose the product to get a column x column product
|
---|
117 | typename remove_all<ResultType>::type _res(res.rows(), res.cols());
|
---|
118 | internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
|
---|
119 | res.swap(_res);
|
---|
120 | }
|
---|
121 | };
|
---|
122 |
|
---|
123 | template<typename Lhs, typename Rhs, typename ResultType>
|
---|
124 | struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
|
---|
125 | {
|
---|
126 | typedef typename ResultType::RealScalar RealScalar;
|
---|
127 | static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
---|
128 | {
|
---|
129 | typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs;
|
---|
130 | typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs;
|
---|
131 | ColMajorMatrixLhs colLhs(lhs);
|
---|
132 | ColMajorMatrixRhs colRhs(rhs);
|
---|
133 | internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
|
---|
134 |
|
---|
135 | // let's transpose the product to get a column x column product
|
---|
136 | // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
|
---|
137 | // SparseTemporaryType _res(res.cols(), res.rows());
|
---|
138 | // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
|
---|
139 | // res = _res.transpose();
|
---|
140 | }
|
---|
141 | };
|
---|
142 |
|
---|
143 | // NOTE the 2 others cases (col row *) must never occur since they are caught
|
---|
144 | // by ProductReturnType which transforms it to (col col *) by evaluating rhs.
|
---|
145 |
|
---|
146 | } // end namespace internal
|
---|
147 |
|
---|
148 | } // end namespace Eigen
|
---|
149 |
|
---|
150 | #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
|
---|