2011-01-29 Patrick Gansterer <paroga@webkit.org>
[WebKit-https.git] / Source / WebCore / websockets / WebSocketHandshake.cpp
1 /*
2  * Copyright (C) 2009 Google Inc.  All rights reserved.
3  * Copyright (C) Research In Motion Limited 2011. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are
7  * met:
8  *
9  *     * Redistributions of source code must retain the above copyright
10  * notice, this list of conditions and the following disclaimer.
11  *     * Redistributions in binary form must reproduce the above
12  * copyright notice, this list of conditions and the following disclaimer
13  * in the documentation and/or other materials provided with the
14  * distribution.
15  *     * Neither the name of Google Inc. nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30  */
31
32 #include "config.h"
33
34 #if ENABLE(WEB_SOCKETS)
35
36 #include "WebSocketHandshake.h"
37
38 #include "Cookie.h"
39 #include "CookieJar.h"
40 #include "Document.h"
41 #include "HTTPHeaderMap.h"
42 #include "KURL.h"
43 #include "Logging.h"
44 #include "ScriptCallStack.h"
45 #include "ScriptExecutionContext.h"
46 #include "SecurityOrigin.h"
47 #include <wtf/MD5.h>
48 #include <wtf/RandomNumber.h>
49 #include <wtf/StdLibExtras.h>
50 #include <wtf/StringExtras.h>
51 #include <wtf/Vector.h>
52 #include <wtf/text/AtomicString.h>
53 #include <wtf/text/CString.h>
54 #include <wtf/text/StringBuilder.h>
55 #include <wtf/text/StringConcatenate.h>
56 #include <wtf/unicode/CharacterNames.h>
57
58 namespace WebCore {
59
60 static const char randomCharacterInSecWebSocketKey[] = "!\"#$%&'()*+,-./:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~";
61
62 static String resourceName(const KURL& url)
63 {
64     String name = url.path();
65     if (name.isEmpty())
66         name = "/";
67     if (!url.query().isNull())
68         name += "?" + url.query();
69     ASSERT(!name.isEmpty());
70     ASSERT(!name.contains(' '));
71     return name;
72 }
73
74 static String hostName(const KURL& url, bool secure)
75 {
76     ASSERT(url.protocolIs("wss") == secure);
77     StringBuilder builder;
78     builder.append(url.host().lower());
79     if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) {
80         builder.append(':');
81         builder.append(String::number(url.port()));
82     }
83     return builder.toString();
84 }
85
86 static const size_t maxConsoleMessageSize = 128;
87 static String trimConsoleMessage(const char* p, size_t len)
88 {
89     String s = String(p, std::min<size_t>(len, maxConsoleMessageSize));
90     if (len > maxConsoleMessageSize)
91         s.append(horizontalEllipsis);
92     return s;
93 }
94
95 static void generateSecWebSocketKey(uint32_t& number, String& key)
96 {
97     uint32_t space = static_cast<uint32_t>(randomNumber() * 12) + 1;
98     uint32_t max = 4294967295U / space;
99     number = static_cast<uint32_t>(randomNumber() * max);
100     uint32_t product = number * space;
101
102     String s = String::number(product);
103     int n = static_cast<int>(randomNumber() * 12) + 1;
104     DEFINE_STATIC_LOCAL(String, randomChars, (randomCharacterInSecWebSocketKey));
105     for (int i = 0; i < n; i++) {
106         int pos = static_cast<int>(randomNumber() * (s.length() + 1));
107         int chpos = static_cast<int>(randomNumber() * randomChars.length());
108         s.insert(randomChars.substring(chpos, 1), pos);
109     }
110     DEFINE_STATIC_LOCAL(String, spaceChar, (" "));
111     for (uint32_t i = 0; i < space; i++) {
112         int pos = static_cast<int>(randomNumber() * (s.length() - 1)) + 1;
113         s.insert(spaceChar, pos);
114     }
115     ASSERT(s[0] != ' ');
116     ASSERT(s[s.length() - 1] != ' ');
117     key = s;
118 }
119
120 static void generateKey3(unsigned char key3[8])
121 {
122     for (int i = 0; i < 8; i++)
123         key3[i] = randomNumber() * 256;
124 }
125
126 static void setChallengeNumber(unsigned char* buf, uint32_t number)
127 {
128     unsigned char* p = buf + 3;
129     for (int i = 0; i < 4; i++) {
130         *p = number & 0xFF;
131         --p;
132         number >>= 8;
133     }
134 }
135
136 static void generateExpectedChallengeResponse(uint32_t number1, uint32_t number2, unsigned char key3[8], unsigned char expectedChallenge[16])
137 {
138     unsigned char challenge[16];
139     setChallengeNumber(&challenge[0], number1);
140     setChallengeNumber(&challenge[4], number2);
141     memcpy(&challenge[8], key3, 8);
142     MD5 md5;
143     md5.addBytes(challenge, sizeof(challenge));
144     Vector<uint8_t, 16> digest;
145     md5.checksum(digest);
146     memcpy(expectedChallenge, digest.data(), 16);
147 }
148
149 WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, ScriptExecutionContext* context)
150     : m_url(url)
151     , m_clientProtocol(protocol)
152     , m_secure(m_url.protocolIs("wss"))
153     , m_context(context)
154     , m_mode(Incomplete)
155 {
156     uint32_t number1;
157     uint32_t number2;
158     generateSecWebSocketKey(number1, m_secWebSocketKey1);
159     generateSecWebSocketKey(number2, m_secWebSocketKey2);
160     generateKey3(m_key3);
161     generateExpectedChallengeResponse(number1, number2, m_key3, m_expectedChallengeResponse);
162 }
163
164 WebSocketHandshake::~WebSocketHandshake()
165 {
166 }
167
168 const KURL& WebSocketHandshake::url() const
169 {
170     return m_url;
171 }
172
173 void WebSocketHandshake::setURL(const KURL& url)
174 {
175     m_url = url.copy();
176 }
177
178 const String WebSocketHandshake::host() const
179 {
180     return m_url.host().lower();
181 }
182
183 const String& WebSocketHandshake::clientProtocol() const
184 {
185     return m_clientProtocol;
186 }
187
188 void WebSocketHandshake::setClientProtocol(const String& protocol)
189 {
190     m_clientProtocol = protocol;
191 }
192
193 bool WebSocketHandshake::secure() const
194 {
195     return m_secure;
196 }
197
198 String WebSocketHandshake::clientOrigin() const
199 {
200     return m_context->securityOrigin()->toString();
201 }
202
203 String WebSocketHandshake::clientLocation() const
204 {
205     StringBuilder builder;
206     builder.append(m_secure ? "wss" : "ws");
207     builder.append("://");
208     builder.append(hostName(m_url, m_secure));
209     builder.append(resourceName(m_url));
210     return builder.toString();
211 }
212
213 CString WebSocketHandshake::clientHandshakeMessage() const
214 {
215     // Keep the following consistent with clientHandshakeRequest().
216     StringBuilder builder;
217
218     builder.append("GET ");
219     builder.append(resourceName(m_url));
220     builder.append(" HTTP/1.1\r\n");
221
222     Vector<String> fields;
223     fields.append("Upgrade: WebSocket");
224     fields.append("Connection: Upgrade");
225     fields.append("Host: " + hostName(m_url, m_secure));
226     fields.append("Origin: " + clientOrigin());
227     if (!m_clientProtocol.isEmpty())
228         fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol);
229
230     KURL url = httpURLForAuthenticationAndCookies();
231     if (m_context->isDocument()) {
232         Document* document = static_cast<Document*>(m_context);
233         String cookie = cookieRequestHeaderFieldValue(document, url);
234         if (!cookie.isEmpty())
235             fields.append("Cookie: " + cookie);
236         // Set "Cookie2: <cookie>" if cookies 2 exists for url?
237     }
238
239     fields.append("Sec-WebSocket-Key1: " + m_secWebSocketKey1);
240     fields.append("Sec-WebSocket-Key2: " + m_secWebSocketKey2);
241
242     // Fields in the handshake are sent by the client in a random order; the
243     // order is not meaningful.  Thus, it's ok to send the order we constructed
244     // the fields.
245
246     for (size_t i = 0; i < fields.size(); i++) {
247         builder.append(fields[i]);
248         builder.append("\r\n");
249     }
250
251     builder.append("\r\n");
252
253     CString handshakeHeader = builder.toString().utf8();
254     char* characterBuffer = 0;
255     CString msg = CString::newUninitialized(handshakeHeader.length() + sizeof(m_key3), characterBuffer);
256     memcpy(characterBuffer, handshakeHeader.data(), handshakeHeader.length());
257     memcpy(characterBuffer + handshakeHeader.length(), m_key3, sizeof(m_key3));
258     return msg;
259 }
260
261 WebSocketHandshakeRequest WebSocketHandshake::clientHandshakeRequest() const
262 {
263     // Keep the following consistent with clientHandshakeMessage().
264     // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and
265     // m_key3 in WebSocketHandshakeRequest?
266     WebSocketHandshakeRequest request("GET", m_url);
267     request.addHeaderField("Upgrade", "WebSocket");
268     request.addHeaderField("Connection", "Upgrade");
269     request.addHeaderField("Host", hostName(m_url, m_secure));
270     request.addHeaderField("Origin", clientOrigin());
271     if (!m_clientProtocol.isEmpty())
272         request.addHeaderField("Sec-WebSocket-Protocol:", m_clientProtocol);
273
274     KURL url = httpURLForAuthenticationAndCookies();
275     if (m_context->isDocument()) {
276         Document* document = static_cast<Document*>(m_context);
277         String cookie = cookieRequestHeaderFieldValue(document, url);
278         if (!cookie.isEmpty())
279             request.addHeaderField("Cookie", cookie);
280         // Set "Cookie2: <cookie>" if cookies 2 exists for url?
281     }
282
283     request.addHeaderField("Sec-WebSocket-Key1", m_secWebSocketKey1);
284     request.addHeaderField("Sec-WebSocket-Key2", m_secWebSocketKey2);
285     request.setKey3(m_key3);
286
287     return request;
288 }
289
290 void WebSocketHandshake::reset()
291 {
292     m_mode = Incomplete;
293
294     m_wsOrigin = String();
295     m_wsLocation = String();
296     m_wsProtocol = String();
297     m_setCookie = String();
298     m_setCookie2 = String();
299 }
300
301 void WebSocketHandshake::clearScriptExecutionContext()
302 {
303     m_context = 0;
304 }
305
306 int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
307 {
308     m_mode = Incomplete;
309     int statusCode;
310     String statusText;
311     int lineLength = readStatusLine(header, len, statusCode, statusText);
312     if (lineLength == -1)
313         return -1;
314     if (statusCode == -1) {
315         m_mode = Failed;
316         return len;
317     }
318     LOG(Network, "response code: %d", statusCode);
319     m_response.setStatusCode(statusCode);
320     m_response.setStatusText(statusText);
321     if (statusCode != 101) {
322         m_mode = Failed;
323         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, makeString("Unexpected response code: ", String::number(statusCode)), 0, clientOrigin(), 0);
324         return len;
325     }
326     m_mode = Normal;
327     if (!strnstr(header, "\r\n\r\n", len)) {
328         // Just hasn't been received fully yet.
329         m_mode = Incomplete;
330         return -1;
331     }
332     const char* p = readHTTPHeaders(header + lineLength, header + len);
333     if (!p) {
334         LOG(Network, "readHTTPHeaders failed");
335         m_mode = Failed;
336         return len;
337     }
338     processHeaders();
339     if (!checkResponseHeaders()) {
340         LOG(Network, "header process failed");
341         m_mode = Failed;
342         return p - header;
343     }
344     if (len < static_cast<size_t>(p - header + sizeof(m_expectedChallengeResponse))) {
345         // Just hasn't been received /expected/ yet.
346         m_mode = Incomplete;
347         return -1;
348     }
349     m_response.setChallengeResponse(static_cast<const unsigned char*>(static_cast<const void*>(p)));
350     if (memcmp(p, m_expectedChallengeResponse, sizeof(m_expectedChallengeResponse))) {
351         m_mode = Failed;
352         return (p - header) + sizeof(m_expectedChallengeResponse);
353     }
354     m_mode = Connected;
355     return (p - header) + sizeof(m_expectedChallengeResponse);
356 }
357
358 WebSocketHandshake::Mode WebSocketHandshake::mode() const
359 {
360     return m_mode;
361 }
362
363 const String& WebSocketHandshake::serverWebSocketOrigin() const
364 {
365     return m_wsOrigin;
366 }
367
368 void WebSocketHandshake::setServerWebSocketOrigin(const String& webSocketOrigin)
369 {
370     m_wsOrigin = webSocketOrigin;
371 }
372
373 const String& WebSocketHandshake::serverWebSocketLocation() const
374 {
375     return m_wsLocation;
376 }
377
378 void WebSocketHandshake::setServerWebSocketLocation(const String& webSocketLocation)
379 {
380     m_wsLocation = webSocketLocation;
381 }
382
383 const String& WebSocketHandshake::serverWebSocketProtocol() const
384 {
385     return m_wsProtocol;
386 }
387
388 void WebSocketHandshake::setServerWebSocketProtocol(const String& webSocketProtocol)
389 {
390     m_wsProtocol = webSocketProtocol;
391 }
392
393 const String& WebSocketHandshake::serverSetCookie() const
394 {
395     return m_setCookie;
396 }
397
398 void WebSocketHandshake::setServerSetCookie(const String& setCookie)
399 {
400     m_setCookie = setCookie;
401 }
402
403 const String& WebSocketHandshake::serverSetCookie2() const
404 {
405     return m_setCookie2;
406 }
407
408 void WebSocketHandshake::setServerSetCookie2(const String& setCookie2)
409 {
410     m_setCookie2 = setCookie2;
411 }
412
413 const WebSocketHandshakeResponse& WebSocketHandshake::serverHandshakeResponse() const
414 {
415     return m_response;
416 }
417
418 KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
419 {
420     KURL url = m_url.copy();
421     bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
422     ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
423     return url;
424 }
425
426 // Returns the header length (including "\r\n"), or -1 if we have not received enough data yet.
427 // If the line is malformed or the status code is not a 3-digit number,
428 // statusCode and statusText will be set to -1 and a null string, respectively.
429 int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, int& statusCode, String& statusText)
430 {
431     // Arbitrary size limit to prevent the server from sending an unbounded
432     // amount of data with no newlines and forcing us to buffer it all.
433     static const int maximumLength = 1024;
434
435     statusCode = -1;
436     statusText = String();
437
438     const char* space1 = 0;
439     const char* space2 = 0;
440     const char* p;
441     size_t consumedLength;
442
443     for (p = header, consumedLength = 0; consumedLength < headerLength; p++, consumedLength++) {
444         if (*p == ' ') {
445             if (!space1)
446                 space1 = p;
447             else if (!space2)
448                 space2 = p;
449         } else if (*p == '\0') {
450             // The caller isn't prepared to deal with null bytes in status
451             // line. WebSockets specification doesn't prohibit this, but HTTP
452             // does, so we'll just treat this as an error. 
453             m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Status line contains embedded null", 0, clientOrigin(), 0);
454             return p + 1 - header;
455         } else if (*p == '\n')
456             break;
457     }
458     if (consumedLength == headerLength)
459         return -1; // We have not received '\n' yet.
460
461     const char* end = p + 1;
462     if (end - header > maximumLength) {
463         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Status line is too long", 0, clientOrigin(), 0);
464         return maximumLength;
465     }
466     int lineLength = end - header;
467
468     if (!space1 || !space2) {
469         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "No response code found: " + trimConsoleMessage(header, lineLength - 1), 0, clientOrigin(), 0);
470         return lineLength;
471     }
472
473     // The line must end with "\r\n".
474     if (*(end - 2) != '\r') {
475         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Status line does not end with CRLF", 0, clientOrigin(), 0);
476         return lineLength;
477     }
478
479     String statusCodeString(space1 + 1, space2 - space1 - 1);
480     if (statusCodeString.length() != 3) // Status code must consist of three digits.
481         return lineLength;
482     for (int i = 0; i < 3; ++i)
483         if (statusCodeString[i] < '0' || statusCodeString[i] > '9') {
484             m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Invalid status code: " + statusCodeString, 0, clientOrigin(), 0);
485             return lineLength;
486         }
487
488     bool ok = false;
489     statusCode = statusCodeString.toInt(&ok);
490     ASSERT(ok);
491
492     statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n".
493     return lineLength;
494 }
495
496 const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end)
497 {
498     m_response.clearHeaderFields();
499
500     Vector<char> name;
501     Vector<char> value;
502     for (const char* p = start; p < end; p++) {
503         name.clear();
504         value.clear();
505
506         for (; p < end; p++) {
507             switch (*p) {
508             case '\r':
509                 if (name.isEmpty()) {
510                     if (p + 1 < end && *(p + 1) == '\n')
511                         return p + 2;
512                     m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF at " + trimConsoleMessage(p, end - p), 0, clientOrigin(), 0);
513                     return 0;
514                 }
515                 m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected CR in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin(), 0);
516                 return 0;
517             case '\n':
518                 m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin(), 0);
519                 return 0;
520             case ':':
521                 break;
522             default:
523                 name.append(*p);
524                 continue;
525             }
526             if (*p == ':') {
527                 ++p;
528                 break;
529             }
530         }
531
532         for (; p < end && *p == 0x20; p++) { }
533
534         for (; p < end; p++) {
535             switch (*p) {
536             case '\r':
537                 break;
538             case '\n':
539                 m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in value at " + trimConsoleMessage(value.data(), value.size()), 0, clientOrigin(), 0);
540                 return 0;
541             default:
542                 value.append(*p);
543             }
544             if (*p == '\r') {
545                 ++p;
546                 break;
547             }
548         }
549         if (p >= end || *p != '\n') {
550             m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF after value at " + trimConsoleMessage(p, end - p), 0, clientOrigin(), 0);
551             return 0;
552         }
553         AtomicString nameStr(String::fromUTF8(name.data(), name.size()));
554         String valueStr = String::fromUTF8(value.data(), value.size());
555         if (nameStr.isNull()) {
556             m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "invalid UTF-8 sequence in header name", 0, clientOrigin(), 0);
557             return 0;
558         }
559         if (valueStr.isNull()) {
560             m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "invalid UTF-8 sequence in header value", 0, clientOrigin(), 0);
561             return 0;
562         }
563         LOG(Network, "name=%s value=%s", nameStr.string().utf8().data(), valueStr.utf8().data());
564         m_response.addHeaderField(nameStr, valueStr);
565     }
566     ASSERT_NOT_REACHED();
567     return 0;
568 }
569
570 void WebSocketHandshake::processHeaders()
571 {
572     ASSERT(m_mode == Normal);
573     const HTTPHeaderMap& headers = m_response.headerFields();
574     m_wsOrigin = headers.get("sec-websocket-origin");
575     m_wsLocation = headers.get("sec-websocket-location");
576     m_wsProtocol = headers.get("sec-websocket-protocol");
577     m_setCookie = headers.get("set-cookie");
578     m_setCookie2 = headers.get("set-cookie2");
579 }
580
581 bool WebSocketHandshake::checkResponseHeaders()
582 {
583     if (m_wsOrigin.isNull()) {
584         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'sec-websocket-origin' header is missing", 0, clientOrigin(), 0);
585         return false;
586     }
587     if (m_wsLocation.isNull()) {
588         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'sec-websocket-location' header is missing", 0, clientOrigin(), 0);
589         return false;
590     }
591
592     if (clientOrigin() != m_wsOrigin) {
593         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: origin mismatch: " + clientOrigin() + " != " + m_wsOrigin, 0, clientOrigin(), 0);
594         return false;
595     }
596     if (clientLocation() != m_wsLocation) {
597         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: location mismatch: " + clientLocation() + " != " + m_wsLocation, 0, clientOrigin(), 0);
598         return false;
599     }
600     if (!m_clientProtocol.isEmpty() && m_clientProtocol != m_wsProtocol) {
601         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: protocol mismatch: " + m_clientProtocol + " != " + m_wsProtocol, 0, clientOrigin(), 0);
602         return false;
603     }
604     return true;
605 }
606
607 } // namespace WebCore
608
609 #endif // ENABLE(WEB_SOCKETS)