1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "Operations.h" 18#include "CpuOperationUtils.h" 19 20#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" 21 22namespace android { 23namespace nn { 24 25// If possible we will use this static buffer for the tensor. 26static constexpr size_t kStaticBufferSize = 1605632; 27static char static_scratch_buffer[kStaticBufferSize]; 28 29// executionMutex is used to protect concurrent access of the static_scratch_buffer 30// and other non-threadsafe resources like gemmlowp::GemmContext. 31// std::mutex is safe for pthreads on Android. 32static std::mutex executionMutex; 33 34#define ANDROID_NN_CONV_PARAMETERS(Type) \ 35 uint32_t height = getSizeOfDimension(inputShape, 1); \ 36 uint32_t width = getSizeOfDimension(inputShape, 2); \ 37 uint32_t filterHeight = getSizeOfDimension(filterShape, 1); \ 38 uint32_t filterWidth = getSizeOfDimension(filterShape, 2); \ 39 uint32_t outHeight = getSizeOfDimension(outputShape, 1); \ 40 uint32_t outWidth = getSizeOfDimension(outputShape, 2); \ 41 uint32_t inDepth = getSizeOfDimension(inputShape, 3); \ 42 \ 43 uint32_t paddingHeight = (uint32_t)padding_top; \ 44 uint32_t paddingWidth = (uint32_t)padding_left; \ 45 \ 46 tflite::Dims<4> im2colDim; \ 47 im2colDim.sizes[3] = (int)getSizeOfDimension(outputShape, 0); \ 48 im2colDim.sizes[2] = (int)getSizeOfDimension(outputShape, 1); \ 49 im2colDim.sizes[1] = (int)getSizeOfDimension(outputShape, 2); \ 50 im2colDim.sizes[0] = (int)inDepth * filterHeight * filterWidth; \ 51 \ 52 im2colDim.strides[0] = 1; \ 53 for (int i=1; i<4; i++) { \ 54 im2colDim.strides[i] = im2colDim.strides[i-1] * im2colDim.sizes[i-1]; \ 55 } \ 56 \ 57 Type* im2colData = nullptr; \ 58 uint64_t im2colByteSize = sizeof(Type); \ 59 std::unique_ptr<Type[]> im2colGuard; \ 60 for (int i=0; i<4; i++) { \ 61 im2colByteSize *= im2colDim.sizes[i]; \ 62 } \ 63 /* http://b/77982879, tflite::optimized_ops::Conv uses int for offsets */ \ 64 if (im2colByteSize >= 0x7fffffff) { \ 65 LOG(ERROR) << "Conv size is too large, not enough memory"; \ 66 return false; \ 67 } \ 68 if (im2colByteSize <= kStaticBufferSize) { \ 69 im2colData = reinterpret_cast<Type *>(static_scratch_buffer); \ 70 } else { \ 71 im2colData = new (std::nothrow) Type[im2colByteSize / sizeof(Type)]; \ 72 if (im2colData == nullptr) { \ 73 LOG(ERROR) << "Conv size is too large, not enough memory"; \ 74 return false; \ 75 } \ 76 im2colGuard.reset(im2colData); \ 77 } 78 79bool convFloat32(const float* inputData, const Shape& inputShape, 80 const float* filterData, const Shape& filterShape, 81 const float* biasData, const Shape& biasShape, 82 int32_t padding_left, int32_t padding_right, 83 int32_t padding_top, int32_t padding_bottom, 84 int32_t stride_width, int32_t stride_height, 85 int32_t activation, 86 float* outputData, const Shape& outputShape) { 87 88 ANDROID_NN_CONV_PARAMETERS(float) 89 90 float output_activation_min, output_activation_max; 91 CalculateActivationRangeFloat(activation, &output_activation_min, 92 &output_activation_max); 93 94 // Prevent concurrent executions that may access the scratch buffer. 95 std::unique_lock<std::mutex> lock(executionMutex); 96 tflite::optimized_ops::Conv( 97 inputData, convertShapeToDims(inputShape), 98 filterData, convertShapeToDims(filterShape), 99 biasData, convertShapeToDims(biasShape), 100 stride_width, stride_height, paddingWidth, paddingHeight, 101 output_activation_min, output_activation_max, 102 outputData, convertShapeToDims(outputShape), 103 im2colData, im2colDim); 104 return true; 105} 106 107bool convQuant8(const uint8_t* inputData, const Shape& inputShape, 108 const uint8_t* filterData, const Shape& filterShape, 109 const int32_t* biasData, const Shape& biasShape, 110 int32_t padding_left, int32_t padding_right, 111 int32_t padding_top, int32_t padding_bottom, 112 int32_t stride_width, int32_t stride_height, 113 int32_t activation, 114 uint8_t* outputData, const Shape& outputShape) { 115 116 ANDROID_NN_CONV_PARAMETERS(uint8_t) 117 118 int32_t inputOffset = -inputShape.offset; 119 int32_t filterOffset = -filterShape.offset; 120 int32_t outputOffset = outputShape.offset; 121 122 float real_multiplier = 0.0; 123 int32_t output_multiplier = 0; 124 int32_t output_shift = 0; 125 int32_t output_activation_min = 0; 126 int32_t output_activation_max = 0; 127 128 if (!GetQuantizedConvolutionMultipler(inputShape, filterShape, biasShape, 129 outputShape, &real_multiplier) || 130 !QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, 131 &output_shift)){ 132 return false; 133 } 134 CalculateActivationRangeUint8(activation, outputShape, 135 &output_activation_min, 136 &output_activation_max); 137 138 static gemmlowp::GemmContext gemm_context; 139 140 // Prevent concurrent executions that may access the scratch buffer and 141 // gemm_context. 142 std::unique_lock<std::mutex> lock(executionMutex); 143 // Alow gemmlowp automatically decide how many threads to use. 144 gemm_context.set_max_num_threads(0); 145 tflite::optimized_ops::Conv( 146 inputData, convertShapeToDims(inputShape), inputOffset, 147 filterData, convertShapeToDims(filterShape), filterOffset, 148 biasData, convertShapeToDims(biasShape), 149 stride_width, stride_height, paddingWidth, paddingHeight, 150 outputOffset, output_multiplier, output_shift, 151 output_activation_min, output_activation_max, 152 outputData, convertShapeToDims(outputShape), 153 im2colData, im2colDim, &gemm_context); 154 return true; 155} 156 157#undef ANDROID_NN_CONV_PARAMETERS 158} // namespace nn 159} // namespace android 160