196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet/*
296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Copyright (C) 2017 The Android Open Source Project
396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet *
496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Licensed under the Apache License, Version 2.0 (the "License");
596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * you may not use this file except in compliance with the License.
696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * You may obtain a copy of the License at
796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet *
896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet *      http://www.apache.org/licenses/LICENSE-2.0
996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet *
1096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * Unless required by applicable law or agreed to in writing, software
1196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * distributed under the License is distributed on an "AS IS" BASIS,
1296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * See the License for the specific language governing permissions and
1496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet * limitations under the License.
1596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet */
1696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
1796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#define LOG_TAG "OperationsUtils"
1896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
1996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#include "OperationsUtils.h"
2027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang#include "Operations.h"
2196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet#include "Utils.h"
2296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
2327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang#include <cmath>
2427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
2596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletnamespace android {
2696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletnamespace nn {
2796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
2896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletbool SameShape(const Shape& in1, const Shape& in2) {
29eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    if (in1.type != in2.type || in1.dimensions.size() != in2.dimensions.size()) {
3096775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        return false;
3196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
32eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    for (size_t i = 0; i < in1.dimensions.size(); i++) {
3396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        if (in1.dimensions[i] != in2.dimensions[i]) {
3496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet            return false;
3596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        }
3696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
3796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    return true;
3896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet}
3996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
40eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wangbool SetShape(const Shape& in, Shape* out) {
41eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    if (in.type != out->type || in.dimensions.size() != out->dimensions.size()) {
4296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        return false;
4396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
44eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    out->dimensions = in.dimensions;
4596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    return true;
4696775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet}
4796775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
4896775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouilletuint32_t getNumberOfElements(const Shape& shape) {
4996775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    uint32_t count = 1;
50eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    for (size_t i = 0; i < shape.dimensions.size(); i++) {
5196775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet        count *= shape.dimensions[i];
5296775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    }
5396775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet    return count;
5496775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet}
5596775128e3bcfdc5be51b62edc50309c83861fe8Jean-Luc Brouillet
56707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouilletuint32_t getNumberOfDimensions(const Shape& shape) {
57eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    return shape.dimensions.size();
58707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet}
59707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet
60707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouilletuint32_t getSizeOfDimension(const Shape& shape, uint32_t dimensionIdx) {
61eb1f88846f147d1d80ee0d688fe4635b89a40ffaMiao Wang    if (dimensionIdx >= shape.dimensions.size()) {
62707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet        // TODO, log the error
63707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet        return 0;
64707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet    }
65707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet    return shape.dimensions[dimensionIdx];
66707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet}
67707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet
68be2b22578baf949d7be42ba002cee94304daf53cMiao Wangbool QuantizeMultiplierSmallerThanOne(double double_multiplier,
6927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                      int32_t* quantized_multiplier,
7027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                      int32_t* right_shift) {
71be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(double_multiplier >= 0.);
72be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(double_multiplier < 1.);
7327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    if (double_multiplier == 0.) {
7427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        *quantized_multiplier = 0;
7527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        *right_shift = 0;
76be2b22578baf949d7be42ba002cee94304daf53cMiao Wang        return true;
7727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    }
78be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(double_multiplier > 0.);
7927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    const double q = std::frexp(double_multiplier, right_shift);
8027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    *right_shift *= -1;
8127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    int64_t q_fixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
82be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(q_fixed <= (1ll << 31));
8327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    if (q_fixed == (1ll << 31)) {
8427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        q_fixed /= 2;
8527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        --*right_shift;
8627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    }
87be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(*right_shift >= 0);
88be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(q_fixed <= std::numeric_limits<int32_t>::max());
8927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    *quantized_multiplier = static_cast<int32_t>(q_fixed);
90be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    return true;
9127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang}
9227e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
93be2b22578baf949d7be42ba002cee94304daf53cMiao Wangbool QuantizeMultiplierGreaterThanOne(double double_multiplier,
9427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                      int32_t* quantized_multiplier,
9527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                      int* left_shift) {
96be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(double_multiplier > 1.);
9727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    const double q = std::frexp(double_multiplier, left_shift);
9827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    int64_t q_fixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
99be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(q_fixed <= (1ll << 31));
10027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    if (q_fixed == (1ll << 31)) {
10127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        q_fixed /= 2;
10227e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        ++*left_shift;
10327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    }
104be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(*left_shift >= 0);
105be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(q_fixed <= std::numeric_limits<int32_t>::max());
10627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    *quantized_multiplier = static_cast<int32_t>(q_fixed);
107be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    return true;
10827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang}
10927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
110be2b22578baf949d7be42ba002cee94304daf53cMiao Wangbool GetQuantizedConvolutionMultipler(const Shape& inputShape,
11127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                      const Shape& filterShape,
11227e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                      const Shape& biasShape,
11327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                      const Shape& outputShape,
11427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                      float* multiplier) {
11527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    const float input_product_scale = inputShape.scale * filterShape.scale;
11627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    const float bias_scale = biasShape.scale;
11727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    const float output_scale = outputShape.scale;
11827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
11927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    // The following conditions must be guaranteed by the training pipeline.
120be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(std::abs(input_product_scale - bias_scale) <=
12127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang              1e-6 * std::min(input_product_scale, bias_scale));
122be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(input_product_scale >= 0);
123be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(input_product_scale < output_scale);
12427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    *multiplier = input_product_scale / output_scale;
125be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    return true;
12627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang}
12727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
12827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wangvoid CalculateActivationRangeUint8(int32_t activation,
12927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                   const Shape& outputShape,
13027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                   int32_t* act_min,
13127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                   int32_t* act_max) {
13227e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    const int32_t qmin = std::numeric_limits<uint8_t>::min();
13327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    const int32_t qmax = std::numeric_limits<uint8_t>::max();
13427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
13527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    const auto scale = outputShape.scale;
13627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    const auto zero_point = outputShape.offset;
13727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
13827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    auto quantize = [scale, zero_point](float f) {
13927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        return zero_point + static_cast<int32_t>(std::round(f / scale));
14027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    };
14127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
14227e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    if (activation == kActivationRelu) {
14327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        *act_min = std::max(qmin, quantize(0.0));
14427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        *act_max = qmax;
14527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    } else if (activation == kActivationRelu6) {
14627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        *act_min = std::max(qmin, quantize(0.0));
14727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        *act_max = std::min(qmax, quantize(6.0));
14827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    } else if (activation == kActivationRelu1) {
14927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        *act_min = std::max(qmin, quantize(-1.0));
15027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        *act_max = std::min(qmax, quantize(1.0));
15127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    } else {
15227e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        *act_min = qmin;
15327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang        *act_max = qmax;
15427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    }
15527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang}
15627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
15727e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wangint32_t CalculateInputRadius(int input_integer_bits, int input_left_shift) {
15827e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) *
15927e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                      (1ll << (31 - input_integer_bits)) /
16027e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang                                      (1ll << input_left_shift);
16127e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    // Tighten bound using floor.  Suppose that we could use the exact value.
16227e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    // After scaling the difference, the result would be at the maximum.  Thus we
16327e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    // must ensure that our value has lower magnitude.
16427e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang    return static_cast<int32_t>(std::floor(max_input_rescaled));
16527e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang}
16627e9be3904b034e422ee9b6ab70b35ea994d2b39Miao Wang
1671b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool addMulPrepare(const Shape& in1, const Shape& in2, Shape* out) {
168be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(in1) <= 4 && getNumberOfDimensions(in2) <= 4);
169238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    NN_OPS_CHECK(in1.type == in2.type);
1701b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    if (SameShape(in1, in2)) {
1711b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        return SetShape(in1, out);
1721b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    } else {
1731b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        // BroadcastAdd needed
1741b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        uint32_t numberOfDims1 = getNumberOfDimensions(in1);
1751b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        uint32_t numberOfDims2 = getNumberOfDimensions(in2);
1761b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        uint32_t maxDims = std::max(numberOfDims1, numberOfDims2);
1771b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        out->dimensions = std::vector<uint32_t>(maxDims);
1781b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        for (uint32_t i = 1; i <= maxDims; i++) {
1791b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            uint32_t dim1 = 1;
1801b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            if (i <= numberOfDims1) {
1811b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                dim1 = getSizeOfDimension(in1, numberOfDims1 - i);
1821b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            }
1831b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            uint32_t dim2 = 1;
1841b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            if (i <= numberOfDims2) {
1851b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                dim2 = getSizeOfDimension(in2, numberOfDims2 - i);
1861b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            }
1871b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            if (dim1 != dim2 && dim1 != 1 && dim2 != 1) {
1881b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                LOG(ERROR) << "Dimensions mismatch for BroadcastAdd";
1891b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                return false;
1901b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            }
1911b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            out->dimensions[maxDims - i] = std::max(dim1, dim2);
1921b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        }
1931b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    }
1941b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return true;
1951b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
1961b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
1971b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool floorPrepare(const Shape& input, Shape* output) {
1981b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return SetShape(input, output);
1991b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
2001b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2011b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool dequantizePrepare(const Shape& input, Shape* output) {
2021b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    if (input.type != OperandType::TENSOR_QUANT8_ASYMM ||
2031b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            output->type != OperandType::TENSOR_FLOAT32) {
2041b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        LOG(ERROR) << "bad input / output operand type.";
2051b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        return false;
2061b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    }
207feb29cb5007bd968e8850e51fab95ea24eccfd90Miao Wang    if (input.dimensions.size() != output->dimensions.size()) {
208feb29cb5007bd968e8850e51fab95ea24eccfd90Miao Wang        LOG(ERROR) << "input and output tensors don't have the same rank.";
209feb29cb5007bd968e8850e51fab95ea24eccfd90Miao Wang        return false;
210feb29cb5007bd968e8850e51fab95ea24eccfd90Miao Wang    }
211feb29cb5007bd968e8850e51fab95ea24eccfd90Miao Wang    output->dimensions = input.dimensions;
212feb29cb5007bd968e8850e51fab95ea24eccfd90Miao Wang    return true;
2131b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
2141b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2151b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool convPrepare(const Shape& input,
2161b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                 const Shape& filter,
2171b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                 const Shape& bias,
2181b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                 int32_t padding_left, int32_t padding_right,
2191b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                 int32_t padding_top, int32_t padding_bottom,
2201b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                 int32_t stride_width, int32_t stride_height,
2211b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                 Shape* output) {
222238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    NN_OPS_CHECK(input.type == filter.type);
223238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
224238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang        NN_OPS_CHECK(bias.type == OperandType::TENSOR_INT32);
225238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    } else {
226238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang        NN_OPS_CHECK(input.type == bias.type);
227238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    }
228be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
229be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(filter) == 4);
230be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(bias) == 1);
2311b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
232be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getSizeOfDimension(filter, 0) == getSizeOfDimension(bias, 0));
233be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getSizeOfDimension(filter, 3) == getSizeOfDimension(input, 3));
2341b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2351b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t channels_out = getSizeOfDimension(filter, 0);
2361b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t width        = getSizeOfDimension(input, 2);
2371b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t height       = getSizeOfDimension(input, 1);
2381b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t filterWidth  = getSizeOfDimension(filter, 2);
2391b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t filterHeight = getSizeOfDimension(filter, 1);
2401b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t batches      = getSizeOfDimension(input, 0);
2411b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2421b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t outWidth = computeOutSize(width, filterWidth, stride_width,
2431b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                                       padding_left, padding_right);
2441b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t outHeight = computeOutSize(height, filterHeight, stride_height,
2451b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                                        padding_top, padding_bottom);
2461b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2471b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->type = input.type;
2481b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->dimensions = {batches, outHeight, outWidth, channels_out};
2491b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return true;
2501b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
2511b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2521b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool depthwiseConvPrepare(const Shape& input,
2531b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          const Shape& filter,
2541b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          const Shape& bias,
2551b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          int32_t padding_left, int32_t padding_right,
2561b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          int32_t padding_top, int32_t padding_bottom,
2571b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          int32_t stride_width, int32_t stride_height,
2581b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          Shape* output) {
259238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    NN_OPS_CHECK(input.type == filter.type);
260238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
261238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang        NN_OPS_CHECK(bias.type == OperandType::TENSOR_INT32);
262238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    } else {
263238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang        NN_OPS_CHECK(input.type == bias.type);
264238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    }
265be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
266be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(filter) == 4);
267be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(bias) == 1);
2681b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
269be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getSizeOfDimension(filter, 3) == getSizeOfDimension(bias, 0));
2701b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2711b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t channels_out = getSizeOfDimension(filter, 3);
2721b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t width        = getSizeOfDimension(input, 2);
2731b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t height       = getSizeOfDimension(input, 1);
2741b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t filterWidth  = getSizeOfDimension(filter, 2);
2751b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t filterHeight = getSizeOfDimension(filter, 1);
2761b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t batches      = getSizeOfDimension(input, 0);
2771b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2781b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t outWidth = computeOutSize(width, filterWidth, stride_width,
2791b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                                       padding_left, padding_right);
2801b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t outHeight = computeOutSize(height, filterHeight, stride_height,
2811b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                                        padding_top, padding_bottom);
2821b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2831b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->type = input.type;
2841b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->dimensions = {batches, outHeight, outWidth, channels_out};
2851b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return true;
2861b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
2871b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2881b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2891b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool genericPoolingPrepare(const Shape& input,
2901b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           int32_t padding_left, int32_t padding_right,
2911b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           int32_t padding_top, int32_t padding_bottom,
2921b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           int32_t stride_width, int32_t stride_height,
2931b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           int32_t filter_width, int32_t filter_height,
2941b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           Shape* output) {
295be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
2961b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
2971b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t batches      = getSizeOfDimension(input, 0);
2981b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t width        = getSizeOfDimension(input, 2);
2991b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t height       = getSizeOfDimension(input, 1);
3001b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t channels_out = getSizeOfDimension(input, 3);
3011b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3021b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t outWidth = computeOutSize(width, filter_width, stride_width,
3031b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                                       padding_left, padding_right);
3041b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t outHeight = computeOutSize(height, filter_height, stride_height,
3051b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                                        padding_top, padding_bottom);
3061b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3071b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->type = input.type;
3081b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->dimensions = {batches, outHeight, outWidth, channels_out};
3091b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return true;
3101b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
3111b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3121b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3131b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool genericActivationPrepare(const Shape& input,
3141b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                              Shape* output) {
315be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(input) <= 4);
3161b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return SetShape(input, output);
3171b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
3181b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3191b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool fullyConnectedPrepare(const Shape& input,
3201b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           const Shape& weights,
3211b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           const Shape& bias,
3221b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           Shape* output) {
3231b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    // Check all the parameters of tensor match within themselves and match the
3241b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    // input configuration.
325238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    NN_OPS_CHECK(input.type == weights.type);
326238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
327238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang        NN_OPS_CHECK(bias.type == OperandType::TENSOR_INT32);
328238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    } else {
329238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang        NN_OPS_CHECK(input.type == bias.type);
330238a880d67b3aa5650ad52037fe1c25b1750eca9Miao Wang    }
33135647da7af1f99b6abb48c4eeaec042ea9edfb4dMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(input) >= 2);
3321b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t input_size = getNumberOfElements(input);
3331b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t num_units  = getSizeOfDimension(weights, 0);
3341b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t batch_size = input_size / getSizeOfDimension(weights, 1);
3351b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
336be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getSizeOfDimension(bias, 0) == num_units);
337be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getSizeOfDimension(weights, 1) * batch_size == input_size);
338be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(weights) == 2);
3391b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3401b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->type = input.type;
3411b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->dimensions = {batch_size, num_units};
3421b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3431b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return true;
3441b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
3451b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3461b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool concatenationPrepare(const std::vector<Shape>& inputShapes,
3471b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          int32_t axis,
3481b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          Shape* output) {
3491b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3501b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    int num_inputs = inputShapes.size();
3511b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    OperandType input_type = inputShapes[0].type;
3521b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t num_dimensions = getNumberOfDimensions(inputShapes[0]);
3531b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
354be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(axis >= 0);
355be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(axis < (int32_t)num_dimensions);
3561b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3571b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    int sum_axis = getSizeOfDimension(inputShapes[0], axis);
3581b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    for (int i = 1; i < num_inputs; ++i) {
359be2b22578baf949d7be42ba002cee94304daf53cMiao Wang        NN_OPS_CHECK(getNumberOfDimensions(inputShapes[i]) == num_dimensions);
360be2b22578baf949d7be42ba002cee94304daf53cMiao Wang        NN_OPS_CHECK(inputShapes[i].type == inputShapes[0].type);
3611b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        if (input_type == OperandType::TENSOR_QUANT8_ASYMM) {
362be2b22578baf949d7be42ba002cee94304daf53cMiao Wang            NN_OPS_CHECK(inputShapes[0].offset == inputShapes[i].offset);
363be2b22578baf949d7be42ba002cee94304daf53cMiao Wang            NN_OPS_CHECK(inputShapes[0].scale == inputShapes[i].scale);
3641b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        }
3651b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        for (int d = 0; d < (int32_t)num_dimensions; ++d) {
3661b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            if (d == axis) {
3671b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                sum_axis += getSizeOfDimension(inputShapes[i], axis);
3681b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            } else {
369be2b22578baf949d7be42ba002cee94304daf53cMiao Wang                NN_OPS_CHECK(getSizeOfDimension(inputShapes[0], d) ==
3701b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           getSizeOfDimension(inputShapes[i], d));
3711b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            }
3721b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        }
3731b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    }
3741b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3751b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->type = input_type;
3761b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->dimensions = inputShapes[0].dimensions;
3771b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->dimensions[axis] = sum_axis;
3781b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3791b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    if (input_type == OperandType::TENSOR_QUANT8_ASYMM) {
380be2b22578baf949d7be42ba002cee94304daf53cMiao Wang        NN_OPS_CHECK(inputShapes[0].offset == output->offset);
381be2b22578baf949d7be42ba002cee94304daf53cMiao Wang        NN_OPS_CHECK(inputShapes[0].scale == output->scale);
3821b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    }
3831b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3841b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return true;
3851b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
3861b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3871b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3881b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool genericNormalizationPrepare(const Shape& input, Shape* output) {
389be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
3901b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return SetShape(input, output);
3911b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
3921b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
3931b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool reshapePrepare(const Shape& input,
3941b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                    const int32_t* targetDims,
3951b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                    const int32_t targetDimsSize,
3961b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                    Shape* output) {
3971b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    // Reshape allows one of the targetDims components to have the
3981b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    // special -1 value, meaning it will be calculated automatically based on the
3991b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    // input. Here we calculate what that dimension should be so that the number
4001b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    // of output elements in the same as the number of input elements.
4011b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    int32_t numInputElements = (int32_t) getNumberOfElements(input);
4021b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4031b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    std::vector<uint32_t> outDims(targetDimsSize);
4041b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    int32_t numOutputElements = 1;
4051b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    int32_t strechDim = -1;
4061b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    for (int32_t i = 0; i < targetDimsSize; ++i) {
4071b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        int32_t value = targetDims[i];
4081b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        if (value == -1) {
409be2b22578baf949d7be42ba002cee94304daf53cMiao Wang            NN_OPS_CHECK(strechDim == -1);
4101b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            strechDim = i;
4111b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        } else {
4121b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            numOutputElements *= value;
4131b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang            outDims[i] = (uint32_t)value;
4141b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        }
4151b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    }
4161b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    if (strechDim != -1) {
4171b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        int32_t strechValue = numInputElements / numOutputElements;
4181b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        outDims[strechDim] = (uint32_t) strechValue;
4191b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang        numOutputElements *= strechValue;
4201b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    }
4211b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
422be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(numInputElements == numOutputElements);
4231b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4241b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->type = input.type;
4251b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->dimensions = outDims;
4261b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->offset = input.offset;
4271b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->scale = input.scale;
4281b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4291b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return true;
4301b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
4311b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4321b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool resizeBilinearPrepare(const Shape& input,
4331b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           int32_t width,
4341b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           int32_t height,
4351b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                           Shape* output) {
436be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
4371b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t batches  = getSizeOfDimension(input, 0);
4381b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t channels = getSizeOfDimension(input, 3);
4391b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4401b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->type = input.type;
4411b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->dimensions = {batches, (uint32_t)height, (uint32_t)width, channels};
4421b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4431b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return true;
4441b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
4451b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4461b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool depthToSpacePrepare(const Shape& input,
4471b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                         int32_t blockSize,
4481b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                         Shape* output) {
449be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
450be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(blockSize > 0);
4511b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4521b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t batches  = getSizeOfDimension(input, 0);
4531b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t height   = getSizeOfDimension(input, 1);
4541b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t width    = getSizeOfDimension(input, 2);
4551b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t channels = getSizeOfDimension(input, 3);
4561b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
457be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(channels % (blockSize * blockSize) == 0);
4581b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->type = input.type;
4591b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->dimensions = {batches,
4601b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          height * blockSize,
4611b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          width * blockSize,
4621b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          channels / (blockSize * blockSize)};
4631b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->offset = input.offset;
4641b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->scale = input.scale;
4651b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4661b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return true;
4671b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
4681b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4691b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wangbool spaceToDepthPrepare(const Shape& input,
4701b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                         int32_t blockSize,
4711b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                         Shape* output) {
472be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(getNumberOfDimensions(input) == 4);
473be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(blockSize > 0);
4741b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4751b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t batches  = getSizeOfDimension(input, 0);
4761b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t height   = getSizeOfDimension(input, 1);
4771b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t width    = getSizeOfDimension(input, 2);
4781b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    uint32_t channels = getSizeOfDimension(input, 3);
4791b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
480be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(height % blockSize == 0);
481be2b22578baf949d7be42ba002cee94304daf53cMiao Wang    NN_OPS_CHECK(width % blockSize == 0);
4821b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4831b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->type = input.type;
4841b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->dimensions = {batches,
4851b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          height / blockSize,
4861b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          width / blockSize,
4871b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang                          channels * (blockSize * blockSize)};
4881b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->offset = input.offset;
4891b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    output->scale = input.scale;
4901b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4911b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang    return true;
4921b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang}
4931b69ceeb5920503f18b6c6c1233b1fa481b6e634Miao Wang
4948c689bdfc391e47854ec27bad0f9d685135af253Yang Nibool embeddingLookupPrepare(const Shape &valueShape,
4958c689bdfc391e47854ec27bad0f9d685135af253Yang Ni                            const Shape &lookupShape,
4968c689bdfc391e47854ec27bad0f9d685135af253Yang Ni                            Shape *outputShape) {
4978c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    NN_OPS_CHECK(getNumberOfDimensions(valueShape) >= 2);
4988c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    NN_OPS_CHECK(getNumberOfDimensions(lookupShape) == 1);
4998c689bdfc391e47854ec27bad0f9d685135af253Yang Ni
5008c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    const uint32_t rows     = getSizeOfDimension(valueShape, 0);
5018c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    const uint32_t columns  = getSizeOfDimension(valueShape, 1);
5028c689bdfc391e47854ec27bad0f9d685135af253Yang Ni
5038c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    const uint32_t lookups  = getSizeOfDimension(lookupShape, 0);
5048c689bdfc391e47854ec27bad0f9d685135af253Yang Ni
5058c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    outputShape->type = valueShape.type;
5068c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    outputShape->dimensions = { lookups, columns };
5078c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    for (uint32_t i = 2; i < getNumberOfDimensions(valueShape); i++) {
508bee07f73a5f998a2dd6dc581e7776557c21f9684Miao Wang        outputShape->dimensions.push_back(getSizeOfDimension(valueShape, i));
5098c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    }
5108c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    outputShape->offset = valueShape.offset;
5118c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    outputShape->scale = valueShape.scale;
5128c689bdfc391e47854ec27bad0f9d685135af253Yang Ni
5138c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    return true;
5148c689bdfc391e47854ec27bad0f9d685135af253Yang Ni}
5158c689bdfc391e47854ec27bad0f9d685135af253Yang Ni
5168c689bdfc391e47854ec27bad0f9d685135af253Yang Nibool hashtableLookupPrepare(const Shape &lookupShape,
5178c689bdfc391e47854ec27bad0f9d685135af253Yang Ni                            const Shape &keyShape,
5188c689bdfc391e47854ec27bad0f9d685135af253Yang Ni                            const Shape &valueShape,
5198c689bdfc391e47854ec27bad0f9d685135af253Yang Ni                            Shape *outputShape,
5208c689bdfc391e47854ec27bad0f9d685135af253Yang Ni                            Shape *hitShape) {
5218c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    NN_OPS_CHECK(getNumberOfDimensions(lookupShape) == 1);
5228c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    NN_OPS_CHECK(getNumberOfDimensions(keyShape) == 1);
5238c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    NN_OPS_CHECK(getNumberOfDimensions(valueShape) >= 1);
5248c689bdfc391e47854ec27bad0f9d685135af253Yang Ni
5258c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    const uint32_t lookups  = getSizeOfDimension(lookupShape, 0);
5268c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    const uint32_t keys     = getSizeOfDimension(keyShape, 0);
5278c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    const uint32_t rows     = getSizeOfDimension(valueShape, 0);
5288c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    outputShape->type = valueShape.type;
5298c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    outputShape->dimensions = { lookups };
5308c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    for (uint32_t i = 1; i < getNumberOfDimensions(valueShape); i++) {
531bee07f73a5f998a2dd6dc581e7776557c21f9684Miao Wang        outputShape->dimensions.push_back(getSizeOfDimension(valueShape, i));
5328c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    }
5338c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    outputShape->offset = valueShape.offset;
5348c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    outputShape->scale = valueShape.scale;
5358c689bdfc391e47854ec27bad0f9d685135af253Yang Ni
5368c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    hitShape->type = OperandType::TENSOR_QUANT8_ASYMM;
5378c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    hitShape->dimensions = { lookups };
5388c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    hitShape->offset = 0;
5398c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    hitShape->scale = 1.f;
5408c689bdfc391e47854ec27bad0f9d685135af253Yang Ni
5418c689bdfc391e47854ec27bad0f9d685135af253Yang Ni    return true;
5428c689bdfc391e47854ec27bad0f9d685135af253Yang Ni}
5438c689bdfc391e47854ec27bad0f9d685135af253Yang Ni
544707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet} // namespace nn
545707dbd2d55f5dacf78ffb3ad7c8b3f37c2e9d758Jean-Luc Brouillet} // namespace android
546