1eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang/* 2eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * Copyright (C) 2017 The Android Open Source Project 3eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * 4eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * Licensed under the Apache License, Version 2.0 (the "License"); 5eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * you may not use this file except in compliance with the License. 6eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * You may obtain a copy of the License at 7eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * 8eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * http://www.apache.org/licenses/LICENSE-2.0 9eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * 10eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * Unless required by applicable law or agreed to in writing, software 11eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * distributed under the License is distributed on an "AS IS" BASIS, 12eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * See the License for the specific language governing permissions and 14eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * limitations under the License. 15eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang */ 16eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang 17eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang#include "Operations.h" 18eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang#include "OperationsUtils.h" 19eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang 20eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang#include "internal/optimized/optimized_ops.h" 21eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang 22eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangnamespace android { 23eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangnamespace nn { 24eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang 25eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangbool reluFloat32(const float* inputData, const Shape& inputShape, 26eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang float* outputData, const Shape& outputShape) { 27eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang int numElements = getNumberOfElements(inputShape); 28eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang for (int i=0; i<numElements; i++, inputData++, outputData++) { 29eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang *outputData = std::max(0.f, *inputData); 30eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang } 31eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang return true; 32eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang} 33eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang 349f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool relu1Float32(const float* inputData, const Shape& inputShape, 359f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang float* outputData, const Shape& outputShape) { 369f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang int numElements = getNumberOfElements(inputShape); 379f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang for (int i=0; i<numElements; i++, inputData++, outputData++) { 389f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang *outputData = std::min(std::max(-1.f, *inputData), 1.f); 399f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang } 409f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang return true; 419f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang} 429f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 43eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangbool relu6Float32(const float* inputData, const Shape& inputShape, 44eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang float* outputData, const Shape& outputShape) { 45eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang int numElements = getNumberOfElements(inputShape); 46eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang for (int i=0; i<numElements; i++, inputData++, outputData++) { 47eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang *outputData = std::min(std::max(0.f, *inputData), 6.f); 48eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang } 49eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang return true; 50eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang} 51eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang 52eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangbool tanhFloat32(const float* inputData, const Shape& inputShape, 53eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang float* outputData, const Shape& outputShape) { 54eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang int numElements = getNumberOfElements(inputShape); 55eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang for (int i=0; i<numElements; i++, inputData++, outputData++) { 56eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang *outputData = std::tanh(*inputData); 57eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang } 58eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang return true; 59eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang} 60eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang 61eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangbool logisticFloat32(const float* inputData, const Shape& inputShape, 62eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang float* outputData, const Shape& outputShape) { 63eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang int numElements = getNumberOfElements(inputShape); 64eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang for (int i=0; i<numElements; i++, inputData++, outputData++) { 65eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang *outputData = 1.f / (1.f + std::exp(-*inputData)); 66eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang } 67eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang return true; 68eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang} 69eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang 709f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool softmaxFloat32(const float* inputData, const Shape& inputShape, 719f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang const float beta, 729f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang float* outputData, const Shape& outputShape) { 739f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang Dims<4> dim; 749f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang if (getNumberOfDimensions(inputShape) == 2) { 759f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang uint32_t batch_size = getSizeOfDimension(inputShape, 0); 769f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang uint32_t input_size = getNumberOfElements(inputShape) / batch_size; 779f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 789f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang Shape shapeIn4D; 799f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang shapeIn4D.dimensions = {batch_size, 1, 1, input_size}; 809f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang dim = convertShapeToDims(shapeIn4D); 819f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang } else if (getNumberOfDimensions(inputShape) == 4) { 829f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang dim = convertShapeToDims(inputShape); 839f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang } else { 849f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang LOG(ERROR) << "only 2D and 4D tensors supported"; 859f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang return false; 869f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang } 879f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 889f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang optimized_ops::Softmax(inputData, dim, beta, 899f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang outputData, dim); 909f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang return true; 919f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang} 929f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 939f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang#define ANDROID_NN_RELUX_QUANT8(activation) \ 949f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang int numElements = getNumberOfElements(inputShape); \ 959f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang int32_t output_activation_min = 0; \ 969f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang int32_t output_activation_max = 0; \ 979f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang \ 989f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang CalculateActivationRangeUint8(activation, inputShape, \ 999f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang &output_activation_min, \ 1009f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang &output_activation_max); \ 1019f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang \ 1029f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang for (int i=0; i<numElements; i++, inputData++, outputData++) { \ 1039f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang *outputData = std::min((uint8_t)output_activation_max, \ 1049f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang std::max((uint8_t)output_activation_min, *inputData)); \ 1059f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang } 1069f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 1079f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 1089f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool reluQuant8(const uint8_t* inputData, const Shape& inputShape, 1099f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang uint8_t* outputData, const Shape& outputShape) { 1109f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang ANDROID_NN_RELUX_QUANT8(kActivationRelu) 1119f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang return true; 1129f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang} 1139f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 1149f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, 1159f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang uint8_t* outputData, const Shape& outputShape) { 1169f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang ANDROID_NN_RELUX_QUANT8(kActivationRelu1) 1179f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang return true; 1189f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang} 1199f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 1209f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, 1219f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang uint8_t* outputData, const Shape& outputShape) { 1229f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang ANDROID_NN_RELUX_QUANT8(kActivationRelu6) 1239f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang return true; 1249f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang} 1259f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 1269f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang#undef ANDROID_NN_RELUX_QUANT8 1279f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 12827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wangbool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, 12927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang uint8_t* outputData, const Shape& outputShape) { 13045bf79e5b9fee354fde7c1f64417d9ca4a1da7daMiao Wang if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) { 131874d039215516aebdaba2e242609199897fe80c0Miao Wang LOG(ERROR) << "incorrect scale / offset for output"; 132874d039215516aebdaba2e242609199897fe80c0Miao Wang return false; 133874d039215516aebdaba2e242609199897fe80c0Miao Wang } 134874d039215516aebdaba2e242609199897fe80c0Miao Wang 13527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang int numElements = getNumberOfElements(inputShape); 13627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang static constexpr int kInputIntegerBits = 4; 13727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang 13827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang const double input_real_multiplier = 13927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang inputShape.scale * 14027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang static_cast<double>(1 << (31 - kInputIntegerBits)); 14127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang 14227e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang int32_t input_multiplier = 0; 14327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang int32_t input_left_shift = 0; 144be2b22578baf949d7be42ba002cee94304daf53cMiao Wang if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, 145be2b22578baf949d7be42ba002cee94304daf53cMiao Wang &input_multiplier, 146be2b22578baf949d7be42ba002cee94304daf53cMiao Wang &input_left_shift)) { 147be2b22578baf949d7be42ba002cee94304daf53cMiao Wang return false; 148be2b22578baf949d7be42ba002cee94304daf53cMiao Wang } 14927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang int32_t input_range_radius = 15027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang CalculateInputRadius(kInputIntegerBits, input_left_shift); 15127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang 15227e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang optimized_ops::Logistic( 15327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang inputData, convertShapeToDims(inputShape), 15427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang inputShape.offset, input_range_radius, 15527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang input_multiplier, input_left_shift, 15627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang outputData, convertShapeToDims(outputShape)); 15727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang 15827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang return true; 15927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang} 16027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang 1619f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool softmaxQuant8(const uint8_t* inputData, const Shape& inputShape, 1629f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang const float beta, 1639f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang uint8_t* outputData, const Shape& outputShape) { 1649f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang Dims<4> dim; 1659f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang if (getNumberOfDimensions(inputShape) == 2) { 1669f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang uint32_t batch_size = getSizeOfDimension(inputShape, 0); 1679f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang uint32_t input_size = getNumberOfElements(inputShape) / batch_size; 1689f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 1699f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang Shape shapeIn4D; 1709f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang shapeIn4D.dimensions = {batch_size, 1, 1, input_size}; 1719f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang dim = convertShapeToDims(shapeIn4D); 1729f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang } else if (getNumberOfDimensions(inputShape) == 4) { 1739f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang dim = convertShapeToDims(inputShape); 1749f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang } else { 1759f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang LOG(ERROR) << "only 2D and 4D tensors supported"; 1769f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang return false; 1779f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang } 1789f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 17945bf79e5b9fee354fde7c1f64417d9ca4a1da7daMiao Wang if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) { 180874d039215516aebdaba2e242609199897fe80c0Miao Wang LOG(ERROR) << "incorrect scale / offset for output"; 1819f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang return false; 182bbfd239e43526ff969699d3fc6110395edd2108bMiao Wang } 183874d039215516aebdaba2e242609199897fe80c0Miao Wang 1849f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang static const int32_t kScaledDiffIntegerBits = 5; 1859f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang const double input_beta_real_multiplier = std::min( 1869f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 1.0 * beta * inputShape.scale * (1 << (31 - kScaledDiffIntegerBits)), 1879f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang (1ll << 31) - 1.0); 1889f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 1899f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang int32_t input_multiplier = 0; 1909f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang int32_t input_left_shift = 0; 191be2b22578baf949d7be42ba002cee94304daf53cMiao Wang if (!QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier, 192be2b22578baf949d7be42ba002cee94304daf53cMiao Wang &input_multiplier, 193be2b22578baf949d7be42ba002cee94304daf53cMiao Wang &input_left_shift)) { 194be2b22578baf949d7be42ba002cee94304daf53cMiao Wang return false; 195be2b22578baf949d7be42ba002cee94304daf53cMiao Wang } 1969f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang float diff_min = -1.0f * CalculateInputRadius(kScaledDiffIntegerBits, 1979f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang input_left_shift); 1989f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 1999f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang optimized_ops::Softmax(inputData, dim, input_multiplier, 2009f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang input_left_shift, diff_min, 2019f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang outputData, dim); 202bbfd239e43526ff969699d3fc6110395edd2108bMiao Wang return true; 203bbfd239e43526ff969699d3fc6110395edd2108bMiao Wang} 204bbfd239e43526ff969699d3fc6110395edd2108bMiao Wang 2059f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang 206eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang} // namespace nn 207eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang} // namespace android 208