unix_domain_socket_util_unittest.cc revision 90dce4d38c5ff5333bea97d859d4e484e27edf0c
1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <sys/socket.h>
6
7#include "base/bind.h"
8#include "base/files/file_path.h"
9#include "base/path_service.h"
10#include "base/posix/eintr_wrapper.h"
11#include "base/synchronization/waitable_event.h"
12#include "base/threading/thread.h"
13#include "base/threading/thread_restrictions.h"
14#include "ipc/unix_domain_socket_util.h"
15#include "testing/gtest/include/gtest/gtest.h"
16
17namespace {
18
19class SocketAcceptor : public base::MessageLoopForIO::Watcher {
20 public:
21  SocketAcceptor(int fd, base::MessageLoopProxy* target_thread)
22      : server_fd_(-1),
23        target_thread_(target_thread),
24        started_watching_event_(false, false),
25        accepted_event_(false, false) {
26    target_thread->PostTask(FROM_HERE,
27        base::Bind(&SocketAcceptor::StartWatching, base::Unretained(this), fd));
28  }
29
30  virtual ~SocketAcceptor() {
31    Close();
32  }
33
34  int server_fd() const { return server_fd_; }
35
36  void WaitUntilReady() {
37    started_watching_event_.Wait();
38  }
39
40  void WaitForAccept() {
41    accepted_event_.Wait();
42  }
43
44  void Close() {
45    if (watcher_.get()) {
46      target_thread_->PostTask(FROM_HERE,
47          base::Bind(&SocketAcceptor::StopWatching, base::Unretained(this),
48              watcher_.release()));
49    }
50  }
51
52 private:
53  void StartWatching(int fd) {
54    watcher_.reset(new base::MessageLoopForIO::FileDescriptorWatcher);
55    base::MessageLoopForIO::current()->WatchFileDescriptor(
56        fd, true, base::MessageLoopForIO::WATCH_READ, watcher_.get(), this);
57    started_watching_event_.Signal();
58  }
59  void StopWatching(base::MessageLoopForIO::FileDescriptorWatcher* watcher) {
60    watcher->StopWatchingFileDescriptor();
61    delete watcher;
62  }
63  virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE {
64    ASSERT_EQ(-1, server_fd_);
65    IPC::ServerAcceptConnection(fd, &server_fd_);
66    watcher_->StopWatchingFileDescriptor();
67    accepted_event_.Signal();
68  }
69  virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE {}
70
71  int server_fd_;
72  base::MessageLoopProxy* target_thread_;
73  scoped_ptr<base::MessageLoopForIO::FileDescriptorWatcher> watcher_;
74  base::WaitableEvent started_watching_event_;
75  base::WaitableEvent accepted_event_;
76
77  DISALLOW_COPY_AND_ASSIGN(SocketAcceptor);
78};
79
80const base::FilePath GetChannelDir() {
81#if defined(OS_ANDROID)
82  base::FilePath tmp_dir;
83  PathService::Get(base::DIR_CACHE, &tmp_dir);
84  return tmp_dir;
85#else
86  return base::FilePath("/var/tmp");
87#endif
88}
89
90class TestUnixSocketConnection {
91 public:
92  TestUnixSocketConnection()
93      : worker_("WorkerThread"),
94        server_listen_fd_(-1),
95        server_fd_(-1),
96        client_fd_(-1) {
97    socket_name_ = GetChannelDir().Append("TestSocket");
98    base::Thread::Options options;
99    options.message_loop_type = base::MessageLoop::TYPE_IO;
100    worker_.StartWithOptions(options);
101  }
102
103  bool CreateServerSocket() {
104    IPC::CreateServerUnixDomainSocket(socket_name_, &server_listen_fd_);
105    if (server_listen_fd_ < 0)
106      return false;
107    struct stat socket_stat;
108    stat(socket_name_.value().c_str(), &socket_stat);
109    EXPECT_TRUE(S_ISSOCK(socket_stat.st_mode));
110    acceptor_.reset(new SocketAcceptor(server_listen_fd_,
111                                       worker_.message_loop_proxy()));
112    acceptor_->WaitUntilReady();
113    return true;
114  }
115
116  bool CreateClientSocket() {
117    DCHECK(server_listen_fd_ >= 0);
118    IPC::CreateClientUnixDomainSocket(socket_name_, &client_fd_);
119    if (client_fd_ < 0)
120      return false;
121    acceptor_->WaitForAccept();
122    server_fd_ = acceptor_->server_fd();
123    return server_fd_ >= 0;
124  }
125
126  virtual ~TestUnixSocketConnection() {
127    if (client_fd_ >= 0)
128      close(client_fd_);
129    if (server_fd_ >= 0)
130      close(server_fd_);
131    if (server_listen_fd_ >= 0) {
132      close(server_listen_fd_);
133      unlink(socket_name_.value().c_str());
134    }
135  }
136
137  int client_fd() const { return client_fd_; }
138  int server_fd() const { return server_fd_; }
139
140 private:
141  base::Thread worker_;
142  base::FilePath socket_name_;
143  int server_listen_fd_;
144  int server_fd_;
145  int client_fd_;
146  scoped_ptr<SocketAcceptor> acceptor_;
147};
148
149// Ensure that IPC::CreateServerUnixDomainSocket creates a socket that
150// IPC::CreateClientUnixDomainSocket can successfully connect to.
151TEST(UnixDomainSocketUtil, Connect) {
152  TestUnixSocketConnection connection;
153  ASSERT_TRUE(connection.CreateServerSocket());
154  ASSERT_TRUE(connection.CreateClientSocket());
155}
156
157// Ensure that messages can be sent across the resulting socket.
158TEST(UnixDomainSocketUtil, SendReceive) {
159  TestUnixSocketConnection connection;
160  ASSERT_TRUE(connection.CreateServerSocket());
161  ASSERT_TRUE(connection.CreateClientSocket());
162
163  const char buffer[] = "Hello, server!";
164  size_t buf_len = sizeof(buffer);
165  size_t sent_bytes =
166      HANDLE_EINTR(send(connection.client_fd(), buffer, buf_len, 0));
167  ASSERT_EQ(buf_len, sent_bytes);
168  char recv_buf[sizeof(buffer)];
169  size_t received_bytes =
170      HANDLE_EINTR(recv(connection.server_fd(), recv_buf, buf_len, 0));
171  ASSERT_EQ(buf_len, received_bytes);
172  ASSERT_EQ(0, memcmp(recv_buf, buffer, buf_len));
173}
174
175}  // namespace
176