quantize_training_test.cc revision 8ba27e3bd5203bd4cd9533b36c7fd02b85a0a42a
1bc0a56da15eed8738e8a53e2dd340030332df28aA. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
3a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenLicensed under the Apache License, Version 2.0 (the "License");
4a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chenyou may not use this file except in compliance with the License.
5a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenYou may obtain a copy of the License at
6a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
7a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen    http://www.apache.org/licenses/LICENSE-2.0
8a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
9a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenUnless required by applicable law or agreed to in writing, software
10a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chendistributed under the License is distributed on an "AS IS" BASIS,
11a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenSee the License for the specific language governing permissions and
13a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chenlimitations under the License.
14a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen==============================================================================*/
15a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
16a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <map>
17a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <string>
18a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <unordered_map>
19a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <vector>
20a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
21a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/graph/quantize_training.h"
22a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
23a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/common_runtime/device_factory.h"
24a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/common_runtime/device_mgr.h"
25a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/node_def_util.h"
26a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/tensor.h"
27a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/tensor_shape.h"
28a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/tensor_testutil.h"
29a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/types.h"
308ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar#include "tensorflow/core/graph/graph_constructor.h"
31a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/graph/node_builder.h"
32a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/graph/testlib.h"
33a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/lib/core/status_test_util.h"
34a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/lib/core/threadpool.h"
35a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/lib/strings/strcat.h"
36a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/platform/test.h"
378ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar#include "tensorflow/core/public/session.h"
38a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/public/session_options.h"
39a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
40a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chennamespace tensorflow {
41a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chennamespace {
42a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
43a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chenclass QuantizeTrainingTest : public ::testing::Test {
44a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen protected:
45a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  QuantizeTrainingTest() { Reset(); }
46a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  void Reset() { g_.reset(new Graph(OpRegistry::Global())); }
47a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
48a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  template <typename T>
49a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* Constant(gtl::ArraySlice<T> values, TensorShape shape) {
50a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen    return test::graph::Constant(g_.get(), test::AsTensor(values, shape));
51a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  }
52a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
538ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Status Placeholder(Graph* g, const string& name, TensorShape shape,
548ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                     Node** out) {
558ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    TF_RETURN_IF_ERROR(NodeBuilder(name, "Placeholder")
568ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                           .Attr("dtype", DT_FLOAT)
578ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                           .Attr("shape", shape)
588ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                           .Finalize(g, out));
598ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    return Status::OK();
608ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  }
618ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
628ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Status FindNode(Graph* g, const string& name, Node** out) {
638ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    for (Node* node : g->nodes()) {
648ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      if (node->name() == name) {
658ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        *out = node;
668ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        return Status::OK();
678ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      }
688ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    }
698ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    return errors::Unimplemented("Node ", name, " not found.");
708ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  }
718ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
72a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  std::unique_ptr<Graph> g_;
73a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen};
74a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
758ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh SivakumarTEST_F(QuantizeTrainingTest, SignedInput) {
768ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Test that Quantization ops are created with the correct signed_input value.
77a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  // Construct the following graph
78a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  /*
798ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar           m1
808ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        /      \
818ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      Relu   Identity
82a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen        |       |
83a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen        a       b
84a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  */
85a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Reset();
86a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Graph* g = g_.get();
87a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
88a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
89a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(g->source_node(), a);
90a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(g->source_node(), b);
91a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* relu = test::graph::Relu(g, a);
92a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* identity = test::graph::Identity(g, b);
93a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* m1 = test::graph::Matmul(g, relu, identity, false, false);
94a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(m1, g->sink_node());
95a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
96d3a41971ccc7c03e8b476a48dc8aa7fbb983431cSuharsh Sivakumar  /*
978ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar         m1
988ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      /      \
998ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    EMA_Q   EMA_Q  <- these are subgraphs that estimate moving average.
1008ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      |       |
1018ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    Relu   Identity
102a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen      |       |
103a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen      a       b
104a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  */
105d3a41971ccc7c03e8b476a48dc8aa7fbb983431cSuharsh Sivakumar  const int num_bits = 8;
106a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  TF_ASSERT_OK(DoQuantizeTraining(num_bits, g));
107a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
1088ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(63, g->num_nodes());
109a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
110a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  // Quantize_and_dequantize node for identity should have signed_input==true.
1118ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* identity_q_node;
1128ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
1138ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"),
1148ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &identity_q_node));
1158ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  NodeDef identity_q = identity_q_node->def();
116a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  ASSERT_EQ("true",
1178ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar            SummarizeAttrValue(identity_q.attr().find("signed_input")->second));
118a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  // Quantize_and_dequantize node for relu should have signed_input==false.
1198ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu_q_node;
1208ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
1218ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
1228ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &relu_q_node));
1238ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  NodeDef relu_q = relu_q_node->def();
124a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  ASSERT_EQ("false",
1258ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar            SummarizeAttrValue(relu_q.attr().find("signed_input")->second));
1268ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar}
1278ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
1288ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh SivakumarTEST_F(QuantizeTrainingTest, RangeGivenTrue) {
1298ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Test that Quantization ops are created with the correct range_given value.
1308ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Construct the following graph
1318ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  /*
1328ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar           m1
1338ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        /      \
1348ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      Relu   Relu6
1358ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        |       |
1368ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        a       b
1378ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  */
1388ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Reset();
1398ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Graph* g = g_.get();
1408ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
1418ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
1428ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(g->source_node(), a);
1438ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(g->source_node(), b);
1448ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu = test::graph::Relu(g, a);
1458ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu6 = test::graph::Relu6(g, b);
1468ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* m1 = test::graph::Matmul(g, relu, relu6, false, false);
1478ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(m1, g->sink_node());
1488ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
1498ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  /*
1508ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar         m1
1518ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      /      \
1528ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    EMA_Q     Q
1538ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      |       |
1548ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    Relu   Relu6
1558ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      |       |
1568ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      a       b
1578ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  */
1588ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  const int num_bits = 8;
1598ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTraining(num_bits, g));
1608ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
1618ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(38, g->num_nodes());
1628ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
1638ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Quantize_and_dequantize node for relu6 should have range_given==true.
1648ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu6_q_node;
1658ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
1668ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(relu6->name(), "/QuantizeAndDequantizeV2"),
1678ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &relu6_q_node));
1688ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  NodeDef identity_q = relu6_q_node->def();
1698ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  ASSERT_EQ("true",
1708ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar            SummarizeAttrValue(identity_q.attr().find("range_given")->second));
1718ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Quantize_and_dequantize node for relu should have range_given==true.
1728ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu_q_node;
1738ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
1748ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
1758ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &relu_q_node));
1768ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  NodeDef relu_q = relu_q_node->def();
1778ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  ASSERT_EQ("true",
1788ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar            SummarizeAttrValue(relu_q.attr().find("range_given")->second));
179a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen}
180a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
181a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenTEST_F(QuantizeTrainingTest, WithBackwardNodes) {
1828ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Construct a graph with an additional backward Matmul.
183a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Reset();
184a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Graph* g = g_.get();
185a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
186a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
187a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2});
1888ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // We will use node d as input to the backwards matmul to ensure that it
1898ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // isn't quantized.
1908ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* d = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2});
191a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(g->source_node(), a);
192a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(g->source_node(), b);
193a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(g->source_node(), c);
1948ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(g->source_node(), d);
195a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* relu = test::graph::Relu(g, a);
196a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* identity = test::graph::Identity(g, b);
197a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* m1 = test::graph::Matmul(g, relu, identity, false, false);
198a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* m2 = test::graph::Matmul(g, identity, c, false, false);
199a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(m1, g->sink_node());
200a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(m2, g->sink_node());
201a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
2028ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Add a Matmul node with name starting with "gradients". We will check that
2038ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // its input d was not quantized.
204a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* backward_m;
205a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  TF_ASSERT_OK(NodeBuilder(g->NewName("gradients/n"), "MatMul")
2068ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                   .Input(d)
207a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen                   .Input(m2)
208a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen                   .Attr("transpose_a", true)
209a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen                   .Attr("transpose_b", false)
210a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen                   .Finalize(g, &backward_m));
211a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(backward_m, g->sink_node());
212a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
213a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  int num_bits = 8;
214a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  TF_ASSERT_OK(DoQuantizeTraining(num_bits, g));
215a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
2168ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(95, g->num_nodes());
2178ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
2188ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Ensure that we the backwards matmul input was not quantized.
2198ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* found_node;
2208ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Status s = FindNode(g, strings::StrCat(d->name(), "/QuantizeAndDequantizeV2"),
2218ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                      &found_node);
2228ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_TRUE(StringPiece(s.ToString()).contains("not found")) << s;
2238ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
2248ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Ensure that m1 and m2's inputs were quantized.
2258ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
2268ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
2278ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &found_node));
2288ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
2298ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"),
2308ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &found_node));
2318ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(FindNode(
2328ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      g, strings::StrCat(c->name(), "/QuantizeAndDequantizeV2"), &found_node));
233a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen}
234a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
2350e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin ChenTEST_F(QuantizeTrainingTest, QuantizeGraphDef) {
2360e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  // Construct a simple graph with 5 nodes.
2370e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Reset();
2380e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Graph* graph = g_.get();
2390e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Node* const_a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
2400e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Node* const_b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
2410e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  graph->AddControlEdge(graph->source_node(), const_a);
2420e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  graph->AddControlEdge(graph->source_node(), const_b);
2430e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Node* relu = test::graph::Relu(graph, const_a);
2440e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Node* identity = test::graph::Identity(graph, const_b);
2450e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Node* matmul = test::graph::Matmul(graph, relu, identity, false, false);
2460e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  graph->AddControlEdge(matmul, graph->sink_node());
2470e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen
2480e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  int num_bits = 8;
2490e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen
2500e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  // Convert the graph to the graphdef string.
2510e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  GraphDef input_graph;
2520e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  graph->ToGraphDef(&input_graph);
2530e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  string input_string;
2540e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  input_graph.SerializeToString(&input_string);
2550e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen
2560e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  string result_string;
2570e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  TF_ASSERT_OK(DoQuantizeTrainingOnSerializedGraphDef(input_string, num_bits,
2580e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen                                                      &result_string));
2590e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen
2608ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  GraphDef result_graphdef;
2618ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_TRUE(ParseProtoUnlimited(&result_graphdef, result_string));
2628ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
2638ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Ensure that quantizing the graph_def results in a graph with the same
2648ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // number of nodes.
2658ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  GraphConstructorOptions opts;
2668ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Graph result_graph(OpRegistry::Global());
2678ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph));
2688ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTraining(num_bits, graph));
2698ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(graph->num_nodes(), result_graph.num_nodes());
2708ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar}
2718ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
2728ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh SivakumarTEST_F(QuantizeTrainingTest, FixedRangeAndEMARange) {
2738ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Construct the following graph
2748ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Relu has an unknown range, so we will check if the EMA correctly estimates
2758ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // the range.
2768ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  /*
2778ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar           m1
2788ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        /      \
2798ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      Relu    Relu6
2808ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        |       |
2818ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        a       c
2828ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  */
2838ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Reset();
2848ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Graph* g = g_.get();
2858ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* a;
2868ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(Placeholder(g, "a", {2, 2}, &a));
2878ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* c = Constant<float>({2.0, 3.0, 4.0, 5.0}, {2, 2});
2888ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(g->source_node(), a);
2898ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(g->source_node(), c);
2908ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu = test::graph::Relu(g, a);
2918ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu6 = test::graph::Relu6(g, c);
2928ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* m1 = test::graph::Matmul(g, relu, relu6, false, false);
2938ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(m1, g->sink_node());
2948ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
2958ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // This is rewritten into the following subgraph, where Q_a and Q_c are
2968ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // quantize and dequantize subgraphs.
2978ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Since relu's range is unknown, we check that the exponential moving average
2988ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // works correctly.
2998ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  /*
3008ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar         m1
3018ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      /      \
3028ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar     Q_a     Q_c
3038ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      |       |
3048ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    Relu     Relu6
3058ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      |       |
3068ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      a       c
3078ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  */
3088ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  const int num_bits = 8;
3098ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTraining(num_bits, g));
3108ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
3118ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  SessionOptions options;
3128ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Session* sess;
3138ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(NewSession(options, &sess));
3148ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  GraphDef gdef;
3158ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->ToGraphDef(&gdef);
3168ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Create(gdef));
3178ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
3188ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // The min and max values of the relu6 quantization should be constant values
3198ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // of 0 and 6.
3208ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  string min_const_name = strings::StrCat(relu6->name(), "/InputMin");
3218ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  string max_const_name = strings::StrCat(relu6->name(), "/InputMax");
3228ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  std::vector<Tensor> outputs;
3238ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs));
3248ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
3258ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 6.0);
3268ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
3278ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Tensor a1(DT_FLOAT, TensorShape({2, 2}));
3288ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  test::FillValues<float>(&a1, {0.0, 1.0, 2.0, 3.0});
3298ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Tensor a2(DT_FLOAT, TensorShape({2, 2}));
3308ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  test::FillValues<float>(&a2, {1.0, 2.0, 3.0, 4.0});
3318ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
3328ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({{"a", a1}}, {m1->name()}, {}, &outputs));
3338ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
3348ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // The value of the min and max should be set to the min and max of a1 since
3358ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // this is the first run that initializes the EMA variables.
3368ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  string min_var_name = strings::StrCat(relu->name(), "/Min/Variable");
3378ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  string max_var_name = strings::StrCat(relu->name(), "/Max/Variable");
3388ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_var_name, max_var_name}, {}, &outputs));
3398ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
3408ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 3.0);
3418ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
3428ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // The relu6 quantization range should remain unchanged.
3438ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs));
3448ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
3458ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 6.0);
3468ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
3478ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Now when we run with new inputs, we should get a moving average for the min
3488ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // and max variables. They should be equal to:
3498ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // min_var = old_min_var * decay + min(a2) * (1 - decay)
3508ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // max_var = old_max_var * decay + max(a2) * (1 - decay)
3518ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({{"a", a2}}, {m1->name()}, {}, &outputs));
3528ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
3538ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_var_name, max_var_name}, {}, &outputs));
3548ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  const float decay = 0.999;
3558ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  const float expected_min = 0.0 * decay + 1.0 * (1.0 - decay);
3568ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  const float expected_max = 3.0 * decay + 4.0 * (1.0 - decay);
3578ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_NEAR(outputs[0].flat<float>()(0), expected_min, 1e-4);
3588ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_NEAR(outputs[1].flat<float>()(0), expected_max, 1e-4);
3590e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen
3608ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // The relu6 quantization range should remain unchanged.
3618ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs));
3628ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
3638ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 6.0);
3640e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen}
365a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
366a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen}  // namespace
367a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen}  // namespace tensorflow
368