1b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
3b89251c6300b9941d06071543e5c4974d0db1984Peter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
4b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkinsyou may not use this file except in compliance with the License.
5b89251c6300b9941d06071543e5c4974d0db1984Peter HawkinsYou may obtain a copy of the License at
6b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
7b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
8b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
9b89251c6300b9941d06071543e5c4974d0db1984Peter HawkinsUnless required by applicable law or agreed to in writing, software
10b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
11b89251c6300b9941d06071543e5c4974d0db1984Peter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12b89251c6300b9941d06071543e5c4974d0db1984Peter HawkinsSee the License for the specific language governing permissions and
13b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkinslimitations under the License.
14b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins==============================================================================*/
15b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
16b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include <vector>
17b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
18b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/compiler/tf2xla/shape_util.h"
19b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h"
20b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h"
21b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/compiler/xla/literal_util.h"
24b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/core/framework/op_kernel.h"
25b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/core/framework/partial_tensor_shape.h"
26b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/core/framework/register_types.h"
27b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/core/framework/tensor.h"
28b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/core/framework/tensor_types.h"
29b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/core/framework/types.h"
30b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/core/kernels/bounds_check.h"
31b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/core/kernels/concat_lib.h"
32b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/core/lib/core/status.h"
33b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins#include "tensorflow/core/platform/types.h"
34b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
35b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkinsnamespace tensorflow {
36b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkinsnamespace {
37b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
386e16af86658cd27b466c7c3ba270338b8f95f184Peter Hawkins// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
396e16af86658cd27b466c7c3ba270338b8f95f184Peter Hawkins// the type constraint.
406e16af86658cd27b466c7c3ba270338b8f95f184Peter Hawkinsconstexpr std::array<DataType, 3> kScanOpTypes = {
416e16af86658cd27b466c7c3ba270338b8f95f184Peter Hawkins    {DT_HALF, DT_BFLOAT16, DT_FLOAT}};
426e16af86658cd27b466c7c3ba270338b8f95f184Peter Hawkins
43b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkinsclass ScanOp : public XlaOpKernel {
44b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins public:
45b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins  ScanOp(OpKernelConstruction* ctx, bool sum) : XlaOpKernel(ctx), sum_(sum) {
46b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_));
47b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_));
48b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins  }
49b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
50b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
51b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    const TensorShape input_shape = ctx->InputShape(0);
52b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    const TensorShape tensor_axis_shape = ctx->InputShape(1);
53b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
54b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis_shape),
55b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins                errors::InvalidArgument("ScanOp: axis must be a scalar, not ",
56b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins                                        tensor_axis_shape.DebugString()));
57b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
58b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    int64 axis;
59b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &axis));
60b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    if (axis < 0) {
61b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      axis += input_shape.dims();
62b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    }
63b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    OP_REQUIRES(
64b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins        ctx, FastBoundsCheck(axis, input_shape.dims()),
65b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins        errors::InvalidArgument("ScanOp: Expected scan axis in the range [",
66b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins                                -input_shape.dims(), ", ", input_shape.dims(),
67b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins                                "), but got ", axis));
68b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
69b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    DataType dtype = ctx->input_type(0);
70b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
71b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    if (input_shape.num_elements() == 0) {
72b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      // Exit early if there is nothing to compute.
73b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      ctx->SetOutput(0, ctx->Input(0));
74b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      return;
75b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    }
76b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
77b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    xla::ComputationBuilder* builder = ctx->builder();
78b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
79b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    std::vector<int64> window_strides(input_shape.dims(), 1);
80b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    std::vector<int64> window_dims(input_shape.dims(), 1);
81b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    window_dims[axis] = input_shape.dim_size(axis);
82b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
83b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    std::vector<std::pair<int64, int64>> padding(input_shape.dims(), {0, 0});
84b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    padding[axis].first = input_shape.dim_size(axis) - 1;
85b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    // In exclusive mode, add an extra padding element so there is a complete
86b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    // window of padding before the data starts.
87b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    if (exclusive_) {
88b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      ++padding[axis].first;
89b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    }
90b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    if (reverse_) {
91b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      std::swap(padding[axis].first, padding[axis].second);
92b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    }
93b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
94b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    xla::ComputationDataHandle input = ctx->Input(0);
95b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    xla::ComputationDataHandle init;
96b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    const xla::Computation* reducer;
97b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    if (sum_) {
98b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      init = XlaHelpers::Zero(builder, dtype);
99b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      reducer = ctx->GetOrCreateAdd(dtype);
100b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    } else {
101b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      init = XlaHelpers::One(builder, dtype);
102b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      reducer = ctx->GetOrCreateMul(dtype);
103b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    }
104b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    auto output = builder->ReduceWindowWithGeneralPadding(
105b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins        ctx->Input(0), init, *reducer, window_dims, window_strides, padding);
106b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
107b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    // In exclusive mode, we have computed an extra element containing the sum
108b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    // of all the input elements. Slice off this extra "last" element.
109b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    if (exclusive_) {
110b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      if (reverse_) {
111b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins        output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1,
112b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins                                     1, axis);
113b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
114b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      } else {
115b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins        output =
116b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins            builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis);
117b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins      }
118b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    }
119b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins    ctx->SetOutput(0, output);
120b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins  }
121b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
122b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins private:
123b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins  const bool sum_;  // True=cumulative sum. False=cumulative product.
124b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins  bool reverse_;
125b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins  bool exclusive_;
126b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins};
127b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
128b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkinsclass CumsumOp : public ScanOp {
129b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins public:
130b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins  explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {}
131b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins};
132c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("Cumsum")
133c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                    .TypeConstraint("T", kScanOpTypes)
134c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                    .CompileTimeConstInput("axis"),
135c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                CumsumOp);
136b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
137b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkinsclass CumprodOp : public ScanOp {
138b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins public:
139b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins  explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {}
140b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins};
141c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("Cumprod")
142c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                    .TypeConstraint("T", kScanOpTypes)
143c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                    .CompileTimeConstInput("axis"),
144c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter Hawkins                CumprodOp);
145b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins
146b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins}  // anonymous namespace
147b89251c6300b9941d06071543e5c4974d0db1984Peter Hawkins}  // namespace tensorflow
148