1b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
3b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
4b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFloweryou may not use this file except in compliance with the License.
5b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerYou may obtain a copy of the License at
6b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
7b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
8b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
9b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
10b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
11b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerSee the License for the specific language governing permissions and
13b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerlimitations under the License.
14b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower==============================================================================*/
15b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
16b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#ifndef TENSORFLOW_KERNELS_SHAPE_OPS_H_
17b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#define TENSORFLOW_KERNELS_SHAPE_OPS_H_
18b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
19b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include <limits>
20b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include <unordered_set>
21b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include <vector>
22b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
23b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h"
24b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include "tensorflow/core/framework/tensor.h"
25b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include "tensorflow/core/framework/tensor_shape.h"
2681ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo#include "tensorflow/core/framework/variant_op_registry.h"
27b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#include "tensorflow/core/kernels/bounds_check.h"
28b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
29b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowernamespace tensorflow {
30b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
3181ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdonamespace shape_op_helpers {
3281ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdoinline Status GetRegularOrVariantShape(OpKernelContext* ctx, int input_index,
3381ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo                                       TensorShape* shape) {
3481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo  const Tensor& inp = ctx->input(input_index);
3581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo  if (ctx->input_dtype(0) == DT_VARIANT) {
3681ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    if (inp.dims() != 0) {
3781ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo      return errors::InvalidArgument(
380acf5bb38a8f208c6d9f048579a076d5bc6ff0beEugene Brevdo          "Shape of non-unary Variant not supported.");
3981ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    }
4081ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    TF_RETURN_IF_ERROR(GetUnaryVariantShape(inp, shape));
4181ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo  } else {
4281ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    *shape = inp.shape();
4381ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo  }
4481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo  return Status::OK();
4581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo}
4681ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo}  // namespace shape_op_helpers
4781ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo
48b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowertemplate <typename OutType>
49b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass ShapeOp : public OpKernel {
50b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public:
51b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
52b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
53b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  void Compute(OpKernelContext* ctx) override {
5481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    TensorShape shape;
5581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    OP_REQUIRES_OK(ctx,
5681ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo                   shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
5781ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    const int rank = shape.dims();
58b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    Tensor* out = nullptr;
59b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out));
60b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    auto vec = out->vec<OutType>();
61b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    for (int i = 0; i < rank; ++i) {
6281ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo      int64 dim_size = shape.dim_size(i);
63b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      if (out->dtype() == DT_INT32) {
64b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        OP_REQUIRES(
65b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower            ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
66b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower            errors::InvalidArgument("Shape output type is 32-bit ", " but dim ",
67b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                                    i, " is ", dim_size));
68b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      }
69b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      vec(i) = static_cast<OutType>(dim_size);
70b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    }
71b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  }
72b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
73b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  bool IsExpensive() override { return false; }
74b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower};
75b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
76b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowertemplate <typename OutType>
77b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass ShapeNOp : public OpKernel {
78b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public:
79b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
80b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
81b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  void Compute(OpKernelContext* ctx) override {
82b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    for (int i = 0; i < ctx->num_inputs(); ++i) {
8381ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo      TensorShape shape;
8481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo      OP_REQUIRES_OK(
8581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo          ctx, shape_op_helpers::GetRegularOrVariantShape(ctx, i, &shape));
86b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      const int dims = shape.dims();
87b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      Tensor* out = nullptr;
88b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out));
89b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      auto vec = out->vec<OutType>();
90b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
91b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      for (int j = 0; j < dims; ++j) {
92b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        int64 dim_size = shape.dim_size(j);
93b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        if (out->dtype() == DT_INT32) {
94b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower          OP_REQUIRES(
95b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower              ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
96b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower              errors::InvalidArgument("ShapeN output type is 32-bit but shape ",
97b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                                      i, " dim ", j, " is ", dim_size));
98b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        }
99b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        vec(j) = static_cast<OutType>(dim_size);
100b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      }
101b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    }
102b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  }
103b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
104b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  bool IsExpensive() override { return false; }
105b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower};
106b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
107b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass RankOp : public OpKernel {
108b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public:
109b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
110b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
111b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  void Compute(OpKernelContext* ctx) override {
11281ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    TensorShape shape;
11381ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    OP_REQUIRES_OK(ctx,
11481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo                   shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
11581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    const int rank = shape.dims();
116b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    Tensor* out = nullptr;
117b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
118b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    out->scalar<int32>()() = rank;
119b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  }
120b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
121b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  bool IsExpensive() override { return false; }
122b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower};
123b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
124b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowertemplate <typename OutType>
125b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass SizeOp : public OpKernel {
126b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public:
127b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
128b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
129b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  void Compute(OpKernelContext* ctx) override {
13081ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    TensorShape shape;
13181ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    OP_REQUIRES_OK(ctx,
13281ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo                   shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
13381ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    const int64 size = shape.num_elements();
134b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    Tensor* out = nullptr;
135b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
136b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    if (out->dtype() == DT_INT32) {
137b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      OP_REQUIRES(
138b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower          ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
139b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower          errors::InvalidArgument("Number of elements was larger than "
140b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                                  "representable by 32-bit output type"));
141b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    }
142b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    out->scalar<OutType>()() = static_cast<OutType>(size);
143b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  }
144b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
145b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  bool IsExpensive() override { return false; }
146b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower};
147b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
148d0a5d885d61b837018cb931a4d577289acc826fcMartin Wicketemplate <typename Tdim>
149b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass ExpandDimsOp : public OpKernel {
150b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public:
151b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
152b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
153b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  void Compute(OpKernelContext* ctx) override {
15481ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
15581ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo                errors::InvalidArgument("ExpandDims on Variant not supported"));
15681ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo
157d0a5d885d61b837018cb931a4d577289acc826fcMartin Wicke    Tdim dim = ctx->input(1).flat<Tdim>()(0);
158b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    OP_REQUIRES(
159b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
160b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        errors::InvalidArgument("Tried to expand dim index ", dim,
161b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                                " for tensor with ", ctx->input(0).dims(),
162b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                                " dimensions."));
163b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
164b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    auto existing_dims = ctx->input(0).shape().dim_sizes();
165b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    // Safe - # elements in tensor dims bounded.
166b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    const int existing_dims_size = static_cast<int>(existing_dims.size());
167b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    std::vector<int64> new_shape(existing_dims_size);
168b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    for (size_t i = 0; i < new_shape.size(); ++i) {
169b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      new_shape[i] = existing_dims[i];
170b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    }
171b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
172b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    // We emulate numpy's interpretation of the dim axis when
173b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    // -input.dims() >= dim <= input.dims().
174b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    if (dim < 0) {
175b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      dim += existing_dims.size() + 1;
176b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    }
177b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
178b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    // Clamp to the end if needed.
179d0a5d885d61b837018cb931a4d577289acc826fcMartin Wicke    dim = std::min<Tdim>(dim, existing_dims_size);
180b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    new_shape.emplace(new_shape.begin() + dim, 1);
181b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    const TensorShape output_shape(new_shape);
182b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
183b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    Tensor* output = nullptr;
184b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
185b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    if (!output->CopyFrom(ctx->input(0), output_shape)) {
186b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      // This should never happen, since the sizes of the input and output
187b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      // should always be the same (we only expand the dimension with 1).
188b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      ctx->SetStatus(
189b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower          errors::Internal("Could not expand dimension with input shape ",
190b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                           ctx->input(0).shape().DebugString(),
191b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                           " and output shape ", output_shape.DebugString()));
192b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    }
193b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  }
194b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
195b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  bool IsExpensive() override { return false; }
196b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower};
197b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
198b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlowerclass SqueezeOp : public OpKernel {
199b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower public:
200b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
201b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    std::vector<int32> squeeze_dims;
202b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
203b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
204b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  }
205b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
206b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  void Compute(OpKernelContext* ctx) override {
20781ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo    OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
20881ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo                errors::InvalidArgument("Squeeze on Variant not supported"));
20981ae1c68a46dfe350342c940d88d59eae1e80eeeEugene Brevdo
210b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    auto existing_dims = ctx->input(0).shape().dim_sizes();
211b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    const int existing_dims_size = static_cast<int>(existing_dims.size());
212b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    std::vector<int64> new_shape;
213b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
214b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    std::unordered_set<int32> wrapped_squeeze_dims;
215b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    wrapped_squeeze_dims.reserve(squeeze_dims_.size());
216b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    // Validate squeeze dims against the input.
217b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    for (int32 dim : squeeze_dims_) {
218b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      OP_REQUIRES(
219b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower          ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()),
220b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower          errors::InvalidArgument("Tried to squeeze dim index ", dim,
221b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                                  " for tensor with ", ctx->input(0).dims(),
222b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                                  " dimensions."));
223b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      // If dim is < 0, we wrap around (-1 means the last element).
224b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      if (dim < 0) {
225b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        dim = existing_dims_size + dim;
226b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      }
227b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
228b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      wrapped_squeeze_dims.insert(dim);
229b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    }
230b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
231b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    for (int i = 0; i < existing_dims_size; ++i) {
232b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      auto existing_dim = existing_dims[i];
233b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
234b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      // If squeeze_set is non-empty, only squeeze those dimensions.
235b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      if (!wrapped_squeeze_dims.empty()) {
236b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        if (wrapped_squeeze_dims.count(i) > 0) {
237b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower          OP_REQUIRES(ctx, existing_dim == 1,
238982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                      errors::InvalidArgument(
239982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                          "Tried to explicitly squeeze "
240982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                          "dimension ",
241982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                          i, " but dimension was not 1: ", existing_dim));
242b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        } else {
243b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower          // This dimension is not being squeezed.
244b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower          new_shape.push_back(existing_dim);
245b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        }
246b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      } else {
247b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        // Copy over all non-1-length dimensions.
248b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        if (existing_dim != 1) {
249b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower          new_shape.push_back(existing_dim);
250b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower        }
251b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      }
252b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    }
253b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
254b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    const TensorShape output_shape(new_shape);
255b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    Tensor* output = nullptr;
256b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
257b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    if (!output->CopyFrom(ctx->input(0), output_shape)) {
258b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      // This should never happen, since the sizes of the input and
259b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      // output should always be the same.
260b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower      ctx->SetStatus(errors::Internal("Could not squeeze input with shape ",
261b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                                      ctx->input(0).shape().DebugString(),
262b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                                      " and output shape ",
263b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower                                      output_shape.DebugString()));
264b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower    }
265b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  }
266b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
267b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  bool IsExpensive() override { return false; }
268b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
269b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower private:
270b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower  std::unordered_set<int32> squeeze_dims_;
271b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower};
272b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
273b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower}  // namespace tensorflow
274b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower
275b06256e58108dd6596f2410ccd359179bd98bd0aA. Unique TensorFlower#endif  // TENSORFLOW_KERNELS_SHAPE_OPS_H_
276