10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License"); 40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License. 50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at 60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle http://www.apache.org/licenses/LICENSE-2.0 80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software 100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS, 110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and 130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License. 140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/ 150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory> 160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string> 170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map> 180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector> 190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h" 220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h" 230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h" 240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco { 260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace { 280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellestd::vector<std::unique_ptr<Operator>>::iterator FindOperator( 300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle Model* model, const Operator* op) { 310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto it = model->operators.begin(); 320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle for (; it != model->operators.end(); ++it) { 330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (it->get() == op) { 340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle break; 350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return it; 380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool CheckArrayIsScalarFloat(Model* model, const std::string& name, float val) { 410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& op_array = model->GetArray(name); 420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (!op_array.buffer || op_array.buffer->type != ArrayDataType::kFloat || 430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle RequiredBufferSizeForShape(op_array.shape()) != 1) { 440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& op_data = op_array.GetBuffer<ArrayDataType::kFloat>().data; 470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return op_data[0] == val; 480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// Returns index of scalar input when there is exactly one scalar, -1 otherwise 510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleint GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op, 520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle float val) { 530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle bool input0_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[0], val); 540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle bool input1_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[1], val); 550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return input0_is_scalar == input1_is_scalar ? -1 : input0_is_scalar ? 0 : 1; 560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace 580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool IdentifyRelu1::Run(Model* model, std::size_t op_index) { 60dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower // Follow sequences of min+max and max+min. First get the leading op. 61dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower const auto op_it = model->operators.begin() + op_index; 62dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower const auto* op_0 = op_it->get(); 63dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower if (op_0->type != OperatorType::kTensorFlowMinimum && 64dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower op_0->type != OperatorType::kTensorFlowMaximum) { 650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 67dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower 68dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower // Get the paired op and ensure it's the counter to the first. 69dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower const auto* op_1 = GetOpWithInput(*model, op_0->outputs[0]); 70dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower if (!op_1 || 71dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower (op_1->type != OperatorType::kTensorFlowMinimum && 72dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower op_1->type != OperatorType::kTensorFlowMaximum) || 73dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower op_0->type == op_1->type) { 740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 76dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower 77dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower const auto* min_op = 78dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower op_0->type == OperatorType::kTensorFlowMinimum ? op_0 : op_1; 79dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower const auto* max_op = 80dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1; 81dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower 82dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower CHECK_EQ(min_op->inputs.size(), 2); 83dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower CHECK_EQ(max_op->inputs.size(), 2); 84dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) { 850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 87dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower 88dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower // Get the original input to the min+max pair. 89dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower int min_scalar_input_index = 90dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower GetSingleScalarInputIndexOfBinaryOp(model, min_op, 1.0f); 91dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower int max_scalar_input_index = 92dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower GetSingleScalarInputIndexOfBinaryOp(model, max_op, -1.0f); 93dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower if (min_scalar_input_index == -1 || max_scalar_input_index == -1) { 940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 96dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower int op_0_scalar_input_index = 97dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower op_0 == min_op ? min_scalar_input_index : max_scalar_input_index; 980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 99dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower // Create and emplace Relu1 node. 1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto* relu1_op = new Relu1Operator; 101dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower relu1_op->inputs = {op_0->inputs[!op_0_scalar_input_index]}; 102dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower relu1_op->outputs = op_1->outputs; 103dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower model->operators.emplace(op_it, relu1_op); 1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op)); 1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 107dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower // Erase op scalar inputs & operators. Note that we preserve the non-scalar 108dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower // input to the first op as that's been redirected to the relu1_op. 109dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower DeleteArrayIfUsedOnce(op_0->inputs[op_0_scalar_input_index], model); 110dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower DeleteArrayIfUsedOnce(op_1->inputs[0], model); 111dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower DeleteArrayIfUsedOnce(op_1->inputs[1], model); 112dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower model->operators.erase(FindOperator(model, op_0)); 113dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower model->operators.erase(FindOperator(model, op_1)); 1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return true; 1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace toco 119