1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_throttle.h"
6
7#include <string>
8
9#include "base/hash_tables.h"
10#include "base/memory/ref_counted.h"
11#include "base/memory/singleton.h"
12#include "base/message_loop.h"
13#include "base/string_number_conversions.h"
14#include "base/string_util.h"
15#include "base/stringprintf.h"
16#include "net/base/io_buffer.h"
17#include "net/base/sys_addrinfo.h"
18#include "net/socket_stream/socket_stream.h"
19#include "net/websockets/websocket_job.h"
20
21namespace net {
22
23static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) {
24  switch (addrinfo->ai_family) {
25    case AF_INET: {
26      const struct sockaddr_in* const addr =
27          reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr);
28      return base::StringPrintf("%d:%s",
29                                addrinfo->ai_family,
30                                base::HexEncode(&addr->sin_addr, 4).c_str());
31      }
32    case AF_INET6: {
33      const struct sockaddr_in6* const addr6 =
34          reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr);
35      return base::StringPrintf(
36          "%d:%s",
37          addrinfo->ai_family,
38          base::HexEncode(&addr6->sin6_addr,
39                          sizeof(addr6->sin6_addr)).c_str());
40      }
41    default:
42      return base::StringPrintf("%d:%s",
43                                addrinfo->ai_family,
44                                base::HexEncode(addrinfo->ai_addr,
45                                                addrinfo->ai_addrlen).c_str());
46  }
47}
48
49WebSocketThrottle::WebSocketThrottle() {
50}
51
52WebSocketThrottle::~WebSocketThrottle() {
53  DCHECK(queue_.empty());
54  DCHECK(addr_map_.empty());
55}
56
57// static
58WebSocketThrottle* WebSocketThrottle::GetInstance() {
59  return Singleton<WebSocketThrottle>::get();
60}
61
62void WebSocketThrottle::PutInQueue(WebSocketJob* job) {
63  queue_.push_back(job);
64  const AddressList& address_list = job->address_list();
65  base::hash_set<std::string> address_set;
66  for (const struct addrinfo* addrinfo = address_list.head();
67       addrinfo != NULL;
68       addrinfo = addrinfo->ai_next) {
69    std::string addrkey = AddrinfoToHashkey(addrinfo);
70
71    // If |addrkey| is already processed, don't do it again.
72    if (address_set.find(addrkey) != address_set.end())
73      continue;
74    address_set.insert(addrkey);
75
76    ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
77    if (iter == addr_map_.end()) {
78      ConnectingQueue* queue = new ConnectingQueue();
79      queue->push_back(job);
80      addr_map_[addrkey] = queue;
81    } else {
82      iter->second->push_back(job);
83      job->SetWaiting();
84      DVLOG(1) << "Waiting on " << addrkey;
85    }
86  }
87}
88
89void WebSocketThrottle::RemoveFromQueue(WebSocketJob* job) {
90  bool in_queue = false;
91  for (ConnectingQueue::iterator iter = queue_.begin();
92       iter != queue_.end();
93       ++iter) {
94    if (*iter == job) {
95      queue_.erase(iter);
96      in_queue = true;
97      break;
98    }
99  }
100  if (!in_queue)
101    return;
102  const AddressList& address_list = job->address_list();
103  base::hash_set<std::string> address_set;
104  for (const struct addrinfo* addrinfo = address_list.head();
105       addrinfo != NULL;
106       addrinfo = addrinfo->ai_next) {
107    std::string addrkey = AddrinfoToHashkey(addrinfo);
108    // If |addrkey| is already processed, don't do it again.
109    if (address_set.find(addrkey) != address_set.end())
110      continue;
111    address_set.insert(addrkey);
112
113    ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
114    DCHECK(iter != addr_map_.end());
115
116    ConnectingQueue* queue = iter->second;
117    // Job may not be front of queue when job is closed early while waiting.
118    for (ConnectingQueue::iterator iter = queue->begin();
119         iter != queue->end();
120         ++iter) {
121      if (*iter == job) {
122        queue->erase(iter);
123        break;
124      }
125    }
126    if (queue->empty()) {
127      delete queue;
128      addr_map_.erase(iter);
129    }
130  }
131}
132
133void WebSocketThrottle::WakeupSocketIfNecessary() {
134  for (ConnectingQueue::iterator iter = queue_.begin();
135       iter != queue_.end();
136       ++iter) {
137    WebSocketJob* job = *iter;
138    if (!job->IsWaiting())
139      continue;
140
141    bool should_wakeup = true;
142    const AddressList& address_list = job->address_list();
143    for (const struct addrinfo* addrinfo = address_list.head();
144         addrinfo != NULL;
145         addrinfo = addrinfo->ai_next) {
146      std::string addrkey = AddrinfoToHashkey(addrinfo);
147      ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
148      DCHECK(iter != addr_map_.end());
149      ConnectingQueue* queue = iter->second;
150      if (job != queue->front()) {
151        should_wakeup = false;
152        break;
153      }
154    }
155    if (should_wakeup)
156      job->Wakeup();
157  }
158}
159
160}  // namespace net
161