1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// XLA-specific Shape Ops.
17
18#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
19#include "tensorflow/compiler/tf2xla/type_util.h"
20#include "tensorflow/compiler/tf2xla/xla_helpers.h"
21#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23#include "tensorflow/core/framework/kernel_def_builder.h"
24#include "tensorflow/core/kernels/bounds_check.h"
25
26namespace tensorflow {
27namespace {
28
29class ShapeOp : public XlaOpKernel {
30 public:
31  explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
32    OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
33  }
34
35  void Compile(XlaOpKernelContext* ctx) override {
36    const TensorShape input_shape = ctx->InputShape(0);
37    Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
38    OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
39    ctx->SetConstantOutput(0, shape_constant);
40  }
41
42 private:
43  DataType out_dtype_;
44};
45
46REGISTER_XLA_OP(Name("Shape"), ShapeOp);
47
48class ShapeNOp : public XlaOpKernel {
49 public:
50  explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
51    OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
52  }
53
54  void Compile(XlaOpKernelContext* ctx) override {
55    for (int i = 0; i < ctx->num_inputs(); ++i) {
56      const TensorShape input_shape = ctx->InputShape(i);
57      Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
58      OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
59      ctx->SetConstantOutput(i, shape_constant);
60    }
61  }
62
63  bool IsExpensive() override { return false; }
64
65 private:
66  DataType out_dtype_;
67};
68REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp);
69
70class RankOp : public XlaOpKernel {
71 public:
72  explicit RankOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
73
74  void Compile(XlaOpKernelContext* ctx) override {
75    const TensorShape input_shape = ctx->InputShape(0);
76    const int rank = input_shape.dims();
77    Tensor rank_constant(DT_INT32, TensorShape({}));
78    rank_constant.scalar<int32>()() = rank;
79
80    ctx->SetConstantOutput(0, rank_constant);
81  }
82};
83
84REGISTER_XLA_OP(Name("Rank"), RankOp);
85
86class SizeOp : public XlaOpKernel {
87 public:
88  explicit SizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
89
90  void Compile(XlaOpKernelContext* ctx) override {
91    const TensorShape input_shape = ctx->InputShape(0);
92    const int64 size = input_shape.num_elements();
93    OP_REQUIRES(ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
94                errors::InvalidArgument("Size does not work for tensors > "
95                                        "int32 max."));
96    Tensor size_constant(DT_INT32, TensorShape({}));
97    size_constant.scalar<int32>()() = static_cast<int32>(size);
98
99    ctx->SetConstantOutput(0, size_constant);
100  }
101};
102
103REGISTER_XLA_OP(Name("Size"), SizeOp);
104
105class ExpandDimsOp : public XlaOpKernel {
106 public:
107  explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
108
109  void Compile(XlaOpKernelContext* ctx) override {
110    const TensorShape input_shape = ctx->InputShape(0);
111    const TensorShape dim_shape = ctx->InputShape(1);
112
113    // TODO(phawkins): the standard implementation of ExpandDimsOp seems to
114    // accept legacy scalars, even when they should be forbidden by the graphdef
115    // version.
116    OP_REQUIRES(ctx, dim_shape.num_elements() == 1,
117                errors::InvalidArgument(strings::StrCat(
118                    "dim input to ExpandDims must be a scalar; got ",
119                    dim_shape.DebugString())));
120
121    xla::Literal literal;
122    OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal));
123
124    int dim = literal.data<int32>()[0];
125
126    OP_REQUIRES(ctx,
127                (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
128                errors::InvalidArgument("Tried to expand dim index ", dim,
129                                        " for tensor with ", input_shape.dims(),
130                                        " dimensions."));
131
132    auto existing_dims = input_shape.dim_sizes();
133    // Safe - # elements in tensor dims bounded.
134    const int existing_dims_size = static_cast<int>(existing_dims.size());
135    std::vector<int64> new_shape(existing_dims_size);
136    for (size_t i = 0; i < new_shape.size(); ++i) {
137      new_shape[i] = existing_dims[i];
138    }
139
140    // We emulate numpy's interpretation of the dim axis when
141    // -input.dims() >= dim <= input.dims().
142    if (dim < 0) {
143      dim += existing_dims.size() + 1;
144    }
145
146    // Clamp to the end if needed.
147    dim = std::min<int32>(dim, existing_dims_size);
148    new_shape.emplace(new_shape.begin() + dim, 1);
149
150    ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
151  }
152};
153REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp);
154
155class SqueezeOp : public XlaOpKernel {
156 public:
157  explicit SqueezeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
158    std::vector<int32> squeeze_dims;
159    OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
160    squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
161  }
162
163  void Compile(XlaOpKernelContext* ctx) override {
164    const TensorShape input_shape = ctx->InputShape(0);
165    auto existing_dims = input_shape.dim_sizes();
166    int existing_dims_size = input_shape.dims();
167    std::vector<int64> new_shape;
168
169    std::unordered_set<int32> wrapped_squeeze_dims;
170    wrapped_squeeze_dims.reserve(squeeze_dims_.size());
171    // Validate squeeze dims against the input.
172    for (int32 dim : squeeze_dims_) {
173      OP_REQUIRES(ctx, (dim >= -input_shape.dims() && dim < input_shape.dims()),
174                  errors::InvalidArgument("Tried to squeeze dim index ", dim,
175                                          " for tensor with ",
176                                          input_shape.dims(), " dimensions."));
177      // If dim is < 0, we wrap around (-1 means the last element).
178      if (dim < 0) {
179        dim = existing_dims_size + dim;
180      }
181
182      wrapped_squeeze_dims.insert(dim);
183    }
184
185    for (int i = 0; i < existing_dims_size; ++i) {
186      auto existing_dim = existing_dims[i];
187
188      // If squeeze_set is non-empty, only squeeze those dimensions.
189      if (!wrapped_squeeze_dims.empty()) {
190        if (wrapped_squeeze_dims.count(i) > 0) {
191          OP_REQUIRES(ctx, existing_dim == 1,
192                      errors::InvalidArgument("Tried to explicitly squeeze "
193                                              "dimension ",
194                                              i, " but dimension was not 1: ",
195                                              existing_dim));
196        } else {
197          // This dimension is not being squeezed.
198          new_shape.push_back(existing_dim);
199        }
200      } else {
201        // Copy over all non-1-length dimensions.
202        if (existing_dim != 1) {
203          new_shape.push_back(existing_dim);
204        }
205      }
206    }
207
208    ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
209  }
210
211 private:
212  std::unordered_set<int32> squeeze_dims_;
213};
214
215REGISTER_XLA_OP(Name("Squeeze"), SqueezeOp);
216
217class ZerosLikeOp : public XlaOpKernel {
218 public:
219  explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
220
221  void Compile(XlaOpKernelContext* ctx) override {
222    const TensorShape input_shape = ctx->InputShape(0);
223
224    auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
225    ctx->SetOutput(0, ctx->builder()->Broadcast(zero, input_shape.dim_sizes()));
226  }
227};
228
229REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
230
231class OnesLikeOp : public XlaOpKernel {
232 public:
233  explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
234
235  void Compile(XlaOpKernelContext* ctx) override {
236    const TensorShape input_shape = ctx->InputShape(0);
237
238    auto one = XlaHelpers::One(ctx->builder(), input_type(0));
239    ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes()));
240  }
241};
242
243REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);
244
245}  // namespace
246}  // namespace tensorflow
247