1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for third_party.tensorflow.contrib.quantize.python.quant_ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.quantize.python import quant_ops
22from tensorflow.python.client import session
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import variables
27from tensorflow.python.platform import googletest
28
29_MIN_MAX_VARS = 'min_max_vars'
30
31
32class QuantOpsTest(googletest.TestCase):
33
34  def testLastValueQuantizeTrainingAssign(self):
35    g = ops.Graph()
36    with session.Session(graph=g) as sess:
37      x = array_ops.placeholder(dtypes.float32, shape=[2])
38      y = quant_ops.LastValueQuantize(
39          x,
40          init_min=0.0,
41          init_max=0.0,
42          is_training=True,
43          vars_collection=_MIN_MAX_VARS)
44
45      # Run the step.
46      sess.run(variables.global_variables_initializer())
47      sess.run(y, feed_dict={x: [-1.0, 1.0]})
48      # Now check that the min_max_vars were, in fact, updated.
49      min_value, max_value = self._GetMinMaxValues(sess)
50      self.assertEqual(min_value, -1.0)
51      self.assertEqual(max_value, 1.0)
52
53  def testMovingAvgQuantizeTrainingAssign(self):
54    g = ops.Graph()
55    with session.Session(graph=g) as sess:
56      x = array_ops.placeholder(dtypes.float32, shape=[2])
57      y = quant_ops.MovingAvgQuantize(
58          x,
59          init_min=0.0,
60          init_max=0.0,
61          is_training=True,
62          vars_collection=_MIN_MAX_VARS)
63
64      # Run the step.
65      sess.run(variables.global_variables_initializer())
66      # Do two runs to avoid zero debias.
67      sess.run(y, feed_dict={x: [-1.0, 1.0]})
68      sess.run(y, feed_dict={x: [0.0, 0.0]})
69      # Now check that the min_max_vars were, in fact, updated.
70      min_value, max_value = self._GetMinMaxValues(sess)
71      self.assertGreater(min_value, -1.0)
72      self.assertLess(min_value, 0.0)
73      self.assertGreater(max_value, 0.0)
74      self.assertLess(max_value, 1.0)
75
76  def _GetMinMaxValues(self, sess):
77    min_max_vars = ops.get_collection(_MIN_MAX_VARS)
78    self.assertEqual(len(min_max_vars), 2)
79    min_idx = 0 if 'min' in min_max_vars[0].name else 1
80    max_idx = (min_idx + 1) % 2
81    min_var, max_var = min_max_vars[min_idx], min_max_vars[max_idx]
82    min_max_values = sess.run([min_var, max_var])
83    return min_max_values[0], min_max_values[1]
84
85
86if __name__ == '__main__':
87  googletest.main()
88