1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cmath>
17#include <limits>
18#include <memory>
19#include <numeric>
20#include <vector>
21
22#include "tensorflow/compiler/xla/array2d.h"
23#include "tensorflow/compiler/xla/array3d.h"
24#include "tensorflow/compiler/xla/array4d.h"
25#include "tensorflow/compiler/xla/client/computation_builder.h"
26#include "tensorflow/compiler/xla/client/global_data.h"
27#include "tensorflow/compiler/xla/client/local_client.h"
28#include "tensorflow/compiler/xla/layout_util.h"
29#include "tensorflow/compiler/xla/literal_util.h"
30#include "tensorflow/compiler/xla/statusor.h"
31#include "tensorflow/compiler/xla/test.h"
32#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
33#include "tensorflow/compiler/xla/tests/literal_test_util.h"
34#include "tensorflow/compiler/xla/tests/test_macros.h"
35#include "tensorflow/compiler/xla/types.h"
36#include "tensorflow/compiler/xla/xla_data.pb.h"
37#include "tensorflow/core/lib/core/casts.h"
38#include "tensorflow/core/platform/types.h"
39
40namespace xla {
41namespace {
42
43class ArrayElementwiseOpTest : public ClientLibraryTestBase {
44 public:
45  ErrorSpec error_spec_{0.0001, 0.0001};
46};
47
48class ArrayElementwiseOpTestParamCount
49    : public ArrayElementwiseOpTest,
50      public ::testing::WithParamInterface<int> {};
51
52XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
53  ComputationBuilder builder(client_, TestName());
54  auto a = builder.ConstantR1<float>({});
55  auto result = builder.Neg(a);
56
57  ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
58}
59
60XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
61  ComputationBuilder builder(client_, TestName());
62  auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
63  auto result = builder.Neg(a);
64
65  ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {},
66                             error_spec_);
67}
68
69XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
70  ComputationBuilder builder(client_, TestName());
71  auto a = builder.ConstantR1<int32>({-1, 0, 1, 324,
72                                      std::numeric_limits<int32>::min(),
73                                      std::numeric_limits<int32>::max()});
74  auto result = builder.Neg(a);
75
76  // -min == min for int32 due to an overflow. In C++ it is undefined behavior
77  // to do this calculation. For XLA we have not specified that, so it
78  // ought to work.
79  ComputeAndCompareR1<int32>(&builder,
80                             {1, 0, -1, -324, std::numeric_limits<int32>::min(),
81                              -std::numeric_limits<int32>::max()},
82                             {});
83}
84
85XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) {
86  ComputationBuilder builder(client_, TestName());
87  auto a = builder.ConstantR1<complex64>({});
88  auto result = builder.Neg(a);
89
90  ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
91}
92
93XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) {
94  ComputationBuilder builder(client_, TestName());
95  auto a = builder.ConstantR1<complex64>(
96      {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}});
97  auto result = builder.Neg(a);
98
99  ComputeAndCompareR1<complex64>(
100      &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}},
101      {}, error_spec_);
102}
103
104XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
105  ComputationBuilder builder(client_, TestName());
106  auto a = builder.ConstantR1<float>({});
107  auto result = builder.IsFinite(a);
108
109  ComputeAndCompareR1<bool>(&builder, {}, {});
110}
111
112// A non-canonical quiet NaN value.
113static const float kNonCanonicalNaN = tensorflow::bit_cast<float>(0x7FD01234);
114
115XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) {
116  ComputationBuilder builder(client_, TestName());
117  auto result = builder.IsFinite(builder.ConstantR0<float>(NAN));
118  ComputeAndCompareR0<bool>(&builder, false, {});
119
120  EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
121  auto result_non_canonical =
122      builder.IsFinite(builder.ConstantR0<float>(kNonCanonicalNaN));
123  ComputeAndCompareR0<bool>(&builder, false, {});
124
125  const float inf = std::numeric_limits<float>::infinity();
126  auto result_inf = builder.IsFinite(builder.ConstantR0<float>(inf));
127  ComputeAndCompareR0<bool>(&builder, false, {});
128
129  auto result_neg_inf = builder.IsFinite(builder.ConstantR0<float>(-inf));
130  ComputeAndCompareR0<bool>(&builder, false, {});
131
132  auto result_zero = builder.IsFinite(builder.ConstantR0<float>(0.0f));
133  ComputeAndCompareR0<bool>(&builder, true, {});
134}
135
136XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
137  ComputationBuilder builder(client_, TestName());
138  const float inf = std::numeric_limits<float>::infinity();
139  EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
140  auto a = builder.ConstantR1<float>(
141      {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}});
142  auto result = builder.IsFinite(a);
143
144  ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false},
145                            {});
146}
147
148XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
149  ComputationBuilder builder(client_, TestName());
150  auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
151  auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
152  auto add = builder.Add(a, b);
153
154  ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {},
155                             error_spec_);
156}
157
158XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
159  ComputationBuilder builder(client_, TestName());
160  auto a = builder.ConstantR1<float>({});
161  auto b = builder.ConstantR1<float>({});
162  auto add = builder.Add(a, b);
163
164  ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
165}
166
167XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) {
168  ComputationBuilder builder(client_, TestName());
169  auto a = builder.ConstantR1<complex64>(
170      {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}});
171  auto b = builder.ConstantR1<complex64>(
172      {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}});
173  auto add = builder.Add(a, b);
174
175  ComputeAndCompareR1<complex64>(
176      &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {},
177      error_spec_);
178}
179
180XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) {
181  ComputationBuilder builder(client_, TestName());
182  auto a = builder.ConstantR1<complex64>({});
183  auto b = builder.ConstantR1<complex64>({});
184  auto add = builder.Add(a, b);
185
186  ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
187}
188
189TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
190  const int count = GetParam();
191  ComputationBuilder builder(client_, TestName());
192  std::vector<float> a_values;
193  std::vector<float> b_values;
194  for (int i = 0; i < count; ++i) {
195    a_values.push_back(i / static_cast<float>(count));
196    b_values.push_back(2 * i / static_cast<float>(count + 2));
197  }
198
199  std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({a_values});
200  std::unique_ptr<GlobalData> a_data =
201      client_->TransferToServer(*a_literal).ConsumeValueOrDie();
202  auto a_constant = builder.ConstantR1<float>(a_values);
203  auto a_param = builder.Parameter(0, a_literal->shape(), "a_param");
204
205  std::unique_ptr<Literal> b_literal = Literal::CreateR1<float>({b_values});
206  std::unique_ptr<GlobalData> b_data =
207      client_->TransferToServer(*b_literal).ConsumeValueOrDie();
208  auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param");
209  auto b_param = builder.ConstantR1<float>(b_values);
210
211  auto sum1 = builder.Add(a_constant, b_constant);
212  auto sum2 = builder.Add(a_constant, b_param);
213  auto sum3 = builder.Add(a_param, b_constant);
214  auto sum4 = builder.Add(a_param, b_param);
215
216  auto sum = builder.Add(sum1, sum2);
217  sum = builder.Add(sum, sum3);
218  sum = builder.Add(sum, sum4);
219
220  std::vector<float> expected;
221  for (int64 i = 0; i < count; ++i) {
222    expected.push_back(4 * (a_values[i] + b_values[i]));
223  }
224
225  ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()},
226                             error_spec_);
227}
228
229XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
230  ComputationBuilder builder(client_, TestName());
231  auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
232  auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
233  auto add = builder.Sub(a, b);
234
235  ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f},
236                             {}, error_spec_);
237}
238
239XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
240  ComputationBuilder builder(client_, TestName());
241  auto a = builder.ConstantR1<float>({});
242  auto b = builder.ConstantR1<float>({});
243  auto add = builder.Sub(a, b);
244
245  ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
246}
247
248XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
249  ComputationBuilder builder(client_, TestName());
250  auto a = builder.ConstantR1<int32>({-1, 0, 2, 1000000000});
251  auto b = builder.ConstantR1<int32>({-1, 2, 1, -1});
252  auto add = builder.Sub(a, b);
253
254  ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {});
255}
256
257XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
258  ComputationBuilder builder(client_, TestName());
259  auto a = builder.ConstantR1<int32>({});
260  auto b = builder.ConstantR1<int32>({});
261  auto add = builder.Sub(a, b);
262
263  ComputeAndCompareR1<int32>(&builder, {}, {});
264}
265
266XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) {
267  ComputationBuilder builder(client_, TestName());
268  auto a = builder.ConstantR1<complex64>(
269      {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}});
270  auto b = builder.ConstantR1<complex64>(
271      {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}});
272  auto add = builder.Sub(a, b);
273
274  ComputeAndCompareR1<complex64>(
275      &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {},
276      error_spec_);
277}
278
279XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) {
280  ComputationBuilder builder(client_, TestName());
281  auto a = builder.ConstantR1<complex64>({});
282  auto b = builder.ConstantR1<complex64>({});
283  auto add = builder.Sub(a, b);
284
285  ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
286}
287
288XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
289  ComputationBuilder builder(client_, TestName());
290  auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
291  auto b = builder.ConstantR1<float>({10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
292  auto add = builder.Div(a, b);
293
294  ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {},
295                             error_spec_);
296}
297
298XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
299  ComputationBuilder builder(client_, TestName());
300  auto a = builder.ConstantR1<float>({});
301  auto b = builder.ConstantR1<float>({});
302  auto add = builder.Div(a, b);
303
304  ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
305}
306
307XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) {
308  // clang-format off
309  // Some interesting values to test.
310  std::vector<int32> vals = {
311    INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff,
312    -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101,
313    7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX};
314  // clang-format on
315
316  std::vector<int32> dividends, divisors, quotients, remainders;
317  for (int32 divisor : vals) {
318    if (divisor != 0) {
319      for (int32 dividend : vals) {
320        // Avoid integer overflow.
321        if (dividend != INT32_MIN || divisor != -1) {
322          dividends.push_back(dividend);
323          divisors.push_back(divisor);
324          quotients.push_back(dividend / divisor);
325          remainders.push_back(dividend % divisor);
326        }
327      }
328    }
329  }
330
331  {
332    ComputationBuilder builder(client_, TestName());
333    ComputationDataHandle dividend;
334    ComputationDataHandle divisor;
335    auto dividend_data =
336        CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
337    auto divisor_data =
338        CreateR1Parameter<int32>(divisors, 1, "divisor", &builder, &divisor);
339    builder.Div(dividend, divisor);
340
341    ComputeAndCompareR1<int32>(&builder, quotients,
342                               {dividend_data.get(), divisor_data.get()});
343  }
344
345  // Test with a compile-time constant divisor.
346  {
347    ComputationBuilder builder(client_, TestName());
348    ComputationDataHandle dividend;
349    auto dividend_data =
350        CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
351    builder.Div(dividend, builder.ConstantR1<int32>(divisors));
352
353    ComputeAndCompareR1<int32>(&builder, quotients, {dividend_data.get()});
354  }
355
356  {
357    ComputationBuilder builder(client_, TestName());
358    ComputationDataHandle dividend;
359    ComputationDataHandle divisor;
360    auto dividend_data =
361        CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
362    auto divisor_data =
363        CreateR1Parameter<int32>(divisors, 1, "divisor", &builder, &divisor);
364    builder.Rem(dividend, divisor);
365
366    ComputeAndCompareR1<int32>(&builder, remainders,
367                               {dividend_data.get(), divisor_data.get()});
368  }
369
370  // Test with a compile-time constant divisor.
371  {
372    ComputationBuilder builder(client_, TestName());
373    ComputationDataHandle dividend;
374    auto dividend_data =
375        CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
376    builder.Rem(dividend, builder.ConstantR1<int32>(divisors));
377
378    ComputeAndCompareR1<int32>(&builder, remainders, {dividend_data.get()});
379  }
380}
381
382XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
383  // clang-format off
384  // Some interesting values to test.
385  std::vector<uint32> vals = {
386    0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000,
387    0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX};
388  // clang-format on
389
390  std::vector<uint32> dividends, divisors, quotients, remainders;
391  for (uint32 divisor : vals) {
392    if (divisor != 0) {
393      for (uint32 dividend : vals) {
394        dividends.push_back(dividend);
395        divisors.push_back(divisor);
396        quotients.push_back(dividend / divisor);
397        remainders.push_back(dividend % divisor);
398      }
399    }
400  }
401
402  {
403    ComputationBuilder builder(client_, TestName());
404    ComputationDataHandle dividend;
405    ComputationDataHandle divisor;
406    auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
407                                                   &builder, &dividend);
408    auto divisor_data =
409        CreateR1Parameter<uint32>(divisors, 1, "divisor", &builder, &divisor);
410    builder.Div(dividend, divisor);
411
412    ComputeAndCompareR1<uint32>(&builder, quotients,
413                                {dividend_data.get(), divisor_data.get()});
414  }
415
416  {
417    ComputationBuilder builder(client_, TestName());
418    ComputationDataHandle dividend;
419    auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
420                                                   &builder, &dividend);
421    builder.Div(dividend, builder.ConstantR1<uint32>(divisors));
422
423    ComputeAndCompareR1<uint32>(&builder, quotients, {dividend_data.get()});
424  }
425
426  {
427    ComputationBuilder builder(client_, TestName());
428    ComputationDataHandle dividend;
429    ComputationDataHandle divisor;
430    auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
431                                                   &builder, &dividend);
432    auto divisor_data =
433        CreateR1Parameter<uint32>(divisors, 1, "divisor", &builder, &divisor);
434    builder.Rem(dividend, divisor);
435
436    ComputeAndCompareR1<uint32>(&builder, remainders,
437                                {dividend_data.get(), divisor_data.get()});
438  }
439
440  {
441    ComputationBuilder builder(client_, TestName());
442    ComputationDataHandle dividend;
443    auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
444                                                   &builder, &dividend);
445    builder.Rem(dividend, builder.ConstantR1<uint32>(divisors));
446
447    ComputeAndCompareR1<uint32>(&builder, remainders, {dividend_data.get()});
448  }
449}
450
451XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
452  ComputationBuilder builder(client_, TestName());
453  auto a = builder.ConstantR1<complex64>(
454      {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}});
455  auto b = builder.ConstantR1<complex64>(
456      {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}});
457  auto div = builder.Div(a, b);
458
459  ComputeAndCompareR1<complex64>(
460      &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_);
461}
462
463XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) {
464  ComputationBuilder builder(client_, TestName());
465  auto a = builder.ConstantR1<complex64>({});
466  auto b = builder.ConstantR1<complex64>({});
467  auto div = builder.Div(a, b);
468
469  ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
470}
471
472XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
473  ComputationBuilder builder(client_, TestName());
474  auto a = builder.ConstantR1<float>(
475      {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
476  auto b = builder.ConstantR1<float>(
477      {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
478  auto add = builder.Rem(a, b);
479
480  ComputeAndCompareR1<float>(
481      &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {},
482      error_spec_);
483}
484
485XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) {
486  ComputationBuilder builder(client_, TestName());
487  auto a = builder.ConstantR1<float>({});
488  auto b = builder.ConstantR1<float>({});
489  auto add = builder.Rem(a, b);
490
491  ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
492}
493
494XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
495  ComputationBuilder builder(client_, TestName());
496  auto a = builder.ConstantR1<double>(
497      {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
498  auto b = builder.ConstantR1<double>(
499      {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
500  auto add = builder.Rem(a, b);
501
502  ComputeAndCompareR1<double>(
503      &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {},
504      error_spec_);
505}
506
507XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
508  ComputationBuilder builder(client_, TestName());
509  auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
510  auto b = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
511  auto add = builder.Mul(a, b);
512
513  ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f},
514                             {}, error_spec_);
515}
516
517XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
518  ComputationBuilder builder(client_, TestName());
519  auto a = builder.ConstantR1<float>({});
520  auto b = builder.ConstantR1<float>({});
521  auto add = builder.Mul(a, b);
522
523  ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
524}
525
526XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
527  std::vector<int32> data = {0,
528                             1,
529                             -1,
530                             1234,
531                             0x1a243514,
532                             std::numeric_limits<int32>::max(),
533                             std::numeric_limits<int32>::min()};
534  // Form the test data set using all products of 'data' with itself.
535  std::vector<int32> a_data, b_data, expected;
536  for (int32 a : data) {
537    for (int32 b : data) {
538      a_data.push_back(a);
539      b_data.push_back(b);
540      expected.push_back(static_cast<uint32>(a) * static_cast<uint32>(b));
541    }
542  }
543
544  ComputationBuilder builder(client_, TestName());
545  auto a = builder.ConstantR1<int32>(a_data);
546  auto b = builder.ConstantR1<int32>(b_data);
547  auto add = builder.Mul(a, b);
548
549  ComputeAndCompareR1<int32>(&builder, expected, {});
550}
551
552XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
553  ComputationBuilder builder(client_, TestName());
554  auto a = builder.ConstantR1<int32>({});
555  auto b = builder.ConstantR1<int32>({});
556  auto add = builder.Mul(a, b);
557
558  ComputeAndCompareR1<int32>(&builder, {}, {});
559}
560
561XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
562  std::vector<uint32> data = {0,          1,          0xDEADBEEF, 1234,
563                              0x1a243514, 0xFFFFFFFF, 0x80808080};
564
565  // Form the test data set using all products of 'data' with itself.
566  std::vector<uint32> a_data, b_data, expected;
567  for (uint32 a : data) {
568    for (uint32 b : data) {
569      a_data.push_back(a);
570      b_data.push_back(b);
571      expected.push_back(a * b);
572    }
573  }
574
575  ComputationBuilder builder(client_, TestName());
576  auto a = builder.ConstantR1<uint32>(a_data);
577  auto b = builder.ConstantR1<uint32>(b_data);
578  auto add = builder.Mul(a, b);
579
580  ComputeAndCompareR1<uint32>(&builder, expected, {});
581}
582
583XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) {
584  ComputationBuilder builder(client_, TestName());
585  auto a = builder.ConstantR1<complex64>(
586      {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}});
587  auto b = builder.ConstantR1<complex64>(
588      {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}});
589  auto add = builder.Mul(a, b);
590
591  ComputeAndCompareR1<complex64>(
592      &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {},
593      error_spec_);
594}
595
596XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) {
597  ComputationBuilder builder(client_, TestName());
598  auto a = builder.ConstantR1<complex64>({});
599  auto b = builder.ConstantR1<complex64>({});
600  auto add = builder.Mul(a, b);
601
602  ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
603}
604
605XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) {
606  ComputationBuilder builder(client_, TestName());
607  auto a = builder.ConstantR1<bool>({false, false, true, true});
608  auto b = builder.ConstantR1<bool>({false, true, false, true});
609  auto out = builder.And(a, b);
610
611  ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
612}
613
614XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) {
615  ComputationBuilder builder(client_, TestName());
616  auto a = builder.ConstantR2<bool>({{false, false}, {true, true}});
617  auto b = builder.ConstantR2<bool>({{false, true}, {false, true}});
618  auto out = builder.And(a, b);
619
620  Array2D<bool> expected_array({{false, false}, {false, true}});
621  ComputeAndCompareR2<bool>(&builder, expected_array, {});
622}
623
624XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) {
625  ComputationBuilder builder(client_, TestName());
626  auto a = builder.ConstantR1<bool>({});
627  auto b = builder.ConstantR1<bool>({});
628  auto out = builder.And(a, b);
629
630  ComputeAndCompareR1<bool>(&builder, {}, {});
631}
632
633XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) {
634  ComputationBuilder builder(client_, TestName());
635  auto a = builder.ConstantR1<int32>({0, -1, -8});
636  auto b = builder.ConstantR1<int32>({5, -7, 12});
637  auto out = builder.And(a, b);
638
639  ComputeAndCompareR1<int32>(&builder, {0, -7, 8}, {});
640}
641
642XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) {
643  ComputationBuilder builder(client_, TestName());
644  auto a = builder.ConstantR2<int32>({{0, -5}, {-1, 5}});
645  auto b = builder.ConstantR2<int32>({{1, -6}, {4, 5}});
646  auto out = builder.And(a, b);
647
648  Array2D<int32> expected_array({{0, -6}, {4, 5}});
649  ComputeAndCompareR2<int32>(&builder, expected_array, {});
650}
651
652XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) {
653  ComputationBuilder builder(client_, TestName());
654  auto a = builder.ConstantR1<int32>({});
655  auto b = builder.ConstantR1<int32>({});
656  auto out = builder.And(a, b);
657
658  ComputeAndCompareR1<int32>(&builder, {}, {});
659}
660
661XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) {
662  ComputationBuilder builder(client_, TestName());
663  auto a = builder.ConstantR1<int32>({0, 1, 8});
664  auto b = builder.ConstantR1<int32>({5, 7, 12});
665  auto out = builder.And(a, b);
666
667  ComputeAndCompareR1<int32>(&builder, {0, 1, 8}, {});
668}
669
670XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) {
671  ComputationBuilder builder(client_, TestName());
672  auto a = builder.ConstantR2<uint32>({{0, 1}, {3, 8}});
673  auto b = builder.ConstantR2<uint32>({{1, 0}, {7, 6}});
674  auto out = builder.And(a, b);
675
676  Array2D<uint32> expected_array({{0, 0}, {3, 0}});
677  ComputeAndCompareR2<uint32>(&builder, expected_array, {});
678}
679
680XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) {
681  ComputationBuilder builder(client_, TestName());
682  auto a = builder.ConstantR1<uint32>({});
683  auto b = builder.ConstantR1<uint32>({});
684  auto out = builder.And(a, b);
685
686  ComputeAndCompareR1<uint32>(&builder, {}, {});
687}
688
689XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) {
690  ComputationBuilder builder(client_, TestName());
691  auto a = builder.ConstantR1<bool>({false, false, true, true});
692  auto b = builder.ConstantR1<bool>({false, true, false, true});
693  auto out = builder.Or(a, b);
694
695  ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
696}
697
698XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) {
699  ComputationBuilder builder(client_, TestName());
700  auto a = builder.ConstantR2<bool>({{false, false}, {true, true}});
701  auto b = builder.ConstantR2<bool>({{false, true}, {false, true}});
702  auto out = builder.Or(a, b);
703
704  Array2D<bool> expected_array({{false, true}, {true, true}});
705  ComputeAndCompareR2<bool>(&builder, expected_array, {});
706}
707
708XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) {
709  ComputationBuilder builder(client_, TestName());
710  auto a = builder.ConstantR1<bool>({});
711  auto b = builder.ConstantR1<bool>({});
712  auto out = builder.Or(a, b);
713
714  ComputeAndCompareR1<bool>(&builder, {}, {});
715}
716
717XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) {
718  ComputationBuilder builder(client_, TestName());
719  auto a = builder.ConstantR1<int32>({0, -1, 8});
720  auto b = builder.ConstantR1<int32>({5, -7, 4});
721  auto out = builder.Or(a, b);
722
723  ComputeAndCompareR1<int32>(&builder, {5, -1, 12}, {});
724}
725
726XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) {
727  ComputationBuilder builder(client_, TestName());
728  auto a = builder.ConstantR2<int32>({{0, -1}, {8, 8}});
729  auto b = builder.ConstantR2<int32>({{5, -7}, {4, 1}});
730  auto out = builder.Or(a, b);
731
732  Array2D<int32> expected_array({{5, -1}, {12, 9}});
733  ComputeAndCompareR2<int32>(&builder, expected_array, {});
734}
735
736XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) {
737  ComputationBuilder builder(client_, TestName());
738  auto a = builder.ConstantR1<int32>({});
739  auto b = builder.ConstantR1<int32>({});
740  auto out = builder.Or(a, b);
741
742  ComputeAndCompareR1<int32>(&builder, {}, {});
743}
744
745XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) {
746  ComputationBuilder builder(client_, TestName());
747  auto a = builder.ConstantR1<uint32>({0, 1, 8});
748  auto b = builder.ConstantR1<uint32>({5, 7, 4});
749  auto out = builder.Or(a, b);
750
751  ComputeAndCompareR1<uint32>(&builder, {5, 7, 12}, {});
752}
753
754XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) {
755  ComputationBuilder builder(client_, TestName());
756  auto a = builder.ConstantR2<uint32>({{0, 1}, {8, 8}});
757  auto b = builder.ConstantR2<uint32>({{5, 7}, {4, 1}});
758  auto out = builder.Or(a, b);
759
760  Array2D<uint32> expected_array({{5, 7}, {12, 9}});
761  ComputeAndCompareR2<uint32>(&builder, expected_array, {});
762}
763
764XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) {
765  ComputationBuilder builder(client_, TestName());
766  auto a = builder.ConstantR1<uint32>({});
767  auto b = builder.ConstantR1<uint32>({});
768  auto out = builder.Or(a, b);
769
770  ComputeAndCompareR1<uint32>(&builder, {}, {});
771}
772
773XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) {
774  ComputationBuilder builder(client_, TestName());
775  auto a = builder.ConstantR1<bool>({false, true, true, false});
776  auto out = builder.Not(a);
777
778  ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
779}
780
781XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) {
782  ComputationBuilder builder(client_, TestName());
783  auto a = builder.ConstantR2<bool>({{false, true}, {true, false}});
784  auto out = builder.Not(a);
785
786  Array2D<bool> expected_array({{true, false}, {false, true}});
787  ComputeAndCompareR2<bool>(&builder, expected_array, {});
788}
789
790XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) {
791  ComputationBuilder builder(client_, TestName());
792  auto a = builder.ConstantR1<bool>({});
793  auto out = builder.Not(a);
794
795  ComputeAndCompareR1<bool>(&builder, {}, {});
796}
797
798XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) {
799  ComputationBuilder builder(client_, TestName());
800  auto a = builder.ConstantR1<int32>({-1, 0, 1});
801  auto out = builder.Not(a);
802
803  ComputeAndCompareR1<int32>(&builder, {0, -1, -2}, {});
804}
805
806XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) {
807  ComputationBuilder builder(client_, TestName());
808  auto a = builder.ConstantR2<int32>({{-1, 0}, {1, 8}});
809  auto out = builder.Not(a);
810
811  Array2D<int32> expected_array({{0, -1}, {-2, -9}});
812  ComputeAndCompareR2<int32>(&builder, expected_array, {});
813}
814
815XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) {
816  ComputationBuilder builder(client_, TestName());
817  auto a = builder.ConstantR1<int32>({});
818  auto out = builder.Not(a);
819
820  ComputeAndCompareR1<int32>(&builder, {}, {});
821}
822
823XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) {
824  ComputationBuilder builder(client_, TestName());
825  auto a = builder.ConstantR1<uint32>({0, 4294967295});
826  auto out = builder.Not(a);
827
828  ComputeAndCompareR1<uint32>(&builder, {4294967295, 0}, {});
829}
830
831XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) {
832  ComputationBuilder builder(client_, TestName());
833  auto a = builder.ConstantR2<uint32>({{0, 4294967295}, {1, 4294967294}});
834  auto out = builder.Not(a);
835
836  Array2D<uint32> expected_array({{4294967295, 0}, {4294967294, 1}});
837  ComputeAndCompareR2<uint32>(&builder, expected_array, {});
838}
839
840XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
841  ComputationBuilder builder(client_, TestName());
842  auto a = builder.ConstantR1<uint32>({});
843  auto out = builder.Not(a);
844
845  ComputeAndCompareR1<uint32>(&builder, {}, {});
846}
847
848XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
849  ComputationBuilder builder(client_, TestName());
850  auto a =
851      builder.ConstantR1<int32>({static_cast<int32>(0x12345678),
852                                 static_cast<int32>(0xF0001000), 1, 3, 77});
853  auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 15});
854  auto out = builder.ShiftLeft(a, b);
855
856  ComputeAndCompareR1<int32>(
857      &builder,
858      {static_cast<int32>(0x23456780), 0x00100000, 0x4, 0x180, 2523136}, {});
859}
860
861XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) {
862  ComputationBuilder builder(client_, TestName());
863  auto a =
864      builder.ConstantR1<int32>({static_cast<int32>(0x92345678),
865                                 static_cast<int32>(0x10001000), 1, 3, 77});
866  auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 2});
867  auto out = builder.ShiftRightArithmetic(a, b);
868
869  ComputeAndCompareR1<int32>(&builder,
870                             {static_cast<int32>(0xF9234567),
871                              static_cast<int32>(0x00100010), 0, 0, 19},
872                             {});
873}
874
875XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) {
876  ComputationBuilder builder(client_, TestName());
877  auto a =
878      builder.ConstantR1<int32>({static_cast<int32>(0x92345678),
879                                 static_cast<int32>(0x10001000), 1, 3, 77});
880  auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 5});
881  auto out = builder.ShiftRightLogical(a, b);
882
883  ComputeAndCompareR1<int32>(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {});
884}
885
886XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) {
887  ComputationBuilder builder(client_, TestName());
888  auto a = builder.ConstantR1<uint32>({0x12345678, 0xF0001000, 1, 3, 77});
889  auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 15});
890  auto out = builder.ShiftLeft(a, b);
891
892  ComputeAndCompareR1<uint32>(
893      &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136}, {});
894}
895
896XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) {
897  ComputationBuilder builder(client_, TestName());
898  auto a = builder.ConstantR1<uint32>({0x92345678, 0x10001000, 1, 3, 77});
899  auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 2});
900  auto out = builder.ShiftRightArithmetic(a, b);
901
902  ComputeAndCompareR1<uint32>(&builder, {0xF9234567, 0x00100010, 0, 0, 19}, {});
903}
904
905XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) {
906  ComputationBuilder builder(client_, TestName());
907  auto a = builder.ConstantR1<uint32>({0x92345678, 0x10001000, 1, 3, 77});
908  auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 5});
909  auto out = builder.ShiftRightLogical(a, b);
910
911  ComputeAndCompareR1<uint32>(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {});
912}
913
914XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
915  SetFastMathDisabled(true);
916  ComputationBuilder builder(client_, TestName());
917  auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
918  auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 2.25f, 10.0f, NAN});
919  auto compare = builder.Eq(lhs, rhs);
920
921  ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
922}
923
924XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
925  ComputationBuilder builder(client_, TestName());
926  auto lhs = builder.ConstantR1<float>({});
927  auto rhs = builder.ConstantR1<float>({});
928  auto compare = builder.Eq(lhs, rhs);
929
930  ComputeAndCompareR1<bool>(&builder, {}, {});
931}
932
933XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
934  SetFastMathDisabled(true);
935  ComputationBuilder builder(client_, TestName());
936  auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
937  auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
938  auto compare = builder.Ge(lhs, rhs);
939
940  ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
941}
942
943XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
944  SetFastMathDisabled(true);
945  ComputationBuilder builder(client_, TestName());
946  auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
947  auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
948  auto compare = builder.Gt(lhs, rhs);
949
950  ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
951}
952
953XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
954  SetFastMathDisabled(true);
955  ComputationBuilder builder(client_, TestName());
956  auto lhs = builder.ConstantR1<float>({-2.5f, 5.0f, 2.25f, NAN, 6.0f});
957  auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
958  auto compare = builder.Le(lhs, rhs);
959
960  ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
961}
962
963XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
964  SetFastMathDisabled(true);
965  ComputationBuilder builder(client_, TestName());
966  auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
967  auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
968  auto compare = builder.Lt(lhs, rhs);
969
970  ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
971}
972
973XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
974  const int32 min = std::numeric_limits<int32>::min();
975  const int32 max = std::numeric_limits<int32>::max();
976  ComputationBuilder builder(client_, TestName());
977  auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
978  auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
979  auto compare = builder.Eq(lhs, rhs);
980
981  ComputeAndCompareR1<bool>(
982      &builder, {true, false, false, false, true, false, false, false, true},
983      {});
984}
985
986XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
987  ComputationBuilder builder(client_, TestName());
988  auto lhs = builder.ConstantR1<int32>({});
989  auto rhs = builder.ConstantR1<int32>({});
990  auto compare = builder.Eq(lhs, rhs);
991
992  ComputeAndCompareR1<bool>(&builder, {}, {});
993}
994
995XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) {
996  SetFastMathDisabled(true);
997  ComputationBuilder builder(client_, TestName());
998  auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f},
999                                            {1.0f, 25.5f},
1000                                            {2.25f, -3.0f},
1001                                            {NAN, 0.0f},
1002                                            {1.0f, 6.0f}});
1003  auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f},
1004                                            {1.0f, 5.0f},
1005                                            {2.25f, -3.0f},
1006                                            {10.0f, 0.0f},
1007                                            {1.0f, NAN}});
1008  auto compare = builder.Eq(lhs, rhs);
1009
1010  ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1011}
1012
1013XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) {
1014  ComputationBuilder builder(client_, TestName());
1015  auto lhs = builder.ConstantR1<complex64>({});
1016  auto rhs = builder.ConstantR1<complex64>({});
1017  auto compare = builder.Eq(lhs, rhs);
1018
1019  ComputeAndCompareR1<bool>(&builder, {}, {});
1020}
1021
1022XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) {
1023  // Disable fast-math because we're operating on NaNs.
1024  SetFastMathDisabled(true);
1025
1026  ComputationBuilder builder(client_, TestName());
1027  auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f},
1028                                            {1.0f, 25.5f},
1029                                            {2.25f, -3.0f},
1030                                            {NAN, 0.0f},
1031                                            {1.0f, 6.0f}});
1032  auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f},
1033                                            {1.0f, 5.0f},
1034                                            {2.25f, -3.0f},
1035                                            {10.0f, 0.0f},
1036                                            {1.0f, NAN}});
1037  auto compare = builder.Ne(lhs, rhs);
1038
1039  ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {});
1040}
1041
1042XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
1043  // Disable fast-math because we're operating on NaNs.
1044  SetFastMathDisabled(true);
1045
1046  ComputationBuilder builder(client_, TestName());
1047  auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1048  auto rhs = builder.ConstantR1<float>({10.0f, 25.5f, 1.0f, 10.0f, NAN});
1049  auto compare = builder.Ne(lhs, rhs);
1050
1051  ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
1052}
1053
1054XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
1055  const int32 min = std::numeric_limits<int32>::min();
1056  const int32 max = std::numeric_limits<int32>::max();
1057  ComputationBuilder builder(client_, TestName());
1058  auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
1059  auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
1060  auto compare = builder.Ne(lhs, rhs);
1061
1062  ComputeAndCompareR1<bool>(
1063      &builder, {false, true, true, true, false, true, true, true, false}, {});
1064}
1065
1066XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
1067  const int32 min = std::numeric_limits<int32>::min();
1068  const int32 max = std::numeric_limits<int32>::max();
1069  ComputationBuilder builder(client_, TestName());
1070  auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
1071  auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
1072  auto compare = builder.Ge(lhs, rhs);
1073
1074  ComputeAndCompareR1<bool>(
1075      &builder, {true, false, false, true, true, false, true, true, true}, {});
1076}
1077
1078XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
1079  const int32 min = std::numeric_limits<int32>::min();
1080  const int32 max = std::numeric_limits<int32>::max();
1081  ComputationBuilder builder(client_, TestName());
1082  auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
1083  auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
1084  auto compare = builder.Gt(lhs, rhs);
1085
1086  ComputeAndCompareR1<bool>(
1087      &builder, {false, false, false, true, false, false, true, true, false},
1088      {});
1089}
1090
1091XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
1092  const int32 min = std::numeric_limits<int32>::min();
1093  const int32 max = std::numeric_limits<int32>::max();
1094  ComputationBuilder builder(client_, TestName());
1095  auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
1096  auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
1097  auto compare = builder.Le(lhs, rhs);
1098
1099  ComputeAndCompareR1<bool>(
1100      &builder, {true, true, true, false, true, true, false, false, true}, {});
1101}
1102
1103XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
1104  const int32 min = std::numeric_limits<int32>::min();
1105  const int32 max = std::numeric_limits<int32>::max();
1106  ComputationBuilder builder(client_, TestName());
1107  auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
1108  auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
1109  auto compare = builder.Lt(lhs, rhs);
1110
1111  ComputeAndCompareR1<bool>(
1112      &builder, {false, true, true, false, false, true, false, false, false},
1113      {});
1114}
1115
1116XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
1117  const uint32 max = std::numeric_limits<uint32>::max();
1118  ComputationBuilder builder(client_, TestName());
1119  auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
1120  auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
1121  auto compare = builder.Eq(lhs, rhs);
1122
1123  ComputeAndCompareR1<bool>(
1124      &builder, {true, false, false, false, true, false, false, false, true},
1125      {});
1126}
1127
1128XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
1129  const uint32 max = std::numeric_limits<uint32>::max();
1130  ComputationBuilder builder(client_, TestName());
1131  auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
1132  auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
1133  auto compare = builder.Ne(lhs, rhs);
1134
1135  ComputeAndCompareR1<bool>(
1136      &builder, {false, true, true, true, false, true, true, true, false}, {});
1137}
1138
1139XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
1140  const uint32 max = std::numeric_limits<uint32>::max();
1141  ComputationBuilder builder(client_, TestName());
1142  auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
1143  auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
1144  auto compare = builder.Ge(lhs, rhs);
1145
1146  ComputeAndCompareR1<bool>(
1147      &builder, {true, false, false, true, true, false, true, true, true}, {});
1148}
1149
1150XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
1151  const uint32 max = std::numeric_limits<uint32>::max();
1152  ComputationBuilder builder(client_, TestName());
1153  auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
1154  auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
1155  auto compare = builder.Gt(lhs, rhs);
1156
1157  ComputeAndCompareR1<bool>(
1158      &builder, {false, false, false, true, false, false, true, true, false},
1159      {});
1160}
1161
1162XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
1163  const uint32 max = std::numeric_limits<uint32>::max();
1164  ComputationBuilder builder(client_, TestName());
1165  auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
1166  auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
1167  auto compare = builder.Le(lhs, rhs);
1168
1169  ComputeAndCompareR1<bool>(
1170      &builder, {true, true, true, false, true, true, false, false, true}, {});
1171}
1172
1173XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
1174  const uint32 max = std::numeric_limits<uint32>::max();
1175  ComputationBuilder builder(client_, TestName());
1176  auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
1177  auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
1178  auto compare = builder.Lt(lhs, rhs);
1179
1180  ComputeAndCompareR1<bool>(
1181      &builder, {false, true, true, false, false, true, false, false, false},
1182      {});
1183}
1184
1185XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
1186  SetFastMathDisabled(true);
1187  ComputationBuilder builder(client_, TestName());
1188  auto lhs =
1189      builder.ConstantR1<float>({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f});
1190  auto rhs =
1191      builder.ConstantR1<float>({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f});
1192  auto minimum = builder.Pow(lhs, rhs);
1193
1194  ComputeAndCompareR1<float>(
1195      &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_);
1196}
1197
1198XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
1199  SetFastMathDisabled(true);
1200  ComputationBuilder builder(client_, TestName());
1201  auto lhs = builder.ConstantR1<float>({-2.0f, -0.6f, -0.6f, 0.0f});
1202  auto rhs = builder.ConstantR1<float>({0.5f, 0.6f, -0.6f, -0.6f});
1203  auto minimum = builder.Pow(lhs, rhs);
1204
1205  ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {},
1206                             error_spec_);
1207}
1208
1209XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
1210  ComputationBuilder builder(client_, TestName());
1211  auto lhs = builder.ConstantR1<float>({});
1212  auto rhs = builder.ConstantR1<float>({});
1213  auto minimum = builder.Pow(lhs, rhs);
1214
1215  ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1216}
1217
1218// Some Pow cases that can be implemented more efficiently.
1219XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
1220  ComputationBuilder b(client_, TestName());
1221
1222  std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
1223  std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1224
1225  std::unique_ptr<Literal> param_literal = Literal::CreateR1<float>(values);
1226  std::unique_ptr<GlobalData> param_data =
1227      client_->TransferToServer(*param_literal).ConsumeValueOrDie();
1228
1229  auto sum = b.ConstantR0<float>(0.0f);
1230  auto param = b.Parameter(0, param_literal->shape(), "param");
1231  for (float exponent : exponents) {
1232    sum = b.Add(sum, b.Pow(param, b.ConstantR0<float>(exponent)));
1233  }
1234
1235  std::vector<float> expected;
1236  for (auto value : values) {
1237    float sum = 0.0f;
1238    for (float exponent : exponents) {
1239      sum += std::pow(value, exponent);
1240    }
1241    expected.push_back(sum);
1242  }
1243
1244  ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
1245}
1246
1247XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
1248  ComputationBuilder b(client_, TestName());
1249
1250  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1251  std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1252
1253  std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
1254  std::unique_ptr<GlobalData> data0 =
1255      client_->TransferToServer(*literal0).ConsumeValueOrDie();
1256  std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
1257  std::unique_ptr<GlobalData> data1 =
1258      client_->TransferToServer(*literal1).ConsumeValueOrDie();
1259  auto param0 = b.Parameter(0, literal0->shape(), "param0");
1260  auto param1 = b.Parameter(1, literal1->shape(), "param1");
1261  b.Pow(b.Exp(param0), param1);
1262
1263  std::vector<float> expected(values0.size());
1264  for (int64 i = 0; i < values0.size(); ++i) {
1265    expected[i] = std::pow(std::exp(values0[i]), values1[i]);
1266  }
1267
1268  ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1269                             error_spec_);
1270}
1271
1272XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
1273  ComputationBuilder b(client_, TestName());
1274
1275  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
1276  std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1277
1278  std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
1279  std::unique_ptr<GlobalData> data0 =
1280      client_->TransferToServer(*literal0).ConsumeValueOrDie();
1281  std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
1282  std::unique_ptr<GlobalData> data1 =
1283      client_->TransferToServer(*literal1).ConsumeValueOrDie();
1284  auto param0 = b.Parameter(0, literal0->shape(), "param0");
1285  auto param1 = b.Parameter(1, literal1->shape(), "param1");
1286  b.Log(b.Pow(param0, param1));
1287
1288  std::vector<float> expected(values0.size());
1289  for (int64 i = 0; i < values0.size(); ++i) {
1290    expected[i] = std::log(std::pow(values0[i], values1[i]));
1291  }
1292
1293  ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1294                             error_spec_);
1295}
1296
1297XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
1298  ComputationBuilder b(client_, TestName());
1299
1300  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1301  std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1302
1303  std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
1304  std::unique_ptr<GlobalData> data0 =
1305      client_->TransferToServer(*literal0).ConsumeValueOrDie();
1306  std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
1307  std::unique_ptr<GlobalData> data1 =
1308      client_->TransferToServer(*literal1).ConsumeValueOrDie();
1309  auto param0 = b.Parameter(0, literal0->shape(), "param0");
1310  auto param1 = b.Parameter(1, literal1->shape(), "param1");
1311  b.Mul(b.Exp(param0), b.Exp(param1));
1312
1313  std::vector<float> expected(values0.size());
1314  for (int64 i = 0; i < values0.size(); ++i) {
1315    expected[i] = std::exp(values0[i]) * std::exp(values1[i]);
1316  }
1317
1318  ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1319                             error_spec_);
1320}
1321
1322XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
1323  ComputationBuilder b(client_, TestName());
1324
1325  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1326  std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1327
1328  std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
1329  std::unique_ptr<GlobalData> data0 =
1330      client_->TransferToServer(*literal0).ConsumeValueOrDie();
1331  std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
1332  std::unique_ptr<GlobalData> data1 =
1333      client_->TransferToServer(*literal1).ConsumeValueOrDie();
1334  auto param0 = b.Parameter(0, literal0->shape(), "param0");
1335  auto param1 = b.Parameter(1, literal1->shape(), "param1");
1336  b.Div(param0, b.Exp(param1));
1337
1338  std::vector<float> expected(values0.size());
1339  for (int64 i = 0; i < values0.size(); ++i) {
1340    expected[i] = values0[i] / std::exp(values1[i]);
1341  }
1342
1343  ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1344                             error_spec_);
1345}
1346
1347XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
1348  ComputationBuilder b(client_, TestName());
1349
1350  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1351  std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1352  std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1353
1354  std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
1355  std::unique_ptr<GlobalData> data0 =
1356      client_->TransferToServer(*literal0).ConsumeValueOrDie();
1357
1358  std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
1359  std::unique_ptr<GlobalData> data1 =
1360      client_->TransferToServer(*literal1).ConsumeValueOrDie();
1361
1362  std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
1363  std::unique_ptr<GlobalData> data2 =
1364      client_->TransferToServer(*literal2).ConsumeValueOrDie();
1365  auto param0 = b.Parameter(0, literal0->shape(), "param0");
1366  auto param1 = b.Parameter(1, literal1->shape(), "param1");
1367  auto param2 = b.Parameter(2, literal2->shape(), "param2");
1368  b.Div(b.Div(param0, param1), param2);
1369
1370  std::vector<float> expected(values0.size());
1371  for (int64 i = 0; i < values0.size(); ++i) {
1372    expected[i] = (values0[i] / values1[i]) / values2[i];
1373  }
1374
1375  ComputeAndCompareR1<float>(
1376      &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1377}
1378
1379XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
1380  ComputationBuilder b(client_, TestName());
1381
1382  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1383  std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1384  std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1385
1386  std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
1387  std::unique_ptr<GlobalData> data0 =
1388      client_->TransferToServer(*literal0).ConsumeValueOrDie();
1389
1390  std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
1391  std::unique_ptr<GlobalData> data1 =
1392      client_->TransferToServer(*literal1).ConsumeValueOrDie();
1393
1394  std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
1395  std::unique_ptr<GlobalData> data2 =
1396      client_->TransferToServer(*literal2).ConsumeValueOrDie();
1397
1398  auto param0 = b.Parameter(0, literal0->shape(), "param0");
1399  auto param1 = b.Parameter(1, literal1->shape(), "param1");
1400  auto param2 = b.Parameter(2, literal2->shape(), "param2");
1401  b.Div(param0, b.Div(param1, param2));
1402
1403  std::vector<float> expected(values0.size());
1404  for (int64 i = 0; i < values0.size(); ++i) {
1405    expected[i] = values0[i] / (values1[i] / values2[i]);
1406  }
1407
1408  ComputeAndCompareR1<float>(
1409      &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1410}
1411
1412XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
1413  ComputationBuilder b(client_, TestName());
1414
1415  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1416  std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
1417  std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
1418
1419  std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
1420  std::unique_ptr<GlobalData> data0 =
1421      client_->TransferToServer(*literal0).ConsumeValueOrDie();
1422
1423  std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
1424  std::unique_ptr<GlobalData> data1 =
1425      client_->TransferToServer(*literal1).ConsumeValueOrDie();
1426
1427  std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
1428  std::unique_ptr<GlobalData> data2 =
1429      client_->TransferToServer(*literal2).ConsumeValueOrDie();
1430
1431  auto param0 = b.Parameter(0, literal0->shape(), "param0");
1432  auto param1 = b.Parameter(1, literal1->shape(), "param1");
1433  auto param2 = b.Parameter(2, literal2->shape(), "param2");
1434  b.Div(param0, b.Pow(param1, param2));
1435
1436  std::vector<float> expected(values0.size());
1437  for (int64 i = 0; i < values0.size(); ++i) {
1438    expected[i] = values0[i] / std::pow(values1[i], values2[i]);
1439  }
1440
1441  ComputeAndCompareR1<float>(
1442      &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1443}
1444
1445XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
1446  ComputationBuilder b(client_, TestName());
1447
1448  std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1449  std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1450  std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1451  std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
1452
1453  std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
1454  std::unique_ptr<GlobalData> data0 =
1455      client_->TransferToServer(*literal0).ConsumeValueOrDie();
1456
1457  std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
1458  std::unique_ptr<GlobalData> data1 =
1459      client_->TransferToServer(*literal1).ConsumeValueOrDie();
1460
1461  std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
1462  std::unique_ptr<GlobalData> data2 =
1463      client_->TransferToServer(*literal2).ConsumeValueOrDie();
1464
1465  std::unique_ptr<Literal> literal3 = Literal::CreateR1<float>(values3);
1466  std::unique_ptr<GlobalData> data3 =
1467      client_->TransferToServer(*literal3).ConsumeValueOrDie();
1468
1469  auto param0 = b.Parameter(0, literal0->shape(), "param0");
1470  auto param1 = b.Parameter(1, literal1->shape(), "param1");
1471  auto param2 = b.Parameter(2, literal2->shape(), "param2");
1472  auto param3 = b.Parameter(3, literal3->shape(), "param2");
1473  b.Div(b.Div(param0, param1), b.Div(param2, param3));
1474
1475  std::vector<float> expected(values0.size());
1476  for (int64 i = 0; i < values0.size(); ++i) {
1477    expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]);
1478  }
1479
1480  ComputeAndCompareR1<float>(
1481      &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()},
1482      error_spec_);
1483}
1484
1485TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
1486  const int count = GetParam();
1487  ComputationBuilder builder(client_, TestName());
1488  std::vector<float> values;
1489  values.reserve(count);
1490  for (int i = 0; i < count; ++i) {
1491    values.push_back(i / static_cast<float>(count));
1492  }
1493  auto x = builder.ConstantR1<float>(values);
1494  auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
1495
1496  std::vector<float> expected;
1497  expected.reserve(values.size());
1498  for (float value : values) {
1499    expected.push_back(value * value);
1500  }
1501
1502  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1503}
1504
1505XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
1506  ComputationBuilder builder(client_, TestName());
1507  Array4D<float> values(2, 2, 2, 2);
1508
1509  std::vector<float> values_vector;
1510  std::vector<float> expected_vector;
1511  for (int i = 0; i < values.num_elements(); ++i) {
1512    values_vector.push_back(static_cast<float>(i) / values.num_elements());
1513    expected_vector.push_back(values_vector.back() * values_vector.back());
1514  }
1515  values.SetValues(values_vector);
1516
1517  Array4D<float> expected(2, 2, 2, 2, expected_vector);
1518
1519  auto x = builder.ConstantR4FromArray4D<float>(values);
1520  auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
1521
1522  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1523}
1524
1525XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
1526  ComputationBuilder builder(client_, TestName());
1527  Array4D<float> values(2, 2, 0, 2);
1528  Array4D<float> expected(2, 2, 0, 2);
1529
1530  auto x = builder.ConstantR4FromArray4D<float>(values);
1531  auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
1532
1533  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1534}
1535
1536// GPU backend emits nvvm intrinsic for fmin and fmax, whose semantics is NOT
1537// such
1538// * fmin(NaN, x) = x
1539// * fmax(NaN, x) = x
1540// so we only test NAN on CPU.
1541//
1542// TODO(b/28180546): Make this compile in a way that is consistent
1543// among backends.
1544XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
1545  ComputationBuilder builder(client_, TestName());
1546#if !defined(XLA_TEST_BACKEND_CPU)
1547  auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
1548  auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f});
1549#else
1550  SetFastMathDisabled(true);
1551  auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f});
1552  auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN});
1553#endif
1554  auto minimum = builder.Min(lhs, rhs);
1555
1556  ComputeAndCompareR1<float>(&builder,
1557#if !defined(XLA_TEST_BACKEND_CPU)
1558                             {1.0f, -5.0f, 1.0f},
1559#else
1560                             {1.0f, -5.0f, 1.0f, 10.0f, 6.0f},
1561#endif
1562                             {}, error_spec_);
1563}
1564
1565XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) {
1566  ComputationBuilder builder(client_, TestName());
1567  auto lhs = builder.ConstantR1<float>({});
1568  auto rhs = builder.ConstantR1<float>({});
1569  auto minimum = builder.Min(lhs, rhs);
1570  ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1571}
1572
1573// TODO(b/28180546): Make this compile in a way that is consistent
1574// among backends. See comment on MinF32s test above.
1575XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
1576  ComputationBuilder builder(client_, TestName());
1577#if !defined(XLA_TEST_BACKEND_CPU)
1578  auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25});
1579  auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0});
1580#else
1581  SetFastMathDisabled(true);
1582  auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0});
1583  auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN});
1584#endif
1585  auto minimum = builder.Min(lhs, rhs);
1586
1587  ComputeAndCompareR1<double>(&builder,
1588#if !defined(XLA_TEST_BACKEND_CPU)
1589                              {1.0, -5.0, 1.0},
1590#else
1591                              {1.0, -5.0, 1.0, 10.0, 6.0},
1592#endif
1593                              {}, error_spec_);
1594}
1595
1596// TODO(b/28180546): Make this compile in a way that is consistent
1597// among backends. See comment on MinF32s test above.
1598XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
1599  ComputationBuilder builder(client_, TestName());
1600#if !defined(XLA_TEST_BACKEND_CPU)
1601  auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
1602  auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f});
1603#else
1604  SetFastMathDisabled(true);
1605  auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f});
1606  auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN});
1607#endif
1608  auto maximum = builder.Max(lhs, rhs);
1609
1610  ComputeAndCompareR1<float>(&builder,
1611#if !defined(XLA_TEST_BACKEND_CPU)
1612                             {2.0f, 1.0f, 2.25f},
1613#else
1614                             {2.0f, 1.0f, 2.25f, 10.0f, 6.0f},
1615#endif
1616                             {}, error_spec_);
1617}
1618
1619XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) {
1620  ComputationBuilder builder(client_, TestName());
1621  auto lhs = builder.ConstantR1<float>({});
1622  auto rhs = builder.ConstantR1<float>({});
1623  auto minimum = builder.Max(lhs, rhs);
1624  ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1625}
1626
1627// TODO(b/28180546): Make this compile in a way that is consistent
1628// among backends. See comment on MinF32s test above.
1629XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
1630  ComputationBuilder builder(client_, TestName());
1631#if !defined(XLA_TEST_BACKEND_CPU)
1632  auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25});
1633  auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0});
1634#else
1635  SetFastMathDisabled(true);
1636  auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0});
1637  auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN});
1638#endif
1639  auto maximum = builder.Max(lhs, rhs);
1640
1641  ComputeAndCompareR1<double>(&builder,
1642#if !defined(XLA_TEST_BACKEND_CPU)
1643                              {2.0, 1.0, 2.25},
1644#else
1645                              {2.0, 1.0, 2.25, 10.0, 6.0},
1646#endif
1647                              {}, error_spec_);
1648}
1649
1650XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) {
1651  const int32 min = std::numeric_limits<int32>::min();
1652  const int32 max = std::numeric_limits<int32>::max();
1653  ComputationBuilder builder(client_, TestName());
1654  auto x = builder.ConstantR1<int32>(
1655      {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1656  auto y = builder.ConstantR1<int32>(
1657      {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1658  builder.Max(x, y);
1659
1660  std::vector<int32> expected = {min, max, 0,  -1,  0,   0,  0,
1661                                 1,   1,   10, max, max, max};
1662  ComputeAndCompareR1<int32>(&builder, expected, {});
1663}
1664
1665XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
1666  const int32 min = std::numeric_limits<int32>::min();
1667  const int32 max = std::numeric_limits<int32>::max();
1668  ComputationBuilder builder(client_, TestName());
1669  auto x = builder.ConstantR1<int32>(
1670      {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1671  auto y = builder.ConstantR1<int32>(
1672      {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1673  builder.Min(x, y);
1674
1675  std::vector<int32> expected = {min, min, min, -10, -1,  -1, 0,
1676                                 0,   0,   1,   0,   max, min};
1677  ComputeAndCompareR1<int32>(&builder, expected, {});
1678}
1679
1680XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
1681  const uint32 max = std::numeric_limits<uint32>::max();
1682  ComputationBuilder builder(client_, TestName());
1683  auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
1684  auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max});
1685  builder.Max(x, y);
1686
1687  std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max};
1688  ComputeAndCompareR1<uint32>(&builder, expected, {});
1689}
1690
1691XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
1692  const uint32 max = std::numeric_limits<uint32>::max();
1693  ComputationBuilder builder(client_, TestName());
1694  auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
1695  auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max});
1696  builder.Min(x, y);
1697
1698  std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max};
1699  ComputeAndCompareR1<uint32>(&builder, expected, {});
1700}
1701
1702XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
1703  ComputationBuilder builder(client_, TestName());
1704  auto x = builder.ConstantR1<float>(
1705      {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
1706  auto y = builder.ConstantR1<float>(
1707      {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
1708  builder.Max(x, y);
1709
1710  std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
1711                                 5.0,  6.0, 7.0, 8.0, 9.0};
1712  ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1713}
1714
1715XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
1716  ComputationBuilder builder(client_, TestName());
1717  auto u = builder.ConstantR1<float>({3.5});
1718  auto v = builder.ConstantR1<float>({});
1719  builder.Max(u, v);
1720
1721  ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1722}
1723
1724XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
1725  for (int broadcast_dim : {0, 1}) {
1726    ComputationBuilder builder(client_, TestName());
1727    auto u = builder.ConstantR1<float>({3.5});
1728    auto v = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
1729    builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
1730
1731    ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
1732  }
1733}
1734
1735XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
1736  ComputationBuilder builder(client_, TestName());
1737  auto v = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
1738  auto m =
1739      builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
1740  builder.Max(v, m, /*broadcast_dimensions=*/{1});
1741
1742  Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}});
1743  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1744}
1745
1746XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
1747  ComputationBuilder builder(client_, TestName());
1748  auto v = builder.ConstantR1<float>({});
1749  auto m = builder.ConstantR2<float>({{}, {}});
1750  builder.Max(v, m, /*broadcast_dimensions=*/{1});
1751
1752  Array2D<float> expected({{}, {}});
1753  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1754}
1755
1756XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
1757  ComputationBuilder builder(client_, TestName());
1758  auto scalar = builder.ConstantR0<int32>(2);
1759  Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}});
1760  auto array = builder.ConstantR3FromArray3D<int32>(a_3d);
1761  builder.Max(array, scalar, /*broadcast_dimensions=*/{});
1762
1763  Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}});
1764  ComputeAndCompareR3<int32>(&builder, expected, {});
1765}
1766
1767XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
1768  ComputationBuilder builder(client_, TestName());
1769  auto scalar = builder.ConstantR0<int32>(2);
1770  Array3D<int32> a_3d(2, 0, 3);
1771  auto array = builder.ConstantR3FromArray3D<int32>(a_3d);
1772  builder.Max(array, scalar, /*broadcast_dimensions=*/{});
1773
1774  Array3D<int32> expected(2, 0, 3);
1775  ComputeAndCompareR3<int32>(&builder, expected, {});
1776}
1777
1778XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
1779  ComputationBuilder builder(client_, TestName());
1780  auto m =
1781      builder.ConstantR2<float>({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
1782  auto v = builder.ConstantR1<float>({-10.2f, 16.4f});
1783  builder.Min(m, v, /*broadcast_dimensions=*/{0});
1784
1785  Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}});
1786  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1787}
1788
1789XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
1790  ComputationBuilder builder(client_, TestName());
1791  auto m = builder.ConstantR2<float>({{}, {}});
1792  auto v = builder.ConstantR1<float>({-10.2f, 16.4f});
1793  builder.Min(m, v, /*broadcast_dimensions=*/{0});
1794
1795  Array2D<float> expected({{}, {}});
1796  ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1797}
1798
1799XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
1800  ComputationBuilder builder(client_, TestName());
1801  auto array2d =
1802      builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
1803  auto array4d = builder.ConstantR4FromArray4D<float>(
1804      {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
1805       {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
1806  builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
1807
1808  Array4D<float> expected(
1809      {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}},
1810       {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}});
1811  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1812}
1813
1814XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
1815  ComputationBuilder builder(client_, TestName());
1816  auto array2d =
1817      builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
1818  Array4D<float> arg(2, 2, 0, 3);
1819  auto array4d = builder.ConstantR4FromArray4D<float>(arg);
1820  builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
1821
1822  Array4D<float> expected(2, 2, 0, 3);
1823  ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1824}
1825
1826XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
1827  ComputationBuilder builder(client_, TestName());
1828  auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
1829  auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
1830  builder.Min(x, y);
1831
1832  std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0};
1833  ComputeAndCompareR1<int32>(&builder, expected, {});
1834}
1835
1836XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
1837  ComputationBuilder builder(client_, TestName());
1838  auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
1839  auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
1840  builder.Max(x, y);
1841
1842  std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9};
1843  ComputeAndCompareR1<int32>(&builder, expected, {});
1844}
1845
1846XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
1847  ComputationBuilder builder(client_, TestName());
1848  auto a = builder.ConstantR1<int32>({-3, 26, 2, -1, 1});
1849  auto b = builder.ConstantR1<int32>({10, 5, 1, 10, -10});
1850  auto add = builder.Rem(a, b);
1851
1852  ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {});
1853}
1854
1855XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
1856  ComputationBuilder builder(client_, TestName());
1857  auto minimum = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
1858  auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
1859  auto maximum = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0});
1860  auto clamp = builder.Clamp(minimum, argument, maximum);
1861
1862  ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {},
1863                             error_spec_);
1864}
1865
1866XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
1867  ComputationBuilder builder(client_, TestName());
1868  auto minimum = builder.ConstantR0<float>(0.0f);
1869  auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
1870  auto maximum = builder.ConstantR0<float>(5.0f);
1871  auto clamp = builder.Clamp(minimum, argument, maximum);
1872
1873  ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {},
1874                             error_spec_);
1875}
1876
1877XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
1878  ComputationBuilder builder(client_, TestName());
1879  auto min_scalar = builder.ConstantR0<float>(0.0f);
1880  auto min_vector = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
1881  auto arg_vector = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
1882  auto max_scalar = builder.ConstantR0<float>(3.0f);
1883  auto max_vector = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0});
1884  // Perform clamp with broadcasted scalar and vector.
1885  auto clamp = builder.Add(
1886      builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar),
1887                  builder.Clamp(min_scalar, arg_vector, max_vector)),
1888      builder.Add(builder.Clamp(min_vector, arg_vector, max_vector),
1889                  builder.Clamp(min_scalar, arg_vector, max_scalar)));
1890
1891  ComputeAndCompareR1<float>(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {},
1892                             error_spec_);
1893}
1894
1895XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
1896  ComputationBuilder builder(client_, TestName());
1897  auto min_vector = builder.ConstantR1<int32>({1, -6, 1, 2, 0, -5});
1898  auto arg_vector = builder.ConstantR1<int32>({2, 10, -5, 1, 4, 10});
1899  auto max_vector = builder.ConstantR1<int32>({3, 0, 25, 5, 123, -1});
1900  auto clamp = builder.Clamp(min_vector, arg_vector, max_vector);
1901
1902  ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {});
1903}
1904
1905XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) {
1906  ComputationBuilder builder(client_, TestName());
1907  auto min_scalar = builder.ConstantR0<int32>(0);
1908  auto min_vector = builder.ConstantR1<int32>({1, -6, 1, 2, 0});
1909  auto arg_vector = builder.ConstantR1<int32>({2, 10, -5, 1, 4});
1910  auto max_scalar = builder.ConstantR0<int32>(3);
1911  auto max_vector = builder.ConstantR1<int32>({3, 1, 25, 5, 123});
1912  // Perform clamp with broadcasted scalar and vector.
1913  auto clamp = builder.Add(
1914      builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar),
1915                  builder.Clamp(min_scalar, arg_vector, max_vector)),
1916      builder.Add(builder.Clamp(min_vector, arg_vector, max_vector),
1917                  builder.Clamp(min_scalar, arg_vector, max_scalar)));
1918
1919  ComputeAndCompareR1<int32>(&builder, {8, 8, 2, 6, 14}, {});
1920}
1921
1922XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
1923  ComputationBuilder builder(client_, TestName());
1924  auto min_vector = builder.ConstantR1<uint32>({1, 2, 1, 2, 0, ~0u - 4});
1925  auto arg_vector = builder.ConstantR1<uint32>({2, 10, 5, 1, 4, 10});
1926  auto max_vector = builder.ConstantR1<uint32>({3, 5, 25, 5, 123, ~0u});
1927  auto clamp = builder.Clamp(min_vector, arg_vector, max_vector);
1928
1929  ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
1930}
1931
1932XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
1933  ComputationBuilder builder(client_, TestName());
1934  auto min_scalar = builder.ConstantR0<uint32>(0);
1935  auto min_vector = builder.ConstantR1<uint32>({1, 0, 1, 2, 0});
1936  auto arg_vector = builder.ConstantR1<uint32>({2, 10, 0, 1, 4});
1937  auto max_scalar = builder.ConstantR0<uint32>(3);
1938  auto max_vector = builder.ConstantR1<uint32>({3, 1, 25, 5, 123});
1939  // Perform clamp with broadcasted scalar and vector.
1940  auto clamp = builder.Add(
1941      builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar),
1942                  builder.Clamp(min_scalar, arg_vector, max_vector)),
1943      builder.Add(builder.Clamp(min_vector, arg_vector, max_vector),
1944                  builder.Clamp(min_scalar, arg_vector, max_scalar)));
1945
1946  ComputeAndCompareR1<uint32>(&builder, {8, 8, 2, 6, 14}, {});
1947}
1948
1949XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
1950  ComputationBuilder builder(client_, TestName());
1951
1952  std::unique_ptr<Literal> param0_literal =
1953      Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
1954  std::unique_ptr<GlobalData> param0_data =
1955      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
1956
1957  std::unique_ptr<Literal> param1_literal =
1958      Literal::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
1959  std::unique_ptr<GlobalData> param1_data =
1960      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
1961
1962  auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
1963  auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
1964  auto add = builder.Add(p0, p1);
1965
1966  ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
1967                             {param0_data.get(), param1_data.get()},
1968                             error_spec_);
1969}
1970
1971XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
1972  ComputationBuilder builder(client_, TestName());
1973
1974  std::unique_ptr<Literal> param0_literal =
1975      Literal::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
1976  std::unique_ptr<GlobalData> param0_data =
1977      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
1978
1979  std::unique_ptr<Literal> param1_literal =
1980      Literal::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
1981  std::unique_ptr<GlobalData> param1_data =
1982      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
1983
1984  auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
1985  auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
1986  auto add = builder.Add(p0, p1);
1987
1988  Array3D<float> expected(0, 7, 0);
1989  ComputeAndCompareR3<float>(
1990      &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
1991}
1992
1993XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
1994  ComputationBuilder builder(client_, TestName());
1995
1996  std::unique_ptr<Literal> param0_literal =
1997      Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
1998  std::unique_ptr<GlobalData> param0_data =
1999      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
2000
2001  auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f});
2002  auto p = builder.Parameter(0, param0_literal->shape(), "param0");
2003  auto add = builder.Add(a, p);
2004
2005  ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
2006                             {param0_data.get()}, error_spec_);
2007}
2008
2009XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) {
2010  ComputationBuilder builder(client_, TestName());
2011  auto a = builder.ConstantR1<float>({3.14159f, 0.0f, 1.570796f, -0.78539f});
2012  auto result = builder.Cos(a);
2013
2014  ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {},
2015                             error_spec_);
2016}
2017
2018XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
2019  ComputationBuilder builder(client_, TestName());
2020  auto a = builder.ConstantR1<float>({3.14159f, 0.0f, 1.570796f, -0.78539f});
2021  auto result = builder.Sin(a);
2022
2023  ComputeAndCompareR1<float>(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {},
2024                             error_spec_);
2025}
2026
2027XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
2028  ComputationBuilder builder(client_, TestName());
2029  auto a = builder.ConstantR1<float>({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f});
2030  auto b = builder.ConstantR1<float>({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f});
2031  auto atan = builder.Atan2(a, b);
2032
2033  ComputeAndCompareR1<float>(
2034      &builder,
2035      {0.0f, 1.57079633f, 3.14159265f, -1.57079633f, 0.78539816f, -0.78539816f},
2036      {}, error_spec_);
2037}
2038
2039XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
2040  ComputationBuilder builder(client_, TestName());
2041  auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f});
2042  auto result = builder.Tanh(a);
2043
2044  ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {},
2045                             error_spec_);
2046}
2047
2048XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
2049  // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that
2050  // the input tensor is large enough to exercise the vectorized tanh
2051  // implementation on XLA CPU.
2052  ComputationBuilder builder(client_, TestName());
2053  auto input_literal = Literal::CreateR1<float>(
2054      {1.02,  -0.32, 0.85,  0.90,  1.23,  -0.91, -0.49, 0.80,  -0.67, 0.16,
2055       -0.07, 0.39,  -0.41, 0.04,  1.36,  1.25,  0.41,  0.65,  -1.08, 0.32,
2056       -1.45, -0.77, -1.09, 0.91,  -1.03, -0.30, -1.11, -1.17, 1.50,  -0.85,
2057       0.04,  1.02,  0.34,  -0.61, 0.41,  0.07,  -0.02, 1.42,  -0.62, 0.81,
2058       0.08,  0.81,  -0.30, 1.17,  -0.65, -0.44, 0.92,  1.26,  -1.29, 1.35,
2059       0.08,  -1.24, -0.92, 0.49,  1.17,  -0.45, -1.31, -1.44, -0.13, -1.31,
2060       -0.79, 1.41,  1.21,  1.05});
2061  TF_ASSERT_OK_AND_ASSIGN(auto input_data,
2062                          client_->TransferToServer(*input_literal));
2063
2064  auto input = builder.Parameter(0, input_literal->shape(), "input");
2065  builder.Tanh(input);
2066
2067  ComputeAndCompareR1<float>(
2068      &builder,
2069      {0.77009583,  -0.30665702, 0.69070244,  0.71401149,  0.84400684,
2070       -0.71985596, -0.45764771, 0.66664988,  -0.58278900, 0.16050975,
2071       -0.06770509, 0.36843640,  -0.38476998, 0.04018109,  0.87562293,
2072       0.84788644,  0.38603750,  0.57294142,  -0.79140943, 0.31032649,
2073       -0.89590985, -0.64770776, -0.79625875, 0.72234446,  -0.77389336,
2074       -0.28871772, -0.80428445, -0.82541436, 0.90456349,  -0.68856895,
2075       0.03877772,  0.76877952,  0.32561871,  -0.54546672, 0.39072621,
2076       0.07273290,  -0.01924866, 0.88924897,  -0.55283129, 0.67183107,
2077       0.08006320,  0.66944766,  -0.29068485, 0.82573754,  -0.57170743,
2078       -0.41581789, 0.72739530,  0.85025692,  -0.85931867, 0.87357593,
2079       0.07782833,  -0.84597743, -0.72748238, 0.45396307,  0.82449573,
2080       -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558,
2081       -0.65565848, 0.88789743,  0.83566397,  0.78287679},
2082      {input_data.get()},
2083      // The error spec is unusually high here to account for the fact that we
2084      // use a rational interpolant to approximate tanh.
2085      ErrorSpec(0.004, 0.004));
2086}
2087
2088XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
2089  // The input tensor is large enough to exercise the vectorized exp
2090  // implementation on XLA CPU.
2091  ComputationBuilder builder(client_, TestName());
2092
2093  // Just to help make sense of the scales here -- exp(89) saturates float32 and
2094  // exp(-10) is smaller than our error spec.
2095  std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(
2096      {1.02,   -0.32,  0.85,   0.9,    1.23,   -0.91,  -0.49, 0.8,    -1.31,
2097       -1.44,  -0.13,  -1.31,  -0.79,  1.41,   1.21,   1.05,  -195.6, -194.5,
2098       -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5,  -17.4,
2099       -16.3,  -15.2,  -14.1,  -13.0,  -11.9,  -10.8,  -9.7,  -8.6,   -7.5,
2100       -6.4,   -5.3,   -4.2,   -3.1,   -2.0,   -0.9,   0.2,   1.3,    2.4,
2101       3.5,    4.6,    5.7,    6.8,    7.9,    9.0,    10.1,  11.2,   12.3,
2102       13.4,   14.5,   15.6,   16.7,   17.8,   18.9,   20.0,  21.1,   22.2,
2103       23.3,   24.4,   25.5,   26.6,   27.7,   28.8,   29.9,  31.0,   32.1,
2104       68.4,   69.5,   70.6,   71.7,   72.8,   73.9,   75.0,  76.1,   77.2,
2105       78.3,   79.4,   80.5,   81.6,   82.7,   83.8,   84.9,  85.2,   86.3,
2106       86.4,   86.5,   87.6,   87.7,   87.8,   87.9});
2107  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2108                          client_->TransferToServer(*input_literal));
2109
2110  auto input = builder.Parameter(0, input_literal->shape(), "input");
2111  builder.Exp(input);
2112
2113  std::vector<float> expected_result;
2114  int64 input_size = input_literal->shape().dimensions(0);
2115  expected_result.reserve(input_size);
2116  for (int64 i = 0; i < input_size; i++) {
2117    expected_result.push_back(std::exp(input_literal->Get<float>({i})));
2118  }
2119
2120  ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2121                             error_spec_);
2122}
2123
2124XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
2125  // The input tensor is large enough to exercise the vectorized exp
2126  // implementation on XLA CPU.
2127  ComputationBuilder builder(client_, TestName());
2128
2129  std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(
2130      {-1.29,    -1.41,    -1.25,    -13.5,    -11.7,    -17.9,    -198,
2131       -167,     1.29,     1.41,     1.25,     13.5,     11.7,     17.9,
2132       198,      167,      1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04,  1.84e+04,
2133       1.74e+04, 1.89e+05, 1.9e+05,  1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07,
2134       1.66e+07, 1e+07,    1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09,
2135       1.44e+10, 1.5e+10,  1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12,
2136       1.4e+12,  1.03e+13, 1.6e+13,  1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15,
2137       1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17,
2138       2e+18,    1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20,
2139       1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21,  1.35e+22, 1.84e+22, 1.02e+22,
2140       1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25,
2141       1.62e+25, 1.2e+26,  1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28,
2142       1.5e+28,  1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30,  1.81e+30, 1.34e+30,
2143       1.7e+31,  1.44e+31, 1.1e+31,  1.4e+32,  1.67e+32, 1.96e+33, 1.11e+33,
2144       1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
2145  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2146                          client_->TransferToServer(*input_literal));
2147
2148  auto input = builder.Parameter(0, input_literal->shape(), "input");
2149  builder.Log(input);
2150
2151  std::vector<float> expected_result;
2152  int64 input_size = input_literal->shape().dimensions(0);
2153  expected_result.reserve(input_size);
2154  for (int64 i = 0; i < input_size; i++) {
2155    expected_result.push_back(std::log(input_literal->Get<float>({i})));
2156  }
2157
2158  ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2159                             error_spec_);
2160}
2161
2162XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
2163  // a ------ (add) --------- (add)
2164  //         /               /
2165  // b -----/               /
2166  // c---------------------/
2167  ComputationBuilder builder(client_, TestName());
2168
2169  auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f});
2170  auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
2171  auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
2172
2173  auto add = builder.Add(a, b);
2174  auto add2 = builder.Add(add, c);
2175
2176  ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {},
2177                             error_spec_);
2178}
2179
2180XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
2181  // b ------ (add) --------- (add)
2182  //         /               /
2183  // c -----/               /
2184  // a---------------------/
2185  ComputationBuilder builder(client_, TestName());
2186
2187  auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
2188  auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
2189  auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
2190
2191  auto add = builder.Add(b, c);
2192  auto add2 = builder.Add(a, add);
2193
2194  ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {},
2195                             error_spec_);
2196}
2197
2198XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
2199  // a ----- (neg) ----- (add)
2200  //                    /
2201  // b ----- (neg) ----/
2202  ComputationBuilder builder(client_, TestName());
2203
2204  auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
2205  auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
2206
2207  auto neg_a = builder.Neg(a);
2208  auto neg_b = builder.Neg(b);
2209  auto result = builder.Add(neg_a, neg_b);
2210
2211  ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {},
2212                             error_spec_);
2213}
2214
2215XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
2216  // a ------ (add) ------------\
2217  //         /                   \
2218  // b -----/                    (add)
2219  //                             /
2220  // c ------ (add) ------------/
2221  //         /
2222  // d -----/
2223  ComputationBuilder builder(client_, TestName());
2224
2225  auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
2226  auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
2227  auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
2228  auto d = builder.ConstantR1<float>({-19.0f, 10.0f, -40.0f, 20.2f});
2229
2230  auto add_ab = builder.Add(a, b);
2231  auto add_cd = builder.Add(c, d);
2232  auto add_all = builder.Add(add_ab, add_cd);
2233
2234  ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {},
2235                             error_spec_);
2236}
2237
2238XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
2239  ComputationBuilder builder(client_, TestName());
2240  auto a =
2241      builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2242  auto b =
2243      builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2244  auto add = builder.Add(a, b);
2245
2246  Array2D<float> expected_array(
2247      {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
2248  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2249}
2250
2251XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
2252  // Add a scalar + matrix.
2253  ComputationBuilder builder(client_, TestName());
2254  auto a =
2255      builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2256  auto scalar = builder.ConstantR0<float>(3.0f);
2257  auto add = builder.Add(scalar, a);
2258
2259  Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2260  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2261}
2262
2263XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
2264  // Add a matrix + scalar.
2265  ComputationBuilder builder(client_, TestName());
2266  auto a =
2267      builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2268  auto scalar = builder.ConstantR0<float>(3.0f);
2269  auto add = builder.Add(a, scalar);
2270
2271  Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2272  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2273}
2274
2275XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
2276  // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches
2277  // only dim 0 of the matrix.
2278  ComputationBuilder builder(client_, TestName());
2279  auto v = builder.ConstantR1<float>({20.0f, 40.0f, 60.0f});
2280  // clang-format off
2281  auto m = builder.ConstantR2<float>({
2282    {-2.5f, 3.14f, 1.0f},
2283    {2.25f, -10.0f, 3.33f}});
2284  // clang-format on
2285  auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1});
2286  Array2D<float> expected_array(
2287      {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}});
2288  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2289}
2290
2291XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
2292  // Test broadcasting in Eq comparison.
2293  ComputationBuilder builder(client_, TestName());
2294  auto v = builder.ConstantR1<int32>({42, 73});
2295  auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}});
2296
2297  // This test exercises both possible broadcast dimensions for a vector/matrix
2298  // comparison.
2299  auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1});
2300  auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0});
2301  auto result = builder.Tuple({cmp_dim_0, cmp_dim_1});
2302
2303  auto expected = Literal::MakeTuple(
2304      {Literal::CreateR2<bool>({{true, true}, {true, false}}).get(),
2305       Literal::CreateR2<bool>({{true, false}, {false, false}}).get()});
2306  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
2307}
2308
2309XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
2310  // Test broadcasting in Ne comparison.
2311  ComputationBuilder builder(client_, TestName());
2312  auto v = builder.ConstantR1<int32>({42, 73});
2313  auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}});
2314  auto cmp = builder.Ne(v, m, /*broadcast_dimensions=*/{1});
2315
2316  const string expected = R"(pred[2,2] {
2317  { 00 },
2318  { 01 }
2319})";
2320  EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2321}
2322
2323XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
2324  // Test broadcasting in Ge comparison.
2325  ComputationBuilder builder(client_, TestName());
2326  auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
2327  auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
2328  auto cmp = builder.Ge(v, m, /*broadcast_dimensions=*/{1});
2329
2330  const string expected = R"(pred[2,4] {
2331  { 1100 },
2332  { 0001 }
2333})";
2334  EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2335}
2336
2337XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
2338  // Test broadcasting in Gt comparison.
2339  ComputationBuilder builder(client_, TestName());
2340  auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
2341  auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
2342  auto cmp = builder.Gt(v, m, /*broadcast_dimensions=*/{1});
2343
2344  const string expected = R"(pred[2,4] {
2345  { 0100 },
2346  { 0000 }
2347})";
2348  EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2349}
2350
2351XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
2352  // Test broadcasting in Le comparison.
2353  ComputationBuilder builder(client_, TestName());
2354  auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
2355  auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
2356  auto cmp = builder.Le(v, m, /*broadcast_dimensions=*/{1});
2357
2358  const string expected = R"(pred[2,4] {
2359  { 1011 },
2360  { 1111 }
2361})";
2362  EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2363}
2364
2365XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
2366  // Test broadcasting in Lt comparison.
2367  ComputationBuilder builder(client_, TestName());
2368  auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
2369  auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
2370  auto cmp = builder.Lt(v, m, /*broadcast_dimensions=*/{1});
2371
2372  const string expected = R"(pred[2,4] {
2373  { 0011 },
2374  { 1110 }
2375})";
2376  EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2377}
2378
2379XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
2380  // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
2381  // arguments is reversed.
2382  ComputationBuilder builder(client_, TestName());
2383  auto m = builder.ConstantR2<float>({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
2384  auto v = builder.ConstantR1<float>({2.0f, 4.0f, 6.0f});
2385  auto add = builder.Mul(m, v, /*broadcast_dimensions=*/{1});
2386  Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}});
2387  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2388}
2389
2390XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
2391  // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2392  ComputationBuilder builder(client_, TestName());
2393  // m's shape in XLA notation is {3, 2}
2394  // md's shape in XLA notation is {3, 1}
2395  // The result has shape {3, 2}, where md is broadcast over m
2396  auto m =
2397      builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2398  auto md = builder.ConstantR2<float>({{10.0f, 20.0f, 30.0f}});
2399  auto add = builder.Add(m, md);
2400  Array2D<float> expected_array(
2401      {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}});
2402  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2403}
2404
2405XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) {
2406  // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2407  ComputationBuilder builder(client_, TestName());
2408  // m's shape in XLA notation is {3, 2}
2409  // md's shape in XLA notation is {1, 2}
2410  // The result has shape {3, 2}, where md is broadcast over m
2411  auto m =
2412      builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2413  auto md = builder.ConstantR2<float>({{10.0f}, {20.0f}});
2414  auto add = builder.Add(m, md);
2415  Array2D<float> expected_array(
2416      {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}});
2417  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2418}
2419
2420XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) {
2421  // Tests broadcasting for two degenerate arrays. This kind of broadcasting
2422  // effectively creates an "outer product" operation.
2423  // This is taken from the Numpy docs example at:
2424  // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html
2425  ComputationBuilder builder(client_, TestName());
2426  // a's shape in XLA notation is {1, 4}
2427  // b's shape in XLA notation is {3, 1}
2428  // The result has shape {3, 4}.
2429  auto a = builder.ConstantR2<float>({{0.0f}, {10.0f}, {20.0f}, {30.0f}});
2430  auto b = builder.ConstantR2<float>({{1.0f, 2.0f, 3.0f}});
2431  auto add = builder.Add(a, b);
2432  Array2D<float> expected_array({{1.0f, 2.0f, 3.0f},
2433                                 {11.0f, 12.0f, 13.0f},
2434                                 {21.0f, 22.0f, 23.0f},
2435                                 {31.0f, 32.0f, 33.0f}});
2436  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2437}
2438
2439XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
2440  // Add together a (2,2) array and a (2) array, using dimension 0 for
2441  // broadcasting (though there are two ways to broadcast these shapes).
2442  ComputationBuilder builder(client_, TestName());
2443  auto v = builder.ConstantR1<float>({20.0f, 40.0f});
2444  auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}});
2445  auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1});
2446  Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}});
2447  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2448}
2449
2450XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
2451  // Add together a (2,2) array and a (2) array, using dimension 1 for
2452  // broadcasting (though there are two ways to broadcast these shapes).
2453  ComputationBuilder builder(client_, TestName());
2454  auto v = builder.ConstantR1<float>({20.0f, 40.0f});
2455  auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}});
2456  auto add = builder.Add(v, m, /*broadcast_dimensions=*/{0});
2457  Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}});
2458  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2459}
2460
2461XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
2462  // Binary add of two R3s together
2463  ComputationBuilder builder(client_, TestName());
2464  Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2465                       {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2466  auto a = builder.ConstantR3FromArray3D<float>(a_3d);
2467
2468  Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}},
2469                       {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}});
2470  auto b = builder.ConstantR3FromArray3D<float>(b_3d);
2471  auto add = builder.Add(a, b);
2472
2473  Array3D<float> expected_3d(
2474      {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}},
2475       {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}});
2476  ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2477}
2478
2479XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) {
2480  // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for
2481  // broadcasting (though there are two ways to broadcast these shapes).
2482  ComputationBuilder builder(client_, TestName());
2483  // clang-format off
2484  Array3D<float> a_3d({
2485    {{1.0f, 2.0f},
2486     {3.0f, 4.0f},
2487     {5.0f, 6.0f}},
2488    {{7.0f, 8.0f},
2489     {9.0f, 10.0f},
2490     {11.0f, 12.0f}},
2491  });
2492  // clang-format on
2493  auto a = builder.ConstantR3FromArray3D<float>(a_3d);
2494  auto v = builder.ConstantR1<float>({10.0f, 20.0f});
2495  auto add = builder.Add(a, v, /*broadcast_dimensions=*/{2});
2496
2497  Array3D<float> expected_3d(
2498      {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}},
2499       {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}});
2500  ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2501}
2502
2503XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) {
2504  // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for
2505  // broadcasting (though there are two ways to broadcast these shapes).
2506  ComputationBuilder builder(client_, TestName());
2507  // clang-format off
2508  Array3D<float> a_3d({
2509    {{1.0f, 2.0f},
2510     {3.0f, 4.0f},
2511     {5.0f, 6.0f}},
2512    {{7.0f, 8.0f},
2513     {9.0f, 10.0f},
2514     {11.0f, 12.0f}},
2515  });
2516  // clang-format on
2517  auto a = builder.ConstantR3FromArray3D<float>(a_3d);
2518  auto v = builder.ConstantR1<float>({10.0f, 20.0f});
2519  auto add = builder.Add(a, v, /*broadcast_dimensions=*/{0});
2520
2521  // clang-format off
2522  Array3D<float> expected_3d({
2523    {{11.0f, 12.0f},
2524     {13.0f, 14.0f},
2525     {15.0f, 16.0f}},
2526    {{27.0f, 28.0f},
2527     {29.0f, 30.0f},
2528     {31.0f, 32.0f}},
2529  });
2530  // clang-format on
2531  ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2532}
2533
2534XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) {
2535  // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2}
2536  // for broadcasting.
2537  ComputationBuilder builder(client_, TestName());
2538  // clang-format off
2539  Array3D<float> a_3d({
2540    {{1.0f, 2.0f},
2541     {3.0f, 4.0f},
2542     {5.0f, 6.0f}},
2543    {{7.0f, 8.0f},
2544     {9.0f, 10.0f},
2545     {11.0f, 12.0f}},
2546  });
2547  auto a = builder.ConstantR3FromArray3D<float>(a_3d);
2548  auto m = builder.ConstantR2<float>({
2549    {10.0f, 20.0f, 30.0f},
2550    {40.0f, 50.0f, 60.0f},
2551  });
2552  auto add = builder.Add(a, m, /*broadcast_dimensions=*/{0, 1});
2553
2554  Array3D<float> expected_3d({
2555    {{11.0f, 12.0f},
2556     {23.0f, 24.0f},
2557     {35.0f, 36.0f}},
2558    {{47.0f, 48.0f},
2559     {59.0f, 60.0f},
2560     {71.0f, 72.0f}},
2561  });
2562  // clang-format on
2563  ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2564}
2565
2566XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
2567  // Comparison between two 3D arrays of compatible shapes:
2568  // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs.
2569  ComputationBuilder builder(client_, TestName());
2570  Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2571                       {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2572  auto a = builder.ConstantR3FromArray3D<float>(a_3d);
2573
2574  Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}});
2575  auto b = builder.ConstantR3FromArray3D<float>(b_3d);
2576
2577  auto compare = builder.Gt(a, b);
2578
2579  Array3D<int> expected_3d(
2580      {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
2581  const string expected = R"(pred[2,3,2] {
2582{ { 01 },
2583  { 00 },
2584  { 00 } },
2585{ { 01 },
2586  { 10 },
2587  { 01 } }
2588})";
2589  EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2590}
2591
2592XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
2593  ComputationBuilder builder(client_, TestName());
2594
2595  std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
2596  std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5));
2597  std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
2598  float value = 0.0;
2599  for (int64 p = 0; p < 2; ++p) {
2600    for (int64 z = 0; z < 3; ++z) {
2601      for (int64 y = 0; y < 4; ++y) {
2602        for (int64 x = 0; x < 5; ++x) {
2603          (*operand_a_4d)(p, z, y, x) = value;
2604          (*operand_b_4d)(p, z, y, x) = 2.0 * value;
2605          (*expected_4d)(p, z, y, x) = 3.0 * value;
2606          value += 0.1;
2607        }
2608      }
2609    }
2610  }
2611
2612  auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d);
2613  auto b = builder.ConstantR4FromArray4D<float>(*operand_b_4d);
2614  auto add = builder.Add(a, b);
2615
2616  ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
2617}
2618
2619XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
2620  ComputationBuilder builder(client_, TestName());
2621
2622  std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
2623  std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
2624  std::vector<float> operand_b_1d(3);
2625  std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0);
2626
2627  float value = 0.0;
2628  for (int64 p = 0; p < 2; ++p) {
2629    for (int64 z = 0; z < 3; ++z) {
2630      for (int64 y = 0; y < 4; ++y) {
2631        for (int64 x = 0; x < 5; ++x) {
2632          (*operand_a_4d)(p, z, y, x) = value;
2633          (*expected_4d)(p, z, y, x) = value + operand_b_1d[z];
2634          value += 0.1;
2635        }
2636      }
2637    }
2638  }
2639
2640  auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d);
2641  auto b = builder.ConstantR1<float>(operand_b_1d);
2642  auto add = builder.Add(a, b, {1});
2643
2644  ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
2645}
2646
2647XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
2648  constexpr int d0 = 16;
2649  constexpr int d1 = 16;
2650  constexpr int d2 = 2;
2651  constexpr int d3 = 2;
2652  Array4D<float> r4(d0, d1, d2, d3);
2653  r4.Fill(1.0);
2654  std::vector<float> r1(d1);
2655  std::iota(r1.begin(), r1.end(), 1.0);
2656
2657  ComputationBuilder builder(client_, TestName());
2658  std::unique_ptr<Literal> a_literal = Literal::CreateR4FromArray4DWithLayout(
2659      r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
2660  auto a = builder.ConstantLiteral(*a_literal);
2661  auto b = builder.ConstantR1<float>(r1);
2662  builder.Add(a, b, {1});
2663
2664  for (int i0 = 0; i0 < d0; ++i0) {
2665    for (int i1 = 0; i1 < d1; ++i1) {
2666      for (int i2 = 0; i2 < d2; ++i2) {
2667        for (int i3 = 0; i3 < d3; ++i3) {
2668          r4(i0, i1, i2, i3) += r1[i1];
2669        }
2670      }
2671    }
2672  }
2673  ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_);
2674}
2675
2676// Show that we can't add two opaques.
2677XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
2678  ComputationBuilder builder(client_, TestName());
2679  auto shape = ShapeUtil::MakeOpaqueShape();
2680  auto x = builder.Parameter(0, shape, "x");
2681  auto concatenated = builder.Add(x, x);
2682  StatusOr<Computation> computation_status = builder.Build();
2683  ASSERT_FALSE(computation_status.ok());
2684  EXPECT_THAT(computation_status.status().ToString(),
2685              ::testing::ContainsRegex(
2686                  "Expected non-opaque argument for lhs of binary operation"));
2687}
2688
2689XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
2690  ComputationBuilder builder(client_, TestName());
2691  auto a =
2692      builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2693  auto b =
2694      builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2695  auto add = builder.Add(a, b, /*broadcast_dimensions=*/{0, 1});
2696
2697  Array2D<float> expected_array(
2698      {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
2699  ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2700}
2701
2702XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
2703  ComputationBuilder builder(client_, TestName());
2704  auto a =
2705      builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2706  auto b =
2707      builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2708  auto add = builder.Add(a, b, /*broadcast_dimensions=*/{1, 0});
2709
2710  StatusOr<Computation> computation_status = builder.Build();
2711  ASSERT_FALSE(computation_status.ok());
2712  EXPECT_THAT(computation_status.status().error_message(),
2713              ::testing::ContainsRegex("must.*be the identity"));
2714}
2715
2716// Regression test for b/31927799. "slice - y" is fused and requires implicit
2717// broadcast.
2718XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
2719  ComputationBuilder builder(client_, TestName());
2720  auto x_literal = Literal::CreateR1<float>({1, 2, 3});
2721  auto y_literal = Literal::CreateR1<float>({4, 5});
2722  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
2723  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
2724
2725  auto x = builder.Parameter(0, x_literal->shape(), "x");
2726  auto y = builder.Parameter(1, y_literal->shape(), "y");
2727  auto slice = builder.Slice(x, {1}, {2}, {1});
2728  builder.Sub(slice, y);
2729
2730  ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
2731                             error_spec_);
2732}
2733
2734INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
2735                        ArrayElementwiseOpTestParamCount,
2736                        ::testing::Values(127, 128, 129, 17 * 4096));
2737
2738}  // namespace
2739}  // namespace xla
2740