2009-12-16 Fumitoshi Ukai <ukai@chromium.org>
[WebKit-https.git] / WebCore / websockets / WebSocketHandshake.cpp
1 /*
2  * Copyright (C) 2009 Google Inc.  All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions are
6  * met:
7  *
8  *     * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  *     * Redistributions in binary form must reproduce the above
11  * copyright notice, this list of conditions and the following disclaimer
12  * in the documentation and/or other materials provided with the
13  * distribution.
14  *     * Neither the name of Google Inc. nor the names of its
15  * contributors may be used to endorse or promote products derived from
16  * this software without specific prior written permission.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  */
30
31 #include "config.h"
32
33 #if ENABLE(WEB_SOCKETS)
34
35 #include "WebSocketHandshake.h"
36
37 #include "AtomicString.h"
38 #include "CString.h"
39 #include "CookieJar.h"
40 #include "Document.h"
41 #include "HTTPHeaderMap.h"
42 #include "KURL.h"
43 #include "Logging.h"
44 #include "ScriptExecutionContext.h"
45 #include "SecurityOrigin.h"
46 #include "StringBuilder.h"
47 #include <wtf/StringExtras.h>
48 #include <wtf/Vector.h>
49
50 namespace WebCore {
51
52 const char webSocketServerHandshakeHeader[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
53 const char webSocketUpgradeHeader[] = "Upgrade: WebSocket\r\n";
54 const char webSocketConnectionHeader[] = "Connection: Upgrade\r\n";
55
56 static String extractResponseCode(const char* header, int len)
57 {
58     const char* space1 = 0;
59     const char* space2 = 0;
60     const char* p;
61     for (p = header; p - header < len; p++) {
62         if (*p == ' ') {
63             if (!space1)
64                 space1 = p;
65             else if (!space2)
66                 space2 = p;
67         } else if (*p == '\n')
68             break;
69     }
70     if (p - header == len)
71         return String();
72     if (!space1 || !space2)
73         return "";
74     return String(space1 + 1, space2 - space1 - 1);
75 }
76
77 static String resourceName(const KURL& url)
78 {
79     if (url.query().isNull())
80         return url.path();
81     return url.path() + "?" + url.query();
82 }
83
84 WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, ScriptExecutionContext* context)
85     : m_url(url)
86     , m_clientProtocol(protocol)
87     , m_secure(m_url.protocolIs("wss"))
88     , m_context(context)
89     , m_mode(Incomplete)
90 {
91 }
92
93 WebSocketHandshake::~WebSocketHandshake()
94 {
95 }
96
97 const KURL& WebSocketHandshake::url() const
98 {
99     return m_url;
100 }
101
102 void WebSocketHandshake::setURL(const KURL& url)
103 {
104     m_url = url.copy();
105 }
106
107 const String WebSocketHandshake::host() const
108 {
109     return m_url.host().lower();
110 }
111
112 const String& WebSocketHandshake::clientProtocol() const
113 {
114     return m_clientProtocol;
115 }
116
117 void WebSocketHandshake::setClientProtocol(const String& protocol)
118 {
119     m_clientProtocol = protocol;
120 }
121
122 bool WebSocketHandshake::secure() const
123 {
124     return m_secure;
125 }
126
127 void WebSocketHandshake::setSecure(bool secure)
128 {
129     m_secure = secure;
130 }
131
132 String WebSocketHandshake::clientOrigin() const
133 {
134     return m_context->securityOrigin()->toString();
135 }
136
137 String WebSocketHandshake::clientLocation() const
138 {
139     StringBuilder builder;
140     builder.append(m_secure ? "wss" : "ws");
141     builder.append("://");
142     builder.append(m_url.host().lower());
143     if (m_url.port()) {
144         if ((!m_secure && m_url.port() != 80) || (m_secure && m_url.port() != 443)) {
145             builder.append(":");
146             builder.append(String::number(m_url.port()));
147         }
148     }
149     builder.append(resourceName(m_url));
150     return builder.toString();
151 }
152
153 CString WebSocketHandshake::clientHandshakeMessage() const
154 {
155     StringBuilder builder;
156
157     builder.append("GET ");
158     builder.append(resourceName(m_url));
159     builder.append(" HTTP/1.1\r\n");
160     builder.append("Upgrade: WebSocket\r\n");
161     builder.append("Connection: Upgrade\r\n");
162     builder.append("Host: ");
163     builder.append(m_url.host().lower());
164     if (m_url.port()) {
165         if ((!m_secure && m_url.port() != 80) || (m_secure && m_url.port() != 443)) {
166             builder.append(":");
167             builder.append(String::number(m_url.port()));
168         }
169     }
170     builder.append("\r\n");
171     builder.append("Origin: ");
172     builder.append(clientOrigin());
173     builder.append("\r\n");
174     if (!m_clientProtocol.isEmpty()) {
175         builder.append("WebSocket-Protocol: ");
176         builder.append(m_clientProtocol);
177         builder.append("\r\n");
178     }
179     KURL url = httpURLForAuthenticationAndCookies();
180     // FIXME: set authentication information or cookies for url.
181     // Set "Authorization: <credentials>" if authentication information exists for url.
182     if (m_context->isDocument()) {
183         Document* document = static_cast<Document*>(m_context);
184         String cookie = cookies(document, url);
185         if (!cookie.isEmpty()) {
186             builder.append("Cookie: ");
187             builder.append(cookie);
188             builder.append("\r\n");
189         }
190         // Set "Cookie2: <cookie>" if cookies 2 exists for url?
191     }
192     builder.append("\r\n");
193     return builder.toString().utf8();
194 }
195
196 void WebSocketHandshake::reset()
197 {
198     m_mode = Incomplete;
199
200     m_wsOrigin = String();
201     m_wsLocation = String();
202     m_wsProtocol = String();
203     m_setCookie = String();
204     m_setCookie2 = String();
205 }
206
207 int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
208 {
209     m_mode = Incomplete;
210     if (len < sizeof(webSocketServerHandshakeHeader) - 1) {
211         // Just hasn't been received fully yet.
212         return -1;
213     }
214     if (!memcmp(header, webSocketServerHandshakeHeader, sizeof(webSocketServerHandshakeHeader) - 1))
215         m_mode = Normal;
216     else {
217         const String& code = extractResponseCode(header, len);
218         if (code.isNull()) {
219             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Short server handshake: " + String(header, len), 0, clientOrigin());
220             return -1;
221         }
222         if (code.isEmpty()) {
223             m_mode = Failed;
224             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "No response code found: " + String(header, len), 0, clientOrigin());
225             return len;
226         }
227         LOG(Network, "response code: %s", code.utf8().data());
228         if (code == "401") {
229             m_mode = Failed;
230             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Authentication required, but not implemented yet.", 0, clientOrigin());
231             return len;
232         } else {
233             m_mode = Failed;
234             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected response code:" + code, 0, clientOrigin());
235             return len;
236         }
237     }
238     const char* p = header + sizeof(webSocketServerHandshakeHeader) - 1;
239     const char* end = header + len + 1;
240
241     if (m_mode == Normal) {
242         size_t headerSize = end - p;
243         if (headerSize < sizeof(webSocketUpgradeHeader) - 1) {
244             m_mode = Incomplete;
245             return 0;
246         }
247         if (memcmp(p, webSocketUpgradeHeader, sizeof(webSocketUpgradeHeader) - 1)) {
248             m_mode = Failed;
249             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Bad Upgrade header: " + String(p, end - p), 0, clientOrigin());
250             return p - header + sizeof(webSocketUpgradeHeader) - 1;
251         }
252         p += sizeof(webSocketUpgradeHeader) - 1;
253
254         headerSize = end - p;
255         if (headerSize < sizeof(webSocketConnectionHeader) - 1) {
256             m_mode = Incomplete;
257             return -1;
258         }
259         if (memcmp(p, webSocketConnectionHeader, sizeof(webSocketConnectionHeader) - 1)) {
260             m_mode = Failed;
261             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Bad Connection header: " + String(p, end - p), 0, clientOrigin());
262             return p - header + sizeof(webSocketConnectionHeader) - 1;
263         }
264         p += sizeof(webSocketConnectionHeader) - 1;
265     }
266
267     if (!strnstr(p, "\r\n\r\n", end - p)) {
268         // Just hasn't been received fully yet.
269         m_mode = Incomplete;
270         return -1;
271     }
272     HTTPHeaderMap headers;
273     p = readHTTPHeaders(p, end, &headers);
274     if (!p) {
275         LOG(Network, "readHTTPHeaders failed");
276         m_mode = Failed;
277         return len;
278     }
279     if (!processHeaders(headers)) {
280         LOG(Network, "header process failed");
281         m_mode = Failed;
282         return p - header;
283     }
284     switch (m_mode) {
285     case Normal:
286         checkResponseHeaders();
287         break;
288     default:
289         m_mode = Failed;
290         break;
291     }
292     return p - header;
293 }
294
295 WebSocketHandshake::Mode WebSocketHandshake::mode() const
296 {
297     return m_mode;
298 }
299
300 const String& WebSocketHandshake::serverWebSocketOrigin() const
301 {
302     return m_wsOrigin;
303 }
304
305 void WebSocketHandshake::setServerWebSocketOrigin(const String& webSocketOrigin)
306 {
307     m_wsOrigin = webSocketOrigin;
308 }
309
310 const String& WebSocketHandshake::serverWebSocketLocation() const
311 {
312     return m_wsLocation;
313 }
314
315 void WebSocketHandshake::setServerWebSocketLocation(const String& webSocketLocation)
316 {
317     m_wsLocation = webSocketLocation;
318 }
319
320 const String& WebSocketHandshake::serverWebSocketProtocol() const
321 {
322     return m_wsProtocol;
323 }
324
325 void WebSocketHandshake::setServerWebSocketProtocol(const String& webSocketProtocol)
326 {
327     m_wsProtocol = webSocketProtocol;
328 }
329
330 const String& WebSocketHandshake::serverSetCookie() const
331 {
332     return m_setCookie;
333 }
334
335 void WebSocketHandshake::setServerSetCookie(const String& setCookie)
336 {
337     m_setCookie = setCookie;
338 }
339
340 const String& WebSocketHandshake::serverSetCookie2() const
341 {
342     return m_setCookie2;
343 }
344
345 void WebSocketHandshake::setServerSetCookie2(const String& setCookie2)
346 {
347     m_setCookie2 = setCookie2;
348 }
349
350 KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
351 {
352     KURL url = m_url.copy();
353     url.setProtocol(m_secure ? "https" : "http");
354     return url;
355 }
356
357 const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end, HTTPHeaderMap* headers)
358 {
359     Vector<char> name;
360     Vector<char> value;
361     for (const char* p = start; p < end; p++) {
362         name.clear();
363         value.clear();
364
365         for (; p < end; p++) {
366             switch (*p) {
367             case '\r':
368                 if (name.isEmpty()) {
369                     if (p + 1 < end && *(p + 1) == '\n')
370                         return p + 2;
371                     m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF at " + String(p, end - p), 0, clientOrigin());
372                     return 0;
373                 }
374                 m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected CR in name at " + String(p, end - p), 0, clientOrigin());
375                 return 0;
376             case '\n':
377                 m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in name at " + String(p, end - p), 0, clientOrigin());
378                 return 0;
379             case ':':
380                 break;
381             default:
382                 if (*p >= 0x41 && *p <= 0x5a)
383                     name.append(*p + 0x20);
384                 else
385                     name.append(*p);
386                 continue;
387             }
388             if (*p == ':') {
389                 ++p;
390                 break;
391             }
392         }
393
394         for (; p < end && *p == 0x20; p++) { }
395
396         for (; p < end; p++) {
397             switch (*p) {
398             case '\r':
399                 break;
400             case '\n':
401                 m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in value at " + String(p, end - p), 0, clientOrigin());
402                 return 0;
403             default:
404                 value.append(*p);
405             }
406             if (*p == '\r') {
407                 ++p;
408                 break;
409             }
410         }
411         if (p >= end || *p != '\n') {
412             m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF after value at " + String(p, end - p), 0, clientOrigin());
413             return 0;
414         }
415         AtomicString nameStr(String::fromUTF8(name.data(), name.size()));
416         String valueStr = String::fromUTF8(value.data(), value.size());
417         LOG(Network, "name=%s value=%s", nameStr.string().utf8().data(), valueStr.utf8().data());
418         headers->add(nameStr, valueStr);
419     }
420     ASSERT_NOT_REACHED();
421     return 0;
422 }
423
424 bool WebSocketHandshake::processHeaders(const HTTPHeaderMap& headers)
425 {
426     for (HTTPHeaderMap::const_iterator it = headers.begin(); it != headers.end(); ++it) {
427         switch (m_mode) {
428         case Normal:
429             if (it->first == "websocket-origin")
430                 m_wsOrigin = it->second;
431             else if (it->first == "websocket-location")
432                 m_wsLocation = it->second;
433             else if (it->first == "websocket-protocol")
434                 m_wsProtocol = it->second;
435             else if (it->first == "set-cookie")
436                 m_setCookie = it->second;
437             else if (it->first == "set-cookie2")
438                 m_setCookie2 = it->second;
439             continue;
440         case Incomplete:
441         case Failed:
442         case Connected:
443             ASSERT_NOT_REACHED();
444         }
445         ASSERT_NOT_REACHED();
446     }
447     return true;
448 }
449
450 void WebSocketHandshake::checkResponseHeaders()
451 {
452     ASSERT(m_mode == Normal);
453     m_mode = Failed;
454     if (m_wsOrigin.isNull()) {
455         m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'websocket-origin' header is missing", 0, clientOrigin());
456         return;
457     }
458     if (m_wsLocation.isNull()) {
459         m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'websocket-location' header is missing", 0, clientOrigin());
460         return;
461     }
462
463     if (clientOrigin() != m_wsOrigin) {
464         m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: origin mismatch: " + clientOrigin() + " != " + m_wsOrigin, 0, clientOrigin());
465         return;
466     }
467     if (clientLocation() != m_wsLocation) {
468         m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: location mismatch: " + clientLocation() + " != " + m_wsLocation, 0, clientOrigin());
469         return;
470     }
471     if (!m_clientProtocol.isEmpty() && m_clientProtocol != m_wsProtocol) {
472         m_context->addMessage(ConsoleDestination, JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: protocol mismatch: " + m_clientProtocol + " != " + m_wsProtocol, 0, clientOrigin());
473         return;
474     }
475     m_mode = Connected;
476     return;
477 }
478
479 }  // namespace WebCore
480
481 #endif  // ENABLE(WEB_SOCKETS)