1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Operations.h"
18#include "OperationsUtils.h"
19
20#include "internal/optimized/optimized_ops.h"
21
22namespace android {
23namespace nn {
24
25bool reluFloat32(const float* inputData, const Shape& inputShape,
26                 float* outputData, const Shape& outputShape) {
27    int numElements = getNumberOfElements(inputShape);
28    for (int i=0; i<numElements; i++, inputData++, outputData++) {
29        *outputData = std::max(0.f, *inputData);
30    }
31    return true;
32}
33
34bool relu1Float32(const float* inputData, const Shape& inputShape,
35                  float* outputData, const Shape& outputShape) {
36    int numElements = getNumberOfElements(inputShape);
37    for (int i=0; i<numElements; i++, inputData++, outputData++) {
38        *outputData = std::min(std::max(-1.f, *inputData), 1.f);
39    }
40    return true;
41}
42
43bool relu6Float32(const float* inputData, const Shape& inputShape,
44                  float* outputData, const Shape& outputShape) {
45    int numElements = getNumberOfElements(inputShape);
46    for (int i=0; i<numElements; i++, inputData++, outputData++) {
47        *outputData = std::min(std::max(0.f, *inputData), 6.f);
48    }
49    return true;
50}
51
52bool tanhFloat32(const float* inputData, const Shape& inputShape,
53                 float* outputData, const Shape& outputShape) {
54    int numElements = getNumberOfElements(inputShape);
55    for (int i=0; i<numElements; i++, inputData++, outputData++) {
56        *outputData = std::tanh(*inputData);
57    }
58    return true;
59}
60
61bool logisticFloat32(const float* inputData, const Shape& inputShape,
62                     float* outputData, const Shape& outputShape) {
63    int numElements = getNumberOfElements(inputShape);
64    for (int i=0; i<numElements; i++, inputData++, outputData++) {
65        *outputData = 1.f / (1.f + std::exp(-*inputData));
66    }
67    return true;
68}
69
70bool softmaxFloat32(const float* inputData, const Shape& inputShape,
71                    const float beta,
72                    float* outputData, const Shape& outputShape) {
73    Dims<4> dim;
74    if (getNumberOfDimensions(inputShape) == 2) {
75        uint32_t batch_size = getSizeOfDimension(inputShape, 0);
76        uint32_t input_size = getNumberOfElements(inputShape) / batch_size;
77
78        Shape shapeIn4D;
79        shapeIn4D.dimensions = {batch_size, 1, 1, input_size};
80        dim = convertShapeToDims(shapeIn4D);
81    } else if (getNumberOfDimensions(inputShape) == 4) {
82        dim = convertShapeToDims(inputShape);
83    } else {
84        LOG(ERROR) << "only 2D and 4D tensors supported";
85        return false;
86    }
87
88    optimized_ops::Softmax(inputData, dim, beta,
89                           outputData, dim);
90    return true;
91}
92
93#define ANDROID_NN_RELUX_QUANT8(activation)                             \
94    int numElements = getNumberOfElements(inputShape);                  \
95    int32_t output_activation_min = 0;                                  \
96    int32_t output_activation_max = 0;                                  \
97                                                                        \
98    CalculateActivationRangeUint8(activation, inputShape,               \
99                                  &output_activation_min,               \
100                                  &output_activation_max);              \
101                                                                        \
102    for (int i=0; i<numElements; i++, inputData++, outputData++) {      \
103        *outputData = std::min((uint8_t)output_activation_max,          \
104                std::max((uint8_t)output_activation_min, *inputData));  \
105    }
106
107
108bool reluQuant8(const uint8_t* inputData, const Shape& inputShape,
109                uint8_t* outputData, const Shape& outputShape) {
110    ANDROID_NN_RELUX_QUANT8(kActivationRelu)
111    return true;
112}
113
114bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape,
115                 uint8_t* outputData, const Shape& outputShape) {
116    ANDROID_NN_RELUX_QUANT8(kActivationRelu1)
117    return true;
118}
119
120bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape,
121                 uint8_t* outputData, const Shape& outputShape) {
122    ANDROID_NN_RELUX_QUANT8(kActivationRelu6)
123    return true;
124}
125
126#undef ANDROID_NN_RELUX_QUANT8
127
128bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape,
129                    uint8_t* outputData, const Shape& outputShape) {
130    if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
131        LOG(ERROR) << "incorrect scale / offset for output";
132        return false;
133    }
134
135    int numElements = getNumberOfElements(inputShape);
136    static constexpr int kInputIntegerBits = 4;
137
138    const double input_real_multiplier =
139            inputShape.scale *
140            static_cast<double>(1 << (31 - kInputIntegerBits));
141
142    int32_t input_multiplier = 0;
143    int32_t input_left_shift = 0;
144    if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier,
145                                          &input_multiplier,
146                                          &input_left_shift)) {
147        return false;
148    }
149    int32_t input_range_radius =
150            CalculateInputRadius(kInputIntegerBits, input_left_shift);
151
152    optimized_ops::Logistic(
153            inputData, convertShapeToDims(inputShape),
154            inputShape.offset, input_range_radius,
155            input_multiplier, input_left_shift,
156            outputData, convertShapeToDims(outputShape));
157
158    return true;
159}
160
161bool softmaxQuant8(const uint8_t* inputData, const Shape& inputShape,
162                   const float beta,
163                   uint8_t* outputData, const Shape& outputShape) {
164    Dims<4> dim;
165    if (getNumberOfDimensions(inputShape) == 2) {
166        uint32_t batch_size = getSizeOfDimension(inputShape, 0);
167        uint32_t input_size = getNumberOfElements(inputShape) / batch_size;
168
169        Shape shapeIn4D;
170        shapeIn4D.dimensions = {batch_size, 1, 1, input_size};
171        dim = convertShapeToDims(shapeIn4D);
172    } else if (getNumberOfDimensions(inputShape) == 4) {
173        dim = convertShapeToDims(inputShape);
174    } else {
175        LOG(ERROR) << "only 2D and 4D tensors supported";
176        return false;
177    }
178
179    if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
180        LOG(ERROR) << "incorrect scale / offset for output";
181        return false;
182    }
183
184    static const int32_t kScaledDiffIntegerBits = 5;
185    const double input_beta_real_multiplier = std::min(
186            1.0 * beta * inputShape.scale * (1 << (31 - kScaledDiffIntegerBits)),
187            (1ll << 31) - 1.0);
188
189    int32_t input_multiplier = 0;
190    int32_t input_left_shift = 0;
191    if (!QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier,
192                                          &input_multiplier,
193                                          &input_left_shift)) {
194        return false;
195    }
196    float diff_min = -1.0f * CalculateInputRadius(kScaledDiffIntegerBits,
197                                                  input_left_shift);
198
199    optimized_ops::Softmax(inputData, dim, input_multiplier,
200                           input_left_shift, diff_min,
201                           outputData, dim);
202    return true;
203}
204
205
206}  // namespace nn
207}  // namespace android
208