client_channel_tests.cpp revision 5a244ed36c8e45fd95b89ff916caf083fb182ec1
1#include <uds/client_channel.h>
2
3#include <sys/socket.h>
4
5#include <algorithm>
6#include <limits>
7#include <random>
8#include <thread>
9
10#include <gmock/gmock.h>
11#include <gtest/gtest.h>
12
13#include <pdx/client.h>
14#include <pdx/rpc/remote_method.h>
15#include <pdx/service.h>
16#include <pdx/service_dispatcher.h>
17
18#include <uds/client_channel_factory.h>
19#include <uds/service_endpoint.h>
20
21using testing::Return;
22using testing::_;
23
24using android::pdx::ClientBase;
25using android::pdx::LocalChannelHandle;
26using android::pdx::LocalHandle;
27using android::pdx::Message;
28using android::pdx::ServiceBase;
29using android::pdx::ServiceDispatcher;
30using android::pdx::Status;
31using android::pdx::rpc::DispatchRemoteMethod;
32using android::pdx::uds::ClientChannel;
33using android::pdx::uds::ClientChannelFactory;
34using android::pdx::uds::Endpoint;
35
36namespace {
37
38struct TestProtocol {
39  using DataType = int8_t;
40  enum {
41    kOpSum = 0,
42  };
43  PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&));
44};
45
46class TestService : public ServiceBase<TestService> {
47 public:
48  TestService(std::unique_ptr<Endpoint> endpoint)
49      : ServiceBase{"TestService", std::move(endpoint)} {}
50
51  Status<void> HandleMessage(Message& message) override {
52    switch (message.GetOp()) {
53      case TestProtocol::kOpSum:
54        DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum,
55                                                message);
56        return {};
57
58      default:
59        return Service::HandleMessage(message);
60    }
61  }
62
63  int64_t OnSum(Message& /*message*/,
64                const std::vector<TestProtocol::DataType>& data) {
65    return std::accumulate(data.begin(), data.end(), int64_t{0});
66  }
67};
68
69class TestClient : public ClientBase<TestClient> {
70 public:
71  using ClientBase::ClientBase;
72
73  int64_t Sum(const std::vector<TestProtocol::DataType>& data) {
74    auto status = InvokeRemoteMethod<TestProtocol::Sum>(data);
75    return status ? status.get() : -1;
76  }
77};
78
79class TestServiceRunner {
80 public:
81  TestServiceRunner(LocalHandle channel_socket) {
82    auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{});
83    endpoint->RegisterNewChannelForTests(std::move(channel_socket));
84    service_ = TestService::Create(std::move(endpoint));
85    dispatcher_ = ServiceDispatcher::Create();
86    dispatcher_->AddService(service_);
87    dispatch_thread_ = std::thread(
88        std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get()));
89  }
90
91  ~TestServiceRunner() {
92    dispatcher_->SetCanceled(true);
93    dispatch_thread_.join();
94    dispatcher_->RemoveService(service_);
95  }
96
97 private:
98  std::shared_ptr<TestService> service_;
99  std::unique_ptr<ServiceDispatcher> dispatcher_;
100  std::thread dispatch_thread_;
101};
102
103class ClientChannelTest : public testing::Test {
104 public:
105  void SetUp() override {
106    int channel_sockets[2] = {};
107    ASSERT_EQ(
108        0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets));
109    LocalHandle service_channel{channel_sockets[0]};
110    LocalHandle client_channel{channel_sockets[1]};
111
112    service_runner_.reset(new TestServiceRunner{std::move(service_channel)});
113    auto factory = ClientChannelFactory::Create(std::move(client_channel));
114    auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout);
115    ASSERT_TRUE(status);
116    client_ = TestClient::Create(status.take());
117  }
118
119  void TearDown() override {
120    service_runner_.reset();
121    client_.reset();
122  }
123
124 protected:
125  std::unique_ptr<TestServiceRunner> service_runner_;
126  std::shared_ptr<TestClient> client_;
127};
128
129TEST_F(ClientChannelTest, MultithreadedClient) {
130  constexpr int kNumTestThreads = 8;
131  constexpr size_t kDataSize = 1000;  // Try to keep RPC buffer size below 4K.
132
133  std::random_device rd;
134  std::mt19937 gen{rd()};
135  std::uniform_int_distribution<TestProtocol::DataType> dist{
136      std::numeric_limits<TestProtocol::DataType>::min(),
137      std::numeric_limits<TestProtocol::DataType>::max()};
138
139  auto worker = [](std::shared_ptr<TestClient> client,
140                   std::vector<TestProtocol::DataType> data) {
141    constexpr int kMaxIterations = 500;
142    int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0});
143    for (int i = 0; i < kMaxIterations; i++) {
144      ASSERT_EQ(expected, client->Sum(data));
145    }
146  };
147
148  // Start client threads.
149  std::vector<TestProtocol::DataType> data;
150  data.resize(kDataSize);
151  std::vector<std::thread> threads;
152  for (int i = 0; i < kNumTestThreads; i++) {
153    std::generate(data.begin(), data.end(),
154                  [&dist, &gen]() { return dist(gen); });
155    threads.emplace_back(worker, client_, data);
156  }
157
158  // Wait for threads to finish.
159  for (auto& thread : threads)
160    thread.join();
161}
162
163}  // namespace
164