1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17
18#include "tensorflow/compiler/xla/client/computation.h"
19#include "tensorflow/compiler/xla/client/computation_builder.h"
20#include "tensorflow/compiler/xla/client/global_data.h"
21#include "tensorflow/compiler/xla/client/local_client.h"
22#include "tensorflow/compiler/xla/literal_util.h"
23#include "tensorflow/compiler/xla/protobuf_util.h"
24#include "tensorflow/compiler/xla/service/session.pb.h"
25#include "tensorflow/compiler/xla/shape_util.h"
26#include "tensorflow/compiler/xla/statusor.h"
27#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28#include "tensorflow/compiler/xla/tests/literal_test_util.h"
29#include "tensorflow/compiler/xla/tests/test_macros.h"
30#include "tensorflow/compiler/xla/xla_data.pb.h"
31#include "tensorflow/core/platform/test.h"
32#include "tensorflow/core/platform/types.h"
33
34namespace xla {
35namespace {
36
37class ReplayTest : public ClientLibraryTestBase {};
38
39TEST_F(ReplayTest, TwoPlusTwoReplay) {
40  // Make 2+2 computation.
41  ComputationBuilder builder(client_, TestName());
42  auto two = builder.ConstantR0<int32>(2);
43  builder.Add(two, two);
44  Computation computation = builder.Build().ConsumeValueOrDie();
45
46  // Serialize it out.
47  std::unique_ptr<SessionModule> module =
48      computation.Snapshot().ConsumeValueOrDie();
49
50  // Replay it.
51  Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
52
53  // Check signature is the same.
54  std::unique_ptr<ProgramShape> original_shape =
55      client_->GetComputationShape(computation).ConsumeValueOrDie();
56  std::unique_ptr<ProgramShape> replayed_shape =
57      client_->GetComputationShape(replayed).ConsumeValueOrDie();
58  ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
59
60  // Run it.
61  std::unique_ptr<Literal> literal =
62      client_
63          ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
64          .ConsumeValueOrDie();
65
66  // Expect 4.
67  LiteralTestUtil::ExpectR0Equal<int32>(4, *literal);
68}
69
70XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
71  // Make computation.
72  ComputationBuilder builder(client_, TestName());
73  auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
74  auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y");
75  builder.Add(x, y);
76  Computation computation = builder.Build().ConsumeValueOrDie();
77
78  // Serialize it out.
79  std::unique_ptr<SessionModule> module =
80      computation.Snapshot().ConsumeValueOrDie();
81
82  // Replay it.
83  Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
84
85  // Check signature is the same.
86  std::unique_ptr<ProgramShape> original_shape =
87      client_->GetComputationShape(computation).ConsumeValueOrDie();
88  std::unique_ptr<ProgramShape> replayed_shape =
89      client_->GetComputationShape(replayed).ConsumeValueOrDie();
90  ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
91
92  // Run it.
93  std::unique_ptr<GlobalData> x_data =
94      client_->TransferToServer(*Literal::CreateR0<int32>(2))
95          .ConsumeValueOrDie();
96  std::unique_ptr<GlobalData> y_data =
97      client_->TransferToServer(*Literal::CreateR0<int32>(3))
98          .ConsumeValueOrDie();
99  std::unique_ptr<Literal> literal =
100      client_
101          ->ExecuteAndTransfer(replayed,
102                               /*arguments=*/{x_data.get(), y_data.get()},
103                               &execution_options_)
104          .ConsumeValueOrDie();
105
106  // Expect 5.
107  LiteralTestUtil::ExpectR0Equal<int32>(5, *literal);
108}
109
110TEST_F(ReplayTest, MapPlusTwoOverR1) {
111  // As above, but with map(+2) over some constant array.
112  ComputationBuilder plus_two_builder(client_, "plus two");
113  auto input =
114      plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input");
115  plus_two_builder.Add(input, plus_two_builder.ConstantR0<int32>(2));
116  Computation plus_two = plus_two_builder.Build().ConsumeValueOrDie();
117
118  ComputationBuilder mapper_builder(client_, TestName());
119  auto original = mapper_builder.ConstantR1<int32>({1, 2, 3});
120  mapper_builder.Map({original}, plus_two, {0});
121
122  Computation computation = mapper_builder.Build().ConsumeValueOrDie();
123
124  // Serialize it out.
125  std::unique_ptr<SessionModule> module =
126      computation.Snapshot().ConsumeValueOrDie();
127
128  // Replay it.
129  Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
130
131  // Check signature is the same.
132  std::unique_ptr<ProgramShape> original_shape =
133      client_->GetComputationShape(computation).ConsumeValueOrDie();
134  std::unique_ptr<ProgramShape> replayed_shape =
135      client_->GetComputationShape(replayed).ConsumeValueOrDie();
136  ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
137
138  // Destroy the originals.
139  computation.Reset();
140  plus_two.Reset();
141
142  // Run it.
143  std::unique_ptr<Literal> literal =
144      client_
145          ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
146          .ConsumeValueOrDie();
147
148  // Expect result.
149  LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, *literal);
150}
151
152}  // namespace
153}  // namespace xla
154