1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.tf.BatchMatMul."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gradient_checker
26from tensorflow.python.ops import math_ops
27from tensorflow.python.platform import test
28
29
30class BatchMatmulOpTest(test.TestCase):
31
32  # Uses numpy to compute batch_matmul(x, y, adjoint_a, adjoint_b).
33  def _npBatchMatmul(self, x, y, adjoint_a, adjoint_b):
34    # output's shape depends on adj[0] and adj[1]
35    d0 = x.shape[-2] if not adjoint_a else x.shape[-1]
36    d2 = y.shape[-1] if not adjoint_b else y.shape[-2]
37    batch_dims = x.shape[:-2]
38    num = np.prod(batch_dims)
39    z = np.empty(list(batch_dims) + [d0, d2], dtype=x.dtype)
40    xr = x.reshape([num, x.shape[-2], x.shape[-1]])
41    yr = y.reshape([num, y.shape[-2], y.shape[-1]])
42    zr = z.reshape([num, z.shape[-2], z.shape[-1]])
43    for i in range(num):
44      a = np.matrix(xr[i, :, :])
45      if adjoint_a:
46        a = a.transpose().conj()
47      b = np.matrix(yr[i, :, :])
48      if adjoint_b:
49        b = b.transpose().conj()
50      zr[i, :, :] = a * b
51    return z
52
53  # Test _npBatchMatMul works.
54  def testNpVersion(self):
55    x = np.array([0., 1., 2., 3.]).reshape([1, 2, 2])
56    y = np.array([1., 2., 3., 4.]).reshape([1, 2, 2])
57    z0 = self._npBatchMatmul(x, y, False, False)
58    z1 = np.array([3., 4., 11., 16.]).reshape([1, 2, 2])
59    self.assertTrue(np.array_equal(z0, z1))
60
61    x = np.array([1., (1j), (-1.), (-1j)]).reshape([1, 2, 2])
62    y = x * np.complex(1, 1)  # rotate x 90 degree
63    z0 = self._npBatchMatmul(x, y, False, False)
64    z1 = np.array([2., (2.j), -2., (-2.j)]).reshape([1, 2, 2])
65    self.assertTrue(np.array_equal(z0, z1))
66
67    z0 = self._npBatchMatmul(x, y, False, True)
68    z1 = np.array([(2. - 2.j), (-2. + 2.j), (-2. + 2.j), (2. - 2.j)]).reshape(
69        [1, 2, 2])
70    self.assertTrue(np.array_equal(z0, z1))
71
72    z0 = self._npBatchMatmul(x, y, True, False)
73    z1 = np.array([(2. + 2.j), (-2. + 2.j), (2. - 2.j), (2. + 2.j)]).reshape(
74        [1, 2, 2])
75    self.assertTrue(np.array_equal(z0, z1))
76
77  # Compares _tfpBatchMatmul(x, y, alpha, adj) and _npBatchMatMul(x, y, alpha,
78  # adj)
79  def _compare(self, x_in, y_in, adjoint_a, adjoint_b, static_shape=True):
80    x_t_shape = x_in.shape[:-2] + (x_in.shape[-1], x_in.shape[-2])
81    y_t_shape = y_in.shape[:-2] + (y_in.shape[-1], y_in.shape[-2])
82    x = x_in if not adjoint_a else x_in.reshape(x_t_shape)
83    y = y_in if not adjoint_b else y_in.reshape(y_t_shape)
84    is_floating = x.dtype != np.int32
85    tol = 100 * np.finfo(x.dtype).eps if is_floating else 0
86    with self.test_session(use_gpu=is_floating) as sess:
87      if static_shape:
88        z0 = math_ops.matmul(x, y, adjoint_a=adjoint_a, adjoint_b=adjoint_b)
89        z0_val = z0.eval()
90      else:
91        x_ph = array_ops.placeholder(x.dtype)
92        y_ph = array_ops.placeholder(y.dtype)
93        z0 = math_ops.matmul(
94            x_ph, y_ph, adjoint_a=adjoint_a, adjoint_b=adjoint_b)
95        z0_val = sess.run(z0, feed_dict={x_ph: x, y_ph: y})
96      z1 = self._npBatchMatmul(x, y, adjoint_a, adjoint_b)
97      self.assertAllClose(z0_val, z1, rtol=tol, atol=tol)
98
99  def _rand(self, shape, dtype):
100    vals = np.array(np.random.normal(-10, 10, np.prod(shape)), dtype=dtype)
101    if dtype in (np.complex64, np.complex128):
102      imag = np.array(np.random.normal(-10, 10, np.prod(shape)), dtype=dtype)
103      vals += 1j * imag
104    return vals.reshape(shape)
105
106  def _testNonEmpty(self, dtype, adjoint_a, adjoint_b, use_static_shape):
107
108    def compareNonEmpty(self, a_shape, b_shape):
109      self._compare(
110          self._rand(a_shape, dtype),
111          self._rand(b_shape, dtype), adjoint_a, adjoint_b, use_static_shape)
112
113    compareNonEmpty(self, [1, 2, 3], [1, 3, 5])
114    compareNonEmpty(self, [1, 2, 3], [1, 3, 1])
115    compareNonEmpty(self, [1, 1, 3], [1, 3, 5])
116    compareNonEmpty(self, [1, 2, 3], [1, 3, 5])
117    compareNonEmpty(self, [7, 1, 3], [7, 3, 5])
118    compareNonEmpty(self, [7, 2, 3], [7, 3, 1])
119    compareNonEmpty(self, [7, 2, 3], [7, 3, 5])
120    compareNonEmpty(self, [10, 64, 75], [10, 75, 30])
121    compareNonEmpty(self, [5, 7, 2, 3], [5, 7, 3, 5])
122
123  def _testEmpty(self, dtype, adjoint_a, adjoint_b, use_static_shape):
124
125    def compareEmpty(self, a_shape, b_shape):
126      self._compare(
127          np.zeros(a_shape).astype(dtype),
128          np.zeros(b_shape).astype(dtype), adjoint_a, adjoint_b,
129          use_static_shape)
130
131    compareEmpty(self, [0, 3, 2], [0, 2, 4])
132    compareEmpty(self, [3, 0, 2], [3, 2, 5])
133    compareEmpty(self, [3, 3, 2], [3, 2, 0])
134
135
136def _GetBatchMatmulOpTest(dtype, adjoint_a, adjoint_b, use_static_shape):
137
138  def Test(self):
139    np.random.seed(42)
140    self._testNonEmpty(dtype, adjoint_a, adjoint_b, use_static_shape)
141    self._testEmpty(dtype, adjoint_a, adjoint_b, use_static_shape)
142
143  return Test
144
145
146class BatchMatmulGradientTest(test.TestCase):
147
148  # loss = sum(batch_matmul(x, y)). Verify dl/dx and dl/dy via the
149  # gradient checker.
150  def _checkGrad(self, x_in, y_in, adjoint_a, adjoint_b):
151    x_t_shape = x_in.shape[:-2] + (x_in.shape[-1], x_in.shape[-2])
152    y_t_shape = y_in.shape[:-2] + (y_in.shape[-1], y_in.shape[-2])
153    x = x_in if not adjoint_a else x_in.reshape(x_t_shape)
154    y = y_in if not adjoint_b else y_in.reshape(y_t_shape)
155    epsilon = np.finfo(x.dtype).eps
156    delta = epsilon**(1.0 / 3.0)
157    with self.test_session(use_gpu=True):
158      inx = constant_op.constant(x)
159      iny = constant_op.constant(y)
160      z = math_ops.matmul(inx, iny, adjoint_a, adjoint_b)
161      loss = math_ops.reduce_sum(z)
162      ((x_jacob_t, x_jacob_n),
163       (y_jacob_t, y_jacob_n)) = gradient_checker.compute_gradient(
164           [inx, iny], [x.shape, y.shape],
165           loss, [1],
166           x_init_value=[x, y],
167           delta=delta)
168      tol = 20 * delta
169      self.assertAllClose(x_jacob_t, x_jacob_n, rtol=tol, atol=tol)
170      self.assertAllClose(y_jacob_t, y_jacob_n, rtol=tol, atol=tol)
171
172  # Tests a batched matmul of x, and y: x is a 3D tensor of shape [b,
173  # n, k] y is a 3D tensor of shape [b, k, m] the batched matmul
174  # computes z of shape [b, n, m], where z[i, :, :] = x[i, :, :]
175  # matmul y[i, :, :]
176  def _compare(self, b, n, k, m, dtype, adjoint_a, adjoint_b):
177    np.random.seed(42)
178    x = np.random.normal(0, 1, b * n * k).astype(dtype).reshape([b, n, k])
179    if dtype in (np.complex64, np.complex128):
180      x.imag = np.random.normal(0, 1,
181                                b * n * k).astype(dtype).reshape([b, n, k])
182    y = np.random.normal(0, 1, b * k * m).astype(dtype).reshape([b, k, m])
183    if dtype in (np.complex64, np.complex128):
184      y.imag = np.random.normal(0, 1,
185                                b * k * m).astype(dtype).reshape([b, k, m])
186    self._checkGrad(x, y, adjoint_a, adjoint_b)
187
188
189def _GetBatchMatmulGradientTest(dtype, adjoint_a, adjoint_b):
190
191  def Test(self):
192    self._compare(1, 2, 3, 5, dtype, adjoint_a, adjoint_b)
193    self._compare(3, 4, 7, 10, dtype, adjoint_a, adjoint_b)
194
195  return Test
196
197
198if __name__ == "__main__":
199  for dtype_ in [
200      np.float16, np.float32, np.float64, np.complex64, np.complex128, np.int32
201  ]:
202    for adjoint_a_ in False, True:
203      for adjoint_b_ in False, True:
204        name = "%s_%s_%s" % (dtype_.__name__, adjoint_a_, adjoint_b_)
205        for use_static_shape in True, False:
206          setattr(BatchMatmulOpTest,
207                  "testBatchMatmulOp_" + name + ("_%s" % use_static_shape),
208                  _GetBatchMatmulOpTest(dtype_, adjoint_a_, adjoint_b_,
209                                        use_static_shape))
210        if dtype_ is not np.int32:
211          setattr(BatchMatmulGradientTest, "testBatchMatmulGradient_" + name,
212                  _GetBatchMatmulGradientTest(dtype_, adjoint_a_, adjoint_b_))
213  test.main()
214