7f30738f8e18be560112dc74477b603fcc345140
[WebKit-https.git] / Source / WebCore / Modules / websockets / WebSocket.cpp
1 /*
2  * Copyright (C) 2011 Google Inc.  All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions are
6  * met:
7  *
8  *     * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  *     * Redistributions in binary form must reproduce the above
11  * copyright notice, this list of conditions and the following disclaimer
12  * in the documentation and/or other materials provided with the
13  * distribution.
14  *     * Neither the name of Google Inc. nor the names of its
15  * contributors may be used to endorse or promote products derived from
16  * this software without specific prior written permission.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  */
30
31 #include "config.h"
32
33 #if ENABLE(WEB_SOCKETS)
34
35 #include "WebSocket.h"
36
37 #include "Blob.h"
38 #include "BlobData.h"
39 #include "CloseEvent.h"
40 #include "ContentSecurityPolicy.h"
41 #include "DOMWindow.h"
42 #include "Event.h"
43 #include "EventException.h"
44 #include "EventListener.h"
45 #include "EventNames.h"
46 #include "ExceptionCode.h"
47 #include "Logging.h"
48 #include "MessageEvent.h"
49 #include "ScriptCallStack.h"
50 #include "ScriptExecutionContext.h"
51 #include "SecurityOrigin.h"
52 #include "ThreadableWebSocketChannel.h"
53 #include "WebSocketChannel.h"
54 #include <wtf/ArrayBuffer.h>
55 #include <wtf/ArrayBufferView.h>
56 #include <wtf/HashSet.h>
57 #include <wtf/OwnPtr.h>
58 #include <wtf/PassOwnPtr.h>
59 #include <wtf/StdLibExtras.h>
60 #include <wtf/text/CString.h>
61 #include <wtf/text/StringBuilder.h>
62 #include <wtf/text/WTFString.h>
63
64 using namespace std;
65
66 namespace WebCore {
67
68 const size_t maxReasonSizeInBytes = 123;
69
70 static inline bool isValidProtocolCharacter(UChar character)
71 {
72     // Hybi-10 says "(Subprotocol string must consist of) characters in the range U+0021 to U+007E not including
73     // separator characters as defined in [RFC2616]."
74     const UChar minimumProtocolCharacter = '!'; // U+0021.
75     const UChar maximumProtocolCharacter = '~'; // U+007E.
76     return character >= minimumProtocolCharacter && character <= maximumProtocolCharacter
77         && character != '"' && character != '(' && character != ')' && character != ',' && character != '/'
78         && !(character >= ':' && character <= '@') // U+003A - U+0040 (':', ';', '<', '=', '>', '?', '@').
79         && !(character >= '[' && character <= ']') // U+005B - U+005D ('[', '\\', ']').
80         && character != '{' && character != '}';
81 }
82
83 static bool isValidProtocolString(const String& protocol)
84 {
85     if (protocol.isEmpty())
86         return false;
87     for (size_t i = 0; i < protocol.length(); ++i) {
88         if (!isValidProtocolCharacter(protocol[i]))
89             return false;
90     }
91     return true;
92 }
93
94 static String encodeProtocolString(const String& protocol)
95 {
96     StringBuilder builder;
97     for (size_t i = 0; i < protocol.length(); i++) {
98         if (protocol[i] < 0x20 || protocol[i] > 0x7E)
99             builder.append(String::format("\\u%04X", protocol[i]));
100         else if (protocol[i] == 0x5c)
101             builder.append("\\\\");
102         else
103             builder.append(protocol[i]);
104     }
105     return builder.toString();
106 }
107
108 static String joinStrings(const Vector<String>& strings, const char* separator)
109 {
110     StringBuilder builder;
111     for (size_t i = 0; i < strings.size(); ++i) {
112         if (i)
113             builder.append(separator);
114         builder.append(strings[i]);
115     }
116     return builder.toString();
117 }
118
119 static unsigned long saturateAdd(unsigned long a, unsigned long b)
120 {
121     if (numeric_limits<unsigned long>::max() - a < b)
122         return numeric_limits<unsigned long>::max();
123     return a + b;
124 }
125
126 static bool webSocketsAvailable = false;
127
128 void WebSocket::setIsAvailable(bool available)
129 {
130     webSocketsAvailable = available;
131 }
132
133 bool WebSocket::isAvailable()
134 {
135     return webSocketsAvailable;
136 }
137
138 const char* WebSocket::subProtocolSeperator()
139 {
140     return ", ";
141 }
142
143 WebSocket::WebSocket(ScriptExecutionContext* context)
144     : ActiveDOMObject(context, this)
145     , m_state(CONNECTING)
146     , m_bufferedAmount(0)
147     , m_bufferedAmountAfterClose(0)
148     , m_binaryType(BinaryTypeBlob)
149     , m_subprotocol("")
150     , m_extensions("")
151 {
152 }
153
154 WebSocket::~WebSocket()
155 {
156     if (m_channel)
157         m_channel->disconnect();
158 }
159
160 PassRefPtr<WebSocket> WebSocket::create(ScriptExecutionContext* context)
161 {
162     RefPtr<WebSocket> webSocket(adoptRef(new WebSocket(context)));
163     webSocket->suspendIfNeeded();
164     return webSocket.release();
165 }
166
167 PassRefPtr<WebSocket> WebSocket::create(ScriptExecutionContext* context, const String& url, ExceptionCode& ec)
168 {
169     Vector<String> protocols;
170     return WebSocket::create(context, url, protocols, ec);
171 }
172
173 PassRefPtr<WebSocket> WebSocket::create(ScriptExecutionContext* context, const String& url, const Vector<String>& protocols, ExceptionCode& ec)
174 {
175     if (url.isNull()) {
176         ec = SYNTAX_ERR;
177         return 0;
178     }
179
180     RefPtr<WebSocket> webSocket(adoptRef(new WebSocket(context)));
181     webSocket->suspendIfNeeded();
182
183     webSocket->connect(context->completeURL(url), protocols, ec);
184     if (ec)
185         return 0;
186
187     return webSocket.release();
188 }
189
190 PassRefPtr<WebSocket> WebSocket::create(ScriptExecutionContext* context, const String& url, const String& protocol, ExceptionCode& ec)
191 {
192     Vector<String> protocols;
193     protocols.append(protocol);
194     return WebSocket::create(context, url, protocols, ec);
195 }
196
197 void WebSocket::connect(const String& url, ExceptionCode& ec)
198 {
199     Vector<String> protocols;
200     connect(url, protocols, ec);
201 }
202
203 void WebSocket::connect(const String& url, const String& protocol, ExceptionCode& ec)
204 {
205     Vector<String> protocols;
206     protocols.append(protocol);
207     connect(url, protocols, ec);
208 }
209
210 void WebSocket::connect(const String& url, const Vector<String>& protocols, ExceptionCode& ec)
211 {
212     LOG(Network, "WebSocket %p connect to %s", this, url.utf8().data());
213     m_url = KURL(KURL(), url);
214
215     if (!m_url.isValid()) {
216         scriptExecutionContext()->addConsoleMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Invalid url for WebSocket " + m_url.string(), scriptExecutionContext()->securityOrigin()->toString());
217         m_state = CLOSED;
218         ec = SYNTAX_ERR;
219         return;
220     }
221
222     if (!m_url.protocolIs("ws") && !m_url.protocolIs("wss")) {
223         scriptExecutionContext()->addConsoleMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Wrong url scheme for WebSocket " + m_url.string(), scriptExecutionContext()->securityOrigin()->toString());
224         m_state = CLOSED;
225         ec = SYNTAX_ERR;
226         return;
227     }
228     if (m_url.hasFragmentIdentifier()) {
229         scriptExecutionContext()->addConsoleMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "URL has fragment component " + m_url.string(), scriptExecutionContext()->securityOrigin()->toString());
230         m_state = CLOSED;
231         ec = SYNTAX_ERR;
232         return;
233     }
234     if (!portAllowed(m_url)) {
235         scriptExecutionContext()->addConsoleMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "WebSocket port " + String::number(m_url.port()) + " blocked", scriptExecutionContext()->securityOrigin()->toString());
236         m_state = CLOSED;
237         ec = SECURITY_ERR;
238         return;
239     }
240
241     if (!scriptExecutionContext()->contentSecurityPolicy()->allowConnectToSource(m_url)) {
242         m_state = CLOSED;
243
244         // FIXME: Should this be throwing an exception?
245         ec = SECURITY_ERR;
246         return;
247     }
248
249     m_channel = ThreadableWebSocketChannel::create(scriptExecutionContext(), this);
250
251     // FIXME: There is a disagreement about restriction of subprotocols between WebSocket API and hybi-10 protocol
252     // draft. The former simply says "only characters in the range U+0021 to U+007E are allowed," while the latter
253     // imposes a stricter rule: "the elements MUST be non-empty strings with characters as defined in [RFC2616],
254     // and MUST all be unique strings."
255     //
256     // Here, we throw SYNTAX_ERR if the given protocols do not meet the latter criteria. This behavior does not
257     // comply with WebSocket API specification, but it seems to be the only reasonable way to handle this conflict.
258     for (size_t i = 0; i < protocols.size(); ++i) {
259         if (!isValidProtocolString(protocols[i])) {
260             scriptExecutionContext()->addConsoleMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Wrong protocol for WebSocket '" + encodeProtocolString(protocols[i]) + "'", scriptExecutionContext()->securityOrigin()->toString());
261             m_state = CLOSED;
262             ec = SYNTAX_ERR;
263             return;
264         }
265     }
266     HashSet<String> visited;
267     for (size_t i = 0; i < protocols.size(); ++i) {
268         if (visited.contains(protocols[i])) {
269             scriptExecutionContext()->addConsoleMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "WebSocket protocols contain duplicates: '" + encodeProtocolString(protocols[i]) + "'", scriptExecutionContext()->securityOrigin()->toString());
270             m_state = CLOSED;
271             ec = SYNTAX_ERR;
272             return;
273         }
274         visited.add(protocols[i]);
275     }
276
277     String protocolString;
278     if (!protocols.isEmpty())
279         protocolString = joinStrings(protocols, subProtocolSeperator());
280
281     m_channel->connect(m_url, protocolString);
282     ActiveDOMObject::setPendingActivity(this);
283 }
284
285 bool WebSocket::send(const String& message, ExceptionCode& ec)
286 {
287     LOG(Network, "WebSocket %p send %s", this, message.utf8().data());
288     if (m_state == CONNECTING) {
289         ec = INVALID_STATE_ERR;
290         return false;
291     }
292     // No exception is raised if the connection was once established but has subsequently been closed.
293     if (m_state == CLOSING || m_state == CLOSED) {
294         size_t payloadSize = message.utf8().length();
295         m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, payloadSize);
296         m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, getFramingOverhead(payloadSize));
297         return false;
298     }
299     ASSERT(m_channel);
300     ThreadableWebSocketChannel::SendResult result = m_channel->send(message);
301     if (result == ThreadableWebSocketChannel::InvalidMessage) {
302         scriptExecutionContext()->addConsoleMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Websocket message contains invalid character(s).");
303         ec = SYNTAX_ERR;
304         return false;
305     }
306     return result == ThreadableWebSocketChannel::SendSuccess;
307 }
308
309 bool WebSocket::send(ArrayBuffer* binaryData, ExceptionCode& ec)
310 {
311     LOG(Network, "WebSocket %p send arraybuffer %p", this, binaryData);
312     ASSERT(binaryData);
313     if (m_state == CONNECTING) {
314         ec = INVALID_STATE_ERR;
315         return false;
316     }
317     if (m_state == CLOSING || m_state == CLOSED) {
318         unsigned payloadSize = binaryData->byteLength();
319         m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, payloadSize);
320         m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, getFramingOverhead(payloadSize));
321         return false;
322     }
323     ASSERT(m_channel);
324     return m_channel->send(*binaryData, 0, binaryData->byteLength()) == ThreadableWebSocketChannel::SendSuccess;
325 }
326
327 bool WebSocket::send(ArrayBufferView* arrayBufferView, ExceptionCode& ec)
328 {
329     LOG(Network, "WebSocket %p send arraybufferview %p", this, arrayBufferView);
330     ASSERT(arrayBufferView);
331     if (m_state == CONNECTING) {
332         ec = INVALID_STATE_ERR;
333         return false;
334     }
335     if (m_state == CLOSING || m_state == CLOSED) {
336         unsigned payloadSize = arrayBufferView->byteLength();
337         m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, payloadSize);
338         m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, getFramingOverhead(payloadSize));
339         return false;
340     }
341     ASSERT(m_channel);
342     RefPtr<ArrayBuffer> arrayBuffer(arrayBufferView->buffer());
343     return m_channel->send(*arrayBuffer, arrayBufferView->byteOffset(), arrayBufferView->byteLength()) == ThreadableWebSocketChannel::SendSuccess;
344 }
345
346 bool WebSocket::send(Blob* binaryData, ExceptionCode& ec)
347 {
348     LOG(Network, "WebSocket %p send blob %s", this, binaryData->url().string().utf8().data());
349     ASSERT(binaryData);
350     if (m_state == CONNECTING) {
351         ec = INVALID_STATE_ERR;
352         return false;
353     }
354     if (m_state == CLOSING || m_state == CLOSED) {
355         unsigned long payloadSize = static_cast<unsigned long>(binaryData->size());
356         m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, payloadSize);
357         m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, getFramingOverhead(payloadSize));
358         return false;
359     }
360     ASSERT(m_channel);
361     return m_channel->send(*binaryData) == ThreadableWebSocketChannel::SendSuccess;
362 }
363
364 void WebSocket::close(int code, const String& reason, ExceptionCode& ec)
365 {
366     if (code == WebSocketChannel::CloseEventCodeNotSpecified)
367         LOG(Network, "WebSocket %p close without code and reason", this);
368     else {
369         LOG(Network, "WebSocket %p close with code = %d, reason = %s", this, code, reason.utf8().data());
370         if (!(code == WebSocketChannel::CloseEventCodeNormalClosure || (WebSocketChannel::CloseEventCodeMinimumUserDefined <= code && code <= WebSocketChannel::CloseEventCodeMaximumUserDefined))) {
371             ec = INVALID_ACCESS_ERR;
372             return;
373         }
374         CString utf8 = reason.utf8(String::StrictConversionReplacingUnpairedSurrogatesWithFFFD);
375         if (utf8.length() > maxReasonSizeInBytes) {
376             scriptExecutionContext()->addConsoleMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "WebSocket close message is too long.");
377             ec = SYNTAX_ERR;
378             return;
379         }
380     }
381
382     if (m_state == CLOSING || m_state == CLOSED)
383         return;
384     if (m_state == CONNECTING) {
385         m_state = CLOSING;
386         m_channel->fail("WebSocket is closed before the connection is established.");
387         return;
388     }
389     m_state = CLOSING;
390     if (m_channel)
391         m_channel->close(code, reason);
392 }
393
394 const KURL& WebSocket::url() const
395 {
396     return m_url;
397 }
398
399 WebSocket::State WebSocket::readyState() const
400 {
401     return m_state;
402 }
403
404 unsigned long WebSocket::bufferedAmount() const
405 {
406     return saturateAdd(m_bufferedAmount, m_bufferedAmountAfterClose);
407 }
408
409 String WebSocket::protocol() const
410 {
411     return m_subprotocol;
412 }
413
414 String WebSocket::extensions() const
415 {
416     return m_extensions;
417 }
418
419 String WebSocket::binaryType() const
420 {
421     switch (m_binaryType) {
422     case BinaryTypeBlob:
423         return "blob";
424     case BinaryTypeArrayBuffer:
425         return "arraybuffer";
426     }
427     ASSERT_NOT_REACHED();
428     return String();
429 }
430
431 void WebSocket::setBinaryType(const String& binaryType)
432 {
433     if (binaryType == "blob") {
434         m_binaryType = BinaryTypeBlob;
435         return;
436     }
437     if (binaryType == "arraybuffer") {
438         m_binaryType = BinaryTypeArrayBuffer;
439         return;
440     }
441     scriptExecutionContext()->addConsoleMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "'" + binaryType + "' is not a valid value for binaryType; binaryType remains unchanged.");
442 }
443
444 const AtomicString& WebSocket::interfaceName() const
445 {
446     return eventNames().interfaceForWebSocket;
447 }
448
449 ScriptExecutionContext* WebSocket::scriptExecutionContext() const
450 {
451     return ActiveDOMObject::scriptExecutionContext();
452 }
453
454 void WebSocket::contextDestroyed()
455 {
456     LOG(Network, "WebSocket %p scriptExecutionContext destroyed", this);
457     ASSERT(!m_channel);
458     ASSERT(m_state == CLOSED);
459     ActiveDOMObject::contextDestroyed();
460 }
461
462 bool WebSocket::canSuspend() const
463 {
464     return !m_channel;
465 }
466
467 void WebSocket::suspend(ReasonForSuspension)
468 {
469     if (m_channel)
470         m_channel->suspend();
471 }
472
473 void WebSocket::resume()
474 {
475     if (m_channel)
476         m_channel->resume();
477 }
478
479 void WebSocket::stop()
480 {
481     bool pending = hasPendingActivity();
482     if (m_channel)
483         m_channel->disconnect();
484     m_channel = 0;
485     m_state = CLOSED;
486     ActiveDOMObject::stop();
487     if (pending)
488         ActiveDOMObject::unsetPendingActivity(this);
489 }
490
491 void WebSocket::didConnect()
492 {
493     LOG(Network, "WebSocket %p didConnect", this);
494     if (m_state != CONNECTING) {
495         didClose(0, ClosingHandshakeIncomplete, WebSocketChannel::CloseEventCodeAbnormalClosure, "");
496         return;
497     }
498     ASSERT(scriptExecutionContext());
499     m_state = OPEN;
500     m_subprotocol = m_channel->subprotocol();
501     m_extensions = m_channel->extensions();
502     dispatchEvent(Event::create(eventNames().openEvent, false, false));
503 }
504
505 void WebSocket::didReceiveMessage(const String& msg)
506 {
507     LOG(Network, "WebSocket %p didReceiveMessage %s", this, msg.utf8().data());
508     if (m_state != OPEN && m_state != CLOSING)
509         return;
510     ASSERT(scriptExecutionContext());
511     dispatchEvent(MessageEvent::create(msg));
512 }
513
514 void WebSocket::didReceiveBinaryData(PassOwnPtr<Vector<char> > binaryData)
515 {
516     switch (m_binaryType) {
517     case BinaryTypeBlob: {
518         size_t size = binaryData->size();
519         RefPtr<RawData> rawData = RawData::create();
520         binaryData->swap(*rawData->mutableData());
521         OwnPtr<BlobData> blobData = BlobData::create();
522         blobData->appendData(rawData.release(), 0, BlobDataItem::toEndOfFile);
523         RefPtr<Blob> blob = Blob::create(blobData.release(), size);
524         dispatchEvent(MessageEvent::create(blob.release()));
525         break;
526     }
527
528     case BinaryTypeArrayBuffer:
529         dispatchEvent(MessageEvent::create(ArrayBuffer::create(binaryData->data(), binaryData->size())));
530         break;
531     }
532 }
533
534 void WebSocket::didReceiveMessageError()
535 {
536     LOG(Network, "WebSocket %p didReceiveErrorMessage", this);
537     ASSERT(scriptExecutionContext());
538     dispatchEvent(Event::create(eventNames().errorEvent, false, false));
539 }
540
541 void WebSocket::didUpdateBufferedAmount(unsigned long bufferedAmount)
542 {
543     LOG(Network, "WebSocket %p didUpdateBufferedAmount %lu", this, bufferedAmount);
544     if (m_state == CLOSED)
545         return;
546     m_bufferedAmount = bufferedAmount;
547 }
548
549 void WebSocket::didStartClosingHandshake()
550 {
551     LOG(Network, "WebSocket %p didStartClosingHandshake", this);
552     m_state = CLOSING;
553 }
554
555 void WebSocket::didClose(unsigned long unhandledBufferedAmount, ClosingHandshakeCompletionStatus closingHandshakeCompletion, unsigned short code, const String& reason)
556 {
557     LOG(Network, "WebSocket %p didClose", this);
558     if (!m_channel)
559         return;
560     bool wasClean = m_state == CLOSING && !unhandledBufferedAmount && closingHandshakeCompletion == ClosingHandshakeComplete && code != WebSocketChannel::CloseEventCodeAbnormalClosure;
561     m_state = CLOSED;
562     m_bufferedAmount = unhandledBufferedAmount;
563     ASSERT(scriptExecutionContext());
564     RefPtr<CloseEvent> event = CloseEvent::create(wasClean, code, reason);
565     dispatchEvent(event);
566     if (m_channel) {
567         m_channel->disconnect();
568         m_channel = 0;
569     }
570     if (hasPendingActivity())
571         ActiveDOMObject::unsetPendingActivity(this);
572 }
573
574 EventTargetData* WebSocket::eventTargetData()
575 {
576     return &m_eventTargetData;
577 }
578
579 EventTargetData* WebSocket::ensureEventTargetData()
580 {
581     return &m_eventTargetData;
582 }
583
584 size_t WebSocket::getFramingOverhead(size_t payloadSize)
585 {
586     static const size_t hybiBaseFramingOverhead = 2; // Every frame has at least two-byte header.
587     static const size_t hybiMaskingKeyLength = 4; // Every frame from client must have masking key.
588     static const size_t minimumPayloadSizeWithTwoByteExtendedPayloadLength = 126;
589     static const size_t minimumPayloadSizeWithEightByteExtendedPayloadLength = 0x10000;
590     size_t overhead = hybiBaseFramingOverhead + hybiMaskingKeyLength;
591     if (payloadSize >= minimumPayloadSizeWithEightByteExtendedPayloadLength)
592         overhead += 8;
593     else if (payloadSize >= minimumPayloadSizeWithTwoByteExtendedPayloadLength)
594         overhead += 2;
595     return overhead;
596 }
597
598 }  // namespace WebCore
599
600 #endif