1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <string.h> 16#include <vector> 17#include "tensorflow/contrib/lite/builtin_op_data.h" 18#include "tensorflow/contrib/lite/context.h" 19#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" 20#include "tensorflow/contrib/lite/kernels/internal/tensor.h" 21#include "tensorflow/contrib/lite/kernels/kernel_util.h" 22#include "tensorflow/contrib/lite/kernels/op_macros.h" 23 24namespace tflite { 25namespace ops { 26namespace builtin { 27namespace transpose { 28 29// This file has two implementations of Transpose. 30enum KernelType { 31 kReference, 32}; 33 34struct TransposeContext { 35 TransposeContext(TfLiteContext* context, TfLiteNode* node) { 36 input = GetInput(context, node, 0); 37 perm = GetInput(context, node, 1); 38 output = GetOutput(context, node, 0); 39 } 40 TfLiteTensor* input; 41 TfLiteTensor* perm; 42 TfLiteTensor* output; 43}; 44 45TfLiteStatus ResizeOutputTensor(TfLiteContext* context, 46 TransposeContext* op_context) { 47 int dims = NumDimensions(op_context->input); 48 const int* perm_data = GetTensorData<int32_t>(op_context->perm); 49 50 // Ensure validity of the permutations tensor as a 1D tensor. 51 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->perm), 1); 52 TF_LITE_ENSURE_EQ(context, op_context->perm->dims->data[0], dims); 53 for (int idx = 0; idx < dims; ++idx) { 54 TF_LITE_ENSURE_MSG(context, (perm_data[idx] >= 0 && perm_data[idx] < dims), 55 "Transpose op permutations array is out of bounds."); 56 } 57 58 // Determine size of output tensor. 59 TfLiteIntArray* input_size = op_context->input->dims; 60 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); 61 for (int idx = 0; idx < dims; ++idx) { 62 output_size->data[idx] = input_size->data[perm_data[idx]]; 63 } 64 65 return context->ResizeTensor(context, op_context->output, output_size); 66} 67 68TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 69 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); 70 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 71 72 TransposeContext op_context(context, node); 73 74 // Ensure validity of input tensor. 75 TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 4, 76 "Transpose op only supports 1D-4D input arrays."); 77 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); 78 79 if (!IsConstantTensor(op_context.perm)) { 80 SetTensorToDynamic(op_context.output); 81 return kTfLiteOk; 82 } 83 return ResizeOutputTensor(context, &op_context); 84} 85 86template <KernelType kernel_type> 87TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 88 TransposeContext op_context(context, node); 89 90 // Resize the output tensor if the output tensor is dynamic. 91 if (IsDynamicTensor(op_context.output)) { 92 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); 93 } 94 95 // Reverse the permuted axes and convert to 4D due to the way Dims are 96 // constructed in GetTensorDims. 97 const int* perm_data = GetTensorData<int32_t>(op_context.perm); 98 const int size = op_context.perm->dims->data[0]; 99 const int kOutputDimensionNum = 4; 100 int reversed_perm[kOutputDimensionNum]; 101 102 for (int output_k = 0, input_k = size - 1; output_k < size; 103 ++output_k, --input_k) { 104 reversed_perm[output_k] = size - perm_data[input_k] - 1; 105 } 106 for (int k = size; k < kOutputDimensionNum; ++k) { 107 reversed_perm[k] = k; 108 } 109 110#define TF_LITE_TRANSPOSE(type, scalar) \ 111 type::Transpose(GetTensorData<scalar>(op_context.input), \ 112 GetTensorDims(op_context.input), \ 113 GetTensorData<scalar>(op_context.output), \ 114 GetTensorDims(op_context.output), reversed_perm) 115 116 switch (op_context.input->type) { 117 case kTfLiteFloat32: 118 if (kernel_type == kReference) { 119 TF_LITE_TRANSPOSE(reference_ops, float); 120 } 121 break; 122 case kTfLiteUInt8: 123 if (kernel_type == kReference) { 124 TF_LITE_TRANSPOSE(reference_ops, uint8_t); 125 } 126 break; 127 case kTfLiteInt32: 128 if (kernel_type == kReference) { 129 TF_LITE_TRANSPOSE(reference_ops, int32_t); 130 } 131 break; 132 case kTfLiteInt64: 133 if (kernel_type == kReference) { 134 TF_LITE_TRANSPOSE(reference_ops, int64_t); 135 } 136 break; 137 default: 138 context->ReportError(context, 139 "Type is currently not supported by Transpose."); 140 return kTfLiteError; 141 } 142#undef TF_LITE_TRANSPOSE 143 144 return kTfLiteOk; 145} 146 147} // namespace transpose 148 149TfLiteRegistration* Register_TRANSPOSE_REF() { 150 static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare, 151 transpose::Eval<transpose::kReference>}; 152 return &r; 153} 154 155TfLiteRegistration* Register_TRANSPOSE() { return Register_TRANSPOSE_REF(); } 156 157} // namespace builtin 158} // namespace ops 159} // namespace tflite 160