1eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang/*
2eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * Copyright (C) 2017 The Android Open Source Project
3eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang *
4eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * Licensed under the Apache License, Version 2.0 (the "License");
5eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * you may not use this file except in compliance with the License.
6eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * You may obtain a copy of the License at
7eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang *
8eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang *      http://www.apache.org/licenses/LICENSE-2.0
9eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang *
10eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * Unless required by applicable law or agreed to in writing, software
11eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * distributed under the License is distributed on an "AS IS" BASIS,
12eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * See the License for the specific language governing permissions and
14eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang * limitations under the License.
15eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang */
16eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang
17eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang#include "Operations.h"
18eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang#include "OperationsUtils.h"
19eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang
20eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang#include "internal/optimized/optimized_ops.h"
21eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang
22eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangnamespace android {
23eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangnamespace nn {
24eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang
25eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangbool reluFloat32(const float* inputData, const Shape& inputShape,
26eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang                 float* outputData, const Shape& outputShape) {
27eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    int numElements = getNumberOfElements(inputShape);
28eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    for (int i=0; i<numElements; i++, inputData++, outputData++) {
29eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang        *outputData = std::max(0.f, *inputData);
30eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    }
31eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    return true;
32eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang}
33eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang
349f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool relu1Float32(const float* inputData, const Shape& inputShape,
359f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                  float* outputData, const Shape& outputShape) {
369f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    int numElements = getNumberOfElements(inputShape);
379f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    for (int i=0; i<numElements; i++, inputData++, outputData++) {
389f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        *outputData = std::min(std::max(-1.f, *inputData), 1.f);
399f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    }
409f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    return true;
419f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang}
429f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
43eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangbool relu6Float32(const float* inputData, const Shape& inputShape,
44eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang                  float* outputData, const Shape& outputShape) {
45eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    int numElements = getNumberOfElements(inputShape);
46eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    for (int i=0; i<numElements; i++, inputData++, outputData++) {
47eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang        *outputData = std::min(std::max(0.f, *inputData), 6.f);
48eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    }
49eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    return true;
50eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang}
51eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang
52eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangbool tanhFloat32(const float* inputData, const Shape& inputShape,
53eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang                 float* outputData, const Shape& outputShape) {
54eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    int numElements = getNumberOfElements(inputShape);
55eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    for (int i=0; i<numElements; i++, inputData++, outputData++) {
56eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang        *outputData = std::tanh(*inputData);
57eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    }
58eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    return true;
59eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang}
60eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang
61eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangbool logisticFloat32(const float* inputData, const Shape& inputShape,
62eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang                     float* outputData, const Shape& outputShape) {
63eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    int numElements = getNumberOfElements(inputShape);
64eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    for (int i=0; i<numElements; i++, inputData++, outputData++) {
65eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang        *outputData = 1.f / (1.f + std::exp(-*inputData));
66eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    }
67eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    return true;
68eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang}
69eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang
709f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool softmaxFloat32(const float* inputData, const Shape& inputShape,
719f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                    const float beta,
729f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                    float* outputData, const Shape& outputShape) {
739f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    Dims<4> dim;
749f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    if (getNumberOfDimensions(inputShape) == 2) {
759f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        uint32_t batch_size = getSizeOfDimension(inputShape, 0);
769f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        uint32_t input_size = getNumberOfElements(inputShape) / batch_size;
779f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
789f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        Shape shapeIn4D;
799f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        shapeIn4D.dimensions = {batch_size, 1, 1, input_size};
809f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        dim = convertShapeToDims(shapeIn4D);
819f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    } else if (getNumberOfDimensions(inputShape) == 4) {
829f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        dim = convertShapeToDims(inputShape);
839f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    } else {
849f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        LOG(ERROR) << "only 2D and 4D tensors supported";
859f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        return false;
869f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    }
879f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
889f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    optimized_ops::Softmax(inputData, dim, beta,
899f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                           outputData, dim);
909f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    return true;
919f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang}
929f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
939f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang#define ANDROID_NN_RELUX_QUANT8(activation)                             \
949f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    int numElements = getNumberOfElements(inputShape);                  \
959f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    int32_t output_activation_min = 0;                                  \
969f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    int32_t output_activation_max = 0;                                  \
979f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                                                                        \
989f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    CalculateActivationRangeUint8(activation, inputShape,               \
999f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                                  &output_activation_min,               \
1009f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                                  &output_activation_max);              \
1019f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                                                                        \
1029f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    for (int i=0; i<numElements; i++, inputData++, outputData++) {      \
1039f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        *outputData = std::min((uint8_t)output_activation_max,          \
1049f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                std::max((uint8_t)output_activation_min, *inputData));  \
1059f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    }
1069f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
1079f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
1089f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool reluQuant8(const uint8_t* inputData, const Shape& inputShape,
1099f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                uint8_t* outputData, const Shape& outputShape) {
1109f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    ANDROID_NN_RELUX_QUANT8(kActivationRelu)
1119f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    return true;
1129f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang}
1139f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
1149f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool relu1Quant8(const uint8_t* inputData, const Shape& inputShape,
1159f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                 uint8_t* outputData, const Shape& outputShape) {
1169f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    ANDROID_NN_RELUX_QUANT8(kActivationRelu1)
1179f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    return true;
1189f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang}
1199f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
1209f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool relu6Quant8(const uint8_t* inputData, const Shape& inputShape,
1219f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                 uint8_t* outputData, const Shape& outputShape) {
1229f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    ANDROID_NN_RELUX_QUANT8(kActivationRelu6)
1239f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    return true;
1249f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang}
1259f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
1269f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang#undef ANDROID_NN_RELUX_QUANT8
1279f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
12827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wangbool logisticQuant8(const uint8_t* inputData, const Shape& inputShape,
12927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                    uint8_t* outputData, const Shape& outputShape) {
13045bf79e5b9fee354fde7c1f64417d9ca4a1da7daMiao Wang    if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
131874d039215516aebdaba2e242609199897fe80c0Miao Wang        LOG(ERROR) << "incorrect scale / offset for output";
132874d039215516aebdaba2e242609199897fe80c0Miao Wang        return false;
133874d039215516aebdaba2e242609199897fe80c0Miao Wang    }
134874d039215516aebdaba2e242609199897fe80c0Miao Wang
13527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    int numElements = getNumberOfElements(inputShape);
13627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    static constexpr int kInputIntegerBits = 4;
13727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
13827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    const double input_real_multiplier =
13927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang            inputShape.scale *
14027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang            static_cast<double>(1 << (31 - kInputIntegerBits));
14127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
14227e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    int32_t input_multiplier = 0;
14327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    int32_t input_left_shift = 0;
144be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier,
145be2b22578baf949d7be42ba002cee94304daf53cMiao Wang                                          &input_multiplier,
146be2b22578baf949d7be42ba002cee94304daf53cMiao Wang                                          &input_left_shift)) {
147be2b22578baf949d7be42ba002cee94304daf53cMiao Wang        return false;
148be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    }
14927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    int32_t input_range_radius =
15027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang            CalculateInputRadius(kInputIntegerBits, input_left_shift);
15127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
15227e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    optimized_ops::Logistic(
15327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang            inputData, convertShapeToDims(inputShape),
15427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang            inputShape.offset, input_range_radius,
15527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang            input_multiplier, input_left_shift,
15627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang            outputData, convertShapeToDims(outputShape));
15727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
15827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    return true;
15927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang}
16027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
1619f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wangbool softmaxQuant8(const uint8_t* inputData, const Shape& inputShape,
1629f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                   const float beta,
1639f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                   uint8_t* outputData, const Shape& outputShape) {
1649f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    Dims<4> dim;
1659f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    if (getNumberOfDimensions(inputShape) == 2) {
1669f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        uint32_t batch_size = getSizeOfDimension(inputShape, 0);
1679f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        uint32_t input_size = getNumberOfElements(inputShape) / batch_size;
1689f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
1699f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        Shape shapeIn4D;
1709f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        shapeIn4D.dimensions = {batch_size, 1, 1, input_size};
1719f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        dim = convertShapeToDims(shapeIn4D);
1729f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    } else if (getNumberOfDimensions(inputShape) == 4) {
1739f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        dim = convertShapeToDims(inputShape);
1749f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    } else {
1759f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        LOG(ERROR) << "only 2D and 4D tensors supported";
1769f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        return false;
1779f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    }
1789f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
17945bf79e5b9fee354fde7c1f64417d9ca4a1da7daMiao Wang    if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
180874d039215516aebdaba2e242609199897fe80c0Miao Wang        LOG(ERROR) << "incorrect scale / offset for output";
1819f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang        return false;
182bbfd239e43526ff969699d3fc6110395edd2108bMiao Wang    }
183874d039215516aebdaba2e242609199897fe80c0Miao Wang
1849f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    static const int32_t kScaledDiffIntegerBits = 5;
1859f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    const double input_beta_real_multiplier = std::min(
1869f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang            1.0 * beta * inputShape.scale * (1 << (31 - kScaledDiffIntegerBits)),
1879f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang            (1ll << 31) - 1.0);
1889f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
1899f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    int32_t input_multiplier = 0;
1909f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    int32_t input_left_shift = 0;
191be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    if (!QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier,
192be2b22578baf949d7be42ba002cee94304daf53cMiao Wang                                          &input_multiplier,
193be2b22578baf949d7be42ba002cee94304daf53cMiao Wang                                          &input_left_shift)) {
194be2b22578baf949d7be42ba002cee94304daf53cMiao Wang        return false;
195be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    }
1969f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    float diff_min = -1.0f * CalculateInputRadius(kScaledDiffIntegerBits,
1979f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                                                  input_left_shift);
1989f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
1999f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang    optimized_ops::Softmax(inputData, dim, input_multiplier,
2009f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                           input_left_shift, diff_min,
2019f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang                           outputData, dim);
202bbfd239e43526ff969699d3fc6110395edd2108bMiao Wang    return true;
203bbfd239e43526ff969699d3fc6110395edd2108bMiao Wang}
204bbfd239e43526ff969699d3fc6110395edd2108bMiao Wang
2059f41362ea2d250b89e59c73cc0194c6a6720cc9dMiao Wang
206eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang}  // namespace nn
207eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang}  // namespace android
208