1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <errno.h>
6#include <fcntl.h>
7#include <netinet/in.h>
8#include <netinet/tcp.h>
9#include <pthread.h>
10#include <signal.h>
11#include <stdio.h>
12#include <stdlib.h>
13#include <string.h>
14#include <sys/select.h>
15#include <sys/socket.h>
16#include <sys/wait.h>
17#include <unistd.h>
18
19#include "base/command_line.h"
20#include "base/logging.h"
21#include "base/posix/eintr_wrapper.h"
22#include "tools/android/common/adb_connection.h"
23#include "tools/android/common/daemon.h"
24#include "tools/android/common/net.h"
25
26namespace {
27
28const pthread_t kInvalidThread = static_cast<pthread_t>(-1);
29volatile bool g_killed = false;
30
31void CloseSocket(int fd) {
32  if (fd >= 0) {
33    int old_errno = errno;
34    close(fd);
35    errno = old_errno;
36  }
37}
38
39class Buffer {
40 public:
41  Buffer()
42      : bytes_read_(0),
43        write_offset_(0) {
44  }
45
46  bool CanRead() {
47    return bytes_read_ == 0;
48  }
49
50  bool CanWrite() {
51    return write_offset_ < bytes_read_;
52  }
53
54  int Read(int fd) {
55    int ret = -1;
56    if (CanRead()) {
57      ret = HANDLE_EINTR(read(fd, buffer_, kBufferSize));
58      if (ret > 0)
59        bytes_read_ = ret;
60    }
61    return ret;
62  }
63
64  int Write(int fd) {
65    int ret = -1;
66    if (CanWrite()) {
67      ret = HANDLE_EINTR(write(fd, buffer_ + write_offset_,
68                               bytes_read_ - write_offset_));
69      if (ret > 0) {
70        write_offset_ += ret;
71        if (write_offset_ == bytes_read_) {
72          write_offset_ = 0;
73          bytes_read_ = 0;
74        }
75      }
76    }
77    return ret;
78  }
79
80 private:
81  // A big buffer to let our file-over-http bridge work more like real file.
82  static const int kBufferSize = 1024 * 128;
83  int bytes_read_;
84  int write_offset_;
85  char buffer_[kBufferSize];
86
87  DISALLOW_COPY_AND_ASSIGN(Buffer);
88};
89
90class Server;
91
92struct ForwarderThreadInfo {
93  ForwarderThreadInfo(Server* a_server, int a_forwarder_index)
94      : server(a_server),
95        forwarder_index(a_forwarder_index) {
96  }
97  Server* server;
98  int forwarder_index;
99};
100
101struct ForwarderInfo {
102  time_t start_time;
103  int socket1;
104  time_t socket1_last_byte_time;
105  size_t socket1_bytes;
106  int socket2;
107  time_t socket2_last_byte_time;
108  size_t socket2_bytes;
109};
110
111class Server {
112 public:
113  Server()
114      : thread_(kInvalidThread),
115        socket_(-1) {
116    memset(forward_to_, 0, sizeof(forward_to_));
117    memset(&forwarders_, 0, sizeof(forwarders_));
118  }
119
120  int GetFreeForwarderIndex() {
121    for (int i = 0; i < kMaxForwarders; i++) {
122      if (forwarders_[i].start_time == 0)
123        return i;
124    }
125    return -1;
126  }
127
128  void DisposeForwarderInfo(int index) {
129    forwarders_[index].start_time = 0;
130  }
131
132  ForwarderInfo* GetForwarderInfo(int index) {
133    return &forwarders_[index];
134  }
135
136  void DumpInformation() {
137    LOG(INFO) << "Server information: " << forward_to_;
138    LOG(INFO) << "No.: age up(bytes,idle) down(bytes,idle)";
139    int count = 0;
140    time_t now = time(NULL);
141    for (int i = 0; i < kMaxForwarders; i++) {
142      const ForwarderInfo& info = forwarders_[i];
143      if (info.start_time) {
144        count++;
145        LOG(INFO) << count << ": " << now - info.start_time << " up("
146                  << info.socket1_bytes << ","
147                  << now - info.socket1_last_byte_time << " down("
148                  << info.socket2_bytes << ","
149                  << now - info.socket2_last_byte_time << ")";
150      }
151    }
152  }
153
154  void Shutdown() {
155    if (socket_ >= 0)
156      shutdown(socket_, SHUT_RDWR);
157  }
158
159  bool InitSocket(const char* arg);
160
161  void StartThread() {
162    pthread_create(&thread_, NULL, ServerThread, this);
163  }
164
165  void JoinThread() {
166    if (thread_ != kInvalidThread)
167      pthread_join(thread_, NULL);
168  }
169
170 private:
171  static void* ServerThread(void* arg);
172
173  // There are 3 kinds of threads that will access the array:
174  // 1. Server thread will get a free ForwarderInfo and initialize it;
175  // 2. Forwarder threads will dispose the ForwarderInfo when it finishes;
176  // 3. Main thread will iterate and print the forwarders.
177  // Using an array is not optimal, but can avoid locks or other complex
178  // inter-thread communication.
179  static const int kMaxForwarders = 512;
180  ForwarderInfo forwarders_[kMaxForwarders];
181
182  pthread_t thread_;
183  int socket_;
184  char forward_to_[40];
185
186  DISALLOW_COPY_AND_ASSIGN(Server);
187};
188
189// Forwards all outputs from one socket to another socket.
190void* ForwarderThread(void* arg) {
191  ForwarderThreadInfo* thread_info =
192      reinterpret_cast<ForwarderThreadInfo*>(arg);
193  Server* server = thread_info->server;
194  int index = thread_info->forwarder_index;
195  delete thread_info;
196  ForwarderInfo* info = server->GetForwarderInfo(index);
197  int socket1 = info->socket1;
198  int socket2 = info->socket2;
199  int nfds = socket1 > socket2 ? socket1 + 1 : socket2 + 1;
200  fd_set read_fds;
201  fd_set write_fds;
202  Buffer buffer1;
203  Buffer buffer2;
204
205  while (!g_killed) {
206    FD_ZERO(&read_fds);
207    if (buffer1.CanRead())
208      FD_SET(socket1, &read_fds);
209    if (buffer2.CanRead())
210      FD_SET(socket2, &read_fds);
211
212    FD_ZERO(&write_fds);
213    if (buffer1.CanWrite())
214      FD_SET(socket2, &write_fds);
215    if (buffer2.CanWrite())
216      FD_SET(socket1, &write_fds);
217
218    if (HANDLE_EINTR(select(nfds, &read_fds, &write_fds, NULL, NULL)) <= 0) {
219      LOG(ERROR) << "Select error: " << strerror(errno);
220      break;
221    }
222
223    int now = time(NULL);
224    if (FD_ISSET(socket1, &read_fds)) {
225      info->socket1_last_byte_time = now;
226      int bytes = buffer1.Read(socket1);
227      if (bytes <= 0)
228        break;
229      info->socket1_bytes += bytes;
230    }
231    if (FD_ISSET(socket2, &read_fds)) {
232      info->socket2_last_byte_time = now;
233      int bytes = buffer2.Read(socket2);
234      if (bytes <= 0)
235        break;
236      info->socket2_bytes += bytes;
237    }
238    if (FD_ISSET(socket1, &write_fds)) {
239      if (buffer2.Write(socket1) <= 0)
240        break;
241    }
242    if (FD_ISSET(socket2, &write_fds)) {
243      if (buffer1.Write(socket2) <= 0)
244        break;
245    }
246  }
247
248  CloseSocket(socket1);
249  CloseSocket(socket2);
250  server->DisposeForwarderInfo(index);
251  return NULL;
252}
253
254// Listens to a server socket. On incoming request, forward it to the host.
255// static
256void* Server::ServerThread(void* arg) {
257  Server* server = reinterpret_cast<Server*>(arg);
258  while (!g_killed) {
259    int forwarder_index = server->GetFreeForwarderIndex();
260    if (forwarder_index < 0) {
261      LOG(ERROR) << "Too many forwarders";
262      continue;
263    }
264
265    struct sockaddr_in addr;
266    socklen_t addr_len = sizeof(addr);
267    int socket = HANDLE_EINTR(accept(server->socket_,
268                                     reinterpret_cast<sockaddr*>(&addr),
269                                     &addr_len));
270    if (socket < 0) {
271      LOG(ERROR) << "Failed to accept: " << strerror(errno);
272      break;
273    }
274    tools::DisableNagle(socket);
275
276    int host_socket = tools::ConnectAdbHostSocket(server->forward_to_);
277    if (host_socket >= 0) {
278      // Set NONBLOCK flag because we use select().
279      fcntl(socket, F_SETFL, fcntl(socket, F_GETFL) | O_NONBLOCK);
280      fcntl(host_socket, F_SETFL, fcntl(host_socket, F_GETFL) | O_NONBLOCK);
281
282      ForwarderInfo* forwarder_info = server->GetForwarderInfo(forwarder_index);
283      time_t now = time(NULL);
284      forwarder_info->start_time = now;
285      forwarder_info->socket1 = socket;
286      forwarder_info->socket1_last_byte_time = now;
287      forwarder_info->socket1_bytes = 0;
288      forwarder_info->socket2 = host_socket;
289      forwarder_info->socket2_last_byte_time = now;
290      forwarder_info->socket2_bytes = 0;
291
292      pthread_t thread;
293      pthread_create(&thread, NULL, ForwarderThread,
294                     new ForwarderThreadInfo(server, forwarder_index));
295    } else {
296      // Close the unused client socket which is failed to connect to host.
297      CloseSocket(socket);
298    }
299  }
300
301  CloseSocket(server->socket_);
302  server->socket_ = -1;
303  return NULL;
304}
305
306// Format of arg: <Device port>[:<Forward to port>:<Forward to address>]
307bool Server::InitSocket(const char* arg) {
308  char* endptr;
309  int local_port = static_cast<int>(strtol(arg, &endptr, 10));
310  if (local_port < 0)
311    return false;
312
313  if (*endptr != ':') {
314    snprintf(forward_to_, sizeof(forward_to_), "%d:127.0.0.1", local_port);
315  } else {
316    strncpy(forward_to_, endptr + 1, sizeof(forward_to_) - 1);
317  }
318
319  socket_ = socket(AF_INET, SOCK_STREAM, 0);
320  if (socket_ < 0) {
321    perror("server socket");
322    return false;
323  }
324  tools::DisableNagle(socket_);
325
326  sockaddr_in addr;
327  memset(&addr, 0, sizeof(addr));
328  addr.sin_family = AF_INET;
329  addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
330  addr.sin_port = htons(local_port);
331  int reuse_addr = 1;
332  setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR,
333             &reuse_addr, sizeof(reuse_addr));
334  tools::DeferAccept(socket_);
335  if (HANDLE_EINTR(bind(socket_, reinterpret_cast<sockaddr*>(&addr),
336                        sizeof(addr))) < 0 ||
337      HANDLE_EINTR(listen(socket_, 5)) < 0) {
338    perror("server bind");
339    CloseSocket(socket_);
340    socket_ = -1;
341    return false;
342  }
343
344  if (local_port == 0) {
345    socklen_t addrlen = sizeof(addr);
346    if (getsockname(socket_, reinterpret_cast<sockaddr*>(&addr), &addrlen)
347        != 0) {
348      perror("get listen address");
349      CloseSocket(socket_);
350      socket_ = -1;
351      return false;
352    }
353    local_port = ntohs(addr.sin_port);
354  }
355
356  printf("Forwarding device port %d to host %s\n", local_port, forward_to_);
357  return true;
358}
359
360int g_server_count = 0;
361Server* g_servers = NULL;
362
363void KillHandler(int unused) {
364  g_killed = true;
365  for (int i = 0; i < g_server_count; i++)
366    g_servers[i].Shutdown();
367}
368
369void DumpInformation(int unused) {
370  for (int i = 0; i < g_server_count; i++)
371    g_servers[i].DumpInformation();
372}
373
374}  // namespace
375
376int main(int argc, char** argv) {
377  printf("Android device to host TCP forwarder\n");
378  printf("Like 'adb forward' but in the reverse direction\n");
379
380  CommandLine command_line(argc, argv);
381  CommandLine::StringVector server_args = command_line.GetArgs();
382  if (tools::HasHelpSwitch(command_line) || server_args.empty()) {
383    tools::ShowHelp(
384        argv[0],
385        "<Device port>[:<Forward to port>:<Forward to address>] ...",
386        "  <Forward to port> default is <Device port>\n"
387        "  <Forward to address> default is 127.0.0.1\n"
388        "If <Device port> is 0, a port will by dynamically allocated.\n");
389    return 0;
390  }
391
392  g_servers = new Server[server_args.size()];
393  g_server_count = 0;
394  int failed_count = 0;
395  for (size_t i = 0; i < server_args.size(); i++) {
396    if (!g_servers[g_server_count].InitSocket(server_args[i].c_str())) {
397      printf("Couldn't start forwarder server for port spec: %s\n",
398             server_args[i].c_str());
399      ++failed_count;
400    } else {
401      ++g_server_count;
402    }
403  }
404
405  if (g_server_count == 0) {
406    printf("No forwarder servers could be started. Exiting.\n");
407    delete [] g_servers;
408    return failed_count;
409  }
410
411  if (!tools::HasNoSpawnDaemonSwitch(command_line))
412    tools::SpawnDaemon(failed_count);
413
414  signal(SIGTERM, KillHandler);
415  signal(SIGUSR2, DumpInformation);
416
417  for (int i = 0; i < g_server_count; i++)
418    g_servers[i].StartThread();
419  for (int i = 0; i < g_server_count; i++)
420    g_servers[i].JoinThread();
421  g_server_count = 0;
422  delete [] g_servers;
423
424  return 0;
425}
426
427