755cd7a9aa576c6f7a4db14c0d9e302c78db5178
[WebKit-https.git] / Source / WTF / wtf / BitVector.h
1 /*
2  * Copyright (C) 2011, 2014, 2016 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
14  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL APPLE INC. OR
17  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
18  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
19  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
20  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
21  * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
24  */
25
26 #pragma once
27
28 #include <stdio.h>
29 #include <wtf/Assertions.h>
30 #include <wtf/DataLog.h>
31 #include <wtf/HashFunctions.h>
32 #include <wtf/HashTraits.h>
33 #include <wtf/PrintStream.h>
34 #include <wtf/StdLibExtras.h>
35
36 namespace WTF {
37
38 // This is a space-efficient, resizeable bitvector class. In the common case it
39 // occupies one word, but if necessary, it will inflate this one word to point
40 // to a single chunk of out-of-line allocated storage to store an arbitrary number
41 // of bits.
42 //
43 // - The bitvector remembers the bound of how many bits can be stored, but this
44 //   may be slightly greater (by as much as some platform-specific constant)
45 //   than the last argument passed to ensureSize().
46 //
47 // - The bitvector can resize itself automatically (set, clear, get) or can be used
48 //   in a manual mode, which is faster (quickSet, quickClear, quickGet, ensureSize).
49 //
50 // - Accesses ASSERT that you are within bounds.
51 //
52 // - Bits are automatically initialized to zero.
53 //
54 // On the other hand, this BitVector class may not be the fastest around, since
55 // it does conditionals on every get/set/clear. But it is great if you need to
56 // juggle a lot of variable-length BitVectors and you're worried about wasting
57 // space.
58
59 class BitVector {
60 public: 
61     BitVector()
62         : m_bitsOrPointer(makeInlineBits(0))
63     {
64     }
65     
66     explicit BitVector(size_t numBits)
67         : m_bitsOrPointer(makeInlineBits(0))
68     {
69         ensureSize(numBits);
70     }
71     
72     BitVector(const BitVector& other)
73         : m_bitsOrPointer(makeInlineBits(0))
74     {
75         (*this) = other;
76     }
77
78     
79     ~BitVector()
80     {
81         if (isInline())
82             return;
83         OutOfLineBits::destroy(outOfLineBits());
84     }
85     
86     BitVector& operator=(const BitVector& other)
87     {
88         if (isInline() && other.isInline())
89             m_bitsOrPointer = other.m_bitsOrPointer;
90         else
91             setSlow(other);
92         return *this;
93     }
94
95     size_t size() const
96     {
97         if (isInline())
98             return maxInlineBits();
99         return outOfLineBits()->numBits();
100     }
101
102     void ensureSize(size_t numBits)
103     {
104         if (numBits <= size())
105             return;
106         resizeOutOfLine(numBits);
107     }
108     
109     // Like ensureSize(), but supports reducing the size of the bitvector.
110     WTF_EXPORT_PRIVATE void resize(size_t numBits);
111     
112     WTF_EXPORT_PRIVATE void clearAll();
113
114     bool quickGet(size_t bit) const
115     {
116         ASSERT_WITH_SECURITY_IMPLICATION(bit < size());
117         return !!(bits()[bit / bitsInPointer()] & (static_cast<uintptr_t>(1) << (bit & (bitsInPointer() - 1))));
118     }
119     
120     bool quickSet(size_t bit)
121     {
122         ASSERT_WITH_SECURITY_IMPLICATION(bit < size());
123         uintptr_t& word = bits()[bit / bitsInPointer()];
124         uintptr_t mask = static_cast<uintptr_t>(1) << (bit & (bitsInPointer() - 1));
125         bool result = !!(word & mask);
126         word |= mask;
127         return result;
128     }
129     
130     bool quickClear(size_t bit)
131     {
132         ASSERT_WITH_SECURITY_IMPLICATION(bit < size());
133         uintptr_t& word = bits()[bit / bitsInPointer()];
134         uintptr_t mask = static_cast<uintptr_t>(1) << (bit & (bitsInPointer() - 1));
135         bool result = !!(word & mask);
136         word &= ~mask;
137         return result;
138     }
139     
140     bool quickSet(size_t bit, bool value)
141     {
142         if (value)
143             return quickSet(bit);
144         return quickClear(bit);
145     }
146     
147     bool get(size_t bit) const
148     {
149         if (bit >= size())
150             return false;
151         return quickGet(bit);
152     }
153
154     bool contains(size_t bit) const
155     {
156         return get(bit);
157     }
158     
159     bool set(size_t bit)
160     {
161         ensureSize(bit + 1);
162         return quickSet(bit);
163     }
164
165     // This works like the add methods of sets. Instead of returning the previous value, like set(),
166     // it returns whether the bit transitioned from false to true.
167     bool add(size_t bit)
168     {
169         return !set(bit);
170     }
171
172     bool ensureSizeAndSet(size_t bit, size_t size)
173     {
174         ensureSize(size);
175         return quickSet(bit);
176     }
177
178     bool clear(size_t bit)
179     {
180         if (bit >= size())
181             return false;
182         return quickClear(bit);
183     }
184
185     bool remove(size_t bit)
186     {
187         return clear(bit);
188     }
189     
190     bool set(size_t bit, bool value)
191     {
192         if (value)
193             return set(bit);
194         return clear(bit);
195     }
196     
197     void merge(const BitVector& other)
198     {
199         if (!isInline() || !other.isInline()) {
200             mergeSlow(other);
201             return;
202         }
203         m_bitsOrPointer |= other.m_bitsOrPointer;
204         ASSERT(isInline());
205     }
206     
207     void filter(const BitVector& other)
208     {
209         if (!isInline() || !other.isInline()) {
210             filterSlow(other);
211             return;
212         }
213         m_bitsOrPointer &= other.m_bitsOrPointer;
214         ASSERT(isInline());
215     }
216     
217     void exclude(const BitVector& other)
218     {
219         if (!isInline() || !other.isInline()) {
220             excludeSlow(other);
221             return;
222         }
223         m_bitsOrPointer &= ~other.m_bitsOrPointer;
224         m_bitsOrPointer |= (static_cast<uintptr_t>(1) << maxInlineBits());
225         ASSERT(isInline());
226     }
227     
228     size_t bitCount() const
229     {
230         if (isInline())
231             return bitCount(cleanseInlineBits(m_bitsOrPointer));
232         return bitCountSlow();
233     }
234     
235     size_t findBit(size_t index, bool value) const
236     {
237         size_t result = findBitFast(index, value);
238         if (!ASSERT_DISABLED) {
239             size_t expectedResult = findBitSimple(index, value);
240             if (result != expectedResult) {
241                 dataLog("findBit(", index, ", ", value, ") on ", *this, " should have gotten ", expectedResult, " but got ", result, "\n");
242                 ASSERT_NOT_REACHED();
243             }
244         }
245         return result;
246     }
247     
248     WTF_EXPORT_PRIVATE void dump(PrintStream& out) const;
249     
250     enum EmptyValueTag { EmptyValue };
251     enum DeletedValueTag { DeletedValue };
252     
253     BitVector(EmptyValueTag)
254         : m_bitsOrPointer(0)
255     {
256     }
257     
258     BitVector(DeletedValueTag)
259         : m_bitsOrPointer(1)
260     {
261     }
262     
263     bool isEmptyValue() const { return !m_bitsOrPointer; }
264     bool isDeletedValue() const { return m_bitsOrPointer == 1; }
265     
266     bool isEmptyOrDeletedValue() const { return m_bitsOrPointer <= 1; }
267     
268     bool operator==(const BitVector& other) const
269     {
270         if (isInline() && other.isInline())
271             return m_bitsOrPointer == other.m_bitsOrPointer;
272         return equalsSlowCase(other);
273     }
274     
275     unsigned hash() const
276     {
277         // This is a very simple hash. Just xor together the words that hold the various
278         // bits and then compute the hash. This makes it very easy to deal with bitvectors
279         // that have a lot of trailing zero's.
280         uintptr_t value;
281         if (isInline())
282             value = cleanseInlineBits(m_bitsOrPointer);
283         else
284             value = hashSlowCase();
285         return IntHash<uintptr_t>::hash(value);
286     }
287     
288     class iterator {
289     public:
290         iterator()
291             : m_bitVector(nullptr)
292             , m_index(0)
293         {
294         }
295         
296         iterator(const BitVector& bitVector, size_t index)
297             : m_bitVector(&bitVector)
298             , m_index(index)
299         {
300         }
301         
302         size_t operator*() const { return m_index; }
303         
304         iterator& operator++()
305         {
306             m_index = m_bitVector->findBit(m_index + 1, true);
307             return *this;
308         }
309
310         iterator operator++(int)
311         {
312             iterator result = *this;
313             ++(*this);
314             return result;
315         }
316
317         bool isAtEnd() const
318         {
319             return m_index >= m_bitVector->size();
320         }
321         
322         bool operator==(const iterator& other) const
323         {
324             return m_index == other.m_index;
325         }
326         
327         bool operator!=(const iterator& other) const
328         {
329             return !(*this == other);
330         }
331     private:
332         const BitVector* m_bitVector;
333         size_t m_index;
334     };
335
336     // Use this to iterate over set bits.
337     iterator begin() const { return iterator(*this, findBit(0, true)); }
338     iterator end() const { return iterator(*this, size()); }
339         
340 private:
341     static unsigned bitsInPointer()
342     {
343         return sizeof(void*) << 3;
344     }
345
346     static unsigned maxInlineBits()
347     {
348         return bitsInPointer() - 1;
349     }
350
351     static size_t byteCount(size_t bitCount)
352     {
353         return (bitCount + 7) >> 3;
354     }
355
356     static uintptr_t makeInlineBits(uintptr_t bits)
357     {
358         ASSERT(!(bits & (static_cast<uintptr_t>(1) << maxInlineBits())));
359         return bits | (static_cast<uintptr_t>(1) << maxInlineBits());
360     }
361     
362     static uintptr_t cleanseInlineBits(uintptr_t bits)
363     {
364         return bits & ~(static_cast<uintptr_t>(1) << maxInlineBits());
365     }
366     
367     static size_t bitCount(uintptr_t bits)
368     {
369         if (sizeof(uintptr_t) == 4)
370             return WTF::bitCount(static_cast<unsigned>(bits));
371         return WTF::bitCount(static_cast<uint64_t>(bits));
372     }
373     
374     size_t findBitFast(size_t startIndex, bool value) const
375     {
376         if (isInline()) {
377             size_t index = startIndex;
378             findBitInWord(m_bitsOrPointer, index, maxInlineBits(), value);
379             return index;
380         }
381         
382         const OutOfLineBits* bits = outOfLineBits();
383         
384         // value = true: casts to 1, then xors to 0, then negates to 0.
385         // value = false: casts to 0, then xors to 1, then negates to -1 (i.e. all one bits).
386         uintptr_t skipValue = -(static_cast<uintptr_t>(value) ^ 1);
387         size_t numWords = bits->numWords();
388         
389         size_t wordIndex = startIndex / bitsInPointer();
390         size_t startIndexInWord = startIndex - wordIndex * bitsInPointer();
391         
392         while (wordIndex < numWords) {
393             uintptr_t word = bits->bits()[wordIndex];
394             if (word != skipValue) {
395                 size_t index = startIndexInWord;
396                 if (findBitInWord(word, index, bitsInPointer(), value))
397                     return wordIndex * bitsInPointer() + index;
398             }
399             
400             wordIndex++;
401             startIndexInWord = 0;
402         }
403         
404         return bits->numBits();
405     }
406     
407     size_t findBitSimple(size_t index, bool value) const
408     {
409         while (index < size()) {
410             if (get(index) == value)
411                 return index;
412             index++;
413         }
414         return size();
415     }
416     
417     class OutOfLineBits {
418     public:
419         size_t numBits() const { return m_numBits; }
420         size_t numWords() const { return (m_numBits + bitsInPointer() - 1) / bitsInPointer(); }
421         uintptr_t* bits() { return bitwise_cast<uintptr_t*>(this + 1); }
422         const uintptr_t* bits() const { return bitwise_cast<const uintptr_t*>(this + 1); }
423         
424         static WTF_EXPORT_PRIVATE OutOfLineBits* create(size_t numBits);
425         
426         static WTF_EXPORT_PRIVATE void destroy(OutOfLineBits*);
427
428     private:
429         OutOfLineBits(size_t numBits)
430             : m_numBits(numBits)
431         {
432         }
433         
434         size_t m_numBits;
435     };
436     
437     bool isInline() const { return m_bitsOrPointer >> maxInlineBits(); }
438     
439     const OutOfLineBits* outOfLineBits() const { return bitwise_cast<const OutOfLineBits*>(m_bitsOrPointer << 1); }
440     OutOfLineBits* outOfLineBits() { return bitwise_cast<OutOfLineBits*>(m_bitsOrPointer << 1); }
441     
442     WTF_EXPORT_PRIVATE void resizeOutOfLine(size_t numBits);
443     WTF_EXPORT_PRIVATE void setSlow(const BitVector& other);
444     
445     WTF_EXPORT_PRIVATE void mergeSlow(const BitVector& other);
446     WTF_EXPORT_PRIVATE void filterSlow(const BitVector& other);
447     WTF_EXPORT_PRIVATE void excludeSlow(const BitVector& other);
448     
449     WTF_EXPORT_PRIVATE size_t bitCountSlow() const;
450     
451     WTF_EXPORT_PRIVATE bool equalsSlowCase(const BitVector& other) const;
452     bool equalsSlowCaseFast(const BitVector& other) const;
453     bool equalsSlowCaseSimple(const BitVector& other) const;
454     WTF_EXPORT_PRIVATE uintptr_t hashSlowCase() const;
455     
456     uintptr_t* bits()
457     {
458         if (isInline())
459             return &m_bitsOrPointer;
460         return outOfLineBits()->bits();
461     }
462     
463     const uintptr_t* bits() const
464     {
465         if (isInline())
466             return &m_bitsOrPointer;
467         return outOfLineBits()->bits();
468     }
469     
470     uintptr_t m_bitsOrPointer;
471 };
472
473 struct BitVectorHash {
474     static unsigned hash(const BitVector& vector) { return vector.hash(); }
475     static bool equal(const BitVector& a, const BitVector& b) { return a == b; }
476     static const bool safeToCompareToEmptyOrDeleted = false;
477 };
478
479 template<typename T> struct DefaultHash;
480 template<> struct DefaultHash<BitVector> {
481     typedef BitVectorHash Hash;
482 };
483
484 template<> struct HashTraits<BitVector> : public CustomHashTraits<BitVector> { };
485
486 } // namespace WTF
487
488 using WTF::BitVector;