WebSocket: Add useHixie76Protocol flag to WebSocketChannel and WebSocketHandshake
[WebKit-https.git] / Source / WebCore / websockets / WebSocketChannel.cpp
1 /*
2  * Copyright (C) 2011 Google Inc.  All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions are
6  * met:
7  *
8  *     * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  *     * Redistributions in binary form must reproduce the above
11  * copyright notice, this list of conditions and the following disclaimer
12  * in the documentation and/or other materials provided with the
13  * distribution.
14  *     * Neither the name of Google Inc. nor the names of its
15  * contributors may be used to endorse or promote products derived from
16  * this software without specific prior written permission.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  */
30
31 #include "config.h"
32
33 #if ENABLE(WEB_SOCKETS)
34
35 #include "WebSocketChannel.h"
36
37 #include "CookieJar.h"
38 #include "Document.h"
39 #include "InspectorInstrumentation.h"
40 #include "Logging.h"
41 #include "Page.h"
42 #include "ProgressTracker.h"
43 #include "ScriptCallStack.h"
44 #include "ScriptExecutionContext.h"
45 #include "Settings.h"
46 #include "SocketStreamError.h"
47 #include "SocketStreamHandle.h"
48 #include "WebSocketChannelClient.h"
49 #include "WebSocketHandshake.h"
50
51 #include <wtf/text/CString.h>
52 #include <wtf/text/WTFString.h>
53 #include <wtf/text/StringHash.h>
54 #include <wtf/Deque.h>
55 #include <wtf/FastMalloc.h>
56 #include <wtf/HashMap.h>
57
58 namespace WebCore {
59
60 const double TCPMaximumSegmentLifetime = 2 * 60.0;
61
62 WebSocketChannel::WebSocketChannel(ScriptExecutionContext* context, WebSocketChannelClient* client, const KURL& url, const String& protocol)
63     : m_context(context)
64     , m_client(client)
65     , m_buffer(0)
66     , m_bufferSize(0)
67     , m_resumeTimer(this, &WebSocketChannel::resumeTimerFired)
68     , m_suspended(false)
69     , m_closing(false)
70     , m_receivedClosingHandshake(false)
71     , m_closingTimer(this, &WebSocketChannel::closingTimerFired)
72     , m_closed(false)
73     , m_shouldDiscardReceivedData(false)
74     , m_unhandledBufferedAmount(0)
75     , m_identifier(0)
76     , m_useHixie76Protocol(true)
77 {
78     ASSERT(m_context->isDocument());
79     Document* document = static_cast<Document*>(m_context);
80     if (Settings* settings = document->settings())
81         m_useHixie76Protocol = settings->useHixie76WebSocketProtocol();
82     m_handshake = adoptPtr(new WebSocketHandshake(url, protocol, context, m_useHixie76Protocol));
83
84     if (Page* page = document->page())
85         m_identifier = page->progress()->createUniqueIdentifier();
86     if (m_identifier)
87         InspectorInstrumentation::didCreateWebSocket(m_context, m_identifier, url, m_context->url());
88 }
89
90 WebSocketChannel::~WebSocketChannel()
91 {
92     fastFree(m_buffer);
93 }
94
95 void WebSocketChannel::connect()
96 {
97     LOG(Network, "WebSocketChannel %p connect", this);
98     ASSERT(!m_handle);
99     ASSERT(!m_suspended);
100     m_handshake->reset();
101     ref();
102     m_handle = SocketStreamHandle::create(m_handshake->url(), this);
103 }
104
105 bool WebSocketChannel::send(const String& msg)
106 {
107     LOG(Network, "WebSocketChannel %p send %s", this, msg.utf8().data());
108     ASSERT(m_handle);
109     ASSERT(!m_suspended);
110     Vector<char> buf;
111     buf.append('\0');  // frame type
112     CString utf8 = msg.utf8();
113     buf.append(utf8.data(), utf8.length());
114     buf.append('\xff');  // frame end
115     return m_handle->send(buf.data(), buf.size());
116 }
117
118 unsigned long WebSocketChannel::bufferedAmount() const
119 {
120     LOG(Network, "WebSocketChannel %p bufferedAmount", this);
121     ASSERT(m_handle);
122     ASSERT(!m_suspended);
123     return m_handle->bufferedAmount();
124 }
125
126 void WebSocketChannel::close()
127 {
128     LOG(Network, "WebSocketChannel %p close", this);
129     ASSERT(!m_suspended);
130     if (!m_handle)
131         return;
132     startClosingHandshake();
133     if (m_closing && !m_closingTimer.isActive())
134         m_closingTimer.startOneShot(2 * TCPMaximumSegmentLifetime);
135 }
136
137 void WebSocketChannel::fail(const String& reason)
138 {
139     LOG(Network, "WebSocketChannel %p fail: %s", this, reason.utf8().data());
140     ASSERT(!m_suspended);
141     if (m_context)
142         m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, reason, 0, m_handshake->clientOrigin(), 0);
143     if (m_handle && !m_closed)
144         m_handle->disconnect(); // Will call didClose().
145 }
146
147 void WebSocketChannel::disconnect()
148 {
149     LOG(Network, "WebSocketChannel %p disconnect", this);
150     if (m_identifier && m_context)
151         InspectorInstrumentation::didCloseWebSocket(m_context, m_identifier);
152     m_handshake->clearScriptExecutionContext();
153     m_client = 0;
154     m_context = 0;
155     if (m_handle)
156         m_handle->disconnect();
157 }
158
159 void WebSocketChannel::suspend()
160 {
161     m_suspended = true;
162 }
163
164 void WebSocketChannel::resume()
165 {
166     m_suspended = false;
167     if ((m_buffer || m_closed) && m_client && !m_resumeTimer.isActive())
168         m_resumeTimer.startOneShot(0);
169 }
170
171 void WebSocketChannel::didOpen(SocketStreamHandle* handle)
172 {
173     LOG(Network, "WebSocketChannel %p didOpen", this);
174     ASSERT(handle == m_handle);
175     if (!m_context)
176         return;
177     if (m_identifier)
178         InspectorInstrumentation::willSendWebSocketHandshakeRequest(m_context, m_identifier, m_handshake->clientHandshakeRequest());
179     CString handshakeMessage = m_handshake->clientHandshakeMessage();
180     if (!handle->send(handshakeMessage.data(), handshakeMessage.length()))
181         fail("Failed to send WebSocket handshake.");
182 }
183
184 void WebSocketChannel::didClose(SocketStreamHandle* handle)
185 {
186     LOG(Network, "WebSocketChannel %p didClose", this);
187     if (m_identifier && m_context)
188         InspectorInstrumentation::didCloseWebSocket(m_context, m_identifier);
189     ASSERT_UNUSED(handle, handle == m_handle || !m_handle);
190     m_closed = true;
191     if (m_closingTimer.isActive())
192         m_closingTimer.stop();
193     if (m_handle) {
194         m_unhandledBufferedAmount = m_handle->bufferedAmount();
195         if (m_suspended)
196             return;
197         WebSocketChannelClient* client = m_client;
198         m_client = 0;
199         m_context = 0;
200         m_handle = 0;
201         if (client)
202             client->didClose(m_unhandledBufferedAmount, m_receivedClosingHandshake ? WebSocketChannelClient::ClosingHandshakeComplete : WebSocketChannelClient::ClosingHandshakeIncomplete);
203     }
204     deref();
205 }
206
207 void WebSocketChannel::didReceiveData(SocketStreamHandle* handle, const char* data, int len)
208 {
209     LOG(Network, "WebSocketChannel %p didReceiveData %d", this, len);
210     RefPtr<WebSocketChannel> protect(this); // The client can close the channel, potentially removing the last reference.
211     ASSERT(handle == m_handle);
212     if (!m_context) {
213         return;
214     }
215     if (len <= 0) {
216         handle->disconnect();
217         return;
218     }
219     if (!m_client) {
220         m_shouldDiscardReceivedData = true;
221         handle->disconnect();
222         return;
223     }
224     if (m_shouldDiscardReceivedData)
225         return;
226     if (!appendToBuffer(data, len)) {
227         m_shouldDiscardReceivedData = true;
228         fail("Ran out of memory while receiving WebSocket data.");
229         return;
230     }
231     while (!m_suspended && m_client && m_buffer)
232         if (!processBuffer())
233             break;
234 }
235
236 void WebSocketChannel::didFail(SocketStreamHandle* handle, const SocketStreamError& error)
237 {
238     LOG(Network, "WebSocketChannel %p didFail", this);
239     ASSERT(handle == m_handle || !m_handle);
240     if (m_context) {
241         String message;
242         if (error.isNull())
243             message = "WebSocket network error";
244         else if (error.localizedDescription().isNull())
245             message = "WebSocket network error: error code " + String::number(error.errorCode());
246         else
247             message = "WebSocket network error: " + error.localizedDescription();
248         String failingURL = error.failingURL();
249         ASSERT(failingURL.isNull() || m_handshake->url().string() == failingURL);
250         if (failingURL.isNull())
251             failingURL = m_handshake->url().string();
252         m_context->addMessage(OtherMessageSource, NetworkErrorMessageType, ErrorMessageLevel, message, 0, failingURL, 0);
253     }
254     m_shouldDiscardReceivedData = true;
255     handle->disconnect();
256 }
257
258 void WebSocketChannel::didReceiveAuthenticationChallenge(SocketStreamHandle*, const AuthenticationChallenge&)
259 {
260 }
261
262 void WebSocketChannel::didCancelAuthenticationChallenge(SocketStreamHandle*, const AuthenticationChallenge&)
263 {
264 }
265
266 bool WebSocketChannel::appendToBuffer(const char* data, size_t len)
267 {
268     size_t newBufferSize = m_bufferSize + len;
269     if (newBufferSize < m_bufferSize) {
270         LOG(Network, "WebSocket buffer overflow (%lu+%lu)", static_cast<unsigned long>(m_bufferSize), static_cast<unsigned long>(len));
271         return false;
272     }
273     char* newBuffer = 0;
274     if (!tryFastMalloc(newBufferSize).getValue(newBuffer))
275         return false;
276
277     if (m_buffer)
278         memcpy(newBuffer, m_buffer, m_bufferSize);
279     memcpy(newBuffer + m_bufferSize, data, len);
280     fastFree(m_buffer);
281     m_buffer = newBuffer;
282     m_bufferSize = newBufferSize;
283     return true;
284 }
285
286 void WebSocketChannel::skipBuffer(size_t len)
287 {
288     ASSERT(len <= m_bufferSize);
289     m_bufferSize -= len;
290     if (!m_bufferSize) {
291         fastFree(m_buffer);
292         m_buffer = 0;
293         return;
294     }
295     memmove(m_buffer, m_buffer + len, m_bufferSize);
296 }
297
298 bool WebSocketChannel::processBuffer()
299 {
300     ASSERT(!m_suspended);
301     ASSERT(m_client);
302     ASSERT(m_buffer);
303     LOG(Network, "WebSocketChannel %p processBuffer %lu", this, static_cast<unsigned long>(m_bufferSize));
304
305     if (m_shouldDiscardReceivedData)
306         return false;
307
308     if (m_receivedClosingHandshake) {
309         skipBuffer(m_bufferSize);
310         return false;
311     }
312
313     RefPtr<WebSocketChannel> protect(this); // The client can close the channel, potentially removing the last reference.
314
315     if (m_handshake->mode() == WebSocketHandshake::Incomplete) {
316         int headerLength = m_handshake->readServerHandshake(m_buffer, m_bufferSize);
317         if (headerLength <= 0)
318             return false;
319         if (m_handshake->mode() == WebSocketHandshake::Connected) {
320             if (m_identifier)
321                 InspectorInstrumentation::didReceiveWebSocketHandshakeResponse(m_context, m_identifier, m_handshake->serverHandshakeResponse());
322             if (!m_handshake->serverSetCookie().isEmpty()) {
323                 if (m_context->isDocument()) {
324                     Document* document = static_cast<Document*>(m_context);
325                     if (cookiesEnabled(document)) {
326                         ExceptionCode ec; // Exception (for sandboxed documents) ignored.
327                         document->setCookie(m_handshake->serverSetCookie(), ec);
328                     }
329                 }
330             }
331             // FIXME: handle set-cookie2.
332             LOG(Network, "WebSocketChannel %p connected", this);
333             skipBuffer(headerLength);
334             m_client->didConnect();
335             LOG(Network, "remaining in read buf %lu", static_cast<unsigned long>(m_bufferSize));
336             return m_buffer;
337         }
338         ASSERT(m_handshake->mode() == WebSocketHandshake::Failed);
339         LOG(Network, "WebSocketChannel %p connection failed", this);
340         skipBuffer(headerLength);
341         m_shouldDiscardReceivedData = true;
342         fail(m_handshake->failureReason());
343         return false;
344     }
345     if (m_handshake->mode() != WebSocketHandshake::Connected)
346         return false;
347
348     const char* nextFrame = m_buffer;
349     const char* p = m_buffer;
350     const char* end = p + m_bufferSize;
351
352     unsigned char frameByte = static_cast<unsigned char>(*p++);
353     if ((frameByte & 0x80) == 0x80) {
354         size_t length = 0;
355         bool errorFrame = false;
356         while (p < end) {
357             if (length > std::numeric_limits<size_t>::max() / 128) {
358                 LOG(Network, "frame length overflow %lu", static_cast<unsigned long>(length));
359                 errorFrame = true;
360                 break;
361             }
362             size_t newLength = length * 128;
363             unsigned char msgByte = static_cast<unsigned char>(*p);
364             unsigned int lengthMsgByte = msgByte & 0x7f;
365             if (newLength > std::numeric_limits<size_t>::max() - lengthMsgByte) {
366                 LOG(Network, "frame length overflow %lu+%u", static_cast<unsigned long>(newLength), lengthMsgByte);
367                 errorFrame = true;
368                 break;
369             }
370             newLength += lengthMsgByte;
371             if (newLength < length) { // sanity check
372                 LOG(Network, "frame length integer wrap %lu->%lu", static_cast<unsigned long>(length), static_cast<unsigned long>(newLength));
373                 errorFrame = true;
374                 break;
375             }
376             length = newLength;
377             ++p;
378             if (!(msgByte & 0x80))
379                 break;
380         }
381         if (p + length < p) {
382             LOG(Network, "frame buffer pointer wrap %p+%lu->%p", p, static_cast<unsigned long>(length), p + length);
383             errorFrame = true;
384         }
385         if (errorFrame) {
386             skipBuffer(m_bufferSize); // Save memory.
387             m_shouldDiscardReceivedData = true;
388             m_client->didReceiveMessageError();
389             fail("WebSocket frame length too large");
390             return false;
391         }
392         ASSERT(p + length >= p);
393         if (p + length <= end) {
394             p += length;
395             nextFrame = p;
396             ASSERT(nextFrame > m_buffer);
397             skipBuffer(nextFrame - m_buffer);
398             if (frameByte == 0xff && !length) {
399                 m_receivedClosingHandshake = true;
400                 startClosingHandshake();
401                 if (m_closing)
402                     m_handle->close(); // close after sending FF 00.
403             } else
404                 m_client->didReceiveMessageError();
405             return m_buffer;
406         }
407         return false;
408     }
409
410     const char* msgStart = p;
411     while (p < end && *p != '\xff')
412         ++p;
413     if (p < end && *p == '\xff') {
414         int msgLength = p - msgStart;
415         ++p;
416         nextFrame = p;
417         if (frameByte == 0x00) {
418             String msg = String::fromUTF8(msgStart, msgLength);
419             skipBuffer(nextFrame - m_buffer);
420             m_client->didReceiveMessage(msg);
421         } else {
422             skipBuffer(nextFrame - m_buffer);
423             m_client->didReceiveMessageError();
424         }
425         return m_buffer;
426     }
427     return false;
428 }
429
430 void WebSocketChannel::resumeTimerFired(Timer<WebSocketChannel>* timer)
431 {
432     ASSERT_UNUSED(timer, timer == &m_resumeTimer);
433
434     RefPtr<WebSocketChannel> protect(this); // The client can close the channel, potentially removing the last reference.
435     while (!m_suspended && m_client && m_buffer)
436         if (!processBuffer())
437             break;
438     if (!m_suspended && m_client && m_closed && m_handle)
439         didClose(m_handle.get());
440 }
441
442 void WebSocketChannel::startClosingHandshake()
443 {
444     LOG(Network, "WebSocketChannel %p closing %d %d", this, m_closing, m_receivedClosingHandshake);
445     if (m_closing)
446         return;
447     ASSERT(m_handle);
448     Vector<char> buf;
449     buf.append('\xff');
450     buf.append('\0');
451     if (!m_handle->send(buf.data(), buf.size())) {
452         m_handle->disconnect();
453         return;
454     }
455     m_closing = true;
456     if (m_client)
457         m_client->didStartClosingHandshake();
458 }
459
460 void WebSocketChannel::closingTimerFired(Timer<WebSocketChannel>* timer)
461 {
462     LOG(Network, "WebSocketChannel %p closing timer", this);
463     ASSERT_UNUSED(timer, &m_closingTimer == timer);
464     if (m_handle)
465         m_handle->disconnect();
466 }
467
468 }  // namespace WebCore
469
470 #endif  // ENABLE(WEB_SOCKETS)