1// Copyright (c) 2010 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_handshake_handler.h"
6
7#include "base/md5.h"
8#include "base/string_piece.h"
9#include "base/string_util.h"
10#include "googleurl/src/gurl.h"
11#include "net/http/http_response_headers.h"
12#include "net/http/http_util.h"
13
14namespace {
15
16const size_t kRequestKey3Size = 8U;
17const size_t kResponseKeySize = 16U;
18
19void ParseHandshakeHeader(
20    const char* handshake_message, int len,
21    std::string* status_line,
22    std::string* headers) {
23  size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n");
24  if (i == base::StringPiece::npos) {
25    *status_line = std::string(handshake_message, len);
26    *headers = "";
27    return;
28  }
29  // |status_line| includes \r\n.
30  *status_line = std::string(handshake_message, i + 2);
31
32  int header_len = len - (i + 2) - 2;
33  if (header_len > 0) {
34    // |handshake_message| includes tailing \r\n\r\n.
35    // |headers| doesn't include 2nd \r\n.
36    *headers = std::string(handshake_message + i + 2, header_len);
37  } else {
38    *headers = "";
39  }
40}
41
42void FetchHeaders(const std::string& headers,
43                  const char* const headers_to_get[],
44                  size_t headers_to_get_len,
45                  std::vector<std::string>* values) {
46  net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n");
47  while (iter.GetNext()) {
48    for (size_t i = 0; i < headers_to_get_len; i++) {
49      if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
50                               headers_to_get[i])) {
51        values->push_back(iter.values());
52      }
53    }
54  }
55}
56
57bool GetHeaderName(std::string::const_iterator line_begin,
58                   std::string::const_iterator line_end,
59                   std::string::const_iterator* name_begin,
60                   std::string::const_iterator* name_end) {
61  std::string::const_iterator colon = std::find(line_begin, line_end, ':');
62  if (colon == line_end) {
63    return false;
64  }
65  *name_begin = line_begin;
66  *name_end = colon;
67  if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin))
68    return false;
69  net::HttpUtil::TrimLWS(name_begin, name_end);
70  return true;
71}
72
73// Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that
74// is, lines that are not formatted as "<name>: <value>\r\n".
75std::string FilterHeaders(
76    const std::string& headers,
77    const char* const headers_to_remove[],
78    size_t headers_to_remove_len) {
79  std::string filtered_headers;
80
81  StringTokenizer lines(headers.begin(), headers.end(), "\r\n");
82  while (lines.GetNext()) {
83    std::string::const_iterator line_begin = lines.token_begin();
84    std::string::const_iterator line_end = lines.token_end();
85    std::string::const_iterator name_begin;
86    std::string::const_iterator name_end;
87    bool should_remove = false;
88    if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) {
89      for (size_t i = 0; i < headers_to_remove_len; ++i) {
90        if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) {
91          should_remove = true;
92          break;
93        }
94      }
95    }
96    if (!should_remove) {
97      filtered_headers.append(line_begin, line_end);
98      filtered_headers.append("\r\n");
99    }
100  }
101  return filtered_headers;
102}
103
104// Gets a key number from |key| and appends the number to |challenge|.
105// The key number (/part_N/) is extracted as step 4.-8. in
106// 5.2. Sending the server's opening handshake of
107// http://www.ietf.org/id/draft-ietf-hybi-thewebsocketprotocol-00.txt
108void GetKeyNumber(const std::string& key, std::string* challenge) {
109  uint32 key_number = 0;
110  uint32 spaces = 0;
111  for (size_t i = 0; i < key.size(); ++i) {
112    if (isdigit(key[i])) {
113      // key_number should not overflow. (it comes from
114      // WebCore/websockets/WebSocketHandshake.cpp).
115      key_number = key_number * 10 + key[i] - '0';
116    } else if (key[i] == ' ') {
117      ++spaces;
118    }
119  }
120  // spaces should not be zero in valid handshake request.
121  if (spaces == 0)
122    return;
123  key_number /= spaces;
124
125  char part[4];
126  for (int i = 0; i < 4; i++) {
127    part[3 - i] = key_number & 0xFF;
128    key_number >>= 8;
129  }
130  challenge->append(part, 4);
131}
132
133}  // anonymous namespace
134
135namespace net {
136
137WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler()
138    : original_length_(0),
139      raw_length_(0) {}
140
141bool WebSocketHandshakeRequestHandler::ParseRequest(
142    const char* data, int length) {
143  DCHECK_GT(length, 0);
144  std::string input(data, length);
145  int input_header_length =
146      HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0);
147  if (input_header_length <= 0 ||
148      input_header_length + kRequestKey3Size > input.size())
149    return false;
150
151  ParseHandshakeHeader(input.data(),
152                       input_header_length,
153                       &status_line_,
154                       &headers_);
155
156  // draft-hixie-thewebsocketprotocol-76 or later will send /key3/
157  // after handshake request header.
158  // Assumes WebKit doesn't send any data after handshake request message
159  // until handshake is finished.
160  // Thus, |key3_| is part of handshake message, and not in part
161  // of WebSocket frame stream.
162  DCHECK_EQ(kRequestKey3Size,
163            input.size() -
164            input_header_length);
165  key3_ = std::string(input.data() + input_header_length,
166                      input.size() - input_header_length);
167  original_length_ = input.size();
168  return true;
169}
170
171size_t WebSocketHandshakeRequestHandler::original_length() const {
172  return original_length_;
173}
174
175void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing(
176    const std::string& name, const std::string& value) {
177  DCHECK(!headers_.empty());
178  HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_);
179}
180
181void WebSocketHandshakeRequestHandler::RemoveHeaders(
182    const char* const headers_to_remove[],
183    size_t headers_to_remove_len) {
184  DCHECK(!headers_.empty());
185  headers_ = FilterHeaders(
186      headers_, headers_to_remove, headers_to_remove_len);
187}
188
189HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo(
190    const GURL& url, std::string* challenge) {
191  HttpRequestInfo request_info;
192  request_info.url = url;
193  base::StringPiece method = status_line_.data();
194  size_t method_end = base::StringPiece(
195      status_line_.data(), status_line_.size()).find_first_of(" ");
196  if (method_end != base::StringPiece::npos)
197    request_info.method = std::string(status_line_.data(), method_end);
198
199  request_info.extra_headers.Clear();
200  request_info.extra_headers.AddHeadersFromString(headers_);
201
202  request_info.extra_headers.RemoveHeader("Upgrade");
203  request_info.extra_headers.RemoveHeader("Connection");
204
205  challenge->clear();
206  std::string key;
207  request_info.extra_headers.GetHeader("Sec-WebSocket-Key1", &key);
208  request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key1");
209  GetKeyNumber(key, challenge);
210
211  request_info.extra_headers.GetHeader("Sec-WebSocket-Key2", &key);
212  request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key2");
213  GetKeyNumber(key, challenge);
214
215  challenge->append(key3_);
216
217  return request_info;
218}
219
220bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock(
221    const GURL& url, spdy::SpdyHeaderBlock* headers, std::string* challenge) {
222  // We don't set "method" and "version".  These are fixed value in WebSocket
223  // protocol.
224  (*headers)["url"] = url.spec();
225
226  std::string key1;
227  std::string key2;
228  HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n");
229  while (iter.GetNext()) {
230    if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
231                             "connection")) {
232      // Ignore "Connection" header.
233      continue;
234    } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
235                                    "upgrade")) {
236      // Ignore "Upgrade" header.
237      continue;
238    } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
239                                    "sec-websocket-key1")) {
240      // Use only for generating challenge.
241      key1 = iter.values();
242      continue;
243    } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
244                                    "sec-websocket-key2")) {
245      // Use only for generating challenge.
246      key2 = iter.values();
247      continue;
248    }
249    // Others should be sent out to |headers|.
250    std::string name = StringToLowerASCII(iter.name());
251    spdy::SpdyHeaderBlock::iterator found = headers->find(name);
252    if (found == headers->end()) {
253      (*headers)[name] = iter.values();
254    } else {
255      // For now, websocket doesn't use multiple headers, but follows to http.
256      found->second.append(1, '\0');  // +=() doesn't append 0's
257      found->second.append(iter.values());
258    }
259  }
260
261  challenge->clear();
262  GetKeyNumber(key1, challenge);
263  GetKeyNumber(key2, challenge);
264  challenge->append(key3_);
265
266  return true;
267}
268
269std::string WebSocketHandshakeRequestHandler::GetRawRequest() {
270  DCHECK(!status_line_.empty());
271  DCHECK(!headers_.empty());
272  DCHECK_EQ(kRequestKey3Size, key3_.size());
273  std::string raw_request = status_line_ + headers_ + "\r\n" + key3_;
274  raw_length_ = raw_request.size();
275  return raw_request;
276}
277
278size_t WebSocketHandshakeRequestHandler::raw_length() const {
279  DCHECK_GT(raw_length_, 0);
280  return raw_length_;
281}
282
283WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler()
284    : original_header_length_(0) {
285}
286
287WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {}
288
289size_t WebSocketHandshakeResponseHandler::ParseRawResponse(
290    const char* data, int length) {
291  DCHECK_GT(length, 0);
292  if (HasResponse()) {
293    DCHECK(!status_line_.empty());
294    DCHECK(!headers_.empty());
295    DCHECK_EQ(kResponseKeySize, key_.size());
296    return 0;
297  }
298
299  size_t old_original_length = original_.size();
300
301  original_.append(data, length);
302  // TODO(ukai): fail fast when response gives wrong status code.
303  original_header_length_ = HttpUtil::LocateEndOfHeaders(
304      original_.data(), original_.size(), 0);
305  if (!HasResponse())
306    return length;
307
308  ParseHandshakeHeader(original_.data(),
309                       original_header_length_,
310                       &status_line_,
311                       &headers_);
312  int header_size = status_line_.size() + headers_.size();
313  DCHECK_GE(original_header_length_, header_size);
314  header_separator_ = std::string(original_.data() + header_size,
315                                  original_header_length_ - header_size);
316  key_ = std::string(original_.data() + original_header_length_,
317                     kResponseKeySize);
318
319  return original_header_length_ + kResponseKeySize - old_original_length;
320}
321
322bool WebSocketHandshakeResponseHandler::HasResponse() const {
323  return original_header_length_ > 0 &&
324      original_header_length_ + kResponseKeySize <= original_.size();
325}
326
327bool WebSocketHandshakeResponseHandler::ParseResponseInfo(
328    const HttpResponseInfo& response_info,
329    const std::string& challenge) {
330  if (!response_info.headers.get())
331    return false;
332
333  std::string response_message;
334  response_message = response_info.headers->GetStatusLine();
335  response_message += "\r\n";
336  response_message += "Upgrade: WebSocket\r\n";
337  response_message += "Connection: Upgrade\r\n";
338  void* iter = NULL;
339  std::string name;
340  std::string value;
341  while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) {
342    response_message += name + ": " + value + "\r\n";
343  }
344  response_message += "\r\n";
345
346  MD5Digest digest;
347  MD5Sum(challenge.data(), challenge.size(), &digest);
348
349  const char* digest_data = reinterpret_cast<char*>(digest.a);
350  response_message.append(digest_data, sizeof(digest.a));
351
352  return ParseRawResponse(response_message.data(),
353                          response_message.size()) == response_message.size();
354}
355
356bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock(
357    const spdy::SpdyHeaderBlock& headers,
358    const std::string& challenge) {
359  std::string response_message;
360  response_message = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n";
361  response_message += "Upgrade: WebSocket\r\n";
362  response_message += "Connection: Upgrade\r\n";
363  for (spdy::SpdyHeaderBlock::const_iterator iter = headers.begin();
364       iter != headers.end();
365       ++iter) {
366    // For each value, if the server sends a NUL-separated list of values,
367    // we separate that back out into individual headers for each value
368    // in the list.
369    const std::string& value = iter->second;
370    size_t start = 0;
371    size_t end = 0;
372    do {
373      end = value.find('\0', start);
374      std::string tval;
375      if (end != std::string::npos)
376        tval = value.substr(start, (end - start));
377      else
378        tval = value.substr(start);
379      response_message += iter->first + ": " + tval + "\r\n";
380      start = end + 1;
381    } while (end != std::string::npos);
382  }
383  response_message += "\r\n";
384
385  MD5Digest digest;
386  MD5Sum(challenge.data(), challenge.size(), &digest);
387
388  const char* digest_data = reinterpret_cast<char*>(digest.a);
389  response_message.append(digest_data, sizeof(digest.a));
390
391  return ParseRawResponse(response_message.data(),
392                          response_message.size()) == response_message.size();
393}
394
395void WebSocketHandshakeResponseHandler::GetHeaders(
396    const char* const headers_to_get[],
397    size_t headers_to_get_len,
398    std::vector<std::string>* values) {
399  DCHECK(HasResponse());
400  DCHECK(!status_line_.empty());
401  DCHECK(!headers_.empty());
402  DCHECK_EQ(kResponseKeySize, key_.size());
403
404  FetchHeaders(headers_, headers_to_get, headers_to_get_len, values);
405}
406
407void WebSocketHandshakeResponseHandler::RemoveHeaders(
408    const char* const headers_to_remove[],
409    size_t headers_to_remove_len) {
410  DCHECK(HasResponse());
411  DCHECK(!status_line_.empty());
412  DCHECK(!headers_.empty());
413  DCHECK_EQ(kResponseKeySize, key_.size());
414
415  headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len);
416}
417
418std::string WebSocketHandshakeResponseHandler::GetRawResponse() const {
419  DCHECK(HasResponse());
420  return std::string(original_.data(),
421                     original_header_length_ + kResponseKeySize);
422}
423
424std::string WebSocketHandshakeResponseHandler::GetResponse() {
425  DCHECK(HasResponse());
426  DCHECK(!status_line_.empty());
427  // headers_ might be empty for wrong response from server.
428  DCHECK_EQ(kResponseKeySize, key_.size());
429
430  return status_line_ + headers_ + header_separator_ + key_;
431}
432
433}  // namespace net
434