10cf9ed3a719c0782695154d5a0bca260001cec15A. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Licensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# you may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# You may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#     http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Unless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# distributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# See the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# limitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# ==============================================================================
15f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur"""Tests for tensorflow.ops.argmax_op."""
16f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import absolute_import
17f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import division
18f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import print_function
19f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevan
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurimport numpy as np
21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
22cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Fengfrom tensorflow.python.framework import dtypes
235866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import math_ops
245866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.platform import test
255866e065bc95c1d7de8a27413b368016941889a6Justine Tunney
265866e065bc95c1d7de8a27413b368016941889a6Justine Tunney
275866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyclass ArgMaxTest(test.TestCase):
28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
295866e065bc95c1d7de8a27413b368016941889a6Justine Tunney  def _testArg(self,
305866e065bc95c1d7de8a27413b368016941889a6Justine Tunney               method,
315866e065bc95c1d7de8a27413b368016941889a6Justine Tunney               x,
32f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson               axis,
335866e065bc95c1d7de8a27413b368016941889a6Justine Tunney               expected_values,
345866e065bc95c1d7de8a27413b368016941889a6Justine Tunney               use_gpu=False,
355866e065bc95c1d7de8a27413b368016941889a6Justine Tunney               expected_err_re=None):
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    with self.test_session(use_gpu=use_gpu):
37f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson      ans = method(x, axis=axis)
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      if expected_err_re is None:
39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        tf_ans = ans.eval()
40cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng        # Defaults to int64 output.
41cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng        self.assertEqual(np.int64, tf_ans.dtype)
42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        self.assertAllEqual(tf_ans, expected_values)
43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        self.assertShapeEqual(expected_values, ans)
44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      else:
45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        with self.assertRaisesOpError(expected_err_re):
46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur          ans.eval()
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
485866e065bc95c1d7de8a27413b368016941889a6Justine Tunney  def _testBothArg(self,
495866e065bc95c1d7de8a27413b368016941889a6Justine Tunney                   method,
505866e065bc95c1d7de8a27413b368016941889a6Justine Tunney                   x,
51f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson                   axis,
525866e065bc95c1d7de8a27413b368016941889a6Justine Tunney                   expected_values,
535866e065bc95c1d7de8a27413b368016941889a6Justine Tunney                   expected_err_re=None):
54f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson    self._testArg(method, x, axis, expected_values, True, expected_err_re)
55f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson    self._testArg(method, x, axis, expected_values, False, expected_err_re)
56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def _testBasic(self, dtype):
585866e065bc95c1d7de8a27413b368016941889a6Justine Tunney    x = np.asarray(100 * np.random.randn(200), dtype=dtype)
59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
60f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson    # Check that argmin and argmax match numpy along the primary axis
615866e065bc95c1d7de8a27413b368016941889a6Justine Tunney    self._testBothArg(math_ops.argmax, x, 0, x.argmax())
625866e065bc95c1d7de8a27413b368016941889a6Justine Tunney    self._testBothArg(math_ops.argmin, x, 0, x.argmin())
63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def _testDim(self, dtype):
655866e065bc95c1d7de8a27413b368016941889a6Justine Tunney    x = np.asarray(100 * np.random.randn(3, 2, 4, 5, 6), dtype=dtype)
66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
67f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson    # Check that argmin and argmax match numpy along all axes
68f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson    for axis in range(-5, 5):
69f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson      self._testBothArg(math_ops.argmax, x, axis, x.argmax(axis))
70f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson      self._testBothArg(math_ops.argmin, x, axis, x.argmin(axis))
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testFloat(self):
73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._testBasic(np.float32)
74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._testDim(np.float32)
75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
76cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng  def testFloatInt32Output(self):
77cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng    x = np.asarray(100 * np.random.randn(200), dtype=np.float32)
78cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng    expected_values = x.argmax()
79cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng    with self.test_session(use_gpu=True):
80f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson      ans = math_ops.argmax(x, axis=0, output_type=dtypes.int32)
81cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng      tf_ans = ans.eval()
82cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng      self.assertEqual(np.int32, tf_ans.dtype)
83cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng      # The values are equal when comparing int32 to int64 because
84cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng      # the values don't have a range that exceeds 32-bit integers.
85cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng      self.assertAllEqual(tf_ans, expected_values)
86cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng    expected_values = x.argmin()
87cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng    with self.test_session(use_gpu=True):
88f6a2bcbeed43744aede1d312dacfa252ce63eaf2John Lawson      ans = math_ops.argmin(x, axis=0, output_type=dtypes.int32)
89cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng      tf_ans = ans.eval()
90cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng      self.assertEqual(np.int32, tf_ans.dtype)
91cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng      self.assertAllEqual(tf_ans, expected_values)
92cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng
93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testDouble(self):
94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._testBasic(np.float64)
95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._testDim(np.float64)
96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testInt32(self):
98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._testBasic(np.int32)
99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._testDim(np.int32)
100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testInt64(self):
102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._testBasic(np.int64)
103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._testDim(np.int64)
104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
105cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving  def testEmpty(self):
106cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving    with self.test_session():
1075866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      for op in math_ops.argmin, math_ops.argmax:
108cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving        with self.assertRaisesOpError(
109cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving            r"Reduction axis 0 is empty in shape \[0\]"):
110cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving          op([], 0).eval()
111cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving
112df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur  def testDefaultAxis(self):
113df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur    with self.test_session():
114df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur      for op in math_ops.argmin, math_ops.argmax:
115df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur        ans = op([1]).eval()
116df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur        self.assertAllEqual(ans, 0)
117df04c1ad4bbebbfdf0fbf9f8e9b95115341d7c8fManjunath Kudlur
118cb78f99de569c13dcc0fc66c555cd77c8884f79dGeoffrey Irving
119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurif __name__ == "__main__":
1205866e065bc95c1d7de8a27413b368016941889a6Justine Tunney  test.main()
121