1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "mojo/system/raw_channel.h"
6
7#include <errno.h>
8#include <sys/uio.h>
9#include <unistd.h>
10
11#include <algorithm>
12#include <deque>
13
14#include "base/basictypes.h"
15#include "base/bind.h"
16#include "base/compiler_specific.h"
17#include "base/location.h"
18#include "base/logging.h"
19#include "base/memory/scoped_ptr.h"
20#include "base/memory/weak_ptr.h"
21#include "base/message_loop/message_loop.h"
22#include "base/synchronization/lock.h"
23#include "mojo/embedder/platform_channel_utils_posix.h"
24#include "mojo/embedder/platform_handle.h"
25#include "mojo/embedder/platform_handle_vector.h"
26#include "mojo/system/transport_data.h"
27
28namespace mojo {
29namespace system {
30
31namespace {
32
33class RawChannelPosix : public RawChannel,
34                        public base::MessageLoopForIO::Watcher {
35 public:
36  explicit RawChannelPosix(embedder::ScopedPlatformHandle handle);
37  virtual ~RawChannelPosix();
38
39  // |RawChannel| public methods:
40  virtual size_t GetSerializedPlatformHandleSize() const OVERRIDE;
41
42 private:
43  // |RawChannel| protected methods:
44  // Actually override this so that we can send multiple messages with (only)
45  // FDs if necessary.
46  virtual void EnqueueMessageNoLock(
47      scoped_ptr<MessageInTransit> message) OVERRIDE;
48  // Override this to handle those extra FD-only messages.
49  virtual bool OnReadMessageForRawChannel(
50      const MessageInTransit::View& message_view) OVERRIDE;
51  virtual IOResult Read(size_t* bytes_read) OVERRIDE;
52  virtual IOResult ScheduleRead() OVERRIDE;
53  virtual embedder::ScopedPlatformHandleVectorPtr GetReadPlatformHandles(
54      size_t num_platform_handles,
55      const void* platform_handle_table) OVERRIDE;
56  virtual IOResult WriteNoLock(size_t* platform_handles_written,
57                               size_t* bytes_written) OVERRIDE;
58  virtual IOResult ScheduleWriteNoLock() OVERRIDE;
59  virtual bool OnInit() OVERRIDE;
60  virtual void OnShutdownNoLock(
61      scoped_ptr<ReadBuffer> read_buffer,
62      scoped_ptr<WriteBuffer> write_buffer) OVERRIDE;
63
64  // |base::MessageLoopForIO::Watcher| implementation:
65  virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE;
66  virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE;
67
68  // Watches for |fd_| to become writable. Must be called on the I/O thread.
69  void WaitToWrite();
70
71  embedder::ScopedPlatformHandle fd_;
72
73  // The following members are only used on the I/O thread:
74  scoped_ptr<base::MessageLoopForIO::FileDescriptorWatcher> read_watcher_;
75  scoped_ptr<base::MessageLoopForIO::FileDescriptorWatcher> write_watcher_;
76
77  bool pending_read_;
78
79  std::deque<embedder::PlatformHandle> read_platform_handles_;
80
81  // The following members are used on multiple threads and protected by
82  // |write_lock()|:
83  bool pending_write_;
84
85  // This is used for posting tasks from write threads to the I/O thread. It
86  // must only be accessed under |write_lock_|. The weak pointers it produces
87  // are only used/invalidated on the I/O thread.
88  base::WeakPtrFactory<RawChannelPosix> weak_ptr_factory_;
89
90  DISALLOW_COPY_AND_ASSIGN(RawChannelPosix);
91};
92
93RawChannelPosix::RawChannelPosix(embedder::ScopedPlatformHandle handle)
94    : fd_(handle.Pass()),
95      pending_read_(false),
96      pending_write_(false),
97      weak_ptr_factory_(this) {
98  DCHECK(fd_.is_valid());
99}
100
101RawChannelPosix::~RawChannelPosix() {
102  DCHECK(!pending_read_);
103  DCHECK(!pending_write_);
104
105  // No need to take the |write_lock()| here -- if there are still weak pointers
106  // outstanding, then we're hosed anyway (since we wouldn't be able to
107  // invalidate them cleanly, since we might not be on the I/O thread).
108  DCHECK(!weak_ptr_factory_.HasWeakPtrs());
109
110  // These must have been shut down/destroyed on the I/O thread.
111  DCHECK(!read_watcher_);
112  DCHECK(!write_watcher_);
113
114  embedder::CloseAllPlatformHandles(&read_platform_handles_);
115}
116
117size_t RawChannelPosix::GetSerializedPlatformHandleSize() const {
118  // We don't actually need any space on POSIX (since we just send FDs).
119  return 0;
120}
121
122void RawChannelPosix::EnqueueMessageNoLock(
123    scoped_ptr<MessageInTransit> message) {
124  if (message->transport_data()) {
125    embedder::PlatformHandleVector* const platform_handles =
126        message->transport_data()->platform_handles();
127    if (platform_handles &&
128        platform_handles->size() > embedder::kPlatformChannelMaxNumHandles) {
129      // We can't attach all the FDs to a single message, so we have to "split"
130      // the message. Send as many control messages as needed first with FDs
131      // attached (and no data).
132      size_t i = 0;
133      for (; platform_handles->size() - i >
134                 embedder::kPlatformChannelMaxNumHandles;
135           i += embedder::kPlatformChannelMaxNumHandles) {
136        scoped_ptr<MessageInTransit> fd_message(
137            new MessageInTransit(
138                MessageInTransit::kTypeRawChannel,
139                MessageInTransit::kSubtypeRawChannelPosixExtraPlatformHandles,
140                0,
141                NULL));
142        embedder::ScopedPlatformHandleVectorPtr fds(
143            new embedder::PlatformHandleVector(
144                platform_handles->begin() + i,
145                platform_handles->begin() + i +
146                    embedder::kPlatformChannelMaxNumHandles));
147        fd_message->SetTransportData(
148            make_scoped_ptr(new TransportData(fds.Pass())));
149        RawChannel::EnqueueMessageNoLock(fd_message.Pass());
150      }
151
152      // Remove the handles that we "moved" into the other messages.
153      platform_handles->erase(platform_handles->begin(),
154                              platform_handles->begin() + i);
155    }
156  }
157
158  RawChannel::EnqueueMessageNoLock(message.Pass());
159}
160
161bool RawChannelPosix::OnReadMessageForRawChannel(
162    const MessageInTransit::View& message_view) {
163  DCHECK_EQ(message_view.type(), MessageInTransit::kTypeRawChannel);
164
165  if (message_view.subtype() ==
166          MessageInTransit::kSubtypeRawChannelPosixExtraPlatformHandles) {
167    // We don't need to do anything. |RawChannel| won't extract the platform
168    // handles, and they'll be accumulated in |Read()|.
169    return true;
170  }
171
172  return RawChannel::OnReadMessageForRawChannel(message_view);
173}
174
175RawChannel::IOResult RawChannelPosix::Read(size_t* bytes_read) {
176  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
177  DCHECK(!pending_read_);
178
179  char* buffer = NULL;
180  size_t bytes_to_read = 0;
181  read_buffer()->GetBuffer(&buffer, &bytes_to_read);
182
183  size_t old_num_platform_handles = read_platform_handles_.size();
184  ssize_t read_result =
185      embedder::PlatformChannelRecvmsg(fd_.get(),
186                                       buffer,
187                                       bytes_to_read,
188                                       &read_platform_handles_);
189  if (read_platform_handles_.size() > old_num_platform_handles) {
190    DCHECK_LE(read_platform_handles_.size() - old_num_platform_handles,
191              embedder::kPlatformChannelMaxNumHandles);
192
193    // We should never accumulate more than |TransportData::kMaxPlatformHandles
194    // + embedder::kPlatformChannelMaxNumHandles| handles. (The latter part is
195    // possible because we could have accumulated all the handles for a message,
196    // then received the message data plus the first set of handles for the next
197    // message in the subsequent |recvmsg()|.)
198    if (read_platform_handles_.size() > (TransportData::kMaxPlatformHandles +
199            embedder::kPlatformChannelMaxNumHandles)) {
200      LOG(WARNING) << "Received too many platform handles";
201      embedder::CloseAllPlatformHandles(&read_platform_handles_);
202      read_platform_handles_.clear();
203      return IO_FAILED;
204    }
205  }
206
207  if (read_result > 0) {
208    *bytes_read = static_cast<size_t>(read_result);
209    return IO_SUCCEEDED;
210  }
211
212  // |read_result == 0| means "end of file".
213  if (read_result == 0 || (errno != EAGAIN && errno != EWOULDBLOCK)) {
214    PLOG_IF(WARNING, read_result != 0) << "recvmsg";
215
216    // Make sure that |OnFileCanReadWithoutBlocking()| won't be called again.
217    read_watcher_.reset();
218
219    return IO_FAILED;
220  }
221
222  return ScheduleRead();
223}
224
225RawChannel::IOResult RawChannelPosix::ScheduleRead() {
226  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
227  DCHECK(!pending_read_);
228
229  pending_read_ = true;
230
231  return IO_PENDING;
232}
233
234embedder::ScopedPlatformHandleVectorPtr RawChannelPosix::GetReadPlatformHandles(
235    size_t num_platform_handles,
236    const void* /*platform_handle_table*/) {
237  DCHECK_GT(num_platform_handles, 0u);
238
239  if (read_platform_handles_.size() < num_platform_handles) {
240    embedder::CloseAllPlatformHandles(&read_platform_handles_);
241    read_platform_handles_.clear();
242    return embedder::ScopedPlatformHandleVectorPtr();
243  }
244
245  embedder::ScopedPlatformHandleVectorPtr rv(
246      new embedder::PlatformHandleVector(num_platform_handles));
247  rv->assign(read_platform_handles_.begin(),
248             read_platform_handles_.begin() + num_platform_handles);
249  read_platform_handles_.erase(
250      read_platform_handles_.begin(),
251      read_platform_handles_.begin() + num_platform_handles);
252  return rv.Pass();
253}
254
255RawChannel::IOResult RawChannelPosix::WriteNoLock(
256    size_t* platform_handles_written,
257    size_t* bytes_written) {
258  write_lock().AssertAcquired();
259
260  DCHECK(!pending_write_);
261
262  size_t num_platform_handles = 0;
263  ssize_t write_result;
264  if (write_buffer_no_lock()->HavePlatformHandlesToSend()) {
265    embedder::PlatformHandle* platform_handles;
266    void* serialization_data;  // Actually unused.
267    write_buffer_no_lock()->GetPlatformHandlesToSend(&num_platform_handles,
268                                                     &platform_handles,
269                                                     &serialization_data);
270    DCHECK_GT(num_platform_handles, 0u);
271    DCHECK_LE(num_platform_handles, embedder::kPlatformChannelMaxNumHandles);
272    DCHECK(platform_handles);
273
274    // TODO(vtl): Reduce code duplication. (This is duplicated from below.)
275    std::vector<WriteBuffer::Buffer> buffers;
276    write_buffer_no_lock()->GetBuffers(&buffers);
277    DCHECK(!buffers.empty());
278    const size_t kMaxBufferCount = 10;
279    iovec iov[kMaxBufferCount];
280    size_t buffer_count = std::min(buffers.size(), kMaxBufferCount);
281    for (size_t i = 0; i < buffer_count; ++i) {
282      iov[i].iov_base = const_cast<char*>(buffers[i].addr);
283      iov[i].iov_len = buffers[i].size;
284    }
285
286    write_result = embedder::PlatformChannelSendmsgWithHandles(
287        fd_.get(), iov, buffer_count, platform_handles, num_platform_handles);
288    for (size_t i = 0; i < num_platform_handles; i++)
289      platform_handles[i].CloseIfNecessary();
290  } else {
291    std::vector<WriteBuffer::Buffer> buffers;
292    write_buffer_no_lock()->GetBuffers(&buffers);
293    DCHECK(!buffers.empty());
294
295    if (buffers.size() == 1) {
296      write_result = embedder::PlatformChannelWrite(fd_.get(), buffers[0].addr,
297                                                    buffers[0].size);
298    } else {
299      const size_t kMaxBufferCount = 10;
300      iovec iov[kMaxBufferCount];
301      size_t buffer_count = std::min(buffers.size(), kMaxBufferCount);
302      for (size_t i = 0; i < buffer_count; ++i) {
303        iov[i].iov_base = const_cast<char*>(buffers[i].addr);
304        iov[i].iov_len = buffers[i].size;
305      }
306
307      write_result = embedder::PlatformChannelWritev(fd_.get(), iov,
308                                                     buffer_count);
309    }
310  }
311
312  if (write_result >= 0) {
313    *platform_handles_written = num_platform_handles;
314    *bytes_written = static_cast<size_t>(write_result);
315    return IO_SUCCEEDED;
316  }
317
318  if (errno != EAGAIN && errno != EWOULDBLOCK) {
319    PLOG(ERROR) << "sendmsg/write/writev";
320    return IO_FAILED;
321  }
322
323  return ScheduleWriteNoLock();
324}
325
326RawChannel::IOResult RawChannelPosix::ScheduleWriteNoLock() {
327  write_lock().AssertAcquired();
328
329  DCHECK(!pending_write_);
330
331  // Set up to wait for the FD to become writable.
332  // If we're not on the I/O thread, we have to post a task to do this.
333  if (base::MessageLoop::current() != message_loop_for_io()) {
334    message_loop_for_io()->PostTask(
335        FROM_HERE,
336        base::Bind(&RawChannelPosix::WaitToWrite,
337                   weak_ptr_factory_.GetWeakPtr()));
338    pending_write_ = true;
339    return IO_PENDING;
340  }
341
342  if (message_loop_for_io()->WatchFileDescriptor(fd_.get().fd, false,
343          base::MessageLoopForIO::WATCH_WRITE, write_watcher_.get(), this)) {
344    pending_write_ = true;
345    return IO_PENDING;
346  }
347
348  return IO_FAILED;
349}
350
351bool RawChannelPosix::OnInit() {
352  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
353
354  DCHECK(!read_watcher_);
355  read_watcher_.reset(new base::MessageLoopForIO::FileDescriptorWatcher());
356  DCHECK(!write_watcher_);
357  write_watcher_.reset(new base::MessageLoopForIO::FileDescriptorWatcher());
358
359  if (!message_loop_for_io()->WatchFileDescriptor(fd_.get().fd, true,
360          base::MessageLoopForIO::WATCH_READ, read_watcher_.get(), this)) {
361    // TODO(vtl): I'm not sure |WatchFileDescriptor()| actually fails cleanly
362    // (in the sense of returning the message loop's state to what it was before
363    // it was called).
364    read_watcher_.reset();
365    write_watcher_.reset();
366    return false;
367  }
368
369  return true;
370}
371
372void RawChannelPosix::OnShutdownNoLock(
373    scoped_ptr<ReadBuffer> /*read_buffer*/,
374    scoped_ptr<WriteBuffer> /*write_buffer*/) {
375  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
376  write_lock().AssertAcquired();
377
378  read_watcher_.reset();  // This will stop watching (if necessary).
379  write_watcher_.reset();  // This will stop watching (if necessary).
380
381  pending_read_ = false;
382  pending_write_ = false;
383
384  DCHECK(fd_.is_valid());
385  fd_.reset();
386
387  weak_ptr_factory_.InvalidateWeakPtrs();
388}
389
390void RawChannelPosix::OnFileCanReadWithoutBlocking(int fd) {
391  DCHECK_EQ(fd, fd_.get().fd);
392  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
393
394  if (!pending_read_) {
395    NOTREACHED();
396    return;
397  }
398
399  pending_read_ = false;
400  size_t bytes_read = 0;
401  IOResult result = Read(&bytes_read);
402  if (result != IO_PENDING)
403    OnReadCompleted(result == IO_SUCCEEDED, bytes_read);
404
405  // On failure, |read_watcher_| must have been reset; on success,
406  // we assume that |OnReadCompleted()| always schedules another read.
407  // Otherwise, we could end up spinning -- getting
408  // |OnFileCanReadWithoutBlocking()| again and again but not doing any actual
409  // read.
410  // TODO(yzshen): An alternative is to stop watching if RawChannel doesn't
411  // schedule a new read. But that code won't be reached under the current
412  // RawChannel implementation.
413  DCHECK(!read_watcher_ || pending_read_);
414}
415
416void RawChannelPosix::OnFileCanWriteWithoutBlocking(int fd) {
417  DCHECK_EQ(fd, fd_.get().fd);
418  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
419
420  IOResult result = IO_FAILED;
421  size_t platform_handles_written = 0;
422  size_t bytes_written = 0;
423  {
424    base::AutoLock locker(write_lock());
425
426    DCHECK(pending_write_);
427
428    pending_write_ = false;
429    result = WriteNoLock(&platform_handles_written, &bytes_written);
430  }
431
432  if (result != IO_PENDING) {
433    OnWriteCompleted(result == IO_SUCCEEDED,
434                     platform_handles_written,
435                     bytes_written);
436  }
437}
438
439void RawChannelPosix::WaitToWrite() {
440  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io());
441
442  DCHECK(write_watcher_);
443
444  if (!message_loop_for_io()->WatchFileDescriptor(
445          fd_.get().fd, false, base::MessageLoopForIO::WATCH_WRITE,
446          write_watcher_.get(), this)) {
447    {
448      base::AutoLock locker(write_lock());
449
450      DCHECK(pending_write_);
451      pending_write_ = false;
452    }
453    OnWriteCompleted(false, 0, 0);
454  }
455}
456
457}  // namespace
458
459// -----------------------------------------------------------------------------
460
461// Static factory method declared in raw_channel.h.
462// static
463scoped_ptr<RawChannel> RawChannel::Create(
464    embedder::ScopedPlatformHandle handle) {
465  return scoped_ptr<RawChannel>(new RawChannelPosix(handle.Pass()));
466}
467
468}  // namespace system
469}  // namespace mojo
470