1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <math.h>
17#include <algorithm>
18#include <memory>
19#include <new>
20#include <random>
21#include <utility>
22
23#define EIGEN_USE_THREADS
24
25#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26#include "tensorflow/compiler/xla/array2d.h"
27#include "tensorflow/compiler/xla/client/client_library.h"
28#include "tensorflow/compiler/xla/client/computation.h"
29#include "tensorflow/compiler/xla/client/computation_builder.h"
30#include "tensorflow/compiler/xla/literal_util.h"
31#include "tensorflow/compiler/xla/primitive_util.h"
32#include "tensorflow/compiler/xla/ptr_util.h"
33#include "tensorflow/compiler/xla/service/hlo_computation.h"
34#include "tensorflow/compiler/xla/service/hlo_instruction.h"
35#include "tensorflow/compiler/xla/service/hlo_module.h"
36#include "tensorflow/compiler/xla/service/hlo_opcode.h"
37#include "tensorflow/compiler/xla/service/platform_util.h"
38#include "tensorflow/compiler/xla/shape_util.h"
39#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
40#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
41#include "tensorflow/compiler/xla/tests/literal_test_util.h"
42#include "tensorflow/compiler/xla/tests/test_macros.h"
43#include "tensorflow/compiler/xla/xla_data.pb.h"
44#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
45#include "tensorflow/core/lib/gtl/array_slice.h"
46#include "tensorflow/core/platform/logging.h"
47#include "tensorflow/core/platform/protobuf.h"
48#include "tensorflow/core/platform/test_benchmark.h"
49#include "tensorflow/core/platform/types.h"
50
51using tensorflow::gtl::ArraySlice;
52
53namespace se = ::perftools::gputools;
54
55namespace xla {
56namespace {
57
58const int test_width = 2, test_height = 3;
59
60const float test_float_vals[3][test_width][test_height] = {
61    {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}},
62    {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}},
63    {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}};
64
65// Test whether fusion operations are emitted with no errors and compute
66// accurate outputs.
67class FusionTest : public HloTestBase {
68 protected:
69  template <typename T, int Arity>
70  void TestElementwise2D(HloOpcode opcode) {
71    Array2D<float> operand_data[Arity];
72    for (int i = 0; i < Arity; ++i) {
73      new (&operand_data[i]) Array2D<float>(test_width, test_height);
74    }
75    Array2D<T> answer_data(test_width, test_height);
76    for (int i = 0; i < test_width; ++i) {
77      for (int j = 0; j < test_height; ++j) {
78        float xs[Arity];
79        for (int k = 0; k < Arity; ++k) {
80          xs[k] = test_float_vals[k][i][j];
81          operand_data[k](i, j) = xs[k];
82        }
83        answer_data(i, j) = ComputeElementwiseAnswer<T>(opcode, xs);
84      }
85    }
86
87    auto builder = HloComputation::Builder(TestName());
88    auto hlo_module = CreateNewModule();
89
90    auto prim_type = primitive_util::NativeToPrimitiveType<T>();
91
92    HloInstruction* hlos[4];
93    for (int i = 0; i < Arity; ++i) {
94      hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant(
95          Literal::CreateR2FromArray2D(operand_data[i])));
96    }
97    auto answer_shape =
98        ShapeUtil::MakeShape(prim_type, {test_width, test_height});
99    std::unique_ptr<HloInstruction> root_hlo;
100    switch (Arity) {
101      case 1:
102        root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
103        break;
104      case 2:
105        root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
106                                                hlos[2]);
107        break;
108      case 3:
109        root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
110                                                 hlos[2], hlos[3]);
111        break;
112      default:
113        LOG(FATAL) << "Bad arity: " << Arity;
114    }
115    hlos[0] = builder.AddInstruction(std::move(root_hlo));
116    hlo_module->AddEntryComputation(builder.Build())
117        ->CreateFusionInstruction(
118            ArraySlice<HloInstruction*>(hlos, 0, Arity + 1),
119            HloInstruction::FusionKind::kLoop);
120
121    auto expected = Literal::CreateR2FromArray2D(answer_data);
122    auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
123    if (primitive_util::IsFloatingPointType(prim_type)) {
124      LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4));
125    } else {
126      LiteralTestUtil::ExpectEqual(*expected, *actual);
127    }
128  }
129
130 private:
131  template <typename T>
132  T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice<float> xs);
133};
134
135template <>
136float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
137                                                  ArraySlice<float> xs) {
138  switch (opcode) {
139    case HloOpcode::kAdd:
140      return xs[0] + xs[1];
141    case HloOpcode::kSubtract:
142      return xs[0] - xs[1];
143    case HloOpcode::kMultiply:
144      return xs[0] * xs[1];
145    case HloOpcode::kDivide:
146      return xs[0] / xs[1];
147    case HloOpcode::kPower:
148      return powf(xs[0], xs[1]);
149    case HloOpcode::kMinimum:
150      return std::min(xs[0], xs[1]);
151    case HloOpcode::kMaximum:
152      return std::max(xs[0], xs[1]);
153    case HloOpcode::kClamp:
154      return std::min(xs[2], std::max(xs[1], xs[0]));
155    default:
156      LOG(FATAL) << "No elementwise opcode: " << opcode;
157  }
158}
159
160template <>
161bool FusionTest::ComputeElementwiseAnswer<bool>(HloOpcode opcode,
162                                                ArraySlice<float> xs) {
163  switch (opcode) {
164    case HloOpcode::kEq:
165      return xs[0] == xs[1];
166    case HloOpcode::kNe:
167      return xs[0] != xs[1];
168    case HloOpcode::kGt:
169      return xs[0] > xs[1];
170    case HloOpcode::kLt:
171      return xs[0] < xs[1];
172    case HloOpcode::kGe:
173      return xs[0] >= xs[1];
174    case HloOpcode::kLe:
175      return xs[0] <= xs[1];
176    default:
177      LOG(FATAL) << "No comparatory opcode: " << opcode;
178  }
179}
180
181XLA_TEST_F(FusionTest, Test) {
182  // test expression:
183  // slice(select({{T, F, T}, {F, T, F}},
184  //              concat(transpose({{1.0}, {2.0}, {3.0}} +
185  //                               {{-1.0}, {-1.0}, {-1.0}}),
186  //                     {{1.62, 2.72, 3.14}}) +
187  //                     (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}),
188  //              {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}}
189  auto builder = HloComputation::Builder(TestName());
190  auto hlo_module = CreateNewModule();
191  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
192      Literal::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
193  auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
194      Literal::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
195  auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
196      ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1));
197  auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose(
198      ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0}));
199  auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
200      Literal::CreateR2<float>({{1.62, 2.72, 3.14}})));
201  auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate(
202      ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0));
203  auto const6 = builder.AddInstruction(HloInstruction::CreateConstant(
204      Literal::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
205  auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary(
206      ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6));
207  auto add8 = builder.AddInstruction(HloInstruction::CreateBinary(
208      ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7));
209  auto const9 = builder.AddInstruction(HloInstruction::CreateConstant(
210      Literal::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
211  auto const10 = builder.AddInstruction(HloInstruction::CreateConstant(
212      Literal::CreateR2<bool>({{true, false, true}, {false, true, false}})));
213  auto select11 = builder.AddInstruction(
214      HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
215                                    HloOpcode::kSelect, const10, add8, const9));
216  auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice(
217      ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}, {1, 1}));
218  // CreateFusionInstruction needs the `instructions_to_fuse` argument in
219  // reverse topological order, so the first element in `instructions_to_fuse`
220  // must be the root.
221  hlo_module->AddEntryComputation(builder.Build())
222      ->CreateFusionInstruction(
223          {slice12, select11, const10, const9, add8, negate7, const6, concat5,
224           const4, reshape3, add2, const1, const0},
225          HloInstruction::FusionKind::kLoop);
226
227  LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{0.5}, {2.72}}),
228                              *ExecuteAndTransfer(std::move(hlo_module), {}),
229                              ErrorSpec(1e-4));
230}
231
232// Test whether we emit appropriate code for parameters of fusion instructions.
233XLA_TEST_F(FusionTest, Parameter) {
234  // Build a computation and fuse part of it so the fusion instruction has an
235  // operand parameter.
236  auto builder = HloComputation::Builder(TestName());
237  auto hlo_module = CreateNewModule();
238  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
239      Literal::CreateR2<float>({{1.0, 2.0, 3.0}})));
240  auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
241      ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0));
242  auto const2 = builder.AddInstruction(HloInstruction::CreateConstant(
243      Literal::CreateR2<float>({{-2.0, -2.0, -2.0}})));
244  // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1}
245  auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
246      ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2));
247  // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological
248  // order.
249  hlo_module->AddEntryComputation(builder.Build())
250      ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
251                                HloInstruction::FusionKind::kLoop);
252
253  LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
254                              *ExecuteAndTransfer(std::move(hlo_module), {}),
255                              ErrorSpec(1e-4));
256}
257
258XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
259  // Tests parallel partitioning of a fusion instruction.
260  // Create shape with random outer dimension size to generate random parallel
261  // partition counts for each test run.
262  const int seed = tensorflow::testing::RandomSeed();
263  LOG(INFO) << "RandomizedParallelPartition seed: " << seed;
264  std::mt19937 generator(seed);
265  std::uniform_int_distribution<int> distribution(128, 1024);
266  const int64 rand_dim0_size = distribution(generator);
267  const int64 dim1_size = 1024;
268  Shape shape =
269      ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0});
270  // Build simple fusion computation: y = x^2 (elementwise).
271  auto builder = HloComputation::Builder(TestName());
272  auto hlo_module = CreateNewModule();
273
274  auto two = builder.AddInstruction(
275      HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
276  auto x =
277      builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {}));
278  auto y = builder.AddInstruction(
279      HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x));
280
281  hlo_module->AddEntryComputation(builder.Build())
282      ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two},
283                                HloInstruction::FusionKind::kLoop);
284  // Compute result.
285  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
286  // Every element of result should be y = x^2 = 4.0.
287  for (int i = 0; i < rand_dim0_size; ++i) {
288    for (int j = 0; j < dim1_size; ++j) {
289      EXPECT_EQ(4.0, result->Get<float>({i, j}));
290    }
291  }
292}
293
294XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
295  auto builder = HloComputation::Builder(TestName());
296  auto hlo_module = CreateNewModule();
297  auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
298      Literal::CreateR1<float>({1.0, 2.0, 3.0})));
299  auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
300      Literal::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
301  auto broadcast = builder.AddInstruction(
302      HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1}));
303  // add2 = broadcast(const_vector) + const_array
304  //      = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
305  //      = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
306  auto add2 = builder.AddInstruction(
307      HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}),
308                                   HloOpcode::kAdd, broadcast, const_array));
309  hlo_module->AddEntryComputation(builder.Build())
310      ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
311                                HloInstruction::FusionKind::kLoop);
312
313  LiteralTestUtil::ExpectNear(
314      *Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
315      *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4));
316}
317
318XLA_TEST_F(FusionTest, ReshapeToScalar) {
319  auto builder = HloComputation::Builder(TestName());
320  auto hlo_module = CreateNewModule();
321  auto single_element_array = builder.AddInstruction(
322      HloInstruction::CreateConstant(Literal::CreateR2<int32>({{5}})));
323  auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
324      ShapeUtil::MakeShape(S32, {}), single_element_array));
325  hlo_module->AddEntryComputation(builder.Build())
326      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
327                                HloInstruction::FusionKind::kLoop);
328  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(5),
329                               *ExecuteAndTransfer(std::move(hlo_module), {}));
330}
331
332XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
333  auto builder = HloComputation::Builder(TestName());
334  auto hlo_module = CreateNewModule();
335  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
336      Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
337  auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
338      ShapeUtil::MakeShape(S32, {1, 2, 3}), const0));
339  hlo_module->AddEntryComputation(builder.Build())
340      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
341                                HloInstruction::FusionKind::kLoop);
342  LiteralTestUtil::ExpectEqual(
343      *Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
344      *ExecuteAndTransfer(std::move(hlo_module), {}));
345}
346
347XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
348  auto builder = HloComputation::Builder(TestName());
349  auto hlo_module = CreateNewModule();
350  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
351      Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
352  auto reshape1 = builder.AddInstruction(
353      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0));
354  hlo_module->AddEntryComputation(builder.Build())
355      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
356                                HloInstruction::FusionKind::kLoop);
357  LiteralTestUtil::ExpectEqual(
358      *Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
359      *ExecuteAndTransfer(std::move(hlo_module), {}));
360}
361
362XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
363  auto builder = HloComputation::Builder(TestName());
364  auto hlo_module = CreateNewModule();
365  auto const0 = builder.AddInstruction(
366      HloInstruction::CreateConstant(Literal::CreateR3<int32>({{{7}}})));
367  auto reshape1 = builder.AddInstruction(
368      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
369  hlo_module->AddEntryComputation(builder.Build())
370      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
371                                HloInstruction::FusionKind::kLoop);
372  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
373                               *ExecuteAndTransfer(std::move(hlo_module), {}));
374}
375
376XLA_TEST_F(FusionTest, Reshape__1by1by1) {
377  auto builder = HloComputation::Builder(TestName());
378  auto hlo_module = CreateNewModule();
379  auto const0 = builder.AddInstruction(
380      HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
381  auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
382      ShapeUtil::MakeShape(S32, {1, 1, 1}), const0));
383  hlo_module->AddEntryComputation(builder.Build())
384      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
385                                HloInstruction::FusionKind::kLoop);
386  LiteralTestUtil::ExpectEqual(*Literal::CreateR3<int32>({{{7}}}),
387                               *ExecuteAndTransfer(std::move(hlo_module), {}));
388}
389
390XLA_TEST_F(FusionTest, Reshape__) {
391  auto builder = HloComputation::Builder(TestName());
392  auto hlo_module = CreateNewModule();
393  auto const0 = builder.AddInstruction(
394      HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
395  auto reshape1 = builder.AddInstruction(
396      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
397  hlo_module->AddEntryComputation(builder.Build())
398      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
399                                HloInstruction::FusionKind::kLoop);
400  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
401                               *ExecuteAndTransfer(std::move(hlo_module), {}));
402}
403
404XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
405  auto builder = HloComputation::Builder(TestName());
406  auto hlo_module = CreateNewModule();
407  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
408      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
409  auto reshape1 = builder.AddInstruction(
410      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0));
411  hlo_module->AddEntryComputation(builder.Build())
412      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
413                                HloInstruction::FusionKind::kLoop);
414  LiteralTestUtil::ExpectEqual(
415      *Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
416      *ExecuteAndTransfer(std::move(hlo_module), {}));
417}
418
419XLA_TEST_F(FusionTest, Transpose_2by3) {
420  auto builder = HloComputation::Builder(TestName());
421  auto hlo_module = CreateNewModule();
422  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
423      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
424  auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
425      ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0}));
426  hlo_module->AddEntryComputation(builder.Build())
427      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
428                                HloInstruction::FusionKind::kLoop);
429  LiteralTestUtil::ExpectEqual(
430      *Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
431      *ExecuteAndTransfer(std::move(hlo_module), {}));
432}
433
434XLA_TEST_F(FusionTest, Transpose_3by3) {
435  auto builder = HloComputation::Builder(TestName());
436  auto hlo_module = CreateNewModule();
437  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
438      Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
439  auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
440      ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0}));
441  hlo_module->AddEntryComputation(builder.Build())
442      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
443                                HloInstruction::FusionKind::kLoop);
444  LiteralTestUtil::ExpectEqual(
445      *Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
446      *ExecuteAndTransfer(std::move(hlo_module), {}));
447}
448
449XLA_TEST_F(FusionTest, Reverse) {
450  auto builder = HloComputation::Builder(TestName());
451  auto hlo_module = CreateNewModule();
452  auto const0 = builder.AddInstruction(
453      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3})));
454  auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
455      ShapeUtil::MakeShape(S32, {3}), const0, {0}));
456  hlo_module->AddEntryComputation(builder.Build())
457      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
458                                HloInstruction::FusionKind::kLoop);
459
460  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({3, 2, 1}),
461                               *ExecuteAndTransfer(std::move(hlo_module), {}));
462}
463
464XLA_TEST_F(FusionTest, ReverseNegate) {
465  auto builder = HloComputation::Builder(TestName());
466  auto hlo_module = CreateNewModule();
467  auto const0 = builder.AddInstruction(
468      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3})));
469  auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
470      ShapeUtil::MakeShape(S32, {3}), const0, {0}));
471  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
472      ShapeUtil::MakeShape(S32, {3}), HloOpcode::kNegate, reverse1));
473  hlo_module->AddEntryComputation(builder.Build())
474      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1},
475                                HloInstruction::FusionKind::kLoop);
476
477  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-3, -2, -1}),
478                               *ExecuteAndTransfer(std::move(hlo_module), {}));
479}
480
481XLA_TEST_F(FusionTest, BroadcastNegate) {
482  auto builder = HloComputation::Builder(TestName());
483  auto hlo_module = CreateNewModule();
484  auto const0 = builder.AddInstruction(
485      HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
486  auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
487      ShapeUtil::MakeShape(S32, {2}), const0, {}));
488  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
489      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, broadcast1));
490  hlo_module->AddEntryComputation(builder.Build())
491      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1},
492                                HloInstruction::FusionKind::kLoop);
493
494  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -1}),
495                               *ExecuteAndTransfer(std::move(hlo_module), {}));
496}
497
498XLA_TEST_F(FusionTest, SliceNegate) {
499  auto builder = HloComputation::Builder(TestName());
500  auto hlo_module = CreateNewModule();
501  auto const0 = builder.AddInstruction(
502      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
503  auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
504      ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2}));
505  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
506      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1));
507  hlo_module->AddEntryComputation(builder.Build())
508      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1},
509                                HloInstruction::FusionKind::kLoop);
510
511  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -3}),
512                               *ExecuteAndTransfer(std::move(hlo_module), {}));
513}
514
515XLA_TEST_F(FusionTest, DynamicSliceNegate) {
516  auto builder = HloComputation::Builder(TestName());
517  auto hlo_module = CreateNewModule();
518  auto const0 = builder.AddInstruction(
519      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
520  auto const1 = builder.AddInstruction(
521      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1})));
522  auto dynamic_slice2 =
523      builder.AddInstruction(HloInstruction::CreateDynamicSlice(
524          ShapeUtil::MakeShape(S32, {2}), const0, const1, {2}));
525  auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
526      ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, dynamic_slice2));
527  hlo_module->AddEntryComputation(builder.Build())
528      ->CreateFusionInstruction(
529          /*instructions_to_fuse=*/{negate3, dynamic_slice2},
530          HloInstruction::FusionKind::kLoop);
531
532  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-2, -3}),
533                               *ExecuteAndTransfer(std::move(hlo_module), {}));
534}
535
536XLA_TEST_F(FusionTest, ReshapeNegate) {
537  auto builder = HloComputation::Builder(TestName());
538  auto hlo_module = CreateNewModule();
539  auto const0 = builder.AddInstruction(
540      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
541  auto reshape1 = builder.AddInstruction(
542      HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0));
543  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
544      ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, reshape1));
545  hlo_module->AddEntryComputation(builder.Build())
546      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
547                                HloInstruction::FusionKind::kLoop);
548
549  LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
550                               *ExecuteAndTransfer(std::move(hlo_module), {}));
551}
552
553// TODO(b/64070202): Investigate failure.
554XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) {
555  auto builder = HloComputation::Builder(TestName());
556  auto hlo_module = CreateNewModule();
557  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
558      Literal::CreateR2<int32>({{1, 2}, {3, 4}})));
559  auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
560      ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0}));
561  auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
562      ShapeUtil::MakeShape(S32, {2, 2}), HloOpcode::kNegate, transpose1));
563  hlo_module->AddEntryComputation(builder.Build())
564      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
565                                HloInstruction::FusionKind::kLoop);
566
567  LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
568                               *ExecuteAndTransfer(std::move(hlo_module), {}));
569}
570
571std::unique_ptr<HloComputation> MakeReduceTestComputation() {
572  auto builder = HloComputation::Builder("add");
573  auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
574      /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs"));
575  auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
576      /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs"));
577  builder.AddInstruction(HloInstruction::CreateBinary(
578      ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs));
579  return builder.Build();
580}
581
582XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
583  auto hlo_module = CreateNewModule();
584
585  auto builder = HloComputation::Builder(TestName());
586  auto const0 = builder.AddInstruction(
587      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
588  auto const1 = builder.AddInstruction(
589      HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
590  auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
591      ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
592      hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
593  hlo_module->AddEntryComputation(builder.Build())
594      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
595                                HloInstruction::FusionKind::kLoop);
596
597  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(15),
598                               *ExecuteAndTransfer(std::move(hlo_module), {}));
599}
600
601XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
602  auto hlo_module = CreateNewModule();
603
604  auto builder = HloComputation::Builder(TestName());
605  auto const0 = builder.AddInstruction(
606      HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
607  auto const1 = builder.AddInstruction(
608      HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
609  auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
610      ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
611      hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
612  auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
613      ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, reduce2));
614  hlo_module->AddEntryComputation(builder.Build())
615      ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
616                                HloInstruction::FusionKind::kLoop);
617
618  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(-15),
619                               *ExecuteAndTransfer(std::move(hlo_module), {}));
620}
621
622XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
623  auto builder = HloComputation::Builder(TestName());
624  auto hlo_module = CreateNewModule();
625  auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
626      Literal::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
627  auto const1 = builder.AddInstruction(
628      HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
629  Window window;
630  ASSERT_TRUE(
631      tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n"
632                                                        "size:2\n"
633                                                        "stride:1\n"
634                                                        "padding_low:0\n"
635                                                        "padding_high:0\n"
636                                                        "window_dilation:1\n"
637                                                        "base_dilation:1\n"
638                                                        "}\n"
639                                                        "dimensions:{\n"
640                                                        "size:2\n"
641                                                        "stride:1\n"
642                                                        "padding_low:0\n"
643                                                        "padding_high:0\n"
644                                                        "window_dilation:1\n"
645                                                        "base_dilation:1\n"
646                                                        "}\n",
647                                                        &window));
648  auto nested_builder = HloComputation::Builder("mul");
649  {
650    auto x = nested_builder.AddInstruction(
651        HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x"));
652    auto y = nested_builder.AddInstruction(
653        HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y"));
654    nested_builder.AddInstruction(HloInstruction::CreateBinary(
655        ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y));
656  }
657  auto nested_computation =
658      hlo_module->AddEmbeddedComputation(nested_builder.Build());
659  auto reduce_window2 =
660      builder.AddInstruction(HloInstruction::CreateReduceWindow(
661          ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window,
662          nested_computation));
663  hlo_module->AddEntryComputation(builder.Build())
664      ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
665                                HloInstruction::FusionKind::kLoop);
666
667  LiteralTestUtil::ExpectEqual(
668      *Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
669      *ExecuteAndTransfer(std::move(hlo_module), {}));
670}
671
672// When a constant (or other op) which has multiple users is imported
673// into a fusion, it should remain shared, rather than being duplicated
674// within the fusion.
675XLA_TEST_F(FusionTest, SharedConstant) {
676  auto hlo_module = CreateNewModule();
677
678  auto builder = HloComputation::Builder(TestName());
679  auto const0 = builder.AddInstruction(
680          HloInstruction::CreateConstant(Literal::CreateR1<int32>({0})));
681  auto const1 = builder.AddInstruction(
682          HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
683  auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
684          ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0));
685  auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
686          ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1));
687  auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
688          ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2));
689  auto add4 = builder.AddInstruction(HloInstruction::CreateBinary(
690          ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3));
691  hlo_module->AddEntryComputation(builder.Build())
692      ->CreateFusionInstruction(
693        {add4, add3, add2, add1, const1},
694        HloInstruction::FusionKind::kLoop);
695
696  HloComputation* entry_comp = hlo_module->entry_computation();
697
698  // entry computation contains the constant(0) and the fusion
699  EXPECT_EQ(entry_comp->instruction_count(), 2);
700
701  // fused instruction contains the constant(2), the parameter, and 4 adds
702  EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
703
704  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}),
705          *ExecuteAndTransfer(std::move(hlo_module), {}));
706}
707
708XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
709
710XLA_TEST_F(FusionTest, Subtract2D) {
711  TestElementwise2D<float, 2>(HloOpcode::kSubtract);
712}
713
714XLA_TEST_F(FusionTest, Multiply2D) {
715  TestElementwise2D<float, 2>(HloOpcode::kMultiply);
716}
717
718XLA_TEST_F(FusionTest, Divide2D) {
719  TestElementwise2D<float, 2>(HloOpcode::kDivide);
720}
721
722XLA_TEST_F(FusionTest, Power2D) {
723  TestElementwise2D<float, 2>(HloOpcode::kPower);
724}
725
726XLA_TEST_F(FusionTest, Minimum2D) {
727  TestElementwise2D<float, 2>(HloOpcode::kMinimum);
728}
729
730XLA_TEST_F(FusionTest, Maximum2D) {
731  TestElementwise2D<float, 2>(HloOpcode::kMaximum);
732}
733
734XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D<bool, 2>(HloOpcode::kEq); }
735
736XLA_TEST_F(FusionTest, Inequal2D) {
737  TestElementwise2D<bool, 2>(HloOpcode::kNe);
738}
739
740XLA_TEST_F(FusionTest, Greater2D) {
741  TestElementwise2D<bool, 2>(HloOpcode::kGt);
742}
743
744XLA_TEST_F(FusionTest, Lesser2D) { TestElementwise2D<bool, 2>(HloOpcode::kLt); }
745
746XLA_TEST_F(FusionTest, GreaterOrEqual2D) {
747  TestElementwise2D<bool, 2>(HloOpcode::kGe);
748}
749
750XLA_TEST_F(FusionTest, LesserOrEqual2D) {
751  TestElementwise2D<bool, 2>(HloOpcode::kLe);
752}
753
754XLA_TEST_F(FusionTest, Clamp2D) {
755  TestElementwise2D<float, 3>(HloOpcode::kClamp);
756}
757
758void BM_ParallelFusion(int num_iters) {
759  // Simple element-wise computation to benchmark parallel task partitioning.
760  tensorflow::testing::StopTiming();
761
762  se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
763  auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
764  StreamExecutorMemoryAllocator allocator(platform, executors);
765
766  const int64 intra_op_parallelism_threads = 24;
767  xla::LocalClientOptions client_options;
768  client_options.set_platform(platform);
769  client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads);
770  auto client =
771      ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie();
772
773  int device_ordinal = client->default_device_ordinal();
774
775  // Computation shape parameters.
776  const int64 param0_dim0 = 1024;
777  const int64 param0_dim1 = 1024;
778  const int64 param1_dim0 = 1024;
779  const int64 param1_dim1 = 1024;
780  const int64 param2_dim0 = 1024;
781  const int64 param2_dim1 = 1024;
782
783  // Create computation.
784  ComputationBuilder builder(client, "ParallelFusion");
785  Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1});
786  auto param0 = builder.Parameter(0, shape0, "param0");
787  Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1});
788  auto param1 = builder.Parameter(1, shape1, "param1");
789  Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1});
790  auto param2 = builder.Parameter(2, shape2, "param2");
791
792  auto x = builder.Mul(param0, param1);
793  auto y = builder.Add(x, param2);
794  auto computation = builder.Build().ConsumeValueOrDie();
795
796  // Transfer literals to device.
797  auto param0_literal =
798      Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
799  std::unique_ptr<ShapedBuffer> buffer0 =
800      client->LiteralToShapedBuffer(*param0_literal, device_ordinal)
801          .ConsumeValueOrDie();
802
803  auto param1_literal =
804      Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
805  std::unique_ptr<ShapedBuffer> buffer1 =
806      client->LiteralToShapedBuffer(*param1_literal, device_ordinal)
807          .ConsumeValueOrDie();
808
809  auto param2_literal =
810      Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
811  std::unique_ptr<ShapedBuffer> buffer2 =
812      client->LiteralToShapedBuffer(*param2_literal, device_ordinal)
813          .ConsumeValueOrDie();
814
815  // Build executable.
816  std::unique_ptr<LocalExecutable> executable =
817      client
818          ->Compile(computation,
819                    {&buffer0->on_host_shape(), &buffer1->on_host_shape(),
820                     &buffer2->on_host_shape()},
821                    ExecutableBuildOptions())
822          .ConsumeValueOrDie();
823
824  se::Stream stream(executors[device_ordinal]);
825  stream.Init();
826
827  // Initialize thread pool.
828  tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
829                                      intra_op_parallelism_threads);
830  tensorflow::EigenThreadPoolWrapper tp(&pool);
831  Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
832
833  // Initialize ExecutableRunOptions.
834  ExecutableRunOptions options;
835  options.set_allocator(&allocator).set_stream(&stream);
836  options.set_intra_op_thread_pool(&device);
837
838  // Run some warm-up executions.
839  const int kWarmups = 2;
840  for (int i = 0; i < kWarmups; ++i) {
841    auto result =
842        executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options);
843    ASSERT_TRUE(result.ok());
844  }
845
846  // Run benchmark.
847  const int64 total_bytes = param0_dim0 * param0_dim0 +
848                            param1_dim0 * param1_dim0 +
849                            param2_dim0 * param2_dim0;
850  tensorflow::testing::BytesProcessed(static_cast<int64>(num_iters) *
851                                      total_bytes * sizeof(float));
852  tensorflow::testing::UseRealTime();
853  tensorflow::testing::StartTiming();
854  for (int i = 0; i < num_iters; ++i) {
855    auto result =
856        executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options);
857    ASSERT_TRUE(result.ok());
858  }
859}
860
861BENCHMARK(BM_ParallelFusion);
862
863}  // namespace
864}  // namespace xla
865