MPQC 3.0.0-alpha
Loading...
Searching...
No Matches
base.hpp
1#ifndef MPQC_MATH_TENSOR_BASE_HPP
2#define MPQC_MATH_TENSOR_BASE_HPP
3
4#include <algorithm>
5#include <assert.h>
6
7#include <boost/array.hpp>
8#include <boost/tuple/tuple.hpp>
9
10#include <boost/preprocessor/repetition/enum_params.hpp>
11#include <boost/preprocessor/repetition/enum_binary_params.hpp>
12
13#include "mpqc/range.hpp"
14#include "mpqc/range/tie.hpp"
15#include "mpqc/utility/string.hpp"
16
17#include "mpqc/math/tensor/forward.hpp"
18#include "mpqc/math/tensor/functional.hpp"
19#include "mpqc/math/tensor/exception.hpp"
20
21namespace mpqc {
22
25
31 template<typename T, size_t N, class Order>
32 class TensorBase {
33
34 public:
35 static const size_t RANK = N;
36 typedef boost::array<size_t,N> Dims;
37 typedef boost::array<size_t,N> Strides;
38
39 protected:
40 T *data_;
41 Dims dims_;
42 Strides strides_;
43
44 public:
45
46 TensorBase(T *data,
47 const size_t *dims,
48 const size_t *ld = NULL)
49 {
50 this->data_ = data;
51 std::copy(dims, dims+N, this->dims_.begin());
52 strides_ = Order::template strides<N>(ld ? ld : dims);
53 }
54
55 size_t size() const {
56 size_t size = 1;
57 for (int i = 0; i < N; ++i)
58 size *= this->dims_[i];
59 return size;
60 }
61
62 const Dims& dims() const {
63 return dims_;
64 }
65
66 public:
67
68 template<typename U, class O>
69 void operator=(const TensorBase<const U,N,O> &u) {
70 detail::Tensor::apply(detail::Tensor::assign(), *this, u);
71 }
72
73 void operator=(const TensorBase& o) {
74 this->operator=<T>(o);
75 }
76
77 public:
78
79 // generate index operator of arity N
80 // CV may be empty or const
81#define MPQC_MATH_TENSOR_INDEX_OPERATOR(Z, N, CV) \
82 template< BOOST_PP_ENUM_PARAMS(N, class T) > \
83 typename boost::enable_if \
84 < detail::Tensor::is_integral_tuple \
85 < boost::tuple<BOOST_PP_ENUM_PARAMS(N,T)> >, \
86 CV T& >::type \
87 operator()(BOOST_PP_ENUM_BINARY_PARAMS(N, const T, &i)) CV { \
88 using detail::Tensor::tie; \
89 return this->operator()(tie(boost::tie(BOOST_PP_ENUM_PARAMS(N,i)))); \
90 } \
91
92 // generate range operator of arity N
93 // CV may be empty or const
94#define MPQC_MATH_TENSOR_RANGE_OPERATOR(Z, N, CV) \
95 template< BOOST_PP_ENUM_PARAMS(N, class T) > \
96 typename boost::disable_if \
97 < detail::Tensor::is_integral_tuple \
98 < boost::tuple<BOOST_PP_ENUM_PARAMS(N,T)> >, \
99 TensorBase<CV T, N, Order> >::type \
100 operator()(BOOST_PP_ENUM_BINARY_PARAMS(N, const T, &i)) CV { \
101 using detail::Tensor::tie; \
102 return this->operator()(tie(boost::tie(BOOST_PP_ENUM_PARAMS(N,i)))); \
103 } \
104
105 BOOST_PP_REPEAT_FROM_TO(1, 5, MPQC_MATH_TENSOR_INDEX_OPERATOR, )
106 BOOST_PP_REPEAT_FROM_TO(1, 5, MPQC_MATH_TENSOR_INDEX_OPERATOR, const)
107 BOOST_PP_REPEAT_FROM_TO(1, 5, MPQC_MATH_TENSOR_RANGE_OPERATOR, )
108 BOOST_PP_REPEAT_FROM_TO(1, 5, MPQC_MATH_TENSOR_RANGE_OPERATOR, const)
109
110 public:
111
113 template<class Seq>
114 T& operator()(const detail::Tensor::integral_tie<Seq> &idx) {
115 return this->data_[this->index(idx)];
116 }
117
119 template<class Seq>
121 return this->data_[this->index(idx)];
122 }
123
124 template<class Seq>
127 return block< TensorBase<T, N, Order> >(*this, tie);
128 }
129
130 template<class Seq>
131 TensorBase<const T, N, Order>
132 operator()(const detail::Tensor::range_tie<Seq> &tie) const {
133 return block< TensorBase<const T, N, Order> >(*this, tie);
134 }
135
136 private:
137
138 template<class Seq, int K>
139 void check_index(const detail::Tensor::integral_tie<Seq> &tie,
140 boost::mpl::int_<K>) const {
141 using boost::fusion::at_c;
142 if ((at_c<K>(tie) < 0) || (at_c<K>(tie) > this->dims_[K])) {
143 throw TensorIndexException(K, at_c<K>(tie), 0, this->dims_[K]);
144 }
145 check_index(tie, boost::mpl::int_<K+1>());
146 }
147
148 template<class Seq>
149 void check_index(const detail::Tensor::integral_tie<Seq> &tie,
150 boost::mpl::int_<N>) const {}
151
152 template<class Seq>
153 ptrdiff_t index(const detail::Tensor::integral_tie<Seq> &idx) const {
154 static_assert(boost::fusion::result_of::size<Seq>::value == N,
155 "Invalid TensorBase::operator() arity");
156// #ifndef NDEBUG
157// check_index(idx, boost::mpl::int_<0>());
158// #endif
159 ptrdiff_t index = Order::index((const Seq&)idx, this->strides_);
160 //std::cout << idx << ":" << index << "->" << data_+index << std::endl;
161 return index;
162 }
163
164 private:
165
166 template<class U, class This, class Tie>
167 static U block(This &t, detail::Tensor::range_tie<Tie> tie) {
168 static_assert(boost::tuples::length<Tie>::value == N,
169 "Invalid TensorBase::operator() arity");
170 boost::array<range,N> r = range::tie<Tie>(tie);
171 boost::array<ptrdiff_t,N> begin;
172 Dims dims;
173 for (int i = 0; i < N; ++i) {
174 begin[i] = *r[i].begin();
175 dims[i] = r[i].size();
176#ifndef NDEBUG
177 if ((*r[i].begin() < 0) || (*r[i].end() > t.dims_[i])) {
178 throw TensorRangeException(i, r[i], 0, t.dims_[i]);
179 }
180#endif
181 }
182 ptrdiff_t offset = Order::index(begin, t.strides_);
183 //std::cout << offset << std::endl;
184 return U(t.data_+offset, dims, t.strides_);
185 }
186
187 private:
188
189 friend class TensorBase< typename boost::remove_const<T>::type, N, Order>;
190 TensorBase(T *data, const Dims &dims, const Strides &strides)
191 : data_(data), dims_(dims), strides_(strides)
192 {
193 }
194
195 };
196
198
199}
200
201
202#endif /* MPQC_MATH_TENSOR_BASE_HPP */
Tensor base class.
Definition forward.hpp:16
T & operator()(const detail::Tensor::integral_tie< Seq > &idx)
element-access operator
Definition base.hpp:114
const T & operator()(const detail::Tensor::integral_tie< Seq > &idx) const
element-access operator
Definition base.hpp:120
Contains new MPQC code since version 3.
Definition integralenginepool.hpp:37
Tensor reference class.
Definition tensor.hpp:13
index tie wrapper
Definition forward.hpp:45
range tie wrapper
Definition forward.hpp:51

Generated at Wed Sep 25 2024 02:45:30 for MPQC 3.0.0-alpha using the documentation package Doxygen 1.12.0.