host_forwarder_main.cc revision a3f6a49ab37290eeeb8db0f41ec0f1cb74a68be7
1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <errno.h>
6#include <signal.h>
7#include <sys/types.h>
8#include <sys/wait.h>
9#include <unistd.h>
10
11#include <cstdio>
12#include <iostream>
13#include <limits>
14#include <string>
15#include <utility>
16#include <vector>
17
18#include "base/at_exit.h"
19#include "base/basictypes.h"
20#include "base/bind.h"
21#include "base/command_line.h"
22#include "base/compiler_specific.h"
23#include "base/containers/hash_tables.h"
24#include "base/file_util.h"
25#include "base/files/file_path.h"
26#include "base/logging.h"
27#include "base/memory/linked_ptr.h"
28#include "base/memory/scoped_vector.h"
29#include "base/memory/weak_ptr.h"
30#include "base/pickle.h"
31#include "base/safe_strerror_posix.h"
32#include "base/strings/string_number_conversions.h"
33#include "base/strings/string_piece.h"
34#include "base/strings/string_split.h"
35#include "base/strings/string_util.h"
36#include "base/strings/stringprintf.h"
37#include "base/task_runner.h"
38#include "base/threading/thread.h"
39#include "tools/android/forwarder2/common.h"
40#include "tools/android/forwarder2/daemon.h"
41#include "tools/android/forwarder2/host_controller.h"
42#include "tools/android/forwarder2/pipe_notifier.h"
43#include "tools/android/forwarder2/socket.h"
44#include "tools/android/forwarder2/util.h"
45
46namespace forwarder2 {
47namespace {
48
49const char kLogFilePath[] = "/tmp/host_forwarder_log";
50const char kDaemonIdentifier[] = "chrome_host_forwarder_daemon";
51
52const char kKillServerCommand[] = "kill-server";
53const char kForwardCommand[] = "forward";
54
55const int kBufSize = 256;
56
57// Needs to be global to be able to be accessed from the signal handler.
58PipeNotifier* g_notifier = NULL;
59
60// Lets the daemon fetch the exit notifier file descriptor.
61int GetExitNotifierFD() {
62  DCHECK(g_notifier);
63  return g_notifier->receiver_fd();
64}
65
66void KillHandler(int signal_number) {
67  char buf[kBufSize];
68  if (signal_number != SIGTERM && signal_number != SIGINT) {
69    snprintf(buf, sizeof(buf), "Ignoring unexpected signal %d.", signal_number);
70    SIGNAL_SAFE_LOG(WARNING, buf);
71    return;
72  }
73  snprintf(buf, sizeof(buf), "Received signal %d.", signal_number);
74  SIGNAL_SAFE_LOG(WARNING, buf);
75  static int s_kill_handler_count = 0;
76  CHECK(g_notifier);
77  // If for some reason the forwarder get stuck in any socket waiting forever,
78  // we can send a SIGKILL or SIGINT three times to force it die
79  // (non-nicely). This is useful when debugging.
80  ++s_kill_handler_count;
81  if (!g_notifier->Notify() || s_kill_handler_count > 2)
82    exit(1);
83}
84
85// Manages HostController instances. There is one HostController instance for
86// each connection being forwarded. Note that forwarding can happen with many
87// devices (identified with a serial id).
88class HostControllersManager {
89 public:
90  HostControllersManager()
91      : weak_ptr_factory_(this),
92        controllers_(new HostControllerMap()),
93        has_failed_(false) {
94  }
95
96  ~HostControllersManager() {
97    if (!thread_.get())
98      return;
99    // Delete the controllers on the thread they were created on.
100    thread_->message_loop_proxy()->DeleteSoon(
101        FROM_HERE, controllers_.release());
102  }
103
104  void HandleRequest(const std::string& device_serial,
105                     int device_port,
106                     int host_port,
107                     scoped_ptr<Socket> client_socket) {
108    // Lazy initialize so that the CLI process doesn't get this thread created.
109    InitOnce();
110    thread_->message_loop_proxy()->PostTask(
111        FROM_HERE,
112        base::Bind(
113            &HostControllersManager::HandleRequestOnInternalThread,
114            base::Unretained(this), device_serial, device_port, host_port,
115            base::Passed(&client_socket)));
116  }
117
118  bool has_failed() const { return has_failed_; }
119
120 private:
121  typedef base::hash_map<
122      std::string, linked_ptr<HostController> > HostControllerMap;
123
124  static std::string MakeHostControllerMapKey(int adb_port, int device_port) {
125    return base::StringPrintf("%d:%d", adb_port, device_port);
126  }
127
128  void InitOnce() {
129    if (thread_.get())
130      return;
131    at_exit_manager_.reset(new base::AtExitManager());
132    thread_.reset(new base::Thread("HostControllersManagerThread"));
133    thread_->Start();
134  }
135
136  // Invoked when a HostController instance reports an error (e.g. due to a
137  // device connectivity issue). Note that this could be called after the
138  // controller manager was destroyed which is why a weak pointer is used.
139  static void DeleteHostController(
140      const base::WeakPtr<HostControllersManager>& manager_ptr,
141      scoped_ptr<HostController> host_controller) {
142    HostController* const controller = host_controller.release();
143    HostControllersManager* const manager = manager_ptr.get();
144    if (!manager) {
145      // Note that |controller| is not leaked in this case since the host
146      // controllers manager owns the controllers. If the manager was deleted
147      // then all the controllers (including |controller|) were also deleted.
148      return;
149    }
150    DCHECK(manager->thread_->message_loop_proxy()->RunsTasksOnCurrentThread());
151    // Note that this will delete |controller| which is owned by the map.
152    DeleteRefCountedValueInMap(
153        MakeHostControllerMapKey(
154            controller->adb_port(), controller->device_port()),
155        manager->controllers_.get());
156  }
157
158  void HandleRequestOnInternalThread(const std::string& device_serial,
159                                     int device_port,
160                                     int host_port,
161                                     scoped_ptr<Socket> client_socket) {
162    const int adb_port = GetAdbPortForDevice(device_serial);
163    if (adb_port < 0) {
164      SendMessage(
165          "ERROR: could not get adb port for device. You might need to add "
166          "'adb' to your PATH or provide the device serial id.",
167          client_socket.get());
168      return;
169    }
170    if (device_port < 0) {
171      // Remove the previously created host controller.
172      const std::string controller_key = MakeHostControllerMapKey(
173          adb_port, -device_port);
174      const bool controller_did_exist = DeleteRefCountedValueInMap(
175          controller_key, controllers_.get());
176      SendMessage(
177          !controller_did_exist ? "ERROR: could not unmap port" : "OK",
178          client_socket.get());
179
180      RemoveAdbPortForDeviceIfNeeded(device_serial);
181      return;
182    }
183    if (host_port < 0) {
184      SendMessage("ERROR: missing host port", client_socket.get());
185      return;
186    }
187    const bool use_dynamic_port_allocation = device_port == 0;
188    if (!use_dynamic_port_allocation) {
189      const std::string controller_key = MakeHostControllerMapKey(
190          adb_port, device_port);
191      if (controllers_->find(controller_key) != controllers_->end()) {
192        LOG(INFO) << "Already forwarding device port " << device_port
193                  << " to host port " << host_port;
194        SendMessage(base::StringPrintf("%d:%d", device_port, host_port),
195                    client_socket.get());
196        return;
197      }
198    }
199    // Create a new host controller.
200    scoped_ptr<HostController> host_controller(
201        HostController::Create(
202            device_port, host_port, adb_port, GetExitNotifierFD(),
203            base::Bind(&HostControllersManager::DeleteHostController,
204                       weak_ptr_factory_.GetWeakPtr())));
205    if (!host_controller.get()) {
206      has_failed_ = true;
207      SendMessage("ERROR: Connection to device failed.", client_socket.get());
208      return;
209    }
210    // Get the current allocated port.
211    device_port = host_controller->device_port();
212    LOG(INFO) << "Forwarding device port " << device_port << " to host port "
213              << host_port;
214    const std::string msg = base::StringPrintf("%d:%d", device_port, host_port);
215    if (!SendMessage(msg, client_socket.get()))
216      return;
217    host_controller->Start();
218    controllers_->insert(
219        std::make_pair(MakeHostControllerMapKey(adb_port, device_port),
220                       linked_ptr<HostController>(host_controller.release())));
221  }
222
223  void RemoveAdbPortForDeviceIfNeeded(const std::string& device_serial) {
224    base::hash_map<std::string, int>::const_iterator it =
225        device_serial_to_adb_port_map_.find(device_serial);
226    if (it == device_serial_to_adb_port_map_.end())
227      return;
228
229    int port = it->second;
230    const std::string prefix = base::StringPrintf("%d:", port);
231    for (HostControllerMap::const_iterator others = controllers_->begin();
232         others != controllers_->end(); ++others) {
233      if (others->first.find(prefix) == 0U)
234        return;
235    }
236    // No other port is being forwarded to this device:
237    // - Remove it from our internal serial -> adb port map.
238    // - Remove from "adb forward" command.
239    LOG(INFO) << "Device " << device_serial << " has no more ports.";
240    device_serial_to_adb_port_map_.erase(device_serial);
241    const std::string serial_part = device_serial.empty() ?
242        std::string() : std::string("-s ") + device_serial;
243    const std::string command = base::StringPrintf(
244        "adb %s forward --remove tcp:%d",
245        serial_part.c_str(),
246        port);
247    const int ret = system(command.c_str());
248    LOG(INFO) << command << " ret: " << ret;
249    // Wait for the socket to be fully unmapped.
250    const std::string port_mapped_cmd = base::StringPrintf(
251        "lsof -nPi:%d",
252        port);
253    const int poll_interval_us = 500 * 1000;
254    int retries = 3;
255    while (retries) {
256      const int port_unmapped = system(port_mapped_cmd.c_str());
257      LOG(INFO) << "Device " << device_serial << " port " << port << " unmap "
258                << port_unmapped;
259      if (port_unmapped)
260        break;
261      --retries;
262      usleep(poll_interval_us);
263    }
264  }
265
266  int GetAdbPortForDevice(const std::string& device_serial) {
267    base::hash_map<std::string, int>::const_iterator it =
268        device_serial_to_adb_port_map_.find(device_serial);
269    if (it != device_serial_to_adb_port_map_.end())
270      return it->second;
271    Socket bind_socket;
272    CHECK(bind_socket.BindTcp("127.0.0.1", 0));
273    const int port = bind_socket.GetPort();
274    bind_socket.Close();
275    const std::string serial_part = device_serial.empty() ?
276        std::string() : std::string("-s ") + device_serial;
277    const std::string command = base::StringPrintf(
278        "adb %s forward tcp:%d localabstract:chrome_device_forwarder",
279        serial_part.c_str(),
280        port);
281    LOG(INFO) << command;
282    const int ret = system(command.c_str());
283    if (ret < 0 || !WIFEXITED(ret) || WEXITSTATUS(ret) != 0)
284      return -1;
285    device_serial_to_adb_port_map_[device_serial] = port;
286    return port;
287  }
288
289  bool SendMessage(const std::string& msg, Socket* client_socket) {
290    bool result = client_socket->WriteString(msg);
291    DCHECK(result);
292    if (!result)
293      has_failed_ = true;
294    return result;
295  }
296
297  base::WeakPtrFactory<HostControllersManager> weak_ptr_factory_;
298  base::hash_map<std::string, int> device_serial_to_adb_port_map_;
299  scoped_ptr<HostControllerMap> controllers_;
300  bool has_failed_;
301  scoped_ptr<base::AtExitManager> at_exit_manager_;  // Needed by base::Thread.
302  scoped_ptr<base::Thread> thread_;
303};
304
305class ServerDelegate : public Daemon::ServerDelegate {
306 public:
307  ServerDelegate() : has_failed_(false) {}
308
309  bool has_failed() const {
310    return has_failed_ || controllers_manager_.has_failed();
311  }
312
313  // Daemon::ServerDelegate:
314  virtual void Init() OVERRIDE {
315    LOG(INFO) << "Starting host process daemon (pid=" << getpid() << ")";
316    DCHECK(!g_notifier);
317    g_notifier = new PipeNotifier();
318    signal(SIGTERM, KillHandler);
319    signal(SIGINT, KillHandler);
320  }
321
322  virtual void OnClientConnected(scoped_ptr<Socket> client_socket) OVERRIDE {
323    char buf[kBufSize];
324    const int bytes_read = client_socket->Read(buf, sizeof(buf));
325    if (bytes_read <= 0) {
326      if (client_socket->DidReceiveEvent())
327        return;
328      PError("Read()");
329      has_failed_ = true;
330      return;
331    }
332    const Pickle command_pickle(buf, bytes_read);
333    PickleIterator pickle_it(command_pickle);
334    std::string device_serial;
335    CHECK(pickle_it.ReadString(&device_serial));
336    int device_port;
337    if (!pickle_it.ReadInt(&device_port)) {
338      client_socket->WriteString("ERROR: missing device port");
339      return;
340    }
341    int host_port;
342    if (!pickle_it.ReadInt(&host_port))
343      host_port = -1;
344    controllers_manager_.HandleRequest(
345        device_serial, device_port, host_port, client_socket.Pass());
346  }
347
348 private:
349  bool has_failed_;
350  HostControllersManager controllers_manager_;
351
352  DISALLOW_COPY_AND_ASSIGN(ServerDelegate);
353};
354
355class ClientDelegate : public Daemon::ClientDelegate {
356 public:
357  ClientDelegate(const Pickle& command_pickle)
358      : command_pickle_(command_pickle),
359        has_failed_(false) {
360  }
361
362  bool has_failed() const { return has_failed_; }
363
364  // Daemon::ClientDelegate:
365  virtual void OnDaemonReady(Socket* daemon_socket) OVERRIDE {
366    // Send the forward command to the daemon.
367    CHECK_EQ(command_pickle_.size(),
368             daemon_socket->WriteNumBytes(command_pickle_.data(),
369                                          command_pickle_.size()));
370    char buf[kBufSize];
371    const int bytes_read = daemon_socket->Read(
372        buf, sizeof(buf) - 1 /* leave space for null terminator */);
373    CHECK_GT(bytes_read, 0);
374    DCHECK(bytes_read < sizeof(buf));
375    buf[bytes_read] = 0;
376    base::StringPiece msg(buf, bytes_read);
377    if (msg.starts_with("ERROR")) {
378      LOG(ERROR) << msg;
379      has_failed_ = true;
380      return;
381    }
382    printf("%s\n", buf);
383  }
384
385 private:
386  const Pickle command_pickle_;
387  bool has_failed_;
388};
389
390void ExitWithUsage() {
391  std::cerr << "Usage: host_forwarder [options]\n\n"
392               "Options:\n"
393               "  --serial-id=[0-9A-Z]{16}]\n"
394               "  --map DEVICE_PORT HOST_PORT\n"
395               "  --unmap DEVICE_PORT\n"
396               "  --kill-server\n";
397  exit(1);
398}
399
400int PortToInt(const std::string& s) {
401  int value;
402  // Note that 0 is a valid port (used for dynamic port allocation).
403  if (!base::StringToInt(s, &value) || value < 0 ||
404      value > std::numeric_limits<uint16>::max()) {
405    LOG(ERROR) << "Could not convert string " << s << " to port";
406    ExitWithUsage();
407  }
408  return value;
409}
410
411int RunHostForwarder(int argc, char** argv) {
412  CommandLine::Init(argc, argv);
413  const CommandLine& cmd_line = *CommandLine::ForCurrentProcess();
414  bool kill_server = false;
415
416  Pickle pickle;
417  pickle.WriteString(
418      cmd_line.HasSwitch("serial-id") ?
419          cmd_line.GetSwitchValueASCII("serial-id") : std::string());
420
421  const std::vector<std::string> args = cmd_line.GetArgs();
422  if (cmd_line.HasSwitch("kill-server")) {
423    kill_server = true;
424  } else if (cmd_line.HasSwitch("unmap")) {
425    if (args.size() != 1)
426      ExitWithUsage();
427    // Note the minus sign below.
428    pickle.WriteInt(-PortToInt(args[0]));
429  } else if (cmd_line.HasSwitch("map")) {
430    if (args.size() != 2)
431      ExitWithUsage();
432    pickle.WriteInt(PortToInt(args[0]));
433    pickle.WriteInt(PortToInt(args[1]));
434  } else {
435    ExitWithUsage();
436  }
437
438  if (kill_server && args.size() > 0)
439    ExitWithUsage();
440
441  ClientDelegate client_delegate(pickle);
442  ServerDelegate daemon_delegate;
443  Daemon daemon(
444      kLogFilePath, kDaemonIdentifier, &client_delegate, &daemon_delegate,
445      &GetExitNotifierFD);
446
447  if (kill_server)
448    return !daemon.Kill();
449  if (!daemon.SpawnIfNeeded())
450    return 1;
451
452  return client_delegate.has_failed() || daemon_delegate.has_failed();
453}
454
455}  // namespace
456}  // namespace forwarder2
457
458int main(int argc, char** argv) {
459  return forwarder2::RunHostForwarder(argc, argv);
460}
461