async_connection.cc revision c0beca55d290fe0b1c96d78cbbcf94b05c23f5a5
1//
2// Copyright (C) 2012 The Android Open Source Project
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//      http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16
17#include "shill/async_connection.h"
18
19#include <base/bind.h>
20#include <errno.h>
21#include <netinet/in.h>
22
23#include <string>
24
25#include "shill/event_dispatcher.h"
26#include "shill/net/ip_address.h"
27#include "shill/net/sockets.h"
28
29using base::Bind;
30using base::Callback;
31using base::Unretained;
32using std::string;
33
34namespace shill {
35
36AsyncConnection::AsyncConnection(const string& interface_name,
37                                 EventDispatcher* dispatcher,
38                                 Sockets* sockets,
39                                 const Callback<void(bool, int)>& callback)
40    : interface_name_(interface_name),
41      dispatcher_(dispatcher),
42      sockets_(sockets),
43      callback_(callback),
44      fd_(-1),
45      connect_completion_callback_(
46          Bind(&AsyncConnection::OnConnectCompletion, Unretained(this))) { }
47
48AsyncConnection::~AsyncConnection() {
49  Stop();
50}
51
52bool AsyncConnection::Start(const IPAddress& address, int port) {
53  DCHECK_LT(fd_, 0);
54
55  int family = PF_INET;
56  if (address.family() == IPAddress::kFamilyIPv6) {
57    family = PF_INET6;
58  }
59  fd_ = sockets_->Socket(family, SOCK_STREAM, 0);
60  if (fd_ < 0 ||
61      sockets_->SetNonBlocking(fd_) < 0) {
62    error_ = sockets_->ErrorString();
63    PLOG(ERROR) << "Async socket setup failed";
64    Stop();
65    return false;
66  }
67
68  if (!interface_name_.empty() &&
69      sockets_->BindToDevice(fd_, interface_name_) < 0) {
70    error_ = sockets_->ErrorString();
71    PLOG(ERROR) << "Async socket failed to bind to device";
72    Stop();
73    return false;
74  }
75
76  int ret = ConnectTo(address, port);
77  if (ret == 0) {
78    callback_.Run(true, fd_);  // Passes ownership
79    fd_ = -1;
80    return true;
81  }
82
83  if (sockets_->Error() != EINPROGRESS) {
84    error_ = sockets_->ErrorString();
85    PLOG(ERROR) << "Async socket connection failed";
86    Stop();
87    return false;
88  }
89
90  connect_completion_handler_.reset(
91      dispatcher_->CreateReadyHandler(fd_,
92                                      IOHandler::kModeOutput,
93                                      connect_completion_callback_));
94  error_ = string();
95
96  return true;
97}
98
99void AsyncConnection::Stop() {
100  connect_completion_handler_.reset();
101  if (fd_ >= 0) {
102    sockets_->Close(fd_);
103    fd_ = -1;
104  }
105}
106
107void AsyncConnection::OnConnectCompletion(int fd) {
108  CHECK_EQ(fd_, fd);
109  bool success = false;
110  int returned_fd = -1;
111
112  if (sockets_->GetSocketError(fd_) != 0) {
113    error_ = sockets_->ErrorString();
114    PLOG(ERROR) << "Async GetSocketError returns failure";
115  } else {
116    returned_fd = fd_;
117    fd_ = -1;
118    success = true;
119  }
120  Stop();
121
122  // Run the callback last, since it may end up freeing this instance.
123  callback_.Run(success, returned_fd);  // Passes ownership
124}
125
126int AsyncConnection::ConnectTo(const IPAddress& address, int port) {
127  struct sockaddr* sock_addr = nullptr;
128  socklen_t addr_len = 0;
129  struct sockaddr_in iaddr;
130  struct sockaddr_in6 iaddr6;
131  if (address.family() == IPAddress::kFamilyIPv4) {
132    CHECK_EQ(sizeof(iaddr.sin_addr.s_addr), address.GetLength());
133
134    memset(&iaddr, 0, sizeof(iaddr));
135    iaddr.sin_family = AF_INET;
136    memcpy(&iaddr.sin_addr.s_addr, address.address().GetConstData(),
137           sizeof(iaddr.sin_addr.s_addr));
138    iaddr.sin_port = htons(port);
139
140    sock_addr = reinterpret_cast<struct sockaddr*>(&iaddr);
141    addr_len = sizeof(iaddr);
142  } else if (address.family() == IPAddress::kFamilyIPv6) {
143    CHECK_EQ(sizeof(iaddr6.sin6_addr.s6_addr), address.GetLength());
144
145    memset(&iaddr6, 0, sizeof(iaddr6));
146    iaddr6.sin6_family = AF_INET6;
147    memcpy(&iaddr6.sin6_addr.s6_addr, address.address().GetConstData(),
148           sizeof(iaddr6.sin6_addr.s6_addr));
149    iaddr6.sin6_port = htons(port);
150
151    sock_addr = reinterpret_cast<struct sockaddr*>(&iaddr6);
152    addr_len = sizeof(iaddr6);
153  } else {
154    NOTREACHED();
155  }
156
157  return sockets_->Connect(fd_, sock_addr, addr_len);
158}
159
160}  // namespace shill
161