1/*
2 *  Copyright 2015 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include <map>
12#include <set>
13#include <string>
14
15#include "webrtc/base/asyncpacketsocket.h"
16#include "webrtc/base/asyncresolverinterface.h"
17#include "webrtc/base/bind.h"
18#include "webrtc/base/checks.h"
19#include "webrtc/base/helpers.h"
20#include "webrtc/base/logging.h"
21#include "webrtc/base/timeutils.h"
22#include "webrtc/base/thread.h"
23#include "webrtc/p2p/base/packetsocketfactory.h"
24#include "webrtc/p2p/base/stun.h"
25#include "webrtc/p2p/stunprober/stunprober.h"
26
27namespace stunprober {
28
29namespace {
30
31const int THREAD_WAKE_UP_INTERVAL_MS = 5;
32
33template <typename T>
34void IncrementCounterByAddress(std::map<T, int>* counter_per_ip, const T& ip) {
35  counter_per_ip->insert(std::make_pair(ip, 0)).first->second++;
36}
37
38}  // namespace
39
40// A requester tracks the requests and responses from a single socket to many
41// STUN servers
42class StunProber::Requester : public sigslot::has_slots<> {
43 public:
44  // Each Request maps to a request and response.
45  struct Request {
46    // Actual time the STUN bind request was sent.
47    int64_t sent_time_ms = 0;
48    // Time the response was received.
49    int64_t received_time_ms = 0;
50
51    // Server reflexive address from STUN response for this given request.
52    rtc::SocketAddress srflx_addr;
53
54    rtc::IPAddress server_addr;
55
56    int64_t rtt() { return received_time_ms - sent_time_ms; }
57    void ProcessResponse(const char* buf, size_t buf_len);
58  };
59
60  // StunProber provides |server_ips| for Requester to probe. For shared
61  // socket mode, it'll be all the resolved IP addresses. For non-shared mode,
62  // it'll just be a single address.
63  Requester(StunProber* prober,
64            rtc::AsyncPacketSocket* socket,
65            const std::vector<rtc::SocketAddress>& server_ips);
66  virtual ~Requester();
67
68  // There is no callback for SendStunRequest as the underneath socket send is
69  // expected to be completed immediately. Otherwise, it'll skip this request
70  // and move to the next one.
71  void SendStunRequest();
72
73  void OnStunResponseReceived(rtc::AsyncPacketSocket* socket,
74                              const char* buf,
75                              size_t size,
76                              const rtc::SocketAddress& addr,
77                              const rtc::PacketTime& time);
78
79  const std::vector<Request*>& requests() { return requests_; }
80
81  // Whether this Requester has completed all requests.
82  bool Done() {
83    return static_cast<size_t>(num_request_sent_) == server_ips_.size();
84  }
85
86 private:
87  Request* GetRequestByAddress(const rtc::IPAddress& ip);
88
89  StunProber* prober_;
90
91  // The socket for this session.
92  rtc::scoped_ptr<rtc::AsyncPacketSocket> socket_;
93
94  // Temporary SocketAddress and buffer for RecvFrom.
95  rtc::SocketAddress addr_;
96  rtc::scoped_ptr<rtc::ByteBuffer> response_packet_;
97
98  std::vector<Request*> requests_;
99  std::vector<rtc::SocketAddress> server_ips_;
100  int16_t num_request_sent_ = 0;
101  int16_t num_response_received_ = 0;
102
103  rtc::ThreadChecker& thread_checker_;
104
105  RTC_DISALLOW_COPY_AND_ASSIGN(Requester);
106};
107
108StunProber::Requester::Requester(
109    StunProber* prober,
110    rtc::AsyncPacketSocket* socket,
111    const std::vector<rtc::SocketAddress>& server_ips)
112    : prober_(prober),
113      socket_(socket),
114      response_packet_(new rtc::ByteBuffer(nullptr, kMaxUdpBufferSize)),
115      server_ips_(server_ips),
116      thread_checker_(prober->thread_checker_) {
117  socket_->SignalReadPacket.connect(
118      this, &StunProber::Requester::OnStunResponseReceived);
119}
120
121StunProber::Requester::~Requester() {
122  if (socket_) {
123    socket_->Close();
124  }
125  for (auto req : requests_) {
126    if (req) {
127      delete req;
128    }
129  }
130}
131
132void StunProber::Requester::SendStunRequest() {
133  RTC_DCHECK(thread_checker_.CalledOnValidThread());
134  requests_.push_back(new Request());
135  Request& request = *(requests_.back());
136  cricket::StunMessage message;
137
138  // Random transaction ID, STUN_BINDING_REQUEST
139  message.SetTransactionID(
140      rtc::CreateRandomString(cricket::kStunTransactionIdLength));
141  message.SetType(cricket::STUN_BINDING_REQUEST);
142
143  rtc::scoped_ptr<rtc::ByteBuffer> request_packet(
144      new rtc::ByteBuffer(nullptr, kMaxUdpBufferSize));
145  if (!message.Write(request_packet.get())) {
146    prober_->ReportOnFinished(WRITE_FAILED);
147    return;
148  }
149
150  auto addr = server_ips_[num_request_sent_];
151  request.server_addr = addr.ipaddr();
152
153  // The write must succeed immediately. Otherwise, the calculating of the STUN
154  // request timing could become too complicated. Callback is ignored by passing
155  // empty AsyncCallback.
156  rtc::PacketOptions options;
157  int rv = socket_->SendTo(const_cast<char*>(request_packet->Data()),
158                           request_packet->Length(), addr, options);
159  if (rv < 0) {
160    prober_->ReportOnFinished(WRITE_FAILED);
161    return;
162  }
163
164  request.sent_time_ms = rtc::Time();
165
166  num_request_sent_++;
167  RTC_DCHECK(static_cast<size_t>(num_request_sent_) <= server_ips_.size());
168}
169
170void StunProber::Requester::Request::ProcessResponse(const char* buf,
171                                                     size_t buf_len) {
172  int64_t now = rtc::Time();
173  rtc::ByteBuffer message(buf, buf_len);
174  cricket::StunMessage stun_response;
175  if (!stun_response.Read(&message)) {
176    // Invalid or incomplete STUN packet.
177    received_time_ms = 0;
178    return;
179  }
180
181  // Get external address of the socket.
182  const cricket::StunAddressAttribute* addr_attr =
183      stun_response.GetAddress(cricket::STUN_ATTR_MAPPED_ADDRESS);
184  if (addr_attr == nullptr) {
185    // Addresses not available to detect whether or not behind a NAT.
186    return;
187  }
188
189  if (addr_attr->family() != cricket::STUN_ADDRESS_IPV4 &&
190      addr_attr->family() != cricket::STUN_ADDRESS_IPV6) {
191    return;
192  }
193
194  received_time_ms = now;
195
196  srflx_addr = addr_attr->GetAddress();
197}
198
199void StunProber::Requester::OnStunResponseReceived(
200    rtc::AsyncPacketSocket* socket,
201    const char* buf,
202    size_t size,
203    const rtc::SocketAddress& addr,
204    const rtc::PacketTime& time) {
205  RTC_DCHECK(thread_checker_.CalledOnValidThread());
206  RTC_DCHECK(socket_);
207  Request* request = GetRequestByAddress(addr.ipaddr());
208  if (!request) {
209    // Something is wrong, finish the test.
210    prober_->ReportOnFinished(GENERIC_FAILURE);
211    return;
212  }
213
214  num_response_received_++;
215  request->ProcessResponse(buf, size);
216}
217
218StunProber::Requester::Request* StunProber::Requester::GetRequestByAddress(
219    const rtc::IPAddress& ipaddr) {
220  RTC_DCHECK(thread_checker_.CalledOnValidThread());
221  for (auto request : requests_) {
222    if (request->server_addr == ipaddr) {
223      return request;
224    }
225  }
226
227  return nullptr;
228}
229
230StunProber::StunProber(rtc::PacketSocketFactory* socket_factory,
231                       rtc::Thread* thread,
232                       const rtc::NetworkManager::NetworkList& networks)
233    : interval_ms_(0),
234      socket_factory_(socket_factory),
235      thread_(thread),
236      networks_(networks) {
237}
238
239StunProber::~StunProber() {
240  for (auto req : requesters_) {
241    if (req) {
242      delete req;
243    }
244  }
245  for (auto s : sockets_) {
246    if (s) {
247      delete s;
248    }
249  }
250}
251
252bool StunProber::Start(const std::vector<rtc::SocketAddress>& servers,
253                       bool shared_socket_mode,
254                       int interval_ms,
255                       int num_request_per_ip,
256                       int timeout_ms,
257                       const AsyncCallback callback) {
258  observer_adapter_.set_callback(callback);
259  return Prepare(servers, shared_socket_mode, interval_ms, num_request_per_ip,
260                 timeout_ms, &observer_adapter_);
261}
262
263bool StunProber::Prepare(const std::vector<rtc::SocketAddress>& servers,
264                         bool shared_socket_mode,
265                         int interval_ms,
266                         int num_request_per_ip,
267                         int timeout_ms,
268                         StunProber::Observer* observer) {
269  RTC_DCHECK(thread_checker_.CalledOnValidThread());
270  interval_ms_ = interval_ms;
271  shared_socket_mode_ = shared_socket_mode;
272
273  requests_per_ip_ = num_request_per_ip;
274  if (requests_per_ip_ == 0 || servers.size() == 0) {
275    return false;
276  }
277
278  timeout_ms_ = timeout_ms;
279  servers_ = servers;
280  observer_ = observer;
281  return ResolveServerName(servers_.back());
282}
283
284bool StunProber::Start(StunProber::Observer* observer) {
285  observer_ = observer;
286  if (total_ready_sockets_ != total_socket_required()) {
287    return false;
288  }
289  MaybeScheduleStunRequests();
290  return true;
291}
292
293bool StunProber::ResolveServerName(const rtc::SocketAddress& addr) {
294  rtc::AsyncResolverInterface* resolver =
295      socket_factory_->CreateAsyncResolver();
296  if (!resolver) {
297    return false;
298  }
299  resolver->SignalDone.connect(this, &StunProber::OnServerResolved);
300  resolver->Start(addr);
301  return true;
302}
303
304void StunProber::OnSocketReady(rtc::AsyncPacketSocket* socket,
305                               const rtc::SocketAddress& addr) {
306  total_ready_sockets_++;
307  if (total_ready_sockets_ == total_socket_required()) {
308    ReportOnPrepared(SUCCESS);
309  }
310}
311
312void StunProber::OnServerResolved(rtc::AsyncResolverInterface* resolver) {
313  RTC_DCHECK(thread_checker_.CalledOnValidThread());
314
315  if (resolver->GetError() == 0) {
316    rtc::SocketAddress addr(resolver->address().ipaddr(),
317                            resolver->address().port());
318    all_servers_addrs_.push_back(addr);
319  }
320
321  // Deletion of AsyncResolverInterface can't be done in OnResolveResult which
322  // handles SignalDone.
323  invoker_.AsyncInvoke<void>(
324      thread_,
325      rtc::Bind(&rtc::AsyncResolverInterface::Destroy, resolver, false));
326  servers_.pop_back();
327
328  if (servers_.size()) {
329    if (!ResolveServerName(servers_.back())) {
330      ReportOnPrepared(RESOLVE_FAILED);
331    }
332    return;
333  }
334
335  if (all_servers_addrs_.size() == 0) {
336    ReportOnPrepared(RESOLVE_FAILED);
337    return;
338  }
339
340  // Dedupe.
341  std::set<rtc::SocketAddress> addrs(all_servers_addrs_.begin(),
342                                     all_servers_addrs_.end());
343  all_servers_addrs_.assign(addrs.begin(), addrs.end());
344
345  // Prepare all the sockets beforehand. All of them will bind to "any" address.
346  while (sockets_.size() < total_socket_required()) {
347    rtc::scoped_ptr<rtc::AsyncPacketSocket> socket(
348        socket_factory_->CreateUdpSocket(rtc::SocketAddress(INADDR_ANY, 0), 0,
349                                         0));
350    if (!socket) {
351      ReportOnPrepared(GENERIC_FAILURE);
352      return;
353    }
354    // Chrome and WebRTC behave differently in terms of the state of a socket
355    // once returned from PacketSocketFactory::CreateUdpSocket.
356    if (socket->GetState() == rtc::AsyncPacketSocket::STATE_BINDING) {
357      socket->SignalAddressReady.connect(this, &StunProber::OnSocketReady);
358    } else {
359      OnSocketReady(socket.get(), rtc::SocketAddress(INADDR_ANY, 0));
360    }
361    sockets_.push_back(socket.release());
362  }
363}
364
365StunProber::Requester* StunProber::CreateRequester() {
366  RTC_DCHECK(thread_checker_.CalledOnValidThread());
367  if (!sockets_.size()) {
368    return nullptr;
369  }
370  StunProber::Requester* requester;
371  if (shared_socket_mode_) {
372    requester = new Requester(this, sockets_.back(), all_servers_addrs_);
373  } else {
374    std::vector<rtc::SocketAddress> server_ip;
375    server_ip.push_back(
376        all_servers_addrs_[(num_request_sent_ % all_servers_addrs_.size())]);
377    requester = new Requester(this, sockets_.back(), server_ip);
378  }
379
380  sockets_.pop_back();
381  return requester;
382}
383
384bool StunProber::SendNextRequest() {
385  if (!current_requester_ || current_requester_->Done()) {
386    current_requester_ = CreateRequester();
387    requesters_.push_back(current_requester_);
388  }
389  if (!current_requester_) {
390    return false;
391  }
392  current_requester_->SendStunRequest();
393  num_request_sent_++;
394  return true;
395}
396
397bool StunProber::should_send_next_request(uint32_t now) {
398  if (interval_ms_ < THREAD_WAKE_UP_INTERVAL_MS) {
399    return now >= next_request_time_ms_;
400  } else {
401    return (now + (THREAD_WAKE_UP_INTERVAL_MS / 2)) >= next_request_time_ms_;
402  }
403}
404
405int StunProber::get_wake_up_interval_ms() {
406  if (interval_ms_ < THREAD_WAKE_UP_INTERVAL_MS) {
407    return 1;
408  } else {
409    return THREAD_WAKE_UP_INTERVAL_MS;
410  }
411}
412
413void StunProber::MaybeScheduleStunRequests() {
414  RTC_DCHECK(thread_checker_.CalledOnValidThread());
415  uint32_t now = rtc::Time();
416
417  if (Done()) {
418    invoker_.AsyncInvokeDelayed<void>(
419        thread_, rtc::Bind(&StunProber::ReportOnFinished, this, SUCCESS),
420        timeout_ms_);
421    return;
422  }
423  if (should_send_next_request(now)) {
424    if (!SendNextRequest()) {
425      ReportOnFinished(GENERIC_FAILURE);
426      return;
427    }
428    next_request_time_ms_ = now + interval_ms_;
429  }
430  invoker_.AsyncInvokeDelayed<void>(
431      thread_, rtc::Bind(&StunProber::MaybeScheduleStunRequests, this),
432      get_wake_up_interval_ms());
433}
434
435bool StunProber::GetStats(StunProber::Stats* prob_stats) const {
436  // No need to be on the same thread.
437  if (!prob_stats) {
438    return false;
439  }
440
441  StunProber::Stats stats;
442
443  int rtt_sum = 0;
444  int64_t first_sent_time = 0;
445  int64_t last_sent_time = 0;
446  NatType nat_type = NATTYPE_INVALID;
447
448  // Track of how many srflx IP that we have seen.
449  std::set<rtc::IPAddress> srflx_ips;
450
451  // If we're not receiving any response on a given IP, all requests sent to
452  // that IP should be ignored as this could just be an DNS error.
453  std::map<rtc::IPAddress, int> num_response_per_server;
454  std::map<rtc::IPAddress, int> num_request_per_server;
455
456  for (auto* requester : requesters_) {
457    std::map<rtc::SocketAddress, int> num_response_per_srflx_addr;
458    for (auto request : requester->requests()) {
459      if (request->sent_time_ms <= 0) {
460        continue;
461      }
462
463      ++stats.raw_num_request_sent;
464      IncrementCounterByAddress(&num_request_per_server, request->server_addr);
465
466      if (!first_sent_time) {
467        first_sent_time = request->sent_time_ms;
468      }
469      last_sent_time = request->sent_time_ms;
470
471      if (request->received_time_ms < request->sent_time_ms) {
472        continue;
473      }
474
475      IncrementCounterByAddress(&num_response_per_server, request->server_addr);
476      IncrementCounterByAddress(&num_response_per_srflx_addr,
477                                request->srflx_addr);
478      rtt_sum += request->rtt();
479      stats.srflx_addrs.insert(request->srflx_addr.ToString());
480      srflx_ips.insert(request->srflx_addr.ipaddr());
481    }
482
483    // If we're using shared mode and seeing >1 srflx addresses for a single
484    // requester, it's symmetric NAT.
485    if (shared_socket_mode_ && num_response_per_srflx_addr.size() > 1) {
486      nat_type = NATTYPE_SYMMETRIC;
487    }
488  }
489
490  // We're probably not behind a regular NAT. We have more than 1 distinct
491  // server reflexive IPs.
492  if (srflx_ips.size() > 1) {
493    return false;
494  }
495
496  int num_sent = 0;
497  int num_received = 0;
498  int num_server_ip_with_response = 0;
499
500  for (const auto& kv : num_response_per_server) {
501    RTC_DCHECK_GT(kv.second, 0);
502    num_server_ip_with_response++;
503    num_received += kv.second;
504    num_sent += num_request_per_server[kv.first];
505  }
506
507  // Shared mode is only true if we use the shared socket and there are more
508  // than 1 responding servers.
509  stats.shared_socket_mode =
510      shared_socket_mode_ && (num_server_ip_with_response > 1);
511
512  if (stats.shared_socket_mode && nat_type == NATTYPE_INVALID) {
513    nat_type = NATTYPE_NON_SYMMETRIC;
514  }
515
516  // If we could find a local IP matching srflx, we're not behind a NAT.
517  rtc::SocketAddress srflx_addr;
518  if (stats.srflx_addrs.size() &&
519      !srflx_addr.FromString(*(stats.srflx_addrs.begin()))) {
520    return false;
521  }
522  for (const auto& net : networks_) {
523    if (srflx_addr.ipaddr() == net->GetBestIP()) {
524      nat_type = stunprober::NATTYPE_NONE;
525      stats.host_ip = net->GetBestIP().ToString();
526      break;
527    }
528  }
529
530  // Finally, we know we're behind a NAT but can't determine which type it is.
531  if (nat_type == NATTYPE_INVALID) {
532    nat_type = NATTYPE_UNKNOWN;
533  }
534
535  stats.nat_type = nat_type;
536  stats.num_request_sent = num_sent;
537  stats.num_response_received = num_received;
538  stats.target_request_interval_ns = interval_ms_ * 1000;
539
540  if (num_sent) {
541    stats.success_percent = static_cast<int>(100 * num_received / num_sent);
542  }
543
544  if (stats.raw_num_request_sent > 1) {
545    stats.actual_request_interval_ns =
546        (1000 * (last_sent_time - first_sent_time)) /
547        (stats.raw_num_request_sent - 1);
548  }
549
550  if (num_received) {
551    stats.average_rtt_ms = static_cast<int>((rtt_sum / num_received));
552  }
553
554  *prob_stats = stats;
555  return true;
556}
557
558void StunProber::ReportOnPrepared(StunProber::Status status) {
559  if (observer_) {
560    observer_->OnPrepared(this, status);
561  }
562}
563
564void StunProber::ReportOnFinished(StunProber::Status status) {
565  if (observer_) {
566    observer_->OnFinished(this, status);
567  }
568}
569
570}  // namespace stunprober
571