Unreviewed, rolling out r234489.
[WebKit-https.git] / Source / WTF / wtf / BitVector.h
1 /*
2  * Copyright (C) 2011, 2014, 2016 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
14  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL APPLE INC. OR
17  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
18  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
19  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
20  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
21  * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
24  */
25
26 #ifndef BitVector_h
27 #define BitVector_h
28
29 #include <stdio.h>
30 #include <wtf/Assertions.h>
31 #include <wtf/DataLog.h>
32 #include <wtf/HashFunctions.h>
33 #include <wtf/HashTraits.h>
34 #include <wtf/PrintStream.h>
35 #include <wtf/StdLibExtras.h>
36
37 namespace WTF {
38
39 // This is a space-efficient, resizeable bitvector class. In the common case it
40 // occupies one word, but if necessary, it will inflate this one word to point
41 // to a single chunk of out-of-line allocated storage to store an arbitrary number
42 // of bits.
43 //
44 // - The bitvector remembers the bound of how many bits can be stored, but this
45 //   may be slightly greater (by as much as some platform-specific constant)
46 //   than the last argument passed to ensureSize().
47 //
48 // - The bitvector can resize itself automatically (set, clear, get) or can be used
49 //   in a manual mode, which is faster (quickSet, quickClear, quickGet, ensureSize).
50 //
51 // - Accesses ASSERT that you are within bounds.
52 //
53 // - Bits are automatically initialized to zero.
54 //
55 // On the other hand, this BitVector class may not be the fastest around, since
56 // it does conditionals on every get/set/clear. But it is great if you need to
57 // juggle a lot of variable-length BitVectors and you're worried about wasting
58 // space.
59
60 class BitVector {
61 public: 
62     BitVector()
63         : m_bitsOrPointer(makeInlineBits(0))
64     {
65     }
66     
67     explicit BitVector(size_t numBits)
68         : m_bitsOrPointer(makeInlineBits(0))
69     {
70         ensureSize(numBits);
71     }
72     
73     BitVector(const BitVector& other)
74         : m_bitsOrPointer(makeInlineBits(0))
75     {
76         (*this) = other;
77     }
78
79     
80     ~BitVector()
81     {
82         if (isInline())
83             return;
84         OutOfLineBits::destroy(outOfLineBits());
85     }
86     
87     BitVector& operator=(const BitVector& other)
88     {
89         if (isInline() && other.isInline())
90             m_bitsOrPointer = other.m_bitsOrPointer;
91         else
92             setSlow(other);
93         return *this;
94     }
95
96     size_t size() const
97     {
98         if (isInline())
99             return maxInlineBits();
100         return outOfLineBits()->numBits();
101     }
102
103     void ensureSize(size_t numBits)
104     {
105         if (numBits <= size())
106             return;
107         resizeOutOfLine(numBits);
108     }
109     
110     // Like ensureSize(), but supports reducing the size of the bitvector.
111     WTF_EXPORT_PRIVATE void resize(size_t numBits);
112     
113     WTF_EXPORT_PRIVATE void clearAll();
114
115     bool quickGet(size_t bit) const
116     {
117         ASSERT_WITH_SECURITY_IMPLICATION(bit < size());
118         return !!(bits()[bit / bitsInPointer()] & (static_cast<uintptr_t>(1) << (bit & (bitsInPointer() - 1))));
119     }
120     
121     bool quickSet(size_t bit)
122     {
123         ASSERT_WITH_SECURITY_IMPLICATION(bit < size());
124         uintptr_t& word = bits()[bit / bitsInPointer()];
125         uintptr_t mask = static_cast<uintptr_t>(1) << (bit & (bitsInPointer() - 1));
126         bool result = !!(word & mask);
127         word |= mask;
128         return result;
129     }
130     
131     bool quickClear(size_t bit)
132     {
133         ASSERT_WITH_SECURITY_IMPLICATION(bit < size());
134         uintptr_t& word = bits()[bit / bitsInPointer()];
135         uintptr_t mask = static_cast<uintptr_t>(1) << (bit & (bitsInPointer() - 1));
136         bool result = !!(word & mask);
137         word &= ~mask;
138         return result;
139     }
140     
141     bool quickSet(size_t bit, bool value)
142     {
143         if (value)
144             return quickSet(bit);
145         return quickClear(bit);
146     }
147     
148     bool get(size_t bit) const
149     {
150         if (bit >= size())
151             return false;
152         return quickGet(bit);
153     }
154
155     bool contains(size_t bit) const
156     {
157         return get(bit);
158     }
159     
160     bool set(size_t bit)
161     {
162         ensureSize(bit + 1);
163         return quickSet(bit);
164     }
165
166     // This works like the add methods of sets. Instead of returning the previous value, like set(),
167     // it returns whether the bit transitioned from false to true.
168     bool add(size_t bit)
169     {
170         return !set(bit);
171     }
172
173     bool ensureSizeAndSet(size_t bit, size_t size)
174     {
175         ensureSize(size);
176         return quickSet(bit);
177     }
178
179     bool clear(size_t bit)
180     {
181         if (bit >= size())
182             return false;
183         return quickClear(bit);
184     }
185
186     bool remove(size_t bit)
187     {
188         return clear(bit);
189     }
190     
191     bool set(size_t bit, bool value)
192     {
193         if (value)
194             return set(bit);
195         return clear(bit);
196     }
197     
198     void merge(const BitVector& other)
199     {
200         if (!isInline() || !other.isInline()) {
201             mergeSlow(other);
202             return;
203         }
204         m_bitsOrPointer |= other.m_bitsOrPointer;
205         ASSERT(isInline());
206     }
207     
208     void filter(const BitVector& other)
209     {
210         if (!isInline() || !other.isInline()) {
211             filterSlow(other);
212             return;
213         }
214         m_bitsOrPointer &= other.m_bitsOrPointer;
215         ASSERT(isInline());
216     }
217     
218     void exclude(const BitVector& other)
219     {
220         if (!isInline() || !other.isInline()) {
221             excludeSlow(other);
222             return;
223         }
224         m_bitsOrPointer &= ~other.m_bitsOrPointer;
225         m_bitsOrPointer |= (static_cast<uintptr_t>(1) << maxInlineBits());
226         ASSERT(isInline());
227     }
228     
229     size_t bitCount() const
230     {
231         if (isInline())
232             return bitCount(cleanseInlineBits(m_bitsOrPointer));
233         return bitCountSlow();
234     }
235     
236     size_t findBit(size_t index, bool value) const
237     {
238         size_t result = findBitFast(index, value);
239         if (!ASSERT_DISABLED) {
240             size_t expectedResult = findBitSimple(index, value);
241             if (result != expectedResult) {
242                 dataLog("findBit(", index, ", ", value, ") on ", *this, " should have gotten ", expectedResult, " but got ", result, "\n");
243                 ASSERT_NOT_REACHED();
244             }
245         }
246         return result;
247     }
248     
249     WTF_EXPORT_PRIVATE void dump(PrintStream& out) const;
250     
251     enum EmptyValueTag { EmptyValue };
252     enum DeletedValueTag { DeletedValue };
253     
254     BitVector(EmptyValueTag)
255         : m_bitsOrPointer(0)
256     {
257     }
258     
259     BitVector(DeletedValueTag)
260         : m_bitsOrPointer(1)
261     {
262     }
263     
264     bool isEmptyValue() const { return !m_bitsOrPointer; }
265     bool isDeletedValue() const { return m_bitsOrPointer == 1; }
266     
267     bool isEmptyOrDeletedValue() const { return m_bitsOrPointer <= 1; }
268     
269     bool operator==(const BitVector& other) const
270     {
271         if (isInline() && other.isInline())
272             return m_bitsOrPointer == other.m_bitsOrPointer;
273         return equalsSlowCase(other);
274     }
275     
276     unsigned hash() const
277     {
278         // This is a very simple hash. Just xor together the words that hold the various
279         // bits and then compute the hash. This makes it very easy to deal with bitvectors
280         // that have a lot of trailing zero's.
281         uintptr_t value;
282         if (isInline())
283             value = cleanseInlineBits(m_bitsOrPointer);
284         else
285             value = hashSlowCase();
286         return IntHash<uintptr_t>::hash(value);
287     }
288     
289     class iterator {
290     public:
291         iterator()
292             : m_bitVector(nullptr)
293             , m_index(0)
294         {
295         }
296         
297         iterator(const BitVector& bitVector, size_t index)
298             : m_bitVector(&bitVector)
299             , m_index(index)
300         {
301         }
302         
303         size_t operator*() const { return m_index; }
304         
305         iterator& operator++()
306         {
307             m_index = m_bitVector->findBit(m_index + 1, true);
308             return *this;
309         }
310
311         iterator operator++(int)
312         {
313             iterator result = *this;
314             ++(*this);
315             return result;
316         }
317
318         bool isAtEnd() const
319         {
320             return m_index >= m_bitVector->size();
321         }
322         
323         bool operator==(const iterator& other) const
324         {
325             return m_index == other.m_index;
326         }
327         
328         bool operator!=(const iterator& other) const
329         {
330             return !(*this == other);
331         }
332     private:
333         const BitVector* m_bitVector;
334         size_t m_index;
335     };
336
337     // Use this to iterate over set bits.
338     iterator begin() const { return iterator(*this, findBit(0, true)); }
339     iterator end() const { return iterator(*this, size()); }
340         
341 private:
342     static unsigned bitsInPointer()
343     {
344         return sizeof(void*) << 3;
345     }
346
347     static unsigned maxInlineBits()
348     {
349         return bitsInPointer() - 1;
350     }
351
352     static size_t byteCount(size_t bitCount)
353     {
354         return (bitCount + 7) >> 3;
355     }
356
357     static uintptr_t makeInlineBits(uintptr_t bits)
358     {
359         ASSERT(!(bits & (static_cast<uintptr_t>(1) << maxInlineBits())));
360         return bits | (static_cast<uintptr_t>(1) << maxInlineBits());
361     }
362     
363     static uintptr_t cleanseInlineBits(uintptr_t bits)
364     {
365         return bits & ~(static_cast<uintptr_t>(1) << maxInlineBits());
366     }
367     
368     static size_t bitCount(uintptr_t bits)
369     {
370         if (sizeof(uintptr_t) == 4)
371             return WTF::bitCount(static_cast<unsigned>(bits));
372         return WTF::bitCount(static_cast<uint64_t>(bits));
373     }
374     
375     size_t findBitFast(size_t startIndex, bool value) const
376     {
377         if (isInline()) {
378             size_t index = startIndex;
379             findBitInWord(m_bitsOrPointer, index, maxInlineBits(), value);
380             return index;
381         }
382         
383         const OutOfLineBits* bits = outOfLineBits();
384         
385         // value = true: casts to 1, then xors to 0, then negates to 0.
386         // value = false: casts to 0, then xors to 1, then negates to -1 (i.e. all one bits).
387         uintptr_t skipValue = -(static_cast<uintptr_t>(value) ^ 1);
388         size_t numWords = bits->numWords();
389         
390         size_t wordIndex = startIndex / bitsInPointer();
391         size_t startIndexInWord = startIndex - wordIndex * bitsInPointer();
392         
393         while (wordIndex < numWords) {
394             uintptr_t word = bits->bits()[wordIndex];
395             if (word != skipValue) {
396                 size_t index = startIndexInWord;
397                 if (findBitInWord(word, index, bitsInPointer(), value))
398                     return wordIndex * bitsInPointer() + index;
399             }
400             
401             wordIndex++;
402             startIndexInWord = 0;
403         }
404         
405         return bits->numBits();
406     }
407     
408     size_t findBitSimple(size_t index, bool value) const
409     {
410         while (index < size()) {
411             if (get(index) == value)
412                 return index;
413             index++;
414         }
415         return size();
416     }
417     
418     class OutOfLineBits {
419     public:
420         size_t numBits() const { return m_numBits; }
421         size_t numWords() const { return (m_numBits + bitsInPointer() - 1) / bitsInPointer(); }
422         uintptr_t* bits() { return bitwise_cast<uintptr_t*>(this + 1); }
423         const uintptr_t* bits() const { return bitwise_cast<const uintptr_t*>(this + 1); }
424         
425         static WTF_EXPORT_PRIVATE OutOfLineBits* create(size_t numBits);
426         
427         static WTF_EXPORT_PRIVATE void destroy(OutOfLineBits*);
428
429     private:
430         OutOfLineBits(size_t numBits)
431             : m_numBits(numBits)
432         {
433         }
434         
435         size_t m_numBits;
436     };
437     
438     bool isInline() const { return m_bitsOrPointer >> maxInlineBits(); }
439     
440     const OutOfLineBits* outOfLineBits() const { return bitwise_cast<const OutOfLineBits*>(m_bitsOrPointer << 1); }
441     OutOfLineBits* outOfLineBits() { return bitwise_cast<OutOfLineBits*>(m_bitsOrPointer << 1); }
442     
443     WTF_EXPORT_PRIVATE void resizeOutOfLine(size_t numBits);
444     WTF_EXPORT_PRIVATE void setSlow(const BitVector& other);
445     
446     WTF_EXPORT_PRIVATE void mergeSlow(const BitVector& other);
447     WTF_EXPORT_PRIVATE void filterSlow(const BitVector& other);
448     WTF_EXPORT_PRIVATE void excludeSlow(const BitVector& other);
449     
450     WTF_EXPORT_PRIVATE size_t bitCountSlow() const;
451     
452     WTF_EXPORT_PRIVATE bool equalsSlowCase(const BitVector& other) const;
453     bool equalsSlowCaseFast(const BitVector& other) const;
454     bool equalsSlowCaseSimple(const BitVector& other) const;
455     WTF_EXPORT_PRIVATE uintptr_t hashSlowCase() const;
456     
457     uintptr_t* bits()
458     {
459         if (isInline())
460             return &m_bitsOrPointer;
461         return outOfLineBits()->bits();
462     }
463     
464     const uintptr_t* bits() const
465     {
466         if (isInline())
467             return &m_bitsOrPointer;
468         return outOfLineBits()->bits();
469     }
470     
471     uintptr_t m_bitsOrPointer;
472 };
473
474 struct BitVectorHash {
475     static unsigned hash(const BitVector& vector) { return vector.hash(); }
476     static bool equal(const BitVector& a, const BitVector& b) { return a == b; }
477     static const bool safeToCompareToEmptyOrDeleted = false;
478 };
479
480 template<typename T> struct DefaultHash;
481 template<> struct DefaultHash<BitVector> {
482     typedef BitVectorHash Hash;
483 };
484
485 template<typename T> struct HashTraits;
486 template<> struct HashTraits<BitVector> : public CustomHashTraits<BitVector> { };
487
488 } // namespace WTF
489
490 using WTF::BitVector;
491
492 #endif // BitVector_h