1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// XLA-specific Shape Ops. 17 18#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" 19#include "tensorflow/compiler/tf2xla/type_util.h" 20#include "tensorflow/compiler/tf2xla/xla_helpers.h" 21#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 22#include "tensorflow/compiler/tf2xla/xla_op_registry.h" 23#include "tensorflow/core/framework/kernel_def_builder.h" 24#include "tensorflow/core/kernels/bounds_check.h" 25 26namespace tensorflow { 27namespace { 28 29class ShapeOp : public XlaOpKernel { 30 public: 31 explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 32 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); 33 } 34 35 void Compile(XlaOpKernelContext* ctx) override { 36 const TensorShape input_shape = ctx->InputShape(0); 37 Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); 38 OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); 39 ctx->SetConstantOutput(0, shape_constant); 40 } 41 42 private: 43 DataType out_dtype_; 44}; 45 46REGISTER_XLA_OP(Name("Shape"), ShapeOp); 47 48class ShapeNOp : public XlaOpKernel { 49 public: 50 explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 51 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); 52 } 53 54 void Compile(XlaOpKernelContext* ctx) override { 55 for (int i = 0; i < ctx->num_inputs(); ++i) { 56 const TensorShape input_shape = ctx->InputShape(i); 57 Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); 58 OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); 59 ctx->SetConstantOutput(i, shape_constant); 60 } 61 } 62 63 bool IsExpensive() override { return false; } 64 65 private: 66 DataType out_dtype_; 67}; 68REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); 69 70class RankOp : public XlaOpKernel { 71 public: 72 explicit RankOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 73 74 void Compile(XlaOpKernelContext* ctx) override { 75 const TensorShape input_shape = ctx->InputShape(0); 76 const int rank = input_shape.dims(); 77 Tensor rank_constant(DT_INT32, TensorShape({})); 78 rank_constant.scalar<int32>()() = rank; 79 80 ctx->SetConstantOutput(0, rank_constant); 81 } 82}; 83 84REGISTER_XLA_OP(Name("Rank"), RankOp); 85 86class SizeOp : public XlaOpKernel { 87 public: 88 explicit SizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 89 90 void Compile(XlaOpKernelContext* ctx) override { 91 const TensorShape input_shape = ctx->InputShape(0); 92 const int64 size = input_shape.num_elements(); 93 OP_REQUIRES(ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()), 94 errors::InvalidArgument("Size does not work for tensors > " 95 "int32 max.")); 96 Tensor size_constant(DT_INT32, TensorShape({})); 97 size_constant.scalar<int32>()() = static_cast<int32>(size); 98 99 ctx->SetConstantOutput(0, size_constant); 100 } 101}; 102 103REGISTER_XLA_OP(Name("Size"), SizeOp); 104 105class ExpandDimsOp : public XlaOpKernel { 106 public: 107 explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 108 109 void Compile(XlaOpKernelContext* ctx) override { 110 const TensorShape input_shape = ctx->InputShape(0); 111 const TensorShape dim_shape = ctx->InputShape(1); 112 113 // TODO(phawkins): the standard implementation of ExpandDimsOp seems to 114 // accept legacy scalars, even when they should be forbidden by the graphdef 115 // version. 116 OP_REQUIRES(ctx, dim_shape.num_elements() == 1, 117 errors::InvalidArgument(strings::StrCat( 118 "dim input to ExpandDims must be a scalar; got ", 119 dim_shape.DebugString()))); 120 121 xla::Literal literal; 122 OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal)); 123 124 int dim = literal.data<int32>()[0]; 125 126 OP_REQUIRES(ctx, 127 (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()), 128 errors::InvalidArgument("Tried to expand dim index ", dim, 129 " for tensor with ", input_shape.dims(), 130 " dimensions.")); 131 132 auto existing_dims = input_shape.dim_sizes(); 133 // Safe - # elements in tensor dims bounded. 134 const int existing_dims_size = static_cast<int>(existing_dims.size()); 135 std::vector<int64> new_shape(existing_dims_size); 136 for (size_t i = 0; i < new_shape.size(); ++i) { 137 new_shape[i] = existing_dims[i]; 138 } 139 140 // We emulate numpy's interpretation of the dim axis when 141 // -input.dims() >= dim <= input.dims(). 142 if (dim < 0) { 143 dim += existing_dims.size() + 1; 144 } 145 146 // Clamp to the end if needed. 147 dim = std::min<int32>(dim, existing_dims_size); 148 new_shape.emplace(new_shape.begin() + dim, 1); 149 150 ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); 151 } 152}; 153REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp); 154 155class SqueezeOp : public XlaOpKernel { 156 public: 157 explicit SqueezeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 158 std::vector<int32> squeeze_dims; 159 OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims)); 160 squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end()); 161 } 162 163 void Compile(XlaOpKernelContext* ctx) override { 164 const TensorShape input_shape = ctx->InputShape(0); 165 auto existing_dims = input_shape.dim_sizes(); 166 int existing_dims_size = input_shape.dims(); 167 std::vector<int64> new_shape; 168 169 std::unordered_set<int32> wrapped_squeeze_dims; 170 wrapped_squeeze_dims.reserve(squeeze_dims_.size()); 171 // Validate squeeze dims against the input. 172 for (int32 dim : squeeze_dims_) { 173 OP_REQUIRES(ctx, (dim >= -input_shape.dims() && dim < input_shape.dims()), 174 errors::InvalidArgument("Tried to squeeze dim index ", dim, 175 " for tensor with ", 176 input_shape.dims(), " dimensions.")); 177 // If dim is < 0, we wrap around (-1 means the last element). 178 if (dim < 0) { 179 dim = existing_dims_size + dim; 180 } 181 182 wrapped_squeeze_dims.insert(dim); 183 } 184 185 for (int i = 0; i < existing_dims_size; ++i) { 186 auto existing_dim = existing_dims[i]; 187 188 // If squeeze_set is non-empty, only squeeze those dimensions. 189 if (!wrapped_squeeze_dims.empty()) { 190 if (wrapped_squeeze_dims.count(i) > 0) { 191 OP_REQUIRES(ctx, existing_dim == 1, 192 errors::InvalidArgument("Tried to explicitly squeeze " 193 "dimension ", 194 i, " but dimension was not 1: ", 195 existing_dim)); 196 } else { 197 // This dimension is not being squeezed. 198 new_shape.push_back(existing_dim); 199 } 200 } else { 201 // Copy over all non-1-length dimensions. 202 if (existing_dim != 1) { 203 new_shape.push_back(existing_dim); 204 } 205 } 206 } 207 208 ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); 209 } 210 211 private: 212 std::unordered_set<int32> squeeze_dims_; 213}; 214 215REGISTER_XLA_OP(Name("Squeeze"), SqueezeOp); 216 217class ZerosLikeOp : public XlaOpKernel { 218 public: 219 explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 220 221 void Compile(XlaOpKernelContext* ctx) override { 222 const TensorShape input_shape = ctx->InputShape(0); 223 224 auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); 225 ctx->SetOutput(0, ctx->builder()->Broadcast(zero, input_shape.dim_sizes())); 226 } 227}; 228 229REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp); 230 231class OnesLikeOp : public XlaOpKernel { 232 public: 233 explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 234 235 void Compile(XlaOpKernelContext* ctx) override { 236 const TensorShape input_shape = ctx->InputShape(0); 237 238 auto one = XlaHelpers::One(ctx->builder(), input_type(0)); 239 ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes())); 240 } 241}; 242 243REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp); 244 245} // namespace 246} // namespace tensorflow 247