1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_handshake_handler.h"
6
7#include <limits>
8
9#include "base/base64.h"
10#include "base/md5.h"
11#include "base/sha1.h"
12#include "base/strings/string_number_conversions.h"
13#include "base/strings/string_piece.h"
14#include "base/strings/string_tokenizer.h"
15#include "base/strings/string_util.h"
16#include "base/strings/stringprintf.h"
17#include "net/http/http_response_headers.h"
18#include "net/http/http_util.h"
19#include "url/gurl.h"
20
21namespace {
22
23const size_t kRequestKey3Size = 8U;
24const size_t kResponseKeySize = 16U;
25
26// First version that introduced new WebSocket handshake which does not
27// require sending "key3" or "response key" data after headers.
28const int kMinVersionOfHybiNewHandshake = 4;
29
30// Used when we calculate the value of Sec-WebSocket-Accept.
31const char* const kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
32
33void ParseHandshakeHeader(
34    const char* handshake_message, int len,
35    std::string* status_line,
36    std::string* headers) {
37  size_t i = base::StringPiece(handshake_message, len).find_first_of("\r\n");
38  if (i == base::StringPiece::npos) {
39    *status_line = std::string(handshake_message, len);
40    *headers = "";
41    return;
42  }
43  // |status_line| includes \r\n.
44  *status_line = std::string(handshake_message, i + 2);
45
46  int header_len = len - (i + 2) - 2;
47  if (header_len > 0) {
48    // |handshake_message| includes trailing \r\n\r\n.
49    // |headers| doesn't include 2nd \r\n.
50    *headers = std::string(handshake_message + i + 2, header_len);
51  } else {
52    *headers = "";
53  }
54}
55
56void FetchHeaders(const std::string& headers,
57                  const char* const headers_to_get[],
58                  size_t headers_to_get_len,
59                  std::vector<std::string>* values) {
60  net::HttpUtil::HeadersIterator iter(headers.begin(), headers.end(), "\r\n");
61  while (iter.GetNext()) {
62    for (size_t i = 0; i < headers_to_get_len; i++) {
63      if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(),
64                               headers_to_get[i])) {
65        values->push_back(iter.values());
66      }
67    }
68  }
69}
70
71bool GetHeaderName(std::string::const_iterator line_begin,
72                   std::string::const_iterator line_end,
73                   std::string::const_iterator* name_begin,
74                   std::string::const_iterator* name_end) {
75  std::string::const_iterator colon = std::find(line_begin, line_end, ':');
76  if (colon == line_end) {
77    return false;
78  }
79  *name_begin = line_begin;
80  *name_end = colon;
81  if (*name_begin == *name_end || net::HttpUtil::IsLWS(**name_begin))
82    return false;
83  net::HttpUtil::TrimLWS(name_begin, name_end);
84  return true;
85}
86
87// Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that
88// is, lines that are not formatted as "<name>: <value>\r\n".
89std::string FilterHeaders(
90    const std::string& headers,
91    const char* const headers_to_remove[],
92    size_t headers_to_remove_len) {
93  std::string filtered_headers;
94
95  base::StringTokenizer lines(headers.begin(), headers.end(), "\r\n");
96  while (lines.GetNext()) {
97    std::string::const_iterator line_begin = lines.token_begin();
98    std::string::const_iterator line_end = lines.token_end();
99    std::string::const_iterator name_begin;
100    std::string::const_iterator name_end;
101    bool should_remove = false;
102    if (GetHeaderName(line_begin, line_end, &name_begin, &name_end)) {
103      for (size_t i = 0; i < headers_to_remove_len; ++i) {
104        if (LowerCaseEqualsASCII(name_begin, name_end, headers_to_remove[i])) {
105          should_remove = true;
106          break;
107        }
108      }
109    }
110    if (!should_remove) {
111      filtered_headers.append(line_begin, line_end);
112      filtered_headers.append("\r\n");
113    }
114  }
115  return filtered_headers;
116}
117
118int GetVersionFromRequest(const std::string& request_headers) {
119  std::vector<std::string> values;
120  const char* const headers_to_get[2] = { "sec-websocket-version",
121                                          "sec-websocket-draft" };
122  FetchHeaders(request_headers, headers_to_get, 2, &values);
123  DCHECK_LE(values.size(), 1U);
124  if (values.empty())
125    return 0;
126  int version;
127  bool conversion_success = base::StringToInt(values[0], &version);
128  DCHECK(conversion_success);
129  DCHECK_GE(version, 1);
130  return version;
131}
132
133}  // namespace
134
135namespace net {
136
137namespace internal {
138
139void GetKeyNumber(const std::string& key, std::string* challenge) {
140  uint32 key_number = 0;
141  uint32 spaces = 0;
142  for (size_t i = 0; i < key.size(); ++i) {
143    if (isdigit(key[i])) {
144      // key_number should not overflow. (it comes from
145      // WebCore/websockets/WebSocketHandshake.cpp).
146      // Trust, but verify.
147      DCHECK_GE((std::numeric_limits<uint32>::max() - (key[i] - '0')) / 10,
148                key_number) << "Supplied key would overflow";
149      key_number = key_number * 10 + key[i] - '0';
150    } else if (key[i] == ' ') {
151      ++spaces;
152    }
153  }
154  DCHECK_NE(0u, spaces) << "Key must contain at least one space";
155  if (spaces == 0)
156    return;
157  DCHECK_EQ(0u, key_number % spaces) << "Key number must be an integral "
158                                     << "multiple of the number of spaces";
159  key_number /= spaces;
160
161  char part[4];
162  for (int i = 0; i < 4; i++) {
163    part[3 - i] = key_number & 0xFF;
164    key_number >>= 8;
165  }
166  challenge->append(part, 4);
167}
168
169}  // namespace internal
170
171WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler()
172    : original_length_(0),
173      raw_length_(0),
174      protocol_version_(-1) {}
175
176bool WebSocketHandshakeRequestHandler::ParseRequest(
177    const char* data, int length) {
178  DCHECK_GT(length, 0);
179  std::string input(data, length);
180  int input_header_length =
181      HttpUtil::LocateEndOfHeaders(input.data(), input.size(), 0);
182  if (input_header_length <= 0)
183    return false;
184
185  ParseHandshakeHeader(input.data(),
186                       input_header_length,
187                       &status_line_,
188                       &headers_);
189
190  // WebSocket protocol drafts hixie-76 (hybi-00), hybi-01, 02 and 03 require
191  // the clients to send key3 after the handshake request header fields.
192  // Hybi-04 and later drafts, on the other hand, no longer have key3
193  // in the handshake format.
194  protocol_version_ = GetVersionFromRequest(headers_);
195  DCHECK_GE(protocol_version_, 0);
196  if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
197    key3_ = "";
198    original_length_ = input_header_length;
199    return true;
200  }
201
202  if (input_header_length + kRequestKey3Size > input.size())
203    return false;
204
205  // Assumes WebKit doesn't send any data after handshake request message
206  // until handshake is finished.
207  // Thus, |key3_| is part of handshake message, and not in part
208  // of WebSocket frame stream.
209  DCHECK_EQ(kRequestKey3Size, input.size() - input_header_length);
210  key3_ = std::string(input.data() + input_header_length,
211                      input.size() - input_header_length);
212  original_length_ = input.size();
213  return true;
214}
215
216size_t WebSocketHandshakeRequestHandler::original_length() const {
217  return original_length_;
218}
219
220void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing(
221    const std::string& name, const std::string& value) {
222  DCHECK(!headers_.empty());
223  HttpUtil::AppendHeaderIfMissing(name.c_str(), value, &headers_);
224}
225
226void WebSocketHandshakeRequestHandler::RemoveHeaders(
227    const char* const headers_to_remove[],
228    size_t headers_to_remove_len) {
229  DCHECK(!headers_.empty());
230  headers_ = FilterHeaders(
231      headers_, headers_to_remove, headers_to_remove_len);
232}
233
234HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo(
235    const GURL& url, std::string* challenge) {
236  HttpRequestInfo request_info;
237  request_info.url = url;
238  size_t method_end = base::StringPiece(status_line_).find_first_of(" ");
239  if (method_end != base::StringPiece::npos)
240    request_info.method = std::string(status_line_.data(), method_end);
241
242  request_info.extra_headers.Clear();
243  request_info.extra_headers.AddHeadersFromString(headers_);
244
245  request_info.extra_headers.RemoveHeader("Upgrade");
246  request_info.extra_headers.RemoveHeader("Connection");
247
248  if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
249    std::string key;
250    bool header_present =
251        request_info.extra_headers.GetHeader("Sec-WebSocket-Key", &key);
252    DCHECK(header_present);
253    request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key");
254    *challenge = key;
255  } else {
256    challenge->clear();
257    std::string key;
258    bool header_present =
259        request_info.extra_headers.GetHeader("Sec-WebSocket-Key1", &key);
260    DCHECK(header_present);
261    request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key1");
262    internal::GetKeyNumber(key, challenge);
263
264    header_present =
265        request_info.extra_headers.GetHeader("Sec-WebSocket-Key2", &key);
266    DCHECK(header_present);
267    request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key2");
268    internal::GetKeyNumber(key, challenge);
269
270    challenge->append(key3_);
271  }
272
273  return request_info;
274}
275
276bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock(
277    const GURL& url,
278    SpdyHeaderBlock* headers,
279    std::string* challenge,
280    int spdy_protocol_version) {
281  // Construct opening handshake request headers as a SPDY header block.
282  // For details, see WebSocket Layering over SPDY/3 Draft 8.
283  if (spdy_protocol_version <= 2) {
284    (*headers)["path"] = url.path();
285    (*headers)["version"] =
286      base::StringPrintf("%s%d", "WebSocket/", protocol_version_);
287    (*headers)["scheme"] = url.scheme();
288  } else {
289    (*headers)[":path"] = url.path();
290    (*headers)[":version"] =
291      base::StringPrintf("%s%d", "WebSocket/", protocol_version_);
292    (*headers)[":scheme"] = url.scheme();
293  }
294
295  HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n");
296  while (iter.GetNext()) {
297    if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), "upgrade") ||
298        LowerCaseEqualsASCII(iter.name_begin(),
299                             iter.name_end(),
300                             "connection") ||
301        LowerCaseEqualsASCII(iter.name_begin(),
302                             iter.name_end(),
303                             "sec-websocket-version")) {
304      // These headers must be ignored.
305      continue;
306    } else if (LowerCaseEqualsASCII(iter.name_begin(),
307                                    iter.name_end(),
308                                    "sec-websocket-key")) {
309      *challenge = iter.values();
310      // Sec-WebSocket-Key is not sent to the server.
311      continue;
312    } else if (LowerCaseEqualsASCII(iter.name_begin(),
313                                    iter.name_end(),
314                                    "host") ||
315               LowerCaseEqualsASCII(iter.name_begin(),
316                                    iter.name_end(),
317                                    "origin") ||
318               LowerCaseEqualsASCII(iter.name_begin(),
319                                    iter.name_end(),
320                                    "sec-websocket-protocol") ||
321               LowerCaseEqualsASCII(iter.name_begin(),
322                                    iter.name_end(),
323                                    "sec-websocket-extensions")) {
324      // TODO(toyoshim): Some WebSocket extensions may not be compatible with
325      // SPDY. We should omit them from a Sec-WebSocket-Extension header.
326      std::string name;
327      if (spdy_protocol_version <= 2)
328        name = StringToLowerASCII(iter.name());
329      else
330        name = ":" + StringToLowerASCII(iter.name());
331      (*headers)[name] = iter.values();
332      continue;
333    }
334    // Others should be sent out to |headers|.
335    std::string name = StringToLowerASCII(iter.name());
336    SpdyHeaderBlock::iterator found = headers->find(name);
337    if (found == headers->end()) {
338      (*headers)[name] = iter.values();
339    } else {
340      // For now, websocket doesn't use multiple headers, but follows to http.
341      found->second.append(1, '\0');  // +=() doesn't append 0's
342      found->second.append(iter.values());
343    }
344  }
345
346  return true;
347}
348
349std::string WebSocketHandshakeRequestHandler::GetRawRequest() {
350  DCHECK(!status_line_.empty());
351  DCHECK(!headers_.empty());
352  // The following works on both hybi-04 and older handshake,
353  // because |key3_| is guaranteed to be empty if the handshake was hybi-04's.
354  std::string raw_request = status_line_ + headers_ + "\r\n" + key3_;
355  raw_length_ = raw_request.size();
356  return raw_request;
357}
358
359size_t WebSocketHandshakeRequestHandler::raw_length() const {
360  DCHECK_GT(raw_length_, 0);
361  return raw_length_;
362}
363
364int WebSocketHandshakeRequestHandler::protocol_version() const {
365  DCHECK_GE(protocol_version_, 0);
366  return protocol_version_;
367}
368
369WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler()
370    : original_header_length_(0),
371      protocol_version_(0) {}
372
373WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {}
374
375int WebSocketHandshakeResponseHandler::protocol_version() const {
376  DCHECK_GE(protocol_version_, 0);
377  return protocol_version_;
378}
379
380void WebSocketHandshakeResponseHandler::set_protocol_version(
381    int protocol_version) {
382  DCHECK_GE(protocol_version, 0);
383  protocol_version_ = protocol_version;
384}
385
386size_t WebSocketHandshakeResponseHandler::ParseRawResponse(
387    const char* data, int length) {
388  DCHECK_GT(length, 0);
389  if (HasResponse()) {
390    DCHECK(!status_line_.empty());
391    // headers_ might be empty for wrong response from server.
392    return 0;
393  }
394
395  size_t old_original_length = original_.size();
396
397  original_.append(data, length);
398  // TODO(ukai): fail fast when response gives wrong status code.
399  original_header_length_ = HttpUtil::LocateEndOfHeaders(
400      original_.data(), original_.size(), 0);
401  if (!HasResponse())
402    return length;
403
404  ParseHandshakeHeader(original_.data(),
405                       original_header_length_,
406                       &status_line_,
407                       &headers_);
408  int header_size = status_line_.size() + headers_.size();
409  DCHECK_GE(original_header_length_, header_size);
410  header_separator_ = std::string(original_.data() + header_size,
411                                  original_header_length_ - header_size);
412  key_ = std::string(original_.data() + original_header_length_,
413                     GetResponseKeySize());
414  return original_header_length_ + GetResponseKeySize() - old_original_length;
415}
416
417bool WebSocketHandshakeResponseHandler::HasResponse() const {
418  return original_header_length_ > 0 &&
419      original_header_length_ + GetResponseKeySize() <= original_.size();
420}
421
422bool WebSocketHandshakeResponseHandler::ParseResponseInfo(
423    const HttpResponseInfo& response_info,
424    const std::string& challenge) {
425  if (!response_info.headers.get())
426    return false;
427
428  std::string response_message;
429  response_message = response_info.headers->GetStatusLine();
430  response_message += "\r\n";
431  if (protocol_version_ >= kMinVersionOfHybiNewHandshake)
432    response_message += "Upgrade: websocket\r\n";
433  else
434    response_message += "Upgrade: WebSocket\r\n";
435  response_message += "Connection: Upgrade\r\n";
436
437  if (protocol_version_ >= kMinVersionOfHybiNewHandshake) {
438    std::string hash = base::SHA1HashString(challenge + kWebSocketGuid);
439    std::string websocket_accept;
440    bool encode_success = base::Base64Encode(hash, &websocket_accept);
441    DCHECK(encode_success);
442    response_message += "Sec-WebSocket-Accept: " + websocket_accept + "\r\n";
443  }
444
445  void* iter = NULL;
446  std::string name;
447  std::string value;
448  while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) {
449    response_message += name + ": " + value + "\r\n";
450  }
451  response_message += "\r\n";
452
453  if (protocol_version_ < kMinVersionOfHybiNewHandshake) {
454    base::MD5Digest digest;
455    base::MD5Sum(challenge.data(), challenge.size(), &digest);
456
457    const char* digest_data = reinterpret_cast<char*>(digest.a);
458    response_message.append(digest_data, sizeof(digest.a));
459  }
460
461  return ParseRawResponse(response_message.data(),
462                          response_message.size()) == response_message.size();
463}
464
465bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock(
466    const SpdyHeaderBlock& headers,
467    const std::string& challenge,
468    int spdy_protocol_version) {
469  SpdyHeaderBlock::const_iterator status;
470  if (spdy_protocol_version <= 2)
471    status = headers.find("status");
472  else
473    status = headers.find(":status");
474  if (status == headers.end())
475    return false;
476  std::string response_message;
477  response_message =
478      base::StringPrintf("%s%s\r\n", "HTTP/1.1 ", status->second.c_str());
479  response_message += "Upgrade: websocket\r\n";
480  response_message += "Connection: Upgrade\r\n";
481
482  std::string hash = base::SHA1HashString(challenge + kWebSocketGuid);
483  std::string websocket_accept;
484  bool encode_success = base::Base64Encode(hash, &websocket_accept);
485  DCHECK(encode_success);
486  response_message += "Sec-WebSocket-Accept: " + websocket_accept + "\r\n";
487
488  for (SpdyHeaderBlock::const_iterator iter = headers.begin();
489       iter != headers.end();
490       ++iter) {
491    // For each value, if the server sends a NUL-separated list of values,
492    // we separate that back out into individual headers for each value
493    // in the list.
494    if ((spdy_protocol_version <= 2 &&
495         LowerCaseEqualsASCII(iter->first, "status")) ||
496        (spdy_protocol_version >= 3 &&
497         LowerCaseEqualsASCII(iter->first, ":status"))) {
498      // The status value is already handled as the first line of
499      // |response_message|. Just skip here.
500      continue;
501    }
502    const std::string& value = iter->second;
503    size_t start = 0;
504    size_t end = 0;
505    do {
506      end = value.find('\0', start);
507      std::string tval;
508      if (end != std::string::npos)
509        tval = value.substr(start, (end - start));
510      else
511        tval = value.substr(start);
512      if (spdy_protocol_version >= 3 &&
513          (LowerCaseEqualsASCII(iter->first, ":sec-websocket-protocol") ||
514           LowerCaseEqualsASCII(iter->first, ":sec-websocket-extensions")))
515        response_message += iter->first.substr(1) + ": " + tval + "\r\n";
516      else
517        response_message += iter->first + ": " + tval + "\r\n";
518      start = end + 1;
519    } while (end != std::string::npos);
520  }
521  response_message += "\r\n";
522
523  return ParseRawResponse(response_message.data(),
524                          response_message.size()) == response_message.size();
525}
526
527void WebSocketHandshakeResponseHandler::GetHeaders(
528    const char* const headers_to_get[],
529    size_t headers_to_get_len,
530    std::vector<std::string>* values) {
531  DCHECK(HasResponse());
532  DCHECK(!status_line_.empty());
533  // headers_ might be empty for wrong response from server.
534  if (headers_.empty())
535    return;
536
537  FetchHeaders(headers_, headers_to_get, headers_to_get_len, values);
538}
539
540void WebSocketHandshakeResponseHandler::RemoveHeaders(
541    const char* const headers_to_remove[],
542    size_t headers_to_remove_len) {
543  DCHECK(HasResponse());
544  DCHECK(!status_line_.empty());
545  // headers_ might be empty for wrong response from server.
546  if (headers_.empty())
547    return;
548
549  headers_ = FilterHeaders(headers_, headers_to_remove, headers_to_remove_len);
550}
551
552std::string WebSocketHandshakeResponseHandler::GetRawResponse() const {
553  DCHECK(HasResponse());
554  return std::string(original_.data(),
555                     original_header_length_ + GetResponseKeySize());
556}
557
558std::string WebSocketHandshakeResponseHandler::GetResponse() {
559  DCHECK(HasResponse());
560  DCHECK(!status_line_.empty());
561  // headers_ might be empty for wrong response from server.
562  DCHECK_EQ(GetResponseKeySize(), key_.size());
563
564  return status_line_ + headers_ + header_separator_ + key_;
565}
566
567size_t WebSocketHandshakeResponseHandler::GetResponseKeySize() const {
568  if (protocol_version_ >= kMinVersionOfHybiNewHandshake)
569    return 0;
570  return kResponseKeySize;
571}
572
573}  // namespace net
574