1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/devtools/device/port_forwarding_controller.h"
6
7#include <algorithm>
8#include <map>
9
10#include "base/bind.h"
11#include "base/compiler_specific.h"
12#include "base/memory/singleton.h"
13#include "base/message_loop/message_loop.h"
14#include "base/prefs/pref_service.h"
15#include "base/strings/string_number_conversions.h"
16#include "base/strings/string_util.h"
17#include "base/strings/stringprintf.h"
18#include "base/threading/non_thread_safe.h"
19#include "chrome/browser/devtools/devtools_protocol.h"
20#include "chrome/browser/devtools/devtools_protocol_constants.h"
21#include "chrome/browser/profiles/profile.h"
22#include "chrome/common/pref_names.h"
23#include "components/keyed_service/content/browser_context_dependency_manager.h"
24#include "content/public/browser/browser_thread.h"
25#include "net/base/address_list.h"
26#include "net/base/io_buffer.h"
27#include "net/base/net_errors.h"
28#include "net/base/net_util.h"
29#include "net/dns/host_resolver.h"
30#include "net/socket/tcp_client_socket.h"
31
32using content::BrowserThread;
33
34namespace {
35
36const int kBufferSize = 16 * 1024;
37
38enum {
39  kStatusError = -3,
40  kStatusDisconnecting = -2,
41  kStatusConnecting = -1,
42  kStatusOK = 0,
43  // Positive values are used to count open connections.
44};
45
46namespace tethering = ::chrome::devtools::Tethering;
47
48static const char kDevToolsRemoteBrowserTarget[] = "/devtools/browser";
49const int kMinVersionPortForwarding = 28;
50
51class SocketTunnel : public base::NonThreadSafe {
52 public:
53  typedef base::Callback<void(int)> CounterCallback;
54
55  static void StartTunnel(const std::string& host,
56                          int port,
57                          const CounterCallback& callback,
58                          int result,
59                          scoped_ptr<net::StreamSocket> socket) {
60    if (result < 0)
61      return;
62    SocketTunnel* tunnel = new SocketTunnel(callback);
63    tunnel->Start(socket.Pass(), host, port);
64  }
65
66 private:
67  explicit SocketTunnel(const CounterCallback& callback)
68      : pending_writes_(0),
69        pending_destruction_(false),
70        callback_(callback),
71        about_to_destroy_(false) {
72    callback_.Run(1);
73  }
74
75  void Start(scoped_ptr<net::StreamSocket> socket,
76             const std::string& host, int port) {
77    remote_socket_.swap(socket);
78
79    host_resolver_ = net::HostResolver::CreateDefaultResolver(NULL);
80    net::HostResolver::RequestInfo request_info(net::HostPortPair(host, port));
81    int result = host_resolver_->Resolve(
82        request_info,
83        net::DEFAULT_PRIORITY,
84        &address_list_,
85        base::Bind(&SocketTunnel::OnResolved, base::Unretained(this)),
86        NULL,
87        net::BoundNetLog());
88    if (result != net::ERR_IO_PENDING)
89      OnResolved(result);
90  }
91
92  void OnResolved(int result) {
93    if (result < 0) {
94      SelfDestruct();
95      return;
96    }
97
98    host_socket_.reset(new net::TCPClientSocket(address_list_, NULL,
99                                                net::NetLog::Source()));
100    result = host_socket_->Connect(base::Bind(&SocketTunnel::OnConnected,
101                                              base::Unretained(this)));
102    if (result != net::ERR_IO_PENDING)
103      OnConnected(result);
104  }
105
106  ~SocketTunnel() {
107    about_to_destroy_ = true;
108    if (host_socket_)
109      host_socket_->Disconnect();
110    if (remote_socket_)
111      remote_socket_->Disconnect();
112    callback_.Run(-1);
113  }
114
115  void OnConnected(int result) {
116    if (result < 0) {
117      SelfDestruct();
118      return;
119    }
120
121    ++pending_writes_; // avoid SelfDestruct in first Pump
122    Pump(host_socket_.get(), remote_socket_.get());
123    --pending_writes_;
124    if (pending_destruction_) {
125      SelfDestruct();
126    } else {
127      Pump(remote_socket_.get(), host_socket_.get());
128    }
129  }
130
131  void Pump(net::StreamSocket* from, net::StreamSocket* to) {
132    scoped_refptr<net::IOBuffer> buffer = new net::IOBuffer(kBufferSize);
133    int result = from->Read(
134        buffer.get(),
135        kBufferSize,
136        base::Bind(
137            &SocketTunnel::OnRead, base::Unretained(this), from, to, buffer));
138    if (result != net::ERR_IO_PENDING)
139      OnRead(from, to, buffer, result);
140  }
141
142  void OnRead(net::StreamSocket* from,
143              net::StreamSocket* to,
144              scoped_refptr<net::IOBuffer> buffer,
145              int result) {
146    if (result <= 0) {
147      SelfDestruct();
148      return;
149    }
150
151    int total = result;
152    scoped_refptr<net::DrainableIOBuffer> drainable =
153        new net::DrainableIOBuffer(buffer.get(), total);
154
155    ++pending_writes_;
156    result = to->Write(drainable.get(),
157                       total,
158                       base::Bind(&SocketTunnel::OnWritten,
159                                  base::Unretained(this),
160                                  drainable,
161                                  from,
162                                  to));
163    if (result != net::ERR_IO_PENDING)
164      OnWritten(drainable, from, to, result);
165  }
166
167  void OnWritten(scoped_refptr<net::DrainableIOBuffer> drainable,
168                 net::StreamSocket* from,
169                 net::StreamSocket* to,
170                 int result) {
171    --pending_writes_;
172    if (result < 0) {
173      SelfDestruct();
174      return;
175    }
176
177    drainable->DidConsume(result);
178    if (drainable->BytesRemaining() > 0) {
179      ++pending_writes_;
180      result = to->Write(drainable.get(),
181                         drainable->BytesRemaining(),
182                         base::Bind(&SocketTunnel::OnWritten,
183                                    base::Unretained(this),
184                                    drainable,
185                                    from,
186                                    to));
187      if (result != net::ERR_IO_PENDING)
188        OnWritten(drainable, from, to, result);
189      return;
190    }
191
192    if (pending_destruction_) {
193      SelfDestruct();
194      return;
195    }
196    Pump(from, to);
197  }
198
199  void SelfDestruct() {
200    // In case one of the connections closes, we could get here
201    // from another one due to Disconnect firing back on all
202    // read callbacks.
203    if (about_to_destroy_)
204      return;
205    if (pending_writes_ > 0) {
206      pending_destruction_ = true;
207      return;
208    }
209    delete this;
210  }
211
212  scoped_ptr<net::StreamSocket> remote_socket_;
213  scoped_ptr<net::StreamSocket> host_socket_;
214  scoped_ptr<net::HostResolver> host_resolver_;
215  net::AddressList address_list_;
216  int pending_writes_;
217  bool pending_destruction_;
218  CounterCallback callback_;
219  bool about_to_destroy_;
220};
221
222typedef DevToolsAndroidBridge::RemoteBrowser::ParsedVersion ParsedVersion;
223
224static bool IsVersionLower(const ParsedVersion& left,
225                           const ParsedVersion& right) {
226  return std::lexicographical_compare(
227    left.begin(), left.end(), right.begin(), right.end());
228}
229
230static bool IsPortForwardingSupported(const ParsedVersion& version) {
231  return !version.empty() && version[0] >= kMinVersionPortForwarding;
232}
233
234static scoped_refptr<DevToolsAndroidBridge::RemoteBrowser>
235FindBestBrowserForTethering(
236    const DevToolsAndroidBridge::RemoteBrowsers browsers) {
237  scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> best_browser;
238  ParsedVersion newest_version;
239  for (DevToolsAndroidBridge::RemoteBrowsers::const_iterator it =
240      browsers.begin(); it != browsers.end(); ++it) {
241    scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> browser = *it;
242    ParsedVersion current_version = browser->GetParsedVersion();
243    if (IsPortForwardingSupported(current_version) &&
244        IsVersionLower(newest_version, current_version)) {
245      best_browser = browser;
246      newest_version = current_version;
247    }
248  }
249  return best_browser;
250}
251
252}  // namespace
253
254class PortForwardingController::Connection
255    : public DevToolsAndroidBridge::AndroidWebSocket::Delegate {
256 public:
257  Connection(Registry* registry,
258             scoped_refptr<DevToolsAndroidBridge::RemoteDevice> device,
259             scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> browser,
260             const ForwardingMap& forwarding_map);
261  virtual ~Connection();
262
263  const PortStatusMap& GetPortStatusMap();
264
265  void UpdateForwardingMap(const ForwardingMap& new_forwarding_map);
266
267 private:
268  friend struct content::BrowserThread::DeleteOnThread<
269      content::BrowserThread::UI>;
270  friend class base::DeleteHelper<Connection>;
271
272
273  typedef std::map<int, std::string> ForwardingMap;
274
275  typedef base::Callback<void(PortStatus)> CommandCallback;
276  typedef std::map<int, CommandCallback> CommandCallbackMap;
277
278  void SerializeChanges(const std::string& method,
279                        const ForwardingMap& old_map,
280                        const ForwardingMap& new_map);
281
282  void SendCommand(const std::string& method, int port);
283  bool ProcessResponse(const std::string& json);
284
285  void ProcessBindResponse(int port, PortStatus status);
286  void ProcessUnbindResponse(int port, PortStatus status);
287
288  static void UpdateSocketCountOnHandlerThread(
289      base::WeakPtr<Connection> weak_connection, int port, int increment);
290  void UpdateSocketCount(int port, int increment);
291
292  // DevToolsAndroidBridge::AndroidWebSocket::Delegate implementation:
293  virtual void OnSocketOpened() OVERRIDE;
294  virtual void OnFrameRead(const std::string& message) OVERRIDE;
295  virtual void OnSocketClosed() OVERRIDE;
296
297  PortForwardingController::Registry* registry_;
298  scoped_refptr<DevToolsAndroidBridge::RemoteDevice> device_;
299  scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> browser_;
300  scoped_ptr<DevToolsAndroidBridge::AndroidWebSocket> web_socket_;
301  int command_id_;
302  bool connected_;
303  ForwardingMap forwarding_map_;
304  CommandCallbackMap pending_responses_;
305  PortStatusMap port_status_;
306  base::WeakPtrFactory<Connection> weak_factory_;
307
308  DISALLOW_COPY_AND_ASSIGN(Connection);
309};
310
311PortForwardingController::Connection::Connection(
312    Registry* registry,
313    scoped_refptr<DevToolsAndroidBridge::RemoteDevice> device,
314    scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> browser,
315    const ForwardingMap& forwarding_map)
316    : registry_(registry),
317      device_(device),
318      browser_(browser),
319      command_id_(0),
320      connected_(false),
321      forwarding_map_(forwarding_map),
322      weak_factory_(this) {
323  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
324  (*registry_)[device_->serial()] = this;
325  web_socket_.reset(
326      browser->CreateWebSocket(kDevToolsRemoteBrowserTarget, this));
327}
328
329PortForwardingController::Connection::~Connection() {
330  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
331  DCHECK(registry_->find(device_->serial()) != registry_->end());
332  registry_->erase(device_->serial());
333}
334
335void PortForwardingController::Connection::UpdateForwardingMap(
336    const ForwardingMap& new_forwarding_map) {
337  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
338  if (connected_) {
339    SerializeChanges(tethering::unbind::kName,
340        new_forwarding_map, forwarding_map_);
341    SerializeChanges(tethering::bind::kName,
342        forwarding_map_, new_forwarding_map);
343  }
344  forwarding_map_ = new_forwarding_map;
345}
346
347void PortForwardingController::Connection::SerializeChanges(
348    const std::string& method,
349    const ForwardingMap& old_map,
350    const ForwardingMap& new_map) {
351  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
352  for (ForwardingMap::const_iterator new_it(new_map.begin());
353      new_it != new_map.end(); ++new_it) {
354    int port = new_it->first;
355    const std::string& location = new_it->second;
356    ForwardingMap::const_iterator old_it = old_map.find(port);
357    if (old_it != old_map.end() && old_it->second == location)
358      continue;  // The port points to the same location in both configs, skip.
359
360    SendCommand(method, port);
361  }
362}
363
364void PortForwardingController::Connection::SendCommand(
365    const std::string& method, int port) {
366  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
367  base::DictionaryValue params;
368  if (method == tethering::bind::kName) {
369    params.SetInteger(tethering::bind::kParamPort, port);
370  } else {
371    DCHECK_EQ(tethering::unbind::kName, method);
372    params.SetInteger(tethering::unbind::kParamPort, port);
373  }
374  DevToolsProtocol::Command command(++command_id_, method, &params);
375
376  if (method == tethering::bind::kName) {
377    pending_responses_[command.id()] =
378        base::Bind(&Connection::ProcessBindResponse,
379                   base::Unretained(this), port);
380#if defined(DEBUG_DEVTOOLS)
381    port_status_[port] = kStatusConnecting;
382#endif  // defined(DEBUG_DEVTOOLS)
383  } else {
384    PortStatusMap::iterator it = port_status_.find(port);
385    if (it != port_status_.end() && it->second == kStatusError) {
386      // The bind command failed on this port, do not attempt unbind.
387      port_status_.erase(it);
388      return;
389    }
390
391    pending_responses_[command.id()] =
392        base::Bind(&Connection::ProcessUnbindResponse,
393                   base::Unretained(this), port);
394#if defined(DEBUG_DEVTOOLS)
395    port_status_[port] = kStatusDisconnecting;
396#endif  // defined(DEBUG_DEVTOOLS)
397  }
398
399  web_socket_->SendFrame(command.Serialize());
400}
401
402bool PortForwardingController::Connection::ProcessResponse(
403    const std::string& message) {
404  scoped_ptr<DevToolsProtocol::Response> response(
405      DevToolsProtocol::ParseResponse(message));
406  if (!response)
407    return false;
408
409  CommandCallbackMap::iterator it = pending_responses_.find(response->id());
410  if (it == pending_responses_.end())
411    return false;
412
413  it->second.Run(response->error_code() ? kStatusError : kStatusOK);
414  pending_responses_.erase(it);
415  return true;
416}
417
418void PortForwardingController::Connection::ProcessBindResponse(
419    int port, PortStatus status) {
420  port_status_[port] = status;
421}
422
423void PortForwardingController::Connection::ProcessUnbindResponse(
424    int port, PortStatus status) {
425  PortStatusMap::iterator it = port_status_.find(port);
426  if (it == port_status_.end())
427    return;
428  if (status == kStatusError)
429    it->second = status;
430  else
431    port_status_.erase(it);
432}
433
434// static
435void PortForwardingController::Connection::UpdateSocketCountOnHandlerThread(
436    base::WeakPtr<Connection> weak_connection, int port, int increment) {
437  BrowserThread::PostTask(BrowserThread::UI, FROM_HERE,
438     base::Bind(&Connection::UpdateSocketCount,
439                weak_connection, port, increment));
440}
441
442void PortForwardingController::Connection::UpdateSocketCount(
443    int port, int increment) {
444#if defined(DEBUG_DEVTOOLS)
445  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
446  PortStatusMap::iterator it = port_status_.find(port);
447  if (it == port_status_.end())
448    return;
449  if (it->second < 0 || (it->second == 0 && increment < 0))
450    return;
451  it->second += increment;
452#endif  // defined(DEBUG_DEVTOOLS)
453}
454
455const PortForwardingController::PortStatusMap&
456PortForwardingController::Connection::GetPortStatusMap() {
457  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
458  return port_status_;
459}
460
461void PortForwardingController::Connection::OnSocketOpened() {
462  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
463  connected_ = true;
464  SerializeChanges(tethering::bind::kName, ForwardingMap(), forwarding_map_);
465}
466
467void PortForwardingController::Connection::OnSocketClosed() {
468  delete this;
469}
470
471void PortForwardingController::Connection::OnFrameRead(
472    const std::string& message) {
473  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
474  if (ProcessResponse(message))
475    return;
476
477  scoped_ptr<DevToolsProtocol::Notification> notification(
478      DevToolsProtocol::ParseNotification(message));
479  if (!notification)
480    return;
481
482  if (notification->method() != tethering::accepted::kName)
483    return;
484
485  base::DictionaryValue* params = notification->params();
486  if (!params)
487    return;
488
489  int port;
490  std::string connection_id;
491  if (!params->GetInteger(tethering::accepted::kParamPort, &port) ||
492      !params->GetString(tethering::accepted::kParamConnectionId,
493                         &connection_id))
494    return;
495
496  std::map<int, std::string>::iterator it = forwarding_map_.find(port);
497  if (it == forwarding_map_.end())
498    return;
499
500  std::string location = it->second;
501  std::vector<std::string> tokens;
502  Tokenize(location, ":", &tokens);
503  int destination_port = 0;
504  if (tokens.size() != 2 || !base::StringToInt(tokens[1], &destination_port))
505    return;
506  std::string destination_host = tokens[0];
507
508  SocketTunnel::CounterCallback callback =
509      base::Bind(&Connection::UpdateSocketCountOnHandlerThread,
510                 weak_factory_.GetWeakPtr(), port);
511
512  device_->OpenSocket(
513      connection_id.c_str(),
514      base::Bind(&SocketTunnel::StartTunnel,
515                 destination_host,
516                 destination_port,
517                 callback));
518}
519
520PortForwardingController::PortForwardingController(Profile* profile)
521    : profile_(profile),
522      pref_service_(profile->GetPrefs()) {
523  pref_change_registrar_.Init(pref_service_);
524  base::Closure callback = base::Bind(
525      &PortForwardingController::OnPrefsChange, base::Unretained(this));
526  pref_change_registrar_.Add(prefs::kDevToolsPortForwardingEnabled, callback);
527  pref_change_registrar_.Add(prefs::kDevToolsPortForwardingConfig, callback);
528  OnPrefsChange();
529}
530
531PortForwardingController::~PortForwardingController() {}
532
533PortForwardingController::DevicesStatus
534PortForwardingController::DeviceListChanged(
535    const DevToolsAndroidBridge::RemoteDevices& devices) {
536  DevicesStatus status;
537  if (forwarding_map_.empty())
538    return status;
539
540  for (DevToolsAndroidBridge::RemoteDevices::const_iterator it =
541      devices.begin(); it != devices.end(); ++it) {
542    scoped_refptr<DevToolsAndroidBridge::RemoteDevice> device = *it;
543    if (!device->is_connected())
544      continue;
545    Registry::iterator rit = registry_.find(device->serial());
546    if (rit == registry_.end()) {
547      scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> browser =
548          FindBestBrowserForTethering(device->browsers());
549      if (browser.get()) {
550        new Connection(&registry_, device, browser, forwarding_map_);
551      }
552    } else {
553      status[device->serial()] = (*rit).second->GetPortStatusMap();
554    }
555  }
556
557  return status;
558}
559
560void PortForwardingController::OnPrefsChange() {
561  forwarding_map_.clear();
562
563  if (pref_service_->GetBoolean(prefs::kDevToolsPortForwardingEnabled)) {
564    const base::DictionaryValue* dict =
565        pref_service_->GetDictionary(prefs::kDevToolsPortForwardingConfig);
566    for (base::DictionaryValue::Iterator it(*dict);
567         !it.IsAtEnd(); it.Advance()) {
568      int port_num;
569      std::string location;
570      if (base::StringToInt(it.key(), &port_num) &&
571          dict->GetString(it.key(), &location))
572        forwarding_map_[port_num] = location;
573    }
574  }
575
576  if (!forwarding_map_.empty()) {
577    UpdateConnections();
578  } else {
579    std::vector<Connection*> registry_copy;
580    for (Registry::iterator it = registry_.begin();
581        it != registry_.end(); ++it) {
582      registry_copy.push_back(it->second);
583    }
584    STLDeleteElements(&registry_copy);
585  }
586}
587
588void PortForwardingController::UpdateConnections() {
589  for (Registry::iterator it = registry_.begin(); it != registry_.end(); ++it)
590    it->second->UpdateForwardingMap(forwarding_map_);
591}
592