19f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
29f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
39f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
49f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinsyou may not use this file except in compliance with the License.
59f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsYou may obtain a copy of the License at
69f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
79f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
89f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
99f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsUnless required by applicable law or agreed to in writing, software
109f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
119f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsSee the License for the specific language governing permissions and
139f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinslimitations under the License.
149f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins==============================================================================*/
159f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
169f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h"
179f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h"
189f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
199f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
209f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins#include "tensorflow/core/util/tensor_format.h"
219f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
229f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinsnamespace tensorflow {
239f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
249f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinsnamespace {
259f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
269f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinsclass ExtractImagePatchesOp : public XlaOpKernel {
279f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins public:
289f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins  explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
299f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_));
309f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
319f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_));
329f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
339f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins  }
349f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
359f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
369f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    const TensorFormat data_format = FORMAT_NHWC;
379f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    const int num_dims = ksizes_.size();
389f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
399f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES(
409f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        ctx, num_dims >= 3,
419f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        errors::InvalidArgument("Kernel size must have at least 3 dimensions"));
429f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    const int num_spatial_dims = num_dims - 2;
439f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
449f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES(ctx, strides_.size() == num_dims,
459f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                errors::InvalidArgument("Sliding window strides field must "
469f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                        "specify ",
479f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                        num_dims, " dimensions"));
489f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES(ctx, dilations_.size() == num_dims,
499f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                errors::InvalidArgument("Dilations field must "
509f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                        "specify ",
519f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                        num_dims, " dimensions"));
529f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
539f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    int batch_dim = GetTensorBatchDimIndex(num_dims, data_format);
549f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format);
559f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES(
569f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1,
579f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        errors::Unimplemented("Current implementation does not yet support "
589f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                              "kernel sizes > 1 in the batch and depth "
599f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                              "dimensions."));
609f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES(
619f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
629f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        errors::Unimplemented("Current implementation does not yet support "
639f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                              "strides in the batch and depth dimensions."));
649f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES(
659f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
669f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        errors::Unimplemented("Current implementation does not support "
679f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                              "dilations in the batch and depth dimensions."));
689f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
699f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    for (int i = 0; i < num_spatial_dims; ++i) {
709f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
719f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      OP_REQUIRES(
729f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins          ctx, ksizes_[input_dim] >= 0,
739f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins          errors::Unimplemented("Kernel size values must be non-negative; ", i,
749f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                "th spatial dimension had dilation ",
759f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                dilations_[input_dim]));
769f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      OP_REQUIRES(ctx, strides_[input_dim] >= 1,
779f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                  errors::Unimplemented("Stride values must be positive; ", i,
789f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                        "th spatial dimension had dilation ",
799f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                        dilations_[input_dim]));
809f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
819f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                  errors::Unimplemented("Dilation values must be positive; ", i,
829f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                        "th spatial dimension had dilation ",
839f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                        dilations_[input_dim]));
849f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    }
859f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
869f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    xla::PrimitiveType type;
879f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type));
889f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
899f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    const TensorShape input_shape = ctx->InputShape(0);
909f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    OP_REQUIRES(
919f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        ctx, input_shape.dims() == num_dims,
929f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        errors::InvalidArgument("input must be ", num_dims, "-dimensional",
939f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                input_shape.DebugString()));
949f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    const int64 depth = input_shape.dim_size(feature_dim);
959f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
969f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    xla::ComputationBuilder* builder = ctx->builder();
979f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
989f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    // The following code is equivalent to:
999f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
1009f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    int64 kernel_size = 1;
1019f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    std::vector<int64> lhs_shape(num_dims, 1);
1029f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    for (int i = 0; i < num_spatial_dims; ++i) {
1039f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1049f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      lhs_shape[i] = ksizes_[input_dim];
1059f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      kernel_size *= ksizes_[input_dim];
1069f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    }
1079f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    lhs_shape[num_spatial_dims] = depth;
1089f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    lhs_shape[num_spatial_dims + 1] = 1;
1099f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
1109f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    // Builds an identity matrix as a broadcast equality of iotas.
1119f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    // iota = np.arange(np.prod(ksize), depth)
1129f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
1139f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    xla::ComputationDataHandle iota;
1149f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
1159f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                 kernel_size * depth, &iota));
1169f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
1179f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    auto lhs = builder->Reshape(iota, lhs_shape);
1189f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    auto filter = builder->ConvertElementType(
1199f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        builder->Eq(lhs, iota, {num_spatial_dims + 1}), type);
1209f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
1219f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    xla::ConvolutionDimensionNumbers dims;
1229f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    std::vector<int64> window_strides(num_spatial_dims);
1239f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    std::vector<int64> lhs_dilation(num_spatial_dims, 1);
1249f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    std::vector<int64> rhs_dilation(num_spatial_dims);
1259f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    std::vector<std::pair<int64, int64>> padding(num_spatial_dims);
1269f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
1279f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    dims.set_input_batch_dimension(batch_dim);
1289f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    dims.set_output_batch_dimension(batch_dim);
1299f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    dims.set_input_feature_dimension(feature_dim);
1309f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    dims.set_output_feature_dimension(feature_dim);
1319f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    dims.set_kernel_input_feature_dimension(num_spatial_dims);
1329f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    dims.set_kernel_output_feature_dimension(num_spatial_dims + 1);
1339f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
1349f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    for (int i = 0; i < num_spatial_dims; ++i) {
1359f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1369f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      dims.add_input_spatial_dimensions(dim);
1379f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      dims.add_kernel_spatial_dimensions(i);
1389f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      dims.add_output_spatial_dimensions(dim);
1399f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      window_strides[i] = strides_.at(dim);
1409f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      rhs_dilation[i] = dilations_.at(dim);
1419f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
1429f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      int64 unused_output_size;
1439f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins      OP_REQUIRES_OK(
1449f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins          ctx, GetWindowedOutputSizeVerboseV2(
1459f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                   input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i],
1469f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                   window_strides[i], padding_, &unused_output_size,
1479f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                   &padding[i].first, &padding[i].second));
1489f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    }
1499f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
1509f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    xla::ComputationDataHandle conv =
1519f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins        builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
1529f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins                                    padding, lhs_dilation, rhs_dilation, dims);
1539f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins    ctx->SetOutput(0, conv);
1549f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins  }
1559f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
1569f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins protected:
1579f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins  std::vector<int32> ksizes_;
1589f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins  std::vector<int32> dilations_;
1599f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins  std::vector<int32> strides_;
1609f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins  Padding padding_;
1619f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
1629f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins private:
1639f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins  TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp);
1649f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins};
1659f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
1669f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsREGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp);
1679f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins
1689f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins}  // namespace
1699f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins}  // namespace tensorflow
170