1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/tf2xla/type_util.h"
17#include "tensorflow/compiler/tf2xla/xla_helpers.h"
18#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
19#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
20#include "tensorflow/core/util/tensor_format.h"
21
22namespace tensorflow {
23
24namespace {
25
26class ExtractImagePatchesOp : public XlaOpKernel {
27 public:
28  explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
29    OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_));
30    OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
31    OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_));
32    OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
33  }
34
35  void Compile(XlaOpKernelContext* ctx) override {
36    const TensorFormat data_format = FORMAT_NHWC;
37    const int num_dims = ksizes_.size();
38
39    OP_REQUIRES(
40        ctx, num_dims >= 3,
41        errors::InvalidArgument("Kernel size must have at least 3 dimensions"));
42    const int num_spatial_dims = num_dims - 2;
43
44    OP_REQUIRES(ctx, strides_.size() == num_dims,
45                errors::InvalidArgument("Sliding window strides field must "
46                                        "specify ",
47                                        num_dims, " dimensions"));
48    OP_REQUIRES(ctx, dilations_.size() == num_dims,
49                errors::InvalidArgument("Dilations field must "
50                                        "specify ",
51                                        num_dims, " dimensions"));
52
53    int batch_dim = GetTensorBatchDimIndex(num_dims, data_format);
54    int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format);
55    OP_REQUIRES(
56        ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1,
57        errors::Unimplemented("Current implementation does not yet support "
58                              "kernel sizes > 1 in the batch and depth "
59                              "dimensions."));
60    OP_REQUIRES(
61        ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
62        errors::Unimplemented("Current implementation does not yet support "
63                              "strides in the batch and depth dimensions."));
64    OP_REQUIRES(
65        ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
66        errors::Unimplemented("Current implementation does not support "
67                              "dilations in the batch and depth dimensions."));
68
69    for (int i = 0; i < num_spatial_dims; ++i) {
70      int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
71      OP_REQUIRES(
72          ctx, ksizes_[input_dim] >= 0,
73          errors::Unimplemented("Kernel size values must be non-negative; ", i,
74                                "th spatial dimension had dilation ",
75                                dilations_[input_dim]));
76      OP_REQUIRES(ctx, strides_[input_dim] >= 1,
77                  errors::Unimplemented("Stride values must be positive; ", i,
78                                        "th spatial dimension had dilation ",
79                                        dilations_[input_dim]));
80      OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
81                  errors::Unimplemented("Dilation values must be positive; ", i,
82                                        "th spatial dimension had dilation ",
83                                        dilations_[input_dim]));
84    }
85
86    xla::PrimitiveType type;
87    OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type));
88
89    const TensorShape input_shape = ctx->InputShape(0);
90    OP_REQUIRES(
91        ctx, input_shape.dims() == num_dims,
92        errors::InvalidArgument("input must be ", num_dims, "-dimensional",
93                                input_shape.DebugString()));
94    const int64 depth = input_shape.dim_size(feature_dim);
95
96    xla::ComputationBuilder* builder = ctx->builder();
97
98    // The following code is equivalent to:
99    // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
100    int64 kernel_size = 1;
101    std::vector<int64> lhs_shape(num_dims, 1);
102    for (int i = 0; i < num_spatial_dims; ++i) {
103      int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
104      lhs_shape[i] = ksizes_[input_dim];
105      kernel_size *= ksizes_[input_dim];
106    }
107    lhs_shape[num_spatial_dims] = depth;
108    lhs_shape[num_spatial_dims + 1] = 1;
109
110    // Builds an identity matrix as a broadcast equality of iotas.
111    // iota = np.arange(np.prod(ksize), depth)
112    // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
113    xla::ComputationDataHandle iota;
114    TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
115                                 kernel_size * depth, &iota));
116
117    auto lhs = builder->Reshape(iota, lhs_shape);
118    auto filter = builder->ConvertElementType(
119        builder->Eq(lhs, iota, {num_spatial_dims + 1}), type);
120
121    xla::ConvolutionDimensionNumbers dims;
122    std::vector<int64> window_strides(num_spatial_dims);
123    std::vector<int64> lhs_dilation(num_spatial_dims, 1);
124    std::vector<int64> rhs_dilation(num_spatial_dims);
125    std::vector<std::pair<int64, int64>> padding(num_spatial_dims);
126
127    dims.set_input_batch_dimension(batch_dim);
128    dims.set_output_batch_dimension(batch_dim);
129    dims.set_input_feature_dimension(feature_dim);
130    dims.set_output_feature_dimension(feature_dim);
131    dims.set_kernel_input_feature_dimension(num_spatial_dims);
132    dims.set_kernel_output_feature_dimension(num_spatial_dims + 1);
133
134    for (int i = 0; i < num_spatial_dims; ++i) {
135      const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
136      dims.add_input_spatial_dimensions(dim);
137      dims.add_kernel_spatial_dimensions(i);
138      dims.add_output_spatial_dimensions(dim);
139      window_strides[i] = strides_.at(dim);
140      rhs_dilation[i] = dilations_.at(dim);
141
142      int64 unused_output_size;
143      OP_REQUIRES_OK(
144          ctx, GetWindowedOutputSizeVerboseV2(
145                   input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i],
146                   window_strides[i], padding_, &unused_output_size,
147                   &padding[i].first, &padding[i].second));
148    }
149
150    xla::ComputationDataHandle conv =
151        builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
152                                    padding, lhs_dilation, rhs_dilation, dims);
153    ctx->SetOutput(0, conv);
154  }
155
156 protected:
157  std::vector<int32> ksizes_;
158  std::vector<int32> dilations_;
159  std::vector<int32> strides_;
160  Padding padding_;
161
162 private:
163  TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp);
164};
165
166REGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp);
167
168}  // namespace
169}  // namespace tensorflow
170