MPQC 3.0.0-alpha
Loading...
Searching...
No Matches
contract.h
1
2/*
3 * Copyright 2009 Sandia Corporation. Under the terms of Contract
4 * DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government
5 * retains certain rights in this software.
6 *
7 * This file is a part of the MPQC LMP2 library.
8 *
9 * The MPQC LMP2 library is free software: you can redistribute it
10 * and/or modify it under the terms of the GNU Lesser General Public
11 * License as published by the Free Software Foundation, either
12 * version 3 of the License, or (at your option) any later version.
13 *
14 * This program is distributed in the hope that it will be useful, but
15 * WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
18 *
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with this program. If not, see
21 * <http://www.gnu.org/licenses/>.
22 *
23 */
24
25#ifndef _chemistry_qc_lmp2_contract_h
26#define _chemistry_qc_lmp2_contract_h
27
28#include <chemistry/qc/lmp2/dgemminfo.h>
29
30#include <stdexcept>
31
32#include <math/scmat/blas.h>
33
34#define USE_BOUNDS_IN_CONTRACT_UNION 1
35
36namespace sc {
37
38namespace sma2 {
39
46template <int N>
47inline bool
48need_repack(const Array<N> *A, const IndexList &row_indices, const IndexList &col_indices,
49 const IndexList &fixed, const BlockInfo<N> *fixedvals)
50{
51 // All that matters are indices that have block sizes greater than 1.
52 // If those indices are ordered correctly, then a repack is not needed.
53
54 std::vector<int> indices;
55
56 for (int i=0; i<row_indices.n(); i++) {
57 int index = row_indices.i(i);
58 if (A->index(index).max_block_size() > 1) {
59 indices.push_back(index);
60 }
61 }
62
63 for (int i=0; i<col_indices.n(); i++) {
64 int index = col_indices.i(i);
65 if (A->index(index).max_block_size() > 1) {
66 indices.push_back(index);
67 }
68 }
69
70 for (int i=1; i<indices.size(); i++) {
71 if (indices[i] < indices[i-1]) return true;
72 }
73
74 return false;
75}
76
83template <int N>
84inline double
85repack_cost(const Array<N> *A,
86 const IndexList &row_indices, const IndexList &col_indices,
87 const IndexList &fixed, const BlockInfo<N> *fixedvals)
88{
89 // This gives a very rough estimate of the relative repack costs
90 // nindex cost
91 // 0 0
92 // 1 1
93 // 2 3
94 // 3 7
95 // 4 15
96 double cost = A->n_element_allocated();
97 for (int i=0; i<fixed.n(); i++) {
98 cost = cost/A->index(fixed.i(i)).nblock();
99 }
100 return cost;
101}
102
104template <int NC, int NA, int NB>
106 double cost_;
107 bool need_repack_A_, transpose_A_;
108 bool need_repack_B_, transpose_B_;
109 bool need_repack_C_, transpose_C_;
110 IndexList extA_, intA_, fixA_;
111 const BlockInfo<NA> *fixvalA_;
112 IndexList extB_, intB_, fixB_;
113 const BlockInfo<NB> *fixvalB_;
114 IndexList extCA_, extCB_, fixC_;
115 const BlockInfo<NC> *fixvalC_;
116 const Array<NA> *A_;
117 const Array<NB> *B_;
118 const Array<NC> *C_;
119 int n_C_repack_;
120
121 void init() {
122 cost_ = 0.0;
123 need_repack_A_ = false;
124 transpose_A_ = false;
125 if (need_repack(A_, extA_, intA_, fixA_, fixvalA_)) {
126 if (need_repack(A_, intA_, extA_, fixA_, fixvalA_)) {
127 need_repack_A_ = true;
128 cost_ += 2.0 * repack_cost(A_,extA_,intA_,fixA_,fixvalA_);
129 }
130 else {
131 transpose_A_ = true;
132 }
133 }
134
135 need_repack_B_ = false;
136 transpose_B_ = false;
137 if (need_repack(B_, intB_, extB_, fixB_, fixvalB_)) {
138 if (need_repack(B_, extB_, intB_, fixB_, fixvalB_)) {
139 need_repack_B_ = true;
140 cost_ += 2.0 * repack_cost(B_,intB_,extB_,fixB_,fixvalB_);
141 }
142 else {
143 transpose_B_ = true;
144 }
145 }
146
147 need_repack_C_ = false;
148 transpose_C_ = false;
149 if (need_repack(C_, extCA_, extCB_, fixC_, fixvalC_)) {
150 if (need_repack(C_, extCB_, extCA_, fixC_, fixvalC_)) {
151 need_repack_C_ = true;
152 cost_ += n_C_repack_
153 * repack_cost(C_,extCA_,extCB_,fixC_,fixvalC_);
154 }
155 else {
156 transpose_C_ = true;
157 }
158 }
159 }
160 void reorder_indices(IndexList &i1,IndexList &i2) {
161 std::map<int,int> index_map;
162 for (int i=0; i<i1.n(); i++) {
163 index_map[i1.i(i)] = i2.i(i);
164 }
165 int i=0;
166 for (std::map<int,int>::iterator
167 iter=index_map.begin();
168 iter!=index_map.end(); i++,iter++) {
169 i1.i(i) = iter->first;
170 i2.i(i) = iter->second;
171 }
172 }
173 public:
176 const IndexList &extA, const IndexList &intA,
177 const IndexList &fixA, const BlockInfo<NA> *fixvalA,
178 const Array<NB> *B,
179 const IndexList &intB, const IndexList &extB,
180 const IndexList &fixB, const BlockInfo<NB> *fixvalB,
181 const Array<NC> *C,
182 const IndexList &extCA, const IndexList &extCB,
183 const IndexList &fixC, const BlockInfo<NC> *fixvalC,
184 int n_C_repack
185 ):
186 cost_(0.0),
187 extA_(extA), intA_(intA), fixA_(fixA), fixvalA_(fixvalA),
188 extB_(extB), intB_(intB), fixB_(fixB), fixvalB_(fixvalB),
189 extCA_(extCA), extCB_(extCB), fixC_(fixC), fixvalC_(fixvalC),
190 A_(A), B_(B), C_(C),
191 n_C_repack_(n_C_repack) {
192 init();
193 }
196 void driver(int i) {
197 if (i==0) {
198 reorder_indices(extCA_,extA_);
199 reorder_indices(extCB_,extB_);
200 }
201 else if (i==1) {
202 reorder_indices(extA_,extCA_);
203 reorder_indices(intA_,intB_);
204 }
205 else {
206 reorder_indices(extB_,extCB_);
207 reorder_indices(intB_,intA_);
208 }
209 init();
210 }
212 bool need_repack_A() const { return need_repack_A_; }
214 bool transpose_A() const { return transpose_A_; }
216 bool need_repack_B() const { return need_repack_B_; }
218 bool transpose_B() const { return transpose_B_; }
220 bool need_repack_C() const { return need_repack_C_; }
222 bool transpose_C() const { return transpose_C_; }
224 double cost() const { return cost_; }
225
229 IndexList &intB, IndexList &extB,
230 IndexList &extCA,IndexList &extCB) {
231 extA = extA_;
232 intA = intA_;
233 intB = intB_;
234 extB = extB_;
235 extCA = extCA_;
236 extCB = extCB_;
237 }
238};
239
247template <int N>
248inline void
249repack(Array<N> &A, const IndexList &row_indices, const IndexList &col_indices,
250 const IndexList &fixed, const BlockInfo<N> &fixedvals,
251 bool reverse = false)
252{
253 double *tmp_data = 0;
254 int n_tmp_data = 0;
255
256 const typename Array<N>::blockmap_t &amap = A.blockmap();
257
258 typename Array<N>::blockmap_t::const_iterator begin, end;
259 if (fixed.n() == 0) {
260 begin = amap.begin();
261 end = amap.end();
262 }
263 else {
264 BlockInfo<N> sbi;
265
266 sbi.zero();
267 sbi.assign_blocks(fixed, fixedvals);
268#ifdef USE_BOUND
269 sbi.set_bound(DBL_MAX);
270#endif
271 begin = amap.lower_bound(sbi);
272
273 for (int i=0; i<N; i++) sbi.block(i) = A.index(i).nblock();
274 sbi.assign_blocks(fixed, fixedvals);
275#ifdef USE_BOUND
276 sbi.set_bound(0.0);
277#endif
278 end = amap.upper_bound(sbi);
279 }
280
281 for (typename Array<N>::blockmap_t::const_iterator aiter = begin;
282 aiter != end;
283 aiter++) {
284 const BlockInfo<N> &bi = aiter->first;
285 double *data = aiter->second;
286 int ndata = bi.size(A.indices());
287 if (n_tmp_data < ndata) {
288 delete[] tmp_data;
289 tmp_data = new double[ndata];
290 n_tmp_data = ndata;
291 }
292 int nrow = bi.subset_size(A.indices(), row_indices);
293 int ncol = bi.subset_size(A.indices(), col_indices);
294 if (ndata != nrow * ncol) {
295 throw std::length_error("sma::repack: ntotal != nrow * ncol");
296 }
297 memcpy(tmp_data,data,sizeof(double)*ndata);
298 BlockIter<N> a_iter(A.indices(), bi);
299 for (a_iter.start(); a_iter.ready(); a_iter++) {
300 int row_index = a_iter.subset_offset(row_indices);
301 int col_index = a_iter.subset_offset(col_indices);
302 int index = a_iter.offset();
303 int repacked_index = row_index*ncol + col_index;
304 if (!reverse) {
305 data[repacked_index] = tmp_data[index];
306 }
307 else {
308 data[index] = tmp_data[repacked_index];
309 }
310 }
311 }
312
313 delete[] tmp_data;
314}
315
344template <int NC, int NA, int NB>
345inline void
346contract(
347 Array<NC> &C, const IndexList &c_extCA, const IndexList &c_extCB,
348 const IndexList &fixextCA, const IndexList &fixextCB,
349 const IndexList &fixC, const BlockInfo<NC> &fixvalC,
350 Array<NA> &A, const IndexList &c_extA, const IndexList &fixextA,
351 const IndexList &c_intA,
352 const IndexList &fixA, const BlockInfo<NA> &fixvalA,
353 bool clear_A_after_use,
354 Array<NB> &B, const IndexList &c_extB, const IndexList &fixextB,
355 const IndexList &c_intB,
356 const IndexList &fixB, const BlockInfo<NB> &fixvalB,
357 bool clear_B_after_use,
358 double ABfactor,
359 bool C_is_zero_on_entry = false,
360 sc::Ref<sc::RegionTimer> timer = 0)
361{
362 // Some of the arguments are copied to local variables. This
363 // is to allow modification of those arguments locally to permit
364 // optimization of the cost of repacking the arrays.
365 IndexList extCA(c_extCA), extCB(c_extCB);
366 IndexList extA(c_extA), extB(c_extB);
367 IndexList intA(c_intA), intB(c_intB);
368
369 if (C.n_element_allocated() == 0
370 || A.n_element_allocated() == 0
371 || B.n_element_allocated() == 0) {
372 return;
373 }
374
375 for (int i=0; i<fixC.n(); i++)
376 if (fixC.i(i) >= fixC.n())
377 throw std::invalid_argument("contract: C's fixed indices must be first");
378
379 // Consistency checks
380 if (extA.n() + extB.n() != NC - fixC.n() + fixextCA.n() + fixextCB.n()) {
381 throw std::invalid_argument("contract: Number of externals on A + B != C");
382 }
383 if (extCA.n() + extCB.n() != NC - fixC.n() + fixextCA.n() + fixextCB.n()) {
384 throw std::invalid_argument("contract: Number of indices on C inconsistent");
385 }
386 if (intA.n() != intB.n()) {
387 throw std::invalid_argument(
388 "contract: Number of internals on A and B inconsistent");
389 }
390 if (intA.n() + extA.n() != NA - fixA.n() + fixextA.n()) {
391 throw std::invalid_argument("contract: Number of indices on A inconsistent");
392 }
393 if (intB.n() + extB.n() != NB - fixB.n() + fixextB.n()) {
394 throw std::invalid_argument("contract: Number of indices on B inconsistent");
395 }
396 for (int i=0; i<extA.n(); i++) {
397 if (A.index(extA.i(i)) != C.index(extCA.i(i))) {
398 throw std::invalid_argument("contract: Range conflict between A and C");
399 }
400 }
401 for (int i=0; i<extB.n(); i++) {
402 if (B.index(extB.i(i)) != C.index(extCB.i(i))) {
403 throw std::invalid_argument("contract: Range conflict between B and C");
404 }
405 }
406 for (int i=0; i<intA.n(); i++) {
407 if (A.index(intA.i(i)) != B.index(intB.i(i))) {
408 throw std::invalid_argument("contract: Range conflict between A and B");
409 }
410 }
411
412 RepackScheme<NC,NA,NB> repack_scheme(&A, extA, intA, fixA, &fixvalA,
413 &B, intB, extB, fixB, &fixvalB,
414 &C, extCA,extCB,fixC, &fixvalC,
415 C_is_zero_on_entry?1:2);
416
417// std::cout << "trying to repack: original cost = "
418// << repack_scheme.cost()
419// << ", rpk A = " << repack_scheme.need_repack_A()
420// << " (" << A.n_element_allocated() << ")"
421// << ", rpk B = " << repack_scheme.need_repack_B()
422// << " (" << B.n_element_allocated() << ")"
423// << ", rpk C = " << repack_scheme.need_repack_C()
424// << " (" << C.n_element_allocated() << ")"
425// << std::endl;
426
427 if (repack_scheme.cost() != 0.0) {
428 RepackScheme<NC,NA,NB> tmp_repack_scheme(repack_scheme);
429 for (int driver1 = 0; driver1 < 3; driver1++) {
430 for (int driver2 = 0; driver2 < 3; driver2++) {
431 if (driver1 == driver2) continue;
432 tmp_repack_scheme.driver(driver1);
433 tmp_repack_scheme.driver(driver2);
434
435// std::cout << " repack scheme: "
436// << " d1 = " << driver1
437// << " d2 = " << driver2
438// << ", cost = "
439// << tmp_repack_scheme.cost()
440// << ", rpk A = "
441// << tmp_repack_scheme.need_repack_A()
442// << " (" << A.n_element_allocated() << ")"
443// << ", rpk B = "
444// << tmp_repack_scheme.need_repack_B()
445// << " (" << B.n_element_allocated() << ")"
446// << ", rpk C = "
447// << tmp_repack_scheme.need_repack_C()
448// << " (" << C.n_element_allocated() << ")"
449// << std::endl;
450
451 if (tmp_repack_scheme.cost() < repack_scheme.cost()) {
452 repack_scheme = tmp_repack_scheme;
453 repack_scheme.assign_indices(extA, intA,
454 intB, extB,
455 extCA,extCB);
456 //std::cout << "found a more efficient scheme" << std::endl;
457 }
458 if (repack_scheme.cost() == 0.0) break;
459 }
460 if (repack_scheme.cost() == 0.0) break;
461 }
462 }
463
464 // Remap the blocks of A so that it is sorted by its external indices
465 // and any fixed indices.
466 if (timer) timer->enter("remap A");
467 IndexList cmpAlist(extA,fixA);
468 IndexListLess<NA> cmpA(cmpAlist);
469 typename Array<NA>::cached_blockmap_t remappedAbm_local(cmpA);
470 typename Array<NA>::cached_blockmap_t *remappedAbm_ptr;
471 if (fixA.n() == 0 && A.use_blockmap_cache()) {
472 remappedAbm_ptr = &A.blockmap_cache_entry(cmpAlist);
473 }
474 else {
475 remap(remappedAbm_local, A, fixA, fixvalA);
476 remappedAbm_ptr = &remappedAbm_local;
477 }
478 typename Array<NA>::cached_blockmap_t &remappedAbm=*remappedAbm_ptr;
479 if (timer) timer->exit();
480// std::cout << "A:" << std::endl << A;
481// std::cout << "remappedA:" << std::endl << remappedA;
482
483 // Repack the data of A, B, and C so DGEMM can be used
484 // note: if need_repack is true then the matrix has not been transposed
485 if (timer) timer->enter("repack1");
486 if (repack_scheme.need_repack_A()) repack(A, extA, intA, fixA, fixvalA);
487 if (repack_scheme.need_repack_B()) repack(B, intB, extB, fixB, fixvalB);
488 if (repack_scheme.need_repack_C() && !C_is_zero_on_entry) {
489 repack(C, extCA, extCB, fixC, fixvalC);
490 }
491 if (timer) timer->exit();
492
493// std::cout << "tA:" << transpose_A
494// << " tB:" << transpose_B
495// << " rA:" << need_repack_A
496// << " rB:" << need_repack_B
497// << " rC:" << need_repack_C
498// << std::endl;
499
500#if 0
501 // Fixed indices imply that a loop over those indices is been done
502 // external to this routine. If one array doesn't have all of the fixed
503 // indices that other arrays have, then repacking that array will result
504 // in extra work. The code below detects this case.
505 std::set<int> fixed_all, fixed_A, fixed_B, fixed_C;
506 for (int i=0; i<fixA.n(); i++) {
507 fixed_all.insert(fixvalA.block(i));
508 fixed_A.insert(fixvalA.block(i));
509 }
510 for (int i=0; i<fixB.n(); i++) {
511 fixed_all.insert(fixvalB.block(i));
512 fixed_B.insert(fixvalB.block(i));
513 }
514 for (int i=0; i<fixC.n(); i++) {
515 fixed_all.insert(fixvalC.block(i));
516 fixed_C.insert(fixvalC.block(i));
517 }
518 if (repack_scheme.need_repack_A()
519 && fixed_A.size() > 0
520 && fixed_A != fixed_all) {
521 std::cout << "PERFORMANCE WARNING: contract needed to repack A"
522 << " but B and/or C have different fixed indices"
523 << std::endl;
524 throw std::runtime_error("contract: performance exception");
525 }
526 if (repack_scheme.need_repack_B()
527 && fixed_B.size() > 0
528 && fixed_B != fixed_all) {
529 std::cout << "PERFORMANCE WARNING: contract needed to repack B"
530 << " but A and/or C have different fixed indices"
531 << std::endl;
532 throw std::runtime_error("contract: performance exception");
533 }
534 if (repack_scheme.need_repack_C()
535 && fixed_C.size() > 0
536 && fixed_C != fixed_all) {
537 std::cout << "PERFORMANCE WARNING: contract needed to repack C"
538 << " but A and/or B have different fixed indices"
539 << std::endl;
540 throw std::runtime_error("contract: performance exception");
541 }
542#endif
543
544 const typename Array<NB>::blockmap_t &
545 Bbm = B.blockmap();
546#ifdef USE_HASH
547 const typename Array<NB>::blockhash_t &
548 Bbh = B.blockhash();
549#endif
550 const typename Array<NC>::blockmap_t &
551 Cbm = C.blockmap();
552
553 BlockInfo<NA> Abi;
554 Abi.assign_blocks(fixA,fixvalA);
555
556 BlockInfo<NB> Bbi;
557 Bbi.zero();
558 Bbi.assign_blocks(fixB,fixvalB);
559
560#ifndef USE_HASH
561 typename Array<NB>::blockmap_t::const_iterator B_fixed_hint
562 = Bbm.lower_bound(Bbi);
563#endif
564
565 typename Array<NC>::blockmap_t::const_iterator C_begin, C_end;
566 BlockInfo<NC> Cbi_lb;
567 BlockInfo<NC> Cbi_ub;
568 for (int i=0; i<NC; i++) {
569 Cbi_lb.block(i) = 0;
570 Cbi_ub.block(i) = C.index(i).nblock();
571 }
572 Cbi_lb.assign_blocks(fixC,fixvalC);
573 Cbi_ub.assign_blocks(fixC,fixvalC);
574 C_begin = Cbm.lower_bound(Cbi_lb);
575 C_end = Cbm.upper_bound(Cbi_ub);
576
577 if (timer) timer->enter("C loop");
578 for (typename Array<NC>::blockmap_t::const_iterator
579 Citer = C_begin;
580 Citer != C_end;
581 Citer++) {
582 const BlockInfo<NC> &Cbi = Citer->first;
583 double *Cdata = Citer->second;
584 Abi.assign_blocks(extA, Cbi, extCA);
585 std::pair<
586 typename Array<NA>::cached_blockmap_t::const_iterator,
587 typename Array<NA>::cached_blockmap_t::const_iterator >
588 rangeA;
589#ifdef USE_BOUND
590 // cannot use equal range on remappedA because bound is used to sort
591 // rangeA = remappedAbm.equal_range(Abi);
592 Abi.set_bound(DBL_MAX);
593 rangeA.first = remappedAbm.lower_bound(Abi);
594 Abi.set_bound(0.0);
595 rangeA.second = remappedAbm.upper_bound(Abi);
596#else
597 rangeA = remappedAbm.equal_range(Abi);
598#endif
599 typename Array<NA>::cached_blockmap_t::const_iterator
600 firstA = rangeA.first,
601 fenceA = rangeA.second;
602 Bbi.assign_blocks(extB, Cbi, extCB);
603 blasint n_extB = Cbi.subset_size(C.indices(), extCB);
604 blasint n_extA = Cbi.subset_size(C.indices(), extCA);
605#ifdef USE_HASH
606 typename Array<NB>::blockhash_t::const_iterator Biter;
607#else
608 typename Array<NB>::blockmap_t::const_iterator Biter = Bbm.begin();
609#endif
610 if (timer) timer->enter("A loop");
611 for (typename Array<NA>::cached_blockmap_t::const_iterator
612 Aiter = firstA;
613 Aiter != fenceA;
614 Aiter++) {
615 const BlockInfo<NA> &Abi = Aiter->first;
616 double *Adata = Aiter->second;
617 Bbi.assign_blocks(intB, Abi, intA);
618#ifdef USE_HASH
619 Biter = Bbh.find(Bbi);
620 if (Biter == Bbh.end()) continue;
621#else
622 if (fixB.n() > 0) {
623#if USE_STL_MULTIMAP
624 Biter = Bbm.find(Bbi);
625#else
626 Biter = Bbm.find(B_fixed_hint, Bbi);
627#endif
628 }
629 else {
630 //blindly using a hint here makes this a bit slower
631 //Biter = Bbm.find(Biter, Bbi);
632 Biter = Bbm.find(Bbi);
633 }
634 if (Biter == Bbm.end()) continue;
635#endif
636 double *Bdata = Biter->second;
637 blasint n_int = Abi.subset_size(A.indices(), intA);
638
639 double one = 1.0;
640 if (timer) timer->enter("dgemm");
641
642 double t0 = cpu_walltime();
643
644 if (n_extA == 1 && n_int == 1) {
645 double tmp = ABfactor * Adata[0];
646 for (int i=0; i<n_extB; i++) {
647 Cdata[i] += tmp*Bdata[i];
648 }
649 }
650 else if (n_extA == 1 && n_extB == 1) {
651 double tmp = 0.0;
652 for (int i=0; i<n_int; i++) {
653 tmp += Adata[i]*Bdata[i];
654 }
655 Cdata[0] += ABfactor*tmp;
656 }
657 else if (n_int == 1 && n_extB == 1) {
658 double tmp = ABfactor*Bdata[0];
659 for (int i=0; i<n_extA; i++) {
660 Cdata[i] += Adata[i]*tmp;
661 }
662 }
663 else if (n_int == 1) {
664 if (repack_scheme.transpose_C()) {
665 for (int i=0,ij=0; i<n_extB; i++) {
666 for (int j=0; j<n_extA; j++,ij++) {
667 Cdata[ij] += ABfactor*Adata[j]*Bdata[i];
668 }
669 }
670 }
671 else {
672 for (int i=0,ij=0; i<n_extA; i++) {
673 for (int j=0; j<n_extB; j++,ij++) {
674 Cdata[ij] += ABfactor*Adata[i]*Bdata[j];
675 }
676 }
677 }
678 }
679 else if (n_extA == 1) {
680 if (repack_scheme.transpose_B()) {
681 for (int i=0,ij=0; i<n_extB; i++) {
682 double tmp = 0.0;
683 for (int j=0; j<n_int; j++,ij++) {
684 tmp += Adata[j]*Bdata[ij];
685 }
686 Cdata[i] += tmp * ABfactor;
687 }
688 }
689 else {
690 for (int i=0; i<n_extB; i++) {
691 double tmp = 0.0;
692 for (int j=0,ij=i; j<n_int; j++,ij+=n_extB) {
693 tmp += Adata[j]*Bdata[ij];
694 }
695 Cdata[i] += tmp * ABfactor;
696 }
697 }
698 }
699 else if (n_extB == 1) {
700 if (repack_scheme.transpose_A()) {
701 for (int i=0; i<n_extA; i++) {
702 double tmp = 0.0;
703 for (int j=0,ij=i; j<n_int; j++,ij+=n_extA) {
704 tmp += Bdata[j]*Adata[ij];
705 }
706 Cdata[i] += tmp * ABfactor;
707 }
708 }
709 else {
710 for (int i=0,ij=0; i<n_extA; i++) {
711 double tmp = 0.0;
712 for (int j=0; j<n_int; j++,ij++) {
713 tmp += Bdata[j]*Adata[ij];
714 }
715 Cdata[i] += tmp * ABfactor;
716 }
717 }
718 }
719 else if (repack_scheme.transpose_C()) {
720 const char *tA = "T";
721 blasint lda = n_int;
722 if (repack_scheme.transpose_A()) { tA = "N"; lda = n_extA; }
723
724 const char *tB = "T";
725 blasint ldb = n_extB;
726 if (repack_scheme.transpose_B()) { tB = "N"; ldb = n_int; }
727
728 blasint ldc = n_extA;
729
730// std::cout << " tA: " << tA
731// << " tB: " << tB
732// << " nr: " << n_extA
733// << " nc: " << n_extB
734// << " nl: " << n_int
735// << " lda: " << lda
736// << " ldb: " << ldb
737// << " ldc: " << ldc
738// << std::endl;
739
740 F77_DGEMM(tA, tB, &n_extA, &n_extB, &n_int,
741 &ABfactor,Adata,&lda,Bdata,&ldb,
742 &one,Cdata,&ldc);
743 }
744 else {
745 const char *tA = "N";
746 blasint lda = n_int;
747 if (repack_scheme.transpose_A()) { tA = "T"; lda = n_extA; }
748
749 const char *tB = "N";
750 blasint ldb = n_extB;
751 if (repack_scheme.transpose_B()) { tB = "T"; ldb = n_int; }
752
753 blasint ldc = n_extB;
754
755 F77_DGEMM(tB, tA, &n_extB, &n_extA, &n_int,
756 &ABfactor,Bdata,&ldb,Adata,&lda,
757 &one,Cdata,&ldc);
758 }
759#ifdef USE_COUNT_DGEMM
760 count_dgemm(n_extA, n_int, n_extB,
761 cpu_walltime()-t0);
762#endif
763 if (timer) timer->exit();
764 }
765 if (timer) timer->exit();
766 }
767 if (timer) timer->exit();
768
769 // Repack the data of A, B, and C to the orginal data layout
770 if (timer) timer->enter("repack2");
771 if (clear_A_after_use) A.clear();
772 else {
773 if (repack_scheme.need_repack_A()) {
774 repack(A, extA, intA, fixA, fixvalA, true);
775 }
776 }
777
778 if (clear_B_after_use) B.clear();
779 else {
780 if (repack_scheme.need_repack_B()) {
781 repack(B, intB, extB, fixB, fixvalB, true);
782 }
783 }
784
785 if (repack_scheme.need_repack_C()) {
786 repack(C, extCA, extCB, fixC, fixvalC, true);
787 }
788 if (timer) timer->exit();
789
790 if (timer) timer->enter("bounds");
791 C.compute_bounds();
792 if (timer) timer->exit();
793}
794
796template <int N>
797inline double
798scalar_contract(
799 Array<N> &c,
800 Array<N> &a, const IndexList &alist)
801{
802 // Consistency checks
803 if (alist.n() != N) {
804 throw std::invalid_argument(
805 "sma::scalar_contract: # of indices inconsistent");
806 }
807 for (int i=0; i<N; i++) {
808 if (c.index(i) != a.index(alist.i(i)))
809 throw std::invalid_argument(
810 "sma::scalar_contract: indices don't agree");
811 }
812
813 bool same_index_order = alist.is_identity();
814
815 double r = 0.0;
816 const typename Array<N>::blockmap_t &amap = a.blockmap();
817 const typename Array<N>::blockmap_t &cmap = c.blockmap();
818 IndexList clist = alist.reverse_mapping();
819 bool use_hint;
820 if (clist.i(0) == 0) use_hint = true;
821 else use_hint = false;
822 typename Array<N>::blockmap_t::const_iterator citer = cmap.begin();
823 for (typename Array<N>::blockmap_t::const_iterator aiter = amap.begin();
824 aiter != amap.end();
825 aiter++) {
826 BlockInfo<N> cbi(aiter->first,clist);
827#if USE_STL_MULTIMAP
828 citer = cmap.find(cbi);
829#else
830 if (use_hint) citer = cmap.find(citer,cbi);
831 else citer = cmap.find(cbi);
832#endif
833 if (citer == cmap.end()) continue;
834 double *cdata = citer->second;
835 double *adata = aiter->second;
836 if (same_index_order) {
837 int sz = c.block_size(cbi);
838 for (int i=0; i<sz; i++) r += cdata[i] * adata[i];
839 }
840 else {
841 BlockIter<N> cbiter(c.indices(),cbi);
842 int coff = 0;
843 for (cbiter.start(); cbiter.ready(); cbiter++,coff++) {
844 r += cdata[coff] * adata[cbiter.subset_offset(alist)];
845 }
846 }
847 }
848
849 return r;
850}
851
858template <int NC, int NA, int NB>
859inline void
860contract_union(
861 Array<NC> &C, const IndexList &extCA, const IndexList &extCB,
862 const IndexList &fixC, const BlockInfo<NC> &fixvalC,
863 Array<NA> &A, const IndexList &extA, const IndexList &intA,
864 const IndexList &fixA, const BlockInfo<NA> &fixvalA,
865 Array<NB> &B, const IndexList &extB, const IndexList &intB,
866 const IndexList &fixB, const BlockInfo<NB> &fixvalB)
867{
868 // Consistency checks
869 if (extA.n() + extB.n() != NC - fixC.n()) {
870 std::cerr << "NA = " << NA << std::endl;
871 std::cerr << "intA = " << intA << std::endl;
872 std::cerr << "extA = " << extA << std::endl;
873 std::cerr << "fixA = " << fixA << std::endl;
874 std::cerr << "NB = " << NB << std::endl;
875 std::cerr << "intB = " << intB << std::endl;
876 std::cerr << "extB = " << extB << std::endl;
877 std::cerr << "fixB = " << fixB << std::endl;
878 std::cerr << "NC = " << NC << std::endl;
879 std::cerr << "extCA = " << extCA << std::endl;
880 std::cerr << "extCB = " << extCB << std::endl;
881 std::cerr << "fixC = " << fixC << std::endl;
882 throw std::invalid_argument("contract_union: Number of externals on A + B != C");
883 }
884 if (extCA.n() + extCB.n() != NC - fixC.n()) {
885 throw std::invalid_argument("contract_union: Number of indices on C inconsistent");
886 }
887 if (intA.n() != intB.n()) {
888 throw std::invalid_argument(
889 "contract_union: Number of internals on A and B inconsistent");
890 }
891 if (intA.n() + extA.n() != NA - fixA.n()) {
892 std::cerr << "NA = " << NA << std::endl;
893 std::cerr << "extA.n() = " << extA.n() << std::endl;
894 std::cerr << "fixA.n() = " << fixA.n() << std::endl;
895 std::cerr << "intA.n() = " << intA.n() << " (";
896 for (int i=0; i<intA.n(); i++) {
897 std::cerr << " " << intA.i(i);
898 }
899 std::cerr << ")" << std::endl;
900 throw std::invalid_argument("contract_union: Number of indices on A inconsistent");
901 }
902 if (intB.n() + extB.n() != NB - fixB.n()) {
903 throw std::invalid_argument("contract_union: Number of indices on B inconsistent");
904 }
905 for (int i=0; i<extB.n(); i++) {
906 if (B.index(extB.i(i)) != C.index(extCB.i(i))) {
907 throw std::invalid_argument("contract_union: Range conflict between B and C");
908 }
909 }
910 for (int i=0; i<intA.n(); i++) {
911 if (A.index(intA.i(i)) != B.index(intB.i(i))) {
912 throw std::invalid_argument("contract_union: Range conflict between A and B");
913 }
914 }
915
916
917 // Remap the blocks of A so that it is sorted by its internal indices and
918 // any fixed indices. The fixed indices appear first (are most
919 // significant wrt the ordering) so we can get the iterator bounds for
920 // relevant internal indices more easily. Data is not moved.
921 IndexList cmpAlist(fixA, intA);
922 IndexListLess<NA> cmpA(cmpAlist);
923 typename Array<NA>::cached_blockmap_t remappedAbm_local(cmpA);
924 typename Array<NA>::cached_blockmap_t *remappedAbm_ptr;
925 if (fixA.n() == 0 && A.use_blockmap_cache()) {
926 remappedAbm_ptr = &A.blockmap_cache_entry(cmpAlist);
927 }
928 else {
929 remap(remappedAbm_local, A, fixA, fixvalA);
930 remappedAbm_ptr = &remappedAbm_local;
931 }
932 typename Array<NA>::cached_blockmap_t &remappedAbm=*remappedAbm_ptr;
933
934 // Remap the blocks of B so that it is sorted by its internal indices
935 // and any fixed indices. Data is not moved.
936 IndexList cmpBlist(intB, fixB);
937 IndexListLess<NB> cmpB(cmpBlist);
938 typename Array<NB>::cached_blockmap_t remappedBbm_local(cmpB);
939 typename Array<NB>::cached_blockmap_t *remappedBbm_ptr;
940 if (fixB.n() == 0 && B.use_blockmap_cache()) {
941 remappedBbm_ptr = &B.blockmap_cache_entry(cmpBlist);
942 }
943 else {
944 remap(remappedBbm_local, B, fixB, fixvalB);
945 remappedBbm_ptr = &remappedBbm_local;
946 }
947 typename Array<NB>::cached_blockmap_t &remappedBbm=*remappedBbm_ptr;
948
949
950// std::cout << "beginning loops" << std::endl;
951// std::cout << "extA = " << extA << std::endl;
952// std::cout << "extB = " << extB << std::endl;
953// std::cout << "extCA = " << extCA << std::endl;
954// std::cout << "extCB = " << extCB << std::endl;
955
956 BlockInfo<NA> ablockinfo;
957 for (int i=0; i<NA; i++) ablockinfo.block(i) = 0;
958#ifdef USE_BOUND
959 ablockinfo.set_bound(DBL_MAX);
960#endif
961 ablockinfo.assign_blocks(fixA, fixvalA);
962 typename Array<NA>::cached_blockmap_t::const_iterator abegin;
963 abegin = remappedAbm.lower_bound(ablockinfo);
964
965 BlockInfo<NB> bblockinfo;
966 bblockinfo.assign_blocks(fixB, fixvalB);
967
968 BlockInfo<NC> cblockinfo;
969 cblockinfo.assign_blocks(fixC, fixvalC);
970
971 while (abegin != remappedAbm.end()) {
972 ablockinfo = abegin->first;
973
974 // if there are fixed indices, then abegin might be beyond
975 // the fixed indices that we are interested in
976 if (!ablockinfo.equiv_blocks(fixA, fixvalA)) break;
977
978#ifdef USE_BOUND
979#if 0 && USE_BOUNDS_IN_CONTRACT_UNION
980 if (B.bound() < DBL_EPSILON) {
981 ablockinfo.set_bound(C.tolerance()/DBL_EPSILON);
982 }
983 else {
984 ablockinfo.set_bound(C.tolerance()/B.bound());
985 }
986#else
987 ablockinfo.set_bound(0.0);
988#endif
989#endif
990 typename Array<NA>::cached_blockmap_t::const_iterator
991 afence = remappedAbm.upper_bound(ablockinfo);
992 bblockinfo.assign_blocks(intB,ablockinfo,intA);
993 std::pair<typename Array<NB>::cached_blockmap_t::const_iterator,
994 typename Array<NB>::cached_blockmap_t::const_iterator>
995 brange;
996 // cannot use equal_range on remappedB since bounds are used to sort
997 // brange = remappedBbm.equal_range(bblockinfo);
998#ifdef USE_BOUND
999 bblockinfo.set_bound(DBL_MAX);
1000#endif
1001 brange.first = remappedBbm.lower_bound(bblockinfo);
1002#ifdef USE_BOUND
1003#if 0 && USE_BOUNDS_IN_CONTRACT_UNION
1004 if (A.bound() < DBL_EPSILON) {
1005 bblockinfo.set_bound(C.tolerance()/DBL_EPSILON);
1006 }
1007 else {
1008 bblockinfo.set_bound(C.tolerance()/A.bound());
1009 }
1010#else
1011 bblockinfo.set_bound(0.0);
1012#endif
1013#endif
1014 brange.second = remappedBbm.upper_bound(bblockinfo);
1015 typename Array<NB>::cached_blockmap_t::const_iterator
1016 bbegin = brange.first,
1017 bfence = brange.second;
1018// std::cout << " in internal loop" << std::endl;
1019 for (typename Array<NA>::cached_blockmap_t::const_iterator
1020 aiter = abegin;
1021 aiter != afence;
1022 aiter++) {
1023#ifdef USE_BOUND
1024 double a_block_bound = aiter->first.bound();
1025#endif
1026 cblockinfo.assign_blocks(extCA,aiter->first,extA);
1027// std::cout << " A blocks: " << aiter->first << std::endl;
1028 for (typename Array<NB>::cached_blockmap_t::const_iterator
1029 biter = bbegin;
1030 biter != bfence;
1031 biter++) {
1032#ifdef USE_BOUND
1033#if 0 && USE_BOUNDS_IN_CONTRACT_UNION
1034 if (a_block_bound * biter->first.bound() < C.tolerance()) {
1035 continue;
1036 }
1037#endif
1038#endif
1039 cblockinfo.assign_blocks(extCB,biter->first,extB);
1040// std::cout << " B blocks: " << biter->first
1041// << " adding " << cblockinfo << std::endl;
1042 C.add_unallocated_block(cblockinfo);
1043 }
1044 }
1045#ifdef USE_BOUND
1046 ablockinfo.set_bound(0.0);
1047#endif
1048 abegin = remappedAbm.upper_bound(ablockinfo);
1049 }
1050}
1051
1052}
1053
1054}
1055
1056#endif
A template class that maintains references counts.
Definition ref.h:361
Implements a block sparse tensor.
Definition sma.h:1247
BlockInfo stores info about a block of data.
Definition sma.h:200
An IndexList is a vector of indices.
Definition sma.h:160
Determine the cost of repacking arrays for a contraction.
Definition contract.h:105
void assign_indices(IndexList &extA, IndexList &intA, IndexList &intB, IndexList &extB, IndexList &extCA, IndexList &extCB)
Assign the contraction indices.
Definition contract.h:228
bool need_repack_C() const
Returns true if C needs repacked in the current scheme.
Definition contract.h:220
bool transpose_C() const
Returns true if C needs transposed in the current scheme.
Definition contract.h:222
double cost() const
Returns the cost of the current scheme.
Definition contract.h:224
bool transpose_B() const
Returns true if B needs transposed in the current scheme.
Definition contract.h:218
RepackScheme(const Array< NA > *A, const IndexList &extA, const IndexList &intA, const IndexList &fixA, const BlockInfo< NA > *fixvalA, const Array< NB > *B, const IndexList &intB, const IndexList &extB, const IndexList &fixB, const BlockInfo< NB > *fixvalB, const Array< NC > *C, const IndexList &extCA, const IndexList &extCB, const IndexList &fixC, const BlockInfo< NC > *fixvalC, int n_C_repack)
Create the RepackScheme for a given contraction.
Definition contract.h:175
bool transpose_A() const
Returns true if A needs transposed in the current scheme.
Definition contract.h:214
void driver(int i)
Set the array that determines the index ordering to C (i==0), A (i==1), or B (i==2) This will update ...
Definition contract.h:196
bool need_repack_B() const
Returns true if B needs repacked in the current scheme.
Definition contract.h:216
bool need_repack_A() const
Returns true if A needs repacked in the current scheme.
Definition contract.h:212
Contains all MPQC code up to version 3.
Definition mpqcin.h:14
void count_dgemm(int n, int l, int m, double t)
Records information about the time take to perform a DGEMM operation.

Generated at Wed Sep 25 2024 02:45:29 for MPQC 3.0.0-alpha using the documentation package Doxygen 1.12.0.