1#include "uds/ipc_helper.h"
2
3#include <alloca.h>
4#include <errno.h>
5#include <log/log.h>
6#include <poll.h>
7#include <string.h>
8#include <sys/inotify.h>
9#include <sys/param.h>
10#include <sys/socket.h>
11
12#include <algorithm>
13
14#include <pdx/service.h>
15#include <pdx/utility.h>
16
17namespace android {
18namespace pdx {
19namespace uds {
20
21namespace {
22
23constexpr size_t kMaxFdCount =
24    256;  // Total of 1KiB of data to transfer these FDs.
25
26// Default implementations of Send/Receive interfaces to use standard socket
27// send/sendmsg/recv/recvmsg functions.
28class SocketSender : public SendInterface {
29 public:
30  ssize_t Send(int socket_fd, const void* data, size_t size,
31               int flags) override {
32    return send(socket_fd, data, size, flags);
33  }
34  ssize_t SendMessage(int socket_fd, const msghdr* msg, int flags) override {
35    return sendmsg(socket_fd, msg, flags);
36  }
37} g_socket_sender;
38
39class SocketReceiver : public RecvInterface {
40 public:
41  ssize_t Receive(int socket_fd, void* data, size_t size, int flags) override {
42    return recv(socket_fd, data, size, flags);
43  }
44  ssize_t ReceiveMessage(int socket_fd, msghdr* msg, int flags) override {
45    return recvmsg(socket_fd, msg, flags);
46  }
47} g_socket_receiver;
48
49}  // anonymous namespace
50
51// Helper wrappers around send()/sendmsg() which repeat send() calls on data
52// that was not sent with the initial call to send/sendmsg. This is important to
53// handle transmissions interrupted by signals.
54Status<void> SendAll(SendInterface* sender, const BorrowedHandle& socket_fd,
55                     const void* data, size_t size) {
56  Status<void> ret;
57  const uint8_t* ptr = static_cast<const uint8_t*>(data);
58  while (size > 0) {
59    ssize_t size_written =
60        RETRY_EINTR(sender->Send(socket_fd.Get(), ptr, size, MSG_NOSIGNAL));
61    if (size_written < 0) {
62      ret.SetError(errno);
63      ALOGE("SendAll: Failed to send data over socket: %s",
64            ret.GetErrorMessage().c_str());
65      break;
66    }
67    size -= size_written;
68    ptr += size_written;
69  }
70  return ret;
71}
72
73Status<void> SendMsgAll(SendInterface* sender, const BorrowedHandle& socket_fd,
74                        const msghdr* msg) {
75  Status<void> ret;
76  ssize_t sent_size =
77      RETRY_EINTR(sender->SendMessage(socket_fd.Get(), msg, MSG_NOSIGNAL));
78  if (sent_size < 0) {
79    ret.SetError(errno);
80    ALOGE("SendMsgAll: Failed to send data over socket: %s",
81          ret.GetErrorMessage().c_str());
82    return ret;
83  }
84
85  ssize_t chunk_start_offset = 0;
86  for (size_t i = 0; i < msg->msg_iovlen; i++) {
87    ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
88    if (sent_size < chunk_end_offset) {
89      size_t offset_within_chunk = sent_size - chunk_start_offset;
90      size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
91      const uint8_t* chunk_base =
92          static_cast<const uint8_t*>(msg->msg_iov[i].iov_base);
93      ret = SendAll(sender, socket_fd, chunk_base + offset_within_chunk,
94                    data_size);
95      if (!ret)
96        break;
97      sent_size += data_size;
98    }
99    chunk_start_offset = chunk_end_offset;
100  }
101  return ret;
102}
103
104// Helper wrappers around recv()/recvmsg() which repeat recv() calls on data
105// that was not received with the initial call to recvmsg(). This is important
106// to handle transmissions interrupted by signals as well as the case when
107// initial data did not arrive in a single chunk over the socket (e.g. socket
108// buffer was full at the time of transmission, and only portion of initial
109// message was sent and the rest was blocked until the buffer was cleared by the
110// receiving side).
111Status<void> RecvMsgAll(RecvInterface* receiver,
112                        const BorrowedHandle& socket_fd, msghdr* msg) {
113  Status<void> ret;
114  ssize_t size_read = RETRY_EINTR(receiver->ReceiveMessage(
115      socket_fd.Get(), msg, MSG_WAITALL | MSG_CMSG_CLOEXEC));
116  if (size_read < 0) {
117    ret.SetError(errno);
118    ALOGE("RecvMsgAll: Failed to receive data from socket: %s",
119          ret.GetErrorMessage().c_str());
120    return ret;
121  } else if (size_read == 0) {
122    ret.SetError(ESHUTDOWN);
123    ALOGW("RecvMsgAll: Socket has been shut down");
124    return ret;
125  }
126
127  ssize_t chunk_start_offset = 0;
128  for (size_t i = 0; i < msg->msg_iovlen; i++) {
129    ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
130    if (size_read < chunk_end_offset) {
131      size_t offset_within_chunk = size_read - chunk_start_offset;
132      size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
133      uint8_t* chunk_base = static_cast<uint8_t*>(msg->msg_iov[i].iov_base);
134      ret = RecvAll(receiver, socket_fd, chunk_base + offset_within_chunk,
135                    data_size);
136      if (!ret)
137        break;
138      size_read += data_size;
139    }
140    chunk_start_offset = chunk_end_offset;
141  }
142  return ret;
143}
144
145Status<void> RecvAll(RecvInterface* receiver, const BorrowedHandle& socket_fd,
146                     void* data, size_t size) {
147  Status<void> ret;
148  uint8_t* ptr = static_cast<uint8_t*>(data);
149  while (size > 0) {
150    ssize_t size_read = RETRY_EINTR(receiver->Receive(
151        socket_fd.Get(), ptr, size, MSG_WAITALL | MSG_CMSG_CLOEXEC));
152    if (size_read < 0) {
153      ret.SetError(errno);
154      ALOGE("RecvAll: Failed to receive data from socket: %s",
155            ret.GetErrorMessage().c_str());
156      break;
157    } else if (size_read == 0) {
158      ret.SetError(ESHUTDOWN);
159      ALOGW("RecvAll: Socket has been shut down");
160      break;
161    }
162    size -= size_read;
163    ptr += size_read;
164  }
165  return ret;
166}
167
168uint32_t kMagicPreamble = 0x7564736d;  // 'udsm'.
169
170struct MessagePreamble {
171  uint32_t magic{0};
172  uint32_t data_size{0};
173  uint32_t fd_count{0};
174};
175
176Status<void> SendPayload::Send(const BorrowedHandle& socket_fd) {
177  return Send(socket_fd, nullptr);
178}
179
180Status<void> SendPayload::Send(const BorrowedHandle& socket_fd,
181                               const ucred* cred, const iovec* data_vec,
182                               size_t vec_count) {
183  if (file_handles_.size() > kMaxFdCount) {
184    ALOGE(
185        "SendPayload::Send: Trying to send too many file descriptors (%zu), "
186        "max allowed = %zu",
187        file_handles_.size(), kMaxFdCount);
188    return ErrorStatus{EINVAL};
189  }
190
191  SendInterface* sender = sender_ ? sender_ : &g_socket_sender;
192  MessagePreamble preamble;
193  preamble.magic = kMagicPreamble;
194  preamble.data_size = buffer_.size();
195  preamble.fd_count = file_handles_.size();
196
197  msghdr msg = {};
198  msg.msg_iovlen = 2 + vec_count;
199  msg.msg_iov = static_cast<iovec*>(alloca(sizeof(iovec) * msg.msg_iovlen));
200  msg.msg_iov[0].iov_base = &preamble;
201  msg.msg_iov[0].iov_len = sizeof(preamble);
202  msg.msg_iov[1].iov_base = buffer_.data();
203  msg.msg_iov[1].iov_len = buffer_.size();
204  for (size_t i = 0; i < vec_count; i++)
205    msg.msg_iov[i + 2] = data_vec[i];
206
207  if (cred || !file_handles_.empty()) {
208    const size_t fd_bytes = file_handles_.size() * sizeof(int);
209    msg.msg_controllen = (cred ? CMSG_SPACE(sizeof(ucred)) : 0) +
210                         (fd_bytes == 0 ? 0 : CMSG_SPACE(fd_bytes));
211    msg.msg_control = alloca(msg.msg_controllen);
212
213    cmsghdr* control = CMSG_FIRSTHDR(&msg);
214    if (cred) {
215      control->cmsg_level = SOL_SOCKET;
216      control->cmsg_type = SCM_CREDENTIALS;
217      control->cmsg_len = CMSG_LEN(sizeof(ucred));
218      memcpy(CMSG_DATA(control), cred, sizeof(ucred));
219      control = CMSG_NXTHDR(&msg, control);
220    }
221
222    if (fd_bytes) {
223      control->cmsg_level = SOL_SOCKET;
224      control->cmsg_type = SCM_RIGHTS;
225      control->cmsg_len = CMSG_LEN(fd_bytes);
226      memcpy(CMSG_DATA(control), file_handles_.data(), fd_bytes);
227    }
228  }
229
230  return SendMsgAll(sender, socket_fd, &msg);
231}
232
233// MessageWriter
234void* SendPayload::GetNextWriteBufferSection(size_t size) {
235  return buffer_.grow_by(size);
236}
237
238OutputResourceMapper* SendPayload::GetOutputResourceMapper() { return this; }
239
240// OutputResourceMapper
241Status<FileReference> SendPayload::PushFileHandle(const LocalHandle& handle) {
242  if (handle) {
243    const int ref = file_handles_.size();
244    file_handles_.push_back(handle.Get());
245    return ref;
246  } else {
247    return handle.Get();
248  }
249}
250
251Status<FileReference> SendPayload::PushFileHandle(
252    const BorrowedHandle& handle) {
253  if (handle) {
254    const int ref = file_handles_.size();
255    file_handles_.push_back(handle.Get());
256    return ref;
257  } else {
258    return handle.Get();
259  }
260}
261
262Status<FileReference> SendPayload::PushFileHandle(const RemoteHandle& handle) {
263  return handle.Get();
264}
265
266Status<ChannelReference> SendPayload::PushChannelHandle(
267    const LocalChannelHandle& /*handle*/) {
268  return ErrorStatus{EOPNOTSUPP};
269}
270Status<ChannelReference> SendPayload::PushChannelHandle(
271    const BorrowedChannelHandle& /*handle*/) {
272  return ErrorStatus{EOPNOTSUPP};
273}
274Status<ChannelReference> SendPayload::PushChannelHandle(
275    const RemoteChannelHandle& /*handle*/) {
276  return ErrorStatus{EOPNOTSUPP};
277}
278
279Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd) {
280  return Receive(socket_fd, nullptr);
281}
282
283Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
284                                     ucred* cred) {
285  RecvInterface* receiver = receiver_ ? receiver_ : &g_socket_receiver;
286  MessagePreamble preamble;
287  msghdr msg = {};
288  iovec recv_vect = {&preamble, sizeof(preamble)};
289  msg.msg_iov = &recv_vect;
290  msg.msg_iovlen = 1;
291  const size_t receive_fd_bytes = kMaxFdCount * sizeof(int);
292  msg.msg_controllen = CMSG_SPACE(sizeof(ucred)) + CMSG_SPACE(receive_fd_bytes);
293  msg.msg_control = alloca(msg.msg_controllen);
294
295  Status<void> ret = RecvMsgAll(receiver, socket_fd, &msg);
296  if (!ret)
297    return ret;
298
299  if (preamble.magic != kMagicPreamble) {
300    ALOGE("ReceivePayload::Receive: Message header is invalid");
301    ret.SetError(EIO);
302    return ret;
303  }
304
305  buffer_.resize(preamble.data_size);
306  file_handles_.clear();
307  read_pos_ = 0;
308
309  bool cred_available = false;
310  file_handles_.reserve(preamble.fd_count);
311  cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
312  while (cmsg) {
313    if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS &&
314        cred && cmsg->cmsg_len == CMSG_LEN(sizeof(ucred))) {
315      cred_available = true;
316      memcpy(cred, CMSG_DATA(cmsg), sizeof(ucred));
317    } else if (cmsg->cmsg_level == SOL_SOCKET &&
318               cmsg->cmsg_type == SCM_RIGHTS) {
319      socklen_t payload_len = cmsg->cmsg_len - CMSG_LEN(0);
320      const int* fds = reinterpret_cast<const int*>(CMSG_DATA(cmsg));
321      size_t fd_count = payload_len / sizeof(int);
322      std::transform(fds, fds + fd_count, std::back_inserter(file_handles_),
323                     [](int fd) { return LocalHandle{fd}; });
324    }
325    cmsg = CMSG_NXTHDR(&msg, cmsg);
326  }
327
328  ret = RecvAll(receiver, socket_fd, buffer_.data(), buffer_.size());
329  if (!ret)
330    return ret;
331
332  if (cred && !cred_available) {
333    ALOGE("ReceivePayload::Receive: Failed to obtain message credentials");
334    ret.SetError(EIO);
335  }
336
337  return ret;
338}
339
340// MessageReader
341MessageReader::BufferSection ReceivePayload::GetNextReadBufferSection() {
342  return {buffer_.data() + read_pos_, &*buffer_.end()};
343}
344
345void ReceivePayload::ConsumeReadBufferSectionData(const void* new_start) {
346  read_pos_ = PointerDistance(new_start, buffer_.data());
347}
348
349InputResourceMapper* ReceivePayload::GetInputResourceMapper() { return this; }
350
351// InputResourceMapper
352bool ReceivePayload::GetFileHandle(FileReference ref, LocalHandle* handle) {
353  if (ref < 0) {
354    *handle = LocalHandle{ref};
355    return true;
356  }
357  if (static_cast<size_t>(ref) > file_handles_.size())
358    return false;
359  *handle = std::move(file_handles_[ref]);
360  return true;
361}
362
363bool ReceivePayload::GetChannelHandle(ChannelReference /*ref*/,
364                                      LocalChannelHandle* /*handle*/) {
365  return false;
366}
367
368Status<void> SendData(const BorrowedHandle& socket_fd, const void* data,
369                      size_t size) {
370  return SendAll(&g_socket_sender, socket_fd, data, size);
371}
372
373Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
374                            size_t count) {
375  msghdr msg = {};
376  msg.msg_iov = const_cast<iovec*>(data);
377  msg.msg_iovlen = count;
378  return SendMsgAll(&g_socket_sender, socket_fd, &msg);
379}
380
381Status<void> ReceiveData(const BorrowedHandle& socket_fd, void* data,
382                         size_t size) {
383  return RecvAll(&g_socket_receiver, socket_fd, data, size);
384}
385
386Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
387                               const iovec* data, size_t count) {
388  msghdr msg = {};
389  msg.msg_iov = const_cast<iovec*>(data);
390  msg.msg_iovlen = count;
391  return RecvMsgAll(&g_socket_receiver, socket_fd, &msg);
392}
393
394size_t CountVectorSize(const iovec* vector, size_t count) {
395  return std::accumulate(
396      vector, vector + count, size_t{0},
397      [](size_t size, const iovec& vec) { return size + vec.iov_len; });
398}
399
400void InitRequest(android::pdx::uds::RequestHeader<BorrowedHandle>* request,
401                 int opcode, uint32_t send_len, uint32_t max_recv_len,
402                 bool is_impulse) {
403  request->op = opcode;
404  request->cred.pid = getpid();
405  request->cred.uid = geteuid();
406  request->cred.gid = getegid();
407  request->send_len = send_len;
408  request->max_recv_len = max_recv_len;
409  request->is_impulse = is_impulse;
410}
411
412Status<void> WaitForEndpoint(const std::string& endpoint_path,
413                             int64_t timeout_ms) {
414  // Endpoint path must be absolute.
415  if (endpoint_path.empty() || endpoint_path.front() != '/')
416    return ErrorStatus(EINVAL);
417
418  // Create inotify fd.
419  LocalHandle fd{inotify_init()};
420  if (!fd)
421    return ErrorStatus(errno);
422
423  // Set the inotify fd to non-blocking.
424  int ret = fcntl(fd.Get(), F_GETFL);
425  fcntl(fd.Get(), F_SETFL, ret | O_NONBLOCK);
426
427  // Setup the pollfd.
428  pollfd pfd = {fd.Get(), POLLIN, 0};
429
430  // Find locations of each path separator.
431  std::vector<size_t> separators{0};  // The path is absolute, so '/' is at #0.
432  size_t pos = endpoint_path.find('/', 1);
433  while (pos != std::string::npos) {
434    separators.push_back(pos);
435    pos = endpoint_path.find('/', pos + 1);
436  }
437  separators.push_back(endpoint_path.size());
438
439  // Walk down the path, checking for existence and waiting if needed.
440  pos = 1;
441  size_t links = 0;
442  std::string current;
443  while (pos < separators.size() && links <= MAXSYMLINKS) {
444    std::string previous = current;
445    current = endpoint_path.substr(0, separators[pos]);
446
447    // Check for existence; proceed to setup a watch if not.
448    if (access(current.c_str(), F_OK) < 0) {
449      if (errno != ENOENT)
450        return ErrorStatus(errno);
451
452      // Extract the name of the path component to wait for.
453      std::string next = current.substr(
454          separators[pos - 1] + 1, separators[pos] - separators[pos - 1] - 1);
455
456      // Add a watch on the last existing directory we reach.
457      int wd = inotify_add_watch(
458          fd.Get(), previous.c_str(),
459          IN_CREATE | IN_DELETE_SELF | IN_MOVE_SELF | IN_MOVED_TO);
460      if (wd < 0) {
461        if (errno != ENOENT)
462          return ErrorStatus(errno);
463        // Restart at the beginning if previous was deleted.
464        links = 0;
465        current.clear();
466        pos = 1;
467        continue;
468      }
469
470      // Make sure current didn't get created before the watch was added.
471      ret = access(current.c_str(), F_OK);
472      if (ret < 0) {
473        if (errno != ENOENT)
474          return ErrorStatus(errno);
475
476        bool exit_poll = false;
477        while (!exit_poll) {
478          // Wait for an event or timeout.
479          ret = poll(&pfd, 1, timeout_ms);
480          if (ret <= 0)
481            return ErrorStatus(ret == 0 ? ETIMEDOUT : errno);
482
483          // Read events.
484          char buffer[sizeof(inotify_event) + NAME_MAX + 1];
485
486          ret = read(fd.Get(), buffer, sizeof(buffer));
487          if (ret < 0) {
488            if (errno == EAGAIN || errno == EWOULDBLOCK)
489              continue;
490            else
491              return ErrorStatus(errno);
492          } else if (static_cast<size_t>(ret) < sizeof(struct inotify_event)) {
493            return ErrorStatus(EIO);
494          }
495
496          auto* event = reinterpret_cast<const inotify_event*>(buffer);
497          auto* end = reinterpret_cast<const inotify_event*>(buffer + ret);
498          while (event < end) {
499            std::string event_for;
500            if (event->len > 0)
501              event_for = event->name;
502
503            if (event->mask & (IN_CREATE | IN_MOVED_TO)) {
504              // See if this is the droid we're looking for.
505              if (next == event_for) {
506                exit_poll = true;
507                break;
508              }
509            } else if (event->mask & (IN_DELETE_SELF | IN_MOVE_SELF)) {
510              // Restart at the beginning if our watch dir is deleted.
511              links = 0;
512              current.clear();
513              pos = 0;
514              exit_poll = true;
515              break;
516            }
517
518            event = reinterpret_cast<const inotify_event*>(AdvancePointer(
519                event, sizeof(struct inotify_event) + event->len));
520          }  // while (event < end)
521        }    // while (!exit_poll)
522      }      // Current dir doesn't exist.
523      ret = inotify_rm_watch(fd.Get(), wd);
524      if (ret < 0 && errno != EINVAL)
525        return ErrorStatus(errno);
526    }  // if (access(current.c_str(), F_OK) < 0)
527
528    // Check for symbolic link and update link count.
529    struct stat stat_buf;
530    ret = lstat(current.c_str(), &stat_buf);
531    if (ret < 0 && errno != ENOENT)
532      return ErrorStatus(errno);
533    else if (ret == 0 && S_ISLNK(stat_buf.st_mode))
534      links++;
535    pos++;
536  }  // while (pos < separators.size() && links <= MAXSYMLINKS)
537
538  return {};
539}
540
541}  // namespace uds
542}  // namespace pdx
543}  // namespace android
544