1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Definition of XLA test case."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import random
23import re
24
25import numpy as np
26
27from tensorflow.contrib.compiler import jit
28from tensorflow.core.framework import types_pb2
29from tensorflow.core.protobuf import config_pb2
30from tensorflow.python.client import session
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import random_seed
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import flags
37from tensorflow.python.platform import test
38from tensorflow.python.platform import tf_logging as logging
39
40FLAGS = flags.FLAGS
41
42flags.DEFINE_string('test_device', None,
43                    'Tensorflow device on which to place operators under test')
44flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.')
45flags.DEFINE_string('disabled_manifest', None,
46                    'Path to a file with a list of tests that should not run.')
47
48
49class XLATestCase(test.TestCase):
50  """XLA test cases are parameterized test cases."""
51
52  def __init__(self, method_name='runTest'):
53    super(XLATestCase, self).__init__(method_name)
54    self.device = FLAGS.test_device
55    self.has_custom_call = (self.device == 'XLA_CPU')
56    self._all_tf_types = set([
57        dtypes.as_dtype(types_pb2.DataType.Value(name))
58        for name in FLAGS.types.split(',')
59    ])
60    self.int_tf_types = set([
61        dtype for dtype in self._all_tf_types if dtype.is_integer
62    ])
63    self._float_tf_types = set([
64        dtype for dtype in self._all_tf_types if dtype.is_floating
65    ])
66    self.complex_tf_types = set([
67        dtype for dtype in self._all_tf_types if dtype.is_complex
68    ])
69    self._numeric_tf_types = set(
70        self.int_tf_types | self._float_tf_types | self.complex_tf_types)
71
72    self._all_types = set(
73        [dtype.as_numpy_dtype for dtype in self._all_tf_types])
74    self.int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
75    self._float_types = set(
76        [dtype.as_numpy_dtype for dtype in self._float_tf_types])
77    self.complex_types = set([
78        dtype.as_numpy_dtype for dtype in self.complex_tf_types
79    ])
80    self._numeric_types = set(
81        self.int_types | self._float_types | self.complex_types)
82
83    # Parse the manifest file, if any, into a regex identifying tests to
84    # disable
85    self.disabled_regex = None
86    self._method_types_filter = dict()
87    # TODO(xpan): Make it text proto if it doesn't scale.
88    # Each line of the manifest file specifies an entry. The entry can be
89    # 1) TestNameRegex  // E.g. CumprodTest.* Or
90    # 2) TestName TypeName  // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16
91    # The 1) disables the entire test. While 2) only filter some numeric types
92    # so that they are not used in those tests.
93
94    if FLAGS.disabled_manifest is not None:
95      comments_re = re.compile('#.*$')
96      manifest_file = open(FLAGS.disabled_manifest, 'r')
97      disabled_tests = []
98      disabled_method_types = []
99      for l in manifest_file.read().splitlines():
100        entry = comments_re.sub('', l).strip().split(' ')
101        if len(entry) == 1:
102          disabled_tests.append(entry[0])
103        elif len(entry) == 2:
104          disabled_method_types.append(
105              (entry[0], entry[1].strip().split(',')))
106        else:
107          raise ValueError('Bad entry in manifest file.')
108
109      self.disabled_regex = re.compile('|'.join(disabled_tests))
110      for method, types in disabled_method_types:
111        self._method_types_filter[method] = set([
112            dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
113            for name in types])
114      manifest_file.close()
115
116  @property
117  def all_tf_types(self):
118    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
119    tf_types = set([dtypes.as_dtype(t)
120                    for t in self._method_types_filter.get(name, set())])
121    return self._all_tf_types - tf_types
122
123  @property
124  def float_types(self):
125    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
126    return self._float_types - self._method_types_filter.get(name, set())
127
128  @property
129  def float_tf_types(self):
130    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
131    return self._float_tf_types - self._method_types_filter.get(name, set())
132
133  @property
134  def numeric_tf_types(self):
135    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
136    tf_types = set([dtypes.as_dtype(t)
137                    for t in self._method_types_filter.get(name, set())])
138    return self._numeric_tf_types - tf_types
139
140  @property
141  def numeric_types(self):
142    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
143    return self._numeric_types - self._method_types_filter.get(name, set())
144
145  @property
146  def all_types(self):
147    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
148    return self._all_types - self._method_types_filter.get(name, set())
149
150  def setUp(self):
151    super(XLATestCase, self).setUp()
152    name = '{}.{}'.format(type(self).__name__, self._testMethodName)
153    if self.disabled_regex is not None and self.disabled_regex.match(name):
154      logging.info('Disabled test case: %s', name)
155      self.skipTest('{} is disabled by manifest.'.format(name))
156      return
157    logging.info('Start test case: %s', name)
158
159    random.seed(random_seed.DEFAULT_GRAPH_SEED)
160    np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
161
162  def tearDown(self):
163    super(XLATestCase, self).tearDown()
164    logging.info('End test case: %s', self._testMethodName)
165
166  @contextlib.contextmanager
167  def test_session(self):
168    """Custom implementation of test_session() for XLA tests.
169
170    We override the standard Tensorflow test_session() since it is too
171    specific to CPU and GPU tests. In particular, we want to disable soft
172    placement and explicitly assign ops to devices under test.
173
174    Yields:
175      A session to use when running a test case.
176    """
177    graph = ops.Graph()
178    with session.Session(graph=graph) as sess, graph.as_default():
179      yield sess
180
181  @contextlib.contextmanager
182  def test_scope(self):
183    """Test scope that runs tests on a Tensorflow/XLA device.
184
185    Uses a compilation_scope() to mark operators to compile.
186
187    Yields:
188      A scope to apply to the operators under test.
189    """
190    with ops.device('device:{}:0'.format(self.device)):
191      yield
192
193
194def Benchmark(tf_bench,
195              builder_fn,
196              use_xla_jit,
197              device,
198              separate_compiled_gradients=False):
199  """Build a graph and run benchmarks against it, with or without XLA.
200
201  Args:
202    tf_bench: An instance of tf.test.Benchmark, used to run the benchmark.
203    builder_fn: A function that builds a graph when invoked, and returns
204        (name, fetches), where name is the name of the test, and fetches
205        is a list of tensors to fetch as output.
206    use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF.
207    device: The tensorflow device to run on, e.g. "cpu", "gpu".
208    separate_compiled_gradients: If true put each gradient subgraph into a
209      separate compilation scope. This gives fine-grained control over which
210      portions of the graph will be compiled as a single unit. Compiling
211      gradients separately may yield better performance for some graphs.
212      The scope is named based on the scope of the forward computation as well
213      as the name of the gradients. As a result, the gradients will be compiled
214      in a scope that is separate from both the forward computation, and from
215      other gradients.
216  """
217
218  with ops.Graph().as_default():
219    name = None
220    targets = []
221    with ops.device(device):
222      fetches = []
223      jit_scope = jit.experimental_jit_scope
224      with jit_scope(
225          compile_ops=use_xla_jit,
226          separate_compiled_gradients=separate_compiled_gradients):
227        name, fetches = builder_fn()
228
229      # We only want to benchmark the operations themselves, and not the data
230      # transfer of the result(s).  Non-compiled identity ops ensure XLA
231      # doesn't know we're dropping the results, otherwise it might compile
232      # away the entire computation.
233      for fetch in fetches:
234        targets.append(array_ops.identity(fetch).op)
235
236    config = config_pb2.ConfigProto(allow_soft_placement=True)
237    with session.Session(config=config) as sess:
238      sess.run(variables.global_variables_initializer())
239      xla = 'xla_' if use_xla_jit else ''
240      tf_bench.run_op_benchmark(
241          sess, targets, name='%s_%s%s' % (name, xla, device))
242