1955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar#
3955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# Licensed under the Apache License, Version 2.0 (the "License");
4955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# you may not use this file except in compliance with the License.
5955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# You may obtain a copy of the License at
6955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar#
7955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar#     http://www.apache.org/licenses/LICENSE-2.0
8955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar#
9955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# Unless required by applicable law or agreed to in writing, software
10955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# distributed under the License is distributed on an "AS IS" BASIS,
11955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# See the License for the specific language governing permissions and
13955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# limitations under the License.
14955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# ==============================================================================
15955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar"""Unit tests for the quantize_graph graph rewriting API."""
16955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar
17955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom __future__ import absolute_import
18955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom __future__ import division
19955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom __future__ import print_function
20955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar
21df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumarfrom tensorflow.contrib.layers.python.layers import layers
22955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom tensorflow.contrib.quantize.python import quantize_graph
23955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom tensorflow.python.framework import ops
24955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom tensorflow.python.framework import test_util
25df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumarfrom tensorflow.python.ops import array_ops
26df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumarfrom tensorflow.python.ops import init_ops
27df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumarfrom tensorflow.python.ops import nn_ops
28955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom tensorflow.python.platform import googletest
29955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar
30955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar
31df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumarclass QuantizeGraphTest(test_util.TensorFlowTestCase):
32955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar  # We have a lot of other tests that test the details of the rewrite, here we
33955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar  # just the specific features of the quantize_graph API.
34df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar
35ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def _RunTestOverAllRewrites(self, test_fn):
3698c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar    rewrite_fns = [
3798c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar        quantize_graph.create_training_graph,
3898c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar        quantize_graph.create_eval_graph,
3998c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar        quantize_graph.experimental_create_training_graph,
4098c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar        quantize_graph.experimental_create_eval_graph,
4198c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar    ]
4298c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar    for fn in rewrite_fns:
4398c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar      test_fn(fn)
4498c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar
45ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def _RunTestOverTrainingRewrites(self, test_fn):
46ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    rewrite_fns = [
47ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        quantize_graph.create_training_graph,
48ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        quantize_graph.experimental_create_training_graph,
49ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    ]
50ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    for fn in rewrite_fns:
51ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      test_fn(fn)
52ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar
53f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar  def _RunTestOverEvalRewrites(self, test_fn):
54f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    rewrite_fns = [
55f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar        quantize_graph.create_eval_graph,
56f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar        quantize_graph.experimental_create_eval_graph,
57f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    ]
58f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    for fn in rewrite_fns:
59f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      test_fn(fn)
60f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar
61ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def _RunTestOverExperimentalRewrites(self, test_fn):
62ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    rewrite_fns = [
63ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        quantize_graph.experimental_create_training_graph,
64ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        quantize_graph.experimental_create_eval_graph,
65ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    ]
66ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    for fn in rewrite_fns:
67ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      test_fn(fn)
68ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar
69a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar  def testRewrite(self):
70ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    self._RunTestOverAllRewrites(self._TestRewrite)
71df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar
72ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def _TestRewrite(self, rewrite_fn):
73df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar    graph = ops.Graph()
74df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar    with graph.as_default():
75ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      self._ConvLayer()
76df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar
77df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar    orig_variable_names = set(
78df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar        [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
79a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar
80ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    rewrite_fn(graph)
81a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar
82a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar    q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
83df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar    # Ensure that variables were added.
84df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar    self.assertTrue(len(orig_variable_names) < len(q_variables))
85a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar
86a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar  def testDefaultGraph(self):
87ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    self._RunTestOverAllRewrites(self._TestRewrite)
88a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar
89ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def _TestDefaultGraph(self, rewrite_fn):
90ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    # Tests that the default graph is correctly used when no args are provided
91ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    # to rewrite_fn.
92a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar    with ops.Graph().as_default() as g:
93ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      self._ConvLayer()
94a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar      orig_variable_names = set(
95a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar          [v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
96ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      rewrite_fn()
97a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar
98a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar      q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
99a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar      # Ensure that variables were added.
100a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar      self.assertTrue(len(orig_variable_names) < len(q_variables))
101df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar
102ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def testQuantDelay(self):
103ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    self._RunTestOverTrainingRewrites(self._TestQuantDelay)
104df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar
105ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def _TestQuantDelay(self, rewrite_fn):
106ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    with ops.Graph().as_default() as g:
107ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      self._ConvLayer()
108ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      quant_delay = 100
109ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      rewrite_fn(quant_delay=quant_delay)
110ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar
111ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    quant_delay_found = False
112ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    for op in g.get_operations():
113ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      # Check to see if the quant_delay is correctly set.
114ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      if 'activate_quant' in op.name and op.type == 'Const':
115ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        quant_delay_found = True
116ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        const_value = str(op.get_attr('value'))
117ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        self.assertTrue(('int64_val: %i' % quant_delay) in const_value)
118ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    self.assertTrue(quant_delay_found)
119ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar
120ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def testWeightBits(self):
121ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    self._RunTestOverExperimentalRewrites(self._TestWeightBits)
122ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar
123ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def _TestWeightBits(self, rewrite_fn):
124ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    with ops.Graph().as_default() as g:
125ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      self._ConvLayer()
126ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      weight_bits = 4
127ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      rewrite_fn(weight_bits=weight_bits)
128ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar
129ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    weights_quant_found = False
130ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    for op in g.get_operations():
131ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      # Check to see if FakeQuant operations for weights have the right bits
132ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      # set.
133ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars':
134ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        weights_quant_found = True
135ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        self.assertEqual(op.get_attr('num_bits'), weight_bits)
136ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    self.assertTrue(weights_quant_found)
137ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar
138ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def testActivationBits(self):
139ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    self._RunTestOverExperimentalRewrites(self._TestActivationBits)
140ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar
141ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def _TestActivationBits(self, rewrite_fn):
142ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    with ops.Graph().as_default() as g:
143ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      self._ConvLayer()
144ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      activation_bits = 4
145ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      rewrite_fn(activation_bits=activation_bits)
146ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar
147ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    act_quant_found = False
148ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    for op in g.get_operations():
149ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      # Check to see if FakeQuant operations for activations have the right bits
150ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      # set.
151ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      act_quant_names = ['act_quant', 'conv_quant', 'add_quant']
152ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar      if any(s in op.name
153ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar             for s in act_quant_names) and op.type == 'FakeQuantWithMinMaxVars':
154ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        act_quant_found = True
155ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        self.assertEqual(op.get_attr('num_bits'), activation_bits)
156ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    self.assertTrue(act_quant_found)
157ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar
158f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar  def testTrainingQuantization(self):
159f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    self._RunTestOverTrainingRewrites(self._TestTrainingQuantization)
160f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar
161f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar  def _TestTrainingQuantization(self, rewrite_fn):
162f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    with ops.Graph().as_default() as g:
163f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      self._ConvLayer()
164f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      rewrite_fn()
165f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar
166f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    # Ensure that FakeQuant and variable update nodes were found.
167f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    quant_found = False
168f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    assign_min_last_found = False
169f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    assign_min_ema_found = False
170f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    assign_max_last_found = False
171f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    assign_max_ema_found = False
172f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    for op in g.get_operations():
173f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      # Check that FakeQuant operations were added.
174f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      if op.type == 'FakeQuantWithMinMaxVars':
175f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar        quant_found = True
176f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      # Check that update operations for the added min max variables exist in
177f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      # the graph.
178f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      if 'AssignMinLast' in op.name:
179f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar        assign_min_last_found = True
180f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      elif 'AssignMinEma' in op.name:
181f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar        assign_min_ema_found = True
182f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      elif 'AssignMaxLast' in op.name:
183f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar        assign_max_last_found = True
184f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      elif 'AssignMaxEma' in op.name:
185f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar        assign_max_ema_found = True
186f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    self.assertTrue(assign_min_last_found)
187f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    self.assertTrue(assign_min_ema_found)
188f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    self.assertTrue(assign_max_last_found)
189f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    self.assertTrue(assign_max_ema_found)
190f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    self.assertTrue(quant_found)
191f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar
192f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar  def testEvalQuantization(self):
193f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    self._RunTestOverEvalRewrites(self._TestEvalQuantization)
194f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar
195f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar  def _TestEvalQuantization(self, rewrite_fn):
196f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    with ops.Graph().as_default() as g:
197f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      self._ConvLayer()
198f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      rewrite_fn()
199f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar
200f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    # Ensure that FakeQuant and variable update nodes were found.
201f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    quant_found = False
202f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    for op in g.get_operations():
203f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      # Check that FakeQuant operations were added.
204f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      if op.type == 'FakeQuantWithMinMaxVars':
205f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar        quant_found = True
206f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      # Check that update operations for the added min max variables don't
207f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      # exist in the graph.
208f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      update_names = [
209f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar          'AssignMinLast', 'AssignMinEma', 'AssignMaxLast', 'AssignMaxEma'
210f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      ]
211f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar      self.assertFalse(any(s in op.name for s in update_names))
212f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar    self.assertTrue(quant_found)
213f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar
214ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar  def _ConvLayer(self):
215ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    """Add a basic convolution layer to the default graph."""
216ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    batch_size, height, width, depth = 5, 128, 128, 3
217ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    inputs = array_ops.zeros((batch_size, height, width, depth))
218ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    weight_init = init_ops.truncated_normal_initializer
219ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    conv = layers.conv2d(
220ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        inputs,
221ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        32, [5, 5],
222ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        stride=2,
223ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        padding='SAME',
224ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        weights_initializer=weight_init(0.09),
225ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        activation_fn=None,
226ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar        scope='test')
227ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar    _ = nn_ops.relu6(conv)
228955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar
229955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar
230955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarif __name__ == '__main__':
231955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar  googletest.main()
232