1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include <errno.h> 6#include <signal.h> 7#include <sys/types.h> 8#include <sys/wait.h> 9#include <unistd.h> 10 11#include <cstdio> 12#include <iostream> 13#include <limits> 14#include <string> 15#include <utility> 16#include <vector> 17 18#include "base/at_exit.h" 19#include "base/basictypes.h" 20#include "base/bind.h" 21#include "base/command_line.h" 22#include "base/compiler_specific.h" 23#include "base/containers/hash_tables.h" 24#include "base/files/file_path.h" 25#include "base/files/file_util.h" 26#include "base/logging.h" 27#include "base/memory/linked_ptr.h" 28#include "base/memory/scoped_vector.h" 29#include "base/memory/weak_ptr.h" 30#include "base/pickle.h" 31#include "base/safe_strerror_posix.h" 32#include "base/strings/string_number_conversions.h" 33#include "base/strings/string_piece.h" 34#include "base/strings/string_split.h" 35#include "base/strings/string_util.h" 36#include "base/strings/stringprintf.h" 37#include "base/task_runner.h" 38#include "base/threading/thread.h" 39#include "tools/android/forwarder2/common.h" 40#include "tools/android/forwarder2/daemon.h" 41#include "tools/android/forwarder2/host_controller.h" 42#include "tools/android/forwarder2/pipe_notifier.h" 43#include "tools/android/forwarder2/socket.h" 44#include "tools/android/forwarder2/util.h" 45 46namespace forwarder2 { 47namespace { 48 49const char kLogFilePath[] = "/tmp/host_forwarder_log"; 50const char kDaemonIdentifier[] = "chrome_host_forwarder_daemon"; 51 52const char kKillServerCommand[] = "kill-server"; 53const char kForwardCommand[] = "forward"; 54 55const int kBufSize = 256; 56 57// Needs to be global to be able to be accessed from the signal handler. 58PipeNotifier* g_notifier = NULL; 59 60// Lets the daemon fetch the exit notifier file descriptor. 61int GetExitNotifierFD() { 62 DCHECK(g_notifier); 63 return g_notifier->receiver_fd(); 64} 65 66void KillHandler(int signal_number) { 67 char buf[kBufSize]; 68 if (signal_number != SIGTERM && signal_number != SIGINT) { 69 snprintf(buf, sizeof(buf), "Ignoring unexpected signal %d.", signal_number); 70 SIGNAL_SAFE_LOG(WARNING, buf); 71 return; 72 } 73 snprintf(buf, sizeof(buf), "Received signal %d.", signal_number); 74 SIGNAL_SAFE_LOG(WARNING, buf); 75 static int s_kill_handler_count = 0; 76 CHECK(g_notifier); 77 // If for some reason the forwarder get stuck in any socket waiting forever, 78 // we can send a SIGKILL or SIGINT three times to force it die 79 // (non-nicely). This is useful when debugging. 80 ++s_kill_handler_count; 81 if (!g_notifier->Notify() || s_kill_handler_count > 2) 82 exit(1); 83} 84 85// Manages HostController instances. There is one HostController instance for 86// each connection being forwarded. Note that forwarding can happen with many 87// devices (identified with a serial id). 88class HostControllersManager { 89 public: 90 HostControllersManager() 91 : weak_ptr_factory_(this), 92 controllers_(new HostControllerMap()), 93 has_failed_(false) { 94 } 95 96 ~HostControllersManager() { 97 if (!thread_.get()) 98 return; 99 // Delete the controllers on the thread they were created on. 100 thread_->message_loop_proxy()->DeleteSoon( 101 FROM_HERE, controllers_.release()); 102 } 103 104 void HandleRequest(const std::string& device_serial, 105 int device_port, 106 int host_port, 107 scoped_ptr<Socket> client_socket) { 108 // Lazy initialize so that the CLI process doesn't get this thread created. 109 InitOnce(); 110 thread_->message_loop_proxy()->PostTask( 111 FROM_HERE, 112 base::Bind( 113 &HostControllersManager::HandleRequestOnInternalThread, 114 base::Unretained(this), device_serial, device_port, host_port, 115 base::Passed(&client_socket))); 116 } 117 118 bool has_failed() const { return has_failed_; } 119 120 private: 121 typedef base::hash_map< 122 std::string, linked_ptr<HostController> > HostControllerMap; 123 124 static std::string MakeHostControllerMapKey(int adb_port, int device_port) { 125 return base::StringPrintf("%d:%d", adb_port, device_port); 126 } 127 128 void InitOnce() { 129 if (thread_.get()) 130 return; 131 at_exit_manager_.reset(new base::AtExitManager()); 132 thread_.reset(new base::Thread("HostControllersManagerThread")); 133 thread_->Start(); 134 } 135 136 // Invoked when a HostController instance reports an error (e.g. due to a 137 // device connectivity issue). Note that this could be called after the 138 // controller manager was destroyed which is why a weak pointer is used. 139 static void DeleteHostController( 140 const base::WeakPtr<HostControllersManager>& manager_ptr, 141 scoped_ptr<HostController> host_controller) { 142 HostController* const controller = host_controller.release(); 143 HostControllersManager* const manager = manager_ptr.get(); 144 if (!manager) { 145 // Note that |controller| is not leaked in this case since the host 146 // controllers manager owns the controllers. If the manager was deleted 147 // then all the controllers (including |controller|) were also deleted. 148 return; 149 } 150 DCHECK(manager->thread_->message_loop_proxy()->RunsTasksOnCurrentThread()); 151 // Note that this will delete |controller| which is owned by the map. 152 DeleteRefCountedValueInMap( 153 MakeHostControllerMapKey( 154 controller->adb_port(), controller->device_port()), 155 manager->controllers_.get()); 156 } 157 158 void HandleRequestOnInternalThread(const std::string& device_serial, 159 int device_port, 160 int host_port, 161 scoped_ptr<Socket> client_socket) { 162 const int adb_port = GetAdbPortForDevice(device_serial); 163 if (adb_port < 0) { 164 SendMessage( 165 "ERROR: could not get adb port for device. You might need to add " 166 "'adb' to your PATH or provide the device serial id.", 167 client_socket.get()); 168 return; 169 } 170 if (device_port < 0) { 171 // Remove the previously created host controller. 172 const std::string controller_key = MakeHostControllerMapKey( 173 adb_port, -device_port); 174 const bool controller_did_exist = DeleteRefCountedValueInMap( 175 controller_key, controllers_.get()); 176 SendMessage( 177 !controller_did_exist ? "ERROR: could not unmap port" : "OK", 178 client_socket.get()); 179 180 RemoveAdbPortForDeviceIfNeeded(device_serial); 181 return; 182 } 183 if (host_port < 0) { 184 SendMessage("ERROR: missing host port", client_socket.get()); 185 return; 186 } 187 const bool use_dynamic_port_allocation = device_port == 0; 188 if (!use_dynamic_port_allocation) { 189 const std::string controller_key = MakeHostControllerMapKey( 190 adb_port, device_port); 191 if (controllers_->find(controller_key) != controllers_->end()) { 192 LOG(INFO) << "Already forwarding device port " << device_port 193 << " to host port " << host_port; 194 SendMessage(base::StringPrintf("%d:%d", device_port, host_port), 195 client_socket.get()); 196 return; 197 } 198 } 199 // Create a new host controller. 200 scoped_ptr<HostController> host_controller( 201 HostController::Create( 202 device_port, host_port, adb_port, GetExitNotifierFD(), 203 base::Bind(&HostControllersManager::DeleteHostController, 204 weak_ptr_factory_.GetWeakPtr()))); 205 if (!host_controller.get()) { 206 has_failed_ = true; 207 SendMessage("ERROR: Connection to device failed.", client_socket.get()); 208 return; 209 } 210 // Get the current allocated port. 211 device_port = host_controller->device_port(); 212 LOG(INFO) << "Forwarding device port " << device_port << " to host port " 213 << host_port; 214 const std::string msg = base::StringPrintf("%d:%d", device_port, host_port); 215 if (!SendMessage(msg, client_socket.get())) 216 return; 217 host_controller->Start(); 218 controllers_->insert( 219 std::make_pair(MakeHostControllerMapKey(adb_port, device_port), 220 linked_ptr<HostController>(host_controller.release()))); 221 } 222 223 void RemoveAdbPortForDeviceIfNeeded(const std::string& device_serial) { 224 base::hash_map<std::string, int>::const_iterator it = 225 device_serial_to_adb_port_map_.find(device_serial); 226 if (it == device_serial_to_adb_port_map_.end()) 227 return; 228 229 int port = it->second; 230 const std::string prefix = base::StringPrintf("%d:", port); 231 for (HostControllerMap::const_iterator others = controllers_->begin(); 232 others != controllers_->end(); ++others) { 233 if (others->first.find(prefix) == 0U) 234 return; 235 } 236 // No other port is being forwarded to this device: 237 // - Remove it from our internal serial -> adb port map. 238 // - Remove from "adb forward" command. 239 LOG(INFO) << "Device " << device_serial << " has no more ports."; 240 device_serial_to_adb_port_map_.erase(device_serial); 241 const std::string serial_part = device_serial.empty() ? 242 std::string() : std::string("-s ") + device_serial; 243 const std::string command = base::StringPrintf( 244 "adb %s forward --remove tcp:%d", 245 serial_part.c_str(), 246 port); 247 const int ret = system(command.c_str()); 248 LOG(INFO) << command << " ret: " << ret; 249 // Wait for the socket to be fully unmapped. 250 const std::string port_mapped_cmd = base::StringPrintf( 251 "lsof -nPi:%d", 252 port); 253 const int poll_interval_us = 500 * 1000; 254 int retries = 3; 255 while (retries) { 256 const int port_unmapped = system(port_mapped_cmd.c_str()); 257 LOG(INFO) << "Device " << device_serial << " port " << port << " unmap " 258 << port_unmapped; 259 if (port_unmapped) 260 break; 261 --retries; 262 usleep(poll_interval_us); 263 } 264 } 265 266 int GetAdbPortForDevice(const std::string& device_serial) { 267 base::hash_map<std::string, int>::const_iterator it = 268 device_serial_to_adb_port_map_.find(device_serial); 269 if (it != device_serial_to_adb_port_map_.end()) 270 return it->second; 271 Socket bind_socket; 272 CHECK(bind_socket.BindTcp("127.0.0.1", 0)); 273 const int port = bind_socket.GetPort(); 274 bind_socket.Close(); 275 const std::string serial_part = device_serial.empty() ? 276 std::string() : std::string("-s ") + device_serial; 277 const std::string command = base::StringPrintf( 278 "adb %s forward tcp:%d localabstract:chrome_device_forwarder", 279 serial_part.c_str(), 280 port); 281 LOG(INFO) << command; 282 const int ret = system(command.c_str()); 283 if (ret < 0 || !WIFEXITED(ret) || WEXITSTATUS(ret) != 0) 284 return -1; 285 device_serial_to_adb_port_map_[device_serial] = port; 286 return port; 287 } 288 289 bool SendMessage(const std::string& msg, Socket* client_socket) { 290 bool result = client_socket->WriteString(msg); 291 DCHECK(result); 292 if (!result) 293 has_failed_ = true; 294 return result; 295 } 296 297 base::WeakPtrFactory<HostControllersManager> weak_ptr_factory_; 298 base::hash_map<std::string, int> device_serial_to_adb_port_map_; 299 scoped_ptr<HostControllerMap> controllers_; 300 bool has_failed_; 301 scoped_ptr<base::AtExitManager> at_exit_manager_; // Needed by base::Thread. 302 scoped_ptr<base::Thread> thread_; 303}; 304 305class ServerDelegate : public Daemon::ServerDelegate { 306 public: 307 ServerDelegate() : has_failed_(false) {} 308 309 bool has_failed() const { 310 return has_failed_ || controllers_manager_.has_failed(); 311 } 312 313 // Daemon::ServerDelegate: 314 virtual void Init() OVERRIDE { 315 LOG(INFO) << "Starting host process daemon (pid=" << getpid() << ")"; 316 DCHECK(!g_notifier); 317 g_notifier = new PipeNotifier(); 318 signal(SIGTERM, KillHandler); 319 signal(SIGINT, KillHandler); 320 } 321 322 virtual void OnClientConnected(scoped_ptr<Socket> client_socket) OVERRIDE { 323 char buf[kBufSize]; 324 const int bytes_read = client_socket->Read(buf, sizeof(buf)); 325 if (bytes_read <= 0) { 326 if (client_socket->DidReceiveEvent()) 327 return; 328 PError("Read()"); 329 has_failed_ = true; 330 return; 331 } 332 const Pickle command_pickle(buf, bytes_read); 333 PickleIterator pickle_it(command_pickle); 334 std::string device_serial; 335 CHECK(pickle_it.ReadString(&device_serial)); 336 int device_port; 337 if (!pickle_it.ReadInt(&device_port)) { 338 client_socket->WriteString("ERROR: missing device port"); 339 return; 340 } 341 int host_port; 342 if (!pickle_it.ReadInt(&host_port)) 343 host_port = -1; 344 controllers_manager_.HandleRequest( 345 device_serial, device_port, host_port, client_socket.Pass()); 346 } 347 348 private: 349 bool has_failed_; 350 HostControllersManager controllers_manager_; 351 352 DISALLOW_COPY_AND_ASSIGN(ServerDelegate); 353}; 354 355class ClientDelegate : public Daemon::ClientDelegate { 356 public: 357 ClientDelegate(const Pickle& command_pickle) 358 : command_pickle_(command_pickle), 359 has_failed_(false) { 360 } 361 362 bool has_failed() const { return has_failed_; } 363 364 // Daemon::ClientDelegate: 365 virtual void OnDaemonReady(Socket* daemon_socket) OVERRIDE { 366 // Send the forward command to the daemon. 367 CHECK_EQ(command_pickle_.size(), 368 daemon_socket->WriteNumBytes(command_pickle_.data(), 369 command_pickle_.size())); 370 char buf[kBufSize]; 371 const int bytes_read = daemon_socket->Read( 372 buf, sizeof(buf) - 1 /* leave space for null terminator */); 373 CHECK_GT(bytes_read, 0); 374 DCHECK(bytes_read < sizeof(buf)); 375 buf[bytes_read] = 0; 376 base::StringPiece msg(buf, bytes_read); 377 if (msg.starts_with("ERROR")) { 378 LOG(ERROR) << msg; 379 has_failed_ = true; 380 return; 381 } 382 printf("%s\n", buf); 383 } 384 385 private: 386 const Pickle command_pickle_; 387 bool has_failed_; 388}; 389 390void ExitWithUsage() { 391 std::cerr << "Usage: host_forwarder [options]\n\n" 392 "Options:\n" 393 " --serial-id=[0-9A-Z]{16}]\n" 394 " --map DEVICE_PORT HOST_PORT\n" 395 " --unmap DEVICE_PORT\n" 396 " --kill-server\n"; 397 exit(1); 398} 399 400int PortToInt(const std::string& s) { 401 int value; 402 // Note that 0 is a valid port (used for dynamic port allocation). 403 if (!base::StringToInt(s, &value) || value < 0 || 404 value > std::numeric_limits<uint16>::max()) { 405 LOG(ERROR) << "Could not convert string " << s << " to port"; 406 ExitWithUsage(); 407 } 408 return value; 409} 410 411int RunHostForwarder(int argc, char** argv) { 412 CommandLine::Init(argc, argv); 413 const CommandLine& cmd_line = *CommandLine::ForCurrentProcess(); 414 bool kill_server = false; 415 416 Pickle pickle; 417 pickle.WriteString( 418 cmd_line.HasSwitch("serial-id") ? 419 cmd_line.GetSwitchValueASCII("serial-id") : std::string()); 420 421 const std::vector<std::string> args = cmd_line.GetArgs(); 422 if (cmd_line.HasSwitch("kill-server")) { 423 kill_server = true; 424 } else if (cmd_line.HasSwitch("unmap")) { 425 if (args.size() != 1) 426 ExitWithUsage(); 427 // Note the minus sign below. 428 pickle.WriteInt(-PortToInt(args[0])); 429 } else if (cmd_line.HasSwitch("map")) { 430 if (args.size() != 2) 431 ExitWithUsage(); 432 pickle.WriteInt(PortToInt(args[0])); 433 pickle.WriteInt(PortToInt(args[1])); 434 } else { 435 ExitWithUsage(); 436 } 437 438 if (kill_server && args.size() > 0) 439 ExitWithUsage(); 440 441 ClientDelegate client_delegate(pickle); 442 ServerDelegate daemon_delegate; 443 Daemon daemon( 444 kLogFilePath, kDaemonIdentifier, &client_delegate, &daemon_delegate, 445 &GetExitNotifierFD); 446 447 if (kill_server) 448 return !daemon.Kill(); 449 if (!daemon.SpawnIfNeeded()) 450 return 1; 451 452 return client_delegate.has_failed() || daemon_delegate.has_failed(); 453} 454 455} // namespace 456} // namespace forwarder2 457 458int main(int argc, char** argv) { 459 return forwarder2::RunHostForwarder(argc, argv); 460} 461