1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "extensions/browser/api/cast_channel/cast_socket.h"
6
7#include <stdlib.h>
8#include <string.h>
9
10#include "base/bind.h"
11#include "base/callback_helpers.h"
12#include "base/format_macros.h"
13#include "base/lazy_instance.h"
14#include "base/numerics/safe_conversions.h"
15#include "base/strings/string_number_conversions.h"
16#include "base/strings/stringprintf.h"
17#include "base/sys_byteorder.h"
18#include "extensions/browser/api/cast_channel/cast_auth_util.h"
19#include "extensions/browser/api/cast_channel/cast_framer.h"
20#include "extensions/browser/api/cast_channel/cast_message_util.h"
21#include "extensions/browser/api/cast_channel/logger.h"
22#include "extensions/browser/api/cast_channel/logger_util.h"
23#include "extensions/common/api/cast_channel/cast_channel.pb.h"
24#include "net/base/address_list.h"
25#include "net/base/host_port_pair.h"
26#include "net/base/net_errors.h"
27#include "net/base/net_util.h"
28#include "net/cert/cert_verifier.h"
29#include "net/cert/x509_certificate.h"
30#include "net/http/transport_security_state.h"
31#include "net/socket/client_socket_factory.h"
32#include "net/socket/client_socket_handle.h"
33#include "net/socket/ssl_client_socket.h"
34#include "net/socket/stream_socket.h"
35#include "net/socket/tcp_client_socket.h"
36#include "net/ssl/ssl_config_service.h"
37#include "net/ssl/ssl_info.h"
38
39// Assumes |ip_endpoint_| of type net::IPEndPoint and |channel_auth_| of enum
40// type ChannelAuthType are available in the current scope.
41#define VLOG_WITH_CONNECTION(level) VLOG(level) << "[" << \
42    ip_endpoint_.ToString() << ", auth=" << channel_auth_ << "] "
43
44namespace {
45
46// The default keepalive delay.  On Linux, keepalives probes will be sent after
47// the socket is idle for this length of time, and the socket will be closed
48// after 9 failed probes.  So the total idle time before close is 10 *
49// kTcpKeepAliveDelaySecs.
50const int kTcpKeepAliveDelaySecs = 10;
51}  // namespace
52
53namespace extensions {
54
55static base::LazyInstance<BrowserContextKeyedAPIFactory<
56    ApiResourceManager<core_api::cast_channel::CastSocket> > > g_factory =
57    LAZY_INSTANCE_INITIALIZER;
58
59// static
60template <>
61BrowserContextKeyedAPIFactory<
62    ApiResourceManager<core_api::cast_channel::CastSocket> >*
63ApiResourceManager<core_api::cast_channel::CastSocket>::GetFactoryInstance() {
64  return g_factory.Pointer();
65}
66
67namespace core_api {
68namespace cast_channel {
69CastSocket::CastSocket(const std::string& owner_extension_id,
70                       const net::IPEndPoint& ip_endpoint,
71                       ChannelAuthType channel_auth,
72                       CastSocket::Delegate* delegate,
73                       net::NetLog* net_log,
74                       const base::TimeDelta& timeout,
75                       const scoped_refptr<Logger>& logger)
76    : ApiResource(owner_extension_id),
77      channel_id_(0),
78      ip_endpoint_(ip_endpoint),
79      channel_auth_(channel_auth),
80      delegate_(delegate),
81      net_log_(net_log),
82      logger_(logger),
83      connect_timeout_(timeout),
84      connect_timeout_timer_(new base::OneShotTimer<CastSocket>),
85      is_canceled_(false),
86      connect_state_(proto::CONN_STATE_NONE),
87      write_state_(proto::WRITE_STATE_NONE),
88      read_state_(proto::READ_STATE_NONE),
89      error_state_(CHANNEL_ERROR_NONE),
90      ready_state_(READY_STATE_NONE) {
91  DCHECK(net_log_);
92  DCHECK(channel_auth_ == CHANNEL_AUTH_TYPE_SSL ||
93         channel_auth_ == CHANNEL_AUTH_TYPE_SSL_VERIFIED);
94  net_log_source_.type = net::NetLog::SOURCE_SOCKET;
95  net_log_source_.id = net_log_->NextID();
96
97  // Buffer is reused across messages.
98  read_buffer_ = new net::GrowableIOBuffer();
99  read_buffer_->SetCapacity(MessageFramer::MessageHeader::max_message_size());
100  framer_.reset(new MessageFramer(read_buffer_));
101}
102
103CastSocket::~CastSocket() {
104  // Ensure that resources are freed but do not run pending callbacks to avoid
105  // any re-entrancy.
106  CloseInternal();
107}
108
109ReadyState CastSocket::ready_state() const {
110  return ready_state_;
111}
112
113ChannelError CastSocket::error_state() const {
114  return error_state_;
115}
116
117scoped_ptr<net::TCPClientSocket> CastSocket::CreateTcpSocket() {
118  net::AddressList addresses(ip_endpoint_);
119  return scoped_ptr<net::TCPClientSocket>(
120      new net::TCPClientSocket(addresses, net_log_, net_log_source_));
121  // Options cannot be set on the TCPClientSocket yet, because the
122  // underlying platform socket will not be created until Bind()
123  // or Connect() is called.
124}
125
126scoped_ptr<net::SSLClientSocket> CastSocket::CreateSslSocket(
127    scoped_ptr<net::StreamSocket> socket) {
128  net::SSLConfig ssl_config;
129  // If a peer cert was extracted in a previous attempt to connect, then
130  // whitelist that cert.
131  if (!peer_cert_.empty()) {
132    net::SSLConfig::CertAndStatus cert_and_status;
133    cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
134    cert_and_status.der_cert = peer_cert_;
135    ssl_config.allowed_bad_certs.push_back(cert_and_status);
136    logger_->LogSocketEvent(channel_id_, proto::SSL_CERT_WHITELISTED);
137  }
138
139  cert_verifier_.reset(net::CertVerifier::CreateDefault());
140  transport_security_state_.reset(new net::TransportSecurityState);
141  net::SSLClientSocketContext context;
142  // CertVerifier and TransportSecurityState are owned by us, not the
143  // context object.
144  context.cert_verifier = cert_verifier_.get();
145  context.transport_security_state = transport_security_state_.get();
146
147  scoped_ptr<net::ClientSocketHandle> connection(new net::ClientSocketHandle);
148  connection->SetSocket(socket.Pass());
149  net::HostPortPair host_and_port = net::HostPortPair::FromIPEndPoint(
150      ip_endpoint_);
151
152  return net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
153      connection.Pass(), host_and_port, ssl_config, context);
154}
155
156bool CastSocket::ExtractPeerCert(std::string* cert) {
157  DCHECK(cert);
158  DCHECK(peer_cert_.empty());
159  net::SSLInfo ssl_info;
160  if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get()) {
161    return false;
162  }
163
164  logger_->LogSocketEvent(channel_id_, proto::SSL_INFO_OBTAINED);
165
166  bool result = net::X509Certificate::GetDEREncoded(
167     ssl_info.cert->os_cert_handle(), cert);
168  if (result) {
169    VLOG_WITH_CONNECTION(1) << "Successfully extracted peer certificate: "
170                            << *cert;
171  }
172
173  logger_->LogSocketEventWithRv(
174      channel_id_, proto::DER_ENCODED_CERT_OBTAIN, result ? 1 : 0);
175  return result;
176}
177
178bool CastSocket::VerifyChallengeReply() {
179  AuthResult result = AuthenticateChallengeReply(*challenge_reply_, peer_cert_);
180  logger_->LogSocketChallengeReplyEvent(channel_id_, result);
181  return result.success();
182}
183
184void CastSocket::Connect(const net::CompletionCallback& callback) {
185  DCHECK(CalledOnValidThread());
186  VLOG_WITH_CONNECTION(1) << "Connect readyState = " << ready_state_;
187  if (ready_state_ != READY_STATE_NONE) {
188    logger_->LogSocketEventWithDetails(
189        channel_id_, proto::CONNECT_FAILED, "ReadyState not NONE");
190    callback.Run(net::ERR_CONNECTION_FAILED);
191    return;
192  }
193
194  connect_callback_ = callback;
195  SetReadyState(READY_STATE_CONNECTING);
196  SetConnectState(proto::CONN_STATE_TCP_CONNECT);
197
198  if (connect_timeout_.InMicroseconds() > 0) {
199    DCHECK(connect_timeout_callback_.IsCancelled());
200    connect_timeout_callback_.Reset(
201        base::Bind(&CastSocket::OnConnectTimeout, base::Unretained(this)));
202    GetTimer()->Start(FROM_HERE,
203                      connect_timeout_,
204                      connect_timeout_callback_.callback());
205  }
206  DoConnectLoop(net::OK);
207}
208
209void CastSocket::PostTaskToStartConnectLoop(int result) {
210  DCHECK(CalledOnValidThread());
211  DCHECK(connect_loop_callback_.IsCancelled());
212  connect_loop_callback_.Reset(
213      base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this), result));
214  base::MessageLoop::current()->PostTask(FROM_HERE,
215                                         connect_loop_callback_.callback());
216}
217
218void CastSocket::OnConnectTimeout() {
219  DCHECK(CalledOnValidThread());
220  // Stop all pending connection setup tasks and report back to the client.
221  is_canceled_ = true;
222  logger_->LogSocketEvent(channel_id_, proto::CONNECT_TIMED_OUT);
223  VLOG_WITH_CONNECTION(1) << "Timeout while establishing a connection.";
224  DoConnectCallback(net::ERR_TIMED_OUT);
225}
226
227// This method performs the state machine transitions for connection flow.
228// There are two entry points to this method:
229// 1. Connect method: this starts the flow
230// 2. Callback from network operations that finish asynchronously
231void CastSocket::DoConnectLoop(int result) {
232  connect_loop_callback_.Cancel();
233  if (is_canceled_) {
234    LOG(ERROR) << "CANCELLED - Aborting DoConnectLoop.";
235    return;
236  }
237  // Network operations can either finish synchronously or asynchronously.
238  // This method executes the state machine transitions in a loop so that
239  // correct state transitions happen even when network operations finish
240  // synchronously.
241  int rv = result;
242  do {
243    proto::ConnectionState state = connect_state_;
244    // Default to CONN_STATE_NONE, which breaks the processing loop if any
245    // handler fails to transition to another state to continue processing.
246    connect_state_ = proto::CONN_STATE_NONE;
247    switch (state) {
248      case proto::CONN_STATE_TCP_CONNECT:
249        rv = DoTcpConnect();
250        break;
251      case proto::CONN_STATE_TCP_CONNECT_COMPLETE:
252        rv = DoTcpConnectComplete(rv);
253        break;
254      case proto::CONN_STATE_SSL_CONNECT:
255        DCHECK_EQ(net::OK, rv);
256        rv = DoSslConnect();
257        break;
258      case proto::CONN_STATE_SSL_CONNECT_COMPLETE:
259        rv = DoSslConnectComplete(rv);
260        break;
261      case proto::CONN_STATE_AUTH_CHALLENGE_SEND:
262        rv = DoAuthChallengeSend();
263        break;
264      case proto::CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE:
265        rv = DoAuthChallengeSendComplete(rv);
266        break;
267      case proto::CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE:
268        rv = DoAuthChallengeReplyComplete(rv);
269        break;
270      default:
271        NOTREACHED() << "BUG in connect flow. Unknown state: " << state;
272        break;
273    }
274  } while (rv != net::ERR_IO_PENDING &&
275           connect_state_ != proto::CONN_STATE_NONE);
276  // Get out of the loop either when: // a. A network operation is pending, OR
277  // b. The Do* method called did not change state
278
279  // No state change occurred in do-while loop above. This means state has
280  // transitioned to NONE.
281  if (connect_state_ == proto::CONN_STATE_NONE) {
282    logger_->LogSocketConnectState(channel_id_, connect_state_);
283  }
284
285  // Connect loop is finished: if there is no pending IO invoke the callback.
286  if (rv != net::ERR_IO_PENDING) {
287    GetTimer()->Stop();
288    DoConnectCallback(rv);
289  }
290}
291
292int CastSocket::DoTcpConnect() {
293  DCHECK(connect_loop_callback_.IsCancelled());
294  VLOG_WITH_CONNECTION(1) << "DoTcpConnect";
295  SetConnectState(proto::CONN_STATE_TCP_CONNECT_COMPLETE);
296  tcp_socket_ = CreateTcpSocket();
297
298  int rv = tcp_socket_->Connect(
299      base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this)));
300  logger_->LogSocketEventWithRv(channel_id_, proto::TCP_SOCKET_CONNECT, rv);
301  return rv;
302}
303
304int CastSocket::DoTcpConnectComplete(int result) {
305  VLOG_WITH_CONNECTION(1) << "DoTcpConnectComplete: " << result;
306  if (result == net::OK) {
307    // Enable TCP protocol-level keep-alive.
308    bool result = tcp_socket_->SetKeepAlive(true, kTcpKeepAliveDelaySecs);
309    LOG_IF(WARNING, !result) << "Failed to SetKeepAlive.";
310    logger_->LogSocketEventWithRv(
311        channel_id_, proto::TCP_SOCKET_SET_KEEP_ALIVE, result ? 1 : 0);
312    SetConnectState(proto::CONN_STATE_SSL_CONNECT);
313  }
314  return result;
315}
316
317int CastSocket::DoSslConnect() {
318  DCHECK(connect_loop_callback_.IsCancelled());
319  VLOG_WITH_CONNECTION(1) << "DoSslConnect";
320  SetConnectState(proto::CONN_STATE_SSL_CONNECT_COMPLETE);
321  socket_ = CreateSslSocket(tcp_socket_.PassAs<net::StreamSocket>());
322
323  int rv = socket_->Connect(
324      base::Bind(&CastSocket::DoConnectLoop, base::Unretained(this)));
325  logger_->LogSocketEventWithRv(channel_id_, proto::SSL_SOCKET_CONNECT, rv);
326  return rv;
327}
328
329int CastSocket::DoSslConnectComplete(int result) {
330  VLOG_WITH_CONNECTION(1) << "DoSslConnectComplete: " << result;
331  if (result == net::ERR_CERT_AUTHORITY_INVALID &&
332      peer_cert_.empty() && ExtractPeerCert(&peer_cert_)) {
333    SetConnectState(proto::CONN_STATE_TCP_CONNECT);
334  } else if (result == net::OK &&
335             channel_auth_ == CHANNEL_AUTH_TYPE_SSL_VERIFIED) {
336    SetConnectState(proto::CONN_STATE_AUTH_CHALLENGE_SEND);
337  }
338  return result;
339}
340
341int CastSocket::DoAuthChallengeSend() {
342  VLOG_WITH_CONNECTION(1) << "DoAuthChallengeSend";
343  SetConnectState(proto::CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE);
344
345  CastMessage challenge_message;
346  CreateAuthChallengeMessage(&challenge_message);
347  VLOG_WITH_CONNECTION(1) << "Sending challenge: "
348                          << CastMessageToString(challenge_message);
349  // Post a task to send auth challenge so that DoWriteLoop is not nested inside
350  // DoConnectLoop. This is not strictly necessary but keeps the write loop
351  // code decoupled from connect loop code.
352  DCHECK(send_auth_challenge_callback_.IsCancelled());
353  send_auth_challenge_callback_.Reset(
354      base::Bind(&CastSocket::SendCastMessageInternal,
355                 base::Unretained(this),
356                 challenge_message,
357                 base::Bind(&CastSocket::DoAuthChallengeSendWriteComplete,
358                            base::Unretained(this))));
359  base::MessageLoop::current()->PostTask(
360      FROM_HERE,
361      send_auth_challenge_callback_.callback());
362  // Always return IO_PENDING since the result is always asynchronous.
363  return net::ERR_IO_PENDING;
364}
365
366void CastSocket::DoAuthChallengeSendWriteComplete(int result) {
367  send_auth_challenge_callback_.Cancel();
368  VLOG_WITH_CONNECTION(2) << "DoAuthChallengeSendWriteComplete: " << result;
369  DCHECK_GT(result, 0);
370  DCHECK_EQ(write_queue_.size(), 1UL);
371  PostTaskToStartConnectLoop(result);
372}
373
374int CastSocket::DoAuthChallengeSendComplete(int result) {
375  VLOG_WITH_CONNECTION(1) << "DoAuthChallengeSendComplete: " << result;
376  if (result < 0) {
377    return result;
378  }
379  SetConnectState(proto::CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE);
380
381  // Post a task to start read loop so that DoReadLoop is not nested inside
382  // DoConnectLoop. This is not strictly necessary but keeps the read loop
383  // code decoupled from connect loop code.
384  PostTaskToStartReadLoop();
385  // Always return IO_PENDING since the result is always asynchronous.
386  return net::ERR_IO_PENDING;
387}
388
389int CastSocket::DoAuthChallengeReplyComplete(int result) {
390  VLOG_WITH_CONNECTION(1) << "DoAuthChallengeReplyComplete: " << result;
391  if (result < 0) {
392    return result;
393  }
394  if (!VerifyChallengeReply()) {
395    return net::ERR_FAILED;
396  }
397  VLOG_WITH_CONNECTION(1) << "Auth challenge verification succeeded";
398  return net::OK;
399}
400
401void CastSocket::DoConnectCallback(int result) {
402  SetReadyState((result == net::OK) ? READY_STATE_OPEN : READY_STATE_CLOSED);
403  if (result == net::OK) {
404    SetErrorState(CHANNEL_ERROR_NONE);
405    PostTaskToStartReadLoop();
406    VLOG_WITH_CONNECTION(1) << "Calling Connect_Callback";
407    base::ResetAndReturn(&connect_callback_).Run(result);
408    return;
409  } else if (result == net::ERR_TIMED_OUT) {
410    SetErrorState(CHANNEL_ERROR_CONNECT_TIMEOUT);
411  } else {
412    SetErrorState(CHANNEL_ERROR_CONNECT_ERROR);
413  }
414  // Calls the connect callback.
415  CloseWithError();
416}
417
418void CastSocket::Close(const net::CompletionCallback& callback) {
419  CloseInternal();
420  RunPendingCallbacksOnClose();
421  // Run this callback last.  It may delete the socket.
422  callback.Run(net::OK);
423}
424
425void CastSocket::CloseInternal() {
426  // TODO(mfoltz): Enforce this when CastChannelAPITest is rewritten to create
427  // and free sockets on the same thread.  crbug.com/398242
428  // DCHECK(CalledOnValidThread());
429  if (ready_state_ == READY_STATE_CLOSED) {
430    return;
431  }
432
433  VLOG_WITH_CONNECTION(1) << "Close ReadyState = " << ready_state_;
434  tcp_socket_.reset();
435  socket_.reset();
436  cert_verifier_.reset();
437  transport_security_state_.reset();
438  GetTimer()->Stop();
439
440  // Cancel callbacks that we queued ourselves to re-enter the connect or read
441  // loops.
442  connect_loop_callback_.Cancel();
443  send_auth_challenge_callback_.Cancel();
444  read_loop_callback_.Cancel();
445  connect_timeout_callback_.Cancel();
446  SetReadyState(READY_STATE_CLOSED);
447  logger_->LogSocketEvent(channel_id_, proto::SOCKET_CLOSED);
448}
449
450void CastSocket::RunPendingCallbacksOnClose() {
451  DCHECK_EQ(ready_state_, READY_STATE_CLOSED);
452  if (!connect_callback_.is_null()) {
453    connect_callback_.Run(net::ERR_CONNECTION_FAILED);
454    connect_callback_.Reset();
455  }
456  for (; !write_queue_.empty(); write_queue_.pop()) {
457    net::CompletionCallback& callback = write_queue_.front().callback;
458    callback.Run(net::ERR_FAILED);
459    callback.Reset();
460  }
461}
462
463void CastSocket::SendMessage(const MessageInfo& message,
464                             const net::CompletionCallback& callback) {
465  DCHECK(CalledOnValidThread());
466  if (ready_state_ != READY_STATE_OPEN) {
467    logger_->LogSocketEventForMessage(channel_id_,
468                                      proto::SEND_MESSAGE_FAILED,
469                                      message.namespace_,
470                                      "Ready state not OPEN");
471    callback.Run(net::ERR_FAILED);
472    return;
473  }
474  CastMessage message_proto;
475  if (!MessageInfoToCastMessage(message, &message_proto)) {
476    logger_->LogSocketEventForMessage(channel_id_,
477                                      proto::SEND_MESSAGE_FAILED,
478                                      message.namespace_,
479                                      "Failed to convert to CastMessage");
480    callback.Run(net::ERR_FAILED);
481    return;
482  }
483  SendCastMessageInternal(message_proto, callback);
484}
485
486void CastSocket::SendCastMessageInternal(
487    const CastMessage& message,
488    const net::CompletionCallback& callback) {
489  WriteRequest write_request(callback);
490  if (!write_request.SetContent(message)) {
491    logger_->LogSocketEventForMessage(channel_id_,
492                                      proto::SEND_MESSAGE_FAILED,
493                                      message.namespace_(),
494                                      "SetContent failed");
495    callback.Run(net::ERR_FAILED);
496    return;
497  }
498
499  write_queue_.push(write_request);
500  logger_->LogSocketEventForMessage(
501      channel_id_,
502      proto::MESSAGE_ENQUEUED,
503      message.namespace_(),
504      base::StringPrintf("Queue size: %" PRIuS, write_queue_.size()));
505  if (write_state_ == proto::WRITE_STATE_NONE) {
506    SetWriteState(proto::WRITE_STATE_WRITE);
507    DoWriteLoop(net::OK);
508  }
509}
510
511void CastSocket::DoWriteLoop(int result) {
512  DCHECK(CalledOnValidThread());
513  VLOG_WITH_CONNECTION(1) << "DoWriteLoop queue size: " << write_queue_.size();
514
515  if (write_queue_.empty()) {
516    SetWriteState(proto::WRITE_STATE_NONE);
517    return;
518  }
519
520  // Network operations can either finish synchronously or asynchronously.
521  // This method executes the state machine transitions in a loop so that
522  // write state transitions happen even when network operations finish
523  // synchronously.
524  int rv = result;
525  do {
526    proto::WriteState state = write_state_;
527    write_state_ = proto::WRITE_STATE_NONE;
528    switch (state) {
529      case proto::WRITE_STATE_WRITE:
530        rv = DoWrite();
531        break;
532      case proto::WRITE_STATE_WRITE_COMPLETE:
533        rv = DoWriteComplete(rv);
534        break;
535      case proto::WRITE_STATE_DO_CALLBACK:
536        rv = DoWriteCallback();
537        break;
538      case proto::WRITE_STATE_ERROR:
539        rv = DoWriteError(rv);
540        break;
541      default:
542        NOTREACHED() << "BUG in write flow. Unknown state: " << state;
543        break;
544    }
545  } while (!write_queue_.empty() && rv != net::ERR_IO_PENDING &&
546           write_state_ != proto::WRITE_STATE_NONE);
547
548  // No state change occurred in do-while loop above. This means state has
549  // transitioned to NONE.
550  if (write_state_ == proto::WRITE_STATE_NONE) {
551    logger_->LogSocketWriteState(channel_id_, write_state_);
552  }
553
554  // If write loop is done because the queue is empty then set write
555  // state to NONE
556  if (write_queue_.empty()) {
557    SetWriteState(proto::WRITE_STATE_NONE);
558  }
559
560  // Write loop is done - if the result is ERR_FAILED then close with error.
561  if (rv == net::ERR_FAILED) {
562    CloseWithError();
563  }
564}
565
566int CastSocket::DoWrite() {
567  DCHECK(!write_queue_.empty());
568  WriteRequest& request = write_queue_.front();
569
570  VLOG_WITH_CONNECTION(2) << "WriteData byte_count = "
571                          << request.io_buffer->size() << " bytes_written "
572                          << request.io_buffer->BytesConsumed();
573
574  SetWriteState(proto::WRITE_STATE_WRITE_COMPLETE);
575
576  int rv = socket_->Write(
577      request.io_buffer.get(),
578      request.io_buffer->BytesRemaining(),
579      base::Bind(&CastSocket::DoWriteLoop, base::Unretained(this)));
580  logger_->LogSocketEventWithRv(channel_id_, proto::SOCKET_WRITE, rv);
581
582  return rv;
583}
584
585int CastSocket::DoWriteComplete(int result) {
586  DCHECK(!write_queue_.empty());
587  if (result <= 0) {  // NOTE that 0 also indicates an error
588    SetErrorState(CHANNEL_ERROR_SOCKET_ERROR);
589    SetWriteState(proto::WRITE_STATE_ERROR);
590    return result == 0 ? net::ERR_FAILED : result;
591  }
592
593  // Some bytes were successfully written
594  WriteRequest& request = write_queue_.front();
595  scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer;
596  io_buffer->DidConsume(result);
597  if (io_buffer->BytesRemaining() == 0) {  // Message fully sent
598    SetWriteState(proto::WRITE_STATE_DO_CALLBACK);
599  } else {
600    SetWriteState(proto::WRITE_STATE_WRITE);
601  }
602
603  return net::OK;
604}
605
606int CastSocket::DoWriteCallback() {
607  DCHECK(!write_queue_.empty());
608
609  SetWriteState(proto::WRITE_STATE_WRITE);
610
611  WriteRequest& request = write_queue_.front();
612  int bytes_consumed = request.io_buffer->BytesConsumed();
613  logger_->LogSocketEventForMessage(
614      channel_id_,
615      proto::MESSAGE_WRITTEN,
616      request.message_namespace,
617      base::StringPrintf("Bytes: %d", bytes_consumed));
618  request.callback.Run(bytes_consumed);
619  write_queue_.pop();
620  return net::OK;
621}
622
623int CastSocket::DoWriteError(int result) {
624  DCHECK(!write_queue_.empty());
625  DCHECK_LT(result, 0);
626
627  // If inside connection flow, then there should be exactly one item in
628  // the write queue.
629  if (ready_state_ == READY_STATE_CONNECTING) {
630    write_queue_.pop();
631    DCHECK(write_queue_.empty());
632    PostTaskToStartConnectLoop(result);
633    // Connect loop will handle the error. Return net::OK so that write flow
634    // does not try to report error also.
635    return net::OK;
636  }
637
638  while (!write_queue_.empty()) {
639    WriteRequest& request = write_queue_.front();
640    request.callback.Run(result);
641    write_queue_.pop();
642  }
643  return net::ERR_FAILED;
644}
645
646void CastSocket::PostTaskToStartReadLoop() {
647  DCHECK(CalledOnValidThread());
648  DCHECK(read_loop_callback_.IsCancelled());
649  read_loop_callback_.Reset(
650      base::Bind(&CastSocket::StartReadLoop, base::Unretained(this)));
651  base::MessageLoop::current()->PostTask(FROM_HERE,
652                                         read_loop_callback_.callback());
653}
654
655void CastSocket::StartReadLoop() {
656  read_loop_callback_.Cancel();
657  // Read loop would have already been started if read state is not NONE
658  if (read_state_ == proto::READ_STATE_NONE) {
659    SetReadState(proto::READ_STATE_READ);
660    DoReadLoop(net::OK);
661  }
662}
663
664void CastSocket::DoReadLoop(int result) {
665  DCHECK(CalledOnValidThread());
666  // Network operations can either finish synchronously or asynchronously.
667  // This method executes the state machine transitions in a loop so that
668  // write state transitions happen even when network operations finish
669  // synchronously.
670  int rv = result;
671  do {
672    proto::ReadState state = read_state_;
673    read_state_ = proto::READ_STATE_NONE;
674
675    switch (state) {
676      case proto::READ_STATE_READ:
677        rv = DoRead();
678        break;
679      case proto::READ_STATE_READ_COMPLETE:
680        rv = DoReadComplete(rv);
681        break;
682      case proto::READ_STATE_DO_CALLBACK:
683        rv = DoReadCallback();
684        break;
685      case proto::READ_STATE_ERROR:
686        rv = DoReadError(rv);
687        DCHECK_EQ(read_state_, proto::READ_STATE_NONE);
688        break;
689      default:
690        NOTREACHED() << "BUG in read flow. Unknown state: " << state;
691        break;
692    }
693  } while (rv != net::ERR_IO_PENDING && read_state_ != proto::READ_STATE_NONE);
694
695  // No state change occurred in do-while loop above. This means state has
696  // transitioned to NONE.
697  if (read_state_ == proto::READ_STATE_NONE) {
698    logger_->LogSocketReadState(channel_id_, read_state_);
699  }
700
701  if (rv == net::ERR_FAILED) {
702    if (ready_state_ == READY_STATE_CONNECTING) {
703      // Read errors during the handshake should notify the caller via the
704      // connect callback.  This will also send error status via the OnError
705      // delegate.
706      PostTaskToStartConnectLoop(net::ERR_FAILED);
707    } else {
708      // Connection is already established.  Close and send error status via the
709      // OnError delegate.
710      CloseWithError();
711    }
712  }
713}
714
715int CastSocket::DoRead() {
716  SetReadState(proto::READ_STATE_READ_COMPLETE);
717
718  // Determine how many bytes need to be read.
719  size_t num_bytes_to_read = framer_->BytesRequested();
720
721  // Read up to num_bytes_to_read into |current_read_buffer_|.
722  int rv = socket_->Read(
723      read_buffer_.get(),
724      base::checked_cast<uint32>(num_bytes_to_read),
725      base::Bind(&CastSocket::DoReadLoop, base::Unretained(this)));
726  logger_->LogSocketEventWithRv(channel_id_, proto::SOCKET_READ, rv);
727
728  return rv;
729}
730
731int CastSocket::DoReadComplete(int result) {
732  VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result;
733
734  if (result <= 0) {  // 0 means EOF: the peer closed the socket
735    VLOG_WITH_CONNECTION(1) << "Read error, peer closed the socket";
736    SetErrorState(CHANNEL_ERROR_SOCKET_ERROR);
737    SetReadState(proto::READ_STATE_ERROR);
738    return result == 0 ? net::ERR_FAILED : result;
739  }
740
741  size_t message_size;
742  DCHECK(current_message_.get() == NULL);
743  current_message_ = framer_->Ingest(result, &message_size, &error_state_);
744  if (current_message_.get()) {
745    DCHECK_EQ(error_state_, CHANNEL_ERROR_NONE);
746    DCHECK_GT(message_size, static_cast<size_t>(0));
747    logger_->LogSocketEventForMessage(
748        channel_id_,
749        proto::MESSAGE_READ,
750        current_message_->namespace_(),
751        base::StringPrintf("Message size: %u",
752                           static_cast<uint32>(message_size)));
753    SetReadState(proto::READ_STATE_DO_CALLBACK);
754  } else if (error_state_ != CHANNEL_ERROR_NONE) {
755    DCHECK(current_message_.get() == NULL);
756    SetReadState(proto::READ_STATE_ERROR);
757  } else {
758    DCHECK(current_message_.get() == NULL);
759    SetReadState(proto::READ_STATE_READ);
760  }
761  return net::OK;
762}
763
764int CastSocket::DoReadCallback() {
765  SetReadState(proto::READ_STATE_READ);
766  const CastMessage& message = *current_message_;
767  if (ready_state_ == READY_STATE_CONNECTING) {
768    if (IsAuthMessage(message)) {
769      challenge_reply_.reset(new CastMessage(message));
770      logger_->LogSocketEvent(channel_id_, proto::RECEIVED_CHALLENGE_REPLY);
771      PostTaskToStartConnectLoop(net::OK);
772      current_message_.reset();
773      return net::OK;
774    } else {
775      SetReadState(proto::READ_STATE_ERROR);
776      SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
777      current_message_.reset();
778      return net::ERR_INVALID_RESPONSE;
779    }
780  }
781
782  MessageInfo message_info;
783  if (!CastMessageToMessageInfo(message, &message_info)) {
784    current_message_.reset();
785    SetReadState(proto::READ_STATE_ERROR);
786    SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
787    return net::ERR_INVALID_RESPONSE;
788  }
789
790  logger_->LogSocketEventForMessage(channel_id_,
791                                    proto::NOTIFY_ON_MESSAGE,
792                                    message.namespace_(),
793                                    std::string());
794  delegate_->OnMessage(this, message_info);
795  current_message_.reset();
796
797  return net::OK;
798}
799
800int CastSocket::DoReadError(int result) {
801  DCHECK_LE(result, 0);
802  return net::ERR_FAILED;
803}
804
805void CastSocket::CloseWithError() {
806  DCHECK(CalledOnValidThread());
807  CloseInternal();
808  RunPendingCallbacksOnClose();
809  if (delegate_) {
810    logger_->LogSocketEvent(channel_id_, proto::NOTIFY_ON_ERROR);
811    delegate_->OnError(this, error_state_, logger_->GetLastErrors(channel_id_));
812  }
813}
814
815std::string CastSocket::CastUrl() const {
816  return ((channel_auth_ == CHANNEL_AUTH_TYPE_SSL_VERIFIED) ?
817          "casts://" : "cast://") + ip_endpoint_.ToString();
818}
819
820bool CastSocket::CalledOnValidThread() const {
821  return thread_checker_.CalledOnValidThread();
822}
823
824base::Timer* CastSocket::GetTimer() {
825  return connect_timeout_timer_.get();
826}
827
828void CastSocket::SetConnectState(proto::ConnectionState connect_state) {
829  if (connect_state_ != connect_state) {
830    connect_state_ = connect_state;
831    logger_->LogSocketConnectState(channel_id_, connect_state_);
832  }
833}
834
835void CastSocket::SetReadyState(ReadyState ready_state) {
836  if (ready_state_ != ready_state) {
837    ready_state_ = ready_state;
838    logger_->LogSocketReadyState(channel_id_, ReadyStateToProto(ready_state_));
839  }
840}
841
842void CastSocket::SetErrorState(ChannelError error_state) {
843  if (error_state_ != error_state) {
844    error_state_ = error_state;
845    logger_->LogSocketErrorState(channel_id_, ErrorStateToProto(error_state_));
846  }
847}
848
849void CastSocket::SetReadState(proto::ReadState read_state) {
850  if (read_state_ != read_state) {
851    read_state_ = read_state;
852    logger_->LogSocketReadState(channel_id_, read_state_);
853  }
854}
855
856void CastSocket::SetWriteState(proto::WriteState write_state) {
857  if (write_state_ != write_state) {
858    write_state_ = write_state;
859    logger_->LogSocketWriteState(channel_id_, write_state_);
860  }
861}
862
863CastSocket::WriteRequest::WriteRequest(const net::CompletionCallback& callback)
864    : callback(callback) {
865}
866
867bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) {
868  DCHECK(!io_buffer.get());
869  std::string message_data;
870  if (!MessageFramer::Serialize(message_proto, &message_data)) {
871    return false;
872  }
873  message_namespace = message_proto.namespace_();
874  io_buffer = new net::DrainableIOBuffer(new net::StringIOBuffer(message_data),
875                                         message_data.size());
876  return true;
877}
878
879CastSocket::WriteRequest::~WriteRequest() {
880}
881
882}  // namespace cast_channel
883}  // namespace core_api
884}  // namespace extensions
885
886#undef VLOG_WITH_CONNECTION
887