1#include "uds/client_channel.h"
2
3#include <errno.h>
4#include <log/log.h>
5#include <sys/epoll.h>
6#include <sys/socket.h>
7
8#include <pdx/client.h>
9#include <pdx/service_endpoint.h>
10#include <uds/ipc_helper.h>
11
12namespace android {
13namespace pdx {
14namespace uds {
15
16namespace {
17
18struct TransactionState {
19  bool GetLocalFileHandle(int index, LocalHandle* handle) {
20    if (index < 0) {
21      handle->Reset(index);
22    } else if (static_cast<size_t>(index) < response.file_descriptors.size()) {
23      *handle = std::move(response.file_descriptors[index]);
24    } else {
25      return false;
26    }
27    return true;
28  }
29
30  bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) {
31    if (index < 0) {
32      *handle = LocalChannelHandle{nullptr, index};
33    } else if (static_cast<size_t>(index) < response.channels.size()) {
34      auto& channel_info = response.channels[index];
35      *handle = ChannelManager::Get().CreateHandle(
36          std::move(channel_info.data_fd),
37          std::move(channel_info.pollin_event_fd),
38          std::move(channel_info.pollhup_event_fd));
39    } else {
40      return false;
41    }
42    return true;
43  }
44
45  FileReference PushFileHandle(BorrowedHandle handle) {
46    if (!handle)
47      return handle.Get();
48    request.file_descriptors.push_back(std::move(handle));
49    return request.file_descriptors.size() - 1;
50  }
51
52  ChannelReference PushChannelHandle(BorrowedChannelHandle handle) {
53    if (!handle)
54      return handle.value();
55
56    if (auto* channel_data =
57            ChannelManager::Get().GetChannelData(handle.value())) {
58      ChannelInfo<BorrowedHandle> channel_info{
59          channel_data->data_fd(), channel_data->pollin_event_fd(),
60          channel_data->pollhup_event_fd()};
61      request.channels.push_back(std::move(channel_info));
62      return request.channels.size() - 1;
63    } else {
64      return -1;
65    }
66  }
67
68  RequestHeader<BorrowedHandle> request;
69  ResponseHeader<LocalHandle> response;
70};
71
72Status<void> ReadAndDiscardData(const BorrowedHandle& socket_fd, size_t size) {
73  while (size > 0) {
74    // If there is more data to read in the message than the buffers provided
75    // by the caller, read and discard the extra data from the socket.
76    char buffer[1024];
77    size_t size_to_read = std::min(sizeof(buffer), size);
78    auto status = ReceiveData(socket_fd, buffer, size_to_read);
79    if (!status)
80      return status;
81    size -= size_to_read;
82  }
83  // We still want to return EIO error to the caller in case we had unexpected
84  // data in the socket stream.
85  return ErrorStatus(EIO);
86}
87
88Status<void> SendRequest(const BorrowedHandle& socket_fd,
89                         TransactionState* transaction_state, int opcode,
90                         const iovec* send_vector, size_t send_count,
91                         size_t max_recv_len) {
92  size_t send_len = CountVectorSize(send_vector, send_count);
93  InitRequest(&transaction_state->request, opcode, send_len, max_recv_len,
94              false);
95  if (send_len == 0) {
96    send_vector = nullptr;
97    send_count = 0;
98  }
99  return SendData(socket_fd, transaction_state->request, send_vector,
100                  send_count);
101}
102
103Status<void> ReceiveResponse(const BorrowedHandle& socket_fd,
104                             TransactionState* transaction_state,
105                             const iovec* receive_vector, size_t receive_count,
106                             size_t max_recv_len) {
107  auto status = ReceiveData(socket_fd, &transaction_state->response);
108  if (!status)
109    return status;
110
111  if (transaction_state->response.recv_len > 0) {
112    std::vector<iovec> read_buffers;
113    size_t size_remaining = 0;
114    if (transaction_state->response.recv_len != max_recv_len) {
115      // If the receive buffer not exactly the size of data available, recreate
116      // the vector list to consume the data exactly since ReceiveDataVector()
117      // validates that the number of bytes received equals the number of bytes
118      // requested.
119      size_remaining = transaction_state->response.recv_len;
120      for (size_t i = 0; i < receive_count && size_remaining > 0; i++) {
121        read_buffers.push_back(receive_vector[i]);
122        iovec& last_vec = read_buffers.back();
123        if (last_vec.iov_len > size_remaining)
124          last_vec.iov_len = size_remaining;
125        size_remaining -= last_vec.iov_len;
126      }
127      receive_vector = read_buffers.data();
128      receive_count = read_buffers.size();
129    }
130    status = ReceiveDataVector(socket_fd, receive_vector, receive_count);
131    if (status && size_remaining > 0)
132      status = ReadAndDiscardData(socket_fd, size_remaining);
133  }
134  return status;
135}
136
137}  // anonymous namespace
138
139ClientChannel::ClientChannel(LocalChannelHandle channel_handle)
140    : channel_handle_{std::move(channel_handle)} {
141  channel_data_ = ChannelManager::Get().GetChannelData(channel_handle_.value());
142}
143
144std::unique_ptr<pdx::ClientChannel> ClientChannel::Create(
145    LocalChannelHandle channel_handle) {
146  return std::unique_ptr<pdx::ClientChannel>{
147      new ClientChannel{std::move(channel_handle)}};
148}
149
150ClientChannel::~ClientChannel() {
151  if (channel_handle_)
152    shutdown(channel_handle_.value(), SHUT_WR);
153}
154
155void* ClientChannel::AllocateTransactionState() { return new TransactionState; }
156
157void ClientChannel::FreeTransactionState(void* state) {
158  delete static_cast<TransactionState*>(state);
159}
160
161Status<void> ClientChannel::SendImpulse(int opcode, const void* buffer,
162                                        size_t length) {
163  std::unique_lock<std::mutex> lock(socket_mutex_);
164  Status<void> status;
165  android::pdx::uds::RequestHeader<BorrowedHandle> request;
166  if (length > request.impulse_payload.size() ||
167      (buffer == nullptr && length != 0)) {
168    status.SetError(EINVAL);
169    return status;
170  }
171
172  InitRequest(&request, opcode, length, 0, true);
173  memcpy(request.impulse_payload.data(), buffer, length);
174  return SendData(BorrowedHandle{channel_handle_.value()}, request);
175}
176
177Status<int> ClientChannel::SendAndReceive(void* transaction_state, int opcode,
178                                          const iovec* send_vector,
179                                          size_t send_count,
180                                          const iovec* receive_vector,
181                                          size_t receive_count) {
182  std::unique_lock<std::mutex> lock(socket_mutex_);
183  Status<int> result;
184  if ((send_vector == nullptr && send_count != 0) ||
185      (receive_vector == nullptr && receive_count != 0)) {
186    result.SetError(EINVAL);
187    return result;
188  }
189
190  auto* state = static_cast<TransactionState*>(transaction_state);
191  size_t max_recv_len = CountVectorSize(receive_vector, receive_count);
192
193  auto status = SendRequest(BorrowedHandle{channel_handle_.value()}, state,
194                            opcode, send_vector, send_count, max_recv_len);
195  if (status) {
196    status = ReceiveResponse(BorrowedHandle{channel_handle_.value()}, state,
197                             receive_vector, receive_count, max_recv_len);
198  }
199  if (!result.PropagateError(status)) {
200    const int return_code = state->response.ret_code;
201    if (return_code >= 0)
202      result.SetValue(return_code);
203    else
204      result.SetError(-return_code);
205  }
206  return result;
207}
208
209Status<int> ClientChannel::SendWithInt(void* transaction_state, int opcode,
210                                       const iovec* send_vector,
211                                       size_t send_count,
212                                       const iovec* receive_vector,
213                                       size_t receive_count) {
214  return SendAndReceive(transaction_state, opcode, send_vector, send_count,
215                        receive_vector, receive_count);
216}
217
218Status<LocalHandle> ClientChannel::SendWithFileHandle(
219    void* transaction_state, int opcode, const iovec* send_vector,
220    size_t send_count, const iovec* receive_vector, size_t receive_count) {
221  Status<int> int_status =
222      SendAndReceive(transaction_state, opcode, send_vector, send_count,
223                     receive_vector, receive_count);
224  Status<LocalHandle> status;
225  if (status.PropagateError(int_status))
226    return status;
227
228  auto* state = static_cast<TransactionState*>(transaction_state);
229  LocalHandle handle;
230  if (state->GetLocalFileHandle(int_status.get(), &handle)) {
231    status.SetValue(std::move(handle));
232  } else {
233    status.SetError(EINVAL);
234  }
235  return status;
236}
237
238Status<LocalChannelHandle> ClientChannel::SendWithChannelHandle(
239    void* transaction_state, int opcode, const iovec* send_vector,
240    size_t send_count, const iovec* receive_vector, size_t receive_count) {
241  Status<int> int_status =
242      SendAndReceive(transaction_state, opcode, send_vector, send_count,
243                     receive_vector, receive_count);
244  Status<LocalChannelHandle> status;
245  if (status.PropagateError(int_status))
246    return status;
247
248  auto* state = static_cast<TransactionState*>(transaction_state);
249  LocalChannelHandle handle;
250  if (state->GetLocalChannelHandle(int_status.get(), &handle)) {
251    status.SetValue(std::move(handle));
252  } else {
253    status.SetError(EINVAL);
254  }
255  return status;
256}
257
258FileReference ClientChannel::PushFileHandle(void* transaction_state,
259                                            const LocalHandle& handle) {
260  auto* state = static_cast<TransactionState*>(transaction_state);
261  return state->PushFileHandle(handle.Borrow());
262}
263
264FileReference ClientChannel::PushFileHandle(void* transaction_state,
265                                            const BorrowedHandle& handle) {
266  auto* state = static_cast<TransactionState*>(transaction_state);
267  return state->PushFileHandle(handle.Duplicate());
268}
269
270ChannelReference ClientChannel::PushChannelHandle(
271    void* transaction_state, const LocalChannelHandle& handle) {
272  auto* state = static_cast<TransactionState*>(transaction_state);
273  return state->PushChannelHandle(handle.Borrow());
274}
275
276ChannelReference ClientChannel::PushChannelHandle(
277    void* transaction_state, const BorrowedChannelHandle& handle) {
278  auto* state = static_cast<TransactionState*>(transaction_state);
279  return state->PushChannelHandle(handle.Duplicate());
280}
281
282bool ClientChannel::GetFileHandle(void* transaction_state, FileReference ref,
283                                  LocalHandle* handle) const {
284  auto* state = static_cast<TransactionState*>(transaction_state);
285  return state->GetLocalFileHandle(ref, handle);
286}
287
288bool ClientChannel::GetChannelHandle(void* transaction_state,
289                                     ChannelReference ref,
290                                     LocalChannelHandle* handle) const {
291  auto* state = static_cast<TransactionState*>(transaction_state);
292  return state->GetLocalChannelHandle(ref, handle);
293}
294
295}  // namespace uds
296}  // namespace pdx
297}  // namespace android
298