reduction_ops_common.cc revision 46737e4e81314f7482bfd6a710f126a27f5d7975
14f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2e0ae9fab61263bf7ec5beaa8256c573f09c744f0Sandrine Bailleux
34f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin GuptaLicensed under the Apache License, Version 2.0 (the "License");
482cb2c1ad9897473743f08437d0a3995bed561b9dp-armyou may not use this file except in compliance with the License.
54f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin GuptaYou may obtain a copy of the License at
64f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta
74f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta    http://www.apache.org/licenses/LICENSE-2.0
835e98e5588d09145f7d0d4d98624f6b75321a187Dan Handley
997043ac98e13a726dbf8b3b41654dca759e3da2cDan HandleyUnless required by applicable law or agreed to in writing, software
10872be88a2916f45d3de38120ede8c8b199b7498fdp-armdistributed under the License is distributed on an "AS IS" BASIS,
11dce74b891e0e6020d0a18384e32f280133631d9bAchin GuptaWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
125f0cdb059d7d5c3a8a834074a7f236b85d014ddeDan HandleySee the License for the specific language governing permissions and
1397043ac98e13a726dbf8b3b41654dca759e3da2cDan Handleylimitations under the License.
144f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta==============================================================================*/
154f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta
164f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta// XLA-specific reduction Ops.
17a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard
18a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
19a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard#include "tensorflow/compiler/tf2xla/type_util.h"
20a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard#include "tensorflow/compiler/tf2xla/xla_helpers.h"
21dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta#include "tensorflow/compiler/xla/literal_util.h"
230c8d4fef28768233f1f46b4d085f904293dffd2cAchin Gupta#include "tensorflow/core/framework/kernel_def_builder.h"
240c8d4fef28768233f1f46b4d085f904293dffd2cAchin Gupta
250c8d4fef28768233f1f46b4d085f904293dffd2cAchin Guptanamespace tensorflow {
26dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta
27872be88a2916f45d3de38120ede8c8b199b7498fdp-armXlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
28872be88a2916f45d3de38120ede8c8b199b7498fdp-arm  const DataType dt = BaseType(input_type(0));
29872be88a2916f45d3de38120ede8c8b199b7498fdp-arm  OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
30a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard
31a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard  OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
32a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard}
33872be88a2916f45d3de38120ede8c8b199b7498fdp-arm
34872be88a2916f45d3de38120ede8c8b199b7498fdp-arm// Return the base case for the reduction. Defaults to zero.
35872be88a2916f45d3de38120ede8c8b199b7498fdp-armxla::ComputationDataHandle XlaReductionOp::InitialValue(
36872be88a2916f45d3de38120ede8c8b199b7498fdp-arm    xla::ComputationBuilder* builder) {
37872be88a2916f45d3de38120ede8c8b199b7498fdp-arm  return XlaHelpers::Zero(builder, input_type(0));
38872be88a2916f45d3de38120ede8c8b199b7498fdp-arm}
39872be88a2916f45d3de38120ede8c8b199b7498fdp-arm
40872be88a2916f45d3de38120ede8c8b199b7498fdp-arm// Unless BuildFinalizer is overridden the reduction has no
41dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta// finalizer.
42dce74b891e0e6020d0a18384e32f280133631d9bAchin Guptaxla::ComputationDataHandle XlaReductionOp::BuildFinalizer(
43dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta    xla::ComputationBuilder* builder,
44a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard    const xla::ComputationDataHandle& reduce_output,
45dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta    int64 num_elements_reduced) {
46dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  return reduce_output;
47dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta}
48dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta
49dce74b891e0e6020d0a18384e32f280133631d9bAchin Guptavoid XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
50dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  const TensorShape data_shape = ctx->InputShape(0);
51a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard  const TensorShape axes_tensor_shape = ctx->InputShape(1);
52a806dad58c4cf752238d7bbffbc9a1ce17f63ceaJeenu Viswambharan  VLOG(1) << "ReductionOp: " << ctx->op_kernel().name();
53dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta
54dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  if (axes_tensor_shape.num_elements() == 0) {
55dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta    // The reduction axes is an empty vector, which means there are no
56a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard    // axes to reduce so just pass the input directly through to the
57a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard    // output.
58a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard    ctx->SetOutput(0, ctx->Input(0));
59a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard    return;
60dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  }
61dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta
620c8d4fef28768233f1f46b4d085f904293dffd2cAchin Gupta  // Evaluate the constant, reshaping to a 1-vector if it is a scalar.
630c8d4fef28768233f1f46b4d085f904293dffd2cAchin Gupta  xla::Literal axes_literal;
640c8d4fef28768233f1f46b4d085f904293dffd2cAchin Gupta  OP_REQUIRES_OK(ctx,
65dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta                 ctx->ConstantInputReshaped(
66dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta                     1, {axes_tensor_shape.num_elements()}, &axes_literal));
67dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta
68a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard  VLOG(1) << "data shape: " << data_shape.DebugString();
695717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta  VLOG(1) << "axes      : " << axes_literal.ToString();
705717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta
715717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta  gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
725717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta  std::vector<int64> xla_axes;
73dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  int64 num_elements_reduced = 1LL;
74dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) {
75dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta    int32 index = axes_literal.Get<int>({i});
76dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta    OP_REQUIRES(ctx,
77dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta                !(index < -data_shape.dims() || index >= data_shape.dims()),
78dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta                errors::InvalidArgument("Invalid reduction dimension (", index,
79dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta                                        " for input with ", data_shape.dims(),
80a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard                                        " dimension(s)"));
81a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard    index = (index + data_shape.dims()) % data_shape.dims();
82a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard    bitmap[index] = true;
83dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta    xla_axes.push_back(index);
849865ac15765f260069047c0e7c56623eb1a70b9aDan Handley    num_elements_reduced *= data_shape.dim_size(index);
85dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  }
86dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta
87dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  std::vector<int64> final_shape;
88dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  for (int i = 0; i < data_shape.dims(); ++i) {
89a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard    if (!bitmap[i]) {
90a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard      // If we are not reducing along dimension i.
915717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta      int64 dim = data_shape.dim_size(i);
92a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard      final_shape.push_back(dim);
93a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard    } else if (keep_dims_) {
945717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta      // We are reducing along dimension i, but we want to keep the
95a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard      // same number of dimensions, so we set the dimension of i to
96a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard      // '1'.
975717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta      final_shape.push_back(1);
98a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard    }
99a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard  }
100a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard
101a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard  string desc = ctx->op_kernel().name();
102a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard
1035717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta  // Call virtual method to get the initial value.
104a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard  const xla::ComputationDataHandle initial = InitialValue(ctx->builder());
105a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard  // Construct the builder for the reduction lambda.
106a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard  xla::ComputationBuilder r(ctx->builder()->client(),
107a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard                            strings::StrCat(desc, "-reduction"));
108dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  xla::PrimitiveType type;
109dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
1105717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta  // Make two scalar parameters of the desired type for the lambda.
111dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  xla::ComputationDataHandle rx =
112dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta      r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x");
113dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  xla::ComputationDataHandle ry =
114dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta      r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y");
115dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta
116dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  auto data = ctx->Input(0);
117dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta
118dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  // Call virtual method to build the reduction lambda.
119dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  BuildReducer(&r, rx, ry);
120dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  xla::Computation reduction_computation = r.Build().ConsumeValueOrDie();
121dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  xla::ComputationDataHandle reduce =
122a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard      ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes);
123b460b8bf23633195535006b29e14c615f888fa24Soby Mathew
124b460b8bf23633195535006b29e14c615f888fa24Soby Mathew  xla::ComputationDataHandle finalized =
125dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta      BuildFinalizer(ctx->builder(), reduce, num_elements_reduced);
126dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta
127dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  xla::ComputationDataHandle result;
128dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  if (keep_dims_) {
129dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta    result = ctx->builder()->Reshape(finalized, final_shape);
130dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  } else {
131dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta    result = finalized;
132dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  }
133dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta  ctx->SetOutput(0, result);
134dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta}
135c3260f9b82c5017ca078f090c03cd7135ee8f8c9Soby Mathew
136c3260f9b82c5017ca078f090c03cd7135ee8f8c9Soby Mathew}  // namespace tensorflow
137c3260f9b82c5017ca078f090c03cd7135ee8f8c9Soby Mathew