1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# 3ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# Licensed under the Apache License, Version 2.0 (the "License"); 4ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# you may not use this file except in compliance with the License. 5ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# You may obtain a copy of the License at 6ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# 7ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# http://www.apache.org/licenses/LICENSE-2.0 8ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# 9ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# Unless required by applicable law or agreed to in writing, software 10ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# distributed under the License is distributed on an "AS IS" BASIS, 11ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# See the License for the specific language governing permissions and 13ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# limitations under the License. 14ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden# ============================================================================== 15ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden"""Tests the graph quantization script. 16ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 17ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden""" 1836d561c253844323fb1de28303407b451b7886a7Geoffrey Irving 1936d561c253844323fb1de28303407b451b7886a7Geoffrey Irvingfrom __future__ import absolute_import 2036d561c253844323fb1de28303407b451b7886a7Geoffrey Irvingfrom __future__ import division 2136d561c253844323fb1de28303407b451b7886a7Geoffrey Irvingfrom __future__ import print_function 2236d561c253844323fb1de28303407b451b7886a7Geoffrey Irving 234b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlowerimport sys 2466024fd508748d706b72d0ae5e8b07f917e78458Andrew Harpimport numpy as np 2566024fd508748d706b72d0ae5e8b07f917e78458Andrew Harp 26e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.core.framework import graph_pb2 27e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.client import session 28e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import dtypes 29a558c6e3b38846727873b5afbbc3ba309ae5dff5Olivia Nordquistfrom tensorflow.python.framework import graph_util 30e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import importer 31e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import ops as ops_lib 32e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.platform import flags as flags_lib 33e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.platform import test 34e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.platform import tf_logging 3566024fd508748d706b72d0ae5e8b07f917e78458Andrew Harpfrom tensorflow.tools.quantization import quantize_graph 36ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 37e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyflags = flags_lib 38ca4e053aa52ab9a42467d4df814ca9272487dbdfPete WardenFLAGS = flags.FLAGS 39ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 40ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 41ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendef run_graph_def(graph_def, input_map, outputs): 42e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney graph = ops_lib.Graph() 43ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden with graph.as_default(): 44e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney importer.import_graph_def(graph_def, input_map={}, name="") 45e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney with session.Session(graph=graph) as sess: 46ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden results = sess.run(outputs, feed_dict=input_map) 47ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden return results 48ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 49ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 50ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendef test_mat_mul(m, n, k, a, b): 51ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden """Tests a MatMul replacement.""" 52ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_constant_name = "a_constant" 53ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_constant_name = "b_constant" 54ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden mat_mul_name = "mat_mul" 55ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 56e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 57e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant = quantize_graph.create_constant_node( 58e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_name, value=a, dtype=dtypes.float32, shape=[m, k]) 59ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([a_constant]) 60e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant = quantize_graph.create_constant_node( 61e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_name, value=b, dtype=dtypes.float32, shape=[k, n]) 62ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([b_constant]) 63ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden mat_mul_node = quantize_graph.create_node("MatMul", mat_mul_name, 64ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [a_constant_name, b_constant_name]) 65e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(mat_mul_node, "T", dtypes.float32) 66ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_bool(mat_mul_node, "transpose_a", False) 67ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_bool(mat_mul_node, "transpose_b", False) 68ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([mat_mul_node]) 69ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 70ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_graph(float_graph_def, {}, [mat_mul_name]) 71ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 72ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 73ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendef test_conv(depth, image_width, image_height, image_batch_count, filter_size, 74ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden filter_count, stride, padding, input_values, filter_values): 75ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden """Tests a Conv replacement.""" 76ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden input_constant_name = "input_constant" 77ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden filter_constant_name = "filter_constant" 78ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden conv_name = "conv" 79ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 80e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 81ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden input_constant = quantize_graph.create_constant_node( 82ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden input_constant_name, 83ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden value=input_values, 84e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 85e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[image_batch_count, image_height, image_width, depth]) 86ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([input_constant]) 87ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden filter_constant = quantize_graph.create_constant_node( 88ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden filter_constant_name, 89ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden value=filter_values, 90e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 91e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[filter_size, filter_size, depth, filter_count]) 92ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([filter_constant]) 93e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney conv_node = quantize_graph.create_node( 94e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "Conv2D", conv_name, [input_constant_name, filter_constant_name]) 95e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(conv_node, "T", dtypes.float32) 96ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_int_list(conv_node, "strides", [1, stride, stride, 1]) 97ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_string(conv_node, "padding", padding) 98ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([conv_node]) 99ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 100ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_graph(float_graph_def, {}, [conv_name]) 101ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 102ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 103ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendef are_tensors_near(a, b, tolerance): 104ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden """Tests whether two tensors are nearly identical. 105ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 106ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden This is a specialized comparison function designed to help debug problems with 107ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantization. It prints out information about the differences between tensors 108ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden on failure, paying special attention to possible biases by looking at the mean 109ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden and absolute average errors. 110ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 111ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden Args: 112ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a: First comparison tensor. 113ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b: Second comparison tensor. 114ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden tolerance: Float value indicating how large an error between values is ok. 115ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 116ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden Returns: 117ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden Boolean indicating whether the two inputs were close enough. 118ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden """ 119ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden flat_a = a.flatten() 120ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden flat_b = b.flatten() 121ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden if len(flat_a) != len(flat_b): 122e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney print("Tensors are different sizes: " + str(len(flat_a)) + " vs " + str( 123e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney len(flat_b))) 124ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden return False 125ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden value_count = len(flat_a) 126ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden how_many_different = 0 127ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden total_difference = 0 128ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden total_abs_difference = 0 129ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden for index in range(value_count): 130ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_value = flat_a[index] 131ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_value = flat_b[index] 132ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden difference = a_value - b_value 133ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden total_difference += difference 134ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden total_abs_difference += abs(difference) 135ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden if abs(difference) > tolerance: 136ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden how_many_different += 1 137ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden mean_difference = total_difference / value_count 138ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden mean_abs_difference = total_abs_difference / value_count 139ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden proportion_different = (how_many_different * 1.0) / value_count 140ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden if how_many_different == 0: 141ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden return True 142ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden else: 143ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden print("Tensors have {0} different values ({1}%), with mean difference" 14457915d504f1910afb43052dd337a35a1becffec7A. Unique TensorFlower " {2} and mean absolute difference {3}".format( 145ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden how_many_different, proportion_different * 100, mean_difference, 14657915d504f1910afb43052dd337a35a1becffec7A. Unique TensorFlower mean_abs_difference)) 147ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden return False 148ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 149ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 150ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardendef get_top_value(input_values): 151ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden max_value = None 152ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden max_index = None 153ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden for index, value in enumerate(input_values.flatten()): 154ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden if max_value is None or value > max: 155ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden max_value = value 156ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden max_index = index 157ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden return max_index, max_value 158ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 159ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 16095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlowerdef test_graph(float_graph_def, input_map, output_names, log_graph=False): 161ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden """Runs the float graph through the rewriter and tests the results.""" 162e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_results = run_graph_def( 163e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def, input_map, 164e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [output_name + ":0" for output_name in output_names]) 165ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # TODO(petewarden): round test is currently failing because there is no 166ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # RoundToSteps op available. 167ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # round_rewriter = quantize_graph.GraphRewriter(float_graph_def, "round") 168ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # round_graph_def = round_rewriter.rewrite(output_name) 169ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # round_results = run_graph_def(round_graph_def, input_map, 170ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # [output_name + ":0"]) 171ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # assert are_tensors_near(expected, round_results[0], 1.0) 172ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # 173ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # TODO(petewarden): Add test for "quantize" mode. 174ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 175e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney eightbit_rewriter = quantize_graph.GraphRewriter( 176e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def, "eightbit", quantized_input_range=None) 177ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden eightbit_graph_def = eightbit_rewriter.rewrite(output_names) 178e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney eightbit_results = run_graph_def( 179e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney eightbit_graph_def, input_map, 180e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [output_name + ":0" for output_name in output_names]) 181ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden for expected, result in zip(float_results, eightbit_results): 182ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden assert are_tensors_near(expected, result, 1.0) 183ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 18495f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower if log_graph: 185e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney tf_logging.info("8bit:\n%s", str(eightbit_graph_def)) 18695f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 187ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # Test the weights_rounded mode. This uses the default bit_depth. 188ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden weights_rounded_rewriter = quantize_graph.GraphRewriter( 1894b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower float_graph_def, "weights_rounded", quantized_input_range=None) 190ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden weights_rounded_graph_def = weights_rounded_rewriter.rewrite(output_names) 1914b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower weights_rounded_results = run_graph_def( 1924b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower weights_rounded_graph_def, input_map, 1934b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower [output_name + ":0" for output_name in output_names]) 194ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden for expected, result in zip(float_results, weights_rounded_results): 195ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden assert are_tensors_near(expected, result, 1.0) 196ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 197ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 198e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyclass QuantizeGraphTest(test.TestCase): 199ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 200e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower def test_negative_const_problem(self): 201e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower shape_constant_name = "shape_constant" 202e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower shape_constant = quantize_graph.create_constant_node( 203e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape_constant_name, value=-0.8, dtype=dtypes.float32, shape=[1]) 204e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower quantization_result = quantize_graph.quantize_weight_eightbit( 205e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower shape_constant, b"MIN_COMBINED") 206e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower self.assertEqual(4, len(quantization_result)) 207e24388242026245244435235ea66fd3693942c67A. Unique TensorFlower 208ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_odd_padding_problem(self): 209ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden """Tests one error case we ran into in a real graph.""" 210ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_conv(1, 4, 4, 1, 3, 1, 2, b"SAME", 211ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], 212ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [1, 2, 3, 4, 5, 6, 7, 8, 9]) 213ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 214ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_mat_mul_tiny(self): 215ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # These tests are added to test the generate case where 216ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # min(matrix) == max(matrix), which used to cause problems. 217ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_mat_mul(1, 1, 1, [2], [3]) 218ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_mat_mul(1, 2, 1, [1], [2, 3]) 219ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_mat_mul(1, 1, 2, [1, 1], [1, 1]) 220ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_mat_mul(1, 1, 2, [0, 0], [1, 1]) 221ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # The general case. 222ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_mat_mul(1, 1, 2, [1, 2], [1, 2]) 223ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 224ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_mat_mul_small(self): 225ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_mat_mul(2, 4, 3, [1, 2, 3, 4, 5, 6], 226ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]) 227ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 228ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_conv(self): 229ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_conv(1, 4, 3, 1, 3, 1, 1, b"SAME", 230ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 231ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [1, 4, 7, 2, 5, 8, 3, 6, 9]) 232ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 233f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower def test_reshape(self): 234f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower """Tests that MatMul->Reshape->MatMul avoids extra quantize/dequantize.""" 235e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney 236f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower def make_matmul(name, a, b): 237f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower n = quantize_graph.create_node("MatMul", name, [a.name, b.name]) 238e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(n, "T", dtypes.float32) 239f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower quantize_graph.set_attr_bool(n, "transpose_a", False) 240f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower quantize_graph.set_attr_bool(n, "transpose_b", False) 241f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower return n 242f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower 243f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower # matmul_1 = input*weight_1 244f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower input_node = quantize_graph.create_constant_node( 245e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "input", value=[0, 1, 2, 3], dtype=dtypes.float32, shape=[4, 1]) 246f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower weight_1_node = quantize_graph.create_constant_node( 247e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "weight_1", 248e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[.5, .6, .7, .8, .9], 249e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 250e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[1, 5]) 251f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower matmul_1_node = make_matmul("matmul_1", input_node, weight_1_node) 252f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower 253f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower # Reshape 4x5 to 10x2. 254f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower new_shape_node = quantize_graph.create_constant_node( 255e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "new_shape_node", value=[10, 2], dtype=dtypes.int32, shape=[2]) 256f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower reshape_node = quantize_graph.create_node( 257f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower "Reshape", "reshape", [matmul_1_node.name, new_shape_node.name]) 258e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(reshape_node, "T", dtypes.float32) 259f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower 260f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower # matmul_2_node = reshape*weight_2 261f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower weight_2_node = quantize_graph.create_constant_node( 262e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "weight_2", value=[1.5, 2.5], dtype=dtypes.float32, shape=[2, 1]) 263f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower matmul_2_node = make_matmul("matmul_2", reshape_node, weight_2_node) 264f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower 265e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney g = graph_pb2.GraphDef() 266e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney g.node.extend([ 267e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_node, weight_1_node, matmul_1_node, new_shape_node, reshape_node, 268e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney weight_2_node, matmul_2_node 269e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney ]) 270f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower 271f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower # Test the graph 272f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower test_graph(g, {}, ["matmul_2"]) 273f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower 274f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower # Verify there is only one Quantize and one Requantize op. 275e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney eightbit_rewriter = quantize_graph.GraphRewriter( 276e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney g, "eightbit", quantized_input_range=None) 277f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower eightbit_graph_def = eightbit_rewriter.rewrite(["matmul_2"]) 278f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower 279f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower ops = [node.op for node in eightbit_graph_def.node] 280f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower # No quantize since all inputs are const and can be quantized up-front. 281f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) 282662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower self.assertEqual(1, ops.count("QuantizedReshape")) 283f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower 284f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower # One dequantize at the end. 285f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower self.assertEqual(1, ops.count("Dequantize")) 286f937100ce786aa5d4154eac7de3dd2db43f1e888A. Unique TensorFlower 287ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_quantize_array(self): 288ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # Test invalid parameters (empty array, or 0 buckets. 289e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney self.assertRaises(ValueError, quantize_graph.quantize_array, np.array([]), 290e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney 2) 291ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden self.assertRaises(ValueError, quantize_graph.quantize_array, 292ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden np.array([1, 2]), 0) 293ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # Test input array of length 1. 294ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden arr = np.array([1]) 295ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden qarr = quantize_graph.quantize_array(arr, 1) 296ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden self.assertEqual(arr, qarr) 297ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden qarr = quantize_graph.quantize_array(arr, 2) 298ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden self.assertEqual(arr, qarr) 299ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # Test input array with all elements equal. 300ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden arr = np.array([1, 1, 1]) 301ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden qarr = quantize_graph.quantize_array(arr, 10) 302ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden self.assertTrue((np.array([1, 1, 1]) == qarr).all()) 303ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden # Test "normal" input arrays. 304ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden arr = np.array([0, 0.3, 0.6, 1]) 305ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden qarr = quantize_graph.quantize_array(arr, 1) 306ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden self.assertTrue((np.array([0.5, 0.5, 0.5, 0.5]) == qarr).all()) 307ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden qarr = quantize_graph.quantize_array(arr, 2) 308ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden self.assertTrue((np.array([0.25, 0.25, 0.75, 0.75]) == qarr).all()) 309ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden qarr = quantize_graph.quantize_array(arr.reshape((2, 2)), 2) 310ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden self.assertTrue((np.array([[0.25, 0.25], [0.75, 0.75]]) == qarr).all()) 311ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 312662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower def test_non_float_concat(self): 313662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower concat_dim = quantize_graph.create_constant_node( 314e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "concat_dim", value=0, dtype=dtypes.int32, shape=[]) 315662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower a = quantize_graph.create_constant_node( 316e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "a", 317e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 318e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.int32, 319e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[2, 2, 3]) 320662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower b = quantize_graph.create_constant_node( 321e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "b", 322e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], 323e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.int32, 324e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[2, 2, 3]) 325e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney concat = quantize_graph.create_node("Concat", "concat", 326e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [concat_dim.name, a.name, b.name]) 327662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower quantize_graph.set_attr_int(concat, "N", 2) 328e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(concat, "T", dtypes.int32) 329662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower 330e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney g = graph_pb2.GraphDef() 331662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower g.node.extend([concat_dim, a, b, concat]) 332662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower test_graph(g, {}, [concat.name]) 333662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower 334662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower def test_non_float_reshape(self): 335662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower a = quantize_graph.create_constant_node( 336e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "a", 337e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 338e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.int32, 339e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[2, 2, 3]) 340662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower shape = quantize_graph.create_constant_node( 341e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "shape", value=[12], dtype=dtypes.int32, shape=[1]) 342e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney reshape = quantize_graph.create_node("Reshape", "reshape", 343e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [a.name, shape.name]) 344e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(reshape, "T", dtypes.int32) 345662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower 346e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney g = graph_pb2.GraphDef() 347662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower g.node.extend([a, shape, reshape]) 348662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower test_graph(g, {}, [reshape.name]) 349662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower 350ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_concat(self): 351ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden shape_constant_name = "shape_constant" 352ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_constant_name = "a_constant" 353ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_constant_name = "b_constant" 354ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden concat_name = "concat" 355ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 356e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 357e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape_constant = quantize_graph.create_constant_node( 358e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape_constant_name, value=0, dtype=dtypes.int32, shape=[]) 359ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([shape_constant]) 360e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant = quantize_graph.create_constant_node( 361e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_name, 362e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 363e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 364e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[2, 2, 3]) 365ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([a_constant]) 366e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant = quantize_graph.create_constant_node( 367e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_name, 368e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], 369e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 370e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[2, 2, 3]) 371ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([b_constant]) 372e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney concat_node = quantize_graph.create_node( 373e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "Concat", concat_name, 374e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [shape_constant_name, a_constant_name, b_constant_name]) 375ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_int(concat_node, "N", 2) 376e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32) 377ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([concat_node]) 378ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 379ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_graph(float_graph_def, {}, [concat_name]) 380ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 381662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower # Verify the concat is quantized. 382662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower eightbit_rewriter = quantize_graph.GraphRewriter( 383662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower float_graph_def, "eightbit", quantized_input_range=None) 384662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower eightbit_graph_def = eightbit_rewriter.rewrite([concat_name]) 385662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower 386662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower ops = [node.op for node in eightbit_graph_def.node] 387662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower self.assertEqual(1, ops.count("QuantizedConcat")) 388662533b85c66f198b779bea147397e1441f3e482A. Unique TensorFlower 389ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_multiple_outputs(self): 390ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden input_constant_name = "input_constant" 391ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden split_constant_name = "split_constant" 392ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden split_name = "split" 393ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden concat_constant_name = "concat_constant" 394ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden concat_name = "concat" 395ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 396e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 397e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant = quantize_graph.create_constant_node( 398e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant_name, 399e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 400e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 401e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[2, 6]) 402ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([input_constant]) 403e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney split_constant = quantize_graph.create_constant_node( 404e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney split_constant_name, value=1, dtype=dtypes.int32, shape=[]) 405ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([split_constant]) 406e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney split_node = quantize_graph.create_node( 407e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "Split", split_name, [split_constant_name, input_constant_name]) 408ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_int(split_node, "num_split", 2) 409e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(split_node, "T", dtypes.float32) 410ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([split_node]) 411e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney concat_constant = quantize_graph.create_constant_node( 412e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney concat_constant_name, value=1, dtype=dtypes.int32, shape=[]) 413ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([concat_constant]) 414e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney concat_node = quantize_graph.create_node( 415e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "Concat", concat_name, 416e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [concat_constant_name, split_name + ":0", split_name + ":1"]) 417ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_int(concat_node, "N", 2) 418e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32) 419ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([concat_node]) 420ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 421ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_graph(float_graph_def, {}, [concat_name]) 422ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 423ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_node_name_from_input(self): 424ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden self.assertEqual("SomeName", 425ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.node_name_from_input("^SomeName:2")) 426ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 427ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_unique_node_name_from_input(self): 428ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden self.assertEqual("__hat__SomeName__port__2", 429ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.unique_node_name_from_input("^SomeName:2")) 430ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 431ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_identity(self): 432ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden input_constant_name = "input_constant" 433ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden identity_name = "identity" 434e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 435e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant = quantize_graph.create_constant_node( 436e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant_name, 437e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 438e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 439e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[2, 6]) 440ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([input_constant]) 441ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden identity_node = quantize_graph.create_node("Identity", identity_name, 442ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [input_constant_name]) 443e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(identity_node, "T", dtypes.float32) 444ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([identity_node]) 4455328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower 4465328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower mul_name = "mul" 4475328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower mul_node = quantize_graph.create_node("Mul", mul_name, 4485328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower [identity_name, identity_name]) 449e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(mul_node, "T", dtypes.float32) 4505328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower float_graph_def.node.extend([mul_node]) 4515328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower 4525328a426fe2d76dabd833e774711b2d56f13f9a8A. Unique TensorFlower test_graph(float_graph_def, {}, [mul_name]) 453ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 454ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_keep_control_edges(self): 455ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden no_op_name = "no_op" 456ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_constant_name = "a_constant" 457ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_constant_name = "b_constant" 458ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_check_name = "a_check" 459ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_check_name = "b_check" 460ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_identity_name = "a_identity" 461ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_identity_name = "b_identity" 462ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden add_name = "add" 463e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney graph_def = graph_pb2.GraphDef() 464ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden no_op = quantize_graph.create_node("NoOp", no_op_name, []) 465ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([no_op]) 466e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant = quantize_graph.create_constant_node( 467e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_name, value=1, dtype=dtypes.float32, shape=[]) 468ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([a_constant]) 469ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name, 470ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [a_constant_name]) 471ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([a_check_node]) 472e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_identity_node = quantize_graph.create_node( 473e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "Identity", a_identity_name, 474e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [a_constant_name, "^" + a_check_name, "^" + no_op_name]) 475ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([a_identity_node]) 476e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant = quantize_graph.create_constant_node( 477e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_name, value=1, dtype=dtypes.float32, shape=[]) 478ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([b_constant]) 479ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name, 480ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [b_constant_name]) 481ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([b_check_node]) 482e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_identity_node = quantize_graph.create_node( 483e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "Identity", b_identity_name, [b_constant_name, "^" + b_check_name]) 484ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([b_identity_node]) 485ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden add_node = quantize_graph.create_node("Add", add_name, 486e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [a_identity_name, b_identity_name]) 487e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32) 488ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([add_node]) 489ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 490e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney expected_output = graph_pb2.GraphDef() 491ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden no_op = quantize_graph.create_node("NoOp", no_op_name, []) 492ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([no_op]) 493e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant = quantize_graph.create_constant_node( 494e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_name, value=1, dtype=dtypes.float32, shape=[]) 495ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([a_constant]) 496e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_identity_node = quantize_graph.create_node( 497e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "Identity", a_identity_name, [a_constant_name, "^" + no_op_name]) 498ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([a_identity_node]) 499e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant = quantize_graph.create_constant_node( 500e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_name, value=1, dtype=dtypes.float32, shape=[]) 501ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([b_constant]) 502ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden add_node = quantize_graph.create_node("Add", add_name, 503e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [a_identity_name, b_constant_name]) 504e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32) 505ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([add_node]) 506fe4ec328ac6210381cc6a0c9d98b35d35f8d9209A. Unique TensorFlower expected_output.versions.CopyFrom(graph_def.versions) 507fe4ec328ac6210381cc6a0c9d98b35d35f8d9209A. Unique TensorFlower expected_output.library.CopyFrom(graph_def.library) 508ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 5097c7014fd41cdf4e24f923b9e79c249d717aa508fPete Warden output = graph_util.remove_training_nodes(graph_def) 510ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden stripped_output = graph_util.extract_sub_graph(output, [add_name]) 511ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden self.assertProtoEquals(expected_output, stripped_output) 512ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 513ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_batch_norm(self): 514ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden input_constant_name = "input_constant" 515ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden mean_constant_name = "mean_constant" 516ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden variance_constant_name = "variance_constant" 517ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden beta_constant_name = "beta_constant" 518ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden gamma_constant_name = "gamma_constant" 519ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden batch_norm_name = "batch_norm" 520e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 521e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant = quantize_graph.create_constant_node( 522e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant_name, 523e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6], 524e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 525e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[1, 1, 6, 2]) 526ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([input_constant]) 527e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney mean_constant = quantize_graph.create_constant_node( 528e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney mean_constant_name, value=[10, 20], dtype=dtypes.float32, shape=[2]) 529ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([mean_constant]) 530ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden variance_constant = quantize_graph.create_constant_node( 531e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney variance_constant_name, 532e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[0.25, 0.5], 533e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 534e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[2]) 535ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([variance_constant]) 536e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney beta_constant = quantize_graph.create_constant_node( 537e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney beta_constant_name, value=[0.1, 0.6], dtype=dtypes.float32, shape=[2]) 538ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([beta_constant]) 539e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney gamma_constant = quantize_graph.create_constant_node( 540e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney gamma_constant_name, value=[0, 0], dtype=dtypes.float32, shape=[2]) 541ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([gamma_constant]) 542ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden batch_norm_node = quantize_graph.create_node( 543e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "BatchNormWithGlobalNormalization", batch_norm_name, [ 544e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant_name, mean_constant_name, variance_constant_name, 545e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney beta_constant_name, gamma_constant_name 546e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney ]) 547e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(batch_norm_node, "T", dtypes.float32) 548ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_bool(batch_norm_node, "scale_after_normalization", 549ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden False) 550ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_float(batch_norm_node, "variance_epsilon", 0.001) 551ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([batch_norm_node]) 552ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_graph(float_graph_def, {}, [batch_norm_name]) 553ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 554ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_max_pool(self): 555ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden input_constant_name = "input_constant" 556ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden max_pool_name = "max_pool" 557e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 558e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant = quantize_graph.create_constant_node( 559e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant_name, 560e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 561e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 562e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[1, 2, 6, 1]) 563ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([input_constant]) 564ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden max_pool_node = quantize_graph.create_node("MaxPool", max_pool_name, 565ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [input_constant_name]) 566ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_int_list(max_pool_node, "ksize", [1, 2, 2, 1]) 567ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_int_list(max_pool_node, "strides", [1, 1, 1, 1]) 568ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_string(max_pool_node, "padding", b"SAME") 569ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([max_pool_node]) 570ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_graph(float_graph_def, {}, [max_pool_name]) 571ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 572ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_avg_pool(self): 573ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden input_constant_name = "input_constant" 574ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden avg_pool_name = "avg_pool" 575e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 576e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant = quantize_graph.create_constant_node( 577e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant_name, 578e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 579e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 580e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[1, 2, 6, 1]) 581ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([input_constant]) 582ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden avg_pool_node = quantize_graph.create_node("AvgPool", avg_pool_name, 583ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [input_constant_name]) 584e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(avg_pool_node, "T", dtypes.float32) 585ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_int_list(avg_pool_node, "ksize", [1, 2, 2, 1]) 586ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_int_list(avg_pool_node, "strides", [1, 1, 1, 1]) 587ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden quantize_graph.set_attr_string(avg_pool_node, "padding", b"SAME") 588ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([avg_pool_node]) 589ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_graph(float_graph_def, {}, [avg_pool_name]) 590ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 591ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_relu(self): 592ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden input_constant_name = "input_constant" 593ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden relu_name = "relu" 594e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 595e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant = quantize_graph.create_constant_node( 596e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant_name, 597e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 598e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 599e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[1, 2, 6, 1]) 600ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([input_constant]) 601ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden relu_node = quantize_graph.create_node("Relu", relu_name, 602ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [input_constant_name]) 603e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(relu_node, "T", dtypes.float32) 604ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([relu_node]) 605ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_graph(float_graph_def, {}, [relu_name]) 606ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 60795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower def test_relu_w_fake_quant_w_min_max_vars(self): 60895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower input_node = quantize_graph.create_constant_node( 609e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "input", 610e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 611e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 612e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[1, 2, 6, 1]) 613e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney relu_node = quantize_graph.create_node("Relu", "relu", [input_node.name]) 614e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(relu_node, "T", dtypes.float32) 61595f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 61695f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower min_node = quantize_graph.create_constant_node( 617e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "min_bias_add", value=0, dtype=dtypes.float32, shape=[]) 61895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower max_node = quantize_graph.create_constant_node( 619e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "max_bias_add", value=12, dtype=dtypes.float32, shape=[]) 62095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower fake_quant_node = quantize_graph.create_node( 62195f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower "FakeQuantWithMinMaxVars", "fake_quant", 62295f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower [relu_node.name, min_node.name, max_node.name]) 62395f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 624e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 625e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def.node.extend( 626e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [input_node, relu_node, min_node, max_node, fake_quant_node]) 62795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True) 62895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 62995f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower # Verify there is only one Quantize and one Requantize op. 630e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney eightbit_rewriter = quantize_graph.GraphRewriter( 631e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def, "eightbit", quantized_input_range=None) 63295f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name]) 63395f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 63495f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower ops = [node.op for node in eightbit_graph_def.node] 63595f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower # No quantize since all inputs are const and can be quantized up-front. 63695f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) 63795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 63895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower # One dequantize at the end. 63995f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower self.assertEqual(1, ops.count("Dequantize")) 64095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 641ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_relu6(self): 642ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden input_constant_name = "input_constant" 643ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden relu6_name = "relu6" 644e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 645e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant = quantize_graph.create_constant_node( 646e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant_name, 647e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 648e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 649e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[1, 2, 6, 1]) 650ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([input_constant]) 651ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden relu6_node = quantize_graph.create_node("Relu6", relu6_name, 652ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden [input_constant_name]) 653e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(relu6_node, "T", dtypes.float32) 654ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([relu6_node]) 655ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_graph(float_graph_def, {}, [relu6_name]) 656ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 657ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_bias_add(self): 658ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden input_constant_name = "input_constant" 659ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden offset_constant_name = "offset_constant" 660ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden bias_add_name = "bias_add" 661e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 662e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant = quantize_graph.create_constant_node( 663e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_constant_name, 664e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 665e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 666e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[1, 1, 2, 6]) 667ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([input_constant]) 668e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney offset_constant = quantize_graph.create_constant_node( 669e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney offset_constant_name, 670e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6], 671e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 672e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[6]) 673ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([offset_constant]) 674e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney bias_add_node = quantize_graph.create_node( 675e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "BiasAdd", bias_add_name, [input_constant_name, offset_constant_name]) 676e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32) 677ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden float_graph_def.node.extend([bias_add_node]) 678ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden test_graph(float_graph_def, {}, [bias_add_name]) 679ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 6804b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower def test_quantized_input_range_errors(self): 6814b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower with self.assertRaises(ValueError): 6824b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower # Invalid mode. 683e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "weights_rounded", 684e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [0, 1]) 6854b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower with self.assertRaises(ValueError): 6864b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower # Invalid range. 687e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "eightbit", [0, -1]) 6884b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower 6894b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower def test_quantized_input_range_bias_add(self): 6904b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower input_shape = [1, 1, 2, 6] 69153cb26d05a5c2080d8022124178b1cc43a30ffe5A. Unique TensorFlower input_n = quantize_graph.create_node("Placeholder", "input", []) 692e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(input_n, "dtype", dtypes.float32) 6934b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower quantize_graph.set_attr_shape(input_n, "shape", input_shape) 694e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney offset_n = quantize_graph.create_constant_node( 695e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "offset", value=[1, 2, 3, 4, 5, 6], dtype=dtypes.float32, shape=[6]) 6964b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower bias_add_n = quantize_graph.create_node("BiasAdd", "bias_add", 6974b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower [input_n.name, offset_n.name]) 698e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(bias_add_n, "T", dtypes.float32) 6994b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower 700e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 7014b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower float_graph_def.node.extend([input_n, offset_n, bias_add_n]) 7024b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower 703e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_map = { 704e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_n.name + ":0": 705e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney np.reshape([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], input_shape) 706e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney } 707e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney self._RunTestsForQuantizedInputRange(float_graph_def, input_map, 708e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [bias_add_n.name], [-1, 20.]) 709e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney self._RunTestsForQuantizedInputRange(float_graph_def, input_map, 710e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [bias_add_n.name], [0, 12.]) 7114b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower 7124b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower def test_quantized_input_range_mat_mul(self): 7134b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower shapes = [[3, 2], [2, 4]] 7144b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower inputs = [] 7154b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower for i, shape in enumerate(shapes): 71653cb26d05a5c2080d8022124178b1cc43a30ffe5A. Unique TensorFlower node = quantize_graph.create_node("Placeholder", "input_%s" % i, []) 717e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(node, "dtype", dtypes.float32) 7184b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower quantize_graph.set_attr_shape(node, "shape", shape) 7194b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower inputs.append(node) 7204b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower mat_mul_node = quantize_graph.create_node("MatMul", "mat_mul", 7214b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower [n.name for n in inputs]) 722e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(mat_mul_node, "T", dtypes.float32) 7234b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower 724e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 7254b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower float_graph_def.node.extend(inputs + [mat_mul_node]) 7264b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower 727e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_map = { 728e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney inputs[0].name + ":0": 729e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney np.reshape([1, 2, 3, 4, 5, 6], shapes[0]), 730e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney inputs[1].name + ":0": 731e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney np.reshape([.8, .7, .6, .5, .4, .3, .2, .1], shapes[1]) 732e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney } 733e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney self._RunTestsForQuantizedInputRange(float_graph_def, input_map, 734e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [mat_mul_node.name], [-1, 20.]) 735e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney self._RunTestsForQuantizedInputRange(float_graph_def, input_map, 736e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [mat_mul_node.name], [0, 6.]) 7374b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower 7384b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower def _RunTestsForQuantizedInputRange(self, float_graph_def, input_map, 7394b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower output_names, input_range): 7404b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower if sys.version_info[0] == 3: 7414b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower # uint8->quint8 conversion for numpy is not working currently. 7424b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower return 7434b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower 7444b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower quantized_input_map = {} 7454b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower for k, v in input_map.items(): 7464b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower arr = [ 747e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney int( 748e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney round((n - input_range[0]) * 255 / (input_range[1] - input_range[ 749e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney 0]))) for n in v.flat 750e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney ] 7514b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower arr = np.array(arr, np.uint8) 7524b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower arr = arr.reshape(v.shape) 753e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney arr = arr.astype(dtypes.quint8.as_numpy_dtype) 7544b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower quantized_input_map[k] = arr 7554b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower output_tensors = [output_name + ":0" for output_name in output_names] 7564b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower float_results = run_graph_def(float_graph_def, input_map, output_tensors) 7574b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower 7584b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower # Quantize treating the input as quantized in range <input_range>. 7594b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower rewriter = quantize_graph.GraphRewriter(float_graph_def, "eightbit", 7604b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower input_range) 7614b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower graph_def = rewriter.rewrite(output_names) 7624b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower results = run_graph_def(graph_def, quantized_input_map, output_tensors) 7634b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower for expected, result in zip(float_results, results): 7644b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower assert are_tensors_near(expected, result, .5) 7654b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower ops = [node.op for node in graph_def.node] 7664b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) 7674b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower self.assertEqual(len(output_names), ops.count("Dequantize")) 7684b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower 7694b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower # Quantize without treating input as quantized. 770e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney rewriter = quantize_graph.GraphRewriter( 771e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def, "eightbit", quantized_input_range=None) 7724b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower graph_def = rewriter.rewrite(output_names) 7734b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower results = run_graph_def(graph_def, input_map, output_tensors) 7744b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower for expected, result in zip(float_results, results): 7754b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower assert are_tensors_near(expected, result, .5) 7764b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower ops = [node.op for node in graph_def.node] 777e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney self.assertEqual( 778e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney len(input_map), ops.count("QuantizeV2") + ops.count("Quantize")) 7794b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower self.assertEqual(len(output_names), ops.count("Dequantize")) 7804b352bd16489236c4df63da9ef4b794d77802f09A. Unique TensorFlower 78195f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower def test_bias_add_w_fake_quant_w_min_max_vars(self): 78295f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower input_node = quantize_graph.create_constant_node( 783e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "input", 784e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 785e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 786e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[1, 1, 2, 5]) 78795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower offset_node = quantize_graph.create_constant_node( 788e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5]) 78995f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower bias_add_node = quantize_graph.create_node( 79095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower "BiasAdd", "bias_add", [input_node.name, offset_node.name]) 791e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32) 79295f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 79395f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower min_node = quantize_graph.create_constant_node( 794e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "min_bias_add", value=-.5, dtype=dtypes.float32, shape=[]) 79595f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower max_node = quantize_graph.create_constant_node( 796e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "max_bias_add", value=15.5, dtype=dtypes.float32, shape=[]) 79795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower fake_quant_node = quantize_graph.create_node( 79895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower "FakeQuantWithMinMaxVars", "fake_quant", 79995f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower [bias_add_node.name, min_node.name, max_node.name]) 80095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 801e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 802e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def.node.extend([ 803e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney input_node, offset_node, bias_add_node, min_node, max_node, 804e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney fake_quant_node 805e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney ]) 80695f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True) 80795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 80895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower # Verify there is only one Quantize and one Requantize op. 80942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower # Pass in fallback_quantization_range, although it will have no effect 81042e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower # because the FakeQuantWithMinMaxVars are used instead. 81142e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower eightbit_rewriter = quantize_graph.GraphRewriter( 812e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def, 813e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "eightbit", 814e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantized_input_range=None, 81542e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower fallback_quantization_range=[-100, 100]) 81695f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name]) 81795f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 81895f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower ops = [node.op for node in eightbit_graph_def.node] 81942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower node_names = [node.name for node in eightbit_graph_def.node] 82095f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower # No quantize since all inputs are const and can be quantized up-front. 82195f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) 82295f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 82395f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower # One dequantize at the end. 82495f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower self.assertEqual(1, ops.count("Dequantize")) 82595f7166b8860f568f056a6c20ff626f6a7f069fcA. Unique TensorFlower 82642e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower # The fallback constants are not in the graph. 82742e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower self.assertEqual(0, node_names.count("fallback_quantization_min_value")) 82842e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower self.assertEqual(0, node_names.count("fallback_quantization_max_value")) 82942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower 83042e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower def test_bias_add_w_fallback_min_max_vars(self): 83142e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower input_node = quantize_graph.create_constant_node( 832e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "input", 833e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 834e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney dtype=dtypes.float32, 835e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney shape=[1, 1, 2, 5]) 83642e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower offset_node = quantize_graph.create_constant_node( 837e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5]) 83842e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower bias_add_node = quantize_graph.create_node( 83942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower "BiasAdd", "bias_add", [input_node.name, offset_node.name]) 840e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32) 84142e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower 842e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def = graph_pb2.GraphDef() 84342e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower float_graph_def.node.extend([input_node, offset_node, bias_add_node]) 84442e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower test_graph(float_graph_def, {}, [bias_add_node.name], log_graph=True) 84542e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower 84642e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower # Verify there is only one Quantize, one Requantize op, and no 84742e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower # RequantizationRange op. 84842e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower eightbit_rewriter = quantize_graph.GraphRewriter( 849e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney float_graph_def, 850e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "eightbit", 851e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantized_input_range=None, 85242e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower fallback_quantization_range=[-.5, 15.5]) 85342e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower eightbit_graph_def = eightbit_rewriter.rewrite([bias_add_node.name]) 85442e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower 85542e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower ops = [node.op for node in eightbit_graph_def.node] 85642e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower node_names = [node.name for node in eightbit_graph_def.node] 85742e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower # No quantize since all inputs are const and can be quantized up-front. 85842e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize")) 85942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower 86042e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower # One dequantize at the end. 86142e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower self.assertEqual(1, ops.count("Dequantize")) 86242e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower 86342e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower # No RequantizationRange 86442e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower self.assertEqual(0, ops.count("RequantizationRange")) 86542e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower 86642e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower # The fallback constants are in the graph. 86742e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower self.assertEqual(1, node_names.count("fallback_quantization_min_value")) 86842e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower self.assertEqual(1, node_names.count("fallback_quantization_max_value")) 86942e9d54c833f6c16b9c864a0cdb2191fceb0e7ddA. Unique TensorFlower 870ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden def test_remove_redundant_quantization(self): 871ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_constant_name = "a_constant" 872ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_constant_min_name = "a_constant_min" 873ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_constant_max_name = "a_constant_max" 874ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_dequantize_name = "a_dequantize" 875ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden a_quantize_name = "a_quantize" 876ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_constant_name = "b_constant" 877ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_constant_min_name = "b_constant_min" 878ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_constant_max_name = "b_constant_max" 879ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_dequantize_name = "b_dequantize" 880ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden b_quantize_name = "b_quantize" 881ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden mat_mul_name = "mat_mul" 882e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney graph_def = graph_pb2.GraphDef() 883e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant = quantize_graph.create_constant_node( 884e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[]) 885ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([a_constant]) 886e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_min = quantize_graph.create_constant_node( 887e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_min_name, value=2, dtype=dtypes.float32, shape=[]) 888ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([a_constant_min]) 889e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_max = quantize_graph.create_constant_node( 890e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_max_name, value=2, dtype=dtypes.float32, shape=[]) 891ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([a_constant_max]) 892e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_dequantize_node = quantize_graph.create_node( 893e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "Dequantize", a_dequantize_name, 894e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [a_constant_name, a_constant_min_name, a_constant_max_name]) 895e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(a_dequantize_node, "T", dtypes.uint8) 896ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([a_dequantize_node]) 897e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_quantize_node = quantize_graph.create_node( 898e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "QuantizeV2", a_quantize_name, 899e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2"]) 900e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(a_quantize_node, "T", dtypes.uint8) 901ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([a_quantize_node]) 902e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant = quantize_graph.create_constant_node( 903e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[]) 904ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([b_constant]) 905e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_min = quantize_graph.create_constant_node( 906e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_min_name, value=3, dtype=dtypes.float32, shape=[]) 907ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([b_constant_min]) 908e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_max = quantize_graph.create_constant_node( 909e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_max_name, value=3, dtype=dtypes.float32, shape=[]) 910ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([b_constant_max]) 911e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_dequantize_node = quantize_graph.create_node( 912e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "Dequantize", b_dequantize_name, 913e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [b_constant_name, b_constant_min_name, b_constant_max_name]) 914e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(b_dequantize_node, "T", dtypes.uint8) 915ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([b_dequantize_node]) 916e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_quantize_node = quantize_graph.create_node( 917e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney "QuantizeV2", b_quantize_name, 918e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2"]) 919e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(b_quantize_node, "T", dtypes.uint8) 920ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([b_quantize_node]) 921e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [ 922e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_quantize_name, b_quantize_name, a_quantize_name + ":1", 923e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2" 924e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney ]) 925e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8) 926e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32) 927ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden graph_def.node.extend([mat_mul_node]) 928ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 929e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney expected_output = graph_pb2.GraphDef() 930e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant = quantize_graph.create_constant_node( 931e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[]) 932ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([a_constant]) 933e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_min = quantize_graph.create_constant_node( 934e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_min_name, value=2, dtype=dtypes.float32, shape=[]) 935ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([a_constant_min]) 936e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_max = quantize_graph.create_constant_node( 937e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_max_name, value=2, dtype=dtypes.float32, shape=[]) 938ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([a_constant_max]) 939e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant = quantize_graph.create_constant_node( 940e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[]) 941ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([b_constant]) 942e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_min = quantize_graph.create_constant_node( 943e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_min_name, value=3, dtype=dtypes.float32, shape=[]) 944ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([b_constant_min]) 945e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_max = quantize_graph.create_constant_node( 946e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney b_constant_max_name, value=3, dtype=dtypes.float32, shape=[]) 947ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([b_constant_max]) 948e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [ 949e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_name, b_constant_name, a_constant_min_name, 950e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney a_constant_max_name, b_constant_min_name, b_constant_max_name 951e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney ]) 952e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8) 953e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32) 954ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden expected_output.node.extend([mat_mul_node]) 955fe4ec328ac6210381cc6a0c9d98b35d35f8d9209A. Unique TensorFlower expected_output.versions.CopyFrom(graph_def.versions) 956fe4ec328ac6210381cc6a0c9d98b35d35f8d9209A. Unique TensorFlower expected_output.library.CopyFrom(graph_def.library) 957ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 958e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney rewriter = quantize_graph.GraphRewriter( 959e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney graph_def, [mat_mul_name], quantized_input_range=None) 960ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden output = rewriter.remove_redundant_quantization(graph_def) 961ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name]) 962ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden self.assertProtoEquals(expected_output, stripped_output) 963ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 964ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Warden 965ca4e053aa52ab9a42467d4df814ca9272487dbdfPete Wardenif __name__ == "__main__": 966e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney test.main() 967