10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <algorithm>
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory>
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string>
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map>
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector>
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h"
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h"
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h"
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco {
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
28cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower// Reorder the elements of an input_array according to the input_axes_order and
29cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower// output_axes_order. Then adjust the shapes of the input and output arrays
30cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower// accordingly. Note that input_array must have a buffer (that is, it is a
31cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower// constant array).
32cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlowertemplate <typename T, ArrayDataType DataType>
33cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlowervoid ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
34cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower                 Array* input_array, Array* output_array) {
35cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  CHECK(input_array->buffer->type == DataType);
36cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  CHECK(!output_array->buffer);
37cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  auto& input_data = input_array->GetMutableBuffer<DataType>().data;
38cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  std::vector<T> reordered_data;
39cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  reordered_data.resize(RequiredBufferSizeForShape(output_array->shape()));
40cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  // TODO(b/62904716) Shapes should be used directly.
41cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  Shape input_shape = input_array->shape();
42cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  Shape output_shape = output_array->shape();
43cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  if (AxesCount(input_axes_order) == 2) {
44cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower    UnextendShape(&input_shape, 2);
45cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower    UnextendShape(&output_shape, 2);
46cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  }
47cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape,
48cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower               input_data.data(), reordered_data.data());
49cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  input_data = reordered_data;
50cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  input_array->copy_shape(output_array->shape());
51cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower}
52cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto reorder_it = model->operators.begin() + op_index;
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* reorder_op = static_cast<ReorderAxesOperator*>(reorder_it->get());
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (reorder_op->type != OperatorType::kReorderAxes) {
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& input_array_name = reorder_op->inputs[0];
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& output_array_name = reorder_op->outputs[0];
610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& input_array = model->GetArray(input_array_name);
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& output_array = model->GetArray(output_array_name);
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!input_array.buffer) {
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Yield until output dims have been resolved.
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!output_array.has_shape()) {
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Reorder the input array dims and buffer data
7160d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower  if (input_array.buffer->type == ArrayDataType::kFloat) {
7260d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower    ReorderAxes<float, ArrayDataType::kFloat>(reorder_op->input_axes_order,
7360d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower                                              reorder_op->output_axes_order,
7460d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower                                              &input_array, &output_array);
7560d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower  } else if (input_array.buffer->type == ArrayDataType::kInt32) {
7660d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower    ReorderAxes<uint8, ArrayDataType::kUint8>(reorder_op->input_axes_order,
7760d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower                                              reorder_op->output_axes_order,
7860d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower                                              &input_array, &output_array);
79cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower  } else {
80cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower    LOG(FATAL) << "Cannot ReorderAxes unless input buffer is float or uint8.";
810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
82cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower
830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  input_array.copy_shape(output_array.shape());
840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Update the edges of the graph to point to the input array
860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (const auto& other_op : model->operators) {
870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    for (auto& input : other_op->inputs) {
880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      if (input == output_array_name) {
890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        input = input_array_name;
900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      }
910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  AddMessageF("Reordered axes for array %s", input_array_name);
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Remove the op and output array.
97ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  model->EraseArray(output_array_name);
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(reorder_it);
990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace toco
103