1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
17
18#include <memory>
19
20#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
21#include "tensorflow/compiler/xla/shape_util.h"
22#include "tensorflow/core/lib/core/status_test_util.h"
23#include "tensorflow/core/lib/core/threadpool.h"
24#include "tensorflow/core/platform/env.h"
25#include "tensorflow/core/platform/logging.h"
26#include "tensorflow/core/platform/test.h"
27
28namespace xla {
29namespace {
30
31class InfeedManagerTest : public ::testing::Test {};
32
33class TestInfeedBuffer : public cpu::runtime::XfeedBuffer {
34 public:
35  explicit TestInfeedBuffer(int32 length, bool expect_shape_match = true)
36      : shape_(ShapeUtil::MakeShape(U8, {length})),
37        done_called_(false),
38        length_(length),
39        expect_shape_match_(expect_shape_match) {}
40  ~TestInfeedBuffer() override { EXPECT_TRUE(done_called_); }
41
42  int32 length() override { return length_; }
43  void* data() override { return nullptr; }
44  void Done(StatusOr<Shape> shape) override {
45    CHECK(!done_called_);
46    done_called_ = true;
47    TF_ASSERT_OK(shape.status());
48    EXPECT_EQ(expect_shape_match_, ShapeUtil::Equal(shape_, shape.ValueOrDie()))
49        << "want " << ShapeUtil::HumanString(shape_) << " "
50        << (expect_shape_match_ ? "==" : "!=") << " "
51        << ShapeUtil::HumanString(shape.ValueOrDie());
52  }
53
54  const Shape& shape() const { return shape_; }
55
56 private:
57  Shape shape_;
58  bool done_called_;
59  int32 length_;
60  bool expect_shape_match_;
61};
62
63// Performs the acquire/release sequence on the infeed, as the generated CPU
64// code would in the process of executing the infeed operation.
65void ProcessNextBuffer(int32 length) {
66  auto shape = ShapeUtil::MakeShape(U8, {length});
67  string bytes = shape.SerializeAsString();
68  void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
69      length, bytes.data(), bytes.size());
70  __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer,
71                                                    bytes.data(), bytes.size());
72}
73
74// Performs the acquire/release sequence on the outfeed, as the generated CPU
75// code would in the process of executing the outfeed operation.
76void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) {
77  string bytes = shape.SerializeAsString();
78  void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
79      length, bytes.data(), bytes.size());
80  __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
81      length, buffer, bytes.data(), bytes.size());
82}
83
84TEST_F(InfeedManagerTest, SingleThreadedSequential) {
85  TestInfeedBuffer* a = new TestInfeedBuffer(64);
86  TestInfeedBuffer* b = new TestInfeedBuffer(32);
87
88  cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
89
90  xfeed->infeed()->EnqueueBuffersAtomically({a});
91  xfeed->infeed()->EnqueueBuffersAtomically({b});
92  ProcessNextBuffer(a->length());
93  ProcessNextBuffer(b->length());
94}
95
96TEST_F(InfeedManagerTest, SingleThreadedInterleaved) {
97  TestInfeedBuffer* a = new TestInfeedBuffer(64);
98  TestInfeedBuffer* b = new TestInfeedBuffer(32);
99
100  cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
101
102  xfeed->infeed()->EnqueueBuffersAtomically({a});
103  ProcessNextBuffer(a->length());
104  xfeed->infeed()->EnqueueBuffersAtomically({b});
105  ProcessNextBuffer(b->length());
106}
107
108TEST_F(InfeedManagerTest, MultiThreaded) {
109  tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2);
110
111  cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
112
113  const int32 length = 64;
114
115  pool.Schedule([xfeed]() {
116    // Spin for 100 milliseconds
117    int64 start_micros = tensorflow::Env::Default()->NowMicros();
118    while (true) {
119      int64 end_micros = tensorflow::Env::Default()->NowMicros();
120      if ((end_micros - start_micros) >= 100000) {  // 100 ms
121        break;
122      }
123    }
124    TestInfeedBuffer* a = new TestInfeedBuffer(length);
125    xfeed->infeed()->EnqueueBuffersAtomically({a});
126  });
127
128  ProcessNextBuffer(length);
129}
130
131TEST_F(InfeedManagerTest, OutfeedWrongShape) {
132  TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false);
133  cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
134  xfeed->outfeed()->EnqueueBuffersAtomically({b});
135
136  ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33}));
137}
138
139}  // namespace
140}  // namespace xla
141