12e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 22e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 32e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 42e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinsyou may not use this file except in compliance with the License. 52e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsYou may obtain a copy of the License at 62e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 72e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins http://www.apache.org/licenses/LICENSE-2.0 82e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 92e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsUnless required by applicable law or agreed to in writing, software 102e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 112e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 122e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsSee the License for the specific language governing permissions and 132e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinslimitations under the License. 142e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins==============================================================================*/ 152e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 162e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h" 172e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_compiler.h" 182e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 192e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h" 202e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins#include "tensorflow/core/framework/kernel_def_builder.h" 21e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving#include "tensorflow/core/framework/tensor.pb.h" 222e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 232e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinsnamespace tensorflow { 242e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinsnamespace { 252e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 262e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkinsclass ConstOp : public XlaOpKernel { 272e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins public: 282e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins explicit ConstOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 292e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins const TensorProto* proto = nullptr; 302e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); 312e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins proto_ = *proto; 322e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins OP_REQUIRES( 332e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins ctx, ctx->output_type(0) == proto_.dtype(), 342e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins errors::InvalidArgument("Type mismatch between value (", 352e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins DataTypeString(proto_.dtype()), ") and dtype (", 362e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins DataTypeString(ctx->output_type(0)), ")")); 372e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins OP_REQUIRES_OK(ctx, TensorShape::IsValidShape(proto_.tensor_shape())); 382e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins } 392e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 402e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins void Compile(XlaOpKernelContext* ctx) override { 412e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins TensorShape shape(proto_.tensor_shape()); 422e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 4304a63c763e25c4f21f22d6d27757f4022d138b8dA. Unique TensorFlower if (proto_.dtype() == DT_STRING) { 4404a63c763e25c4f21f22d6d27757f4022d138b8dA. Unique TensorFlower LOG(WARNING) << "Not computing Const of type DT_STRING"; 4504a63c763e25c4f21f22d6d27757f4022d138b8dA. Unique TensorFlower ctx->SetInvalidOutput(0); 4604a63c763e25c4f21f22d6d27757f4022d138b8dA. Unique TensorFlower return; 4704a63c763e25c4f21f22d6d27757f4022d138b8dA. Unique TensorFlower } 482e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins xla::ComputationBuilder* b = ctx->builder(); 492e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 502e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins // To avoid blowups for large constants filled with the same value, 512e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins // recognize that case and emit a scalar broadcast instead. 522e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins if (shape.num_elements() > 1) { 532e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins switch (proto_.dtype()) { 542e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins case DT_BOOL: 552e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins if (proto_.bool_val_size() == 1) { 562e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins ctx->SetOutput(0, 572e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins b->Broadcast(b->ConstantR0<bool>(proto_.bool_val(0)), 582e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins shape.dim_sizes())); 592e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins return; 602e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins } 612e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins break; 622e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins case DT_FLOAT: 632e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins if (proto_.float_val_size() == 1) { 642e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins ctx->SetOutput( 652e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 0, b->Broadcast(b->ConstantR0<float>(proto_.float_val(0)), 662e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins shape.dim_sizes())); 672e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins return; 682e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins } 692e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins break; 702e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins case DT_DOUBLE: 712e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins if (proto_.double_val_size() == 1) { 722e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins ctx->SetOutput( 732e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 0, b->Broadcast(b->ConstantR0<double>(proto_.double_val(0)), 742e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins shape.dim_sizes())); 752e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins return; 762e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins } 772e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins break; 782e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins case DT_INT32: 792e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins if (proto_.int_val_size() == 1) { 802e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins ctx->SetOutput(0, 812e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins b->Broadcast(b->ConstantR0<int32>(proto_.int_val(0)), 822e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins shape.dim_sizes())); 832e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins return; 842e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins } 852e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins break; 862e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins case DT_INT64: 872e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins if (proto_.int64_val_size() == 1) { 882e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins ctx->SetOutput( 892e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 0, b->Broadcast(b->ConstantR0<int64>(proto_.int64_val(0)), 902e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins shape.dim_sizes())); 912e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins return; 922e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins } 932e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins break; 942e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins default: 952e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins break; 962e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins } 972e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins } 982e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 992e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins // General case 1002e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins Tensor tensor(proto_.dtype()); 1012e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_), 1022e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins errors::InvalidArgument("Cannot parse tensor from proto: ", 1032e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins proto_.DebugString())); 1042e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins ctx->SetConstantOutput(0, tensor); 1052e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins } 1062e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 1072e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins private: 1082e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins TensorProto proto_; 1092e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins TF_DISALLOW_COPY_AND_ASSIGN(ConstOp); 1102e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins}; 1112e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 1122e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins// XLA_* devices also register a "real" Const operator so we suppress the 1132e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins// dummy operator using CompilationOnly(). 1142e307b457dfc4df1d82d568b71f9a796bd218084Peter HawkinsREGISTER_XLA_OP(Name("Const").CompilationOnly(), ConstOp); 1152e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins 1162e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins} // namespace 1172e307b457dfc4df1d82d568b71f9a796bd218084Peter Hawkins} // namespace tensorflow 118