11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
16b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower#include "tensorflow/compiler/tf2xla/lib/util.h"
17b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower#include "tensorflow/compiler/tf2xla/type_util.h"
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_helpers.h"
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20a8c325e57c1077f1e8df540a20bd8b36d3d1f968Peter Hawkins#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/util.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/framework/op_kernel.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace tensorflow {
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace {
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
27b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower// Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
28b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlowerxla::StatusOr<xla::ComputationDataHandle> CreateDiagonal(
29b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    const xla::ComputationDataHandle& input, int64 last_dim_size,
30b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    tensorflow::gtl::ArraySlice<int64> other_dims, XlaOpKernelContext* ctx,
31b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    xla::ComputationBuilder* builder) {
32b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // Create two matrices that have the following forms, and compare them:
33b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  //
34b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // [[0, 0, 0, 0]            [[0, 1, 2, 3]
35b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  //  [1, 1, 1, 1]             [0, 1, 2, 3]
36b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  //  [2, 2, 2, 2]             [0, 1, 2, 3]
37b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  //  [3, 3, 3, 3]]            [0, 1, 2, 3]]
38b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  //
39b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // This produces a predicate matrix of the right size, with "true" on the
40b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // diagonal.
41b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  xla::ComputationDataHandle iota;
42b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  TF_RETURN_IF_ERROR(
43b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower      XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota));
44b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  xla::ComputationDataHandle iota_broadcast =
45b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower      builder->Broadcast(iota, {last_dim_size});
46b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0});
47b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower
48b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // If this is a batched diagonal, broadcast the mask across the other
49b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // dimensions.
50b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  if (!other_dims.empty()) {
51b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    mask = builder->Broadcast(mask, other_dims);
52b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  }
53b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower
54b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // Broadcast the input, and then use the mask computed above to select the
55b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // diagonal:
56b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // e.g, in 2D:
57b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  //         [[t, f, f]    [[1, 1, 1]    [[0, 0, 0]      [[1, 0, 0]
58b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // select(  [f, t, f]  ,  [4, 4, 4]  ,  [0, 0, 0]  ) =  [0, 4, 0]
59b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  //          [f, f, t]]    [9, 9, 9]]    [0, 0, 0]]      [0, 0, 9]]
60b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  //
61b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // Broadcasting the input is less-than-trivial, since we need to broadcast
62b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // into a "middle" dimension. We can do this with a reshape + implicit
63b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // broadcast.
64b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  // TODO(b/30112114): Replace with in-dim broadcast when those are supported.
65b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  std::vector<int64> broadcast_dims(other_dims.begin(), other_dims.end());
66b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  broadcast_dims.push_back(1LL);
67b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  broadcast_dims.push_back(last_dim_size);
68b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  xla::ComputationDataHandle input_broadcast =
69b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower      builder->Reshape(input, broadcast_dims);
70b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower
71b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  broadcast_dims[broadcast_dims.size() - 2] = last_dim_size;
72b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  xla::PrimitiveType element_type;
73b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  TF_RETURN_IF_ERROR(
74b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower      DataTypeToPrimitiveType(ctx->input_type(0), &element_type));
75b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  auto broadcast_shape =
76b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower      xla::ShapeUtil::MakeShape(element_type, broadcast_dims);
77b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape);
78b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower
79b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  input_broadcast = builder->Add(input_broadcast, zeros);
80b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower  return builder->Select(mask, input_broadcast, zeros);
81b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower}
82b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass DiagOp : public XlaOpKernel {
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationBuilder* builder = ctx->builder();
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
90b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
91b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower                errors::InvalidArgument("Diag op must have at an input"));
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape input_shape = ctx->InputShape(0);
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto dims = input_shape.dim_sizes();
951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, !dims.empty(),
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("Expected 1 <= dims, got shape ",
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        input_shape.DebugString()));
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
99b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    xla::ComputationDataHandle input = ctx->Input(0);
1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Picture:
1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0]
1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //                            [0, 2, 0, 0]
1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //                            [0, 0, 3, 0]
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //                            [0, 0, 0, 4]]
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Flattens the input to 1D.
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int64 size = input_shape.num_elements();
109b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    input = builder->Reshape(input, {size});
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
111b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    // Create an R2 with the R1 diagonal.
112b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    auto diag_or_status =
113b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower        CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder);
114b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    OP_REQUIRES_OK(ctx, diag_or_status.status());
115b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    xla::ComputationDataHandle diag = diag_or_status.ValueOrDie();
1161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Reshapes to the final shape.
1181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int64> new_dims(dims.size() * 2);
1191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::copy(dims.begin(), dims.end(), new_dims.begin());
1201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size());
1211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    diag = builder->Reshape(diag, new_dims);
1221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ctx->SetOutput(0, diag);
1241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
1261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
12793f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("Diag"), DiagOp);
1281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass DiagPartOp : public XlaOpKernel {
1301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
1311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
1321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
1341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationBuilder* builder = ctx->builder();
1351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape input_shape = ctx->InputShape(0);
1371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto dims = input_shape.dim_sizes();
1381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int num_dims = dims.size();
1401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const int out_dims = num_dims / 2;
1411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, 2 <= num_dims,
1431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("Expected 2 <= dims, got shape ",
1441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        input_shape.DebugString()));
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, num_dims % 2 == 0,
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("The input tensor must have even rank; "
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        "got shape ",
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        input_shape.DebugString()));
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int64 new_size = 1;
1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int64> new_dims;
1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    for (int i = 0; i < out_dims; i++) {
1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      OP_REQUIRES(
1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          ctx, dims[i] == dims[i + out_dims],
1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          errors::InvalidArgument("Invalid shape ", input_shape.DebugString(),
1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                  ": dimensions ", i, " and ", i + out_dims,
1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                  " do not match."));
1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      new_size *= dims[i];
1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      new_dims.push_back(dims[i]);
1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationDataHandle diag = ctx->Input(0);
1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // TODO(b/30878775): use Slice with strides when supported, in place of
1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // the Pad -> Reshape -> Slice.
1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Picture:
1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // [[1, 0, 0, 0]  pad and reshape to [[1, 0, 0, 0, 0],
1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //  [0, 2, 0, 0]  =================>  [2, 0, 0, 0, 0],
1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //  [0, 0, 3, 0]                      [3, 0, 0, 0, 0],
1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //  [0, 0, 0, 4]]                     [4, 0, 0, 0, 0]]
1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // and then slice out the first column.
1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Flattens the input to 1D.
1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int64 size = input_shape.num_elements();
1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    diag = builder->Reshape(diag, {size});
1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Adds padding after the last element of 'new_size'.
1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::PaddingConfig config;
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto* dim = config.add_dimensions();
1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    dim->set_edge_padding_high(new_size);
1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto zero = XlaHelpers::Zero(builder, input_type(0));
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    diag = builder->Pad(diag, zero, config);
1831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Reshapes so the diagonal is now in the first column.
1851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    diag = builder->Reshape(diag, {new_size, new_size + 1});
1861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Slices out the first column and reshapes to the final shape.
18850b999a8336d19400ab75aea66fe46eca2f5fe0bA. Unique TensorFlower    diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1});
1891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    diag = builder->Reshape(diag, new_dims);
1901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ctx->SetOutput(0, diag);
1921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
1941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
19593f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("DiagPart"), DiagPartOp);
1961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass MatrixDiagOp : public XlaOpKernel {
1981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
1991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
2001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
2021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationBuilder* builder = ctx->builder();
2031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
204b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
205b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower                errors::InvalidArgument("MatrixDiag op must have at an input"));
2061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape input_shape = ctx->InputShape(0);
2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto dims = input_shape.dim_sizes();
2091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, !dims.empty(),
2101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("Expected 1 <= dims, got shape ",
2111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        input_shape.DebugString()));
2121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationDataHandle diag = ctx->Input(0);
2141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int last_dim = dims.size() - 1;
2161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int64 last_dim_size = input_shape.dim_size(last_dim);
217b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    tensorflow::gtl::ArraySlice<int64> other_dims(dims);
218b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    other_dims.pop_back();
2191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
220b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    auto diag_or_status =
221b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower        CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder);
222b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    OP_REQUIRES_OK(ctx, diag_or_status.status());
223b7affdee5d3baa3c98084f254510d65c7f8a3860A. Unique TensorFlower    diag = diag_or_status.ValueOrDie();
2241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ctx->SetOutput(0, diag);
2251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
2261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
2271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
22893f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
2291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass MatrixDiagPartOp : public XlaOpKernel {
2311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
2321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
2331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  void Compile(XlaOpKernelContext* ctx) override {
2351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationBuilder* builder = ctx->builder();
2361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const TensorShape input_shape = ctx->InputShape(0);
2381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    auto dims = input_shape.dim_sizes();
2391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    OP_REQUIRES(ctx, 2 <= dims.size(),
2411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                errors::InvalidArgument("Expected 2 <= dims, got shape ",
2421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                        input_shape.DebugString()));
2431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    xla::ComputationDataHandle diag = ctx->Input(0);
2451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int last_dim = dims.size() - 1;
2471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int64 last_dim_size = dims[last_dim];
2481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // The smaller of the last two dimension sizes.
2501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int64 smaller_dim_size = std::min(dims[last_dim - 1], dims[last_dim]);
2511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // TODO(b/30878775): use Slice with strides when supported, in place of
2531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // the Pad -> Reshape -> Slice.
2541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Picture: for each 2D matrix in the tensor's last two dimensions:
2561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // [[1, 0, 0, 0]  pad and reshape to [[1, 0, 0, 0, 0],
2571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //  [0, 2, 0, 0]  =================>  [2, 0, 0, 0, 0],
2581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //  [0, 0, 3, 0]]                     [3, 0, 0, 0, 0],
2591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // and then slice out the first column.
2601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //
2611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Another example, with tall and narrow input.
2621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // [[1, 0]  pad and reshape to [[1, 0, 0],
2631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //  [0, 2]  =================>  [2, 0, 0]]
2641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //  [0, 0]
2651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    //  [0, 0]]
2661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Collapses the last two dimensions.
2681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int64> flattened_dims(dims.begin(), dims.end() - 1);
2691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    flattened_dims.back() *= dims.back();
2701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    diag = builder->Reshape(diag, flattened_dims);
2711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Slices or pads the last dimension to 'target_size'.
2731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int64 actual_size = flattened_dims.back();
2741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int64 target_size = smaller_dim_size * (last_dim_size + 1);
2751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    if (actual_size < target_size) {
2761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      xla::PaddingConfig config =
2771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins          xla::MakeNoPaddingConfig(flattened_dims.size());
2781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      auto* dim = config.mutable_dimensions(flattened_dims.size() - 1);
2791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      dim->set_edge_padding_high(target_size - actual_size);
2801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      auto zero = XlaHelpers::Zero(builder, input_type(0));
2811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      diag = builder->Pad(diag, zero, config);
2821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    } else if (actual_size > target_size) {
2831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      std::vector<int64> start(flattened_dims.size(), 0);
2841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end());
28550b999a8336d19400ab75aea66fe46eca2f5fe0bA. Unique TensorFlower      std::vector<int64> strides(flattened_dims.size(), 1);
2861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      limits[flattened_dims.size() - 1] = target_size;
28750b999a8336d19400ab75aea66fe46eca2f5fe0bA. Unique TensorFlower      diag = builder->Slice(diag, start, limits, strides);
2881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    }
2891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Reshape so the target values are in the first position of the last
2911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // dimension.
2921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int64> unflattened_dims(dims.begin(), dims.end());
2931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    dims[last_dim - 1] = smaller_dim_size;
2941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    dims[last_dim] = last_dim_size + 1;
2951e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    diag = builder->Reshape(diag, dims);
2961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Slices out the first column and reshapes to the final shape.
2981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int64> start(dims.size(), 0);
2991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<int64> limits(dims.begin(), dims.end());
30050b999a8336d19400ab75aea66fe46eca2f5fe0bA. Unique TensorFlower    std::vector<int64> strides(dims.size(), 1);
3011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    limits[last_dim] = 1;
30250b999a8336d19400ab75aea66fe46eca2f5fe0bA. Unique TensorFlower    diag = builder->Slice(diag, start, limits, strides);
3031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Collapses away the last dimension.
3051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    dims.pop_back();
3061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    diag = builder->Reshape(diag, dims);
3071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ctx->SetOutput(0, diag);
3091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
3101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
3111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31293f9caba8e371bd2f55ec789ed2f8ece9b3d976dPeter HawkinsREGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
3131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
3141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace
3151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace tensorflow
316