10cf9ed3a719c0782695154d5a0bca260001cec15A. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Licensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# you may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# You may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Unless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# distributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# See the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# limitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# ============================================================================== 15f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur"""Tests for tensorflow.ops.argmax_op.""" 16f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import absolute_import 17f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import division 18f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import print_function 19f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevan 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurimport numpy as np 21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 22cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengfrom tensorflow.python.framework import dtypes 235866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import math_ops 245866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.platform import test 255866e065bc95c1d7de8a27413b368016941889a6Justine Tunney 265866e065bc95c1d7de8a27413b368016941889a6Justine Tunney 275866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyclass ArgMaxTest(test.TestCase): 28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 295866e065bc95c1d7de8a27413b368016941889a6Justine Tunney def _testArg(self, 305866e065bc95c1d7de8a27413b368016941889a6Justine Tunney method, 315866e065bc95c1d7de8a27413b368016941889a6Justine Tunney x, 32f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson axis, 335866e065bc95c1d7de8a27413b368016941889a6Justine Tunney expected_values, 345866e065bc95c1d7de8a27413b368016941889a6Justine Tunney use_gpu=False, 355866e065bc95c1d7de8a27413b368016941889a6Justine Tunney expected_err_re=None): 36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur with self.test_session(use_gpu=use_gpu): 37f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson ans = method(x, axis=axis) 38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur if expected_err_re is None: 39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur tf_ans = ans.eval() 40cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng # Defaults to int64 output. 41cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng self.assertEqual(np.int64, tf_ans.dtype) 42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self.assertAllEqual(tf_ans, expected_values) 43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self.assertShapeEqual(expected_values, ans) 44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur else: 45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur with self.assertRaisesOpError(expected_err_re): 46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ans.eval() 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 485866e065bc95c1d7de8a27413b368016941889a6Justine Tunney def _testBothArg(self, 495866e065bc95c1d7de8a27413b368016941889a6Justine Tunney method, 505866e065bc95c1d7de8a27413b368016941889a6Justine Tunney x, 51f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson axis, 525866e065bc95c1d7de8a27413b368016941889a6Justine Tunney expected_values, 535866e065bc95c1d7de8a27413b368016941889a6Justine Tunney expected_err_re=None): 54f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson self._testArg(method, x, axis, expected_values, True, expected_err_re) 55f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson self._testArg(method, x, axis, expected_values, False, expected_err_re) 56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def _testBasic(self, dtype): 585866e065bc95c1d7de8a27413b368016941889a6Justine Tunney x = np.asarray(100 * np.random.randn(200), dtype=dtype) 59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 60f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson # Check that argmin and argmax match numpy along the primary axis 615866e065bc95c1d7de8a27413b368016941889a6Justine Tunney self._testBothArg(math_ops.argmax, x, 0, x.argmax()) 625866e065bc95c1d7de8a27413b368016941889a6Justine Tunney self._testBothArg(math_ops.argmin, x, 0, x.argmin()) 63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def _testDim(self, dtype): 655866e065bc95c1d7de8a27413b368016941889a6Justine Tunney x = np.asarray(100 * np.random.randn(3, 2, 4, 5, 6), dtype=dtype) 66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 67f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson # Check that argmin and argmax match numpy along all axes 68f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson for axis in range(-5, 5): 69f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson self._testBothArg(math_ops.argmax, x, axis, x.argmax(axis)) 70f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson self._testBothArg(math_ops.argmin, x, axis, x.argmin(axis)) 71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def testFloat(self): 73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._testBasic(np.float32) 74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._testDim(np.float32) 75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 76cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng def testFloatInt32Output(self): 77cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng x = np.asarray(100 * np.random.randn(200), dtype=np.float32) 78cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng expected_values = x.argmax() 79cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng with self.test_session(use_gpu=True): 80f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson ans = math_ops.argmax(x, axis=0, output_type=dtypes.int32) 81cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng tf_ans = ans.eval() 82cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng self.assertEqual(np.int32, tf_ans.dtype) 83cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng # The values are equal when comparing int32 to int64 because 84cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng # the values don't have a range that exceeds 32-bit integers. 85cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng self.assertAllEqual(tf_ans, expected_values) 86cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng expected_values = x.argmin() 87cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng with self.test_session(use_gpu=True): 88f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson ans = math_ops.argmin(x, axis=0, output_type=dtypes.int32) 89cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng tf_ans = ans.eval() 90cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng self.assertEqual(np.int32, tf_ans.dtype) 91cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng self.assertAllEqual(tf_ans, expected_values) 92cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng 93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def testDouble(self): 94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._testBasic(np.float64) 95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._testDim(np.float64) 96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def testInt32(self): 98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._testBasic(np.int32) 99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._testDim(np.int32) 100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def testInt64(self): 102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._testBasic(np.int64) 103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._testDim(np.int64) 104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 105cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving def testEmpty(self): 106cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving with self.test_session(): 1075866e065bc95c1d7de8a27413b368016941889a6Justine Tunney for op in math_ops.argmin, math_ops.argmax: 108cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving with self.assertRaisesOpError( 109cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving r"Reduction axis 0 is empty in shape \[0\]"): 110cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving op([], 0).eval() 111cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving 112df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur def testDefaultAxis(self): 113df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur with self.test_session(): 114df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur for op in math_ops.argmin, math_ops.argmax: 115df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur ans = op([1]).eval() 116df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur self.assertAllEqual(ans, 0) 117df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur 118cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving 119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurif __name__ == "__main__": 1205866e065bc95c1d7de8a27413b368016941889a6Justine Tunney test.main() 121