1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "mojo/system/raw_channel.h"
6
7#include <stdint.h>
8
9#include <vector>
10
11#include "base/bind.h"
12#include "base/location.h"
13#include "base/logging.h"
14#include "base/macros.h"
15#include "base/memory/scoped_ptr.h"
16#include "base/memory/scoped_vector.h"
17#include "base/rand_util.h"
18#include "base/synchronization/lock.h"
19#include "base/synchronization/waitable_event.h"
20#include "base/test/test_io_thread.h"
21#include "base/threading/platform_thread.h"  // For |Sleep()|.
22#include "base/threading/simple_thread.h"
23#include "base/time/time.h"
24#include "build/build_config.h"
25#include "mojo/common/test/test_utils.h"
26#include "mojo/embedder/platform_channel_pair.h"
27#include "mojo/embedder/platform_handle.h"
28#include "mojo/embedder/scoped_platform_handle.h"
29#include "mojo/system/message_in_transit.h"
30#include "mojo/system/test_utils.h"
31#include "testing/gtest/include/gtest/gtest.h"
32
33namespace mojo {
34namespace system {
35namespace {
36
37scoped_ptr<MessageInTransit> MakeTestMessage(uint32_t num_bytes) {
38  std::vector<unsigned char> bytes(num_bytes, 0);
39  for (size_t i = 0; i < num_bytes; i++)
40    bytes[i] = static_cast<unsigned char>(i + num_bytes);
41  return make_scoped_ptr(
42      new MessageInTransit(MessageInTransit::kTypeMessagePipeEndpoint,
43                           MessageInTransit::kSubtypeMessagePipeEndpointData,
44                           num_bytes,
45                           bytes.empty() ? nullptr : &bytes[0]));
46}
47
48bool CheckMessageData(const void* bytes, uint32_t num_bytes) {
49  const unsigned char* b = static_cast<const unsigned char*>(bytes);
50  for (uint32_t i = 0; i < num_bytes; i++) {
51    if (b[i] != static_cast<unsigned char>(i + num_bytes))
52      return false;
53  }
54  return true;
55}
56
57void InitOnIOThread(RawChannel* raw_channel, RawChannel::Delegate* delegate) {
58  CHECK(raw_channel->Init(delegate));
59}
60
61bool WriteTestMessageToHandle(const embedder::PlatformHandle& handle,
62                              uint32_t num_bytes) {
63  scoped_ptr<MessageInTransit> message(MakeTestMessage(num_bytes));
64
65  size_t write_size = 0;
66  mojo::test::BlockingWrite(
67      handle, message->main_buffer(), message->main_buffer_size(), &write_size);
68  return write_size == message->main_buffer_size();
69}
70
71// -----------------------------------------------------------------------------
72
73class RawChannelTest : public testing::Test {
74 public:
75  RawChannelTest() : io_thread_(base::TestIOThread::kManualStart) {}
76  virtual ~RawChannelTest() {}
77
78  virtual void SetUp() OVERRIDE {
79    embedder::PlatformChannelPair channel_pair;
80    handles[0] = channel_pair.PassServerHandle();
81    handles[1] = channel_pair.PassClientHandle();
82    io_thread_.Start();
83  }
84
85  virtual void TearDown() OVERRIDE {
86    io_thread_.Stop();
87    handles[0].reset();
88    handles[1].reset();
89  }
90
91 protected:
92  base::TestIOThread* io_thread() { return &io_thread_; }
93
94  embedder::ScopedPlatformHandle handles[2];
95
96 private:
97  base::TestIOThread io_thread_;
98
99  DISALLOW_COPY_AND_ASSIGN(RawChannelTest);
100};
101
102// RawChannelTest.WriteMessage -------------------------------------------------
103
104class WriteOnlyRawChannelDelegate : public RawChannel::Delegate {
105 public:
106  WriteOnlyRawChannelDelegate() {}
107  virtual ~WriteOnlyRawChannelDelegate() {}
108
109  // |RawChannel::Delegate| implementation:
110  virtual void OnReadMessage(
111      const MessageInTransit::View& /*message_view*/,
112      embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE {
113    CHECK(false);  // Should not get called.
114  }
115  virtual void OnError(Error error) OVERRIDE {
116    // We'll get a read (shutdown) error when the connection is closed.
117    CHECK_EQ(error, ERROR_READ_SHUTDOWN);
118  }
119
120 private:
121  DISALLOW_COPY_AND_ASSIGN(WriteOnlyRawChannelDelegate);
122};
123
124static const int64_t kMessageReaderSleepMs = 1;
125static const size_t kMessageReaderMaxPollIterations = 3000;
126
127class TestMessageReaderAndChecker {
128 public:
129  explicit TestMessageReaderAndChecker(embedder::PlatformHandle handle)
130      : handle_(handle) {}
131  ~TestMessageReaderAndChecker() { CHECK(bytes_.empty()); }
132
133  bool ReadAndCheckNextMessage(uint32_t expected_size) {
134    unsigned char buffer[4096];
135
136    for (size_t i = 0; i < kMessageReaderMaxPollIterations;) {
137      size_t read_size = 0;
138      CHECK(mojo::test::NonBlockingRead(
139          handle_, buffer, sizeof(buffer), &read_size));
140
141      // Append newly-read data to |bytes_|.
142      bytes_.insert(bytes_.end(), buffer, buffer + read_size);
143
144      // If we have the header....
145      size_t message_size;
146      if (MessageInTransit::GetNextMessageSize(
147              bytes_.empty() ? nullptr : &bytes_[0],
148              bytes_.size(),
149              &message_size)) {
150        // If we've read the whole message....
151        if (bytes_.size() >= message_size) {
152          bool rv = true;
153          MessageInTransit::View message_view(message_size, &bytes_[0]);
154          CHECK_EQ(message_view.main_buffer_size(), message_size);
155
156          if (message_view.num_bytes() != expected_size) {
157            LOG(ERROR) << "Wrong size: " << message_size << " instead of "
158                       << expected_size << " bytes.";
159            rv = false;
160          } else if (!CheckMessageData(message_view.bytes(),
161                                       message_view.num_bytes())) {
162            LOG(ERROR) << "Incorrect message bytes.";
163            rv = false;
164          }
165
166          // Erase message data.
167          bytes_.erase(bytes_.begin(),
168                       bytes_.begin() + message_view.main_buffer_size());
169          return rv;
170        }
171      }
172
173      if (static_cast<size_t>(read_size) < sizeof(buffer)) {
174        i++;
175        base::PlatformThread::Sleep(
176            base::TimeDelta::FromMilliseconds(kMessageReaderSleepMs));
177      }
178    }
179
180    LOG(ERROR) << "Too many iterations.";
181    return false;
182  }
183
184 private:
185  const embedder::PlatformHandle handle_;
186
187  // The start of the received data should always be on a message boundary.
188  std::vector<unsigned char> bytes_;
189
190  DISALLOW_COPY_AND_ASSIGN(TestMessageReaderAndChecker);
191};
192
193// Tests writing (and verifies reading using our own custom reader).
194TEST_F(RawChannelTest, WriteMessage) {
195  WriteOnlyRawChannelDelegate delegate;
196  scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
197  TestMessageReaderAndChecker checker(handles[1].get());
198  io_thread()->PostTaskAndWait(
199      FROM_HERE,
200      base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
201
202  // Write and read, for a variety of sizes.
203  for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
204    EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
205    EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
206  }
207
208  // Write/queue and read afterwards, for a variety of sizes.
209  for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
210    EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
211  for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
212    EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
213
214  io_thread()->PostTaskAndWait(
215      FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
216}
217
218// RawChannelTest.OnReadMessage ------------------------------------------------
219
220class ReadCheckerRawChannelDelegate : public RawChannel::Delegate {
221 public:
222  ReadCheckerRawChannelDelegate() : done_event_(false, false), position_(0) {}
223  virtual ~ReadCheckerRawChannelDelegate() {}
224
225  // |RawChannel::Delegate| implementation (called on the I/O thread):
226  virtual void OnReadMessage(
227      const MessageInTransit::View& message_view,
228      embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
229    EXPECT_FALSE(platform_handles);
230
231    size_t position;
232    size_t expected_size;
233    bool should_signal = false;
234    {
235      base::AutoLock locker(lock_);
236      CHECK_LT(position_, expected_sizes_.size());
237      position = position_;
238      expected_size = expected_sizes_[position];
239      position_++;
240      if (position_ >= expected_sizes_.size())
241        should_signal = true;
242    }
243
244    EXPECT_EQ(expected_size, message_view.num_bytes()) << position;
245    if (message_view.num_bytes() == expected_size) {
246      EXPECT_TRUE(
247          CheckMessageData(message_view.bytes(), message_view.num_bytes()))
248          << position;
249    }
250
251    if (should_signal)
252      done_event_.Signal();
253  }
254  virtual void OnError(Error error) OVERRIDE {
255    // We'll get a read (shutdown) error when the connection is closed.
256    CHECK_EQ(error, ERROR_READ_SHUTDOWN);
257  }
258
259  // Waits for all the messages (of sizes |expected_sizes_|) to be seen.
260  void Wait() { done_event_.Wait(); }
261
262  void SetExpectedSizes(const std::vector<uint32_t>& expected_sizes) {
263    base::AutoLock locker(lock_);
264    CHECK_EQ(position_, expected_sizes_.size());
265    expected_sizes_ = expected_sizes;
266    position_ = 0;
267  }
268
269 private:
270  base::WaitableEvent done_event_;
271
272  base::Lock lock_;  // Protects the following members.
273  std::vector<uint32_t> expected_sizes_;
274  size_t position_;
275
276  DISALLOW_COPY_AND_ASSIGN(ReadCheckerRawChannelDelegate);
277};
278
279// Tests reading (writing using our own custom writer).
280TEST_F(RawChannelTest, OnReadMessage) {
281  ReadCheckerRawChannelDelegate delegate;
282  scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
283  io_thread()->PostTaskAndWait(
284      FROM_HERE,
285      base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
286
287  // Write and read, for a variety of sizes.
288  for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
289    delegate.SetExpectedSizes(std::vector<uint32_t>(1, size));
290
291    EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
292
293    delegate.Wait();
294  }
295
296  // Set up reader and write as fast as we can.
297  // Write/queue and read afterwards, for a variety of sizes.
298  std::vector<uint32_t> expected_sizes;
299  for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
300    expected_sizes.push_back(size);
301  delegate.SetExpectedSizes(expected_sizes);
302  for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
303    EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
304  delegate.Wait();
305
306  io_thread()->PostTaskAndWait(
307      FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
308}
309
310// RawChannelTest.WriteMessageAndOnReadMessage ---------------------------------
311
312class RawChannelWriterThread : public base::SimpleThread {
313 public:
314  RawChannelWriterThread(RawChannel* raw_channel, size_t write_count)
315      : base::SimpleThread("raw_channel_writer_thread"),
316        raw_channel_(raw_channel),
317        left_to_write_(write_count) {}
318
319  virtual ~RawChannelWriterThread() { Join(); }
320
321 private:
322  virtual void Run() OVERRIDE {
323    static const int kMaxRandomMessageSize = 25000;
324
325    while (left_to_write_-- > 0) {
326      EXPECT_TRUE(raw_channel_->WriteMessage(MakeTestMessage(
327          static_cast<uint32_t>(base::RandInt(1, kMaxRandomMessageSize)))));
328    }
329  }
330
331  RawChannel* const raw_channel_;
332  size_t left_to_write_;
333
334  DISALLOW_COPY_AND_ASSIGN(RawChannelWriterThread);
335};
336
337class ReadCountdownRawChannelDelegate : public RawChannel::Delegate {
338 public:
339  explicit ReadCountdownRawChannelDelegate(size_t expected_count)
340      : done_event_(false, false), expected_count_(expected_count), count_(0) {}
341  virtual ~ReadCountdownRawChannelDelegate() {}
342
343  // |RawChannel::Delegate| implementation (called on the I/O thread):
344  virtual void OnReadMessage(
345      const MessageInTransit::View& message_view,
346      embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
347    EXPECT_FALSE(platform_handles);
348
349    EXPECT_LT(count_, expected_count_);
350    count_++;
351
352    EXPECT_TRUE(
353        CheckMessageData(message_view.bytes(), message_view.num_bytes()));
354
355    if (count_ >= expected_count_)
356      done_event_.Signal();
357  }
358  virtual void OnError(Error error) OVERRIDE {
359    // We'll get a read (shutdown) error when the connection is closed.
360    CHECK_EQ(error, ERROR_READ_SHUTDOWN);
361  }
362
363  // Waits for all the messages to have been seen.
364  void Wait() { done_event_.Wait(); }
365
366 private:
367  base::WaitableEvent done_event_;
368  size_t expected_count_;
369  size_t count_;
370
371  DISALLOW_COPY_AND_ASSIGN(ReadCountdownRawChannelDelegate);
372};
373
374TEST_F(RawChannelTest, WriteMessageAndOnReadMessage) {
375  static const size_t kNumWriterThreads = 10;
376  static const size_t kNumWriteMessagesPerThread = 4000;
377
378  WriteOnlyRawChannelDelegate writer_delegate;
379  scoped_ptr<RawChannel> writer_rc(RawChannel::Create(handles[0].Pass()));
380  io_thread()->PostTaskAndWait(FROM_HERE,
381                               base::Bind(&InitOnIOThread,
382                                          writer_rc.get(),
383                                          base::Unretained(&writer_delegate)));
384
385  ReadCountdownRawChannelDelegate reader_delegate(kNumWriterThreads *
386                                                  kNumWriteMessagesPerThread);
387  scoped_ptr<RawChannel> reader_rc(RawChannel::Create(handles[1].Pass()));
388  io_thread()->PostTaskAndWait(FROM_HERE,
389                               base::Bind(&InitOnIOThread,
390                                          reader_rc.get(),
391                                          base::Unretained(&reader_delegate)));
392
393  {
394    ScopedVector<RawChannelWriterThread> writer_threads;
395    for (size_t i = 0; i < kNumWriterThreads; i++) {
396      writer_threads.push_back(new RawChannelWriterThread(
397          writer_rc.get(), kNumWriteMessagesPerThread));
398    }
399    for (size_t i = 0; i < writer_threads.size(); i++)
400      writer_threads[i]->Start();
401  }  // Joins all the writer threads.
402
403  // Sleep a bit, to let any extraneous reads be processed. (There shouldn't be
404  // any, but we want to know about them.)
405  base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100));
406
407  // Wait for reading to finish.
408  reader_delegate.Wait();
409
410  io_thread()->PostTaskAndWait(
411      FROM_HERE,
412      base::Bind(&RawChannel::Shutdown, base::Unretained(reader_rc.get())));
413
414  io_thread()->PostTaskAndWait(
415      FROM_HERE,
416      base::Bind(&RawChannel::Shutdown, base::Unretained(writer_rc.get())));
417}
418
419// RawChannelTest.OnError ------------------------------------------------------
420
421class ErrorRecordingRawChannelDelegate
422    : public ReadCountdownRawChannelDelegate {
423 public:
424  ErrorRecordingRawChannelDelegate(size_t expected_read_count,
425                                   bool expect_read_error,
426                                   bool expect_write_error)
427      : ReadCountdownRawChannelDelegate(expected_read_count),
428        got_read_error_event_(false, false),
429        got_write_error_event_(false, false),
430        expecting_read_error_(expect_read_error),
431        expecting_write_error_(expect_write_error) {}
432
433  virtual ~ErrorRecordingRawChannelDelegate() {}
434
435  virtual void OnError(Error error) OVERRIDE {
436    switch (error) {
437      case ERROR_READ_SHUTDOWN:
438        ASSERT_TRUE(expecting_read_error_);
439        expecting_read_error_ = false;
440        got_read_error_event_.Signal();
441        break;
442      case ERROR_READ_BROKEN:
443        // TODO(vtl): Test broken connections.
444        CHECK(false);
445        break;
446      case ERROR_READ_BAD_MESSAGE:
447        // TODO(vtl): Test reception/detection of bad messages.
448        CHECK(false);
449        break;
450      case ERROR_READ_UNKNOWN:
451        // TODO(vtl): Test however it is we might get here.
452        CHECK(false);
453        break;
454      case ERROR_WRITE:
455        ASSERT_TRUE(expecting_write_error_);
456        expecting_write_error_ = false;
457        got_write_error_event_.Signal();
458        break;
459    }
460  }
461
462  void WaitForReadError() { got_read_error_event_.Wait(); }
463  void WaitForWriteError() { got_write_error_event_.Wait(); }
464
465 private:
466  base::WaitableEvent got_read_error_event_;
467  base::WaitableEvent got_write_error_event_;
468
469  bool expecting_read_error_;
470  bool expecting_write_error_;
471
472  DISALLOW_COPY_AND_ASSIGN(ErrorRecordingRawChannelDelegate);
473};
474
475// Tests (fatal) errors.
476TEST_F(RawChannelTest, OnError) {
477  ErrorRecordingRawChannelDelegate delegate(0, true, true);
478  scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
479  io_thread()->PostTaskAndWait(
480      FROM_HERE,
481      base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
482
483  // Close the handle of the other end, which should make writing fail.
484  handles[1].reset();
485
486  EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
487
488  // We should get a write error.
489  delegate.WaitForWriteError();
490
491  // We should also get a read error.
492  delegate.WaitForReadError();
493
494  EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(2)));
495
496  // Sleep a bit, to make sure we don't get another |OnError()|
497  // notification. (If we actually get another one, |OnError()| crashes.)
498  base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(20));
499
500  io_thread()->PostTaskAndWait(
501      FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
502}
503
504// RawChannelTest.ReadUnaffectedByWriteError -----------------------------------
505
506TEST_F(RawChannelTest, ReadUnaffectedByWriteError) {
507  const size_t kMessageCount = 5;
508
509  // Write a few messages into the other end.
510  uint32_t message_size = 1;
511  for (size_t i = 0; i < kMessageCount;
512       i++, message_size += message_size / 2 + 1)
513    EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), message_size));
514
515  // Close the other end, which should make writing fail.
516  handles[1].reset();
517
518  // Only start up reading here. The system buffer should still contain the
519  // messages that were written.
520  ErrorRecordingRawChannelDelegate delegate(kMessageCount, true, true);
521  scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
522  io_thread()->PostTaskAndWait(
523      FROM_HERE,
524      base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
525
526  EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
527
528  // We should definitely get a write error.
529  delegate.WaitForWriteError();
530
531  // Wait for reading to finish. A writing failure shouldn't affect reading.
532  delegate.Wait();
533
534  // And then we should get a read error.
535  delegate.WaitForReadError();
536
537  io_thread()->PostTaskAndWait(
538      FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
539}
540
541// RawChannelTest.WriteMessageAfterShutdown ------------------------------------
542
543// Makes sure that calling |WriteMessage()| after |Shutdown()| behaves
544// correctly.
545TEST_F(RawChannelTest, WriteMessageAfterShutdown) {
546  WriteOnlyRawChannelDelegate delegate;
547  scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
548  io_thread()->PostTaskAndWait(
549      FROM_HERE,
550      base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
551  io_thread()->PostTaskAndWait(
552      FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
553
554  EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
555}
556
557// RawChannelTest.ShutdownOnReadMessage ----------------------------------------
558
559class ShutdownOnReadMessageRawChannelDelegate : public RawChannel::Delegate {
560 public:
561  explicit ShutdownOnReadMessageRawChannelDelegate(RawChannel* raw_channel)
562      : raw_channel_(raw_channel),
563        done_event_(false, false),
564        did_shutdown_(false) {}
565  virtual ~ShutdownOnReadMessageRawChannelDelegate() {}
566
567  // |RawChannel::Delegate| implementation (called on the I/O thread):
568  virtual void OnReadMessage(
569      const MessageInTransit::View& message_view,
570      embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
571    EXPECT_FALSE(platform_handles);
572    EXPECT_FALSE(did_shutdown_);
573    EXPECT_TRUE(
574        CheckMessageData(message_view.bytes(), message_view.num_bytes()));
575    raw_channel_->Shutdown();
576    did_shutdown_ = true;
577    done_event_.Signal();
578  }
579  virtual void OnError(Error /*error*/) OVERRIDE {
580    CHECK(false);  // Should not get called.
581  }
582
583  // Waits for shutdown.
584  void Wait() {
585    done_event_.Wait();
586    EXPECT_TRUE(did_shutdown_);
587  }
588
589 private:
590  RawChannel* const raw_channel_;
591  base::WaitableEvent done_event_;
592  bool did_shutdown_;
593
594  DISALLOW_COPY_AND_ASSIGN(ShutdownOnReadMessageRawChannelDelegate);
595};
596
597TEST_F(RawChannelTest, ShutdownOnReadMessage) {
598  // Write a few messages into the other end.
599  for (size_t count = 0; count < 5; count++)
600    EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), 10));
601
602  scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
603  ShutdownOnReadMessageRawChannelDelegate delegate(rc.get());
604  io_thread()->PostTaskAndWait(
605      FROM_HERE,
606      base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
607
608  // Wait for the delegate, which will shut the |RawChannel| down.
609  delegate.Wait();
610}
611
612// RawChannelTest.ShutdownOnError{Read, Write} ---------------------------------
613
614class ShutdownOnErrorRawChannelDelegate : public RawChannel::Delegate {
615 public:
616  ShutdownOnErrorRawChannelDelegate(RawChannel* raw_channel,
617                                    Error shutdown_on_error_type)
618      : raw_channel_(raw_channel),
619        shutdown_on_error_type_(shutdown_on_error_type),
620        done_event_(false, false),
621        did_shutdown_(false) {}
622  virtual ~ShutdownOnErrorRawChannelDelegate() {}
623
624  // |RawChannel::Delegate| implementation (called on the I/O thread):
625  virtual void OnReadMessage(
626      const MessageInTransit::View& /*message_view*/,
627      embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE {
628    CHECK(false);  // Should not get called.
629  }
630  virtual void OnError(Error error) OVERRIDE {
631    EXPECT_FALSE(did_shutdown_);
632    if (error != shutdown_on_error_type_)
633      return;
634    raw_channel_->Shutdown();
635    did_shutdown_ = true;
636    done_event_.Signal();
637  }
638
639  // Waits for shutdown.
640  void Wait() {
641    done_event_.Wait();
642    EXPECT_TRUE(did_shutdown_);
643  }
644
645 private:
646  RawChannel* const raw_channel_;
647  const Error shutdown_on_error_type_;
648  base::WaitableEvent done_event_;
649  bool did_shutdown_;
650
651  DISALLOW_COPY_AND_ASSIGN(ShutdownOnErrorRawChannelDelegate);
652};
653
654TEST_F(RawChannelTest, ShutdownOnErrorRead) {
655  scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
656  ShutdownOnErrorRawChannelDelegate delegate(
657      rc.get(), RawChannel::Delegate::ERROR_READ_SHUTDOWN);
658  io_thread()->PostTaskAndWait(
659      FROM_HERE,
660      base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
661
662  // Close the handle of the other end, which should stuff fail.
663  handles[1].reset();
664
665  // Wait for the delegate, which will shut the |RawChannel| down.
666  delegate.Wait();
667}
668
669TEST_F(RawChannelTest, ShutdownOnErrorWrite) {
670  scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
671  ShutdownOnErrorRawChannelDelegate delegate(rc.get(),
672                                             RawChannel::Delegate::ERROR_WRITE);
673  io_thread()->PostTaskAndWait(
674      FROM_HERE,
675      base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
676
677  // Close the handle of the other end, which should stuff fail.
678  handles[1].reset();
679
680  EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
681
682  // Wait for the delegate, which will shut the |RawChannel| down.
683  delegate.Wait();
684}
685
686}  // namespace
687}  // namespace system
688}  // namespace mojo
689