10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License"); 40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License. 50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at 60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle http://www.apache.org/licenses/LICENSE-2.0 80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software 100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS, 110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and 130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License. 140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/ 150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <algorithm> 160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory> 170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string> 180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map> 190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector> 200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h" 230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h" 240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h" 250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco { 270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 28cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower// Reorder the elements of an input_array according to the input_axes_order and 29cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower// output_axes_order. Then adjust the shapes of the input and output arrays 30cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower// accordingly. Note that input_array must have a buffer (that is, it is a 31cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower// constant array). 32cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlowertemplate <typename T, ArrayDataType DataType> 33cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlowervoid ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order, 34cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower Array* input_array, Array* output_array) { 35cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower CHECK(input_array->buffer->type == DataType); 36cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower CHECK(!output_array->buffer); 37cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower auto& input_data = input_array->GetMutableBuffer<DataType>().data; 38cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower std::vector<T> reordered_data; 39cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower reordered_data.resize(RequiredBufferSizeForShape(output_array->shape())); 40cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower // TODO(b/62904716) Shapes should be used directly. 41cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower Shape input_shape = input_array->shape(); 42cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower Shape output_shape = output_array->shape(); 43cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower if (AxesCount(input_axes_order) == 2) { 44cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower UnextendShape(&input_shape, 2); 45cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower UnextendShape(&output_shape, 2); 46cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower } 47cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape, 48cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower input_data.data(), reordered_data.data()); 49cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower input_data = reordered_data; 50cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower input_array->copy_shape(output_array->shape()); 51cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower} 52cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower 530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { 540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto reorder_it = model->operators.begin() + op_index; 550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto* reorder_op = static_cast<ReorderAxesOperator*>(reorder_it->get()); 560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (reorder_op->type != OperatorType::kReorderAxes) { 570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& input_array_name = reorder_op->inputs[0]; 600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& output_array_name = reorder_op->outputs[0]; 610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto& input_array = model->GetArray(input_array_name); 620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto& output_array = model->GetArray(output_array_name); 630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!input_array.buffer) { 640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Yield until output dims have been resolved. 670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!output_array.has_shape()) { 680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Reorder the input array dims and buffer data 7160d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower if (input_array.buffer->type == ArrayDataType::kFloat) { 7260d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower ReorderAxes<float, ArrayDataType::kFloat>(reorder_op->input_axes_order, 7360d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower reorder_op->output_axes_order, 7460d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower &input_array, &output_array); 7560d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower } else if (input_array.buffer->type == ArrayDataType::kInt32) { 7660d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower ReorderAxes<uint8, ArrayDataType::kUint8>(reorder_op->input_axes_order, 7760d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower reorder_op->output_axes_order, 7860d5caeb2d506401d480503e21cc97c9a784c81bA. Unique TensorFlower &input_array, &output_array); 79cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower } else { 80cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower LOG(FATAL) << "Cannot ReorderAxes unless input buffer is float or uint8."; 810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 82cf327e8560fc044ab37e6a766c852e7b6546f228A. Unique TensorFlower 830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle input_array.copy_shape(output_array.shape()); 840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Update the edges of the graph to point to the input array 860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle for (const auto& other_op : model->operators) { 870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle for (auto& input : other_op->inputs) { 880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (input == output_array_name) { 890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle input = input_array_name; 900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AddMessageF("Reordered axes for array %s", input_array_name); 950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Remove the op and output array. 97ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->EraseArray(output_array_name); 980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(reorder_it); 990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return true; 1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace toco 103