10cf9ed3a719c0782695154d5a0bca260001cec15A. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Licensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# you may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# You may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#     http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Unless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# distributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# See the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# limitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# ==============================================================================
15f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur"""Tests for tensorflow.ops.math_ops.matmul."""
165866e065bc95c1d7de8a27413b368016941889a6Justine Tunney
17f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import absolute_import
18f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import division
1961d3a958d6d83cb6037490d933b47621cc4009ccVijay Vasudevanfrom __future__ import print_function
20f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevan
210386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Manéimport operator
2232c11fd917f82619f76273f6b83d7e21fb68c173Dandelion Manéimport numpy as np
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
245866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.framework import constant_op
258525dbdc02b0c45c38ba16cfba26856eece2b851Pete Wardenfrom tensorflow.python.framework import ops
264d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlowerfrom tensorflow.python.framework import test_util
275866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import array_ops
285866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import gradient_checker
295866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import math_ops
305866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import random_ops
315866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import variables
32a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerfrom tensorflow.python.platform import test as test_lib
338525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden
343e3306ef0009b5b21050139f9b8e5f4868c4c0c7Yangzihao Wang# TODO(yangzihao): Currently matmul autotuning is disabled by default. Use
353e3306ef0009b5b21050139f9b8e5f4868c4c0c7Yangzihao Wang# os.environ["TF_MATMUL_AUTOTUNE_ENABLE"] = "1" to enable it.
363e3306ef0009b5b21050139f9b8e5f4868c4c0c7Yangzihao Wang
37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
38a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerdef _AddTest(test, op_name, testcase_name, fn):
39a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower  test_name = "_".join(["test", op_name, testcase_name])
40a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower  if hasattr(test, test_name):
41a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower    raise RuntimeError("Test %s defined more than once" % test_name)
42a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower  setattr(test, test_name, fn)
43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
444d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlower
453fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlowerdef _GetTransposedMatrices(x, x_name, kwargs):
463fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower  if kwargs["transpose_" + x_name] is True:
473fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    return x.T
483fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower  elif kwargs["adjoint_" + x_name] is True:
493fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    return np.conj(x.T)
503fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower  else:
513fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    return x
523fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower
533fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower
54a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerclass MatMulTest(test_lib.TestCase):
55a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower  pass  # Filled in below
564d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlower
574d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlower
58a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerdef _GetMatMulTest(a_np_, b_np_, use_static_shape_, **kwargs_):
594d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlower
60a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower  def Test(self):
613fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    np_val = np.matrix(a_np_) * np.matrix(b_np_)
62a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower
63a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower    use_gpu = True
64a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower    if a_np_.dtype is np.float16 and (
65a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower        not test_util.CudaSupportsHalfMatMulAndConv()):
66a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower      use_gpu = False
67a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower      print("Built without fp16 matmul support for Cuda, running test on CPU.")
68a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower
693fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    # Transpose and possibly conjugate a_np_ and b_np_ according to the
703fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    # attributes such that tf.matmul(effective_a_np, effective_b_np, **kwargs)
713fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    # results in a valid matrix multiplication and produces the same result as
723fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    # np.matrix(a_np_) * np.matrix(b_np_)
733fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    effective_a_np = _GetTransposedMatrices(a_np_, "a", kwargs_)
743fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    effective_b_np = _GetTransposedMatrices(b_np_, "b", kwargs_)
75a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower    with self.test_session(use_gpu=use_gpu) as sess:
76a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower      if use_static_shape_:
77a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower        a = constant_op.constant(effective_a_np)
78a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower        b = constant_op.constant(effective_b_np)
79a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower        res = math_ops.matmul(a, b, **kwargs_)
80a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower        tf_val = res.eval()
814d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlower      else:
82a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower        a = array_ops.placeholder(a_np_.dtype)
83a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower        b = array_ops.placeholder(b_np_.dtype)
84a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower        res = math_ops.matmul(a, b, **kwargs_)
853fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower        tf_val = sess.run(res, feed_dict={a: effective_a_np, b: effective_b_np})
86011402d8987e753acd54c6251a7edd2e2d8155baA. Unique TensorFlower
87a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower    self.assertAllCloseAccordingToType(
88a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower        tf_val,
89a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower        np_val,
90173d1b32ce54874e85400e2d5007b1c3380bdb14A. Unique TensorFlower        float_rtol=2e-5,
91173d1b32ce54874e85400e2d5007b1c3380bdb14A. Unique TensorFlower        float_atol=2e-5,
92173d1b32ce54874e85400e2d5007b1c3380bdb14A. Unique TensorFlower        half_rtol=0.2,
93173d1b32ce54874e85400e2d5007b1c3380bdb14A. Unique TensorFlower        half_atol=0.2)
94011402d8987e753acd54c6251a7edd2e2d8155baA. Unique TensorFlower
95a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower  return Test
96011402d8987e753acd54c6251a7edd2e2d8155baA. Unique TensorFlower
97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
98a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerclass MatMulGradientTest(test_lib.TestCase):
993fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower  pass  # Will be filled in below.
100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
1013fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower
1023fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlowerdef _GetMatMulGradientTest(a_np_, b_np_, use_static_shape_, **kwargs_):
1033fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower
1043fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower  def Test(self):
1053fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    if not use_static_shape_ or a_np_.dtype in (np.int32, np.float16):
1063fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower      self.skipTest("Skipping infeasible gradient test.")
1073fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower
1083fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    # Transpose and possibly conjugate a_np_ and b_np_ according to the
1093fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    # attributes such that tf.matmul(effective_a_np, effective_b_np, **kwargs)
1103fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    # results in a valid matrix multiplication and produces the same result as
1113fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    # np.matrix(a_np_) * np.matrix(b_np_)
1123fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    effective_a_np = _GetTransposedMatrices(a_np_, "a", kwargs_)
1133fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    effective_b_np = _GetTransposedMatrices(b_np_, "b", kwargs_)
1143fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower
1153fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    epsilon = np.finfo(a_np_.dtype).eps
1163fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    delta = epsilon**(1.0 / 3.0)
1173fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    tol = 20 * delta
1183fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower    with self.test_session(use_gpu=True):
1193fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower      a = constant_op.constant(effective_a_np)
1203fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower      b = constant_op.constant(effective_b_np)
1213fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower      res = math_ops.matmul(a, b, **kwargs_)
1223fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower      for x, x_init in [a, effective_a_np], [b, effective_b_np]:
1233fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower        theoretical, numerical = gradient_checker.compute_gradient(
1243fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower            x,
1253fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower            x_init.shape,
1263fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower            res, [a_np_.shape[0], b_np_.shape[1]],
1273fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower            x_init_value=x_init,
1283fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower            delta=delta)
1293fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower        self.assertAllClose(theoretical, numerical, rtol=tol, atol=tol)
1303fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower
1313fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower  return Test
132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
134a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerclass MatMulStatsTest(test_lib.TestCase):
1358525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden
1368525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden  def testSimpleStatistics(self):
1375866e065bc95c1d7de8a27413b368016941889a6Justine Tunney    g = ops.Graph()
1388525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden    with g.as_default():
1395866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      a = variables.Variable(random_ops.random_normal([25, 16]))
1405866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      b = variables.Variable(random_ops.random_normal([16, 9]))
1415866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      math_ops.matmul(a, b)
1428525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden      for op in g.get_operations():
1438525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden        flops = ops.get_stats_for_node_def(g, op.node_def, "flops").value
1448525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden        if op.name == "MatMul":
1458525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden          self.assertEqual(7200, flops)
1468525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden
1478525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden  def testTransposedStatistics(self):
1485866e065bc95c1d7de8a27413b368016941889a6Justine Tunney    g = ops.Graph()
1498525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden    with g.as_default():
1505866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      a = variables.Variable(random_ops.random_normal([16, 25]))
1515866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      b = variables.Variable(random_ops.random_normal([16, 9]))
1525866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      math_ops.matmul(a, b, transpose_a=True)
1538525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden      for op in g.get_operations():
1548525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden        flops = ops.get_stats_for_node_def(g, op.node_def, "flops").value
1558525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden        if op.name == "MatMul":
1568525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden          self.assertEqual(7200, flops)
1578525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden
1588525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden
1590386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Manétry:
1600386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané  # @ operator supported since python 3.5.
1610386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané  infix_matmul = operator.matmul
1620386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Manéexcept AttributeError:
16332c11fd917f82619f76273f6b83d7e21fb68c173Dandelion Mané
1640386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané  # For earlier versions of python, emulate regular behavior.
1650386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané  # Useful to build and test for 3.5+ on earlier versions.
16632c11fd917f82619f76273f6b83d7e21fb68c173Dandelion Mané  def infix_matmul(x, y):  # pylint: disable=invalid-name
1670386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    try:
1680386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané      r = type(x).__matmul__(x, y)
1690386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    except AttributeError:
1700386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané      r = NotImplemented
1710386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    if r is NotImplemented and type(x) is not type(y):
1720386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané      try:
1730386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané        r = type(y).__rmatmul__(y, x)
1740386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané      except AttributeError:
1750386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané        r = NotImplemented
1760386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    if r is NotImplemented:
1770386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané      raise TypeError("unsupported operand type(s) for @: '{}' and '{}'"
1780386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané                      .format(type(x).__name__, type(y).__name__))
1790386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    return r
1800386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané
1810386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané
1820386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Manéclass MatMulInfixOperatorTest(test_lib.TestCase):
1830386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané
1840386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané  def testMismatchedShape(self):
1850386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    with self.assertRaisesWithPredicateMatch(ValueError,
1860386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané                                             lambda e: "Shape must" in str(e)):
1870386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané      infix_matmul(
1880386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané          ops.convert_to_tensor([10.0, 20.0, 30.0]),
1890386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané          ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
1900386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané
1910386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané  def testMismatchedDimensions(self):
1920386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    with self.assertRaisesWithPredicateMatch(
1930386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané        ValueError, lambda e: "Dimensions must" in str(e)):
1940386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané      infix_matmul(
1950386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané          ops.convert_to_tensor([[10.0, 20.0, 30.0]]),
1960386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané          ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
1970386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané
1980386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané  def testInfixMatmulIsTfMatmul(self):
1990386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    a = ops.convert_to_tensor([[10.0, 20.0, 30.0]])
2000386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    b = ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0], [80.0, 90.0]])
2010386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    c = infix_matmul(a, b)
2020386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    self.assertEqual(c.op.type, "MatMul")
2030386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané
2040386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané  def testInfixMatmulDoesDotProduct(self):
2050386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    a = ops.convert_to_tensor([[10.0, 20.0, 30.0]])
2060386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    b = ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0], [80.0, 90.0]])
2070386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    c = infix_matmul(a, b)
2080386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    d = math_ops.matmul(a, b)
2090386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané    with self.test_session():
2100386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané      self.assertAllEqual(c.eval(), d.eval())
2110386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané
2120386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané
213f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurif __name__ == "__main__":
214a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower  sizes = [1, 3, 5]
215a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower  trans_options = [[False, False], [True, False], [False, True]]
216ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower  for use_static_shape in [False, True]:
217ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower    for dtype in (np.int32, np.float16, np.float32, np.float64, np.complex64,
218ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower                  np.complex128):
219ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower      if not use_static_shape and dtype == np.int32:
220ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower        # TODO(rmlarsen): Re-enable this test when we have fixed the underlying
221ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower        # bug in Windows (b/35935459).
222ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower        continue
223ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower      for m in sizes:
224ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower        for n in sizes:
225ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower          for k in sizes:
226ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower            # Construct compatible random matrices a_np of size [m, k] and b_np
227ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower            # of size [k, n].
228ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower            a_np = np.random.normal(-5, 5, m * k).astype(dtype).reshape([m, k])
229ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower            if dtype in (np.complex64, np.complex128):
230ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower              a_np.imag = np.random.normal(-5, 5,
231ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower                                           m * k).astype(dtype).reshape([m, k])
232ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower            b_np = np.random.normal(-5, 5, k * n).astype(dtype).reshape([k, n])
233ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower            if dtype in (np.complex64, np.complex128):
234ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower              b_np.imag = np.random.normal(-5, 5,
235ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower                                           k * n).astype(dtype).reshape([k, n])
236ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower            for adjoint_a, transpose_a in trans_options:
237ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower              for adjoint_b, transpose_b in trans_options:
238a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower                name = "%s_%s_%s_%s_%s_%s_%s_%s_%s" % (
239ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower                    use_static_shape, dtype.__name__, m, n, k, adjoint_a,
240ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower                    transpose_a, adjoint_b, transpose_b)
241a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower                _AddTest(MatMulTest, "MatMulTest", name,
242a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower                         _GetMatMulTest(
243a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower                             a_np,
244a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower                             b_np,
245a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower                             use_static_shape,
246a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower                             adjoint_a=adjoint_a,
247a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower                             transpose_a=transpose_a,
248a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower                             adjoint_b=adjoint_b,
249a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower                             transpose_b=transpose_b))
2503fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower                _AddTest(MatMulGradientTest, "MatMulGradientTest", name,
2513fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower                         _GetMatMulGradientTest(
2523fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower                             a_np,
2533fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower                             b_np,
2543fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower                             use_static_shape,
2553fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower                             adjoint_a=adjoint_a,
2563fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower                             transpose_a=transpose_a,
2573fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower                             adjoint_b=adjoint_b,
2583fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower                             transpose_b=transpose_b))
2593fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower
260a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower  test_lib.main()
261