10cf9ed3a719c0782695154d5a0bca260001cec15A. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Licensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# you may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# You may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Unless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# distributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# See the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# limitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# ============================================================================== 15f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur"""Tests for tensorflow.ops.math_ops.matmul.""" 165866e065bc95c1d7de8a27413b368016941889a6Justine Tunney 17f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import absolute_import 18f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import division 1961d3a958d6d83cb6037490d933b47621cc4009ccVijay Vasudevanfrom __future__ import print_function 20f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevan 210386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Manéimport operator 2232c11fd917f82619f76273f6b83d7e21fb68c173Dandelion Manéimport numpy as np 23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 245866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.framework import constant_op 258525dbdc02b0c45c38ba16cfba26856eece2b851Pete Wardenfrom tensorflow.python.framework import ops 264d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlowerfrom tensorflow.python.framework import test_util 275866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import array_ops 285866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import gradient_checker 295866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import math_ops 305866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import random_ops 315866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import variables 32a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerfrom tensorflow.python.platform import test as test_lib 338525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden 343e3306ef0009b5b21050139f9b8e5f4868c4c0c7Yangzihao Wang# TODO(yangzihao): Currently matmul autotuning is disabled by default. Use 353e3306ef0009b5b21050139f9b8e5f4868c4c0c7Yangzihao Wang# os.environ["TF_MATMUL_AUTOTUNE_ENABLE"] = "1" to enable it. 363e3306ef0009b5b21050139f9b8e5f4868c4c0c7Yangzihao Wang 37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 38a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerdef _AddTest(test, op_name, testcase_name, fn): 39a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower test_name = "_".join(["test", op_name, testcase_name]) 40a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower if hasattr(test, test_name): 41a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower raise RuntimeError("Test %s defined more than once" % test_name) 42a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower setattr(test, test_name, fn) 43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 444d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlower 453fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlowerdef _GetTransposedMatrices(x, x_name, kwargs): 463fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower if kwargs["transpose_" + x_name] is True: 473fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower return x.T 483fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower elif kwargs["adjoint_" + x_name] is True: 493fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower return np.conj(x.T) 503fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower else: 513fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower return x 523fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower 533fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower 54a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerclass MatMulTest(test_lib.TestCase): 55a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower pass # Filled in below 564d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlower 574d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlower 58a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerdef _GetMatMulTest(a_np_, b_np_, use_static_shape_, **kwargs_): 594d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlower 60a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower def Test(self): 613fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower np_val = np.matrix(a_np_) * np.matrix(b_np_) 62a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower 63a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower use_gpu = True 64a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower if a_np_.dtype is np.float16 and ( 65a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower not test_util.CudaSupportsHalfMatMulAndConv()): 66a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower use_gpu = False 67a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower print("Built without fp16 matmul support for Cuda, running test on CPU.") 68a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower 693fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower # Transpose and possibly conjugate a_np_ and b_np_ according to the 703fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower # attributes such that tf.matmul(effective_a_np, effective_b_np, **kwargs) 713fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower # results in a valid matrix multiplication and produces the same result as 723fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower # np.matrix(a_np_) * np.matrix(b_np_) 733fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower effective_a_np = _GetTransposedMatrices(a_np_, "a", kwargs_) 743fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower effective_b_np = _GetTransposedMatrices(b_np_, "b", kwargs_) 75a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower with self.test_session(use_gpu=use_gpu) as sess: 76a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower if use_static_shape_: 77a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower a = constant_op.constant(effective_a_np) 78a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower b = constant_op.constant(effective_b_np) 79a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower res = math_ops.matmul(a, b, **kwargs_) 80a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower tf_val = res.eval() 814d55cb5c55fcc85009171b6a4657cbd966fd85e5A. Unique TensorFlower else: 82a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower a = array_ops.placeholder(a_np_.dtype) 83a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower b = array_ops.placeholder(b_np_.dtype) 84a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower res = math_ops.matmul(a, b, **kwargs_) 853fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower tf_val = sess.run(res, feed_dict={a: effective_a_np, b: effective_b_np}) 86011402d8987e753acd54c6251a7edd2e2d8155baA. Unique TensorFlower 87a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower self.assertAllCloseAccordingToType( 88a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower tf_val, 89a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower np_val, 90173d1b32ce54874e85400e2d5007b1c3380bdb14A. Unique TensorFlower float_rtol=2e-5, 91173d1b32ce54874e85400e2d5007b1c3380bdb14A. Unique TensorFlower float_atol=2e-5, 92173d1b32ce54874e85400e2d5007b1c3380bdb14A. Unique TensorFlower half_rtol=0.2, 93173d1b32ce54874e85400e2d5007b1c3380bdb14A. Unique TensorFlower half_atol=0.2) 94011402d8987e753acd54c6251a7edd2e2d8155baA. Unique TensorFlower 95a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower return Test 96011402d8987e753acd54c6251a7edd2e2d8155baA. Unique TensorFlower 97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 98a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerclass MatMulGradientTest(test_lib.TestCase): 993fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower pass # Will be filled in below. 100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 1013fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower 1023fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlowerdef _GetMatMulGradientTest(a_np_, b_np_, use_static_shape_, **kwargs_): 1033fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower 1043fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower def Test(self): 1053fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower if not use_static_shape_ or a_np_.dtype in (np.int32, np.float16): 1063fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower self.skipTest("Skipping infeasible gradient test.") 1073fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower 1083fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower # Transpose and possibly conjugate a_np_ and b_np_ according to the 1093fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower # attributes such that tf.matmul(effective_a_np, effective_b_np, **kwargs) 1103fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower # results in a valid matrix multiplication and produces the same result as 1113fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower # np.matrix(a_np_) * np.matrix(b_np_) 1123fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower effective_a_np = _GetTransposedMatrices(a_np_, "a", kwargs_) 1133fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower effective_b_np = _GetTransposedMatrices(b_np_, "b", kwargs_) 1143fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower 1153fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower epsilon = np.finfo(a_np_.dtype).eps 1163fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower delta = epsilon**(1.0 / 3.0) 1173fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower tol = 20 * delta 1183fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower with self.test_session(use_gpu=True): 1193fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower a = constant_op.constant(effective_a_np) 1203fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower b = constant_op.constant(effective_b_np) 1213fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower res = math_ops.matmul(a, b, **kwargs_) 1223fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower for x, x_init in [a, effective_a_np], [b, effective_b_np]: 1233fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower theoretical, numerical = gradient_checker.compute_gradient( 1243fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower x, 1253fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower x_init.shape, 1263fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower res, [a_np_.shape[0], b_np_.shape[1]], 1273fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower x_init_value=x_init, 1283fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower delta=delta) 1293fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower self.assertAllClose(theoretical, numerical, rtol=tol, atol=tol) 1303fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower 1313fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower return Test 132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 134a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlowerclass MatMulStatsTest(test_lib.TestCase): 1358525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden 1368525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden def testSimpleStatistics(self): 1375866e065bc95c1d7de8a27413b368016941889a6Justine Tunney g = ops.Graph() 1388525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden with g.as_default(): 1395866e065bc95c1d7de8a27413b368016941889a6Justine Tunney a = variables.Variable(random_ops.random_normal([25, 16])) 1405866e065bc95c1d7de8a27413b368016941889a6Justine Tunney b = variables.Variable(random_ops.random_normal([16, 9])) 1415866e065bc95c1d7de8a27413b368016941889a6Justine Tunney math_ops.matmul(a, b) 1428525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden for op in g.get_operations(): 1438525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden flops = ops.get_stats_for_node_def(g, op.node_def, "flops").value 1448525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden if op.name == "MatMul": 1458525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden self.assertEqual(7200, flops) 1468525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden 1478525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden def testTransposedStatistics(self): 1485866e065bc95c1d7de8a27413b368016941889a6Justine Tunney g = ops.Graph() 1498525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden with g.as_default(): 1505866e065bc95c1d7de8a27413b368016941889a6Justine Tunney a = variables.Variable(random_ops.random_normal([16, 25])) 1515866e065bc95c1d7de8a27413b368016941889a6Justine Tunney b = variables.Variable(random_ops.random_normal([16, 9])) 1525866e065bc95c1d7de8a27413b368016941889a6Justine Tunney math_ops.matmul(a, b, transpose_a=True) 1538525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden for op in g.get_operations(): 1548525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden flops = ops.get_stats_for_node_def(g, op.node_def, "flops").value 1558525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden if op.name == "MatMul": 1568525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden self.assertEqual(7200, flops) 1578525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden 1588525dbdc02b0c45c38ba16cfba26856eece2b851Pete Warden 1590386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Manétry: 1600386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané # @ operator supported since python 3.5. 1610386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané infix_matmul = operator.matmul 1620386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Manéexcept AttributeError: 16332c11fd917f82619f76273f6b83d7e21fb68c173Dandelion Mané 1640386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané # For earlier versions of python, emulate regular behavior. 1650386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané # Useful to build and test for 3.5+ on earlier versions. 16632c11fd917f82619f76273f6b83d7e21fb68c173Dandelion Mané def infix_matmul(x, y): # pylint: disable=invalid-name 1670386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané try: 1680386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané r = type(x).__matmul__(x, y) 1690386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané except AttributeError: 1700386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané r = NotImplemented 1710386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané if r is NotImplemented and type(x) is not type(y): 1720386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané try: 1730386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané r = type(y).__rmatmul__(y, x) 1740386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané except AttributeError: 1750386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané r = NotImplemented 1760386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané if r is NotImplemented: 1770386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané raise TypeError("unsupported operand type(s) for @: '{}' and '{}'" 1780386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané .format(type(x).__name__, type(y).__name__)) 1790386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané return r 1800386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané 1810386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané 1820386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Manéclass MatMulInfixOperatorTest(test_lib.TestCase): 1830386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané 1840386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané def testMismatchedShape(self): 1850386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané with self.assertRaisesWithPredicateMatch(ValueError, 1860386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané lambda e: "Shape must" in str(e)): 1870386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané infix_matmul( 1880386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané ops.convert_to_tensor([10.0, 20.0, 30.0]), 1890386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]])) 1900386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané 1910386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané def testMismatchedDimensions(self): 1920386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané with self.assertRaisesWithPredicateMatch( 1930386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané ValueError, lambda e: "Dimensions must" in str(e)): 1940386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané infix_matmul( 1950386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané ops.convert_to_tensor([[10.0, 20.0, 30.0]]), 1960386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]])) 1970386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané 1980386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané def testInfixMatmulIsTfMatmul(self): 1990386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané a = ops.convert_to_tensor([[10.0, 20.0, 30.0]]) 2000386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané b = ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0], [80.0, 90.0]]) 2010386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané c = infix_matmul(a, b) 2020386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané self.assertEqual(c.op.type, "MatMul") 2030386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané 2040386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané def testInfixMatmulDoesDotProduct(self): 2050386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané a = ops.convert_to_tensor([[10.0, 20.0, 30.0]]) 2060386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané b = ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0], [80.0, 90.0]]) 2070386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané c = infix_matmul(a, b) 2080386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané d = math_ops.matmul(a, b) 2090386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané with self.test_session(): 2100386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané self.assertAllEqual(c.eval(), d.eval()) 2110386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané 2120386a01ad3beb28364599d82199be1c0837b3fa9Dandelion Mané 213f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurif __name__ == "__main__": 214a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower sizes = [1, 3, 5] 215a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower trans_options = [[False, False], [True, False], [False, True]] 216ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower for use_static_shape in [False, True]: 217ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower for dtype in (np.int32, np.float16, np.float32, np.float64, np.complex64, 218ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower np.complex128): 219ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower if not use_static_shape and dtype == np.int32: 220ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower # TODO(rmlarsen): Re-enable this test when we have fixed the underlying 221ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower # bug in Windows (b/35935459). 222ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower continue 223ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower for m in sizes: 224ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower for n in sizes: 225ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower for k in sizes: 226ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower # Construct compatible random matrices a_np of size [m, k] and b_np 227ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower # of size [k, n]. 228ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower a_np = np.random.normal(-5, 5, m * k).astype(dtype).reshape([m, k]) 229ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower if dtype in (np.complex64, np.complex128): 230ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower a_np.imag = np.random.normal(-5, 5, 231ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower m * k).astype(dtype).reshape([m, k]) 232ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower b_np = np.random.normal(-5, 5, k * n).astype(dtype).reshape([k, n]) 233ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower if dtype in (np.complex64, np.complex128): 234ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower b_np.imag = np.random.normal(-5, 5, 235ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower k * n).astype(dtype).reshape([k, n]) 236ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower for adjoint_a, transpose_a in trans_options: 237ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower for adjoint_b, transpose_b in trans_options: 238a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower name = "%s_%s_%s_%s_%s_%s_%s_%s_%s" % ( 239ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower use_static_shape, dtype.__name__, m, n, k, adjoint_a, 240ceb7fc1b64611b09a1d03490f5f0a9c155a93137A. Unique TensorFlower transpose_a, adjoint_b, transpose_b) 241a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower _AddTest(MatMulTest, "MatMulTest", name, 242a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower _GetMatMulTest( 243a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower a_np, 244a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower b_np, 245a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower use_static_shape, 246a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower adjoint_a=adjoint_a, 247a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower transpose_a=transpose_a, 248a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower adjoint_b=adjoint_b, 249a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower transpose_b=transpose_b)) 2503fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower _AddTest(MatMulGradientTest, "MatMulGradientTest", name, 2513fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower _GetMatMulGradientTest( 2523fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower a_np, 2533fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower b_np, 2543fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower use_static_shape, 2553fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower adjoint_a=adjoint_a, 2563fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower transpose_a=transpose_a, 2573fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower adjoint_b=adjoint_b, 2583fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower transpose_b=transpose_b)) 2593fe917bfb116ecebdd73de632c18cef6a1a04d9cA. Unique TensorFlower 260a0db2664a4051b98ab90db94bfe86b17379dbef4A. Unique TensorFlower test_lib.main() 261