1/*
2 *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "webrtc/base/firewallsocketserver.h"
12
13#include <assert.h>
14
15#include <algorithm>
16
17#include "webrtc/base/asyncsocket.h"
18#include "webrtc/base/logging.h"
19
20namespace rtc {
21
22class FirewallSocket : public AsyncSocketAdapter {
23 public:
24  FirewallSocket(FirewallSocketServer* server, AsyncSocket* socket, int type)
25    : AsyncSocketAdapter(socket), server_(server), type_(type) {
26  }
27
28  virtual int Connect(const SocketAddress& addr) {
29    if (type_ == SOCK_STREAM) {
30      if (!server_->Check(FP_TCP, GetLocalAddress(), addr)) {
31        LOG(LS_VERBOSE) << "FirewallSocket outbound TCP connection from "
32                        << GetLocalAddress().ToSensitiveString() << " to "
33                        << addr.ToSensitiveString() << " denied";
34        // TODO: Handle this asynchronously.
35        SetError(EHOSTUNREACH);
36        return SOCKET_ERROR;
37      }
38    }
39    return AsyncSocketAdapter::Connect(addr);
40  }
41  virtual int Send(const void* pv, size_t cb) {
42    return SendTo(pv, cb, GetRemoteAddress());
43  }
44  virtual int SendTo(const void* pv, size_t cb, const SocketAddress& addr) {
45    if (type_ == SOCK_DGRAM) {
46      if (!server_->Check(FP_UDP, GetLocalAddress(), addr)) {
47        LOG(LS_VERBOSE) << "FirewallSocket outbound UDP packet from "
48                        << GetLocalAddress().ToSensitiveString() << " to "
49                        << addr.ToSensitiveString() << " dropped";
50        return static_cast<int>(cb);
51      }
52    }
53    return AsyncSocketAdapter::SendTo(pv, cb, addr);
54  }
55  virtual int Recv(void* pv, size_t cb) {
56    SocketAddress addr;
57    return RecvFrom(pv, cb, &addr);
58  }
59  virtual int RecvFrom(void* pv, size_t cb, SocketAddress* paddr) {
60    if (type_ == SOCK_DGRAM) {
61      while (true) {
62        int res = AsyncSocketAdapter::RecvFrom(pv, cb, paddr);
63        if (res <= 0)
64          return res;
65        if (server_->Check(FP_UDP, *paddr, GetLocalAddress()))
66          return res;
67        LOG(LS_VERBOSE) << "FirewallSocket inbound UDP packet from "
68                        << paddr->ToSensitiveString() << " to "
69                        << GetLocalAddress().ToSensitiveString() << " dropped";
70      }
71    }
72    return AsyncSocketAdapter::RecvFrom(pv, cb, paddr);
73  }
74
75  virtual int Listen(int backlog) {
76    if (!server_->tcp_listen_enabled()) {
77      LOG(LS_VERBOSE) << "FirewallSocket listen attempt denied";
78      return -1;
79    }
80
81    return AsyncSocketAdapter::Listen(backlog);
82  }
83  virtual AsyncSocket* Accept(SocketAddress* paddr) {
84    SocketAddress addr;
85    while (AsyncSocket* sock = AsyncSocketAdapter::Accept(&addr)) {
86      if (server_->Check(FP_TCP, addr, GetLocalAddress())) {
87        if (paddr)
88          *paddr = addr;
89        return sock;
90      }
91      sock->Close();
92      delete sock;
93      LOG(LS_VERBOSE) << "FirewallSocket inbound TCP connection from "
94                      << addr.ToSensitiveString() << " to "
95                      << GetLocalAddress().ToSensitiveString() << " denied";
96    }
97    return 0;
98  }
99
100 private:
101  FirewallSocketServer* server_;
102  int type_;
103};
104
105FirewallSocketServer::FirewallSocketServer(SocketServer* server,
106                                           FirewallManager* manager,
107                                           bool should_delete_server)
108    : server_(server), manager_(manager),
109      should_delete_server_(should_delete_server),
110      udp_sockets_enabled_(true), tcp_sockets_enabled_(true),
111      tcp_listen_enabled_(true) {
112  if (manager_)
113    manager_->AddServer(this);
114}
115
116FirewallSocketServer::~FirewallSocketServer() {
117  if (manager_)
118    manager_->RemoveServer(this);
119
120  if (server_ && should_delete_server_) {
121    delete server_;
122    server_ = NULL;
123  }
124}
125
126void FirewallSocketServer::AddRule(bool allow, FirewallProtocol p,
127                                   FirewallDirection d,
128                                   const SocketAddress& addr) {
129  SocketAddress src, dst;
130  if (d == FD_IN) {
131    dst = addr;
132  } else {
133    src = addr;
134  }
135  AddRule(allow, p, src, dst);
136}
137
138
139void FirewallSocketServer::AddRule(bool allow, FirewallProtocol p,
140                                   const SocketAddress& src,
141                                   const SocketAddress& dst) {
142  Rule r;
143  r.allow = allow;
144  r.p = p;
145  r.src = src;
146  r.dst = dst;
147  CritScope scope(&crit_);
148  rules_.push_back(r);
149}
150
151void FirewallSocketServer::ClearRules() {
152  CritScope scope(&crit_);
153  rules_.clear();
154}
155
156bool FirewallSocketServer::Check(FirewallProtocol p,
157                                 const SocketAddress& src,
158                                 const SocketAddress& dst) {
159  CritScope scope(&crit_);
160  for (size_t i = 0; i < rules_.size(); ++i) {
161    const Rule& r = rules_[i];
162    if ((r.p != p) && (r.p != FP_ANY))
163      continue;
164    if ((r.src.ipaddr() != src.ipaddr()) && !r.src.IsNil())
165      continue;
166    if ((r.src.port() != src.port()) && (r.src.port() != 0))
167      continue;
168    if ((r.dst.ipaddr() != dst.ipaddr()) && !r.dst.IsNil())
169      continue;
170    if ((r.dst.port() != dst.port()) && (r.dst.port() != 0))
171      continue;
172    return r.allow;
173  }
174  return true;
175}
176
177Socket* FirewallSocketServer::CreateSocket(int type) {
178  return CreateSocket(AF_INET, type);
179}
180
181Socket* FirewallSocketServer::CreateSocket(int family, int type) {
182  return WrapSocket(server_->CreateAsyncSocket(family, type), type);
183}
184
185AsyncSocket* FirewallSocketServer::CreateAsyncSocket(int type) {
186  return CreateAsyncSocket(AF_INET, type);
187}
188
189AsyncSocket* FirewallSocketServer::CreateAsyncSocket(int family, int type) {
190  return WrapSocket(server_->CreateAsyncSocket(family, type), type);
191}
192
193AsyncSocket* FirewallSocketServer::WrapSocket(AsyncSocket* sock, int type) {
194  if (!sock ||
195      (type == SOCK_STREAM && !tcp_sockets_enabled_) ||
196      (type == SOCK_DGRAM && !udp_sockets_enabled_)) {
197    LOG(LS_VERBOSE) << "FirewallSocketServer socket creation denied";
198    delete sock;
199    return NULL;
200  }
201  return new FirewallSocket(this, sock, type);
202}
203
204FirewallManager::FirewallManager() {
205}
206
207FirewallManager::~FirewallManager() {
208  assert(servers_.empty());
209}
210
211void FirewallManager::AddServer(FirewallSocketServer* server) {
212  CritScope scope(&crit_);
213  servers_.push_back(server);
214}
215
216void FirewallManager::RemoveServer(FirewallSocketServer* server) {
217  CritScope scope(&crit_);
218  servers_.erase(std::remove(servers_.begin(), servers_.end(), server),
219                 servers_.end());
220}
221
222void FirewallManager::AddRule(bool allow, FirewallProtocol p,
223                              FirewallDirection d, const SocketAddress& addr) {
224  CritScope scope(&crit_);
225  for (std::vector<FirewallSocketServer*>::const_iterator it =
226      servers_.begin(); it != servers_.end(); ++it) {
227    (*it)->AddRule(allow, p, d, addr);
228  }
229}
230
231void FirewallManager::ClearRules() {
232  CritScope scope(&crit_);
233  for (std::vector<FirewallSocketServer*>::const_iterator it =
234      servers_.begin(); it != servers_.end(); ++it) {
235    (*it)->ClearRules();
236  }
237}
238
239}  // namespace rtc
240