1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Lite Python Interface: Sanity check."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.contrib.lite.python import lite
21from tensorflow.contrib.lite.python.op_hint import _tensor_name_base as _tensor_name_base
22from tensorflow.python.client import session
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import test_util
25from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
26from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.platform import test
30
31
32class LiteTest(test_util.TensorFlowTestCase):
33
34  def testBasic(self):
35    in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
36                                      dtype=dtypes.float32)
37    out_tensor = in_tensor + in_tensor
38    sess = session.Session()
39    # Try running on valid graph
40    result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
41    self.assertTrue(result)
42    # TODO(aselle): remove tests that fail (we must get TOCO to not fatal
43    # all the time).
44    # Try running on identity graph (known fail)
45    # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
46    #   result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
47
48  def testQuantization(self):
49    in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
50                                      dtype=dtypes.float32)
51    out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor,
52                                                        min=0., max=1.)
53    sess = session.Session()
54    result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor],
55                               inference_type=lite.QUANTIZED_UINT8,
56                               quantized_input_stats=[(0., 1.)])
57    self.assertTrue(result)
58
59
60class LiteTestOpHint(test_util.TensorFlowTestCase):
61  """Test the hint to stub functionality."""
62
63  def _getGraphOpTypes(self, graphdef, output_nodes):
64    """Returns used op types in `graphdef` reachable from `output_nodes`.
65
66    This is used to check that after the stub transformation the expected
67    nodes are there. Typically use this with self.assertCountEqual(...).
68
69    NOTE: this is not a exact test that the graph is the correct output, but
70      it balances compact expressibility of test with sanity checking.
71
72    Args:
73      graphdef: TensorFlow proto graphdef.
74      output_nodes: A list of output node names that we need to reach.
75
76    Returns:
77      A set of node types reachable from `output_nodes`.
78    """
79    name_to_input_name, name_to_node, _ = (
80        _extract_graph_summary(graphdef))
81    # Find all nodes that are needed by the outputs
82    used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name)
83    return set([name_to_node[node_name].op for node_name in used_node_names])
84
85  def _countIdentities(self, nodes):
86    """Count the number of "Identity" op types in the list of proto nodes.
87
88    Args:
89      nodes: NodeDefs of the graph.
90
91    Returns:
92      The number of nodes with op type "Identity" found.
93    """
94    return len([x for x in nodes if x.op == "Identity"])
95
96  def testSwishLiteHint(self):
97    """Makes a custom op swish and makes sure it gets converted as a unit."""
98    image = array_ops.constant([1., 2., 3., 4.])
99    swish_scale = array_ops.constant(1.0)
100
101    def _swish(input_tensor, scale):
102      custom = lite.OpHint("cool_activation")
103      input_tensor, scale = custom.add_inputs(input_tensor, scale)
104      output = math_ops.sigmoid(input_tensor) * input_tensor * scale
105      output, = custom.add_outputs(output)
106      return output
107    output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput")
108
109    with self.test_session() as sess:
110      # check if identities have been put into the graph (2 input, 1 output,
111      # and 1 final output).
112      self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
113
114      stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
115
116      self.assertCountEqual(
117          self._getGraphOpTypes(
118              stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
119          ["cool_activation", "Const", "Identity"])
120
121  def testScaleAndBiasAndIdentity(self):
122    """This tests a scaled add which has 3 inputs and 2 outputs."""
123    a = array_ops.constant(1.)
124    x = array_ops.constant([2., 3.])
125    b = array_ops.constant([4., 5.])
126
127    def _scaled_and_bias_and_identity(a, x, b):
128      custom = lite.OpHint("scale_and_bias_and_identity")
129      a, x, b = custom.add_inputs(a, x, b)
130      return custom.add_outputs(a * x + b, x)
131    output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b),
132                                name="ModelOutput")
133
134    with self.test_session() as sess:
135      # make sure one identity for each input (3) and output (2) => 3 + 2 = 5
136      # +1 for the final output
137      self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
138
139      stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
140
141      self.assertCountEqual(
142          self._getGraphOpTypes(
143              stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
144          ["scale_and_bias_and_identity", "Const", "Identity", "Pack"])
145
146  def testTwoFunctions(self):
147    """Tests if two functions are converted correctly."""
148    a = array_ops.constant([1.])
149    b = array_ops.constant([1.])
150    def _double_values(x):
151      custom = lite.OpHint("add_test")
152      x = custom.add_inputs(x)
153      output = math_ops.multiply(x, x)
154      output, = custom.add_outputs(output)
155      return output
156    output = array_ops.identity(
157        math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput")
158
159    with self.test_session() as sess:
160      # make sure one identity for each input (2) and output (2) => 2 + 2
161      # +1 for the final output
162      self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
163      stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
164      self.assertCountEqual(
165          self._getGraphOpTypes(
166              stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
167          ["add_test", "Const", "Identity", "Add"])
168
169
170if __name__ == "__main__":
171  test.main()
172