1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#
3ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# Licensed under the Apache License, Version 2.0 (the "License");
4ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# you may not use this file except in compliance with the License.
5ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# You may obtain a copy of the License at
6ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#
7ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#     http://www.apache.org/licenses/LICENSE-2.0
8ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden#
9ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# Unless required by applicable law or agreed to in writing, software
10ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# distributed under the License is distributed on an "AS IS" BASIS,
11ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# See the License for the specific language governing permissions and
13ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# limitations under the License.
14ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# ==============================================================================
15ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden"""Tests the graph quantization script.
16ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
17ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden"""
1836d561c253844323fb1de28303407b451b7886a7Geoffrey Irving
1936d561c253844323fb1de28303407b451b7886a7Geoffrey Irvingfrom __future__ import absolute_import
2036d561c253844323fb1de28303407b451b7886a7Geoffrey Irvingfrom __future__ import division
2136d561c253844323fb1de28303407b451b7886a7Geoffrey Irvingfrom __future__ import print_function
2236d561c253844323fb1de28303407b451b7886a7Geoffrey Irving
234b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlowerimport sys
2466024fd508748d706b72d0ae5e8b07f917e78458Andrew Harpimport numpy as np
2566024fd508748d706b72d0ae5e8b07f917e78458Andrew Harp
26e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.core.framework import graph_pb2
27e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.client import session
28e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import dtypes
29a558c6e3b38846727873b5afbbc3ba309ae5dff5Olivia Nordquistfrom tensorflow.python.framework import graph_util
30e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import importer
31e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import ops as ops_lib
32e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.platform import flags as flags_lib
33e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.platform import test
34e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.platform import tf_logging
3566024fd508748d706b72d0ae5e8b07f917e78458Andrew Harpfrom tensorflow.tools.quantization import quantize_graph
36ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
37e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyflags = flags_lib
38ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenFLAGS = flags.FLAGS
39ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
40ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
41ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendef run_graph_def(graph_def, input_map, outputs):
42e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  graph = ops_lib.Graph()
43ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  with graph.as_default():
44e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    importer.import_graph_def(graph_def, input_map={}, name="")
45e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  with session.Session(graph=graph) as sess:
46ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    results = sess.run(outputs, feed_dict=input_map)
47ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  return results
48ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
49ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
50ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendef test_mat_mul(m, n, k, a, b):
51ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  """Tests a MatMul replacement."""
52ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  a_constant_name = "a_constant"
53ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  b_constant_name = "b_constant"
54ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  mat_mul_name = "mat_mul"
55ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
56e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  float_graph_def = graph_pb2.GraphDef()
57e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  a_constant = quantize_graph.create_constant_node(
58e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      a_constant_name, value=a, dtype=dtypes.float32, shape=[m, k])
59ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  float_graph_def.node.extend([a_constant])
60e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  b_constant = quantize_graph.create_constant_node(
61e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      b_constant_name, value=b, dtype=dtypes.float32, shape=[k, n])
62ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  float_graph_def.node.extend([b_constant])
63ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  mat_mul_node = quantize_graph.create_node("MatMul", mat_mul_name,
64ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                            [a_constant_name, b_constant_name])
65e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  quantize_graph.set_attr_dtype(mat_mul_node, "T", dtypes.float32)
66ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  quantize_graph.set_attr_bool(mat_mul_node, "transpose_a", False)
67ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  quantize_graph.set_attr_bool(mat_mul_node, "transpose_b", False)
68ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  float_graph_def.node.extend([mat_mul_node])
69ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
70ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  test_graph(float_graph_def, {}, [mat_mul_name])
71ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
72ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
73ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendef test_conv(depth, image_width, image_height, image_batch_count, filter_size,
74ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden              filter_count, stride, padding, input_values, filter_values):
75ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  """Tests a Conv replacement."""
76ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  input_constant_name = "input_constant"
77ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  filter_constant_name = "filter_constant"
78ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  conv_name = "conv"
79ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
80e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  float_graph_def = graph_pb2.GraphDef()
81ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  input_constant = quantize_graph.create_constant_node(
82ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      input_constant_name,
83ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      value=input_values,
84e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      dtype=dtypes.float32,
85e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      shape=[image_batch_count, image_height, image_width, depth])
86ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  float_graph_def.node.extend([input_constant])
87ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  filter_constant = quantize_graph.create_constant_node(
88ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      filter_constant_name,
89ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      value=filter_values,
90e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      dtype=dtypes.float32,
91e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      shape=[filter_size, filter_size, depth, filter_count])
92ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  float_graph_def.node.extend([filter_constant])
93e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  conv_node = quantize_graph.create_node(
94e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      "Conv2D", conv_name, [input_constant_name, filter_constant_name])
95e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  quantize_graph.set_attr_dtype(conv_node, "T", dtypes.float32)
96ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  quantize_graph.set_attr_int_list(conv_node, "strides", [1, stride, stride, 1])
97ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  quantize_graph.set_attr_string(conv_node, "padding", padding)
98ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  float_graph_def.node.extend([conv_node])
99ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
100ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  test_graph(float_graph_def, {}, [conv_name])
101ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
102ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
103ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendef are_tensors_near(a, b, tolerance):
104ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  """Tests whether two tensors are nearly identical.
105ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
106ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  This is a specialized comparison function designed to help debug problems with
107ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  quantization. It prints out information about the differences between tensors
108ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  on failure, paying special attention to possible biases by looking at the mean
109ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  and absolute average errors.
110ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
111ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  Args:
112ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a: First comparison tensor.
113ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b: Second comparison tensor.
114ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    tolerance: Float value indicating how large an error between values is ok.
115ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
116ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  Returns:
117ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    Boolean indicating whether the two inputs were close enough.
118ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  """
119ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  flat_a = a.flatten()
120ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  flat_b = b.flatten()
121ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  if len(flat_a) != len(flat_b):
122e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    print("Tensors are different sizes: " + str(len(flat_a)) + " vs " + str(
123e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        len(flat_b)))
124ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    return False
125ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  value_count = len(flat_a)
126ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  how_many_different = 0
127ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  total_difference = 0
128ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  total_abs_difference = 0
129ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  for index in range(value_count):
130ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a_value = flat_a[index]
131ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b_value = flat_b[index]
132ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    difference = a_value - b_value
133ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    total_difference += difference
134ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    total_abs_difference += abs(difference)
135ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    if abs(difference) > tolerance:
136ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      how_many_different += 1
137ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  mean_difference = total_difference / value_count
138ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  mean_abs_difference = total_abs_difference / value_count
139ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  proportion_different = (how_many_different * 1.0) / value_count
140ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  if how_many_different == 0:
141ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    return True
142ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  else:
143ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    print("Tensors have {0} different values ({1}%), with mean difference"
14457915d504f1910afb43052dd337a35a1becffec7A. Unique TensorFlower          " {2} and mean absolute difference {3}".format(
145ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden              how_many_different, proportion_different * 100, mean_difference,
14657915d504f1910afb43052dd337a35a1becffec7A. Unique TensorFlower              mean_abs_difference))
147ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    return False
148ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
149ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
150ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendef get_top_value(input_values):
151ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  max_value = None
152ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  max_index = None
153ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  for index, value in enumerate(input_values.flatten()):
154ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    if max_value is None or value > max:
155ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      max_value = value
156ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden      max_index = index
157ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  return max_index, max_value
158ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
159ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
16095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlowerdef test_graph(float_graph_def, input_map, output_names, log_graph=False):
161ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  """Runs the float graph through the rewriter and tests the results."""
162e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  float_results = run_graph_def(
163e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      float_graph_def, input_map,
164e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      [output_name + ":0" for output_name in output_names])
165ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  # TODO(petewarden): round test is currently failing because there is no
166ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  # RoundToSteps op available.
167ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  # round_rewriter = quantize_graph.GraphRewriter(float_graph_def, "round")
168ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  # round_graph_def = round_rewriter.rewrite(output_name)
169ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  # round_results = run_graph_def(round_graph_def, input_map,
170ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  #                               [output_name + ":0"])
171ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  # assert are_tensors_near(expected, round_results[0], 1.0)
172ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  #
173ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  # TODO(petewarden): Add test for "quantize" mode.
174ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
175e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  eightbit_rewriter = quantize_graph.GraphRewriter(
176e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      float_graph_def, "eightbit", quantized_input_range=None)
177ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  eightbit_graph_def = eightbit_rewriter.rewrite(output_names)
178e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  eightbit_results = run_graph_def(
179e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      eightbit_graph_def, input_map,
180e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      [output_name + ":0" for output_name in output_names])
181ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  for expected, result in zip(float_results, eightbit_results):
182ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    assert are_tensors_near(expected, result, 1.0)
183ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
18495f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower  if log_graph:
185e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    tf_logging.info("8bit:\n%s", str(eightbit_graph_def))
18695f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
187ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  # Test the weights_rounded mode. This uses the default bit_depth.
188ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  weights_rounded_rewriter = quantize_graph.GraphRewriter(
1894b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      float_graph_def, "weights_rounded", quantized_input_range=None)
190ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  weights_rounded_graph_def = weights_rounded_rewriter.rewrite(output_names)
1914b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower  weights_rounded_results = run_graph_def(
1924b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      weights_rounded_graph_def, input_map,
1934b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      [output_name + ":0" for output_name in output_names])
194ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  for expected, result in zip(float_results, weights_rounded_results):
195ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    assert are_tensors_near(expected, result, 1.0)
196ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
197ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
198e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyclass QuantizeGraphTest(test.TestCase):
199ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
200e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower  def test_negative_const_problem(self):
201e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower    shape_constant_name = "shape_constant"
202e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower    shape_constant = quantize_graph.create_constant_node(
203e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape_constant_name, value=-0.8, dtype=dtypes.float32, shape=[1])
204e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower    quantization_result = quantize_graph.quantize_weight_eightbit(
205e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower        shape_constant, b"MIN_COMBINED")
206e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower    self.assertEqual(4, len(quantization_result))
207e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower
208ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_odd_padding_problem(self):
209ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    """Tests one error case we ran into in a real graph."""
210ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_conv(1, 4, 4, 1, 3, 1, 2, b"SAME",
211ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden              [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
212ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden              [1, 2, 3, 4, 5, 6, 7, 8, 9])
213ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
214ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_mat_mul_tiny(self):
215ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    # These tests are added to test the generate case where
216ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    # min(matrix) == max(matrix), which used to cause problems.
217ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_mat_mul(1, 1, 1, [2], [3])
218ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_mat_mul(1, 2, 1, [1], [2, 3])
219ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_mat_mul(1, 1, 2, [1, 1], [1, 1])
220ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_mat_mul(1, 1, 2, [0, 0], [1, 1])
221ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    # The general case.
222ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_mat_mul(1, 1, 2, [1, 2], [1, 2])
223ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
224ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_mat_mul_small(self):
225ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_mat_mul(2, 4, 3, [1, 2, 3, 4, 5, 6],
226ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                 [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])
227ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
228ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_conv(self):
229ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_conv(1, 4, 3, 1, 3, 1, 1, b"SAME",
230ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden              [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
231ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden              [1, 4, 7, 2, 5, 8, 3, 6, 9])
232ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
233f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower  def test_reshape(self):
234f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    """Tests that MatMul->Reshape->MatMul avoids extra quantize/dequantize."""
235e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney
236f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    def make_matmul(name, a, b):
237f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower      n = quantize_graph.create_node("MatMul", name, [a.name, b.name])
238e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      quantize_graph.set_attr_dtype(n, "T", dtypes.float32)
239f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower      quantize_graph.set_attr_bool(n, "transpose_a", False)
240f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower      quantize_graph.set_attr_bool(n, "transpose_b", False)
241f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower      return n
242f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower
243f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    # matmul_1 = input*weight_1
244f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    input_node = quantize_graph.create_constant_node(
245e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "input", value=[0, 1, 2, 3], dtype=dtypes.float32, shape=[4, 1])
246f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    weight_1_node = quantize_graph.create_constant_node(
247e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "weight_1",
248e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[.5, .6, .7, .8, .9],
249e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
250e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[1, 5])
251f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    matmul_1_node = make_matmul("matmul_1", input_node, weight_1_node)
252f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower
253f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    # Reshape 4x5 to 10x2.
254f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    new_shape_node = quantize_graph.create_constant_node(
255e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "new_shape_node", value=[10, 2], dtype=dtypes.int32, shape=[2])
256f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    reshape_node = quantize_graph.create_node(
257f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower        "Reshape", "reshape", [matmul_1_node.name, new_shape_node.name])
258e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(reshape_node, "T", dtypes.float32)
259f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower
260f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    # matmul_2_node = reshape*weight_2
261f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    weight_2_node = quantize_graph.create_constant_node(
262e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "weight_2", value=[1.5, 2.5], dtype=dtypes.float32, shape=[2, 1])
263f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    matmul_2_node = make_matmul("matmul_2", reshape_node, weight_2_node)
264f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower
265e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    g = graph_pb2.GraphDef()
266e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    g.node.extend([
267e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        input_node, weight_1_node, matmul_1_node, new_shape_node, reshape_node,
268e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        weight_2_node, matmul_2_node
269e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    ])
270f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower
271f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    # Test the graph
272f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    test_graph(g, {}, ["matmul_2"])
273f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower
274f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    # Verify there is only one Quantize and one Requantize op.
275e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    eightbit_rewriter = quantize_graph.GraphRewriter(
276e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        g, "eightbit", quantized_input_range=None)
277f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    eightbit_graph_def = eightbit_rewriter.rewrite(["matmul_2"])
278f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower
279f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    ops = [node.op for node in eightbit_graph_def.node]
280f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    # No quantize since all inputs are const and can be quantized up-front.
281f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
282662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    self.assertEqual(1, ops.count("QuantizedReshape"))
283f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower
284f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    # One dequantize at the end.
285f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower    self.assertEqual(1, ops.count("Dequantize"))
286f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower
287ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_quantize_array(self):
288ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    # Test invalid parameters (empty array, or 0 buckets.
289e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    self.assertRaises(ValueError, quantize_graph.quantize_array, np.array([]),
290e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                      2)
291ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    self.assertRaises(ValueError, quantize_graph.quantize_array,
292ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                      np.array([1, 2]), 0)
293ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    # Test input array of length 1.
294ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    arr = np.array([1])
295ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    qarr = quantize_graph.quantize_array(arr, 1)
296ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    self.assertEqual(arr, qarr)
297ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    qarr = quantize_graph.quantize_array(arr, 2)
298ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    self.assertEqual(arr, qarr)
299ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    # Test input array with all elements equal.
300ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    arr = np.array([1, 1, 1])
301ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    qarr = quantize_graph.quantize_array(arr, 10)
302ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    self.assertTrue((np.array([1, 1, 1]) == qarr).all())
303ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    # Test "normal" input arrays.
304ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    arr = np.array([0, 0.3, 0.6, 1])
305ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    qarr = quantize_graph.quantize_array(arr, 1)
306ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    self.assertTrue((np.array([0.5, 0.5, 0.5, 0.5]) == qarr).all())
307ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    qarr = quantize_graph.quantize_array(arr, 2)
308ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    self.assertTrue((np.array([0.25, 0.25, 0.75, 0.75]) == qarr).all())
309ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    qarr = quantize_graph.quantize_array(arr.reshape((2, 2)), 2)
310ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    self.assertTrue((np.array([[0.25, 0.25], [0.75, 0.75]]) == qarr).all())
311ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
312662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower  def test_non_float_concat(self):
313662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    concat_dim = quantize_graph.create_constant_node(
314e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "concat_dim", value=0, dtype=dtypes.int32, shape=[])
315662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    a = quantize_graph.create_constant_node(
316e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "a",
317e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
318e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.int32,
319e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[2, 2, 3])
320662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    b = quantize_graph.create_constant_node(
321e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "b",
322e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
323e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.int32,
324e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[2, 2, 3])
325e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    concat = quantize_graph.create_node("Concat", "concat",
326e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                        [concat_dim.name, a.name, b.name])
327662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    quantize_graph.set_attr_int(concat, "N", 2)
328e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(concat, "T", dtypes.int32)
329662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower
330e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    g = graph_pb2.GraphDef()
331662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    g.node.extend([concat_dim, a, b, concat])
332662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    test_graph(g, {}, [concat.name])
333662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower
334662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower  def test_non_float_reshape(self):
335662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    a = quantize_graph.create_constant_node(
336e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "a",
337e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
338e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.int32,
339e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[2, 2, 3])
340662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    shape = quantize_graph.create_constant_node(
341e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "shape", value=[12], dtype=dtypes.int32, shape=[1])
342e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    reshape = quantize_graph.create_node("Reshape", "reshape",
343e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                         [a.name, shape.name])
344e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(reshape, "T", dtypes.int32)
345662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower
346e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    g = graph_pb2.GraphDef()
347662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    g.node.extend([a, shape, reshape])
348662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    test_graph(g, {}, [reshape.name])
349662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower
350ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_concat(self):
351ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    shape_constant_name = "shape_constant"
352ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a_constant_name = "a_constant"
353ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b_constant_name = "b_constant"
354ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    concat_name = "concat"
355ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
356e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
357e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    shape_constant = quantize_graph.create_constant_node(
358e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape_constant_name, value=0, dtype=dtypes.int32, shape=[])
359ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([shape_constant])
360e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_constant = quantize_graph.create_constant_node(
361e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_constant_name,
362e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
363e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
364e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[2, 2, 3])
365ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([a_constant])
366e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_constant = quantize_graph.create_constant_node(
367e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        b_constant_name,
368e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
369e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
370e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[2, 2, 3])
371ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([b_constant])
372e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    concat_node = quantize_graph.create_node(
373e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "Concat", concat_name,
374e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        [shape_constant_name, a_constant_name, b_constant_name])
375ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    quantize_graph.set_attr_int(concat_node, "N", 2)
376e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32)
377ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([concat_node])
378ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
379ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_graph(float_graph_def, {}, [concat_name])
380ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
381662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    # Verify the concat is quantized.
382662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    eightbit_rewriter = quantize_graph.GraphRewriter(
383662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower        float_graph_def, "eightbit", quantized_input_range=None)
384662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    eightbit_graph_def = eightbit_rewriter.rewrite([concat_name])
385662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower
386662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    ops = [node.op for node in eightbit_graph_def.node]
387662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower    self.assertEqual(1, ops.count("QuantizedConcat"))
388662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower
389ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_multiple_outputs(self):
390ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    input_constant_name = "input_constant"
391ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    split_constant_name = "split_constant"
392ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    split_name = "split"
393ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    concat_constant_name = "concat_constant"
394ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    concat_name = "concat"
395ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
396e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
397e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    input_constant = quantize_graph.create_constant_node(
398e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        input_constant_name,
399e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
400e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
401e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[2, 6])
402ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([input_constant])
403e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    split_constant = quantize_graph.create_constant_node(
404e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        split_constant_name, value=1, dtype=dtypes.int32, shape=[])
405ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([split_constant])
406e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    split_node = quantize_graph.create_node(
407e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "Split", split_name, [split_constant_name, input_constant_name])
408ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    quantize_graph.set_attr_int(split_node, "num_split", 2)
409e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(split_node, "T", dtypes.float32)
410ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([split_node])
411e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    concat_constant = quantize_graph.create_constant_node(
412e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        concat_constant_name, value=1, dtype=dtypes.int32, shape=[])
413ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([concat_constant])
414e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    concat_node = quantize_graph.create_node(
415e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "Concat", concat_name,
416e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        [concat_constant_name, split_name + ":0", split_name + ":1"])
417ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    quantize_graph.set_attr_int(concat_node, "N", 2)
418e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32)
419ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([concat_node])
420ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
421ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_graph(float_graph_def, {}, [concat_name])
422ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
423ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_node_name_from_input(self):
424ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    self.assertEqual("SomeName",
425ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                     quantize_graph.node_name_from_input("^SomeName:2"))
426ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
427ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_unique_node_name_from_input(self):
428ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    self.assertEqual("__hat__SomeName__port__2",
429ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                     quantize_graph.unique_node_name_from_input("^SomeName:2"))
430ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
431ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_identity(self):
432ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    input_constant_name = "input_constant"
433ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    identity_name = "identity"
434e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
435e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    input_constant = quantize_graph.create_constant_node(
436e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        input_constant_name,
437e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
438e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
439e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[2, 6])
440ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([input_constant])
441ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    identity_node = quantize_graph.create_node("Identity", identity_name,
442ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                               [input_constant_name])
443e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(identity_node, "T", dtypes.float32)
444ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([identity_node])
4455328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower
4465328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower    mul_name = "mul"
4475328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower    mul_node = quantize_graph.create_node("Mul", mul_name,
4485328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower                                          [identity_name, identity_name])
449e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(mul_node, "T", dtypes.float32)
4505328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower    float_graph_def.node.extend([mul_node])
4515328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower
4525328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower    test_graph(float_graph_def, {}, [mul_name])
453ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
454ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_keep_control_edges(self):
455ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    no_op_name = "no_op"
456ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a_constant_name = "a_constant"
457ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b_constant_name = "b_constant"
458ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a_check_name = "a_check"
459ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b_check_name = "b_check"
460ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a_identity_name = "a_identity"
461ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b_identity_name = "b_identity"
462ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    add_name = "add"
463e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    graph_def = graph_pb2.GraphDef()
464ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    no_op = quantize_graph.create_node("NoOp", no_op_name, [])
465ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([no_op])
466e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_constant = quantize_graph.create_constant_node(
467e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_constant_name, value=1, dtype=dtypes.float32, shape=[])
468ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([a_constant])
469ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
470ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                              [a_constant_name])
471ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([a_check_node])
472e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_identity_node = quantize_graph.create_node(
473e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "Identity", a_identity_name,
474e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        [a_constant_name, "^" + a_check_name, "^" + no_op_name])
475ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([a_identity_node])
476e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_constant = quantize_graph.create_constant_node(
477e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        b_constant_name, value=1, dtype=dtypes.float32, shape=[])
478ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([b_constant])
479ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
480ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                              [b_constant_name])
481ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([b_check_node])
482e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_identity_node = quantize_graph.create_node(
483e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
484ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([b_identity_node])
485ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    add_node = quantize_graph.create_node("Add", add_name,
486e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                          [a_identity_name, b_identity_name])
487e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
488ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([add_node])
489ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
490e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    expected_output = graph_pb2.GraphDef()
491ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    no_op = quantize_graph.create_node("NoOp", no_op_name, [])
492ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([no_op])
493e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_constant = quantize_graph.create_constant_node(
494e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_constant_name, value=1, dtype=dtypes.float32, shape=[])
495ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([a_constant])
496e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_identity_node = quantize_graph.create_node(
497e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "Identity", a_identity_name, [a_constant_name, "^" + no_op_name])
498ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([a_identity_node])
499e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_constant = quantize_graph.create_constant_node(
500e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        b_constant_name, value=1, dtype=dtypes.float32, shape=[])
501ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([b_constant])
502ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    add_node = quantize_graph.create_node("Add", add_name,
503e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                          [a_identity_name, b_constant_name])
504e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
505ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([add_node])
506fe4ec328ac6210381cc6a0c9d98b35d35f8d9209A. Unique TensorFlower    expected_output.versions.CopyFrom(graph_def.versions)
507fe4ec328ac6210381cc6a0c9d98b35d35f8d9209A. Unique TensorFlower    expected_output.library.CopyFrom(graph_def.library)
508ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
5097c7014fd41cdf4e24f923b9e79c249d717aa508fPete Warden    output = graph_util.remove_training_nodes(graph_def)
510ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    stripped_output = graph_util.extract_sub_graph(output, [add_name])
511ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    self.assertProtoEquals(expected_output, stripped_output)
512ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
513ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_batch_norm(self):
514ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    input_constant_name = "input_constant"
515ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    mean_constant_name = "mean_constant"
516ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    variance_constant_name = "variance_constant"
517ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    beta_constant_name = "beta_constant"
518ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    gamma_constant_name = "gamma_constant"
519ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    batch_norm_name = "batch_norm"
520e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
521e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    input_constant = quantize_graph.create_constant_node(
522e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        input_constant_name,
523e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6],
524e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
525e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[1, 1, 6, 2])
526ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([input_constant])
527e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    mean_constant = quantize_graph.create_constant_node(
528e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        mean_constant_name, value=[10, 20], dtype=dtypes.float32, shape=[2])
529ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([mean_constant])
530ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    variance_constant = quantize_graph.create_constant_node(
531e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        variance_constant_name,
532e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[0.25, 0.5],
533e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
534e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[2])
535ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([variance_constant])
536e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    beta_constant = quantize_graph.create_constant_node(
537e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        beta_constant_name, value=[0.1, 0.6], dtype=dtypes.float32, shape=[2])
538ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([beta_constant])
539e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    gamma_constant = quantize_graph.create_constant_node(
540e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        gamma_constant_name, value=[0, 0], dtype=dtypes.float32, shape=[2])
541ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([gamma_constant])
542ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    batch_norm_node = quantize_graph.create_node(
543e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "BatchNormWithGlobalNormalization", batch_norm_name, [
544e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney            input_constant_name, mean_constant_name, variance_constant_name,
545e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney            beta_constant_name, gamma_constant_name
546e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        ])
547e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(batch_norm_node, "T", dtypes.float32)
548ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    quantize_graph.set_attr_bool(batch_norm_node, "scale_after_normalization",
549ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                 False)
550ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    quantize_graph.set_attr_float(batch_norm_node, "variance_epsilon", 0.001)
551ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([batch_norm_node])
552ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_graph(float_graph_def, {}, [batch_norm_name])
553ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
554ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_max_pool(self):
555ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    input_constant_name = "input_constant"
556ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    max_pool_name = "max_pool"
557e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
558e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    input_constant = quantize_graph.create_constant_node(
559e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        input_constant_name,
560e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
561e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
562e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[1, 2, 6, 1])
563ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([input_constant])
564ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    max_pool_node = quantize_graph.create_node("MaxPool", max_pool_name,
565ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                               [input_constant_name])
566ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    quantize_graph.set_attr_int_list(max_pool_node, "ksize", [1, 2, 2, 1])
567ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    quantize_graph.set_attr_int_list(max_pool_node, "strides", [1, 1, 1, 1])
568ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    quantize_graph.set_attr_string(max_pool_node, "padding", b"SAME")
569ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([max_pool_node])
570ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_graph(float_graph_def, {}, [max_pool_name])
571ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
572ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_avg_pool(self):
573ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    input_constant_name = "input_constant"
574ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    avg_pool_name = "avg_pool"
575e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
576e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    input_constant = quantize_graph.create_constant_node(
577e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        input_constant_name,
578e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
579e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
580e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[1, 2, 6, 1])
581ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([input_constant])
582ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    avg_pool_node = quantize_graph.create_node("AvgPool", avg_pool_name,
583ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                               [input_constant_name])
584e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(avg_pool_node, "T", dtypes.float32)
585ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    quantize_graph.set_attr_int_list(avg_pool_node, "ksize", [1, 2, 2, 1])
586ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    quantize_graph.set_attr_int_list(avg_pool_node, "strides", [1, 1, 1, 1])
587ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    quantize_graph.set_attr_string(avg_pool_node, "padding", b"SAME")
588ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([avg_pool_node])
589ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_graph(float_graph_def, {}, [avg_pool_name])
590ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
591ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_relu(self):
592ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    input_constant_name = "input_constant"
593ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    relu_name = "relu"
594e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
595e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    input_constant = quantize_graph.create_constant_node(
596e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        input_constant_name,
597e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
598e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
599e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[1, 2, 6, 1])
600ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([input_constant])
601ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    relu_node = quantize_graph.create_node("Relu", relu_name,
602ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                           [input_constant_name])
603e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(relu_node, "T", dtypes.float32)
604ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([relu_node])
605ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_graph(float_graph_def, {}, [relu_name])
606ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
60795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower  def test_relu_w_fake_quant_w_min_max_vars(self):
60895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    input_node = quantize_graph.create_constant_node(
609e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "input",
610e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
611e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
612e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[1, 2, 6, 1])
613e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    relu_node = quantize_graph.create_node("Relu", "relu", [input_node.name])
614e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(relu_node, "T", dtypes.float32)
61595f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
61695f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    min_node = quantize_graph.create_constant_node(
617e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "min_bias_add", value=0, dtype=dtypes.float32, shape=[])
61895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    max_node = quantize_graph.create_constant_node(
619e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "max_bias_add", value=12, dtype=dtypes.float32, shape=[])
62095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    fake_quant_node = quantize_graph.create_node(
62195f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower        "FakeQuantWithMinMaxVars", "fake_quant",
62295f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower        [relu_node.name, min_node.name, max_node.name])
62395f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
624e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
625e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def.node.extend(
626e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        [input_node, relu_node, min_node, max_node, fake_quant_node])
62795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
62895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
62995f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    # Verify there is only one Quantize and one Requantize op.
630e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    eightbit_rewriter = quantize_graph.GraphRewriter(
631e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        float_graph_def, "eightbit", quantized_input_range=None)
63295f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
63395f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
63495f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    ops = [node.op for node in eightbit_graph_def.node]
63595f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    # No quantize since all inputs are const and can be quantized up-front.
63695f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
63795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
63895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    # One dequantize at the end.
63995f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    self.assertEqual(1, ops.count("Dequantize"))
64095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
641ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_relu6(self):
642ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    input_constant_name = "input_constant"
643ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    relu6_name = "relu6"
644e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
645e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    input_constant = quantize_graph.create_constant_node(
646e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        input_constant_name,
647e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
648e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
649e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[1, 2, 6, 1])
650ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([input_constant])
651ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    relu6_node = quantize_graph.create_node("Relu6", relu6_name,
652ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden                                            [input_constant_name])
653e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(relu6_node, "T", dtypes.float32)
654ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([relu6_node])
655ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_graph(float_graph_def, {}, [relu6_name])
656ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
657ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_bias_add(self):
658ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    input_constant_name = "input_constant"
659ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    offset_constant_name = "offset_constant"
660ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    bias_add_name = "bias_add"
661e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
662e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    input_constant = quantize_graph.create_constant_node(
663e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        input_constant_name,
664e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
665e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
666e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[1, 1, 2, 6])
667ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([input_constant])
668e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    offset_constant = quantize_graph.create_constant_node(
669e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        offset_constant_name,
670e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6],
671e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
672e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[6])
673ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([offset_constant])
674e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    bias_add_node = quantize_graph.create_node(
675e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "BiasAdd", bias_add_name, [input_constant_name, offset_constant_name])
676e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
677ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    float_graph_def.node.extend([bias_add_node])
678ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    test_graph(float_graph_def, {}, [bias_add_name])
679ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
6804b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower  def test_quantized_input_range_errors(self):
6814b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    with self.assertRaises(ValueError):
6824b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      # Invalid mode.
683e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "weights_rounded",
684e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                   [0, 1])
6854b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    with self.assertRaises(ValueError):
6864b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      # Invalid range.
687e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "eightbit", [0, -1])
6884b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower
6894b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower  def test_quantized_input_range_bias_add(self):
6904b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    input_shape = [1, 1, 2, 6]
69153cb26d05a5c2080d8022124178b1cc43a30ffe5A. Unique TensorFlower    input_n = quantize_graph.create_node("Placeholder", "input", [])
692e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(input_n, "dtype", dtypes.float32)
6934b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    quantize_graph.set_attr_shape(input_n, "shape", input_shape)
694e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    offset_n = quantize_graph.create_constant_node(
695e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "offset", value=[1, 2, 3, 4, 5, 6], dtype=dtypes.float32, shape=[6])
6964b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    bias_add_n = quantize_graph.create_node("BiasAdd", "bias_add",
6974b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower                                            [input_n.name, offset_n.name])
698e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(bias_add_n, "T", dtypes.float32)
6994b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower
700e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
7014b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    float_graph_def.node.extend([input_n, offset_n, bias_add_n])
7024b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower
703e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    input_map = {
704e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        input_n.name + ":0":
705e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney            np.reshape([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], input_shape)
706e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    }
707e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
708e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                         [bias_add_n.name], [-1, 20.])
709e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
710e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                         [bias_add_n.name], [0, 12.])
7114b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower
7124b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower  def test_quantized_input_range_mat_mul(self):
7134b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    shapes = [[3, 2], [2, 4]]
7144b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    inputs = []
7154b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    for i, shape in enumerate(shapes):
71653cb26d05a5c2080d8022124178b1cc43a30ffe5A. Unique TensorFlower      node = quantize_graph.create_node("Placeholder", "input_%s" % i, [])
717e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      quantize_graph.set_attr_dtype(node, "dtype", dtypes.float32)
7184b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      quantize_graph.set_attr_shape(node, "shape", shape)
7194b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      inputs.append(node)
7204b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    mat_mul_node = quantize_graph.create_node("MatMul", "mat_mul",
7214b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower                                              [n.name for n in inputs])
722e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(mat_mul_node, "T", dtypes.float32)
7234b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower
724e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
7254b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    float_graph_def.node.extend(inputs + [mat_mul_node])
7264b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower
727e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    input_map = {
728e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        inputs[0].name + ":0":
729e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney            np.reshape([1, 2, 3, 4, 5, 6], shapes[0]),
730e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        inputs[1].name + ":0":
731e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney            np.reshape([.8, .7, .6, .5, .4, .3, .2, .1], shapes[1])
732e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    }
733e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
734e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                         [mat_mul_node.name], [-1, 20.])
735e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
736e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                         [mat_mul_node.name], [0, 6.])
7374b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower
7384b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower  def _RunTestsForQuantizedInputRange(self, float_graph_def, input_map,
7394b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower                                      output_names, input_range):
7404b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    if sys.version_info[0] == 3:
7414b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      # uint8->quint8 conversion for numpy is not working currently.
7424b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      return
7434b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower
7444b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    quantized_input_map = {}
7454b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    for k, v in input_map.items():
7464b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      arr = [
747e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney          int(
748e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney              round((n - input_range[0]) * 255 / (input_range[1] - input_range[
749e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                  0]))) for n in v.flat
750e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      ]
7514b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      arr = np.array(arr, np.uint8)
7524b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      arr = arr.reshape(v.shape)
753e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      arr = arr.astype(dtypes.quint8.as_numpy_dtype)
7544b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      quantized_input_map[k] = arr
7554b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    output_tensors = [output_name + ":0" for output_name in output_names]
7564b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    float_results = run_graph_def(float_graph_def, input_map, output_tensors)
7574b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower
7584b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    # Quantize treating the input as quantized in range <input_range>.
7594b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    rewriter = quantize_graph.GraphRewriter(float_graph_def, "eightbit",
7604b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower                                            input_range)
7614b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    graph_def = rewriter.rewrite(output_names)
7624b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    results = run_graph_def(graph_def, quantized_input_map, output_tensors)
7634b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    for expected, result in zip(float_results, results):
7644b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      assert are_tensors_near(expected, result, .5)
7654b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    ops = [node.op for node in graph_def.node]
7664b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
7674b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    self.assertEqual(len(output_names), ops.count("Dequantize"))
7684b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower
7694b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    # Quantize without treating input as quantized.
770e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    rewriter = quantize_graph.GraphRewriter(
771e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        float_graph_def, "eightbit", quantized_input_range=None)
7724b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    graph_def = rewriter.rewrite(output_names)
7734b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    results = run_graph_def(graph_def, input_map, output_tensors)
7744b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    for expected, result in zip(float_results, results):
7754b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower      assert are_tensors_near(expected, result, .5)
7764b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    ops = [node.op for node in graph_def.node]
777e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    self.assertEqual(
778e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        len(input_map), ops.count("QuantizeV2") + ops.count("Quantize"))
7794b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower    self.assertEqual(len(output_names), ops.count("Dequantize"))
7804b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower
78195f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower  def test_bias_add_w_fake_quant_w_min_max_vars(self):
78295f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    input_node = quantize_graph.create_constant_node(
783e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "input",
784e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
785e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
786e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[1, 1, 2, 5])
78795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    offset_node = quantize_graph.create_constant_node(
788e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5])
78995f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    bias_add_node = quantize_graph.create_node(
79095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower        "BiasAdd", "bias_add", [input_node.name, offset_node.name])
791e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
79295f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
79395f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    min_node = quantize_graph.create_constant_node(
794e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "min_bias_add", value=-.5, dtype=dtypes.float32, shape=[])
79595f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    max_node = quantize_graph.create_constant_node(
796e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "max_bias_add", value=15.5, dtype=dtypes.float32, shape=[])
79795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    fake_quant_node = quantize_graph.create_node(
79895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower        "FakeQuantWithMinMaxVars", "fake_quant",
79995f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower        [bias_add_node.name, min_node.name, max_node.name])
80095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
801e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
802e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def.node.extend([
803e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        input_node, offset_node, bias_add_node, min_node, max_node,
804e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        fake_quant_node
805e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    ])
80695f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
80795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
80895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    # Verify there is only one Quantize and one Requantize op.
80942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    # Pass in fallback_quantization_range, although it will have no effect
81042e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    # because the FakeQuantWithMinMaxVars are used instead.
81142e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    eightbit_rewriter = quantize_graph.GraphRewriter(
812e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        float_graph_def,
813e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "eightbit",
814e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        quantized_input_range=None,
81542e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower        fallback_quantization_range=[-100, 100])
81695f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
81795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
81895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    ops = [node.op for node in eightbit_graph_def.node]
81942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    node_names = [node.name for node in eightbit_graph_def.node]
82095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    # No quantize since all inputs are const and can be quantized up-front.
82195f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
82295f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
82395f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    # One dequantize at the end.
82495f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower    self.assertEqual(1, ops.count("Dequantize"))
82595f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower
82642e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    # The fallback constants are not in the graph.
82742e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    self.assertEqual(0, node_names.count("fallback_quantization_min_value"))
82842e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    self.assertEqual(0, node_names.count("fallback_quantization_max_value"))
82942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower
83042e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower  def test_bias_add_w_fallback_min_max_vars(self):
83142e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    input_node = quantize_graph.create_constant_node(
832e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "input",
833e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
834e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        dtype=dtypes.float32,
835e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        shape=[1, 1, 2, 5])
83642e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    offset_node = quantize_graph.create_constant_node(
837e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5])
83842e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    bias_add_node = quantize_graph.create_node(
83942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower        "BiasAdd", "bias_add", [input_node.name, offset_node.name])
840e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
84142e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower
842e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    float_graph_def = graph_pb2.GraphDef()
84342e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    float_graph_def.node.extend([input_node, offset_node, bias_add_node])
84442e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    test_graph(float_graph_def, {}, [bias_add_node.name], log_graph=True)
84542e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower
84642e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    # Verify there is only one Quantize, one Requantize op, and no
84742e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    # RequantizationRange op.
84842e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    eightbit_rewriter = quantize_graph.GraphRewriter(
849e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        float_graph_def,
850e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "eightbit",
851e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        quantized_input_range=None,
85242e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower        fallback_quantization_range=[-.5, 15.5])
85342e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    eightbit_graph_def = eightbit_rewriter.rewrite([bias_add_node.name])
85442e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower
85542e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    ops = [node.op for node in eightbit_graph_def.node]
85642e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    node_names = [node.name for node in eightbit_graph_def.node]
85742e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    # No quantize since all inputs are const and can be quantized up-front.
85842e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
85942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower
86042e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    # One dequantize at the end.
86142e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    self.assertEqual(1, ops.count("Dequantize"))
86242e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower
86342e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    # No RequantizationRange
86442e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    self.assertEqual(0, ops.count("RequantizationRange"))
86542e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower
86642e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    # The fallback constants are in the graph.
86742e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    self.assertEqual(1, node_names.count("fallback_quantization_min_value"))
86842e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower    self.assertEqual(1, node_names.count("fallback_quantization_max_value"))
86942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower
870ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden  def test_remove_redundant_quantization(self):
871ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a_constant_name = "a_constant"
872ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a_constant_min_name = "a_constant_min"
873ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a_constant_max_name = "a_constant_max"
874ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a_dequantize_name = "a_dequantize"
875ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    a_quantize_name = "a_quantize"
876ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b_constant_name = "b_constant"
877ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b_constant_min_name = "b_constant_min"
878ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b_constant_max_name = "b_constant_max"
879ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b_dequantize_name = "b_dequantize"
880ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    b_quantize_name = "b_quantize"
881ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    mat_mul_name = "mat_mul"
882e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    graph_def = graph_pb2.GraphDef()
883e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_constant = quantize_graph.create_constant_node(
884e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
885ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([a_constant])
886e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_constant_min = quantize_graph.create_constant_node(
887e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
888ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([a_constant_min])
889e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_constant_max = quantize_graph.create_constant_node(
890e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
891ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([a_constant_max])
892e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_dequantize_node = quantize_graph.create_node(
893e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "Dequantize", a_dequantize_name,
894e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        [a_constant_name, a_constant_min_name, a_constant_max_name])
895e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(a_dequantize_node, "T", dtypes.uint8)
896ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([a_dequantize_node])
897e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_quantize_node = quantize_graph.create_node(
898e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "QuantizeV2", a_quantize_name,
899e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        [a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2"])
900e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(a_quantize_node, "T", dtypes.uint8)
901ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([a_quantize_node])
902e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_constant = quantize_graph.create_constant_node(
903e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
904ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([b_constant])
905e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_constant_min = quantize_graph.create_constant_node(
906e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
907ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([b_constant_min])
908e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_constant_max = quantize_graph.create_constant_node(
909e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
910ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([b_constant_max])
911e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_dequantize_node = quantize_graph.create_node(
912e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "Dequantize", b_dequantize_name,
913e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        [b_constant_name, b_constant_min_name, b_constant_max_name])
914e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(b_dequantize_node, "T", dtypes.uint8)
915ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([b_dequantize_node])
916e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_quantize_node = quantize_graph.create_node(
917e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        "QuantizeV2", b_quantize_name,
918e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        [b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2"])
919e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(b_quantize_node, "T", dtypes.uint8)
920ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([b_quantize_node])
921e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
922e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_quantize_name, b_quantize_name, a_quantize_name + ":1",
923e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2"
924e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    ])
925e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
926e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
927ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    graph_def.node.extend([mat_mul_node])
928ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
929e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    expected_output = graph_pb2.GraphDef()
930e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_constant = quantize_graph.create_constant_node(
931e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
932ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([a_constant])
933e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_constant_min = quantize_graph.create_constant_node(
934e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
935ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([a_constant_min])
936e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    a_constant_max = quantize_graph.create_constant_node(
937e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
938ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([a_constant_max])
939e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_constant = quantize_graph.create_constant_node(
940e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
941ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([b_constant])
942e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_constant_min = quantize_graph.create_constant_node(
943e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
944ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([b_constant_min])
945e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    b_constant_max = quantize_graph.create_constant_node(
946e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
947ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([b_constant_max])
948e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
949e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_constant_name, b_constant_name, a_constant_min_name,
950e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        a_constant_max_name, b_constant_min_name, b_constant_max_name
951e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    ])
952e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
953e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
954ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    expected_output.node.extend([mat_mul_node])
955fe4ec328ac6210381cc6a0c9d98b35d35f8d9209A. Unique TensorFlower    expected_output.versions.CopyFrom(graph_def.versions)
956fe4ec328ac6210381cc6a0c9d98b35d35f8d9209A. Unique TensorFlower    expected_output.library.CopyFrom(graph_def.library)
957ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
958e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    rewriter = quantize_graph.GraphRewriter(
959e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        graph_def, [mat_mul_name], quantized_input_range=None)
960ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    output = rewriter.remove_redundant_quantization(graph_def)
961ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
962ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden    self.assertProtoEquals(expected_output, stripped_output)
963ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
964ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden
965ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardenif __name__ == "__main__":
966e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  test.main()
967