44 template <
class... Dims>
45 Tensor(Dims... dims) : dims_{std::forward<Dims>(dims)...} {
46 strides_.resize(
sizeof...(dims));
50 stride() : value(1) {}
51 size_t operator()(
size_t dim) {
52 auto old_value = value;
58 std::transform(dims_.rbegin(), dims_.rend(), strides_.rbegin(), stride());
59 size_t size = strides_.size() == 0 ? 0 : strides_[0] * dims_[0];
63 T* data() {
return &data_[0]; }
64 const T* data()
const {
return static_cast<const T*
>(this->data()); }
65 template <
class... MultiIndex>
66 T* data(MultiIndex... index) {
67 assert(
sizeof...(MultiIndex) <= dims_.size());
68 auto index_list = {std::forward<MultiIndex>(index)...};
69 size_t ordinal = std::inner_product(index_list.begin(), index_list.end(),
71 return &data_[ordinal];
73 template <
class... MultiIndex>
74 const T* data(MultiIndex... index)
const {
75 return static_cast<const T*
>(this->data());
78 template <
class... MultiIndex>
79 const T& operator()(MultiIndex... index) {
80 assert(
sizeof...(MultiIndex) == dims_.size());
81 auto index_list = {std::forward<MultiIndex>(index)...};
82 size_t ordinal = std::inner_product(index_list.begin(), index_list.end(),
84 return data_[ordinal];
88 std::vector<size_t> dims_;
89 std::vector<size_t> strides_;