1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.tf.BatchMatMul.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import gradient_checker 26from tensorflow.python.ops import math_ops 27from tensorflow.python.platform import test 28 29 30class BatchMatmulOpTest(test.TestCase): 31 32 # Uses numpy to compute batch_matmul(x, y, adjoint_a, adjoint_b). 33 def _npBatchMatmul(self, x, y, adjoint_a, adjoint_b): 34 # output's shape depends on adj[0] and adj[1] 35 d0 = x.shape[-2] if not adjoint_a else x.shape[-1] 36 d2 = y.shape[-1] if not adjoint_b else y.shape[-2] 37 batch_dims = x.shape[:-2] 38 num = np.prod(batch_dims) 39 z = np.empty(list(batch_dims) + [d0, d2], dtype=x.dtype) 40 xr = x.reshape([num, x.shape[-2], x.shape[-1]]) 41 yr = y.reshape([num, y.shape[-2], y.shape[-1]]) 42 zr = z.reshape([num, z.shape[-2], z.shape[-1]]) 43 for i in range(num): 44 a = np.matrix(xr[i, :, :]) 45 if adjoint_a: 46 a = a.transpose().conj() 47 b = np.matrix(yr[i, :, :]) 48 if adjoint_b: 49 b = b.transpose().conj() 50 zr[i, :, :] = a * b 51 return z 52 53 # Test _npBatchMatMul works. 54 def testNpVersion(self): 55 x = np.array([0., 1., 2., 3.]).reshape([1, 2, 2]) 56 y = np.array([1., 2., 3., 4.]).reshape([1, 2, 2]) 57 z0 = self._npBatchMatmul(x, y, False, False) 58 z1 = np.array([3., 4., 11., 16.]).reshape([1, 2, 2]) 59 self.assertTrue(np.array_equal(z0, z1)) 60 61 x = np.array([1., (1j), (-1.), (-1j)]).reshape([1, 2, 2]) 62 y = x * np.complex(1, 1) # rotate x 90 degree 63 z0 = self._npBatchMatmul(x, y, False, False) 64 z1 = np.array([2., (2.j), -2., (-2.j)]).reshape([1, 2, 2]) 65 self.assertTrue(np.array_equal(z0, z1)) 66 67 z0 = self._npBatchMatmul(x, y, False, True) 68 z1 = np.array([(2. - 2.j), (-2. + 2.j), (-2. + 2.j), (2. - 2.j)]).reshape( 69 [1, 2, 2]) 70 self.assertTrue(np.array_equal(z0, z1)) 71 72 z0 = self._npBatchMatmul(x, y, True, False) 73 z1 = np.array([(2. + 2.j), (-2. + 2.j), (2. - 2.j), (2. + 2.j)]).reshape( 74 [1, 2, 2]) 75 self.assertTrue(np.array_equal(z0, z1)) 76 77 # Compares _tfpBatchMatmul(x, y, alpha, adj) and _npBatchMatMul(x, y, alpha, 78 # adj) 79 def _compare(self, x_in, y_in, adjoint_a, adjoint_b, static_shape=True): 80 x_t_shape = x_in.shape[:-2] + (x_in.shape[-1], x_in.shape[-2]) 81 y_t_shape = y_in.shape[:-2] + (y_in.shape[-1], y_in.shape[-2]) 82 x = x_in if not adjoint_a else x_in.reshape(x_t_shape) 83 y = y_in if not adjoint_b else y_in.reshape(y_t_shape) 84 is_floating = x.dtype != np.int32 85 tol = 100 * np.finfo(x.dtype).eps if is_floating else 0 86 with self.test_session(use_gpu=is_floating) as sess: 87 if static_shape: 88 z0 = math_ops.matmul(x, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b) 89 z0_val = z0.eval() 90 else: 91 x_ph = array_ops.placeholder(x.dtype) 92 y_ph = array_ops.placeholder(y.dtype) 93 z0 = math_ops.matmul( 94 x_ph, y_ph, adjoint_a=adjoint_a, adjoint_b=adjoint_b) 95 z0_val = sess.run(z0, feed_dict={x_ph: x, y_ph: y}) 96 z1 = self._npBatchMatmul(x, y, adjoint_a, adjoint_b) 97 self.assertAllClose(z0_val, z1, rtol=tol, atol=tol) 98 99 def _rand(self, shape, dtype): 100 vals = np.array(np.random.normal(-10, 10, np.prod(shape)), dtype=dtype) 101 if dtype in (np.complex64, np.complex128): 102 imag = np.array(np.random.normal(-10, 10, np.prod(shape)), dtype=dtype) 103 vals += 1j * imag 104 return vals.reshape(shape) 105 106 def _testNonEmpty(self, dtype, adjoint_a, adjoint_b, use_static_shape): 107 108 def compareNonEmpty(self, a_shape, b_shape): 109 self._compare( 110 self._rand(a_shape, dtype), 111 self._rand(b_shape, dtype), adjoint_a, adjoint_b, use_static_shape) 112 113 compareNonEmpty(self, [1, 2, 3], [1, 3, 5]) 114 compareNonEmpty(self, [1, 2, 3], [1, 3, 1]) 115 compareNonEmpty(self, [1, 1, 3], [1, 3, 5]) 116 compareNonEmpty(self, [1, 2, 3], [1, 3, 5]) 117 compareNonEmpty(self, [7, 1, 3], [7, 3, 5]) 118 compareNonEmpty(self, [7, 2, 3], [7, 3, 1]) 119 compareNonEmpty(self, [7, 2, 3], [7, 3, 5]) 120 compareNonEmpty(self, [10, 64, 75], [10, 75, 30]) 121 compareNonEmpty(self, [5, 7, 2, 3], [5, 7, 3, 5]) 122 123 def _testEmpty(self, dtype, adjoint_a, adjoint_b, use_static_shape): 124 125 def compareEmpty(self, a_shape, b_shape): 126 self._compare( 127 np.zeros(a_shape).astype(dtype), 128 np.zeros(b_shape).astype(dtype), adjoint_a, adjoint_b, 129 use_static_shape) 130 131 compareEmpty(self, [0, 3, 2], [0, 2, 4]) 132 compareEmpty(self, [3, 0, 2], [3, 2, 5]) 133 compareEmpty(self, [3, 3, 2], [3, 2, 0]) 134 135 136def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape): 137 138 def Test(self): 139 np.random.seed(42) 140 self._testNonEmpty(dtype, adjoint_a, adjoint_b, use_static_shape) 141 self._testEmpty(dtype, adjoint_a, adjoint_b, use_static_shape) 142 143 return Test 144 145 146class BatchMatmulGradientTest(test.TestCase): 147 148 # loss = sum(batch_matmul(x, y)). Verify dl/dx and dl/dy via the 149 # gradient checker. 150 def _checkGrad(self, x_in, y_in, adjoint_a, adjoint_b): 151 x_t_shape = x_in.shape[:-2] + (x_in.shape[-1], x_in.shape[-2]) 152 y_t_shape = y_in.shape[:-2] + (y_in.shape[-1], y_in.shape[-2]) 153 x = x_in if not adjoint_a else x_in.reshape(x_t_shape) 154 y = y_in if not adjoint_b else y_in.reshape(y_t_shape) 155 epsilon = np.finfo(x.dtype).eps 156 delta = epsilon**(1.0 / 3.0) 157 with self.test_session(use_gpu=True): 158 inx = constant_op.constant(x) 159 iny = constant_op.constant(y) 160 z = math_ops.matmul(inx, iny, adjoint_a, adjoint_b) 161 loss = math_ops.reduce_sum(z) 162 ((x_jacob_t, x_jacob_n), 163 (y_jacob_t, y_jacob_n)) = gradient_checker.compute_gradient( 164 [inx, iny], [x.shape, y.shape], 165 loss, [1], 166 x_init_value=[x, y], 167 delta=delta) 168 tol = 20 * delta 169 self.assertAllClose(x_jacob_t, x_jacob_n, rtol=tol, atol=tol) 170 self.assertAllClose(y_jacob_t, y_jacob_n, rtol=tol, atol=tol) 171 172 # Tests a batched matmul of x, and y: x is a 3D tensor of shape [b, 173 # n, k] y is a 3D tensor of shape [b, k, m] the batched matmul 174 # computes z of shape [b, n, m], where z[i, :, :] = x[i, :, :] 175 # matmul y[i, :, :] 176 def _compare(self, b, n, k, m, dtype, adjoint_a, adjoint_b): 177 np.random.seed(42) 178 x = np.random.normal(0, 1, b * n * k).astype(dtype).reshape([b, n, k]) 179 if dtype in (np.complex64, np.complex128): 180 x.imag = np.random.normal(0, 1, 181 b * n * k).astype(dtype).reshape([b, n, k]) 182 y = np.random.normal(0, 1, b * k * m).astype(dtype).reshape([b, k, m]) 183 if dtype in (np.complex64, np.complex128): 184 y.imag = np.random.normal(0, 1, 185 b * k * m).astype(dtype).reshape([b, k, m]) 186 self._checkGrad(x, y, adjoint_a, adjoint_b) 187 188 189def _GetBatchMatmulGradientTest(dtype, adjoint_a, adjoint_b): 190 191 def Test(self): 192 self._compare(1, 2, 3, 5, dtype, adjoint_a, adjoint_b) 193 self._compare(3, 4, 7, 10, dtype, adjoint_a, adjoint_b) 194 195 return Test 196 197 198if __name__ == "__main__": 199 for dtype_ in [ 200 np.float16, np.float32, np.float64, np.complex64, np.complex128, np.int32 201 ]: 202 for adjoint_a_ in False, True: 203 for adjoint_b_ in False, True: 204 name = "%s_%s_%s" % (dtype_.__name__, adjoint_a_, adjoint_b_) 205 for use_static_shape in True, False: 206 setattr(BatchMatmulOpTest, 207 "testBatchMatmulOp_" + name + ("_%s" % use_static_shape), 208 _GetBatchMatmulOpTest(dtype_, adjoint_a_, adjoint_b_, 209 use_static_shape)) 210 if dtype_ is not np.int32: 211 setattr(BatchMatmulGradientTest, "testBatchMatmulGradient_" + name, 212 _GetBatchMatmulGradientTest(dtype_, adjoint_a_, adjoint_b_)) 213 test.main() 214