raw_channel.cc revision 1320f92c476a1ad9d19dba2a48c72b75566198e9
1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "mojo/system/raw_channel.h"
6
7#include <string.h>
8
9#include <algorithm>
10
11#include "base/bind.h"
12#include "base/location.h"
13#include "base/logging.h"
14#include "base/message_loop/message_loop.h"
15#include "base/stl_util.h"
16#include "mojo/system/message_in_transit.h"
17#include "mojo/system/transport_data.h"
18
19namespace mojo {
20namespace system {
21
22const size_t kReadSize = 4096;
23
24// RawChannel::ReadBuffer ------------------------------------------------------
25
26RawChannel::ReadBuffer::ReadBuffer() : buffer_(kReadSize), num_valid_bytes_(0) {
27}
28
29RawChannel::ReadBuffer::~ReadBuffer() {
30}
31
32void RawChannel::ReadBuffer::GetBuffer(char** addr, size_t* size) {
33  DCHECK_GE(buffer_.size(), num_valid_bytes_ + kReadSize);
34  *addr = &buffer_[0] + num_valid_bytes_;
35  *size = kReadSize;
36}
37
38// RawChannel::WriteBuffer -----------------------------------------------------
39
40RawChannel::WriteBuffer::WriteBuffer(size_t serialized_platform_handle_size)
41    : serialized_platform_handle_size_(serialized_platform_handle_size),
42      platform_handles_offset_(0),
43      data_offset_(0) {
44}
45
46RawChannel::WriteBuffer::~WriteBuffer() {
47  STLDeleteElements(&message_queue_);
48}
49
50bool RawChannel::WriteBuffer::HavePlatformHandlesToSend() const {
51  if (message_queue_.empty())
52    return false;
53
54  const TransportData* transport_data =
55      message_queue_.front()->transport_data();
56  if (!transport_data)
57    return false;
58
59  const embedder::PlatformHandleVector* all_platform_handles =
60      transport_data->platform_handles();
61  if (!all_platform_handles) {
62    DCHECK_EQ(platform_handles_offset_, 0u);
63    return false;
64  }
65  if (platform_handles_offset_ >= all_platform_handles->size()) {
66    DCHECK_EQ(platform_handles_offset_, all_platform_handles->size());
67    return false;
68  }
69
70  return true;
71}
72
73void RawChannel::WriteBuffer::GetPlatformHandlesToSend(
74    size_t* num_platform_handles,
75    embedder::PlatformHandle** platform_handles,
76    void** serialization_data) {
77  DCHECK(HavePlatformHandlesToSend());
78
79  TransportData* transport_data = message_queue_.front()->transport_data();
80  embedder::PlatformHandleVector* all_platform_handles =
81      transport_data->platform_handles();
82  *num_platform_handles =
83      all_platform_handles->size() - platform_handles_offset_;
84  *platform_handles = &(*all_platform_handles)[platform_handles_offset_];
85  size_t serialization_data_offset =
86      transport_data->platform_handle_table_offset();
87  DCHECK_GT(serialization_data_offset, 0u);
88  serialization_data_offset +=
89      platform_handles_offset_ * serialized_platform_handle_size_;
90  *serialization_data =
91      static_cast<char*>(transport_data->buffer()) + serialization_data_offset;
92}
93
94void RawChannel::WriteBuffer::GetBuffers(std::vector<Buffer>* buffers) const {
95  buffers->clear();
96
97  if (message_queue_.empty())
98    return;
99
100  MessageInTransit* message = message_queue_.front();
101  DCHECK_LT(data_offset_, message->total_size());
102  size_t bytes_to_write = message->total_size() - data_offset_;
103
104  size_t transport_data_buffer_size =
105      message->transport_data() ? message->transport_data()->buffer_size() : 0;
106
107  if (!transport_data_buffer_size) {
108    // Only write from the main buffer.
109    DCHECK_LT(data_offset_, message->main_buffer_size());
110    DCHECK_LE(bytes_to_write, message->main_buffer_size());
111    Buffer buffer = {
112        static_cast<const char*>(message->main_buffer()) + data_offset_,
113        bytes_to_write};
114    buffers->push_back(buffer);
115    return;
116  }
117
118  if (data_offset_ >= message->main_buffer_size()) {
119    // Only write from the transport data buffer.
120    DCHECK_LT(data_offset_ - message->main_buffer_size(),
121              transport_data_buffer_size);
122    DCHECK_LE(bytes_to_write, transport_data_buffer_size);
123    Buffer buffer = {
124        static_cast<const char*>(message->transport_data()->buffer()) +
125            (data_offset_ - message->main_buffer_size()),
126        bytes_to_write};
127    buffers->push_back(buffer);
128    return;
129  }
130
131  // TODO(vtl): We could actually send out buffers from multiple messages, with
132  // the "stopping" condition being reaching a message with platform handles
133  // attached.
134
135  // Write from both buffers.
136  DCHECK_EQ(
137      bytes_to_write,
138      message->main_buffer_size() - data_offset_ + transport_data_buffer_size);
139  Buffer buffer1 = {
140      static_cast<const char*>(message->main_buffer()) + data_offset_,
141      message->main_buffer_size() - data_offset_};
142  buffers->push_back(buffer1);
143  Buffer buffer2 = {
144      static_cast<const char*>(message->transport_data()->buffer()),
145      transport_data_buffer_size};
146  buffers->push_back(buffer2);
147}
148
149// RawChannel ------------------------------------------------------------------
150
151RawChannel::RawChannel()
152    : message_loop_for_io_(nullptr),
153      delegate_(nullptr),
154      read_stopped_(false),
155      write_stopped_(false),
156      weak_ptr_factory_(this) {
157}
158
159RawChannel::~RawChannel() {
160  DCHECK(!read_buffer_);
161  DCHECK(!write_buffer_);
162
163  // No need to take the |write_lock_| here -- if there are still weak pointers
164  // outstanding, then we're hosed anyway (since we wouldn't be able to
165  // invalidate them cleanly, since we might not be on the I/O thread).
166  DCHECK(!weak_ptr_factory_.HasWeakPtrs());
167}
168
169bool RawChannel::Init(Delegate* delegate) {
170  DCHECK(delegate);
171
172  DCHECK(!delegate_);
173  delegate_ = delegate;
174
175  CHECK_EQ(base::MessageLoop::current()->type(), base::MessageLoop::TYPE_IO);
176  DCHECK(!message_loop_for_io_);
177  message_loop_for_io_ =
178      static_cast<base::MessageLoopForIO*>(base::MessageLoop::current());
179
180  // No need to take the lock. No one should be using us yet.
181  DCHECK(!read_buffer_);
182  read_buffer_.reset(new ReadBuffer);
183  DCHECK(!write_buffer_);
184  write_buffer_.reset(new WriteBuffer(GetSerializedPlatformHandleSize()));
185
186  if (!OnInit()) {
187    delegate_ = nullptr;
188    message_loop_for_io_ = nullptr;
189    read_buffer_.reset();
190    write_buffer_.reset();
191    return false;
192  }
193
194  IOResult io_result = ScheduleRead();
195  if (io_result != IO_PENDING) {
196    // This will notify the delegate about the read failure. Although we're on
197    // the I/O thread, don't call it in the nested context.
198    message_loop_for_io_->PostTask(FROM_HERE,
199                                   base::Bind(&RawChannel::OnReadCompleted,
200                                              weak_ptr_factory_.GetWeakPtr(),
201                                              io_result,
202                                              0));
203  }
204
205  // ScheduleRead() failure is treated as a read failure (by notifying the
206  // delegate), not as an init failure.
207  return true;
208}
209
210void RawChannel::Shutdown() {
211  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
212
213  base::AutoLock locker(write_lock_);
214
215  LOG_IF(WARNING, !write_buffer_->message_queue_.empty())
216      << "Shutting down RawChannel with write buffer nonempty";
217
218  // Reset the delegate so that it won't receive further calls.
219  delegate_ = nullptr;
220  read_stopped_ = true;
221  write_stopped_ = true;
222  weak_ptr_factory_.InvalidateWeakPtrs();
223
224  OnShutdownNoLock(read_buffer_.Pass(), write_buffer_.Pass());
225}
226
227// Reminder: This must be thread-safe.
228bool RawChannel::WriteMessage(scoped_ptr<MessageInTransit> message) {
229  DCHECK(message);
230
231  base::AutoLock locker(write_lock_);
232  if (write_stopped_)
233    return false;
234
235  if (!write_buffer_->message_queue_.empty()) {
236    EnqueueMessageNoLock(message.Pass());
237    return true;
238  }
239
240  EnqueueMessageNoLock(message.Pass());
241  DCHECK_EQ(write_buffer_->data_offset_, 0u);
242
243  size_t platform_handles_written = 0;
244  size_t bytes_written = 0;
245  IOResult io_result = WriteNoLock(&platform_handles_written, &bytes_written);
246  if (io_result == IO_PENDING)
247    return true;
248
249  bool result = OnWriteCompletedNoLock(
250      io_result, platform_handles_written, bytes_written);
251  if (!result) {
252    // Even if we're on the I/O thread, don't call |OnError()| in the nested
253    // context.
254    message_loop_for_io_->PostTask(FROM_HERE,
255                                   base::Bind(&RawChannel::CallOnError,
256                                              weak_ptr_factory_.GetWeakPtr(),
257                                              Delegate::ERROR_WRITE));
258  }
259
260  return result;
261}
262
263// Reminder: This must be thread-safe.
264bool RawChannel::IsWriteBufferEmpty() {
265  base::AutoLock locker(write_lock_);
266  return write_buffer_->message_queue_.empty();
267}
268
269void RawChannel::OnReadCompleted(IOResult io_result, size_t bytes_read) {
270  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
271
272  if (read_stopped_) {
273    NOTREACHED();
274    return;
275  }
276
277  // Keep reading data in a loop, and dispatch messages if enough data is
278  // received. Exit the loop if any of the following happens:
279  //   - one or more messages were dispatched;
280  //   - the last read failed, was a partial read or would block;
281  //   - |Shutdown()| was called.
282  do {
283    switch (io_result) {
284      case IO_SUCCEEDED:
285        break;
286      case IO_FAILED_SHUTDOWN:
287      case IO_FAILED_BROKEN:
288      case IO_FAILED_UNKNOWN:
289        read_stopped_ = true;
290        CallOnError(ReadIOResultToError(io_result));
291        return;
292      case IO_PENDING:
293        NOTREACHED();
294        return;
295    }
296
297    read_buffer_->num_valid_bytes_ += bytes_read;
298
299    // Dispatch all the messages that we can.
300    bool did_dispatch_message = false;
301    // Tracks the offset of the first undispatched message in |read_buffer_|.
302    // Currently, we copy data to ensure that this is zero at the beginning.
303    size_t read_buffer_start = 0;
304    size_t remaining_bytes = read_buffer_->num_valid_bytes_;
305    size_t message_size;
306    // Note that we rely on short-circuit evaluation here:
307    //   - |read_buffer_start| may be an invalid index into
308    //     |read_buffer_->buffer_| if |remaining_bytes| is zero.
309    //   - |message_size| is only valid if |GetNextMessageSize()| returns true.
310    // TODO(vtl): Use |message_size| more intelligently (e.g., to request the
311    // next read).
312    // TODO(vtl): Validate that |message_size| is sane.
313    while (remaining_bytes > 0 && MessageInTransit::GetNextMessageSize(
314                                      &read_buffer_->buffer_[read_buffer_start],
315                                      remaining_bytes,
316                                      &message_size) &&
317           remaining_bytes >= message_size) {
318      MessageInTransit::View message_view(
319          message_size, &read_buffer_->buffer_[read_buffer_start]);
320      DCHECK_EQ(message_view.total_size(), message_size);
321
322      const char* error_message = nullptr;
323      if (!message_view.IsValid(GetSerializedPlatformHandleSize(),
324                                &error_message)) {
325        DCHECK(error_message);
326        LOG(ERROR) << "Received invalid message: " << error_message;
327        read_stopped_ = true;
328        CallOnError(Delegate::ERROR_READ_BAD_MESSAGE);
329        return;
330      }
331
332      if (message_view.type() == MessageInTransit::kTypeRawChannel) {
333        if (!OnReadMessageForRawChannel(message_view)) {
334          read_stopped_ = true;
335          CallOnError(Delegate::ERROR_READ_BAD_MESSAGE);
336          return;
337        }
338      } else {
339        embedder::ScopedPlatformHandleVectorPtr platform_handles;
340        if (message_view.transport_data_buffer()) {
341          size_t num_platform_handles;
342          const void* platform_handle_table;
343          TransportData::GetPlatformHandleTable(
344              message_view.transport_data_buffer(),
345              &num_platform_handles,
346              &platform_handle_table);
347
348          if (num_platform_handles > 0) {
349            platform_handles =
350                GetReadPlatformHandles(num_platform_handles,
351                                       platform_handle_table).Pass();
352            if (!platform_handles) {
353              LOG(ERROR) << "Invalid number of platform handles received";
354              read_stopped_ = true;
355              CallOnError(Delegate::ERROR_READ_BAD_MESSAGE);
356              return;
357            }
358          }
359        }
360
361        // TODO(vtl): In the case that we aren't expecting any platform handles,
362        // for the POSIX implementation, we should confirm that none are stored.
363
364        // Dispatch the message.
365        DCHECK(delegate_);
366        delegate_->OnReadMessage(message_view, platform_handles.Pass());
367        if (read_stopped_) {
368          // |Shutdown()| was called in |OnReadMessage()|.
369          // TODO(vtl): Add test for this case.
370          return;
371        }
372      }
373
374      did_dispatch_message = true;
375
376      // Update our state.
377      read_buffer_start += message_size;
378      remaining_bytes -= message_size;
379    }
380
381    if (read_buffer_start > 0) {
382      // Move data back to start.
383      read_buffer_->num_valid_bytes_ = remaining_bytes;
384      if (read_buffer_->num_valid_bytes_ > 0) {
385        memmove(&read_buffer_->buffer_[0],
386                &read_buffer_->buffer_[read_buffer_start],
387                remaining_bytes);
388      }
389      read_buffer_start = 0;
390    }
391
392    if (read_buffer_->buffer_.size() - read_buffer_->num_valid_bytes_ <
393        kReadSize) {
394      // Use power-of-2 buffer sizes.
395      // TODO(vtl): Make sure the buffer doesn't get too large (and enforce the
396      // maximum message size to whatever extent necessary).
397      // TODO(vtl): We may often be able to peek at the header and get the real
398      // required extra space (which may be much bigger than |kReadSize|).
399      size_t new_size = std::max(read_buffer_->buffer_.size(), kReadSize);
400      while (new_size < read_buffer_->num_valid_bytes_ + kReadSize)
401        new_size *= 2;
402
403      // TODO(vtl): It's suboptimal to zero out the fresh memory.
404      read_buffer_->buffer_.resize(new_size, 0);
405    }
406
407    // (1) If we dispatched any messages, stop reading for now (and let the
408    // message loop do its thing for another round).
409    // TODO(vtl): Is this the behavior we want? (Alternatives: i. Dispatch only
410    // a single message. Risks: slower, more complex if we want to avoid lots of
411    // copying. ii. Keep reading until there's no more data and dispatch all the
412    // messages we can. Risks: starvation of other users of the message loop.)
413    // (2) If we didn't max out |kReadSize|, stop reading for now.
414    bool schedule_for_later = did_dispatch_message || bytes_read < kReadSize;
415    bytes_read = 0;
416    io_result = schedule_for_later ? ScheduleRead() : Read(&bytes_read);
417  } while (io_result != IO_PENDING);
418}
419
420void RawChannel::OnWriteCompleted(IOResult io_result,
421                                  size_t platform_handles_written,
422                                  size_t bytes_written) {
423  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
424  DCHECK_NE(io_result, IO_PENDING);
425
426  bool did_fail = false;
427  {
428    base::AutoLock locker(write_lock_);
429    DCHECK_EQ(write_stopped_, write_buffer_->message_queue_.empty());
430
431    if (write_stopped_) {
432      NOTREACHED();
433      return;
434    }
435
436    did_fail = !OnWriteCompletedNoLock(
437                   io_result, platform_handles_written, bytes_written);
438  }
439
440  if (did_fail)
441    CallOnError(Delegate::ERROR_WRITE);
442}
443
444void RawChannel::EnqueueMessageNoLock(scoped_ptr<MessageInTransit> message) {
445  write_lock_.AssertAcquired();
446  write_buffer_->message_queue_.push_back(message.release());
447}
448
449bool RawChannel::OnReadMessageForRawChannel(
450    const MessageInTransit::View& message_view) {
451  // No non-implementation specific |RawChannel| control messages.
452  LOG(ERROR) << "Invalid control message (subtype " << message_view.subtype()
453             << ")";
454  return false;
455}
456
457// static
458RawChannel::Delegate::Error RawChannel::ReadIOResultToError(
459    IOResult io_result) {
460  switch (io_result) {
461    case IO_FAILED_SHUTDOWN:
462      return Delegate::ERROR_READ_SHUTDOWN;
463    case IO_FAILED_BROKEN:
464      return Delegate::ERROR_READ_BROKEN;
465    case IO_FAILED_UNKNOWN:
466      return Delegate::ERROR_READ_UNKNOWN;
467    case IO_SUCCEEDED:
468    case IO_PENDING:
469      NOTREACHED();
470      break;
471  }
472  return Delegate::ERROR_READ_UNKNOWN;
473}
474
475void RawChannel::CallOnError(Delegate::Error error) {
476  DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
477  // TODO(vtl): Add a "write_lock_.AssertNotAcquired()"?
478  if (delegate_)
479    delegate_->OnError(error);
480}
481
482bool RawChannel::OnWriteCompletedNoLock(IOResult io_result,
483                                        size_t platform_handles_written,
484                                        size_t bytes_written) {
485  write_lock_.AssertAcquired();
486
487  DCHECK(!write_stopped_);
488  DCHECK(!write_buffer_->message_queue_.empty());
489
490  if (io_result == IO_SUCCEEDED) {
491    write_buffer_->platform_handles_offset_ += platform_handles_written;
492    write_buffer_->data_offset_ += bytes_written;
493
494    MessageInTransit* message = write_buffer_->message_queue_.front();
495    if (write_buffer_->data_offset_ >= message->total_size()) {
496      // Complete write.
497      CHECK_EQ(write_buffer_->data_offset_, message->total_size());
498      write_buffer_->message_queue_.pop_front();
499      delete message;
500      write_buffer_->platform_handles_offset_ = 0;
501      write_buffer_->data_offset_ = 0;
502
503      if (write_buffer_->message_queue_.empty())
504        return true;
505    }
506
507    // Schedule the next write.
508    io_result = ScheduleWriteNoLock();
509    if (io_result == IO_PENDING)
510      return true;
511    DCHECK_NE(io_result, IO_SUCCEEDED);
512  }
513
514  write_stopped_ = true;
515  STLDeleteElements(&write_buffer_->message_queue_);
516  write_buffer_->platform_handles_offset_ = 0;
517  write_buffer_->data_offset_ = 0;
518  return false;
519}
520
521}  // namespace system
522}  // namespace mojo
523