11ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 31ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkinsyou may not use this file except in compliance with the License. 51ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter HawkinsYou may obtain a copy of the License at 61ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 71ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 91ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter HawkinsUnless required by applicable law or agreed to in writing, software 101ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter HawkinsSee the License for the specific language governing permissions and 131ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkinslimitations under the License. 141ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins==============================================================================*/ 151ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 161ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" 171ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 181ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include <memory> 191ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include <vector> 201ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 211ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include "tensorflow/compiler/xla/shape_util.h" 221ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include "tensorflow/compiler/xla/status_macros.h" 231ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include "tensorflow/compiler/xla/statusor.h" 241ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include "tensorflow/core/lib/core/errors.h" 251ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 261ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkinsnamespace tensorflow { 271ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 281ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkinsxla::StatusOr<xla::ComputationDataHandle> BatchDot( 291ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins xla::ComputationBuilder* builder, xla::ComputationDataHandle x, 30aca9f4257f064d381be171a7ff2ee114002a8fabA. Unique TensorFlower xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, 31aca9f4257f064d381be171a7ff2ee114002a8fabA. Unique TensorFlower bool conjugate_x, bool conjugate_y) { 321ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> x_shape, 331ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins builder->GetShape(x)); 341ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> y_shape, 351ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins builder->GetShape(y)); 361ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 371ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins // Check that both tensors have the same number of dimensions. There must be 381ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins // at least two (the batch dimensions can be empty). 391ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) { 401ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins return errors::InvalidArgument( 411ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins "Arguments to BatchedDot have different ranks: ", 421ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins xla::ShapeUtil::HumanString(*x_shape), " vs. ", 431ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins xla::ShapeUtil::HumanString(*y_shape)); 441ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins } 451ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins const int ndims = xla::ShapeUtil::Rank(*x_shape); 461ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins if (ndims < 2) { 471ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins return errors::InvalidArgument( 481ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins "Arguments to BatchedDot must have rank >= 2: ", ndims); 491ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins } 501ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 511ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins // The batch dimensions must be equal and the matrix dimensions must be 521ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins // valid. 53dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower std::vector<int64> batch_dimension_numbers; 541ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins for (int i = 0; i < ndims - 2; ++i) { 55dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower if (x_shape->dimensions(i) != y_shape->dimensions(i)) { 561ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins return errors::InvalidArgument( 571ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins "Dimension ", i, " of inputs to BatchedDot must be equal: ", 581ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins xla::ShapeUtil::HumanString(*x_shape), " vs ", 591ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins xla::ShapeUtil::HumanString(*y_shape)); 601ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins } 61dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower batch_dimension_numbers.push_back(i); 621ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins } 631ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 641ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); 651ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); 66dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) { 671ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins return errors::InvalidArgument( 681ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins "Dimensions ", x_inner_dim, " and ", y_inner_dim, 691ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins " of arguments to BatchedDot must be equal: ", 701ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x, 711ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins " vs. ", xla::ShapeUtil::HumanString(*y_shape), 721ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins " transpose: ", transpose_y); 731ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins } 741ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 75dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower // Check for zero lhs/rhs dim size. 76dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower if (xla::ShapeUtil::HasZeroElements(*x_shape) || 77dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower xla::ShapeUtil::HasZeroElements(*y_shape)) { 78dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower std::vector<int64> dimensions(batch_dimension_numbers.size()); 79dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower for (int i = 0; i < batch_dimension_numbers.size(); ++i) { 80dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]); 81dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower } 82dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); 83dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); 84dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower dimensions.push_back(x_shape->dimensions(x_outer_dim)); 85dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower dimensions.push_back(y_shape->dimensions(y_outer_dim)); 86dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower return builder->Broadcast( 87dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())), 88dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower dimensions); 891ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins } 901ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 91aca9f4257f064d381be171a7ff2ee114002a8fabA. Unique TensorFlower if (x_shape->element_type() == xla::C64 && conjugate_x) { 921ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins x = builder->Conj(x); 931ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins } 94aca9f4257f064d381be171a7ff2ee114002a8fabA. Unique TensorFlower if (y_shape->element_type() == xla::C64 && conjugate_y) { 951ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins y = builder->Conj(y); 961ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins } 971ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 98dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower // If there are no batch dimensions, use a regular Dot. 99dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower // TODO(b/69062148) Remove this code when Dot emitters can be passed 100dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower // dimensions to transpose directly (i.e. without requiring a Transpose HLO). 101dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower if (batch_dimension_numbers.empty()) { 102dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; 103dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; 104dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower return builder->Dot(lhs, rhs); 1051ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins } 1061ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 107dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower xla::DotDimensionNumbers dot_dnums; 108dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); 109dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); 110dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower for (auto batch_dimension_number : batch_dimension_numbers) { 111dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); 112dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); 1131ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins } 114dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower return builder->DotGeneral(x, y, dot_dnums); 1151ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins} 1161ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins 1171ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins} // namespace tensorflow 118