1bc0a56da15eed8738e8a53e2dd340030332df28aA. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
3a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenLicensed under the Apache License, Version 2.0 (the "License");
4a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chenyou may not use this file except in compliance with the License.
5a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenYou may obtain a copy of the License at
6a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
7a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen    http://www.apache.org/licenses/LICENSE-2.0
8a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
9a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenUnless required by applicable law or agreed to in writing, software
10a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chendistributed under the License is distributed on an "AS IS" BASIS,
11a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin ChenSee the License for the specific language governing permissions and
13a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chenlimitations under the License.
14a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen==============================================================================*/
15a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
16a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <map>
17a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <string>
18a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <unordered_map>
19a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include <vector>
20a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
21a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/graph/quantize_training.h"
22a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
23a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/common_runtime/device_factory.h"
24a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/common_runtime/device_mgr.h"
25a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/node_def_util.h"
26a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/tensor.h"
27a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/tensor_shape.h"
28a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/tensor_testutil.h"
29a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/framework/types.h"
308ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar#include "tensorflow/core/graph/graph_constructor.h"
31a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/graph/node_builder.h"
32a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/graph/testlib.h"
33a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/lib/core/status_test_util.h"
34a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/lib/core/threadpool.h"
35a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/lib/strings/strcat.h"
36a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/platform/test.h"
378ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar#include "tensorflow/core/public/session.h"
38a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen#include "tensorflow/core/public/session_options.h"
39a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
40a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chennamespace tensorflow {
41a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chennamespace {
42a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
43a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chenclass QuantizeTrainingTest : public ::testing::Test {
44a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen protected:
45a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  QuantizeTrainingTest() { Reset(); }
46a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  void Reset() { g_.reset(new Graph(OpRegistry::Global())); }
47a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
48a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  template <typename T>
49a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* Constant(gtl::ArraySlice<T> values, TensorShape shape) {
50a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen    return test::graph::Constant(g_.get(), test::AsTensor(values, shape));
51a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  }
52a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
538ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Status Placeholder(Graph* g, const string& name, TensorShape shape,
548ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                     Node** out) {
558ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    TF_RETURN_IF_ERROR(NodeBuilder(name, "Placeholder")
568ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                           .Attr("dtype", DT_FLOAT)
578ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                           .Attr("shape", shape)
588ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                           .Finalize(g, out));
598ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    return Status::OK();
608ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  }
618ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
628ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Status FindNode(Graph* g, const string& name, Node** out) {
638ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    for (Node* node : g->nodes()) {
648ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      if (node->name() == name) {
658ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        *out = node;
668ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        return Status::OK();
678ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      }
688ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    }
698ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    return errors::Unimplemented("Node ", name, " not found.");
708ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  }
718ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
72a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  std::unique_ptr<Graph> g_;
73a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen};
74a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
758ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh SivakumarTEST_F(QuantizeTrainingTest, SignedInput) {
768ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Test that Quantization ops are created with the correct signed_input value.
77a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  // Construct the following graph
78a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  /*
798ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar           m1
808ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        /      \
818ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      Relu   Identity
82a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen        |       |
83a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen        a       b
84a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  */
85a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Reset();
86a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Graph* g = g_.get();
87a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
88a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
89a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(g->source_node(), a);
90a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(g->source_node(), b);
91a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* relu = test::graph::Relu(g, a);
92a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* identity = test::graph::Identity(g, b);
93a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* m1 = test::graph::Matmul(g, relu, identity, false, false);
94a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(m1, g->sink_node());
95a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
96d3a41971ccc7c03e8b476a48dc8aa7fbb983431cSuharsh Sivakumar  /*
978ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar         m1
988ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      /      \
998ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    EMA_Q   EMA_Q  <- these are subgraphs that estimate moving average.
1008ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      |       |
1018ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    Relu   Identity
102a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen      |       |
103a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen      a       b
104a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  */
105d3a41971ccc7c03e8b476a48dc8aa7fbb983431cSuharsh Sivakumar  const int num_bits = 8;
106d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTraining(num_bits, "QuantizeAndDequantizeV2", g));
107a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
1088ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(63, g->num_nodes());
109a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
110a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  // Quantize_and_dequantize node for identity should have signed_input==true.
1118ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* identity_q_node;
1128ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
1138ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"),
1148ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &identity_q_node));
115a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  ASSERT_EQ("true",
11673882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving            SummarizeAttrValue(*identity_q_node->attrs().Find("signed_input")));
117a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  // Quantize_and_dequantize node for relu should have signed_input==false.
1188ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu_q_node;
1198ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
1208ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
1218ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &relu_q_node));
122a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  ASSERT_EQ("false",
12373882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving            SummarizeAttrValue(*relu_q_node->attrs().Find("signed_input")));
1248ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar}
1258ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
1268ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh SivakumarTEST_F(QuantizeTrainingTest, RangeGivenTrue) {
1278ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Test that Quantization ops are created with the correct range_given value.
1288ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Construct the following graph
1298ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  /*
1308ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar           m1
1318ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        /      \
1328ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      Relu   Relu6
1338ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        |       |
1348ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        a       b
1358ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  */
1368ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Reset();
1378ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Graph* g = g_.get();
1388ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
1398ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
1408ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(g->source_node(), a);
1418ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(g->source_node(), b);
1428ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu = test::graph::Relu(g, a);
1438ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu6 = test::graph::Relu6(g, b);
1448ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* m1 = test::graph::Matmul(g, relu, relu6, false, false);
1458ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(m1, g->sink_node());
1468ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
1478ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  /*
1488ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar         m1
1498ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      /      \
1508ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    EMA_Q     Q
1518ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      |       |
1528ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    Relu   Relu6
1538ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      |       |
1548ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      a       b
1558ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  */
1568ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  const int num_bits = 8;
157d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTraining(num_bits, "QuantizeAndDequantizeV2", g));
1588ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
1598ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(38, g->num_nodes());
1608ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
1618ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Quantize_and_dequantize node for relu6 should have range_given==true.
1628ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu6_q_node;
1638ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
1648ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(relu6->name(), "/QuantizeAndDequantizeV2"),
1658ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &relu6_q_node));
1668ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  ASSERT_EQ("true",
16773882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving            SummarizeAttrValue(*relu6_q_node->attrs().Find("range_given")));
1688ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Quantize_and_dequantize node for relu should have range_given==true.
1698ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu_q_node;
1708ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
1718ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
1728ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &relu_q_node));
1738ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  ASSERT_EQ("true",
17473882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving            SummarizeAttrValue(*relu_q_node->attrs().Find("range_given")));
175a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen}
176a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
177d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh SivakumarTEST_F(QuantizeTrainingTest, WithBackwardNodes_QuantizeAndDequantize) {
1788ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Construct a graph with an additional backward Matmul.
179a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Reset();
180a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Graph* g = g_.get();
181a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
182a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
183a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2});
1848ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // We will use node d as input to the backwards matmul to ensure that it
1858ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // isn't quantized.
1868ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* d = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2});
187a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(g->source_node(), a);
188a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(g->source_node(), b);
189a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(g->source_node(), c);
1908ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(g->source_node(), d);
191a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* relu = test::graph::Relu(g, a);
192a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* identity = test::graph::Identity(g, b);
193a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* m1 = test::graph::Matmul(g, relu, identity, false, false);
194a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* m2 = test::graph::Matmul(g, identity, c, false, false);
195a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(m1, g->sink_node());
196a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(m2, g->sink_node());
197a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
1988ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Add a Matmul node with name starting with "gradients". We will check that
1998ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // its input d was not quantized.
200a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  Node* backward_m;
201a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  TF_ASSERT_OK(NodeBuilder(g->NewName("gradients/n"), "MatMul")
2028ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                   .Input(d)
203a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen                   .Input(m2)
204a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen                   .Attr("transpose_a", true)
205a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen                   .Attr("transpose_b", false)
206a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen                   .Finalize(g, &backward_m));
207a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  g->AddControlEdge(backward_m, g->sink_node());
208a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
209a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen  int num_bits = 8;
210d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTraining(num_bits, "QuantizeAndDequantizeV2", g));
211a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
2128ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(95, g->num_nodes());
2138ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
214d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // Ensure that the backwards matmul input was not quantized.
2158ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* found_node;
2168ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Status s = FindNode(g, strings::StrCat(d->name(), "/QuantizeAndDequantizeV2"),
2178ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar                      &found_node);
2188ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_TRUE(StringPiece(s.ToString()).contains("not found")) << s;
2198ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
2208ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Ensure that m1 and m2's inputs were quantized.
2218ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
2228ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
2238ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &found_node));
2248ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(
2258ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"),
2268ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar               &found_node));
2278ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(FindNode(
2288ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      g, strings::StrCat(c->name(), "/QuantizeAndDequantizeV2"), &found_node));
229a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen}
230a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
231d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh SivakumarTEST_F(QuantizeTrainingTest, WithBackwardNodes_FakeQuant) {
232d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // Construct a graph with an additional backward Matmul.
233d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Reset();
234d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Graph* g = g_.get();
235d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
236d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
237d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* c = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2});
238d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // We will use node d as input to the backwards matmul to ensure that it
239d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // isn't quantized.
240d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* d = Constant<float>({0.0, 1.0, 1.0, 0.0}, {2, 2});
241d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  g->AddControlEdge(g->source_node(), a);
242d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  g->AddControlEdge(g->source_node(), b);
243d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  g->AddControlEdge(g->source_node(), c);
244d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  g->AddControlEdge(g->source_node(), d);
245d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* relu = test::graph::Relu(g, a);
246d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* identity = test::graph::Identity(g, b);
247d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* m1 = test::graph::Matmul(g, relu, identity, false, false);
248d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* m2 = test::graph::Matmul(g, identity, c, false, false);
249d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  g->AddControlEdge(m1, g->sink_node());
250d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  g->AddControlEdge(m2, g->sink_node());
251d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
252d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // Add a Matmul node with name starting with "gradients". We will check that
253d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // its input d was not quantized.
254d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* backward_m;
255d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(NodeBuilder(g->NewName("gradients/n"), "MatMul")
256d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar                   .Input(d)
257d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar                   .Input(m2)
258d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar                   .Attr("transpose_a", true)
259d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar                   .Attr("transpose_b", false)
260d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar                   .Finalize(g, &backward_m));
261d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  g->AddControlEdge(backward_m, g->sink_node());
262d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
263d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  int num_bits = 8;
264d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTraining(num_bits, "FakeQuantWithMinMaxVars", g));
265d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
266d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_EQ(95, g->num_nodes());
267d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
268d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // Ensure that the backwards matmul input was not quantized.
269d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* found_node;
270d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Status s = FindNode(g, strings::StrCat(d->name(), "/FakeQuantWithMinMaxVars"),
271d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar                      &found_node);
272d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_TRUE(StringPiece(s.ToString()).contains("not found")) << s;
273d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
274d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // Ensure that m1 and m2's inputs were quantized.
275d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(
276d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar      FindNode(g, strings::StrCat(relu->name(), "/FakeQuantWithMinMaxVars"),
277d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar               &found_node));
278d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(
279d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar      FindNode(g, strings::StrCat(identity->name(), "/FakeQuantWithMinMaxVars"),
280d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar               &found_node));
281d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(FindNode(
282d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar      g, strings::StrCat(c->name(), "/FakeQuantWithMinMaxVars"), &found_node));
283d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar}
284d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
285602632b5bce191ba9b6a8ba339599168802827acSuharsh SivakumarTEST_F(QuantizeTrainingTest, QuantizeSerializedGraphDef) {
2860e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  // Construct a simple graph with 5 nodes.
2870e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Reset();
2880e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Graph* graph = g_.get();
2890e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Node* const_a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
2900e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Node* const_b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
2910e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  graph->AddControlEdge(graph->source_node(), const_a);
2920e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  graph->AddControlEdge(graph->source_node(), const_b);
2930e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Node* relu = test::graph::Relu(graph, const_a);
2940e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Node* identity = test::graph::Identity(graph, const_b);
2950e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  Node* matmul = test::graph::Matmul(graph, relu, identity, false, false);
2960e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  graph->AddControlEdge(matmul, graph->sink_node());
2970e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen
2980e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  int num_bits = 8;
2990e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen
3000e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  // Convert the graph to the graphdef string.
3010e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  GraphDef input_graph;
3020e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  graph->ToGraphDef(&input_graph);
3030e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  string input_string;
3040e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  input_graph.SerializeToString(&input_string);
3050e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen
3060e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen  string result_string;
307d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTrainingOnSerializedGraphDef(
308d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar      input_string, num_bits, "QuantizeAndDequantizeV2", &result_string));
3090e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen
3108ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  GraphDef result_graphdef;
3118ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_TRUE(ParseProtoUnlimited(&result_graphdef, result_string));
3128ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
313602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  // Ensure that quantizing the serialized graph_def results in a graph with the
314602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  // same number of nodes as quantizing the graph.
315602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  GraphConstructorOptions opts;
316602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  Graph result_graph(OpRegistry::Global());
317602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph));
318602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTraining(num_bits, "QuantizeAndDequantizeV2", graph));
319602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  EXPECT_EQ(graph->num_nodes(), result_graph.num_nodes());
320602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar}
321602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar
322602632b5bce191ba9b6a8ba339599168802827acSuharsh SivakumarTEST_F(QuantizeTrainingTest, QuantizeGraphDef) {
323602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  // Construct a simple graph with 5 nodes.
324602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  Reset();
325602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  Graph* graph = g_.get();
326602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  Node* const_a = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
327602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  Node* const_b = Constant<float>({1.0, 2.0, 3.0, 4.0}, {2, 2});
328602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  graph->AddControlEdge(graph->source_node(), const_a);
329602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  graph->AddControlEdge(graph->source_node(), const_b);
330602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  Node* relu = test::graph::Relu(graph, const_a);
331602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  Node* identity = test::graph::Identity(graph, const_b);
332602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  Node* matmul = test::graph::Matmul(graph, relu, identity, false, false);
333602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  graph->AddControlEdge(matmul, graph->sink_node());
334602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar
335602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  int num_bits = 8;
336602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar
337602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  // Convert the graph to the graphdef string.
338602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  GraphDef input_graphdef;
339602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  graph->ToGraphDef(&input_graphdef);
340602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar
341602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  GraphDef result_graphdef;
342602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTrainingOnGraphDef(
343602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar      input_graphdef, num_bits, "QuantizeAndDequantizeV2", &result_graphdef));
344602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar
3458ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Ensure that quantizing the graph_def results in a graph with the same
346602632b5bce191ba9b6a8ba339599168802827acSuharsh Sivakumar  // number of nodes as the graph_def.
3478ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  GraphConstructorOptions opts;
3488ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Graph result_graph(OpRegistry::Global());
3498ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(ConvertGraphDefToGraph(opts, result_graphdef, &result_graph));
350d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTraining(num_bits, "QuantizeAndDequantizeV2", graph));
3518ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(graph->num_nodes(), result_graph.num_nodes());
3528ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar}
3538ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
354d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh SivakumarTEST_F(QuantizeTrainingTest, FixedRangeAndEMARange_QuantizeAndDequantize) {
355d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // Construct the following graph
356d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // Relu has an unknown range, so we will check if the EMA correctly estimates
357d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // the range.
358d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  /*
359d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar           m1
360d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar        /      \
361d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar      Relu    Relu6
362d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar        |       |
363d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar        a       c
364d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  */
365d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Reset();
366d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Graph* g = g_.get();
367d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* a;
368d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(Placeholder(g, "a", {2, 2}, &a));
369d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* c = Constant<float>({2.0, 3.0, 4.0, 5.0}, {2, 2});
370d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  g->AddControlEdge(g->source_node(), a);
371d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  g->AddControlEdge(g->source_node(), c);
372d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* relu = test::graph::Relu(g, a);
373d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* relu6 = test::graph::Relu6(g, c);
374d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Node* m1 = test::graph::Matmul(g, relu, relu6, false, false);
375d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  g->AddControlEdge(m1, g->sink_node());
376d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
377d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // This is rewritten into the following subgraph, where Q_a and Q_c are
378d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // quantize and dequantize subgraphs.
379d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // Since relu's range is unknown, we check that the exponential moving average
380d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // works correctly.
381d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  /*
382d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar         m1
383d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar      /      \
384d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar     Q_a     Q_c
385d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar      |       |
386d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar    Relu     Relu6
387d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar      |       |
388d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar      a       c
389d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  */
390d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  const int num_bits = 8;
391d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTraining(num_bits, "QuantizeAndDequantizeV2", g));
392d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
393d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  SessionOptions options;
394d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Session* sess;
395d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(NewSession(options, &sess));
396d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  GraphDef gdef;
397d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  g->ToGraphDef(&gdef);
398d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(sess->Create(gdef));
399d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
400d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // The min and max values of the relu6 quantization should be constant values
401d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // of 0 and 6.
402d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  string min_const_name = strings::StrCat(relu6->name(), "/InputMin");
403d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  string max_const_name = strings::StrCat(relu6->name(), "/InputMax");
404d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  std::vector<Tensor> outputs;
405d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs));
406d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
407d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 6.0);
408d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
409d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Tensor a1(DT_FLOAT, TensorShape({2, 2}));
410d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  test::FillValues<float>(&a1, {0.0, 1.0, 2.0, 3.0});
411d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  Tensor a2(DT_FLOAT, TensorShape({2, 2}));
412d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  test::FillValues<float>(&a2, {1.0, 2.0, 3.0, 4.0});
413d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
414d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(sess->Run({{"a", a1}}, {m1->name()}, {}, &outputs));
415d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
416d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // The value of the min and max should be set to the min and max of a1 since
417d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // this is the first run that initializes the EMA variables.
418d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  string min_var_name = strings::StrCat(relu->name(), "/Min/Variable");
419d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  string max_var_name = strings::StrCat(relu->name(), "/Max/Variable");
420d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_var_name, max_var_name}, {}, &outputs));
421d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
422d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 3.0);
423d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
424d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // The relu6 quantization range should remain unchanged.
425d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs));
426d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
427d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 6.0);
428d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
429d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // Now when we run with new inputs, we should get a moving average for the min
430d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // and max variables. They should be equal to:
431d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // min_var = old_min_var * decay + min(a2) * (1 - decay)
432d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // max_var = old_max_var * decay + max(a2) * (1 - decay)
433d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(sess->Run({{"a", a2}}, {m1->name()}, {}, &outputs));
434d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
435d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_var_name, max_var_name}, {}, &outputs));
436d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  const float decay = 0.999;
437d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  const float expected_min = 0.0 * decay + 1.0 * (1.0 - decay);
438d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  const float expected_max = 3.0 * decay + 4.0 * (1.0 - decay);
439d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_NEAR(outputs[0].flat<float>()(0), expected_min, 1e-4);
440d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_NEAR(outputs[1].flat<float>()(0), expected_max, 1e-4);
441d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
442d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  // The relu6 quantization range should remain unchanged.
443d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs));
444d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
445d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 6.0);
446d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar}
447d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar
448d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh SivakumarTEST_F(QuantizeTrainingTest, FixedRangeAndEMARange_FakeQuant) {
4498ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Construct the following graph
4508ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Relu has an unknown range, so we will check if the EMA correctly estimates
4518ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // the range.
4528ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  /*
4538ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar           m1
4548ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        /      \
4558ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      Relu    Relu6
4568ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        |       |
4578ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar        a       c
4588ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  */
4598ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Reset();
4608ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Graph* g = g_.get();
4618ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* a;
4628ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(Placeholder(g, "a", {2, 2}, &a));
4638ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* c = Constant<float>({2.0, 3.0, 4.0, 5.0}, {2, 2});
4648ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(g->source_node(), a);
4658ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(g->source_node(), c);
4668ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu = test::graph::Relu(g, a);
4678ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* relu6 = test::graph::Relu6(g, c);
4688ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Node* m1 = test::graph::Matmul(g, relu, relu6, false, false);
4698ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->AddControlEdge(m1, g->sink_node());
4708ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
4718ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // This is rewritten into the following subgraph, where Q_a and Q_c are
4728ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // quantize and dequantize subgraphs.
4738ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Since relu's range is unknown, we check that the exponential moving average
4748ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // works correctly.
4758ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  /*
4768ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar         m1
4778ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      /      \
4788ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar     Q_a     Q_c
4798ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      |       |
4808ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar    Relu     Relu6
4818ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      |       |
4828ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar      a       c
4838ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  */
4848ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  const int num_bits = 8;
485d84aacfb7c8b5e561af7414b220a37bb1209ad77Suharsh Sivakumar  TF_ASSERT_OK(DoQuantizeTraining(num_bits, "FakeQuantWithMinMaxVars", g));
4868ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
4878ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  SessionOptions options;
4888ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Session* sess;
4898ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(NewSession(options, &sess));
4908ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  GraphDef gdef;
4918ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  g->ToGraphDef(&gdef);
4928ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Create(gdef));
4938ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
4948ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // The min and max values of the relu6 quantization should be constant values
4958ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // of 0 and 6.
4968ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  string min_const_name = strings::StrCat(relu6->name(), "/InputMin");
4978ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  string max_const_name = strings::StrCat(relu6->name(), "/InputMax");
4988ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  std::vector<Tensor> outputs;
4998ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs));
5008ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
5018ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 6.0);
5028ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
5038ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Tensor a1(DT_FLOAT, TensorShape({2, 2}));
5048ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  test::FillValues<float>(&a1, {0.0, 1.0, 2.0, 3.0});
5058ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  Tensor a2(DT_FLOAT, TensorShape({2, 2}));
5068ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  test::FillValues<float>(&a2, {1.0, 2.0, 3.0, 4.0});
5078ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
5088ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({{"a", a1}}, {m1->name()}, {}, &outputs));
5098ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
5108ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // The value of the min and max should be set to the min and max of a1 since
5118ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // this is the first run that initializes the EMA variables.
5128ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  string min_var_name = strings::StrCat(relu->name(), "/Min/Variable");
5138ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  string max_var_name = strings::StrCat(relu->name(), "/Max/Variable");
5148ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_var_name, max_var_name}, {}, &outputs));
5158ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
5168ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 3.0);
5178ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
5188ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // The relu6 quantization range should remain unchanged.
5198ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs));
5208ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
5218ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 6.0);
5228ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
5238ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // Now when we run with new inputs, we should get a moving average for the min
5248ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // and max variables. They should be equal to:
5258ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // min_var = old_min_var * decay + min(a2) * (1 - decay)
5268ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // max_var = old_max_var * decay + max(a2) * (1 - decay)
5278ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({{"a", a2}}, {m1->name()}, {}, &outputs));
5288ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar
5298ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_var_name, max_var_name}, {}, &outputs));
5308ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  const float decay = 0.999;
5318ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  const float expected_min = 0.0 * decay + 1.0 * (1.0 - decay);
5328ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  const float expected_max = 3.0 * decay + 4.0 * (1.0 - decay);
5338ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_NEAR(outputs[0].flat<float>()(0), expected_min, 1e-4);
5348ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_NEAR(outputs[1].flat<float>()(0), expected_max, 1e-4);
5350e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen
5368ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  // The relu6 quantization range should remain unchanged.
5378ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs));
5388ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[0].flat<float>()(0), 0.0);
5398ba27e3bd5203bd4cd9533b36c7fd02b85a0a42aSuharsh Sivakumar  EXPECT_EQ(outputs[1].flat<float>()(0), 6.0);
5400e96cccf3a0597e5b9d5a0971ec17dbe71659505Jianmin Chen}
541a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen
542a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen}  // namespace
543a8a25d85a5c57f8cbb4b22aa5bd7e9c86e5aedd8Jianmin Chen}  // namespace tensorflow
544