1/* tcpsvd.c - TCP(UDP)/IP service daemon
2 *
3 * Copyright 2013 Ashwini Kumar <ak.ashwini@gmail.com>
4 * Copyright 2013 Sandeep Sharma <sandeep.jack2756@gmail.com>
5 * Copyright 2013 Kyungwan Han <asura321@gmail.com>
6 *
7 * No Standard.
8
9USE_TCPSVD(NEWTOY(tcpsvd, "^<3c#=30<1C:b#=20<0u:l:hEv", TOYFLAG_USR|TOYFLAG_BIN))
10USE_TCPSVD(OLDTOY(udpsvd, tcpsvd, TOYFLAG_USR|TOYFLAG_BIN))
11
12config TCPSVD
13  bool "tcpsvd"
14  default n
15  help
16    usage: tcpsvd [-hEv] [-c N] [-C N[:MSG]] [-b N] [-u User] [-l Name] IP Port Prog
17    usage: udpsvd [-hEv] [-c N] [-u User] [-l Name] IP Port Prog
18
19    Create TCP/UDP socket, bind to IP:PORT and listen for incoming connection.
20    Run PROG for each connection.
21
22    IP            IP to listen on, 0 = all
23    PORT          Port to listen on
24    PROG ARGS     Program to run
25    -l NAME       Local hostname (else looks up local hostname in DNS)
26    -u USER[:GRP] Change to user/group after bind
27    -c N          Handle up to N (> 0) connections simultaneously
28    -b N          (TCP Only) Allow a backlog of approximately N TCP SYNs
29    -C N[:MSG]    (TCP Only) Allow only up to N (> 0) connections from the same IP
30                  New connections from this IP address are closed
31                  immediately. MSG is written to the peer before close
32    -h            Look up peer's hostname
33    -E            Don't set up environment variables
34    -v            Verbose
35*/
36
37#define FOR_tcpsvd
38#include "toys.h"
39
40GLOBALS(
41  char *name;
42  char *user;
43  long bn;
44  char *nmsg;
45  long cn;
46
47  int maxc;
48  int count_all;
49  int udp;
50)
51
52struct list_pid {
53  struct list_pid *next;
54  char *ip;
55  int pid;
56};
57
58struct list {
59  struct list* next;
60  char *d;
61  int count;
62};
63
64struct hashed {
65  struct list *head;
66};
67
68#define HASH_NR 256
69struct hashed h[HASH_NR];
70struct list_pid *pids = NULL;
71
72// convert IP address to string.
73static char *sock_to_address(struct sockaddr *sock, int flags)
74{
75  char hbuf[NI_MAXHOST] = {0,};
76  char sbuf[NI_MAXSERV] = {0,};
77  int status = 0;
78  socklen_t len = sizeof(struct sockaddr_in6);
79
80  if (!(status = getnameinfo(sock, len, hbuf, sizeof(hbuf), sbuf,
81          sizeof(sbuf), flags))) {
82    if (flags & NI_NUMERICSERV) return xmprintf("%s:%s",hbuf, sbuf);
83    return xmprintf("%s",hbuf);
84  }
85  error_exit("getnameinfo: %s", gai_strerror(status));
86}
87
88// Insert pid, ip and fd in the list.
89static void insert(struct list_pid **l, int pid, char *addr)
90{
91  struct list_pid *newnode = xmalloc(sizeof(struct list_pid));
92  newnode->pid = pid;
93  newnode->ip = addr;
94  newnode->next = NULL;
95  if (!*l) *l = newnode;
96  else {
97    newnode->next = (*l);
98   *l = newnode;
99  }
100}
101
102// Hashing of IP address.
103static int haship( char *addr)
104{
105  uint32_t ip[8] = {0,};
106  int count = 0, i = 0;
107
108  if (!addr) error_exit("NULL ip");
109  while (i < strlen(addr)) {
110    while (addr[i] && (addr[i] != ':') && (addr[i] != '.')) {
111      ip[count] = ip[count]*10 + (addr[i]-'0');
112      i++;
113    }
114    if (i >= strlen(addr)) break;
115    count++;
116    i++;
117  }
118  return (ip[0]^ip[1]^ip[2]^ip[3]^ip[4]^ip[5]^ip[6]^ip[7])%HASH_NR;
119}
120
121// Remove a node from the list.
122static char *delete(struct list_pid **pids, int pid)
123{
124  struct list_pid *prev, *free_node, *head = *pids;
125  char *ip = NULL;
126
127  if (!head) return NULL;
128  prev = free_node = NULL;
129  while (head) {
130    if (head->pid == pid) {
131      ip = head->ip;
132      free_node = head;
133      if (!prev) *pids = head->next;
134      else prev->next = head->next;
135      free(free_node);
136      return ip;
137    }
138    prev = head;
139    head = head->next;
140  }
141  return NULL;
142}
143
144// decrement the ref count fora connection, if count reches ZERO then remove the node
145static void remove_connection(char *ip)
146{
147  struct list *head, *prev = NULL, *free_node = NULL;
148  int hash = haship(ip);
149
150  head = h[hash].head;
151  while (head) {
152    if (!strcmp(ip, head->d)) {
153      head->count--;
154      free_node = head;
155      if (!head->count) {
156        if (!prev) h[hash].head = head->next;
157        else prev->next = head->next;
158        free(free_node);
159      }
160      break;
161    }
162    prev = head;
163    head = head->next;
164  }
165  free(ip);
166}
167
168// Handler function.
169static void handle_exit(int sig)
170{
171  int status;
172  pid_t pid_n = wait(&status);
173
174  if (pid_n <= 0) return;
175  char *ip = delete(&pids, pid_n);
176  if (!ip) return;
177  remove_connection(ip);
178  TT.count_all--;
179  if (toys.optflags & FLAG_v) {
180    if (WIFEXITED(status))
181      xprintf("%s: end %d exit %d\n",toys.which->name, pid_n, WEXITSTATUS(status));
182    else if (WIFSIGNALED(status))
183      xprintf("%s: end %d signaled %d\n",toys.which->name, pid_n, WTERMSIG(status));
184    if (TT.cn > 1) xprintf("%s: status %d/%d\n",toys.which->name, TT.count_all, TT.cn);
185  }
186}
187
188// Grab uid and gid
189static void get_uidgid(uid_t *uid, gid_t *gid, char *ug)
190{
191  struct passwd *pass = NULL;
192  struct group *grp = NULL;
193  char *user = NULL, *group = NULL;
194  unsigned int n;
195
196  user = ug;
197  group = strchr(ug,':');
198  if (group) {
199    *group = '\0';
200    group++;
201  }
202  if (!(pass = getpwnam(user))) {
203    n = atolx_range(user, 0, INT_MAX);
204    if (!(pass = getpwuid(n))) perror_exit("Invalid user '%s'", user);
205  }
206  *uid = pass->pw_uid;
207  *gid = pass->pw_gid;
208
209  if (group) {
210    if (!(grp = getgrnam(group))) {
211      n = atolx_range(group, 0, INT_MAX);
212      if (!(grp = getgrgid(n))) perror_exit("Invalid group '%s'",group);
213    }
214  }
215  if (grp) *gid = grp->gr_gid;
216}
217
218// Bind socket.
219static int create_bind_sock(char *host, struct sockaddr *haddr)
220{
221  struct addrinfo hints, *res = NULL, *rp;
222  int sockfd, ret, set = 1;
223  char *ptr;
224  unsigned long port;
225
226  errno = 0;
227  port = strtoul(toys.optargs[1], &ptr, 10);
228  if (errno || port > 65535)
229    error_exit("Invalid port, Range is [0-65535]");
230  if (*ptr) ptr = toys.optargs[1];
231  else {
232    sprintf(toybuf, "%lu", port);
233    ptr = toybuf;
234  }
235
236  memset(&hints, 0, sizeof hints);
237  hints.ai_family = AF_UNSPEC;
238  hints.ai_socktype = ((TT.udp) ?SOCK_DGRAM : SOCK_STREAM);
239  if ((ret = getaddrinfo(host, ptr, &hints, &res)))
240    perror_exit("%s", gai_strerror(ret));
241
242  for (rp = res; rp; rp = rp->ai_next)
243    if ( (rp->ai_family == AF_INET) || (rp->ai_family == AF_INET6)) break;
244
245  if (!rp) error_exit("Invalid IP %s", host);
246
247  sockfd = xsocket(rp->ai_family, TT.udp ?SOCK_DGRAM :SOCK_STREAM, 0);
248  setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &set, sizeof(set));
249  if (TT.udp) setsockopt(sockfd, IPPROTO_IP, IP_PKTINFO, &set, sizeof(set));
250  if ((bind(sockfd, rp->ai_addr, rp->ai_addrlen)) < 0) perror_exit("Bind failed");
251  if(haddr) memcpy(haddr, rp->ai_addr, rp->ai_addrlen);
252  freeaddrinfo(res);
253  return sockfd;
254}
255
256static void handle_signal(int sig)
257{
258  if (toys.optflags & FLAG_v) xprintf("got signal %d, exit\n", sig);
259  raise(sig);
260  _exit(sig + 128); //should not reach here
261}
262
263void tcpsvd_main(void)
264{
265  uid_t uid = 0;
266  gid_t gid = 0;
267  pid_t pid;
268  char haddr[sizeof(struct sockaddr_in6)];
269  struct list *head, *newnode;
270  int hash, fd, newfd, j;
271  char *ptr = NULL, *addr, *server, buf[sizeof(struct sockaddr_in6)];
272  socklen_t len = sizeof(buf);
273
274  TT.udp = (*toys.which->name == 'u');
275  if (TT.udp) toys.optflags &= ~FLAG_C;
276  memset(buf, 0, len);
277  if (toys.optflags & FLAG_C) {
278    if ((ptr = strchr(TT.nmsg, ':'))) {
279      *ptr = '\0';
280      ptr++;
281    }
282    TT.maxc = atolx_range(TT.nmsg, 1, INT_MAX);
283  }
284
285  fd = create_bind_sock(toys.optargs[0], (struct sockaddr*)&haddr);
286  if(toys.optflags & FLAG_u) {
287    get_uidgid(&uid, &gid, TT.user);
288    setuid(uid);
289    setgid(gid);
290  }
291
292  if (!TT.udp && (listen(fd, TT.bn) < 0)) perror_exit("Listen failed");
293  server = sock_to_address((struct sockaddr*)&haddr, NI_NUMERICHOST|NI_NUMERICSERV);
294  if (toys.optflags & FLAG_v) {
295    if (toys.optflags & FLAG_u)
296      xprintf("%s: listening on %s, starting, uid %u, gid %u\n"
297          ,toys.which->name, server, uid, gid);
298    else
299      xprintf("%s: listening on %s, starting\n", toys.which->name, server);
300  }
301  for (j = 0; j < HASH_NR; j++) h[j].head = NULL;
302  sigatexit(handle_signal);
303  signal(SIGCHLD, handle_exit);
304
305  while (1) {
306    if (TT.count_all  < TT.cn) {
307      if (TT.udp) {
308        if(recvfrom(fd, NULL, 0, MSG_PEEK, (struct sockaddr *)buf, &len) < 0)
309          perror_exit("recvfrom");
310        newfd = fd;
311      } else {
312        newfd = accept(fd, (struct sockaddr *)buf, &len);
313        if (newfd < 0) perror_exit("Error on accept");
314      }
315    } else {
316      sigset_t ss;
317      sigemptyset(&ss);
318      sigsuspend(&ss);
319      continue;
320    }
321    TT.count_all++;
322    addr = sock_to_address((struct sockaddr*)buf, NI_NUMERICHOST);
323
324    hash = haship(addr);
325    if (toys.optflags & FLAG_C) {
326      for (head = h[hash].head; head; head = head->next)
327        if (!strcmp(head->d, addr)) break;
328
329      if (head && head->count >= TT.maxc) {
330        if (ptr) write(newfd, ptr, strlen(ptr)+1);
331        close(newfd);
332        TT.count_all--;
333        continue;
334      }
335    }
336
337    newnode = (struct list*)xzalloc(sizeof(struct list));
338    newnode->d = addr;
339    for (head = h[hash].head; head; head = head->next) {
340      if (!strcmp(addr, head->d)) {
341        head->count++;
342        free(newnode);
343        break;
344      }
345    }
346
347    if (!head) {
348      newnode->next = h[hash].head;
349      h[hash].head = newnode;
350      h[hash].head->count++;
351    }
352
353    if (!(pid = xfork())) {
354      char *serv = NULL, *clie = NULL;
355      char *client = sock_to_address((struct sockaddr*)buf, NI_NUMERICHOST | NI_NUMERICSERV);
356      if (toys.optflags & FLAG_h) { //lookup name
357        if (toys.optflags & FLAG_l) serv = xstrdup(TT.name);
358        else serv = sock_to_address((struct sockaddr*)&haddr, 0);
359        clie = sock_to_address((struct sockaddr*)buf, 0);
360      }
361
362      if (!(toys.optflags & FLAG_E)) {
363        setenv("PROTO", TT.udp ?"UDP" :"TCP", 1);
364        setenv("PROTOLOCALADDR", server, 1);
365        setenv("PROTOREMOTEADDR", client, 1);
366        if (toys.optflags & FLAG_h) {
367          setenv("PROTOLOCALHOST", serv, 1);
368          setenv("PROTOREMOTEHOST", clie, 1);
369        }
370        if (!TT.udp) {
371          char max_c[32];
372          sprintf(max_c, "%d", TT.maxc);
373          setenv("TCPCONCURRENCY", max_c, 1); //Not valid for udp
374        }
375      }
376      if (toys.optflags & FLAG_v) {
377        xprintf("%s: start %d %s-%s",toys.which->name, getpid(), server, client);
378        if (toys.optflags & FLAG_h) xprintf(" (%s-%s)", serv, clie);
379        xputc('\n');
380        if (TT.cn > 1)
381          xprintf("%s: status %d/%d\n",toys.which->name, TT.count_all, TT.cn);
382      }
383      free(client);
384      if (toys.optflags & FLAG_h) {
385        free(serv);
386        free(clie);
387      }
388      if (TT.udp && (connect(newfd, (struct sockaddr *)buf, sizeof(buf)) < 0))
389          perror_exit("connect");
390
391      close(0);
392      close(1);
393      dup2(newfd, 0);
394      dup2(newfd, 1);
395      xexec(toys.optargs+2); //skip IP PORT
396    } else {
397      insert(&pids, pid, addr);
398      xclose(newfd); //close and reopen for next client.
399      if (TT.udp) fd = create_bind_sock(toys.optargs[0],
400          (struct sockaddr*)&haddr);
401    }
402  } //while(1)
403}
404