1f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 3f424ca38712a87aeaf614af454d96b5d155592caPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 4f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinsyou may not use this file except in compliance with the License. 5f424ca38712a87aeaf614af454d96b5d155592caPeter HawkinsYou may obtain a copy of the License at 6f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 7f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 8f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 9f424ca38712a87aeaf614af454d96b5d155592caPeter HawkinsUnless required by applicable law or agreed to in writing, software 10f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 11f424ca38712a87aeaf614af454d96b5d155592caPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12f424ca38712a87aeaf614af454d96b5d155592caPeter HawkinsSee the License for the specific language governing permissions and 13f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinslimitations under the License. 14f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins==============================================================================*/ 15f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 16f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins// XLA implementation of OneHot operator. 17f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 18f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins#include "tensorflow/compiler/tf2xla/literal_util.h" 19eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h" 20f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 21f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 23f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinsnamespace tensorflow { 24f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinsnamespace { 25f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 26f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinsclass OneHotOp : public XlaOpKernel { 27f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins public: 28f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins explicit OneHotOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 29f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_)); 30f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins } 31f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 32f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins void Compile(XlaOpKernelContext* ctx) override { 33f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins const TensorShape indices_shape = ctx->InputShape(0); 34f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins const TensorShape depth_shape = ctx->InputShape(1); 35f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins const TensorShape on_value_shape = ctx->InputShape(2); 36f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins const TensorShape off_value_shape = ctx->InputShape(3); 37f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 38f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins const int indices_dims = indices_shape.dims(); 39f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins const int output_dims = indices_dims + 1; 40f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 41f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins // Preliminary validation of sizes. 42f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins OP_REQUIRES( 43f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins ctx, axis_ == -1 || (axis_ >= 0 && axis_ < output_dims), 44f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins errors::InvalidArgument("Expected axis to be -1 or between [0, ", 45f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins output_dims, "). But received: ", axis_)); 46f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(depth_shape), 47f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins errors::InvalidArgument("depth must be a scalar, but got: ", 48f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins depth_shape.DebugString())); 49f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(on_value_shape), 50f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins errors::InvalidArgument("on_value must be a scalar, but got: ", 51f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins on_value_shape.DebugString())); 52f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(off_value_shape), 53f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins errors::InvalidArgument("off_value must be a scalar, but got: ", 54f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins off_value_shape.DebugString())); 55f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 56f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins const int axis = (axis_ == -1) ? indices_dims : axis_; 57f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 58f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins // The one-hot dimension. 59f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins int64 depth; 60f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &depth)); 61f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins OP_REQUIRES( 62f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins ctx, depth >= 0, 63f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins errors::InvalidArgument("depth must be non-negative, got: ", depth)); 64f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 65eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins xla::ComputationDataHandle one_hot; 66eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins OP_REQUIRES_OK( 67eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0), 68eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins indices_shape, ctx->Input(0), ctx->Input(2), 69eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins ctx->Input(3), &one_hot)); 70eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins ctx->SetOutput(0, one_hot); 71f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins } 72f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 73f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins private: 74f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins int32 axis_; 75f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 76f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp); 77f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins}; 78f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 79c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("OneHot").CompileTimeConstInput("depth"), OneHotOp); 80f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins 81f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins} // namespace 82f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins} // namespace tensorflow 83