1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <vector>
17
18#include "tensorflow/compiler/tf2xla/shape_util.h"
19#include "tensorflow/compiler/tf2xla/type_util.h"
20#include "tensorflow/compiler/tf2xla/xla_helpers.h"
21#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23#include "tensorflow/compiler/xla/literal_util.h"
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/partial_tensor_shape.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/framework/tensor_types.h"
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/kernels/bounds_check.h"
31#include "tensorflow/core/kernels/concat_lib.h"
32#include "tensorflow/core/lib/core/status.h"
33#include "tensorflow/core/platform/types.h"
34
35namespace tensorflow {
36namespace {
37
38// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
39// the type constraint.
40constexpr std::array<DataType, 3> kScanOpTypes = {
41    {DT_HALF, DT_BFLOAT16, DT_FLOAT}};
42
43class ScanOp : public XlaOpKernel {
44 public:
45  ScanOp(OpKernelConstruction* ctx, bool sum) : XlaOpKernel(ctx), sum_(sum) {
46    OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_));
47    OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_));
48  }
49
50  void Compile(XlaOpKernelContext* ctx) override {
51    const TensorShape input_shape = ctx->InputShape(0);
52    const TensorShape tensor_axis_shape = ctx->InputShape(1);
53
54    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis_shape),
55                errors::InvalidArgument("ScanOp: axis must be a scalar, not ",
56                                        tensor_axis_shape.DebugString()));
57
58    int64 axis;
59    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &axis));
60    if (axis < 0) {
61      axis += input_shape.dims();
62    }
63    OP_REQUIRES(
64        ctx, FastBoundsCheck(axis, input_shape.dims()),
65        errors::InvalidArgument("ScanOp: Expected scan axis in the range [",
66                                -input_shape.dims(), ", ", input_shape.dims(),
67                                "), but got ", axis));
68
69    DataType dtype = ctx->input_type(0);
70
71    if (input_shape.num_elements() == 0) {
72      // Exit early if there is nothing to compute.
73      ctx->SetOutput(0, ctx->Input(0));
74      return;
75    }
76
77    xla::ComputationBuilder* builder = ctx->builder();
78
79    std::vector<int64> window_strides(input_shape.dims(), 1);
80    std::vector<int64> window_dims(input_shape.dims(), 1);
81    window_dims[axis] = input_shape.dim_size(axis);
82
83    std::vector<std::pair<int64, int64>> padding(input_shape.dims(), {0, 0});
84    padding[axis].first = input_shape.dim_size(axis) - 1;
85    // In exclusive mode, add an extra padding element so there is a complete
86    // window of padding before the data starts.
87    if (exclusive_) {
88      ++padding[axis].first;
89    }
90    if (reverse_) {
91      std::swap(padding[axis].first, padding[axis].second);
92    }
93
94    xla::ComputationDataHandle input = ctx->Input(0);
95    xla::ComputationDataHandle init;
96    const xla::Computation* reducer;
97    if (sum_) {
98      init = XlaHelpers::Zero(builder, dtype);
99      reducer = ctx->GetOrCreateAdd(dtype);
100    } else {
101      init = XlaHelpers::One(builder, dtype);
102      reducer = ctx->GetOrCreateMul(dtype);
103    }
104    auto output = builder->ReduceWindowWithGeneralPadding(
105        ctx->Input(0), init, *reducer, window_dims, window_strides, padding);
106
107    // In exclusive mode, we have computed an extra element containing the sum
108    // of all the input elements. Slice off this extra "last" element.
109    if (exclusive_) {
110      if (reverse_) {
111        output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1,
112                                     1, axis);
113
114      } else {
115        output =
116            builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis);
117      }
118    }
119    ctx->SetOutput(0, output);
120  }
121
122 private:
123  const bool sum_;  // True=cumulative sum. False=cumulative product.
124  bool reverse_;
125  bool exclusive_;
126};
127
128class CumsumOp : public ScanOp {
129 public:
130  explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {}
131};
132REGISTER_XLA_OP(Name("Cumsum")
133                    .TypeConstraint("T", kScanOpTypes)
134                    .CompileTimeConstInput("axis"),
135                CumsumOp);
136
137class CumprodOp : public ScanOp {
138 public:
139  explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {}
140};
141REGISTER_XLA_OP(Name("Cumprod")
142                    .TypeConstraint("T", kScanOpTypes)
143                    .CompileTimeConstInput("axis"),
144                CumprodOp);
145
146}  // anonymous namespace
147}  // namespace tensorflow
148