Extend URL filter's Term definition to support groups/subpatterns
[WebKit-https.git] / Source / WebCore / contentextensions / URLFilterParser.cpp
index 26c3f46..4c7d9ce 100644 (file)
@@ -31,6 +31,7 @@
 #include "NFA.h"
 #include <JavaScriptCore/YarrParser.h>
 #include <wtf/BitVector.h>
+#include <wtf/Deque.h>
 
 namespace WebCore {
 
@@ -66,13 +67,20 @@ public:
     }
 
     enum CharacterSetTermTag { CharacterSetTerm };
-    Term(CharacterSetTermTag, bool isInverted)
+    explicit Term(CharacterSetTermTag, bool isInverted)
         : m_termType(TermType::CharacterSet)
     {
         new (NotNull, &m_atomData.characterSet) CharacterSet();
         m_atomData.characterSet.inverted = isInverted;
     }
 
+    enum GroupTermTag { GroupTerm };
+    explicit Term(GroupTermTag)
+        : m_termType(TermType::Group)
+    {
+        new (NotNull, &m_atomData.group) Group();
+    }
+
     Term(const Term& other)
         : m_termType(other.m_termType)
         , m_quantifier(other.m_quantifier)
@@ -84,6 +92,9 @@ public:
         case TermType::CharacterSet:
             new (NotNull, &m_atomData.characterSet) CharacterSet(other.m_atomData.characterSet);
             break;
+        case TermType::Group:
+            new (NotNull, &m_atomData.group) Group(other.m_atomData.group);
+            break;
         }
     }
 
@@ -98,6 +109,9 @@ public:
         case TermType::CharacterSet:
             new (NotNull, &m_atomData.characterSet) CharacterSet(WTF::move(other.m_atomData.characterSet));
             break;
+        case TermType::Group:
+            new (NotNull, &m_atomData.group) Group(WTF::move(other.m_atomData.group));
+            break;
         }
         other.destroy();
     }
@@ -140,39 +154,61 @@ public:
         }
     }
 
+    void extendGroupSubpattern(const Term& term)
+    {
+        ASSERT_WITH_SECURITY_IMPLICATION(m_termType == TermType::Group);
+        if (m_termType != TermType::Group)
+            return;
+        m_atomData.group.terms.append(term);
+    }
+
     void quantify(const AtomQuantifier& quantifier)
     {
         ASSERT_WITH_MESSAGE(m_quantifier == AtomQuantifier::One, "Transition to quantified term should only happen once.");
         m_quantifier = quantifier;
     }
 
-    AtomQuantifier quantifier() const
+    unsigned generateGraph(NFA& nfa, uint64_t patternId, unsigned start) const
     {
-        return m_quantifier;
-    }
+        ASSERT(isValid());
 
-    bool isUniversalTransition() const
-    {
-        return m_termType == TermType::CharacterSet
-            && ((m_atomData.characterSet.inverted && !m_atomData.characterSet.characters.bitCount())
-                || (!m_atomData.characterSet.inverted && m_atomData.characterSet.characters.bitCount() == 128));
-    }
+        switch (m_quantifier) {
+        case AtomQuantifier::One: {
+            unsigned newEnd = generateSubgraphForAtom(nfa, patternId, start);
+            return newEnd;
+        }
+        case AtomQuantifier::ZeroOrOne: {
+            unsigned newEnd = generateSubgraphForAtom(nfa, patternId, start);
+            nfa.addEpsilonTransition(start, newEnd);
+            return newEnd;
+        }
+        case AtomQuantifier::ZeroOrMore: {
+            unsigned repeatStart = nfa.createNode();
+            nfa.addRuleId(repeatStart, patternId);
+            nfa.addEpsilonTransition(start, repeatStart);
 
-    void visitSimpleTransitions(std::function<void(char)> visitor) const
-    {
-        ASSERT_WITH_SECURITY_IMPLICATION(m_termType == TermType::CharacterSet);
-        if (m_termType != TermType::CharacterSet)
-            return;
+            unsigned repeatEnd = generateSubgraphForAtom(nfa, patternId, repeatStart);
+            nfa.addEpsilonTransition(repeatEnd, repeatStart);
 
-        if (!m_atomData.characterSet.inverted) {
-            for (const auto& characterIterator : m_atomData.characterSet.characters.setBits())
-                visitor(static_cast<char>(characterIterator));
-        } else {
-            for (unsigned i = 1; i < m_atomData.characterSet.characters.size(); ++i) {
-                if (m_atomData.characterSet.characters.get(i))
-                    continue;
-                visitor(static_cast<char>(i));
-            }
+            unsigned kleenEnd = nfa.createNode();
+            nfa.addRuleId(kleenEnd, patternId);
+            nfa.addEpsilonTransition(repeatEnd, kleenEnd);
+            nfa.addEpsilonTransition(start, kleenEnd);
+            return kleenEnd;
+        }
+        case AtomQuantifier::OneOrMore: {
+            unsigned repeatStart = nfa.createNode();
+            nfa.addRuleId(repeatStart, patternId);
+            nfa.addEpsilonTransition(start, repeatStart);
+
+            unsigned repeatEnd = generateSubgraphForAtom(nfa, patternId, repeatStart);
+            nfa.addEpsilonTransition(repeatEnd, repeatStart);
+
+            unsigned afterRepeat = nfa.createNode();
+            nfa.addRuleId(afterRepeat, patternId);
+            nfa.addEpsilonTransition(repeatEnd, afterRepeat);
+            return afterRepeat;
+        }
         }
     }
 
@@ -201,6 +237,8 @@ public:
             return true;
         case TermType::CharacterSet:
             return m_atomData.characterSet == other.m_atomData.characterSet;
+        case TermType::Group:
+            return m_atomData.group == other.m_atomData.group;
         }
         ASSERT_NOT_REACHED();
         return false;
