1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Miscellaneous tests with the PRED type that don't fit anywhere else.
17#include <memory>
18
19#include "tensorflow/compiler/xla/array2d.h"
20#include "tensorflow/compiler/xla/client/computation_builder.h"
21#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22#include "tensorflow/compiler/xla/client/local_client.h"
23#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
24#include "tensorflow/core/lib/core/status_test_util.h"
25#include "tensorflow/core/platform/test.h"
26
27namespace xla {
28namespace {
29
30class PredTest : public ClientLibraryTestBase {
31 protected:
32  void TestCompare(bool lhs, bool rhs, bool expected,
33                   ComputationDataHandle (ComputationBuilder::*op)(
34                       const ComputationDataHandle&,
35                       const ComputationDataHandle&,
36                       tensorflow::gtl::ArraySlice<int64>)) {
37    ComputationBuilder builder(client_, TestName());
38    ComputationDataHandle lhs_op = builder.ConstantR0<bool>(lhs);
39    ComputationDataHandle rhs_op = builder.ConstantR0<bool>(rhs);
40    ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {});
41    ComputeAndCompareR0<bool>(&builder, expected, {});
42  }
43};
44
45TEST_F(PredTest, ConstantR0PredTrue) {
46  ComputationBuilder builder(client_, TestName());
47  auto a = builder.ConstantR0<bool>(true);
48  ComputeAndCompareR0<bool>(&builder, true, {});
49}
50
51TEST_F(PredTest, ConstantR0PredFalse) {
52  ComputationBuilder builder(client_, TestName());
53  auto a = builder.ConstantR0<bool>(false);
54  ComputeAndCompareR0<bool>(&builder, false, {});
55}
56
57TEST_F(PredTest, ConstantR0PredCompareEq) {
58  TestCompare(true, false, false, &ComputationBuilder::Eq);
59}
60
61TEST_F(PredTest, ConstantR0PredCompareNe) {
62  TestCompare(true, false, true, &ComputationBuilder::Ne);
63}
64
65TEST_F(PredTest, ConstantR0PredCompareLe) {
66  TestCompare(true, false, false, &ComputationBuilder::Le);
67}
68
69TEST_F(PredTest, ConstantR0PredCompareLt) {
70  TestCompare(true, false, false, &ComputationBuilder::Lt);
71}
72
73TEST_F(PredTest, ConstantR0PredCompareGe) {
74  TestCompare(true, false, true, &ComputationBuilder::Ge);
75}
76
77TEST_F(PredTest, ConstantR0PredCompareGt) {
78  TestCompare(true, false, true, &ComputationBuilder::Gt);
79}
80
81TEST_F(PredTest, ConstantR1Pred) {
82  ComputationBuilder builder(client_, TestName());
83  auto a = builder.ConstantR1<bool>({true, false, false, true});
84  ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
85}
86
87TEST_F(PredTest, ConstantR2Pred) {
88  ComputationBuilder builder(client_, TestName());
89  auto a =
90      builder.ConstantR2<bool>({{false, true, true}, {true, false, false}});
91  const string expected = R"(pred[2,3] {
92  { 011 },
93  { 100 }
94})";
95  EXPECT_EQ(expected, ExecuteToString(&builder, {}));
96}
97
98TEST_F(PredTest, AnyR1True) {
99  ComputationBuilder builder(client_, TestName());
100  auto a = builder.ConstantR1<bool>({true, false});
101  TF_ASSERT_OK(Any(a, &builder).status());
102  ComputeAndCompareR0<bool>(&builder, true, {});
103}
104
105TEST_F(PredTest, AnyR1False) {
106  ComputationBuilder builder(client_, TestName());
107  auto a = builder.ConstantR1<bool>({false, false});
108  TF_ASSERT_OK(Any(a, &builder).status());
109  ComputeAndCompareR0<bool>(&builder, false, {});
110}
111
112TEST_F(PredTest, AnyR1VacuouslyFalse) {
113  ComputationBuilder builder(client_, TestName());
114  auto a = builder.ConstantR1<bool>({});
115  TF_ASSERT_OK(Any(a, &builder).status());
116  ComputeAndCompareR0<bool>(&builder, false, {});
117}
118
119TEST_F(PredTest, AnyR2True) {
120  ComputationBuilder builder(client_, TestName());
121  auto a = builder.ConstantR2<bool>({
122      {false, false, false},
123      {false, false, false},
124      {false, false, true},
125  });
126  TF_ASSERT_OK(Any(a, &builder).status());
127  ComputeAndCompareR0<bool>(&builder, true, {});
128}
129
130TEST_F(PredTest, AnyR2False) {
131  ComputationBuilder builder(client_, TestName());
132  auto a = builder.ConstantR2<bool>({
133      {false, false, false},
134      {false, false, false},
135      {false, false, false},
136  });
137  TF_ASSERT_OK(Any(a, &builder).status());
138  ComputeAndCompareR0<bool>(&builder, false, {});
139}
140
141}  // namespace
142}  // namespace xla
143