1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Definition of XLA test case.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import random 23import re 24 25import numpy as np 26 27from tensorflow.contrib.compiler import jit 28from tensorflow.core.framework import types_pb2 29from tensorflow.core.protobuf import config_pb2 30from tensorflow.python.client import session 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import random_seed 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import flags 37from tensorflow.python.platform import test 38from tensorflow.python.platform import tf_logging as logging 39 40FLAGS = flags.FLAGS 41 42flags.DEFINE_string('test_device', None, 43 'Tensorflow device on which to place operators under test') 44flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.') 45flags.DEFINE_string('disabled_manifest', None, 46 'Path to a file with a list of tests that should not run.') 47 48 49class XLATestCase(test.TestCase): 50 """XLA test cases are parameterized test cases.""" 51 52 def __init__(self, method_name='runTest'): 53 super(XLATestCase, self).__init__(method_name) 54 self.device = FLAGS.test_device 55 self.has_custom_call = (self.device == 'XLA_CPU') 56 self._all_tf_types = set([ 57 dtypes.as_dtype(types_pb2.DataType.Value(name)) 58 for name in FLAGS.types.split(',') 59 ]) 60 self.int_tf_types = set([ 61 dtype for dtype in self._all_tf_types if dtype.is_integer 62 ]) 63 self._float_tf_types = set([ 64 dtype for dtype in self._all_tf_types if dtype.is_floating 65 ]) 66 self.complex_tf_types = set([ 67 dtype for dtype in self._all_tf_types if dtype.is_complex 68 ]) 69 self._numeric_tf_types = set( 70 self.int_tf_types | self._float_tf_types | self.complex_tf_types) 71 72 self._all_types = set( 73 [dtype.as_numpy_dtype for dtype in self._all_tf_types]) 74 self.int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) 75 self._float_types = set( 76 [dtype.as_numpy_dtype for dtype in self._float_tf_types]) 77 self.complex_types = set([ 78 dtype.as_numpy_dtype for dtype in self.complex_tf_types 79 ]) 80 self._numeric_types = set( 81 self.int_types | self._float_types | self.complex_types) 82 83 # Parse the manifest file, if any, into a regex identifying tests to 84 # disable 85 self.disabled_regex = None 86 self._method_types_filter = dict() 87 # TODO(xpan): Make it text proto if it doesn't scale. 88 # Each line of the manifest file specifies an entry. The entry can be 89 # 1) TestNameRegex // E.g. CumprodTest.* Or 90 # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 91 # The 1) disables the entire test. While 2) only filter some numeric types 92 # so that they are not used in those tests. 93 94 if FLAGS.disabled_manifest is not None: 95 comments_re = re.compile('#.*$') 96 manifest_file = open(FLAGS.disabled_manifest, 'r') 97 disabled_tests = [] 98 disabled_method_types = [] 99 for l in manifest_file.read().splitlines(): 100 entry = comments_re.sub('', l).strip().split(' ') 101 if len(entry) == 1: 102 disabled_tests.append(entry[0]) 103 elif len(entry) == 2: 104 disabled_method_types.append( 105 (entry[0], entry[1].strip().split(','))) 106 else: 107 raise ValueError('Bad entry in manifest file.') 108 109 self.disabled_regex = re.compile('|'.join(disabled_tests)) 110 for method, types in disabled_method_types: 111 self._method_types_filter[method] = set([ 112 dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype 113 for name in types]) 114 manifest_file.close() 115 116 @property 117 def all_tf_types(self): 118 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 119 tf_types = set([dtypes.as_dtype(t) 120 for t in self._method_types_filter.get(name, set())]) 121 return self._all_tf_types - tf_types 122 123 @property 124 def float_types(self): 125 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 126 return self._float_types - self._method_types_filter.get(name, set()) 127 128 @property 129 def float_tf_types(self): 130 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 131 return self._float_tf_types - self._method_types_filter.get(name, set()) 132 133 @property 134 def numeric_tf_types(self): 135 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 136 tf_types = set([dtypes.as_dtype(t) 137 for t in self._method_types_filter.get(name, set())]) 138 return self._numeric_tf_types - tf_types 139 140 @property 141 def numeric_types(self): 142 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 143 return self._numeric_types - self._method_types_filter.get(name, set()) 144 145 @property 146 def all_types(self): 147 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 148 return self._all_types - self._method_types_filter.get(name, set()) 149 150 def setUp(self): 151 super(XLATestCase, self).setUp() 152 name = '{}.{}'.format(type(self).__name__, self._testMethodName) 153 if self.disabled_regex is not None and self.disabled_regex.match(name): 154 logging.info('Disabled test case: %s', name) 155 self.skipTest('{} is disabled by manifest.'.format(name)) 156 return 157 logging.info('Start test case: %s', name) 158 159 random.seed(random_seed.DEFAULT_GRAPH_SEED) 160 np.random.seed(random_seed.DEFAULT_GRAPH_SEED) 161 162 def tearDown(self): 163 super(XLATestCase, self).tearDown() 164 logging.info('End test case: %s', self._testMethodName) 165 166 @contextlib.contextmanager 167 def test_session(self): 168 """Custom implementation of test_session() for XLA tests. 169 170 We override the standard Tensorflow test_session() since it is too 171 specific to CPU and GPU tests. In particular, we want to disable soft 172 placement and explicitly assign ops to devices under test. 173 174 Yields: 175 A session to use when running a test case. 176 """ 177 graph = ops.Graph() 178 with session.Session(graph=graph) as sess, graph.as_default(): 179 yield sess 180 181 @contextlib.contextmanager 182 def test_scope(self): 183 """Test scope that runs tests on a Tensorflow/XLA device. 184 185 Uses a compilation_scope() to mark operators to compile. 186 187 Yields: 188 A scope to apply to the operators under test. 189 """ 190 with ops.device('device:{}:0'.format(self.device)): 191 yield 192 193 194def Benchmark(tf_bench, 195 builder_fn, 196 use_xla_jit, 197 device, 198 separate_compiled_gradients=False): 199 """Build a graph and run benchmarks against it, with or without XLA. 200 201 Args: 202 tf_bench: An instance of tf.test.Benchmark, used to run the benchmark. 203 builder_fn: A function that builds a graph when invoked, and returns 204 (name, fetches), where name is the name of the test, and fetches 205 is a list of tensors to fetch as output. 206 use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF. 207 device: The tensorflow device to run on, e.g. "cpu", "gpu". 208 separate_compiled_gradients: If true put each gradient subgraph into a 209 separate compilation scope. This gives fine-grained control over which 210 portions of the graph will be compiled as a single unit. Compiling 211 gradients separately may yield better performance for some graphs. 212 The scope is named based on the scope of the forward computation as well 213 as the name of the gradients. As a result, the gradients will be compiled 214 in a scope that is separate from both the forward computation, and from 215 other gradients. 216 """ 217 218 with ops.Graph().as_default(): 219 name = None 220 targets = [] 221 with ops.device(device): 222 fetches = [] 223 jit_scope = jit.experimental_jit_scope 224 with jit_scope( 225 compile_ops=use_xla_jit, 226 separate_compiled_gradients=separate_compiled_gradients): 227 name, fetches = builder_fn() 228 229 # We only want to benchmark the operations themselves, and not the data 230 # transfer of the result(s). Non-compiled identity ops ensure XLA 231 # doesn't know we're dropping the results, otherwise it might compile 232 # away the entire computation. 233 for fetch in fetches: 234 targets.append(array_ops.identity(fetch).op) 235 236 config = config_pb2.ConfigProto(allow_soft_placement=True) 237 with session.Session(config=config) as sess: 238 sess.run(variables.global_variables_initializer()) 239 xla = 'xla_' if use_xla_jit else '' 240 tf_bench.run_op_benchmark( 241 sess, targets, name='%s_%s%s' % (name, xla, device)) 242