1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_KERNELS_SHAPE_OPS_H_
17#define TENSORFLOW_KERNELS_SHAPE_OPS_H_
18
19#include <limits>
20#include <unordered_set>
21#include <vector>
22
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/framework/variant_op_registry.h"
27#include "tensorflow/core/kernels/bounds_check.h"
28
29namespace tensorflow {
30
31namespace shape_op_helpers {
32inline Status GetRegularOrVariantShape(OpKernelContext* ctx, int input_index,
33                                       TensorShape* shape) {
34  const Tensor& inp = ctx->input(input_index);
35  if (ctx->input_dtype(0) == DT_VARIANT) {
36    if (inp.dims() != 0) {
37      return errors::InvalidArgument(
38          "Shape of non-unary Variant not supported.");
39    }
40    TF_RETURN_IF_ERROR(GetUnaryVariantShape(inp, shape));
41  } else {
42    *shape = inp.shape();
43  }
44  return Status::OK();
45}
46}  // namespace shape_op_helpers
47
48template <typename OutType>
49class ShapeOp : public OpKernel {
50 public:
51  explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
52
53  void Compute(OpKernelContext* ctx) override {
54    TensorShape shape;
55    OP_REQUIRES_OK(ctx,
56                   shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
57    const int rank = shape.dims();
58    Tensor* out = nullptr;
59    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out));
60    auto vec = out->vec<OutType>();
61    for (int i = 0; i < rank; ++i) {
62      int64 dim_size = shape.dim_size(i);
63      if (out->dtype() == DT_INT32) {
64        OP_REQUIRES(
65            ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
66            errors::InvalidArgument("Shape output type is 32-bit ", " but dim ",
67                                    i, " is ", dim_size));
68      }
69      vec(i) = static_cast<OutType>(dim_size);
70    }
71  }
72
73  bool IsExpensive() override { return false; }
74};
75
76template <typename OutType>
77class ShapeNOp : public OpKernel {
78 public:
79  explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
80
81  void Compute(OpKernelContext* ctx) override {
82    for (int i = 0; i < ctx->num_inputs(); ++i) {
83      TensorShape shape;
84      OP_REQUIRES_OK(
85          ctx, shape_op_helpers::GetRegularOrVariantShape(ctx, i, &shape));
86      const int dims = shape.dims();
87      Tensor* out = nullptr;
88      OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out));
89      auto vec = out->vec<OutType>();
90
91      for (int j = 0; j < dims; ++j) {
92        int64 dim_size = shape.dim_size(j);
93        if (out->dtype() == DT_INT32) {
94          OP_REQUIRES(
95              ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
96              errors::InvalidArgument("ShapeN output type is 32-bit but shape ",
97                                      i, " dim ", j, " is ", dim_size));
98        }
99        vec(j) = static_cast<OutType>(dim_size);
100      }
101    }
102  }
103
104  bool IsExpensive() override { return false; }
105};
106
107class RankOp : public OpKernel {
108 public:
109  explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
110
111  void Compute(OpKernelContext* ctx) override {
112    TensorShape shape;
113    OP_REQUIRES_OK(ctx,
114                   shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
115    const int rank = shape.dims();
116    Tensor* out = nullptr;
117    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
118    out->scalar<int32>()() = rank;
119  }
120
121  bool IsExpensive() override { return false; }
122};
123
124template <typename OutType>
125class SizeOp : public OpKernel {
126 public:
127  explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
128
129  void Compute(OpKernelContext* ctx) override {
130    TensorShape shape;
131    OP_REQUIRES_OK(ctx,
132                   shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
133    const int64 size = shape.num_elements();
134    Tensor* out = nullptr;
135    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
136    if (out->dtype() == DT_INT32) {
137      OP_REQUIRES(
138          ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
139          errors::InvalidArgument("Number of elements was larger than "
140                                  "representable by 32-bit output type"));
141    }
142    out->scalar<OutType>()() = static_cast<OutType>(size);
143  }
144
145  bool IsExpensive() override { return false; }
146};
147
148template <typename Tdim>
149class ExpandDimsOp : public OpKernel {
150 public:
151  explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
152
153  void Compute(OpKernelContext* ctx) override {
154    OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
155                errors::InvalidArgument("ExpandDims on Variant not supported"));
156
157    Tdim dim = ctx->input(1).flat<Tdim>()(0);
158    OP_REQUIRES(
159        ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
160        errors::InvalidArgument("Tried to expand dim index ", dim,
161                                " for tensor with ", ctx->input(0).dims(),
162                                " dimensions."));
163
164    auto existing_dims = ctx->input(0).shape().dim_sizes();
165    // Safe - # elements in tensor dims bounded.
166    const int existing_dims_size = static_cast<int>(existing_dims.size());
167    std::vector<int64> new_shape(existing_dims_size);
168    for (size_t i = 0; i < new_shape.size(); ++i) {
169      new_shape[i] = existing_dims[i];
170    }
171
172    // We emulate numpy's interpretation of the dim axis when
173    // -input.dims() >= dim <= input.dims().
174    if (dim < 0) {
175      dim += existing_dims.size() + 1;
176    }
177
178    // Clamp to the end if needed.
179    dim = std::min<Tdim>(dim, existing_dims_size);
180    new_shape.emplace(new_shape.begin() + dim, 1);
181    const TensorShape output_shape(new_shape);
182
183    Tensor* output = nullptr;
184    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
185    if (!output->CopyFrom(ctx->input(0), output_shape)) {
186      // This should never happen, since the sizes of the input and output
187      // should always be the same (we only expand the dimension with 1).
188      ctx->SetStatus(
189          errors::Internal("Could not expand dimension with input shape ",
190                           ctx->input(0).shape().DebugString(),
191                           " and output shape ", output_shape.DebugString()));
192    }
193  }
194
195  bool IsExpensive() override { return false; }
196};
197
198class SqueezeOp : public OpKernel {
199 public:
200  explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
201    std::vector<int32> squeeze_dims;
202    OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
203    squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
204  }
205
206  void Compute(OpKernelContext* ctx) override {
207    OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
208                errors::InvalidArgument("Squeeze on Variant not supported"));
209
210    auto existing_dims = ctx->input(0).shape().dim_sizes();
211    const int existing_dims_size = static_cast<int>(existing_dims.size());
212    std::vector<int64> new_shape;
213
214    std::unordered_set<int32> wrapped_squeeze_dims;
215    wrapped_squeeze_dims.reserve(squeeze_dims_.size());
216    // Validate squeeze dims against the input.
217    for (int32 dim : squeeze_dims_) {
218      OP_REQUIRES(
219          ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()),
220          errors::InvalidArgument("Tried to squeeze dim index ", dim,
221                                  " for tensor with ", ctx->input(0).dims(),
222                                  " dimensions."));
223      // If dim is < 0, we wrap around (-1 means the last element).
224      if (dim < 0) {
225        dim = existing_dims_size + dim;
226      }
227
228      wrapped_squeeze_dims.insert(dim);
229    }
230
231    for (int i = 0; i < existing_dims_size; ++i) {
232      auto existing_dim = existing_dims[i];
233
234      // If squeeze_set is non-empty, only squeeze those dimensions.
235      if (!wrapped_squeeze_dims.empty()) {
236        if (wrapped_squeeze_dims.count(i) > 0) {
237          OP_REQUIRES(ctx, existing_dim == 1,
238                      errors::InvalidArgument(
239                          "Tried to explicitly squeeze "
240                          "dimension ",
241                          i, " but dimension was not 1: ", existing_dim));
242        } else {
243          // This dimension is not being squeezed.
244          new_shape.push_back(existing_dim);
245        }
246      } else {
247        // Copy over all non-1-length dimensions.
248        if (existing_dim != 1) {
249          new_shape.push_back(existing_dim);
250        }
251      }
252    }
253
254    const TensorShape output_shape(new_shape);
255    Tensor* output = nullptr;
256    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
257    if (!output->CopyFrom(ctx->input(0), output_shape)) {
258      // This should never happen, since the sizes of the input and
259      // output should always be the same.
260      ctx->SetStatus(errors::Internal("Could not squeeze input with shape ",
261                                      ctx->input(0).shape().DebugString(),
262                                      " and output shape ",
263                                      output_shape.DebugString()));
264    }
265  }
266
267  bool IsExpensive() override { return false; }
268
269 private:
270  std::unordered_set<int32> squeeze_dims_;
271};
272
273}  // namespace tensorflow
274
275#endif  // TENSORFLOW_KERNELS_SHAPE_OPS_H_
276