1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.reverse_sequence_op."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.compiler.tests.xla_test import XLATestCase
24from tensorflow.python.framework import dtypes
25from tensorflow.python.ops import array_ops
26from tensorflow.python.platform import test
27
28
29class ReverseSequenceTest(XLATestCase):
30
31  def _testReverseSequence(self,
32                           x,
33                           batch_axis,
34                           seq_axis,
35                           seq_lengths,
36                           truth,
37                           expected_err_re=None):
38    with self.test_session():
39      p = array_ops.placeholder(dtypes.as_dtype(x.dtype))
40      lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype))
41      with self.test_scope():
42        ans = array_ops.reverse_sequence(
43            p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths)
44      if expected_err_re is None:
45        tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths})
46        self.assertAllClose(tf_ans, truth, atol=1e-10)
47      else:
48        with self.assertRaisesOpError(expected_err_re):
49          ans.eval(feed_dict={p: x, lengths: seq_lengths})
50
51  def testSimple(self):
52    x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
53    expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32)
54    self._testReverseSequence(
55        x,
56        batch_axis=0,
57        seq_axis=1,
58        seq_lengths=np.array([1, 3, 2], np.int32),
59        truth=expected)
60
61  def _testBasic(self, dtype, len_dtype):
62    x = np.asarray(
63        [[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]],
64         [[17, 18, 19, 20], [21, 22, 23, 24]]],
65        dtype=dtype)
66    x = x.reshape(3, 2, 4, 1, 1)
67    x = x.transpose([2, 1, 0, 3, 4])  # permute axes 0 <=> 2
68
69    # reverse dim 2 up to (0:3, none, 0:4) along dim=0
70    seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype)
71
72    truth_orig = np.asarray(
73        [
74            [[3, 2, 1, 4], [7, 6, 5, 8]],  # reverse 0:3
75            [[9, 10, 11, 12], [13, 14, 15, 16]],  # reverse none
76            [[20, 19, 18, 17], [24, 23, 22, 21]]
77        ],  # reverse 0:4 (all)
78        dtype=dtype)
79    truth_orig = truth_orig.reshape(3, 2, 4, 1, 1)
80    truth = truth_orig.transpose([2, 1, 0, 3, 4])  # permute axes 0 <=> 2
81
82    seq_axis = 0  # permute seq_axis and batch_axis (originally 2 and 0, resp.)
83    batch_axis = 2
84    self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth)
85
86  def testSeqLength(self):
87    for dtype in self.all_types:
88      for seq_dtype in self.int_types:
89        self._testBasic(dtype, seq_dtype)
90
91
92if __name__ == "__main__":
93  test.main()
94