1f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
3f424ca38712a87aeaf614af454d96b5d155592caPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
4f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinsyou may not use this file except in compliance with the License.
5f424ca38712a87aeaf614af454d96b5d155592caPeter HawkinsYou may obtain a copy of the License at
6f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
7f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
8f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
9f424ca38712a87aeaf614af454d96b5d155592caPeter HawkinsUnless required by applicable law or agreed to in writing, software
10f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
11f424ca38712a87aeaf614af454d96b5d155592caPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12f424ca38712a87aeaf614af454d96b5d155592caPeter HawkinsSee the License for the specific language governing permissions and
13f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinslimitations under the License.
14f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins==============================================================================*/
15f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
16f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins// XLA implementation of OneHot operator.
17f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
18f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins#include "tensorflow/compiler/tf2xla/literal_util.h"
19eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h"
20f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
23f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinsnamespace tensorflow {
24f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinsnamespace {
25f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
26f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkinsclass OneHotOp : public XlaOpKernel {
27f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins public:
28f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins  explicit OneHotOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
29f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
30f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins  }
31f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
32f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
33f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    const TensorShape indices_shape = ctx->InputShape(0);
34f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    const TensorShape depth_shape = ctx->InputShape(1);
35f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    const TensorShape on_value_shape = ctx->InputShape(2);
36f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    const TensorShape off_value_shape = ctx->InputShape(3);
37f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
38f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    const int indices_dims = indices_shape.dims();
39f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    const int output_dims = indices_dims + 1;
40f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
41f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    // Preliminary validation of sizes.
42f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    OP_REQUIRES(
43f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins        ctx, axis_ == -1 || (axis_ >= 0 && axis_ < output_dims),
44f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins        errors::InvalidArgument("Expected axis to be -1 or between [0, ",
45f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins                                output_dims, ").  But received: ", axis_));
46f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(depth_shape),
47f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins                errors::InvalidArgument("depth must be a scalar, but got: ",
48f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins                                        depth_shape.DebugString()));
49f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(on_value_shape),
50f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins                errors::InvalidArgument("on_value must be a scalar, but got: ",
51f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins                                        on_value_shape.DebugString()));
52f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(off_value_shape),
53f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins                errors::InvalidArgument("off_value must be a scalar, but got: ",
54f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins                                        off_value_shape.DebugString()));
55f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
56f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    const int axis = (axis_ == -1) ? indices_dims : axis_;
57f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
58f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    // The one-hot dimension.
59f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    int64 depth;
60f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &depth));
61f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins    OP_REQUIRES(
62f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins        ctx, depth >= 0,
63f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins        errors::InvalidArgument("depth must be non-negative, got: ", depth));
64f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
65eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins    xla::ComputationDataHandle one_hot;
66eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins    OP_REQUIRES_OK(
67eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins        ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0),
68eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins                                indices_shape, ctx->Input(0), ctx->Input(2),
69eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins                                ctx->Input(3), &one_hot));
70eaa668e7e5d28072964ce8b78c155720aed951d3Peter Hawkins    ctx->SetOutput(0, one_hot);
71f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins  }
72f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
73f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins private:
74f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins  int32 axis_;
75f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
76f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins  TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
77f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins};
78f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
79c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0Peter HawkinsREGISTER_XLA_OP(Name("OneHot").CompileTimeConstInput("depth"), OneHotOp);
80f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins
81f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins}  // namespace
82f424ca38712a87aeaf614af454d96b5d155592caPeter Hawkins}  // namespace tensorflow
83