1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <initializer_list> 17#include <memory> 18#include <string> 19 20#include "tensorflow/compiler/xla/client/computation.h" 21#include "tensorflow/compiler/xla/client/computation_builder.h" 22#include "tensorflow/compiler/xla/client/global_data.h" 23#include "tensorflow/compiler/xla/client/local_client.h" 24#include "tensorflow/compiler/xla/literal_util.h" 25#include "tensorflow/compiler/xla/shape_util.h" 26#include "tensorflow/compiler/xla/statusor.h" 27#include "tensorflow/compiler/xla/tests/client_library_test_base.h" 28#include "tensorflow/compiler/xla/tests/literal_test_util.h" 29#include "tensorflow/compiler/xla/tests/test_macros.h" 30#include "tensorflow/compiler/xla/tests/test_utils.h" 31#include "tensorflow/compiler/xla/xla.pb.h" 32#include "tensorflow/compiler/xla/xla_data.pb.h" 33#include "tensorflow/core/lib/gtl/array_slice.h" 34#include "tensorflow/core/platform/test.h" 35 36namespace xla { 37namespace { 38 39class CompilationCacheTest : public ClientLibraryTestBase { 40 public: 41 void ExecuteComputationR0F32( 42 const Computation& computation, 43 tensorflow::gtl::ArraySlice<GlobalData*> arguments, float expected_result, 44 bool expect_cache_hit) { 45 ExecutionProfile execution_profile; 46 std::unique_ptr<Literal> result = 47 client_ 48 ->ExecuteAndTransfer(computation, arguments, 49 /*execution_options=*/&execution_options_, 50 &execution_profile) 51 .ConsumeValueOrDie(); 52 LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(expected_result), 53 *result, error_spec_); 54 EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); 55 } 56 57 void ExecuteComputationR2F32( 58 const Computation& computation, 59 tensorflow::gtl::ArraySlice<GlobalData*> arguments, 60 std::initializer_list<std::initializer_list<float>> expected_result, 61 bool expect_cache_hit) { 62 ExecutionProfile execution_profile; 63 auto data_handle = client_ 64 ->Execute(computation, arguments, 65 &execution_options_, &execution_profile) 66 .ConsumeValueOrDie(); 67 std::unique_ptr<Literal> result = 68 client_->Transfer(*data_handle).ConsumeValueOrDie(); 69 LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>(expected_result), 70 *result, error_spec_); 71 EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); 72 } 73 74 ErrorSpec error_spec_{0.0001}; 75}; 76 77XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) { 78 ComputationBuilder builder(client_, TestName()); 79 builder.Neg(builder.ConstantR0<float>(42.0)); 80 Computation computation = builder.Build().ConsumeValueOrDie(); 81 82 ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); 83 ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); 84 ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); 85} 86 87XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { 88 std::unique_ptr<GlobalData> data_42 = 89 client_->TransferToServer(*Literal::CreateR0<float>(42.0f)) 90 .ConsumeValueOrDie(); 91 std::unique_ptr<GlobalData> data_123 = 92 client_->TransferToServer(*Literal::CreateR0<float>(123.0f)) 93 .ConsumeValueOrDie(); 94 std::unique_ptr<GlobalData> data_456 = 95 client_->TransferToServer(*Literal::CreateR0<float>(456.0f)) 96 .ConsumeValueOrDie(); 97 98 ComputationBuilder builder(client_, TestName()); 99 builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); 100 Computation computation = builder.Build().ConsumeValueOrDie(); 101 102 ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, 103 /*expect_cache_hit=*/false); 104 ExecuteComputationR0F32(computation, {data_123.get()}, -123.0, 105 /*expect_cache_hit=*/true); 106 ExecuteComputationR0F32(computation, {data_456.get()}, -456.0, 107 /*expect_cache_hit=*/true); 108 ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, 109 /*expect_cache_hit=*/true); 110} 111 112XLA_TEST_F(CompilationCacheTest, MultipleComputations) { 113 ComputationBuilder builder_neg(client_, TestName() + "_neg"); 114 builder_neg.Neg(builder_neg.ConstantR0<float>(42.0)); 115 Computation computation_neg = builder_neg.Build().ConsumeValueOrDie(); 116 117 ComputationBuilder builder_exp(client_, TestName() + "_exp"); 118 builder_exp.Exp(builder_exp.ConstantR0<float>(1.0)); 119 Computation computation_exp = builder_exp.Build().ConsumeValueOrDie(); 120 121 ComputationBuilder builder_add(client_, TestName() + "_add"); 122 builder_add.Add(builder_add.ConstantR0<float>(2.0), 123 builder_add.ConstantR0<float>(3.0)); 124 Computation computation_add = builder_add.Build().ConsumeValueOrDie(); 125 126 ExecuteComputationR0F32(computation_neg, {}, -42.0, 127 /*expect_cache_hit=*/false); 128 ExecuteComputationR0F32(computation_exp, {}, 2.7182817, 129 /*expect_cache_hit=*/false); 130 ExecuteComputationR0F32(computation_add, {}, 5.0, 131 /*expect_cache_hit=*/false); 132 ExecuteComputationR0F32(computation_neg, {}, -42.0, 133 /*expect_cache_hit=*/true); 134} 135 136XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { 137 // Create two GlobalData arrays with the same shape but different 138 // layouts. Use these arrays as parameters to a simple computation. If the 139 // layout of the array changes then computation should be recompiled (cache 140 // miss). 141 auto rowmaj_array = Literal::CreateR2WithLayout( 142 {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); 143 auto rowmaj_handle = 144 client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); 145 146 auto colmaj_array = Literal::CreateR2WithLayout( 147 {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); 148 auto colmaj_handle = 149 client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); 150 151 ComputationBuilder builder(client_, TestName()); 152 builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); 153 Computation computation = builder.Build().ConsumeValueOrDie(); 154 155 ExecuteComputationR2F32(computation, {colmaj_handle.get()}, 156 {{1.0f, 2.0f}, {3.0f, 4.0f}}, 157 /*expect_cache_hit=*/false); 158 ExecuteComputationR2F32(computation, {colmaj_handle.get()}, 159 {{1.0f, 2.0f}, {3.0f, 4.0f}}, 160 /*expect_cache_hit=*/true); 161 ExecuteComputationR2F32(computation, {rowmaj_handle.get()}, 162 {{1.0f, 2.0f}, {3.0f, 4.0f}}, 163 /*expect_cache_hit=*/false); 164 ExecuteComputationR2F32(computation, {rowmaj_handle.get()}, 165 {{1.0f, 2.0f}, {3.0f, 4.0f}}, 166 /*expect_cache_hit=*/true); 167 ExecuteComputationR2F32(computation, {colmaj_handle.get()}, 168 {{1.0f, 2.0f}, {3.0f, 4.0f}}, 169 /*expect_cache_hit=*/true); 170} 171 172XLA_TEST_F(CompilationCacheTest, MutatedComputation) { 173 // Build a computation, execute it, then mutate it. The mutated computation 174 // should not be in the cache until it is run once. This must be done through 175 // the stub interface because Computations built from ComputationBuilder are 176 // immutable. 177 ComputationBuilder builder(client_, TestName()); 178 auto neg = builder.Neg(builder.ConstantR0<float>(42.0)); 179 Computation computation = builder.Build().ConsumeValueOrDie(); 180 181 ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); 182 ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); 183 184 BinaryOpRequest request; 185 request.set_binop(BINOP_ADD); 186 *request.mutable_lhs() = neg; 187 *request.mutable_rhs() = neg; 188 OpRequest op_request; 189 *op_request.mutable_computation() = computation.handle(); 190 *op_request.mutable_binary_op_request() = request; 191 OpResponse response; 192 tensorflow::Status s = client_->stub()->Op(&op_request, &response); 193 ASSERT_TRUE(s.ok()); 194 195 ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/false); 196 ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/true); 197} 198 199} // namespace 200} // namespace xla 201