[Content Filtering] Crash when allowing a 0-byte resource to load
[WebKit-https.git] / Source / WebCore / loader / ContentFilter.cpp
index 6d8e88e5fd765890aac364a2bb59c4b05be8b414..2f600cb816b4f46d662efa745bc37c83f9a1750b 100644 (file)
@@ -30,6 +30,7 @@
 
 #include "CachedRawResource.h"
 #include "ContentFilterUnblockHandler.h"
+#include "DocumentLoader.h"
 #include "Logging.h"
 #include "NetworkExtensionContentFilter.h"
 #include "ParentalControlsContentFilter.h"
@@ -47,7 +48,9 @@ Vector<ContentFilter::Type>& ContentFilter::types()
 {
     static NeverDestroyed<Vector<ContentFilter::Type>> types {
         Vector<ContentFilter::Type> {
+#if HAVE(PARENTAL_CONTROLS)
             type<ParentalControlsContentFilter>(),
+#endif
 #if HAVE(NETWORK_EXTENSION)
             type<NetworkExtensionContentFilter>()
 #endif
@@ -56,7 +59,7 @@ Vector<ContentFilter::Type>& ContentFilter::types()
     return types;
 }
 
-std::unique_ptr<ContentFilter> ContentFilter::createIfNeeded(DecisionFunction decisionFunction)
+std::unique_ptr<ContentFilter> ContentFilter::createIfEnabled(DocumentLoader& documentLoader)
 {
     Container filters;
     for (auto& type : types()) {
@@ -71,12 +74,12 @@ std::unique_ptr<ContentFilter> ContentFilter::createIfNeeded(DecisionFunction de
     if (filters.isEmpty())
         return nullptr;
 
-    return std::make_unique<ContentFilter>(WTF::move(filters), WTF::move(decisionFunction));
+    return std::make_unique<ContentFilter>(WTF::move(filters), documentLoader);
 }
 
-ContentFilter::ContentFilter(Container contentFilters, DecisionFunction decisionFunction)
+ContentFilter::ContentFilter(Container contentFilters, DocumentLoader& documentLoader)
     : m_contentFilters { WTF::move(contentFilters) }
-    , m_decisionFunction { WTF::move(decisionFunction) }
+    , m_documentLoader { documentLoader }
 {
     LOG(ContentFiltering, "Creating ContentFilter with %zu platform content filter(s).\n", m_contentFilters.size());
     ASSERT(!m_contentFilters.isEmpty());
@@ -145,39 +148,78 @@ String ContentFilter::unblockRequestDeniedScript() const
 
 void ContentFilter::responseReceived(CachedResource* resource, const ResourceResponse& response)
 {
+    ASSERT(resource);
+    ASSERT(resource == m_mainResource);
+    ASSERT(m_state != State::Initialized);
     LOG(ContentFiltering, "ContentFilter received response from <%s>.\n", response.url().string().ascii().data());
-    ASSERT(m_state == State::Filtering);
-    ASSERT_UNUSED(resource, resource == m_mainResource.get());
-    forEachContentFilterUntilBlocked([&response](PlatformContentFilter& contentFilter) {
-        contentFilter.responseReceived(response);
-    });
+
+    if (m_state == State::Filtering) {
+        forEachContentFilterUntilBlocked([&response](PlatformContentFilter& contentFilter) {
+            contentFilter.responseReceived(response);
+        });
+    }
+
+    if (m_state != State::Blocked)
+        m_documentLoader.responseReceived(resource, response);
 }
 
 void ContentFilter::dataReceived(CachedResource* resource, const char* data, int length)
 {
+    ASSERT(resource);
+    ASSERT(resource == m_mainResource);
+    ASSERT(m_state != State::Initialized);
     LOG(ContentFiltering, "ContentFilter received %d bytes of data from <%s>.\n", length, resource->url().string().ascii().data());
-    ASSERT(m_state == State::Filtering);
-    ASSERT_UNUSED(resource, resource == m_mainResource.get());
-    forEachContentFilterUntilBlocked([data, length](PlatformContentFilter& contentFilter) {
-        contentFilter.addData(data, length);
-    });
+
+    if (m_state == State::Filtering) {
+        forEachContentFilterUntilBlocked([data, length](PlatformContentFilter& contentFilter) {
+            contentFilter.addData(data, length);
+        });
+
+        if (m_state == State::Allowed)
+            deliverResourceData(*resource);
+        return;
+    }
+
+    if (m_state == State::Allowed)
+        m_documentLoader.dataReceived(resource, data, length);
 }
 
 void ContentFilter::redirectReceived(CachedResource* resource, ResourceRequest& request, const ResourceResponse& redirectResponse)
 {
-    ASSERT(m_state == State::Filtering);
-    ASSERT_UNUSED(resource, resource == m_mainResource.get());
-    willSendRequest(request, redirectResponse);
+    ASSERT(resource);
+    ASSERT(resource == m_mainResource);
+    ASSERT(m_state != State::Initialized);
+
+    if (m_state == State::Filtering)
+        willSendRequest(request, redirectResponse);
+
+    if (m_state != State::Blocked)
+        m_documentLoader.redirectReceived(resource, request, redirectResponse);
 }
 
 void ContentFilter::notifyFinished(CachedResource* resource)
 {
+    ASSERT(resource);
+    ASSERT(resource == m_mainResource);
+    ASSERT(m_state != State::Initialized);
     LOG(ContentFiltering, "ContentFilter will finish filtering main resource at <%s>.\n", resource->url().string().ascii().data());
-    ASSERT(m_state == State::Filtering);
-    ASSERT_UNUSED(resource, resource == m_mainResource.get());
-    forEachContentFilterUntilBlocked([](PlatformContentFilter& contentFilter) {
-        contentFilter.finishedAddingData();
-    });
+
+    if (resource->errorOccurred()) {
+        m_documentLoader.notifyFinished(resource);
+        return;
+    }
+
+    if (m_state == State::Filtering) {
+        forEachContentFilterUntilBlocked([](PlatformContentFilter& contentFilter) {
+            contentFilter.finishedAddingData();
+        });
+
+        if (m_state != State::Blocked)
+            deliverResourceData(*resource);
+    }
+
+    if (m_state != State::Blocked)
+        m_documentLoader.notifyFinished(resource);
 }
 
 void ContentFilter::forEachContentFilterUntilBlocked(std::function<void(PlatformContentFilter&)> function)
@@ -211,10 +253,14 @@ void ContentFilter::didDecide(State state)
     ASSERT(state == State::Allowed || state == State::Blocked);
     LOG(ContentFiltering, "ContentFilter decided load should be %s for main resource at <%s>.\n", state == State::Allowed ? "allowed" : "blocked", m_mainResource ? m_mainResource->url().string().ascii().data() : "");
     m_state = state;
+    m_documentLoader.contentFilterDidDecide();
+}
 
-    // Calling m_decisionFunction might delete |this|.
-    if (m_decisionFunction)
-        m_decisionFunction();
+void ContentFilter::deliverResourceData(CachedResource& resource)
+{
+    ASSERT(resource.dataBufferingPolicy() == BufferData);
+    if (auto* resourceBuffer = resource.resourceBuffer())
+        m_documentLoader.dataReceived(&resource, resourceBuffer->data(), resourceBuffer->size());
 }
 
 } // namespace WebCore