LIBINT 2.7.2
tensor.h
1/*
2 * Copyright (C) 2004-2021 Edward F. Valeev
3 *
4 * This file is part of Libint.
5 *
6 * Libint is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Lesser General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * Libint is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public License
17 * along with Libint. If not, see <http://www.gnu.org/licenses/>.
18 *
19 */
20
21
22#ifndef _libint2_include_tensor_h_
23#define _libint2_include_tensor_h_
24
25#include <libint2/util/cxxstd.h>
26#if LIBINT2_CPLUSPLUS_STD < 2011
27# error "The simple Libint API requires C++11 support"
28#endif
29
30#include <cstdlib>
31#include <numeric>
32#include <vector>
33#include <algorithm>
34
35namespace libint2 {
36 template <typename T>
37 struct Tensor {
38 public:
39 Tensor() = default;
40 Tensor(const Tensor&) = default;
41 Tensor(Tensor&&) = default;
42 ~Tensor() = default;
43
44 template <class ... Dims> Tensor(Dims ... dims) : dims_{ std::forward<Dims>(dims)... } {
45 strides_.resize(sizeof...(dims));
46 // used in transform to compute strides
47 struct stride {
48 size_t value;
49 stride() : value(1) {}
50 size_t operator()(size_t dim) { auto old_value = value; value *= dim; return old_value; }
51 };
52 // row-major order of dimensions
53 std::transform(dims_.rbegin(), dims_.rend(),
54 strides_.rbegin(), stride());
55 size_t size = strides_.size() == 0 ? 0 : strides_[0] * dims_[0];
56 data_.resize(size);
57 }
58
59 T* data() { return &data_[0]; }
60 const T* data() const { return static_cast<const T*>(this->data()); }
61 template <class ... MultiIndex> T* data(MultiIndex ... index) {
62 assert(sizeof...(MultiIndex) <= dims_.size());
63 auto index_list = { std::forward<MultiIndex>(index)... };
64 size_t ordinal = std::inner_product(index_list.begin(), index_list.end(),
65 strides_.begin(), 0);
66 return &data_[ordinal];
67 }
68 template <class ... MultiIndex> const T* data(MultiIndex ... index) const {
69 return static_cast<const T*>(this->data());
70 }
71
72 template <class ... MultiIndex> const T& operator()(MultiIndex ... index) {
73 assert(sizeof...(MultiIndex) == dims_.size());
74 auto index_list = { std::forward<MultiIndex>(index)... };
75 size_t ordinal = std::inner_product(index_list.begin(), index_list.end(),
76 strides_.begin(), 0);
77 return data_[ordinal];
78 }
79
80 private:
81 std::vector<size_t> dims_;
82 std::vector<size_t> strides_;
83 std::vector<T> data_;
84
85 }; // class definition
86} // namespace libint2
87
88#endif /* _libint2_include_tensor_h_*/
Defaults definitions for various parameters assumed by Libint.
Definition: algebra.cc:24
Definition: tensor.h:37