1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "mojo/system/raw_channel.h"
6
7#include <errno.h>
8#include <sys/uio.h>
9#include <unistd.h>
10
11#include <algorithm>
12#include <deque>
13
14#include "base/bind.h"
15#include "base/location.h"
16#include "base/logging.h"
17#include "base/macros.h"
18#include "base/memory/scoped_ptr.h"
19#include "base/memory/weak_ptr.h"
20#include "base/message_loop/message_loop.h"
21#include "base/synchronization/lock.h"
22#include "mojo/embedder/platform_channel_utils_posix.h"
23#include "mojo/embedder/platform_handle.h"
24#include "mojo/embedder/platform_handle_vector.h"
25#include "mojo/system/transport_data.h"
26
27namespace mojo {
28namespace system {
29
30namespace {
31
32class RawChannelPosix : public RawChannel,
33                        public base::MessageLoopForIO::Watcher {
34 public:
35  explicit RawChannelPosix(embedder::ScopedPlatformHandle handle);
36  virtual ~RawChannelPosix();
37
38  // |RawChannel| public methods:
39  virtual size_t GetSerializedPlatformHandleSize() const OVERRIDE;
40
41 private:
42  // |RawChannel| protected methods:
43  // Actually override this so that we can send multiple messages with (only)
44  // FDs if necessary.
45  virtual void EnqueueMessageNoLock(
46      scoped_ptr<MessageInTransit> message) OVERRIDE;
47  // Override this to handle those extra FD-only messages.
48  virtual bool OnReadMessageForRawChannel(
49      const MessageInTransit::View& message_view) OVERRIDE;
50  virtual IOResult Read(size_t* bytes_read) OVERRIDE;
51  virtual IOResult ScheduleRead() OVERRIDE;
52  virtual embedder::ScopedPlatformHandleVectorPtr GetReadPlatformHandles(
53      size_t num_platform_handles,
54      const void* platform_handle_table) OVERRIDE;
55  virtual IOResult WriteNoLock(size_t* platform_handles_written,
56                               size_t* bytes_written) OVERRIDE;
57  virtual IOResult ScheduleWriteNoLock() OVERRIDE;
58  virtual bool OnInit() OVERRIDE;
59  virtual void OnShutdownNoLock(scoped_ptr<ReadBuffer> read_buffer,
60                                scoped_ptr<WriteBuffer> write_buffer) OVERRIDE;
61
62  // |base::MessageLoopForIO::Watcher| implementation:
63  virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE;
64  virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE;
65
66  // Implements most of |Read()| (except for a bit of clean-up):
67  IOResult ReadImpl(size_t* bytes_read);
68
69  // Watches for |fd_| to become writable. Must be called on the I/O thread.
70  void WaitToWrite();
71
72  embedder::ScopedPlatformHandle fd_;
73
74  // The following members are only used on the I/O thread:
75  scoped_ptr<base::MessageLoopForIO::FileDescriptorWatcher> read_watcher_;
76  scoped_ptr<base::MessageLoopForIO::FileDescriptorWatcher> write_watcher_;
77
78  bool pending_read_;
79
80  std::deque<embedder::PlatformHandle> read_platform_handles_;
81
82  // The following members are used on multiple threads and protected by
83  // |write_lock()|:
84  bool pending_write_;
85
86  // This is used for posting tasks from write threads to the I/O thread. It
87  // must only be accessed under |write_lock_|. The weak pointers it produces
88  // are only used/invalidated on the I/O thread.
89  base::WeakPtrFactory<RawChannelPosix> weak_ptr_factory_;
90
91  DISALLOW_COPY_AND_ASSIGN(RawChannelPosix);
92};
93
94RawChannelPosix::RawChannelPosix(embedder::ScopedPlatformHandle handle)
95    : fd_(handle.Pass()),
96      pending_read_(false),
97      pending_write_(false),
98      weak_ptr_factory_(this) {
99  DCHECK(fd_.is_valid());
100}
101
102RawChannelPosix::~RawChannelPosix() {
103  DCHECK(!pending_read_);
104  DCHECK(!pending_write_);
105
106  // No need to take the |write_lock()| here -- if there are still weak pointers
107  // outstanding, then we're hosed anyway (since we wouldn't be able to
108  // invalidate them cleanly, since we might not be on the I/O thread).
109  DCHECK(!weak_ptr_factory_.HasWeakPtrs());
110
111  // These must have been shut down/destroyed on the I/O thread.
112  DCHECK(!read_watcher_);
113  DCHECK(!write_watcher_);
114
115  embedder::CloseAllPlatformHandles(&read_platform_handles_);
116}
117
118size_t RawChannelPosix::GetSerializedPlatformHandleSize() const {
119  // We don't actually need any space on POSIX (since we just send FDs).
120  return 0;
121}
122
123void RawChannelPosix::EnqueueMessageNoLock(
124    scoped_ptr<MessageInTransit> message) {
125  if (message->transport_data()) {
126    embedder::PlatformHandleVector* const platform_handles =
127        message->transport_data()->platform_handles();
128    if (platform_handles &&
129        platform_handles->size() > embedder::kPlatformChannelMaxNumHandles) {
130      // We can't attach all the FDs to a single message, so we have to "split"
131      // the message. Send as many control messages as needed first with FDs
132      // attached (and no data).
133      size_t i = 0;
134      for (; platform_handles->size() - i >
135                 embedder::kPlatformChannelMaxNumHandles;
136           i += embedder::kPlatformChannelMaxNumHandles) {
137        scoped_ptr<MessageInTransit> fd_message(new MessageInTransit(
138            MessageInTransit::kTypeRawChannel,
139            MessageInTransit::kSubtypeRawChannelPosixExtraPlatformHandles,
140            0,
141            nullptr));
142        embedder::ScopedPlatformHandleVectorPtr fds(
143            new embedder::PlatformHandleVector(
144                platform_handles->begin() + i,
145                platform_handles->begin() + i +
146                    embedder::kPlatformChannelMaxNumHandles));
147        fd_message->SetTransportData(
148            make_scoped_ptr(new TransportData(fds.Pass())));
149        RawChannel::EnqueueMessageNoLock(fd_message.Pass());
150      }
151
152      // Remove the handles that we "moved" into the other messages.
153      platform_handles->erase(platform_handles->begin(),
154                              platform_handles->begin() + i);
155    }
156  }
157
158  RawChannel::EnqueueMessageNoLock(message.Pass());
159}
160
161bool RawChannelPosix::OnReadMessageForRawChannel(
162    const MessageInTransit::View& message_view) {
163  DCHECK_EQ(message_view.type(), MessageInTransit::kTypeRawChannel);
164
165  if (message_view.subtype() ==
166      MessageInTransit::kSubtypeRawChannelPosixExtraPlatformHandles) {
167    // We don't need to do anything. |RawChannel| won't extract the platform
168    // handles, and they'll be accumulated in |Read()|.
169    return true;
170  }
171
172  return RawChannel::OnReadMessageForRawChannel(message_view);
173}
174
175RawChannel::IOResult RawChannelPosix::Read(size_t* bytes_read) {
176  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
177  DCHECK(!pending_read_);
178
179  IOResult rv = ReadImpl(bytes_read);
180  if (rv != IO_SUCCEEDED && rv != IO_PENDING) {
181    // Make sure that |OnFileCanReadWithoutBlocking()| won't be called again.
182    read_watcher_.reset();
183  }
184  return rv;
185}
186
187RawChannel::IOResult RawChannelPosix::ScheduleRead() {
188  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
189  DCHECK(!pending_read_);
190
191  pending_read_ = true;
192
193  return IO_PENDING;
194}
195
196embedder::ScopedPlatformHandleVectorPtr RawChannelPosix::GetReadPlatformHandles(
197    size_t num_platform_handles,
198    const void* /*platform_handle_table*/) {
199  DCHECK_GT(num_platform_handles, 0u);
200
201  if (read_platform_handles_.size() < num_platform_handles) {
202    embedder::CloseAllPlatformHandles(&read_platform_handles_);
203    read_platform_handles_.clear();
204    return embedder::ScopedPlatformHandleVectorPtr();
205  }
206
207  embedder::ScopedPlatformHandleVectorPtr rv(
208      new embedder::PlatformHandleVector(num_platform_handles));
209  rv->assign(read_platform_handles_.begin(),
210             read_platform_handles_.begin() + num_platform_handles);
211  read_platform_handles_.erase(
212      read_platform_handles_.begin(),
213      read_platform_handles_.begin() + num_platform_handles);
214  return rv.Pass();
215}
216
217RawChannel::IOResult RawChannelPosix::WriteNoLock(
218    size_t* platform_handles_written,
219    size_t* bytes_written) {
220  write_lock().AssertAcquired();
221
222  DCHECK(!pending_write_);
223
224  size_t num_platform_handles = 0;
225  ssize_t write_result;
226  if (write_buffer_no_lock()->HavePlatformHandlesToSend()) {
227    embedder::PlatformHandle* platform_handles;
228    void* serialization_data;  // Actually unused.
229    write_buffer_no_lock()->GetPlatformHandlesToSend(
230        &num_platform_handles, &platform_handles, &serialization_data);
231    DCHECK_GT(num_platform_handles, 0u);
232    DCHECK_LE(num_platform_handles, embedder::kPlatformChannelMaxNumHandles);
233    DCHECK(platform_handles);
234
235    // TODO(vtl): Reduce code duplication. (This is duplicated from below.)
236    std::vector<WriteBuffer::Buffer> buffers;
237    write_buffer_no_lock()->GetBuffers(&buffers);
238    DCHECK(!buffers.empty());
239    const size_t kMaxBufferCount = 10;
240    iovec iov[kMaxBufferCount];
241    size_t buffer_count = std::min(buffers.size(), kMaxBufferCount);
242    for (size_t i = 0; i < buffer_count; ++i) {
243      iov[i].iov_base = const_cast<char*>(buffers[i].addr);
244      iov[i].iov_len = buffers[i].size;
245    }
246
247    write_result = embedder::PlatformChannelSendmsgWithHandles(
248        fd_.get(), iov, buffer_count, platform_handles, num_platform_handles);
249    for (size_t i = 0; i < num_platform_handles; i++)
250      platform_handles[i].CloseIfNecessary();
251  } else {
252    std::vector<WriteBuffer::Buffer> buffers;
253    write_buffer_no_lock()->GetBuffers(&buffers);
254    DCHECK(!buffers.empty());
255
256    if (buffers.size() == 1) {
257      write_result = embedder::PlatformChannelWrite(
258          fd_.get(), buffers[0].addr, buffers[0].size);
259    } else {
260      const size_t kMaxBufferCount = 10;
261      iovec iov[kMaxBufferCount];
262      size_t buffer_count = std::min(buffers.size(), kMaxBufferCount);
263      for (size_t i = 0; i < buffer_count; ++i) {
264        iov[i].iov_base = const_cast<char*>(buffers[i].addr);
265        iov[i].iov_len = buffers[i].size;
266      }
267
268      write_result =
269          embedder::PlatformChannelWritev(fd_.get(), iov, buffer_count);
270    }
271  }
272
273  if (write_result >= 0) {
274    *platform_handles_written = num_platform_handles;
275    *bytes_written = static_cast<size_t>(write_result);
276    return IO_SUCCEEDED;
277  }
278
279  if (errno == EPIPE)
280    return IO_FAILED_SHUTDOWN;
281
282  if (errno != EAGAIN && errno != EWOULDBLOCK) {
283    PLOG(WARNING) << "sendmsg/write/writev";
284    return IO_FAILED_UNKNOWN;
285  }
286
287  return ScheduleWriteNoLock();
288}
289
290RawChannel::IOResult RawChannelPosix::ScheduleWriteNoLock() {
291  write_lock().AssertAcquired();
292
293  DCHECK(!pending_write_);
294
295  // Set up to wait for the FD to become writable.
296  // If we're not on the I/O thread, we have to post a task to do this.
297  if (base::MessageLoop::current() != message_loop_for_io()) {
298    message_loop_for_io()->PostTask(FROM_HERE,
299                                    base::Bind(&RawChannelPosix::WaitToWrite,
300                                               weak_ptr_factory_.GetWeakPtr()));
301    pending_write_ = true;
302    return IO_PENDING;
303  }
304
305  if (message_loop_for_io()->WatchFileDescriptor(
306          fd_.get().fd,
307          false,
308          base::MessageLoopForIO::WATCH_WRITE,
309          write_watcher_.get(),
310          this)) {
311    pending_write_ = true;
312    return IO_PENDING;
313  }
314
315  return IO_FAILED_UNKNOWN;
316}
317
318bool RawChannelPosix::OnInit() {
319  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
320
321  DCHECK(!read_watcher_);
322  read_watcher_.reset(new base::MessageLoopForIO::FileDescriptorWatcher());
323  DCHECK(!write_watcher_);
324  write_watcher_.reset(new base::MessageLoopForIO::FileDescriptorWatcher());
325
326  if (!message_loop_for_io()->WatchFileDescriptor(
327          fd_.get().fd,
328          true,
329          base::MessageLoopForIO::WATCH_READ,
330          read_watcher_.get(),
331          this)) {
332    // TODO(vtl): I'm not sure |WatchFileDescriptor()| actually fails cleanly
333    // (in the sense of returning the message loop's state to what it was before
334    // it was called).
335    read_watcher_.reset();
336    write_watcher_.reset();
337    return false;
338  }
339
340  return true;
341}
342
343void RawChannelPosix::OnShutdownNoLock(
344    scoped_ptr<ReadBuffer> /*read_buffer*/,
345    scoped_ptr<WriteBuffer> /*write_buffer*/) {
346  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
347  write_lock().AssertAcquired();
348
349  read_watcher_.reset();   // This will stop watching (if necessary).
350  write_watcher_.reset();  // This will stop watching (if necessary).
351
352  pending_read_ = false;
353  pending_write_ = false;
354
355  DCHECK(fd_.is_valid());
356  fd_.reset();
357
358  weak_ptr_factory_.InvalidateWeakPtrs();
359}
360
361void RawChannelPosix::OnFileCanReadWithoutBlocking(int fd) {
362  DCHECK_EQ(fd, fd_.get().fd);
363  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
364
365  if (!pending_read_) {
366    NOTREACHED();
367    return;
368  }
369
370  pending_read_ = false;
371  size_t bytes_read = 0;
372  IOResult io_result = Read(&bytes_read);
373  if (io_result != IO_PENDING)
374    OnReadCompleted(io_result, bytes_read);
375
376  // On failure, |read_watcher_| must have been reset; on success,
377  // we assume that |OnReadCompleted()| always schedules another read.
378  // Otherwise, we could end up spinning -- getting
379  // |OnFileCanReadWithoutBlocking()| again and again but not doing any actual
380  // read.
381  // TODO(yzshen): An alternative is to stop watching if RawChannel doesn't
382  // schedule a new read. But that code won't be reached under the current
383  // RawChannel implementation.
384  DCHECK(!read_watcher_ || pending_read_);
385}
386
387void RawChannelPosix::OnFileCanWriteWithoutBlocking(int fd) {
388  DCHECK_EQ(fd, fd_.get().fd);
389  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
390
391  IOResult io_result;
392  size_t platform_handles_written = 0;
393  size_t bytes_written = 0;
394  {
395    base::AutoLock locker(write_lock());
396
397    DCHECK(pending_write_);
398
399    pending_write_ = false;
400    io_result = WriteNoLock(&platform_handles_written, &bytes_written);
401  }
402
403  if (io_result != IO_PENDING)
404    OnWriteCompleted(io_result, platform_handles_written, bytes_written);
405}
406
407RawChannel::IOResult RawChannelPosix::ReadImpl(size_t* bytes_read) {
408  char* buffer = nullptr;
409  size_t bytes_to_read = 0;
410  read_buffer()->GetBuffer(&buffer, &bytes_to_read);
411
412  size_t old_num_platform_handles = read_platform_handles_.size();
413  ssize_t read_result = embedder::PlatformChannelRecvmsg(
414      fd_.get(), buffer, bytes_to_read, &read_platform_handles_);
415  if (read_platform_handles_.size() > old_num_platform_handles) {
416    DCHECK_LE(read_platform_handles_.size() - old_num_platform_handles,
417              embedder::kPlatformChannelMaxNumHandles);
418
419    // We should never accumulate more than |TransportData::kMaxPlatformHandles
420    // + embedder::kPlatformChannelMaxNumHandles| handles. (The latter part is
421    // possible because we could have accumulated all the handles for a message,
422    // then received the message data plus the first set of handles for the next
423    // message in the subsequent |recvmsg()|.)
424    if (read_platform_handles_.size() >
425        (TransportData::kMaxPlatformHandles +
426         embedder::kPlatformChannelMaxNumHandles)) {
427      LOG(ERROR) << "Received too many platform handles";
428      embedder::CloseAllPlatformHandles(&read_platform_handles_);
429      read_platform_handles_.clear();
430      return IO_FAILED_UNKNOWN;
431    }
432  }
433
434  if (read_result > 0) {
435    *bytes_read = static_cast<size_t>(read_result);
436    return IO_SUCCEEDED;
437  }
438
439  // |read_result == 0| means "end of file".
440  if (read_result == 0)
441    return IO_FAILED_SHUTDOWN;
442
443  if (errno == EAGAIN || errno == EWOULDBLOCK)
444    return ScheduleRead();
445
446  if (errno == ECONNRESET)
447    return IO_FAILED_BROKEN;
448
449  PLOG(WARNING) << "recvmsg";
450  return IO_FAILED_UNKNOWN;
451}
452
453void RawChannelPosix::WaitToWrite() {
454  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
455
456  DCHECK(write_watcher_);
457
458  if (!message_loop_for_io()->WatchFileDescriptor(
459          fd_.get().fd,
460          false,
461          base::MessageLoopForIO::WATCH_WRITE,
462          write_watcher_.get(),
463          this)) {
464    {
465      base::AutoLock locker(write_lock());
466
467      DCHECK(pending_write_);
468      pending_write_ = false;
469    }
470    OnWriteCompleted(IO_FAILED_UNKNOWN, 0, 0);
471  }
472}
473
474}  // namespace
475
476// -----------------------------------------------------------------------------
477
478// Static factory method declared in raw_channel.h.
479// static
480scoped_ptr<RawChannel> RawChannel::Create(
481    embedder::ScopedPlatformHandle handle) {
482  return scoped_ptr<RawChannel>(new RawChannelPosix(handle.Pass()));
483}
484
485}  // namespace system
486}  // namespace mojo
487