1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Operations.h"
18#include "OperationsUtils.h"
19
20#include "internal/optimized/optimized_ops.h"
21
22namespace android {
23namespace nn {
24
25bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
26                           const float* weightsData, const Shape& weightsShape,
27                           const float* biasData, const Shape& biasShape,
28                           int32_t activation,
29                           float* outputData, const Shape& outputShape) {
30
31    #define ANDROID_NN_FULLY_CONNECTED(activation)                              \
32        optimized_ops::FullyConnected<FusedActivationFunctionType::activation>( \
33            inputData, convertShapeToDims(inputShape),                          \
34            weightsData, convertShapeToDims(weightsShape),                      \
35            biasData, convertShapeToDims(biasShape),                            \
36            outputData, convertShapeToDims(outputShape))
37
38    ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_FULLY_CONNECTED)
39    #undef ANDROID_NN_FULLY_CONNECTED
40    return true;
41}
42
43bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape,
44                          const uint8_t* weightsData, const Shape& weightsShape,
45                          const int32_t* biasData, const Shape& biasShape,
46                          int32_t activation,
47                          uint8_t* outputData, const Shape& outputShape) {
48    int32_t inputOffset = -inputShape.offset;
49    int32_t weightsOffset = -weightsShape.offset;
50    int32_t outputOffset = outputShape.offset;
51
52    float real_multiplier = 0.0;
53    int32_t output_multiplier = 0;
54    int32_t output_shift = 0;
55    int32_t output_activation_min = 0;
56    int32_t output_activation_max = 0;
57
58    if (!GetQuantizedConvolutionMultipler(inputShape, weightsShape, biasShape,
59                                          outputShape, &real_multiplier) ||
60            !QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
61                                              &output_shift)) {
62        return false;
63    }
64    CalculateActivationRangeUint8(activation, outputShape,
65                                  &output_activation_min,
66                                  &output_activation_max);
67
68    static gemmlowp::GemmContext gemm_context;
69    // Alow gemmlowp automatcally decide how many threads to use.
70    gemm_context.set_max_num_threads(0);
71
72    #define ANDROID_NN_FULLY_CONNECTED(activation)                              \
73        optimized_ops::FullyConnected<FusedActivationFunctionType::activation>( \
74            inputData, convertShapeToDims(inputShape), inputOffset,             \
75            weightsData, convertShapeToDims(weightsShape), weightsOffset,       \
76            biasData, convertShapeToDims(biasShape),                            \
77            outputOffset, output_multiplier, output_shift,                      \
78            output_activation_min, output_activation_max,                       \
79            outputData, convertShapeToDims(outputShape), &gemm_context)
80
81    ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_FULLY_CONNECTED)
82    #undef ANDROID_NN_FULLY_CONNECTED
83    return true;
84}
85}  // namespace nn
86}  // namespace android
87