10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory>
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string>
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map>
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector>
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h"
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/runtime/types.h"
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h"
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h"
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco {
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto bn_it = model->operators.begin() + op_index;
300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (bn_it->get()->type != OperatorType::kBatchNormalization) {
310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return false;
320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto* bn_op =
340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      static_cast<const BatchNormalizationOperator*>(bn_it->get());
350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& mean_array = model->GetArray(bn_op->inputs[1]);
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& multiplier_array = model->GetArray(bn_op->inputs[2]);
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& offset_array = model->GetArray(bn_op->inputs[3]);
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) &&
410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        IsConstantParameterArray(*model, bn_op->inputs[2]) &&
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        IsConstantParameterArray(*model, bn_op->inputs[3]))
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      << "Batch normalization resolution requires that mean, multiplier and "
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle         "offset arrays be constant.";
450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // We should only have *float* BatchNormalizations... let's guard this
470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // assumption by CHECK's.
480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(mean_array.data_type == ArrayDataType::kFloat);
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(multiplier_array.data_type == ArrayDataType::kFloat);
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(offset_array.data_type == ArrayDataType::kFloat);
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Create the new Mul, Add operators
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* mul_op = new MulOperator;
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto* add_op = new AddOperator;
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const string mul_name =
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      AvailableArrayName(*model, bn_op->outputs[0] + "_mul");
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const string add_name =
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      AvailableArrayName(*model, bn_op->outputs[0] + "_add");
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const string mul_param_name = AvailableArrayName(*model, mul_name + "_param");
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const string add_param_name = AvailableArrayName(*model, add_name + "_param");
610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  mul_op->inputs = {bn_op->inputs[0], mul_param_name};
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  mul_op->outputs = {mul_name};
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  add_op->inputs = {mul_name, add_param_name};
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  add_op->outputs = {bn_op->outputs[0]};
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op),
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              LogName(*add_op));
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Create the intermediate activation array (output of mul, input of add)
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]);
700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type;
710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Insert the new operators in the graph
730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto add_it = model->operators.emplace(bn_it, add_op);
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto mul_it = model->operators.emplace(add_it, mul_op);
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // update invalidated iterators.
760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DCHECK_EQ(mul_it->get(), mul_op);
770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  add_it = mul_it + 1;
780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DCHECK_EQ(add_it->get(), add_op);
790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  bn_it = add_it + 1;
800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DCHECK_EQ(bn_it->get(), bn_op);
810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Create the new param arrays
830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& mean_shape = mean_array.shape();
840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& multiplier_shape = multiplier_array.shape();
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& offset_shape = offset_array.shape();
860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(mean_shape.dims() == multiplier_shape.dims());
870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(mean_shape.dims() == offset_shape.dims());
880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& param_shape = mean_shape;
890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const int buffer_size = RequiredBufferSizeForShape(param_shape);
900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& mul_param_array = model->GetOrCreateArray(mul_param_name);
910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& add_param_array = model->GetOrCreateArray(add_param_name);
920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DropMinMax(model, mul_param_name);
930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DropMinMax(model, add_param_name);
940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  mul_param_array.copy_shape(param_shape);
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  add_param_array.copy_shape(param_shape);
960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  mul_param_array.data_type = ArrayDataType::kFloat;
970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  add_param_array.data_type = ArrayDataType::kFloat;
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& mul_float_data =
990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      mul_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  auto& add_float_data =
1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      add_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  mul_float_data.resize(buffer_size);
1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  add_float_data.resize(buffer_size);
1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& mean_float_data =
1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      mean_array.GetBuffer<ArrayDataType::kFloat>().data;
1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& multiplier_float_data =
1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      multiplier_array.GetBuffer<ArrayDataType::kFloat>().data;
1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  const auto& offset_float_data =
1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      offset_array.GetBuffer<ArrayDataType::kFloat>().data;
1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(mul_float_data.size() == buffer_size);
1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(add_float_data.size() == buffer_size);
1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(mean_float_data.size() == buffer_size);
1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(multiplier_float_data.size() == buffer_size);
1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  CHECK(offset_float_data.size() == buffer_size);
1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (int i = 0; i < buffer_size; i++) {
1180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    mul_float_data[i] = multiplier_float_data[i];
1190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    add_float_data[i] =
1200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i];
1210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Remove the old param arrays
124ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  model->EraseArray(bn_op->inputs[1]);
125ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  model->EraseArray(bn_op->inputs[2]);
126ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower  model->EraseArray(bn_op->inputs[3]);
1270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Remove the old operator
1290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  DCHECK_EQ(bn_it->get(), bn_op);
1300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  model->operators.erase(bn_it);
1310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  return true;
1330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace toco
136