11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// XLA-specific base classes for Unary and Binary Ops.
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/client_library.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/client/computation_builder.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/op_kernel.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/util/bcast.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace tensorflow {
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Coefficient-wise binary operations. Each binary Op expects two
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// inputs that can be broadcast to the same shape. The base class
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// contains pure virtual methods to override: description is a textual
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// description of the operation; and Computation adds the
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// implementation of the operation to a xla::ComputationBuilder. For most
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// arithmetic Ops XLA handles the broadcasting automatically given the input
353e975ea978bac4d861bb09328b06f3c316212611Andrew Harp// tensors.
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass XlaBinaryOp : public XlaOpKernel {
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const DataType lhs = BaseType(input_type(0));
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const DataType rhs = BaseType(input_type(1));
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, lhs == rhs,
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("Input types of binary op must match"));
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ~XlaBinaryOp() override {}
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Implement the (tensor,tensor)->tensor lambda that should be
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // applied to the inputs. The desired computation should be added to
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // 'tc->builder()' and '(lhs,rhs)' are the function's inputs and
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // (lhs_shape,rhs_shape) are their respective
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // shapes. 'broadcast_helper' contains metadata about the shapes of
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // the inputs and the dimensions that need to be broadcast, which
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // may be useful for Ops that can't use standard XLA automatic
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // broadcasting. 'extend_dimension' is non-empty if lhs and rhs have
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // different ranks, and indicates which dimensions of the
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // higher-rank input should be matched when broadcasting the
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // lower-rank input. See comment below and the documentation on broadcasting
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // in the XLA documentation.
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  virtual xla::ComputationDataHandle Computation(
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs,
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      const gtl::ArraySlice<int64>& lhs_shape,
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      const xla::ComputationDataHandle& rhs,
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      const std::vector<int64>& extend_dimensions) = 0;
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override;
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Helper function that performs the broadcasting described by
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // shape.
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  Broadcast(xla::ComputationBuilder* builder,
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            const xla::ComputationDataHandle& lhs,
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            const xla::ComputationDataHandle& rhs,
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            const BCast& broadcast_helper);
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace tensorflow
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#endif  // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
80