11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h"
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
19a8c325e57c1077f1e8df540a20bd8b36d3d1f968Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/kernel_def_builder.h"
21cf7c008ab150ac8e5edb3ed053d38b2919699796Yifei Feng#include "tensorflow/core/framework/node_def.pb.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace tensorflow {
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsconst char* const kGradientOp = "SymbolicGradient";
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// Implementations of _ListToArray and _ArrayToList for functions.
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass PassOn : public XlaOpKernel {
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  explicit PassOn(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 " vs. ", ctx->num_outputs()));
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < ctx->num_inputs(); ++i) {
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      OP_REQUIRES(
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          ctx, input_type(i) == output_type(i),
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          errors::Internal("Input and output types for position ", i,
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                           " do not match: ", DataTypeString(input_type(i)),
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                           " vs. ", DataTypeString(output_type(i))));
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < ctx->num_inputs(); ++i) {
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ctx->SetOutput(i, ctx->Input(i));
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
5193f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("_ListToArray"), PassOn);
5293f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("_ArrayToList"), PassOn);
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// TODO(phawkins): this is an almost exact copy of the SymbolicGradientOp
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// implementation from regular Tensorflow. Once XLA has been open sourced
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// merge the two implementations. (Note: this implementation propagates the
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// step_resource_manager).
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass SymbolicGradientOp : public AsyncOpKernel {
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  explicit SymbolicGradientOp(OpKernelConstruction* ctx)
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      : AsyncOpKernel(ctx), handle_(-1) {}
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  ~SymbolicGradientOp() override {}
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    FunctionLibraryRuntime* lib = ctx->function_library();
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_ASYNC(ctx, lib != nullptr,
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      errors::Internal("No function library is provided."),
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      done);
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES_OK_ASYNC(
7273882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving        ctx, lib->Instantiate(kGradientOp, AttrSlice(&def().attr()), &handle_),
7373882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving        done);
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    FunctionLibraryRuntime::Options opts;
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    opts.step_id = ctx->step_id();
771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    opts.runner = ctx->runner();
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    opts.step_container = ctx->step_container();
791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<Tensor> args;
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    args.reserve(ctx->num_inputs());
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < ctx->num_inputs(); ++i) {
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      args.push_back(ctx->input(i));
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<Tensor>* rets = new std::vector<Tensor>;
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    lib->Run(
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        opts, handle_, args, rets, [ctx, done, rets](const Status& status) {
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          if (!status.ok()) {
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            ctx->SetStatus(status);
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          } else if (rets->size() != ctx->num_outputs()) {
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            ctx->SetStatus(errors::InvalidArgument(
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                "SymGrad expects to return ", ctx->num_outputs(),
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                " tensor(s), but get ", rets->size(), " tensor(s) instead."));
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          } else {
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            for (size_t i = 0; i < rets->size(); ++i) {
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins              ctx->set_output(i, (*rets)[i]);
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            }
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          }
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          delete rets;
991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          done();
1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        });
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private:
1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  FunctionLibraryRuntime::Handle handle_;
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp);
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
10993f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name(kGradientOp), SymbolicGradientOp);
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace tensorflow
113