1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_KERNELS_SHAPE_OPS_H_ 17#define TENSORFLOW_KERNELS_SHAPE_OPS_H_ 18 19#include <limits> 20#include <unordered_set> 21#include <vector> 22 23#include "tensorflow/core/framework/op_kernel.h" 24#include "tensorflow/core/framework/tensor.h" 25#include "tensorflow/core/framework/tensor_shape.h" 26#include "tensorflow/core/framework/variant_op_registry.h" 27#include "tensorflow/core/kernels/bounds_check.h" 28 29namespace tensorflow { 30 31namespace shape_op_helpers { 32inline Status GetRegularOrVariantShape(OpKernelContext* ctx, int input_index, 33 TensorShape* shape) { 34 const Tensor& inp = ctx->input(input_index); 35 if (ctx->input_dtype(0) == DT_VARIANT) { 36 if (inp.dims() != 0) { 37 return errors::InvalidArgument( 38 "Shape of non-unary Variant not supported."); 39 } 40 TF_RETURN_IF_ERROR(GetUnaryVariantShape(inp, shape)); 41 } else { 42 *shape = inp.shape(); 43 } 44 return Status::OK(); 45} 46} // namespace shape_op_helpers 47 48template <typename OutType> 49class ShapeOp : public OpKernel { 50 public: 51 explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 52 53 void Compute(OpKernelContext* ctx) override { 54 TensorShape shape; 55 OP_REQUIRES_OK(ctx, 56 shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape)); 57 const int rank = shape.dims(); 58 Tensor* out = nullptr; 59 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out)); 60 auto vec = out->vec<OutType>(); 61 for (int i = 0; i < rank; ++i) { 62 int64 dim_size = shape.dim_size(i); 63 if (out->dtype() == DT_INT32) { 64 OP_REQUIRES( 65 ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()), 66 errors::InvalidArgument("Shape output type is 32-bit ", " but dim ", 67 i, " is ", dim_size)); 68 } 69 vec(i) = static_cast<OutType>(dim_size); 70 } 71 } 72 73 bool IsExpensive() override { return false; } 74}; 75 76template <typename OutType> 77class ShapeNOp : public OpKernel { 78 public: 79 explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 80 81 void Compute(OpKernelContext* ctx) override { 82 for (int i = 0; i < ctx->num_inputs(); ++i) { 83 TensorShape shape; 84 OP_REQUIRES_OK( 85 ctx, shape_op_helpers::GetRegularOrVariantShape(ctx, i, &shape)); 86 const int dims = shape.dims(); 87 Tensor* out = nullptr; 88 OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out)); 89 auto vec = out->vec<OutType>(); 90 91 for (int j = 0; j < dims; ++j) { 92 int64 dim_size = shape.dim_size(j); 93 if (out->dtype() == DT_INT32) { 94 OP_REQUIRES( 95 ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()), 96 errors::InvalidArgument("ShapeN output type is 32-bit but shape ", 97 i, " dim ", j, " is ", dim_size)); 98 } 99 vec(j) = static_cast<OutType>(dim_size); 100 } 101 } 102 } 103 104 bool IsExpensive() override { return false; } 105}; 106 107class RankOp : public OpKernel { 108 public: 109 explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 110 111 void Compute(OpKernelContext* ctx) override { 112 TensorShape shape; 113 OP_REQUIRES_OK(ctx, 114 shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape)); 115 const int rank = shape.dims(); 116 Tensor* out = nullptr; 117 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); 118 out->scalar<int32>()() = rank; 119 } 120 121 bool IsExpensive() override { return false; } 122}; 123 124template <typename OutType> 125class SizeOp : public OpKernel { 126 public: 127 explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 128 129 void Compute(OpKernelContext* ctx) override { 130 TensorShape shape; 131 OP_REQUIRES_OK(ctx, 132 shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape)); 133 const int64 size = shape.num_elements(); 134 Tensor* out = nullptr; 135 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); 136 if (out->dtype() == DT_INT32) { 137 OP_REQUIRES( 138 ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()), 139 errors::InvalidArgument("Number of elements was larger than " 140 "representable by 32-bit output type")); 141 } 142 out->scalar<OutType>()() = static_cast<OutType>(size); 143 } 144 145 bool IsExpensive() override { return false; } 146}; 147 148template <typename Tdim> 149class ExpandDimsOp : public OpKernel { 150 public: 151 explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 152 153 void Compute(OpKernelContext* ctx) override { 154 OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT, 155 errors::InvalidArgument("ExpandDims on Variant not supported")); 156 157 Tdim dim = ctx->input(1).flat<Tdim>()(0); 158 OP_REQUIRES( 159 ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()), 160 errors::InvalidArgument("Tried to expand dim index ", dim, 161 " for tensor with ", ctx->input(0).dims(), 162 " dimensions.")); 163 164 auto existing_dims = ctx->input(0).shape().dim_sizes(); 165 // Safe - # elements in tensor dims bounded. 166 const int existing_dims_size = static_cast<int>(existing_dims.size()); 167 std::vector<int64> new_shape(existing_dims_size); 168 for (size_t i = 0; i < new_shape.size(); ++i) { 169 new_shape[i] = existing_dims[i]; 170 } 171 172 // We emulate numpy's interpretation of the dim axis when 173 // -input.dims() >= dim <= input.dims(). 174 if (dim < 0) { 175 dim += existing_dims.size() + 1; 176 } 177 178 // Clamp to the end if needed. 179 dim = std::min<Tdim>(dim, existing_dims_size); 180 new_shape.emplace(new_shape.begin() + dim, 1); 181 const TensorShape output_shape(new_shape); 182 183 Tensor* output = nullptr; 184 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output)); 185 if (!output->CopyFrom(ctx->input(0), output_shape)) { 186 // This should never happen, since the sizes of the input and output 187 // should always be the same (we only expand the dimension with 1). 188 ctx->SetStatus( 189 errors::Internal("Could not expand dimension with input shape ", 190 ctx->input(0).shape().DebugString(), 191 " and output shape ", output_shape.DebugString())); 192 } 193 } 194 195 bool IsExpensive() override { return false; } 196}; 197 198class SqueezeOp : public OpKernel { 199 public: 200 explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 201 std::vector<int32> squeeze_dims; 202 OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims)); 203 squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end()); 204 } 205 206 void Compute(OpKernelContext* ctx) override { 207 OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT, 208 errors::InvalidArgument("Squeeze on Variant not supported")); 209 210 auto existing_dims = ctx->input(0).shape().dim_sizes(); 211 const int existing_dims_size = static_cast<int>(existing_dims.size()); 212 std::vector<int64> new_shape; 213 214 std::unordered_set<int32> wrapped_squeeze_dims; 215 wrapped_squeeze_dims.reserve(squeeze_dims_.size()); 216 // Validate squeeze dims against the input. 217 for (int32 dim : squeeze_dims_) { 218 OP_REQUIRES( 219 ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()), 220 errors::InvalidArgument("Tried to squeeze dim index ", dim, 221 " for tensor with ", ctx->input(0).dims(), 222 " dimensions.")); 223 // If dim is < 0, we wrap around (-1 means the last element). 224 if (dim < 0) { 225 dim = existing_dims_size + dim; 226 } 227 228 wrapped_squeeze_dims.insert(dim); 229 } 230 231 for (int i = 0; i < existing_dims_size; ++i) { 232 auto existing_dim = existing_dims[i]; 233 234 // If squeeze_set is non-empty, only squeeze those dimensions. 235 if (!wrapped_squeeze_dims.empty()) { 236 if (wrapped_squeeze_dims.count(i) > 0) { 237 OP_REQUIRES(ctx, existing_dim == 1, 238 errors::InvalidArgument( 239 "Tried to explicitly squeeze " 240 "dimension ", 241 i, " but dimension was not 1: ", existing_dim)); 242 } else { 243 // This dimension is not being squeezed. 244 new_shape.push_back(existing_dim); 245 } 246 } else { 247 // Copy over all non-1-length dimensions. 248 if (existing_dim != 1) { 249 new_shape.push_back(existing_dim); 250 } 251 } 252 } 253 254 const TensorShape output_shape(new_shape); 255 Tensor* output = nullptr; 256 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output)); 257 if (!output->CopyFrom(ctx->input(0), output_shape)) { 258 // This should never happen, since the sizes of the input and 259 // output should always be the same. 260 ctx->SetStatus(errors::Internal("Could not squeeze input with shape ", 261 ctx->input(0).shape().DebugString(), 262 " and output shape ", 263 output_shape.DebugString())); 264 } 265 } 266 267 bool IsExpensive() override { return false; } 268 269 private: 270 std::unordered_set<int32> squeeze_dims_; 271}; 272 273} // namespace tensorflow 274 275#endif // TENSORFLOW_KERNELS_SHAPE_OPS_H_ 276