shape_ops.h revision b06256e58108dd6596f2410ccd359179bd98bd0a
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_KERNELS_SHAPE_OPS_H_
17#define TENSORFLOW_KERNELS_SHAPE_OPS_H_
18
19#include <limits>
20#include <unordered_set>
21#include <vector>
22
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/kernels/bounds_check.h"
27
28namespace tensorflow {
29
30template <typename OutType>
31class ShapeOp : public OpKernel {
32 public:
33  explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
34
35  void Compute(OpKernelContext* ctx) override {
36    const Tensor& inp = ctx->input(0);
37    const int rank = inp.dims();
38    Tensor* out = nullptr;
39    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out));
40    auto vec = out->vec<OutType>();
41    for (int i = 0; i < rank; ++i) {
42      int64 dim_size = inp.dim_size(i);
43      if (out->dtype() == DT_INT32) {
44        OP_REQUIRES(
45            ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
46            errors::InvalidArgument("Shape output type is 32-bit ", " but dim ",
47                                    i, " is ", dim_size));
48      }
49      vec(i) = static_cast<OutType>(dim_size);
50    }
51  }
52
53  bool IsExpensive() override { return false; }
54};
55
56template <typename OutType>
57class ShapeNOp : public OpKernel {
58 public:
59  explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
60
61  void Compute(OpKernelContext* ctx) override {
62    for (int i = 0; i < ctx->num_inputs(); ++i) {
63      auto shape = ctx->input(i).shape();
64      const int dims = shape.dims();
65      Tensor* out = nullptr;
66      OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out));
67      auto vec = out->vec<OutType>();
68
69      for (int j = 0; j < dims; ++j) {
70        int64 dim_size = shape.dim_size(j);
71        if (out->dtype() == DT_INT32) {
72          OP_REQUIRES(
73              ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
74              errors::InvalidArgument("ShapeN output type is 32-bit but shape ",
75                                      i, " dim ", j, " is ", dim_size));
76        }
77        vec(j) = static_cast<OutType>(dim_size);
78      }
79    }
80  }
81
82  bool IsExpensive() override { return false; }
83};
84
85class RankOp : public OpKernel {
86 public:
87  explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
88
89  void Compute(OpKernelContext* ctx) override {
90    const Tensor& inp = ctx->input(0);
91    const int rank = inp.dims();
92    Tensor* out = nullptr;
93    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
94    out->scalar<int32>()() = rank;
95  }
96
97  bool IsExpensive() override { return false; }
98};
99
100template <typename OutType>
101class SizeOp : public OpKernel {
102 public:
103  explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
104
105  void Compute(OpKernelContext* ctx) override {
106    const Tensor& inp = ctx->input(0);
107    const int64 size = inp.NumElements();
108    Tensor* out = nullptr;
109    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
110    if (out->dtype() == DT_INT32) {
111      OP_REQUIRES(
112          ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
113          errors::InvalidArgument("Number of elements was larger than "
114                                  "representable by 32-bit output type"));
115    }
116    out->scalar<OutType>()() = static_cast<OutType>(size);
117  }
118
119  bool IsExpensive() override { return false; }
120};
121
122class ExpandDimsOp : public OpKernel {
123 public:
124  explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
125
126  void Compute(OpKernelContext* ctx) override {
127    int32 dim = ctx->input(1).flat<int32>()(0);
128    OP_REQUIRES(
129        ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
130        errors::InvalidArgument("Tried to expand dim index ", dim,
131                                " for tensor with ", ctx->input(0).dims(),
132                                " dimensions."));
133
134    auto existing_dims = ctx->input(0).shape().dim_sizes();
135    // Safe - # elements in tensor dims bounded.
136    const int existing_dims_size = static_cast<int>(existing_dims.size());
137    std::vector<int64> new_shape(existing_dims_size);
138    for (size_t i = 0; i < new_shape.size(); ++i) {
139      new_shape[i] = existing_dims[i];
140    }
141
142    // We emulate numpy's interpretation of the dim axis when
143    // -input.dims() >= dim <= input.dims().
144    if (dim < 0) {
145      dim += existing_dims.size() + 1;
146    }
147
148    // Clamp to the end if needed.
149    dim = std::min<int32>(dim, existing_dims_size);
150    new_shape.emplace(new_shape.begin() + dim, 1);
151    const TensorShape output_shape(new_shape);
152
153    Tensor* output = nullptr;
154    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
155    if (!output->CopyFrom(ctx->input(0), output_shape)) {
156      // This should never happen, since the sizes of the input and output
157      // should always be the same (we only expand the dimension with 1).
158      ctx->SetStatus(
159          errors::Internal("Could not expand dimension with input shape ",
160                           ctx->input(0).shape().DebugString(),
161                           " and output shape ", output_shape.DebugString()));
162    }
163  }
164
165  bool IsExpensive() override { return false; }
166};
167
168class SqueezeOp : public OpKernel {
169 public:
170  explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
171    std::vector<int32> squeeze_dims;
172    OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
173    squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
174  }
175
176  void Compute(OpKernelContext* ctx) override {
177    auto existing_dims = ctx->input(0).shape().dim_sizes();
178    const int existing_dims_size = static_cast<int>(existing_dims.size());
179    std::vector<int64> new_shape;
180
181    std::unordered_set<int32> wrapped_squeeze_dims;
182    wrapped_squeeze_dims.reserve(squeeze_dims_.size());
183    // Validate squeeze dims against the input.
184    for (int32 dim : squeeze_dims_) {
185      OP_REQUIRES(
186          ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()),
187          errors::InvalidArgument("Tried to squeeze dim index ", dim,
188                                  " for tensor with ", ctx->input(0).dims(),
189                                  " dimensions."));
190      // If dim is < 0, we wrap around (-1 means the last element).
191      if (dim < 0) {
192        dim = existing_dims_size + dim;
193      }
194
195      wrapped_squeeze_dims.insert(dim);
196    }
197
198    for (int i = 0; i < existing_dims_size; ++i) {
199      auto existing_dim = existing_dims[i];
200
201      // If squeeze_set is non-empty, only squeeze those dimensions.
202      if (!wrapped_squeeze_dims.empty()) {
203        if (wrapped_squeeze_dims.count(i) > 0) {
204          OP_REQUIRES(ctx, existing_dim == 1,
205                      errors::InvalidArgument(
206                          "Tried to explicitly squeeze "
207                          "dimension ",
208                          i, " but dimension was not 1: ", existing_dim));
209        } else {
210          // This dimension is not being squeezed.
211          new_shape.push_back(existing_dim);
212        }
213      } else {
214        // Copy over all non-1-length dimensions.
215        if (existing_dim != 1) {
216          new_shape.push_back(existing_dim);
217        }
218      }
219    }
220
221    const TensorShape output_shape(new_shape);
222    Tensor* output = nullptr;
223    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
224    if (!output->CopyFrom(ctx->input(0), output_shape)) {
225      // This should never happen, since the sizes of the input and
226      // output should always be the same.
227      ctx->SetStatus(errors::Internal("Could not squeeze input with shape ",
228                                      ctx->input(0).shape().DebugString(),
229                                      " and output shape ",
230                                      output_shape.DebugString()));
231    }
232  }
233
234  bool IsExpensive() override { return false; }
235
236 private:
237  std::unordered_set<int32> squeeze_dims_;
238};
239
240}  // namespace tensorflow
241
242#endif  // TENSORFLOW_KERNELS_SHAPE_OPS_H_
243