client_channel_tests.cpp revision fd22b3e5ad1aae1fc3de54801f33466db3c9b3fe
1#include <uds/client_channel.h> 2 3#include <sys/socket.h> 4 5#include <algorithm> 6#include <limits> 7#include <random> 8#include <thread> 9 10#include <gmock/gmock.h> 11#include <gtest/gtest.h> 12 13#include <pdx/client.h> 14#include <pdx/rpc/remote_method.h> 15#include <pdx/service.h> 16 17#include <uds/client_channel_factory.h> 18#include <uds/service_endpoint.h> 19 20using testing::Return; 21using testing::_; 22 23using android::pdx::ClientBase; 24using android::pdx::LocalChannelHandle; 25using android::pdx::LocalHandle; 26using android::pdx::Message; 27using android::pdx::ServiceBase; 28using android::pdx::ServiceDispatcher; 29using android::pdx::Status; 30using android::pdx::rpc::DispatchRemoteMethod; 31using android::pdx::uds::ClientChannel; 32using android::pdx::uds::ClientChannelFactory; 33using android::pdx::uds::Endpoint; 34 35namespace { 36 37struct TestProtocol { 38 using DataType = int8_t; 39 enum { 40 kOpSum = 0, 41 }; 42 PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&)); 43}; 44 45class TestService : public ServiceBase<TestService> { 46 public: 47 TestService(std::unique_ptr<Endpoint> endpoint) 48 : ServiceBase{"TestService", std::move(endpoint)} {} 49 50 Status<void> HandleMessage(Message& message) override { 51 switch (message.GetOp()) { 52 case TestProtocol::kOpSum: 53 DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum, 54 message); 55 return {}; 56 57 default: 58 return Service::HandleMessage(message); 59 } 60 } 61 62 int64_t OnSum(Message& /*message*/, 63 const std::vector<TestProtocol::DataType>& data) { 64 return std::accumulate(data.begin(), data.end(), int64_t{0}); 65 } 66}; 67 68class TestClient : public ClientBase<TestClient> { 69 public: 70 using ClientBase::ClientBase; 71 72 int64_t Sum(const std::vector<TestProtocol::DataType>& data) { 73 auto status = InvokeRemoteMethod<TestProtocol::Sum>(data); 74 return status ? status.get() : -1; 75 } 76}; 77 78class TestServiceRunner { 79 public: 80 TestServiceRunner(LocalHandle channel_socket) { 81 auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{}); 82 endpoint->RegisterNewChannelForTests(std::move(channel_socket)); 83 service_ = TestService::Create(std::move(endpoint)); 84 dispatcher_ = android::pdx::uds::ServiceDispatcher::Create(); 85 dispatcher_->AddService(service_); 86 dispatch_thread_ = std::thread( 87 std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get())); 88 } 89 90 ~TestServiceRunner() { 91 dispatcher_->SetCanceled(true); 92 dispatch_thread_.join(); 93 dispatcher_->RemoveService(service_); 94 } 95 96 private: 97 std::shared_ptr<TestService> service_; 98 std::unique_ptr<ServiceDispatcher> dispatcher_; 99 std::thread dispatch_thread_; 100}; 101 102class ClientChannelTest : public testing::Test { 103 public: 104 void SetUp() override { 105 int channel_sockets[2] = {}; 106 ASSERT_EQ( 107 0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets)); 108 LocalHandle service_channel{channel_sockets[0]}; 109 LocalHandle client_channel{channel_sockets[1]}; 110 111 service_runner_.reset(new TestServiceRunner{std::move(service_channel)}); 112 auto factory = ClientChannelFactory::Create(std::move(client_channel)); 113 auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout); 114 ASSERT_TRUE(status); 115 client_ = TestClient::Create(status.take()); 116 } 117 118 void TearDown() override { 119 service_runner_.reset(); 120 client_.reset(); 121 } 122 123 protected: 124 std::unique_ptr<TestServiceRunner> service_runner_; 125 std::shared_ptr<TestClient> client_; 126}; 127 128TEST_F(ClientChannelTest, MultithreadedClient) { 129 constexpr int kNumTestThreads = 8; 130 constexpr size_t kDataSize = 1000; // Try to keep RPC buffer size below 4K. 131 132 std::random_device rd; 133 std::mt19937 gen{rd()}; 134 std::uniform_int_distribution<TestProtocol::DataType> dist{ 135 std::numeric_limits<TestProtocol::DataType>::min(), 136 std::numeric_limits<TestProtocol::DataType>::max()}; 137 138 auto worker = [](std::shared_ptr<TestClient> client, 139 std::vector<TestProtocol::DataType> data) { 140 constexpr int kMaxIterations = 500; 141 int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0}); 142 for (int i = 0; i < kMaxIterations; i++) { 143 ASSERT_EQ(expected, client->Sum(data)); 144 } 145 }; 146 147 // Start client threads. 148 std::vector<TestProtocol::DataType> data; 149 data.resize(kDataSize); 150 std::vector<std::thread> threads; 151 for (int i = 0; i < kNumTestThreads; i++) { 152 std::generate(data.begin(), data.end(), 153 [&dist, &gen]() { return dist(gen); }); 154 threads.emplace_back(worker, client_, data); 155 } 156 157 // Wait for threads to finish. 158 for (auto& thread : threads) 159 thread.join(); 160} 161 162} // namespace 163