1/*
2 * Copyright (C) 2011 Google Inc.  All rights reserved.
3 * Copyright (C) Research In Motion Limited 2011. All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are
7 * met:
8 *
9 *     * Redistributions of source code must retain the above copyright
10 * notice, this list of conditions and the following disclaimer.
11 *     * Redistributions in binary form must reproduce the above
12 * copyright notice, this list of conditions and the following disclaimer
13 * in the documentation and/or other materials provided with the
14 * distribution.
15 *     * Neither the name of Google Inc. nor the names of its
16 * contributors may be used to endorse or promote products derived from
17 * this software without specific prior written permission.
18 *
19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32#include "config.h"
33
34#include "modules/websockets/WebSocketHandshake.h"
35
36#include "core/dom/Document.h"
37#include "core/inspector/ScriptCallStack.h"
38#include "core/loader/CookieJar.h"
39#include "modules/websockets/DOMWebSocket.h"
40#include "platform/Cookie.h"
41#include "platform/Crypto.h"
42#include "platform/Logging.h"
43#include "platform/network/HTTPHeaderMap.h"
44#include "platform/network/HTTPParsers.h"
45#include "platform/weborigin/SecurityOrigin.h"
46#include "public/platform/Platform.h"
47#include "wtf/CryptographicallyRandomNumber.h"
48#include "wtf/StdLibExtras.h"
49#include "wtf/StringExtras.h"
50#include "wtf/Vector.h"
51#include "wtf/text/Base64.h"
52#include "wtf/text/CString.h"
53#include "wtf/text/StringBuilder.h"
54#include "wtf/unicode/CharacterNames.h"
55
56namespace blink {
57
58String formatHandshakeFailureReason(const String& detail)
59{
60    return "Error during WebSocket handshake: " + detail;
61}
62
63static String resourceName(const KURL& url)
64{
65    StringBuilder name;
66    name.append(url.path());
67    if (name.isEmpty())
68        name.append('/');
69    if (!url.query().isNull()) {
70        name.append('?');
71        name.append(url.query());
72    }
73    String result = name.toString();
74    ASSERT(!result.isEmpty());
75    ASSERT(!result.contains(' '));
76    return result;
77}
78
79static String hostName(const KURL& url, bool secure)
80{
81    ASSERT(url.protocolIs("wss") == secure);
82    StringBuilder builder;
83    builder.append(url.host().lower());
84    if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) {
85        builder.append(':');
86        builder.appendNumber(url.port());
87    }
88    return builder.toString();
89}
90
91static const size_t maxInputSampleSize = 128;
92static String trimInputSample(const char* p, size_t len)
93{
94    if (len > maxInputSampleSize)
95        return String(p, maxInputSampleSize) + horizontalEllipsis;
96    return String(p, len);
97}
98
99static String generateSecWebSocketKey()
100{
101    static const size_t nonceSize = 16;
102    unsigned char key[nonceSize];
103    cryptographicallyRandomValues(key, nonceSize);
104    return base64Encode(reinterpret_cast<char*>(key), nonceSize);
105}
106
107String WebSocketHandshake::getExpectedWebSocketAccept(const String& secWebSocketKey)
108{
109    static const char webSocketKeyGUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
110    CString keyData = secWebSocketKey.ascii();
111
112    StringBuilder digestable;
113    digestable.append(secWebSocketKey);
114    digestable.append(webSocketKeyGUID, strlen(webSocketKeyGUID));
115    CString digestableCString = digestable.toString().utf8();
116    DigestValue digest;
117    bool digestSuccess = computeDigest(HashAlgorithmSha1, digestableCString.data(), digestableCString.length(), digest);
118    RELEASE_ASSERT(digestSuccess);
119
120    return base64Encode(reinterpret_cast<const char*>(digest.data()), sha1HashSize);
121}
122
123WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, Document* document)
124    : m_url(url)
125    , m_clientProtocol(protocol)
126    , m_secure(m_url.protocolIs("wss"))
127    , m_document(document)
128    , m_mode(Incomplete)
129{
130    m_secWebSocketKey = generateSecWebSocketKey();
131    m_expectedAccept = getExpectedWebSocketAccept(m_secWebSocketKey);
132}
133
134WebSocketHandshake::~WebSocketHandshake()
135{
136    Platform::current()->histogramEnumeration("WebCore.WebSocket.HandshakeResult", m_mode, WebSocketHandshake::ModeMax);
137}
138
139const KURL& WebSocketHandshake::url() const
140{
141    return m_url;
142}
143
144void WebSocketHandshake::setURL(const KURL& url)
145{
146    m_url = url.copy();
147}
148
149const String WebSocketHandshake::host() const
150{
151    return m_url.host().lower();
152}
153
154const String& WebSocketHandshake::clientProtocol() const
155{
156    return m_clientProtocol;
157}
158
159void WebSocketHandshake::setClientProtocol(const String& protocol)
160{
161    m_clientProtocol = protocol;
162}
163
164bool WebSocketHandshake::secure() const
165{
166    return m_secure;
167}
168
169String WebSocketHandshake::clientOrigin() const
170{
171    return m_document->securityOrigin()->toString();
172}
173
174String WebSocketHandshake::clientLocation() const
175{
176    StringBuilder builder;
177    if (m_secure)
178        builder.appendLiteral("wss");
179    else
180        builder.appendLiteral("ws");
181    builder.appendLiteral("://");
182    builder.append(hostName(m_url, m_secure));
183    builder.append(resourceName(m_url));
184    return builder.toString();
185}
186
187CString WebSocketHandshake::clientHandshakeMessage() const
188{
189    ASSERT(m_document);
190
191    // Keep the following consistent with clientHandshakeRequest().
192    StringBuilder builder;
193
194    builder.appendLiteral("GET ");
195    builder.append(resourceName(m_url));
196    builder.appendLiteral(" HTTP/1.1\r\n");
197
198    Vector<String> fields;
199    fields.append("Upgrade: websocket");
200    fields.append("Connection: Upgrade");
201    fields.append("Host: " + hostName(m_url, m_secure));
202    fields.append("Origin: " + clientOrigin());
203    if (!m_clientProtocol.isEmpty())
204        fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol);
205
206    // Add no-cache headers to avoid compatibility issue.
207    // There are some proxies that rewrite "Connection: upgrade"
208    // to "Connection: close" in the response if a request doesn't contain
209    // these headers.
210    fields.append("Pragma: no-cache");
211    fields.append("Cache-Control: no-cache");
212
213    fields.append("Sec-WebSocket-Key: " + m_secWebSocketKey);
214    fields.append("Sec-WebSocket-Version: 13");
215    const String extensionValue = m_extensionDispatcher.createHeaderValue();
216    if (extensionValue.length())
217        fields.append("Sec-WebSocket-Extensions: " + extensionValue);
218
219    fields.append("User-Agent: " + m_document->userAgent(m_document->url()));
220
221    // Fields in the handshake are sent by the client in a random order; the
222    // order is not meaningful. Thus, it's ok to send the order we constructed
223    // the fields.
224
225    for (size_t i = 0; i < fields.size(); i++) {
226        builder.append(fields[i]);
227        builder.appendLiteral("\r\n");
228    }
229
230    builder.appendLiteral("\r\n");
231
232    return builder.toString().utf8();
233}
234
235PassRefPtr<WebSocketHandshakeRequest> WebSocketHandshake::clientHandshakeRequest() const
236{
237    ASSERT(m_document);
238
239    // Keep the following consistent with clientHandshakeMessage().
240    // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and
241    // m_key3 in WebSocketHandshakeRequest?
242    RefPtr<WebSocketHandshakeRequest> request = WebSocketHandshakeRequest::create(m_url);
243    request->addHeaderField("Upgrade", "websocket");
244    request->addHeaderField("Connection", "Upgrade");
245    request->addHeaderField("Host", AtomicString(hostName(m_url, m_secure)));
246    request->addHeaderField("Origin", AtomicString(clientOrigin()));
247    if (!m_clientProtocol.isEmpty())
248        request->addHeaderField("Sec-WebSocket-Protocol", AtomicString(m_clientProtocol));
249
250    KURL url = httpURLForAuthenticationAndCookies();
251
252    String cookie = cookieRequestHeaderFieldValue(m_document, url);
253    if (!cookie.isEmpty())
254        request->addHeaderField("Cookie", AtomicString(cookie));
255    // Set "Cookie2: <cookie>" if cookies 2 exists for url?
256
257    request->addHeaderField("Pragma", "no-cache");
258    request->addHeaderField("Cache-Control", "no-cache");
259
260    request->addHeaderField("Sec-WebSocket-Key", AtomicString(m_secWebSocketKey));
261    request->addHeaderField("Sec-WebSocket-Version", "13");
262    const String extensionValue = m_extensionDispatcher.createHeaderValue();
263    if (extensionValue.length())
264        request->addHeaderField("Sec-WebSocket-Extensions", AtomicString(extensionValue));
265
266    request->addHeaderField("User-Agent", AtomicString(m_document->userAgent(m_document->url())));
267
268    return request.release();
269}
270
271void WebSocketHandshake::reset()
272{
273    m_mode = Incomplete;
274    m_extensionDispatcher.reset();
275}
276
277void WebSocketHandshake::clearDocument()
278{
279    m_document = nullptr;
280}
281
282int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
283{
284    m_mode = Incomplete;
285    int statusCode;
286    String statusText;
287    int lineLength = readStatusLine(header, len, statusCode, statusText);
288    if (lineLength == -1)
289        return -1;
290    if (statusCode == -1) {
291        m_mode = Failed; // m_failureReason is set inside readStatusLine().
292        return len;
293    }
294    WTF_LOG(Network, "WebSocketHandshake %p readServerHandshake() Status code is %d", this, statusCode);
295    m_response.setStatusCode(statusCode);
296    m_response.setStatusText(statusText);
297    if (statusCode != 101) {
298        m_mode = Failed;
299        m_failureReason = formatHandshakeFailureReason("Unexpected response code: " + String::number(statusCode));
300        return len;
301    }
302    m_mode = Normal;
303    if (!strnstr(header, "\r\n\r\n", len)) {
304        // Just hasn't been received fully yet.
305        m_mode = Incomplete;
306        return -1;
307    }
308    const char* p = readHTTPHeaders(header + lineLength, header + len);
309    if (!p) {
310        WTF_LOG(Network, "WebSocketHandshake %p readServerHandshake() readHTTPHeaders() failed", this);
311        m_mode = Failed; // m_failureReason is set inside readHTTPHeaders().
312        return len;
313    }
314    if (!checkResponseHeaders()) {
315        WTF_LOG(Network, "WebSocketHandshake %p readServerHandshake() checkResponseHeaders() failed", this);
316        m_mode = Failed;
317        return p - header;
318    }
319
320    m_mode = Connected;
321    return p - header;
322}
323
324WebSocketHandshake::Mode WebSocketHandshake::mode() const
325{
326    return m_mode;
327}
328
329String WebSocketHandshake::failureReason() const
330{
331    return m_failureReason;
332}
333
334const AtomicString& WebSocketHandshake::serverWebSocketProtocol() const
335{
336    return m_response.headerFields().get("sec-websocket-protocol");
337}
338
339const AtomicString& WebSocketHandshake::serverUpgrade() const
340{
341    return m_response.headerFields().get("upgrade");
342}
343
344const AtomicString& WebSocketHandshake::serverConnection() const
345{
346    return m_response.headerFields().get("connection");
347}
348
349const AtomicString& WebSocketHandshake::serverWebSocketAccept() const
350{
351    return m_response.headerFields().get("sec-websocket-accept");
352}
353
354String WebSocketHandshake::acceptedExtensions() const
355{
356    return m_extensionDispatcher.acceptedExtensions();
357}
358
359const WebSocketHandshakeResponse& WebSocketHandshake::serverHandshakeResponse() const
360{
361    return m_response;
362}
363
364void WebSocketHandshake::addExtensionProcessor(PassOwnPtr<WebSocketExtensionProcessor> processor)
365{
366    m_extensionDispatcher.addProcessor(processor);
367}
368
369KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
370{
371    KURL url = m_url.copy();
372    bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
373    ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
374    return url;
375}
376
377// Returns the header length (including "\r\n"), or -1 if we have not received enough data yet.
378// If the line is malformed or the status code is not a 3-digit number,
379// statusCode and statusText will be set to -1 and a null string, respectively.
380int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, int& statusCode, String& statusText)
381{
382    // Arbitrary size limit to prevent the server from sending an unbounded
383    // amount of data with no newlines and forcing us to buffer it all.
384    static const int maximumLength = 1024;
385
386    statusCode = -1;
387    statusText = String();
388
389    const char* space1 = 0;
390    const char* space2 = 0;
391    const char* p;
392    size_t consumedLength;
393
394    for (p = header, consumedLength = 0; consumedLength < headerLength; p++, consumedLength++) {
395        if (*p == ' ') {
396            if (!space1)
397                space1 = p;
398            else if (!space2)
399                space2 = p;
400        } else if (*p == '\0') {
401            // The caller isn't prepared to deal with null bytes in status
402            // line. WebSockets specification doesn't prohibit this, but HTTP
403            // does, so we'll just treat this as an error.
404            m_failureReason = formatHandshakeFailureReason("Status line contains embedded null");
405            return p + 1 - header;
406        } else if (*p == '\n') {
407            break;
408        }
409    }
410    if (consumedLength == headerLength)
411        return -1; // We have not received '\n' yet.
412
413    const char* end = p + 1;
414    int lineLength = end - header;
415    if (lineLength > maximumLength) {
416        m_failureReason = formatHandshakeFailureReason("Status line is too long");
417        return maximumLength;
418    }
419
420    // The line must end with "\r\n".
421    if (lineLength < 2 || *(end - 2) != '\r') {
422        m_failureReason = formatHandshakeFailureReason("Status line does not end with CRLF");
423        return lineLength;
424    }
425
426    if (!space1 || !space2) {
427        m_failureReason = formatHandshakeFailureReason("No response code found in status line: " + trimInputSample(header, lineLength - 2));
428        return lineLength;
429    }
430
431    String statusCodeString(space1 + 1, space2 - space1 - 1);
432    if (statusCodeString.length() != 3) // Status code must consist of three digits.
433        return lineLength;
434    for (int i = 0; i < 3; ++i) {
435        if (statusCodeString[i] < '0' || statusCodeString[i] > '9') {
436            m_failureReason = formatHandshakeFailureReason("Invalid status code: " + statusCodeString);
437            return lineLength;
438        }
439    }
440
441    bool ok = false;
442    statusCode = statusCodeString.toInt(&ok);
443    ASSERT(ok);
444
445    statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n".
446    return lineLength;
447}
448
449const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end)
450{
451    m_response.clearHeaderFields();
452
453    AtomicString name;
454    AtomicString value;
455    bool sawSecWebSocketAcceptHeaderField = false;
456    bool sawSecWebSocketProtocolHeaderField = false;
457    const char* p = start;
458    while (p < end) {
459        size_t consumedLength = parseHTTPHeader(p, end - p, m_failureReason, name, value);
460        if (!consumedLength)
461            return 0;
462        p += consumedLength;
463
464        // Stop once we consumed an empty line.
465        if (name.isEmpty())
466            break;
467
468        // Sec-WebSocket-Extensions may be split. We parse and check the
469        // header value every time the header appears.
470        if (equalIgnoringCase("Sec-WebSocket-Extensions", name)) {
471            if (!m_extensionDispatcher.processHeaderValue(value)) {
472                m_failureReason = formatHandshakeFailureReason(m_extensionDispatcher.failureReason());
473                return 0;
474            }
475        } else if (equalIgnoringCase("Sec-WebSocket-Accept", name)) {
476            if (sawSecWebSocketAcceptHeaderField) {
477                m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Accept' header must not appear more than once in a response");
478                return 0;
479            }
480            m_response.addHeaderField(name, value);
481            sawSecWebSocketAcceptHeaderField = true;
482        } else if (equalIgnoringCase("Sec-WebSocket-Protocol", name)) {
483            if (sawSecWebSocketProtocolHeaderField) {
484                m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Protocol' header must not appear more than once in a response");
485                return 0;
486            }
487            m_response.addHeaderField(name, value);
488            sawSecWebSocketProtocolHeaderField = true;
489        } else {
490            m_response.addHeaderField(name, value);
491        }
492    }
493
494    String extensions = m_extensionDispatcher.acceptedExtensions();
495    if (!extensions.isEmpty())
496        m_response.addHeaderField("Sec-WebSocket-Extensions", AtomicString(extensions));
497    return p;
498}
499
500bool WebSocketHandshake::checkResponseHeaders()
501{
502    const AtomicString& serverWebSocketProtocol = this->serverWebSocketProtocol();
503    const AtomicString& serverUpgrade = this->serverUpgrade();
504    const AtomicString& serverConnection = this->serverConnection();
505    const AtomicString& serverWebSocketAccept = this->serverWebSocketAccept();
506
507    if (serverUpgrade.isNull()) {
508        m_failureReason = formatHandshakeFailureReason("'Upgrade' header is missing");
509        return false;
510    }
511    if (serverConnection.isNull()) {
512        m_failureReason = formatHandshakeFailureReason("'Connection' header is missing");
513        return false;
514    }
515    if (serverWebSocketAccept.isNull()) {
516        m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Accept' header is missing");
517        return false;
518    }
519
520    if (!equalIgnoringCase(serverUpgrade, "websocket")) {
521        m_failureReason = formatHandshakeFailureReason("'Upgrade' header value is not 'WebSocket': " + serverUpgrade);
522        return false;
523    }
524    if (!equalIgnoringCase(serverConnection, "upgrade")) {
525        m_failureReason = formatHandshakeFailureReason("'Connection' header value is not 'Upgrade': " + serverConnection);
526        return false;
527    }
528
529    if (serverWebSocketAccept != m_expectedAccept) {
530        m_failureReason = formatHandshakeFailureReason("Incorrect 'Sec-WebSocket-Accept' header value");
531        return false;
532    }
533    if (!serverWebSocketProtocol.isNull()) {
534        if (m_clientProtocol.isEmpty()) {
535            m_failureReason = formatHandshakeFailureReason("Response must not include 'Sec-WebSocket-Protocol' header if not present in request: " + serverWebSocketProtocol);
536            return false;
537        }
538        Vector<String> result;
539        m_clientProtocol.split(String(DOMWebSocket::subprotocolSeperator()), result);
540        if (!result.contains(serverWebSocketProtocol)) {
541            m_failureReason = formatHandshakeFailureReason("'Sec-WebSocket-Protocol' header value '" + serverWebSocketProtocol + "' in response does not match any of sent values");
542            return false;
543        }
544    } else if (!m_clientProtocol.isEmpty()) {
545        m_failureReason = formatHandshakeFailureReason("Sent non-empty 'Sec-WebSocket-Protocol' header but no response was received");
546        return false;
547    }
548    return true;
549}
550
551void WebSocketHandshake::trace(Visitor* visitor)
552{
553    visitor->trace(m_document);
554}
555
556} // namespace blink
557