1// Copyright 2016 The Gemmlowp Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#ifndef GEMMLOWP_META_MULTI_THREAD_GEMM_H_ 16#define GEMMLOWP_META_MULTI_THREAD_GEMM_H_ 17 18#include "multi_thread_common.h" 19#include "single_thread_gemm.h" 20 21namespace gemmlowp { 22namespace meta { 23namespace internal { 24 25const std::int32_t kMinGemmTaskSize = 16000; 26const std::int32_t kMinGemmTaskDimension = 4; 27 28template <typename Executor, typename Params> 29std::uint8_t* PrepareGemmTask(const Params& params, int kernel_m, int kernel_n, 30 int kernel_k, std::uint8_t* scratch, int m_start, 31 int m, int n_start, int n, 32 std::vector<Params>* tasks) { 33 tasks->push_back(params); 34 Params& task = tasks->back(); 35 task.scratch = scratch; 36 37 task.m = m; 38 task.lhs = 39 StreamUtil<typename Params::InType, typename Params::LeftStream>::Offset( 40 params.left_stream, params.lhs, m_start, 0); 41 42 task.n = n; 43 task.rhs = 44 StreamUtil<typename Params::InType, typename Params::RightStream>::Offset( 45 params.right_stream, params.rhs, n_start, 0); 46 47 task.result = 48 StreamUtil<typename Params::OutType, typename Params::OutputStream>:: 49 Offset(params.fused_kernel.output_stream, params.result, m_start, 50 n_start); 51 52 return scratch + Executor::template EstimateScratchSize<Params>( 53 task, kernel_m, kernel_n, kernel_k); 54} 55 56template <typename MultiThreadingContext, typename Executor, typename Params> 57bool PrepareGemmTasks(MultiThreadingContext* context, const Params& params, 58 int kernel_m, int kernel_n, int kernel_k, 59 std::vector<Params>* task_params) { 60 const int max_threads = ResolveMaxThreads(context->max_num_threads()); 61 const int max_tasks_by_size = 62 (params.m * params.n * params.k) / kMinGemmTaskSize; 63 const int max_tasks_m = params.m / kMinGemmTaskDimension; 64 const int max_tasks_n = params.n / kMinGemmTaskDimension; 65 const int max_tasks_dimension = std::max(max_tasks_m, max_tasks_n); 66 67 const int real_tasks = std::max( 68 1, 69 std::min(max_threads, std::min(max_tasks_by_size, max_tasks_dimension))); 70 71 if (real_tasks == 1) { 72 return false; 73 } 74 75 std::uint8_t* scratch = params.scratch; 76 77 if (max_tasks_m > max_tasks_n) { 78 const int m_chunk = params.m / real_tasks; 79 for (int i = 0; i < real_tasks - 1; ++i) { 80 scratch = PrepareGemmTask<Executor, Params>( 81 params, kernel_m, kernel_n, kernel_k, scratch, i * m_chunk, m_chunk, 82 0, params.n, task_params); 83 } 84 const int sum_m = (real_tasks - 1) * m_chunk; 85 PrepareGemmTask<Executor, Params>(params, kernel_m, kernel_n, kernel_k, 86 scratch, sum_m, params.m - sum_m, 0, 87 params.n, task_params); 88 } else { 89 const int n_chunk = params.n / real_tasks; 90 for (int i = 0; i < real_tasks - 1; ++i) { 91 scratch = PrepareGemmTask<Executor, Params>( 92 params, kernel_m, kernel_n, kernel_k, scratch, 0, params.m, 93 i * n_chunk, n_chunk, task_params); 94 } 95 int sum_n = (real_tasks - 1) * n_chunk; 96 PrepareGemmTask<Executor, Params>(params, kernel_m, kernel_n, kernel_k, 97 scratch, 0, params.m, sum_n, 98 params.n - sum_n, task_params); 99 } 100 101 return true; 102} 103 104template <typename Executor, typename Params, int kernel_m, int kernel_n, 105 int kernel_k> 106struct GemmTaskRunner : gemmlowp::Task { 107 GemmTaskRunner(const Params& params) : params(params) {} 108 109 void Run() override { 110 Gemm<Executor, Params, kernel_m, kernel_n, kernel_k>(params); 111 } 112 113 Params params; 114}; 115 116} // namespace internal 117 118template <typename MultiThreadingContext, typename Executor, typename Params, 119 int kernel_m, int kernel_n, int kernel_k> 120inline void MultiThreadGemm(MultiThreadingContext* context, 121 const Params& params) { 122 typedef internal::GemmTaskRunner<Executor, Params, kernel_m, kernel_n, 123 kernel_k> 124 TaskRunnerType; 125 126 std::vector<Params> task_params; 127 if (!internal::PrepareGemmTasks<MultiThreadingContext, Executor, Params>( 128 context, params, kernel_m, kernel_n, kernel_k, &task_params)) { 129 Gemm<Executor, Params, kernel_m, kernel_n, kernel_k>(params); 130 return; 131 } 132 133 auto workers_pool = context->workers_pool(); 134 std::vector<Task*> tasks; 135 std::for_each(task_params.begin(), task_params.end(), [tasks](Params* param) { 136 tasks.push_back(new TaskRunnerType(param)); 137 }); 138 workers_pool->Execute(tasks); 139} 140 141} // namespace meta 142} // namespace gemmlowp 143 144#endif // GEMMLOWP_META_MULTI_THREAD_GEMM_H_ 145