1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/buffered_socket_writer.h"
6
7#include "base/bind.h"
8#include "base/location.h"
9#include "base/single_thread_task_runner.h"
10#include "base/stl_util.h"
11#include "base/thread_task_runner_handle.h"
12#include "net/base/net_errors.h"
13
14namespace remoting {
15namespace protocol {
16
17struct BufferedSocketWriterBase::PendingPacket {
18  PendingPacket(scoped_refptr<net::IOBufferWithSize> data,
19                const base::Closure& done_task)
20      : data(data),
21        done_task(done_task) {
22  }
23
24  scoped_refptr<net::IOBufferWithSize> data;
25  base::Closure done_task;
26};
27
28BufferedSocketWriterBase::BufferedSocketWriterBase()
29    : buffer_size_(0),
30      socket_(NULL),
31      write_pending_(false),
32      closed_(false),
33      destroyed_flag_(NULL) {
34}
35
36void BufferedSocketWriterBase::Init(net::Socket* socket,
37                                    const WriteFailedCallback& callback) {
38  DCHECK(CalledOnValidThread());
39  DCHECK(socket);
40  socket_ = socket;
41  write_failed_callback_ = callback;
42}
43
44bool BufferedSocketWriterBase::Write(
45    scoped_refptr<net::IOBufferWithSize> data, const base::Closure& done_task) {
46  DCHECK(CalledOnValidThread());
47  DCHECK(socket_);
48  DCHECK(data.get());
49
50  // Don't write after Close().
51  if (closed_)
52    return false;
53
54  queue_.push_back(new PendingPacket(data, done_task));
55  buffer_size_ += data->size();
56
57  DoWrite();
58
59  // DoWrite() may trigger OnWriteError() to be called.
60  return !closed_;
61}
62
63void BufferedSocketWriterBase::DoWrite() {
64  DCHECK(CalledOnValidThread());
65  DCHECK(socket_);
66
67  // Don't try to write if there is another write pending.
68  if (write_pending_)
69    return;
70
71  // Don't write after Close().
72  if (closed_)
73    return;
74
75  while (true) {
76    net::IOBuffer* current_packet;
77    int current_packet_size;
78    GetNextPacket(&current_packet, &current_packet_size);
79
80    // Return if the queue is empty.
81    if (!current_packet)
82      return;
83
84    int result = socket_->Write(
85        current_packet, current_packet_size,
86        base::Bind(&BufferedSocketWriterBase::OnWritten,
87                   base::Unretained(this)));
88    bool write_again = false;
89    HandleWriteResult(result, &write_again);
90    if (!write_again)
91      return;
92  }
93}
94
95void BufferedSocketWriterBase::HandleWriteResult(int result,
96                                                 bool* write_again) {
97  *write_again = false;
98  if (result < 0) {
99    if (result == net::ERR_IO_PENDING) {
100      write_pending_ = true;
101    } else {
102      HandleError(result);
103      if (!write_failed_callback_.is_null())
104        write_failed_callback_.Run(result);
105    }
106    return;
107  }
108
109  base::Closure done_task = AdvanceBufferPosition(result);
110  if (!done_task.is_null()) {
111    bool destroyed = false;
112    destroyed_flag_ = &destroyed;
113    done_task.Run();
114    if (destroyed) {
115      // Stop doing anything if we've been destroyed by the callback.
116      return;
117    }
118    destroyed_flag_ = NULL;
119  }
120
121  *write_again = true;
122}
123
124void BufferedSocketWriterBase::OnWritten(int result) {
125  DCHECK(CalledOnValidThread());
126  DCHECK(write_pending_);
127  write_pending_ = false;
128
129  bool write_again;
130  HandleWriteResult(result, &write_again);
131  if (write_again)
132    DoWrite();
133}
134
135void BufferedSocketWriterBase::HandleError(int result) {
136  DCHECK(CalledOnValidThread());
137
138  closed_ = true;
139
140  STLDeleteElements(&queue_);
141
142  // Notify subclass that an error is received.
143  OnError(result);
144}
145
146int BufferedSocketWriterBase::GetBufferSize() {
147  return buffer_size_;
148}
149
150int BufferedSocketWriterBase::GetBufferChunks() {
151  return queue_.size();
152}
153
154void BufferedSocketWriterBase::Close() {
155  DCHECK(CalledOnValidThread());
156  closed_ = true;
157}
158
159BufferedSocketWriterBase::~BufferedSocketWriterBase() {
160  if (destroyed_flag_)
161    *destroyed_flag_ = true;
162
163  STLDeleteElements(&queue_);
164}
165
166base::Closure BufferedSocketWriterBase::PopQueue() {
167  base::Closure result = queue_.front()->done_task;
168  delete queue_.front();
169  queue_.pop_front();
170  return result;
171}
172
173BufferedSocketWriter::BufferedSocketWriter() {
174}
175
176void BufferedSocketWriter::GetNextPacket(
177    net::IOBuffer** buffer, int* size) {
178  if (!current_buf_.get()) {
179    if (queue_.empty()) {
180      *buffer = NULL;
181      return;  // Nothing to write.
182    }
183    current_buf_ = new net::DrainableIOBuffer(queue_.front()->data.get(),
184                                              queue_.front()->data->size());
185  }
186
187  *buffer = current_buf_.get();
188  *size = current_buf_->BytesRemaining();
189}
190
191base::Closure BufferedSocketWriter::AdvanceBufferPosition(int written) {
192  buffer_size_ -= written;
193  current_buf_->DidConsume(written);
194
195  if (current_buf_->BytesRemaining() == 0) {
196    current_buf_ = NULL;
197    return PopQueue();
198  }
199  return base::Closure();
200}
201
202void BufferedSocketWriter::OnError(int result) {
203  current_buf_ = NULL;
204}
205
206BufferedSocketWriter::~BufferedSocketWriter() {
207}
208
209BufferedDatagramWriter::BufferedDatagramWriter() {
210}
211
212void BufferedDatagramWriter::GetNextPacket(
213    net::IOBuffer** buffer, int* size) {
214  if (queue_.empty()) {
215    *buffer = NULL;
216    return;  // Nothing to write.
217  }
218  *buffer = queue_.front()->data.get();
219  *size = queue_.front()->data->size();
220}
221
222base::Closure BufferedDatagramWriter::AdvanceBufferPosition(int written) {
223  DCHECK_EQ(written, queue_.front()->data->size());
224  buffer_size_ -= queue_.front()->data->size();
225  return PopQueue();
226}
227
228void BufferedDatagramWriter::OnError(int result) {
229  // Nothing to do here.
230}
231
232BufferedDatagramWriter::~BufferedDatagramWriter() {
233}
234
235}  // namespace protocol
236}  // namespace remoting
237