1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/framework/register_types.h"
20#include "tensorflow/core/framework/register_types_traits.h"
21#include "tensorflow/core/framework/shape_inference.h"
22#include "tensorflow/core/lib/gtl/array_slice.h"
23#include "tensorflow/core/platform/types.h"
24#include "tensorflow/core/util/work_sharder.h"
25
26namespace tensorflow {
27
28#define EIGEN_USE_THREADS
29using CPUDevice = Eigen::ThreadPoolDevice;
30
31// dim_size - the size of each dimension
32// dim_range - the number of indices over in the flattened tensor
33//    you need to skip in order to make it over from one side of a dimension
34//    to the other. Used to make the shifts wrap around after a threshold.
35// threshold - the index for each dimension that the roll starts to wrap
36//    back to the front
37template <typename T>
38void DoRoll(OpKernelContext* context, const int64 num_elements,
39            const int num_dims, const gtl::ArraySlice<int>& dim_size,
40            const T* input, T* output, const gtl::ArraySlice<int>& threshold,
41            const gtl::ArraySlice<int64>& dim_range) {
42  auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range](
43                  int64 start, int64 end) {
44    // array of indices for each dimension
45    gtl::InlinedVector<int, 4> indices(num_dims);
46    int offset = 0;  // the shift along the flattened tensor for current element
47    // initialize indices and offset
48    for (int i = 0; i < num_dims; i++) {
49      // stride is the number of indices over in the flattened tensor
50      // you need to skip in order to make it over to an adjacent element
51      // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
52      const int64 stride = dim_range[i] / dim_size[i];
53      const int shift = dim_size[i] - threshold[i];
54      const int indx = (start / stride) % dim_size[i];
55      indices[i] = indx;
56      // calculate dimension index after the shift
57      const int shifted_indx = (indx + shift) % dim_size[i];
58      offset += (shifted_indx - indx) * stride;
59    }
60
61    for (int64 i = start; i < end; i++) {
62      output[i + offset] = input[i];
63      // create next combination of indices
64      // while at it adjust offset if needed
65      for (int j = num_dims - 1; j >= 0; j--) {
66        const int indx = (indices[j] + 1) % dim_size[j];
67        indices[j] = indx;
68        if (indx != 0) {
69          if (indx == threshold[j]) {  // we've reached the threshold
70            // dim_range[j] = threshold[j] + shift[j]
71            // offset = shift[j] + ... other offsets
72            // offset - dim_range[j] = -threshold[j] + ... other offsets
73            // thus we undo our previous offset as well as add a new offset of
74            // -threshold[j] in one operation
75            offset -= dim_range[j];  // now wraps around
76          }
77          break;                         // indx != 0 don't need to carry
78        } else if (threshold[j] != 0) {  // if threshold is 0 shift is 0
79          offset += dim_range[j];        // indx became 0 so reverse wrap around
80        }
81      }
82    }
83  };
84  // Shard
85  auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
86  // 15 - expiramentally determined with float and bool types
87  const int cost_per_element = 15 * sizeof(T);  // rough esitmate
88  Shard(worker_threads->num_threads, worker_threads->workers, num_elements,
89        cost_per_element, std::move(work));
90}
91
92// dim_size - the size of each dimension
93// dim_range - the number of indices over in the flattened tensor
94//    you need to skip in order to make it over from one side of a dimension
95//    to the other. Used to make the shifts wrap around after a threshold.
96// threshold - the index for each dimension that the roll starts to wrap
97//    back to the front
98// isd - inner shift dimension
99template <typename T>
100// Use memcpy to copy memory in groups when the data type supports memcpy
101void DoRollWithMemcpy(OpKernelContext* context, const int64 num_elements,
102                      const int num_dims, const gtl::ArraySlice<int>& dim_size,
103                      const T* input, T* output,
104                      const gtl::ArraySlice<int>& threshold,
105                      const gtl::ArraySlice<int64>& dim_range,
106                      const int64 isd) {
107  auto work = [input, output, num_dims, &dim_size, &threshold, &dim_range, isd](
108                  int64 start, int64 end) {
109    // the number of indices over in the flattened tensor you need to skip in
110    // order to make it over from one side of the isd to the other
111    const int64 isd_range = std::max<int>(dim_range[isd], 1);
112    // the distance along the flattend tensor to the next element in the isd
113    const int64 isd_stride = isd_range / std::max<int>(dim_size[isd], 1);
114
115    // start and end represent the i-th group currently so we will convert
116    // them into numbers representing the i-th elements.
117    // there are 2 groups per isd one for all elements before threshold[isd]
118    // and another for all elements after threshold[isd].
119    const int64 start_remainder = (start % 2) * threshold[isd] * isd_stride;
120    const int64 end_remainder = (end % 2) * threshold[isd] * isd_stride;
121    start = (start / 2) * isd_range + start_remainder;
122    end = (end / 2) * isd_range + end_remainder;
123
124    const T* in_ptr = &input[0];
125    T* out_ptr = &output[0];
126    in_ptr += start;
127    out_ptr += start;
128
129    // array of indices for each dimension
130    // indicies = [i, j, k, l, m, n]
131    gtl::InlinedVector<int, 4> indicies(num_dims);
132    // the offset needed to make all inner non-shifting dimensions become 0
133    int64 remainder_offset = 0;
134    // initialize indicies
135    for (int i = 0; i < num_dims; i++) {
136      // stride is the number of indices over in the flattened tensor
137      // you need to skip in order to make it over to an adjacent element
138      // along a dimension. dim_size[i] != 0 because we set it to max(dim, 1)
139      const int64 stride = dim_range[i] / dim_size[i];
140      const int shift = dim_size[i] - threshold[i];
141      const int indx = (start / stride) % dim_size[i];
142      indicies[i] = indx;
143      // calculate dimension index after the shift
144      int out_indx = (indx + shift) % dim_size[i];
145      if (i > isd) {
146        // trailing zeroes for indices after the inner shifted dimension
147        out_indx = 0;
148        remainder_offset += (out_indx - indx) * stride;
149      }
150      out_ptr += (out_indx - indx) * stride;
151    }
152    // set trailing zeroes for indices after the inner shifted dimension
153    for (int i = num_dims - 1; i > isd; i--) indicies[i] = 0;
154
155    // the number of indices in the isd dimension the next group will skip
156    // to make it to the next threshold or end point
157    int isd_indx_skip = 0;
158    // the size of the next group
159    int64 group_size = 0;
160    // initialize isd_indx_skip and group_size
161    if (indicies[isd] < threshold[isd]) {
162      isd_indx_skip = threshold[isd] - indicies[isd];
163      group_size = isd_indx_skip * isd_stride + remainder_offset;
164    } else {
165      isd_indx_skip = dim_size[isd] - indicies[isd];
166      group_size = isd_indx_skip * isd_stride + remainder_offset;
167    }
168
169    int64 i = start;
170    while (i < end) {
171      // copy group of elements
172      memcpy(out_ptr, in_ptr, group_size * sizeof(T));
173
174      // shift i and the pointers over to the next group position
175      i += group_size;
176      out_ptr += group_size;
177      in_ptr += group_size;
178
179      // produce next combination of indices and adjust the out_ptr position
180      // to fix the offset if necessary
181      // the isd (inner shift dim) should skip to next threshold or endpoint
182      // all dimensions to the left increment by 1 when a digit is carried
183      // all dimensions to the right remain set to 0
184      //            +1 +1 +1 +isd_indx_skip
185      // indicies = [i, j, k, l, 0, 0]
186      //                      ^isd
187      for (int j = isd; j >= 0; j--) {
188        int inc = 1;
189        if (j == isd) inc = isd_indx_skip;
190        const int indx = (indicies[j] + inc) % dim_size[j];
191        indicies[j] = indx;
192        if (indx != 0) {
193          if (indx == threshold[j]) {
194            out_ptr -= dim_range[j];  // now wraps around
195          }
196          break;                         // indx != 0 don't need to carry
197        } else if (threshold[j] != 0) {  // if threshold is 0 shift is 0
198          out_ptr += dim_range[j];       // indx became 0 so reverse wrap around
199        }
200      }
201
202      // set isd_indx_skip and group_size for next iteration
203      if (indicies[isd] < threshold[isd]) {
204        isd_indx_skip = threshold[isd] - indicies[isd];
205        group_size = isd_indx_skip * isd_stride;
206      } else {
207        isd_indx_skip = dim_size[isd] - indicies[isd];
208        group_size = isd_indx_skip * isd_stride;
209      }
210    }
211  };
212  // Shard
213  auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
214  const int64 ave_group_size = dim_range[isd] / 2;
215  const int total_work = 2 * num_elements / std::max<int>(dim_range[isd], 1);
216  // 25000 - expiramentally determined with float and bool types
217  const int cost_per_group = 25000 * sizeof(T) * ave_group_size;
218  Shard(worker_threads->num_threads, worker_threads->workers, total_work,
219        cost_per_group, std::move(work));
220}
221
222template <typename Device, typename T, typename Tshift, typename Taxis>
223class RollOp : public OpKernel {
224 public:
225  explicit RollOp(OpKernelConstruction* context) : OpKernel(context) {}
226
227  void Compute(OpKernelContext* context) override {
228    // Grab the input tensor
229    const Tensor& input = context->input(0);
230    const Tensor& shift = context->input(1);
231    const Tensor& axis = context->input(2);
232
233    auto shift_flat = shift.flat<Tshift>();
234    auto axis_flat = axis.flat<Taxis>();
235
236    OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input.shape()),
237                errors::InvalidArgument("input must be 1-D or higher"));
238    OP_REQUIRES(context, shift.shape().dims() <= 1,
239                errors::InvalidArgument(
240                    "shift must be a scalar or a 1-D vector. Found: ",
241                    shift.shape().DebugString()));
242    OP_REQUIRES(context, axis.shape().dims() <= 1,
243                errors::InvalidArgument(
244                    "axis must be a scalar or a 1-D vector. Found: ",
245                    axis.shape().DebugString()));
246    OP_REQUIRES(
247        context, shift.shape() == axis.shape(),
248        errors::InvalidArgument("shift and axis must have the same size"));
249    const int64 num_elements = input.NumElements();
250    const int num_shifts = static_cast<int>(shift_flat.size());
251    const int num_dims = input.dims();
252
253    // if there are any duplicate axes, shift_mod_sum will have the
254    // total modulo sum of shifts for each dimension
255    gtl::InlinedVector<int, 4> shift_mod_sum(num_dims, 0);
256    for (int i = 0; i < num_shifts; i++) {
257      const int axis = axis_flat(i);
258      OP_REQUIRES(context, axis < num_dims,
259                  errors::InvalidArgument("axis ", axis, " is out of range"));
260      const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
261      const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
262      // modulo that works with negatives: ((x % y) + y) % y
263      shift_mod_sum[axis] = (sum % ds + ds) % ds;
264    }
265    // the size of each dimension
266    gtl::InlinedVector<int, 4> dim_size(num_dims);
267    // threshold[i] is the index that the roll starts to wrap back to the front
268    gtl::InlinedVector<int, 4> threshold(num_dims);
269    // dim_range is the number of indices over in the flattened tensor
270    // you need to skip in order to make it over from one side of a dimension
271    // to the other. Used to make the shifts wrap around after a threshold.
272    gtl::InlinedVector<int64, 4> dim_range(num_dims);
273    int64 dim_size_prod = 1;  // dimension size product
274    // inner shift dimension (inner most shifted dimension)
275    int64 isd = 0;
276    for (int i = num_dims - 1; i >= 0; i--) {
277      if (isd == 0 && shift_mod_sum[i] != 0) isd = i;
278      const int ds = std::max<int>(static_cast<int>(input.dim_size(i)), 1);
279      dim_size[i] = ds;
280      threshold[i] = (ds - shift_mod_sum[i]) % ds;
281      dim_size_prod *= static_cast<int64>(input.dim_size(i));
282      dim_range[i] = dim_size_prod;
283    }
284
285    Tensor* output = NULL;
286    OP_REQUIRES_OK(context,
287                   context->allocate_output(0, input.shape(), &output));
288    auto input_flat = input.flat<T>().data();
289    auto output_flat = output->flat<T>().data();
290
291    if (std::is_same<Device, CPUDevice>::value) {
292      if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
293        // V2 copies memory in groups instead of element by element
294        DoRollWithMemcpy<T>(context, num_elements, num_dims, dim_size,
295                            input_flat, output_flat, threshold, dim_range, isd);
296      } else {
297        // incase memcpy does not work for current data type
298        DoRoll<T>(context, num_elements, num_dims, dim_size, input_flat,
299                  output_flat, threshold, dim_range);
300      }
301    }
302  }
303};
304
305// Register the CPU kernels.
306#define REGISTER_CPU(type)                                       \
307  REGISTER_KERNEL_BUILDER(Name("Roll")                           \
308                              .Device(DEVICE_CPU)                \
309                              .TypeConstraint<type>("T")         \
310                              .TypeConstraint<int32>("Tshift")   \
311                              .TypeConstraint<int32>("Taxis"),   \
312                          RollOp<CPUDevice, type, int32, int32>) \
313  REGISTER_KERNEL_BUILDER(Name("Roll")                           \
314                              .Device(DEVICE_CPU)                \
315                              .TypeConstraint<type>("T")         \
316                              .TypeConstraint<int64>("Tshift")   \
317                              .TypeConstraint<int32>("Taxis"),   \
318                          RollOp<CPUDevice, type, int64, int32>) \
319  REGISTER_KERNEL_BUILDER(Name("Roll")                           \
320                              .Device(DEVICE_CPU)                \
321                              .TypeConstraint<type>("T")         \
322                              .TypeConstraint<int32>("Tshift")   \
323                              .TypeConstraint<int64>("Taxis"),   \
324                          RollOp<CPUDevice, type, int32, int64>) \
325  REGISTER_KERNEL_BUILDER(Name("Roll")                           \
326                              .Device(DEVICE_CPU)                \
327                              .TypeConstraint<type>("T")         \
328                              .TypeConstraint<int64>("Tshift")   \
329                              .TypeConstraint<int64>("Taxis"),   \
330                          RollOp<CPUDevice, type, int64, int64>)
331
332TF_CALL_ALL_TYPES(REGISTER_CPU);
333#undef REGISTER_CPU
334}  // namespace tensorflow
335