10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License"); 40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License. 50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at 60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle http://www.apache.org/licenses/LICENSE-2.0 80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software 100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS, 110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and 130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License. 140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/ 150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <memory> 160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <string> 170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <unordered_map> 180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include <vector> 190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" 210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/model.h" 220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/runtime/types.h" 230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/toco/tooling_util.h" 240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/core/platform/logging.h" 250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace toco { 270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellebool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { 290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto bn_it = model->operators.begin() + op_index; 300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle if (bn_it->get()->type != OperatorType::kBatchNormalization) { 310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return false; 320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto* bn_op = 340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle static_cast<const BatchNormalizationOperator*>(bn_it->get()); 350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& mean_array = model->GetArray(bn_op->inputs[1]); 370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& multiplier_array = model->GetArray(bn_op->inputs[2]); 380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& offset_array = model->GetArray(bn_op->inputs[3]); 390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) && 410b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle IsConstantParameterArray(*model, bn_op->inputs[2]) && 420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle IsConstantParameterArray(*model, bn_op->inputs[3])) 430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle << "Batch normalization resolution requires that mean, multiplier and " 440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle "offset arrays be constant."; 450b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 460b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // We should only have *float* BatchNormalizations... let's guard this 470b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // assumption by CHECK's. 480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(mean_array.data_type == ArrayDataType::kFloat); 490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(multiplier_array.data_type == ArrayDataType::kFloat); 500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(offset_array.data_type == ArrayDataType::kFloat); 510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Create the new Mul, Add operators 530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto* mul_op = new MulOperator; 540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto* add_op = new AddOperator; 550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const string mul_name = 560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AvailableArrayName(*model, bn_op->outputs[0] + "_mul"); 570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const string add_name = 580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AvailableArrayName(*model, bn_op->outputs[0] + "_add"); 590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const string mul_param_name = AvailableArrayName(*model, mul_name + "_param"); 600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const string add_param_name = AvailableArrayName(*model, add_name + "_param"); 610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle mul_op->inputs = {bn_op->inputs[0], mul_param_name}; 620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle mul_op->outputs = {mul_name}; 630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle add_op->inputs = {mul_name, add_param_name}; 640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle add_op->outputs = {bn_op->outputs[0]}; 650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op), 660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle LogName(*add_op)); 670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Create the intermediate activation array (output of mul, input of add) 690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]); 700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type; 710b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 720b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Insert the new operators in the graph 730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto add_it = model->operators.emplace(bn_it, add_op); 740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto mul_it = model->operators.emplace(add_it, mul_op); 750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // update invalidated iterators. 760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DCHECK_EQ(mul_it->get(), mul_op); 770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle add_it = mul_it + 1; 780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DCHECK_EQ(add_it->get(), add_op); 790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle bn_it = add_it + 1; 800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DCHECK_EQ(bn_it->get(), bn_op); 810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Create the new param arrays 830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& mean_shape = mean_array.shape(); 840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& multiplier_shape = multiplier_array.shape(); 850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& offset_shape = offset_array.shape(); 860b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(mean_shape.dims() == multiplier_shape.dims()); 870b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(mean_shape.dims() == offset_shape.dims()); 880b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& param_shape = mean_shape; 890b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const int buffer_size = RequiredBufferSizeForShape(param_shape); 900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto& mul_param_array = model->GetOrCreateArray(mul_param_name); 910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto& add_param_array = model->GetOrCreateArray(add_param_name); 920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DropMinMax(model, mul_param_name); 930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DropMinMax(model, add_param_name); 940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle mul_param_array.copy_shape(param_shape); 950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle add_param_array.copy_shape(param_shape); 960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle mul_param_array.data_type = ArrayDataType::kFloat; 970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle add_param_array.data_type = ArrayDataType::kFloat; 980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto& mul_float_data = 990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle mul_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data; 1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle auto& add_float_data = 1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle add_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data; 1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle mul_float_data.resize(buffer_size); 1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle add_float_data.resize(buffer_size); 1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& mean_float_data = 1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle mean_array.GetBuffer<ArrayDataType::kFloat>().data; 1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& multiplier_float_data = 1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle multiplier_array.GetBuffer<ArrayDataType::kFloat>().data; 1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle const auto& offset_float_data = 1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle offset_array.GetBuffer<ArrayDataType::kFloat>().data; 1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(mul_float_data.size() == buffer_size); 1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(add_float_data.size() == buffer_size); 1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(mean_float_data.size() == buffer_size); 1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(multiplier_float_data.size() == buffer_size); 1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle CHECK(offset_float_data.size() == buffer_size); 1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle for (int i = 0; i < buffer_size; i++) { 1180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle mul_float_data[i] = multiplier_float_data[i]; 1190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle add_float_data[i] = 1200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i]; 1210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle } 1220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Remove the old param arrays 124ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->EraseArray(bn_op->inputs[1]); 125ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->EraseArray(bn_op->inputs[2]); 126ba4aec48268d02f111cd7e2c2666f4e7b077e68aA. Unique TensorFlower model->EraseArray(bn_op->inputs[3]); 1270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle // Remove the old operator 1290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle DCHECK_EQ(bn_it->get(), bn_op); 1300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle model->operators.erase(bn_it); 1310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle return true; 1330b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} 1340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle 1350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle} // namespace toco 136