1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This file defines helper routines for XLA compilation.
17
18#include "tensorflow/compiler/tf2xla/xla_helpers.h"
19#include "tensorflow/compiler/tf2xla/lib/util.h"
20
21#include "tensorflow/compiler/tf2xla/literal_util.h"
22#include "tensorflow/compiler/tf2xla/type_util.h"
23#include "tensorflow/compiler/tf2xla/xla_context.h"
24#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25#include "tensorflow/compiler/xla/client/computation_builder.h"
26#include "tensorflow/compiler/xla/types.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/lib/gtl/array_slice.h"
29
30namespace tensorflow {
31
32namespace {
33
34Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx,
35                 const xla::ComputationDataHandle& input,
36                 const TensorShape& input_shape, DataType input_type,
37                 DataType output_type, int axis, bool is_min,
38                 xla::ComputationDataHandle* argminmax) {
39  xla::ComputationDataHandle init_value;
40  const xla::Computation* reducer;
41  if (is_min) {
42    init_value = XlaHelpers::MaxValue(builder, input_type);
43    reducer = ctx->GetOrCreateMin(input_type);
44  } else {
45    init_value = XlaHelpers::MinValue(builder, input_type);
46    reducer = ctx->GetOrCreateMax(input_type);
47  }
48
49  xla::PrimitiveType xla_output_type;
50  TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type));
51
52  xla::ComputationDataHandle input_max = builder->Reduce(
53      input, init_value, *reducer, /*dimensions_to_reduce=*/{axis});
54  std::vector<int64> broadcast_dims(input_shape.dims() - 1);
55  std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
56  std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
57  // Compute a mask that has 1s for elements equal to the maximum.
58  xla::ComputationDataHandle partial_mask = builder->ConvertElementType(
59      builder->Eq(input, input_max, broadcast_dims), xla_output_type);
60
61  // In order to make identity elements for a bitwise And, we:
62  //   Left shift the 1 to the leftmost bit, yielding 0x10...0
63  //   Arithmetic right shift the 1 back to the rightmost bit, yielding
64  //   0xFF...F
65  int32 bits_in_type =
66      xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1;
67  xla::ComputationDataHandle shift_amount =
68      XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type);
69  xla::ComputationDataHandle full_mask = builder->ShiftRightArithmetic(
70      builder->ShiftLeft(partial_mask, shift_amount), shift_amount);
71
72  // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
73  // index.
74  xla::ComputationDataHandle iota;
75
76  const int64 axis_size = input_shape.dim_size(axis);
77  TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota));
78  xla::ComputationDataHandle product =
79      builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis});
80
81  // If there are multiple maximum elements, choose the one with the highest
82  // index.
83  xla::ComputationDataHandle output =
84      builder->Reduce(product, XlaHelpers::MinValue(builder, output_type),
85                      *ctx->GetOrCreateMax(output_type),
86                      /*dimensions_to_reduce=*/{axis});
87  *argminmax = output;
88  return Status::OK();
89}
90
91}  // namespace
92
93xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b,
94                                                DataType data_type) {
95  xla::PrimitiveType type;
96  TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
97  return b->ConstantLiteral(xla::Literal::MinValue(type));
98}
99
100xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b,
101                                                DataType data_type) {
102  xla::PrimitiveType type;
103  TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
104  return b->ConstantLiteral(xla::Literal::MaxValue(type));
105}
106
107xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b,
108                                            DataType data_type) {
109  xla::PrimitiveType type;
110  TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
111  return b->ConstantLiteral(xla::Literal::Zero(type));
112}
113
114xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
115                                           DataType data_type) {
116  xla::PrimitiveType type;
117  TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
118  return b->ConstantLiteral(xla::Literal::One(type));
119}
120
121xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
122                                               DataType data_type) {
123  switch (data_type) {
124    case DT_BFLOAT16:
125      return b->ConstantR0<bfloat16>(bfloat16::epsilon());
126    case DT_FLOAT:
127      return b->ConstantR0<float>(std::numeric_limits<float>::epsilon());
128    case DT_DOUBLE:
129      return b->ConstantR0<double>(std::numeric_limits<double>::epsilon());
130    default:
131      LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: "
132                 << DataTypeString(data_type);
133  }
134}
135
136xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
137    xla::ComputationBuilder* b, DataType data_type, int64 value) {
138  xla::PrimitiveType type;
139  TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
140  return ::tensorflow::IntegerLiteral(b, type, value);
141}
142
143xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,
144                                                    DataType data_type,
145                                                    double value) {
146  xla::PrimitiveType type;
147  TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
148  return ::tensorflow::FloatLiteral(b, type, value);
149}
150
151/* static */ Status XlaHelpers::ReshapeLiteral(
152    const xla::Literal& input, gtl::ArraySlice<int64> dimensions,
153    xla::Literal* output) {
154  if (xla::ShapeUtil::IsTuple(input.shape())) {
155    return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
156  }
157  xla::Shape shape =
158      xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
159  int64 elements_before = xla::ShapeUtil::ElementsIn(input.shape());
160  int64 elements_after = xla::ShapeUtil::ElementsIn(shape);
161  if (elements_before != elements_after) {
162    return errors::InvalidArgument(
163        "Shapes before and after ReshapeLiteral have different numbers of "
164        "elements.");
165  }
166
167  *output = input.Clone();
168  output->mutable_shape_do_not_use()->Swap(&shape);
169  return Status::OK();
170}
171
172template <typename T>
173static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
174  Tensor linspace(DataTypeToEnum<T>::v(), shape);
175  auto linspace_flat = linspace.flat<T>();
176  for (int64 i = 0; i < depth; ++i) {
177    linspace_flat(i) = i;
178  }
179  return linspace;
180}
181
182Status XlaHelpers::ArgMax(xla::ComputationBuilder* builder,
183                          XlaOpKernelContext* ctx,
184                          const xla::ComputationDataHandle& input,
185                          const TensorShape& input_shape, DataType input_type,
186                          DataType output_type, int axis,
187                          xla::ComputationDataHandle* argmax) {
188  return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type,
189                   axis, /*is_min=*/false, argmax);
190}
191
192Status XlaHelpers::ArgMin(xla::ComputationBuilder* builder,
193                          XlaOpKernelContext* ctx,
194                          const xla::ComputationDataHandle& input,
195                          const TensorShape& input_shape, DataType input_type,
196                          DataType output_type, int axis,
197                          xla::ComputationDataHandle* argmin) {
198  return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type,
199                   axis, /*is_min=*/true, argmin);
200}
201
202Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype,
203                        int64 size, xla::ComputationDataHandle* iota) {
204  TensorShape linspace_shape({size});
205  Tensor linspace;
206  switch (dtype) {
207    case DT_UINT8:
208      linspace = MakeLinspaceTensor<uint8>(linspace_shape, size);
209      break;
210    case DT_INT32:
211      linspace = MakeLinspaceTensor<int32>(linspace_shape, size);
212      break;
213    case DT_INT64:
214      linspace = MakeLinspaceTensor<int64>(linspace_shape, size);
215      break;
216    default:
217      return errors::InvalidArgument("Invalid argument type ",
218                                     DataTypeString(dtype));
219  }
220  xla::Literal linspace_literal;
221  TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal));
222  *iota = builder->ConstantLiteral(linspace_literal);
223  return Status::OK();
224}
225
226Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth,
227                          int axis, DataType index_type,
228                          const TensorShape& indices_shape,
229                          const xla::ComputationDataHandle& indices,
230                          const xla::ComputationDataHandle& on_value,
231                          const xla::ComputationDataHandle& off_value,
232                          xla::ComputationDataHandle* one_hot) {
233  const int indices_dims = indices_shape.dims();
234  const int output_dims = indices_dims + 1;
235
236  TensorShape output_shape = indices_shape;
237  output_shape.InsertDim(axis, depth);
238
239  // Build a Tensor populated with values 0, 1, 2, ... depth.
240  std::vector<int64> linspace_dims(output_dims, 1);
241  linspace_dims[axis] = depth;
242  TensorShape linspace_shape(linspace_dims);
243  Tensor linspace;
244  switch (index_type) {
245    case DT_UINT8:
246      linspace = MakeLinspaceTensor<uint8>(linspace_shape, depth);
247      break;
248    case DT_INT32:
249      linspace = MakeLinspaceTensor<int32>(linspace_shape, depth);
250      break;
251    case DT_INT64:
252      linspace = MakeLinspaceTensor<int64>(linspace_shape, depth);
253      break;
254    default:
255      return errors::InvalidArgument("Invalid argument type ",
256                                     DataTypeString(index_type));
257  }
258  xla::Literal linspace_literal;
259  TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal));
260
261  // Broadcast the linspace constant across the indices along the new axis,
262  // and test equality at each position.
263  std::vector<int64> broadcast_dims(indices_shape.dims());
264  std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
265  std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
266  xla::ComputationDataHandle one_hot_bool = builder->Eq(
267      indices, builder->ConstantLiteral(linspace_literal), broadcast_dims);
268
269  // Selects the user-provided off_value and on_value values.
270  *one_hot = builder->Select(
271      one_hot_bool, builder->Broadcast(on_value, output_shape.dim_sizes()),
272      builder->Broadcast(off_value, output_shape.dim_sizes()));
273  return Status::OK();
274}
275
276}  // end namespace tensorflow
277