1#include <uds/client_channel.h>
2
3#include <sys/socket.h>
4
5#include <algorithm>
6#include <limits>
7#include <random>
8#include <thread>
9
10#include <gmock/gmock.h>
11#include <gtest/gtest.h>
12
13#include <pdx/client.h>
14#include <pdx/rpc/remote_method.h>
15#include <pdx/service.h>
16
17#include <uds/client_channel_factory.h>
18#include <uds/service_endpoint.h>
19
20using testing::Return;
21using testing::_;
22
23using android::pdx::ClientBase;
24using android::pdx::LocalChannelHandle;
25using android::pdx::LocalHandle;
26using android::pdx::Message;
27using android::pdx::ServiceBase;
28using android::pdx::ServiceDispatcher;
29using android::pdx::Status;
30using android::pdx::rpc::DispatchRemoteMethod;
31using android::pdx::uds::ClientChannel;
32using android::pdx::uds::ClientChannelFactory;
33using android::pdx::uds::Endpoint;
34
35namespace {
36
37struct TestProtocol {
38  using DataType = int8_t;
39  enum {
40    kOpSum = 0,
41  };
42  PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&));
43};
44
45class TestService : public ServiceBase<TestService> {
46 public:
47  TestService(std::unique_ptr<Endpoint> endpoint)
48      : ServiceBase{"TestService", std::move(endpoint)} {}
49
50  Status<void> HandleMessage(Message& message) override {
51    switch (message.GetOp()) {
52      case TestProtocol::kOpSum:
53        DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum,
54                                                message);
55        return {};
56
57      default:
58        return Service::HandleMessage(message);
59    }
60  }
61
62  int64_t OnSum(Message& /*message*/,
63                const std::vector<TestProtocol::DataType>& data) {
64    return std::accumulate(data.begin(), data.end(), int64_t{0});
65  }
66};
67
68class TestClient : public ClientBase<TestClient> {
69 public:
70  using ClientBase::ClientBase;
71
72  int64_t Sum(const std::vector<TestProtocol::DataType>& data) {
73    auto status = InvokeRemoteMethod<TestProtocol::Sum>(data);
74    return status ? status.get() : -1;
75  }
76};
77
78class TestServiceRunner {
79 public:
80  TestServiceRunner(LocalHandle channel_socket) {
81    auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{});
82    endpoint->RegisterNewChannelForTests(std::move(channel_socket));
83    service_ = TestService::Create(std::move(endpoint));
84    dispatcher_ = android::pdx::uds::ServiceDispatcher::Create();
85    dispatcher_->AddService(service_);
86    dispatch_thread_ = std::thread(
87        std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get()));
88  }
89
90  ~TestServiceRunner() {
91    dispatcher_->SetCanceled(true);
92    dispatch_thread_.join();
93    dispatcher_->RemoveService(service_);
94  }
95
96 private:
97  std::shared_ptr<TestService> service_;
98  std::unique_ptr<ServiceDispatcher> dispatcher_;
99  std::thread dispatch_thread_;
100};
101
102class ClientChannelTest : public testing::Test {
103 public:
104  void SetUp() override {
105    int channel_sockets[2] = {};
106    ASSERT_EQ(
107        0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets));
108    LocalHandle service_channel{channel_sockets[0]};
109    LocalHandle client_channel{channel_sockets[1]};
110
111    service_runner_.reset(new TestServiceRunner{std::move(service_channel)});
112    auto factory = ClientChannelFactory::Create(std::move(client_channel));
113    auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout);
114    ASSERT_TRUE(status);
115    client_ = TestClient::Create(status.take());
116  }
117
118  void TearDown() override {
119    service_runner_.reset();
120    client_.reset();
121  }
122
123 protected:
124  std::unique_ptr<TestServiceRunner> service_runner_;
125  std::shared_ptr<TestClient> client_;
126};
127
128TEST_F(ClientChannelTest, MultithreadedClient) {
129  constexpr int kNumTestThreads = 8;
130  constexpr size_t kDataSize = 1000;  // Try to keep RPC buffer size below 4K.
131
132  std::random_device rd;
133  std::mt19937 gen{rd()};
134  std::uniform_int_distribution<TestProtocol::DataType> dist{
135      std::numeric_limits<TestProtocol::DataType>::min(),
136      std::numeric_limits<TestProtocol::DataType>::max()};
137
138  auto worker = [](std::shared_ptr<TestClient> client,
139                   std::vector<TestProtocol::DataType> data) {
140    constexpr int kMaxIterations = 500;
141    int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0});
142    for (int i = 0; i < kMaxIterations; i++) {
143      ASSERT_EQ(expected, client->Sum(data));
144    }
145  };
146
147  // Start client threads.
148  std::vector<TestProtocol::DataType> data;
149  data.resize(kDataSize);
150  std::vector<std::thread> threads;
151  for (int i = 0; i < kNumTestThreads; i++) {
152    std::generate(data.begin(), data.end(),
153                  [&dist, &gen]() { return dist(gen); });
154    threads.emplace_back(worker, client_, data);
155  }
156
157  // Wait for threads to finish.
158  for (auto& thread : threads)
159    thread.join();
160}
161
162}  // namespace
163