host_forwarder_main.cc revision f2477e01787aa58f445919b809d89e252beef54f
1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <errno.h>
6#include <signal.h>
7#include <sys/types.h>
8#include <sys/wait.h>
9#include <unistd.h>
10
11#include <cstdio>
12#include <iostream>
13#include <limits>
14#include <string>
15#include <utility>
16#include <vector>
17
18#include "base/at_exit.h"
19#include "base/basictypes.h"
20#include "base/bind.h"
21#include "base/command_line.h"
22#include "base/compiler_specific.h"
23#include "base/containers/hash_tables.h"
24#include "base/file_util.h"
25#include "base/files/file_path.h"
26#include "base/logging.h"
27#include "base/memory/linked_ptr.h"
28#include "base/memory/scoped_vector.h"
29#include "base/memory/weak_ptr.h"
30#include "base/pickle.h"
31#include "base/posix/eintr_wrapper.h"
32#include "base/safe_strerror_posix.h"
33#include "base/strings/string_number_conversions.h"
34#include "base/strings/string_piece.h"
35#include "base/strings/string_split.h"
36#include "base/strings/string_util.h"
37#include "base/strings/stringprintf.h"
38#include "base/task_runner.h"
39#include "base/threading/thread.h"
40#include "tools/android/forwarder2/common.h"
41#include "tools/android/forwarder2/daemon.h"
42#include "tools/android/forwarder2/host_controller.h"
43#include "tools/android/forwarder2/pipe_notifier.h"
44#include "tools/android/forwarder2/socket.h"
45#include "tools/android/forwarder2/util.h"
46
47namespace forwarder2 {
48namespace {
49
50const char kLogFilePath[] = "/tmp/host_forwarder_log";
51const char kDaemonIdentifier[] = "chrome_host_forwarder_daemon";
52
53const char kKillServerCommand[] = "kill-server";
54const char kForwardCommand[] = "forward";
55
56const int kBufSize = 256;
57
58// Needs to be global to be able to be accessed from the signal handler.
59PipeNotifier* g_notifier = NULL;
60
61// Lets the daemon fetch the exit notifier file descriptor.
62int GetExitNotifierFD() {
63  DCHECK(g_notifier);
64  return g_notifier->receiver_fd();
65}
66
67void KillHandler(int signal_number) {
68  char buf[kBufSize];
69  if (signal_number != SIGTERM && signal_number != SIGINT) {
70    snprintf(buf, sizeof(buf), "Ignoring unexpected signal %d.", signal_number);
71    SIGNAL_SAFE_LOG(WARNING, buf);
72    return;
73  }
74  snprintf(buf, sizeof(buf), "Received signal %d.", signal_number);
75  SIGNAL_SAFE_LOG(WARNING, buf);
76  static int s_kill_handler_count = 0;
77  CHECK(g_notifier);
78  // If for some reason the forwarder get stuck in any socket waiting forever,
79  // we can send a SIGKILL or SIGINT three times to force it die
80  // (non-nicely). This is useful when debugging.
81  ++s_kill_handler_count;
82  if (!g_notifier->Notify() || s_kill_handler_count > 2)
83    exit(1);
84}
85
86// Manages HostController instances. There is one HostController instance for
87// each connection being forwarded. Note that forwarding can happen with many
88// devices (identified with a serial id).
89class HostControllersManager {
90 public:
91  HostControllersManager()
92      : weak_ptr_factory_(this),
93        controllers_(new HostControllerMap()),
94        has_failed_(false) {
95  }
96
97  ~HostControllersManager() {
98    if (!thread_.get())
99      return;
100    // Delete the controllers on the thread they were created on.
101    thread_->message_loop_proxy()->DeleteSoon(
102        FROM_HERE, controllers_.release());
103  }
104
105  void HandleRequest(const std::string& device_serial,
106                     int device_port,
107                     int host_port,
108                     scoped_ptr<Socket> client_socket) {
109    // Lazy initialize so that the CLI process doesn't get this thread created.
110    InitOnce();
111    thread_->message_loop_proxy()->PostTask(
112        FROM_HERE,
113        base::Bind(
114            &HostControllersManager::HandleRequestOnInternalThread,
115            base::Unretained(this), device_serial, device_port, host_port,
116            base::Passed(&client_socket)));
117  }
118
119  bool has_failed() const { return has_failed_; }
120
121 private:
122  typedef base::hash_map<
123      std::string, linked_ptr<HostController> > HostControllerMap;
124
125  static std::string MakeHostControllerMapKey(int adb_port, int device_port) {
126    return base::StringPrintf("%d:%d", adb_port, device_port);
127  }
128
129  void InitOnce() {
130    if (thread_.get())
131      return;
132    at_exit_manager_.reset(new base::AtExitManager());
133    thread_.reset(new base::Thread("HostControllersManagerThread"));
134    thread_->Start();
135  }
136
137  // Invoked when a HostController instance reports an error (e.g. due to a
138  // device connectivity issue). Note that this could be called after the
139  // controller manager was destroyed which is why a weak pointer is used.
140  static void DeleteHostController(
141      const base::WeakPtr<HostControllersManager>& manager_ptr,
142      scoped_ptr<HostController> host_controller) {
143    HostController* const controller = host_controller.release();
144    HostControllersManager* const manager = manager_ptr.get();
145    if (!manager) {
146      // Note that |controller| is not leaked in this case since the host
147      // controllers manager owns the controllers. If the manager was deleted
148      // then all the controllers (including |controller|) were also deleted.
149      return;
150    }
151    DCHECK(manager->thread_->message_loop_proxy()->RunsTasksOnCurrentThread());
152    // Note that this will delete |controller| which is owned by the map.
153    DeleteRefCountedValueInMap(
154        MakeHostControllerMapKey(
155            controller->adb_port(), controller->device_port()),
156        manager->controllers_.get());
157  }
158
159  void HandleRequestOnInternalThread(const std::string& device_serial,
160                                     int device_port,
161                                     int host_port,
162                                     scoped_ptr<Socket> client_socket) {
163    const int adb_port = GetAdbPortForDevice(device_serial);
164    if (adb_port < 0) {
165      SendMessage(
166          "ERROR: could not get adb port for device. You might need to add "
167          "'adb' to your PATH or provide the device serial id.",
168          client_socket.get());
169      return;
170    }
171    if (device_port < 0) {
172      // Remove the previously created host controller.
173      const std::string controller_key = MakeHostControllerMapKey(
174          adb_port, -device_port);
175      const bool controller_did_exist = DeleteRefCountedValueInMap(
176          controller_key, controllers_.get());
177      SendMessage(
178          !controller_did_exist ? "ERROR: could not unmap port" : "OK",
179          client_socket.get());
180
181      RemoveAdbPortForDeviceIfNeeded(device_serial);
182      return;
183    }
184    if (host_port < 0) {
185      SendMessage("ERROR: missing host port", client_socket.get());
186      return;
187    }
188    const bool use_dynamic_port_allocation = device_port == 0;
189    if (!use_dynamic_port_allocation) {
190      const std::string controller_key = MakeHostControllerMapKey(
191          adb_port, device_port);
192      if (controllers_->find(controller_key) != controllers_->end()) {
193        LOG(INFO) << "Already forwarding device port " << device_port
194                  << " to host port " << host_port;
195        SendMessage(base::StringPrintf("%d:%d", device_port, host_port),
196                    client_socket.get());
197        return;
198      }
199    }
200    // Create a new host controller.
201    scoped_ptr<HostController> host_controller(
202        HostController::Create(
203            device_port, host_port, adb_port, GetExitNotifierFD(),
204            base::Bind(&HostControllersManager::DeleteHostController,
205                       weak_ptr_factory_.GetWeakPtr())));
206    if (!host_controller.get()) {
207      has_failed_ = true;
208      SendMessage("ERROR: Connection to device failed.", client_socket.get());
209      return;
210    }
211    // Get the current allocated port.
212    device_port = host_controller->device_port();
213    LOG(INFO) << "Forwarding device port " << device_port << " to host port "
214              << host_port;
215    const std::string msg = base::StringPrintf("%d:%d", device_port, host_port);
216    if (!SendMessage(msg, client_socket.get()))
217      return;
218    host_controller->Start();
219    controllers_->insert(
220        std::make_pair(MakeHostControllerMapKey(adb_port, device_port),
221                       linked_ptr<HostController>(host_controller.release())));
222  }
223
224  void RemoveAdbPortForDeviceIfNeeded(const std::string& device_serial) {
225    base::hash_map<std::string, int>::const_iterator it =
226        device_serial_to_adb_port_map_.find(device_serial);
227    if (it == device_serial_to_adb_port_map_.end())
228      return;
229
230    int port = it->second;
231    const std::string prefix = base::StringPrintf("%d:", port);
232    for (HostControllerMap::const_iterator others = controllers_->begin();
233         others != controllers_->end(); ++others) {
234      if (others->first.find(prefix) == 0U)
235        return;
236    }
237    // No other port is being forwarded to this device:
238    // - Remove it from our internal serial -> adb port map.
239    // - Remove from "adb forward" command.
240    LOG(INFO) << "Device " << device_serial << " has no more ports.";
241    device_serial_to_adb_port_map_.erase(device_serial);
242    const std::string serial_part = device_serial.empty() ?
243        std::string() : std::string("-s ") + device_serial;
244    const std::string command = base::StringPrintf(
245        "adb %s forward --remove tcp:%d",
246        serial_part.c_str(),
247        port);
248    const int ret = system(command.c_str());
249    LOG(INFO) << command << " ret: " << ret;
250    // Wait for the socket to be fully unmapped.
251    const std::string port_mapped_cmd = base::StringPrintf(
252        "lsof -nPi:%d",
253        port);
254    const int poll_interval_us = 500 * 1000;
255    int retries = 3;
256    while (retries) {
257      const int port_unmapped = system(port_mapped_cmd.c_str());
258      LOG(INFO) << "Device " << device_serial << " port " << port << " unmap "
259                << port_unmapped;
260      if (port_unmapped)
261        break;
262      --retries;
263      usleep(poll_interval_us);
264    }
265  }
266
267  int GetAdbPortForDevice(const std::string& device_serial) {
268    base::hash_map<std::string, int>::const_iterator it =
269        device_serial_to_adb_port_map_.find(device_serial);
270    if (it != device_serial_to_adb_port_map_.end())
271      return it->second;
272    Socket bind_socket;
273    CHECK(bind_socket.BindTcp("127.0.0.1", 0));
274    const int port = bind_socket.GetPort();
275    bind_socket.Close();
276    const std::string serial_part = device_serial.empty() ?
277        std::string() : std::string("-s ") + device_serial;
278    const std::string command = base::StringPrintf(
279        "adb %s forward tcp:%d localabstract:chrome_device_forwarder",
280        serial_part.c_str(),
281        port);
282    LOG(INFO) << command;
283    const int ret = system(command.c_str());
284    if (ret < 0 || !WIFEXITED(ret) || WEXITSTATUS(ret) != 0)
285      return -1;
286    device_serial_to_adb_port_map_[device_serial] = port;
287    return port;
288  }
289
290  bool SendMessage(const std::string& msg, Socket* client_socket) {
291    bool result = client_socket->WriteString(msg);
292    DCHECK(result);
293    if (!result)
294      has_failed_ = true;
295    return result;
296  }
297
298  base::WeakPtrFactory<HostControllersManager> weak_ptr_factory_;
299  base::hash_map<std::string, int> device_serial_to_adb_port_map_;
300  scoped_ptr<HostControllerMap> controllers_;
301  bool has_failed_;
302  scoped_ptr<base::AtExitManager> at_exit_manager_;  // Needed by base::Thread.
303  scoped_ptr<base::Thread> thread_;
304};
305
306class ServerDelegate : public Daemon::ServerDelegate {
307 public:
308  ServerDelegate() : has_failed_(false) {}
309
310  bool has_failed() const {
311    return has_failed_ || controllers_manager_.has_failed();
312  }
313
314  // Daemon::ServerDelegate:
315  virtual void Init() OVERRIDE {
316    LOG(INFO) << "Starting host process daemon (pid=" << getpid() << ")";
317    DCHECK(!g_notifier);
318    g_notifier = new PipeNotifier();
319    signal(SIGTERM, KillHandler);
320    signal(SIGINT, KillHandler);
321  }
322
323  virtual void OnClientConnected(scoped_ptr<Socket> client_socket) OVERRIDE {
324    char buf[kBufSize];
325    const int bytes_read = client_socket->Read(buf, sizeof(buf));
326    if (bytes_read <= 0) {
327      if (client_socket->DidReceiveEvent())
328        return;
329      PError("Read()");
330      has_failed_ = true;
331      return;
332    }
333    const Pickle command_pickle(buf, bytes_read);
334    PickleIterator pickle_it(command_pickle);
335    std::string device_serial;
336    CHECK(pickle_it.ReadString(&device_serial));
337    int device_port;
338    if (!pickle_it.ReadInt(&device_port)) {
339      client_socket->WriteString("ERROR: missing device port");
340      return;
341    }
342    int host_port;
343    if (!pickle_it.ReadInt(&host_port))
344      host_port = -1;
345    controllers_manager_.HandleRequest(
346        device_serial, device_port, host_port, client_socket.Pass());
347  }
348
349 private:
350  bool has_failed_;
351  HostControllersManager controllers_manager_;
352
353  DISALLOW_COPY_AND_ASSIGN(ServerDelegate);
354};
355
356class ClientDelegate : public Daemon::ClientDelegate {
357 public:
358  ClientDelegate(const Pickle& command_pickle)
359      : command_pickle_(command_pickle),
360        has_failed_(false) {
361  }
362
363  bool has_failed() const { return has_failed_; }
364
365  // Daemon::ClientDelegate:
366  virtual void OnDaemonReady(Socket* daemon_socket) OVERRIDE {
367    // Send the forward command to the daemon.
368    CHECK_EQ(command_pickle_.size(),
369             daemon_socket->WriteNumBytes(command_pickle_.data(),
370                                          command_pickle_.size()));
371    char buf[kBufSize];
372    const int bytes_read = daemon_socket->Read(
373        buf, sizeof(buf) - 1 /* leave space for null terminator */);
374    CHECK_GT(bytes_read, 0);
375    DCHECK(bytes_read < sizeof(buf));
376    buf[bytes_read] = 0;
377    base::StringPiece msg(buf, bytes_read);
378    if (msg.starts_with("ERROR")) {
379      LOG(ERROR) << msg;
380      has_failed_ = true;
381      return;
382    }
383    printf("%s\n", buf);
384  }
385
386 private:
387  const Pickle command_pickle_;
388  bool has_failed_;
389};
390
391void ExitWithUsage() {
392  std::cerr << "Usage: host_forwarder [options]\n\n"
393               "Options:\n"
394               "  --serial-id=[0-9A-Z]{16}]\n"
395               "  --map DEVICE_PORT HOST_PORT\n"
396               "  --unmap DEVICE_PORT\n"
397               "  --kill-server\n";
398  exit(1);
399}
400
401int PortToInt(const std::string& s) {
402  int value;
403  // Note that 0 is a valid port (used for dynamic port allocation).
404  if (!base::StringToInt(s, &value) || value < 0 ||
405      value > std::numeric_limits<uint16>::max()) {
406    LOG(ERROR) << "Could not convert string " << s << " to port";
407    ExitWithUsage();
408  }
409  return value;
410}
411
412int RunHostForwarder(int argc, char** argv) {
413  CommandLine::Init(argc, argv);
414  const CommandLine& cmd_line = *CommandLine::ForCurrentProcess();
415  bool kill_server = false;
416
417  Pickle pickle;
418  pickle.WriteString(
419      cmd_line.HasSwitch("serial-id") ?
420          cmd_line.GetSwitchValueASCII("serial-id") : std::string());
421
422  const std::vector<std::string> args = cmd_line.GetArgs();
423  if (cmd_line.HasSwitch("kill-server")) {
424    kill_server = true;
425  } else if (cmd_line.HasSwitch("unmap")) {
426    if (args.size() != 1)
427      ExitWithUsage();
428    // Note the minus sign below.
429    pickle.WriteInt(-PortToInt(args[0]));
430  } else if (cmd_line.HasSwitch("map")) {
431    if (args.size() != 2)
432      ExitWithUsage();
433    pickle.WriteInt(PortToInt(args[0]));
434    pickle.WriteInt(PortToInt(args[1]));
435  } else {
436    ExitWithUsage();
437  }
438
439  if (kill_server && args.size() > 0)
440    ExitWithUsage();
441
442  ClientDelegate client_delegate(pickle);
443  ServerDelegate daemon_delegate;
444  Daemon daemon(
445      kLogFilePath, kDaemonIdentifier, &client_delegate, &daemon_delegate,
446      &GetExitNotifierFD);
447
448  if (kill_server)
449    return !daemon.Kill();
450  if (!daemon.SpawnIfNeeded())
451    return 1;
452
453  return client_delegate.has_failed() || daemon_delegate.has_failed();
454}
455
456}  // namespace
457}  // namespace forwarder2
458
459int main(int argc, char** argv) {
460  return forwarder2::RunHostForwarder(argc, argv);
461}
462