1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef ANDROID_ML_NN_COMMON_OPERATIONS_UTILS_H
18#define ANDROID_ML_NN_COMMON_OPERATIONS_UTILS_H
19
20#include "Utils.h"
21
22#include <cstdint>
23#include <vector>
24
25// Macro to check if the input parameters for operation are valid or not.
26#define NN_CHECK(v)                                                     \
27  do {                                                                  \
28    if (!(v)) {                                                         \
29      LOG(ERROR) << "NN_CHECK failed: "  << #v << "'\n";                \
30      return false;                                                     \
31    }                                                                   \
32  } while(0);
33
34#define NN_CHECK_EQ(actual, expected)           \
35  NN_CHECK((actual) == (expected))
36
37#define NN_OPS_CHECK NN_CHECK
38
39namespace android {
40namespace nn {
41
42enum PaddingScheme {
43    kPaddingUnknown = 0,
44    kPaddingSame = 1,
45    kPaddingValid = 2,
46};
47
48// The type and dimensions of an operand.
49struct Shape {
50    OperandType type;
51    std::vector<uint32_t> dimensions;
52    float scale;
53    int32_t offset;
54};
55
56// Verifies that the two shapes are the same.
57bool SameShape(const Shape& in1, const Shape& in2);
58
59// Sets out to the same shape as in.
60bool SetShape(const Shape& in, Shape* out);
61
62// Return the total number of elements, i.e. all the dimensions multiplied
63// together. For a scalar, returns one.
64uint32_t getNumberOfElements(const Shape& shape);
65
66uint32_t getNumberOfDimensions(const Shape& shape);
67
68uint32_t getSizeOfDimension(const Shape& shape, uint32_t dimensionIdx);
69
70inline uint32_t computeOutSize(uint32_t imageSize, uint32_t filterSize, uint32_t stride,
71                               uint32_t paddingHead, uint32_t paddingTail) {
72    return (imageSize - filterSize + stride + paddingHead + paddingTail) / stride;
73}
74
75__wur
76bool QuantizeMultiplierSmallerThanOne(double double_multiplier,
77                                      int32_t* quantized_multiplier,
78                                      int32_t* right_shift);
79
80__wur
81bool QuantizeMultiplierGreaterThanOne(double double_multiplier,
82                                      int32_t* quantized_multiplier,
83                                      int* left_shift);
84
85__wur
86bool GetQuantizedConvolutionMultipler(const Shape& inputShape,
87                                      const Shape& filterShape,
88                                      const Shape& biasShape,
89                                      const Shape& outputShape,
90                                      float* multiplier);
91
92void CalculateActivationRangeUint8(int32_t activation,
93                                   const Shape& outputShape,
94                                   int32_t* act_min,
95                                   int32_t* act_max);
96
97int32_t CalculateInputRadius(int input_integer_bits, int input_left_shift);
98
99inline void calculateExplicitPadding(int32_t in_size, int32_t stride,
100                                     int32_t filter_size, int32_t padding_implicit,
101                                     int32_t* padding_head, int32_t* padding_tail) {
102    *padding_head = 0;
103    *padding_tail = 0;
104
105    if (padding_implicit == kPaddingSame) {
106        int32_t out_size = (in_size + stride - 1) / stride;
107        int32_t tmp = (out_size - 1) * stride + filter_size;
108        if (tmp > in_size) {
109            *padding_head = (tmp - in_size) / 2;
110            *padding_tail = (tmp - in_size) - *padding_head;
111        }
112    }
113}
114
115inline PaddingScheme getPaddingScheme(int32_t inWidth, int32_t inHeight,
116                                      int32_t strideWidth, int32_t strideHeight,
117                                      int32_t filterWidth, int32_t filterHeight,
118                                      int32_t paddingLeft, int32_t paddingRight,
119                                      int32_t paddingTop, int32_t paddingBottom) {
120    if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && paddingBottom == 0) {
121        return kPaddingValid;
122    }
123
124    int32_t expectedPaddingLeft, expectedPaddingRight;
125    int32_t expectedPaddingTop, expectedPaddingBottom;
126
127    calculateExplicitPadding(inWidth, strideWidth, filterWidth, kPaddingSame,
128                             &expectedPaddingLeft, &expectedPaddingRight);
129    calculateExplicitPadding(inHeight, strideHeight, filterHeight, kPaddingSame,
130                             &expectedPaddingTop, &expectedPaddingBottom);
131    if (expectedPaddingLeft == paddingLeft && expectedPaddingRight == paddingRight &&
132        expectedPaddingTop == paddingTop && expectedPaddingBottom == paddingBottom) {
133        return kPaddingSame;
134    } else {
135        return kPaddingUnknown;
136    }
137}
138
139// Preparation functions for the corresponding ops
140bool addMulPrepare(const Shape& in1, const Shape& in2, Shape* out1);
141
142bool floorPrepare(const Shape& input, Shape* output);
143
144bool dequantizePrepare(const Shape& input, Shape* output);
145
146bool depthwiseConvPrepare(const Shape& input,
147                          const Shape& filter,
148                          const Shape& bias,
149                          int32_t padding_left, int32_t padding_right,
150                          int32_t padding_top, int32_t padding_bottom,
151                          int32_t stride_width, int32_t stride_height,
152                          Shape* output);
153
154bool convPrepare(const Shape& input,
155                 const Shape& filter,
156                 const Shape& bias,
157                 int32_t padding_left, int32_t padding_right,
158                 int32_t padding_top, int32_t padding_bottom,
159                 int32_t stride_width, int32_t stride_height,
160                 Shape* output);
161
162bool genericPoolingPrepare(const Shape& input,
163                           int32_t padding_left, int32_t padding_right,
164                           int32_t padding_top, int32_t padding_bottom,
165                           int32_t stride_width, int32_t stride_height,
166                           int32_t filter_width, int32_t filter_height,
167                           Shape* output);
168
169bool genericActivationPrepare(const Shape& input, Shape* output);
170
171bool fullyConnectedPrepare(const Shape& input,
172                           const Shape& weights,
173                           const Shape& bias,
174                           Shape* output);
175
176bool concatenationPrepare(const std::vector<Shape>& inputShapes,
177                          int32_t axis,
178                          Shape* output);
179
180bool genericNormalizationPrepare(const Shape& input, Shape* output);
181
182bool reshapePrepare(const Shape& input,
183                    const int32_t* targetDims,
184                    const int32_t targetDimsSize,
185                    Shape* output);
186
187bool resizeBilinearPrepare(const Shape& input,
188                           int32_t height,
189                           int32_t width,
190                           Shape* output);
191
192bool depthToSpacePrepare(const Shape& input,
193                         int32_t blockSize,
194                         Shape* output);
195
196bool spaceToDepthPrepare(const Shape& input,
197                         int32_t blockSize,
198                         Shape* output);
199
200bool embeddingLookupPrepare(const Shape &valueShape,
201                            const Shape &lookupShape,
202                            Shape *outputShape);
203
204bool hashtableLookupPrepare(const Shape &lookupShape,
205                            const Shape &keyShape,
206                            const Shape &valueShape,
207                            Shape *outputShape,
208                            Shape *hitShape);
209
210#define ANDROID_NN_MACRO_DISPATCH(macro)                                    \
211    switch (activation) {                                                   \
212        case (int32_t) FusedActivationFunc::NONE:                           \
213            macro(kNone);                                                   \
214            break;                                                          \
215        case (int32_t) FusedActivationFunc::RELU:                           \
216            macro(kRelu);                                                   \
217            break;                                                          \
218        case (int32_t) FusedActivationFunc::RELU1:                          \
219            macro(kRelu1);                                                  \
220            break;                                                          \
221        case (int32_t) FusedActivationFunc::RELU6:                          \
222            macro(kRelu6);                                                  \
223            break;                                                          \
224        default:                                                            \
225            LOG(ERROR) << "Unsupported fused activation function type";     \
226            return false;                                                   \
227    }
228
229} // namespace nn
230} // namespace android
231
232#endif // ANDROID_ML_NN_COMMON_OPERATIONS_UTILS_H
233