1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "jingle/glue/chrome_async_socket.h"
6
7#include <algorithm>
8#include <cstdlib>
9#include <cstring>
10
11#include "base/basictypes.h"
12#include "base/bind.h"
13#include "base/compiler_specific.h"
14#include "base/logging.h"
15#include "base/message_loop/message_loop.h"
16#include "jingle/glue/resolving_client_socket_factory.h"
17#include "net/base/address_list.h"
18#include "net/base/host_port_pair.h"
19#include "net/base/io_buffer.h"
20#include "net/base/net_util.h"
21#include "net/socket/client_socket_handle.h"
22#include "net/socket/ssl_client_socket.h"
23#include "net/socket/tcp_client_socket.h"
24#include "net/ssl/ssl_config_service.h"
25#include "third_party/libjingle/source/talk/base/socketaddress.h"
26
27namespace jingle_glue {
28
29ChromeAsyncSocket::ChromeAsyncSocket(
30    ResolvingClientSocketFactory* resolving_client_socket_factory,
31    size_t read_buf_size,
32    size_t write_buf_size)
33    : resolving_client_socket_factory_(resolving_client_socket_factory),
34      state_(STATE_CLOSED),
35      error_(ERROR_NONE),
36      net_error_(net::OK),
37      read_state_(IDLE),
38      read_buf_(new net::IOBufferWithSize(read_buf_size)),
39      read_start_(0U),
40      read_end_(0U),
41      write_state_(IDLE),
42      write_buf_(new net::IOBufferWithSize(write_buf_size)),
43      write_end_(0U),
44      weak_ptr_factory_(this) {
45  DCHECK(resolving_client_socket_factory_.get());
46  DCHECK_GT(read_buf_size, 0U);
47  DCHECK_GT(write_buf_size, 0U);
48}
49
50ChromeAsyncSocket::~ChromeAsyncSocket() {}
51
52ChromeAsyncSocket::State ChromeAsyncSocket::state() {
53  return state_;
54}
55
56ChromeAsyncSocket::Error ChromeAsyncSocket::error() {
57  return error_;
58}
59
60int ChromeAsyncSocket::GetError() {
61  return net_error_;
62}
63
64bool ChromeAsyncSocket::IsOpen() const {
65  return (state_ == STATE_OPEN) || (state_ == STATE_TLS_OPEN);
66}
67
68void ChromeAsyncSocket::DoNonNetError(Error error) {
69  DCHECK_NE(error, ERROR_NONE);
70  DCHECK_NE(error, ERROR_WINSOCK);
71  error_ = error;
72  net_error_ = net::OK;
73}
74
75void ChromeAsyncSocket::DoNetError(net::Error net_error) {
76  error_ = ERROR_WINSOCK;
77  net_error_ = net_error;
78}
79
80void ChromeAsyncSocket::DoNetErrorFromStatus(int status) {
81  DCHECK_LT(status, net::OK);
82  DoNetError(static_cast<net::Error>(status));
83}
84
85// STATE_CLOSED -> STATE_CONNECTING
86
87bool ChromeAsyncSocket::Connect(const talk_base::SocketAddress& address) {
88  if (state_ != STATE_CLOSED) {
89    LOG(DFATAL) << "Connect() called on non-closed socket";
90    DoNonNetError(ERROR_WRONGSTATE);
91    return false;
92  }
93  if (address.hostname().empty() || address.port() == 0) {
94    DoNonNetError(ERROR_DNS);
95    return false;
96  }
97
98  DCHECK_EQ(state_, buzz::AsyncSocket::STATE_CLOSED);
99  DCHECK_EQ(read_state_, IDLE);
100  DCHECK_EQ(write_state_, IDLE);
101
102  state_ = STATE_CONNECTING;
103
104  DCHECK(!weak_ptr_factory_.HasWeakPtrs());
105  weak_ptr_factory_.InvalidateWeakPtrs();
106
107  net::HostPortPair dest_host_port_pair(address.hostname(), address.port());
108
109  transport_socket_ =
110      resolving_client_socket_factory_->CreateTransportClientSocket(
111          dest_host_port_pair);
112  int status = transport_socket_->Connect(
113      base::Bind(&ChromeAsyncSocket::ProcessConnectDone,
114                 weak_ptr_factory_.GetWeakPtr()));
115  if (status != net::ERR_IO_PENDING) {
116    // We defer execution of ProcessConnectDone instead of calling it
117    // directly here as the caller may not expect an error/close to
118    // happen here.  This is okay, as from the caller's point of view,
119    // the connect always happens asynchronously.
120    base::MessageLoop* message_loop = base::MessageLoop::current();
121    CHECK(message_loop);
122    message_loop->PostTask(
123        FROM_HERE,
124        base::Bind(&ChromeAsyncSocket::ProcessConnectDone,
125                   weak_ptr_factory_.GetWeakPtr(), status));
126  }
127  return true;
128}
129
130// STATE_CONNECTING -> STATE_OPEN
131// read_state_ == IDLE -> read_state_ == POSTED (via PostDoRead())
132
133void ChromeAsyncSocket::ProcessConnectDone(int status) {
134  DCHECK_NE(status, net::ERR_IO_PENDING);
135  DCHECK_EQ(read_state_, IDLE);
136  DCHECK_EQ(write_state_, IDLE);
137  DCHECK_EQ(state_, STATE_CONNECTING);
138  if (status != net::OK) {
139    DoNetErrorFromStatus(status);
140    DoClose();
141    return;
142  }
143  state_ = STATE_OPEN;
144  PostDoRead();
145  // Write buffer should be empty.
146  DCHECK_EQ(write_end_, 0U);
147  SignalConnected();
148}
149
150// read_state_ == IDLE -> read_state_ == POSTED
151
152void ChromeAsyncSocket::PostDoRead() {
153  DCHECK(IsOpen());
154  DCHECK_EQ(read_state_, IDLE);
155  DCHECK_EQ(read_start_, 0U);
156  DCHECK_EQ(read_end_, 0U);
157  base::MessageLoop* message_loop = base::MessageLoop::current();
158  CHECK(message_loop);
159  message_loop->PostTask(
160      FROM_HERE,
161      base::Bind(&ChromeAsyncSocket::DoRead,
162                 weak_ptr_factory_.GetWeakPtr()));
163  read_state_ = POSTED;
164}
165
166// read_state_ == POSTED -> read_state_ == PENDING
167
168void ChromeAsyncSocket::DoRead() {
169  DCHECK(IsOpen());
170  DCHECK_EQ(read_state_, POSTED);
171  DCHECK_EQ(read_start_, 0U);
172  DCHECK_EQ(read_end_, 0U);
173  // Once we call Read(), we cannot call StartTls() until the read
174  // finishes.  This is okay, as StartTls() is called only from a read
175  // handler (i.e., after a read finishes and before another read is
176  // done).
177  int status =
178      transport_socket_->Read(
179          read_buf_.get(), read_buf_->size(),
180          base::Bind(&ChromeAsyncSocket::ProcessReadDone,
181                     weak_ptr_factory_.GetWeakPtr()));
182  read_state_ = PENDING;
183  if (status != net::ERR_IO_PENDING) {
184    ProcessReadDone(status);
185  }
186}
187
188// read_state_ == PENDING -> read_state_ == IDLE
189
190void ChromeAsyncSocket::ProcessReadDone(int status) {
191  DCHECK_NE(status, net::ERR_IO_PENDING);
192  DCHECK(IsOpen());
193  DCHECK_EQ(read_state_, PENDING);
194  DCHECK_EQ(read_start_, 0U);
195  DCHECK_EQ(read_end_, 0U);
196  read_state_ = IDLE;
197  if (status > 0) {
198    read_end_ = static_cast<size_t>(status);
199    SignalRead();
200  } else if (status == 0) {
201    // Other side closed the connection.
202    error_ = ERROR_NONE;
203    net_error_ = net::OK;
204    DoClose();
205  } else {  // status < 0
206    DoNetErrorFromStatus(status);
207    DoClose();
208  }
209}
210
211// (maybe) read_state_ == IDLE -> read_state_ == POSTED (via
212// PostDoRead())
213
214bool ChromeAsyncSocket::Read(char* data, size_t len, size_t* len_read) {
215  if (!IsOpen() && (state_ != STATE_TLS_CONNECTING)) {
216    // Read() may be called on a closed socket if a previous read
217    // causes a socket close (e.g., client sends wrong password and
218    // server terminates connection).
219    //
220    // TODO(akalin): Fix handling of this on the libjingle side.
221    if (state_ != STATE_CLOSED) {
222      LOG(DFATAL) << "Read() called on non-open non-tls-connecting socket";
223    }
224    DoNonNetError(ERROR_WRONGSTATE);
225    return false;
226  }
227  DCHECK_LE(read_start_, read_end_);
228  if ((state_ == STATE_TLS_CONNECTING) || read_end_ == 0U) {
229    if (state_ == STATE_TLS_CONNECTING) {
230      DCHECK_EQ(read_state_, IDLE);
231      DCHECK_EQ(read_end_, 0U);
232    } else {
233      DCHECK_NE(read_state_, IDLE);
234    }
235    *len_read = 0;
236    return true;
237  }
238  DCHECK_EQ(read_state_, IDLE);
239  *len_read = std::min(len, read_end_ - read_start_);
240  DCHECK_GT(*len_read, 0U);
241  std::memcpy(data, read_buf_->data() + read_start_, *len_read);
242  read_start_ += *len_read;
243  if (read_start_ == read_end_) {
244    read_start_ = 0U;
245    read_end_ = 0U;
246    // We defer execution of DoRead() here for similar reasons as
247    // ProcessConnectDone().
248    PostDoRead();
249  }
250  return true;
251}
252
253// (maybe) write_state_ == IDLE -> write_state_ == POSTED (via
254// PostDoWrite())
255
256bool ChromeAsyncSocket::Write(const char* data, size_t len) {
257  if (!IsOpen() && (state_ != STATE_TLS_CONNECTING)) {
258    LOG(DFATAL) << "Write() called on non-open non-tls-connecting socket";
259    DoNonNetError(ERROR_WRONGSTATE);
260    return false;
261  }
262  // TODO(akalin): Avoid this check by modifying the interface to have
263  // a "ready for writing" signal.
264  if ((static_cast<size_t>(write_buf_->size()) - write_end_) < len) {
265    LOG(DFATAL) << "queueing " << len << " bytes would exceed the "
266                << "max write buffer size = " << write_buf_->size()
267                << " by " << (len - write_buf_->size()) << " bytes";
268    DoNetError(net::ERR_INSUFFICIENT_RESOURCES);
269    return false;
270  }
271  std::memcpy(write_buf_->data() + write_end_, data, len);
272  write_end_ += len;
273  // If we're TLS-connecting, the write buffer will get flushed once
274  // the TLS-connect finishes.  Otherwise, start writing if we're not
275  // already writing and we have something to write.
276  if ((state_ != STATE_TLS_CONNECTING) &&
277      (write_state_ == IDLE) && (write_end_ > 0U)) {
278    // We defer execution of DoWrite() here for similar reasons as
279    // ProcessConnectDone().
280    PostDoWrite();
281  }
282  return true;
283}
284
285// write_state_ == IDLE -> write_state_ == POSTED
286
287void ChromeAsyncSocket::PostDoWrite() {
288  DCHECK(IsOpen());
289  DCHECK_EQ(write_state_, IDLE);
290  DCHECK_GT(write_end_, 0U);
291  base::MessageLoop* message_loop = base::MessageLoop::current();
292  CHECK(message_loop);
293  message_loop->PostTask(
294      FROM_HERE,
295      base::Bind(&ChromeAsyncSocket::DoWrite,
296                 weak_ptr_factory_.GetWeakPtr()));
297  write_state_ = POSTED;
298}
299
300// write_state_ == POSTED -> write_state_ == PENDING
301
302void ChromeAsyncSocket::DoWrite() {
303  DCHECK(IsOpen());
304  DCHECK_EQ(write_state_, POSTED);
305  DCHECK_GT(write_end_, 0U);
306  // Once we call Write(), we cannot call StartTls() until the write
307  // finishes.  This is okay, as StartTls() is called only after we
308  // have received a reply to a message we sent to the server and
309  // before we send the next message.
310  int status =
311      transport_socket_->Write(
312          write_buf_.get(), write_end_,
313          base::Bind(&ChromeAsyncSocket::ProcessWriteDone,
314                     weak_ptr_factory_.GetWeakPtr()));
315  write_state_ = PENDING;
316  if (status != net::ERR_IO_PENDING) {
317    ProcessWriteDone(status);
318  }
319}
320
321// write_state_ == PENDING -> write_state_ == IDLE or POSTED (the
322// latter via PostDoWrite())
323
324void ChromeAsyncSocket::ProcessWriteDone(int status) {
325  DCHECK_NE(status, net::ERR_IO_PENDING);
326  DCHECK(IsOpen());
327  DCHECK_EQ(write_state_, PENDING);
328  DCHECK_GT(write_end_, 0U);
329  write_state_ = IDLE;
330  if (status < net::OK) {
331    DoNetErrorFromStatus(status);
332    DoClose();
333    return;
334  }
335  size_t written = static_cast<size_t>(status);
336  if (written > write_end_) {
337    LOG(DFATAL) << "bytes written = " << written
338                << " exceeds bytes requested = " << write_end_;
339    DoNetError(net::ERR_UNEXPECTED);
340    DoClose();
341    return;
342  }
343  // TODO(akalin): Figure out a better way to do this; perhaps a queue
344  // of DrainableIOBuffers.  This'll also allow us to not have an
345  // artificial buffer size limit.
346  std::memmove(write_buf_->data(),
347               write_buf_->data() + written,
348               write_end_ - written);
349  write_end_ -= written;
350  if (write_end_ > 0U) {
351    PostDoWrite();
352  }
353}
354
355// * -> STATE_CLOSED
356
357bool ChromeAsyncSocket::Close() {
358  DoClose();
359  return true;
360}
361
362// (not STATE_CLOSED) -> STATE_CLOSED
363
364void ChromeAsyncSocket::DoClose() {
365  weak_ptr_factory_.InvalidateWeakPtrs();
366  if (transport_socket_.get()) {
367    transport_socket_->Disconnect();
368  }
369  transport_socket_.reset();
370  read_state_ = IDLE;
371  read_start_ = 0U;
372  read_end_ = 0U;
373  write_state_ = IDLE;
374  write_end_ = 0U;
375  if (state_ != STATE_CLOSED) {
376    state_ = STATE_CLOSED;
377    SignalClosed();
378  }
379  // Reset error variables after SignalClosed() so slots connected
380  // to it can read it.
381  error_ = ERROR_NONE;
382  net_error_ = net::OK;
383}
384
385// STATE_OPEN -> STATE_TLS_CONNECTING
386
387bool ChromeAsyncSocket::StartTls(const std::string& domain_name) {
388  if ((state_ != STATE_OPEN) || (read_state_ == PENDING) ||
389      (write_state_ != IDLE)) {
390    LOG(DFATAL) << "StartTls() called in wrong state";
391    DoNonNetError(ERROR_WRONGSTATE);
392    return false;
393  }
394
395  state_ = STATE_TLS_CONNECTING;
396  read_state_ = IDLE;
397  read_start_ = 0U;
398  read_end_ = 0U;
399  DCHECK_EQ(write_end_, 0U);
400
401  // Clear out any posted DoRead() tasks.
402  weak_ptr_factory_.InvalidateWeakPtrs();
403
404  DCHECK(transport_socket_.get());
405  scoped_ptr<net::ClientSocketHandle> socket_handle(
406      new net::ClientSocketHandle());
407  socket_handle->SetSocket(transport_socket_.Pass());
408  transport_socket_ =
409      resolving_client_socket_factory_->CreateSSLClientSocket(
410          socket_handle.Pass(), net::HostPortPair(domain_name, 443));
411  int status = transport_socket_->Connect(
412      base::Bind(&ChromeAsyncSocket::ProcessSSLConnectDone,
413                 weak_ptr_factory_.GetWeakPtr()));
414  if (status != net::ERR_IO_PENDING) {
415    base::MessageLoop* message_loop = base::MessageLoop::current();
416    CHECK(message_loop);
417    message_loop->PostTask(
418        FROM_HERE,
419        base::Bind(&ChromeAsyncSocket::ProcessSSLConnectDone,
420                   weak_ptr_factory_.GetWeakPtr(), status));
421  }
422  return true;
423}
424
425// STATE_TLS_CONNECTING -> STATE_TLS_OPEN
426// read_state_ == IDLE -> read_state_ == POSTED (via PostDoRead())
427// (maybe) write_state_ == IDLE -> write_state_ == POSTED (via
428// PostDoWrite())
429
430void ChromeAsyncSocket::ProcessSSLConnectDone(int status) {
431  DCHECK_NE(status, net::ERR_IO_PENDING);
432  DCHECK_EQ(state_, STATE_TLS_CONNECTING);
433  DCHECK_EQ(read_state_, IDLE);
434  DCHECK_EQ(read_start_, 0U);
435  DCHECK_EQ(read_end_, 0U);
436  DCHECK_EQ(write_state_, IDLE);
437  if (status != net::OK) {
438    DoNetErrorFromStatus(status);
439    DoClose();
440    return;
441  }
442  state_ = STATE_TLS_OPEN;
443  PostDoRead();
444  if (write_end_ > 0U) {
445    PostDoWrite();
446  }
447  SignalSSLConnected();
448}
449
450}  // namespace jingle_glue
451