11ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
31ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkinsyou may not use this file except in compliance with the License.
51ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter HawkinsYou may obtain a copy of the License at
61ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
71ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
91ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter HawkinsUnless required by applicable law or agreed to in writing, software
101ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter HawkinsSee the License for the specific language governing permissions and
131ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkinslimitations under the License.
141ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins==============================================================================*/
151ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
161ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
171ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
181ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include <memory>
191ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include <vector>
201ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
211ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include "tensorflow/compiler/xla/shape_util.h"
221ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include "tensorflow/compiler/xla/status_macros.h"
231ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include "tensorflow/compiler/xla/statusor.h"
241ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins#include "tensorflow/core/lib/core/errors.h"
251ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
261ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkinsnamespace tensorflow {
271ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
281ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkinsxla::StatusOr<xla::ComputationDataHandle> BatchDot(
291ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins    xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
30aca9f4257f064d381be171a7ff2ee114002a8fabA. Unique TensorFlower    xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
31aca9f4257f064d381be171a7ff2ee114002a8fabA. Unique TensorFlower    bool conjugate_x, bool conjugate_y) {
321ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> x_shape,
331ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins                      builder->GetShape(x));
341ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> y_shape,
351ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins                      builder->GetShape(y));
361ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
371ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  // Check that both tensors have the same number of dimensions. There must be
381ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  // at least two (the batch dimensions can be empty).
391ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) {
401ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins    return errors::InvalidArgument(
411ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins        "Arguments to BatchedDot have different ranks: ",
421ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins        xla::ShapeUtil::HumanString(*x_shape), " vs. ",
431ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins        xla::ShapeUtil::HumanString(*y_shape));
441ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  }
451ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  const int ndims = xla::ShapeUtil::Rank(*x_shape);
461ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  if (ndims < 2) {
471ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins    return errors::InvalidArgument(
481ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins        "Arguments to BatchedDot must have rank >= 2: ", ndims);
491ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  }
501ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
511ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  // The batch dimensions must be equal and the matrix dimensions must be
521ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  // valid.
53dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  std::vector<int64> batch_dimension_numbers;
541ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  for (int i = 0; i < ndims - 2; ++i) {
55dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    if (x_shape->dimensions(i) != y_shape->dimensions(i)) {
561ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins      return errors::InvalidArgument(
571ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins          "Dimension ", i, " of inputs to BatchedDot must be equal: ",
581ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins          xla::ShapeUtil::HumanString(*x_shape), " vs ",
591ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins          xla::ShapeUtil::HumanString(*y_shape));
601ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins    }
61dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    batch_dimension_numbers.push_back(i);
621ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  }
631ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
641ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1);
651ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2);
66dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) {
671ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins    return errors::InvalidArgument(
681ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins        "Dimensions ", x_inner_dim, " and ", y_inner_dim,
691ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins        " of arguments to BatchedDot must be equal: ",
701ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins        xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x,
711ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins        " vs. ", xla::ShapeUtil::HumanString(*y_shape),
721ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins        " transpose: ", transpose_y);
731ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  }
741ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
75dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  // Check for zero lhs/rhs dim size.
76dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  if (xla::ShapeUtil::HasZeroElements(*x_shape) ||
77dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower      xla::ShapeUtil::HasZeroElements(*y_shape)) {
78dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    std::vector<int64> dimensions(batch_dimension_numbers.size());
79dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
80dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower      dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]);
81dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    }
82dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
83dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
84dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    dimensions.push_back(x_shape->dimensions(x_outer_dim));
85dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    dimensions.push_back(y_shape->dimensions(y_outer_dim));
86dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    return builder->Broadcast(
87dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower        builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())),
88dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower        dimensions);
891ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  }
901ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
91aca9f4257f064d381be171a7ff2ee114002a8fabA. Unique TensorFlower  if (x_shape->element_type() == xla::C64 && conjugate_x) {
921ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins    x = builder->Conj(x);
931ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  }
94aca9f4257f064d381be171a7ff2ee114002a8fabA. Unique TensorFlower  if (y_shape->element_type() == xla::C64 && conjugate_y) {
951ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins    y = builder->Conj(y);
961ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  }
971ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
98dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  // If there are no batch dimensions, use a regular Dot.
99dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  // TODO(b/69062148) Remove this code when Dot emitters can be passed
100dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  // dimensions to transpose directly (i.e. without requiring a Transpose HLO).
101dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  if (batch_dimension_numbers.empty()) {
102dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x;
103dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y;
104dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    return builder->Dot(lhs, rhs);
1051ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  }
1061ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
107dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  xla::DotDimensionNumbers dot_dnums;
108dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  dot_dnums.add_lhs_contracting_dimensions(x_inner_dim);
109dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  dot_dnums.add_rhs_contracting_dimensions(y_inner_dim);
110dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  for (auto batch_dimension_number : batch_dimension_numbers) {
111dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    dot_dnums.add_lhs_batch_dimensions(batch_dimension_number);
112dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower    dot_dnums.add_rhs_batch_dimensions(batch_dimension_number);
1131ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins  }
114dd77f385591c8b6ef7ab8dae7429c7eff7813a1eA. Unique TensorFlower  return builder->DotGeneral(x, y, dot_dnums);
1151ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins}
1161ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins
1171ddff94002e037fce0fd15c62f4c6090aeb5dce4Peter Hawkins}  // namespace tensorflow
118