1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "extensions/browser/api/cast_channel/cast_transport.h"
6
7#include <string>
8
9#include "base/bind.h"
10#include "base/format_macros.h"
11#include "base/numerics/safe_conversions.h"
12#include "base/strings/stringprintf.h"
13#include "extensions/browser/api/cast_channel/cast_framer.h"
14#include "extensions/browser/api/cast_channel/cast_message_util.h"
15#include "extensions/browser/api/cast_channel/logger.h"
16#include "extensions/browser/api/cast_channel/logger_util.h"
17#include "extensions/common/api/cast_channel/cast_channel.pb.h"
18#include "net/base/net_errors.h"
19
20#define VLOG_WITH_CONNECTION(level)                       \
21  VLOG(level) << "[" << socket_->ip_endpoint().ToString() \
22              << ", auth=" << socket_->channel_auth() << "] "
23
24namespace extensions {
25namespace core_api {
26namespace cast_channel {
27
28CastTransport::CastTransport(CastSocketInterface* socket,
29                             Delegate* read_delegate,
30                             scoped_refptr<Logger> logger)
31    : socket_(socket),
32      read_delegate_(read_delegate),
33      write_state_(WRITE_STATE_NONE),
34      read_state_(READ_STATE_NONE),
35      logger_(logger) {
36  DCHECK(socket);
37  DCHECK(read_delegate);
38
39  // Buffer is reused across messages to minimize unnecessary buffer
40  // [re]allocations.
41  read_buffer_ = new net::GrowableIOBuffer();
42  read_buffer_->SetCapacity(MessageFramer::MessageHeader::max_message_size());
43  framer_.reset(new MessageFramer(read_buffer_));
44}
45
46CastTransport::~CastTransport() {
47  DCHECK(thread_checker_.CalledOnValidThread());
48  FlushWriteQueue();
49}
50
51// static
52proto::ReadState CastTransport::ReadStateToProto(
53    CastTransport::ReadState state) {
54  switch (state) {
55    case CastTransport::READ_STATE_NONE:
56      return proto::READ_STATE_NONE;
57    case CastTransport::READ_STATE_READ:
58      return proto::READ_STATE_READ;
59    case CastTransport::READ_STATE_READ_COMPLETE:
60      return proto::READ_STATE_READ_COMPLETE;
61    case CastTransport::READ_STATE_DO_CALLBACK:
62      return proto::READ_STATE_DO_CALLBACK;
63    case CastTransport::READ_STATE_ERROR:
64      return proto::READ_STATE_ERROR;
65    default:
66      NOTREACHED();
67      return proto::READ_STATE_NONE;
68  }
69}
70
71// static
72proto::WriteState CastTransport::WriteStateToProto(
73    CastTransport::WriteState state) {
74  switch (state) {
75    case CastTransport::WRITE_STATE_NONE:
76      return proto::WRITE_STATE_NONE;
77    case CastTransport::WRITE_STATE_WRITE:
78      return proto::WRITE_STATE_WRITE;
79    case CastTransport::WRITE_STATE_WRITE_COMPLETE:
80      return proto::WRITE_STATE_WRITE_COMPLETE;
81    case CastTransport::WRITE_STATE_DO_CALLBACK:
82      return proto::WRITE_STATE_DO_CALLBACK;
83    case CastTransport::WRITE_STATE_ERROR:
84      return proto::WRITE_STATE_ERROR;
85    default:
86      NOTREACHED();
87      return proto::WRITE_STATE_NONE;
88  }
89}
90
91// static
92proto::ErrorState CastTransport::ErrorStateToProto(ChannelError state) {
93  switch (state) {
94    case CHANNEL_ERROR_NONE:
95      return proto::CHANNEL_ERROR_NONE;
96    case CHANNEL_ERROR_CHANNEL_NOT_OPEN:
97      return proto::CHANNEL_ERROR_CHANNEL_NOT_OPEN;
98    case CHANNEL_ERROR_AUTHENTICATION_ERROR:
99      return proto::CHANNEL_ERROR_AUTHENTICATION_ERROR;
100    case CHANNEL_ERROR_CONNECT_ERROR:
101      return proto::CHANNEL_ERROR_CONNECT_ERROR;
102    case CHANNEL_ERROR_SOCKET_ERROR:
103      return proto::CHANNEL_ERROR_SOCKET_ERROR;
104    case CHANNEL_ERROR_TRANSPORT_ERROR:
105      return proto::CHANNEL_ERROR_TRANSPORT_ERROR;
106    case CHANNEL_ERROR_INVALID_MESSAGE:
107      return proto::CHANNEL_ERROR_INVALID_MESSAGE;
108    case CHANNEL_ERROR_INVALID_CHANNEL_ID:
109      return proto::CHANNEL_ERROR_INVALID_CHANNEL_ID;
110    case CHANNEL_ERROR_CONNECT_TIMEOUT:
111      return proto::CHANNEL_ERROR_CONNECT_TIMEOUT;
112    case CHANNEL_ERROR_UNKNOWN:
113      return proto::CHANNEL_ERROR_UNKNOWN;
114    default:
115      NOTREACHED();
116      return proto::CHANNEL_ERROR_NONE;
117  }
118}
119
120void CastTransport::FlushWriteQueue() {
121  for (; !write_queue_.empty(); write_queue_.pop()) {
122    net::CompletionCallback& callback = write_queue_.front().callback;
123    callback.Run(net::ERR_FAILED);
124    callback.Reset();
125  }
126}
127
128void CastTransport::SendMessage(const CastMessage& message,
129                                const net::CompletionCallback& callback) {
130  DCHECK(thread_checker_.CalledOnValidThread());
131  std::string serialized_message;
132  if (!MessageFramer::Serialize(message, &serialized_message)) {
133    logger_->LogSocketEventForMessage(socket_->id(),
134                                      proto::SEND_MESSAGE_FAILED,
135                                      message.namespace_(),
136                                      "Error when serializing message.");
137    callback.Run(net::ERR_FAILED);
138    return;
139  }
140  WriteRequest write_request(
141      message.namespace_(), serialized_message, callback);
142
143  write_queue_.push(write_request);
144  logger_->LogSocketEventForMessage(
145      socket_->id(),
146      proto::MESSAGE_ENQUEUED,
147      message.namespace_(),
148      base::StringPrintf("Queue size: %" PRIuS, write_queue_.size()));
149  if (write_state_ == WRITE_STATE_NONE) {
150    SetWriteState(WRITE_STATE_WRITE);
151    OnWriteResult(net::OK);
152  }
153}
154
155CastTransport::WriteRequest::WriteRequest(
156    const std::string& namespace_,
157    const std::string& payload,
158    const net::CompletionCallback& callback)
159    : message_namespace(namespace_), callback(callback) {
160  VLOG(2) << "WriteRequest size: " << payload.size();
161  io_buffer = new net::DrainableIOBuffer(new net::StringIOBuffer(payload),
162                                         payload.size());
163}
164
165CastTransport::WriteRequest::~WriteRequest() {
166}
167
168void CastTransport::SetReadState(ReadState read_state) {
169  if (read_state_ != read_state) {
170    read_state_ = read_state;
171    logger_->LogSocketReadState(socket_->id(), ReadStateToProto(read_state_));
172  }
173}
174
175void CastTransport::SetWriteState(WriteState write_state) {
176  if (write_state_ != write_state) {
177    write_state_ = write_state;
178    logger_->LogSocketWriteState(socket_->id(),
179                                 WriteStateToProto(write_state_));
180  }
181}
182
183void CastTransport::SetErrorState(ChannelError error_state) {
184  if (error_state_ != error_state) {
185    error_state_ = error_state;
186    logger_->LogSocketErrorState(socket_->id(),
187                                 ErrorStateToProto(error_state_));
188  }
189}
190
191void CastTransport::OnWriteResult(int result) {
192  DCHECK(thread_checker_.CalledOnValidThread());
193  VLOG_WITH_CONNECTION(1) << "OnWriteResult queue size: "
194                          << write_queue_.size();
195
196  if (write_queue_.empty()) {
197    SetWriteState(WRITE_STATE_NONE);
198    return;
199  }
200
201  // Network operations can either finish synchronously or asynchronously.
202  // This method executes the state machine transitions in a loop so that
203  // write state transitions happen even when network operations finish
204  // synchronously.
205  int rv = result;
206  do {
207    WriteState state = write_state_;
208    write_state_ = WRITE_STATE_NONE;
209    switch (state) {
210      case WRITE_STATE_WRITE:
211        rv = DoWrite();
212        break;
213      case WRITE_STATE_WRITE_COMPLETE:
214        rv = DoWriteComplete(rv);
215        break;
216      case WRITE_STATE_DO_CALLBACK:
217        rv = DoWriteCallback();
218        break;
219      case WRITE_STATE_ERROR:
220        rv = DoWriteError(rv);
221        break;
222      default:
223        NOTREACHED() << "BUG in write flow. Unknown state: " << state;
224        break;
225    }
226  } while (!write_queue_.empty() && rv != net::ERR_IO_PENDING &&
227           write_state_ != WRITE_STATE_NONE);
228
229  // No state change occurred in do-while loop above. This means state has
230  // transitioned to NONE.
231  if (write_state_ == WRITE_STATE_NONE) {
232    logger_->LogSocketWriteState(socket_->id(),
233                                 WriteStateToProto(write_state_));
234  }
235
236  // If write loop is done because the queue is empty then set write
237  // state to NONE
238  if (write_queue_.empty()) {
239    SetWriteState(WRITE_STATE_NONE);
240  }
241
242  // Write loop is done - if the result is ERR_FAILED then close with error.
243  if (rv == net::ERR_FAILED) {
244    DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
245    socket_->CloseWithError(error_state_);
246    FlushWriteQueue();
247  }
248}
249
250int CastTransport::DoWrite() {
251  DCHECK(!write_queue_.empty());
252  WriteRequest& request = write_queue_.front();
253
254  VLOG_WITH_CONNECTION(2) << "WriteData byte_count = "
255                          << request.io_buffer->size() << " bytes_written "
256                          << request.io_buffer->BytesConsumed();
257
258  SetWriteState(WRITE_STATE_WRITE_COMPLETE);
259
260  int rv = socket_->Write(
261      request.io_buffer.get(),
262      request.io_buffer->BytesRemaining(),
263      base::Bind(&CastTransport::OnWriteResult, base::Unretained(this)));
264  logger_->LogSocketEventWithRv(socket_->id(), proto::SOCKET_WRITE, rv);
265
266  return rv;
267}
268
269int CastTransport::DoWriteComplete(int result) {
270  VLOG_WITH_CONNECTION(2) << "DoWriteComplete result=" << result;
271  DCHECK(!write_queue_.empty());
272  if (result <= 0) {  // NOTE that 0 also indicates an error
273    SetErrorState(CHANNEL_ERROR_TRANSPORT_ERROR);
274    SetWriteState(WRITE_STATE_ERROR);
275    return result == 0 ? net::ERR_FAILED : result;
276  }
277
278  // Some bytes were successfully written
279  WriteRequest& request = write_queue_.front();
280  scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer;
281  io_buffer->DidConsume(result);
282  if (io_buffer->BytesRemaining() == 0) {  // Message fully sent
283    SetWriteState(WRITE_STATE_DO_CALLBACK);
284  } else {
285    SetWriteState(WRITE_STATE_WRITE);
286  }
287
288  return net::OK;
289}
290
291int CastTransport::DoWriteCallback() {
292  VLOG_WITH_CONNECTION(2) << "DoWriteCallback";
293  DCHECK(!write_queue_.empty());
294
295  SetWriteState(WRITE_STATE_WRITE);
296
297  WriteRequest& request = write_queue_.front();
298  int bytes_consumed = request.io_buffer->BytesConsumed();
299  logger_->LogSocketEventForMessage(
300      socket_->id(),
301      proto::MESSAGE_WRITTEN,
302      request.message_namespace,
303      base::StringPrintf("Bytes: %d", bytes_consumed));
304  request.callback.Run(net::OK);
305  write_queue_.pop();
306  return net::OK;
307}
308
309int CastTransport::DoWriteError(int result) {
310  VLOG_WITH_CONNECTION(2) << "DoWriteError result=" << result;
311  DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
312  DCHECK_LT(result, 0);
313  return net::ERR_FAILED;
314}
315
316void CastTransport::StartReadLoop() {
317  DCHECK(thread_checker_.CalledOnValidThread());
318  // Read loop would have already been started if read state is not NONE
319  if (read_state_ == READ_STATE_NONE) {
320    SetReadState(READ_STATE_READ);
321    OnReadResult(net::OK);
322  }
323}
324
325void CastTransport::OnReadResult(int result) {
326  DCHECK(thread_checker_.CalledOnValidThread());
327  // Network operations can either finish synchronously or asynchronously.
328  // This method executes the state machine transitions in a loop so that
329  // write state transitions happen even when network operations finish
330  // synchronously.
331  int rv = result;
332  do {
333    ReadState state = read_state_;
334    read_state_ = READ_STATE_NONE;
335
336    switch (state) {
337      case READ_STATE_READ:
338        rv = DoRead();
339        break;
340      case READ_STATE_READ_COMPLETE:
341        rv = DoReadComplete(rv);
342        break;
343      case READ_STATE_DO_CALLBACK:
344        rv = DoReadCallback();
345        break;
346      case READ_STATE_ERROR:
347        rv = DoReadError(rv);
348        DCHECK_EQ(read_state_, READ_STATE_NONE);
349        break;
350      default:
351        NOTREACHED() << "BUG in read flow. Unknown state: " << state;
352        break;
353    }
354  } while (rv != net::ERR_IO_PENDING && read_state_ != READ_STATE_NONE);
355
356  // No state change occurred in do-while loop above. This means state has
357  // transitioned to NONE.
358  if (read_state_ == READ_STATE_NONE) {
359    logger_->LogSocketReadState(socket_->id(), ReadStateToProto(read_state_));
360  }
361
362  if (rv == net::ERR_FAILED) {
363    DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
364    socket_->CloseWithError(error_state_);
365    FlushWriteQueue();
366    read_delegate_->OnError(
367        socket_, error_state_, logger_->GetLastErrors(socket_->id()));
368  }
369}
370
371int CastTransport::DoRead() {
372  VLOG_WITH_CONNECTION(2) << "DoRead";
373  SetReadState(READ_STATE_READ_COMPLETE);
374
375  // Determine how many bytes need to be read.
376  size_t num_bytes_to_read = framer_->BytesRequested();
377
378  // Read up to num_bytes_to_read into |current_read_buffer_|.
379  int rv = socket_->Read(
380      read_buffer_.get(),
381      base::checked_cast<uint32>(num_bytes_to_read),
382      base::Bind(&CastTransport::OnReadResult, base::Unretained(this)));
383
384  return rv;
385}
386
387int CastTransport::DoReadComplete(int result) {
388  VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result;
389
390  if (result <= 0) {
391    SetErrorState(CHANNEL_ERROR_TRANSPORT_ERROR);
392    SetReadState(READ_STATE_ERROR);
393    return result == 0 ? net::ERR_FAILED : result;
394  }
395
396  size_t message_size;
397  DCHECK(current_message_.get() == NULL);
398  current_message_ = framer_->Ingest(result, &message_size, &error_state_);
399  if (current_message_.get()) {
400    DCHECK_EQ(error_state_, CHANNEL_ERROR_NONE);
401    DCHECK_GT(message_size, static_cast<size_t>(0));
402    logger_->LogSocketEventForMessage(
403        socket_->id(),
404        proto::MESSAGE_READ,
405        current_message_->namespace_(),
406        base::StringPrintf("Message size: %u",
407                           static_cast<uint32>(message_size)));
408    SetReadState(READ_STATE_DO_CALLBACK);
409  } else if (error_state_ != CHANNEL_ERROR_NONE) {
410    DCHECK(current_message_.get() == NULL);
411    SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
412    SetReadState(READ_STATE_ERROR);
413  } else {
414    DCHECK(current_message_.get() == NULL);
415    SetReadState(READ_STATE_READ);
416  }
417  return net::OK;
418}
419
420int CastTransport::DoReadCallback() {
421  VLOG_WITH_CONNECTION(2) << "DoReadCallback";
422  SetReadState(READ_STATE_READ);
423  if (!IsCastMessageValid(*current_message_)) {
424    SetReadState(READ_STATE_ERROR);
425    SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
426    return net::ERR_INVALID_RESPONSE;
427  }
428  logger_->LogSocketEventForMessage(socket_->id(),
429                                    proto::NOTIFY_ON_MESSAGE,
430                                    current_message_->namespace_(),
431                                    std::string());
432  read_delegate_->OnMessage(socket_, *current_message_);
433  current_message_.reset();
434  return net::OK;
435}
436
437int CastTransport::DoReadError(int result) {
438  VLOG_WITH_CONNECTION(2) << "DoReadError";
439  DCHECK_NE(CHANNEL_ERROR_NONE, error_state_);
440  DCHECK_LE(result, 0);
441  return net::ERR_FAILED;
442}
443
444}  // namespace cast_channel
445}  // namespace core_api
446}  // namespace extensions
447