1// Copyright 2015 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// multi_thread_gemm.h: Multi-threaded GEMM entry point.
16// Readers note: To understand this file, it is useful to first
17// read and understand the much simpler single_thread_gemm.h.
18
19#ifndef GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
20#define GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
21
22#include <pthread.h>
23#include <unistd.h>
24#include <vector>
25
26#include "single_thread_gemm.h"
27
28namespace gemmlowp {
29
30#ifdef GEMMLOWP_ALLOW_INLINE_ASM
31// Where inline asm is allowed, we use some busy-waiting,
32// preferably implemented using NOP instructions.
33const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
34
35#define GEMMLOWP_NOP "nop\n"
36
37#define GEMMLOWP_STRING_CONCAT_4(X) X X X X
38#define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP)
39#define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
40#define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)
41#define GEMMLOWP_NOP256 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP64)
42
43inline int Do256NOPs() {
44  asm volatile(GEMMLOWP_NOP256);
45  return 256;
46}
47
48#undef GEMMLOWP_STRING_CONCAT_4
49#undef GEMMLOWP_NOP256
50#undef GEMMLOWP_NOP64
51#undef GEMMLOWP_NOP16
52#undef GEMMLOWP_NOP4
53#undef GEMMLOWP_NOP
54
55#else  // not GEMMLOWP_ALLOW_INLINE_ASM
56
57// It is nontrivial to implement a good busy-waiting without
58// using asm; NOP instructions have the least side effects
59// and the lowest power usage; and since the whole busy-waiting
60// story is an optimization, it's not very interesting anyway
61// in places where we're slow anyway due to not being able to
62// use our inline asm kernels.
63
64const int kMaxBusyWaitNOPs = 0;
65inline int Do256NOPs() { return 0; }
66
67#endif  // not GEMMLOWP_ALLOW_INLINE_ASM
68
69inline void WriteBarrier() {
70#ifdef GEMMLOWP_ARM_32
71  MemoryBarrier();
72#elif defined(GEMMLOWP_ARM_64)
73  asm volatile("dmb ishst" ::: "memory");
74#elif defined(GEMMLOWP_X86)
75  asm volatile("sfence" ::: "memory");
76#elif defined(__mips__)
77  MemoryBarrier();
78#else
79#error "Unsupported architecture for WriteBarrier."
80#endif
81}
82
83inline void ReadBarrier() {
84#ifdef GEMMLOWP_ARM_32
85  MemoryBarrier();
86#elif defined(GEMMLOWP_ARM_64)
87  asm volatile("dmb ishld" ::: "memory");
88#elif defined(GEMMLOWP_X86)
89  asm volatile("lfence" ::: "memory");
90#elif defined(__mips__)
91  MemoryBarrier();
92#else
93#error "Unsupported architecture for ReadBarrier."
94#endif
95}
96
97// Waits until *var != initial_value.
98//
99// Returns the new value of *var. The guarantee here is that
100// the return value is different from initial_value, and that that
101// new value has been taken by *var at some point during the
102// execution of this function. There is no guarantee that this is
103// still the value of *var when this function returns, since *var is
104// not assumed to be guarded by any lock.
105//
106// First does some busy-waiting for a fixed number of no-op cycles,
107// then falls back to passive waiting for the given condvar, guarded
108// by the given mutex.
109//
110// The idea of doing some initial busy-waiting is to help get
111// better and more consistent multithreading benefits for small GEMM sizes.
112// Busy-waiting help ensuring that if we need to wake up soon after having
113// started waiting, then we can wake up quickly (as opposed to, say,
114// having to wait to be scheduled again by the OS). On the other hand,
115// we must still eventually revert to passive waiting for longer waits
116// (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
117// so as to avoid permanently spinning.
118//
119template <typename T>
120T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond,
121                        pthread_mutex_t* mutex) {
122  int nops = 0;
123  // First, trivial case where the variable already changed value.
124  T new_value = *var;
125  if (new_value != initial_value) {
126    return new_value;
127  }
128  // Then try busy-waiting.
129  while (nops < kMaxBusyWaitNOPs) {
130    nops += Do256NOPs();
131    new_value = *var;
132    if (new_value != initial_value) {
133      return new_value;
134    }
135  }
136  // Finally, do real passive waiting.
137  pthread_mutex_lock(mutex);
138  new_value = *var;
139  if (new_value == initial_value) {
140    pthread_cond_wait(cond, mutex);
141    new_value = *var;
142    assert(new_value != initial_value);
143  }
144  pthread_mutex_unlock(mutex);
145  return new_value;
146}
147
148// A BlockingCounter lets one thread to wait for N events to occur.
149// This is how the master thread waits for all the worker threads
150// to have finished working.
151class BlockingCounter {
152 public:
153  BlockingCounter()
154      : cond_(PTHREAD_COND_INITIALIZER),
155        mutex_(PTHREAD_MUTEX_INITIALIZER),
156        count_(0),
157        initial_count_(0) {}
158
159  // Sets/resets the counter; initial_count is the number of
160  // decrementing events that the Wait() call will be waiting for.
161  void Reset(std::size_t initial_count) {
162    pthread_mutex_lock(&mutex_);
163    assert(count_ == 0);
164    initial_count_ = initial_count;
165    count_ = initial_count_;
166    pthread_mutex_unlock(&mutex_);
167  }
168
169  // Decrements the counter; if the counter hits zero, signals
170  // the thread that was waiting for that, and returns true.
171  // Otherwise (if the decremented count is still nonzero),
172  // returns false.
173  bool DecrementCount() {
174    pthread_mutex_lock(&mutex_);
175    assert(count_ > 0);
176    count_--;
177    if (count_ == 0) {
178      pthread_cond_signal(&cond_);
179    }
180    bool retval = count_ == 0;
181    pthread_mutex_unlock(&mutex_);
182    return retval;
183  }
184
185  // Waits for the N other threads (N having been set by Reset())
186  // to hit the BlockingCounter.
187  void Wait() {
188    ScopedProfilingLabel label("BlockingCounter::Wait");
189    while (count_) {
190      MemoryBarrier();
191      const std::size_t count_value = count_;
192      if (count_value) {
193        WaitForVariableChange(&count_, count_value, &cond_, &mutex_);
194      }
195    }
196  }
197
198 private:
199  pthread_cond_t cond_;
200  pthread_mutex_t mutex_;
201  std::size_t count_;
202  std::size_t initial_count_;
203};
204
205// A workload for a worker.
206struct Task {
207  Task() : local_allocator(nullptr) {}
208  virtual ~Task() {}
209  virtual void Run() const = 0;
210  Allocator* local_allocator;
211};
212
213// A worker thread.
214class Worker {
215 public:
216  enum class State {
217    ThreadStartup,  // The initial state before the thread main loop runs.
218    Ready,          // Is not working, has not yet received new work to do.
219    HasWork,        // Has work to do.
220    ExitAsSoonAsPossible  // Should exit at earliest convenience.
221  };
222
223  explicit Worker(BlockingCounter* counter_to_decrement_when_ready)
224      : task_(nullptr),
225        state_cond_(PTHREAD_COND_INITIALIZER),
226        state_mutex_(PTHREAD_MUTEX_INITIALIZER),
227        state_(State::ThreadStartup),
228        counter_to_decrement_when_ready_(counter_to_decrement_when_ready) {
229    pthread_create(&thread_, nullptr, ThreadFunc, this);
230  }
231
232  ~Worker() {
233    ChangeState(State::ExitAsSoonAsPossible);
234    pthread_join(thread_, nullptr);
235  }
236
237  // Changes State; may be called from either the worker thread
238  // or the master thread; however, not all state transitions are legal,
239  // which is guarded by assertions.
240  void ChangeState(State new_state) {
241    ScopedProfilingLabel label("Worker::ChangeState");
242    pthread_mutex_lock(&state_mutex_);
243    assert(new_state != state_);
244    switch (state_) {
245      case State::ThreadStartup:
246        assert(new_state == State::Ready);
247        break;
248      case State::Ready:
249        assert(new_state == State::HasWork ||
250               new_state == State::ExitAsSoonAsPossible);
251        break;
252      case State::HasWork:
253        assert(new_state == State::Ready ||
254               new_state == State::ExitAsSoonAsPossible);
255        break;
256      default:
257        abort();
258    }
259    state_ = new_state;
260    pthread_cond_signal(&state_cond_);
261    if (state_ == State::Ready) {
262      counter_to_decrement_when_ready_->DecrementCount();
263    }
264    pthread_mutex_unlock(&state_mutex_);
265  }
266
267  // Thread entry point.
268  void ThreadFunc() {
269    ScopedProfilingLabel label("Worker::ThreadFunc");
270    RegisterCurrentThreadForProfiling();
271
272    ChangeState(State::Ready);
273
274    // Thread main loop
275    while (true) {
276      // Get a state to act on
277      // In the 'Ready' state, we have nothing to do but to wait until
278      // we switch to another state.
279      State state_to_act_upon = WaitForVariableChange(
280          &state_, State::Ready, &state_cond_, &state_mutex_);
281
282      // We now have a state to act on, so act.
283      switch (state_to_act_upon) {
284        case State::HasWork:
285          // Got work to do! So do it, and then revert to 'Ready' state.
286          ReadBarrier();
287          assert(task_);
288          task_->Run();
289          delete task_;
290          task_ = nullptr;
291          ChangeState(State::Ready);
292          break;
293        case State::ExitAsSoonAsPossible:
294          return;
295        default:
296          abort();
297      }
298    }
299  }
300
301  static void* ThreadFunc(void* arg) {
302    static_cast<Worker*>(arg)->ThreadFunc();
303    return nullptr;
304  }
305
306  // Called by the master thead to give this worker work to do.
307  // It is only legal to call this if the worker
308  void StartWork(Task* task) {
309    assert(!task_);
310    task->local_allocator = &local_allocator_;
311    task_ = task;
312    WriteBarrier();
313    assert(state_ == State::Ready);
314    ChangeState(State::HasWork);
315  }
316
317 private:
318  // The underlying thread.
319  pthread_t thread_;
320
321  // The task to be worked on.
322  const Task* task_;
323
324  // The condition variable and mutex guarding state changes.
325  pthread_cond_t state_cond_;
326  pthread_mutex_t state_mutex_;
327
328  // The state enum tells if we're currently working, waiting for work, etc.
329  State state_;
330
331  // Each thread had a local allocator so they can allocate temporary
332  // buffers without blocking each other.
333  Allocator local_allocator_;
334
335  // pointer to the master's thread BlockingCounter object, to notify the
336  // master thread of when this worker switches to the 'Ready' state.
337  BlockingCounter* const counter_to_decrement_when_ready_;
338};
339
340// A very simple pool of workers, that only allows the very
341// specific parallelization pattern that we use here:
342// a fixed number of workers can be given work, and one then
343// waits for all of them to finish.
344class WorkersPool {
345 public:
346  WorkersPool() {}
347
348  ~WorkersPool() {
349    for (auto w : workers_) {
350      delete w;
351    }
352  }
353
354  BlockingCounter& counter_to_decrement_when_ready() {
355    return counter_to_decrement_when_ready_;
356  }
357
358  // Give work to a specific worker.
359  void StartWorker(int index, Task* task_) {
360    assert(static_cast<std::size_t>(index) < workers_.size());
361    workers_[index]->StartWork(task_);
362  }
363
364  // Ensures that the pool has at least the given count of workers.
365  // If any new worker has to be created, this function waits for it to
366  // be ready.
367  void CreateWorkers(std::size_t workers_count) {
368    if (workers_.size() >= workers_count) {
369      return;
370    }
371    counter_to_decrement_when_ready_.Reset(workers_count - workers_.size());
372    while (workers_.size() < workers_count) {
373      workers_.push_back(new Worker(&counter_to_decrement_when_ready_));
374    }
375    counter_to_decrement_when_ready_.Wait();
376  }
377
378 private:
379  // copy construction disallowed
380  WorkersPool(const WorkersPool&) = delete;
381
382  // The workers in this pool. They are owned by the pool:
383  // the pool creates workers and destroys them in its destructor.
384  std::vector<Worker*> workers_;
385
386  // The BlockingCounter used to wait for the workers.
387  BlockingCounter counter_to_decrement_when_ready_;
388};
389
390// The task we use to implement a multi-threaded Gemm: a block of the
391// RHS has been packed by the master thread; each worker thread
392// then has to pack a block of the LHS and accumulate the Gemm of these
393// packed LHS and RHS blocks.
394template <typename KernelFormat, typename InputScalar, typename OutputScalar,
395          typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
396          MapOrder ResultOrder, typename LhsOffset, typename RhsOffset,
397          typename OutputPipelineType>
398struct GemmWithPackedRhsTask : Task {
399  typedef PackedSideBlock<typename KernelFormat::Lhs> PackedLhs;
400  typedef PackedSideBlock<typename KernelFormat::Rhs> PackedRhs;
401  GemmWithPackedRhsTask(const KernelBase& _kernel,
402                        const MatrixMap<const InputScalar, LhsOrder>& _lhs,
403                        const PackedRhs& _packed_rhs,
404                        MatrixMap<OutputScalar, ResultOrder>* _result,
405                        const LhsOffset& _lhs_offset,
406                        const RhsOffset& _rhs_offset,
407                        const OutputPipelineType& _output_pipeline)
408      : kernel(_kernel),
409        lhs(_lhs),
410        packed_rhs(_packed_rhs),
411        result(*_result),
412        lhs_offset(_lhs_offset),
413        rhs_offset(_rhs_offset),
414        output_pipeline(_output_pipeline) {}
415
416  void Run() const override {
417    ScopedProfilingLabel label("GemmWithPackedRhsTask");
418
419    const int rows = result.rows();
420    const int cols = result.cols();
421    const int depth = lhs.cols();
422
423    BlockParams block_params;
424    block_params.Init<KernelFormat>(rows, cols, depth, 1);
425
426    PackedLhs packed_lhs(Side::Lhs, local_allocator, block_params);
427
428    PackedResult packed_result(local_allocator, block_params);
429
430    local_allocator->Commit();
431
432    for (int c = 0; c < cols; c += block_params.l2_cols) {
433      int cs = std::min(block_params.l2_cols, cols - c);
434
435      for (int r = 0; r < rows; r += block_params.l2_rows) {
436        int rs = std::min(block_params.l2_rows, rows - r);
437
438        PackLhs<BitDepthParams>(&packed_lhs, lhs.block(r, 0, rs, depth));
439
440        Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs);
441
442        auto result_block = result.block(r, c, rs, cs);
443        UnpackResult<BitDepthParams>(&result_block, packed_result, depth,
444                                     packed_lhs.sums_of_each_slice(),
445                                     packed_rhs.sums_of_each_slice(),
446                                     lhs_offset, rhs_offset, output_pipeline);
447      }
448    }
449
450    local_allocator->Decommit();
451  }
452
453  const KernelBase& kernel;
454  const MatrixMap<const InputScalar, LhsOrder> lhs;
455  const PackedRhs packed_rhs;
456  MatrixMap<OutputScalar, ResultOrder> result;
457  const LhsOffset& lhs_offset;
458  const RhsOffset& rhs_offset;
459  const OutputPipelineType& output_pipeline;
460};
461
462class MultiThreadGemmContext : public SingleThreadGemmContext {
463 public:
464  MultiThreadGemmContext() : max_num_threads_(0) {}
465
466  void set_max_num_threads(int n) { max_num_threads_ = n; }
467
468  int max_num_threads() const { return max_num_threads_; }
469
470  WorkersPool* workers_pool() { return &workers_pool_; }
471
472  Allocator* main_thread_task_allocator() {
473    return &main_thread_task_allocator_;
474  }
475
476 protected:
477  // The workers pool used by MultiThreadGemm. Making
478  // this part of the context allows it to be persistent,
479  // avoiding recreating threads on every Gemm.
480  WorkersPool workers_pool_;
481
482  // The maximum number of worker threads to use (in addition
483  // to the master thread).
484  // The default value 0 means the default behavior of
485  // detecting the number of hardware threads. Nonzero values mean
486  // skipping and overriding hardware detection.
487  int max_num_threads_;
488
489  // For N-threaded operations, we will use only N-1 worker threads
490  // while the last task will be run directly on the main thread.
491  // It will then use this main_thread_task_allocator_; having a
492  // dedicated allocator for that (separate from the base allocator_)
493  // allows to use the same code for all tasks regardless of which
494  // thread they run on.
495  Allocator main_thread_task_allocator_;
496};
497
498// Determines how many threads should be used for a given Gemm
499// operation.
500template <int KernelRows>
501inline int HowManyThreads(MultiThreadGemmContext* context, int rows, int cols,
502                          int depth) {
503  // First check if the user set an explicit maximum number of threads.
504  int max_count = context->max_num_threads();
505  if (!max_count) {
506    // No user-set maximum number of threads, so we need to
507    // do some hardware detection.
508    // This is expensive to query so we do it only once.
509    // Too bad for dynamicness. Also, we dont use the c++11 standard getter
510    // because Google's coding style currently bans #include <thread_>.
511    static const int hardware_threads_count =
512        static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
513
514    max_count = hardware_threads_count;
515  }
516
517  // Basic calculation: take into account max pool size, and
518  // how many rows we have to feed our kernel.
519  // The motivation for an absolute minimum number of rows per thread,
520  // potentially higher than KernelRows, is that very thin thread workload
521  // currently defeat assumptions of the AddMod generator, resulting
522  // in substantial bias in TestWithRealData on 24 threads.
523  // Ideally, the AddMod generator should be aware of global (r,c) coordinates
524  // so as to be independent of the number of threads.
525  static const int AbsoluteMinRowsPerThread = 16;
526  static const int MinRowsPerThread = KernelRows > AbsoluteMinRowsPerThread
527                                          ? KernelRows
528                                          : AbsoluteMinRowsPerThread;
529  int thread_count = std::min(max_count, CeilQuotient(rows, MinRowsPerThread));
530
531  // At this point for small products we already have thread_count==1 so
532  // we can avoid doing more work; otherwise, we still want to check
533  // that the cubic size (rows*cols*depth) is big enough to keep
534  // workers_ busy.
535  if (thread_count > 1) {
536    // Empirically determined value.
537    static const std::uint64_t min_cubic_size_per_thread = 64 * 1024;
538
539    // We can only multiply two out of three sizes without risking overflow
540    const std::uint64_t cubic_size =
541        std::uint64_t(rows) * std::uint64_t(cols) * std::uint64_t(depth);
542
543    thread_count =
544        std::min(thread_count, int(cubic_size / min_cubic_size_per_thread));
545
546    if (thread_count < 1) {
547      thread_count = 1;
548    }
549  }
550
551  assert(thread_count > 0 && thread_count <= max_count);
552  return thread_count;
553}
554
555// The main multi-threaded Gemm function.
556// To understand it, first read the code of SingleThreadedGemm().
557// The parallelization scheme used here is to have this master function
558// pack a block of RHS and then start worker threads to pack a block of LHS
559// each, and accumulate the corresponding products.
560template <typename KernelFormat, typename InputScalar, typename OutputScalar,
561          typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
562          MapOrder ResultOrder, typename LhsOffset, typename RhsOffset,
563          typename OutputPipelineType>
564void MultiThreadGemm(MultiThreadGemmContext* context, const KernelBase& kernel,
565                     const MatrixMap<const InputScalar, LhsOrder>& lhs,
566                     const MatrixMap<const InputScalar, RhsOrder>& rhs,
567                     MatrixMap<OutputScalar, ResultOrder>* result,
568                     const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
569                     const OutputPipelineType& output_pipeline) {
570  ScopedProfilingLabel label("gemmlowp::MultiThreadGemm");
571
572  assert(lhs.cols() == rhs.rows());
573
574  int rows = result->rows();
575  int cols = result->cols();
576  int depth = lhs.cols();
577
578  assert(rows > 0);
579  assert(cols > 0);
580  assert(depth > 0);
581
582  const int thread_count =
583      HowManyThreads<KernelFormat::kRows>(context, rows, cols, depth);
584  if (thread_count == 1) {
585    return SingleThreadGemm<KernelFormat, InputScalar, OutputScalar,
586                            BitDepthParams>(context, kernel, lhs, rhs, result,
587                                            lhs_offset, rhs_offset,
588                                            output_pipeline);
589  }
590  assert(thread_count > 1);
591
592  // We choose to use a worker thread for all but one
593  // of the thread workloads. The remaining thread workload will be
594  // executed immediately on the current thread.
595  // In this way, the total number of threads (1 master, N-1 workers)
596  // equals the value returned by HowManyThread. This simple
597  // 1:1 mapping of threads to physical cores, is very important
598  // to getting good multithreaded performance especially for
599  // not-very-large GEMMs, and especially on Android.
600  const int workers_count = thread_count - 1;
601
602  Allocator* allocator = context->allocator();
603  WorkersPool* workers_pool = context->workers_pool();
604
605  workers_pool->CreateWorkers(workers_count);
606
607  BlockParams block_params;
608  block_params.Init<KernelFormat>(rows, cols, depth, workers_count);
609
610  PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(
611      Side::Rhs, allocator, block_params);
612  allocator->Commit();
613
614  // We loop over large blocks of the RHS.
615  for (int c = 0; c < cols; c += block_params.l2_cols) {
616    int cs = std::min(block_params.l2_cols, cols - c);
617
618    // Pack a large block of the RHS.
619    PackRhs<BitDepthParams>(&packed_rhs, rhs.block(0, c, depth, cs));
620
621    // Give work to each worker.
622    int next_start_row = 0;
623    workers_pool->counter_to_decrement_when_ready().Reset(workers_count);
624    for (int thread = 0; thread < thread_count; thread++) {
625      int start_row = next_start_row;
626      next_start_row = std::min(rows, RoundUp<KernelFormat::kRows>(
627                                          rows * (thread + 1) / thread_count));
628
629      int block_rows = next_start_row - start_row;
630      auto lhs_block = lhs.block(start_row, 0, block_rows, depth);
631      auto result_block = result->block(start_row, c, block_rows, cs);
632      typedef GemmWithPackedRhsTask<KernelFormat, InputScalar, OutputScalar,
633                                    BitDepthParams, LhsOrder, RhsOrder,
634                                    ResultOrder, LhsOffset, RhsOffset,
635                                    OutputPipelineType>
636          TaskType;
637      auto task = new TaskType(kernel, lhs_block, packed_rhs, &result_block,
638                               lhs_offset, rhs_offset, output_pipeline);
639      if (thread < workers_count) {
640        workers_pool->StartWorker(thread, task);
641      } else {
642        // Execute the remaining workload immediately on the current thread.
643        task->local_allocator = context->main_thread_task_allocator();
644        task->Run();
645        delete task;
646      }
647    }
648    // Wait for the workers.
649    workers_pool->counter_to_decrement_when_ready().Wait();
650  }
651
652  allocator->Decommit();
653}
654
655}  // namespace gemmlowp
656
657#endif  // GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
658