19f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 29f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 39f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 49f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinsyou may not use this file except in compliance with the License. 59f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsYou may obtain a copy of the License at 69f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 79f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins http://www.apache.org/licenses/LICENSE-2.0 89f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 99f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsUnless required by applicable law or agreed to in writing, software 109f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 119f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsSee the License for the specific language governing permissions and 139f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinslimitations under the License. 149f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins==============================================================================*/ 159f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 169f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins#include "tensorflow/compiler/tf2xla/type_util.h" 179f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h" 189f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 199f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h" 209f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins#include "tensorflow/core/util/tensor_format.h" 219f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 229f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinsnamespace tensorflow { 239f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 249f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinsnamespace { 259f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 269f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkinsclass ExtractImagePatchesOp : public XlaOpKernel { 279f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins public: 289f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 299f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_)); 309f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); 319f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_)); 329f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); 339f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins } 349f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 359f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins void Compile(XlaOpKernelContext* ctx) override { 369f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins const TensorFormat data_format = FORMAT_NHWC; 379f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins const int num_dims = ksizes_.size(); 389f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 399f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES( 409f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins ctx, num_dims >= 3, 419f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins errors::InvalidArgument("Kernel size must have at least 3 dimensions")); 429f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins const int num_spatial_dims = num_dims - 2; 439f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 449f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES(ctx, strides_.size() == num_dims, 459f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins errors::InvalidArgument("Sliding window strides field must " 469f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins "specify ", 479f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins num_dims, " dimensions")); 489f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES(ctx, dilations_.size() == num_dims, 499f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins errors::InvalidArgument("Dilations field must " 509f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins "specify ", 519f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins num_dims, " dimensions")); 529f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 539f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); 549f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format); 559f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES( 569f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1, 579f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins errors::Unimplemented("Current implementation does not yet support " 589f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins "kernel sizes > 1 in the batch and depth " 599f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins "dimensions.")); 609f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES( 619f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, 629f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins errors::Unimplemented("Current implementation does not yet support " 639f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins "strides in the batch and depth dimensions.")); 649f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES( 659f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, 669f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins errors::Unimplemented("Current implementation does not support " 679f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins "dilations in the batch and depth dimensions.")); 689f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 699f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins for (int i = 0; i < num_spatial_dims; ++i) { 709f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); 719f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES( 729f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins ctx, ksizes_[input_dim] >= 0, 739f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins errors::Unimplemented("Kernel size values must be non-negative; ", i, 749f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins "th spatial dimension had dilation ", 759f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dilations_[input_dim])); 769f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES(ctx, strides_[input_dim] >= 1, 779f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins errors::Unimplemented("Stride values must be positive; ", i, 789f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins "th spatial dimension had dilation ", 799f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dilations_[input_dim])); 809f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES(ctx, dilations_[input_dim] >= 1, 819f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins errors::Unimplemented("Dilation values must be positive; ", i, 829f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins "th spatial dimension had dilation ", 839f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dilations_[input_dim])); 849f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins } 859f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 869f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins xla::PrimitiveType type; 879f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type)); 889f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 899f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins const TensorShape input_shape = ctx->InputShape(0); 909f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES( 919f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins ctx, input_shape.dims() == num_dims, 929f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins errors::InvalidArgument("input must be ", num_dims, "-dimensional", 939f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins input_shape.DebugString())); 949f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins const int64 depth = input_shape.dim_size(feature_dim); 959f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 969f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins xla::ComputationBuilder* builder = ctx->builder(); 979f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 989f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins // The following code is equivalent to: 999f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD]) 1009f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins int64 kernel_size = 1; 1019f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins std::vector<int64> lhs_shape(num_dims, 1); 1029f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins for (int i = 0; i < num_spatial_dims; ++i) { 1039f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); 1049f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins lhs_shape[i] = ksizes_[input_dim]; 1059f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins kernel_size *= ksizes_[input_dim]; 1069f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins } 1079f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins lhs_shape[num_spatial_dims] = depth; 1089f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins lhs_shape[num_spatial_dims + 1] = 1; 1099f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 1109f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins // Builds an identity matrix as a broadcast equality of iotas. 1119f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins // iota = np.arange(np.prod(ksize), depth) 1129f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) 1139f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins xla::ComputationDataHandle iota; 1149f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, 1159f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins kernel_size * depth, &iota)); 1169f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 1179f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins auto lhs = builder->Reshape(iota, lhs_shape); 1189f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins auto filter = builder->ConvertElementType( 1199f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins builder->Eq(lhs, iota, {num_spatial_dims + 1}), type); 1209f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 1219f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins xla::ConvolutionDimensionNumbers dims; 1229f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins std::vector<int64> window_strides(num_spatial_dims); 1239f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins std::vector<int64> lhs_dilation(num_spatial_dims, 1); 1249f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins std::vector<int64> rhs_dilation(num_spatial_dims); 1259f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins std::vector<std::pair<int64, int64>> padding(num_spatial_dims); 1269f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 1279f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dims.set_input_batch_dimension(batch_dim); 1289f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dims.set_output_batch_dimension(batch_dim); 1299f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dims.set_input_feature_dimension(feature_dim); 1309f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dims.set_output_feature_dimension(feature_dim); 1319f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dims.set_kernel_input_feature_dimension(num_spatial_dims); 1329f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dims.set_kernel_output_feature_dimension(num_spatial_dims + 1); 1339f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 1349f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins for (int i = 0; i < num_spatial_dims; ++i) { 1359f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i); 1369f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dims.add_input_spatial_dimensions(dim); 1379f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dims.add_kernel_spatial_dimensions(i); 1389f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins dims.add_output_spatial_dimensions(dim); 1399f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins window_strides[i] = strides_.at(dim); 1409f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins rhs_dilation[i] = dilations_.at(dim); 1419f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 1429f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins int64 unused_output_size; 1439f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins OP_REQUIRES_OK( 1449f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins ctx, GetWindowedOutputSizeVerboseV2( 1459f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i], 1469f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins window_strides[i], padding_, &unused_output_size, 1479f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins &padding[i].first, &padding[i].second)); 1489f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins } 1499f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 1509f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins xla::ComputationDataHandle conv = 1519f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, 1529f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins padding, lhs_dilation, rhs_dilation, dims); 1539f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins ctx->SetOutput(0, conv); 1549f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins } 1559f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 1569f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins protected: 1579f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins std::vector<int32> ksizes_; 1589f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins std::vector<int32> dilations_; 1599f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins std::vector<int32> strides_; 1609f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins Padding padding_; 1619f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 1629f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins private: 1639f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp); 1649f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins}; 1659f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 1669f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter HawkinsREGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp); 1679f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins 1689f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins} // namespace 1699f75f8e6d55fb9ad605bce80b656e3e19781ee43Peter Hawkins} // namespace tensorflow 170