1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/unix_domain_listen_socket_posix.h"
6
7#include <errno.h>
8#include <fcntl.h>
9#include <poll.h>
10#include <sys/socket.h>
11#include <sys/stat.h>
12#include <sys/time.h>
13#include <sys/types.h>
14#include <sys/un.h>
15#include <unistd.h>
16
17#include <cstring>
18#include <queue>
19#include <string>
20
21#include "base/bind.h"
22#include "base/callback.h"
23#include "base/compiler_specific.h"
24#include "base/files/file_path.h"
25#include "base/files/file_util.h"
26#include "base/files/scoped_temp_dir.h"
27#include "base/memory/ref_counted.h"
28#include "base/memory/scoped_ptr.h"
29#include "base/message_loop/message_loop.h"
30#include "base/posix/eintr_wrapper.h"
31#include "base/synchronization/condition_variable.h"
32#include "base/synchronization/lock.h"
33#include "base/threading/platform_thread.h"
34#include "base/threading/thread.h"
35#include "net/socket/socket_descriptor.h"
36#include "testing/gtest/include/gtest/gtest.h"
37
38using std::queue;
39using std::string;
40
41namespace net {
42namespace deprecated {
43namespace {
44
45const char kSocketFilename[] = "socket_for_testing";
46const char kInvalidSocketPath[] = "/invalid/path";
47const char kMsg[] = "hello";
48
49enum EventType {
50  EVENT_ACCEPT,
51  EVENT_AUTH_DENIED,
52  EVENT_AUTH_GRANTED,
53  EVENT_CLOSE,
54  EVENT_LISTEN,
55  EVENT_READ,
56};
57
58class EventManager : public base::RefCounted<EventManager> {
59 public:
60  EventManager() : condition_(&mutex_) {}
61
62  bool HasPendingEvent() {
63    base::AutoLock lock(mutex_);
64    return !events_.empty();
65  }
66
67  void Notify(EventType event) {
68    base::AutoLock lock(mutex_);
69    events_.push(event);
70    condition_.Broadcast();
71  }
72
73  EventType WaitForEvent() {
74    base::AutoLock lock(mutex_);
75    while (events_.empty())
76      condition_.Wait();
77    EventType event = events_.front();
78    events_.pop();
79    return event;
80  }
81
82 private:
83  friend class base::RefCounted<EventManager>;
84  virtual ~EventManager() {}
85
86  queue<EventType> events_;
87  base::Lock mutex_;
88  base::ConditionVariable condition_;
89};
90
91class TestListenSocketDelegate : public StreamListenSocket::Delegate {
92 public:
93  explicit TestListenSocketDelegate(
94      const scoped_refptr<EventManager>& event_manager)
95      : event_manager_(event_manager) {}
96
97  virtual void DidAccept(StreamListenSocket* server,
98                         scoped_ptr<StreamListenSocket> connection) OVERRIDE {
99    LOG(ERROR) << __PRETTY_FUNCTION__;
100    connection_ = connection.Pass();
101    Notify(EVENT_ACCEPT);
102  }
103
104  virtual void DidRead(StreamListenSocket* connection,
105                       const char* data,
106                       int len) OVERRIDE {
107    {
108      base::AutoLock lock(mutex_);
109      DCHECK(len);
110      data_.assign(data, len - 1);
111    }
112    Notify(EVENT_READ);
113  }
114
115  virtual void DidClose(StreamListenSocket* sock) OVERRIDE {
116    Notify(EVENT_CLOSE);
117  }
118
119  void OnListenCompleted() {
120    Notify(EVENT_LISTEN);
121  }
122
123  string ReceivedData() {
124    base::AutoLock lock(mutex_);
125    return data_;
126  }
127
128 private:
129  void Notify(EventType event) {
130    event_manager_->Notify(event);
131  }
132
133  const scoped_refptr<EventManager> event_manager_;
134  scoped_ptr<StreamListenSocket> connection_;
135  base::Lock mutex_;
136  string data_;
137};
138
139bool UserCanConnectCallback(
140    bool allow_user, const scoped_refptr<EventManager>& event_manager,
141    const UnixDomainServerSocket::Credentials&) {
142  event_manager->Notify(
143      allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED);
144  return allow_user;
145}
146
147class UnixDomainListenSocketTestHelper : public testing::Test {
148 public:
149  void CreateAndListen() {
150    socket_ = UnixDomainListenSocket::CreateAndListen(
151        file_path_.value(), socket_delegate_.get(), MakeAuthCallback());
152    socket_delegate_->OnListenCompleted();
153  }
154
155 protected:
156  UnixDomainListenSocketTestHelper(const string& path_str, bool allow_user)
157      : allow_user_(allow_user) {
158    file_path_ = base::FilePath(path_str);
159    if (!file_path_.IsAbsolute()) {
160      EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
161      file_path_ = GetTempSocketPath(file_path_.value());
162    }
163    // Beware that if path_str is an absolute path, this class doesn't delete
164    // the file. It must be an invalid path and cannot be created by unittests.
165  }
166
167  base::FilePath GetTempSocketPath(const std::string socket_name) {
168    DCHECK(temp_dir_.IsValid());
169    return temp_dir_.path().Append(socket_name);
170  }
171
172  virtual void SetUp() OVERRIDE {
173    event_manager_ = new EventManager();
174    socket_delegate_.reset(new TestListenSocketDelegate(event_manager_));
175  }
176
177  virtual void TearDown() OVERRIDE {
178    socket_.reset();
179    socket_delegate_.reset();
180    event_manager_ = NULL;
181  }
182
183  UnixDomainListenSocket::AuthCallback MakeAuthCallback() {
184    return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_);
185  }
186
187  SocketDescriptor CreateClientSocket() {
188    const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
189    if (sock < 0) {
190      LOG(ERROR) << "socket() error";
191      return kInvalidSocket;
192    }
193    sockaddr_un addr;
194    memset(&addr, 0, sizeof(addr));
195    addr.sun_family = AF_UNIX;
196    socklen_t addr_len;
197    strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path));
198    addr_len = sizeof(sockaddr_un);
199    if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) {
200      LOG(ERROR) << "connect() error: " << strerror(errno)
201                 << ": path=" << file_path_.value();
202      return kInvalidSocket;
203    }
204    return sock;
205  }
206
207  scoped_ptr<base::Thread> CreateAndRunServerThread() {
208    base::Thread::Options options;
209    options.message_loop_type = base::MessageLoop::TYPE_IO;
210    scoped_ptr<base::Thread> thread(new base::Thread("socketio_test"));
211    thread->StartWithOptions(options);
212    thread->message_loop()->PostTask(
213        FROM_HERE,
214        base::Bind(&UnixDomainListenSocketTestHelper::CreateAndListen,
215                   base::Unretained(this)));
216    return thread.Pass();
217  }
218
219  base::ScopedTempDir temp_dir_;
220  base::FilePath file_path_;
221  const bool allow_user_;
222  scoped_refptr<EventManager> event_manager_;
223  scoped_ptr<TestListenSocketDelegate> socket_delegate_;
224  scoped_ptr<UnixDomainListenSocket> socket_;
225};
226
227class UnixDomainListenSocketTest : public UnixDomainListenSocketTestHelper {
228 protected:
229  UnixDomainListenSocketTest()
230      : UnixDomainListenSocketTestHelper(kSocketFilename,
231                                         true /* allow user */) {}
232};
233
234class UnixDomainListenSocketTestWithInvalidPath
235    : public UnixDomainListenSocketTestHelper {
236 protected:
237  UnixDomainListenSocketTestWithInvalidPath()
238      : UnixDomainListenSocketTestHelper(kInvalidSocketPath, true) {}
239};
240
241class UnixDomainListenSocketTestWithForbiddenUser
242    : public UnixDomainListenSocketTestHelper {
243 protected:
244  UnixDomainListenSocketTestWithForbiddenUser()
245      : UnixDomainListenSocketTestHelper(kSocketFilename,
246                                         false /* forbid user */) {}
247};
248
249TEST_F(UnixDomainListenSocketTest, CreateAndListen) {
250  CreateAndListen();
251  EXPECT_FALSE(socket_.get() == NULL);
252}
253
254TEST_F(UnixDomainListenSocketTestWithInvalidPath,
255       CreateAndListenWithInvalidPath) {
256  CreateAndListen();
257  EXPECT_TRUE(socket_.get() == NULL);
258}
259
260#ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
261// Test with an invalid path to make sure that the socket is not backed by a
262// file.
263TEST_F(UnixDomainListenSocketTestWithInvalidPath,
264       CreateAndListenWithAbstractNamespace) {
265  socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
266      file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
267  EXPECT_FALSE(socket_.get() == NULL);
268}
269
270TEST_F(UnixDomainListenSocketTest, TestFallbackName) {
271  scoped_ptr<UnixDomainListenSocket> existing_socket =
272      UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
273          file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
274  EXPECT_FALSE(existing_socket.get() == NULL);
275  // First, try to bind socket with the same name with no fallback name.
276  socket_ =
277      UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
278          file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
279  EXPECT_TRUE(socket_.get() == NULL);
280  // Now with a fallback name.
281  const char kFallbackSocketName[] = "socket_for_testing_2";
282  socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
283      file_path_.value(),
284      GetTempSocketPath(kFallbackSocketName).value(),
285      socket_delegate_.get(),
286      MakeAuthCallback());
287  EXPECT_FALSE(socket_.get() == NULL);
288}
289#endif
290
291TEST_F(UnixDomainListenSocketTest, TestWithClient) {
292  const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
293  EventType event = event_manager_->WaitForEvent();
294  ASSERT_EQ(EVENT_LISTEN, event);
295
296  // Create the client socket.
297  const SocketDescriptor sock = CreateClientSocket();
298  ASSERT_NE(kInvalidSocket, sock);
299  event = event_manager_->WaitForEvent();
300  ASSERT_EQ(EVENT_AUTH_GRANTED, event);
301  event = event_manager_->WaitForEvent();
302  ASSERT_EQ(EVENT_ACCEPT, event);
303
304  // Send a message from the client to the server.
305  ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
306  ASSERT_NE(-1, ret);
307  ASSERT_EQ(sizeof(kMsg), static_cast<size_t>(ret));
308  event = event_manager_->WaitForEvent();
309  ASSERT_EQ(EVENT_READ, event);
310  ASSERT_EQ(kMsg, socket_delegate_->ReceivedData());
311
312  // Close the client socket.
313  ret = IGNORE_EINTR(close(sock));
314  event = event_manager_->WaitForEvent();
315  ASSERT_EQ(EVENT_CLOSE, event);
316}
317
318TEST_F(UnixDomainListenSocketTestWithForbiddenUser, TestWithForbiddenUser) {
319  const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
320  EventType event = event_manager_->WaitForEvent();
321  ASSERT_EQ(EVENT_LISTEN, event);
322  const SocketDescriptor sock = CreateClientSocket();
323  ASSERT_NE(kInvalidSocket, sock);
324
325  event = event_manager_->WaitForEvent();
326  ASSERT_EQ(EVENT_AUTH_DENIED, event);
327
328  // Wait until the file descriptor is closed by the server.
329  struct pollfd poll_fd;
330  poll_fd.fd = sock;
331  poll_fd.events = POLLIN;
332  poll(&poll_fd, 1, -1 /* rely on GTest for timeout handling */);
333
334  // Send() must fail.
335  ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
336  ASSERT_EQ(-1, ret);
337  ASSERT_EQ(EPIPE, errno);
338  ASSERT_FALSE(event_manager_->HasPendingEvent());
339}
340
341}  // namespace
342}  // namespace deprecated
343}  // namespace net
344