12e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
22e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
32e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
42e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinsyou may not use this file except in compliance with the License.
52e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsYou may obtain a copy of the License at
62e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
72e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
82e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
92e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsUnless required by applicable law or agreed to in writing, software
102e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
112e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsSee the License for the specific language governing permissions and
132e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinslimitations under the License.
142e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins==============================================================================*/
152e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
162e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h"
172e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_compiler.h"
182e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
192e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
202e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins#include "tensorflow/core/framework/kernel_def_builder.h"
21e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving#include "tensorflow/core/framework/tensor.pb.h"
222e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
232e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinsnamespace tensorflow {
242e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinsnamespace {
252e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
262e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinsclass ConstOp : public XlaOpKernel {
272e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins public:
282e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins  explicit ConstOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
292e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    const TensorProto* proto = nullptr;
302e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
312e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    proto_ = *proto;
322e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    OP_REQUIRES(
332e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins        ctx, ctx->output_type(0) == proto_.dtype(),
342e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins        errors::InvalidArgument("Type mismatch between value (",
352e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                                DataTypeString(proto_.dtype()), ") and dtype (",
362e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                                DataTypeString(ctx->output_type(0)), ")"));
372e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    OP_REQUIRES_OK(ctx, TensorShape::IsValidShape(proto_.tensor_shape()));
382e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins  }
392e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
402e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
412e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    TensorShape shape(proto_.tensor_shape());
422e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
4304a63c763e25c4f21f22d6d27757f4022d138b8dA. Unique TensorFlower    if (proto_.dtype() == DT_STRING) {
4404a63c763e25c4f21f22d6d27757f4022d138b8dA. Unique TensorFlower      LOG(WARNING) << "Not computing Const of type DT_STRING";
4504a63c763e25c4f21f22d6d27757f4022d138b8dA. Unique TensorFlower      ctx->SetInvalidOutput(0);
4604a63c763e25c4f21f22d6d27757f4022d138b8dA. Unique TensorFlower      return;
4704a63c763e25c4f21f22d6d27757f4022d138b8dA. Unique TensorFlower    }
482e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    xla::ComputationBuilder* b = ctx->builder();
492e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
502e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    // To avoid blowups for large constants filled with the same value,
512e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    // recognize that case and emit a scalar broadcast instead.
522e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    if (shape.num_elements() > 1) {
532e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins      switch (proto_.dtype()) {
542e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins        case DT_BOOL:
552e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          if (proto_.bool_val_size() == 1) {
562e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins            ctx->SetOutput(0,
572e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                           b->Broadcast(b->ConstantR0<bool>(proto_.bool_val(0)),
582e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                                        shape.dim_sizes()));
592e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins            return;
602e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          }
612e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          break;
622e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins        case DT_FLOAT:
632e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          if (proto_.float_val_size() == 1) {
642e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins            ctx->SetOutput(
652e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                0, b->Broadcast(b->ConstantR0<float>(proto_.float_val(0)),
662e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                                shape.dim_sizes()));
672e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins            return;
682e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          }
692e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          break;
702e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins        case DT_DOUBLE:
712e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          if (proto_.double_val_size() == 1) {
722e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins            ctx->SetOutput(
732e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                0, b->Broadcast(b->ConstantR0<double>(proto_.double_val(0)),
742e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                                shape.dim_sizes()));
752e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins            return;
762e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          }
772e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          break;
782e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins        case DT_INT32:
792e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          if (proto_.int_val_size() == 1) {
802e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins            ctx->SetOutput(0,
812e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                           b->Broadcast(b->ConstantR0<int32>(proto_.int_val(0)),
822e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                                        shape.dim_sizes()));
832e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins            return;
842e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          }
852e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          break;
862e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins        case DT_INT64:
872e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          if (proto_.int64_val_size() == 1) {
882e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins            ctx->SetOutput(
892e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                0, b->Broadcast(b->ConstantR0<int64>(proto_.int64_val(0)),
902e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                                shape.dim_sizes()));
912e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins            return;
922e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          }
932e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          break;
942e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins        default:
952e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins          break;
962e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins      }
972e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    }
982e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
992e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    // General case
1002e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    Tensor tensor(proto_.dtype());
1012e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_),
1022e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                errors::InvalidArgument("Cannot parse tensor from proto: ",
1032e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins                                        proto_.DebugString()));
1042e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins    ctx->SetConstantOutput(0, tensor);
1052e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins  }
1062e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
1072e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins private:
1082e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins  TensorProto proto_;
1092e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins  TF_DISALLOW_COPY_AND_ASSIGN(ConstOp);
1102e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins};
1112e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
1122e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins// XLA_* devices also register a "real" Const operator so we suppress the
1132e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins// dummy operator using CompilationOnly().
1142e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsREGISTER_XLA_OP(Name("Const").CompilationOnly(), ConstOp);
1152e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins
1162e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins}  // namespace
1172e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins}  // namespace tensorflow
118