Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
Eigen  3.4.0
SparseDenseProduct.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SPARSEDENSEPRODUCT_H
11 #define EIGEN_SPARSEDENSEPRODUCT_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template <> struct product_promote_storage_type<Sparse,Dense, OuterProduct> { typedef Sparse ret; };
18 template <> struct product_promote_storage_type<Dense,Sparse, OuterProduct> { typedef Sparse ret; };
19 
20 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
21  typename AlphaType,
22  int LhsStorageOrder = ((SparseLhsType::Flags&RowMajorBit)==RowMajorBit) ? RowMajor : ColMajor,
23  bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
24 struct sparse_time_dense_product_impl;
25 
26 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
27 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, true>
28 {
29  typedef typename internal::remove_all<SparseLhsType>::type Lhs;
30  typedef typename internal::remove_all<DenseRhsType>::type Rhs;
31  typedef typename internal::remove_all<DenseResType>::type Res;
32  typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
33  typedef evaluator<Lhs> LhsEval;
34  static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
35  {
36  LhsEval lhsEval(lhs);
37 
38  Index n = lhs.outerSize();
39 #ifdef EIGEN_HAS_OPENMP
40  Eigen::initParallel();
41  Index threads = Eigen::nbThreads();
42 #endif
43 
44  for(Index c=0; c<rhs.cols(); ++c)
45  {
46 #ifdef EIGEN_HAS_OPENMP
47  // This 20000 threshold has been found experimentally on 2D and 3D Poisson problems.
48  // It basically represents the minimal amount of work to be done to be worth it.
49  if(threads>1 && lhsEval.nonZerosEstimate() > 20000)
50  {
51  int sched = (n+threads*4-1)/(threads*4);
52  #pragma omp parallel for schedule(dynamic,sched) num_threads(threads)
53  for(Index i=0; i<n; ++i)
54  processRow(lhsEval,rhs,res,alpha,i,c);
55  }
56  else
57 #endif
58  {
59  for(Index i=0; i<n; ++i)
60  processRow(lhsEval,rhs,res,alpha,i,c);
61  }
62  }
63  }
64 
65  static void processRow(const LhsEval& lhsEval, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha, Index i, Index col)
66  {
67  typename Res::Scalar tmp(0);
68  for(LhsInnerIterator it(lhsEval,i); it ;++it)
69  tmp += it.value() * rhs.coeff(it.index(),col);
70  res.coeffRef(i,col) += alpha * tmp;
71  }
72 
73 };
74 
75 // FIXME: what is the purpose of the following specialization? Is it for the BlockedSparse format?
76 // -> let's disable it for now as it is conflicting with generic scalar*matrix and matrix*scalar operators
77 // template<typename T1, typename T2/*, int _Options, typename _StrideType*/>
78 // struct ScalarBinaryOpTraits<T1, Ref<T2/*, _Options, _StrideType*/> >
79 // {
80 // enum {
81 // Defined = 1
82 // };
83 // typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType;
84 // };
85 
86 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType, typename AlphaType>
87 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType, ColMajor, true>
88 {
89  typedef typename internal::remove_all<SparseLhsType>::type Lhs;
90  typedef typename internal::remove_all<DenseRhsType>::type Rhs;
91  typedef typename internal::remove_all<DenseResType>::type Res;
92  typedef evaluator<Lhs> LhsEval;
93  typedef typename LhsEval::InnerIterator LhsInnerIterator;
94  static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
95  {
96  LhsEval lhsEval(lhs);
97  for(Index c=0; c<rhs.cols(); ++c)
98  {
99  for(Index j=0; j<lhs.outerSize(); ++j)
100  {
101 // typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
102  typename ScalarBinaryOpTraits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c));
103  for(LhsInnerIterator it(lhsEval,j); it ;++it)
104  res.coeffRef(it.index(),c) += it.value() * rhs_j;
105  }
106  }
107  }
108 };
109 
110 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
111 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, RowMajor, false>
112 {
113  typedef typename internal::remove_all<SparseLhsType>::type Lhs;
114  typedef typename internal::remove_all<DenseRhsType>::type Rhs;
115  typedef typename internal::remove_all<DenseResType>::type Res;
116  typedef evaluator<Lhs> LhsEval;
117  typedef typename LhsEval::InnerIterator LhsInnerIterator;
118  static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
119  {
120  Index n = lhs.rows();
121  LhsEval lhsEval(lhs);
122 
123 #ifdef EIGEN_HAS_OPENMP
124  Eigen::initParallel();
125  Index threads = Eigen::nbThreads();
126  // This 20000 threshold has been found experimentally on 2D and 3D Poisson problems.
127  // It basically represents the minimal amount of work to be done to be worth it.
128  if(threads>1 && lhsEval.nonZerosEstimate()*rhs.cols() > 20000)
129  {
130  int sched = (n+threads*4-1)/(threads*4);
131  #pragma omp parallel for schedule(dynamic,sched) num_threads(threads)
132  for(Index i=0; i<n; ++i)
133  processRow(lhsEval,rhs,res,alpha,i);
134  }
135  else
136 #endif
137  {
138  for(Index i=0; i<n; ++i)
139  processRow(lhsEval, rhs, res, alpha, i);
140  }
141  }
142 
143  static void processRow(const LhsEval& lhsEval, const DenseRhsType& rhs, Res& res, const typename Res::Scalar& alpha, Index i)
144  {
145  typename Res::RowXpr res_i(res.row(i));
146  for(LhsInnerIterator it(lhsEval,i); it ;++it)
147  res_i += (alpha*it.value()) * rhs.row(it.index());
148  }
149 };
150 
151 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
152 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar, ColMajor, false>
153 {
154  typedef typename internal::remove_all<SparseLhsType>::type Lhs;
155  typedef typename internal::remove_all<DenseRhsType>::type Rhs;
156  typedef typename internal::remove_all<DenseResType>::type Res;
157  typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
158  static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const typename Res::Scalar& alpha)
159  {
160  evaluator<Lhs> lhsEval(lhs);
161  for(Index j=0; j<lhs.outerSize(); ++j)
162  {
163  typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
164  for(LhsInnerIterator it(lhsEval,j); it ;++it)
165  res.row(it.index()) += (alpha*it.value()) * rhs_j;
166  }
167  }
168 };
169 
170 template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
171 inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
172 {
173  sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha);
174 }
175 
176 } // end namespace internal
177 
178 namespace internal {
179 
180 template<typename Lhs, typename Rhs, int ProductType>
181 struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType>
182  : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SparseShape,DenseShape,ProductType> >
183 {
184  typedef typename Product<Lhs,Rhs>::Scalar Scalar;
185 
186  template<typename Dest>
187  static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
188  {
189  typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? 1 : Rhs::ColsAtCompileTime>::type LhsNested;
190  typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==0) ? 1 : Dynamic>::type RhsNested;
191  LhsNested lhsNested(lhs);
192  RhsNested rhsNested(rhs);
193  internal::sparse_time_dense_product(lhsNested, rhsNested, dst, alpha);
194  }
195 };
196 
197 template<typename Lhs, typename Rhs, int ProductType>
198 struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType>
199  : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType>
200 {};
201 
202 template<typename Lhs, typename Rhs, int ProductType>
203 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType>
204  : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SparseShape,ProductType> >
205 {
206  typedef typename Product<Lhs,Rhs>::Scalar Scalar;
207 
208  template<typename Dst>
209  static void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
210  {
211  typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? Dynamic : 1>::type LhsNested;
212  typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==RowMajorBit) ? 1 : Lhs::RowsAtCompileTime>::type RhsNested;
213  LhsNested lhsNested(lhs);
214  RhsNested rhsNested(rhs);
215 
216  // transpose everything
217  Transpose<Dst> dstT(dst);
218  internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, alpha);
219  }
220 };
221 
222 template<typename Lhs, typename Rhs, int ProductType>
223 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType>
224  : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType>
225 {};
226 
227 template<typename LhsT, typename RhsT, bool NeedToTranspose>
228 struct sparse_dense_outer_product_evaluator
229 {
230 protected:
231  typedef typename conditional<NeedToTranspose,RhsT,LhsT>::type Lhs1;
232  typedef typename conditional<NeedToTranspose,LhsT,RhsT>::type ActualRhs;
233  typedef Product<LhsT,RhsT,DefaultProduct> ProdXprType;
234 
235  // if the actual left-hand side is a dense vector,
236  // then build a sparse-view so that we can seamlessly iterate over it.
237  typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value,
238  Lhs1, SparseView<Lhs1> >::type ActualLhs;
239  typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value,
240  Lhs1 const&, SparseView<Lhs1> >::type LhsArg;
241 
242  typedef evaluator<ActualLhs> LhsEval;
243  typedef evaluator<ActualRhs> RhsEval;
244  typedef typename evaluator<ActualLhs>::InnerIterator LhsIterator;
245  typedef typename ProdXprType::Scalar Scalar;
246 
247 public:
248  enum {
249  Flags = NeedToTranspose ? RowMajorBit : 0,
250  CoeffReadCost = HugeCost
251  };
252 
253  class InnerIterator : public LhsIterator
254  {
255  public:
256  InnerIterator(const sparse_dense_outer_product_evaluator &xprEval, Index outer)
257  : LhsIterator(xprEval.m_lhsXprImpl, 0),
258  m_outer(outer),
259  m_empty(false),
260  m_factor(get(xprEval.m_rhsXprImpl, outer, typename internal::traits<ActualRhs>::StorageKind() ))
261  {}
262 
263  EIGEN_STRONG_INLINE Index outer() const { return m_outer; }
264  EIGEN_STRONG_INLINE Index row() const { return NeedToTranspose ? m_outer : LhsIterator::index(); }
265  EIGEN_STRONG_INLINE Index col() const { return NeedToTranspose ? LhsIterator::index() : m_outer; }
266 
267  EIGEN_STRONG_INLINE Scalar value() const { return LhsIterator::value() * m_factor; }
268  EIGEN_STRONG_INLINE operator bool() const { return LhsIterator::operator bool() && (!m_empty); }
269 
270  protected:
271  Scalar get(const RhsEval &rhs, Index outer, Dense = Dense()) const
272  {
273  return rhs.coeff(outer);
274  }
275 
276  Scalar get(const RhsEval &rhs, Index outer, Sparse = Sparse())
277  {
278  typename RhsEval::InnerIterator it(rhs, outer);
279  if (it && it.index()==0 && it.value()!=Scalar(0))
280  return it.value();
281  m_empty = true;
282  return Scalar(0);
283  }
284 
285  Index m_outer;
286  bool m_empty;
287  Scalar m_factor;
288  };
289 
290  sparse_dense_outer_product_evaluator(const Lhs1 &lhs, const ActualRhs &rhs)
291  : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs)
292  {
293  EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
294  }
295 
296  // transpose case
297  sparse_dense_outer_product_evaluator(const ActualRhs &rhs, const Lhs1 &lhs)
298  : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs)
299  {
300  EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
301  }
302 
303 protected:
304  const LhsArg m_lhs;
305  evaluator<ActualLhs> m_lhsXprImpl;
306  evaluator<ActualRhs> m_rhsXprImpl;
307 };
308 
309 // sparse * dense outer product
310 template<typename Lhs, typename Rhs>
311 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape>
312  : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor>
313 {
314  typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base;
315 
316  typedef Product<Lhs, Rhs> XprType;
317  typedef typename XprType::PlainObject PlainObject;
318 
319  explicit product_evaluator(const XprType& xpr)
320  : Base(xpr.lhs(), xpr.rhs())
321  {}
322 
323 };
324 
325 template<typename Lhs, typename Rhs>
326 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape>
327  : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor>
328 {
329  typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base;
330 
331  typedef Product<Lhs, Rhs> XprType;
332  typedef typename XprType::PlainObject PlainObject;
333 
334  explicit product_evaluator(const XprType& xpr)
335  : Base(xpr.lhs(), xpr.rhs())
336  {}
337 
338 };
339 
340 } // end namespace internal
341 
342 } // end namespace Eigen
343 
344 #endif // EIGEN_SPARSEDENSEPRODUCT_H
Definition: Constants.h:319
const int HugeCost
Definition: Constants.h:44
Namespace containing all symbols from the Eigen library.
Definition: Core:141
const unsigned int RowMajorBit
Definition: Constants.h:66
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:74
Definition: Eigen_Colamd.h:50
Definition: Constants.h:321