1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" 17 18#include <memory> 19 20#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" 21#include "tensorflow/compiler/xla/shape_util.h" 22#include "tensorflow/core/lib/core/status_test_util.h" 23#include "tensorflow/core/lib/core/threadpool.h" 24#include "tensorflow/core/platform/env.h" 25#include "tensorflow/core/platform/logging.h" 26#include "tensorflow/core/platform/test.h" 27 28namespace xla { 29namespace { 30 31class InfeedManagerTest : public ::testing::Test {}; 32 33class TestInfeedBuffer : public cpu::runtime::XfeedBuffer { 34 public: 35 explicit TestInfeedBuffer(int32 length, bool expect_shape_match = true) 36 : shape_(ShapeUtil::MakeShape(U8, {length})), 37 done_called_(false), 38 length_(length), 39 expect_shape_match_(expect_shape_match) {} 40 ~TestInfeedBuffer() override { EXPECT_TRUE(done_called_); } 41 42 int32 length() override { return length_; } 43 void* data() override { return nullptr; } 44 void Done(StatusOr<Shape> shape) override { 45 CHECK(!done_called_); 46 done_called_ = true; 47 TF_ASSERT_OK(shape.status()); 48 EXPECT_EQ(expect_shape_match_, ShapeUtil::Equal(shape_, shape.ValueOrDie())) 49 << "want " << ShapeUtil::HumanString(shape_) << " " 50 << (expect_shape_match_ ? "==" : "!=") << " " 51 << ShapeUtil::HumanString(shape.ValueOrDie()); 52 } 53 54 const Shape& shape() const { return shape_; } 55 56 private: 57 Shape shape_; 58 bool done_called_; 59 int32 length_; 60 bool expect_shape_match_; 61}; 62 63// Performs the acquire/release sequence on the infeed, as the generated CPU 64// code would in the process of executing the infeed operation. 65void ProcessNextBuffer(int32 length) { 66 auto shape = ShapeUtil::MakeShape(U8, {length}); 67 string bytes = shape.SerializeAsString(); 68 void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue( 69 length, bytes.data(), bytes.size()); 70 __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer, 71 bytes.data(), bytes.size()); 72} 73 74// Performs the acquire/release sequence on the outfeed, as the generated CPU 75// code would in the process of executing the outfeed operation. 76void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) { 77 string bytes = shape.SerializeAsString(); 78 void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( 79 length, bytes.data(), bytes.size()); 80 __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( 81 length, buffer, bytes.data(), bytes.size()); 82} 83 84TEST_F(InfeedManagerTest, SingleThreadedSequential) { 85 TestInfeedBuffer* a = new TestInfeedBuffer(64); 86 TestInfeedBuffer* b = new TestInfeedBuffer(32); 87 88 cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); 89 90 xfeed->infeed()->EnqueueBuffersAtomically({a}); 91 xfeed->infeed()->EnqueueBuffersAtomically({b}); 92 ProcessNextBuffer(a->length()); 93 ProcessNextBuffer(b->length()); 94} 95 96TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { 97 TestInfeedBuffer* a = new TestInfeedBuffer(64); 98 TestInfeedBuffer* b = new TestInfeedBuffer(32); 99 100 cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); 101 102 xfeed->infeed()->EnqueueBuffersAtomically({a}); 103 ProcessNextBuffer(a->length()); 104 xfeed->infeed()->EnqueueBuffersAtomically({b}); 105 ProcessNextBuffer(b->length()); 106} 107 108TEST_F(InfeedManagerTest, MultiThreaded) { 109 tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2); 110 111 cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); 112 113 const int32 length = 64; 114 115 pool.Schedule([xfeed]() { 116 // Spin for 100 milliseconds 117 int64 start_micros = tensorflow::Env::Default()->NowMicros(); 118 while (true) { 119 int64 end_micros = tensorflow::Env::Default()->NowMicros(); 120 if ((end_micros - start_micros) >= 100000) { // 100 ms 121 break; 122 } 123 } 124 TestInfeedBuffer* a = new TestInfeedBuffer(length); 125 xfeed->infeed()->EnqueueBuffersAtomically({a}); 126 }); 127 128 ProcessNextBuffer(length); 129} 130 131TEST_F(InfeedManagerTest, OutfeedWrongShape) { 132 TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false); 133 cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); 134 xfeed->outfeed()->EnqueueBuffersAtomically({b}); 135 136 ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33})); 137} 138 139} // namespace 140} // namespace xla 141