1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17
18#include <memory>
19#include <utility>
20
21#include "tensorflow/compiler/xla/layout_util.h"
22#include "tensorflow/compiler/xla/literal_util.h"
23#include "tensorflow/compiler/xla/ptr_util.h"
24#include "tensorflow/compiler/xla/service/hlo_computation.h"
25#include "tensorflow/compiler/xla/service/hlo_instruction.h"
26#include "tensorflow/compiler/xla/service/hlo_matchers.h"
27#include "tensorflow/compiler/xla/service/hlo_opcode.h"
28#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
29#include "tensorflow/compiler/xla/shape_util.h"
30#include "tensorflow/compiler/xla/test.h"
31#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
32#include "tensorflow/compiler/xla/types.h"
33#include "tensorflow/compiler/xla/window_util.h"
34#include "tensorflow/compiler/xla/xla_data.pb.h"
35#include "tensorflow/core/lib/core/status_test_util.h"
36#include "tensorflow/core/lib/strings/str_util.h"
37
38namespace xla {
39namespace {
40
41namespace op = xla::testing::opcode_matchers;
42
43AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() {
44  return [](const Shape&, const Shape&) { return true; };
45}
46
47AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
48  return [](const Shape&, const Shape&) { return false; };
49}
50
51class AlgebraicSimplifierTest : public HloVerifiedTestBase {};
52
53// Test that A + 0 is simplified to A
54TEST_F(AlgebraicSimplifierTest, AddZero) {
55  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
56  HloComputation::Builder builder(TestName());
57  HloInstruction* param0 = builder.AddInstruction(
58      HloInstruction::CreateParameter(0, r0f32, "param0"));
59  HloInstruction* zero = builder.AddInstruction(
60      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
61  builder.AddInstruction(
62      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
63
64  auto computation = module().AddEntryComputation(builder.Build());
65  HloInstruction* root = computation->root_instruction();
66  EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
67  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
68                                 non_bitcasting_callback());
69  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
70  root = computation->root_instruction();
71  EXPECT_EQ(root, param0);
72}
73
74// Test that Const + A is canonicalized to A + Const.
75TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
76  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
77  HloComputation::Builder builder(TestName());
78  HloInstruction* param0 = builder.AddInstruction(
79      HloInstruction::CreateParameter(0, r0f32, "param0"));
80  HloInstruction* constant = builder.AddInstruction(
81      HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
82  builder.AddInstruction(
83      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
84
85  auto computation = module().AddEntryComputation(builder.Build());
86  HloInstruction* root = computation->root_instruction();
87  EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
88  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
89                                 non_bitcasting_callback());
90  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
91  root = computation->root_instruction();
92  EXPECT_THAT(root, op::Add(param0, op::Constant()));
93}
94
95// Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2.
96TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
97  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
98  HloComputation::Builder builder(TestName());
99  HloInstruction* param0 = builder.AddInstruction(
100      HloInstruction::CreateParameter(0, r0f32, "param0"));
101  HloInstruction* constant1 = builder.AddInstruction(
102      HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
103  HloInstruction* constant2 = builder.AddInstruction(
104      HloInstruction::CreateConstant(Literal::CreateR0(3.14159f)));
105
106  HloInstruction* add1 = builder.AddInstruction(
107      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
108  builder.AddInstruction(
109      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2));
110
111  auto computation = module().AddEntryComputation(builder.Build());
112  HloInstruction* root = computation->root_instruction();
113  EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
114  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
115                                 non_bitcasting_callback());
116  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
117  root = computation->root_instruction();
118  EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2)));
119}
120
121TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
122  Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
123  HloComputation::Builder builder(TestName());
124  HloInstruction* param0 = builder.AddInstruction(
125      HloInstruction::CreateParameter(0, r2f32, "param0"));
126  HloInstruction* zero = builder.AddInstruction(
127      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
128  HloInstruction* bcast = builder.AddInstruction(
129      HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
130  builder.AddInstruction(
131      HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
132
133  auto computation = module().AddEntryComputation(builder.Build());
134  HloInstruction* root = computation->root_instruction();
135  EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
136  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
137                                 non_bitcasting_callback());
138  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
139  root = computation->root_instruction();
140  EXPECT_EQ(root, param0);
141}
142
143TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
144  Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
145  HloComputation::Builder builder(TestName());
146  HloInstruction* param0 = builder.AddInstruction(
147      HloInstruction::CreateParameter(0, r2f32, "param0"));
148  HloInstruction* zero = builder.AddInstruction(
149      HloInstruction::CreateConstant(Literal::CreateR1<float>({0, 0, 0})));
150  HloInstruction* bcast =
151      builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
152  builder.AddInstruction(
153      HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0));
154
155  auto computation = module().AddEntryComputation(builder.Build());
156  HloInstruction* root = computation->root_instruction();
157  EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
158  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
159                                 non_bitcasting_callback());
160  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
161  root = computation->root_instruction();
162  EXPECT_EQ(root, param0);
163}
164
165// Test that A - 0 is simplified to A
166TEST_F(AlgebraicSimplifierTest, SubZero) {
167  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
168  HloComputation::Builder builder(TestName());
169  HloInstruction* param0 = builder.AddInstruction(
170      HloInstruction::CreateParameter(0, r0f32, "param0"));
171  HloInstruction* zero = builder.AddInstruction(
172      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
173  builder.AddInstruction(
174      HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
175
176  auto computation = module().AddEntryComputation(builder.Build());
177  HloInstruction* root = computation->root_instruction();
178  EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
179  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
180                                 non_bitcasting_callback());
181  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
182  root = computation->root_instruction();
183  EXPECT_EQ(root, param0);
184}
185
186// Test that A - Const is canonicalized to A + (-Const).
187TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
188  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
189  HloComputation::Builder builder(TestName());
190  HloInstruction* param0 = builder.AddInstruction(
191      HloInstruction::CreateParameter(0, r0f32, "param0"));
192  HloInstruction* constant = builder.AddInstruction(
193      HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
194  builder.AddInstruction(HloInstruction::CreateBinary(
195      r0f32, HloOpcode::kSubtract, param0, constant));
196
197  auto computation = module().AddEntryComputation(builder.Build());
198  HloInstruction* root = computation->root_instruction();
199  EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
200  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
201                                 non_bitcasting_callback());
202  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
203  root = computation->root_instruction();
204  EXPECT_THAT(root, op::Add(param0, op::Negate(constant)));
205}
206
207// Test that (A/B)/C is simplified to A/(B*C).
208TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
209  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
210  HloComputation::Builder builder(TestName());
211  HloInstruction* param0 = builder.AddInstruction(
212      HloInstruction::CreateParameter(0, r0f32, "param0"));
213  HloInstruction* param1 = builder.AddInstruction(
214      HloInstruction::CreateParameter(1, r0f32, "param1"));
215  HloInstruction* param2 = builder.AddInstruction(
216      HloInstruction::CreateParameter(2, r0f32, "param2"));
217  HloInstruction* div = builder.AddInstruction(
218      HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1));
219  builder.AddInstruction(
220      HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2));
221
222  auto computation = module().AddEntryComputation(builder.Build());
223
224  EXPECT_THAT(computation->root_instruction(),
225              op::Divide(op::Divide(param0, param1), param2));
226
227  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
228                                 non_bitcasting_callback());
229  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
230
231  EXPECT_THAT(computation->root_instruction(),
232              op::Divide(param0, op::Multiply(param1, param2)));
233}
234
235// Test that A/(B/C) is simplified to (A*C)/B.
236TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) {
237  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
238  HloComputation::Builder builder(TestName());
239  HloInstruction* param0 = builder.AddInstruction(
240      HloInstruction::CreateParameter(0, r0f32, "param0"));
241  HloInstruction* param1 = builder.AddInstruction(
242      HloInstruction::CreateParameter(1, r0f32, "param1"));
243  HloInstruction* param2 = builder.AddInstruction(
244      HloInstruction::CreateParameter(2, r0f32, "param2"));
245  HloInstruction* div = builder.AddInstruction(
246      HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param1, param2));
247  builder.AddInstruction(
248      HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div));
249
250  auto computation = module().AddEntryComputation(builder.Build());
251
252  EXPECT_THAT(computation->root_instruction(),
253              op::Divide(param0, op::Divide(param1, param2)));
254
255  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
256                                 non_bitcasting_callback());
257  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
258
259  EXPECT_THAT(computation->root_instruction(),
260              op::Divide(op::Multiply(param0, param2), param1));
261}
262
263// Test that (A/B)/(C/D) is simplified to (A*D)/(B*C).
264TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
265  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
266  Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123});
267  HloComputation::Builder builder(TestName());
268  HloInstruction* param0 = builder.AddInstruction(
269      HloInstruction::CreateParameter(0, r0f32, "param0"));
270  HloInstruction* param1 = builder.AddInstruction(
271      HloInstruction::CreateParameter(1, r2f32, "param1"));
272  HloInstruction* param2 = builder.AddInstruction(
273      HloInstruction::CreateParameter(2, r2f32, "param2"));
274  HloInstruction* param3 = builder.AddInstruction(
275      HloInstruction::CreateParameter(3, r0f32, "param3"));
276  HloInstruction* div0 = builder.AddInstruction(
277      HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1));
278  HloInstruction* div1 = builder.AddInstruction(
279      HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param2, param3));
280  builder.AddInstruction(
281      HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1));
282
283  auto computation = module().AddEntryComputation(builder.Build());
284
285  EXPECT_THAT(
286      computation->root_instruction(),
287      op::Divide(op::Divide(param0, param1), op::Divide(param2, param3)));
288
289  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
290                                 non_bitcasting_callback());
291  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
292
293  EXPECT_THAT(
294      computation->root_instruction(),
295      op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2)));
296  EXPECT_TRUE(
297      ShapeUtil::Compatible(computation->root_instruction()->shape(), r2f32));
298}
299
300// Test that A/exp(B) is simplified to A*exp(-B).
301TEST_F(AlgebraicSimplifierTest, DivOfExp) {
302  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
303  HloComputation::Builder builder(TestName());
304  HloInstruction* param0 = builder.AddInstruction(
305      HloInstruction::CreateParameter(0, r0f32, "param0"));
306  HloInstruction* param1 = builder.AddInstruction(
307      HloInstruction::CreateParameter(1, r0f32, "param1"));
308  HloInstruction* exp = builder.AddInstruction(
309      HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
310  builder.AddInstruction(
311      HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp));
312
313  auto computation = module().AddEntryComputation(builder.Build());
314
315  EXPECT_THAT(computation->root_instruction(),
316              op::Divide(param0, op::Exp(param1)));
317
318  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
319                                 non_bitcasting_callback());
320  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
321
322  EXPECT_THAT(computation->root_instruction(),
323              op::Multiply(param0, op::Exp(op::Negate(param1))));
324}
325
326// Test that A/pow(B,C) is simplified to A*pow(B,-C).
327TEST_F(AlgebraicSimplifierTest, DivOfPower) {
328  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
329  HloComputation::Builder builder(TestName());
330  HloInstruction* param0 = builder.AddInstruction(
331      HloInstruction::CreateParameter(0, r0f32, "param0"));
332  HloInstruction* param1 = builder.AddInstruction(
333      HloInstruction::CreateParameter(1, r0f32, "param1"));
334  HloInstruction* param2 = builder.AddInstruction(
335      HloInstruction::CreateParameter(2, r0f32, "param2"));
336  HloInstruction* power = builder.AddInstruction(
337      HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param1, param2));
338  builder.AddInstruction(
339      HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power));
340
341  auto computation = module().AddEntryComputation(builder.Build());
342
343  EXPECT_THAT(computation->root_instruction(),
344              op::Divide(param0, op::Power(param1, param2)));
345
346  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
347                                 non_bitcasting_callback());
348  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
349
350  EXPECT_THAT(computation->root_instruction(),
351              op::Multiply(param0, op::Power(param1, op::Negate(param2))));
352}
353
354// Test that broadcasting is done on the right step when simplifying A/pow(B,C)
355// to A*pow(B,-C).
356TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
357  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
358  Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
359  HloComputation::Builder builder(TestName());
360  HloInstruction* param0 = builder.AddInstruction(
361      HloInstruction::CreateParameter(0, r1f32, "param0"));
362  HloInstruction* param1 = builder.AddInstruction(
363      HloInstruction::CreateParameter(1, r1f32, "param1"));
364  HloInstruction* param2 = builder.AddInstruction(
365      HloInstruction::CreateParameter(2, r0f32, "param2"));
366  HloInstruction* power = builder.AddInstruction(
367      HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2));
368  builder.AddInstruction(
369      HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power));
370
371  auto computation = module().AddEntryComputation(builder.Build());
372
373  EXPECT_THAT(computation->root_instruction(),
374              op::Divide(param0, op::Power(param1, param2)));
375
376  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
377                                 non_bitcasting_callback());
378  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
379
380  ASSERT_THAT(computation->root_instruction(),
381              op::Multiply(param0, op::Power(param1, op::Negate(param2))));
382
383  const HloInstruction* negate =
384      computation->root_instruction()->operand(1)->operand(1);
385  const Shape& negate_shape = negate->shape();
386  EXPECT_EQ(0, negate_shape.dimensions_size());
387}
388
389// A / Const => A * (1 / Const)
390TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
391  Shape r1f32 = ShapeUtil::MakeShape(F32, {3});
392  HloComputation::Builder builder(TestName());
393  HloInstruction* param0 = builder.AddInstruction(
394      HloInstruction::CreateParameter(0, r1f32, "param0"));
395  HloInstruction* constant =
396      builder.AddInstruction(HloInstruction::CreateConstant(
397          Literal::CreateR1<float>({0.f, 1.f, 2.f})));
398  builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
399                                                      param0, constant));
400
401  auto computation = module().AddEntryComputation(builder.Build());
402
403  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
404                                 non_bitcasting_callback());
405  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
406
407  EXPECT_THAT(computation->root_instruction(),
408              op::Multiply(param0, op::Divide(op::Constant(), constant)));
409}
410
411// pow(pow(A, X), Y) => pow(A, X*Y)
412TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
413  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
414  Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
415  HloComputation::Builder builder(TestName());
416  HloInstruction* base = builder.AddInstruction(
417      HloInstruction::CreateParameter(0, r1f32, "param0"));
418  HloInstruction* exp1 = builder.AddInstruction(
419      HloInstruction::CreateParameter(1, r0f32, "param1"));
420  HloInstruction* exp2 = builder.AddInstruction(
421      HloInstruction::CreateParameter(2, r0f32, "param2"));
422  HloInstruction* inner_power = builder.AddInstruction(
423      HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
424  builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
425                                                      inner_power, exp2));
426
427  auto computation = module().AddEntryComputation(builder.Build());
428  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
429                                 non_bitcasting_callback());
430  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
431  EXPECT_THAT(computation->root_instruction(),
432              op::Power(base, op::Multiply(exp1, exp2)));
433}
434
435// Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex
436// numbers.
437TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
438  Shape r0c64 = ShapeUtil::MakeShape(C64, {});
439  Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
440  HloComputation::Builder builder(TestName());
441  HloInstruction* base = builder.AddInstruction(
442      HloInstruction::CreateParameter(0, r1c64, "param0"));
443  HloInstruction* exp1 = builder.AddInstruction(
444      HloInstruction::CreateParameter(1, r0c64, "param1"));
445  HloInstruction* exp2 = builder.AddInstruction(
446      HloInstruction::CreateParameter(2, r0c64, "param2"));
447  HloInstruction* inner_power = builder.AddInstruction(
448      HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
449  builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
450                                                      inner_power, exp2));
451
452  module().AddEntryComputation(builder.Build());
453  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
454                                 non_bitcasting_callback());
455  ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie());
456}
457
458// Test that A/1 is simplified to A for a scalar.
459TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
460  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
461  HloComputation::Builder builder(TestName());
462  HloInstruction* param0 = builder.AddInstruction(
463      HloInstruction::CreateParameter(0, r0f32, "param0"));
464  HloInstruction* one = builder.AddInstruction(
465      HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
466  HloInstruction* div = builder.AddInstruction(
467      HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
468
469  auto computation = module().AddEntryComputation(builder.Build());
470  HloInstruction* root = computation->root_instruction();
471  EXPECT_EQ(root, div);
472  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
473                                 non_bitcasting_callback());
474  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
475  root = computation->root_instruction();
476  EXPECT_EQ(root, param0);
477}
478
479// Test that A/1 is simplified to A for an array.
480TEST_F(AlgebraicSimplifierTest, DivOneArray) {
481  Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
482  HloComputation::Builder builder(TestName());
483  HloInstruction* param0 = builder.AddInstruction(
484      HloInstruction::CreateParameter(0, r2f32, "param0"));
485  HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
486      Literal::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
487  HloInstruction* div = builder.AddInstruction(
488      HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
489
490  auto computation = module().AddEntryComputation(builder.Build());
491  HloInstruction* root = computation->root_instruction();
492  EXPECT_EQ(root, div);
493  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
494                                 non_bitcasting_callback());
495  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
496  root = computation->root_instruction();
497  EXPECT_EQ(root, param0);
498}
499
500// Test that complex(real(c), imag(c)) is simplified to c.
501TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) {
502  Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
503  Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2});
504  HloComputation::Builder builder(TestName());
505  HloInstruction* param0 = builder.AddInstruction(
506      HloInstruction::CreateParameter(0, r2c64, "param0"));
507  HloInstruction* real = builder.AddInstruction(
508      HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0));
509  HloInstruction* imag = builder.AddInstruction(
510      HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0));
511  HloInstruction* cplx = builder.AddInstruction(
512      HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag));
513
514  auto computation = module().AddEntryComputation(builder.Build());
515  HloInstruction* root = computation->root_instruction();
516  EXPECT_EQ(root, cplx);
517  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
518                                 non_bitcasting_callback());
519  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
520  root = computation->root_instruction();
521  EXPECT_EQ(root, param0);
522}
523
524// Test that real(complex(r,i)) is simplified to r.
525TEST_F(AlgebraicSimplifierTest, RealOfComplex) {
526  Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
527  HloComputation::Builder builder(TestName());
528  HloInstruction* param0 = builder.AddInstruction(
529      HloInstruction::CreateParameter(0, r2f32, "param0"));
530  HloInstruction* param1 = builder.AddInstruction(
531      HloInstruction::CreateParameter(1, r2f32, "param1"));
532  HloInstruction* cplx = builder.AddInstruction(
533      HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
534                                   HloOpcode::kComplex, param0, param1));
535  HloInstruction* real = builder.AddInstruction(
536      HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx));
537
538  auto computation = module().AddEntryComputation(builder.Build());
539  HloInstruction* root = computation->root_instruction();
540  EXPECT_EQ(root, real);
541  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
542                                 non_bitcasting_callback());
543  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
544  root = computation->root_instruction();
545  EXPECT_EQ(root, param0);
546}
547
548// Test that imag(complex(r,i)) is simplified to i.
549TEST_F(AlgebraicSimplifierTest, ImagOfComplex) {
550  Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
551  HloComputation::Builder builder(TestName());
552  HloInstruction* param0 = builder.AddInstruction(
553      HloInstruction::CreateParameter(0, r2f32, "param0"));
554  HloInstruction* param1 = builder.AddInstruction(
555      HloInstruction::CreateParameter(1, r2f32, "param1"));
556  HloInstruction* cplx = builder.AddInstruction(
557      HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64),
558                                   HloOpcode::kComplex, param0, param1));
559  HloInstruction* imag = builder.AddInstruction(
560      HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx));
561
562  auto computation = module().AddEntryComputation(builder.Build());
563  HloInstruction* root = computation->root_instruction();
564  EXPECT_EQ(root, imag);
565  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
566                                 non_bitcasting_callback());
567  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
568  root = computation->root_instruction();
569  EXPECT_EQ(root, param1);
570}
571
572// Test that get_element(make_tuple({A,B}),1) is simplified to B
573TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
574  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
575  HloComputation::Builder builder(TestName());
576  HloInstruction* param0 = builder.AddInstruction(
577      HloInstruction::CreateParameter(0, r0f32, "param0"));
578  HloInstruction* param1 = builder.AddInstruction(
579      HloInstruction::CreateParameter(1, r0f32, "param1"));
580  HloInstruction* param2 = builder.AddInstruction(
581      HloInstruction::CreateParameter(2, r0f32, "param2"));
582  HloInstruction* tuple =
583      builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
584  HloInstruction* get = builder.AddInstruction(
585      HloInstruction::CreateGetTupleElement(r0f32, tuple, 1));
586  HloInstruction* add = builder.AddInstruction(
587      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
588
589  auto computation = module().AddEntryComputation(builder.Build());
590  HloInstruction* root = computation->root_instruction();
591  EXPECT_EQ(root, add);
592  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
593                                 non_bitcasting_callback());
594  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
595  root = computation->root_instruction();
596  EXPECT_THAT(root, op::Add(param1, param2));
597}
598
599// Test that exp(A)/exp(B) is simplified to exp(A-B)
600TEST_F(AlgebraicSimplifierTest, ExpDiv) {
601  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
602  HloComputation::Builder builder(TestName());
603  HloInstruction* param0 = builder.AddInstruction(
604      HloInstruction::CreateParameter(0, r0f32, "param0"));
605  HloInstruction* param1 = builder.AddInstruction(
606      HloInstruction::CreateParameter(1, r0f32, "param1"));
607  HloInstruction* exp0 = builder.AddInstruction(
608      HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
609  HloInstruction* exp1 = builder.AddInstruction(
610      HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
611  builder.AddInstruction(
612      HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
613
614  auto computation = module().AddEntryComputation(builder.Build());
615
616  EXPECT_THAT(computation->root_instruction(),
617              op::Divide(op::Exp(param0), op::Exp(param1)));
618
619  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
620                                 non_bitcasting_callback());
621  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
622
623  EXPECT_THAT(computation->root_instruction(),
624              op::Exp(op::Subtract(param0, param1)));
625}
626
627// Test that exp(A)*exp(B) is simplified to exp(A+B)
628TEST_F(AlgebraicSimplifierTest, ExpMul) {
629  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
630  HloComputation::Builder builder(TestName());
631  HloInstruction* param0 = builder.AddInstruction(
632      HloInstruction::CreateParameter(0, r0f32, "param0"));
633  HloInstruction* param1 = builder.AddInstruction(
634      HloInstruction::CreateParameter(1, r0f32, "param1"));
635  HloInstruction* exp0 = builder.AddInstruction(
636      HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
637  HloInstruction* exp1 = builder.AddInstruction(
638      HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
639  builder.AddInstruction(
640      HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1));
641
642  auto computation = module().AddEntryComputation(builder.Build());
643
644  EXPECT_THAT(computation->root_instruction(),
645              op::Multiply(op::Exp(param0), op::Exp(param1)));
646
647  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
648                                 non_bitcasting_callback());
649  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
650
651  EXPECT_THAT(computation->root_instruction(),
652              op::Exp(op::Add(param0, param1)));
653}
654
655// Test that pow(exp(A), B) is simplified to exp(A*B)
656TEST_F(AlgebraicSimplifierTest, PowExp) {
657  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
658  HloComputation::Builder builder(TestName());
659  HloInstruction* param0 = builder.AddInstruction(
660      HloInstruction::CreateParameter(0, r0f32, "param0"));
661  HloInstruction* param1 = builder.AddInstruction(
662      HloInstruction::CreateParameter(1, r0f32, "param1"));
663  HloInstruction* exp0 = builder.AddInstruction(
664      HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
665  builder.AddInstruction(
666      HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1));
667
668  auto computation = module().AddEntryComputation(builder.Build());
669
670  EXPECT_THAT(computation->root_instruction(),
671              op::Power(op::Exp(param0), param1));
672
673  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
674                                 non_bitcasting_callback());
675  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
676
677  EXPECT_THAT(computation->root_instruction(),
678              op::Exp(op::Multiply(param0, param1)));
679}
680
681// Test that ln(pow(A, B)) is simplified to ln(A)*B
682TEST_F(AlgebraicSimplifierTest, LnPow) {
683  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
684  HloComputation::Builder builder(TestName());
685  HloInstruction* param0 = builder.AddInstruction(
686      HloInstruction::CreateParameter(0, r0f32, "param0"));
687  HloInstruction* param1 = builder.AddInstruction(
688      HloInstruction::CreateParameter(1, r0f32, "param1"));
689  HloInstruction* pow = builder.AddInstruction(
690      HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, param1));
691  builder.AddInstruction(
692      HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow));
693
694  auto computation = module().AddEntryComputation(builder.Build());
695
696  EXPECT_THAT(computation->root_instruction(),
697              op::Log(op::Power(param0, param1)));
698
699  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
700                                 non_bitcasting_callback());
701  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
702
703  EXPECT_THAT(computation->root_instruction(),
704              op::Multiply(op::Log(param0), param1));
705}
706
707// Test that ln(exp(A)) is simplified to A
708TEST_F(AlgebraicSimplifierTest, LnExp) {
709  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
710  HloComputation::Builder builder(TestName());
711  HloInstruction* param0 = builder.AddInstruction(
712      HloInstruction::CreateParameter(0, r0f32, "param0"));
713  HloInstruction* exp0 = builder.AddInstruction(
714      HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
715  builder.AddInstruction(
716      HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
717
718  auto computation = module().AddEntryComputation(builder.Build());
719
720  EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0)));
721
722  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
723                                 non_bitcasting_callback());
724  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
725
726  EXPECT_EQ(computation->root_instruction(), param0);
727}
728
729// Test that ln(exp(A)/exp(B)) is simplified to A-B
730TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
731  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
732  HloComputation::Builder builder(TestName());
733  HloInstruction* param0 = builder.AddInstruction(
734      HloInstruction::CreateParameter(0, r0f32, "param0"));
735  HloInstruction* param1 = builder.AddInstruction(
736      HloInstruction::CreateParameter(1, r0f32, "param1"));
737  HloInstruction* exp0 = builder.AddInstruction(
738      HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
739  HloInstruction* exp1 = builder.AddInstruction(
740      HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
741  HloInstruction* div = builder.AddInstruction(
742      HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
743  builder.AddInstruction(
744      HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
745
746  auto computation = module().AddEntryComputation(builder.Build());
747
748  EXPECT_THAT(computation->root_instruction(),
749              op::Log(op::Divide(op::Exp(param0), op::Exp(param1))));
750
751  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
752                                 non_bitcasting_callback());
753  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
754
755  EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1));
756}
757
758// Test that pow(A, 0) where A is a scalar is simplified to the scalar
759// constant 1.
760TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
761  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
762  HloComputation::Builder builder(TestName());
763  HloInstruction* param0 = builder.AddInstruction(
764      HloInstruction::CreateParameter(0, r0f32, "param0"));
765  HloInstruction* zero = builder.AddInstruction(
766      HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
767  builder.AddInstruction(
768      HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
769
770  auto computation = module().AddEntryComputation(builder.Build());
771
772  EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
773
774  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
775                                 non_bitcasting_callback());
776  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
777
778  HloInstruction* root = computation->root_instruction();
779  EXPECT_THAT(root, op::Constant());
780  EXPECT_EQ(root->literal().GetFirstElement<float>(), 1);
781}
782
783// Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
784TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
785  Shape r1f32 = ShapeUtil::MakeShape(F32, {42});
786  HloComputation::Builder builder(TestName());
787  HloInstruction* param0 = builder.AddInstruction(
788      HloInstruction::CreateParameter(0, r1f32, "param0"));
789  HloInstruction* zero = builder.AddInstruction(
790      HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
791  builder.AddInstruction(
792      HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
793
794  auto computation = module().AddEntryComputation(builder.Build());
795
796  EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero));
797
798  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
799                                 non_bitcasting_callback());
800  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
801
802  HloInstruction* root = computation->root_instruction();
803  EXPECT_THAT(root, op::Broadcast());
804  EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32))
805      << ShapeUtil::HumanString(root->shape());
806  EXPECT_EQ(root->dimensions().size(), 0);
807  EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
808  EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1);
809}
810
811// Test that pow(A, 1) is simplified to A.
812TEST_F(AlgebraicSimplifierTest, Pow1) {
813  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
814  HloComputation::Builder builder(TestName());
815  HloInstruction* param0 = builder.AddInstruction(
816      HloInstruction::CreateParameter(0, r0f32, "param0"));
817  HloInstruction* one = builder.AddInstruction(
818      HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
819  builder.AddInstruction(
820      HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
821
822  auto computation = module().AddEntryComputation(builder.Build());
823
824  EXPECT_THAT(computation->root_instruction(), op::Power(param0, one));
825
826  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
827                                 non_bitcasting_callback());
828  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
829
830  EXPECT_EQ(computation->root_instruction(), param0);
831}
832
833// Test that pow(A, 2) is simplified to A*A.
834TEST_F(AlgebraicSimplifierTest, Pow2) {
835  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
836  HloComputation::Builder builder(TestName());
837  HloInstruction* param0 = builder.AddInstruction(
838      HloInstruction::CreateParameter(0, r0f32, "param0"));
839  HloInstruction* two = builder.AddInstruction(
840      HloInstruction::CreateConstant(Literal::CreateR0<float>(2)));
841  builder.AddInstruction(
842      HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
843
844  auto computation = module().AddEntryComputation(builder.Build());
845
846  EXPECT_THAT(computation->root_instruction(), op::Power(param0, two));
847
848  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
849                                 non_bitcasting_callback());
850  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
851
852  EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0));
853}
854
855// Test that pow(A, -1) is simplified to 1/A.
856TEST_F(AlgebraicSimplifierTest, PowNegative1) {
857  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
858  HloComputation::Builder builder(TestName());
859  HloInstruction* param0 = builder.AddInstruction(
860      HloInstruction::CreateParameter(0, r0f32, "param0"));
861  HloInstruction* negative_one = builder.AddInstruction(
862      HloInstruction::CreateConstant(Literal::CreateR0<float>(-1)));
863  builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
864                                                      param0, negative_one));
865
866  auto computation = module().AddEntryComputation(builder.Build());
867
868  EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one));
869
870  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
871                                 non_bitcasting_callback());
872  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
873
874  HloInstruction* root = computation->root_instruction();
875  EXPECT_THAT(root, op::Divide(op::Broadcast(), param0));
876  EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast);
877  EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement<float>(),
878            1);
879}
880
881TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
882  auto builder = HloComputation::Builder(TestName());
883  HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
884      0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs"));
885
886  HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter(
887      1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs"));
888
889  ConvolutionDimensionNumbers dnums;
890  dnums.set_input_batch_dimension(0);
891  dnums.add_input_spatial_dimensions(1);
892  dnums.set_input_feature_dimension(2);
893
894  dnums.set_output_batch_dimension(0);
895  dnums.add_output_spatial_dimensions(1);
896  dnums.set_output_feature_dimension(2);
897
898  dnums.add_kernel_spatial_dimensions(0);
899  dnums.set_kernel_input_feature_dimension(1);
900  dnums.set_kernel_output_feature_dimension(2);
901  Window window;
902  WindowDimension* dim = window.add_dimensions();
903  dim->set_size(3);
904  dim->set_padding_low(0);
905  dim->set_padding_high(0);
906  dim->set_stride(1);
907  dim->set_window_dilation(1);
908  dim->set_base_dilation(1);
909  dim->set_window_reversal(false);
910  // Create add computation.
911  builder.AddInstruction(HloInstruction::CreateConvolve(
912      ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
913  module().AddEntryComputation(builder.Build());
914  HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
915                                             non_bitcasting_callback());
916  EXPECT_THAT(module().entry_computation()->root_instruction(),
917              op::Convolution(lhs, rhs));
918  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
919  EXPECT_THAT(module().entry_computation()->root_instruction(),
920              op::Broadcast(op::Constant()));
921}
922
923TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
924  auto builder = HloComputation::Builder(TestName());
925  HloInstruction* param =
926      builder.AddInstruction(HloInstruction::CreateParameter(
927          0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
928  Window window;
929  for (int64 i = 0; i < 2; ++i) {
930    WindowDimension* dim = window.add_dimensions();
931    dim->set_size(1);
932    dim->set_padding_low(1);
933    dim->set_padding_high(1);
934    dim->set_window_dilation(1);
935    dim->set_base_dilation(1);
936  }
937  // Create add computation.
938  HloComputation* add_computation = nullptr;
939  {
940    HloComputation::Builder builder(TestName() + ".add");
941    const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
942    HloInstruction* p0 = builder.AddInstruction(
943        HloInstruction::CreateParameter(0, scalar_shape, "p0"));
944    HloInstruction* p1 = builder.AddInstruction(
945        HloInstruction::CreateParameter(1, scalar_shape, "p1"));
946    builder.AddInstruction(
947        HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
948    add_computation = module().AddEmbeddedComputation(builder.Build());
949  }
950  builder.AddInstruction(HloInstruction::CreateReduceWindow(
951      ShapeUtil::MakeShape(F32, {5, 2}), param,
952      builder.AddInstruction(
953          HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))),
954      window, add_computation));
955  module().AddEntryComputation(builder.Build());
956  HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
957                                             non_bitcasting_callback());
958  EXPECT_THAT(module().entry_computation()->root_instruction(),
959              op::ReduceWindow(param, op::Constant()));
960  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
961  EXPECT_THAT(module().entry_computation()->root_instruction(),
962              op::Broadcast(op::Constant()));
963}
964
965TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
966  auto builder = HloComputation::Builder(TestName());
967  HloInstruction* param =
968      builder.AddInstruction(HloInstruction::CreateParameter(
969          0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
970  PaddingConfig padding;
971  for (int i = 0; i < 2; ++i) {
972    PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions();
973    dimension->set_edge_padding_low(1);
974    dimension->set_edge_padding_high(1);
975    dimension->set_interior_padding(0);
976  }
977  builder.AddInstruction(HloInstruction::CreatePad(
978      ShapeUtil::MakeShape(F32, {5, 2}), param,
979      builder.AddInstruction(
980          HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
981      padding));
982  module().AddEntryComputation(builder.Build());
983  EXPECT_THAT(module().entry_computation()->root_instruction(),
984              op::Pad(param, op::Constant()));
985  HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
986                                             non_bitcasting_callback());
987  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
988  EXPECT_THAT(module().entry_computation()->root_instruction(),
989              op::Broadcast(op::Constant()));
990}
991
992TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
993  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
994
995  auto builder = HloComputation::Builder(TestName());
996  auto op = builder.AddInstruction(HloInstruction::CreateParameter(
997      0, ShapeUtil::MakeShape(F32, {3, 2}), "op"));
998  auto reshape1 = builder.AddInstruction(
999      HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op));
1000  auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1001      ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1}));
1002  builder.AddInstruction(HloInstruction::CreateReshape(
1003      ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
1004
1005  auto computation = builder.Build();
1006  module().AddEntryComputation(std::move(computation));
1007
1008  EXPECT_THAT(module().entry_computation()->root_instruction(),
1009              op::Reshape(op::Broadcast(op::Reshape(op))));
1010
1011  HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
1012                                             non_bitcasting_callback());
1013  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1014
1015  EXPECT_THAT(module().entry_computation()->root_instruction(), op);
1016}
1017
1018// Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
1019TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
1020  HloComputation::Builder builder(TestName());
1021  HloInstruction* input = builder.AddInstruction(
1022      HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
1023  builder.AddInstruction(
1024      HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
1025
1026  auto computation = module().AddEntryComputation(builder.Build());
1027
1028  EXPECT_THAT(computation->root_instruction(), op::Convert(input));
1029
1030  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1031                                 non_bitcasting_callback());
1032  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1033
1034  EXPECT_THAT(computation->root_instruction(), input);
1035}
1036
1037// Test that copies are removed.
1038TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
1039  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1040  HloComputation::Builder builder(TestName());
1041  HloInstruction* param0 = builder.AddInstruction(
1042      HloInstruction::CreateParameter(0, r0f32, "param0"));
1043  builder.AddInstruction(
1044      HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
1045
1046  auto computation = module().AddEntryComputation(builder.Build());
1047
1048  EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
1049
1050  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1051                                 non_bitcasting_callback());
1052  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1053
1054  EXPECT_THAT(computation->root_instruction(), param0);
1055}
1056
1057// Test that unary concatenates are removed.
1058TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
1059  Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
1060  HloComputation::Builder builder(TestName());
1061  HloInstruction* param0 = builder.AddInstruction(
1062      HloInstruction::CreateParameter(0, r1f32, "param0"));
1063  builder.AddInstruction(
1064      HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0));
1065
1066  auto computation = module().AddEntryComputation(builder.Build());
1067
1068  EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0));
1069
1070  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1071                                 non_bitcasting_callback());
1072  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1073
1074  EXPECT_THAT(computation->root_instruction(), param0);
1075}
1076
1077// Test that empty operands of concatenates are removed.
1078TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
1079  const int kParamLength = 100;
1080  Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
1081  HloComputation::Builder builder(TestName());
1082  HloInstruction* param0 = builder.AddInstruction(
1083      HloInstruction::CreateParameter(0, r1f32, "param0"));
1084  HloInstruction* param1 = builder.AddInstruction(
1085      HloInstruction::CreateParameter(1, r1f32, "param1"));
1086  HloInstruction* empty_literal = builder.AddInstruction(
1087      HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
1088  HloInstruction* empty_slice =
1089      builder.AddInstruction(HloInstruction::CreateSlice(
1090          ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
1091  Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
1092  builder.AddInstruction(HloInstruction::CreateConcatenate(
1093      result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
1094
1095  auto computation = module().AddEntryComputation(builder.Build());
1096
1097  EXPECT_THAT(
1098      computation->root_instruction(),
1099      op::Concatenate(empty_literal, param0, param0, empty_slice, param1));
1100
1101  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1102                                 non_bitcasting_callback());
1103  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1104
1105  EXPECT_THAT(computation->root_instruction(),
1106              op::Concatenate(param0, param0, param1));
1107}
1108
1109// Test a concatenate with only empty operands is removed.
1110TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
1111  const int kParamLength = 100;
1112  Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength});
1113  HloComputation::Builder builder(TestName());
1114  HloInstruction* param0 = builder.AddInstruction(
1115      HloInstruction::CreateParameter(0, r1f32, "param0"));
1116  HloInstruction* empty_literal = builder.AddInstruction(
1117      HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
1118  HloInstruction* empty_slice =
1119      builder.AddInstruction(HloInstruction::CreateSlice(
1120          ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
1121  Shape result_shape = ShapeUtil::MakeShape(F32, {0});
1122  builder.AddInstruction(HloInstruction::CreateConcatenate(
1123      result_shape, {empty_literal, empty_slice}, 0));
1124
1125  auto computation = module().AddEntryComputation(builder.Build());
1126
1127  EXPECT_THAT(computation->root_instruction(),
1128              op::Concatenate(empty_literal, empty_slice));
1129
1130  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1131                                 non_bitcasting_callback());
1132  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1133
1134  EXPECT_EQ(computation->root_instruction(), empty_literal);
1135}
1136
1137// Test that concat with a scalar broadcast becomes a pad.
1138TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) {
1139  Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
1140  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1141  HloComputation::Builder builder(TestName());
1142  HloInstruction* param0 = builder.AddInstruction(
1143      HloInstruction::CreateParameter(0, r1f32, "param0"));
1144  HloInstruction* param1 = builder.AddInstruction(
1145      HloInstruction::CreateParameter(1, r0f32, "param1"));
1146  HloInstruction* broadcast = builder.AddInstruction(
1147      HloInstruction::CreateBroadcast(r1f32, param1, {}));
1148  builder.AddInstruction(HloInstruction::CreateConcatenate(
1149      ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0));
1150
1151  auto computation = module().AddEntryComputation(builder.Build());
1152
1153  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1154                                 non_bitcasting_callback());
1155  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1156  EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1));
1157}
1158
1159// Test that a simplification which changes layouts is not performed if layout
1160// sensitive is true.
1161TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
1162  HloComputation::Builder builder(TestName());
1163  HloInstruction* param0 =
1164      builder.AddInstruction(HloInstruction::CreateParameter(
1165          0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
1166  HloInstruction* copy = builder.AddInstruction(
1167      HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
1168
1169  auto computation = module().AddEntryComputation(builder.Build());
1170
1171  // Set to different layouts.
1172  *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1173  *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
1174
1175  EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
1176
1177  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
1178                                 non_bitcasting_callback());
1179  EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
1180
1181  // Copy has not been removed.
1182  EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
1183}
1184
1185// Test that a simplification which preserves layouts is performed if layout
1186// sensitive is true.
1187TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
1188  HloComputation::Builder builder(TestName());
1189  HloInstruction* param0 =
1190      builder.AddInstruction(HloInstruction::CreateParameter(
1191          0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
1192  HloInstruction* copy = builder.AddInstruction(
1193      HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
1194
1195  auto computation = module().AddEntryComputation(builder.Build());
1196
1197  // Set to same layouts.
1198  *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1199  *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1200
1201  EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
1202
1203  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
1204                                 non_bitcasting_callback());
1205  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1206
1207  // Copy has been removed.
1208  EXPECT_THAT(computation->root_instruction(), param0);
1209}
1210
1211// Test that a reshape which could be replaced with a bitcast is not if
1212// add_bitcasts is false.
1213TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
1214  HloComputation::Builder builder(TestName());
1215  HloInstruction* param0 =
1216      builder.AddInstruction(HloInstruction::CreateParameter(
1217          0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
1218  HloInstruction* reshape =
1219      builder.AddInstruction(HloInstruction::CreateReshape(
1220          ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
1221
1222  *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1223  *reshape->mutable_shape()->mutable_layout() =
1224      LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
1225
1226  auto computation = module().AddEntryComputation(builder.Build());
1227
1228  EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
1229
1230  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
1231                                 non_bitcasting_callback());
1232  EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
1233
1234  // Reshape is not replaced with a bitcast.
1235  EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
1236}
1237
1238// Test transforming reshapes to bitcasts under various conditions.
1239TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
1240  HloComputation::Builder builder(TestName());
1241  HloInstruction* param0 =
1242      builder.AddInstruction(HloInstruction::CreateParameter(
1243          0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
1244  *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
1245
1246  // Reshape which can be transformed into a bitcast.
1247  HloInstruction* transformable_reshape =
1248      builder.AddInstruction(HloInstruction::CreateReshape(
1249          ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
1250  *transformable_reshape->mutable_shape()->mutable_layout() =
1251      LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
1252
1253  // Reshape does not just add degenerate dimensions.
1254  HloInstruction* dimensions_wrong_reshape =
1255      builder.AddInstruction(HloInstruction::CreateReshape(
1256          ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0));
1257  *dimensions_wrong_reshape->mutable_shape()->mutable_layout() =
1258      LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
1259
1260  // Reshape has wrong layout.
1261  HloInstruction* layout_wrong_reshape =
1262      builder.AddInstruction(HloInstruction::CreateReshape(
1263          ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
1264  *layout_wrong_reshape->mutable_shape()->mutable_layout() =
1265      LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0});
1266
1267  // Collect all the reshapes into a tuple so they are not dead.
1268  builder.AddInstruction(HloInstruction::CreateTuple(
1269      {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
1270
1271  auto computation = module().AddEntryComputation(builder.Build());
1272
1273  EXPECT_THAT(computation->root_instruction(),
1274              op::Tuple(transformable_reshape, dimensions_wrong_reshape,
1275                        layout_wrong_reshape));
1276
1277  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
1278                                 bitcasting_callback());
1279  simplifier.Run(&module()).ValueOrDie();
1280
1281  // Verify that only the first reshape is replaced.
1282  EXPECT_THAT(
1283      computation->root_instruction(),
1284      op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape));
1285}
1286
1287TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
1288  HloComputation::Builder builder(TestName());
1289  HloInstruction* param =
1290      builder.AddInstruction(HloInstruction::CreateParameter(
1291          0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param"));
1292  HloInstruction* movable_reshape =
1293      builder.AddInstruction(HloInstruction::CreateReshape(
1294          ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param));
1295  HloInstruction* zero = builder.AddInstruction(
1296      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
1297  builder.AddInstruction(
1298      HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}),
1299                                   HloOpcode::kMaximum, movable_reshape, zero));
1300  auto computation = module().AddEntryComputation(builder.Build());
1301
1302  EXPECT_THAT(computation->root_instruction(),
1303              op::Maximum(op::Reshape(param), zero));
1304
1305  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1306                                 bitcasting_callback());
1307
1308  simplifier.Run(&module()).ValueOrDie();
1309  EXPECT_THAT(computation->root_instruction(),
1310              op::Reshape(op::Maximum(param, zero)));
1311}
1312
1313// Regression test for a bug in the reshape sinking transformation, where
1314// moving a reshape to a scalar led to a crash.
1315TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) {
1316  HloComputation::Builder builder(TestName());
1317  HloInstruction* param =
1318      builder.AddInstruction(HloInstruction::CreateParameter(
1319          0, ShapeUtil::MakeShape(F32, {1, 1}), "param"));
1320  HloInstruction* reshape = builder.AddInstruction(
1321      HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param));
1322  HloInstruction* zero = builder.AddInstruction(
1323      HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.})));
1324  builder.AddInstruction(HloInstruction::CreateBinary(
1325      ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero));
1326  auto computation = module().AddEntryComputation(builder.Build());
1327
1328  EXPECT_THAT(computation->root_instruction(),
1329              op::Maximum(op::Reshape(param), zero));
1330
1331  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1332                                 bitcasting_callback());
1333
1334  simplifier.Run(&module()).ValueOrDie();
1335
1336  EXPECT_THAT(computation->root_instruction(),
1337              op::Maximum(op::Reshape(param), zero));
1338}
1339
1340// Regression test for a bug where if we failed to sink a reshape, we'd set the
1341// 'changed' bit in AlgebraicSimplifier to false.
1342TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
1343  HloComputation::Builder builder(TestName());
1344
1345  // This add (param0 + 0) can be simplified.
1346  Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
1347  HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
1348      shape, HloOpcode::kAdd,
1349      builder.AddInstruction(
1350          HloInstruction::CreateParameter(0, shape, "param0")),
1351      builder.AddInstruction(HloInstruction::CreateConstant(
1352          Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
1353
1354  builder.AddInstruction(
1355      HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
1356
1357  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1358                                 bitcasting_callback());
1359  module().AddEntryComputation(builder.Build());
1360  EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie());
1361}
1362
1363// Regression test for a bug where if we failed to sink a reshape, we'd set the
1364// 'changed' bit in AlgebraicSimplifier to false.
1365TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
1366  HloComputation::Builder builder(TestName());
1367
1368  // This add (param0 + 0) can be simplified.
1369  Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
1370  HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
1371      shape, HloOpcode::kAdd,
1372      builder.AddInstruction(
1373          HloInstruction::CreateParameter(0, shape, "param0")),
1374      builder.AddInstruction(HloInstruction::CreateConstant(
1375          Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
1376
1377  builder.AddInstruction(
1378      HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
1379                                      /*broadcast_dimensions=*/{0, 1}));
1380
1381  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1382                                 bitcasting_callback());
1383  module().AddEntryComputation(builder.Build());
1384  EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie());
1385}
1386
1387TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
1388  HloComputation::Builder builder(TestName());
1389  HloInstruction* param =
1390      builder.AddInstruction(HloInstruction::CreateParameter(
1391          0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param"));
1392  *param->mutable_shape()->mutable_layout() =
1393      LayoutUtil::MakeLayout({1, 2, 0, 3});
1394
1395  HloInstruction* transpose =
1396      builder.AddInstruction(HloInstruction::CreateTranspose(
1397          ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3}));
1398  *transpose->mutable_shape()->mutable_layout() =
1399      LayoutUtil::MakeLayout({0, 1, 2, 3});
1400
1401  auto computation = module().AddEntryComputation(builder.Build());
1402
1403  EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
1404
1405  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
1406                                 bitcasting_callback());
1407  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1408
1409  // Verify that the reshape is replaced.
1410  EXPECT_THAT(computation->root_instruction(), op::Bitcast(param));
1411}
1412
1413TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
1414  HloComputation::Builder builder(TestName());
1415  HloInstruction* param =
1416      builder.AddInstruction(HloInstruction::CreateParameter(
1417          0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param"));
1418  *param->mutable_shape()->mutable_layout() =
1419      LayoutUtil::MakeLayout({1, 2, 3, 0});
1420
1421  HloInstruction* transpose =
1422      builder.AddInstruction(HloInstruction::CreateTranspose(
1423          ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1}));
1424  *transpose->mutable_shape()->mutable_layout() =
1425      LayoutUtil::MakeLayout({3, 1, 2, 0});
1426
1427  auto computation = module().AddEntryComputation(builder.Build());
1428
1429  EXPECT_THAT(computation->root_instruction(), op::Transpose(param));
1430
1431  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
1432                                 bitcasting_callback());
1433  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1434
1435  // Verify that the reshape is replaced.
1436  EXPECT_THAT(computation->root_instruction(), op::Bitcast(param));
1437}
1438
1439TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
1440  HloComputation::Builder builder(TestName());
1441  HloInstruction* param0 =
1442      builder.AddInstruction(HloInstruction::CreateParameter(
1443          0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
1444
1445  HloInstruction* reshape1 =
1446      builder.AddInstruction(HloInstruction::CreateReshape(
1447          ShapeUtil::MakeShape(F32, {2, 1, 2}), param0));
1448
1449  builder.AddInstruction(HloInstruction::CreateReshape(
1450      ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
1451
1452  auto computation = module().AddEntryComputation(builder.Build());
1453
1454  EXPECT_THAT(computation->root_instruction(),
1455              op::Reshape(op::Reshape(param0)));
1456
1457  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1458                                 non_bitcasting_callback());
1459  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1460
1461  EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
1462}
1463
1464TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
1465  HloComputation::Builder builder(TestName());
1466  HloInstruction* param0 =
1467      builder.AddInstruction(HloInstruction::CreateParameter(
1468          0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}),
1469          "param0"));
1470
1471  HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
1472      ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
1473      HloOpcode::kCopy, param0));
1474
1475  builder.AddInstruction(HloInstruction::CreateUnary(
1476      ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}),
1477      HloOpcode::kCopy, copy1));
1478
1479  auto computation = module().AddEntryComputation(builder.Build());
1480
1481  EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0)));
1482
1483  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
1484                                 non_bitcasting_callback());
1485  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1486
1487  EXPECT_THAT(computation->root_instruction(), op::Copy(param0));
1488}
1489
1490TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
1491  HloComputation::Builder builder(TestName());
1492  HloInstruction* param0 =
1493      builder.AddInstruction(HloInstruction::CreateParameter(
1494          0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0"));
1495
1496  HloInstruction* transpose1 =
1497      builder.AddInstruction(HloInstruction::CreateTranspose(
1498          ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0}));
1499
1500  builder.AddInstruction(HloInstruction::CreateTranspose(
1501      ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
1502
1503  auto computation = module().AddEntryComputation(builder.Build());
1504
1505  EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1));
1506
1507  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1508                                 non_bitcasting_callback());
1509  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1510
1511  EXPECT_THAT(computation->root_instruction(), op::Transpose(param0));
1512  EXPECT_EQ(std::vector<int64>({2, 1, 0}),
1513            computation->root_instruction()->dimensions());
1514}
1515
1516// Test merging reshape and broadcast.
1517TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
1518  HloComputation::Builder builder(TestName());
1519  auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
1520      0, ShapeUtil::MakeShape(F32, {5}), "param0"));
1521  auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
1522      ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
1523  builder.AddInstruction(HloInstruction::CreateBroadcast(
1524      ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2}));
1525
1526  auto computation = module().AddEntryComputation(builder.Build());
1527
1528  EXPECT_THAT(computation->root_instruction(),
1529              op::Broadcast(op::Reshape(param0)));
1530
1531  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1532                                 non_bitcasting_callback());
1533  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1534
1535  EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0));
1536}
1537
1538// Test merging broadcast and reshape.
1539TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
1540  HloComputation::Builder builder(TestName());
1541  auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
1542      0, ShapeUtil::MakeShape(F32, {2, 3}), "param0"));
1543  auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
1544      ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2}));
1545  builder.AddInstruction(HloInstruction::CreateReshape(
1546      ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
1547
1548  auto computation = module().AddEntryComputation(builder.Build());
1549
1550  EXPECT_THAT(computation->root_instruction(),
1551              op::Reshape(op::Broadcast(param0)));
1552
1553  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1554                                 non_bitcasting_callback());
1555  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1556
1557  EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0));
1558}
1559
1560TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
1561  HloComputation::Builder builder(TestName());
1562  auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1563      0, ShapeUtil::MakeShape(F32, {1}), "param"));
1564  auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1565      ShapeUtil::MakeShape(F32, {3, 1}), param, {1}));
1566  builder.AddInstruction(
1567      HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
1568
1569  auto computation = module().AddEntryComputation(builder.Build());
1570
1571  EXPECT_THAT(computation->root_instruction(),
1572              op::Reshape(op::Broadcast(param)));
1573
1574  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1575                                 non_bitcasting_callback());
1576  EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
1577
1578  EXPECT_THAT(computation->root_instruction(),
1579              op::Reshape(op::Broadcast(param)));
1580}
1581
1582TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
1583  HloComputation::Builder builder(TestName());
1584  auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1585      0, ShapeUtil::MakeShape(F32, {4}), "param"));
1586  auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1587      ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2}));
1588  builder.AddInstruction(HloInstruction::CreateReshape(
1589      ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
1590
1591  HloComputation* computation = module().AddEntryComputation(builder.Build());
1592
1593  EXPECT_THAT(computation->root_instruction(),
1594              op::Reshape(op::Broadcast(param)));
1595
1596  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1597                                 non_bitcasting_callback());
1598  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1599
1600  EXPECT_THAT(computation->root_instruction(), op::Broadcast(param));
1601  EXPECT_THAT(computation->root_instruction()->dimensions(),
1602              ::testing::ElementsAre(3));
1603}
1604
1605TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
1606  HloComputation::Builder builder(TestName());
1607  auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1608      0, ShapeUtil::MakeShape(F32, {1}), "param"));
1609  auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1610      ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2}));
1611  builder.AddInstruction(HloInstruction::CreateReshape(
1612      ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
1613
1614  HloComputation* computation = module().AddEntryComputation(builder.Build());
1615
1616  EXPECT_THAT(computation->root_instruction(),
1617              op::Reshape(op::Broadcast(param)));
1618
1619  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1620                                 non_bitcasting_callback());
1621  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
1622
1623  EXPECT_THAT(computation->root_instruction(), op::Broadcast(param));
1624  const std::vector<int64> broadcast_dims =
1625      computation->root_instruction()->dimensions();
1626  EXPECT_EQ(1, broadcast_dims.size());
1627  EXPECT_THAT(broadcast_dims[0], ::testing::AnyOf(1, 2, 3));
1628}
1629
1630TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
1631  HloComputation::Builder builder(TestName());
1632  auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1633      0, ShapeUtil::MakeShape(F32, {4}), "param"));
1634  auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1635      ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2}));
1636  builder.AddInstruction(HloInstruction::CreateReshape(
1637      ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
1638
1639  HloComputation* computation = module().AddEntryComputation(builder.Build());
1640
1641  EXPECT_THAT(computation->root_instruction(),
1642              op::Reshape(op::Broadcast(param)));
1643
1644  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1645                                 non_bitcasting_callback());
1646  EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie());
1647
1648  EXPECT_THAT(computation->root_instruction(),
1649              op::Reshape(op::Broadcast(param)));
1650}
1651
1652TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
1653  HloComputation::Builder builder(TestName());
1654  HloInstruction* param =
1655      builder.AddInstruction(HloInstruction::CreateParameter(
1656          0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
1657  HloInstruction* zero = builder.AddInstruction(
1658      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
1659  PaddingConfig no_padding;
1660  for (int i = 0; i < 2; ++i) {
1661    auto dimension = no_padding.add_dimensions();
1662    dimension->set_edge_padding_low(0);
1663    dimension->set_edge_padding_high(0);
1664    dimension->set_interior_padding(0);
1665  }
1666  builder.AddInstruction(HloInstruction::CreatePad(
1667      ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
1668
1669  HloModule module(TestName());
1670  HloComputation* computation = module.AddEntryComputation(builder.Build());
1671
1672  EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
1673
1674  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1675                                 non_bitcasting_callback());
1676  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
1677
1678  EXPECT_THAT(computation->root_instruction(), param);
1679}
1680
1681TEST_F(AlgebraicSimplifierTest, NegativePadding) {
1682  // Verify that a pad instruction with negative padding is replaced with a
1683  // pad with non-negative padding followed by a slice.
1684  HloComputation::Builder builder(TestName());
1685  HloInstruction* param =
1686      builder.AddInstruction(HloInstruction::CreateParameter(
1687          0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
1688  HloInstruction* zero = builder.AddInstruction(
1689      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
1690  PaddingConfig padding;
1691  int64 low_padding[2] = {-1, -2};
1692  int64 high_padding[2] = {2, -3};
1693  for (int i = 0; i < 2; ++i) {
1694    auto dimension = padding.add_dimensions();
1695    dimension->set_edge_padding_low(low_padding[i]);
1696    dimension->set_edge_padding_high(high_padding[i]);
1697    dimension->set_interior_padding(0);
1698  }
1699  HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
1700      ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
1701
1702  HloModule module(TestName());
1703  HloComputation* computation = module.AddEntryComputation(builder.Build());
1704
1705  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1706                                 non_bitcasting_callback());
1707
1708  auto has_negative_padding = [](const HloInstruction* pad) {
1709    for (auto& padding_dimension : pad->padding_config().dimensions()) {
1710      if (padding_dimension.edge_padding_low() < 0 ||
1711          padding_dimension.edge_padding_high() < 0) {
1712        return true;
1713      }
1714    }
1715    return false;
1716  };
1717
1718  EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
1719  EXPECT_TRUE(has_negative_padding(pad));
1720
1721  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
1722
1723  EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero)));
1724  EXPECT_FALSE(
1725      has_negative_padding(computation->root_instruction()->operand(0)));
1726}
1727
1728TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
1729  HloComputation::Builder builder(TestName());
1730  HloInstruction* param =
1731      builder.AddInstruction(HloInstruction::CreateParameter(
1732          0, ShapeUtil::MakeShape(F32, {2, 3}), "param"));
1733  builder.AddInstruction(
1734      HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
1735
1736  HloModule module(TestName());
1737  HloComputation* computation = module.AddEntryComputation(builder.Build());
1738
1739  EXPECT_THAT(computation->root_instruction(), op::Reshape(param));
1740
1741  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1742                                 non_bitcasting_callback());
1743  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
1744
1745  EXPECT_THAT(computation->root_instruction(), param);
1746}
1747
1748TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
1749  HloComputation::Builder builder(TestName());
1750  const int64 dim0 = 2;
1751  const int64 dim1 = 3;
1752  HloInstruction* param =
1753      builder.AddInstruction(HloInstruction::CreateParameter(
1754          0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
1755  builder.AddInstruction(HloInstruction::CreateSlice(
1756      ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
1757      /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
1758
1759  HloModule module(TestName());
1760  HloComputation* computation = module.AddEntryComputation(builder.Build());
1761
1762  EXPECT_THAT(computation->root_instruction(), op::Slice(param));
1763
1764  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
1765                                 non_bitcasting_callback());
1766  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
1767
1768  EXPECT_THAT(computation->root_instruction(), param);
1769}
1770
1771TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
1772  struct ConvTestOptions {
1773    int in_batch = 10;
1774    int in_height = 2;
1775    int in_width = 2;
1776    int in_channels = 3;
1777    int f_width = 1;
1778    int f_height = 1;
1779    int f_output_channels = 10;
1780    int row_stride = 1;
1781    int row_padding = 0;
1782    int col_stride = 1;
1783    int col_padding = 0;
1784    bool input_minor_to_major_layout = false;
1785    bool filter_minor_to_major_layout = false;
1786    bool output_minor_to_major_layout = false;
1787
1788    const char* dim_order = "NHWC";         // can use chars NHWC in any order.
1789    const char* kernel_dim_order = "HWIO";  // can use chars HWIO in any order.
1790
1791    ConvTestOptions& Reset() {
1792      *this = ConvTestOptions();
1793      return *this;
1794    }
1795  };
1796
1797  ConvTestOptions options;
1798
1799  // Builds a convolution from <options> and runs algebraic simplification on
1800  // the computation. Returns a string description of the result of
1801  // simplification.
1802  auto build_and_simplify = [&options, this]() -> string {
1803    HloComputation::Builder b(TestName());
1804
1805    Window window;
1806    auto* f_dim_1 = window.add_dimensions();
1807    f_dim_1->set_size(options.f_height);
1808    f_dim_1->set_stride(options.row_stride);
1809    f_dim_1->set_padding_low(options.row_padding);
1810    f_dim_1->set_padding_high(options.row_padding);
1811    f_dim_1->set_window_dilation(1);
1812    f_dim_1->set_base_dilation(1);
1813    auto* f_dim_2 = window.add_dimensions();
1814    f_dim_2->set_size(options.f_width);
1815    f_dim_2->set_stride(options.col_stride);
1816    f_dim_2->set_padding_low(options.col_padding);
1817    f_dim_2->set_padding_high(options.col_padding);
1818    f_dim_2->set_window_dilation(1);
1819    f_dim_2->set_base_dilation(1);
1820
1821    ConvolutionDimensionNumbers dnums;
1822    std::vector<int64> in_dims;
1823    int in_channel_idx = -1;
1824    // filled in later
1825    dnums.add_input_spatial_dimensions(-1);
1826    dnums.add_output_spatial_dimensions(-1);
1827    dnums.add_input_spatial_dimensions(-1);
1828    dnums.add_output_spatial_dimensions(-1);
1829    for (int i = 0; i < strlen(options.dim_order); ++i) {
1830      char ch = options.dim_order[i];
1831      if (ch == 'N') {
1832        dnums.set_input_batch_dimension(i);
1833        dnums.set_output_batch_dimension(i);
1834        in_dims.push_back(options.in_batch);
1835      } else if (ch == 'H') {
1836        dnums.set_input_spatial_dimensions(0, i);
1837        dnums.set_output_spatial_dimensions(0, i);
1838        in_dims.push_back(options.in_height);
1839      } else if (ch == 'W') {
1840        dnums.set_input_spatial_dimensions(1, i);
1841        dnums.set_output_spatial_dimensions(1, i);
1842        in_dims.push_back(options.in_width);
1843      } else if (ch == 'C') {
1844        dnums.set_input_feature_dimension(i);
1845        dnums.set_output_feature_dimension(i);
1846        in_dims.push_back(options.in_channels);
1847        in_channel_idx = i;
1848      }
1849    }
1850
1851    std::vector<int64> f_dims;
1852    dnums.add_kernel_spatial_dimensions(-1);  // filled in later
1853    dnums.add_kernel_spatial_dimensions(-1);  // filled in later
1854    for (int i = 0; i < strlen(options.kernel_dim_order); ++i) {
1855      char ch = options.kernel_dim_order[i];
1856      if (ch == 'H') {
1857        dnums.set_kernel_spatial_dimensions(0, i);
1858        f_dims.push_back(options.f_height);
1859      } else if (ch == 'W') {
1860        dnums.set_kernel_spatial_dimensions(1, i);
1861        f_dims.push_back(options.f_width);
1862      } else if (ch == 'I') {
1863        dnums.set_kernel_input_feature_dimension(i);
1864        f_dims.push_back(options.in_channels);
1865      } else if (ch == 'O') {
1866        dnums.set_kernel_output_feature_dimension(i);
1867        f_dims.push_back(options.f_output_channels);
1868      }
1869    }
1870
1871    auto out_dims = in_dims;
1872    out_dims[in_channel_idx] = options.f_output_channels;
1873
1874    auto make_shape = [](tensorflow::gtl::ArraySlice<int64> dims,
1875                         bool minor_to_major_layout) {
1876      if (minor_to_major_layout) {
1877        return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
1878      } else {
1879        return ShapeUtil::MakeShape(F32, dims);
1880      }
1881    };
1882    auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
1883    auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
1884    auto out_shape = make_shape(out_dims, options.output_minor_to_major_layout);
1885
1886    HloInstruction* input =
1887        b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
1888    HloInstruction* filter =
1889        b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
1890
1891    b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
1892                                                    window, dnums));
1893
1894    HloModule module(TestName());
1895    auto* computation = module.AddEntryComputation(b.Build());
1896
1897    AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
1898                                   bitcasting_callback());
1899    if (!simplifier.Run(&module).ValueOrDie()) {
1900      return "NO_CHANGE";
1901    }
1902    auto* root = computation->root_instruction();
1903    if (root->opcode() == HloOpcode::kBitcast &&
1904        root->operand(0)->opcode() == HloOpcode::kDot) {
1905      auto lhs_shape = root->operand(0)->operand(0)->shape();
1906      auto rhs_shape = root->operand(0)->operand(1)->shape();
1907      return tensorflow::strings::StrCat(
1908          tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ",
1909          tensorflow::str_util::Join(rhs_shape.dimensions(), "x"));
1910    }
1911    return "UNEXPECTED CHANGE";
1912  };
1913
1914  // Default options are the simplest case and succeed.
1915  options.Reset();
1916  EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
1917
1918  // Swapping dim spatial and batch order works.
1919  options.Reset().dim_order = "NWHC";
1920  EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
1921  options.Reset().dim_order = "WHNC";
1922  EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
1923  // Channel dimension earlier fails.
1924  options.Reset().dim_order = "HWCN";
1925  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1926  options.Reset().dim_order = "CHWN";
1927  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1928
1929  // Filtering dims spatial dims can be anywhere, since they are 1x1.
1930  options.Reset().kernel_dim_order = "WHIO";
1931  EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
1932  options.Reset().kernel_dim_order = "IWOH";
1933  EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
1934  options.Reset().kernel_dim_order = "IWHO";
1935  EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
1936  // But moving output channel before input channel fails.
1937  options.Reset().kernel_dim_order = "HWOI";
1938  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1939  options.Reset().kernel_dim_order = "WHOI";
1940  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1941  options.Reset().kernel_dim_order = "OWIH";
1942  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1943  options.Reset().kernel_dim_order = "OWHI";
1944  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1945
1946  // Combine different dim and kernel dim orders.
1947  options.Reset().kernel_dim_order = "IWHO";
1948  options.dim_order = "WHNC";
1949  EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
1950
1951  // Test invalid cases from wrong filter size, strides, or padding.
1952  options.Reset().f_width = 2;
1953  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1954  options.Reset().f_height = 2;
1955  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1956  options.Reset().row_stride = 2;
1957  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1958  options.Reset().col_stride = 2;
1959  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1960  options.Reset().col_padding = 1;
1961  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1962  options.Reset().row_padding = 1;
1963  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1964
1965  // The default dim_order is "NHWC". Col-major layout makes C the most major.
1966  options.Reset().input_minor_to_major_layout = true;
1967  options.output_minor_to_major_layout = true;
1968  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1969
1970  // The input and output have different layouts.
1971  options.Reset().input_minor_to_major_layout = true;
1972  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1973
1974  // C is most minor, and I is more major than O.
1975  options.Reset().input_minor_to_major_layout = true;
1976  options.filter_minor_to_major_layout = true;
1977  options.output_minor_to_major_layout = true;
1978  options.dim_order = "CHWN";
1979  options.kernel_dim_order = "OIHW";
1980  EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
1981
1982  // C is not the most minor dimension.
1983  options.Reset().input_minor_to_major_layout = true;
1984  options.filter_minor_to_major_layout = true;
1985  options.output_minor_to_major_layout = true;
1986  options.dim_order = "HWNC";
1987  options.kernel_dim_order = "OIHW";
1988  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1989
1990  // I is more minor than O.
1991  options.Reset().input_minor_to_major_layout = true;
1992  options.filter_minor_to_major_layout = true;
1993  options.output_minor_to_major_layout = true;
1994  options.dim_order = "CHWN";
1995  options.kernel_dim_order = "IOHW";
1996  EXPECT_EQ("NO_CHANGE", build_and_simplify());
1997}
1998
1999// Test that max(min(A, x), y) is transformed to clamp(y, A, x)
2000TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
2001  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2002  HloComputation::Builder builder(TestName());
2003  HloInstruction* param0 = builder.AddInstruction(
2004      HloInstruction::CreateParameter(0, r0f32, "param0"));
2005  HloInstruction* min_value = builder.AddInstruction(
2006      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
2007  HloInstruction* max_value = builder.AddInstruction(
2008      HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
2009  HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary(
2010      r0f32, HloOpcode::kMinimum, param0, min_value));
2011  builder.AddInstruction(
2012      HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value));
2013
2014  HloModule module(TestName());
2015  auto computation = module.AddEntryComputation(builder.Build());
2016
2017  EXPECT_THAT(computation->root_instruction(),
2018              op::Maximum(op::Minimum(param0, min_value), max_value));
2019
2020  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2021                                 non_bitcasting_callback());
2022  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
2023
2024  EXPECT_THAT(computation->root_instruction(),
2025              op::Clamp(max_value, param0, min_value));
2026}
2027
2028// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar
2029// values.
2030TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
2031  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2032  HloComputation::Builder builder(TestName());
2033  HloInstruction* param0 = builder.AddInstruction(
2034      HloInstruction::CreateParameter(0, r0f32, "param0"));
2035  HloInstruction* min_value = builder.AddInstruction(
2036      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
2037  HloInstruction* max_value = builder.AddInstruction(
2038      HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
2039  HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
2040      r0f32, HloOpcode::kMaximum, param0, max_value));
2041  builder.AddInstruction(
2042      HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
2043
2044  HloModule module(TestName());
2045  auto computation = module.AddEntryComputation(builder.Build());
2046
2047  EXPECT_THAT(computation->root_instruction(),
2048              op::Minimum(op::Maximum(param0, max_value), min_value));
2049
2050  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2051                                 non_bitcasting_callback());
2052  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
2053
2054  EXPECT_THAT(computation->root_instruction(),
2055              op::Clamp(max_value, param0, min_value));
2056}
2057
2058// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for
2059// broadcasted scalar values.
2060TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
2061  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2062  Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
2063  HloComputation::Builder builder(TestName());
2064  HloInstruction* param0 = builder.AddInstruction(
2065      HloInstruction::CreateParameter(0, r1f32, "param0"));
2066  HloInstruction* min_value = builder.AddInstruction(
2067      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
2068  HloInstruction* max_value = builder.AddInstruction(
2069      HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
2070  HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
2071      r1f32, HloOpcode::kMaximum, param0, max_value));
2072  builder.AddInstruction(
2073      HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value));
2074
2075  HloModule module(TestName());
2076  auto computation = module.AddEntryComputation(builder.Build());
2077
2078  EXPECT_THAT(computation->root_instruction(),
2079              op::Minimum(op::Maximum(param0, max_value), min_value));
2080
2081  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2082                                 non_bitcasting_callback());
2083  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
2084
2085  EXPECT_THAT(computation->root_instruction(),
2086              op::Clamp(max_value, param0, min_value));
2087}
2088
2089// Test that min(max(A, non-constant1), non-constant2) is not canonicalized to
2090// clamp(non-constant1, A, non-constant2)
2091TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
2092  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2093  HloComputation::Builder builder(TestName());
2094  HloInstruction* param0 = builder.AddInstruction(
2095      HloInstruction::CreateParameter(0, r0f32, "param0"));
2096  HloInstruction* min_value = builder.AddInstruction(
2097      HloInstruction::CreateParameter(1, r0f32, "param1"));
2098  HloInstruction* max_value = builder.AddInstruction(
2099      HloInstruction::CreateParameter(2, r0f32, "param2"));
2100  HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
2101      r0f32, HloOpcode::kMaximum, param0, max_value));
2102  builder.AddInstruction(
2103      HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
2104
2105  HloModule module(TestName());
2106  auto computation = module.AddEntryComputation(builder.Build());
2107
2108  EXPECT_THAT(computation->root_instruction(),
2109              op::Minimum(op::Maximum(param0, max_value), min_value));
2110
2111  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2112                                 non_bitcasting_callback());
2113  EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
2114
2115  EXPECT_THAT(computation->root_instruction(),
2116              op::Minimum(op::Maximum(param0, max_value), min_value));
2117}
2118
2119// Test that min(f(max(A, constant1)), constant2) is not transformed to
2120// clamp(constant1, A, constant2)
2121TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
2122  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2123  HloComputation::Builder builder(TestName());
2124  HloInstruction* param0 = builder.AddInstruction(
2125      HloInstruction::CreateParameter(0, r0f32, "param0"));
2126  HloInstruction* min_value = builder.AddInstruction(
2127      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
2128  HloInstruction* max_value = builder.AddInstruction(
2129      HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
2130  HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
2131      r0f32, HloOpcode::kMaximum, param0, max_value));
2132  HloInstruction* fmax = builder.AddInstruction(
2133      HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value));
2134  builder.AddInstruction(HloInstruction::CreateBinary(
2135      r0f32, HloOpcode::kMinimum, fmax, min_value));
2136
2137  HloModule module(TestName());
2138  auto computation = module.AddEntryComputation(builder.Build());
2139
2140  EXPECT_THAT(computation->root_instruction(),
2141              op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
2142                          min_value));
2143
2144  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2145                                 non_bitcasting_callback());
2146  EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
2147
2148  EXPECT_THAT(computation->root_instruction(),
2149              op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
2150                          min_value));
2151}
2152
2153// Test that slice(broadcast(/*scalar value*/)) simplifies to a single
2154// broadcast.
2155TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
2156  Shape r0f32 = ShapeUtil::MakeShape(F32, {});
2157  HloComputation::Builder builder(TestName());
2158  HloInstruction* scalar_param = builder.AddInstruction(
2159      HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
2160
2161  Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
2162  HloInstruction* broadcast =
2163      builder.AddInstruction(HloInstruction::CreateBroadcast(
2164          broadcast_shape, scalar_param,
2165          AsInt64Slice(broadcast_shape.dimensions())));
2166
2167  Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
2168  HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
2169      slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
2170
2171  HloModule module(TestName());
2172  auto computation = module.AddEntryComputation(builder.Build());
2173
2174  HloInstruction* root = computation->root_instruction();
2175  EXPECT_EQ(root, slice);
2176  EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
2177
2178  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2179                                 non_bitcasting_callback());
2180
2181  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
2182
2183  // Running simplification again should not result in any further changes.
2184  ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
2185
2186  root = computation->root_instruction();
2187  EXPECT_THAT(root, op::Broadcast(scalar_param));
2188  EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
2189}
2190
2191// Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a
2192// single broadcast.
2193TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
2194  HloComputation::Builder builder(TestName());
2195  HloInstruction* forty_two = builder.AddInstruction(
2196      HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
2197
2198  Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
2199  HloInstruction* broadcast =
2200      builder.AddInstruction(HloInstruction::CreateBroadcast(
2201          broadcast_shape, forty_two,
2202          AsInt64Slice(broadcast_shape.dimensions())));
2203
2204  HloInstruction* transpose =
2205      builder.AddInstruction(HloInstruction::CreateTranspose(
2206          ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0}));
2207
2208  Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4});
2209  HloInstruction* reshape = builder.AddInstruction(
2210      HloInstruction::CreateReshape(reshape_shape, transpose));
2211
2212  HloModule module(TestName());
2213  auto computation = module.AddEntryComputation(builder.Build());
2214
2215  HloInstruction* root = computation->root_instruction();
2216  EXPECT_EQ(root, reshape);
2217  EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
2218
2219  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2220                                 non_bitcasting_callback());
2221  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
2222
2223  root = computation->root_instruction();
2224  EXPECT_THAT(root, op::Broadcast(forty_two));
2225  EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
2226}
2227
2228// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
2229TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
2230  HloModule module(TestName());
2231  HloComputation::Builder builder(TestName());
2232
2233  // Create operand to the pad.
2234  HloInstruction* operand =
2235      builder.AddInstruction(HloInstruction::CreateParameter(
2236          0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "p0"));
2237
2238  // Create the pad.
2239  PaddingConfig padding = MakeNoPaddingConfig(4);
2240  padding.mutable_dimensions(1)->set_edge_padding_low(1);
2241  padding.mutable_dimensions(3)->set_edge_padding_high(2);
2242
2243  HloInstruction* pad_value = builder.AddInstruction(
2244      HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
2245  HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
2246      ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
2247
2248  // Create add computation.
2249  HloComputation* add_computation = nullptr;
2250  {
2251    HloComputation::Builder builder(TestName() + ".add");
2252    const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2253    HloInstruction* p0 = builder.AddInstruction(
2254        HloInstruction::CreateParameter(0, scalar_shape, "p0"));
2255    HloInstruction* p1 = builder.AddInstruction(
2256        HloInstruction::CreateParameter(1, scalar_shape, "p1"));
2257    builder.AddInstruction(
2258        HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
2259    add_computation = module.AddEmbeddedComputation(builder.Build());
2260  }
2261
2262  // Create the reduce-window.
2263  Window window;
2264  for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
2265    auto* dim = window.add_dimensions();
2266    dim->set_size(1);
2267    dim->set_padding_low(10);
2268    dim->set_padding_high(100);
2269    dim->set_window_dilation(1);
2270    dim->set_base_dilation(1);
2271  }
2272  const Shape reduce_window_shape =
2273      ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
2274  HloInstruction* reduce_init_value = builder.AddInstruction(
2275      HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
2276  HloInstruction* reduce_window =
2277      builder.AddInstruction(HloInstruction::CreateReduceWindow(
2278          reduce_window_shape, pad, reduce_init_value, window,
2279          add_computation));
2280
2281  // Build the computation and run the simplifier.
2282  auto computation = module.AddEntryComputation(builder.Build());
2283  HloInstruction* root = computation->root_instruction();
2284  EXPECT_EQ(root, reduce_window);
2285  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2286                                 non_bitcasting_callback());
2287  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
2288
2289  // Running simplification again should not result in any further changes.
2290  ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
2291
2292  // Verify the result
2293  root = computation->root_instruction();
2294  EXPECT_THAT(root, op::ReduceWindow(operand, op::Constant()));
2295  EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
2296      << ShapeUtil::HumanString(root->shape()) << " vs "
2297      << ShapeUtil::HumanString(reduce_window_shape);
2298  EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
2299  EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
2300  EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
2301  EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
2302  EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
2303  EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
2304  EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
2305  EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
2306}
2307
2308TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
2309  HloComputation::Builder builder(TestName());
2310  const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});
2311  HloInstruction* a =
2312      builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
2313  builder.AddInstruction(
2314      HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
2315
2316  HloModule module(TestName());
2317  auto computation = module.AddEntryComputation(builder.Build());
2318
2319  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2320                                 non_bitcasting_callback());
2321  ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
2322
2323  HloInstruction* root = computation->root_instruction();
2324  EXPECT_EQ(a, root);
2325  EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
2326}
2327
2328TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
2329  // Dots add computations to the parent module. Test that, when the HloModule's
2330  // computations are updated, then iterator invalidation doesn't occur
2331  // when running on subsequent computations.
2332  Shape r1f32 = ShapeUtil::MakeShape(F32, {1});
2333  HloComputation::Builder builder(TestName() + ".Dot");
2334  HloInstruction* x =
2335      builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
2336  HloInstruction* y =
2337      builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
2338  DotDimensionNumbers dot_dnums;
2339  dot_dnums.add_lhs_contracting_dimensions(1);
2340  dot_dnums.add_rhs_contracting_dimensions(0);
2341  builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums));
2342  std::unique_ptr<HloComputation> dot_computation(builder.Build());
2343
2344  HloComputation::Builder call_builder(TestName() + ".Call");
2345  HloInstruction* zero = call_builder.AddInstruction(
2346      HloInstruction::CreateConstant(Literal::CreateR1<float>({0.0f})));
2347  HloInstruction* one = call_builder.AddInstruction(
2348      HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0f})));
2349  call_builder.AddInstruction(
2350      HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
2351
2352  module().AddEmbeddedComputation(std::move(dot_computation));
2353  module().AddEntryComputation(call_builder.Build());
2354  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2355                                 non_bitcasting_callback());
2356  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
2357}
2358
2359// Test that a constant with tuple shape becomes a tuple of constants.
2360TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
2361  HloComputation::Builder builder(TestName());
2362  const float constant_scalar = 7.3f;
2363  std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
2364  std::unique_ptr<Literal> value =
2365      Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
2366                          Literal::CreateR1<float>(constant_vector).get()});
2367  builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
2368
2369  auto computation = module().AddEntryComputation(builder.Build());
2370
2371  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2372                                 non_bitcasting_callback());
2373  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
2374  EXPECT_THAT(computation->root_instruction(),
2375              op::Tuple(op::Constant(), op::Constant()));
2376}
2377
2378// A dynamic-slice is trivial if its start indices are all zeroes and the size
2379// of its input equals the size of its output.  In this case, the dynamic slice
2380// is equal to its input.
2381TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
2382  HloComputation::Builder builder(TestName());
2383
2384  Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
2385  builder.AddInstruction(HloInstruction::CreateDynamicSlice(
2386      shape,
2387      builder.AddInstruction(
2388          HloInstruction::CreateParameter(0, shape, "slice_from")),
2389      builder.AddInstruction(
2390          HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))),
2391      /*slice_sizes=*/{10, 100, 1000}));
2392
2393  auto computation = module().AddEntryComputation(builder.Build());
2394  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2395                                 non_bitcasting_callback());
2396  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
2397  EXPECT_THAT(computation->root_instruction(), op::Parameter());
2398}
2399
2400// A dynamic-update-slice is trivial if its start indices are all zeroes and the
2401// size of its "update" equals the size of its output.  In this case, the
2402// dynamic-update-slice is equal to its update.
2403TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
2404  HloComputation::Builder builder(TestName());
2405
2406  Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
2407  Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
2408
2409  HloInstruction* slice =
2410      builder.AddInstruction(HloInstruction::CreateDynamicSlice(
2411          slice_shape,
2412          builder.AddInstruction(
2413              HloInstruction::CreateParameter(0, full_shape, "slice_from")),
2414          builder.AddInstruction(HloInstruction::CreateParameter(
2415              1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")),
2416          /*slice_sizes=*/{10, 1, 1000}));
2417
2418  builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2419      slice_shape,
2420      builder.AddInstruction(
2421          HloInstruction::CreateParameter(2, slice_shape, "to_update")),
2422      slice,
2423      builder.AddInstruction(
2424          HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0})))));
2425
2426  auto computation = module().AddEntryComputation(builder.Build());
2427  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2428                                 non_bitcasting_callback());
2429  ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
2430  EXPECT_THAT(computation->root_instruction(),
2431              op::DynamicSlice(op::Parameter(), op::Parameter()));
2432}
2433
2434struct PadReduceWindowEffectiveBroadcastCase {
2435  std::vector<int64> input_spatials;
2436  std::vector<int64> symmetric_pad_spatials;
2437  std::vector<int64> reduce_window_spatials;
2438  // Whether to use `B F S0 S1` form vs `B S0 S1 F` form.
2439  //
2440  // This doesn't test any different functionality but is useful for making sure
2441  // kBroadcast nodes are well formed.
2442  bool prepend_a;
2443  bool should_become_broadcast;
2444
2445  string ToTestCaseName() const {
2446    return tensorflow::strings::StrCat(
2447        tensorflow::str_util::Join(input_spatials, ","), ";",
2448        tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";",
2449        tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a,
2450        ";", should_become_broadcast);
2451  }
2452};
2453
2454void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) {
2455  *os << c.ToTestCaseName();
2456}
2457
2458class PadReduceWindowEffectiveBroadcastTest
2459    : public AlgebraicSimplifierTest,
2460      public ::testing::WithParamInterface<
2461          PadReduceWindowEffectiveBroadcastCase> {};
2462
2463TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
2464  const auto& param = GetParam();
2465
2466  // a and b are parallel bounds we can either turn into a B F S0 S1 or
2467  // `B S0 S1 F` kind of pattern.
2468  auto decorate_spatials = [&param](tensorflow::gtl::ArraySlice<int64> spatials,
2469                                    int64 a, int64 b) {
2470    std::vector<int64> result;
2471    if (param.prepend_a) {
2472      result.push_back(a);
2473    }
2474    for (int64 s : spatials) {
2475      result.push_back(s);
2476    }
2477    if (!param.prepend_a) {
2478      result.push_back(a);
2479    }
2480    result.push_back(b);
2481    return result;
2482  };
2483
2484  HloComputation::Builder builder(TestName());
2485  const Shape input_shape = ShapeUtil::MakeShape(
2486      F32, decorate_spatials(param.input_spatials, 128, 2048));
2487  HloInstruction* input = builder.AddInstruction(
2488      HloInstruction::CreateParameter(0, input_shape, "input"));
2489
2490  PaddingConfig padding = window_util::MakeSymmetricPadding(
2491      decorate_spatials(param.symmetric_pad_spatials, 0, 0));
2492  TF_ASSERT_OK_AND_ASSIGN(
2493      const Shape pad_shape,
2494      ShapeInference::InferPadShape(input->shape(),
2495                                    ShapeUtil::MakeShape(F32, {}), padding));
2496  HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
2497      pad_shape, input,
2498      builder.AddInstruction(
2499          HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
2500      padding));
2501
2502  HloComputation* add_computation = nullptr;
2503  {
2504    HloComputation::Builder builder(TestName() + ".add");
2505    const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2506    HloInstruction* p0 = builder.AddInstruction(
2507        HloInstruction::CreateParameter(0, scalar_shape, "p0"));
2508    HloInstruction* p1 = builder.AddInstruction(
2509        HloInstruction::CreateParameter(1, scalar_shape, "p1"));
2510    builder.AddInstruction(
2511        HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
2512    add_computation = module().AddEmbeddedComputation(builder.Build());
2513  }
2514
2515  Window window = window_util::MakeWindow(
2516      decorate_spatials(param.reduce_window_spatials, 1, 1));
2517  auto zero = builder.AddInstruction(
2518      HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
2519  TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
2520                          ShapeInference::InferReduceWindowShape(
2521                              pad->shape(), zero->shape(), window,
2522                              add_computation->ComputeProgramShape()));
2523  builder.AddInstruction(HloInstruction::CreateReduceWindow(
2524      output_shape, pad, zero, window, add_computation));
2525
2526  auto computation = module().AddEntryComputation(builder.Build());
2527  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2528                                 non_bitcasting_callback());
2529  TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
2530  ASSERT_TRUE(run_successful);
2531
2532  EXPECT_TRUE(
2533      ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape));
2534
2535  if (param.should_become_broadcast) {
2536    EXPECT_THAT(computation->root_instruction(), op::Broadcast(::testing::_));
2537  } else {
2538    EXPECT_THAT(computation->root_instruction(),
2539                op::ReduceWindow(::testing::_, zero));
2540  }
2541}
2542
2543const std::vector<PadReduceWindowEffectiveBroadcastCase>&
2544PadReduceWindowEffectiveBroadcastCases() {
2545  static auto* cases = new std::vector<PadReduceWindowEffectiveBroadcastCase>{
2546      {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
2547       /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
2548       /*should_become_broadcast=*/true},  //
2549      {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6},
2550       /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false,
2551       /*should_become_broadcast=*/true},  //
2552      {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6},
2553       /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true,
2554       /*should_become_broadcast=*/false},  //
2555      {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
2556       /*reduce_window_spatials=*/{5, 5}, /*prepend_a=*/true,
2557       /*should_become_broadcast=*/true},  //
2558      {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2},
2559       /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true,
2560       /*should_become_broadcast=*/false},  //
2561      {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2},
2562       /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true,
2563       /*should_become_broadcast=*/false},  //
2564  };
2565  return *cases;
2566}
2567
2568INSTANTIATE_TEST_CASE_P(
2569    PadReduceWindowEffectiveBroadcastInstantiation,
2570    PadReduceWindowEffectiveBroadcastTest,
2571    ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases()));
2572
2573class DotStrengthReductionTest
2574    : public AlgebraicSimplifierTest,
2575      public ::testing::WithParamInterface<
2576          ::testing::tuple<int, int, int, bool, bool>> {};
2577TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
2578  int m, k, n;
2579  bool transpose_lhs, transpose_rhs;
2580  std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam();
2581
2582  Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
2583  Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
2584  Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m});
2585  Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
2586  Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k});
2587  HloComputation::Builder builder(TestName());
2588
2589  auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
2590      0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs"));
2591  if (transpose_lhs) {
2592    lhs = builder.AddInstruction(
2593        HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0}));
2594  }
2595  auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
2596      1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs"));
2597  if (transpose_rhs) {
2598    rhs = builder.AddInstruction(
2599        HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0}));
2600  }
2601  DotDimensionNumbers dot_dnums;
2602  dot_dnums.add_lhs_contracting_dimensions(1);
2603  dot_dnums.add_rhs_contracting_dimensions(0);
2604  builder.AddInstruction(
2605      HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
2606  auto computation = module().AddEntryComputation(builder.Build());
2607  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2608                                 non_bitcasting_callback());
2609  TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module()));
2610  const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1;
2611  const bool computation_should_be_modified =
2612      dot_should_be_transformed || (transpose_lhs && transpose_rhs);
2613  EXPECT_EQ(changed, computation_should_be_modified);
2614  bool has_no_dot = true;
2615  for (const auto& hlo : computation->instructions()) {
2616    if (hlo->opcode() == HloOpcode::kDot) {
2617      has_no_dot = false;
2618      break;
2619    }
2620  }
2621  EXPECT_EQ(has_no_dot, dot_should_be_transformed);
2622}
2623
2624INSTANTIATE_TEST_CASE_P(
2625    DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
2626    ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
2627                       ::testing::Values(1, 2), ::testing::Bool(),
2628                       ::testing::Bool()));
2629
2630struct DotOfConcatTestSpec {
2631  int64 m;
2632  int64 k;
2633  int64 n;
2634};
2635
2636class DotOfConcatSimplificationTest
2637    : public HloVerifiedTestBase,
2638      public ::testing::WithParamInterface<DotOfConcatTestSpec> {};
2639
2640// Test that we transform
2641//  dot(const, concat(A, B, C))
2642// to
2643//  add(dot(const_0, A), dot(const_1, B),  dot(const_2, C))
2644TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
2645  HloComputation::Builder builder(TestName());
2646
2647  DotOfConcatTestSpec spec = GetParam();
2648
2649  ASSERT_GE(spec.k, 3);
2650
2651  int64 k0 = spec.k / 3;
2652  int64 k1 = spec.k / 3;
2653  int64 k2 = spec.k - k0 - k1;
2654
2655  Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
2656  auto* lhs = builder.AddInstruction(
2657      HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
2658          /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
2659
2660  Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
2661  Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n});
2662  Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n});
2663
2664  HloInstruction* rhs0 = builder.AddInstruction(
2665      HloInstruction::CreateParameter(0, rhs0_shape, "rhs0"));
2666  HloInstruction* rhs1 = builder.AddInstruction(
2667      HloInstruction::CreateParameter(1, rhs1_shape, "rhs1"));
2668  HloInstruction* rhs2 = builder.AddInstruction(
2669      HloInstruction::CreateParameter(2, rhs2_shape, "rhs2"));
2670
2671  Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
2672  HloInstruction* rhs = builder.AddInstruction(
2673      HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0));
2674
2675  DotDimensionNumbers dot_dnums;
2676  dot_dnums.add_lhs_contracting_dimensions(1);
2677  dot_dnums.add_rhs_contracting_dimensions(0);
2678
2679  Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
2680  builder.AddInstruction(
2681      HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
2682
2683  auto computation = module().AddEntryComputation(builder.Build());
2684  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2685                                 non_bitcasting_callback());
2686  TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
2687  ASSERT_TRUE(run_successful);
2688
2689  EXPECT_TRUE(
2690      ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
2691
2692  auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0));
2693  auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1));
2694  auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2));
2695  EXPECT_THAT(computation->root_instruction(),
2696              op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2));
2697}
2698
2699// Test that we transform
2700//  dot(concat(A, B, C), const)
2701// to
2702//  add(dot(A, const_0), dot(B, const_1),  dot(C, const_2))
2703TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
2704  HloComputation::Builder builder(TestName());
2705
2706  DotOfConcatTestSpec spec = GetParam();
2707
2708  ASSERT_GE(spec.k, 4);
2709
2710  int64 k0 = spec.k / 4;
2711  int64 k1 = spec.k / 4;
2712  int64 k2 = spec.k / 4;
2713  int64 k3 = spec.k - k0 - k1 - k2;
2714
2715  Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0});
2716  Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1});
2717  Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2});
2718  Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3});
2719
2720  HloInstruction* lhs0 = builder.AddInstruction(
2721      HloInstruction::CreateParameter(0, lhs0_shape, "lhs0"));
2722  HloInstruction* lhs1 = builder.AddInstruction(
2723      HloInstruction::CreateParameter(1, lhs1_shape, "lhs1"));
2724  HloInstruction* lhs2 = builder.AddInstruction(
2725      HloInstruction::CreateParameter(2, lhs2_shape, "lhs2"));
2726  HloInstruction* lhs3 = builder.AddInstruction(
2727      HloInstruction::CreateParameter(3, lhs3_shape, "lhs3"));
2728
2729  Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
2730  HloInstruction* lhs =
2731      builder.AddInstruction(HloInstruction::CreateConcatenate(
2732          lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1));
2733
2734  Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
2735  auto* rhs = builder.AddInstruction(
2736      HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
2737          /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
2738
2739  DotDimensionNumbers dot_dnums;
2740  dot_dnums.add_lhs_contracting_dimensions(1);
2741  dot_dnums.add_rhs_contracting_dimensions(0);
2742
2743  Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
2744  builder.AddInstruction(
2745      HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
2746
2747  auto computation = module().AddEntryComputation(builder.Build());
2748  AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
2749                                 non_bitcasting_callback());
2750  TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module()));
2751  ASSERT_TRUE(run_successful);
2752  EXPECT_TRUE(
2753      ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape));
2754
2755  auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant()));
2756  auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant()));
2757  auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant()));
2758  auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant()));
2759  EXPECT_THAT(computation->root_instruction(),
2760              op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2),
2761                      match_dot_3));
2762}
2763
2764DotOfConcatTestSpec kDotOfConcatTestSpecs[] = {
2765    {/*m=*/3, /*k=*/9, /*n=*/3},    //
2766    {/*m=*/3, /*k=*/20, /*n=*/3},   //
2767    {/*m=*/1, /*k=*/18, /*n=*/5},   //
2768    {/*m=*/20, /*k=*/20, /*n=*/1},  //
2769    {/*m=*/1, /*k=*/16, /*n=*/1},   //
2770};
2771
2772INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation,
2773                        DotOfConcatSimplificationTest,
2774                        ::testing::ValuesIn(kDotOfConcatTestSpecs));
2775}  // namespace
2776}  // namespace xla
2777