1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#if GOOGLE_CUDA
21#define EIGEN_USE_GPU
22#endif  // GOOGLE_CUDA
23
24#include "tensorflow/core/kernels/reverse_sequence_op.h"
25
26#include <memory>
27#include <vector>
28#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29#include "tensorflow/core/framework/op_kernel.h"
30#include "tensorflow/core/framework/register_types.h"
31#include "tensorflow/core/framework/tensor.h"
32#include "tensorflow/core/framework/tensor_shape.h"
33#include "tensorflow/core/framework/tensor_types.h"
34#include "tensorflow/core/framework/types.h"
35#include "tensorflow/core/platform/logging.h"
36#include "tensorflow/core/platform/macros.h"
37
38namespace tensorflow {
39
40typedef Eigen::ThreadPoolDevice CPUDevice;
41typedef Eigen::GpuDevice GPUDevice;
42
43template <typename Device, typename Tlen>
44void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
45  const Tensor& input = context->input(0);
46  const Tensor& seq_lens = context->input(1);
47
48  auto seq_lens_t = seq_lens.vec<Tlen>();
49
50  std::vector<Tlen> seq_lens_vec(seq_lens_t.size());
51
52  // Copy seq_len info down for validity checks
53  context->eigen_device<Device>().memcpyDeviceToHost(
54      seq_lens_vec.data(), seq_lens_t.data(), sizeof(Tlen) * seq_lens_t.size());
55
56  OP_REQUIRES(context, batch_dim != seq_dim,
57              errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
58  OP_REQUIRES(context, seq_dim < input.dims(),
59              errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
60                                      seq_dim, " vs. ", input.dims(), ")"));
61  OP_REQUIRES(context, batch_dim < input.dims(),
62              errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
63                                      batch_dim, " vs. ", input.dims(), ")"));
64  OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
65              errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
66                                      "), ", "(", seq_lens.NumElements(),
67                                      " vs. ", input.dim_size(batch_dim)));
68
69  for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
70    OP_REQUIRES(context, seq_lens_vec[d] >= 0,
71                errors::InvalidArgument("seq_lens(", d, ") < 0"));
72    OP_REQUIRES(context, seq_lens_vec[d] <= input.dim_size(seq_dim),
73                errors::InvalidArgument("seq_lens(", d, ") > input.dims(",
74                                        seq_dim, ")"));
75  }
76}
77
78void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
79  const Tensor& input = context->input(0);
80  const Tensor& seq_lens = context->input(1);
81
82  OP_REQUIRES(context, batch_dim != seq_dim,
83              errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
84  OP_REQUIRES(context, seq_dim < input.dims(),
85              errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
86                                      seq_dim, " vs. ", input.dims(), ")"));
87  OP_REQUIRES(context, batch_dim < input.dims(),
88              errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
89                                      batch_dim, " vs. ", input.dims(), ")"));
90
91  OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
92              errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
93                                      "), ", "(", seq_lens.NumElements(),
94                                      " vs. ", input.dim_size(batch_dim)));
95}
96
97template <>
98void CheckErrors<GPUDevice, int32>(OpKernelContext* context, int batch_dim,
99                                   int seq_dim) {
100  CheckErrorsGPU(context, batch_dim, seq_dim);
101}
102
103template <>
104void CheckErrors<GPUDevice, int64>(OpKernelContext* context, int batch_dim,
105                                   int seq_dim) {
106  CheckErrorsGPU(context, batch_dim, seq_dim);
107}
108
109template <typename Device, typename T, typename Tlen>
110class ReverseSequenceOp : public OpKernel {
111 public:
112  explicit ReverseSequenceOp(OpKernelConstruction* context)
113      : OpKernel(context) {
114    OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
115    OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
116  }
117
118  void Compute(OpKernelContext* context) override {
119    const Tensor& input = context->input(0);
120    const Tensor& seq_lens = context->input(1);
121
122    // Preliminary validation of sizes.
123    OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens.shape()),
124                errors::InvalidArgument("seq_lens input must be 1-dim, not ",
125                                        seq_lens.dims()));
126
127    auto seq_lens_t = seq_lens.vec<Tlen>();
128
129    CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
130
131    const int input_dims = input.dims();
132
133    Tensor* output = nullptr;
134    OP_REQUIRES_OK(context,
135                   context->allocate_output(0, input.shape(), &output));
136
137#define HANDLE_DIM(NDIM)                                                      \
138  case NDIM:                                                                  \
139    functor::ReverseSequence<Device, T, Tlen, NDIM>::Compute(                 \
140        context->eigen_device<Device>(), input.tensor<T, NDIM>(), batch_dim_, \
141        seq_dim_, seq_lens_t, output->tensor<T, NDIM>());                     \
142    break;
143
144    switch (input_dims) {
145      HANDLE_DIM(2);
146      HANDLE_DIM(3);
147      HANDLE_DIM(4);
148      HANDLE_DIM(5);
149
150      default:
151        OP_REQUIRES(context, false,
152                    errors::InvalidArgument(
153                        "ReverseSequenceOp : Unhandled input dimensions: ",
154                        input_dims));
155    }
156  }
157
158 private:
159  int32 batch_dim_;
160  int32 seq_dim_;
161
162  TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp);
163};
164
165#define REGISTER_REVERSE_SEQUENCE(type, len_type)                \
166  REGISTER_KERNEL_BUILDER(Name("ReverseSequence")                \
167                              .Device(DEVICE_CPU)                \
168                              .TypeConstraint<type>("T")         \
169                              .TypeConstraint<len_type>("Tlen"), \
170                          ReverseSequenceOp<CPUDevice, type, len_type>);
171
172#define REGISTER_REVERSE_SEQUENCE_LEN(type) \
173  REGISTER_REVERSE_SEQUENCE(type, int32);   \
174  REGISTER_REVERSE_SEQUENCE(type, int64);
175
176TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_LEN);
177TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_LEN);
178
179#if GOOGLE_CUDA
180
181// Forward declarations of the functor specializations for GPU.
182namespace functor {
183#define DECLARE_GPU_SPEC(T, Tlen, Dims)                                \
184  template <>                                                          \
185  void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute(             \
186      const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
187      int32 batch_dim, int32 seq_dim,                                  \
188      typename TTypes<Tlen>::ConstVec seq_lens,                        \
189      typename TTypes<T, Dims>::Tensor output);                        \
190  extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>;
191
192#define DECLARE_GPU_SPEC_LEN(T, Dims) \
193  DECLARE_GPU_SPEC(T, int32, Dims);   \
194  DECLARE_GPU_SPEC(T, int64, Dims);
195
196#define DECLARE_GPU_SPECS(T)  \
197  DECLARE_GPU_SPEC_LEN(T, 2); \
198  DECLARE_GPU_SPEC_LEN(T, 3); \
199  DECLARE_GPU_SPEC_LEN(T, 4); \
200  DECLARE_GPU_SPEC_LEN(T, 5);
201
202TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
203TF_CALL_bool(DECLARE_GPU_SPECS);
204
205}  // namespace functor
206
207// Registration of the GPU implementations.
208#define REGISTER_REVERSE_SEQUENCE_GPU(type, len_type)            \
209  REGISTER_KERNEL_BUILDER(Name("ReverseSequence")                \
210                              .Device(DEVICE_GPU)                \
211                              .TypeConstraint<type>("T")         \
212                              .TypeConstraint<len_type>("Tlen"), \
213                          ReverseSequenceOp<GPUDevice, type, len_type>);
214
215#define REGISTER_REVERSE_SEQUENCE_GPU_LEN(type) \
216  REGISTER_REVERSE_SEQUENCE_GPU(type, int32);   \
217  REGISTER_REVERSE_SEQUENCE_GPU(type, int64);
218
219TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU_LEN);
220TF_CALL_bool(REGISTER_REVERSE_SEQUENCE_GPU_LEN);
221
222#undef REGISTER_REVERSE_SEQUENCE_GPU
223
224#endif  // GOOGLE_CUDA
225
226}  // namespace tensorflow
227