1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <memory>
16#include <string>
17#include <unordered_map>
18#include <vector>
19
20#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
21#include "tensorflow/contrib/lite/toco/model.h"
22#include "tensorflow/contrib/lite/toco/runtime/types.h"
23#include "tensorflow/contrib/lite/toco/tooling_util.h"
24#include "tensorflow/core/platform/logging.h"
25
26namespace toco {
27
28bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
29  auto bn_it = model->operators.begin() + op_index;
30  if (bn_it->get()->type != OperatorType::kBatchNormalization) {
31    return false;
32  }
33  const auto* bn_op =
34      static_cast<const BatchNormalizationOperator*>(bn_it->get());
35
36  const auto& mean_array = model->GetArray(bn_op->inputs[1]);
37  const auto& multiplier_array = model->GetArray(bn_op->inputs[2]);
38  const auto& offset_array = model->GetArray(bn_op->inputs[3]);
39
40  CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) &&
41        IsConstantParameterArray(*model, bn_op->inputs[2]) &&
42        IsConstantParameterArray(*model, bn_op->inputs[3]))
43      << "Batch normalization resolution requires that mean, multiplier and "
44         "offset arrays be constant.";
45
46  // We should only have *float* BatchNormalizations... let's guard this
47  // assumption by CHECK's.
48  CHECK(mean_array.data_type == ArrayDataType::kFloat);
49  CHECK(multiplier_array.data_type == ArrayDataType::kFloat);
50  CHECK(offset_array.data_type == ArrayDataType::kFloat);
51
52  // Create the new Mul, Add operators
53  auto* mul_op = new MulOperator;
54  auto* add_op = new AddOperator;
55  const string mul_name =
56      AvailableArrayName(*model, bn_op->outputs[0] + "_mul");
57  const string add_name =
58      AvailableArrayName(*model, bn_op->outputs[0] + "_add");
59  const string mul_param_name = AvailableArrayName(*model, mul_name + "_param");
60  const string add_param_name = AvailableArrayName(*model, add_name + "_param");
61  mul_op->inputs = {bn_op->inputs[0], mul_param_name};
62  mul_op->outputs = {mul_name};
63  add_op->inputs = {mul_name, add_param_name};
64  add_op->outputs = {bn_op->outputs[0]};
65  AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op),
66              LogName(*add_op));
67
68  // Create the intermediate activation array (output of mul, input of add)
69  auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]);
70  intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type;
71
72  // Insert the new operators in the graph
73  auto add_it = model->operators.emplace(bn_it, add_op);
74  auto mul_it = model->operators.emplace(add_it, mul_op);
75  // update invalidated iterators.
76  DCHECK_EQ(mul_it->get(), mul_op);
77  add_it = mul_it + 1;
78  DCHECK_EQ(add_it->get(), add_op);
79  bn_it = add_it + 1;
80  DCHECK_EQ(bn_it->get(), bn_op);
81
82  // Create the new param arrays
83  const auto& mean_shape = mean_array.shape();
84  const auto& multiplier_shape = multiplier_array.shape();
85  const auto& offset_shape = offset_array.shape();
86  CHECK(mean_shape.dims() == multiplier_shape.dims());
87  CHECK(mean_shape.dims() == offset_shape.dims());
88  const auto& param_shape = mean_shape;
89  const int buffer_size = RequiredBufferSizeForShape(param_shape);
90  auto& mul_param_array = model->GetOrCreateArray(mul_param_name);
91  auto& add_param_array = model->GetOrCreateArray(add_param_name);
92  DropMinMax(model, mul_param_name);
93  DropMinMax(model, add_param_name);
94  mul_param_array.copy_shape(param_shape);
95  add_param_array.copy_shape(param_shape);
96  mul_param_array.data_type = ArrayDataType::kFloat;
97  add_param_array.data_type = ArrayDataType::kFloat;
98  auto& mul_float_data =
99      mul_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
100  auto& add_float_data =
101      add_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
102  mul_float_data.resize(buffer_size);
103  add_float_data.resize(buffer_size);
104  const auto& mean_float_data =
105      mean_array.GetBuffer<ArrayDataType::kFloat>().data;
106  const auto& multiplier_float_data =
107      multiplier_array.GetBuffer<ArrayDataType::kFloat>().data;
108  const auto& offset_float_data =
109      offset_array.GetBuffer<ArrayDataType::kFloat>().data;
110
111  CHECK(mul_float_data.size() == buffer_size);
112  CHECK(add_float_data.size() == buffer_size);
113  CHECK(mean_float_data.size() == buffer_size);
114  CHECK(multiplier_float_data.size() == buffer_size);
115  CHECK(offset_float_data.size() == buffer_size);
116
117  for (int i = 0; i < buffer_size; i++) {
118    mul_float_data[i] = multiplier_float_data[i];
119    add_float_data[i] =
120        offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i];
121  }
122
123  // Remove the old param arrays
124  model->EraseArray(bn_op->inputs[1]);
125  model->EraseArray(bn_op->inputs[2]);
126  model->EraseArray(bn_op->inputs[3]);
127
128  // Remove the old operator
129  DCHECK_EQ(bn_it->get(), bn_op);
130  model->operators.erase(bn_it);
131
132  return true;
133}
134
135}  // namespace toco
136