LIBINT 2.9.0
tensor.h
1/*
2 * Copyright (C) 2004-2024 Edward F. Valeev
3 *
4 * This file is part of Libint library.
5 *
6 * Libint library is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Lesser General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * Libint library is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public License
17 * along with Libint library. If not, see <http://www.gnu.org/licenses/>.
18 *
19 */
20
21#ifndef _libint2_include_tensor_h_
22#define _libint2_include_tensor_h_
23
24#include <libint2/util/cxxstd.h>
25#if LIBINT2_CPLUSPLUS_STD < 2011
26#error "The simple Libint API requires C++11 support"
27#endif
28
29#include <algorithm>
30#include <cassert>
31#include <cstdlib>
32#include <numeric>
33#include <vector>
34
35namespace libint2 {
36template <typename T>
37struct Tensor {
38 public:
39 Tensor() = default;
40 Tensor(const Tensor&) = default;
41 Tensor(Tensor&&) = default;
42 ~Tensor() = default;
43
44 template <class... Dims>
45 Tensor(Dims... dims) : dims_{std::forward<Dims>(dims)...} {
46 strides_.resize(sizeof...(dims));
47 // used in transform to compute strides
48 struct stride {
49 size_t value;
50 stride() : value(1) {}
51 size_t operator()(size_t dim) {
52 auto old_value = value;
53 value *= dim;
54 return old_value;
55 }
56 };
57 // row-major order of dimensions
58 std::transform(dims_.rbegin(), dims_.rend(), strides_.rbegin(), stride());
59 size_t size = strides_.size() == 0 ? 0 : strides_[0] * dims_[0];
60 data_.resize(size);
61 }
62
63 T* data() { return &data_[0]; }
64 const T* data() const { return static_cast<const T*>(this->data()); }
65 template <class... MultiIndex>
66 T* data(MultiIndex... index) {
67 assert(sizeof...(MultiIndex) <= dims_.size());
68 auto index_list = {std::forward<MultiIndex>(index)...};
69 size_t ordinal = std::inner_product(index_list.begin(), index_list.end(),
70 strides_.begin(), 0);
71 return &data_[ordinal];
72 }
73 template <class... MultiIndex>
74 const T* data(MultiIndex... index) const {
75 return static_cast<const T*>(this->data());
76 }
77
78 template <class... MultiIndex>
79 const T& operator()(MultiIndex... index) {
80 assert(sizeof...(MultiIndex) == dims_.size());
81 auto index_list = {std::forward<MultiIndex>(index)...};
82 size_t ordinal = std::inner_product(index_list.begin(), index_list.end(),
83 strides_.begin(), 0);
84 return data_[ordinal];
85 }
86
87 private:
88 std::vector<size_t> dims_;
89 std::vector<size_t> strides_;
90 std::vector<T> data_;
91
92}; // class definition
93} // namespace libint2
94
95#endif /* _libint2_include_tensor_h_*/
Defaults definitions for various parameters assumed by Libint.
Definition algebra.cc:24
Definition tensor.h:37