shape_ops.h revision b06256e58108dd6596f2410ccd359179bd98bd0a
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_KERNELS_SHAPE_OPS_H_ 17#define TENSORFLOW_KERNELS_SHAPE_OPS_H_ 18 19#include <limits> 20#include <unordered_set> 21#include <vector> 22 23#include "tensorflow/core/framework/op_kernel.h" 24#include "tensorflow/core/framework/tensor.h" 25#include "tensorflow/core/framework/tensor_shape.h" 26#include "tensorflow/core/kernels/bounds_check.h" 27 28namespace tensorflow { 29 30template <typename OutType> 31class ShapeOp : public OpKernel { 32 public: 33 explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 34 35 void Compute(OpKernelContext* ctx) override { 36 const Tensor& inp = ctx->input(0); 37 const int rank = inp.dims(); 38 Tensor* out = nullptr; 39 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out)); 40 auto vec = out->vec<OutType>(); 41 for (int i = 0; i < rank; ++i) { 42 int64 dim_size = inp.dim_size(i); 43 if (out->dtype() == DT_INT32) { 44 OP_REQUIRES( 45 ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()), 46 errors::InvalidArgument("Shape output type is 32-bit ", " but dim ", 47 i, " is ", dim_size)); 48 } 49 vec(i) = static_cast<OutType>(dim_size); 50 } 51 } 52 53 bool IsExpensive() override { return false; } 54}; 55 56template <typename OutType> 57class ShapeNOp : public OpKernel { 58 public: 59 explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 60 61 void Compute(OpKernelContext* ctx) override { 62 for (int i = 0; i < ctx->num_inputs(); ++i) { 63 auto shape = ctx->input(i).shape(); 64 const int dims = shape.dims(); 65 Tensor* out = nullptr; 66 OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out)); 67 auto vec = out->vec<OutType>(); 68 69 for (int j = 0; j < dims; ++j) { 70 int64 dim_size = shape.dim_size(j); 71 if (out->dtype() == DT_INT32) { 72 OP_REQUIRES( 73 ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()), 74 errors::InvalidArgument("ShapeN output type is 32-bit but shape ", 75 i, " dim ", j, " is ", dim_size)); 76 } 77 vec(j) = static_cast<OutType>(dim_size); 78 } 79 } 80 } 81 82 bool IsExpensive() override { return false; } 83}; 84 85class RankOp : public OpKernel { 86 public: 87 explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 88 89 void Compute(OpKernelContext* ctx) override { 90 const Tensor& inp = ctx->input(0); 91 const int rank = inp.dims(); 92 Tensor* out = nullptr; 93 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); 94 out->scalar<int32>()() = rank; 95 } 96 97 bool IsExpensive() override { return false; } 98}; 99 100template <typename OutType> 101class SizeOp : public OpKernel { 102 public: 103 explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 104 105 void Compute(OpKernelContext* ctx) override { 106 const Tensor& inp = ctx->input(0); 107 const int64 size = inp.NumElements(); 108 Tensor* out = nullptr; 109 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); 110 if (out->dtype() == DT_INT32) { 111 OP_REQUIRES( 112 ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()), 113 errors::InvalidArgument("Number of elements was larger than " 114 "representable by 32-bit output type")); 115 } 116 out->scalar<OutType>()() = static_cast<OutType>(size); 117 } 118 119 bool IsExpensive() override { return false; } 120}; 121 122class ExpandDimsOp : public OpKernel { 123 public: 124 explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 125 126 void Compute(OpKernelContext* ctx) override { 127 int32 dim = ctx->input(1).flat<int32>()(0); 128 OP_REQUIRES( 129 ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()), 130 errors::InvalidArgument("Tried to expand dim index ", dim, 131 " for tensor with ", ctx->input(0).dims(), 132 " dimensions.")); 133 134 auto existing_dims = ctx->input(0).shape().dim_sizes(); 135 // Safe - # elements in tensor dims bounded. 136 const int existing_dims_size = static_cast<int>(existing_dims.size()); 137 std::vector<int64> new_shape(existing_dims_size); 138 for (size_t i = 0; i < new_shape.size(); ++i) { 139 new_shape[i] = existing_dims[i]; 140 } 141 142 // We emulate numpy's interpretation of the dim axis when 143 // -input.dims() >= dim <= input.dims(). 144 if (dim < 0) { 145 dim += existing_dims.size() + 1; 146 } 147 148 // Clamp to the end if needed. 149 dim = std::min<int32>(dim, existing_dims_size); 150 new_shape.emplace(new_shape.begin() + dim, 1); 151 const TensorShape output_shape(new_shape); 152 153 Tensor* output = nullptr; 154 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output)); 155 if (!output->CopyFrom(ctx->input(0), output_shape)) { 156 // This should never happen, since the sizes of the input and output 157 // should always be the same (we only expand the dimension with 1). 158 ctx->SetStatus( 159 errors::Internal("Could not expand dimension with input shape ", 160 ctx->input(0).shape().DebugString(), 161 " and output shape ", output_shape.DebugString())); 162 } 163 } 164 165 bool IsExpensive() override { return false; } 166}; 167 168class SqueezeOp : public OpKernel { 169 public: 170 explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 171 std::vector<int32> squeeze_dims; 172 OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims)); 173 squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end()); 174 } 175 176 void Compute(OpKernelContext* ctx) override { 177 auto existing_dims = ctx->input(0).shape().dim_sizes(); 178 const int existing_dims_size = static_cast<int>(existing_dims.size()); 179 std::vector<int64> new_shape; 180 181 std::unordered_set<int32> wrapped_squeeze_dims; 182 wrapped_squeeze_dims.reserve(squeeze_dims_.size()); 183 // Validate squeeze dims against the input. 184 for (int32 dim : squeeze_dims_) { 185 OP_REQUIRES( 186 ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()), 187 errors::InvalidArgument("Tried to squeeze dim index ", dim, 188 " for tensor with ", ctx->input(0).dims(), 189 " dimensions.")); 190 // If dim is < 0, we wrap around (-1 means the last element). 191 if (dim < 0) { 192 dim = existing_dims_size + dim; 193 } 194 195 wrapped_squeeze_dims.insert(dim); 196 } 197 198 for (int i = 0; i < existing_dims_size; ++i) { 199 auto existing_dim = existing_dims[i]; 200 201 // If squeeze_set is non-empty, only squeeze those dimensions. 202 if (!wrapped_squeeze_dims.empty()) { 203 if (wrapped_squeeze_dims.count(i) > 0) { 204 OP_REQUIRES(ctx, existing_dim == 1, 205 errors::InvalidArgument( 206 "Tried to explicitly squeeze " 207 "dimension ", 208 i, " but dimension was not 1: ", existing_dim)); 209 } else { 210 // This dimension is not being squeezed. 211 new_shape.push_back(existing_dim); 212 } 213 } else { 214 // Copy over all non-1-length dimensions. 215 if (existing_dim != 1) { 216 new_shape.push_back(existing_dim); 217 } 218 } 219 } 220 221 const TensorShape output_shape(new_shape); 222 Tensor* output = nullptr; 223 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output)); 224 if (!output->CopyFrom(ctx->input(0), output_shape)) { 225 // This should never happen, since the sizes of the input and 226 // output should always be the same. 227 ctx->SetStatus(errors::Internal("Could not squeeze input with shape ", 228 ctx->input(0).shape().DebugString(), 229 " and output shape ", 230 output_shape.DebugString())); 231 } 232 } 233 234 bool IsExpensive() override { return false; } 235 236 private: 237 std::unordered_set<int32> squeeze_dims_; 238}; 239 240} // namespace tensorflow 241 242#endif // TENSORFLOW_KERNELS_SHAPE_OPS_H_ 243