1fe8406149feec453250905965a14285465cd2063Shanqing Cai# ============================================================================= 2fe8406149feec453250905965a14285465cd2063Shanqing Cai# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3fe8406149feec453250905965a14285465cd2063Shanqing Cai# 4fe8406149feec453250905965a14285465cd2063Shanqing Cai# Licensed under the Apache License, Version 2.0 (the "License"); 5fe8406149feec453250905965a14285465cd2063Shanqing Cai# you may not use this file except in compliance with the License. 6fe8406149feec453250905965a14285465cd2063Shanqing Cai# You may obtain a copy of the License at 7fe8406149feec453250905965a14285465cd2063Shanqing Cai# 8fe8406149feec453250905965a14285465cd2063Shanqing Cai# http://www.apache.org/licenses/LICENSE-2.0 9fe8406149feec453250905965a14285465cd2063Shanqing Cai# 10fe8406149feec453250905965a14285465cd2063Shanqing Cai# Unless required by applicable law or agreed to in writing, software 11fe8406149feec453250905965a14285465cd2063Shanqing Cai# distributed under the License is distributed on an "AS IS" BASIS, 12fe8406149feec453250905965a14285465cd2063Shanqing Cai# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13fe8406149feec453250905965a14285465cd2063Shanqing Cai# See the License for the specific language governing permissions and 14fe8406149feec453250905965a14285465cd2063Shanqing Cai# limitations under the License. 15fe8406149feec453250905965a14285465cd2063Shanqing Cai# ============================================================================= 16fe8406149feec453250905965a14285465cd2063Shanqing Cai 17fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import absolute_import 18fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import division 19fe8406149feec453250905965a14285465cd2063Shanqing Caifrom __future__ import print_function 20fe8406149feec453250905965a14285465cd2063Shanqing Cai 21fe8406149feec453250905965a14285465cd2063Shanqing Caiimport numpy 22d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 23fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.contrib.periodic_resample import periodic_resample 24d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xiefrom tensorflow.python.framework import errors_impl 25fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.framework import test_util 26fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.ops import variables 27fe8406149feec453250905965a14285465cd2063Shanqing Caifrom tensorflow.python.platform import googletest 28fe8406149feec453250905965a14285465cd2063Shanqing Cai 29fe8406149feec453250905965a14285465cd2063Shanqing Cai 30fe8406149feec453250905965a14285465cd2063Shanqing Caiclass PeriodicResampleTest(test_util.TensorFlowTestCase): 31fe8406149feec453250905965a14285465cd2063Shanqing Cai 32fe8406149feec453250905965a14285465cd2063Shanqing Cai def testPeriodicResampleBasic2D(self): 33fe8406149feec453250905965a14285465cd2063Shanqing Cai 34fe8406149feec453250905965a14285465cd2063Shanqing Cai input_tensor = numpy.arange(12).reshape((3, 4)) 35fe8406149feec453250905965a14285465cd2063Shanqing Cai desired_shape = numpy.array([6, None]) 36fe8406149feec453250905965a14285465cd2063Shanqing Cai output_tensor = input_tensor.reshape((6, 2)) 37fe8406149feec453250905965a14285465cd2063Shanqing Cai 38fe8406149feec453250905965a14285465cd2063Shanqing Cai with self.test_session(): 39fe8406149feec453250905965a14285465cd2063Shanqing Cai variables.global_variables_initializer().run() 40fe8406149feec453250905965a14285465cd2063Shanqing Cai result = periodic_resample(input_tensor, desired_shape).eval() 41fe8406149feec453250905965a14285465cd2063Shanqing Cai self.assertAllEqual(result, output_tensor) 42fe8406149feec453250905965a14285465cd2063Shanqing Cai 43fe8406149feec453250905965a14285465cd2063Shanqing Cai def testPeriodicResampleTruncatedBasic2D(self): 44fe8406149feec453250905965a14285465cd2063Shanqing Cai 45fe8406149feec453250905965a14285465cd2063Shanqing Cai input_tensor = numpy.arange(12).reshape((3, 4)) 46fe8406149feec453250905965a14285465cd2063Shanqing Cai desired_shape = numpy.array([5, None]) 47fe8406149feec453250905965a14285465cd2063Shanqing Cai output_tensor = input_tensor.reshape((6, 2))[:-1] 48fe8406149feec453250905965a14285465cd2063Shanqing Cai 49fe8406149feec453250905965a14285465cd2063Shanqing Cai with self.test_session(): 50fe8406149feec453250905965a14285465cd2063Shanqing Cai variables.global_variables_initializer().run() 51fe8406149feec453250905965a14285465cd2063Shanqing Cai result = periodic_resample(input_tensor, desired_shape).eval() 52fe8406149feec453250905965a14285465cd2063Shanqing Cai self.assertAllEqual(result, output_tensor) 53fe8406149feec453250905965a14285465cd2063Shanqing Cai 54fe8406149feec453250905965a14285465cd2063Shanqing Cai def testPeriodicResampleBasic3D(self): 55fe8406149feec453250905965a14285465cd2063Shanqing Cai 56ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie input_tensor = numpy.arange(2 * 2 * 4).reshape((2, 2, 4)) 57fe8406149feec453250905965a14285465cd2063Shanqing Cai desired_shape = numpy.array([4, 4, None]) 58ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie output_tensor = numpy.array([[[0], [2], [4], [6]], [[1], [3], [5], [7]], 59ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [[8], [10], [12], [14]], [[9], [11], [13], 60ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [15]]]) 61fe8406149feec453250905965a14285465cd2063Shanqing Cai 62fe8406149feec453250905965a14285465cd2063Shanqing Cai # NOTE: output_tensor != input_tensor.reshape((4, 4, -1)) 63fe8406149feec453250905965a14285465cd2063Shanqing Cai with self.test_session(): 64fe8406149feec453250905965a14285465cd2063Shanqing Cai variables.global_variables_initializer().run() 65fe8406149feec453250905965a14285465cd2063Shanqing Cai result = periodic_resample(input_tensor, desired_shape).eval() 66fe8406149feec453250905965a14285465cd2063Shanqing Cai # input_tensor[0, 0, 0] == result[0, 0, 0] 67fe8406149feec453250905965a14285465cd2063Shanqing Cai # input_tensor[0, 0, 1] == result[1, 0, 0] 68fe8406149feec453250905965a14285465cd2063Shanqing Cai # input_tensor[0, 0, 2] == result[0, 1, 0] 69fe8406149feec453250905965a14285465cd2063Shanqing Cai # input_tensor[0, 0, 3] == result[1, 1, 0] 70fe8406149feec453250905965a14285465cd2063Shanqing Cai self.assertAllEqual(result, output_tensor) 71fe8406149feec453250905965a14285465cd2063Shanqing Cai 72fe8406149feec453250905965a14285465cd2063Shanqing Cai def testPeriodicResampleBasic4D(self): 73fe8406149feec453250905965a14285465cd2063Shanqing Cai 74ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie input_tensor = numpy.arange(2 * 2 * 2 * 8).reshape((2, 2, 2, 8)) 75fe8406149feec453250905965a14285465cd2063Shanqing Cai desired_shape = numpy.array([4, 4, 4, None]) 76ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie output_tensor = numpy.array( 77ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [[[[0], [4], [8], [12]], [[2], [6], [10], [14]], 78ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [[16], [20], [24], [28]], [[18], [22], [26], [30]]], 79ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [[[1], [5], [9], [13]], [[3], [7], [11], [15]], [[17], [21], [25], 80ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [29]], 81ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [[19], [23], [27], 82ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [31]]], [[[32], [36], [40], [44]], [[34], [38], [42], [46]], 83ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [[48], [52], [56], [60]], [[50], [54], [58], [62]]], 84ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [[[33], [37], [41], [45]], [[35], [39], [43], [47]], 85ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [[49], [53], [57], [61]], [[51], [55], [59], [63]]]]) 86fe8406149feec453250905965a14285465cd2063Shanqing Cai 87fe8406149feec453250905965a14285465cd2063Shanqing Cai # NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1)) 88fe8406149feec453250905965a14285465cd2063Shanqing Cai with self.test_session(): 89fe8406149feec453250905965a14285465cd2063Shanqing Cai variables.global_variables_initializer().run() 90fe8406149feec453250905965a14285465cd2063Shanqing Cai result = periodic_resample(input_tensor, desired_shape).eval() 91fe8406149feec453250905965a14285465cd2063Shanqing Cai self.assertAllEqual(result, output_tensor) 92fe8406149feec453250905965a14285465cd2063Shanqing Cai 93d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie def testPeriodicResampleErrors(self): 94d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie input_tensor = numpy.zeros(shape=[1, 2, 2, 4]) 95d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie with self.test_session(): 96d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie variables.global_variables_initializer().run() 97d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie with self.assertRaisesWithPredicateMatch( 98d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie errors_impl.InvalidArgumentError, 99d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 'Dimension 3 input tensor has size 4, desired shape has size 1'): 100d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie periodic_resample(input_tensor, [None, 4, 4, 1]).eval() 101d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie with self.assertRaisesWithPredicateMatch( 102d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie errors_impl.InvalidArgumentError, 103d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie '4, to be the same as the length of the desired shape, 3'): 104d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie periodic_resample(input_tensor, [None, 4, 4]).eval() 105d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 106fe8406149feec453250905965a14285465cd2063Shanqing Cai 107ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xieif __name__ == '__main__': 108fe8406149feec453250905965a14285465cd2063Shanqing Cai googletest.main() 109