1955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# 3955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# Licensed under the Apache License, Version 2.0 (the "License"); 4955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# you may not use this file except in compliance with the License. 5955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# You may obtain a copy of the License at 6955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# 7955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# http://www.apache.org/licenses/LICENSE-2.0 8955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# 9955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# Unless required by applicable law or agreed to in writing, software 10955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# distributed under the License is distributed on an "AS IS" BASIS, 11955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# See the License for the specific language governing permissions and 13955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# limitations under the License. 14955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar# ============================================================================== 15955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar"""Unit tests for the quantize_graph graph rewriting API.""" 16955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar 17955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom __future__ import absolute_import 18955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom __future__ import division 19955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom __future__ import print_function 20955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar 21df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumarfrom tensorflow.contrib.layers.python.layers import layers 22955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom tensorflow.contrib.quantize.python import quantize_graph 23955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom tensorflow.python.framework import ops 24955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom tensorflow.python.framework import test_util 25df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumarfrom tensorflow.python.ops import array_ops 26df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumarfrom tensorflow.python.ops import init_ops 27df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumarfrom tensorflow.python.ops import nn_ops 28955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarfrom tensorflow.python.platform import googletest 29955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar 30955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar 31df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumarclass QuantizeGraphTest(test_util.TensorFlowTestCase): 32955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar # We have a lot of other tests that test the details of the rewrite, here we 33955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar # just the specific features of the quantize_graph API. 34df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar 35ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def _RunTestOverAllRewrites(self, test_fn): 3698c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar rewrite_fns = [ 3798c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar quantize_graph.create_training_graph, 3898c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar quantize_graph.create_eval_graph, 3998c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar quantize_graph.experimental_create_training_graph, 4098c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar quantize_graph.experimental_create_eval_graph, 4198c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar ] 4298c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar for fn in rewrite_fns: 4398c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar test_fn(fn) 4498c4a4efaf836e120505b4f1d52d7508802e004eSuharsh Sivakumar 45ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def _RunTestOverTrainingRewrites(self, test_fn): 46ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar rewrite_fns = [ 47ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar quantize_graph.create_training_graph, 48ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar quantize_graph.experimental_create_training_graph, 49ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar ] 50ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar for fn in rewrite_fns: 51ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar test_fn(fn) 52ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar 53f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar def _RunTestOverEvalRewrites(self, test_fn): 54f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar rewrite_fns = [ 55f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar quantize_graph.create_eval_graph, 56f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar quantize_graph.experimental_create_eval_graph, 57f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar ] 58f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar for fn in rewrite_fns: 59f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar test_fn(fn) 60f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar 61ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def _RunTestOverExperimentalRewrites(self, test_fn): 62ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar rewrite_fns = [ 63ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar quantize_graph.experimental_create_training_graph, 64ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar quantize_graph.experimental_create_eval_graph, 65ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar ] 66ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar for fn in rewrite_fns: 67ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar test_fn(fn) 68ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar 69a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar def testRewrite(self): 70ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self._RunTestOverAllRewrites(self._TestRewrite) 71df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar 72ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def _TestRewrite(self, rewrite_fn): 73df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar graph = ops.Graph() 74df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar with graph.as_default(): 75ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self._ConvLayer() 76df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar 77df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar orig_variable_names = set( 78df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) 79a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar 80ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar rewrite_fn(graph) 81a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar 82a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 83df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar # Ensure that variables were added. 84df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar self.assertTrue(len(orig_variable_names) < len(q_variables)) 85a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar 86a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar def testDefaultGraph(self): 87ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self._RunTestOverAllRewrites(self._TestRewrite) 88a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar 89ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def _TestDefaultGraph(self, rewrite_fn): 90ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar # Tests that the default graph is correctly used when no args are provided 91ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar # to rewrite_fn. 92a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar with ops.Graph().as_default() as g: 93ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self._ConvLayer() 94a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar orig_variable_names = set( 95a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar [v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) 96ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar rewrite_fn() 97a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar 98a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 99a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar # Ensure that variables were added. 100a4f0b3afb631f40024996c16a8bf2a146fb3dc8cSuharsh Sivakumar self.assertTrue(len(orig_variable_names) < len(q_variables)) 101df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar 102ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def testQuantDelay(self): 103ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self._RunTestOverTrainingRewrites(self._TestQuantDelay) 104df299e1a0c91f50acf4868c7bb3e0ea93b52db7bSuharsh Sivakumar 105ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def _TestQuantDelay(self, rewrite_fn): 106ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar with ops.Graph().as_default() as g: 107ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self._ConvLayer() 108ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar quant_delay = 100 109ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar rewrite_fn(quant_delay=quant_delay) 110ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar 111ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar quant_delay_found = False 112ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar for op in g.get_operations(): 113ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar # Check to see if the quant_delay is correctly set. 114ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar if 'activate_quant' in op.name and op.type == 'Const': 115ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar quant_delay_found = True 116ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar const_value = str(op.get_attr('value')) 117ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self.assertTrue(('int64_val: %i' % quant_delay) in const_value) 118ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self.assertTrue(quant_delay_found) 119ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar 120ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def testWeightBits(self): 121ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self._RunTestOverExperimentalRewrites(self._TestWeightBits) 122ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar 123ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def _TestWeightBits(self, rewrite_fn): 124ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar with ops.Graph().as_default() as g: 125ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self._ConvLayer() 126ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar weight_bits = 4 127ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar rewrite_fn(weight_bits=weight_bits) 128ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar 129ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar weights_quant_found = False 130ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar for op in g.get_operations(): 131ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar # Check to see if FakeQuant operations for weights have the right bits 132ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar # set. 133ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars': 134ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar weights_quant_found = True 135ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self.assertEqual(op.get_attr('num_bits'), weight_bits) 136ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self.assertTrue(weights_quant_found) 137ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar 138ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def testActivationBits(self): 139ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self._RunTestOverExperimentalRewrites(self._TestActivationBits) 140ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar 141ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def _TestActivationBits(self, rewrite_fn): 142ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar with ops.Graph().as_default() as g: 143ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self._ConvLayer() 144ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar activation_bits = 4 145ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar rewrite_fn(activation_bits=activation_bits) 146ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar 147ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar act_quant_found = False 148ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar for op in g.get_operations(): 149ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar # Check to see if FakeQuant operations for activations have the right bits 150ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar # set. 151ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar act_quant_names = ['act_quant', 'conv_quant', 'add_quant'] 152ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar if any(s in op.name 153ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar for s in act_quant_names) and op.type == 'FakeQuantWithMinMaxVars': 154ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar act_quant_found = True 155ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self.assertEqual(op.get_attr('num_bits'), activation_bits) 156ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar self.assertTrue(act_quant_found) 157ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar 158f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar def testTrainingQuantization(self): 159f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar self._RunTestOverTrainingRewrites(self._TestTrainingQuantization) 160f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar 161f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar def _TestTrainingQuantization(self, rewrite_fn): 162f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar with ops.Graph().as_default() as g: 163f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar self._ConvLayer() 164f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar rewrite_fn() 165f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar 166f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar # Ensure that FakeQuant and variable update nodes were found. 167f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar quant_found = False 168f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar assign_min_last_found = False 169f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar assign_min_ema_found = False 170f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar assign_max_last_found = False 171f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar assign_max_ema_found = False 172f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar for op in g.get_operations(): 173f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar # Check that FakeQuant operations were added. 174f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar if op.type == 'FakeQuantWithMinMaxVars': 175f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar quant_found = True 176f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar # Check that update operations for the added min max variables exist in 177f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar # the graph. 178f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar if 'AssignMinLast' in op.name: 179f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar assign_min_last_found = True 180f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar elif 'AssignMinEma' in op.name: 181f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar assign_min_ema_found = True 182f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar elif 'AssignMaxLast' in op.name: 183f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar assign_max_last_found = True 184f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar elif 'AssignMaxEma' in op.name: 185f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar assign_max_ema_found = True 186f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar self.assertTrue(assign_min_last_found) 187f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar self.assertTrue(assign_min_ema_found) 188f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar self.assertTrue(assign_max_last_found) 189f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar self.assertTrue(assign_max_ema_found) 190f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar self.assertTrue(quant_found) 191f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar 192f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar def testEvalQuantization(self): 193f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar self._RunTestOverEvalRewrites(self._TestEvalQuantization) 194f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar 195f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar def _TestEvalQuantization(self, rewrite_fn): 196f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar with ops.Graph().as_default() as g: 197f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar self._ConvLayer() 198f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar rewrite_fn() 199f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar 200f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar # Ensure that FakeQuant and variable update nodes were found. 201f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar quant_found = False 202f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar for op in g.get_operations(): 203f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar # Check that FakeQuant operations were added. 204f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar if op.type == 'FakeQuantWithMinMaxVars': 205f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar quant_found = True 206f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar # Check that update operations for the added min max variables don't 207f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar # exist in the graph. 208f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar update_names = [ 209f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar 'AssignMinLast', 'AssignMinEma', 'AssignMaxLast', 'AssignMaxEma' 210f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar ] 211f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar self.assertFalse(any(s in op.name for s in update_names)) 212f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar self.assertTrue(quant_found) 213f0029e14a17c53e97f9ddb02486efdcc06165091Suharsh Sivakumar 214ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar def _ConvLayer(self): 215ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar """Add a basic convolution layer to the default graph.""" 216ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar batch_size, height, width, depth = 5, 128, 128, 3 217ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar inputs = array_ops.zeros((batch_size, height, width, depth)) 218ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar weight_init = init_ops.truncated_normal_initializer 219ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar conv = layers.conv2d( 220ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar inputs, 221ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar 32, [5, 5], 222ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar stride=2, 223ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar padding='SAME', 224ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar weights_initializer=weight_init(0.09), 225ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar activation_fn=None, 226ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar scope='test') 227ef59be7e91a2b61c73b71086a43cfc7d96374e99Suharsh Sivakumar _ = nn_ops.relu6(conv) 228955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar 229955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar 230955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumarif __name__ == '__main__': 231955c525d416c163c9dd857e637b0476b112b0ea0Suharsh Sivakumar googletest.main() 232