f65341580eef31fc77699c2af7d824e46ecad5f3
[WebKit-https.git] / Source / WebCore / websockets / WebSocketHandshake.cpp
1 /*
2  * Copyright (C) 2009 Google Inc.  All rights reserved.
3  * Copyright (C) Research In Motion Limited 2011. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are
7  * met:
8  *
9  *     * Redistributions of source code must retain the above copyright
10  * notice, this list of conditions and the following disclaimer.
11  *     * Redistributions in binary form must reproduce the above
12  * copyright notice, this list of conditions and the following disclaimer
13  * in the documentation and/or other materials provided with the
14  * distribution.
15  *     * Neither the name of Google Inc. nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30  */
31
32 #include "config.h"
33
34 #if ENABLE(WEB_SOCKETS)
35
36 #include "WebSocketHandshake.h"
37
38 #include "CharacterNames.h"
39 #include "Cookie.h"
40 #include "CookieJar.h"
41 #include "Document.h"
42 #include "HTTPHeaderMap.h"
43 #include "KURL.h"
44 #include "Logging.h"
45 #include "ScriptCallStack.h"
46 #include "ScriptExecutionContext.h"
47 #include "SecurityOrigin.h"
48
49 #include <wtf/MD5.h>
50 #include <wtf/RandomNumber.h>
51 #include <wtf/StdLibExtras.h>
52 #include <wtf/StringExtras.h>
53 #include <wtf/Vector.h>
54 #include <wtf/text/AtomicString.h>
55 #include <wtf/text/CString.h>
56 #include <wtf/text/StringBuilder.h>
57 #include <wtf/text/StringConcatenate.h>
58
59 namespace WebCore {
60
61 static const char randomCharacterInSecWebSocketKey[] = "!\"#$%&'()*+,-./:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~";
62
63 static String resourceName(const KURL& url)
64 {
65     String name = url.path();
66     if (name.isEmpty())
67         name = "/";
68     if (!url.query().isNull())
69         name += "?" + url.query();
70     ASSERT(!name.isEmpty());
71     ASSERT(!name.contains(' '));
72     return name;
73 }
74
75 static String hostName(const KURL& url, bool secure)
76 {
77     ASSERT(url.protocolIs("wss") == secure);
78     StringBuilder builder;
79     builder.append(url.host().lower());
80     if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) {
81         builder.append(':');
82         builder.append(String::number(url.port()));
83     }
84     return builder.toString();
85 }
86
87 static const size_t maxConsoleMessageSize = 128;
88 static String trimConsoleMessage(const char* p, size_t len)
89 {
90     String s = String(p, std::min<size_t>(len, maxConsoleMessageSize));
91     if (len > maxConsoleMessageSize)
92         s.append(horizontalEllipsis);
93     return s;
94 }
95
96 static void generateSecWebSocketKey(uint32_t& number, String& key)
97 {
98     uint32_t space = static_cast<uint32_t>(randomNumber() * 12) + 1;
99     uint32_t max = 4294967295U / space;
100     number = static_cast<uint32_t>(randomNumber() * max);
101     uint32_t product = number * space;
102
103     String s = String::number(product);
104     int n = static_cast<int>(randomNumber() * 12) + 1;
105     DEFINE_STATIC_LOCAL(String, randomChars, (randomCharacterInSecWebSocketKey));
106     for (int i = 0; i < n; i++) {
107         int pos = static_cast<int>(randomNumber() * (s.length() + 1));
108         int chpos = static_cast<int>(randomNumber() * randomChars.length());
109         s.insert(randomChars.substring(chpos, 1), pos);
110     }
111     DEFINE_STATIC_LOCAL(String, spaceChar, (" "));
112     for (uint32_t i = 0; i < space; i++) {
113         int pos = static_cast<int>(randomNumber() * (s.length() - 1)) + 1;
114         s.insert(spaceChar, pos);
115     }
116     ASSERT(s[0] != ' ');
117     ASSERT(s[s.length() - 1] != ' ');
118     key = s;
119 }
120
121 static void generateKey3(unsigned char key3[8])
122 {
123     for (int i = 0; i < 8; i++)
124         key3[i] = randomNumber() * 256;
125 }
126
127 static void setChallengeNumber(unsigned char* buf, uint32_t number)
128 {
129     unsigned char* p = buf + 3;
130     for (int i = 0; i < 4; i++) {
131         *p = number & 0xFF;
132         --p;
133         number >>= 8;
134     }
135 }
136
137 static void generateExpectedChallengeResponse(uint32_t number1, uint32_t number2, unsigned char key3[8], unsigned char expectedChallenge[16])
138 {
139     unsigned char challenge[16];
140     setChallengeNumber(&challenge[0], number1);
141     setChallengeNumber(&challenge[4], number2);
142     memcpy(&challenge[8], key3, 8);
143     MD5 md5;
144     md5.addBytes(challenge, sizeof(challenge));
145     Vector<uint8_t, 16> digest;
146     md5.checksum(digest);
147     memcpy(expectedChallenge, digest.data(), 16);
148 }
149
150 WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, ScriptExecutionContext* context)
151     : m_url(url)
152     , m_clientProtocol(protocol)
153     , m_secure(m_url.protocolIs("wss"))
154     , m_context(context)
155     , m_mode(Incomplete)
156 {
157     uint32_t number1;
158     uint32_t number2;
159     generateSecWebSocketKey(number1, m_secWebSocketKey1);
160     generateSecWebSocketKey(number2, m_secWebSocketKey2);
161     generateKey3(m_key3);
162     generateExpectedChallengeResponse(number1, number2, m_key3, m_expectedChallengeResponse);
163 }
164
165 WebSocketHandshake::~WebSocketHandshake()
166 {
167 }
168
169 const KURL& WebSocketHandshake::url() const
170 {
171     return m_url;
172 }
173
174 void WebSocketHandshake::setURL(const KURL& url)
175 {
176     m_url = url.copy();
177 }
178
179 const String WebSocketHandshake::host() const
180 {
181     return m_url.host().lower();
182 }
183
184 const String& WebSocketHandshake::clientProtocol() const
185 {
186     return m_clientProtocol;
187 }
188
189 void WebSocketHandshake::setClientProtocol(const String& protocol)
190 {
191     m_clientProtocol = protocol;
192 }
193
194 bool WebSocketHandshake::secure() const
195 {
196     return m_secure;
197 }
198
199 String WebSocketHandshake::clientOrigin() const
200 {
201     return m_context->securityOrigin()->toString();
202 }
203
204 String WebSocketHandshake::clientLocation() const
205 {
206     StringBuilder builder;
207     builder.append(m_secure ? "wss" : "ws");
208     builder.append("://");
209     builder.append(hostName(m_url, m_secure));
210     builder.append(resourceName(m_url));
211     return builder.toString();
212 }
213
214 CString WebSocketHandshake::clientHandshakeMessage() const
215 {
216     // Keep the following consistent with clientHandshakeRequest().
217     StringBuilder builder;
218
219     builder.append("GET ");
220     builder.append(resourceName(m_url));
221     builder.append(" HTTP/1.1\r\n");
222
223     Vector<String> fields;
224     fields.append("Upgrade: WebSocket");
225     fields.append("Connection: Upgrade");
226     fields.append("Host: " + hostName(m_url, m_secure));
227     fields.append("Origin: " + clientOrigin());
228     if (!m_clientProtocol.isEmpty())
229         fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol);
230
231     KURL url = httpURLForAuthenticationAndCookies();
232     if (m_context->isDocument()) {
233         Document* document = static_cast<Document*>(m_context);
234         String cookie = cookieRequestHeaderFieldValue(document, url);
235         if (!cookie.isEmpty())
236             fields.append("Cookie: " + cookie);
237         // Set "Cookie2: <cookie>" if cookies 2 exists for url?
238     }
239
240     fields.append("Sec-WebSocket-Key1: " + m_secWebSocketKey1);
241     fields.append("Sec-WebSocket-Key2: " + m_secWebSocketKey2);
242
243     // Fields in the handshake are sent by the client in a random order; the
244     // order is not meaningful.  Thus, it's ok to send the order we constructed
245     // the fields.
246
247     for (size_t i = 0; i < fields.size(); i++) {
248         builder.append(fields[i]);
249         builder.append("\r\n");
250     }
251
252     builder.append("\r\n");
253
254     CString handshakeHeader = builder.toString().utf8();
255     char* characterBuffer = 0;
256     CString msg = CString::newUninitialized(handshakeHeader.length() + sizeof(m_key3), characterBuffer);
257     memcpy(characterBuffer, handshakeHeader.data(), handshakeHeader.length());
258     memcpy(characterBuffer + handshakeHeader.length(), m_key3, sizeof(m_key3));
259     return msg;
260 }
261
262 WebSocketHandshakeRequest WebSocketHandshake::clientHandshakeRequest() const
263 {
264     // Keep the following consistent with clientHandshakeMessage().
265     // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and
266     // m_key3 in WebSocketHandshakeRequest?
267     WebSocketHandshakeRequest request("GET", m_url);
268     request.addHeaderField("Upgrade", "WebSocket");
269     request.addHeaderField("Connection", "Upgrade");
270     request.addHeaderField("Host", hostName(m_url, m_secure));
271     request.addHeaderField("Origin", clientOrigin());
272     if (!m_clientProtocol.isEmpty())
273         request.addHeaderField("Sec-WebSocket-Protocol:", m_clientProtocol);
274
275     KURL url = httpURLForAuthenticationAndCookies();
276     if (m_context->isDocument()) {
277         Document* document = static_cast<Document*>(m_context);
278         String cookie = cookieRequestHeaderFieldValue(document, url);
279         if (!cookie.isEmpty())
280             request.addHeaderField("Cookie", cookie);
281         // Set "Cookie2: <cookie>" if cookies 2 exists for url?
282     }
283
284     request.addHeaderField("Sec-WebSocket-Key1", m_secWebSocketKey1);
285     request.addHeaderField("Sec-WebSocket-Key2", m_secWebSocketKey2);
286     request.setKey3(m_key3);
287
288     return request;
289 }
290
291 void WebSocketHandshake::reset()
292 {
293     m_mode = Incomplete;
294
295     m_wsOrigin = String();
296     m_wsLocation = String();
297     m_wsProtocol = String();
298     m_setCookie = String();
299     m_setCookie2 = String();
300 }
301
302 void WebSocketHandshake::clearScriptExecutionContext()
303 {
304     m_context = 0;
305 }
306
307 int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
308 {
309     m_mode = Incomplete;
310     int statusCode;
311     String statusText;
312     int lineLength = readStatusLine(header, len, statusCode, statusText);
313     if (lineLength == -1)
314         return -1;
315     if (statusCode == -1) {
316         m_mode = Failed;
317         return len;
318     }
319     LOG(Network, "response code: %d", statusCode);
320     m_response.setStatusCode(statusCode);
321     m_response.setStatusText(statusText);
322     if (statusCode != 101) {
323         m_mode = Failed;
324         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, makeString("Unexpected response code: ", String::number(statusCode)), 0, clientOrigin(), 0);
325         return len;
326     }
327     m_mode = Normal;
328     if (!strnstr(header, "\r\n\r\n", len)) {
329         // Just hasn't been received fully yet.
330         m_mode = Incomplete;
331         return -1;
332     }
333     const char* p = readHTTPHeaders(header + lineLength, header + len);
334     if (!p) {
335         LOG(Network, "readHTTPHeaders failed");
336         m_mode = Failed;
337         return len;
338     }
339     processHeaders();
340     if (!checkResponseHeaders()) {
341         LOG(Network, "header process failed");
342         m_mode = Failed;
343         return p - header;
344     }
345     if (len < static_cast<size_t>(p - header + sizeof(m_expectedChallengeResponse))) {
346         // Just hasn't been received /expected/ yet.
347         m_mode = Incomplete;
348         return -1;
349     }
350     m_response.setChallengeResponse(static_cast<const unsigned char*>(static_cast<const void*>(p)));
351     if (memcmp(p, m_expectedChallengeResponse, sizeof(m_expectedChallengeResponse))) {
352         m_mode = Failed;
353         return (p - header) + sizeof(m_expectedChallengeResponse);
354     }
355     m_mode = Connected;
356     return (p - header) + sizeof(m_expectedChallengeResponse);
357 }
358
359 WebSocketHandshake::Mode WebSocketHandshake::mode() const
360 {
361     return m_mode;
362 }
363
364 const String& WebSocketHandshake::serverWebSocketOrigin() const
365 {
366     return m_wsOrigin;
367 }
368
369 void WebSocketHandshake::setServerWebSocketOrigin(const String& webSocketOrigin)
370 {
371     m_wsOrigin = webSocketOrigin;
372 }
373
374 const String& WebSocketHandshake::serverWebSocketLocation() const
375 {
376     return m_wsLocation;
377 }
378
379 void WebSocketHandshake::setServerWebSocketLocation(const String& webSocketLocation)
380 {
381     m_wsLocation = webSocketLocation;
382 }
383
384 const String& WebSocketHandshake::serverWebSocketProtocol() const
385 {
386     return m_wsProtocol;
387 }
388
389 void WebSocketHandshake::setServerWebSocketProtocol(const String& webSocketProtocol)
390 {
391     m_wsProtocol = webSocketProtocol;
392 }
393
394 const String& WebSocketHandshake::serverSetCookie() const
395 {
396     return m_setCookie;
397 }
398
399 void WebSocketHandshake::setServerSetCookie(const String& setCookie)
400 {
401     m_setCookie = setCookie;
402 }
403
404 const String& WebSocketHandshake::serverSetCookie2() const
405 {
406     return m_setCookie2;
407 }
408
409 void WebSocketHandshake::setServerSetCookie2(const String& setCookie2)
410 {
411     m_setCookie2 = setCookie2;
412 }
413
414 const WebSocketHandshakeResponse& WebSocketHandshake::serverHandshakeResponse() const
415 {
416     return m_response;
417 }
418
419 KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
420 {
421     KURL url = m_url.copy();
422     bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
423     ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
424     return url;
425 }
426
427 // Returns the header length (including "\r\n"), or -1 if we have not received enough data yet.
428 // If the line is malformed or the status code is not a 3-digit number,
429 // statusCode and statusText will be set to -1 and a null string, respectively.
430 int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, int& statusCode, String& statusText)
431 {
432     // Arbitrary size limit to prevent the server from sending an unbounded
433     // amount of data with no newlines and forcing us to buffer it all.
434     static const int maximumLength = 1024;
435
436     statusCode = -1;
437     statusText = String();
438
439     const char* space1 = 0;
440     const char* space2 = 0;
441     const char* p;
442     size_t consumedLength;
443
444     for (p = header, consumedLength = 0; consumedLength < headerLength; p++, consumedLength++) {
445         if (*p == ' ') {
446             if (!space1)
447                 space1 = p;
448             else if (!space2)
449                 space2 = p;
450         } else if (*p == '\0') {
451             // The caller isn't prepared to deal with null bytes in status
452             // line. WebSockets specification doesn't prohibit this, but HTTP
453             // does, so we'll just treat this as an error. 
454             m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Status line contains embedded null", 0, clientOrigin(), 0);
455             return p + 1 - header;
456         } else if (*p == '\n')
457             break;
458     }
459     if (consumedLength == headerLength)
460         return -1; // We have not received '\n' yet.
461
462     const char* end = p + 1;
463     if (end - header > maximumLength) {
464         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Status line is too long", 0, clientOrigin(), 0);
465         return maximumLength;
466     }
467     int lineLength = end - header;
468
469     if (!space1 || !space2) {
470         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "No response code found: " + trimConsoleMessage(header, lineLength - 1), 0, clientOrigin(), 0);
471         return lineLength;
472     }
473
474     // The line must end with "\r\n".
475     if (*(end - 2) != '\r') {
476         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Status line does not end with CRLF", 0, clientOrigin(), 0);
477         return lineLength;
478     }
479
480     String statusCodeString(space1 + 1, space2 - space1 - 1);
481     if (statusCodeString.length() != 3) // Status code must consist of three digits.
482         return lineLength;
483     for (int i = 0; i < 3; ++i)
484         if (statusCodeString[i] < '0' || statusCodeString[i] > '9') {
485             m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Invalid status code: " + statusCodeString, 0, clientOrigin(), 0);
486             return lineLength;
487         }
488
489     bool ok = false;
490     statusCode = statusCodeString.toInt(&ok);
491     ASSERT(ok);
492
493     statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n".
494     return lineLength;
495 }
496
497 const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end)
498 {
499     m_response.clearHeaderFields();
500
501     Vector<char> name;
502     Vector<char> value;
503     for (const char* p = start; p < end; p++) {
504         name.clear();
505         value.clear();
506
507         for (; p < end; p++) {
508             switch (*p) {
509             case '\r':
510                 if (name.isEmpty()) {
511                     if (p + 1 < end && *(p + 1) == '\n')
512                         return p + 2;
513                     m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF at " + trimConsoleMessage(p, end - p), 0, clientOrigin(), 0);
514                     return 0;
515                 }
516                 m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected CR in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin(), 0);
517                 return 0;
518             case '\n':
519                 m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin(), 0);
520                 return 0;
521             case ':':
522                 break;
523             default:
524                 name.append(*p);
525                 continue;
526             }
527             if (*p == ':') {
528                 ++p;
529                 break;
530             }
531         }
532
533         for (; p < end && *p == 0x20; p++) { }
534
535         for (; p < end; p++) {
536             switch (*p) {
537             case '\r':
538                 break;
539             case '\n':
540                 m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in value at " + trimConsoleMessage(value.data(), value.size()), 0, clientOrigin(), 0);
541                 return 0;
542             default:
543                 value.append(*p);
544             }
545             if (*p == '\r') {
546                 ++p;
547                 break;
548             }
549         }
550         if (p >= end || *p != '\n') {
551             m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF after value at " + trimConsoleMessage(p, end - p), 0, clientOrigin(), 0);
552             return 0;
553         }
554         AtomicString nameStr(String::fromUTF8(name.data(), name.size()));
555         String valueStr = String::fromUTF8(value.data(), value.size());
556         if (nameStr.isNull()) {
557             m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "invalid UTF-8 sequence in header name", 0, clientOrigin(), 0);
558             return 0;
559         }
560         if (valueStr.isNull()) {
561             m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "invalid UTF-8 sequence in header value", 0, clientOrigin(), 0);
562             return 0;
563         }
564         LOG(Network, "name=%s value=%s", nameStr.string().utf8().data(), valueStr.utf8().data());
565         m_response.addHeaderField(nameStr, valueStr);
566     }
567     ASSERT_NOT_REACHED();
568     return 0;
569 }
570
571 void WebSocketHandshake::processHeaders()
572 {
573     ASSERT(m_mode == Normal);
574     const HTTPHeaderMap& headers = m_response.headerFields();
575     m_wsOrigin = headers.get("sec-websocket-origin");
576     m_wsLocation = headers.get("sec-websocket-location");
577     m_wsProtocol = headers.get("sec-websocket-protocol");
578     m_setCookie = headers.get("set-cookie");
579     m_setCookie2 = headers.get("set-cookie2");
580 }
581
582 bool WebSocketHandshake::checkResponseHeaders()
583 {
584     if (m_wsOrigin.isNull()) {
585         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'sec-websocket-origin' header is missing", 0, clientOrigin(), 0);
586         return false;
587     }
588     if (m_wsLocation.isNull()) {
589         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'sec-websocket-location' header is missing", 0, clientOrigin(), 0);
590         return false;
591     }
592
593     if (clientOrigin() != m_wsOrigin) {
594         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: origin mismatch: " + clientOrigin() + " != " + m_wsOrigin, 0, clientOrigin(), 0);
595         return false;
596     }
597     if (clientLocation() != m_wsLocation) {
598         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: location mismatch: " + clientLocation() + " != " + m_wsLocation, 0, clientOrigin(), 0);
599         return false;
600     }
601     if (!m_clientProtocol.isEmpty() && m_clientProtocol != m_wsProtocol) {
602         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: protocol mismatch: " + m_clientProtocol + " != " + m_wsProtocol, 0, clientOrigin(), 0);
603         return false;
604     }
605     return true;
606 }
607
608 } // namespace WebCore
609
610 #endif // ENABLE(WEB_SOCKETS)