1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for third_party.tensorflow.contrib.quantize.python.quant_ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.quantize.python import quant_ops 22from tensorflow.python.client import session 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import variables 27from tensorflow.python.platform import googletest 28 29_MIN_MAX_VARS = 'min_max_vars' 30 31 32class QuantOpsTest(googletest.TestCase): 33 34 def testLastValueQuantizeTrainingAssign(self): 35 g = ops.Graph() 36 with session.Session(graph=g) as sess: 37 x = array_ops.placeholder(dtypes.float32, shape=[2]) 38 y = quant_ops.LastValueQuantize( 39 x, 40 init_min=0.0, 41 init_max=0.0, 42 is_training=True, 43 vars_collection=_MIN_MAX_VARS) 44 45 # Run the step. 46 sess.run(variables.global_variables_initializer()) 47 sess.run(y, feed_dict={x: [-1.0, 1.0]}) 48 # Now check that the min_max_vars were, in fact, updated. 49 min_value, max_value = self._GetMinMaxValues(sess) 50 self.assertEqual(min_value, -1.0) 51 self.assertEqual(max_value, 1.0) 52 53 def testMovingAvgQuantizeTrainingAssign(self): 54 g = ops.Graph() 55 with session.Session(graph=g) as sess: 56 x = array_ops.placeholder(dtypes.float32, shape=[2]) 57 y = quant_ops.MovingAvgQuantize( 58 x, 59 init_min=0.0, 60 init_max=0.0, 61 is_training=True, 62 vars_collection=_MIN_MAX_VARS) 63 64 # Run the step. 65 sess.run(variables.global_variables_initializer()) 66 # Do two runs to avoid zero debias. 67 sess.run(y, feed_dict={x: [-1.0, 1.0]}) 68 sess.run(y, feed_dict={x: [0.0, 0.0]}) 69 # Now check that the min_max_vars were, in fact, updated. 70 min_value, max_value = self._GetMinMaxValues(sess) 71 self.assertGreater(min_value, -1.0) 72 self.assertLess(min_value, 0.0) 73 self.assertGreater(max_value, 0.0) 74 self.assertLess(max_value, 1.0) 75 76 def _GetMinMaxValues(self, sess): 77 min_max_vars = ops.get_collection(_MIN_MAX_VARS) 78 self.assertEqual(len(min_max_vars), 2) 79 min_idx = 0 if 'min' in min_max_vars[0].name else 1 80 max_idx = (min_idx + 1) % 2 81 min_var, max_var = min_max_vars[min_idx], min_max_vars[max_idx] 82 min_max_values = sess.run([min_var, max_var]) 83 return min_max_values[0], min_max_values[1] 84 85 86if __name__ == '__main__': 87 googletest.main() 88