quantize_training_test.cc revision 8ba27e3bd5203bd4cd9533b36c7fd02b85a0a42a
1bc0a56da15eed8738e8a53e2dd340030332df28aA. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 3a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenLicensed under the Apache License, Version 2.0 (the "License"); 4a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chenyou may not use this file except in compliance with the License. 5a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenYou may obtain a copy of the License at 6a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 7a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen http://www.apache.org/licenses/LICENSE-2.0 8a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 9a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenUnless required by applicable law or agreed to in writing, software 10a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chendistributed under the License is distributed on an "AS IS" BASIS, 11a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenSee the License for the specific language governing permissions and 13a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chenlimitations under the License. 14a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen==============================================================================*/ 15a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 16a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <map> 17a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <string> 18a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <unordered_map> 19a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <vector> 20a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 21a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/graph/quantize_training.h" 22a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 23a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/common_runtime/device_factory.h" 24a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/common_runtime/device_mgr.h" 25a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/node_def_util.h" 26a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/tensor.h" 27a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/tensor_shape.h" 28a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/tensor_testutil.h" 29a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/types.h" 308ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar#include "tensorflow/core/graph/graph_constructor.h" 31a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/graph/node_builder.h" 32a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/graph/testlib.h" 33a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/lib/core/status_test_util.h" 34a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/lib/core/threadpool.h" 35a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/lib/strings/strcat.h" 36a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/platform/test.h" 378ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar#include "tensorflow/core/public/session.h" 38a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/public/session_options.h" 39a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 40a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chennamespace tensorflow { 41a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chennamespace { 42a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 43a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chenclass QuantizeTrainingTest : public ::testing::Test { 44a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen protected: 45a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen QuantizeTrainingTest() { Reset(); } 46a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen void Reset() { g_.reset(new Graph(OpRegistry::Global())); } 47a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 48a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen template <typename T> 49a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* Constant(gtl::ArraySlice<T> values, TensorShape shape) { 50a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen return test::graph::Constant(g_.get(), test::AsTensor(values, shape)); 51a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen } 52a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 538ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Status Placeholder(Graph* g, const string& name, TensorShape shape, 548ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node** out) { 558ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_RETURN_IF_ERROR(NodeBuilder(name, "Placeholder") 568ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar .Attr("dtype", DT_FLOAT) 578ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar .Attr("shape", shape) 588ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar .Finalize(g, out)); 598ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar return Status::OK(); 608ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar } 618ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 628ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Status FindNode(Graph* g, const string& name, Node** out) { 638ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar for (Node* node : g->nodes()) { 648ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar if (node->name() == name) { 658ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar *out = node; 668ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar return Status::OK(); 678ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar } 688ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar } 698ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar return errors::Unimplemented("Node ", name, " not found."); 708ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar } 718ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 72a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen std::unique_ptr<Graph> g_; 73a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen}; 74a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 758ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh SivakumarTEST_F(QuantizeTrainingTest, SignedInput) { 768ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Test that Quantization ops are created with the correct signed_input value. 77a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen // Construct the following graph 78a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen /* 798ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar m1 808ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar / \ 818ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Relu Identity 82a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen | | 83a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen a b 84a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen */ 85a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Reset(); 86a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Graph* g = g_.get(); 87a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); 88a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); 89a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen g->AddControlEdge(g->source_node(), a); 90a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen g->AddControlEdge(g->source_node(), b); 91a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* relu = test::graph::Relu(g, a); 92a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* identity = test::graph::Identity(g, b); 93a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* m1 = test::graph::Matmul(g, relu, identity, false, false); 94a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen g->AddControlEdge(m1, g->sink_node()); 95a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 96d3a41971ccc7c03e8b476a48dc8aa7fbb983431cSuharsh Sivakumar /* 978ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar m1 988ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar / \ 998ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EMA_Q EMA_Q <- these are subgraphs that estimate moving average. 1008ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar | | 1018ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Relu Identity 102a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen | | 103a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen a b 104a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen */ 105d3a41971ccc7c03e8b476a48dc8aa7fbb983431cSuharsh Sivakumar const int num_bits = 8; 106a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen TF_ASSERT_OK(DoQuantizeTraining(num_bits, g)); 107a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 1088ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(63, g->num_nodes()); 109a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 110a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen // Quantize_and_dequantize node for identity should have signed_input==true. 1118ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* identity_q_node; 1128ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK( 1138ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"), 1148ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar &identity_q_node)); 1158ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar NodeDef identity_q = identity_q_node->def(); 116a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen ASSERT_EQ("true", 1178ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar SummarizeAttrValue(identity_q.attr().find("signed_input")->second)); 118a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen // Quantize_and_dequantize node for relu should have signed_input==false. 1198ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* relu_q_node; 1208ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK( 1218ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), 1228ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar &relu_q_node)); 1238ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar NodeDef relu_q = relu_q_node->def(); 124a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen ASSERT_EQ("false", 1258ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar SummarizeAttrValue(relu_q.attr().find("signed_input")->second)); 1268ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar} 1278ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 1288ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh SivakumarTEST_F(QuantizeTrainingTest, RangeGivenTrue) { 1298ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Test that Quantization ops are created with the correct range_given value. 1308ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Construct the following graph 1318ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar /* 1328ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar m1 1338ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar / \ 1348ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Relu Relu6 1358ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar | | 1368ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar a b 1378ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar */ 1388ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Reset(); 1398ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Graph* g = g_.get(); 1408ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); 1418ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); 1428ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar g->AddControlEdge(g->source_node(), a); 1438ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar g->AddControlEdge(g->source_node(), b); 1448ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* relu = test::graph::Relu(g, a); 1458ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* relu6 = test::graph::Relu6(g, b); 1468ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* m1 = test::graph::Matmul(g, relu, relu6, false, false); 1478ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar g->AddControlEdge(m1, g->sink_node()); 1488ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 1498ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar /* 1508ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar m1 1518ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar / \ 1528ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EMA_Q Q 1538ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar | | 1548ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Relu Relu6 1558ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar | | 1568ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar a b 1578ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar */ 1588ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar const int num_bits = 8; 1598ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(DoQuantizeTraining(num_bits, g)); 1608ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 1618ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(38, g->num_nodes()); 1628ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 1638ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Quantize_and_dequantize node for relu6 should have range_given==true. 1648ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* relu6_q_node; 1658ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK( 1668ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar FindNode(g, strings::StrCat(relu6->name(), "/QuantizeAndDequantizeV2"), 1678ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar &relu6_q_node)); 1688ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar NodeDef identity_q = relu6_q_node->def(); 1698ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar ASSERT_EQ("true", 1708ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar SummarizeAttrValue(identity_q.attr().find("range_given")->second)); 1718ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Quantize_and_dequantize node for relu should have range_given==true. 1728ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* relu_q_node; 1738ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK( 1748ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), 1758ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar &relu_q_node)); 1768ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar NodeDef relu_q = relu_q_node->def(); 1778ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar ASSERT_EQ("true", 1788ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar SummarizeAttrValue(relu_q.attr().find("range_given")->second)); 179a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen} 180a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 181a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenTEST_F(QuantizeTrainingTest, WithBackwardNodes) { 1828ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Construct a graph with an additional backward Matmul. 183a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Reset(); 184a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Graph* g = g_.get(); 185a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); 186a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); 187a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2}); 1888ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // We will use node d as input to the backwards matmul to ensure that it 1898ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // isn't quantized. 1908ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* d = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2}); 191a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen g->AddControlEdge(g->source_node(), a); 192a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen g->AddControlEdge(g->source_node(), b); 193a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen g->AddControlEdge(g->source_node(), c); 1948ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar g->AddControlEdge(g->source_node(), d); 195a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* relu = test::graph::Relu(g, a); 196a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* identity = test::graph::Identity(g, b); 197a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* m1 = test::graph::Matmul(g, relu, identity, false, false); 198a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* m2 = test::graph::Matmul(g, identity, c, false, false); 199a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen g->AddControlEdge(m1, g->sink_node()); 200a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen g->AddControlEdge(m2, g->sink_node()); 201a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 2028ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Add a Matmul node with name starting with "gradients". We will check that 2038ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // its input d was not quantized. 204a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen Node* backward_m; 205a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen TF_ASSERT_OK(NodeBuilder(g->NewName("gradients/n"), "MatMul") 2068ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar .Input(d) 207a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen .Input(m2) 208a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen .Attr("transpose_a", true) 209a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen .Attr("transpose_b", false) 210a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen .Finalize(g, &backward_m)); 211a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen g->AddControlEdge(backward_m, g->sink_node()); 212a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 213a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen int num_bits = 8; 214a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen TF_ASSERT_OK(DoQuantizeTraining(num_bits, g)); 215a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 2168ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(95, g->num_nodes()); 2178ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 2188ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Ensure that we the backwards matmul input was not quantized. 2198ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* found_node; 2208ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Status s = FindNode(g, strings::StrCat(d->name(), "/QuantizeAndDequantizeV2"), 2218ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar &found_node); 2228ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_TRUE(StringPiece(s.ToString()).contains("not found")) << s; 2238ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 2248ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Ensure that m1 and m2's inputs were quantized. 2258ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK( 2268ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), 2278ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar &found_node)); 2288ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK( 2298ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"), 2308ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar &found_node)); 2318ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(FindNode( 2328ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar g, strings::StrCat(c->name(), "/QuantizeAndDequantizeV2"), &found_node)); 233a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen} 234a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 2350e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin ChenTEST_F(QuantizeTrainingTest, QuantizeGraphDef) { 2360e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen // Construct a simple graph with 5 nodes. 2370e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen Reset(); 2380e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen Graph* graph = g_.get(); 2390e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen Node* const_a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); 2400e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen Node* const_b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2}); 2410e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen graph->AddControlEdge(graph->source_node(), const_a); 2420e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen graph->AddControlEdge(graph->source_node(), const_b); 2430e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen Node* relu = test::graph::Relu(graph, const_a); 2440e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen Node* identity = test::graph::Identity(graph, const_b); 2450e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen Node* matmul = test::graph::Matmul(graph, relu, identity, false, false); 2460e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen graph->AddControlEdge(matmul, graph->sink_node()); 2470e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen 2480e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen int num_bits = 8; 2490e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen 2500e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen // Convert the graph to the graphdef string. 2510e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen GraphDef input_graph; 2520e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen graph->ToGraphDef(&input_graph); 2530e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen string input_string; 2540e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen input_graph.SerializeToString(&input_string); 2550e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen 2560e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen string result_string; 2570e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen TF_ASSERT_OK(DoQuantizeTrainingOnSerializedGraphDef(input_string, num_bits, 2580e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen &result_string)); 2590e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen 2608ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar GraphDef result_graphdef; 2618ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_TRUE(ParseProtoUnlimited(&result_graphdef, result_string)); 2628ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 2638ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Ensure that quantizing the graph_def results in a graph with the same 2648ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // number of nodes. 2658ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar GraphConstructorOptions opts; 2668ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Graph result_graph(OpRegistry::Global()); 2678ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph)); 2688ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(DoQuantizeTraining(num_bits, graph)); 2698ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(graph->num_nodes(), result_graph.num_nodes()); 2708ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar} 2718ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 2728ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh SivakumarTEST_F(QuantizeTrainingTest, FixedRangeAndEMARange) { 2738ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Construct the following graph 2748ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Relu has an unknown range, so we will check if the EMA correctly estimates 2758ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // the range. 2768ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar /* 2778ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar m1 2788ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar / \ 2798ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Relu Relu6 2808ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar | | 2818ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar a c 2828ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar */ 2838ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Reset(); 2848ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Graph* g = g_.get(); 2858ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* a; 2868ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(Placeholder(g, "a", {2, 2}, &a)); 2878ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* c = Constant<float>({2.0, 3.0, 4.0, 5.0}, {2, 2}); 2888ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar g->AddControlEdge(g->source_node(), a); 2898ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar g->AddControlEdge(g->source_node(), c); 2908ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* relu = test::graph::Relu(g, a); 2918ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* relu6 = test::graph::Relu6(g, c); 2928ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Node* m1 = test::graph::Matmul(g, relu, relu6, false, false); 2938ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar g->AddControlEdge(m1, g->sink_node()); 2948ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 2958ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // This is rewritten into the following subgraph, where Q_a and Q_c are 2968ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // quantize and dequantize subgraphs. 2978ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Since relu's range is unknown, we check that the exponential moving average 2988ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // works correctly. 2998ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar /* 3008ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar m1 3018ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar / \ 3028ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Q_a Q_c 3038ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar | | 3048ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Relu Relu6 3058ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar | | 3068ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar a c 3078ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar */ 3088ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar const int num_bits = 8; 3098ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(DoQuantizeTraining(num_bits, g)); 3108ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 3118ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar SessionOptions options; 3128ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Session* sess; 3138ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(NewSession(options, &sess)); 3148ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar GraphDef gdef; 3158ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar g->ToGraphDef(&gdef); 3168ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(sess->Create(gdef)); 3178ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 3188ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // The min and max values of the relu6 quantization should be constant values 3198ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // of 0 and 6. 3208ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar string min_const_name = strings::StrCat(relu6->name(), "/InputMin"); 3218ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar string max_const_name = strings::StrCat(relu6->name(), "/InputMax"); 3228ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar std::vector<Tensor> outputs; 3238ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs)); 3248ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(outputs[0].flat<float>()(0), 0.0); 3258ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(outputs[1].flat<float>()(0), 6.0); 3268ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 3278ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Tensor a1(DT_FLOAT, TensorShape({2, 2})); 3288ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar test::FillValues<float>(&a1, {0.0, 1.0, 2.0, 3.0}); 3298ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar Tensor a2(DT_FLOAT, TensorShape({2, 2})); 3308ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar test::FillValues<float>(&a2, {1.0, 2.0, 3.0, 4.0}); 3318ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 3328ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(sess->Run({{"a", a1}}, {m1->name()}, {}, &outputs)); 3338ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 3348ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // The value of the min and max should be set to the min and max of a1 since 3358ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // this is the first run that initializes the EMA variables. 3368ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar string min_var_name = strings::StrCat(relu->name(), "/Min/Variable"); 3378ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar string max_var_name = strings::StrCat(relu->name(), "/Max/Variable"); 3388ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(sess->Run({}, {min_var_name, max_var_name}, {}, &outputs)); 3398ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(outputs[0].flat<float>()(0), 0.0); 3408ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(outputs[1].flat<float>()(0), 3.0); 3418ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 3428ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // The relu6 quantization range should remain unchanged. 3438ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs)); 3448ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(outputs[0].flat<float>()(0), 0.0); 3458ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(outputs[1].flat<float>()(0), 6.0); 3468ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 3478ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // Now when we run with new inputs, we should get a moving average for the min 3488ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // and max variables. They should be equal to: 3498ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // min_var = old_min_var * decay + min(a2) * (1 - decay) 3508ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // max_var = old_max_var * decay + max(a2) * (1 - decay) 3518ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(sess->Run({{"a", a2}}, {m1->name()}, {}, &outputs)); 3528ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar 3538ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(sess->Run({}, {min_var_name, max_var_name}, {}, &outputs)); 3548ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar const float decay = 0.999; 3558ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar const float expected_min = 0.0 * decay + 1.0 * (1.0 - decay); 3568ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar const float expected_max = 3.0 * decay + 4.0 * (1.0 - decay); 3578ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_NEAR(outputs[0].flat<float>()(0), expected_min, 1e-4); 3588ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_NEAR(outputs[1].flat<float>()(0), expected_max, 1e-4); 3590e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen 3608ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar // The relu6 quantization range should remain unchanged. 3618ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs)); 3628ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(outputs[0].flat<float>()(0), 0.0); 3638ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar EXPECT_EQ(outputs[1].flat<float>()(0), 6.0); 3640e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen} 365a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen 366a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen} // namespace 367a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen} // namespace tensorflow 368