1// Copyright 2014 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/socket/unix_domain_listen_socket_posix.h" 6 7#include <errno.h> 8#include <fcntl.h> 9#include <poll.h> 10#include <sys/socket.h> 11#include <sys/stat.h> 12#include <sys/time.h> 13#include <sys/types.h> 14#include <sys/un.h> 15#include <unistd.h> 16 17#include <cstring> 18#include <queue> 19#include <string> 20 21#include "base/bind.h" 22#include "base/callback.h" 23#include "base/compiler_specific.h" 24#include "base/files/file_path.h" 25#include "base/files/file_util.h" 26#include "base/files/scoped_temp_dir.h" 27#include "base/memory/ref_counted.h" 28#include "base/memory/scoped_ptr.h" 29#include "base/message_loop/message_loop.h" 30#include "base/posix/eintr_wrapper.h" 31#include "base/synchronization/condition_variable.h" 32#include "base/synchronization/lock.h" 33#include "base/threading/platform_thread.h" 34#include "base/threading/thread.h" 35#include "net/socket/socket_descriptor.h" 36#include "testing/gtest/include/gtest/gtest.h" 37 38using std::queue; 39using std::string; 40 41namespace net { 42namespace deprecated { 43namespace { 44 45const char kSocketFilename[] = "socket_for_testing"; 46const char kInvalidSocketPath[] = "/invalid/path"; 47const char kMsg[] = "hello"; 48 49enum EventType { 50 EVENT_ACCEPT, 51 EVENT_AUTH_DENIED, 52 EVENT_AUTH_GRANTED, 53 EVENT_CLOSE, 54 EVENT_LISTEN, 55 EVENT_READ, 56}; 57 58class EventManager : public base::RefCounted<EventManager> { 59 public: 60 EventManager() : condition_(&mutex_) {} 61 62 bool HasPendingEvent() { 63 base::AutoLock lock(mutex_); 64 return !events_.empty(); 65 } 66 67 void Notify(EventType event) { 68 base::AutoLock lock(mutex_); 69 events_.push(event); 70 condition_.Broadcast(); 71 } 72 73 EventType WaitForEvent() { 74 base::AutoLock lock(mutex_); 75 while (events_.empty()) 76 condition_.Wait(); 77 EventType event = events_.front(); 78 events_.pop(); 79 return event; 80 } 81 82 private: 83 friend class base::RefCounted<EventManager>; 84 virtual ~EventManager() {} 85 86 queue<EventType> events_; 87 base::Lock mutex_; 88 base::ConditionVariable condition_; 89}; 90 91class TestListenSocketDelegate : public StreamListenSocket::Delegate { 92 public: 93 explicit TestListenSocketDelegate( 94 const scoped_refptr<EventManager>& event_manager) 95 : event_manager_(event_manager) {} 96 97 virtual void DidAccept(StreamListenSocket* server, 98 scoped_ptr<StreamListenSocket> connection) OVERRIDE { 99 LOG(ERROR) << __PRETTY_FUNCTION__; 100 connection_ = connection.Pass(); 101 Notify(EVENT_ACCEPT); 102 } 103 104 virtual void DidRead(StreamListenSocket* connection, 105 const char* data, 106 int len) OVERRIDE { 107 { 108 base::AutoLock lock(mutex_); 109 DCHECK(len); 110 data_.assign(data, len - 1); 111 } 112 Notify(EVENT_READ); 113 } 114 115 virtual void DidClose(StreamListenSocket* sock) OVERRIDE { 116 Notify(EVENT_CLOSE); 117 } 118 119 void OnListenCompleted() { 120 Notify(EVENT_LISTEN); 121 } 122 123 string ReceivedData() { 124 base::AutoLock lock(mutex_); 125 return data_; 126 } 127 128 private: 129 void Notify(EventType event) { 130 event_manager_->Notify(event); 131 } 132 133 const scoped_refptr<EventManager> event_manager_; 134 scoped_ptr<StreamListenSocket> connection_; 135 base::Lock mutex_; 136 string data_; 137}; 138 139bool UserCanConnectCallback( 140 bool allow_user, const scoped_refptr<EventManager>& event_manager, 141 const UnixDomainServerSocket::Credentials&) { 142 event_manager->Notify( 143 allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED); 144 return allow_user; 145} 146 147class UnixDomainListenSocketTestHelper : public testing::Test { 148 public: 149 void CreateAndListen() { 150 socket_ = UnixDomainListenSocket::CreateAndListen( 151 file_path_.value(), socket_delegate_.get(), MakeAuthCallback()); 152 socket_delegate_->OnListenCompleted(); 153 } 154 155 protected: 156 UnixDomainListenSocketTestHelper(const string& path_str, bool allow_user) 157 : allow_user_(allow_user) { 158 file_path_ = base::FilePath(path_str); 159 if (!file_path_.IsAbsolute()) { 160 EXPECT_TRUE(temp_dir_.CreateUniqueTempDir()); 161 file_path_ = GetTempSocketPath(file_path_.value()); 162 } 163 // Beware that if path_str is an absolute path, this class doesn't delete 164 // the file. It must be an invalid path and cannot be created by unittests. 165 } 166 167 base::FilePath GetTempSocketPath(const std::string socket_name) { 168 DCHECK(temp_dir_.IsValid()); 169 return temp_dir_.path().Append(socket_name); 170 } 171 172 virtual void SetUp() OVERRIDE { 173 event_manager_ = new EventManager(); 174 socket_delegate_.reset(new TestListenSocketDelegate(event_manager_)); 175 } 176 177 virtual void TearDown() OVERRIDE { 178 socket_.reset(); 179 socket_delegate_.reset(); 180 event_manager_ = NULL; 181 } 182 183 UnixDomainListenSocket::AuthCallback MakeAuthCallback() { 184 return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_); 185 } 186 187 SocketDescriptor CreateClientSocket() { 188 const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0); 189 if (sock < 0) { 190 LOG(ERROR) << "socket() error"; 191 return kInvalidSocket; 192 } 193 sockaddr_un addr; 194 memset(&addr, 0, sizeof(addr)); 195 addr.sun_family = AF_UNIX; 196 socklen_t addr_len; 197 strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path)); 198 addr_len = sizeof(sockaddr_un); 199 if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) { 200 LOG(ERROR) << "connect() error: " << strerror(errno) 201 << ": path=" << file_path_.value(); 202 return kInvalidSocket; 203 } 204 return sock; 205 } 206 207 scoped_ptr<base::Thread> CreateAndRunServerThread() { 208 base::Thread::Options options; 209 options.message_loop_type = base::MessageLoop::TYPE_IO; 210 scoped_ptr<base::Thread> thread(new base::Thread("socketio_test")); 211 thread->StartWithOptions(options); 212 thread->message_loop()->PostTask( 213 FROM_HERE, 214 base::Bind(&UnixDomainListenSocketTestHelper::CreateAndListen, 215 base::Unretained(this))); 216 return thread.Pass(); 217 } 218 219 base::ScopedTempDir temp_dir_; 220 base::FilePath file_path_; 221 const bool allow_user_; 222 scoped_refptr<EventManager> event_manager_; 223 scoped_ptr<TestListenSocketDelegate> socket_delegate_; 224 scoped_ptr<UnixDomainListenSocket> socket_; 225}; 226 227class UnixDomainListenSocketTest : public UnixDomainListenSocketTestHelper { 228 protected: 229 UnixDomainListenSocketTest() 230 : UnixDomainListenSocketTestHelper(kSocketFilename, 231 true /* allow user */) {} 232}; 233 234class UnixDomainListenSocketTestWithInvalidPath 235 : public UnixDomainListenSocketTestHelper { 236 protected: 237 UnixDomainListenSocketTestWithInvalidPath() 238 : UnixDomainListenSocketTestHelper(kInvalidSocketPath, true) {} 239}; 240 241class UnixDomainListenSocketTestWithForbiddenUser 242 : public UnixDomainListenSocketTestHelper { 243 protected: 244 UnixDomainListenSocketTestWithForbiddenUser() 245 : UnixDomainListenSocketTestHelper(kSocketFilename, 246 false /* forbid user */) {} 247}; 248 249TEST_F(UnixDomainListenSocketTest, CreateAndListen) { 250 CreateAndListen(); 251 EXPECT_FALSE(socket_.get() == NULL); 252} 253 254TEST_F(UnixDomainListenSocketTestWithInvalidPath, 255 CreateAndListenWithInvalidPath) { 256 CreateAndListen(); 257 EXPECT_TRUE(socket_.get() == NULL); 258} 259 260#ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED 261// Test with an invalid path to make sure that the socket is not backed by a 262// file. 263TEST_F(UnixDomainListenSocketTestWithInvalidPath, 264 CreateAndListenWithAbstractNamespace) { 265 socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( 266 file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); 267 EXPECT_FALSE(socket_.get() == NULL); 268} 269 270TEST_F(UnixDomainListenSocketTest, TestFallbackName) { 271 scoped_ptr<UnixDomainListenSocket> existing_socket = 272 UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( 273 file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); 274 EXPECT_FALSE(existing_socket.get() == NULL); 275 // First, try to bind socket with the same name with no fallback name. 276 socket_ = 277 UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( 278 file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); 279 EXPECT_TRUE(socket_.get() == NULL); 280 // Now with a fallback name. 281 const char kFallbackSocketName[] = "socket_for_testing_2"; 282 socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace( 283 file_path_.value(), 284 GetTempSocketPath(kFallbackSocketName).value(), 285 socket_delegate_.get(), 286 MakeAuthCallback()); 287 EXPECT_FALSE(socket_.get() == NULL); 288} 289#endif 290 291TEST_F(UnixDomainListenSocketTest, TestWithClient) { 292 const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread(); 293 EventType event = event_manager_->WaitForEvent(); 294 ASSERT_EQ(EVENT_LISTEN, event); 295 296 // Create the client socket. 297 const SocketDescriptor sock = CreateClientSocket(); 298 ASSERT_NE(kInvalidSocket, sock); 299 event = event_manager_->WaitForEvent(); 300 ASSERT_EQ(EVENT_AUTH_GRANTED, event); 301 event = event_manager_->WaitForEvent(); 302 ASSERT_EQ(EVENT_ACCEPT, event); 303 304 // Send a message from the client to the server. 305 ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0)); 306 ASSERT_NE(-1, ret); 307 ASSERT_EQ(sizeof(kMsg), static_cast<size_t>(ret)); 308 event = event_manager_->WaitForEvent(); 309 ASSERT_EQ(EVENT_READ, event); 310 ASSERT_EQ(kMsg, socket_delegate_->ReceivedData()); 311 312 // Close the client socket. 313 ret = IGNORE_EINTR(close(sock)); 314 event = event_manager_->WaitForEvent(); 315 ASSERT_EQ(EVENT_CLOSE, event); 316} 317 318TEST_F(UnixDomainListenSocketTestWithForbiddenUser, TestWithForbiddenUser) { 319 const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread(); 320 EventType event = event_manager_->WaitForEvent(); 321 ASSERT_EQ(EVENT_LISTEN, event); 322 const SocketDescriptor sock = CreateClientSocket(); 323 ASSERT_NE(kInvalidSocket, sock); 324 325 event = event_manager_->WaitForEvent(); 326 ASSERT_EQ(EVENT_AUTH_DENIED, event); 327 328 // Wait until the file descriptor is closed by the server. 329 struct pollfd poll_fd; 330 poll_fd.fd = sock; 331 poll_fd.events = POLLIN; 332 poll(&poll_fd, 1, -1 /* rely on GTest for timeout handling */); 333 334 // Send() must fail. 335 ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0)); 336 ASSERT_EQ(-1, ret); 337 ASSERT_EQ(EPIPE, errno); 338 ASSERT_FALSE(event_manager_->HasPendingEvent()); 339} 340 341} // namespace 342} // namespace deprecated 343} // namespace net 344