10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory>
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string>
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map>
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector>
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h"
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h"
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h"
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco {
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace {
280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellestd::vector<std::unique_ptr<Operator>>::iterator FindOperator(
300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    Model* model, const Operator* op) {
310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto it = model->operators.begin();
320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (; it != model->operators.end(); ++it) {
330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    if (it->get() == op) {
340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      break;
350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return it;
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool CheckArrayIsScalarFloat(Model* model, const std::string& name, float val) {
410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& op_array = model->GetArray(name);
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!op_array.buffer || op_array.buffer->type != ArrayDataType::kFloat ||
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      RequiredBufferSizeForShape(op_array.shape()) != 1) {
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& op_data = op_array.GetBuffer<ArrayDataType::kFloat>().data;
470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return op_data[0] == val;
480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// Returns index of scalar input when there is exactly one scalar, -1 otherwise
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleint GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op,
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle                                        float val) {
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  bool input0_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[0], val);
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  bool input1_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[1], val);
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return input0_is_scalar == input1_is_scalar ? -1 : input0_is_scalar ? 0 : 1;
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
60dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  // Follow sequences of min+max and max+min. First get the leading op.
61dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  const auto op_it = model->operators.begin() + op_index;
62dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  const auto* op_0 = op_it->get();
63dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  if (op_0->type != OperatorType::kTensorFlowMinimum &&
64dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower      op_0->type != OperatorType::kTensorFlowMaximum) {
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
67dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower
68dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  // Get the paired op and ensure it's the counter to the first.
69dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  const auto* op_1 = GetOpWithInput(*model, op_0->outputs[0]);
70dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  if (!op_1 ||
71dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower      (op_1->type != OperatorType::kTensorFlowMinimum &&
72dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower       op_1->type != OperatorType::kTensorFlowMaximum) ||
73dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower      op_0->type == op_1->type) {
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
76dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower
77dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  const auto* min_op =
78dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower      op_0->type == OperatorType::kTensorFlowMinimum ? op_0 : op_1;
79dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  const auto* max_op =
80dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower      op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1;
81dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower
82dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  CHECK_EQ(min_op->inputs.size(), 2);
83dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  CHECK_EQ(max_op->inputs.size(), 2);
84dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) {
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
87dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower
88dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  // Get the original input to the min+max pair.
89dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  int min_scalar_input_index =
90dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower      GetSingleScalarInputIndexOfBinaryOp(model, min_op, 1.0f);
91dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  int max_scalar_input_index =
92dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower      GetSingleScalarInputIndexOfBinaryOp(model, max_op, -1.0f);
93dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  if (min_scalar_input_index == -1 || max_scalar_input_index == -1) {
940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
96dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  int op_0_scalar_input_index =
97dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower      op_0 == min_op ? min_scalar_input_index : max_scalar_input_index;
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
99dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  // Create and emplace Relu1 node.
1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* relu1_op = new Relu1Operator;
101dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  relu1_op->inputs = {op_0->inputs[!op_0_scalar_input_index]};
102dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  relu1_op->outputs = op_1->outputs;
103dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  model->operators.emplace(op_it, relu1_op);
1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op));
1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
107dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  // Erase op scalar inputs & operators. Note that we preserve the non-scalar
108dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  // input to the first op as that's been redirected to the relu1_op.
109dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  DeleteArrayIfUsedOnce(op_0->inputs[op_0_scalar_input_index], model);
110dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  DeleteArrayIfUsedOnce(op_1->inputs[0], model);
111dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  DeleteArrayIfUsedOnce(op_1->inputs[1], model);
112dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  model->operators.erase(FindOperator(model, op_0));
113dd5dab7c45e0089af81cfe2955bf3e24a7917771A. Unique TensorFlower  model->operators.erase(FindOperator(model, op_1));
1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace toco
119