1/* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Contains the implementation of the operations. 18 19#define LOG_TAG "Operations" 20 21#include "Operations.h" 22#include "CpuOperationUtils.h" 23 24#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" 25 26namespace android { 27namespace nn { 28 29bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape, 30 const int32_t* beginData, const int32_t* endData, 31 const int32_t* stridesData, 32 int32_t beginMask, int32_t endMask, int32_t shrinkAxisMask, 33 uint8_t* outputData, const Shape& outputShape) { 34 // This Op only supports 1-4D cases and since we use the reference 4D 35 // implementation, the 1-3D tensors are mapped to 4D. 36 const int kMaxDim = 4; 37 38 std::vector<int> starts; 39 std::vector<int> stops; 40 std::vector<int> strides; 41 42 int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(inputShape)); 43 for (int32_t idx = numInputDims - 1; idx >= 0; --idx) { 44 int32_t dim = static_cast<int32_t>(getSizeOfDimension(inputShape, idx)); 45 int32_t stride = stridesData[idx]; 46 // stride value has to be non-zero 47 NN_OPS_CHECK(stride != 0); 48 bool positiveStride = stride > 0; 49 50 int32_t begin = beginMask & (1 << idx) 51 ? positiveStride ? 0 : dim - 1 52 : ClampedIndex(beginData[idx], dim, positiveStride); 53 int32_t end = endMask & (1 << idx) 54 ? positiveStride ? dim : -1 55 : ClampedIndex(endData[idx], dim, positiveStride); 56 57 starts.emplace_back(begin); 58 stops.emplace_back(end); 59 strides.emplace_back(stride); 60 } 61 62 for (int i = numInputDims; i < kMaxDim; i++) { 63 starts.emplace_back(0); 64 stops.emplace_back(1); 65 strides.emplace_back(1); 66 } 67 68 beginMask = ReverseMaskBits(beginMask, numInputDims); 69 endMask = ReverseMaskBits(endMask, numInputDims); 70 shrinkAxisMask = ReverseMaskBits(shrinkAxisMask, numInputDims); 71 72 if (inputShape.type == OperandType::TENSOR_FLOAT32) { 73 tflite::reference_ops::StridedSlice( 74 reinterpret_cast<const float*>(inputData), 75 convertShapeToDims(inputShape), 76 beginMask, endMask, shrinkAxisMask, 77 starts, stops, strides, 78 reinterpret_cast<float*>(outputData), 79 convertShapeToDims(outputShape)); 80 } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) { 81 tflite::reference_ops::StridedSlice( 82 reinterpret_cast<const uint8_t*>(inputData), 83 convertShapeToDims(inputShape), 84 beginMask, endMask, shrinkAxisMask, 85 starts, stops, strides, 86 reinterpret_cast<uint8_t*>(outputData), 87 convertShapeToDims(outputShape)); 88 } else { 89 LOG(ERROR) << "Unsupported data type"; 90 return false; 91 } 92 93 return true; 94} 95 96} // namespace nn 97} // namespace android 98