1// Copyright (c) 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/tcp_client_socket.h"
6
7#include "base/callback_helpers.h"
8#include "base/logging.h"
9#include "net/base/io_buffer.h"
10#include "net/base/ip_endpoint.h"
11#include "net/base/net_errors.h"
12#include "net/base/net_util.h"
13
14namespace net {
15
16TCPClientSocket::TCPClientSocket(const AddressList& addresses,
17                                 net::NetLog* net_log,
18                                 const net::NetLog::Source& source)
19    : socket_(new TCPSocket(net_log, source)),
20      addresses_(addresses),
21      current_address_index_(-1),
22      next_connect_state_(CONNECT_STATE_NONE),
23      previously_disconnected_(false) {
24}
25
26TCPClientSocket::TCPClientSocket(scoped_ptr<TCPSocket> connected_socket,
27                                 const IPEndPoint& peer_address)
28    : socket_(connected_socket.Pass()),
29      addresses_(AddressList(peer_address)),
30      current_address_index_(0),
31      next_connect_state_(CONNECT_STATE_NONE),
32      previously_disconnected_(false) {
33  DCHECK(socket_);
34
35  socket_->SetDefaultOptionsForClient();
36  use_history_.set_was_ever_connected();
37}
38
39TCPClientSocket::~TCPClientSocket() {
40}
41
42int TCPClientSocket::Bind(const IPEndPoint& address) {
43  if (current_address_index_ >= 0 || bind_address_) {
44    // Cannot bind the socket if we are already connected or connecting.
45    NOTREACHED();
46    return ERR_UNEXPECTED;
47  }
48
49  int result = OK;
50  if (!socket_->IsValid()) {
51    result = OpenSocket(address.GetFamily());
52    if (result != OK)
53      return result;
54  }
55
56  result = socket_->Bind(address);
57  if (result != OK)
58    return result;
59
60  bind_address_.reset(new IPEndPoint(address));
61  return OK;
62}
63
64int TCPClientSocket::Connect(const CompletionCallback& callback) {
65  DCHECK(!callback.is_null());
66
67  // If connecting or already connected, then just return OK.
68  if (socket_->IsValid() && current_address_index_ >= 0)
69    return OK;
70
71  socket_->StartLoggingMultipleConnectAttempts(addresses_);
72
73  // We will try to connect to each address in addresses_. Start with the
74  // first one in the list.
75  next_connect_state_ = CONNECT_STATE_CONNECT;
76  current_address_index_ = 0;
77
78  int rv = DoConnectLoop(OK);
79  if (rv == ERR_IO_PENDING) {
80    connect_callback_ = callback;
81  } else {
82    socket_->EndLoggingMultipleConnectAttempts(rv);
83  }
84
85  return rv;
86}
87
88int TCPClientSocket::DoConnectLoop(int result) {
89  DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE);
90
91  int rv = result;
92  do {
93    ConnectState state = next_connect_state_;
94    next_connect_state_ = CONNECT_STATE_NONE;
95    switch (state) {
96      case CONNECT_STATE_CONNECT:
97        DCHECK_EQ(OK, rv);
98        rv = DoConnect();
99        break;
100      case CONNECT_STATE_CONNECT_COMPLETE:
101        rv = DoConnectComplete(rv);
102        break;
103      default:
104        NOTREACHED() << "bad state " << state;
105        rv = ERR_UNEXPECTED;
106        break;
107    }
108  } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE);
109
110  return rv;
111}
112
113int TCPClientSocket::DoConnect() {
114  DCHECK_GE(current_address_index_, 0);
115  DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size()));
116
117  const IPEndPoint& endpoint = addresses_[current_address_index_];
118
119  if (previously_disconnected_) {
120    use_history_.Reset();
121    previously_disconnected_ = false;
122  }
123
124  next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE;
125
126  if (socket_->IsValid()) {
127    DCHECK(bind_address_);
128  } else {
129    int result = OpenSocket(endpoint.GetFamily());
130    if (result != OK)
131      return result;
132
133    if (bind_address_) {
134      result = socket_->Bind(*bind_address_);
135      if (result != OK) {
136        socket_->Close();
137        return result;
138      }
139    }
140  }
141
142  // |socket_| is owned by this class and the callback won't be run once
143  // |socket_| is gone. Therefore, it is safe to use base::Unretained() here.
144  return socket_->Connect(endpoint,
145                          base::Bind(&TCPClientSocket::DidCompleteConnect,
146                                     base::Unretained(this)));
147}
148
149int TCPClientSocket::DoConnectComplete(int result) {
150  if (result == OK) {
151    use_history_.set_was_ever_connected();
152    return OK;  // Done!
153  }
154
155  // Close whatever partially connected socket we currently have.
156  DoDisconnect();
157
158  // Try to fall back to the next address in the list.
159  if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) {
160    next_connect_state_ = CONNECT_STATE_CONNECT;
161    ++current_address_index_;
162    return OK;
163  }
164
165  // Otherwise there is nothing to fall back to, so give up.
166  return result;
167}
168
169void TCPClientSocket::Disconnect() {
170  DoDisconnect();
171  current_address_index_ = -1;
172  bind_address_.reset();
173}
174
175void TCPClientSocket::DoDisconnect() {
176  // If connecting or already connected, record that the socket has been
177  // disconnected.
178  previously_disconnected_ = socket_->IsValid() && current_address_index_ >= 0;
179  socket_->Close();
180}
181
182bool TCPClientSocket::IsConnected() const {
183  return socket_->IsConnected();
184}
185
186bool TCPClientSocket::IsConnectedAndIdle() const {
187  return socket_->IsConnectedAndIdle();
188}
189
190int TCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
191  return socket_->GetPeerAddress(address);
192}
193
194int TCPClientSocket::GetLocalAddress(IPEndPoint* address) const {
195  DCHECK(address);
196
197  if (!socket_->IsValid()) {
198    if (bind_address_) {
199      *address = *bind_address_;
200      return OK;
201    }
202    return ERR_SOCKET_NOT_CONNECTED;
203  }
204
205  return socket_->GetLocalAddress(address);
206}
207
208const BoundNetLog& TCPClientSocket::NetLog() const {
209  return socket_->net_log();
210}
211
212void TCPClientSocket::SetSubresourceSpeculation() {
213  use_history_.set_subresource_speculation();
214}
215
216void TCPClientSocket::SetOmniboxSpeculation() {
217  use_history_.set_omnibox_speculation();
218}
219
220bool TCPClientSocket::WasEverUsed() const {
221  return use_history_.was_used_to_convey_data();
222}
223
224bool TCPClientSocket::UsingTCPFastOpen() const {
225  return socket_->UsingTCPFastOpen();
226}
227
228void TCPClientSocket::EnableTCPFastOpenIfSupported() {
229  socket_->EnableTCPFastOpenIfSupported();
230}
231
232bool TCPClientSocket::WasNpnNegotiated() const {
233  return false;
234}
235
236NextProto TCPClientSocket::GetNegotiatedProtocol() const {
237  return kProtoUnknown;
238}
239
240bool TCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
241  return false;
242}
243
244int TCPClientSocket::Read(IOBuffer* buf,
245                          int buf_len,
246                          const CompletionCallback& callback) {
247  DCHECK(!callback.is_null());
248
249  // |socket_| is owned by this class and the callback won't be run once
250  // |socket_| is gone. Therefore, it is safe to use base::Unretained() here.
251  CompletionCallback read_callback = base::Bind(
252      &TCPClientSocket::DidCompleteReadWrite, base::Unretained(this), callback);
253  int result = socket_->Read(buf, buf_len, read_callback);
254  if (result > 0)
255    use_history_.set_was_used_to_convey_data();
256
257  return result;
258}
259
260int TCPClientSocket::Write(IOBuffer* buf,
261                           int buf_len,
262                           const CompletionCallback& callback) {
263  DCHECK(!callback.is_null());
264
265  // |socket_| is owned by this class and the callback won't be run once
266  // |socket_| is gone. Therefore, it is safe to use base::Unretained() here.
267  CompletionCallback write_callback = base::Bind(
268      &TCPClientSocket::DidCompleteReadWrite, base::Unretained(this), callback);
269  int result = socket_->Write(buf, buf_len, write_callback);
270  if (result > 0)
271    use_history_.set_was_used_to_convey_data();
272
273  return result;
274}
275
276int TCPClientSocket::SetReceiveBufferSize(int32 size) {
277  return socket_->SetReceiveBufferSize(size);
278}
279
280int TCPClientSocket::SetSendBufferSize(int32 size) {
281    return socket_->SetSendBufferSize(size);
282}
283
284bool TCPClientSocket::SetKeepAlive(bool enable, int delay) {
285  return socket_->SetKeepAlive(enable, delay);
286}
287
288bool TCPClientSocket::SetNoDelay(bool no_delay) {
289  return socket_->SetNoDelay(no_delay);
290}
291
292void TCPClientSocket::DidCompleteConnect(int result) {
293  DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE);
294  DCHECK_NE(result, ERR_IO_PENDING);
295  DCHECK(!connect_callback_.is_null());
296
297  result = DoConnectLoop(result);
298  if (result != ERR_IO_PENDING) {
299    socket_->EndLoggingMultipleConnectAttempts(result);
300    base::ResetAndReturn(&connect_callback_).Run(result);
301  }
302}
303
304void TCPClientSocket::DidCompleteReadWrite(const CompletionCallback& callback,
305                                           int result) {
306  if (result > 0)
307    use_history_.set_was_used_to_convey_data();
308
309  callback.Run(result);
310}
311
312int TCPClientSocket::OpenSocket(AddressFamily family) {
313  DCHECK(!socket_->IsValid());
314
315  int result = socket_->Open(family);
316  if (result != OK)
317    return result;
318
319  socket_->SetDefaultOptionsForClient();
320
321  return OK;
322}
323
324}  // namespace net
325