1fe8406149feec453250905965a14285465cd2063Shanqing Cai# =============================================================================
2fe8406149feec453250905965a14285465cd2063Shanqing Cai# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3fe8406149feec453250905965a14285465cd2063Shanqing Cai#
4fe8406149feec453250905965a14285465cd2063Shanqing Cai# Licensed under the Apache License, Version 2.0 (the "License");
5fe8406149feec453250905965a14285465cd2063Shanqing Cai# you may not use this file except in compliance with the License.
6fe8406149feec453250905965a14285465cd2063Shanqing Cai# You may obtain a copy of the License at
7fe8406149feec453250905965a14285465cd2063Shanqing Cai#
8fe8406149feec453250905965a14285465cd2063Shanqing Cai#     http://www.apache.org/licenses/LICENSE-2.0
9fe8406149feec453250905965a14285465cd2063Shanqing Cai#
10fe8406149feec453250905965a14285465cd2063Shanqing Cai# Unless required by applicable law or agreed to in writing, software
11fe8406149feec453250905965a14285465cd2063Shanqing Cai# distributed under the License is distributed on an "AS IS" BASIS,
12fe8406149feec453250905965a14285465cd2063Shanqing Cai# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13fe8406149feec453250905965a14285465cd2063Shanqing Cai# See the License for the specific language governing permissions and
14fe8406149feec453250905965a14285465cd2063Shanqing Cai# limitations under the License.
15fe8406149feec453250905965a14285465cd2063Shanqing Cai# =============================================================================
16fe8406149feec453250905965a14285465cd2063Shanqing Cai
17fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import absolute_import
18fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import division
19fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import print_function
20fe8406149feec453250905965a14285465cd2063Shanqing Cai
21fe8406149feec453250905965a14285465cd2063Shanqing Caiimport numpy
22d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
23fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.contrib.periodic_resample import periodic_resample
24d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xiefrom tensorflow.python.framework import errors_impl
25fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.framework import test_util
26fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import variables
27fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.platform import googletest
28fe8406149feec453250905965a14285465cd2063Shanqing Cai
29fe8406149feec453250905965a14285465cd2063Shanqing Cai
30fe8406149feec453250905965a14285465cd2063Shanqing Caiclass PeriodicResampleTest(test_util.TensorFlowTestCase):
31fe8406149feec453250905965a14285465cd2063Shanqing Cai
32fe8406149feec453250905965a14285465cd2063Shanqing Cai  def testPeriodicResampleBasic2D(self):
33fe8406149feec453250905965a14285465cd2063Shanqing Cai
34fe8406149feec453250905965a14285465cd2063Shanqing Cai    input_tensor = numpy.arange(12).reshape((3, 4))
35fe8406149feec453250905965a14285465cd2063Shanqing Cai    desired_shape = numpy.array([6, None])
36fe8406149feec453250905965a14285465cd2063Shanqing Cai    output_tensor = input_tensor.reshape((6, 2))
37fe8406149feec453250905965a14285465cd2063Shanqing Cai
38fe8406149feec453250905965a14285465cd2063Shanqing Cai    with self.test_session():
39fe8406149feec453250905965a14285465cd2063Shanqing Cai      variables.global_variables_initializer().run()
40fe8406149feec453250905965a14285465cd2063Shanqing Cai      result = periodic_resample(input_tensor, desired_shape).eval()
41fe8406149feec453250905965a14285465cd2063Shanqing Cai      self.assertAllEqual(result, output_tensor)
42fe8406149feec453250905965a14285465cd2063Shanqing Cai
43fe8406149feec453250905965a14285465cd2063Shanqing Cai  def testPeriodicResampleTruncatedBasic2D(self):
44fe8406149feec453250905965a14285465cd2063Shanqing Cai
45fe8406149feec453250905965a14285465cd2063Shanqing Cai    input_tensor = numpy.arange(12).reshape((3, 4))
46fe8406149feec453250905965a14285465cd2063Shanqing Cai    desired_shape = numpy.array([5, None])
47fe8406149feec453250905965a14285465cd2063Shanqing Cai    output_tensor = input_tensor.reshape((6, 2))[:-1]
48fe8406149feec453250905965a14285465cd2063Shanqing Cai
49fe8406149feec453250905965a14285465cd2063Shanqing Cai    with self.test_session():
50fe8406149feec453250905965a14285465cd2063Shanqing Cai      variables.global_variables_initializer().run()
51fe8406149feec453250905965a14285465cd2063Shanqing Cai      result = periodic_resample(input_tensor, desired_shape).eval()
52fe8406149feec453250905965a14285465cd2063Shanqing Cai      self.assertAllEqual(result, output_tensor)
53fe8406149feec453250905965a14285465cd2063Shanqing Cai
54fe8406149feec453250905965a14285465cd2063Shanqing Cai  def testPeriodicResampleBasic3D(self):
55fe8406149feec453250905965a14285465cd2063Shanqing Cai
56ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    input_tensor = numpy.arange(2 * 2 * 4).reshape((2, 2, 4))
57fe8406149feec453250905965a14285465cd2063Shanqing Cai    desired_shape = numpy.array([4, 4, None])
58ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    output_tensor = numpy.array([[[0], [2], [4], [6]], [[1], [3], [5], [7]],
59ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                 [[8], [10], [12], [14]], [[9], [11], [13],
60ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                                           [15]]])
61fe8406149feec453250905965a14285465cd2063Shanqing Cai
62fe8406149feec453250905965a14285465cd2063Shanqing Cai    # NOTE: output_tensor != input_tensor.reshape((4, 4, -1))
63fe8406149feec453250905965a14285465cd2063Shanqing Cai    with self.test_session():
64fe8406149feec453250905965a14285465cd2063Shanqing Cai      variables.global_variables_initializer().run()
65fe8406149feec453250905965a14285465cd2063Shanqing Cai      result = periodic_resample(input_tensor, desired_shape).eval()
66fe8406149feec453250905965a14285465cd2063Shanqing Cai      # input_tensor[0, 0, 0] == result[0, 0, 0]
67fe8406149feec453250905965a14285465cd2063Shanqing Cai      # input_tensor[0, 0, 1] == result[1, 0, 0]
68fe8406149feec453250905965a14285465cd2063Shanqing Cai      # input_tensor[0, 0, 2] == result[0, 1, 0]
69fe8406149feec453250905965a14285465cd2063Shanqing Cai      # input_tensor[0, 0, 3] == result[1, 1, 0]
70fe8406149feec453250905965a14285465cd2063Shanqing Cai      self.assertAllEqual(result, output_tensor)
71fe8406149feec453250905965a14285465cd2063Shanqing Cai
72fe8406149feec453250905965a14285465cd2063Shanqing Cai  def testPeriodicResampleBasic4D(self):
73fe8406149feec453250905965a14285465cd2063Shanqing Cai
74ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    input_tensor = numpy.arange(2 * 2 * 2 * 8).reshape((2, 2, 2, 8))
75fe8406149feec453250905965a14285465cd2063Shanqing Cai    desired_shape = numpy.array([4, 4, 4, None])
76ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    output_tensor = numpy.array(
77ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        [[[[0], [4], [8], [12]], [[2], [6], [10], [14]],
78ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          [[16], [20], [24], [28]], [[18], [22], [26], [30]]],
79ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie         [[[1], [5], [9], [13]], [[3], [7], [11], [15]], [[17], [21], [25],
80ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                                          [29]],
81ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          [[19], [23], [27],
82ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie           [31]]], [[[32], [36], [40], [44]], [[34], [38], [42], [46]],
83ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                    [[48], [52], [56], [60]], [[50], [54], [58], [62]]],
84ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie         [[[33], [37], [41], [45]], [[35], [39], [43], [47]],
85ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          [[49], [53], [57], [61]], [[51], [55], [59], [63]]]])
86fe8406149feec453250905965a14285465cd2063Shanqing Cai
87fe8406149feec453250905965a14285465cd2063Shanqing Cai    # NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1))
88fe8406149feec453250905965a14285465cd2063Shanqing Cai    with self.test_session():
89fe8406149feec453250905965a14285465cd2063Shanqing Cai      variables.global_variables_initializer().run()
90fe8406149feec453250905965a14285465cd2063Shanqing Cai      result = periodic_resample(input_tensor, desired_shape).eval()
91fe8406149feec453250905965a14285465cd2063Shanqing Cai      self.assertAllEqual(result, output_tensor)
92fe8406149feec453250905965a14285465cd2063Shanqing Cai
93d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie  def testPeriodicResampleErrors(self):
94d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
95d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    with self.test_session():
96d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      variables.global_variables_initializer().run()
97d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      with self.assertRaisesWithPredicateMatch(
98d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          errors_impl.InvalidArgumentError,
99d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          'Dimension 3 input tensor has size 4, desired shape has size 1'):
100d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        periodic_resample(input_tensor, [None, 4, 4, 1]).eval()
101d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      with self.assertRaisesWithPredicateMatch(
102d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          errors_impl.InvalidArgumentError,
103d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          '4, to be the same as the length of the desired shape, 3'):
104d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        periodic_resample(input_tensor, [None, 4, 4]).eval()
105d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
106fe8406149feec453250905965a14285465cd2063Shanqing Cai
107ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xieif __name__ == '__main__':
108fe8406149feec453250905965a14285465cd2063Shanqing Cai  googletest.main()
109