691fa1cdc4da9e6a3198fb587bc5005cfa36aa7d
[WebKit-https.git] / WebCore / websockets / WebSocketHandshake.cpp
1 /*
2  * Copyright (C) 2009 Google Inc.  All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions are
6  * met:
7  *
8  *     * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  *     * Redistributions in binary form must reproduce the above
11  * copyright notice, this list of conditions and the following disclaimer
12  * in the documentation and/or other materials provided with the
13  * distribution.
14  *     * Neither the name of Google Inc. nor the names of its
15  * contributors may be used to endorse or promote products derived from
16  * this software without specific prior written permission.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  */
30
31 #include "config.h"
32
33 #if ENABLE(WEB_SOCKETS)
34
35 #include "WebSocketHandshake.h"
36
37 #include "AtomicString.h"
38 #include "CString.h"
39 #include "CookieJar.h"
40 #include "Document.h"
41 #include "HTTPHeaderMap.h"
42 #include "KURL.h"
43 #include "Logging.h"
44 #include "ScriptExecutionContext.h"
45 #include "SecurityOrigin.h"
46 #include "StringBuilder.h"
47 #include <wtf/StringExtras.h>
48 #include <wtf/Vector.h>
49
50 namespace WebCore {
51
52 const char webSocketServerHandshakeHeader[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
53 const char webSocketUpgradeHeader[] = "Upgrade: WebSocket\r\n";
54 const char webSocketConnectionHeader[] = "Connection: Upgrade\r\n";
55
56 static String extractResponseCode(const char* header, int len)
57 {
58     const char* space1 = 0;
59     const char* space2 = 0;
60     const char* p;
61     for (p = header; p - header < len; p++) {
62         if (*p == ' ') {
63             if (!space1)
64                 space1 = p;
65             else if (!space2)
66                 space2 = p;
67         } else if (*p == '\n')
68             break;
69     }
70     if (p - header == len)
71         return String();
72     if (!space1 || !space2)
73         return "";
74     return String(space1 + 1, space2 - space1 - 1);
75 }
76
77 WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, ScriptExecutionContext* context)
78     : m_url(url)
79     , m_clientProtocol(protocol)
80     , m_secure(m_url.protocolIs("wss"))
81     , m_context(context)
82     , m_mode(Incomplete)
83 {
84 }
85
86 WebSocketHandshake::~WebSocketHandshake()
87 {
88 }
89
90 const KURL& WebSocketHandshake::url() const
91 {
92     return m_url;
93 }
94
95 void WebSocketHandshake::setURL(const KURL& url)
96 {
97     m_url = url.copy();
98 }
99
100 const String WebSocketHandshake::host() const
101 {
102     return m_url.host().lower();
103 }
104
105 const String& WebSocketHandshake::clientProtocol() const
106 {
107     return m_clientProtocol;
108 }
109
110 void WebSocketHandshake::setClientProtocol(const String& protocol)
111 {
112     m_clientProtocol = protocol;
113 }
114
115 bool WebSocketHandshake::secure() const
116 {
117     return m_secure;
118 }
119
120 void WebSocketHandshake::setSecure(bool secure)
121 {
122     m_secure = secure;
123 }
124
125 String WebSocketHandshake::clientOrigin() const
126 {
127     return m_context->securityOrigin()->toString();
128 }
129
130 String WebSocketHandshake::clientLocation() const
131 {
132     StringBuilder builder;
133     builder.append(m_secure ? "wss" : "ws");
134     builder.append("://");
135     builder.append(m_url.host().lower());
136     if (m_url.port()) {
137         if ((!m_secure && m_url.port() != 80) || (m_secure && m_url.port() != 443)) {
138             builder.append(":");
139             builder.append(String::number(m_url.port()));
140         }
141     }
142     builder.append(m_url.path());
143     return builder.toString();
144 }
145
146 CString WebSocketHandshake::clientHandshakeMessage() const
147 {
148     StringBuilder builder;
149
150     builder.append("GET ");
151     builder.append(m_url.path());
152     if (!m_url.query().isEmpty()) {
153         builder.append("?");
154         builder.append(m_url.query());
155     }
156     builder.append(" HTTP/1.1\r\n");
157     builder.append("Upgrade: WebSocket\r\n");
158     builder.append("Connection: Upgrade\r\n");
159     builder.append("Host: ");
160     builder.append(m_url.host().lower());
161     if (m_url.port()) {
162         if ((!m_secure && m_url.port() != 80) || (m_secure && m_url.port() != 443)) {
163             builder.append(":");
164             builder.append(String::number(m_url.port()));
165         }
166     }
167     builder.append("\r\n");
168     builder.append("Origin: ");
169     builder.append(clientOrigin());
170     builder.append("\r\n");
171     if (!m_clientProtocol.isEmpty()) {
172         builder.append("WebSocket-Protocol: ");
173         builder.append(m_clientProtocol);
174         builder.append("\r\n");
175     }
176     KURL url = httpURLForAuthenticationAndCookies();
177     // FIXME: set authentication information or cookies for url.
178     // Set "Authorization: <credentials>" if authentication information exists for url.
179     if (m_context->isDocument()) {
180         Document* document = static_cast<Document*>(m_context);
181         String cookie = cookies(document, url);
182         if (!cookie.isEmpty()) {
183             builder.append("Cookie: ");
184             builder.append(cookie);
185             builder.append("\r\n");
186         }
187         // Set "Cookie2: <cookie>" if cookies 2 exists for url?
188     }
189     builder.append("\r\n");
190     return builder.toString().utf8();
191 }
192
193 void WebSocketHandshake::reset()
194 {
195     m_mode = Incomplete;
196
197     m_wsOrigin = String();
198     m_wsLocation = String();
199     m_wsProtocol = String();
200     m_setCookie = String();
201     m_setCookie2 = String();
202 }
203
204 int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
205 {
206     m_mode = Incomplete;
207     if (len < sizeof(webSocketServerHandshakeHeader) - 1) {
208         // Just hasn't been received fully yet.
209         return -1;
210     }
211     if (!memcmp(header, webSocketServerHandshakeHeader, sizeof(webSocketServerHandshakeHeader) - 1))
212         m_mode = Normal;
213     else {
214         const String& code = extractResponseCode(header, len);
215         if (code.isNull()) {
216             LOG(Network, "short server handshake: %s", header);
217             return -1;
218         }
219         if (code.isEmpty()) {
220             LOG(Network, "no response code found: %s", header);
221             return len;
222         }
223         LOG(Network, "response code: %s", code.utf8().data());
224         if (code == "401") {
225             LOG(Network, "Authentication required");
226             return len;
227         } else {
228             LOG(Network, "Mismatch server handshake: %s", header);
229             return len;
230         }
231     }
232     const char* p = header + sizeof(webSocketServerHandshakeHeader) - 1;
233     const char* end = header + len + 1;
234
235     if (m_mode == Normal) {
236         size_t headerSize = end - p;
237         if (headerSize < sizeof(webSocketUpgradeHeader) - 1)
238             return 0;
239         if (memcmp(p, webSocketUpgradeHeader, sizeof(webSocketUpgradeHeader) - 1)) {
240             LOG(Network, "Bad upgrade header: %s", p);
241             return p - header + sizeof(webSocketUpgradeHeader) - 1;
242         }
243         p += sizeof(webSocketUpgradeHeader) - 1;
244
245         headerSize = end - p;
246         if (headerSize < sizeof(webSocketConnectionHeader) - 1)
247             return -1;
248         if (memcmp(p, webSocketConnectionHeader, sizeof(webSocketConnectionHeader) - 1)) {
249             LOG(Network, "Bad connection header: %s", p);
250             return p - header + sizeof(webSocketConnectionHeader) - 1;
251         }
252         p += sizeof(webSocketConnectionHeader) - 1;
253     }
254
255     if (!strnstr(p, "\r\n\r\n", end - p)) {
256         // Just hasn't been received fully yet.
257         return -1;
258     }
259     HTTPHeaderMap headers;
260     p = readHTTPHeaders(p, end, &headers);
261     if (!p) {
262         LOG(Network, "readHTTPHeaders failed");
263         m_mode = Failed;
264         return len;
265     }
266     if (!processHeaders(headers)) {
267         LOG(Network, "header process failed");
268         m_mode = Failed;
269         return p - header;
270     }
271     switch (m_mode) {
272     case Normal:
273         checkResponseHeaders();
274         break;
275     default:
276         m_mode = Failed;
277         break;
278     }
279     return p - header;
280 }
281
282 WebSocketHandshake::Mode WebSocketHandshake::mode() const
283 {
284     return m_mode;
285 }
286
287 const String& WebSocketHandshake::serverWebSocketOrigin() const
288 {
289     return m_wsOrigin;
290 }
291
292 void WebSocketHandshake::setServerWebSocketOrigin(const String& webSocketOrigin)
293 {
294     m_wsOrigin = webSocketOrigin;
295 }
296
297 const String& WebSocketHandshake::serverWebSocketLocation() const
298 {
299     return m_wsLocation;
300 }
301
302 void WebSocketHandshake::setServerWebSocketLocation(const String& webSocketLocation)
303 {
304     m_wsLocation = webSocketLocation;
305 }
306
307 const String& WebSocketHandshake::serverWebSocketProtocol() const
308 {
309     return m_wsProtocol;
310 }
311
312 void WebSocketHandshake::setServerWebSocketProtocol(const String& webSocketProtocol)
313 {
314     m_wsProtocol = webSocketProtocol;
315 }
316
317 const String& WebSocketHandshake::serverSetCookie() const
318 {
319     return m_setCookie;
320 }
321
322 void WebSocketHandshake::setServerSetCookie(const String& setCookie)
323 {
324     m_setCookie = setCookie;
325 }
326
327 const String& WebSocketHandshake::serverSetCookie2() const
328 {
329     return m_setCookie2;
330 }
331
332 void WebSocketHandshake::setServerSetCookie2(const String& setCookie2)
333 {
334     m_setCookie2 = setCookie2;
335 }
336
337 KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
338 {
339     KURL url = m_url.copy();
340     url.setProtocol(m_secure ? "https" : "http");
341     return url;
342 }
343
344 const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end, HTTPHeaderMap* headers)
345 {
346     Vector<char> name;
347     Vector<char> value;
348     for (const char* p = start; p < end; p++) {
349         name.clear();
350         value.clear();
351
352         for (; p < end; p++) {
353             switch (*p) {
354             case '\r':
355                 if (name.isEmpty()) {
356                     if (p + 1 < end && *(p + 1) == '\n')
357                         return p + 2;
358                     LOG(Network, "CR doesn't follow LF p=%p end=%p", p, end);
359                     return 0;
360                 }
361                 LOG(Network, "Unexpected CR in name");
362                 return 0;
363             case '\n':
364                 LOG(Network, "Unexpected LF in name");
365                 return 0;
366             case ':':
367                 break;
368             default:
369                 if (*p >= 0x41 && *p <= 0x5a)
370                     name.append(*p + 0x20);
371                 else
372                     name.append(*p);
373                 continue;
374             }
375             if (*p == ':') {
376                 ++p;
377                 break;
378             }
379         }
380
381         for (; p < end && *p == 0x20; p++) { }
382
383         for (; p < end; p++) {
384             switch (*p) {
385             case '\r':
386                 break;
387             case '\n':
388                 LOG(Network, "Unexpected LF in value");
389                 return 0;
390             default:
391                 value.append(*p);
392             }
393             if (*p == '\r') {
394                 ++p;
395                 break;
396             }
397         }
398         if (p >= end || *p != '\n') {
399             LOG(Network, "CR doesn't follow LF after value p=%p end=%p", p, end);
400             return 0;
401         }
402         AtomicString nameStr(String::fromUTF8(name.data(), name.size()));
403         String valueStr = String::fromUTF8(value.data(), value.size());
404         LOG(Network, "name=%s value=%s", nameStr.string().utf8().data(), valueStr.utf8().data());
405         headers->add(nameStr, valueStr);
406     }
407     ASSERT_NOT_REACHED();
408     return 0;
409 }
410
411 bool WebSocketHandshake::processHeaders(const HTTPHeaderMap& headers)
412 {
413     for (HTTPHeaderMap::const_iterator it = headers.begin(); it != headers.end(); ++it) {
414         switch (m_mode) {
415         case Normal:
416             if (it->first == "websocket-origin")
417                 m_wsOrigin = it->second;
418             else if (it->first == "websocket-location")
419                 m_wsLocation = it->second;
420             else if (it->first == "websocket-protocol")
421                 m_wsProtocol = it->second;
422             else if (it->first == "set-cookie")
423                 m_setCookie = it->second;
424             else if (it->first == "set-cookie2")
425                 m_setCookie2 = it->second;
426             continue;
427         case Incomplete:
428         case Failed:
429         case Connected:
430             ASSERT_NOT_REACHED();
431         }
432         ASSERT_NOT_REACHED();
433     }
434     return true;
435 }
436
437 void WebSocketHandshake::checkResponseHeaders()
438 {
439     ASSERT(m_mode == Normal);
440     m_mode = Failed;
441     if (m_wsOrigin.isNull() || m_wsLocation.isNull())
442         return;
443
444     if (clientOrigin() != m_wsOrigin) {
445         LOG(Network, "Mismatch origin: %s != %s", clientOrigin().utf8().data(), m_wsOrigin.utf8().data());
446         return;
447     }
448     if (clientLocation() != m_wsLocation) {
449         LOG(Network, "Mismatch location: %s != %s", clientLocation().utf8().data(), m_wsLocation.utf8().data());
450         return;
451     }
452     if (!m_clientProtocol.isEmpty() && m_clientProtocol != m_wsProtocol) {
453         LOG(Network, "Mismatch protocol: %s != %s", m_clientProtocol.utf8().data(), m_wsProtocol.utf8().data());
454         return;
455     }
456     m_mode = Connected;
457     return;
458 }
459
460 }  // namespace WebCore
461
462 #endif  // ENABLE(WEB_SOCKETS)