websocket_throttle.cc revision c7f5f8508d98d5952d42ed7648c2a8f30a4da156
1// Copyright (c) 2009 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_throttle.h"
6
7#include <string>
8
9#include "base/message_loop.h"
10#include "base/ref_counted.h"
11#include "base/singleton.h"
12#include "base/string_util.h"
13#include "net/base/io_buffer.h"
14#include "net/base/sys_addrinfo.h"
15#include "net/socket_stream/socket_stream.h"
16
17namespace net {
18
19static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) {
20  switch (addrinfo->ai_family) {
21    case AF_INET: {
22      const struct sockaddr_in* const addr =
23          reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr);
24      return StringPrintf("%d:%s",
25                          addrinfo->ai_family,
26                          HexEncode(&addr->sin_addr, 4).c_str());
27      }
28    case AF_INET6: {
29      const struct sockaddr_in6* const addr6 =
30          reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr);
31      return StringPrintf("%d:%s",
32                          addrinfo->ai_family,
33                          HexEncode(&addr6->sin6_addr,
34                                    sizeof(addr6->sin6_addr)).c_str());
35      }
36    default:
37      return StringPrintf("%d:%s",
38                          addrinfo->ai_family,
39                          HexEncode(addrinfo->ai_addr,
40                                    addrinfo->ai_addrlen).c_str());
41  }
42}
43
44// State for WebSocket protocol on each SocketStream.
45// This is owned in SocketStream as UserData keyed by WebSocketState::kKeyName.
46// This is alive between connection starts and handshake is finished.
47// In this class, it doesn't check actual handshake finishes, but only checks
48// end of header is found in read data.
49class WebSocketThrottle::WebSocketState : public SocketStream::UserData {
50 public:
51  explicit WebSocketState(const AddressList& addrs)
52      : address_list_(addrs),
53        callback_(NULL),
54        waiting_(false),
55        handshake_finished_(false),
56        buffer_(NULL) {
57  }
58  ~WebSocketState() {}
59
60  int OnStartOpenConnection(CompletionCallback* callback) {
61    DCHECK(!callback_);
62    if (!waiting_)
63      return OK;
64    callback_ = callback;
65    return ERR_IO_PENDING;
66  }
67
68  int OnRead(const char* data, int len, CompletionCallback* callback) {
69    DCHECK(!waiting_);
70    DCHECK(!callback_);
71    DCHECK(!handshake_finished_);
72    static const int kBufferSize = 8129;
73
74    if (!buffer_) {
75      // Fast path.
76      int eoh = HttpUtil::LocateEndOfHeaders(data, len, 0);
77      if (eoh > 0) {
78        handshake_finished_ = true;
79        return OK;
80      }
81      buffer_ = new GrowableIOBuffer();
82      buffer_->SetCapacity(kBufferSize);
83    } else if (buffer_->RemainingCapacity() < len) {
84      buffer_->SetCapacity(buffer_->capacity() + kBufferSize);
85    }
86    memcpy(buffer_->data(), data, len);
87    buffer_->set_offset(buffer_->offset() + len);
88
89    int eoh = HttpUtil::LocateEndOfHeaders(buffer_->StartOfBuffer(),
90                                           buffer_->offset(), 0);
91    handshake_finished_ = (eoh > 0);
92    return OK;
93  }
94
95  const AddressList& address_list() const { return address_list_; }
96  void SetWaiting() { waiting_ = true; }
97  bool IsWaiting() const { return waiting_; }
98  bool HandshakeFinished() const { return handshake_finished_; }
99  void Wakeup() {
100    waiting_ = false;
101    // We wrap |callback_| to keep this alive while this is released.
102    scoped_refptr<CompletionCallbackRunner> runner =
103        new CompletionCallbackRunner(callback_);
104    callback_ = NULL;
105    MessageLoopForIO::current()->PostTask(
106        FROM_HERE,
107        NewRunnableMethod(runner.get(),
108                          &CompletionCallbackRunner::Run));
109  }
110
111  static const char* kKeyName;
112
113 private:
114  class CompletionCallbackRunner
115      : public base::RefCountedThreadSafe<CompletionCallbackRunner> {
116   public:
117    explicit CompletionCallbackRunner(CompletionCallback* callback)
118        : callback_(callback) {
119      DCHECK(callback_);
120    }
121    void Run() {
122      callback_->Run(OK);
123    }
124   private:
125    friend class base::RefCountedThreadSafe<CompletionCallbackRunner>;
126
127    virtual ~CompletionCallbackRunner() {}
128
129    CompletionCallback* callback_;
130
131    DISALLOW_COPY_AND_ASSIGN(CompletionCallbackRunner);
132  };
133
134  const AddressList& address_list_;
135
136  CompletionCallback* callback_;
137  // True if waiting another websocket connection is established.
138  // False if the websocket is performing handshaking.
139  bool waiting_;
140
141  // True if the websocket handshake is completed.
142  // If true, it will be removed from queue and deleted from the SocketStream
143  // UserData soon.
144  bool handshake_finished_;
145
146  // Buffer for read data to check handshake response message.
147  scoped_refptr<GrowableIOBuffer> buffer_;
148
149  DISALLOW_COPY_AND_ASSIGN(WebSocketState);
150};
151
152const char* WebSocketThrottle::WebSocketState::kKeyName = "WebSocketState";
153
154WebSocketThrottle::WebSocketThrottle() {
155  SocketStreamThrottle::RegisterSocketStreamThrottle("ws", this);
156  SocketStreamThrottle::RegisterSocketStreamThrottle("wss", this);
157}
158
159WebSocketThrottle::~WebSocketThrottle() {
160  DCHECK(queue_.empty());
161  DCHECK(addr_map_.empty());
162}
163
164int WebSocketThrottle::OnStartOpenConnection(
165    SocketStream* socket, CompletionCallback* callback) {
166  WebSocketState* state = new WebSocketState(socket->address_list());
167  PutInQueue(socket, state);
168  return state->OnStartOpenConnection(callback);
169}
170
171int WebSocketThrottle::OnRead(SocketStream* socket,
172                              const char* data, int len,
173                              CompletionCallback* callback) {
174  WebSocketState* state = static_cast<WebSocketState*>(
175      socket->GetUserData(WebSocketState::kKeyName));
176  // If no state, handshake was already completed. Do nothing.
177  if (!state)
178    return OK;
179
180  int result = state->OnRead(data, len, callback);
181  if (state->HandshakeFinished()) {
182    RemoveFromQueue(socket, state);
183    WakeupSocketIfNecessary();
184  }
185  return result;
186}
187
188int WebSocketThrottle::OnWrite(SocketStream* socket,
189                               const char* data, int len,
190                               CompletionCallback* callback) {
191  // Do nothing.
192  return OK;
193}
194
195void WebSocketThrottle::OnClose(SocketStream* socket) {
196  WebSocketState* state = static_cast<WebSocketState*>(
197      socket->GetUserData(WebSocketState::kKeyName));
198  if (!state)
199    return;
200  RemoveFromQueue(socket, state);
201  WakeupSocketIfNecessary();
202}
203
204void WebSocketThrottle::PutInQueue(SocketStream* socket,
205                                   WebSocketState* state) {
206  socket->SetUserData(WebSocketState::kKeyName, state);
207  queue_.push_back(state);
208  const AddressList& address_list = socket->address_list();
209  for (const struct addrinfo* addrinfo = address_list.head();
210       addrinfo != NULL;
211       addrinfo = addrinfo->ai_next) {
212    std::string addrkey = AddrinfoToHashkey(addrinfo);
213    ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
214    if (iter == addr_map_.end()) {
215      ConnectingQueue* queue = new ConnectingQueue();
216      queue->push_back(state);
217      addr_map_[addrkey] = queue;
218    } else {
219      iter->second->push_back(state);
220      state->SetWaiting();
221    }
222  }
223}
224
225void WebSocketThrottle::RemoveFromQueue(SocketStream* socket,
226                                        WebSocketState* state) {
227  const AddressList& address_list = socket->address_list();
228  for (const struct addrinfo* addrinfo = address_list.head();
229       addrinfo != NULL;
230       addrinfo = addrinfo->ai_next) {
231    std::string addrkey = AddrinfoToHashkey(addrinfo);
232    ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
233    DCHECK(iter != addr_map_.end());
234    ConnectingQueue* queue = iter->second;
235    DCHECK(state == queue->front());
236    queue->pop_front();
237    if (queue->empty()) {
238      delete queue;
239      addr_map_.erase(iter);
240    }
241  }
242  for (ConnectingQueue::iterator iter = queue_.begin();
243       iter != queue_.end();
244       ++iter) {
245    if (*iter == state) {
246      queue_.erase(iter);
247      break;
248    }
249  }
250  socket->SetUserData(WebSocketState::kKeyName, NULL);
251}
252
253void WebSocketThrottle::WakeupSocketIfNecessary() {
254  for (ConnectingQueue::iterator iter = queue_.begin();
255       iter != queue_.end();
256       ++iter) {
257    WebSocketState* state = *iter;
258    if (!state->IsWaiting())
259      continue;
260
261    bool should_wakeup = true;
262    const AddressList& address_list = state->address_list();
263    for (const struct addrinfo* addrinfo = address_list.head();
264         addrinfo != NULL;
265         addrinfo = addrinfo->ai_next) {
266      std::string addrkey = AddrinfoToHashkey(addrinfo);
267      ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
268      DCHECK(iter != addr_map_.end());
269      ConnectingQueue* queue = iter->second;
270      if (state != queue->front()) {
271        should_wakeup = false;
272        break;
273      }
274    }
275    if (should_wakeup)
276      state->Wakeup();
277  }
278}
279
280/* static */
281void WebSocketThrottle::Init() {
282  Singleton<WebSocketThrottle>::get();
283}
284
285}  // namespace net
286