reduction_ops_common.cc revision 46737e4e81314f7482bfd6a710f126a27f5d7975
14f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2e0ae9fab61263bf7ec5beaa8256c573f09c744f0Sandrine Bailleux 34f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin GuptaLicensed under the Apache License, Version 2.0 (the "License"); 482cb2c1ad9897473743f08437d0a3995bed561b9dp-armyou may not use this file except in compliance with the License. 54f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin GuptaYou may obtain a copy of the License at 64f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta 74f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta http://www.apache.org/licenses/LICENSE-2.0 835e98e5588d09145f7d0d4d98624f6b75321a187Dan Handley 997043ac98e13a726dbf8b3b41654dca759e3da2cDan HandleyUnless required by applicable law or agreed to in writing, software 10872be88a2916f45d3de38120ede8c8b199b7498fdp-armdistributed under the License is distributed on an "AS IS" BASIS, 11dce74b891e0e6020d0a18384e32f280133631d9bAchin GuptaWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 125f0cdb059d7d5c3a8a834074a7f236b85d014ddeDan HandleySee the License for the specific language governing permissions and 1397043ac98e13a726dbf8b3b41654dca759e3da2cDan Handleylimitations under the License. 144f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta==============================================================================*/ 154f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta 164f6ad66ae9fcc8bcb3b0fcee10b7ab1ffcaf1a5Achin Gupta// XLA-specific reduction Ops. 17a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard 18a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" 19a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard#include "tensorflow/compiler/tf2xla/type_util.h" 20a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard#include "tensorflow/compiler/tf2xla/xla_helpers.h" 21dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 22dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta#include "tensorflow/compiler/xla/literal_util.h" 230c8d4fef28768233f1f46b4d085f904293dffd2cAchin Gupta#include "tensorflow/core/framework/kernel_def_builder.h" 240c8d4fef28768233f1f46b4d085f904293dffd2cAchin Gupta 250c8d4fef28768233f1f46b4d085f904293dffd2cAchin Guptanamespace tensorflow { 26dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta 27872be88a2916f45d3de38120ede8c8b199b7498fdp-armXlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 28872be88a2916f45d3de38120ede8c8b199b7498fdp-arm const DataType dt = BaseType(input_type(0)); 29872be88a2916f45d3de38120ede8c8b199b7498fdp-arm OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); 30a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard 31a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); 32a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard} 33872be88a2916f45d3de38120ede8c8b199b7498fdp-arm 34872be88a2916f45d3de38120ede8c8b199b7498fdp-arm// Return the base case for the reduction. Defaults to zero. 35872be88a2916f45d3de38120ede8c8b199b7498fdp-armxla::ComputationDataHandle XlaReductionOp::InitialValue( 36872be88a2916f45d3de38120ede8c8b199b7498fdp-arm xla::ComputationBuilder* builder) { 37872be88a2916f45d3de38120ede8c8b199b7498fdp-arm return XlaHelpers::Zero(builder, input_type(0)); 38872be88a2916f45d3de38120ede8c8b199b7498fdp-arm} 39872be88a2916f45d3de38120ede8c8b199b7498fdp-arm 40872be88a2916f45d3de38120ede8c8b199b7498fdp-arm// Unless BuildFinalizer is overridden the reduction has no 41dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta// finalizer. 42dce74b891e0e6020d0a18384e32f280133631d9bAchin Guptaxla::ComputationDataHandle XlaReductionOp::BuildFinalizer( 43dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta xla::ComputationBuilder* builder, 44a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard const xla::ComputationDataHandle& reduce_output, 45dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta int64 num_elements_reduced) { 46dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta return reduce_output; 47dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta} 48dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta 49dce74b891e0e6020d0a18384e32f280133631d9bAchin Guptavoid XlaReductionOp::Compile(XlaOpKernelContext* ctx) { 50dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta const TensorShape data_shape = ctx->InputShape(0); 51a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard const TensorShape axes_tensor_shape = ctx->InputShape(1); 52a806dad58c4cf752238d7bbffbc9a1ce17f63ceaJeenu Viswambharan VLOG(1) << "ReductionOp: " << ctx->op_kernel().name(); 53dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta 54dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta if (axes_tensor_shape.num_elements() == 0) { 55dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta // The reduction axes is an empty vector, which means there are no 56a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard // axes to reduce so just pass the input directly through to the 57a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard // output. 58a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard ctx->SetOutput(0, ctx->Input(0)); 59a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard return; 60dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta } 61dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta 620c8d4fef28768233f1f46b4d085f904293dffd2cAchin Gupta // Evaluate the constant, reshaping to a 1-vector if it is a scalar. 630c8d4fef28768233f1f46b4d085f904293dffd2cAchin Gupta xla::Literal axes_literal; 640c8d4fef28768233f1f46b4d085f904293dffd2cAchin Gupta OP_REQUIRES_OK(ctx, 65dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta ctx->ConstantInputReshaped( 66dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta 1, {axes_tensor_shape.num_elements()}, &axes_literal)); 67dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta 68a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard VLOG(1) << "data shape: " << data_shape.DebugString(); 695717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta VLOG(1) << "axes : " << axes_literal.ToString(); 705717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta 715717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false); 725717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta std::vector<int64> xla_axes; 73dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta int64 num_elements_reduced = 1LL; 74dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { 75dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta int32 index = axes_literal.Get<int>({i}); 76dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta OP_REQUIRES(ctx, 77dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta !(index < -data_shape.dims() || index >= data_shape.dims()), 78dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta errors::InvalidArgument("Invalid reduction dimension (", index, 79dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta " for input with ", data_shape.dims(), 80a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard " dimension(s)")); 81a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard index = (index + data_shape.dims()) % data_shape.dims(); 82a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard bitmap[index] = true; 83dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta xla_axes.push_back(index); 849865ac15765f260069047c0e7c56623eb1a70b9aDan Handley num_elements_reduced *= data_shape.dim_size(index); 85dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta } 86dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta 87dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta std::vector<int64> final_shape; 88dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta for (int i = 0; i < data_shape.dims(); ++i) { 89a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard if (!bitmap[i]) { 90a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard // If we are not reducing along dimension i. 915717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta int64 dim = data_shape.dim_size(i); 92a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard final_shape.push_back(dim); 93a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard } else if (keep_dims_) { 945717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta // We are reducing along dimension i, but we want to keep the 95a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard // same number of dimensions, so we set the dimension of i to 96a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard // '1'. 975717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta final_shape.push_back(1); 98a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard } 99a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard } 100a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard 101a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard string desc = ctx->op_kernel().name(); 102a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard 1035717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta // Call virtual method to get the initial value. 104a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard const xla::ComputationDataHandle initial = InitialValue(ctx->builder()); 105a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard // Construct the builder for the reduction lambda. 106a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard xla::ComputationBuilder r(ctx->builder()->client(), 107a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard strings::StrCat(desc, "-reduction")); 108dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta xla::PrimitiveType type; 109dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); 1105717aae1c34c8ad3b556d65179f1e197c45a41c3Achin Gupta // Make two scalar parameters of the desired type for the lambda. 111dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta xla::ComputationDataHandle rx = 112dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); 113dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta xla::ComputationDataHandle ry = 114dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); 115dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta 116dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta auto data = ctx->Input(0); 117dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta 118dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta // Call virtual method to build the reduction lambda. 119dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta BuildReducer(&r, rx, ry); 120dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta xla::Computation reduction_computation = r.Build().ConsumeValueOrDie(); 121dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta xla::ComputationDataHandle reduce = 122a6ef4393b6f0a346d6629ae01bb34c3d44ae5a08Douglas Raillard ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes); 123b460b8bf23633195535006b29e14c615f888fa24Soby Mathew 124b460b8bf23633195535006b29e14c615f888fa24Soby Mathew xla::ComputationDataHandle finalized = 125dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta BuildFinalizer(ctx->builder(), reduce, num_elements_reduced); 126dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta 127dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta xla::ComputationDataHandle result; 128dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta if (keep_dims_) { 129dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta result = ctx->builder()->Reshape(finalized, final_shape); 130dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta } else { 131dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta result = finalized; 132dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta } 133dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta ctx->SetOutput(0, result); 134dce74b891e0e6020d0a18384e32f280133631d9bAchin Gupta} 135c3260f9b82c5017ca078f090c03cd7135ee8f8c9Soby Mathew 136c3260f9b82c5017ca078f090c03cd7135ee8f8c9Soby Mathew} // namespace tensorflow 137c3260f9b82c5017ca078f090c03cd7135ee8f8c9Soby Mathew