@@ -220,6 +258,9 @@ public:
         case TermType::CharacterSet:
             secondary = m_atomData.characterSet.hash();
             break;
+        case TermType::Group:
+            secondary = m_atomData.group.hash();
+            break;
         }
         return WTF::pairIntHash(primary, secondary);
     }
@@ -235,6 +276,48 @@ public:
     }
 
 private:
+    bool isUniversalTransition() const
+    {
+        return m_termType == TermType::CharacterSet
+            && ((m_atomData.characterSet.inverted && !m_atomData.characterSet.characters.bitCount())
+                || (!m_atomData.characterSet.inverted && m_atomData.characterSet.characters.bitCount() == 128));
+    }
+
+    unsigned generateSubgraphForAtom(NFA& nfa, uint64_t patternId, unsigned source) const
+    {
+        switch (m_termType) {
+        case TermType::Empty:
+        case TermType::Deleted:
+            ASSERT_NOT_REACHED();
+            return -1;
+        case TermType::CharacterSet: {
+            unsigned target = nfa.createNode();
+            nfa.addRuleId(target, patternId);
+            if (isUniversalTransition())
+                nfa.addTransitionsOnAnyCharacter(source, target);
+            else {
+                if (!m_atomData.characterSet.inverted) {
+                    for (const auto& characterIterator : m_atomData.characterSet.characters.setBits())
+                        nfa.addTransition(source, target, static_cast<char>(characterIterator));
+                } else {
+                    for (unsigned i = 1; i < m_atomData.characterSet.characters.size(); ++i) {
+                        if (m_atomData.characterSet.characters.get(i))
+                            continue;
+                        nfa.addTransition(source, target, static_cast<char>(i));
+                    }
+                }
+            }
+            return target;
+        }
+        case TermType::Group: {
+            unsigned lastTarget = source;
+            for (const Term& term : m_atomData.group.terms)
+                lastTarget = term.generateGraph(nfa, patternId, lastTarget);
+            return lastTarget;
+        }
+        }
+    }
+
     void destroy()
     {
         switch (m_termType) {
@@ -244,6 +327,9 @@ private:
         case TermType::CharacterSet:
             m_atomData.characterSet.~CharacterSet();
             break;
+        case TermType::Group:
+            m_atomData.group.~Group();
+            break;
         }
         m_termType = TermType::Deleted;
     }
@@ -251,7 +337,8 @@ private:
     enum class TermType : uint8_t {
         Empty,
         Deleted,
-        CharacterSet
+        CharacterSet,
+        Group
     };
 
     TermType m_termType { TermType::Empty };
@@ -272,6 +359,26 @@ private:
         }
     };
 
+    struct Group {
+        Vector<Term> terms;
+
+        bool operator==(const Group& other) const
+        {
+            return other.terms == terms;
+        }
+
+        unsigned hash() const
+        {
+            unsigned hash = 6421749;
+            for (const Term& term : terms) {
+                unsigned termHash = term.hash();
+                hash = (hash << 16) ^ ((termHash << 11) ^ hash);
+                hash += hash >> 11;
+            }
+            return hash;
+        }
+    };
+
     union AtomData {
         AtomData()
             : invalidTerm(0)
@@ -283,6 +390,7 @@ private:
 
         char invalidTerm;
         CharacterSet characterSet;
+        Group group;
     } m_atomData;
 };
 
@@ -318,6 +426,11 @@ public:
 
         sinkFloatingTermIfNecessary();
 
+        if (!m_openGroups.isEmpty()) {
+            fail(ASCIILiteral("The expression has unclosed groups."));
+            return;
+        }
+
         if (m_subtreeStart != m_subtreeEnd)
             m_nfa.setFinal(m_subtreeEnd, m_patternId);
         else
@@ -448,7 +561,12 @@ public:
 
     void atomParenthesesSubpatternBegin(bool = true)
     {
-        fail(ASCIILiteral("Groups are not supported yet."));
+        if (hasError())
+            return;
+
+        sinkFloatingTermIfNecessary();
+
+        m_openGroups.append(Term(Term::GroupTerm));
     }
 
     void atomParentheticalAssertionBegin(bool = false)
