1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "google_apis/gcm/engine/connection_handler_impl.h"
6
7#include "base/message_loop/message_loop.h"
8#include "google/protobuf/io/coded_stream.h"
9#include "google_apis/gcm/base/mcs_util.h"
10#include "google_apis/gcm/base/socket_stream.h"
11#include "google_apis/gcm/protocol/mcs.pb.h"
12#include "net/base/net_errors.h"
13#include "net/socket/stream_socket.h"
14
15using namespace google::protobuf::io;
16
17namespace gcm {
18
19namespace {
20
21// # of bytes a MCS version packet consumes.
22const int kVersionPacketLen = 1;
23// # of bytes a tag packet consumes.
24const int kTagPacketLen = 1;
25// Max # of bytes a length packet consumes. A Varint32 can consume up to 5 bytes
26// (the MSB in each byte is reserved for denoting whether more bytes follow).
27// But, the protocol only allows for 4KiB payloads, and the socket stream buffer
28// is only of size 8KiB. As such we should never need more than 2 bytes (max
29// value of 16KiB). Anything higher than that will result in an error, either
30// because the socket stream buffer overflowed or too many bytes were required
31// in the size packet.
32const int kSizePacketLenMin = 1;
33const int kSizePacketLenMax = 2;
34
35// The current MCS protocol version.
36const int kMCSVersion = 41;
37
38}  // namespace
39
40ConnectionHandlerImpl::ConnectionHandlerImpl(
41    base::TimeDelta read_timeout,
42    const ProtoReceivedCallback& read_callback,
43    const ProtoSentCallback& write_callback,
44    const ConnectionChangedCallback& connection_callback)
45    : read_timeout_(read_timeout),
46      socket_(NULL),
47      handshake_complete_(false),
48      message_tag_(0),
49      message_size_(0),
50      read_callback_(read_callback),
51      write_callback_(write_callback),
52      connection_callback_(connection_callback),
53      weak_ptr_factory_(this) {
54}
55
56ConnectionHandlerImpl::~ConnectionHandlerImpl() {
57}
58
59void ConnectionHandlerImpl::Init(
60    const mcs_proto::LoginRequest& login_request,
61    net::StreamSocket* socket) {
62  DCHECK(!read_callback_.is_null());
63  DCHECK(!write_callback_.is_null());
64  DCHECK(!connection_callback_.is_null());
65
66  // Invalidate any previously outstanding reads.
67  weak_ptr_factory_.InvalidateWeakPtrs();
68
69  handshake_complete_ = false;
70  message_tag_ = 0;
71  message_size_ = 0;
72  socket_ = socket;
73  input_stream_.reset(new SocketInputStream(socket_));
74  output_stream_.reset(new SocketOutputStream(socket_));
75
76  Login(login_request);
77}
78
79void ConnectionHandlerImpl::Reset() {
80  CloseConnection();
81}
82
83bool ConnectionHandlerImpl::CanSendMessage() const {
84  return handshake_complete_ && output_stream_.get() &&
85      output_stream_->GetState() == SocketOutputStream::EMPTY;
86}
87
88void ConnectionHandlerImpl::SendMessage(
89    const google::protobuf::MessageLite& message) {
90  DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
91  DCHECK(handshake_complete_);
92
93  {
94    CodedOutputStream coded_output_stream(output_stream_.get());
95    DVLOG(1) << "Writing proto of size " << message.ByteSize();
96    int tag = GetMCSProtoTag(message);
97    DCHECK_NE(tag, -1);
98    coded_output_stream.WriteRaw(&tag, 1);
99    coded_output_stream.WriteVarint32(message.ByteSize());
100    message.SerializeToCodedStream(&coded_output_stream);
101  }
102
103  if (output_stream_->Flush(
104          base::Bind(&ConnectionHandlerImpl::OnMessageSent,
105                     weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
106    OnMessageSent();
107  }
108}
109
110void ConnectionHandlerImpl::Login(
111    const google::protobuf::MessageLite& login_request) {
112  DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY);
113
114  const char version_byte[1] = {kMCSVersion};
115  const char login_request_tag[1] = {kLoginRequestTag};
116  {
117    CodedOutputStream coded_output_stream(output_stream_.get());
118    coded_output_stream.WriteRaw(version_byte, 1);
119    coded_output_stream.WriteRaw(login_request_tag, 1);
120    coded_output_stream.WriteVarint32(login_request.ByteSize());
121    login_request.SerializeToCodedStream(&coded_output_stream);
122  }
123
124  if (output_stream_->Flush(
125          base::Bind(&ConnectionHandlerImpl::OnMessageSent,
126                     weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) {
127    base::MessageLoop::current()->PostTask(
128        FROM_HERE,
129        base::Bind(&ConnectionHandlerImpl::OnMessageSent,
130                   weak_ptr_factory_.GetWeakPtr()));
131  }
132
133  read_timeout_timer_.Start(FROM_HERE,
134                            read_timeout_,
135                            base::Bind(&ConnectionHandlerImpl::OnTimeout,
136                                       weak_ptr_factory_.GetWeakPtr()));
137  WaitForData(MCS_VERSION_TAG_AND_SIZE);
138}
139
140void ConnectionHandlerImpl::OnMessageSent() {
141  if (!output_stream_.get()) {
142    // The connection has already been closed. Just return.
143    DCHECK(!input_stream_.get());
144    DCHECK(!read_timeout_timer_.IsRunning());
145    return;
146  }
147
148  if (output_stream_->GetState() != SocketOutputStream::EMPTY) {
149    int last_error = output_stream_->last_error();
150    CloseConnection();
151    // If the socket stream had an error, plumb it up, else plumb up FAILED.
152    if (last_error == net::OK)
153      last_error = net::ERR_FAILED;
154    connection_callback_.Run(last_error);
155    return;
156  }
157
158  write_callback_.Run();
159}
160
161void ConnectionHandlerImpl::GetNextMessage() {
162  DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() ||
163         SocketInputStream::READY == input_stream_->GetState());
164  message_tag_ = 0;
165  message_size_ = 0;
166
167  WaitForData(MCS_TAG_AND_SIZE);
168}
169
170void ConnectionHandlerImpl::WaitForData(ProcessingState state) {
171  DVLOG(1) << "Waiting for MCS data: state == " << state;
172
173  if (!input_stream_) {
174    // The connection has already been closed. Just return.
175    DCHECK(!output_stream_.get());
176    DCHECK(!read_timeout_timer_.IsRunning());
177    return;
178  }
179
180  if (input_stream_->GetState() != SocketInputStream::EMPTY &&
181      input_stream_->GetState() != SocketInputStream::READY) {
182    // An error occurred.
183    int last_error = output_stream_->last_error();
184    CloseConnection();
185    // If the socket stream had an error, plumb it up, else plumb up FAILED.
186    if (last_error == net::OK)
187      last_error = net::ERR_FAILED;
188    connection_callback_.Run(last_error);
189    return;
190  }
191
192  // Used to determine whether a Socket::Read is necessary.
193  int min_bytes_needed = 0;
194  // Used to limit the size of the Socket::Read.
195  int max_bytes_needed = 0;
196
197  switch(state) {
198    case MCS_VERSION_TAG_AND_SIZE:
199      min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin;
200      max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax;
201      break;
202    case MCS_TAG_AND_SIZE:
203      min_bytes_needed = kTagPacketLen + kSizePacketLenMin;
204      max_bytes_needed = kTagPacketLen + kSizePacketLenMax;
205      break;
206    case MCS_FULL_SIZE:
207      // If in this state, the minimum size packet length must already have been
208      // insufficient, so set both to the max length.
209      min_bytes_needed = kSizePacketLenMax;
210      max_bytes_needed = kSizePacketLenMax;
211      break;
212    case MCS_PROTO_BYTES:
213      read_timeout_timer_.Reset();
214      // No variability in the message size, set both to the same.
215      min_bytes_needed = message_size_;
216      max_bytes_needed = message_size_;
217      break;
218    default:
219      NOTREACHED();
220  }
221  DCHECK_GE(max_bytes_needed, min_bytes_needed);
222
223  int unread_byte_count = input_stream_->UnreadByteCount();
224  if (min_bytes_needed > unread_byte_count &&
225      input_stream_->Refresh(
226          base::Bind(&ConnectionHandlerImpl::WaitForData,
227                     weak_ptr_factory_.GetWeakPtr(),
228                     state),
229          max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) {
230    return;
231  }
232
233  // Check for refresh errors.
234  if (input_stream_->GetState() != SocketInputStream::READY) {
235    // An error occurred.
236    int last_error = input_stream_->last_error();
237    CloseConnection();
238    // If the socket stream had an error, plumb it up, else plumb up FAILED.
239    if (last_error == net::OK)
240      last_error = net::ERR_FAILED;
241    connection_callback_.Run(last_error);
242    return;
243  }
244
245  // Check whether read is complete, or needs to be continued (
246  // SocketInputStream::Refresh can finish without reading all the data).
247  if (input_stream_->UnreadByteCount() < min_bytes_needed) {
248    DVLOG(1) << "Socket read finished prematurely. Waiting for "
249             << min_bytes_needed - input_stream_->UnreadByteCount()
250             << " more bytes.";
251    base::MessageLoop::current()->PostTask(
252        FROM_HERE,
253        base::Bind(&ConnectionHandlerImpl::WaitForData,
254                   weak_ptr_factory_.GetWeakPtr(),
255                   MCS_PROTO_BYTES));
256    return;
257  }
258
259  // Received enough bytes, process them.
260  DVLOG(1) << "Processing MCS data: state == " << state;
261  switch(state) {
262    case MCS_VERSION_TAG_AND_SIZE:
263      OnGotVersion();
264      break;
265    case MCS_TAG_AND_SIZE:
266      OnGotMessageTag();
267      break;
268    case MCS_FULL_SIZE:
269      OnGotMessageSize();
270      break;
271    case MCS_PROTO_BYTES:
272      OnGotMessageBytes();
273      break;
274    default:
275      NOTREACHED();
276  }
277}
278
279void ConnectionHandlerImpl::OnGotVersion() {
280  uint8 version = 0;
281  {
282    CodedInputStream coded_input_stream(input_stream_.get());
283    coded_input_stream.ReadRaw(&version, 1);
284  }
285  // TODO(zea): remove this when the server is ready.
286  if (version < kMCSVersion && version != 38) {
287    LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version);
288    connection_callback_.Run(net::ERR_FAILED);
289    return;
290  }
291
292  input_stream_->RebuildBuffer();
293
294  // Process the LoginResponse message tag.
295  OnGotMessageTag();
296}
297
298void ConnectionHandlerImpl::OnGotMessageTag() {
299  if (input_stream_->GetState() != SocketInputStream::READY) {
300    LOG(ERROR) << "Failed to receive protobuf tag.";
301    read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
302    return;
303  }
304
305  {
306    CodedInputStream coded_input_stream(input_stream_.get());
307    coded_input_stream.ReadRaw(&message_tag_, 1);
308  }
309
310  DVLOG(1) << "Received proto of type "
311           << static_cast<unsigned int>(message_tag_);
312
313  if (!read_timeout_timer_.IsRunning()) {
314    read_timeout_timer_.Start(FROM_HERE,
315                              read_timeout_,
316                              base::Bind(&ConnectionHandlerImpl::OnTimeout,
317                                         weak_ptr_factory_.GetWeakPtr()));
318  }
319  OnGotMessageSize();
320}
321
322void ConnectionHandlerImpl::OnGotMessageSize() {
323  if (input_stream_->GetState() != SocketInputStream::READY) {
324    LOG(ERROR) << "Failed to receive message size.";
325    read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>());
326    return;
327  }
328
329  bool need_another_byte = false;
330  int prev_byte_count = input_stream_->UnreadByteCount();
331  {
332    CodedInputStream coded_input_stream(input_stream_.get());
333    if (!coded_input_stream.ReadVarint32(&message_size_))
334      need_another_byte = true;
335  }
336
337  if (need_another_byte) {
338    DVLOG(1) << "Expecting another message size byte.";
339    if (prev_byte_count >= kSizePacketLenMax) {
340      // Already had enough bytes, something else went wrong.
341      LOG(ERROR) << "Failed to process message size, too many bytes needed.";
342      connection_callback_.Run(net::ERR_FILE_TOO_BIG);
343      return;
344    }
345    // Back up by the amount read (should always be 1 byte).
346    int bytes_read = prev_byte_count - input_stream_->UnreadByteCount();
347    DCHECK_EQ(bytes_read, 1);
348    input_stream_->BackUp(bytes_read);
349    WaitForData(MCS_FULL_SIZE);
350    return;
351  }
352
353  DVLOG(1) << "Proto size: " << message_size_;
354
355  if (message_size_ > 0)
356    WaitForData(MCS_PROTO_BYTES);
357  else
358    OnGotMessageBytes();
359}
360
361void ConnectionHandlerImpl::OnGotMessageBytes() {
362  read_timeout_timer_.Stop();
363  scoped_ptr<google::protobuf::MessageLite> protobuf(
364      BuildProtobufFromTag(message_tag_));
365  // Messages with no content are valid; just use the default protobuf for
366  // that tag.
367  if (protobuf.get() && message_size_ == 0) {
368    base::MessageLoop::current()->PostTask(
369        FROM_HERE,
370        base::Bind(&ConnectionHandlerImpl::GetNextMessage,
371                   weak_ptr_factory_.GetWeakPtr()));
372    read_callback_.Run(protobuf.Pass());
373    return;
374  }
375
376  if (input_stream_->GetState() != SocketInputStream::READY) {
377    LOG(ERROR) << "Failed to extract protobuf bytes of type "
378               << static_cast<unsigned int>(message_tag_);
379    // Reset the connection.
380    connection_callback_.Run(net::ERR_FAILED);
381    return;
382  }
383
384  if (!protobuf.get()) {
385     LOG(ERROR) << "Received message of invalid type "
386                << static_cast<unsigned int>(message_tag_);
387     connection_callback_.Run(net::ERR_INVALID_ARGUMENT);
388     return;
389   }
390
391  {
392    CodedInputStream coded_input_stream(input_stream_.get());
393    if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) {
394      LOG(ERROR) << "Unable to parse GCM message of type "
395                 << static_cast<unsigned int>(message_tag_);
396      // Reset the connection.
397      connection_callback_.Run(net::ERR_FAILED);
398      return;
399    }
400  }
401
402  input_stream_->RebuildBuffer();
403  base::MessageLoop::current()->PostTask(
404      FROM_HERE,
405      base::Bind(&ConnectionHandlerImpl::GetNextMessage,
406                 weak_ptr_factory_.GetWeakPtr()));
407  if (message_tag_ == kLoginResponseTag) {
408    if (handshake_complete_) {
409      LOG(ERROR) << "Unexpected login response.";
410    } else {
411      handshake_complete_ = true;
412      DVLOG(1) << "GCM Handshake complete.";
413      connection_callback_.Run(net::OK);
414    }
415  }
416  read_callback_.Run(protobuf.Pass());
417}
418
419void ConnectionHandlerImpl::OnTimeout() {
420  LOG(ERROR) << "Timed out waiting for GCM Protocol buffer.";
421  CloseConnection();
422  connection_callback_.Run(net::ERR_TIMED_OUT);
423}
424
425void ConnectionHandlerImpl::CloseConnection() {
426  DVLOG(1) << "Closing connection.";
427  read_timeout_timer_.Stop();
428  if (socket_)
429    socket_->Disconnect();
430  socket_ = NULL;
431  handshake_complete_ = false;
432  message_tag_ = 0;
433  message_size_ = 0;
434  input_stream_.reset();
435  output_stream_.reset();
436  weak_ptr_factory_.InvalidateWeakPtrs();
437}
438
439}  // namespace gcm
440