1a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// Copyright 2015 Google Inc. All Rights Reserved. 2a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// 3a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// Licensed under the Apache License, Version 2.0 (the "License"); 4a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// you may not use this file except in compliance with the License. 5a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// You may obtain a copy of the License at 6a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// 7a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// http://www.apache.org/licenses/LICENSE-2.0 8a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// 9a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// Unless required by applicable law or agreed to in writing, software 10a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// distributed under the License is distributed on an "AS IS" BASIS, 11a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// See the License for the specific language governing permissions and 13a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// limitations under the License. 14a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 15a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// multi_thread_gemv.h: Entry point to the multithreaded version of the 16a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// generated (meta) gemv library. 17a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 18a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#ifndef GEMMLOWP_META_MULTI_THREAD_GEMV_H_ 19a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#define GEMMLOWP_META_MULTI_THREAD_GEMV_H_ 20a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 21a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#ifdef GEMMLOWP_NEON_32 22a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 23a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#include "multi_thread_common.h" 24a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#include "operations_common.h" 25a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#include "single_thread_gemm.h" 26a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 27a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangnamespace gemmlowp { 28a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangnamespace meta { 29a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangnamespace internal { 30a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 31a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangclass GemvQuantized8BitOperation : public Quantized8BitOperation { 32a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang public: 33a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang GemvQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset, 34a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t sum_offset, std::int32_t multiplier, 35a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t shift) 36a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang : Quantized8BitOperation(lhs_offset, rhs_offset, sum_offset, multiplier, 37a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang shift) {} 38a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 39a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, 40a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang const std::uint8_t* rhs, std::int32_t m, 41a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t n, std::int32_t k, std::uint8_t* result, 42a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t result_stride) const { 43a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang gemv_q8(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, sum_offset, 44a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang multiplier, shift, result); 45a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } 46a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 47a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, 48a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t k) { 49a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang return 128 * 1024; 50a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } 51a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}; 52a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 53a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangclass GemvFloatOperation : public FloatOperation { 54a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang public: 55a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang GemvFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset, 56a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang float result_offset) 57a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang : FloatOperation(lhs_offset, rhs_offset, result_offset) {} 58a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 59a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, 60a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang const std::uint8_t* rhs, std::int32_t m, 61a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t n, std::int32_t k, float* result, 62a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t result_stride) const { 63a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang gemv_f(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result_offset, 64a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang result); 65a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } 66a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 67a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, 68a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t k) { 69a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang return 128 * 1024; 70a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } 71a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}; 72a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 73a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangclass GemvInt32Operation : public Int32Operation { 74a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang public: 75a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang GemvInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset) 76a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang : Int32Operation(lhs_offset, rhs_offset) {} 77a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 78a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs, 79a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang const std::uint8_t* rhs, std::int32_t m, 80a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t n, std::int32_t k, std::int32_t* result, 81a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t result_stride) const { 82a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang gemv_i32(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result); 83a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } 84a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 85a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n, 86a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t k) { 87a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang return 128 * 1024; 88a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } 89a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}; 90a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 91a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang} // namespace internal 92a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 93a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangstd::int32_t gemv_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k, 94a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t max_threads) { 95a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang return internal::ResolveMaxThreads(max_threads) * 96a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang internal::GemvQuantized8BitOperation::ScratchPerThread(m, n, k); 97a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang} 98a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 99a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangvoid multi_thread_gemv_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads, 100a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::uint8_t* scratch, const std::uint8_t* lhs, 101a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang const std::uint8_t* rhs, std::int32_t n, 102a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t k, std::int32_t lhs_offset, 103a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t rhs_offset, std::int32_t sum_offset, 104a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t multiplier, std::int32_t shift, 105a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::uint8_t* result) { 106a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang max_threads = internal::ResolveMaxThreads(max_threads); 107a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang internal::GemvQuantized8BitOperation operation(lhs_offset, rhs_offset, 108a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang sum_offset, multiplier, shift); 109a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang if (max_threads == 1) { 110a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n); 111a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } else { 112a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1, 113a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang n, k, result, n, operation); 114a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } 115a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang} 116a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 117a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangstd::int32_t gemv_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k, 118a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t max_threads) { 119a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang return internal::ResolveMaxThreads(max_threads) * 120a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang internal::GemvFloatOperation::ScratchPerThread(m, n, k); 121a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang} 122a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 123a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangvoid multi_thread_gemv_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads, 124a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::uint8_t* scratch, const std::uint8_t* lhs, 125a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang const std::uint8_t* rhs, std::int32_t n, 126a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t k, std::int32_t lhs_offset, 127a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t rhs_offset, float result_offset, 128a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang float* result) { 129a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang max_threads = internal::ResolveMaxThreads(max_threads); 130a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang internal::GemvFloatOperation operation(lhs_offset, rhs_offset, result_offset); 131a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang if (max_threads == 1) { 132a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n); 133a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } else { 134a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1, 135a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang n, k, result, n, operation); 136a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } 137a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang} 138a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 139a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangstd::int32_t gemv_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k, 140a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t max_threads) { 141a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang return internal::ResolveMaxThreads(max_threads) * 142a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang internal::GemvInt32Operation::ScratchPerThread(m, n, k); 143a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang} 144a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 145a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangvoid multi_thread_gemv_i32(gemmlowp::WorkersPool* pool, 146a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t max_threads, std::uint8_t* scratch, 147a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang const std::uint8_t* lhs, const std::uint8_t* rhs, 148a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t n, std::int32_t k, 149a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t lhs_offset, std::int32_t rhs_offset, 150a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang std::int32_t* result) { 151a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang max_threads = internal::ResolveMaxThreads(max_threads); 152a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang internal::GemvInt32Operation operation(lhs_offset, rhs_offset); 153a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang if (max_threads == 1) { 154a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n); 155a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } else { 156a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1, 157a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang n, k, result, n, operation); 158a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang } 159a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang} 160a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 161a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang} // namespace meta 162a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang} // namespace gemmlowp 163a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 164a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#else 165a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#warning "Meta gemm fast-path requires GEMMLOWP_NEON_32!" 166a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#endif 167a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang 168a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#endif // GEMMLOWP_META_MULTI_THREAD_GEMV_H_ 169