@@ -458,7 +576,13 @@ public:
 
     void atomParenthesesEnd()
     {
-        fail(ASCIILiteral("Groups are not supported yet."));
+        if (hasError())
+            return;
+
+        sinkFloatingTermIfNecessary();
+        ASSERT(!m_floatingTerm.isValid());
+
+        m_floatingTerm = m_openGroups.takeLast();
     }
 
     void disjunction()
@@ -483,68 +607,6 @@ private:
         m_errorMessage = errorMessage;
     }
 
-    void addTransitions(unsigned source, unsigned target)
-    {
-        auto visitor = [this, source, target](char character) {
-            if (m_floatingTerm.isUniversalTransition())
-                m_nfa.addTransitionsOnAnyCharacter(source, target);
-            else
-                m_nfa.addTransition(source, target, character);
-        };
-        m_floatingTerm.visitSimpleTransitions(visitor);
-    }
-
-    unsigned sinkFloatingTerm(unsigned start)
-    {
-        switch (m_floatingTerm.quantifier()) {
-        case AtomQuantifier::One: {
-            unsigned newEnd = m_nfa.createNode();
-            m_nfa.addRuleId(newEnd, m_patternId);
-            addTransitions(start, newEnd);
-            return newEnd;
-        }
-        case AtomQuantifier::ZeroOrOne: {
-            unsigned newEnd = m_nfa.createNode();
-            m_nfa.addRuleId(newEnd, m_patternId);
-            addTransitions(start, newEnd);
-            return newEnd;
-        }
-        case AtomQuantifier::ZeroOrMore: {
-            unsigned repeatStart = m_nfa.createNode();
-            m_nfa.addRuleId(repeatStart, m_patternId);
-            unsigned repeatEnd = m_nfa.createNode();
-            m_nfa.addRuleId(repeatEnd, m_patternId);
-
-            addTransitions(repeatStart, repeatEnd);
-            m_nfa.addEpsilonTransition(repeatEnd, repeatStart);
-
-            m_nfa.addEpsilonTransition(start, repeatStart);
-
-            unsigned kleenEnd = m_nfa.createNode();
-            m_nfa.addRuleId(kleenEnd, m_patternId);
-            m_nfa.addEpsilonTransition(repeatEnd, kleenEnd);
-            m_nfa.addEpsilonTransition(start, kleenEnd);
-            return kleenEnd;
-        }
-        case AtomQuantifier::OneOrMore: {
-            unsigned repeatStart = m_nfa.createNode();
-            m_nfa.addRuleId(repeatStart, m_patternId);
-            unsigned repeatEnd = m_nfa.createNode();
-            m_nfa.addRuleId(repeatEnd, m_patternId);
-
-            addTransitions(repeatStart, repeatEnd);
-            m_nfa.addEpsilonTransition(repeatEnd, repeatStart);
-
-            m_nfa.addEpsilonTransition(start, repeatStart);
-
-            unsigned afterRepeat = m_nfa.createNode();
-            m_nfa.addRuleId(afterRepeat, m_patternId);
-            m_nfa.addEpsilonTransition(repeatEnd, afterRepeat);
-            return afterRepeat;
-        }
-        }
-    }
-
     void sinkFloatingTermIfNecessary()
     {
         if (!m_floatingTerm.isValid())
@@ -552,6 +614,12 @@ private:
 
         ASSERT(m_lastPrefixTreeEntry);
 
+        if (!m_openGroups.isEmpty()) {
+            m_openGroups.last().extendGroupSubpattern(m_floatingTerm);
+            m_floatingTerm = Term();
+            return;
+        }
+
         auto nextEntry = m_lastPrefixTreeEntry->nextPattern.find(m_floatingTerm);
         if (nextEntry != m_lastPrefixTreeEntry->nextPattern.end()) {
             m_lastPrefixTreeEntry = nextEntry->value.get();
@@ -559,7 +627,7 @@ private:
         } else {
             std::unique_ptr<PrefixTreeEntry> nextPrefixTreeEntry = std::make_unique<PrefixTreeEntry>();
 
-            unsigned newEnd = sinkFloatingTerm(m_lastPrefixTreeEntry->nfaNode);
+            unsigned newEnd = m_floatingTerm.generateGraph(m_nfa, m_patternId, m_lastPrefixTreeEntry->nfaNode);
             nextPrefixTreeEntry->nfaNode = newEnd;
 
             auto addResult = m_lastPrefixTreeEntry->nextPattern.set(m_floatingTerm, WTF::move(nextPrefixTreeEntry));
@@ -586,6 +654,7 @@ private:
     unsigned m_subtreeEnd { 0 };
 
     PrefixTreeEntry* m_lastPrefixTreeEntry;
+    Deque<Term> m_openGroups;
     Term m_floatingTerm;
 
     PrefixTreeEntry* m_newPrefixSubtreeRoot = nullptr;