1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <initializer_list>
17#include <memory>
18#include <string>
19
20#include "tensorflow/compiler/xla/client/computation.h"
21#include "tensorflow/compiler/xla/client/computation_builder.h"
22#include "tensorflow/compiler/xla/client/global_data.h"
23#include "tensorflow/compiler/xla/client/local_client.h"
24#include "tensorflow/compiler/xla/literal_util.h"
25#include "tensorflow/compiler/xla/shape_util.h"
26#include "tensorflow/compiler/xla/statusor.h"
27#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28#include "tensorflow/compiler/xla/tests/literal_test_util.h"
29#include "tensorflow/compiler/xla/tests/test_macros.h"
30#include "tensorflow/compiler/xla/tests/test_utils.h"
31#include "tensorflow/compiler/xla/xla.pb.h"
32#include "tensorflow/compiler/xla/xla_data.pb.h"
33#include "tensorflow/core/lib/gtl/array_slice.h"
34#include "tensorflow/core/platform/test.h"
35
36namespace xla {
37namespace {
38
39class CompilationCacheTest : public ClientLibraryTestBase {
40 public:
41  void ExecuteComputationR0F32(
42      const Computation& computation,
43      tensorflow::gtl::ArraySlice<GlobalData*> arguments, float expected_result,
44      bool expect_cache_hit) {
45    ExecutionProfile execution_profile;
46    std::unique_ptr<Literal> result =
47        client_
48            ->ExecuteAndTransfer(computation, arguments,
49                                 /*execution_options=*/&execution_options_,
50                                 &execution_profile)
51            .ConsumeValueOrDie();
52    LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(expected_result),
53                                *result, error_spec_);
54    EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
55  }
56
57  void ExecuteComputationR2F32(
58      const Computation& computation,
59      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
60      std::initializer_list<std::initializer_list<float>> expected_result,
61      bool expect_cache_hit) {
62    ExecutionProfile execution_profile;
63    auto data_handle = client_
64                           ->Execute(computation, arguments,
65                                     &execution_options_, &execution_profile)
66                           .ConsumeValueOrDie();
67    std::unique_ptr<Literal> result =
68        client_->Transfer(*data_handle).ConsumeValueOrDie();
69    LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>(expected_result),
70                                *result, error_spec_);
71    EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
72  }
73
74  ErrorSpec error_spec_{0.0001};
75};
76
77XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) {
78  ComputationBuilder builder(client_, TestName());
79  builder.Neg(builder.ConstantR0<float>(42.0));
80  Computation computation = builder.Build().ConsumeValueOrDie();
81
82  ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false);
83  ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
84  ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
85}
86
87XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) {
88  std::unique_ptr<GlobalData> data_42 =
89      client_->TransferToServer(*Literal::CreateR0<float>(42.0f))
90          .ConsumeValueOrDie();
91  std::unique_ptr<GlobalData> data_123 =
92      client_->TransferToServer(*Literal::CreateR0<float>(123.0f))
93          .ConsumeValueOrDie();
94  std::unique_ptr<GlobalData> data_456 =
95      client_->TransferToServer(*Literal::CreateR0<float>(456.0f))
96          .ConsumeValueOrDie();
97
98  ComputationBuilder builder(client_, TestName());
99  builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
100  Computation computation = builder.Build().ConsumeValueOrDie();
101
102  ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
103                          /*expect_cache_hit=*/false);
104  ExecuteComputationR0F32(computation, {data_123.get()}, -123.0,
105                          /*expect_cache_hit=*/true);
106  ExecuteComputationR0F32(computation, {data_456.get()}, -456.0,
107                          /*expect_cache_hit=*/true);
108  ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
109                          /*expect_cache_hit=*/true);
110}
111
112XLA_TEST_F(CompilationCacheTest, MultipleComputations) {
113  ComputationBuilder builder_neg(client_, TestName() + "_neg");
114  builder_neg.Neg(builder_neg.ConstantR0<float>(42.0));
115  Computation computation_neg = builder_neg.Build().ConsumeValueOrDie();
116
117  ComputationBuilder builder_exp(client_, TestName() + "_exp");
118  builder_exp.Exp(builder_exp.ConstantR0<float>(1.0));
119  Computation computation_exp = builder_exp.Build().ConsumeValueOrDie();
120
121  ComputationBuilder builder_add(client_, TestName() + "_add");
122  builder_add.Add(builder_add.ConstantR0<float>(2.0),
123                  builder_add.ConstantR0<float>(3.0));
124  Computation computation_add = builder_add.Build().ConsumeValueOrDie();
125
126  ExecuteComputationR0F32(computation_neg, {}, -42.0,
127                          /*expect_cache_hit=*/false);
128  ExecuteComputationR0F32(computation_exp, {}, 2.7182817,
129                          /*expect_cache_hit=*/false);
130  ExecuteComputationR0F32(computation_add, {}, 5.0,
131                          /*expect_cache_hit=*/false);
132  ExecuteComputationR0F32(computation_neg, {}, -42.0,
133                          /*expect_cache_hit=*/true);
134}
135
136XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) {
137  // Create two GlobalData arrays with the same shape but different
138  // layouts. Use these arrays as parameters to a simple computation. If the
139  // layout of the array changes then computation should be recompiled (cache
140  // miss).
141  auto rowmaj_array = Literal::CreateR2WithLayout(
142      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
143  auto rowmaj_handle =
144      client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
145
146  auto colmaj_array = Literal::CreateR2WithLayout(
147      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
148  auto colmaj_handle =
149      client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
150
151  ComputationBuilder builder(client_, TestName());
152  builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
153  Computation computation = builder.Build().ConsumeValueOrDie();
154
155  ExecuteComputationR2F32(computation, {colmaj_handle.get()},
156                          {{1.0f, 2.0f}, {3.0f, 4.0f}},
157                          /*expect_cache_hit=*/false);
158  ExecuteComputationR2F32(computation, {colmaj_handle.get()},
159                          {{1.0f, 2.0f}, {3.0f, 4.0f}},
160                          /*expect_cache_hit=*/true);
161  ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
162                          {{1.0f, 2.0f}, {3.0f, 4.0f}},
163                          /*expect_cache_hit=*/false);
164  ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
165                          {{1.0f, 2.0f}, {3.0f, 4.0f}},
166                          /*expect_cache_hit=*/true);
167  ExecuteComputationR2F32(computation, {colmaj_handle.get()},
168                          {{1.0f, 2.0f}, {3.0f, 4.0f}},
169                          /*expect_cache_hit=*/true);
170}
171
172XLA_TEST_F(CompilationCacheTest, MutatedComputation) {
173  // Build a computation, execute it, then mutate it. The mutated computation
174  // should not be in the cache until it is run once. This must be done through
175  // the stub interface because Computations built from ComputationBuilder are
176  // immutable.
177  ComputationBuilder builder(client_, TestName());
178  auto neg = builder.Neg(builder.ConstantR0<float>(42.0));
179  Computation computation = builder.Build().ConsumeValueOrDie();
180
181  ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false);
182  ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
183
184  BinaryOpRequest request;
185  request.set_binop(BINOP_ADD);
186  *request.mutable_lhs() = neg;
187  *request.mutable_rhs() = neg;
188  OpRequest op_request;
189  *op_request.mutable_computation() = computation.handle();
190  *op_request.mutable_binary_op_request() = request;
191  OpResponse response;
192  tensorflow::Status s = client_->stub()->Op(&op_request, &response);
193  ASSERT_TRUE(s.ok());
194
195  ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/false);
196  ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/true);
197}
198
199}  // namespace
200}  // namespace xla
201