Extend URL filter's Term definition to support groups/subpatterns
[WebKit-https.git] / Source / WebCore / contentextensions / URLFilterParser.cpp
index 38841be..4c7d9ce 100644 (file)
 
 #include "NFA.h"
 #include <JavaScriptCore/YarrParser.h>
+#include <wtf/BitVector.h>
+#include <wtf/Deque.h>
 
 namespace WebCore {
 
 namespace ContentExtensions {
 
-class GraphBuilder {
+enum class AtomQuantifier : uint8_t {
+    One,
+    ZeroOrOne,
+    ZeroOrMore,
+    OneOrMore
+};
+
+class Term {
+public:
+    Term()
+    {
+    }
+
+    Term(char character, bool isCaseSensitive)
+        : m_termType(TermType::CharacterSet)
+    {
+        new (NotNull, &m_atomData.characterSet) CharacterSet();
+        addCharacter(character, isCaseSensitive);
+    }
+
+    enum UniversalTransitionTag { UniversalTransition };
+    explicit Term(UniversalTransitionTag)
+        : m_termType(TermType::CharacterSet)
+    {
+        new (NotNull, &m_atomData.characterSet) CharacterSet();
+        for (unsigned i = 0; i < 128; ++i)
+            m_atomData.characterSet.characters.set(i);
+    }
+
+    enum CharacterSetTermTag { CharacterSetTerm };
+    explicit Term(CharacterSetTermTag, bool isInverted)
+        : m_termType(TermType::CharacterSet)
+    {
+        new (NotNull, &m_atomData.characterSet) CharacterSet();
+        m_atomData.characterSet.inverted = isInverted;
+    }
+
+    enum GroupTermTag { GroupTerm };
+    explicit Term(GroupTermTag)
+        : m_termType(TermType::Group)
+    {
+        new (NotNull, &m_atomData.group) Group();
+    }
+
+    Term(const Term& other)
+        : m_termType(other.m_termType)
+        , m_quantifier(other.m_quantifier)
+    {
+        switch (m_termType) {
+        case TermType::Empty:
+        case TermType::Deleted:
+            break;
+        case TermType::CharacterSet:
+            new (NotNull, &m_atomData.characterSet) CharacterSet(other.m_atomData.characterSet);
+            break;
+        case TermType::Group:
+            new (NotNull, &m_atomData.group) Group(other.m_atomData.group);
+            break;
+        }
+    }
+
+    Term(Term&& other)
+        : m_termType(WTF::move(other.m_termType))
+        , m_quantifier(WTF::move(other.m_quantifier))
+    {
+        switch (m_termType) {
+        case TermType::Empty:
+        case TermType::Deleted:
+            break;
+        case TermType::CharacterSet:
+            new (NotNull, &m_atomData.characterSet) CharacterSet(WTF::move(other.m_atomData.characterSet));
+            break;
+        case TermType::Group:
+            new (NotNull, &m_atomData.group) Group(WTF::move(other.m_atomData.group));
+            break;
+        }
+        other.destroy();
+    }
+
+    enum EmptyValueTag { EmptyValue };
+    Term(EmptyValueTag)
+        : m_termType(TermType::Empty)
+    {
+    }
+
+    enum DeletedValueTag { DeletedValue };
+    Term(DeletedValueTag)
+        : m_termType(TermType::Deleted)
+    {
+    }
+
+    ~Term()
+    {
+        destroy();
+    }
+
+    bool isValid() const
+    {
+        return m_termType != TermType::Empty && m_termType != TermType::Deleted;
+    }
+
+    void addCharacter(UChar character, bool isCaseSensitive)
+    {
+        ASSERT(isASCII(character));
+
+        ASSERT_WITH_SECURITY_IMPLICATION(m_termType == TermType::CharacterSet);
+        if (m_termType != TermType::CharacterSet)
+            return;
+
+        if (isCaseSensitive || !isASCIIAlpha(character))
+            m_atomData.characterSet.characters.set(character);
+        else {
+            m_atomData.characterSet.characters.set(toASCIIUpper(character));
+            m_atomData.characterSet.characters.set(toASCIILower(character));
+        }
+    }
+
+    void extendGroupSubpattern(const Term& term)
+    {
+        ASSERT_WITH_SECURITY_IMPLICATION(m_termType == TermType::Group);
+        if (m_termType != TermType::Group)
+            return;
+        m_atomData.group.terms.append(term);
+    }
+
+    void quantify(const AtomQuantifier& quantifier)
+    {
+        ASSERT_WITH_MESSAGE(m_quantifier == AtomQuantifier::One, "Transition to quantified term should only happen once.");
+        m_quantifier = quantifier;
+    }
+
+    unsigned generateGraph(NFA& nfa, uint64_t patternId, unsigned start) const
+    {
+        ASSERT(isValid());
+
+        switch (m_quantifier) {
+        case AtomQuantifier::One: {
+            unsigned newEnd = generateSubgraphForAtom(nfa, patternId, start);
+            return newEnd;
+        }
+        case AtomQuantifier::ZeroOrOne: {
+            unsigned newEnd = generateSubgraphForAtom(nfa, patternId, start);
+            nfa.addEpsilonTransition(start, newEnd);
+            return newEnd;
+        }
+        case AtomQuantifier::ZeroOrMore: {
+            unsigned repeatStart = nfa.createNode();
+            nfa.addRuleId(repeatStart, patternId);
+            nfa.addEpsilonTransition(start, repeatStart);
+
+            unsigned repeatEnd = generateSubgraphForAtom(nfa, patternId, repeatStart);
+            nfa.addEpsilonTransition(repeatEnd, repeatStart);
+
+            unsigned kleenEnd = nfa.createNode();
+            nfa.addRuleId(kleenEnd, patternId);
+            nfa.addEpsilonTransition(repeatEnd, kleenEnd);
+            nfa.addEpsilonTransition(start, kleenEnd);
+            return kleenEnd;
+        }
+        case AtomQuantifier::OneOrMore: {
+            unsigned repeatStart = nfa.createNode();
+            nfa.addRuleId(repeatStart, patternId);
+            nfa.addEpsilonTransition(start, repeatStart);
+
+            unsigned repeatEnd = generateSubgraphForAtom(nfa, patternId, repeatStart);
+            nfa.addEpsilonTransition(repeatEnd, repeatStart);
+
+            unsigned afterRepeat = nfa.createNode();
+            nfa.addRuleId(afterRepeat, patternId);
+            nfa.addEpsilonTransition(repeatEnd, afterRepeat);
+            return afterRepeat;
+        }
+        }
+    }
+
+    Term& operator=(const Term& other)
+    {
+        destroy();
+        new (NotNull, this) Term(other);
+        return *this;
+    }
+
+    Term& operator=(Term&& other)
+    {
+        destroy();
+        new (NotNull, this) Term(WTF::move(other));
+        return *this;
+    }
+
+    bool operator==(const Term& other) const
+    {
+        if (other.m_termType != m_termType || other.m_quantifier != m_quantifier)
+            return false;
+
+        switch (m_termType) {
+        case TermType::Empty:
+        case TermType::Deleted:
+            return true;
+        case TermType::CharacterSet:
+            return m_atomData.characterSet == other.m_atomData.characterSet;
+        case TermType::Group:
+            return m_atomData.group == other.m_atomData.group;
+        }
+        ASSERT_NOT_REACHED();
+        return false;
+    }
+
+    unsigned hash() const
+    {
+        unsigned primary = static_cast<unsigned>(m_termType) << 16 | static_cast<unsigned>(m_quantifier);
+        unsigned secondary = 0;
+        switch (m_termType) {
+        case TermType::Empty:
+            secondary = 52184393;
+            break;
+        case TermType::Deleted:
+            secondary = 40342988;
+            break;
+        case TermType::CharacterSet:
+            secondary = m_atomData.characterSet.hash();
+            break;
+        case TermType::Group:
+            secondary = m_atomData.group.hash();
+            break;
+        }
+        return WTF::pairIntHash(primary, secondary);
+    }
+
+    bool isEmptyValue() const
+    {
+        return m_termType == TermType::Empty;
+    }
+
+    bool isDeletedValue() const
+    {
+        return m_termType == TermType::Deleted;
+    }
+
 private:
-    struct BoundedSubGraph {
-        unsigned start;
-        unsigned end;
+    bool isUniversalTransition() const
+    {
+        return m_termType == TermType::CharacterSet
+            && ((m_atomData.characterSet.inverted && !m_atomData.characterSet.characters.bitCount())
+                || (!m_atomData.characterSet.inverted && m_atomData.characterSet.characters.bitCount() == 128));
+    }
+
+    unsigned generateSubgraphForAtom(NFA& nfa, uint64_t patternId, unsigned source) const
+    {
+        switch (m_termType) {
+        case TermType::Empty:
+        case TermType::Deleted:
+            ASSERT_NOT_REACHED();
+            return -1;
+        case TermType::CharacterSet: {
+            unsigned target = nfa.createNode();
+            nfa.addRuleId(target, patternId);
+            if (isUniversalTransition())
+                nfa.addTransitionsOnAnyCharacter(source, target);
+            else {
+                if (!m_atomData.characterSet.inverted) {
+                    for (const auto& characterIterator : m_atomData.characterSet.characters.setBits())
+                        nfa.addTransition(source, target, static_cast<char>(characterIterator));
+                } else {
+                    for (unsigned i = 1; i < m_atomData.characterSet.characters.size(); ++i) {
+                        if (m_atomData.characterSet.characters.get(i))
+                            continue;
+                        nfa.addTransition(source, target, static_cast<char>(i));
+                    }
+                }
+            }
+            return target;
+        }
+        case TermType::Group: {
+            unsigned lastTarget = source;
+            for (const Term& term : m_atomData.group.terms)
+                lastTarget = term.generateGraph(nfa, patternId, lastTarget);
+            return lastTarget;
+        }
+        }
+    }
+
+    void destroy()
+    {
+        switch (m_termType) {
+        case TermType::Empty:
+        case TermType::Deleted:
+            break;
+        case TermType::CharacterSet:
+            m_atomData.characterSet.~CharacterSet();
+            break;
+        case TermType::Group:
+            m_atomData.group.~Group();
+            break;
+        }
+        m_termType = TermType::Deleted;
+    }
+
+    enum class TermType : uint8_t {
+        Empty,
+        Deleted,
+        CharacterSet,
+        Group
+    };
+
+    TermType m_termType { TermType::Empty };
+    AtomQuantifier m_quantifier { AtomQuantifier::One };
+
+    struct CharacterSet {
+        bool inverted { false };
+        BitVector characters { 128 };
+
+        bool operator==(const CharacterSet& other) const
+        {
+            return other.inverted == inverted && other.characters == characters;
+        }
+
+        unsigned hash() const
+        {
+            return WTF::pairIntHash(inverted, characters.hash());
+        }
     };
+
+    struct Group {
+        Vector<Term> terms;
+
+        bool operator==(const Group& other) const
+        {
+            return other.terms == terms;
+        }
+
+        unsigned hash() const
+        {
+            unsigned hash = 6421749;
+            for (const Term& term : terms) {
+                unsigned termHash = term.hash();
+                hash = (hash << 16) ^ ((termHash << 11) ^ hash);
+                hash += hash >> 11;
+            }
+            return hash;
+        }
+    };
+
+    union AtomData {
+        AtomData()
+            : invalidTerm(0)
+        {
+        }
+        ~AtomData()
+        {
+        }
+
+        char invalidTerm;
+        CharacterSet characterSet;
+        Group group;
+    } m_atomData;
+};
+
+struct TermHash {
+    static unsigned hash(const Term& term) { return term.hash(); }
+    static bool equal(const Term& a, const Term& b) { return a == b; }
+    static const bool safeToCompareToEmptyOrDeleted = true;
+};
+
+struct TermHashTraits : public WTF::CustomHashTraits<Term> { };
+
+struct PrefixTreeEntry {
+    unsigned nfaNode;
+    HashMap<Term, std::unique_ptr<PrefixTreeEntry>, TermHash, TermHashTraits> nextPattern;
+};
+
+class GraphBuilder {
 public:
-    GraphBuilder(NFA& nfa, uint64_t patternId)
+    GraphBuilder(NFA& nfa, PrefixTreeEntry& prefixTreeRoot, bool patternIsCaseSensitive, uint64_t patternId)
         : m_nfa(nfa)
+        , m_patternIsCaseSensitive(patternIsCaseSensitive)
         , m_patternId(patternId)
-        , m_activeGroup({ nfa.root(), nfa.root() })
-        , m_lastAtom(m_activeGroup)
+        , m_subtreeStart(nfa.root())
+        , m_subtreeEnd(nfa.root())
+        , m_lastPrefixTreeEntry(&prefixTreeRoot)
     {
     }
 
@@ -54,8 +423,16 @@ public:
     {
         if (hasError())
             return;
-        if (m_activeGroup.start != m_activeGroup.end)
-            m_nfa.setFinal(m_activeGroup.end);
+
+        sinkFloatingTermIfNecessary();
+
+        if (!m_openGroups.isEmpty()) {
+            fail(ASCIILiteral("The expression has unclosed groups."));
+            return;
+        }
+
+        if (m_subtreeStart != m_subtreeEnd)
+            m_nfa.setFinal(m_subtreeEnd, m_patternId);
         else
             fail(ASCIILiteral("The pattern cannot match anything."));
     }
@@ -67,20 +444,19 @@ public:
 
     void atomPatternCharacter(UChar character)
     {
+        if (hasError())
+            return;
+
         if (!isASCII(character)) {
             fail(ASCIILiteral("Only ASCII characters are supported in pattern."));
             return;
         }
 
-        if (hasError())
-            return;
+        sinkFloatingTermIfNecessary();
+        ASSERT(!m_floatingTerm.isValid());
 
-        m_hasValidAtom = true;
-        unsigned newEnd = m_nfa.createNode(m_patternId);
-        m_nfa.addTransition(m_lastAtom.end, newEnd, static_cast<char>(character));
-        m_lastAtom.start = m_lastAtom.end;
-        m_lastAtom.end = newEnd;
-        m_activeGroup.end = m_lastAtom.end;
+        char asciiChararacter = static_cast<char>(character);
+        m_floatingTerm = Term(asciiChararacter, m_patternIsCaseSensitive);
     }
 
     void atomBuiltInCharacterClass(JSC::Yarr::BuiltInCharacterClassID builtInCharacterClassID, bool inverted)
@@ -88,16 +464,12 @@ public:
         if (hasError())
             return;
 
-        if (builtInCharacterClassID == JSC::Yarr::NewlineClassID && inverted) {
-            // FIXME: handle new line properly.
-            m_hasValidAtom = true;
-            unsigned newEnd = m_nfa.createNode(m_patternId);
-            for (unsigned i = 1; i < 128; ++i)
-                m_nfa.addTransition(m_lastAtom.end, newEnd, i);
-            m_lastAtom.start = m_lastAtom.end;
-            m_lastAtom.end = newEnd;
-            m_activeGroup.end = m_lastAtom.end;
-        } else
+        sinkFloatingTermIfNecessary();
+        ASSERT(!m_floatingTerm.isValid());
+
+        if (builtInCharacterClassID == JSC::Yarr::NewlineClassID && inverted)
+            m_floatingTerm = Term(Term::UniversalTransition);
+        else
             fail(ASCIILiteral("Character class is not supported."));
     }
 
@@ -106,32 +478,22 @@ public:
         if (hasError())
             return;
 
-        ASSERT(m_hasValidAtom);
-        if (!m_hasValidAtom) {
-            fail(ASCIILiteral("Quantifier without corresponding atom to quantify."));
-            return;
-        }
+        if (!m_floatingTerm.isValid())
+            fail(ASCIILiteral("Quantifier without corresponding term to quantify."));
 
         if (!minimum && maximum == 1)
-            m_nfa.addEpsilonTransition(m_lastAtom.start, m_lastAtom.end);
-        else if (!minimum && maximum == JSC::Yarr::quantifyInfinite) {
-            m_nfa.addEpsilonTransition(m_lastAtom.start, m_lastAtom.end);
-            m_nfa.addEpsilonTransition(m_lastAtom.end, m_lastAtom.start);
-        } else if (minimum == 1 && maximum == JSC::Yarr::quantifyInfinite)
-            m_nfa.addEpsilonTransition(m_lastAtom.end, m_lastAtom.start);
+            m_floatingTerm.quantify(AtomQuantifier::ZeroOrOne);
+        else if (!minimum && maximum == JSC::Yarr::quantifyInfinite)
+            m_floatingTerm.quantify(AtomQuantifier::ZeroOrMore);
+        else if (minimum == 1 && maximum == JSC::Yarr::quantifyInfinite)
+            m_floatingTerm.quantify(AtomQuantifier::OneOrMore);
         else
             fail(ASCIILiteral("Arbitrary atom repetitions are not supported."));
     }
 
-    NO_RETURN_DUE_TO_ASSERT void atomBackReference(unsigned)
+    void atomBackReference(unsigned)
     {
         fail(ASCIILiteral("Patterns cannot contain backreferences."));
-        ASSERT_NOT_REACHED();
-    }
-
-    void atomCharacterClassAtom(UChar)
-    {
-        fail(ASCIILiteral("Character class atoms are not supported yet."));
     }
 
     void assertionBOL()
@@ -149,29 +511,62 @@ public:
         fail(ASCIILiteral("Word boundaries assertions are not supported yet."));
     }
 
-    void atomCharacterClassBegin(bool = false)
+    void atomCharacterClassBegin(bool inverted = false)
     {
-        fail(ASCIILiteral("Character class atoms are not supported yet."));
+        if (hasError())
+            return;
+
+        sinkFloatingTermIfNecessary();
+        ASSERT(!m_floatingTerm.isValid());
+
+        m_floatingTerm = Term(Term::CharacterSetTerm, inverted);
     }
 
-    void atomCharacterClassRange(UChar, UChar)
+    void atomCharacterClassAtom(UChar character)
     {
-        fail(ASCIILiteral("Character class ranges are not supported yet."));
+        if (hasError())
+            return;
+
+        if (!isASCII(character)) {
+            fail(ASCIILiteral("Non ASCII Character in a character set."));
+            return;
+        }
+
+        m_floatingTerm.addCharacter(character, m_patternIsCaseSensitive);
     }
 
-    void atomCharacterClassBuiltIn(JSC::Yarr::BuiltInCharacterClassID, bool)
+    void atomCharacterClassRange(UChar a, UChar b)
     {
-        fail(ASCIILiteral("Buildins character class atoms are not supported yet."));
+        if (hasError())
+            return;
+
+        if (!a || !b || !isASCII(a) || !isASCII(b)) {
+            fail(ASCIILiteral("Non ASCII Character in a character range of a character set."));
+            return;
+        }
+
+        for (unsigned i = a; i <= b; ++i)
+            m_floatingTerm.addCharacter(static_cast<UChar>(i), m_patternIsCaseSensitive);
     }
 
     void atomCharacterClassEnd()
     {
-        fail(ASCIILiteral("Character class are not supported yet."));
+        // Nothing to do here. The character set atom may have a quantifier, we sink the atom lazily.
+    }
+
+    void atomCharacterClassBuiltIn(JSC::Yarr::BuiltInCharacterClassID, bool)
+    {
+        fail(ASCIILiteral("Builtins character class atoms are not supported yet."));
     }
 
     void atomParenthesesSubpatternBegin(bool = true)
     {
-        fail(ASCIILiteral("Groups are not supported yet."));
+        if (hasError())
+            return;
+
+        sinkFloatingTermIfNecessary();
+
+        m_openGroups.append(Term(Term::GroupTerm));
     }
 
     void atomParentheticalAssertionBegin(bool = false)
@@ -181,7 +576,13 @@ public:
 
     void atomParenthesesEnd()
     {
-        fail(ASCIILiteral("Groups are not supported yet."));
+        if (hasError())
+            return;
+
+        sinkFloatingTermIfNecessary();
+        ASSERT(!m_floatingTerm.isValid());
+
+        m_floatingTerm = m_openGroups.takeLast();
     }
 
     void disjunction()
@@ -189,7 +590,6 @@ public:
         fail(ASCIILiteral("Disjunctions are not supported yet."));
     }
 
-
 private:
     bool hasError() const
     {
@@ -200,43 +600,105 @@ private:
     {
         if (hasError())
             return;
+
+        if (m_newPrefixSubtreeRoot)
+            m_newPrefixSubtreeRoot->nextPattern.remove(m_newPrefixStaringPoint);
+
         m_errorMessage = errorMessage;
     }
 
+    void sinkFloatingTermIfNecessary()
+    {
+        if (!m_floatingTerm.isValid())
+            return;
+
+        ASSERT(m_lastPrefixTreeEntry);
+
+        if (!m_openGroups.isEmpty()) {
+            m_openGroups.last().extendGroupSubpattern(m_floatingTerm);
+            m_floatingTerm = Term();
+            return;
+        }
+
+        auto nextEntry = m_lastPrefixTreeEntry->nextPattern.find(m_floatingTerm);
+        if (nextEntry != m_lastPrefixTreeEntry->nextPattern.end()) {
+            m_lastPrefixTreeEntry = nextEntry->value.get();
+            m_nfa.addRuleId(m_lastPrefixTreeEntry->nfaNode, m_patternId);
+        } else {
+            std::unique_ptr<PrefixTreeEntry> nextPrefixTreeEntry = std::make_unique<PrefixTreeEntry>();
+
+            unsigned newEnd = m_floatingTerm.generateGraph(m_nfa, m_patternId, m_lastPrefixTreeEntry->nfaNode);
+            nextPrefixTreeEntry->nfaNode = newEnd;
+
+            auto addResult = m_lastPrefixTreeEntry->nextPattern.set(m_floatingTerm, WTF::move(nextPrefixTreeEntry));
+            ASSERT(addResult.isNewEntry);
+
+            if (!m_newPrefixSubtreeRoot) {
+                m_newPrefixSubtreeRoot = m_lastPrefixTreeEntry;
+                m_newPrefixStaringPoint = m_floatingTerm;
+            }
+
+            m_lastPrefixTreeEntry = addResult.iterator->value.get();
+        }
+        m_subtreeEnd = m_lastPrefixTreeEntry->nfaNode;
+
+        m_floatingTerm = Term();
+        ASSERT(m_lastPrefixTreeEntry);
+    }
+
     NFA& m_nfa;
+    bool m_patternIsCaseSensitive;
     const uint64_t m_patternId;
 
-    BoundedSubGraph m_activeGroup;
+    unsigned m_subtreeStart { 0 };
+    unsigned m_subtreeEnd { 0 };
 
-    bool m_hasValidAtom = false;
-    BoundedSubGraph m_lastAtom;
+    PrefixTreeEntry* m_lastPrefixTreeEntry;
+    Deque<Term> m_openGroups;
+    Term m_floatingTerm;
+
+    PrefixTreeEntry* m_newPrefixSubtreeRoot = nullptr;
+    Term m_newPrefixStaringPoint;
 
     String m_errorMessage;
 };
 
-void URLFilterParser::parse(const String& pattern, uint64_t patternId, NFA& nfa)
+URLFilterParser::URLFilterParser(NFA& nfa)
+    : m_nfa(nfa)
+    , m_prefixTreeRoot(std::make_unique<PrefixTreeEntry>())
+{
+    m_prefixTreeRoot->nfaNode = nfa.root();
+}
+
+URLFilterParser::~URLFilterParser()
+{
+}
+
+String URLFilterParser::addPattern(const String& pattern, bool patternIsCaseSensitive, uint64_t patternId)
 {
     if (!pattern.containsOnlyASCII())
-        m_errorMessage = ASCIILiteral("URLFilterParser only supports ASCII patterns.");
+        return ASCIILiteral("URLFilterParser only supports ASCII patterns.");
     ASSERT(!pattern.isEmpty());
 
     if (pattern.isEmpty())
-        return;
+        return ASCIILiteral("Empty pattern.");
+
+    unsigned oldSize = m_nfa.graphSize();
 
-    unsigned oldSize = nfa.graphSize();
+    String error;
 
-    GraphBuilder graphBuilder(nfa, patternId);
-    const char* error = JSC::Yarr::parse(graphBuilder, pattern, 0);
-    if (error)
-        m_errorMessage = String(error);
-    else
+    GraphBuilder graphBuilder(m_nfa, *m_prefixTreeRoot, patternIsCaseSensitive, patternId);
+    error = String(JSC::Yarr::parse(graphBuilder, pattern, 0));
+    if (error.isNull())
         graphBuilder.finalize();
 
-    if (!error)
-        m_errorMessage = graphBuilder.errorMessage();
+    if (error.isNull())
+        error = graphBuilder.errorMessage();
+
+    if (!error.isNull())
+        m_nfa.restoreToGraphSize(oldSize);
 
-    if (hasError())
-        nfa.restoreToGraphSize(oldSize);
+    return error;
 }
 
 } // namespace ContentExtensions