port_forwarding_controller.cc revision 6e8cce623b6e4fe0c9e4af605d675dd9d0338c38
1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/devtools/device/port_forwarding_controller.h"
6
7#include <algorithm>
8#include <map>
9
10#include "base/bind.h"
11#include "base/compiler_specific.h"
12#include "base/memory/singleton.h"
13#include "base/message_loop/message_loop.h"
14#include "base/prefs/pref_service.h"
15#include "base/strings/string_number_conversions.h"
16#include "base/strings/string_util.h"
17#include "base/strings/stringprintf.h"
18#include "base/threading/non_thread_safe.h"
19#include "chrome/browser/devtools/devtools_protocol.h"
20#include "chrome/browser/profiles/profile.h"
21#include "chrome/common/pref_names.h"
22#include "components/keyed_service/content/browser_context_dependency_manager.h"
23#include "content/public/browser/browser_thread.h"
24#include "net/base/address_list.h"
25#include "net/base/io_buffer.h"
26#include "net/base/net_errors.h"
27#include "net/base/net_util.h"
28#include "net/dns/host_resolver.h"
29#include "net/socket/tcp_client_socket.h"
30
31using content::BrowserThread;
32
33namespace {
34
35const int kBufferSize = 16 * 1024;
36
37enum {
38  kStatusError = -3,
39  kStatusDisconnecting = -2,
40  kStatusConnecting = -1,
41  kStatusOK = 0,
42  // Positive values are used to count open connections.
43};
44
45static const char kPortAttribute[] = "port";
46static const char kConnectionIdAttribute[] = "connectionId";
47static const char kTetheringAccepted[] = "Tethering.accepted";
48static const char kTetheringBind[] = "Tethering.bind";
49static const char kTetheringUnbind[] = "Tethering.unbind";
50
51static const char kDevToolsRemoteBrowserTarget[] = "/devtools/browser";
52const int kMinVersionPortForwarding = 28;
53
54class SocketTunnel : public base::NonThreadSafe {
55 public:
56  typedef base::Callback<void(int)> CounterCallback;
57
58  static void StartTunnel(const std::string& host,
59                          int port,
60                          const CounterCallback& callback,
61                          int result,
62                          scoped_ptr<net::StreamSocket> socket) {
63    if (result < 0)
64      return;
65    SocketTunnel* tunnel = new SocketTunnel(callback);
66    tunnel->Start(socket.Pass(), host, port);
67  }
68
69 private:
70  explicit SocketTunnel(const CounterCallback& callback)
71      : pending_writes_(0),
72        pending_destruction_(false),
73        callback_(callback),
74        about_to_destroy_(false) {
75    callback_.Run(1);
76  }
77
78  void Start(scoped_ptr<net::StreamSocket> socket,
79             const std::string& host, int port) {
80    remote_socket_.swap(socket);
81
82    host_resolver_ = net::HostResolver::CreateDefaultResolver(NULL);
83    net::HostResolver::RequestInfo request_info(net::HostPortPair(host, port));
84    int result = host_resolver_->Resolve(
85        request_info,
86        net::DEFAULT_PRIORITY,
87        &address_list_,
88        base::Bind(&SocketTunnel::OnResolved, base::Unretained(this)),
89        NULL,
90        net::BoundNetLog());
91    if (result != net::ERR_IO_PENDING)
92      OnResolved(result);
93  }
94
95  void OnResolved(int result) {
96    if (result < 0) {
97      SelfDestruct();
98      return;
99    }
100
101    host_socket_.reset(new net::TCPClientSocket(address_list_, NULL,
102                                                net::NetLog::Source()));
103    result = host_socket_->Connect(base::Bind(&SocketTunnel::OnConnected,
104                                              base::Unretained(this)));
105    if (result != net::ERR_IO_PENDING)
106      OnConnected(result);
107  }
108
109  ~SocketTunnel() {
110    about_to_destroy_ = true;
111    if (host_socket_)
112      host_socket_->Disconnect();
113    if (remote_socket_)
114      remote_socket_->Disconnect();
115    callback_.Run(-1);
116  }
117
118  void OnConnected(int result) {
119    if (result < 0) {
120      SelfDestruct();
121      return;
122    }
123
124    ++pending_writes_; // avoid SelfDestruct in first Pump
125    Pump(host_socket_.get(), remote_socket_.get());
126    --pending_writes_;
127    if (pending_destruction_) {
128      SelfDestruct();
129    } else {
130      Pump(remote_socket_.get(), host_socket_.get());
131    }
132  }
133
134  void Pump(net::StreamSocket* from, net::StreamSocket* to) {
135    scoped_refptr<net::IOBuffer> buffer = new net::IOBuffer(kBufferSize);
136    int result = from->Read(
137        buffer.get(),
138        kBufferSize,
139        base::Bind(
140            &SocketTunnel::OnRead, base::Unretained(this), from, to, buffer));
141    if (result != net::ERR_IO_PENDING)
142      OnRead(from, to, buffer, result);
143  }
144
145  void OnRead(net::StreamSocket* from,
146              net::StreamSocket* to,
147              scoped_refptr<net::IOBuffer> buffer,
148              int result) {
149    if (result <= 0) {
150      SelfDestruct();
151      return;
152    }
153
154    int total = result;
155    scoped_refptr<net::DrainableIOBuffer> drainable =
156        new net::DrainableIOBuffer(buffer.get(), total);
157
158    ++pending_writes_;
159    result = to->Write(drainable.get(),
160                       total,
161                       base::Bind(&SocketTunnel::OnWritten,
162                                  base::Unretained(this),
163                                  drainable,
164                                  from,
165                                  to));
166    if (result != net::ERR_IO_PENDING)
167      OnWritten(drainable, from, to, result);
168  }
169
170  void OnWritten(scoped_refptr<net::DrainableIOBuffer> drainable,
171                 net::StreamSocket* from,
172                 net::StreamSocket* to,
173                 int result) {
174    --pending_writes_;
175    if (result < 0) {
176      SelfDestruct();
177      return;
178    }
179
180    drainable->DidConsume(result);
181    if (drainable->BytesRemaining() > 0) {
182      ++pending_writes_;
183      result = to->Write(drainable.get(),
184                         drainable->BytesRemaining(),
185                         base::Bind(&SocketTunnel::OnWritten,
186                                    base::Unretained(this),
187                                    drainable,
188                                    from,
189                                    to));
190      if (result != net::ERR_IO_PENDING)
191        OnWritten(drainable, from, to, result);
192      return;
193    }
194
195    if (pending_destruction_) {
196      SelfDestruct();
197      return;
198    }
199    Pump(from, to);
200  }
201
202  void SelfDestruct() {
203    // In case one of the connections closes, we could get here
204    // from another one due to Disconnect firing back on all
205    // read callbacks.
206    if (about_to_destroy_)
207      return;
208    if (pending_writes_ > 0) {
209      pending_destruction_ = true;
210      return;
211    }
212    delete this;
213  }
214
215  scoped_ptr<net::StreamSocket> remote_socket_;
216  scoped_ptr<net::StreamSocket> host_socket_;
217  scoped_ptr<net::HostResolver> host_resolver_;
218  net::AddressList address_list_;
219  int pending_writes_;
220  bool pending_destruction_;
221  CounterCallback callback_;
222  bool about_to_destroy_;
223};
224
225typedef DevToolsAndroidBridge::RemoteBrowser::ParsedVersion ParsedVersion;
226
227static bool IsVersionLower(const ParsedVersion& left,
228                           const ParsedVersion& right) {
229  return std::lexicographical_compare(
230    left.begin(), left.end(), right.begin(), right.end());
231}
232
233static bool IsPortForwardingSupported(const ParsedVersion& version) {
234  return !version.empty() && version[0] >= kMinVersionPortForwarding;
235}
236
237static scoped_refptr<DevToolsAndroidBridge::RemoteBrowser>
238FindBestBrowserForTethering(
239    const DevToolsAndroidBridge::RemoteBrowsers browsers) {
240  scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> best_browser;
241  ParsedVersion newest_version;
242  for (DevToolsAndroidBridge::RemoteBrowsers::const_iterator it =
243      browsers.begin(); it != browsers.end(); ++it) {
244    scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> browser = *it;
245    ParsedVersion current_version = browser->GetParsedVersion();
246    if (IsPortForwardingSupported(current_version) &&
247        IsVersionLower(newest_version, current_version)) {
248      best_browser = browser;
249      newest_version = current_version;
250    }
251  }
252  return best_browser;
253}
254
255}  // namespace
256
257class PortForwardingController::Connection
258    : public DevToolsAndroidBridge::AndroidWebSocket::Delegate {
259 public:
260  Connection(Registry* registry,
261             scoped_refptr<DevToolsAndroidBridge::RemoteDevice> device,
262             scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> browser,
263             const ForwardingMap& forwarding_map);
264  virtual ~Connection();
265
266  const PortStatusMap& GetPortStatusMap();
267
268  void UpdateForwardingMap(const ForwardingMap& new_forwarding_map);
269
270  void Shutdown();
271
272 private:
273  friend struct content::BrowserThread::DeleteOnThread<
274      content::BrowserThread::UI>;
275  friend class base::DeleteHelper<Connection>;
276
277
278  typedef std::map<int, std::string> ForwardingMap;
279
280  typedef base::Callback<void(PortStatus)> CommandCallback;
281  typedef std::map<int, CommandCallback> CommandCallbackMap;
282
283  void SerializeChanges(const std::string& method,
284                        const ForwardingMap& old_map,
285                        const ForwardingMap& new_map);
286
287  void SendCommand(const std::string& method, int port);
288  bool ProcessResponse(const std::string& json);
289
290  void ProcessBindResponse(int port, PortStatus status);
291  void ProcessUnbindResponse(int port, PortStatus status);
292
293  static void UpdateSocketCountOnHandlerThread(
294      base::WeakPtr<Connection> weak_connection, int port, int increment);
295  void UpdateSocketCount(int port, int increment);
296
297  // DevToolsAndroidBridge::AndroidWebSocket::Delegate implementation:
298  virtual void OnSocketOpened() OVERRIDE;
299  virtual void OnFrameRead(const std::string& message) OVERRIDE;
300  virtual void OnSocketClosed() OVERRIDE;
301
302  PortForwardingController::Registry* registry_;
303  scoped_refptr<DevToolsAndroidBridge::RemoteDevice> device_;
304  scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> browser_;
305  scoped_ptr<DevToolsAndroidBridge::AndroidWebSocket> web_socket_;
306  int command_id_;
307  bool connected_;
308  ForwardingMap forwarding_map_;
309  CommandCallbackMap pending_responses_;
310  PortStatusMap port_status_;
311  base::WeakPtrFactory<Connection> weak_factory_;
312
313  DISALLOW_COPY_AND_ASSIGN(Connection);
314};
315
316PortForwardingController::Connection::Connection(
317    Registry* registry,
318    scoped_refptr<DevToolsAndroidBridge::RemoteDevice> device,
319    scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> browser,
320    const ForwardingMap& forwarding_map)
321    : registry_(registry),
322      device_(device),
323      browser_(browser),
324      command_id_(0),
325      connected_(false),
326      forwarding_map_(forwarding_map),
327      weak_factory_(this) {
328  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
329  (*registry_)[device_->serial()] = this;
330  web_socket_.reset(
331      browser->CreateWebSocket(kDevToolsRemoteBrowserTarget, this));
332}
333
334PortForwardingController::Connection::~Connection() {
335  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
336  DCHECK(registry_->find(device_->serial()) != registry_->end());
337  registry_->erase(device_->serial());
338}
339
340void PortForwardingController::Connection::UpdateForwardingMap(
341    const ForwardingMap& new_forwarding_map) {
342  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
343  if (connected_) {
344    SerializeChanges(kTetheringUnbind, new_forwarding_map, forwarding_map_);
345    SerializeChanges(kTetheringBind, forwarding_map_, new_forwarding_map);
346  }
347  forwarding_map_ = new_forwarding_map;
348}
349
350void PortForwardingController::Connection::SerializeChanges(
351    const std::string& method,
352    const ForwardingMap& old_map,
353    const ForwardingMap& new_map) {
354  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
355  for (ForwardingMap::const_iterator new_it(new_map.begin());
356      new_it != new_map.end(); ++new_it) {
357    int port = new_it->first;
358    const std::string& location = new_it->second;
359    ForwardingMap::const_iterator old_it = old_map.find(port);
360    if (old_it != old_map.end() && old_it->second == location)
361      continue;  // The port points to the same location in both configs, skip.
362
363    SendCommand(method, port);
364  }
365}
366
367void PortForwardingController::Connection::SendCommand(
368    const std::string& method, int port) {
369  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
370  base::DictionaryValue params;
371  params.SetInteger(kPortAttribute, port);
372  DevToolsProtocol::Command command(++command_id_, method, &params);
373
374  if (method == kTetheringBind) {
375    pending_responses_[command.id()] =
376        base::Bind(&Connection::ProcessBindResponse,
377                   base::Unretained(this), port);
378#if defined(DEBUG_DEVTOOLS)
379    port_status_[port] = kStatusConnecting;
380#endif  // defined(DEBUG_DEVTOOLS)
381  } else {
382    DCHECK_EQ(kTetheringUnbind, method);
383
384    PortStatusMap::iterator it = port_status_.find(port);
385    if (it != port_status_.end() && it->second == kStatusError) {
386      // The bind command failed on this port, do not attempt unbind.
387      port_status_.erase(it);
388      return;
389    }
390
391    pending_responses_[command.id()] =
392        base::Bind(&Connection::ProcessUnbindResponse,
393                   base::Unretained(this), port);
394#if defined(DEBUG_DEVTOOLS)
395    port_status_[port] = kStatusDisconnecting;
396#endif  // defined(DEBUG_DEVTOOLS)
397  }
398
399  web_socket_->SendFrame(command.Serialize());
400}
401
402bool PortForwardingController::Connection::ProcessResponse(
403    const std::string& message) {
404  scoped_ptr<DevToolsProtocol::Response> response(
405      DevToolsProtocol::ParseResponse(message));
406  if (!response)
407    return false;
408
409  CommandCallbackMap::iterator it = pending_responses_.find(response->id());
410  if (it == pending_responses_.end())
411    return false;
412
413  it->second.Run(response->error_code() ? kStatusError : kStatusOK);
414  pending_responses_.erase(it);
415  return true;
416}
417
418void PortForwardingController::Connection::ProcessBindResponse(
419    int port, PortStatus status) {
420  port_status_[port] = status;
421}
422
423void PortForwardingController::Connection::ProcessUnbindResponse(
424    int port, PortStatus status) {
425  PortStatusMap::iterator it = port_status_.find(port);
426  if (it == port_status_.end())
427    return;
428  if (status == kStatusError)
429    it->second = status;
430  else
431    port_status_.erase(it);
432}
433
434// static
435void PortForwardingController::Connection::UpdateSocketCountOnHandlerThread(
436    base::WeakPtr<Connection> weak_connection, int port, int increment) {
437  BrowserThread::PostTask(BrowserThread::UI, FROM_HERE,
438     base::Bind(&Connection::UpdateSocketCount,
439                weak_connection, port, increment));
440}
441
442void PortForwardingController::Connection::UpdateSocketCount(
443    int port, int increment) {
444#if defined(DEBUG_DEVTOOLS)
445  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
446  PortStatusMap::iterator it = port_status_.find(port);
447  if (it == port_status_.end())
448    return;
449  if (it->second < 0 || (it->second == 0 && increment < 0))
450    return;
451  it->second += increment;
452#endif  // defined(DEBUG_DEVTOOLS)
453}
454
455const PortForwardingController::PortStatusMap&
456PortForwardingController::Connection::GetPortStatusMap() {
457  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
458  return port_status_;
459}
460
461void PortForwardingController::Connection::OnSocketOpened() {
462  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
463  connected_ = true;
464  SerializeChanges(kTetheringBind, ForwardingMap(), forwarding_map_);
465}
466
467void PortForwardingController::Connection::OnSocketClosed() {
468  delete this;
469}
470
471void PortForwardingController::Connection::OnFrameRead(
472    const std::string& message) {
473  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI));
474  if (ProcessResponse(message))
475    return;
476
477  scoped_ptr<DevToolsProtocol::Notification> notification(
478      DevToolsProtocol::ParseNotification(message));
479  if (!notification)
480    return;
481
482  if (notification->method() != kTetheringAccepted)
483    return;
484
485  base::DictionaryValue* params = notification->params();
486  if (!params)
487    return;
488
489  int port;
490  std::string connection_id;
491  if (!params->GetInteger(kPortAttribute, &port) ||
492      !params->GetString(kConnectionIdAttribute, &connection_id))
493    return;
494
495  std::map<int, std::string>::iterator it = forwarding_map_.find(port);
496  if (it == forwarding_map_.end())
497    return;
498
499  std::string location = it->second;
500  std::vector<std::string> tokens;
501  Tokenize(location, ":", &tokens);
502  int destination_port = 0;
503  if (tokens.size() != 2 || !base::StringToInt(tokens[1], &destination_port))
504    return;
505  std::string destination_host = tokens[0];
506
507  SocketTunnel::CounterCallback callback =
508      base::Bind(&Connection::UpdateSocketCountOnHandlerThread,
509                 weak_factory_.GetWeakPtr(), port);
510
511  device_->OpenSocket(
512      connection_id.c_str(),
513      base::Bind(&SocketTunnel::StartTunnel,
514                 destination_host,
515                 destination_port,
516                 callback));
517}
518
519PortForwardingController::PortForwardingController(Profile* profile)
520    : profile_(profile),
521      pref_service_(profile->GetPrefs()),
522      listening_(false) {
523  pref_change_registrar_.Init(pref_service_);
524  base::Closure callback = base::Bind(
525      &PortForwardingController::OnPrefsChange, base::Unretained(this));
526  pref_change_registrar_.Add(prefs::kDevToolsPortForwardingEnabled, callback);
527  pref_change_registrar_.Add(prefs::kDevToolsPortForwardingConfig, callback);
528  OnPrefsChange();
529}
530
531
532PortForwardingController::~PortForwardingController() {}
533
534void PortForwardingController::Shutdown() {
535  // Existing connection will not be shut down. This might be confusing for
536  // some users, but the opposite is more confusing.
537  StopListening();
538}
539
540void PortForwardingController::AddListener(Listener* listener) {
541  listeners_.push_back(listener);
542}
543
544void PortForwardingController::RemoveListener(Listener* listener) {
545  Listeners::iterator it =
546      std::find(listeners_.begin(), listeners_.end(), listener);
547  DCHECK(it != listeners_.end());
548  listeners_.erase(it);
549}
550
551void PortForwardingController::DeviceListChanged(
552    const DevToolsAndroidBridge::RemoteDevices& devices) {
553  DevicesStatus status;
554
555  for (DevToolsAndroidBridge::RemoteDevices::const_iterator it =
556      devices.begin(); it != devices.end(); ++it) {
557    scoped_refptr<DevToolsAndroidBridge::RemoteDevice> device = *it;
558    if (!device->is_connected())
559      continue;
560    Registry::iterator rit = registry_.find(device->serial());
561    if (rit == registry_.end()) {
562      scoped_refptr<DevToolsAndroidBridge::RemoteBrowser> browser =
563          FindBestBrowserForTethering(device->browsers());
564      if (browser) {
565        new Connection(&registry_, device, browser, forwarding_map_);
566      }
567    } else {
568      status[device->serial()] = (*rit).second->GetPortStatusMap();
569    }
570  }
571
572  NotifyListeners(status);
573}
574
575void PortForwardingController::OnPrefsChange() {
576  forwarding_map_.clear();
577
578  if (pref_service_->GetBoolean(prefs::kDevToolsPortForwardingEnabled)) {
579    const base::DictionaryValue* dict =
580        pref_service_->GetDictionary(prefs::kDevToolsPortForwardingConfig);
581    for (base::DictionaryValue::Iterator it(*dict);
582         !it.IsAtEnd(); it.Advance()) {
583      int port_num;
584      std::string location;
585      if (base::StringToInt(it.key(), &port_num) &&
586          dict->GetString(it.key(), &location))
587        forwarding_map_[port_num] = location;
588    }
589  }
590
591  if (!forwarding_map_.empty()) {
592    StartListening();
593    UpdateConnections();
594  } else {
595    StopListening();
596    STLDeleteValues(&registry_);
597    NotifyListeners(DevicesStatus());
598  }
599}
600
601void PortForwardingController::StartListening() {
602  if (listening_)
603    return;
604  listening_ = true;
605  DevToolsAndroidBridge* android_bridge =
606      DevToolsAndroidBridge::Factory::GetForProfile(profile_);
607  if (android_bridge)
608    android_bridge->AddDeviceListListener(this);
609
610}
611
612void PortForwardingController::StopListening() {
613  if (!listening_)
614    return;
615  listening_ = false;
616  DevToolsAndroidBridge* android_bridge =
617      DevToolsAndroidBridge::Factory::GetForProfile(profile_);
618  if (android_bridge)
619    android_bridge->RemoveDeviceListListener(this);
620}
621
622void PortForwardingController::UpdateConnections() {
623  for (Registry::iterator it = registry_.begin(); it != registry_.end(); ++it)
624    it->second->UpdateForwardingMap(forwarding_map_);
625}
626
627void PortForwardingController::NotifyListeners(
628    const DevicesStatus& status) const {
629  Listeners copy(listeners_);  // Iterate over copy.
630  for (Listeners::const_iterator it = copy.begin(); it != copy.end(); ++it)
631    (*it)->PortStatusChanged(status);
632}
633
634// static
635PortForwardingController::Factory*
636PortForwardingController::Factory::GetInstance() {
637  return Singleton<PortForwardingController::Factory>::get();
638}
639
640// static
641PortForwardingController* PortForwardingController::Factory::GetForProfile(
642    Profile* profile) {
643  return static_cast<PortForwardingController*>(GetInstance()->
644          GetServiceForBrowserContext(profile, true));
645}
646
647PortForwardingController::Factory::Factory()
648    : BrowserContextKeyedServiceFactory(
649          "PortForwardingController",
650          BrowserContextDependencyManager::GetInstance()) {}
651
652PortForwardingController::Factory::~Factory() {}
653
654KeyedService* PortForwardingController::Factory::BuildServiceInstanceFor(
655    content::BrowserContext* context) const {
656  Profile* profile = Profile::FromBrowserContext(context);
657  return new PortForwardingController(profile);
658}
659