MPQC 3.0.0-alpha
Loading...
Searching...
No Matches
thread.hpp
1#ifndef MPQC_ARRAY_THREAD_HPP
2#define MPQC_ARRAY_THREAD_HPP
3
4#include "mpqc/mpi.hpp"
5#include "mpqc/utility/timer.hpp"
6#include "mpqc/array/forward.hpp"
7#include "mpqc/array/socket.hpp"
8
9#include <memory>
10
11#include <boost/thread/thread.hpp>
12#include "boost/thread/tss.hpp"
13
14namespace mpqc {
15namespace detail {
16namespace ArrayServer {
17
18 struct array_proxy {
19
20 static const size_t BUFFER = (32<<20);
21
22 struct Descriptor {
23 ArrayBase *object;
24 size_t rank;
25 size_t count;
26 };
27
28 struct Segment {
29 size_t size;
30 std::vector<range> extents;
31 };
32
33 array_proxy(ArrayBase *object,
34 const std::vector<range> &r) {
35 this->object = object;
36 this->rank = r.size();
37 this->count = 0;
38
39 size_t block = 1;
40 size_t N = BUFFER/sizeof(double);
41
42 for (int i = 0; i < rank-1; ++i)
43 block *= r[i].size();
44 MPQC_ASSERT(block < N);
45
46 BOOST_FOREACH (range rj, split(r.back(), N/block)) {
47 for (int i = 0; i < rank-1; ++i) {
48 data.push_back(*r[i].begin());
49 data.push_back(*r[i].end());
50 }
51 data.push_back(*rj.begin());
52 data.push_back(*rj.end());
53 ++count;
54 }
55 }
56
57 array_proxy(Descriptor ds) {
58 this->object = ds.object;
59 this->rank = ds.rank;
60 this->count = ds.count;
61 this->data.resize(2*rank*count);
62 }
63
64 Descriptor descriptor() const {
65 Descriptor ds;
66 ds.object = this->object;
67 ds.rank = this->rank;
68 ds.count = this->count;
69 return ds;
70 }
71
72 std::vector<Segment> segments() const {
73 std::vector<Segment> segments;
74 auto data = this->data.begin();
75 for (int j = 0; j < count; ++j) {
76 std::vector<range> r;
77 size_t size = 1;
78 for (int i = 0; i < rank; ++i) {
79 r.push_back(range(data[0], data[1]));
80 size *= r.back().size();
81 data += 2;
82 }
83 segments.push_back(Segment());
84 segments.back().size = size;
85 segments.back().extents = r;
86 }
87 return segments;
88 }
89
90 public:
91
92 ArrayBase *object;
93 int rank, count;
94 std::vector<int> data;
95
96 };
97
98
99 struct Message {
100
101 enum Request { INVALID = 0x0,
102 JOIN = 0x100,
103 SYNC = 0x101,
104 WRITE = 0x200,
105 READ = 0x201
106 };
107
108 Request request;
109 array_proxy::Descriptor dataspace;
110 int src, tag;
111
112 explicit Message(int tag = 0, Request r = INVALID)
113 : tag(tag), request(r) {}
114
115 Message(int tag, Request op, array_proxy::Descriptor ds) {
116 this->request = op;
117 this->dataspace = ds;
118 this->tag = tag;
119 }
120
121 };
122
123
124 struct Thread : boost::noncopyable {
125
126 private:
127 explicit Thread(MPI::Comm comm)
128 : comm_(comm), tag_(1<<20)
129 {
130 buffer_ = malloc(array_proxy::BUFFER);
131 this->socket_.start();
132 this->servers_ = comm.allgather(this->socket_.address());
133 this->thread_ = new boost::thread(&Thread::run, this);
134 }
135
136 public:
137
138 enum {
139 RECV_MASK = 1<<22,
140 SEND_MASK = 1<<23
141 };
142
143 ~Thread() {
144 //comm_.printf("~Thread\n");
145 sync();
146 join();
147 delete this->thread_;
148 free(this->buffer_);
149 }
150
151 void join() {
152 int tag = next();
153 send(Message(tag, Message::JOIN), comm_.rank());
154 this->thread_->join();
155 //comm_.printf("thread joined\n");
156 }
157
158 void sync() const {
159 int tag = next();
160 send(Message(tag, Message::SYNC), comm_.rank());
161 comm_.recv<Message>(comm_.rank(), tag | SEND_MASK);
162 }
163
164 void send(Message msg, int proc) const {
165 msg.src = comm_.rank();
166 ArraySocket::send(&msg, this->servers_.at(proc));
167 //send(&msg, sizeof(Message), MPI_BYTE, proc, this->tag_);
168 }
169
171 void send(const void *data,
172 size_t count, MPI_Datatype type,
173 int proc, int tag) const {
174 MPQC_ASSERT(!(tag & SEND_MASK));
175 MPQC_ASSERT(!(tag & RECV_MASK));
176 comm_.send(data, count, type, proc, tag | RECV_MASK);
177 }
178
180 void recv(void *data,
181 size_t count, MPI_Datatype type,
182 int proc, int tag) const {
183 MPQC_ASSERT(!(tag & SEND_MASK));
184 MPQC_ASSERT(!(tag & RECV_MASK));
185 comm_.recv(data, count, type, proc, tag | SEND_MASK);
186 }
187
188 static std::shared_ptr<Thread>& instance() {
189 static std::shared_ptr<Thread> thread;
190 if (!thread.get()) {
191 MPI::initialize(MPI_THREAD_MULTIPLE);
192 thread.reset(new Thread(MPI::Comm(MPI_COMM_WORLD)));
193 }
194 return thread;
195 }
196
197 static void run(Thread *thread) {
198 //mutex_.unlock();
199 thread->loop();
200 //mutex_.lock();
201 }
202
203 public:
204
205 int next() const {
206 const unsigned int N = 1 << 21;
207 boost::mutex::scoped_lock lock(mutex_);
208 return int(N + (next_++ % N));
209 }
210
211 int translate(MPI::Comm comm1, int rank1) const {
212 int rank2;;
213 MPI_Group group1, group2;
214 MPI_Comm_group(comm1, &group1);
215 MPI_Comm_group(this->comm_, &group2);
216 MPI_Group_translate_ranks(group1, 1, &rank1, group2, &rank2);
217 MPQC_ASSERT(rank2 != MPI_UNDEFINED);
218 return rank2;
219 }
220
221 private:
222
223 void loop() {
224
225 // std::cout << "thread/rank: "
226 // << boost::this_thread::get_id() << "/" << comm_.rank()
227 // << std::endl;
228
229 while (1) {
230
231 Message msg;
232
233 this->socket_.wait(&msg);
234
235 //std::cout << MPI::get_processor_name() << std::endl;
236
237 // MPI_Status status;
238 // status = comm_.recv(&msg, sizeof(Message), MPI_BYTE,
239 // MPI_ANY_SOURCE, this->tag_ | RECV_MASK);
240
241 // MPI_Request request =
242 // comm_.irecv(&msg, sizeof(Message), MPI_BYTE,
243 // MPI_ANY_SOURCE, this->tag_);
244 // status = MPI::wait(request, 10);
245
246 //comm_.printf("Message received %i\n", msg.request);
247
248 if (msg.request == Message::READ) {
249 //printf("Message::READ\n");
250 read(msg, msg.dataspace);
251 continue;
252 }
253 if (msg.request == Message::WRITE) {
254 //printf("Message::WRITE\n");
255 write(msg, msg.dataspace);
256 continue;
257 }
258 if (msg.request == Message::SYNC) {
259 sync(msg);
260 continue;
261 }
262 if (msg.request == Message::JOIN) {
263 //comm_.printf("Message::JOIN\n");
264 //comm_.ssend(msg, status.MPI_SOURCE, msg.tag);
265 break;
266 }
267 printf("invalid message request %i\n", msg.request);
268 throw std::runtime_error("invalid message");
269 }
270 }
271
272 private:
273
274 MPI::Comm comm_;
275 void *buffer_;
276 boost::thread *thread_;
277
278 ArraySocket socket_;
279 std::vector<ArraySocket::Address> servers_;
280
281 int tag_;
282 mutable unsigned int next_;
283 mutable boost::mutex mutex_;
284
285 void sync(Message msg) {
286 //comm_.printf("thread message/sync dst=%i tag=%i\n", status.MPI_SOURCE, tag);
287 comm_.send(Message(Message::SYNC), msg.src, msg.tag | SEND_MASK);
288 }
289
290 void read(Message msg, array_proxy::Descriptor ds) {
291 io<Message::READ>(array_proxy(ds), msg.src, msg.tag);
292 }
293
294 void write(Message msg, array_proxy::Descriptor ds) {
295 io<Message::WRITE>(array_proxy(ds), msg.src, msg.tag);
296 }
297
298 private:
299
300 template<Message::Request OP>
301 void io(array_proxy ds, int proc, int tag) {
302 // comm_.printf("thread recv descriptor bytes=%lu src=%i tag=%i\n",
303 // ds.data.size()*sizeof(int), proc, tag);
304 comm_.recv(&ds.data[0], ds.data.size(), MPI_INT, proc, tag | RECV_MASK);
305 const auto &segments = ds.segments();
306 double* buffer = static_cast<double*>(this->buffer_);
307 for (int i = 0; i < segments.size(); ++i) {
308 const auto &extents = segments[i].extents;
309 if (OP == Message::WRITE) {
310 // comm_.printf("thread recv/write bytes=%lu, dst=%i tag=%i\n",
311 // segments[i].size*8, proc, tag);
312 comm_.recv(buffer, segments[i].size, MPI_DOUBLE,
313 proc, tag | RECV_MASK);
314 //std::cout << "write " << o.extents() << std::endl;
315 ds.object->put(extents, buffer);
316 }
317 if (OP == Message::READ) {
318 // comm_.printf("thread read/send bytes=%lu, dst=%i tag=%i\n",
319 // segments[i].size*8, proc, tag);
320 //std::cout << "read " << o.extents() << std::endl;
321 ds.object->get(extents, buffer);
322 comm_.send(buffer, segments[i].size, MPI_DOUBLE,
323 proc, tag | SEND_MASK);
324 }
325 }
326 }
327
328 };
329
330
331}
332} // namespace detail
333} // namespace mpqc
334
335namespace mpqc {
336namespace detail {
337
339
343
344 explicit array_thread_comm(MPI_Comm comm)
345 : comm_(comm)
346 {
347 thread_ = Thread::instance();
348 sync();
349 //printf("mpqc::Comm initialized\n");
350 }
351
352 Thread& thread() {
353 return *this->thread_;
354 }
355
356 void sync() const {
357 //comm_.printf("sync\n");
358 comm_.barrier();
359 thread_->sync();
360 comm_.barrier();
361 }
362
363 template<typename T>
364 void write(const T *data, ArrayBase *object,
365 const std::vector<range> &r, int rank) const {
366 //printf("Comm::write\n");
367 io<Message::WRITE>((T*)data, object, r, rank);
368 }
369
370 template<typename T>
371 void read(T *data, ArrayBase *object,
372 const std::vector<range> &r, int rank) const {
373 //printf("Comm::read\n");
374 io<Message::READ>(data, object, r, rank);
375 }
376
377 MPI::Comm comm() const {
378 return comm_;
379 }
380
381 private:
382
383 MPI::Comm comm_;
384 std::shared_ptr<Thread> thread_;
385
386 int tag() const {
387 static boost::thread_specific_ptr<int> tag;
388 if (!tag.get()) tag.reset(new int(thread_->next()));
389 return *tag;
390 }
391
392 template<Message::Request OP, typename T>
393 void io(T* buffer, ArrayBase *object,
394 const std::vector<range> &r,
395 int proc) const
396 {
397
398 static_assert(OP == Message::WRITE ||
399 OP == Message::READ,
400 "invalid OP");
401
402 //proc = thread_->translate(this->comm_, proc);
403 int tag = this->tag();
404 array_proxy ds(object, r);
405
406 {
407 //MPQC_PROFILE_LINE;
408 //printf("message dst=%i tag=%i\n", proc, thread_->tag());
409 thread_->send(Message(tag, OP, ds.descriptor()), proc);
410 }
411
412 {
413 //MPQC_PROFILE_LINE;
414 //printf("descriptor dst=%i tag=%i\n", proc, tag);
415 thread_->send(&ds.data[0], ds.data.size(), MPI_INT, proc, tag);
416 }
417
418 auto segments = ds.segments();
419 std::vector<MPI_Request> requests(ds.count);
420
421 mpqc::timer t;
422 size_t total = 0;
423 for (int i = 0; i < segments.size(); ++i) {
424 size_t size = segments[i].size;
425 if (OP == Message::READ) {
426 //MPQC_PROFILE_LINE;
427 // printf("recv segment %i bytes=%lu proc=%i tag=%i\n",
428 // i, size*sizeof(T), proc, tag);
429 /*requests[i] = i*/
430 thread_->recv(buffer, size*sizeof(T), MPI_BYTE, proc, tag);
431 }
432 if (OP == Message::WRITE) {
433 //MPQC_PROFILE_LINE;
434 // printf("send segment %i bytes=%lu proc=%i tag=%i\n",
435 // i, size*sizeof(T), proc, tag);
436 /*requests[i] = i*/
437 thread_->send(buffer, size*sizeof(T), MPI_BYTE, proc, tag);
438 }
439 buffer += size;
440 total += size;
441 }
442 //MPI_Waitall(n, &requests[0], MPI_STATUSES_IGNORE);
443 // double mb = (total*sizeof(double))/1e6;
444 // printf("I/O %s: %f Mbytes, %f Mbytes/s\n",
445 // (OP == Message::WRITE ? "WRITE" : "READ"), mb, mb/t);
446 }
447
448 };
449
450
451}
452}
453
454#endif // MPQC_ARRAY_THREAD_HPP
std::vector< range > split(range r, size_t N)
Split range into blocks of size N.
Definition range.hpp:94
Contains new MPQC code since version 3.
Definition integralenginepool.hpp:37
MPI_Comm object wrapper/stub.
Definition comm.hpp:14
Definition forward.hpp:23
Definition thread.hpp:99
Definition thread.hpp:124
void recv(void *data, size_t count, MPI_Datatype type, int proc, int tag) const
recv from server thread
Definition thread.hpp:180
void send(const void *data, size_t count, MPI_Datatype type, int proc, int tag) const
send to server thread
Definition thread.hpp:171
Definition thread.hpp:338
Definition range.hpp:25
Definition timer.hpp:9

Generated at Wed Sep 25 2024 02:45:30 for MPQC 3.0.0-alpha using the documentation package Doxygen 1.12.0.