1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "google_apis/gcm/engine/connection_handler_impl.h" 6 7#include "base/message_loop/message_loop.h" 8#include "google/protobuf/io/coded_stream.h" 9#include "google_apis/gcm/base/mcs_util.h" 10#include "google_apis/gcm/base/socket_stream.h" 11#include "google_apis/gcm/protocol/mcs.pb.h" 12#include "net/base/net_errors.h" 13#include "net/socket/stream_socket.h" 14 15using namespace google::protobuf::io; 16 17namespace gcm { 18 19namespace { 20 21// # of bytes a MCS version packet consumes. 22const int kVersionPacketLen = 1; 23// # of bytes a tag packet consumes. 24const int kTagPacketLen = 1; 25// Max # of bytes a length packet consumes. A Varint32 can consume up to 5 bytes 26// (the MSB in each byte is reserved for denoting whether more bytes follow). 27// But, the protocol only allows for 4KiB payloads, and the socket stream buffer 28// is only of size 8KiB. As such we should never need more than 2 bytes (max 29// value of 16KiB). Anything higher than that will result in an error, either 30// because the socket stream buffer overflowed or too many bytes were required 31// in the size packet. 32const int kSizePacketLenMin = 1; 33const int kSizePacketLenMax = 2; 34 35// The current MCS protocol version. 36const int kMCSVersion = 41; 37 38} // namespace 39 40ConnectionHandlerImpl::ConnectionHandlerImpl( 41 base::TimeDelta read_timeout, 42 const ProtoReceivedCallback& read_callback, 43 const ProtoSentCallback& write_callback, 44 const ConnectionChangedCallback& connection_callback) 45 : read_timeout_(read_timeout), 46 socket_(NULL), 47 handshake_complete_(false), 48 message_tag_(0), 49 message_size_(0), 50 read_callback_(read_callback), 51 write_callback_(write_callback), 52 connection_callback_(connection_callback), 53 weak_ptr_factory_(this) { 54} 55 56ConnectionHandlerImpl::~ConnectionHandlerImpl() { 57} 58 59void ConnectionHandlerImpl::Init( 60 const mcs_proto::LoginRequest& login_request, 61 net::StreamSocket* socket) { 62 DCHECK(!read_callback_.is_null()); 63 DCHECK(!write_callback_.is_null()); 64 DCHECK(!connection_callback_.is_null()); 65 66 // Invalidate any previously outstanding reads. 67 weak_ptr_factory_.InvalidateWeakPtrs(); 68 69 handshake_complete_ = false; 70 message_tag_ = 0; 71 message_size_ = 0; 72 socket_ = socket; 73 input_stream_.reset(new SocketInputStream(socket_)); 74 output_stream_.reset(new SocketOutputStream(socket_)); 75 76 Login(login_request); 77} 78 79void ConnectionHandlerImpl::Reset() { 80 CloseConnection(); 81} 82 83bool ConnectionHandlerImpl::CanSendMessage() const { 84 return handshake_complete_ && output_stream_.get() && 85 output_stream_->GetState() == SocketOutputStream::EMPTY; 86} 87 88void ConnectionHandlerImpl::SendMessage( 89 const google::protobuf::MessageLite& message) { 90 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); 91 DCHECK(handshake_complete_); 92 93 { 94 CodedOutputStream coded_output_stream(output_stream_.get()); 95 DVLOG(1) << "Writing proto of size " << message.ByteSize(); 96 int tag = GetMCSProtoTag(message); 97 DCHECK_NE(tag, -1); 98 coded_output_stream.WriteRaw(&tag, 1); 99 coded_output_stream.WriteVarint32(message.ByteSize()); 100 message.SerializeToCodedStream(&coded_output_stream); 101 } 102 103 if (output_stream_->Flush( 104 base::Bind(&ConnectionHandlerImpl::OnMessageSent, 105 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { 106 OnMessageSent(); 107 } 108} 109 110void ConnectionHandlerImpl::Login( 111 const google::protobuf::MessageLite& login_request) { 112 DCHECK_EQ(output_stream_->GetState(), SocketOutputStream::EMPTY); 113 114 const char version_byte[1] = {kMCSVersion}; 115 const char login_request_tag[1] = {kLoginRequestTag}; 116 { 117 CodedOutputStream coded_output_stream(output_stream_.get()); 118 coded_output_stream.WriteRaw(version_byte, 1); 119 coded_output_stream.WriteRaw(login_request_tag, 1); 120 coded_output_stream.WriteVarint32(login_request.ByteSize()); 121 login_request.SerializeToCodedStream(&coded_output_stream); 122 } 123 124 if (output_stream_->Flush( 125 base::Bind(&ConnectionHandlerImpl::OnMessageSent, 126 weak_ptr_factory_.GetWeakPtr())) != net::ERR_IO_PENDING) { 127 base::MessageLoop::current()->PostTask( 128 FROM_HERE, 129 base::Bind(&ConnectionHandlerImpl::OnMessageSent, 130 weak_ptr_factory_.GetWeakPtr())); 131 } 132 133 read_timeout_timer_.Start(FROM_HERE, 134 read_timeout_, 135 base::Bind(&ConnectionHandlerImpl::OnTimeout, 136 weak_ptr_factory_.GetWeakPtr())); 137 WaitForData(MCS_VERSION_TAG_AND_SIZE); 138} 139 140void ConnectionHandlerImpl::OnMessageSent() { 141 if (!output_stream_.get()) { 142 // The connection has already been closed. Just return. 143 DCHECK(!input_stream_.get()); 144 DCHECK(!read_timeout_timer_.IsRunning()); 145 return; 146 } 147 148 if (output_stream_->GetState() != SocketOutputStream::EMPTY) { 149 int last_error = output_stream_->last_error(); 150 CloseConnection(); 151 // If the socket stream had an error, plumb it up, else plumb up FAILED. 152 if (last_error == net::OK) 153 last_error = net::ERR_FAILED; 154 connection_callback_.Run(last_error); 155 return; 156 } 157 158 write_callback_.Run(); 159} 160 161void ConnectionHandlerImpl::GetNextMessage() { 162 DCHECK(SocketInputStream::EMPTY == input_stream_->GetState() || 163 SocketInputStream::READY == input_stream_->GetState()); 164 message_tag_ = 0; 165 message_size_ = 0; 166 167 WaitForData(MCS_TAG_AND_SIZE); 168} 169 170void ConnectionHandlerImpl::WaitForData(ProcessingState state) { 171 DVLOG(1) << "Waiting for MCS data: state == " << state; 172 173 if (!input_stream_) { 174 // The connection has already been closed. Just return. 175 DCHECK(!output_stream_.get()); 176 DCHECK(!read_timeout_timer_.IsRunning()); 177 return; 178 } 179 180 if (input_stream_->GetState() != SocketInputStream::EMPTY && 181 input_stream_->GetState() != SocketInputStream::READY) { 182 // An error occurred. 183 int last_error = output_stream_->last_error(); 184 CloseConnection(); 185 // If the socket stream had an error, plumb it up, else plumb up FAILED. 186 if (last_error == net::OK) 187 last_error = net::ERR_FAILED; 188 connection_callback_.Run(last_error); 189 return; 190 } 191 192 // Used to determine whether a Socket::Read is necessary. 193 int min_bytes_needed = 0; 194 // Used to limit the size of the Socket::Read. 195 int max_bytes_needed = 0; 196 197 switch(state) { 198 case MCS_VERSION_TAG_AND_SIZE: 199 min_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMin; 200 max_bytes_needed = kVersionPacketLen + kTagPacketLen + kSizePacketLenMax; 201 break; 202 case MCS_TAG_AND_SIZE: 203 min_bytes_needed = kTagPacketLen + kSizePacketLenMin; 204 max_bytes_needed = kTagPacketLen + kSizePacketLenMax; 205 break; 206 case MCS_FULL_SIZE: 207 // If in this state, the minimum size packet length must already have been 208 // insufficient, so set both to the max length. 209 min_bytes_needed = kSizePacketLenMax; 210 max_bytes_needed = kSizePacketLenMax; 211 break; 212 case MCS_PROTO_BYTES: 213 read_timeout_timer_.Reset(); 214 // No variability in the message size, set both to the same. 215 min_bytes_needed = message_size_; 216 max_bytes_needed = message_size_; 217 break; 218 default: 219 NOTREACHED(); 220 } 221 DCHECK_GE(max_bytes_needed, min_bytes_needed); 222 223 int unread_byte_count = input_stream_->UnreadByteCount(); 224 if (min_bytes_needed > unread_byte_count && 225 input_stream_->Refresh( 226 base::Bind(&ConnectionHandlerImpl::WaitForData, 227 weak_ptr_factory_.GetWeakPtr(), 228 state), 229 max_bytes_needed - unread_byte_count) == net::ERR_IO_PENDING) { 230 return; 231 } 232 233 // Check for refresh errors. 234 if (input_stream_->GetState() != SocketInputStream::READY) { 235 // An error occurred. 236 int last_error = input_stream_->last_error(); 237 CloseConnection(); 238 // If the socket stream had an error, plumb it up, else plumb up FAILED. 239 if (last_error == net::OK) 240 last_error = net::ERR_FAILED; 241 connection_callback_.Run(last_error); 242 return; 243 } 244 245 // Check whether read is complete, or needs to be continued ( 246 // SocketInputStream::Refresh can finish without reading all the data). 247 if (input_stream_->UnreadByteCount() < min_bytes_needed) { 248 DVLOG(1) << "Socket read finished prematurely. Waiting for " 249 << min_bytes_needed - input_stream_->UnreadByteCount() 250 << " more bytes."; 251 base::MessageLoop::current()->PostTask( 252 FROM_HERE, 253 base::Bind(&ConnectionHandlerImpl::WaitForData, 254 weak_ptr_factory_.GetWeakPtr(), 255 MCS_PROTO_BYTES)); 256 return; 257 } 258 259 // Received enough bytes, process them. 260 DVLOG(1) << "Processing MCS data: state == " << state; 261 switch(state) { 262 case MCS_VERSION_TAG_AND_SIZE: 263 OnGotVersion(); 264 break; 265 case MCS_TAG_AND_SIZE: 266 OnGotMessageTag(); 267 break; 268 case MCS_FULL_SIZE: 269 OnGotMessageSize(); 270 break; 271 case MCS_PROTO_BYTES: 272 OnGotMessageBytes(); 273 break; 274 default: 275 NOTREACHED(); 276 } 277} 278 279void ConnectionHandlerImpl::OnGotVersion() { 280 uint8 version = 0; 281 { 282 CodedInputStream coded_input_stream(input_stream_.get()); 283 coded_input_stream.ReadRaw(&version, 1); 284 } 285 // TODO(zea): remove this when the server is ready. 286 if (version < kMCSVersion && version != 38) { 287 LOG(ERROR) << "Invalid GCM version response: " << static_cast<int>(version); 288 connection_callback_.Run(net::ERR_FAILED); 289 return; 290 } 291 292 input_stream_->RebuildBuffer(); 293 294 // Process the LoginResponse message tag. 295 OnGotMessageTag(); 296} 297 298void ConnectionHandlerImpl::OnGotMessageTag() { 299 if (input_stream_->GetState() != SocketInputStream::READY) { 300 LOG(ERROR) << "Failed to receive protobuf tag."; 301 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); 302 return; 303 } 304 305 { 306 CodedInputStream coded_input_stream(input_stream_.get()); 307 coded_input_stream.ReadRaw(&message_tag_, 1); 308 } 309 310 DVLOG(1) << "Received proto of type " 311 << static_cast<unsigned int>(message_tag_); 312 313 if (!read_timeout_timer_.IsRunning()) { 314 read_timeout_timer_.Start(FROM_HERE, 315 read_timeout_, 316 base::Bind(&ConnectionHandlerImpl::OnTimeout, 317 weak_ptr_factory_.GetWeakPtr())); 318 } 319 OnGotMessageSize(); 320} 321 322void ConnectionHandlerImpl::OnGotMessageSize() { 323 if (input_stream_->GetState() != SocketInputStream::READY) { 324 LOG(ERROR) << "Failed to receive message size."; 325 read_callback_.Run(scoped_ptr<google::protobuf::MessageLite>()); 326 return; 327 } 328 329 bool need_another_byte = false; 330 int prev_byte_count = input_stream_->UnreadByteCount(); 331 { 332 CodedInputStream coded_input_stream(input_stream_.get()); 333 if (!coded_input_stream.ReadVarint32(&message_size_)) 334 need_another_byte = true; 335 } 336 337 if (need_another_byte) { 338 DVLOG(1) << "Expecting another message size byte."; 339 if (prev_byte_count >= kSizePacketLenMax) { 340 // Already had enough bytes, something else went wrong. 341 LOG(ERROR) << "Failed to process message size, too many bytes needed."; 342 connection_callback_.Run(net::ERR_FILE_TOO_BIG); 343 return; 344 } 345 // Back up by the amount read (should always be 1 byte). 346 int bytes_read = prev_byte_count - input_stream_->UnreadByteCount(); 347 DCHECK_EQ(bytes_read, 1); 348 input_stream_->BackUp(bytes_read); 349 WaitForData(MCS_FULL_SIZE); 350 return; 351 } 352 353 DVLOG(1) << "Proto size: " << message_size_; 354 355 if (message_size_ > 0) 356 WaitForData(MCS_PROTO_BYTES); 357 else 358 OnGotMessageBytes(); 359} 360 361void ConnectionHandlerImpl::OnGotMessageBytes() { 362 read_timeout_timer_.Stop(); 363 scoped_ptr<google::protobuf::MessageLite> protobuf( 364 BuildProtobufFromTag(message_tag_)); 365 // Messages with no content are valid; just use the default protobuf for 366 // that tag. 367 if (protobuf.get() && message_size_ == 0) { 368 base::MessageLoop::current()->PostTask( 369 FROM_HERE, 370 base::Bind(&ConnectionHandlerImpl::GetNextMessage, 371 weak_ptr_factory_.GetWeakPtr())); 372 read_callback_.Run(protobuf.Pass()); 373 return; 374 } 375 376 if (input_stream_->GetState() != SocketInputStream::READY) { 377 LOG(ERROR) << "Failed to extract protobuf bytes of type " 378 << static_cast<unsigned int>(message_tag_); 379 // Reset the connection. 380 connection_callback_.Run(net::ERR_FAILED); 381 return; 382 } 383 384 if (!protobuf.get()) { 385 LOG(ERROR) << "Received message of invalid type " 386 << static_cast<unsigned int>(message_tag_); 387 connection_callback_.Run(net::ERR_INVALID_ARGUMENT); 388 return; 389 } 390 391 { 392 CodedInputStream coded_input_stream(input_stream_.get()); 393 if (!protobuf->ParsePartialFromCodedStream(&coded_input_stream)) { 394 LOG(ERROR) << "Unable to parse GCM message of type " 395 << static_cast<unsigned int>(message_tag_); 396 // Reset the connection. 397 connection_callback_.Run(net::ERR_FAILED); 398 return; 399 } 400 } 401 402 input_stream_->RebuildBuffer(); 403 base::MessageLoop::current()->PostTask( 404 FROM_HERE, 405 base::Bind(&ConnectionHandlerImpl::GetNextMessage, 406 weak_ptr_factory_.GetWeakPtr())); 407 if (message_tag_ == kLoginResponseTag) { 408 if (handshake_complete_) { 409 LOG(ERROR) << "Unexpected login response."; 410 } else { 411 handshake_complete_ = true; 412 DVLOG(1) << "GCM Handshake complete."; 413 connection_callback_.Run(net::OK); 414 } 415 } 416 read_callback_.Run(protobuf.Pass()); 417} 418 419void ConnectionHandlerImpl::OnTimeout() { 420 LOG(ERROR) << "Timed out waiting for GCM Protocol buffer."; 421 CloseConnection(); 422 connection_callback_.Run(net::ERR_TIMED_OUT); 423} 424 425void ConnectionHandlerImpl::CloseConnection() { 426 DVLOG(1) << "Closing connection."; 427 read_timeout_timer_.Stop(); 428 if (socket_) 429 socket_->Disconnect(); 430 socket_ = NULL; 431 handshake_complete_ = false; 432 message_tag_ = 0; 433 message_size_ = 0; 434 input_stream_.reset(); 435 output_stream_.reset(); 436 weak_ptr_factory_.InvalidateWeakPtrs(); 437} 438 439} // namespace gcm 440