1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "remoting/protocol/message_reader.h" 6 7#include "base/bind.h" 8#include "base/callback.h" 9#include "base/compiler_specific.h" 10#include "base/location.h" 11#include "base/thread_task_runner_handle.h" 12#include "base/single_thread_task_runner.h" 13#include "net/base/io_buffer.h" 14#include "net/base/net_errors.h" 15#include "net/socket/socket.h" 16#include "remoting/base/compound_buffer.h" 17#include "remoting/proto/internal.pb.h" 18 19namespace remoting { 20namespace protocol { 21 22static const int kReadBufferSize = 4096; 23 24MessageReader::MessageReader() 25 : socket_(NULL), 26 read_pending_(false), 27 pending_messages_(0), 28 closed_(false), 29 weak_factory_(this) { 30} 31 32void MessageReader::Init(net::Socket* socket, 33 const MessageReceivedCallback& callback) { 34 DCHECK(CalledOnValidThread()); 35 message_received_callback_ = callback; 36 DCHECK(socket); 37 socket_ = socket; 38 DoRead(); 39} 40 41MessageReader::~MessageReader() { 42} 43 44void MessageReader::DoRead() { 45 DCHECK(CalledOnValidThread()); 46 // Don't try to read again if there is another read pending or we 47 // have messages that we haven't finished processing yet. 48 while (!closed_ && !read_pending_ && pending_messages_ == 0) { 49 read_buffer_ = new net::IOBuffer(kReadBufferSize); 50 int result = socket_->Read( 51 read_buffer_.get(), 52 kReadBufferSize, 53 base::Bind(&MessageReader::OnRead, weak_factory_.GetWeakPtr())); 54 HandleReadResult(result); 55 } 56} 57 58void MessageReader::OnRead(int result) { 59 DCHECK(CalledOnValidThread()); 60 DCHECK(read_pending_); 61 read_pending_ = false; 62 63 if (!closed_) { 64 HandleReadResult(result); 65 DoRead(); 66 } 67} 68 69void MessageReader::HandleReadResult(int result) { 70 DCHECK(CalledOnValidThread()); 71 if (closed_) 72 return; 73 74 if (result > 0) { 75 OnDataReceived(read_buffer_.get(), result); 76 } else if (result == net::ERR_IO_PENDING) { 77 read_pending_ = true; 78 } else { 79 if (result != net::ERR_CONNECTION_CLOSED) { 80 LOG(ERROR) << "Read() returned error " << result; 81 } 82 // Stop reading after any error. 83 closed_ = true; 84 } 85} 86 87void MessageReader::OnDataReceived(net::IOBuffer* data, int data_size) { 88 DCHECK(CalledOnValidThread()); 89 message_decoder_.AddData(data, data_size); 90 91 // Get list of all new messages first, and then call the callback 92 // for all of them. 93 while (true) { 94 CompoundBuffer* buffer = message_decoder_.GetNextMessage(); 95 if (!buffer) 96 break; 97 pending_messages_++; 98 base::ThreadTaskRunnerHandle::Get()->PostTask( 99 FROM_HERE, 100 base::Bind(&MessageReader::RunCallback, 101 weak_factory_.GetWeakPtr(), 102 base::Passed(scoped_ptr<CompoundBuffer>(buffer)))); 103 } 104} 105 106void MessageReader::RunCallback(scoped_ptr<CompoundBuffer> message) { 107 message_received_callback_.Run( 108 message.Pass(), base::Bind(&MessageReader::OnMessageDone, 109 weak_factory_.GetWeakPtr())); 110} 111 112void MessageReader::OnMessageDone() { 113 DCHECK(CalledOnValidThread()); 114 pending_messages_--; 115 DCHECK_GE(pending_messages_, 0); 116 117 // Start next read if necessary. 118 DoRead(); 119} 120 121} // namespace protocol 122} // namespace remoting 123