1#include "uds/channel_parcelable.h"
2#include "uds/client_channel.h"
3
4#include <errno.h>
5#include <log/log.h>
6#include <sys/epoll.h>
7#include <sys/socket.h>
8
9#include <pdx/client.h>
10#include <pdx/service_endpoint.h>
11#include <uds/ipc_helper.h>
12
13namespace android {
14namespace pdx {
15namespace uds {
16
17namespace {
18
19struct TransactionState {
20  bool GetLocalFileHandle(int index, LocalHandle* handle) {
21    if (index < 0) {
22      handle->Reset(index);
23    } else if (static_cast<size_t>(index) < response.file_descriptors.size()) {
24      *handle = std::move(response.file_descriptors[index]);
25    } else {
26      return false;
27    }
28    return true;
29  }
30
31  bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) {
32    if (index < 0) {
33      *handle = LocalChannelHandle{nullptr, index};
34    } else if (static_cast<size_t>(index) < response.channels.size()) {
35      auto& channel_info = response.channels[index];
36      *handle = ChannelManager::Get().CreateHandle(
37          std::move(channel_info.data_fd),
38          std::move(channel_info.pollin_event_fd),
39          std::move(channel_info.pollhup_event_fd));
40    } else {
41      return false;
42    }
43    return true;
44  }
45
46  FileReference PushFileHandle(BorrowedHandle handle) {
47    if (!handle)
48      return handle.Get();
49    request.file_descriptors.push_back(std::move(handle));
50    return request.file_descriptors.size() - 1;
51  }
52
53  ChannelReference PushChannelHandle(BorrowedChannelHandle handle) {
54    if (!handle)
55      return handle.value();
56
57    if (auto* channel_data =
58            ChannelManager::Get().GetChannelData(handle.value())) {
59      ChannelInfo<BorrowedHandle> channel_info{
60          channel_data->data_fd(), channel_data->pollin_event_fd(),
61          channel_data->pollhup_event_fd()};
62      request.channels.push_back(std::move(channel_info));
63      return request.channels.size() - 1;
64    } else {
65      return -1;
66    }
67  }
68
69  RequestHeader<BorrowedHandle> request;
70  ResponseHeader<LocalHandle> response;
71};
72
73Status<void> ReadAndDiscardData(const BorrowedHandle& socket_fd, size_t size) {
74  while (size > 0) {
75    // If there is more data to read in the message than the buffers provided
76    // by the caller, read and discard the extra data from the socket.
77    char buffer[1024];
78    size_t size_to_read = std::min(sizeof(buffer), size);
79    auto status = ReceiveData(socket_fd, buffer, size_to_read);
80    if (!status)
81      return status;
82    size -= size_to_read;
83  }
84  // We still want to return EIO error to the caller in case we had unexpected
85  // data in the socket stream.
86  return ErrorStatus(EIO);
87}
88
89Status<void> SendRequest(const BorrowedHandle& socket_fd,
90                         TransactionState* transaction_state, int opcode,
91                         const iovec* send_vector, size_t send_count,
92                         size_t max_recv_len) {
93  size_t send_len = CountVectorSize(send_vector, send_count);
94  InitRequest(&transaction_state->request, opcode, send_len, max_recv_len,
95              false);
96  if (send_len == 0) {
97    send_vector = nullptr;
98    send_count = 0;
99  }
100  return SendData(socket_fd, transaction_state->request, send_vector,
101                  send_count);
102}
103
104Status<void> ReceiveResponse(const BorrowedHandle& socket_fd,
105                             TransactionState* transaction_state,
106                             const iovec* receive_vector, size_t receive_count,
107                             size_t max_recv_len) {
108  auto status = ReceiveData(socket_fd, &transaction_state->response);
109  if (!status)
110    return status;
111
112  if (transaction_state->response.recv_len > 0) {
113    std::vector<iovec> read_buffers;
114    size_t size_remaining = 0;
115    if (transaction_state->response.recv_len != max_recv_len) {
116      // If the receive buffer not exactly the size of data available, recreate
117      // the vector list to consume the data exactly since ReceiveDataVector()
118      // validates that the number of bytes received equals the number of bytes
119      // requested.
120      size_remaining = transaction_state->response.recv_len;
121      for (size_t i = 0; i < receive_count && size_remaining > 0; i++) {
122        read_buffers.push_back(receive_vector[i]);
123        iovec& last_vec = read_buffers.back();
124        if (last_vec.iov_len > size_remaining)
125          last_vec.iov_len = size_remaining;
126        size_remaining -= last_vec.iov_len;
127      }
128      receive_vector = read_buffers.data();
129      receive_count = read_buffers.size();
130    }
131    status = ReceiveDataVector(socket_fd, receive_vector, receive_count);
132    if (status && size_remaining > 0)
133      status = ReadAndDiscardData(socket_fd, size_remaining);
134  }
135  return status;
136}
137
138}  // anonymous namespace
139
140ClientChannel::ClientChannel(LocalChannelHandle channel_handle)
141    : channel_handle_{std::move(channel_handle)} {
142  channel_data_ = ChannelManager::Get().GetChannelData(channel_handle_.value());
143}
144
145std::unique_ptr<pdx::ClientChannel> ClientChannel::Create(
146    LocalChannelHandle channel_handle) {
147  return std::unique_ptr<pdx::ClientChannel>{
148      new ClientChannel{std::move(channel_handle)}};
149}
150
151ClientChannel::~ClientChannel() {
152  if (channel_handle_)
153    shutdown(channel_handle_.value(), SHUT_WR);
154}
155
156void* ClientChannel::AllocateTransactionState() { return new TransactionState; }
157
158void ClientChannel::FreeTransactionState(void* state) {
159  delete static_cast<TransactionState*>(state);
160}
161
162Status<void> ClientChannel::SendImpulse(int opcode, const void* buffer,
163                                        size_t length) {
164  std::unique_lock<std::mutex> lock(socket_mutex_);
165  Status<void> status;
166  android::pdx::uds::RequestHeader<BorrowedHandle> request;
167  if (length > request.impulse_payload.size() ||
168      (buffer == nullptr && length != 0)) {
169    status.SetError(EINVAL);
170    return status;
171  }
172
173  InitRequest(&request, opcode, length, 0, true);
174  memcpy(request.impulse_payload.data(), buffer, length);
175  return SendData(BorrowedHandle{channel_handle_.value()}, request);
176}
177
178Status<int> ClientChannel::SendAndReceive(void* transaction_state, int opcode,
179                                          const iovec* send_vector,
180                                          size_t send_count,
181                                          const iovec* receive_vector,
182                                          size_t receive_count) {
183  std::unique_lock<std::mutex> lock(socket_mutex_);
184  Status<int> result;
185  if ((send_vector == nullptr && send_count != 0) ||
186      (receive_vector == nullptr && receive_count != 0)) {
187    result.SetError(EINVAL);
188    return result;
189  }
190
191  auto* state = static_cast<TransactionState*>(transaction_state);
192  size_t max_recv_len = CountVectorSize(receive_vector, receive_count);
193
194  auto status = SendRequest(BorrowedHandle{channel_handle_.value()}, state,
195                            opcode, send_vector, send_count, max_recv_len);
196  if (status) {
197    status = ReceiveResponse(BorrowedHandle{channel_handle_.value()}, state,
198                             receive_vector, receive_count, max_recv_len);
199  }
200  if (!result.PropagateError(status)) {
201    const int return_code = state->response.ret_code;
202    if (return_code >= 0)
203      result.SetValue(return_code);
204    else
205      result.SetError(-return_code);
206  }
207  return result;
208}
209
210Status<int> ClientChannel::SendWithInt(void* transaction_state, int opcode,
211                                       const iovec* send_vector,
212                                       size_t send_count,
213                                       const iovec* receive_vector,
214                                       size_t receive_count) {
215  return SendAndReceive(transaction_state, opcode, send_vector, send_count,
216                        receive_vector, receive_count);
217}
218
219Status<LocalHandle> ClientChannel::SendWithFileHandle(
220    void* transaction_state, int opcode, const iovec* send_vector,
221    size_t send_count, const iovec* receive_vector, size_t receive_count) {
222  Status<int> int_status =
223      SendAndReceive(transaction_state, opcode, send_vector, send_count,
224                     receive_vector, receive_count);
225  Status<LocalHandle> status;
226  if (status.PropagateError(int_status))
227    return status;
228
229  auto* state = static_cast<TransactionState*>(transaction_state);
230  LocalHandle handle;
231  if (state->GetLocalFileHandle(int_status.get(), &handle)) {
232    status.SetValue(std::move(handle));
233  } else {
234    status.SetError(EINVAL);
235  }
236  return status;
237}
238
239Status<LocalChannelHandle> ClientChannel::SendWithChannelHandle(
240    void* transaction_state, int opcode, const iovec* send_vector,
241    size_t send_count, const iovec* receive_vector, size_t receive_count) {
242  Status<int> int_status =
243      SendAndReceive(transaction_state, opcode, send_vector, send_count,
244                     receive_vector, receive_count);
245  Status<LocalChannelHandle> status;
246  if (status.PropagateError(int_status))
247    return status;
248
249  auto* state = static_cast<TransactionState*>(transaction_state);
250  LocalChannelHandle handle;
251  if (state->GetLocalChannelHandle(int_status.get(), &handle)) {
252    status.SetValue(std::move(handle));
253  } else {
254    status.SetError(EINVAL);
255  }
256  return status;
257}
258
259FileReference ClientChannel::PushFileHandle(void* transaction_state,
260                                            const LocalHandle& handle) {
261  auto* state = static_cast<TransactionState*>(transaction_state);
262  return state->PushFileHandle(handle.Borrow());
263}
264
265FileReference ClientChannel::PushFileHandle(void* transaction_state,
266                                            const BorrowedHandle& handle) {
267  auto* state = static_cast<TransactionState*>(transaction_state);
268  return state->PushFileHandle(handle.Duplicate());
269}
270
271ChannelReference ClientChannel::PushChannelHandle(
272    void* transaction_state, const LocalChannelHandle& handle) {
273  auto* state = static_cast<TransactionState*>(transaction_state);
274  return state->PushChannelHandle(handle.Borrow());
275}
276
277ChannelReference ClientChannel::PushChannelHandle(
278    void* transaction_state, const BorrowedChannelHandle& handle) {
279  auto* state = static_cast<TransactionState*>(transaction_state);
280  return state->PushChannelHandle(handle.Duplicate());
281}
282
283bool ClientChannel::GetFileHandle(void* transaction_state, FileReference ref,
284                                  LocalHandle* handle) const {
285  auto* state = static_cast<TransactionState*>(transaction_state);
286  return state->GetLocalFileHandle(ref, handle);
287}
288
289bool ClientChannel::GetChannelHandle(void* transaction_state,
290                                     ChannelReference ref,
291                                     LocalChannelHandle* handle) const {
292  auto* state = static_cast<TransactionState*>(transaction_state);
293  return state->GetLocalChannelHandle(ref, handle);
294}
295
296std::unique_ptr<pdx::ChannelParcelable> ClientChannel::TakeChannelParcelable()
297    {
298  if (!channel_handle_)
299    return nullptr;
300
301  if (auto* channel_data =
302          ChannelManager::Get().GetChannelData(channel_handle_.value())) {
303    auto fds = channel_data->TakeFds();
304    auto parcelable = std::make_unique<ChannelParcelable>(
305        std::move(std::get<0>(fds)), std::move(std::get<1>(fds)),
306        std::move(std::get<2>(fds)));
307
308    // Here we need to explicitly close the channel handle so that the channel
309    // won't get shutdown in the destructor, while the FDs in ChannelParcelable
310    // can keep the channel alive so that new client can be created from it
311    // later.
312    channel_handle_.Close();
313    return parcelable;
314  } else {
315    return nullptr;
316  }
317}
318
319}  // namespace uds
320}  // namespace pdx
321}  // namespace android
322