1#include "pdx/service.h"
2
3#include <fcntl.h>
4#include <log/log.h>
5#include <utils/misc.h>
6
7#include <algorithm>
8#include <cstdint>
9
10#include <pdx/trace.h>
11
12namespace android {
13namespace pdx {
14
15std::shared_ptr<Channel> Channel::GetFromMessageInfo(const MessageInfo& info) {
16  return info.channel ? info.channel->shared_from_this()
17                      : std::shared_ptr<Channel>();
18}
19
20Message::Message() : replied_(true) {}
21
22Message::Message(const MessageInfo& info)
23    : service_{Service::GetFromMessageInfo(info)},
24      channel_{Channel::GetFromMessageInfo(info)},
25      info_{info},
26      replied_{IsImpulse()} {
27  auto svc = service_.lock();
28  if (svc)
29    state_ = svc->endpoint()->AllocateMessageState();
30}
31
32// C++11 specifies the move semantics for shared_ptr but not weak_ptr. This
33// means we have to manually implement the desired move semantics for Message.
34Message::Message(Message&& other) { *this = std::move(other); }
35
36Message& Message::operator=(Message&& other) {
37  Destroy();
38  auto base = reinterpret_cast<std::uint8_t*>(&info_);
39  std::fill(&base[0], &base[sizeof(info_)], 0);
40  replied_ = true;
41  std::swap(service_, other.service_);
42  std::swap(channel_, other.channel_);
43  std::swap(info_, other.info_);
44  std::swap(state_, other.state_);
45  std::swap(replied_, other.replied_);
46  return *this;
47}
48
49Message::~Message() { Destroy(); }
50
51void Message::Destroy() {
52  auto svc = service_.lock();
53  if (svc) {
54    if (!replied_) {
55      ALOGE(
56          "ERROR: Service \"%s\" failed to reply to message: op=%d pid=%d "
57          "cid=%d\n",
58          svc->name_.c_str(), info_.op, info_.pid, info_.cid);
59      svc->DefaultHandleMessage(*this);
60    }
61    svc->endpoint()->FreeMessageState(state_);
62  }
63  state_ = nullptr;
64  service_.reset();
65  channel_.reset();
66}
67
68const std::uint8_t* Message::ImpulseBegin() const {
69  return reinterpret_cast<const std::uint8_t*>(info_.impulse);
70}
71
72const std::uint8_t* Message::ImpulseEnd() const {
73  return ImpulseBegin() + (IsImpulse() ? GetSendLength() : 0);
74}
75
76Status<size_t> Message::ReadVector(const struct iovec* vector,
77                                   size_t vector_length) {
78  PDX_TRACE_NAME("Message::ReadVector");
79  if (auto svc = service_.lock()) {
80    return svc->endpoint()->ReadMessageData(this, vector, vector_length);
81  } else {
82    return ErrorStatus{ESHUTDOWN};
83  }
84}
85
86Status<void> Message::ReadVectorAll(const struct iovec* vector,
87                                    size_t vector_length) {
88  PDX_TRACE_NAME("Message::ReadVectorAll");
89  if (auto svc = service_.lock()) {
90    const auto status =
91        svc->endpoint()->ReadMessageData(this, vector, vector_length);
92    if (!status)
93      return status.error_status();
94    size_t size_to_read = 0;
95    for (size_t i = 0; i < vector_length; i++)
96      size_to_read += vector[i].iov_len;
97    if (status.get() < size_to_read)
98      return ErrorStatus{EIO};
99    return {};
100  } else {
101    return ErrorStatus{ESHUTDOWN};
102  }
103}
104
105Status<size_t> Message::Read(void* buffer, size_t length) {
106  PDX_TRACE_NAME("Message::Read");
107  if (auto svc = service_.lock()) {
108    const struct iovec vector = {buffer, length};
109    return svc->endpoint()->ReadMessageData(this, &vector, 1);
110  } else {
111    return ErrorStatus{ESHUTDOWN};
112  }
113}
114
115Status<size_t> Message::WriteVector(const struct iovec* vector,
116                                    size_t vector_length) {
117  PDX_TRACE_NAME("Message::WriteVector");
118  if (auto svc = service_.lock()) {
119    return svc->endpoint()->WriteMessageData(this, vector, vector_length);
120  } else {
121    return ErrorStatus{ESHUTDOWN};
122  }
123}
124
125Status<void> Message::WriteVectorAll(const struct iovec* vector,
126                                     size_t vector_length) {
127  PDX_TRACE_NAME("Message::WriteVector");
128  if (auto svc = service_.lock()) {
129    const auto status =
130        svc->endpoint()->WriteMessageData(this, vector, vector_length);
131    if (!status)
132      return status.error_status();
133    size_t size_to_write = 0;
134    for (size_t i = 0; i < vector_length; i++)
135      size_to_write += vector[i].iov_len;
136    if (status.get() < size_to_write)
137      return ErrorStatus{EIO};
138    return {};
139  } else {
140    return ErrorStatus{ESHUTDOWN};
141  }
142}
143
144Status<size_t> Message::Write(const void* buffer, size_t length) {
145  PDX_TRACE_NAME("Message::Write");
146  if (auto svc = service_.lock()) {
147    const struct iovec vector = {const_cast<void*>(buffer), length};
148    return svc->endpoint()->WriteMessageData(this, &vector, 1);
149  } else {
150    return ErrorStatus{ESHUTDOWN};
151  }
152}
153
154Status<FileReference> Message::PushFileHandle(const LocalHandle& handle) {
155  PDX_TRACE_NAME("Message::PushFileHandle");
156  if (auto svc = service_.lock()) {
157    return svc->endpoint()->PushFileHandle(this, handle);
158  } else {
159    return ErrorStatus{ESHUTDOWN};
160  }
161}
162
163Status<FileReference> Message::PushFileHandle(const BorrowedHandle& handle) {
164  PDX_TRACE_NAME("Message::PushFileHandle");
165  if (auto svc = service_.lock()) {
166    return svc->endpoint()->PushFileHandle(this, handle);
167  } else {
168    return ErrorStatus{ESHUTDOWN};
169  }
170}
171
172Status<FileReference> Message::PushFileHandle(const RemoteHandle& handle) {
173  PDX_TRACE_NAME("Message::PushFileHandle");
174  if (auto svc = service_.lock()) {
175    return svc->endpoint()->PushFileHandle(this, handle);
176  } else {
177    return ErrorStatus{ESHUTDOWN};
178  }
179}
180
181Status<ChannelReference> Message::PushChannelHandle(
182    const LocalChannelHandle& handle) {
183  PDX_TRACE_NAME("Message::PushChannelHandle");
184  if (auto svc = service_.lock()) {
185    return svc->endpoint()->PushChannelHandle(this, handle);
186  } else {
187    return ErrorStatus{ESHUTDOWN};
188  }
189}
190
191Status<ChannelReference> Message::PushChannelHandle(
192    const BorrowedChannelHandle& handle) {
193  PDX_TRACE_NAME("Message::PushChannelHandle");
194  if (auto svc = service_.lock()) {
195    return svc->endpoint()->PushChannelHandle(this, handle);
196  } else {
197    return ErrorStatus{ESHUTDOWN};
198  }
199}
200
201Status<ChannelReference> Message::PushChannelHandle(
202    const RemoteChannelHandle& handle) {
203  PDX_TRACE_NAME("Message::PushChannelHandle");
204  if (auto svc = service_.lock()) {
205    return svc->endpoint()->PushChannelHandle(this, handle);
206  } else {
207    return ErrorStatus{ESHUTDOWN};
208  }
209}
210
211bool Message::GetFileHandle(FileReference ref, LocalHandle* handle) {
212  PDX_TRACE_NAME("Message::GetFileHandle");
213  auto svc = service_.lock();
214  if (!svc)
215    return false;
216
217  if (ref >= 0) {
218    *handle = svc->endpoint()->GetFileHandle(this, ref);
219    if (!handle->IsValid())
220      return false;
221  } else {
222    *handle = LocalHandle{ref};
223  }
224  return true;
225}
226
227bool Message::GetChannelHandle(ChannelReference ref,
228                               LocalChannelHandle* handle) {
229  PDX_TRACE_NAME("Message::GetChannelHandle");
230  auto svc = service_.lock();
231  if (!svc)
232    return false;
233
234  if (ref >= 0) {
235    *handle = svc->endpoint()->GetChannelHandle(this, ref);
236    if (!handle->valid())
237      return false;
238  } else {
239    *handle = LocalChannelHandle{nullptr, ref};
240  }
241  return true;
242}
243
244Status<void> Message::Reply(int return_code) {
245  PDX_TRACE_NAME("Message::Reply");
246  auto svc = service_.lock();
247  if (!replied_ && svc) {
248    const auto ret = svc->endpoint()->MessageReply(this, return_code);
249    replied_ = ret.ok();
250    return ret;
251  } else {
252    return ErrorStatus{EINVAL};
253  }
254}
255
256Status<void> Message::ReplyFileDescriptor(unsigned int fd) {
257  PDX_TRACE_NAME("Message::ReplyFileDescriptor");
258  auto svc = service_.lock();
259  if (!replied_ && svc) {
260    const auto ret = svc->endpoint()->MessageReplyFd(this, fd);
261    replied_ = ret.ok();
262    return ret;
263  } else {
264    return ErrorStatus{EINVAL};
265  }
266}
267
268Status<void> Message::ReplyError(unsigned int error) {
269  PDX_TRACE_NAME("Message::ReplyError");
270  auto svc = service_.lock();
271  if (!replied_ && svc) {
272    const auto ret =
273        svc->endpoint()->MessageReply(this, -static_cast<int>(error));
274    replied_ = ret.ok();
275    return ret;
276  } else {
277    return ErrorStatus{EINVAL};
278  }
279}
280
281Status<void> Message::Reply(const LocalHandle& handle) {
282  PDX_TRACE_NAME("Message::ReplyFileHandle");
283  auto svc = service_.lock();
284  if (!replied_ && svc) {
285    Status<void> ret;
286
287    if (handle)
288      ret = svc->endpoint()->MessageReplyFd(this, handle.Get());
289    else
290      ret = svc->endpoint()->MessageReply(this, handle.Get());
291
292    replied_ = ret.ok();
293    return ret;
294  } else {
295    return ErrorStatus{EINVAL};
296  }
297}
298
299Status<void> Message::Reply(const BorrowedHandle& handle) {
300  PDX_TRACE_NAME("Message::ReplyFileHandle");
301  auto svc = service_.lock();
302  if (!replied_ && svc) {
303    Status<void> ret;
304
305    if (handle)
306      ret = svc->endpoint()->MessageReplyFd(this, handle.Get());
307    else
308      ret = svc->endpoint()->MessageReply(this, handle.Get());
309
310    replied_ = ret.ok();
311    return ret;
312  } else {
313    return ErrorStatus{EINVAL};
314  }
315}
316
317Status<void> Message::Reply(const RemoteHandle& handle) {
318  PDX_TRACE_NAME("Message::ReplyFileHandle");
319  auto svc = service_.lock();
320  if (!replied_ && svc) {
321    Status<void> ret;
322
323    if (handle)
324      ret = svc->endpoint()->MessageReply(this, handle.Get());
325    else
326      ret = svc->endpoint()->MessageReply(this, handle.Get());
327
328    replied_ = ret.ok();
329    return ret;
330  } else {
331    return ErrorStatus{EINVAL};
332  }
333}
334
335Status<void> Message::Reply(const LocalChannelHandle& handle) {
336  auto svc = service_.lock();
337  if (!replied_ && svc) {
338    const auto ret = svc->endpoint()->MessageReplyChannelHandle(this, handle);
339    replied_ = ret.ok();
340    return ret;
341  } else {
342    return ErrorStatus{EINVAL};
343  }
344}
345
346Status<void> Message::Reply(const BorrowedChannelHandle& handle) {
347  auto svc = service_.lock();
348  if (!replied_ && svc) {
349    const auto ret = svc->endpoint()->MessageReplyChannelHandle(this, handle);
350    replied_ = ret.ok();
351    return ret;
352  } else {
353    return ErrorStatus{EINVAL};
354  }
355}
356
357Status<void> Message::Reply(const RemoteChannelHandle& handle) {
358  auto svc = service_.lock();
359  if (!replied_ && svc) {
360    const auto ret = svc->endpoint()->MessageReplyChannelHandle(this, handle);
361    replied_ = ret.ok();
362    return ret;
363  } else {
364    return ErrorStatus{EINVAL};
365  }
366}
367
368Status<void> Message::ModifyChannelEvents(int clear_mask, int set_mask) {
369  PDX_TRACE_NAME("Message::ModifyChannelEvents");
370  if (auto svc = service_.lock()) {
371    return svc->endpoint()->ModifyChannelEvents(info_.cid, clear_mask,
372                                                set_mask);
373  } else {
374    return ErrorStatus{ESHUTDOWN};
375  }
376}
377
378Status<RemoteChannelHandle> Message::PushChannel(
379    int flags, const std::shared_ptr<Channel>& channel, int* channel_id) {
380  PDX_TRACE_NAME("Message::PushChannel");
381  if (auto svc = service_.lock()) {
382    return svc->PushChannel(this, flags, channel, channel_id);
383  } else {
384    return ErrorStatus(ESHUTDOWN);
385  }
386}
387
388Status<RemoteChannelHandle> Message::PushChannel(
389    Service* service, int flags, const std::shared_ptr<Channel>& channel,
390    int* channel_id) {
391  PDX_TRACE_NAME("Message::PushChannel");
392  return service->PushChannel(this, flags, channel, channel_id);
393}
394
395Status<int> Message::CheckChannel(ChannelReference ref,
396                                  std::shared_ptr<Channel>* channel) const {
397  PDX_TRACE_NAME("Message::CheckChannel");
398  if (auto svc = service_.lock()) {
399    return svc->CheckChannel(this, ref, channel);
400  } else {
401    return ErrorStatus(ESHUTDOWN);
402  }
403}
404
405Status<int> Message::CheckChannel(const Service* service, ChannelReference ref,
406                                  std::shared_ptr<Channel>* channel) const {
407  PDX_TRACE_NAME("Message::CheckChannel");
408  return service->CheckChannel(this, ref, channel);
409}
410
411pid_t Message::GetProcessId() const { return info_.pid; }
412
413pid_t Message::GetThreadId() const { return info_.tid; }
414
415uid_t Message::GetEffectiveUserId() const { return info_.euid; }
416
417gid_t Message::GetEffectiveGroupId() const { return info_.egid; }
418
419int Message::GetChannelId() const { return info_.cid; }
420
421int Message::GetMessageId() const { return info_.mid; }
422
423int Message::GetOp() const { return info_.op; }
424
425int Message::GetFlags() const { return info_.flags; }
426
427size_t Message::GetSendLength() const { return info_.send_len; }
428
429size_t Message::GetReceiveLength() const { return info_.recv_len; }
430
431size_t Message::GetFileDescriptorCount() const { return info_.fd_count; }
432
433std::shared_ptr<Channel> Message::GetChannel() const { return channel_.lock(); }
434
435Status<void> Message::SetChannel(const std::shared_ptr<Channel>& chan) {
436  channel_ = chan;
437  Status<void> status;
438  if (auto svc = service_.lock())
439    status = svc->SetChannel(info_.cid, chan);
440  return status;
441}
442
443std::shared_ptr<Service> Message::GetService() const { return service_.lock(); }
444
445const MessageInfo& Message::GetInfo() const { return info_; }
446
447Service::Service(const std::string& name, std::unique_ptr<Endpoint> endpoint)
448    : name_(name), endpoint_{std::move(endpoint)} {
449  if (!endpoint_)
450    return;
451
452  const auto status = endpoint_->SetService(this);
453  ALOGE_IF(!status, "Failed to set service context because: %s",
454           status.GetErrorMessage().c_str());
455}
456
457Service::~Service() {
458  if (endpoint_) {
459    const auto status = endpoint_->SetService(nullptr);
460    ALOGE_IF(!status, "Failed to clear service context because: %s",
461             status.GetErrorMessage().c_str());
462  }
463}
464
465std::shared_ptr<Service> Service::GetFromMessageInfo(const MessageInfo& info) {
466  return info.service ? info.service->shared_from_this()
467                      : std::shared_ptr<Service>();
468}
469
470bool Service::IsInitialized() const { return endpoint_.get() != nullptr; }
471
472std::shared_ptr<Channel> Service::OnChannelOpen(Message& /*message*/) {
473  return nullptr;
474}
475
476void Service::OnChannelClose(Message& /*message*/,
477                             const std::shared_ptr<Channel>& /*channel*/) {}
478
479Status<void> Service::SetChannel(int channel_id,
480                                 const std::shared_ptr<Channel>& channel) {
481  PDX_TRACE_NAME("Service::SetChannel");
482  std::lock_guard<std::mutex> autolock(channels_mutex_);
483
484  const auto status = endpoint_->SetChannel(channel_id, channel.get());
485  if (!status) {
486    ALOGE("%s::SetChannel: Failed to set channel context: %s\n", name_.c_str(),
487          status.GetErrorMessage().c_str());
488
489    // It's possible someone mucked with things behind our back by calling the C
490    // API directly. Since we know the channel id isn't valid, make sure we
491    // don't have it in the channels map.
492    if (status.error() == ENOENT)
493      channels_.erase(channel_id);
494  } else {
495    if (channel != nullptr)
496      channels_[channel_id] = channel;
497    else
498      channels_.erase(channel_id);
499  }
500  return status;
501}
502
503std::shared_ptr<Channel> Service::GetChannel(int channel_id) const {
504  PDX_TRACE_NAME("Service::GetChannel");
505  std::lock_guard<std::mutex> autolock(channels_mutex_);
506
507  auto search = channels_.find(channel_id);
508  if (search != channels_.end())
509    return search->second;
510  else
511    return nullptr;
512}
513
514Status<void> Service::CloseChannel(int channel_id) {
515  PDX_TRACE_NAME("Service::CloseChannel");
516  std::lock_guard<std::mutex> autolock(channels_mutex_);
517
518  const auto status = endpoint_->CloseChannel(channel_id);
519
520  // Always erase the map entry, in case someone mucked with things behind our
521  // back using the C API directly.
522  channels_.erase(channel_id);
523
524  return status;
525}
526
527Status<void> Service::ModifyChannelEvents(int channel_id, int clear_mask,
528                                          int set_mask) {
529  PDX_TRACE_NAME("Service::ModifyChannelEvents");
530  return endpoint_->ModifyChannelEvents(channel_id, clear_mask, set_mask);
531}
532
533Status<RemoteChannelHandle> Service::PushChannel(
534    Message* message, int flags, const std::shared_ptr<Channel>& channel,
535    int* channel_id) {
536  PDX_TRACE_NAME("Service::PushChannel");
537
538  std::lock_guard<std::mutex> autolock(channels_mutex_);
539
540  int channel_id_temp = -1;
541  Status<RemoteChannelHandle> ret =
542      endpoint_->PushChannel(message, flags, channel.get(), &channel_id_temp);
543  ALOGE_IF(!ret.ok(), "%s::PushChannel: Failed to push channel: %s",
544           name_.c_str(), strerror(ret.error()));
545
546  if (channel && channel_id_temp != -1)
547    channels_[channel_id_temp] = channel;
548  if (channel_id)
549    *channel_id = channel_id_temp;
550
551  return ret;
552}
553
554Status<int> Service::CheckChannel(const Message* message, ChannelReference ref,
555                                  std::shared_ptr<Channel>* channel) const {
556  PDX_TRACE_NAME("Service::CheckChannel");
557
558  // Synchronization to maintain consistency between the kernel's channel
559  // context pointer and the userspace channels_ map. Other threads may attempt
560  // to modify the map at the same time, which could cause the channel context
561  // pointer returned by the kernel to be invalid.
562  std::lock_guard<std::mutex> autolock(channels_mutex_);
563
564  Channel* channel_context = nullptr;
565  Status<int> ret = endpoint_->CheckChannel(
566      message, ref, channel ? &channel_context : nullptr);
567  if (ret && channel) {
568    if (channel_context)
569      *channel = channel_context->shared_from_this();
570    else
571      *channel = nullptr;
572  }
573
574  return ret;
575}
576
577std::string Service::DumpState(size_t /*max_length*/) { return ""; }
578
579Status<void> Service::HandleMessage(Message& message) {
580  return DefaultHandleMessage(message);
581}
582
583void Service::HandleImpulse(Message& /*impulse*/) {}
584
585Status<void> Service::HandleSystemMessage(Message& message) {
586  const MessageInfo& info = message.GetInfo();
587
588  switch (info.op) {
589    case opcodes::CHANNEL_OPEN: {
590      ALOGD("%s::OnChannelOpen: pid=%d cid=%d\n", name_.c_str(), info.pid,
591            info.cid);
592      message.SetChannel(OnChannelOpen(message));
593      return message.Reply(0);
594    }
595
596    case opcodes::CHANNEL_CLOSE: {
597      ALOGD("%s::OnChannelClose: pid=%d cid=%d\n", name_.c_str(), info.pid,
598            info.cid);
599      OnChannelClose(message, Channel::GetFromMessageInfo(info));
600      message.SetChannel(nullptr);
601      return message.Reply(0);
602    }
603
604    case opcodes::REPORT_SYSPROP_CHANGE:
605      ALOGD("%s:REPORT_SYSPROP_CHANGE: pid=%d cid=%d\n", name_.c_str(),
606            info.pid, info.cid);
607      OnSysPropChange();
608      android::report_sysprop_change();
609      return message.Reply(0);
610
611    case opcodes::DUMP_STATE: {
612      ALOGD("%s:DUMP_STATE: pid=%d cid=%d\n", name_.c_str(), info.pid,
613            info.cid);
614      auto response = DumpState(message.GetReceiveLength());
615      const size_t response_size = response.size() < message.GetReceiveLength()
616                                       ? response.size()
617                                       : message.GetReceiveLength();
618      const Status<size_t> status =
619          message.Write(response.data(), response_size);
620      if (status && status.get() < response_size)
621        return message.ReplyError(EIO);
622      else
623        return message.Reply(status);
624    }
625
626    default:
627      return ErrorStatus{EOPNOTSUPP};
628  }
629}
630
631Status<void> Service::DefaultHandleMessage(Message& message) {
632  const MessageInfo& info = message.GetInfo();
633
634  ALOGD_IF(TRACE, "Service::DefaultHandleMessage: pid=%d cid=%d op=%d\n",
635           info.pid, info.cid, info.op);
636
637  switch (info.op) {
638    case opcodes::CHANNEL_OPEN:
639    case opcodes::CHANNEL_CLOSE:
640    case opcodes::REPORT_SYSPROP_CHANGE:
641    case opcodes::DUMP_STATE:
642      return HandleSystemMessage(message);
643
644    default:
645      return message.ReplyError(EOPNOTSUPP);
646  }
647}
648
649void Service::OnSysPropChange() {}
650
651Status<void> Service::ReceiveAndDispatch() {
652  Message message;
653  const auto status = endpoint_->MessageReceive(&message);
654  if (!status) {
655    ALOGE("Failed to receive message: %s\n", status.GetErrorMessage().c_str());
656    return status;
657  }
658
659  std::shared_ptr<Service> service = message.GetService();
660
661  if (!service) {
662    ALOGE("Service::ReceiveAndDispatch: service context is NULL!!!\n");
663    // Don't block the sender indefinitely in this error case.
664    endpoint_->MessageReply(&message, -EINVAL);
665    return ErrorStatus{EINVAL};
666  }
667
668  if (message.IsImpulse()) {
669    service->HandleImpulse(message);
670    return {};
671  } else if (service->HandleSystemMessage(message)) {
672    return {};
673  } else {
674    return service->HandleMessage(message);
675  }
676}
677
678Status<void> Service::Cancel() { return endpoint_->Cancel(); }
679
680}  // namespace pdx
681}  // namespace android
682