11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// XLA-specific base classes for Unary and Binary Ops. 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/client_library.h" 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h" 241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/op_kernel.h" 251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/util/bcast.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace tensorflow { 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Coefficient-wise binary operations. Each binary Op expects two 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// inputs that can be broadcast to the same shape. The base class 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// contains pure virtual methods to override: description is a textual 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// description of the operation; and Computation adds the 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// implementation of the operation to a xla::ComputationBuilder. For most 341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// arithmetic Ops XLA handles the broadcasting automatically given the input 353e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// tensors. 361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass XlaBinaryOp : public XlaOpKernel { 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const DataType lhs = BaseType(input_type(0)); 401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const DataType rhs = BaseType(input_type(1)); 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES(ctx, lhs == rhs, 421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::InvalidArgument("Input types of binary op must match")); 431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ~XlaBinaryOp() override {} 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Implement the (tensor,tensor)->tensor lambda that should be 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // applied to the inputs. The desired computation should be added to 481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // 'tc->builder()' and '(lhs,rhs)' are the function's inputs and 491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // (lhs_shape,rhs_shape) are their respective 501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // shapes. 'broadcast_helper' contains metadata about the shapes of 511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // the inputs and the dimensions that need to be broadcast, which 521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // may be useful for Ops that can't use standard XLA automatic 531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // broadcasting. 'extend_dimension' is non-empty if lhs and rhs have 541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // different ranks, and indicates which dimensions of the 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // higher-rank input should be matched when broadcasting the 561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // lower-rank input. See comment below and the documentation on broadcasting 571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // in the XLA documentation. 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins virtual xla::ComputationDataHandle Computation( 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const gtl::ArraySlice<int64>& lhs_shape, 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const xla::ComputationDataHandle& rhs, 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper, 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const std::vector<int64>& extend_dimensions) = 0; 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins void Compile(XlaOpKernelContext* ctx) override; 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Helper function that performs the broadcasting described by 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // shape. 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle> 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Broadcast(xla::ComputationBuilder* builder, 721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const xla::ComputationDataHandle& lhs, 731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const xla::ComputationDataHandle& rhs, 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const BCast& broadcast_helper); 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace tensorflow 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ 80