1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow Lite Python Interface: Sanity check.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.contrib.lite.python import lite 21from tensorflow.contrib.lite.python.op_hint import _tensor_name_base as _tensor_name_base 22from tensorflow.python.client import session 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import test_util 25from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes 26from tensorflow.python.framework.graph_util_impl import _extract_graph_summary 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.platform import test 30 31 32class LiteTest(test_util.TensorFlowTestCase): 33 34 def testBasic(self): 35 in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], 36 dtype=dtypes.float32) 37 out_tensor = in_tensor + in_tensor 38 sess = session.Session() 39 # Try running on valid graph 40 result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor]) 41 self.assertTrue(result) 42 # TODO(aselle): remove tests that fail (we must get TOCO to not fatal 43 # all the time). 44 # Try running on identity graph (known fail) 45 # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"): 46 # result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor]) 47 48 def testQuantization(self): 49 in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], 50 dtype=dtypes.float32) 51 out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor, 52 min=0., max=1.) 53 sess = session.Session() 54 result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor], 55 inference_type=lite.QUANTIZED_UINT8, 56 quantized_input_stats=[(0., 1.)]) 57 self.assertTrue(result) 58 59 60class LiteTestOpHint(test_util.TensorFlowTestCase): 61 """Test the hint to stub functionality.""" 62 63 def _getGraphOpTypes(self, graphdef, output_nodes): 64 """Returns used op types in `graphdef` reachable from `output_nodes`. 65 66 This is used to check that after the stub transformation the expected 67 nodes are there. Typically use this with self.assertCountEqual(...). 68 69 NOTE: this is not a exact test that the graph is the correct output, but 70 it balances compact expressibility of test with sanity checking. 71 72 Args: 73 graphdef: TensorFlow proto graphdef. 74 output_nodes: A list of output node names that we need to reach. 75 76 Returns: 77 A set of node types reachable from `output_nodes`. 78 """ 79 name_to_input_name, name_to_node, _ = ( 80 _extract_graph_summary(graphdef)) 81 # Find all nodes that are needed by the outputs 82 used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name) 83 return set([name_to_node[node_name].op for node_name in used_node_names]) 84 85 def _countIdentities(self, nodes): 86 """Count the number of "Identity" op types in the list of proto nodes. 87 88 Args: 89 nodes: NodeDefs of the graph. 90 91 Returns: 92 The number of nodes with op type "Identity" found. 93 """ 94 return len([x for x in nodes if x.op == "Identity"]) 95 96 def testSwishLiteHint(self): 97 """Makes a custom op swish and makes sure it gets converted as a unit.""" 98 image = array_ops.constant([1., 2., 3., 4.]) 99 swish_scale = array_ops.constant(1.0) 100 101 def _swish(input_tensor, scale): 102 custom = lite.OpHint("cool_activation") 103 input_tensor, scale = custom.add_inputs(input_tensor, scale) 104 output = math_ops.sigmoid(input_tensor) * input_tensor * scale 105 output, = custom.add_outputs(output) 106 return output 107 output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput") 108 109 with self.test_session() as sess: 110 # check if identities have been put into the graph (2 input, 1 output, 111 # and 1 final output). 112 self.assertEqual(self._countIdentities(sess.graph_def.node), 4) 113 114 stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) 115 116 self.assertCountEqual( 117 self._getGraphOpTypes( 118 stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), 119 ["cool_activation", "Const", "Identity"]) 120 121 def testScaleAndBiasAndIdentity(self): 122 """This tests a scaled add which has 3 inputs and 2 outputs.""" 123 a = array_ops.constant(1.) 124 x = array_ops.constant([2., 3.]) 125 b = array_ops.constant([4., 5.]) 126 127 def _scaled_and_bias_and_identity(a, x, b): 128 custom = lite.OpHint("scale_and_bias_and_identity") 129 a, x, b = custom.add_inputs(a, x, b) 130 return custom.add_outputs(a * x + b, x) 131 output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b), 132 name="ModelOutput") 133 134 with self.test_session() as sess: 135 # make sure one identity for each input (3) and output (2) => 3 + 2 = 5 136 # +1 for the final output 137 self.assertEqual(self._countIdentities(sess.graph_def.node), 6) 138 139 stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) 140 141 self.assertCountEqual( 142 self._getGraphOpTypes( 143 stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), 144 ["scale_and_bias_and_identity", "Const", "Identity", "Pack"]) 145 146 def testTwoFunctions(self): 147 """Tests if two functions are converted correctly.""" 148 a = array_ops.constant([1.]) 149 b = array_ops.constant([1.]) 150 def _double_values(x): 151 custom = lite.OpHint("add_test") 152 x = custom.add_inputs(x) 153 output = math_ops.multiply(x, x) 154 output, = custom.add_outputs(output) 155 return output 156 output = array_ops.identity( 157 math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput") 158 159 with self.test_session() as sess: 160 # make sure one identity for each input (2) and output (2) => 2 + 2 161 # +1 for the final output 162 self.assertEqual(self._countIdentities(sess.graph_def.node), 5) 163 stubbed_graphdef = lite.convert_op_hints_to_stubs(sess) 164 self.assertCountEqual( 165 self._getGraphOpTypes( 166 stubbed_graphdef, output_nodes=[_tensor_name_base(output)]), 167 ["add_test", "Const", "Identity", "Add"]) 168 169 170if __name__ == "__main__": 171 test.main() 172