1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/channel_multiplexer.h"
6
7#include <string.h>
8
9#include "base/bind.h"
10#include "base/callback.h"
11#include "base/location.h"
12#include "base/single_thread_task_runner.h"
13#include "base/stl_util.h"
14#include "base/thread_task_runner_handle.h"
15#include "net/base/net_errors.h"
16#include "net/socket/stream_socket.h"
17#include "remoting/protocol/message_serialization.h"
18
19namespace remoting {
20namespace protocol {
21
22namespace {
23const int kChannelIdUnknown = -1;
24const int kMaxPacketSize = 1024;
25
26class PendingPacket {
27 public:
28  PendingPacket(scoped_ptr<MultiplexPacket> packet,
29                const base::Closure& done_task)
30      : packet(packet.Pass()),
31        done_task(done_task),
32        pos(0U) {
33  }
34  ~PendingPacket() {
35    done_task.Run();
36  }
37
38  bool is_empty() { return pos >= packet->data().size(); }
39
40  int Read(char* buffer, size_t size) {
41    size = std::min(size, packet->data().size() - pos);
42    memcpy(buffer, packet->data().data() + pos, size);
43    pos += size;
44    return size;
45  }
46
47 private:
48  scoped_ptr<MultiplexPacket> packet;
49  base::Closure done_task;
50  size_t pos;
51
52  DISALLOW_COPY_AND_ASSIGN(PendingPacket);
53};
54
55}  // namespace
56
57const char ChannelMultiplexer::kMuxChannelName[] = "mux";
58
59struct ChannelMultiplexer::PendingChannel {
60  PendingChannel(const std::string& name,
61                 const ChannelCreatedCallback& callback)
62      : name(name), callback(callback) {
63  }
64  std::string name;
65  ChannelCreatedCallback callback;
66};
67
68class ChannelMultiplexer::MuxChannel {
69 public:
70  MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
71             int send_id);
72  ~MuxChannel();
73
74  const std::string& name() { return name_; }
75  int receive_id() { return receive_id_; }
76  void set_receive_id(int id) { receive_id_ = id; }
77
78  // Called by ChannelMultiplexer.
79  scoped_ptr<net::StreamSocket> CreateSocket();
80  void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
81                        const base::Closure& done_task);
82  void OnWriteFailed();
83
84  // Called by MuxSocket.
85  void OnSocketDestroyed();
86  bool DoWrite(scoped_ptr<MultiplexPacket> packet,
87               const base::Closure& done_task);
88  int DoRead(net::IOBuffer* buffer, int buffer_len);
89
90 private:
91  ChannelMultiplexer* multiplexer_;
92  std::string name_;
93  int send_id_;
94  bool id_sent_;
95  int receive_id_;
96  MuxSocket* socket_;
97  std::list<PendingPacket*> pending_packets_;
98
99  DISALLOW_COPY_AND_ASSIGN(MuxChannel);
100};
101
102class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
103                                      public base::NonThreadSafe,
104                                      public base::SupportsWeakPtr<MuxSocket> {
105 public:
106  MuxSocket(MuxChannel* channel);
107  virtual ~MuxSocket();
108
109  void OnWriteComplete();
110  void OnWriteFailed();
111  void OnPacketReceived();
112
113  // net::StreamSocket interface.
114  virtual int Read(net::IOBuffer* buffer, int buffer_len,
115                   const net::CompletionCallback& callback) OVERRIDE;
116  virtual int Write(net::IOBuffer* buffer, int buffer_len,
117                    const net::CompletionCallback& callback) OVERRIDE;
118
119  virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
120    NOTIMPLEMENTED();
121    return net::ERR_NOT_IMPLEMENTED;
122  }
123  virtual int SetSendBufferSize(int32 size) OVERRIDE {
124    NOTIMPLEMENTED();
125    return net::ERR_NOT_IMPLEMENTED;
126  }
127
128  virtual int Connect(const net::CompletionCallback& callback) OVERRIDE {
129    NOTIMPLEMENTED();
130    return net::ERR_NOT_IMPLEMENTED;
131  }
132  virtual void Disconnect() OVERRIDE {
133    NOTIMPLEMENTED();
134  }
135  virtual bool IsConnected() const OVERRIDE {
136    NOTIMPLEMENTED();
137    return true;
138  }
139  virtual bool IsConnectedAndIdle() const OVERRIDE {
140    NOTIMPLEMENTED();
141    return false;
142  }
143  virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE {
144    NOTIMPLEMENTED();
145    return net::ERR_NOT_IMPLEMENTED;
146  }
147  virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE {
148    NOTIMPLEMENTED();
149    return net::ERR_NOT_IMPLEMENTED;
150  }
151  virtual const net::BoundNetLog& NetLog() const OVERRIDE {
152    NOTIMPLEMENTED();
153    return net_log_;
154  }
155  virtual void SetSubresourceSpeculation() OVERRIDE {
156    NOTIMPLEMENTED();
157  }
158  virtual void SetOmniboxSpeculation() OVERRIDE {
159    NOTIMPLEMENTED();
160  }
161  virtual bool WasEverUsed() const OVERRIDE {
162    return true;
163  }
164  virtual bool UsingTCPFastOpen() const OVERRIDE {
165    return false;
166  }
167  virtual bool WasNpnNegotiated() const OVERRIDE {
168    return false;
169  }
170  virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE {
171    return net::kProtoUnknown;
172  }
173  virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE {
174    NOTIMPLEMENTED();
175    return false;
176  }
177
178 private:
179  MuxChannel* channel_;
180
181  net::CompletionCallback read_callback_;
182  scoped_refptr<net::IOBuffer> read_buffer_;
183  int read_buffer_size_;
184
185  bool write_pending_;
186  int write_result_;
187  net::CompletionCallback write_callback_;
188
189  net::BoundNetLog net_log_;
190
191  DISALLOW_COPY_AND_ASSIGN(MuxSocket);
192};
193
194
195ChannelMultiplexer::MuxChannel::MuxChannel(
196    ChannelMultiplexer* multiplexer,
197    const std::string& name,
198    int send_id)
199    : multiplexer_(multiplexer),
200      name_(name),
201      send_id_(send_id),
202      id_sent_(false),
203      receive_id_(kChannelIdUnknown),
204      socket_(NULL) {
205}
206
207ChannelMultiplexer::MuxChannel::~MuxChannel() {
208  // Socket must be destroyed before the channel.
209  DCHECK(!socket_);
210  STLDeleteElements(&pending_packets_);
211}
212
213scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() {
214  DCHECK(!socket_);  // Can't create more than one socket per channel.
215  scoped_ptr<MuxSocket> result(new MuxSocket(this));
216  socket_ = result.get();
217  return result.PassAs<net::StreamSocket>();
218}
219
220void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
221    scoped_ptr<MultiplexPacket> packet,
222    const base::Closure& done_task) {
223  DCHECK_EQ(packet->channel_id(), receive_id_);
224  if (packet->data().size() > 0) {
225    pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
226    if (socket_) {
227      // Notify the socket that we have more data.
228      socket_->OnPacketReceived();
229    }
230  }
231}
232
233void ChannelMultiplexer::MuxChannel::OnWriteFailed() {
234  if (socket_)
235    socket_->OnWriteFailed();
236}
237
238void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
239  DCHECK(socket_);
240  socket_ = NULL;
241}
242
243bool ChannelMultiplexer::MuxChannel::DoWrite(
244    scoped_ptr<MultiplexPacket> packet,
245    const base::Closure& done_task) {
246  packet->set_channel_id(send_id_);
247  if (!id_sent_) {
248    packet->set_channel_name(name_);
249    id_sent_ = true;
250  }
251  return multiplexer_->DoWrite(packet.Pass(), done_task);
252}
253
254int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer,
255                                           int buffer_len) {
256  int pos = 0;
257  while (buffer_len > 0 && !pending_packets_.empty()) {
258    DCHECK(!pending_packets_.front()->is_empty());
259    int result = pending_packets_.front()->Read(
260        buffer->data() + pos, buffer_len);
261    DCHECK_LE(result, buffer_len);
262    pos += result;
263    buffer_len -= pos;
264    if (pending_packets_.front()->is_empty()) {
265      delete pending_packets_.front();
266      pending_packets_.erase(pending_packets_.begin());
267    }
268  }
269  return pos;
270}
271
272ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
273    : channel_(channel),
274      read_buffer_size_(0),
275      write_pending_(false),
276      write_result_(0) {
277}
278
279ChannelMultiplexer::MuxSocket::~MuxSocket() {
280  channel_->OnSocketDestroyed();
281}
282
283int ChannelMultiplexer::MuxSocket::Read(
284    net::IOBuffer* buffer, int buffer_len,
285    const net::CompletionCallback& callback) {
286  DCHECK(CalledOnValidThread());
287  DCHECK(read_callback_.is_null());
288
289  int result = channel_->DoRead(buffer, buffer_len);
290  if (result == 0) {
291    read_buffer_ = buffer;
292    read_buffer_size_ = buffer_len;
293    read_callback_ = callback;
294    return net::ERR_IO_PENDING;
295  }
296  return result;
297}
298
299int ChannelMultiplexer::MuxSocket::Write(
300    net::IOBuffer* buffer, int buffer_len,
301    const net::CompletionCallback& callback) {
302  DCHECK(CalledOnValidThread());
303
304  scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
305  size_t size = std::min(kMaxPacketSize, buffer_len);
306  packet->mutable_data()->assign(buffer->data(), size);
307
308  write_pending_ = true;
309  bool result = channel_->DoWrite(packet.Pass(), base::Bind(
310      &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
311
312  if (!result) {
313    // Cannot complete the write, e.g. if the connection has been terminated.
314    return net::ERR_FAILED;
315  }
316
317  // OnWriteComplete() might be called above synchronously.
318  if (write_pending_) {
319    DCHECK(write_callback_.is_null());
320    write_callback_ = callback;
321    write_result_ = size;
322    return net::ERR_IO_PENDING;
323  }
324
325  return size;
326}
327
328void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
329  write_pending_ = false;
330  if (!write_callback_.is_null()) {
331    net::CompletionCallback cb;
332    std::swap(cb, write_callback_);
333    cb.Run(write_result_);
334  }
335}
336
337void ChannelMultiplexer::MuxSocket::OnWriteFailed() {
338  if (!write_callback_.is_null()) {
339    net::CompletionCallback cb;
340    std::swap(cb, write_callback_);
341    cb.Run(net::ERR_FAILED);
342  }
343}
344
345void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
346  if (!read_callback_.is_null()) {
347    int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
348    read_buffer_ = NULL;
349    DCHECK_GT(result, 0);
350    net::CompletionCallback cb;
351    std::swap(cb, read_callback_);
352    cb.Run(result);
353  }
354}
355
356ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory,
357                                       const std::string& base_channel_name)
358    : base_channel_factory_(factory),
359      base_channel_name_(base_channel_name),
360      next_channel_id_(0),
361      weak_factory_(this) {
362}
363
364ChannelMultiplexer::~ChannelMultiplexer() {
365  DCHECK(pending_channels_.empty());
366  STLDeleteValues(&channels_);
367
368  // Cancel creation of the base channel if it hasn't finished.
369  if (base_channel_factory_)
370    base_channel_factory_->CancelChannelCreation(base_channel_name_);
371}
372
373void ChannelMultiplexer::CreateChannel(const std::string& name,
374                                       const ChannelCreatedCallback& callback) {
375  if (base_channel_.get()) {
376    // Already have |base_channel_|. Create new multiplexed channel
377    // synchronously.
378    callback.Run(GetOrCreateChannel(name)->CreateSocket());
379  } else if (!base_channel_.get() && !base_channel_factory_) {
380    // Fail synchronously if we failed to create |base_channel_|.
381    callback.Run(scoped_ptr<net::StreamSocket>());
382  } else {
383    // Still waiting for the |base_channel_|.
384    pending_channels_.push_back(PendingChannel(name, callback));
385
386    // If this is the first multiplexed channel then create the base channel.
387    if (pending_channels_.size() == 1U) {
388      base_channel_factory_->CreateChannel(
389          base_channel_name_,
390          base::Bind(&ChannelMultiplexer::OnBaseChannelReady,
391                     base::Unretained(this)));
392    }
393  }
394}
395
396void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
397  for (std::list<PendingChannel>::iterator it = pending_channels_.begin();
398       it != pending_channels_.end(); ++it) {
399    if (it->name == name) {
400      pending_channels_.erase(it);
401      return;
402    }
403  }
404}
405
406void ChannelMultiplexer::OnBaseChannelReady(
407    scoped_ptr<net::StreamSocket> socket) {
408  base_channel_factory_ = NULL;
409  base_channel_ = socket.Pass();
410
411  if (base_channel_.get()) {
412    // Initialize reader and writer.
413    reader_.Init(base_channel_.get(),
414                 base::Bind(&ChannelMultiplexer::OnIncomingPacket,
415                            base::Unretained(this)));
416    writer_.Init(base_channel_.get(),
417                 base::Bind(&ChannelMultiplexer::OnWriteFailed,
418                            base::Unretained(this)));
419  }
420
421  DoCreatePendingChannels();
422}
423
424void ChannelMultiplexer::DoCreatePendingChannels() {
425  if (pending_channels_.empty())
426    return;
427
428  // Every time this function is called it connects a single channel and posts a
429  // separate task to connect other channels. This is necessary because the
430  // callback may destroy the multiplexer or somehow else modify
431  // |pending_channels_| list (e.g. call CancelChannelCreation()).
432  base::ThreadTaskRunnerHandle::Get()->PostTask(
433      FROM_HERE, base::Bind(&ChannelMultiplexer::DoCreatePendingChannels,
434                            weak_factory_.GetWeakPtr()));
435
436  PendingChannel c = pending_channels_.front();
437  pending_channels_.erase(pending_channels_.begin());
438  scoped_ptr<net::StreamSocket> socket;
439  if (base_channel_.get())
440    socket = GetOrCreateChannel(c.name)->CreateSocket();
441  c.callback.Run(socket.Pass());
442}
443
444ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
445    const std::string& name) {
446  // Check if we already have a channel with the requested name.
447  std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
448  if (it != channels_.end())
449    return it->second;
450
451  // Create a new channel if we haven't found existing one.
452  MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
453  ++next_channel_id_;
454  channels_[channel->name()] = channel;
455  return channel;
456}
457
458
459void ChannelMultiplexer::OnWriteFailed(int error) {
460  for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
461       it != channels_.end(); ++it) {
462    base::ThreadTaskRunnerHandle::Get()->PostTask(
463        FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed,
464                              weak_factory_.GetWeakPtr(), it->second->name()));
465  }
466}
467
468void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) {
469  std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
470  if (it != channels_.end()) {
471    it->second->OnWriteFailed();
472  }
473}
474
475void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
476                                          const base::Closure& done_task) {
477  DCHECK(packet->has_channel_id());
478  if (!packet->has_channel_id()) {
479    LOG(ERROR) << "Received packet without channel_id.";
480    done_task.Run();
481    return;
482  }
483
484  int receive_id = packet->channel_id();
485  MuxChannel* channel = NULL;
486  std::map<int, MuxChannel*>::iterator it =
487      channels_by_receive_id_.find(receive_id);
488  if (it != channels_by_receive_id_.end()) {
489    channel = it->second;
490  } else {
491    // This is a new |channel_id| we haven't seen before. Look it up by name.
492    if (!packet->has_channel_name()) {
493      LOG(ERROR) << "Received packet with unknown channel_id and "
494          "without channel_name.";
495      done_task.Run();
496      return;
497    }
498    channel = GetOrCreateChannel(packet->channel_name());
499    channel->set_receive_id(receive_id);
500    channels_by_receive_id_[receive_id] = channel;
501  }
502
503  channel->OnIncomingPacket(packet.Pass(), done_task);
504}
505
506bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
507                                 const base::Closure& done_task) {
508  return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
509}
510
511}  // namespace protocol
512}  // namespace remoting
513