1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/tf2xla/type_util.h" 17#include "tensorflow/compiler/tf2xla/xla_helpers.h" 18#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 19#include "tensorflow/compiler/tf2xla/xla_op_registry.h" 20#include "tensorflow/core/util/tensor_format.h" 21 22namespace tensorflow { 23 24namespace { 25 26class ExtractImagePatchesOp : public XlaOpKernel { 27 public: 28 explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 29 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_)); 30 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); 31 OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_)); 32 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); 33 } 34 35 void Compile(XlaOpKernelContext* ctx) override { 36 const TensorFormat data_format = FORMAT_NHWC; 37 const int num_dims = ksizes_.size(); 38 39 OP_REQUIRES( 40 ctx, num_dims >= 3, 41 errors::InvalidArgument("Kernel size must have at least 3 dimensions")); 42 const int num_spatial_dims = num_dims - 2; 43 44 OP_REQUIRES(ctx, strides_.size() == num_dims, 45 errors::InvalidArgument("Sliding window strides field must " 46 "specify ", 47 num_dims, " dimensions")); 48 OP_REQUIRES(ctx, dilations_.size() == num_dims, 49 errors::InvalidArgument("Dilations field must " 50 "specify ", 51 num_dims, " dimensions")); 52 53 int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); 54 int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format); 55 OP_REQUIRES( 56 ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1, 57 errors::Unimplemented("Current implementation does not yet support " 58 "kernel sizes > 1 in the batch and depth " 59 "dimensions.")); 60 OP_REQUIRES( 61 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, 62 errors::Unimplemented("Current implementation does not yet support " 63 "strides in the batch and depth dimensions.")); 64 OP_REQUIRES( 65 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, 66 errors::Unimplemented("Current implementation does not support " 67 "dilations in the batch and depth dimensions.")); 68 69 for (int i = 0; i < num_spatial_dims; ++i) { 70 int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); 71 OP_REQUIRES( 72 ctx, ksizes_[input_dim] >= 0, 73 errors::Unimplemented("Kernel size values must be non-negative; ", i, 74 "th spatial dimension had dilation ", 75 dilations_[input_dim])); 76 OP_REQUIRES(ctx, strides_[input_dim] >= 1, 77 errors::Unimplemented("Stride values must be positive; ", i, 78 "th spatial dimension had dilation ", 79 dilations_[input_dim])); 80 OP_REQUIRES(ctx, dilations_[input_dim] >= 1, 81 errors::Unimplemented("Dilation values must be positive; ", i, 82 "th spatial dimension had dilation ", 83 dilations_[input_dim])); 84 } 85 86 xla::PrimitiveType type; 87 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type)); 88 89 const TensorShape input_shape = ctx->InputShape(0); 90 OP_REQUIRES( 91 ctx, input_shape.dims() == num_dims, 92 errors::InvalidArgument("input must be ", num_dims, "-dimensional", 93 input_shape.DebugString())); 94 const int64 depth = input_shape.dim_size(feature_dim); 95 96 xla::ComputationBuilder* builder = ctx->builder(); 97 98 // The following code is equivalent to: 99 // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD]) 100 int64 kernel_size = 1; 101 std::vector<int64> lhs_shape(num_dims, 1); 102 for (int i = 0; i < num_spatial_dims; ++i) { 103 int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); 104 lhs_shape[i] = ksizes_[input_dim]; 105 kernel_size *= ksizes_[input_dim]; 106 } 107 lhs_shape[num_spatial_dims] = depth; 108 lhs_shape[num_spatial_dims + 1] = 1; 109 110 // Builds an identity matrix as a broadcast equality of iotas. 111 // iota = np.arange(np.prod(ksize), depth) 112 // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) 113 xla::ComputationDataHandle iota; 114 TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, 115 kernel_size * depth, &iota)); 116 117 auto lhs = builder->Reshape(iota, lhs_shape); 118 auto filter = builder->ConvertElementType( 119 builder->Eq(lhs, iota, {num_spatial_dims + 1}), type); 120 121 xla::ConvolutionDimensionNumbers dims; 122 std::vector<int64> window_strides(num_spatial_dims); 123 std::vector<int64> lhs_dilation(num_spatial_dims, 1); 124 std::vector<int64> rhs_dilation(num_spatial_dims); 125 std::vector<std::pair<int64, int64>> padding(num_spatial_dims); 126 127 dims.set_input_batch_dimension(batch_dim); 128 dims.set_output_batch_dimension(batch_dim); 129 dims.set_input_feature_dimension(feature_dim); 130 dims.set_output_feature_dimension(feature_dim); 131 dims.set_kernel_input_feature_dimension(num_spatial_dims); 132 dims.set_kernel_output_feature_dimension(num_spatial_dims + 1); 133 134 for (int i = 0; i < num_spatial_dims; ++i) { 135 const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i); 136 dims.add_input_spatial_dimensions(dim); 137 dims.add_kernel_spatial_dimensions(i); 138 dims.add_output_spatial_dimensions(dim); 139 window_strides[i] = strides_.at(dim); 140 rhs_dilation[i] = dilations_.at(dim); 141 142 int64 unused_output_size; 143 OP_REQUIRES_OK( 144 ctx, GetWindowedOutputSizeVerboseV2( 145 input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i], 146 window_strides[i], padding_, &unused_output_size, 147 &padding[i].first, &padding[i].second)); 148 } 149 150 xla::ComputationDataHandle conv = 151 builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, 152 padding, lhs_dilation, rhs_dilation, dims); 153 ctx->SetOutput(0, conv); 154 } 155 156 protected: 157 std::vector<int32> ksizes_; 158 std::vector<int32> dilations_; 159 std::vector<int32> strides_; 160 Padding padding_; 161 162 private: 163 TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp); 164}; 165 166REGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp); 167 168} // namespace 169} // namespace tensorflow 170