1//
2// Copyright (C) 2011 The Android Open Source Project
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//      http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16
17#include "shill/async_connection.h"
18
19#include <netinet/in.h>
20
21#include <vector>
22
23#include <base/bind.h>
24#include <gtest/gtest.h>
25
26#include "shill/mock_event_dispatcher.h"
27#include "shill/net/ip_address.h"
28#include "shill/net/mock_sockets.h"
29
30using base::Bind;
31using base::Callback;
32using base::Unretained;
33using std::string;
34using ::testing::_;
35using ::testing::Return;
36using ::testing::ReturnNew;
37using ::testing::StrEq;
38using ::testing::StrictMock;
39using ::testing::Test;
40
41namespace shill {
42
43namespace {
44const char kInterfaceName[] = "int0";
45const char kIPv4Address[] = "10.11.12.13";
46const char kIPv6Address[] = "2001:db8::1";
47const int kConnectPort = 10203;
48const int kErrorNumber = 30405;
49const int kSocketFD = 60708;
50}  // namespace
51
52class AsyncConnectionTest : public Test {
53 public:
54  AsyncConnectionTest()
55      : async_connection_(
56            new AsyncConnection(kInterfaceName, &dispatcher_, &sockets_,
57                                callback_target_.callback())),
58        ipv4_address_(IPAddress::kFamilyIPv4),
59        ipv6_address_(IPAddress::kFamilyIPv6) { }
60
61  virtual void SetUp() {
62    EXPECT_TRUE(ipv4_address_.SetAddressFromString(kIPv4Address));
63    EXPECT_TRUE(ipv6_address_.SetAddressFromString(kIPv6Address));
64  }
65  virtual void TearDown() {
66    if (async_connection_.get() && async_connection_->fd_ >= 0) {
67      EXPECT_CALL(sockets(), Close(kSocketFD))
68          .WillOnce(Return(0));
69    }
70  }
71  void InvokeFreeConnection(bool /*success*/, int /*fd*/) {
72    async_connection_.reset();
73  }
74
75 protected:
76  class ConnectCallbackTarget {
77   public:
78    ConnectCallbackTarget()
79        : callback_(Bind(&ConnectCallbackTarget::CallTarget,
80                         Unretained(this))) {}
81
82    MOCK_METHOD2(CallTarget, void(bool success, int fd));
83    const Callback<void(bool, int)>& callback() { return callback_; }
84
85   private:
86    Callback<void(bool, int)> callback_;
87  };
88
89  void ExpectReset() {
90    EXPECT_STREQ(kInterfaceName, async_connection_->interface_name_.c_str());
91    EXPECT_EQ(&dispatcher_, async_connection_->dispatcher_);
92    EXPECT_EQ(&sockets_, async_connection_->sockets_);
93    EXPECT_TRUE(callback_target_.callback().
94                Equals(async_connection_->callback_));
95    EXPECT_EQ(-1, async_connection_->fd_);
96    EXPECT_FALSE(async_connection_->connect_completion_callback_.is_null());
97    EXPECT_FALSE(async_connection_->connect_completion_handler_.get());
98  }
99
100  void StartConnection() {
101    EXPECT_CALL(sockets_, Socket(_, _, _))
102        .WillOnce(Return(kSocketFD));
103    EXPECT_CALL(sockets_, SetNonBlocking(kSocketFD))
104        .WillOnce(Return(0));
105    EXPECT_CALL(sockets_, BindToDevice(kSocketFD, StrEq(kInterfaceName)))
106        .WillOnce(Return(0));
107    EXPECT_CALL(sockets(), Connect(kSocketFD, _, _))
108        .WillOnce(Return(-1));
109    EXPECT_CALL(sockets_, Error())
110        .WillOnce(Return(EINPROGRESS));
111    EXPECT_CALL(dispatcher(),
112                CreateReadyHandler(kSocketFD, IOHandler::kModeOutput, _))
113        .WillOnce(ReturnNew<IOHandler>());
114    EXPECT_TRUE(async_connection().Start(ipv4_address_, kConnectPort));
115  }
116
117  void OnConnectCompletion(int fd) {
118    async_connection_->OnConnectCompletion(fd);
119  }
120  AsyncConnection& async_connection() { return *async_connection_.get(); }
121  StrictMock<MockSockets>& sockets() { return sockets_; }
122  MockEventDispatcher& dispatcher() { return dispatcher_; }
123  const IPAddress& ipv4_address() { return ipv4_address_; }
124  const IPAddress& ipv6_address() { return ipv6_address_; }
125  int fd() { return async_connection_->fd_; }
126  void set_fd(int fd) { async_connection_->fd_ = fd; }
127  StrictMock<ConnectCallbackTarget>& callback_target() {
128    return callback_target_;
129  }
130
131 private:
132  MockEventDispatcher dispatcher_;
133  StrictMock<MockSockets> sockets_;
134  StrictMock<ConnectCallbackTarget> callback_target_;
135  std::unique_ptr<AsyncConnection> async_connection_;
136  IPAddress ipv4_address_;
137  IPAddress ipv6_address_;
138};
139
140TEST_F(AsyncConnectionTest, InitState) {
141  ExpectReset();
142  EXPECT_EQ(string(), async_connection().error());
143}
144
145TEST_F(AsyncConnectionTest, StartSocketFailure) {
146  EXPECT_CALL(sockets(), Socket(_, _, _))
147      .WillOnce(Return(-1));
148  EXPECT_CALL(sockets(), Error())
149      .WillOnce(Return(kErrorNumber));
150  EXPECT_FALSE(async_connection().Start(ipv4_address(), kConnectPort));
151  ExpectReset();
152  EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
153}
154
155TEST_F(AsyncConnectionTest, StartNonBlockingFailure) {
156  EXPECT_CALL(sockets(), Socket(_, _, _))
157      .WillOnce(Return(kSocketFD));
158  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
159      .WillOnce(Return(-1));
160  EXPECT_CALL(sockets(), Error())
161      .WillOnce(Return(kErrorNumber));
162  EXPECT_CALL(sockets(), Close(kSocketFD))
163      .WillOnce(Return(0));
164  EXPECT_FALSE(async_connection().Start(ipv4_address(), kConnectPort));
165  ExpectReset();
166  EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
167}
168
169TEST_F(AsyncConnectionTest, StartBindToDeviceFailure) {
170  EXPECT_CALL(sockets(), Socket(_, _, _))
171      .WillOnce(Return(kSocketFD));
172  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
173      .WillOnce(Return(0));
174  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
175      .WillOnce(Return(-1));
176  EXPECT_CALL(sockets(), Error())
177      .WillOnce(Return(kErrorNumber));
178  EXPECT_CALL(sockets(), Close(kSocketFD))
179      .WillOnce(Return(0));
180  EXPECT_FALSE(async_connection().Start(ipv4_address(), kConnectPort));
181  ExpectReset();
182  EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
183}
184
185TEST_F(AsyncConnectionTest, SynchronousFailure) {
186  EXPECT_CALL(sockets(), Socket(_, _, _))
187      .WillOnce(Return(kSocketFD));
188  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
189      .WillOnce(Return(0));
190  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
191      .WillOnce(Return(0));
192  EXPECT_CALL(sockets(), Connect(kSocketFD, _, _))
193      .WillOnce(Return(-1));
194  EXPECT_CALL(sockets(), Error())
195      .Times(2)
196      .WillRepeatedly(Return(0));
197  EXPECT_CALL(sockets(), Close(kSocketFD))
198      .WillOnce(Return(0));
199  EXPECT_FALSE(async_connection().Start(ipv4_address(), kConnectPort));
200  ExpectReset();
201}
202
203MATCHER_P2(IsSocketAddress, address, port, "") {
204  const struct sockaddr_in* arg_saddr =
205      reinterpret_cast<const struct sockaddr_in*>(arg);
206  IPAddress arg_addr(IPAddress::kFamilyIPv4,
207                     ByteString(reinterpret_cast<const unsigned char*>(
208                         &arg_saddr->sin_addr.s_addr),
209                                sizeof(arg_saddr->sin_addr.s_addr)));
210  return address.Equals(arg_addr) && arg_saddr->sin_port == htons(port);
211}
212
213MATCHER_P2(IsSocketIpv6Address, ipv6_address, port, "") {
214  const struct sockaddr_in6* arg_saddr =
215      reinterpret_cast<const struct sockaddr_in6*>(arg);
216  IPAddress arg_addr(IPAddress::kFamilyIPv6,
217                     ByteString(reinterpret_cast<const unsigned char*>(
218                         &arg_saddr->sin6_addr.s6_addr),
219                                sizeof(arg_saddr->sin6_addr.s6_addr)));
220  return ipv6_address.Equals(arg_addr) && arg_saddr->sin6_port == htons(port);
221}
222
223TEST_F(AsyncConnectionTest, SynchronousStart) {
224  EXPECT_CALL(sockets(), Socket(_, _, _))
225      .WillOnce(Return(kSocketFD));
226  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
227      .WillOnce(Return(0));
228  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
229      .WillOnce(Return(0));
230  EXPECT_CALL(sockets(), Connect(kSocketFD,
231                                  IsSocketAddress(ipv4_address(), kConnectPort),
232                                  sizeof(struct sockaddr_in)))
233      .WillOnce(Return(-1));
234  EXPECT_CALL(dispatcher(),
235              CreateReadyHandler(kSocketFD, IOHandler::kModeOutput, _))
236        .WillOnce(ReturnNew<IOHandler>());
237  EXPECT_CALL(sockets(), Error())
238      .WillOnce(Return(EINPROGRESS));
239  EXPECT_TRUE(async_connection().Start(ipv4_address(), kConnectPort));
240  EXPECT_EQ(kSocketFD, fd());
241}
242
243TEST_F(AsyncConnectionTest, SynchronousStartIpv6) {
244  EXPECT_CALL(sockets(), Socket(_, _, _))
245      .WillOnce(Return(kSocketFD));
246  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
247      .WillOnce(Return(0));
248  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
249      .WillOnce(Return(0));
250  EXPECT_CALL(sockets(), Connect(kSocketFD,
251                                  IsSocketIpv6Address(ipv6_address(),
252                                                      kConnectPort),
253                                  sizeof(struct sockaddr_in6)))
254      .WillOnce(Return(-1));
255  EXPECT_CALL(dispatcher(),
256              CreateReadyHandler(kSocketFD, IOHandler::kModeOutput, _))
257        .WillOnce(ReturnNew<IOHandler>());
258  EXPECT_CALL(sockets(), Error())
259      .WillOnce(Return(EINPROGRESS));
260  EXPECT_TRUE(async_connection().Start(ipv6_address(), kConnectPort));
261  EXPECT_EQ(kSocketFD, fd());
262}
263
264TEST_F(AsyncConnectionTest, AsynchronousFailure) {
265  StartConnection();
266  EXPECT_CALL(sockets(), GetSocketError(kSocketFD))
267      .WillOnce(Return(1));
268  EXPECT_CALL(sockets(), Error())
269      .WillOnce(Return(kErrorNumber));
270  EXPECT_CALL(callback_target(), CallTarget(false, -1));
271  EXPECT_CALL(sockets(), Close(kSocketFD))
272      .WillOnce(Return(0));
273  OnConnectCompletion(kSocketFD);
274  ExpectReset();
275  EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
276}
277
278TEST_F(AsyncConnectionTest, AsynchronousSuccess) {
279  StartConnection();
280  EXPECT_CALL(sockets(), GetSocketError(kSocketFD))
281      .WillOnce(Return(0));
282  EXPECT_CALL(callback_target(), CallTarget(true, kSocketFD));
283  OnConnectCompletion(kSocketFD);
284  ExpectReset();
285}
286
287TEST_F(AsyncConnectionTest, SynchronousSuccess) {
288  EXPECT_CALL(sockets(), Socket(_, _, _))
289      .WillOnce(Return(kSocketFD));
290  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
291      .WillOnce(Return(0));
292  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
293      .WillOnce(Return(0));
294  EXPECT_CALL(sockets(), Connect(kSocketFD,
295                                  IsSocketAddress(ipv4_address(), kConnectPort),
296                                  sizeof(struct sockaddr_in)))
297      .WillOnce(Return(0));
298  EXPECT_CALL(callback_target(), CallTarget(true, kSocketFD));
299  EXPECT_TRUE(async_connection().Start(ipv4_address(), kConnectPort));
300  ExpectReset();
301}
302
303TEST_F(AsyncConnectionTest, SynchronousSuccessIpv6) {
304  EXPECT_CALL(sockets(), Socket(_, _, _))
305      .WillOnce(Return(kSocketFD));
306  EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
307      .WillOnce(Return(0));
308  EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
309      .WillOnce(Return(0));
310  EXPECT_CALL(sockets(), Connect(kSocketFD,
311                                  IsSocketIpv6Address(ipv6_address(),
312                                                      kConnectPort),
313                                  sizeof(struct sockaddr_in6)))
314      .WillOnce(Return(0));
315  EXPECT_CALL(callback_target(), CallTarget(true, kSocketFD));
316  EXPECT_TRUE(async_connection().Start(ipv6_address(), kConnectPort));
317  ExpectReset();
318}
319
320TEST_F(AsyncConnectionTest, FreeOnSuccessCallback) {
321  StartConnection();
322  EXPECT_CALL(sockets(), GetSocketError(kSocketFD))
323      .WillOnce(Return(0));
324  EXPECT_CALL(callback_target(), CallTarget(true, kSocketFD))
325      .WillOnce(Invoke(this, &AsyncConnectionTest::InvokeFreeConnection));
326  OnConnectCompletion(kSocketFD);
327}
328
329TEST_F(AsyncConnectionTest, FreeOnFailureCallback) {
330  StartConnection();
331  EXPECT_CALL(sockets(), GetSocketError(kSocketFD))
332      .WillOnce(Return(1));
333  EXPECT_CALL(callback_target(), CallTarget(false, -1))
334      .WillOnce(Invoke(this, &AsyncConnectionTest::InvokeFreeConnection));
335  EXPECT_CALL(sockets(), Error())
336      .WillOnce(Return(kErrorNumber));
337  EXPECT_CALL(sockets(), Close(kSocketFD))
338      .WillOnce(Return(0));
339  OnConnectCompletion(kSocketFD);
340}
341
342}  // namespace shill
343