11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h" 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h" 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 19a8c325e57c1077f1e8df540a20bd8b36d3d1f968Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h" 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/kernel_def_builder.h" 21cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng#include "tensorflow/core/framework/node_def.pb.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace tensorflow { 241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace { 251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsconst char* const kGradientOp = "SymbolicGradient"; 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Implementations of _ListToArray and _ArrayToList for functions. 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass PassOn : public XlaOpKernel { 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins explicit PassOn(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(), 331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::Internal("#inputs != #outputs : ", ctx->num_inputs(), 341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins " vs. ", ctx->num_outputs())); 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < ctx->num_inputs(); ++i) { 361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES( 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx, input_type(i) == output_type(i), 381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::Internal("Input and output types for position ", i, 391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins " do not match: ", DataTypeString(input_type(i)), 401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins " vs. ", DataTypeString(output_type(i)))); 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins void Compile(XlaOpKernelContext* ctx) override { 451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < ctx->num_inputs(); ++i) { 461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->SetOutput(i, ctx->Input(i)); 471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 5193f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("_ListToArray"), PassOn); 5293f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("_ArrayToList"), PassOn); 531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// TODO(phawkins): this is an almost exact copy of the SymbolicGradientOp 551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// implementation from regular Tensorflow. Once XLA has been open sourced 561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// merge the two implementations. (Note: this implementation propagates the 571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// step_resource_manager). 581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass SymbolicGradientOp : public AsyncOpKernel { 591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public: 601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins explicit SymbolicGradientOp(OpKernelConstruction* ctx) 611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins : AsyncOpKernel(ctx), handle_(-1) {} 621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ~SymbolicGradientOp() override {} 641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { 661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins FunctionLibraryRuntime* lib = ctx->function_library(); 671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES_ASYNC(ctx, lib != nullptr, 681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins errors::Internal("No function library is provided."), 691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins done); 701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins OP_REQUIRES_OK_ASYNC( 7273882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving ctx, lib->Instantiate(kGradientOp, AttrSlice(&def().attr()), &handle_), 7373882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving done); 741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins FunctionLibraryRuntime::Options opts; 761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins opts.step_id = ctx->step_id(); 771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins opts.runner = ctx->runner(); 781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins opts.step_container = ctx->step_container(); 791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<Tensor> args; 801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins args.reserve(ctx->num_inputs()); 811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (int i = 0; i < ctx->num_inputs(); ++i) { 821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins args.push_back(ctx->input(i)); 831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins std::vector<Tensor>* rets = new std::vector<Tensor>; 851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins lib->Run( 861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins opts, handle_, args, rets, [ctx, done, rets](const Status& status) { 871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (!status.ok()) { 881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->SetStatus(status); 891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } else if (rets->size() != ctx->num_outputs()) { 901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->SetStatus(errors::InvalidArgument( 911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "SymGrad expects to return ", ctx->num_outputs(), 921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins " tensor(s), but get ", rets->size(), " tensor(s) instead.")); 931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } else { 941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins for (size_t i = 0; i < rets->size(); ++i) { 951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins ctx->set_output(i, (*rets)[i]); 961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins delete rets; 991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins done(); 1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins }); 1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private: 1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins FunctionLibraryRuntime::Handle handle_; 1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp); 1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}; 1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 10993f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name(kGradientOp), SymbolicGradientOp); 1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace 1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace tensorflow 113