client_channel.cpp revision e4eec20f6263f4a42ae462456f60ea6c4518bb0a
1#include "uds/client_channel.h"
2
3#include <errno.h>
4#include <log/log.h>
5#include <sys/epoll.h>
6#include <sys/socket.h>
7
8#include <pdx/client.h>
9#include <pdx/service_endpoint.h>
10#include <uds/ipc_helper.h>
11
12namespace android {
13namespace pdx {
14namespace uds {
15
16namespace {
17
18struct TransactionState {
19  bool GetLocalFileHandle(int index, LocalHandle* handle) {
20    if (index < 0) {
21      handle->Reset(index);
22    } else if (static_cast<size_t>(index) < response.file_descriptors.size()) {
23      *handle = std::move(response.file_descriptors[index]);
24    } else {
25      return false;
26    }
27    return true;
28  }
29
30  bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) {
31    if (index < 0) {
32      *handle = LocalChannelHandle{nullptr, index};
33    } else if (static_cast<size_t>(index) < response.channels.size()) {
34      auto& channel_info = response.channels[index];
35      *handle = ChannelManager::Get().CreateHandle(
36          std::move(channel_info.data_fd), std::move(channel_info.event_fd));
37    } else {
38      return false;
39    }
40    return true;
41  }
42
43  FileReference PushFileHandle(BorrowedHandle handle) {
44    if (!handle)
45      return handle.Get();
46    request.file_descriptors.push_back(std::move(handle));
47    return request.file_descriptors.size() - 1;
48  }
49
50  ChannelReference PushChannelHandle(BorrowedChannelHandle handle) {
51    if (!handle)
52      return handle.value();
53    ChannelInfo<BorrowedHandle> channel_info;
54    channel_info.data_fd.Reset(handle.value());
55    channel_info.event_fd.Reset(
56        ChannelManager::Get().GetEventFd(handle.value()));
57    request.channels.push_back(std::move(channel_info));
58    return request.channels.size() - 1;
59  }
60
61  RequestHeader<BorrowedHandle> request;
62  ResponseHeader<LocalHandle> response;
63};
64
65Status<void> ReadAndDiscardData(int socket_fd, size_t size) {
66  while (size > 0) {
67    // If there is more data to read in the message than the buffers provided
68    // by the caller, read and discard the extra data from the socket.
69    char buffer[1024];
70    size_t size_to_read = std::min(sizeof(buffer), size);
71    auto status = ReceiveData(socket_fd, buffer, size_to_read);
72    if (!status)
73      return status;
74    size -= size_to_read;
75  }
76  // We still want to return EIO error to the caller in case we had unexpected
77  // data in the socket stream.
78  return ErrorStatus(EIO);
79}
80
81Status<void> SendRequest(int socket_fd, TransactionState* transaction_state,
82                         int opcode, const iovec* send_vector,
83                         size_t send_count, size_t max_recv_len) {
84  size_t send_len = CountVectorSize(send_vector, send_count);
85  InitRequest(&transaction_state->request, opcode, send_len, max_recv_len,
86              false);
87  auto status = SendData(socket_fd, transaction_state->request);
88  if (status && send_len > 0)
89    status = SendDataVector(socket_fd, send_vector, send_count);
90  return status;
91}
92
93Status<void> ReceiveResponse(int socket_fd, TransactionState* transaction_state,
94                             const iovec* receive_vector, size_t receive_count,
95                             size_t max_recv_len) {
96  auto status = ReceiveData(socket_fd, &transaction_state->response);
97  if (!status)
98    return status;
99
100  if (transaction_state->response.recv_len > 0) {
101    std::vector<iovec> read_buffers;
102    size_t size_remaining = 0;
103    if (transaction_state->response.recv_len != max_recv_len) {
104      // If the receive buffer not exactly the size of data available, recreate
105      // the vector list to consume the data exactly since ReceiveDataVector()
106      // validates that the number of bytes received equals the number of bytes
107      // requested.
108      size_remaining = transaction_state->response.recv_len;
109      for (size_t i = 0; i < receive_count && size_remaining > 0; i++) {
110        read_buffers.push_back(receive_vector[i]);
111        iovec& last_vec = read_buffers.back();
112        if (last_vec.iov_len > size_remaining)
113          last_vec.iov_len = size_remaining;
114        size_remaining -= last_vec.iov_len;
115      }
116      receive_vector = read_buffers.data();
117      receive_count = read_buffers.size();
118    }
119    status = ReceiveDataVector(socket_fd, receive_vector, receive_count);
120    if (status && size_remaining > 0)
121      status = ReadAndDiscardData(socket_fd, size_remaining);
122  }
123  return status;
124}
125
126}  // anonymous namespace
127
128ClientChannel::ClientChannel(LocalChannelHandle channel_handle)
129    : channel_handle_{std::move(channel_handle)} {
130  int data_fd = channel_handle_.value();
131  int event_fd =
132      channel_handle_ ? ChannelManager::Get().GetEventFd(data_fd) : -1;
133
134  if (event_fd >= 0) {
135    epoll_fd_.Reset(epoll_create(1));
136    if (epoll_fd_) {
137      epoll_event data_ev;
138      data_ev.events = EPOLLHUP;
139      data_ev.data.fd = data_fd;
140
141      epoll_event event_ev;
142      event_ev.events = EPOLLIN;
143      event_ev.data.fd = event_fd;
144      if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, data_fd, &data_ev) < 0 ||
145          epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, event_fd, &event_ev) < 0) {
146        ALOGE("Failed to add fd to epoll fd because: %s\n", strerror(errno));
147        epoll_fd_.Close();
148      }
149    }
150  }
151}
152
153std::unique_ptr<pdx::ClientChannel> ClientChannel::Create(
154    LocalChannelHandle channel_handle) {
155  return std::unique_ptr<pdx::ClientChannel>{
156      new ClientChannel{std::move(channel_handle)}};
157}
158
159ClientChannel::~ClientChannel() {
160  if (channel_handle_)
161    shutdown(channel_handle_.value(), SHUT_WR);
162}
163
164void* ClientChannel::AllocateTransactionState() { return new TransactionState; }
165
166void ClientChannel::FreeTransactionState(void* state) {
167  delete static_cast<TransactionState*>(state);
168}
169
170Status<void> ClientChannel::SendImpulse(int opcode, const void* buffer,
171                                        size_t length) {
172  Status<void> status;
173  android::pdx::uds::RequestHeader<BorrowedHandle> request;
174  if (length > request.impulse_payload.size() ||
175      (buffer == nullptr && length != 0)) {
176    status.SetError(EINVAL);
177    return status;
178  }
179
180  InitRequest(&request, opcode, length, 0, true);
181  memcpy(request.impulse_payload.data(), buffer, length);
182  return SendData(channel_handle_.value(), request);
183}
184
185Status<int> ClientChannel::SendAndReceive(void* transaction_state, int opcode,
186                                          const iovec* send_vector,
187                                          size_t send_count,
188                                          const iovec* receive_vector,
189                                          size_t receive_count) {
190  Status<int> result;
191  if ((send_vector == nullptr && send_count != 0) ||
192      (receive_vector == nullptr && receive_count != 0)) {
193    result.SetError(EINVAL);
194    return result;
195  }
196
197  auto* state = static_cast<TransactionState*>(transaction_state);
198  size_t max_recv_len = CountVectorSize(receive_vector, receive_count);
199
200  auto status = SendRequest(channel_handle_.value(), state, opcode, send_vector,
201                            send_count, max_recv_len);
202  if (status) {
203    status = ReceiveResponse(channel_handle_.value(), state, receive_vector,
204                             receive_count, max_recv_len);
205  }
206  if (!result.PropagateError(status)) {
207    const int return_code = state->response.ret_code;
208    if (return_code >= 0)
209      result.SetValue(return_code);
210    else
211      result.SetError(-return_code);
212  }
213  return result;
214}
215
216Status<int> ClientChannel::SendWithInt(void* transaction_state, int opcode,
217                                       const iovec* send_vector,
218                                       size_t send_count,
219                                       const iovec* receive_vector,
220                                       size_t receive_count) {
221  return SendAndReceive(transaction_state, opcode, send_vector, send_count,
222                        receive_vector, receive_count);
223}
224
225Status<LocalHandle> ClientChannel::SendWithFileHandle(
226    void* transaction_state, int opcode, const iovec* send_vector,
227    size_t send_count, const iovec* receive_vector, size_t receive_count) {
228  Status<int> int_status =
229      SendAndReceive(transaction_state, opcode, send_vector, send_count,
230                     receive_vector, receive_count);
231  Status<LocalHandle> status;
232  if (status.PropagateError(int_status))
233    return status;
234
235  auto* state = static_cast<TransactionState*>(transaction_state);
236  LocalHandle handle;
237  if (state->GetLocalFileHandle(int_status.get(), &handle)) {
238    status.SetValue(std::move(handle));
239  } else {
240    status.SetError(EINVAL);
241  }
242  return status;
243}
244
245Status<LocalChannelHandle> ClientChannel::SendWithChannelHandle(
246    void* transaction_state, int opcode, const iovec* send_vector,
247    size_t send_count, const iovec* receive_vector, size_t receive_count) {
248  Status<int> int_status =
249      SendAndReceive(transaction_state, opcode, send_vector, send_count,
250                     receive_vector, receive_count);
251  Status<LocalChannelHandle> status;
252  if (status.PropagateError(int_status))
253    return status;
254
255  auto* state = static_cast<TransactionState*>(transaction_state);
256  LocalChannelHandle handle;
257  if (state->GetLocalChannelHandle(int_status.get(), &handle)) {
258    status.SetValue(std::move(handle));
259  } else {
260    status.SetError(EINVAL);
261  }
262  return status;
263}
264
265FileReference ClientChannel::PushFileHandle(void* transaction_state,
266                                            const LocalHandle& handle) {
267  auto* state = static_cast<TransactionState*>(transaction_state);
268  return state->PushFileHandle(handle.Borrow());
269}
270
271FileReference ClientChannel::PushFileHandle(void* transaction_state,
272                                            const BorrowedHandle& handle) {
273  auto* state = static_cast<TransactionState*>(transaction_state);
274  return state->PushFileHandle(handle.Duplicate());
275}
276
277ChannelReference ClientChannel::PushChannelHandle(
278    void* transaction_state, const LocalChannelHandle& handle) {
279  auto* state = static_cast<TransactionState*>(transaction_state);
280  return state->PushChannelHandle(handle.Borrow());
281}
282
283ChannelReference ClientChannel::PushChannelHandle(
284    void* transaction_state, const BorrowedChannelHandle& handle) {
285  auto* state = static_cast<TransactionState*>(transaction_state);
286  return state->PushChannelHandle(handle.Duplicate());
287}
288
289bool ClientChannel::GetFileHandle(void* transaction_state, FileReference ref,
290                                  LocalHandle* handle) const {
291  auto* state = static_cast<TransactionState*>(transaction_state);
292  return state->GetLocalFileHandle(ref, handle);
293}
294
295bool ClientChannel::GetChannelHandle(void* transaction_state,
296                                     ChannelReference ref,
297                                     LocalChannelHandle* handle) const {
298  auto* state = static_cast<TransactionState*>(transaction_state);
299  return state->GetLocalChannelHandle(ref, handle);
300}
301
302}  // namespace uds
303}  // namespace pdx
304}  // namespace android
305