slice_op_test.py revision 0cf9ed3a719c0782695154d5a0bca260001cec15
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Functional tests for slice op."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22from six.moves import xrange  # pylint: disable=redefined-builtin
23import tensorflow as tf
24
25
26class SliceTest(tf.test.TestCase):
27
28  def _testEmpty(self, use_gpu):
29    inp = np.random.rand(4, 4).astype("f")
30    for k in xrange(4):
31      with self.test_session(use_gpu=use_gpu):
32        a = tf.constant(inp, shape=[4, 4], dtype=tf.float32)
33        slice_t = a[2, k:k]
34        slice_val = slice_t.eval()
35      self.assertAllEqual(slice_val, inp[2, k:k])
36
37  def testEmptyAll(self):
38    self._testEmpty(use_gpu=False)
39    self._testEmpty(use_gpu=True)
40
41  def _testInt32(self, use_gpu):
42    inp = np.random.rand(4, 4).astype("i")
43    for k in xrange(4):
44      with self.test_session(use_gpu=use_gpu):
45        a = tf.constant(inp, shape=[4, 4], dtype=tf.int32)
46        slice_t = a[2, k:k]
47        slice_val = slice_t.eval()
48      self.assertAllEqual(slice_val, inp[2, k:k])
49
50  def testInt32(self):
51    self._testEmpty(use_gpu=False)
52    self._testEmpty(use_gpu=True)
53
54  def _testSelectAll(self, use_gpu):
55    with self.test_session(use_gpu=use_gpu):
56      inp = np.random.rand(4, 4, 4, 4).astype("f")
57      a = tf.constant(inp, shape=[4, 4, 4, 4],
58                               dtype=tf.float32)
59
60      slice_explicit_t = tf.slice(a, [0, 0, 0, 0], [-1, -1, -1, -1])
61      slice_implicit_t = a[:, :, :, :]
62
63      self.assertAllEqual(inp, slice_explicit_t.eval())
64      self.assertAllEqual(inp, slice_implicit_t.eval())
65      self.assertEqual(inp.shape, slice_explicit_t.get_shape())
66      self.assertEqual(inp.shape, slice_implicit_t.get_shape())
67
68  def testSelectAll(self):
69    for _ in range(10):
70      self._testSelectAll(use_gpu=False)
71      self._testSelectAll(use_gpu=True)
72
73  def _testSingleDimension(self, use_gpu):
74    with self.test_session(use_gpu=use_gpu):
75      inp = np.random.rand(10).astype("f")
76      a = tf.constant(inp, shape=[10], dtype=tf.float32)
77
78      hi = np.random.randint(0, 9)
79      scalar_t = a[hi]
80      scalar_val = scalar_t.eval()
81      self.assertAllEqual(scalar_val, inp[hi])
82
83      if hi > 0:
84        lo = np.random.randint(0, hi)
85      else:
86        lo = 0
87      slice_t = a[lo:hi]
88      slice_val = slice_t.eval()
89      self.assertAllEqual(slice_val, inp[lo:hi])
90
91  def testSingleDimension(self):
92    for _ in range(10):
93      self._testSingleDimension(use_gpu=False)
94      self._testSingleDimension(use_gpu=True)
95
96  def _testSliceMatrixDim0(self, x, begin, size, use_gpu):
97    with self.test_session(use_gpu=use_gpu):
98      tf_ans = tf.slice(x, [begin, 0], [size, x.shape[1]]).eval()
99    np_ans = x[begin:begin+size, :]
100    self.assertAllEqual(tf_ans, np_ans)
101
102  def testSliceMatrixDim0(self):
103    for use_gpu in [False, True]:
104      x = np.random.rand(8, 4).astype("f")
105      self._testSliceMatrixDim0(x, 1, 2, use_gpu)
106      self._testSliceMatrixDim0(x, 3, 3, use_gpu)
107      y = np.random.rand(8, 7).astype("f")    # 7 * sizeof(float) is not aligned
108      self._testSliceMatrixDim0(y, 1, 2, use_gpu)
109      self._testSliceMatrixDim0(y, 3, 3, use_gpu)
110
111  def _testIndexAndSlice(self, use_gpu):
112    with self.test_session(use_gpu=use_gpu):
113      inp = np.random.rand(4, 4).astype("f")
114      a = tf.constant(inp, shape=[4, 4], dtype=tf.float32)
115
116      x, y = np.random.randint(0, 3, size=2).tolist()
117      slice_t = a[x, 0:y]
118      slice_val = slice_t.eval()
119    self.assertAllEqual(slice_val, inp[x, 0:y])
120
121  def testSingleElementAll(self):
122    for _ in range(10):
123      self._testIndexAndSlice(use_gpu=False)
124      self._testIndexAndSlice(use_gpu=True)
125
126  def _testSimple(self, use_gpu):
127    with self.test_session(use_gpu=use_gpu) as sess:
128      inp = np.random.rand(4, 4).astype("f")
129      a = tf.constant([float(x) for x in inp.ravel(order="C")],
130                               shape=[4, 4], dtype=tf.float32)
131      slice_t = tf.slice(a, [0, 0], [2, 2])
132      slice2_t = a[:2, :2]
133      slice_val, slice2_val = sess.run([slice_t, slice2_t])
134    self.assertAllEqual(slice_val, inp[:2, :2])
135    self.assertAllEqual(slice2_val, inp[:2, :2])
136    self.assertEqual(slice_val.shape, slice_t.get_shape())
137    self.assertEqual(slice2_val.shape, slice2_t.get_shape())
138
139  def testSimpleAll(self):
140    self._testSimple(use_gpu=False)
141    self._testSimple(use_gpu=True)
142
143  def _testComplex(self, use_gpu):
144    with self.test_session(use_gpu=use_gpu):
145      inp = np.random.rand(4, 10, 10, 4).astype("f")
146      a = tf.constant(inp, dtype=tf.float32)
147
148      x = np.random.randint(0, 9)
149      z = np.random.randint(0, 9)
150      if z > 0:
151        y = np.random.randint(0, z)
152      else:
153        y = 0
154      slice_t = a[:, x, y:z, :]
155      self.assertAllEqual(slice_t.eval(), inp[:, x, y:z, :])
156
157  def testComplex(self):
158    for _ in range(10):
159      self._testComplex(use_gpu=False)
160      self._testComplex(use_gpu=True)
161
162  def _RunAndVerifyResult(self, use_gpu):
163    # Random dims of rank 6
164    input_shape = np.random.randint(0, 20, size=6)
165    inp = np.random.rand(*input_shape).astype("f")
166    with self.test_session(use_gpu=use_gpu) as sess:
167      a = tf.constant([float(x) for x in inp.ravel(order="C")],
168                               shape=input_shape, dtype=tf.float32)
169      indices = [0 if x == 0 else np.random.randint(x) for x in input_shape]
170      sizes = [np.random.randint(0, input_shape[i] - indices[i] + 1)
171               for i in range(6)]
172      slice_t = tf.slice(a, indices, sizes)
173      slice2_t = a[indices[0]:indices[0]+sizes[0],
174                   indices[1]:indices[1]+sizes[1],
175                   indices[2]:indices[2]+sizes[2],
176                   indices[3]:indices[3]+sizes[3],
177                   indices[4]:indices[4]+sizes[4],
178                   indices[5]:indices[5]+sizes[5]]
179
180      slice_val, slice2_val = sess.run([slice_t, slice2_t])
181
182    expected_val = inp[indices[0]:indices[0]+sizes[0],
183                       indices[1]:indices[1]+sizes[1],
184                       indices[2]:indices[2]+sizes[2],
185                       indices[3]:indices[3]+sizes[3],
186                       indices[4]:indices[4]+sizes[4],
187                       indices[5]:indices[5]+sizes[5]]
188    self.assertAllEqual(slice_val, expected_val)
189    self.assertAllEqual(slice2_val, expected_val)
190    self.assertEqual(expected_val.shape, slice_t.get_shape())
191    self.assertEqual(expected_val.shape, slice2_t.get_shape())
192
193  def testRandom(self):
194    for _ in range(10):
195      self._RunAndVerifyResult(use_gpu=False)
196      self._RunAndVerifyResult(use_gpu=True)
197
198  def _testGradientSlice(self, input_shape, slice_begin, slice_size, use_gpu):
199    with self.test_session(use_gpu=use_gpu):
200      num_inputs = np.prod(input_shape)
201      num_grads = np.prod(slice_size)
202      inp = np.random.rand(num_inputs).astype("f").reshape(input_shape)
203      a = tf.constant([float(x) for x in inp.ravel(order="C")],
204                               shape=input_shape, dtype=tf.float32)
205      slice_t = tf.slice(a, slice_begin, slice_size)
206      grads = np.random.rand(num_grads).astype("f").reshape(slice_size)
207      grad_tensor = tf.constant(grads)
208      grad = tf.gradients(slice_t, [a], grad_tensor)[0]
209      result = grad.eval()
210
211    # Create a zero tensor of the input shape ane place
212    # the grads into the right location to compare against TensorFlow.
213    np_ans = np.zeros(input_shape)
214    slices = []
215    for i in xrange(len(input_shape)):
216      slices.append(slice(slice_begin[i], slice_begin[i] + slice_size[i]))
217    np_ans[slices] = grads
218
219    self.assertAllClose(np_ans, result)
220
221  def _testGradientVariableSize(self, use_gpu):
222    with self.test_session(use_gpu=use_gpu):
223      inp = tf.constant([1.0, 2.0, 3.0], name="in")
224      out = tf.slice(inp, [1], [-1])
225      grad_actual = tf.gradients(out, inp)[0].eval()
226    self.assertAllClose([0., 1., 1.], grad_actual)
227
228  def _testGradientsSimple(self, use_gpu):
229    # Slice the middle square out of a 4x4 input
230    self._testGradientSlice([4, 4], [1, 1], [2, 2], use_gpu)
231
232    # Slice the upper left square out of a 4x4 input
233    self._testGradientSlice([4, 4], [0, 0], [2, 2], use_gpu)
234
235    # Slice a non-square input starting from (2,1)
236    self._testGradientSlice([4, 4], [2, 1], [1, 2], use_gpu)
237
238    # Slice a 3D tensor
239    self._testGradientSlice([3, 3, 3], [0, 1, 0], [2, 1, 1], use_gpu)
240
241    # Use -1 as a slice dimension.
242    self._testGradientVariableSize(use_gpu)
243
244  def testGradientsAll(self):
245    self._testGradientsSimple(use_gpu=False)
246    self._testGradientsSimple(use_gpu=True)
247
248  def testNotIterable(self):
249    # NOTE(mrry): If we register __getitem__ as an overloaded
250    # operator, Python will valiantly attempt to iterate over the
251    # Tensor from 0 to infinity.  This test ensures that this
252    # unintended behavior is prevented.
253    c = tf.constant(5.0)
254    with self.assertRaisesWithPredicateMatch(
255        TypeError,
256        lambda e: "'Tensor' object is not iterable" in str(e)):
257      for _ in c:
258        pass
259
260if __name__ == "__main__":
261  tf.test.main()
262