1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <string.h>
16#include <vector>
17#include "tensorflow/contrib/lite/builtin_op_data.h"
18#include "tensorflow/contrib/lite/context.h"
19#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
20#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
21#include "tensorflow/contrib/lite/kernels/kernel_util.h"
22#include "tensorflow/contrib/lite/kernels/op_macros.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace transpose {
28
29// This file has two implementations of Transpose.
30enum KernelType {
31  kReference,
32};
33
34struct TransposeContext {
35  TransposeContext(TfLiteContext* context, TfLiteNode* node) {
36    input = GetInput(context, node, 0);
37    perm = GetInput(context, node, 1);
38    output = GetOutput(context, node, 0);
39  }
40  TfLiteTensor* input;
41  TfLiteTensor* perm;
42  TfLiteTensor* output;
43};
44
45TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
46                                TransposeContext* op_context) {
47  int dims = NumDimensions(op_context->input);
48  const int* perm_data = GetTensorData<int32_t>(op_context->perm);
49
50  // Ensure validity of the permutations tensor as a 1D tensor.
51  TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->perm), 1);
52  TF_LITE_ENSURE_EQ(context, op_context->perm->dims->data[0], dims);
53  for (int idx = 0; idx < dims; ++idx) {
54    TF_LITE_ENSURE_MSG(context, (perm_data[idx] >= 0 && perm_data[idx] < dims),
55                       "Transpose op permutations array is out of bounds.");
56  }
57
58  // Determine size of output tensor.
59  TfLiteIntArray* input_size = op_context->input->dims;
60  TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
61  for (int idx = 0; idx < dims; ++idx) {
62    output_size->data[idx] = input_size->data[perm_data[idx]];
63  }
64
65  return context->ResizeTensor(context, op_context->output, output_size);
66}
67
68TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
69  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
70  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
71
72  TransposeContext op_context(context, node);
73
74  // Ensure validity of input tensor.
75  TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 4,
76                     "Transpose op only supports 1D-4D input arrays.");
77  TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
78
79  if (!IsConstantTensor(op_context.perm)) {
80    SetTensorToDynamic(op_context.output);
81    return kTfLiteOk;
82  }
83  return ResizeOutputTensor(context, &op_context);
84}
85
86template <KernelType kernel_type>
87TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
88  TransposeContext op_context(context, node);
89
90  // Resize the output tensor if the output tensor is dynamic.
91  if (IsDynamicTensor(op_context.output)) {
92    TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
93  }
94
95  // Reverse the permuted axes and convert to 4D due to the way Dims are
96  // constructed in GetTensorDims.
97  const int* perm_data = GetTensorData<int32_t>(op_context.perm);
98  const int size = op_context.perm->dims->data[0];
99  const int kOutputDimensionNum = 4;
100  int reversed_perm[kOutputDimensionNum];
101
102  for (int output_k = 0, input_k = size - 1; output_k < size;
103       ++output_k, --input_k) {
104    reversed_perm[output_k] = size - perm_data[input_k] - 1;
105  }
106  for (int k = size; k < kOutputDimensionNum; ++k) {
107    reversed_perm[k] = k;
108  }
109
110#define TF_LITE_TRANSPOSE(type, scalar)                     \
111  type::Transpose(GetTensorData<scalar>(op_context.input),  \
112                  GetTensorDims(op_context.input),          \
113                  GetTensorData<scalar>(op_context.output), \
114                  GetTensorDims(op_context.output), reversed_perm)
115
116  switch (op_context.input->type) {
117    case kTfLiteFloat32:
118      if (kernel_type == kReference) {
119        TF_LITE_TRANSPOSE(reference_ops, float);
120      }
121      break;
122    case kTfLiteUInt8:
123      if (kernel_type == kReference) {
124        TF_LITE_TRANSPOSE(reference_ops, uint8_t);
125      }
126      break;
127    case kTfLiteInt32:
128      if (kernel_type == kReference) {
129        TF_LITE_TRANSPOSE(reference_ops, int32_t);
130      }
131      break;
132    case kTfLiteInt64:
133      if (kernel_type == kReference) {
134        TF_LITE_TRANSPOSE(reference_ops, int64_t);
135      }
136      break;
137    default:
138      context->ReportError(context,
139                           "Type is currently not supported by Transpose.");
140      return kTfLiteError;
141  }
142#undef TF_LITE_TRANSPOSE
143
144  return kTfLiteOk;
145}
146
147}  // namespace transpose
148
149TfLiteRegistration* Register_TRANSPOSE_REF() {
150  static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
151                                 transpose::Eval<transpose::kReference>};
152  return &r;
153}
154
155TfLiteRegistration* Register_TRANSPOSE() { return Register_TRANSPOSE_REF(); }
156
157}  // namespace builtin
158}  // namespace ops
159}  // namespace tflite
160