1#include "uds/ipc_helper.h"
2
3#include <alloca.h>
4#include <errno.h>
5#include <log/log.h>
6#include <poll.h>
7#include <string.h>
8#include <sys/inotify.h>
9#include <sys/param.h>
10#include <sys/socket.h>
11
12#include <algorithm>
13
14#include <pdx/service.h>
15#include <pdx/utility.h>
16
17namespace android {
18namespace pdx {
19namespace uds {
20
21namespace {
22
23// Default implementations of Send/Receive interfaces to use standard socket
24// send/sendmsg/recv/recvmsg functions.
25class SocketSender : public SendInterface {
26 public:
27  ssize_t Send(int socket_fd, const void* data, size_t size,
28               int flags) override {
29    return send(socket_fd, data, size, flags);
30  }
31  ssize_t SendMessage(int socket_fd, const msghdr* msg, int flags) override {
32    return sendmsg(socket_fd, msg, flags);
33  }
34} g_socket_sender;
35
36class SocketReceiver : public RecvInterface {
37 public:
38  ssize_t Receive(int socket_fd, void* data, size_t size, int flags) override {
39    return recv(socket_fd, data, size, flags);
40  }
41  ssize_t ReceiveMessage(int socket_fd, msghdr* msg, int flags) override {
42    return recvmsg(socket_fd, msg, flags);
43  }
44} g_socket_receiver;
45
46}  // anonymous namespace
47
48// Helper wrappers around send()/sendmsg() which repeat send() calls on data
49// that was not sent with the initial call to send/sendmsg. This is important to
50// handle transmissions interrupted by signals.
51Status<void> SendAll(SendInterface* sender, const BorrowedHandle& socket_fd,
52                     const void* data, size_t size) {
53  Status<void> ret;
54  const uint8_t* ptr = static_cast<const uint8_t*>(data);
55  while (size > 0) {
56    ssize_t size_written =
57        RETRY_EINTR(sender->Send(socket_fd.Get(), ptr, size, MSG_NOSIGNAL));
58    if (size_written < 0) {
59      ret.SetError(errno);
60      ALOGE("SendAll: Failed to send data over socket: %s",
61            ret.GetErrorMessage().c_str());
62      break;
63    }
64    size -= size_written;
65    ptr += size_written;
66  }
67  return ret;
68}
69
70Status<void> SendMsgAll(SendInterface* sender, const BorrowedHandle& socket_fd,
71                        const msghdr* msg) {
72  Status<void> ret;
73  ssize_t sent_size =
74      RETRY_EINTR(sender->SendMessage(socket_fd.Get(), msg, MSG_NOSIGNAL));
75  if (sent_size < 0) {
76    ret.SetError(errno);
77    ALOGE("SendMsgAll: Failed to send data over socket: %s",
78          ret.GetErrorMessage().c_str());
79    return ret;
80  }
81
82  ssize_t chunk_start_offset = 0;
83  for (size_t i = 0; i < msg->msg_iovlen; i++) {
84    ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
85    if (sent_size < chunk_end_offset) {
86      size_t offset_within_chunk = sent_size - chunk_start_offset;
87      size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
88      const uint8_t* chunk_base =
89          static_cast<const uint8_t*>(msg->msg_iov[i].iov_base);
90      ret = SendAll(sender, socket_fd, chunk_base + offset_within_chunk,
91                    data_size);
92      if (!ret)
93        break;
94      sent_size += data_size;
95    }
96    chunk_start_offset = chunk_end_offset;
97  }
98  return ret;
99}
100
101// Helper wrappers around recv()/recvmsg() which repeat recv() calls on data
102// that was not received with the initial call to recvmsg(). This is important
103// to handle transmissions interrupted by signals as well as the case when
104// initial data did not arrive in a single chunk over the socket (e.g. socket
105// buffer was full at the time of transmission, and only portion of initial
106// message was sent and the rest was blocked until the buffer was cleared by the
107// receiving side).
108Status<void> RecvMsgAll(RecvInterface* receiver,
109                        const BorrowedHandle& socket_fd, msghdr* msg) {
110  Status<void> ret;
111  ssize_t size_read = RETRY_EINTR(receiver->ReceiveMessage(
112      socket_fd.Get(), msg, MSG_WAITALL | MSG_CMSG_CLOEXEC));
113  if (size_read < 0) {
114    ret.SetError(errno);
115    ALOGE("RecvMsgAll: Failed to receive data from socket: %s",
116          ret.GetErrorMessage().c_str());
117    return ret;
118  } else if (size_read == 0) {
119    ret.SetError(ESHUTDOWN);
120    ALOGW("RecvMsgAll: Socket has been shut down");
121    return ret;
122  }
123
124  ssize_t chunk_start_offset = 0;
125  for (size_t i = 0; i < msg->msg_iovlen; i++) {
126    ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
127    if (size_read < chunk_end_offset) {
128      size_t offset_within_chunk = size_read - chunk_start_offset;
129      size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
130      uint8_t* chunk_base = static_cast<uint8_t*>(msg->msg_iov[i].iov_base);
131      ret = RecvAll(receiver, socket_fd, chunk_base + offset_within_chunk,
132                    data_size);
133      if (!ret)
134        break;
135      size_read += data_size;
136    }
137    chunk_start_offset = chunk_end_offset;
138  }
139  return ret;
140}
141
142Status<void> RecvAll(RecvInterface* receiver, const BorrowedHandle& socket_fd,
143                     void* data, size_t size) {
144  Status<void> ret;
145  uint8_t* ptr = static_cast<uint8_t*>(data);
146  while (size > 0) {
147    ssize_t size_read = RETRY_EINTR(receiver->Receive(
148        socket_fd.Get(), ptr, size, MSG_WAITALL | MSG_CMSG_CLOEXEC));
149    if (size_read < 0) {
150      ret.SetError(errno);
151      ALOGE("RecvAll: Failed to receive data from socket: %s",
152            ret.GetErrorMessage().c_str());
153      break;
154    } else if (size_read == 0) {
155      ret.SetError(ESHUTDOWN);
156      ALOGW("RecvAll: Socket has been shut down");
157      break;
158    }
159    size -= size_read;
160    ptr += size_read;
161  }
162  return ret;
163}
164
165uint32_t kMagicPreamble = 0x7564736d;  // 'udsm'.
166
167struct MessagePreamble {
168  uint32_t magic{0};
169  uint32_t data_size{0};
170  uint32_t fd_count{0};
171};
172
173Status<void> SendPayload::Send(const BorrowedHandle& socket_fd) {
174  return Send(socket_fd, nullptr);
175}
176
177Status<void> SendPayload::Send(const BorrowedHandle& socket_fd,
178                               const ucred* cred) {
179  SendInterface* sender = sender_ ? sender_ : &g_socket_sender;
180  MessagePreamble preamble;
181  preamble.magic = kMagicPreamble;
182  preamble.data_size = buffer_.size();
183  preamble.fd_count = file_handles_.size();
184  Status<void> ret = SendAll(sender, socket_fd, &preamble, sizeof(preamble));
185  if (!ret)
186    return ret;
187
188  msghdr msg = {};
189  iovec recv_vect = {buffer_.data(), buffer_.size()};
190  msg.msg_iov = &recv_vect;
191  msg.msg_iovlen = 1;
192
193  if (cred || !file_handles_.empty()) {
194    const size_t fd_bytes = file_handles_.size() * sizeof(int);
195    msg.msg_controllen = (cred ? CMSG_SPACE(sizeof(ucred)) : 0) +
196                         (fd_bytes == 0 ? 0 : CMSG_SPACE(fd_bytes));
197    msg.msg_control = alloca(msg.msg_controllen);
198
199    cmsghdr* control = CMSG_FIRSTHDR(&msg);
200    if (cred) {
201      control->cmsg_level = SOL_SOCKET;
202      control->cmsg_type = SCM_CREDENTIALS;
203      control->cmsg_len = CMSG_LEN(sizeof(ucred));
204      memcpy(CMSG_DATA(control), cred, sizeof(ucred));
205      control = CMSG_NXTHDR(&msg, control);
206    }
207
208    if (fd_bytes) {
209      control->cmsg_level = SOL_SOCKET;
210      control->cmsg_type = SCM_RIGHTS;
211      control->cmsg_len = CMSG_LEN(fd_bytes);
212      memcpy(CMSG_DATA(control), file_handles_.data(), fd_bytes);
213    }
214  }
215
216  return SendMsgAll(sender, socket_fd, &msg);
217}
218
219// MessageWriter
220void* SendPayload::GetNextWriteBufferSection(size_t size) {
221  return buffer_.grow_by(size);
222}
223
224OutputResourceMapper* SendPayload::GetOutputResourceMapper() { return this; }
225
226// OutputResourceMapper
227Status<FileReference> SendPayload::PushFileHandle(const LocalHandle& handle) {
228  if (handle) {
229    const int ref = file_handles_.size();
230    file_handles_.push_back(handle.Get());
231    return ref;
232  } else {
233    return handle.Get();
234  }
235}
236
237Status<FileReference> SendPayload::PushFileHandle(
238    const BorrowedHandle& handle) {
239  if (handle) {
240    const int ref = file_handles_.size();
241    file_handles_.push_back(handle.Get());
242    return ref;
243  } else {
244    return handle.Get();
245  }
246}
247
248Status<FileReference> SendPayload::PushFileHandle(const RemoteHandle& handle) {
249  return handle.Get();
250}
251
252Status<ChannelReference> SendPayload::PushChannelHandle(
253    const LocalChannelHandle& /*handle*/) {
254  return ErrorStatus{EOPNOTSUPP};
255}
256Status<ChannelReference> SendPayload::PushChannelHandle(
257    const BorrowedChannelHandle& /*handle*/) {
258  return ErrorStatus{EOPNOTSUPP};
259}
260Status<ChannelReference> SendPayload::PushChannelHandle(
261    const RemoteChannelHandle& /*handle*/) {
262  return ErrorStatus{EOPNOTSUPP};
263}
264
265Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd) {
266  return Receive(socket_fd, nullptr);
267}
268
269Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
270                                     ucred* cred) {
271  RecvInterface* receiver = receiver_ ? receiver_ : &g_socket_receiver;
272  MessagePreamble preamble;
273  Status<void> ret = RecvAll(receiver, socket_fd, &preamble, sizeof(preamble));
274  if (!ret)
275    return ret;
276
277  if (preamble.magic != kMagicPreamble) {
278    ALOGE("ReceivePayload::Receive: Message header is invalid");
279    ret.SetError(EIO);
280    return ret;
281  }
282
283  buffer_.resize(preamble.data_size);
284  file_handles_.clear();
285  read_pos_ = 0;
286
287  msghdr msg = {};
288  iovec recv_vect = {buffer_.data(), buffer_.size()};
289  msg.msg_iov = &recv_vect;
290  msg.msg_iovlen = 1;
291
292  if (cred || preamble.fd_count) {
293    const size_t receive_fd_bytes = preamble.fd_count * sizeof(int);
294    msg.msg_controllen =
295        (cred ? CMSG_SPACE(sizeof(ucred)) : 0) +
296        (receive_fd_bytes == 0 ? 0 : CMSG_SPACE(receive_fd_bytes));
297    msg.msg_control = alloca(msg.msg_controllen);
298  }
299
300  ret = RecvMsgAll(receiver, socket_fd, &msg);
301  if (!ret)
302    return ret;
303
304  bool cred_available = false;
305  file_handles_.reserve(preamble.fd_count);
306  cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
307  while (cmsg) {
308    if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS &&
309        cred && cmsg->cmsg_len == CMSG_LEN(sizeof(ucred))) {
310      cred_available = true;
311      memcpy(cred, CMSG_DATA(cmsg), sizeof(ucred));
312    } else if (cmsg->cmsg_level == SOL_SOCKET &&
313               cmsg->cmsg_type == SCM_RIGHTS) {
314      socklen_t payload_len = cmsg->cmsg_len - CMSG_LEN(0);
315      const int* fds = reinterpret_cast<const int*>(CMSG_DATA(cmsg));
316      size_t fd_count = payload_len / sizeof(int);
317      std::transform(fds, fds + fd_count, std::back_inserter(file_handles_),
318                     [](int fd) { return LocalHandle{fd}; });
319    }
320    cmsg = CMSG_NXTHDR(&msg, cmsg);
321  }
322
323  if (cred && !cred_available) {
324    ALOGE("ReceivePayload::Receive: Failed to obtain message credentials");
325    ret.SetError(EIO);
326  }
327
328  return ret;
329}
330
331// MessageReader
332MessageReader::BufferSection ReceivePayload::GetNextReadBufferSection() {
333  return {buffer_.data() + read_pos_, &*buffer_.end()};
334}
335
336void ReceivePayload::ConsumeReadBufferSectionData(const void* new_start) {
337  read_pos_ = PointerDistance(new_start, buffer_.data());
338}
339
340InputResourceMapper* ReceivePayload::GetInputResourceMapper() { return this; }
341
342// InputResourceMapper
343bool ReceivePayload::GetFileHandle(FileReference ref, LocalHandle* handle) {
344  if (ref < 0) {
345    *handle = LocalHandle{ref};
346    return true;
347  }
348  if (static_cast<size_t>(ref) > file_handles_.size())
349    return false;
350  *handle = std::move(file_handles_[ref]);
351  return true;
352}
353
354bool ReceivePayload::GetChannelHandle(ChannelReference /*ref*/,
355                                      LocalChannelHandle* /*handle*/) {
356  return false;
357}
358
359Status<void> SendData(const BorrowedHandle& socket_fd, const void* data,
360                      size_t size) {
361  return SendAll(&g_socket_sender, socket_fd, data, size);
362}
363
364Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
365                            size_t count) {
366  msghdr msg = {};
367  msg.msg_iov = const_cast<iovec*>(data);
368  msg.msg_iovlen = count;
369  return SendMsgAll(&g_socket_sender, socket_fd, &msg);
370}
371
372Status<void> ReceiveData(const BorrowedHandle& socket_fd, void* data,
373                         size_t size) {
374  return RecvAll(&g_socket_receiver, socket_fd, data, size);
375}
376
377Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
378                               const iovec* data, size_t count) {
379  msghdr msg = {};
380  msg.msg_iov = const_cast<iovec*>(data);
381  msg.msg_iovlen = count;
382  return RecvMsgAll(&g_socket_receiver, socket_fd, &msg);
383}
384
385size_t CountVectorSize(const iovec* vector, size_t count) {
386  return std::accumulate(
387      vector, vector + count, size_t{0},
388      [](size_t size, const iovec& vec) { return size + vec.iov_len; });
389}
390
391void InitRequest(android::pdx::uds::RequestHeader<BorrowedHandle>* request,
392                 int opcode, uint32_t send_len, uint32_t max_recv_len,
393                 bool is_impulse) {
394  request->op = opcode;
395  request->cred.pid = getpid();
396  request->cred.uid = geteuid();
397  request->cred.gid = getegid();
398  request->send_len = send_len;
399  request->max_recv_len = max_recv_len;
400  request->is_impulse = is_impulse;
401}
402
403Status<void> WaitForEndpoint(const std::string& endpoint_path,
404                             int64_t timeout_ms) {
405  // Endpoint path must be absolute.
406  if (endpoint_path.empty() || endpoint_path.front() != '/')
407    return ErrorStatus(EINVAL);
408
409  // Create inotify fd.
410  LocalHandle fd{inotify_init()};
411  if (!fd)
412    return ErrorStatus(errno);
413
414  // Set the inotify fd to non-blocking.
415  int ret = fcntl(fd.Get(), F_GETFL);
416  fcntl(fd.Get(), F_SETFL, ret | O_NONBLOCK);
417
418  // Setup the pollfd.
419  pollfd pfd = {fd.Get(), POLLIN, 0};
420
421  // Find locations of each path separator.
422  std::vector<size_t> separators{0};  // The path is absolute, so '/' is at #0.
423  size_t pos = endpoint_path.find('/', 1);
424  while (pos != std::string::npos) {
425    separators.push_back(pos);
426    pos = endpoint_path.find('/', pos + 1);
427  }
428  separators.push_back(endpoint_path.size());
429
430  // Walk down the path, checking for existence and waiting if needed.
431  pos = 1;
432  size_t links = 0;
433  std::string current;
434  while (pos < separators.size() && links <= MAXSYMLINKS) {
435    std::string previous = current;
436    current = endpoint_path.substr(0, separators[pos]);
437
438    // Check for existence; proceed to setup a watch if not.
439    if (access(current.c_str(), F_OK) < 0) {
440      if (errno != ENOENT)
441        return ErrorStatus(errno);
442
443      // Extract the name of the path component to wait for.
444      std::string next = current.substr(
445          separators[pos - 1] + 1, separators[pos] - separators[pos - 1] - 1);
446
447      // Add a watch on the last existing directory we reach.
448      int wd = inotify_add_watch(
449          fd.Get(), previous.c_str(),
450          IN_CREATE | IN_DELETE_SELF | IN_MOVE_SELF | IN_MOVED_TO);
451      if (wd < 0) {
452        if (errno != ENOENT)
453          return ErrorStatus(errno);
454        // Restart at the beginning if previous was deleted.
455        links = 0;
456        current.clear();
457        pos = 1;
458        continue;
459      }
460
461      // Make sure current didn't get created before the watch was added.
462      ret = access(current.c_str(), F_OK);
463      if (ret < 0) {
464        if (errno != ENOENT)
465          return ErrorStatus(errno);
466
467        bool exit_poll = false;
468        while (!exit_poll) {
469          // Wait for an event or timeout.
470          ret = poll(&pfd, 1, timeout_ms);
471          if (ret <= 0)
472            return ErrorStatus(ret == 0 ? ETIMEDOUT : errno);
473
474          // Read events.
475          char buffer[sizeof(inotify_event) + NAME_MAX + 1];
476
477          ret = read(fd.Get(), buffer, sizeof(buffer));
478          if (ret < 0) {
479            if (errno == EAGAIN || errno == EWOULDBLOCK)
480              continue;
481            else
482              return ErrorStatus(errno);
483          } else if (static_cast<size_t>(ret) < sizeof(struct inotify_event)) {
484            return ErrorStatus(EIO);
485          }
486
487          auto* event = reinterpret_cast<const inotify_event*>(buffer);
488          auto* end = reinterpret_cast<const inotify_event*>(buffer + ret);
489          while (event < end) {
490            std::string event_for;
491            if (event->len > 0)
492              event_for = event->name;
493
494            if (event->mask & (IN_CREATE | IN_MOVED_TO)) {
495              // See if this is the droid we're looking for.
496              if (next == event_for) {
497                exit_poll = true;
498                break;
499              }
500            } else if (event->mask & (IN_DELETE_SELF | IN_MOVE_SELF)) {
501              // Restart at the beginning if our watch dir is deleted.
502              links = 0;
503              current.clear();
504              pos = 0;
505              exit_poll = true;
506              break;
507            }
508
509            event = reinterpret_cast<const inotify_event*>(AdvancePointer(
510                event, sizeof(struct inotify_event) + event->len));
511          }  // while (event < end)
512        }    // while (!exit_poll)
513      }      // Current dir doesn't exist.
514      ret = inotify_rm_watch(fd.Get(), wd);
515      if (ret < 0 && errno != EINVAL)
516        return ErrorStatus(errno);
517    }  // if (access(current.c_str(), F_OK) < 0)
518
519    // Check for symbolic link and update link count.
520    struct stat stat_buf;
521    ret = lstat(current.c_str(), &stat_buf);
522    if (ret < 0 && errno != ENOENT)
523      return ErrorStatus(errno);
524    else if (ret == 0 && S_ISLNK(stat_buf.st_mode))
525      links++;
526    pos++;
527  }  // while (pos < separators.size() && links <= MAXSYMLINKS)
528
529  return {};
530}
531
532}  // namespace uds
533}  // namespace pdx
534}  // namespace android
535