1// Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_
16#define GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_
17
18#include "multi_thread_common.h"
19#include "single_thread_transform.h"
20
21namespace gemmlowp {
22namespace meta {
23namespace internal {
24
25const int kTransformTaskOverhead = 128000;
26const int kMinTransformTaskSize = 32000;
27
28template <typename MultiThreadingContext, typename Params>
29inline bool PrepareTransform1DTasks(MultiThreadingContext* context,
30                                    const Params& params, int kernel_size,
31                                    std::vector<Params>* task_params) {
32  typedef Transform1DUtil<typename Params::InType, typename Params::OutType,
33                          typename Params::Kernel>
34      Util;
35
36  const int max_threads = ResolveMaxThreads(context->max_num_threads());
37  const int task_size = Util::EstimateComputeCost(params.kernel);
38  const int max_tasks_by_size =
39      (task_size - kTransformTaskOverhead) / kMinTransformTaskSize;
40
41  const int real_tasks = std::max(1, std::min(max_threads, max_tasks_by_size));
42
43  if (real_tasks == 1) {
44    return false;
45  }
46
47  const int chunk = params.kernel.count / real_tasks;
48  for (int i = 0; i < real_tasks - 1; ++i) {
49    task_params->push_back(params);
50    Params& task = task_params->back();
51    task.kernel.count = chunk;
52    task.input = Util::OffsetInput(params.kernel, params.input, i * chunk);
53    task.output = Util::OffsetOutput(params.kernel, params.output, i * chunk);
54  }
55  task_params->push_back(params);
56  Params& task = task_params->back();
57  const int sum_chunk = (real_tasks - 1) * chunk;
58  task.kernel.count = params.kernel.count - sum_chunk;
59  task.input = Util::OffsetInput(params.kernel, params.input, sum_chunk);
60  task.output = Util::OffsetOutput(params.kernel, params.output, sum_chunk);
61  return true;
62}
63
64template <typename Params, int kernel_size>
65struct Transform1DTaskRunner : gemmlowp::Task {
66  Transform1DTaskRunner(const Params& params) : params(params) {}
67
68  void Run() override { Transform1D<Params, kernel_size>(params); }
69
70  Params params;
71};
72
73}  // namespace internal
74
75template <typename MultiThreadingContext, typename Params, int kernel_size>
76inline void MultiThreadTransform1D(MultiThreadingContext* context,
77                                   const Params& params) {
78  typedef internal::Transform1DTaskRunner<Params, kernel_size> TaskRunnerType;
79
80  std::vector<Params> task_params;
81  if (!internal::PrepareTransform1DTasks<MultiThreadingContext, Params>(
82          context, params, kernel_size, &task_params)) {
83    Transform1D<Params, kernel_size>(params);
84    return;
85  }
86
87  auto workers_pool = context->workers_pool();
88  std::vector<Task*> tasks;
89  for (auto& task_param : task_params) {
90    tasks.push_back(new TaskRunnerType(task_param));
91  }
92  workers_pool->Execute(tasks);
93}
94
95}  // namespace meta
96}  // namespace gemmlowp
97
98#endif  // GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_
99