1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "content/browser/devtools/tethering_handler.h"
6
7#include "base/bind.h"
8#include "base/callback.h"
9#include "base/stl_util.h"
10#include "base/values.h"
11#include "content/browser/devtools/devtools_http_handler_impl.h"
12#include "content/browser/devtools/devtools_protocol_constants.h"
13#include "content/public/browser/devtools_http_handler_delegate.h"
14#include "net/base/io_buffer.h"
15#include "net/base/ip_endpoint.h"
16#include "net/base/net_errors.h"
17#include "net/base/net_log.h"
18#include "net/socket/stream_listen_socket.h"
19#include "net/socket/stream_socket.h"
20#include "net/socket/tcp_server_socket.h"
21
22namespace content {
23
24namespace {
25
26const char kLocalhost[] = "127.0.0.1";
27
28const int kListenBacklog = 5;
29const int kBufferSize = 16 * 1024;
30
31const int kMinTetheringPort = 1024;
32const int kMaxTetheringPort = 32767;
33
34class SocketPump : public net::StreamListenSocket::Delegate {
35 public:
36  SocketPump(DevToolsHttpHandlerDelegate* delegate,
37             net::StreamSocket* client_socket)
38      : client_socket_(client_socket),
39        delegate_(delegate),
40        wire_buffer_size_(0),
41        pending_destruction_(false) {
42  }
43
44  std::string Init() {
45    std::string channel_name;
46    server_socket_ = delegate_->CreateSocketForTethering(this, &channel_name);
47    if (!server_socket_.get() || channel_name.empty())
48      SelfDestruct();
49    return channel_name;
50  }
51
52  virtual ~SocketPump() { }
53
54 private:
55  virtual void DidAccept(net::StreamListenSocket* server,
56                         scoped_ptr<net::StreamListenSocket> socket) OVERRIDE {
57    if (accepted_socket_.get())
58      return;
59
60    buffer_ = new net::IOBuffer(kBufferSize);
61    wire_buffer_ = new net::GrowableIOBuffer();
62    wire_buffer_->SetCapacity(kBufferSize);
63
64    accepted_socket_ = socket.Pass();
65    int result = client_socket_->Read(
66        buffer_.get(),
67        kBufferSize,
68        base::Bind(&SocketPump::OnClientRead, base::Unretained(this)));
69    if (result != net::ERR_IO_PENDING)
70      OnClientRead(result);
71  }
72
73  virtual void DidRead(net::StreamListenSocket* socket,
74                       const char* data,
75                       int len) OVERRIDE {
76    int old_size = wire_buffer_size_;
77    wire_buffer_size_ += len;
78    while (wire_buffer_->capacity() < wire_buffer_size_)
79      wire_buffer_->SetCapacity(wire_buffer_->capacity() * 2);
80    memcpy(wire_buffer_->StartOfBuffer() + old_size, data, len);
81    if (old_size != wire_buffer_->offset())
82      return;
83    OnClientWrite(0);
84  }
85
86  virtual void DidClose(net::StreamListenSocket* socket) OVERRIDE {
87    SelfDestruct();
88  }
89
90  void OnClientRead(int result) {
91    if (result <= 0) {
92      SelfDestruct();
93      return;
94    }
95
96    accepted_socket_->Send(buffer_->data(), result);
97    result = client_socket_->Read(
98        buffer_.get(),
99        kBufferSize,
100        base::Bind(&SocketPump::OnClientRead, base::Unretained(this)));
101    if (result != net::ERR_IO_PENDING)
102      OnClientRead(result);
103  }
104
105  void OnClientWrite(int result) {
106    if (result < 0) {
107      SelfDestruct();
108      return;
109    }
110
111    wire_buffer_->set_offset(wire_buffer_->offset() + result);
112
113    int remaining = wire_buffer_size_ - wire_buffer_->offset();
114    if (remaining == 0) {
115      if (pending_destruction_)
116        SelfDestruct();
117      return;
118    }
119
120
121    if (remaining > kBufferSize)
122      remaining = kBufferSize;
123
124    scoped_refptr<net::IOBuffer> buffer = new net::IOBuffer(remaining);
125    memcpy(buffer->data(), wire_buffer_->data(), remaining);
126    result = client_socket_->Write(
127        buffer.get(),
128        remaining,
129        base::Bind(&SocketPump::OnClientWrite, base::Unretained(this)));
130
131    // Shrink buffer
132    int offset = wire_buffer_->offset();
133    if (offset > kBufferSize) {
134      memcpy(wire_buffer_->StartOfBuffer(), wire_buffer_->data(),
135          wire_buffer_size_ - offset);
136      wire_buffer_size_ -= offset;
137      wire_buffer_->set_offset(0);
138    }
139
140    if (result != net::ERR_IO_PENDING)
141      OnClientWrite(result);
142    return;
143  }
144
145  void SelfDestruct() {
146    if (wire_buffer_.get() && wire_buffer_->offset() != wire_buffer_size_) {
147      pending_destruction_ = true;
148      return;
149    }
150    delete this;
151  }
152
153 private:
154  scoped_ptr<net::StreamSocket> client_socket_;
155  scoped_ptr<net::StreamListenSocket> server_socket_;
156  scoped_ptr<net::StreamListenSocket> accepted_socket_;
157  scoped_refptr<net::IOBuffer> buffer_;
158  scoped_refptr<net::GrowableIOBuffer> wire_buffer_;
159  DevToolsHttpHandlerDelegate* delegate_;
160  int wire_buffer_size_;
161  bool pending_destruction_;
162};
163
164}  // namespace
165
166class TetheringHandler::BoundSocket {
167 public:
168  BoundSocket(TetheringHandler* handler,
169              DevToolsHttpHandlerDelegate* delegate)
170      : handler_(handler),
171        delegate_(delegate),
172        socket_(new net::TCPServerSocket(NULL, net::NetLog::Source())),
173        port_(0) {
174  }
175
176  virtual ~BoundSocket() {
177  }
178
179  bool Listen(int port) {
180    port_ = port;
181    net::IPAddressNumber ip_number;
182    if (!net::ParseIPLiteralToNumber(kLocalhost, &ip_number))
183      return false;
184
185    net::IPEndPoint end_point(ip_number, port);
186    int result = socket_->Listen(end_point, kListenBacklog);
187    if (result < 0)
188      return false;
189
190    net::IPEndPoint local_address;
191    result = socket_->GetLocalAddress(&local_address);
192    if (result < 0)
193      return false;
194
195    DoAccept();
196    return true;
197  }
198
199 private:
200  typedef std::map<net::IPEndPoint, net::StreamSocket*> AcceptedSocketsMap;
201
202  void DoAccept() {
203    while (true) {
204      int result = socket_->Accept(
205          &accept_socket_,
206          base::Bind(&BoundSocket::OnAccepted, base::Unretained(this)));
207      if (result == net::ERR_IO_PENDING)
208        break;
209      else
210        HandleAcceptResult(result);
211    }
212  }
213
214  void OnAccepted(int result) {
215    HandleAcceptResult(result);
216    if (result == net::OK)
217      DoAccept();
218  }
219
220  void HandleAcceptResult(int result) {
221    if (result != net::OK)
222      return;
223
224    SocketPump* pump = new SocketPump(delegate_, accept_socket_.release());
225    std::string name = pump->Init();
226    if (!name.empty())
227      handler_->Accepted(port_, name);
228  }
229
230  TetheringHandler* handler_;
231  DevToolsHttpHandlerDelegate* delegate_;
232  scoped_ptr<net::ServerSocket> socket_;
233  scoped_ptr<net::StreamSocket> accept_socket_;
234  int port_;
235};
236
237TetheringHandler::TetheringHandler(DevToolsHttpHandlerDelegate* delegate)
238    : delegate_(delegate) {
239  RegisterCommandHandler(devtools::Tethering::bind::kName,
240                         base::Bind(&TetheringHandler::OnBind,
241                                    base::Unretained(this)));
242  RegisterCommandHandler(devtools::Tethering::unbind::kName,
243                         base::Bind(&TetheringHandler::OnUnbind,
244                                    base::Unretained(this)));
245}
246
247TetheringHandler::~TetheringHandler() {
248  STLDeleteContainerPairSecondPointers(bound_sockets_.begin(),
249                                       bound_sockets_.end());
250}
251
252void TetheringHandler::Accepted(int port, const std::string& name) {
253  base::DictionaryValue* params = new base::DictionaryValue();
254  params->SetInteger(devtools::Tethering::accepted::kParamPort, port);
255  params->SetString(devtools::Tethering::accepted::kParamConnectionId, name);
256  SendNotification(devtools::Tethering::accepted::kName, params);
257}
258
259static int GetPort(scoped_refptr<DevToolsProtocol::Command> command,
260                   const std::string& paramName) {
261  base::DictionaryValue* params = command->params();
262  int port = 0;
263  if (!params ||
264      !params->GetInteger(paramName, &port) ||
265      port < kMinTetheringPort || port > kMaxTetheringPort)
266    return 0;
267  return port;
268}
269
270scoped_refptr<DevToolsProtocol::Response>
271TetheringHandler::OnBind(scoped_refptr<DevToolsProtocol::Command> command) {
272  const std::string& portParamName = devtools::Tethering::bind::kParamPort;
273  int port = GetPort(command, portParamName);
274  if (port == 0)
275    return command->InvalidParamResponse(portParamName);
276
277  if (bound_sockets_.find(port) != bound_sockets_.end())
278    return command->InternalErrorResponse("Port already bound");
279
280  scoped_ptr<BoundSocket> bound_socket(new BoundSocket(this, delegate_));
281  if (!bound_socket->Listen(port))
282    return command->InternalErrorResponse("Could not bind port");
283
284  bound_sockets_[port] = bound_socket.release();
285  return command->SuccessResponse(NULL);
286}
287
288scoped_refptr<DevToolsProtocol::Response>
289TetheringHandler::OnUnbind(scoped_refptr<DevToolsProtocol::Command> command) {
290  const std::string& portParamName = devtools::Tethering::unbind::kParamPort;
291  int port = GetPort(command, portParamName);
292  if (port == 0)
293    return command->InvalidParamResponse(portParamName);
294
295  BoundSockets::iterator it = bound_sockets_.find(port);
296  if (it == bound_sockets_.end())
297    return command->InternalErrorResponse("Port is not bound");
298
299  delete it->second;
300  bound_sockets_.erase(it);
301  return command->SuccessResponse(NULL);
302}
303
304}  // namespace content
305