1/*
2 * libjingle
3 * Copyright 2004--2005, Google Inc.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 *  1. Redistributions of source code must retain the above copyright notice,
9 *     this list of conditions and the following disclaimer.
10 *  2. Redistributions in binary form must reproduce the above copyright notice,
11 *     this list of conditions and the following disclaimer in the documentation
12 *     and/or other materials provided with the distribution.
13 *  3. The name of the author may not be used to endorse or promote products
14 *     derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28#include "talk/base/firewallsocketserver.h"
29
30#include <cassert>
31#include <algorithm>
32
33#include "talk/base/asyncsocket.h"
34#include "talk/base/logging.h"
35
36namespace talk_base {
37
38class FirewallSocket : public AsyncSocketAdapter {
39 public:
40  FirewallSocket(FirewallSocketServer* server, AsyncSocket* socket, int type)
41    : AsyncSocketAdapter(socket), server_(server), type_(type) {
42  }
43
44  virtual int Connect(const SocketAddress& addr) {
45    if (type_ == SOCK_STREAM) {
46      if (!server_->Check(FP_TCP, GetLocalAddress(), addr)) {
47        LOG(LS_VERBOSE) << "FirewallSocket outbound TCP connection from "
48                        << GetLocalAddress().ToString() << " to "
49                        << addr.ToString() << " denied";
50        // TODO: Handle this asynchronously.
51        SetError(EHOSTUNREACH);
52        return SOCKET_ERROR;
53      }
54    }
55    return AsyncSocketAdapter::Connect(addr);
56  }
57  virtual int Send(const void* pv, size_t cb) {
58    return SendTo(pv, cb, GetRemoteAddress());
59  }
60  virtual int SendTo(const void* pv, size_t cb, const SocketAddress& addr) {
61    if (type_ == SOCK_DGRAM) {
62      if (!server_->Check(FP_UDP, GetLocalAddress(), addr)) {
63        LOG(LS_VERBOSE) << "FirewallSocket outbound UDP packet from "
64                        << GetLocalAddress().ToString() << " to "
65                        << addr.ToString() << " dropped";
66        return static_cast<int>(cb);
67      }
68    }
69    return AsyncSocketAdapter::SendTo(pv, cb, addr);
70  }
71  virtual int Recv(void* pv, size_t cb) {
72    SocketAddress addr;
73    return RecvFrom(pv, cb, &addr);
74  }
75  virtual int RecvFrom(void* pv, size_t cb, SocketAddress* paddr) {
76    if (type_ == SOCK_DGRAM) {
77      while (true) {
78        int res = AsyncSocketAdapter::RecvFrom(pv, cb, paddr);
79        if (res <= 0)
80          return res;
81        if (server_->Check(FP_UDP, *paddr, GetLocalAddress()))
82          return res;
83        LOG(LS_VERBOSE) << "FirewallSocket inbound UDP packet from "
84                        << paddr->ToString() << " to "
85                        << GetLocalAddress().ToString() << " dropped";
86      }
87    }
88    return AsyncSocketAdapter::RecvFrom(pv, cb, paddr);
89  }
90
91  virtual int Listen(int backlog) {
92    if (!server_->tcp_listen_enabled()) {
93      LOG(LS_VERBOSE) << "FirewallSocket listen attempt denied";
94      return -1;
95    }
96
97    return AsyncSocketAdapter::Listen(backlog);
98  }
99  virtual AsyncSocket* Accept(SocketAddress* paddr) {
100    SocketAddress addr;
101    while (AsyncSocket* sock = AsyncSocketAdapter::Accept(&addr)) {
102      if (server_->Check(FP_TCP, addr, GetLocalAddress())) {
103        if (paddr)
104          *paddr = addr;
105        return sock;
106      }
107      sock->Close();
108      delete sock;
109      LOG(LS_VERBOSE) << "FirewallSocket inbound TCP connection from "
110                      << addr.ToString() << " to "
111                      << GetLocalAddress().ToString() << " denied";
112    }
113    return 0;
114  }
115
116 private:
117  FirewallSocketServer* server_;
118  int type_;
119};
120
121FirewallSocketServer::FirewallSocketServer(SocketServer* server,
122                                           FirewallManager* manager,
123                                           bool should_delete_server)
124    : server_(server), manager_(manager),
125      should_delete_server_(should_delete_server),
126      udp_sockets_enabled_(true), tcp_sockets_enabled_(true),
127      tcp_listen_enabled_(true) {
128  if (manager_)
129    manager_->AddServer(this);
130}
131
132FirewallSocketServer::~FirewallSocketServer() {
133  if (manager_)
134    manager_->RemoveServer(this);
135
136  if (server_ && should_delete_server_) {
137    delete server_;
138    server_ = NULL;
139  }
140}
141
142void FirewallSocketServer::AddRule(bool allow, FirewallProtocol p,
143                                   FirewallDirection d,
144                                   const SocketAddress& addr) {
145  SocketAddress src, dst;
146  if (d == FD_IN) {
147    dst = addr;
148  } else {
149    src = addr;
150  }
151  AddRule(allow, p, src, dst);
152}
153
154
155void FirewallSocketServer::AddRule(bool allow, FirewallProtocol p,
156                                   const SocketAddress& src,
157                                   const SocketAddress& dst) {
158  Rule r;
159  r.allow = allow;
160  r.p = p;
161  r.src = src;
162  r.dst = dst;
163  CritScope scope(&crit_);
164  rules_.push_back(r);
165}
166
167void FirewallSocketServer::ClearRules() {
168  CritScope scope(&crit_);
169  rules_.clear();
170}
171
172bool FirewallSocketServer::Check(FirewallProtocol p,
173                                 const SocketAddress& src,
174                                 const SocketAddress& dst) {
175  CritScope scope(&crit_);
176  for (size_t i = 0; i < rules_.size(); ++i) {
177    const Rule& r = rules_[i];
178    if ((r.p != p) && (r.p != FP_ANY))
179      continue;
180    if ((r.src.ip() != src.ip()) && !r.src.IsAny())
181      continue;
182    if ((r.src.port() != src.port()) && (r.src.port() != 0))
183      continue;
184    if ((r.dst.ip() != dst.ip()) && !r.dst.IsAny())
185      continue;
186    if ((r.dst.port() != dst.port()) && (r.dst.port() != 0))
187      continue;
188    return r.allow;
189  }
190  return true;
191}
192
193Socket* FirewallSocketServer::CreateSocket(int type) {
194  return WrapSocket(server_->CreateAsyncSocket(type), type);
195}
196
197AsyncSocket* FirewallSocketServer::CreateAsyncSocket(int type) {
198  return WrapSocket(server_->CreateAsyncSocket(type), type);
199}
200
201AsyncSocket* FirewallSocketServer::WrapSocket(AsyncSocket* sock, int type) {
202  if (!sock ||
203      (type == SOCK_STREAM && !tcp_sockets_enabled_) ||
204      (type == SOCK_DGRAM && !udp_sockets_enabled_)) {
205    LOG(LS_VERBOSE) << "FirewallSocketServer socket creation denied";
206    return NULL;
207  }
208  return new FirewallSocket(this, sock, type);
209}
210
211FirewallManager::FirewallManager() {
212}
213
214FirewallManager::~FirewallManager() {
215  assert(servers_.empty());
216}
217
218void FirewallManager::AddServer(FirewallSocketServer* server) {
219  CritScope scope(&crit_);
220  servers_.push_back(server);
221}
222
223void FirewallManager::RemoveServer(FirewallSocketServer* server) {
224  CritScope scope(&crit_);
225  servers_.erase(std::remove(servers_.begin(), servers_.end(), server),
226                 servers_.end());
227}
228
229void FirewallManager::AddRule(bool allow, FirewallProtocol p,
230                              FirewallDirection d, const SocketAddress& addr) {
231  CritScope scope(&crit_);
232  for (std::vector<FirewallSocketServer*>::const_iterator it =
233      servers_.begin(); it != servers_.end(); ++it) {
234    (*it)->AddRule(allow, p, d, addr);
235  }
236}
237
238void FirewallManager::ClearRules() {
239  CritScope scope(&crit_);
240  for (std::vector<FirewallSocketServer*>::const_iterator it =
241      servers_.begin(); it != servers_.end(); ++it) {
242    (*it)->ClearRules();
243  }
244}
245
246}  // namespace talk_base
247