1a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang//
3a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// Licensed under the Apache License, Version 2.0 (the "License");
4a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// you may not use this file except in compliance with the License.
5a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// You may obtain a copy of the License at
6a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang//
7a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang//     http://www.apache.org/licenses/LICENSE-2.0
8a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang//
9a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// Unless required by applicable law or agreed to in writing, software
10a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// distributed under the License is distributed on an "AS IS" BASIS,
11a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// See the License for the specific language governing permissions and
13a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// limitations under the License.
14a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
15a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// multi_thread_common.h: Multithreading code shared by different meta gemm
16a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// versions.
17a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
18a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#ifndef GEMMLOWP_META_MULTI_THREAD_COMMON_H_
19a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#define GEMMLOWP_META_MULTI_THREAD_COMMON_H_
20a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
21a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#include "../internal/multi_thread_gemm.h"
22a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
23a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangnamespace gemmlowp {
24a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangnamespace meta {
25a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangnamespace internal {
26a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
27a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangconst std::int32_t kMinTaskSize = 16000;
28a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangconst std::int32_t kMinTaskDimension = 4;
29a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
30a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangstruct TaskRect {
31a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::int32_t m_offset;
32a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::int32_t m;
33a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::int32_t n_offset;
34a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::int32_t n;
35a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
36a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  TaskRect(std::int32_t m_offset, std::int32_t m, std::int32_t n_offset,
37a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang           std::int32_t n)
38a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang      : m_offset(m_offset), m(m), n_offset(n_offset), n(n) {}
39a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang};
40a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
41a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangtemplate <typename IN_TYPE, typename OUT_TYPE, typename F>
42a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangstruct MetaTask : gemmlowp::Task {
43a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::uint8_t* scratch;
44a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  const IN_TYPE* lhs;
45a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  const IN_TYPE* rhs;
46a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  TaskRect task_rect;
47a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::int32_t k;
48a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  OUT_TYPE* result;
49a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::int32_t result_stride;
50a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  const F& operation;
51a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
52a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  MetaTask(std::uint8_t* scratch, const IN_TYPE* lhs, const IN_TYPE* rhs,
53a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang           const TaskRect& task_rect, std::int32_t k, OUT_TYPE* result,
54a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang           std::int32_t result_stride, const F& operation)
55a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang      : scratch(scratch),
56a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang        lhs(lhs),
57a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang        rhs(rhs),
58a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang        task_rect(task_rect),
59a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang        k(k),
60a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang        result(result),
61a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang        result_stride(result_stride),
62a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang        operation(operation) {}
63a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
64a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  void Run() override {
65a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    const IN_TYPE* task_lhs = lhs + task_rect.m_offset * k;
66a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    const IN_TYPE* task_rhs = rhs + task_rect.n_offset * k;
67a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    OUT_TYPE* task_result =
68a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang        result + task_rect.m_offset * result_stride + task_rect.n_offset;
69a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    operation.ExecuteMatrixMatrix(scratch, task_lhs, task_rhs, task_rect.m,
70a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                                  task_rect.n, k, task_result, result_stride);
71a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
72a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang};
73a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
74a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangstd::int32_t ResolveMaxThreads(std::int32_t max_threads) {
75a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  if (max_threads == 0) {
76a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    static const int hardware_threads_count =
77a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang        static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
78a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    return hardware_threads_count;
79a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
80a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  return max_threads;
81a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}
82a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
83a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangvoid PrepareTasks(std::int32_t max_tasks, std::int32_t m, std::int32_t n,
84a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                  std::int32_t k, std::vector<internal::TaskRect>* tasks) {
85a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  const std::int32_t max_tasks_by_size = (m * n * k) / kMinTaskSize;
86a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  const std::int32_t max_tasks_m = m / kMinTaskDimension;
87a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  const std::int32_t max_tasks_n = n / kMinTaskDimension;
88a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  const std::int32_t max_tasks_dimension = std::max(max_tasks_m, max_tasks_n);
89a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
90a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::int32_t real_tasks = std::max(
91a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang      1, std::min(max_tasks, std::min(max_tasks_by_size, max_tasks_dimension)));
92a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
93a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  if (real_tasks == 1) {
94a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    tasks->push_back(TaskRect(0, m, 0, n));
95a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    return;
96a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
97a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
98a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  if (max_tasks_m > max_tasks_n) {
99a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    const std::int32_t m_chunk = m / real_tasks;
100a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    for (int i = 0; i < real_tasks - 1; ++i) {
101a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang      tasks->push_back(TaskRect(i * m_chunk, m_chunk, 0, n));
102a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    }
103a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    const std::int32_t last_m_offset = (real_tasks - 1) * m_chunk;
104a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    tasks->push_back(TaskRect(last_m_offset, m - last_m_offset, 0, n));
105a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  } else {
106a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    const std::int32_t n_chunk = n / real_tasks;
107a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    for (int i = 0; i < real_tasks - 1; ++i) {
108a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang      tasks->push_back(TaskRect(0, m, i * n_chunk, n_chunk));
109a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    }
110a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    const std::int32_t last_n_offset = (real_tasks - 1) * n_chunk;
111a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    tasks->push_back(TaskRect(0, m, last_n_offset, n - last_n_offset));
112a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
113a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}
114a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
115a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangtemplate <typename IN_TYPE, typename OUT_TYPE, typename F>
116a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangvoid MultiThreadedMatrixMatrix(gemmlowp::WorkersPool* pool,
117a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                               std::int32_t max_threads, std::uint8_t* scratch,
118a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                               const IN_TYPE* lhs, const IN_TYPE* rhs,
119a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                               std::int32_t m, std::int32_t n, std::int32_t k,
120a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                               OUT_TYPE* result, std::int32_t result_stride,
121a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                               const F& operation) {
122a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  max_threads = internal::ResolveMaxThreads(max_threads);
123a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
124a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::vector<internal::TaskRect> task_rects;
125a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  internal::PrepareTasks(max_threads, m, n, k, &task_rects);
126a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
127a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  if (task_rects.size() == 1) {
128a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    operation.ExecuteMatrixMatrix(scratch, lhs, rhs, m, n, k, result,
129a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                                  result_stride);
130a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    return;
131a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
132a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
133a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::uint8_t* task_scratch = scratch;
134a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::int32_t scratch_per_thread = operation.ScratchPerThread(m, n, k);
135a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::vector<Task*> tasks;
136a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  std::for_each(
137a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang      task_rects.begin(), task_rects.end(),
138a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang      [&tasks, &task_scratch, lhs, rhs, k, result, result_stride, operation,
139a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang       scratch_per_thread](internal::TaskRect& rect) {
140a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang        tasks.push_back(new internal::MetaTask<IN_TYPE, OUT_TYPE, F>(
141a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang            task_scratch, lhs, rhs, rect, k, result, result_stride, operation));
142a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang        task_scratch += scratch_per_thread;
143a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang      });
144a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  pool->Execute(tasks);
145a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}
146a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
147a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}  // namespace internal
148a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}  // namespace meta
149a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}  // namespace gemmlowp
150a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
151a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#endif  // GEMMLOWP_META_MULTI_THREAD_COMMON_H_
152