FastBitVector should have efficient and easy-to-use vector-vector operations
[WebKit-https.git] / Source / WTF / wtf / FastBitVector.h
1 /*
2  * Copyright (C) 2012, 2013, 2016 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
14  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL APPLE INC. OR
17  * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
18  * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
19  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
20  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
21  * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
24  */
25
26 #pragma once
27
28 #include <string.h>
29 #include <wtf/FastMalloc.h>
30 #include <wtf/PrintStream.h>
31 #include <wtf/StdLibExtras.h>
32
33 namespace WTF {
34
35 class PrintStream;
36
37 inline size_t fastBitVectorArrayLength(size_t numBits) { return (numBits + 31) / 32; }
38
39 class FastBitVectorWordView {
40 public:
41     typedef FastBitVectorWordView ViewType;
42     
43     FastBitVectorWordView() { }
44     
45     FastBitVectorWordView(const uint32_t* array, size_t numBits)
46         : m_words(array)
47         , m_numBits(numBits)
48     {
49     }
50     
51     size_t numBits() const
52     {
53         return m_numBits;
54     }
55     
56     uint32_t word(size_t index) const
57     {
58         ASSERT_WITH_SECURITY_IMPLICATION(index < fastBitVectorArrayLength(numBits()));
59         return m_words[index];
60     }
61     
62 private:
63     const uint32_t* m_words { nullptr };
64     size_t m_numBits { 0 };
65 };
66
67 class FastBitVectorWordOwner {
68 public:
69     typedef FastBitVectorWordView ViewType;
70     
71     FastBitVectorWordOwner() = default;
72     
73     FastBitVectorWordOwner(FastBitVectorWordOwner&& other)
74         : m_words(std::exchange(other.m_words, nullptr))
75         , m_numBits(std::exchange(other.m_numBits, 0))
76     {
77     }
78
79     FastBitVectorWordOwner(const FastBitVectorWordOwner& other)
80     {
81         *this = other;
82     }
83     
84     ~FastBitVectorWordOwner()
85     {
86         if (m_words)
87             fastFree(m_words);
88     }
89     
90     FastBitVectorWordView view() const { return FastBitVectorWordView(m_words, m_numBits); }
91     
92     FastBitVectorWordOwner& operator=(const FastBitVectorWordOwner& other)
93     {
94         size_t length = other.arrayLength();
95         if (length == arrayLength()) {
96             memcpy(m_words, other.m_words, length * sizeof(uint32_t));
97             return *this;
98         }
99         uint32_t* newArray = static_cast<uint32_t*>(fastCalloc(length, sizeof(uint32_t)));
100         memcpy(newArray, other.m_words, length * sizeof(uint32_t));
101         if (m_words)
102             fastFree(m_words);
103         m_words = newArray;
104         m_numBits = other.m_numBits;
105         return *this;
106     }
107     
108     FastBitVectorWordOwner& operator=(FastBitVectorWordOwner&& other)
109     {
110         std::swap(m_words, other.m_words);
111         std::swap(m_numBits, other.m_numBits);
112         return *this;
113     }
114     
115     void setAll()
116     {
117         memset(m_words, 255, arrayLength() * sizeof(uint32_t));
118     }
119     
120     void clearAll()
121     {
122         memset(m_words, 0, arrayLength() * sizeof(uint32_t));
123     }
124     
125     void set(const FastBitVectorWordOwner& other)
126     {
127         ASSERT_WITH_SECURITY_IMPLICATION(m_numBits == other.m_numBits);
128         memcpy(m_words, other.m_words, arrayLength() * sizeof(uint32_t));
129     }
130     
131     size_t numBits() const
132     {
133         return m_numBits;
134     }
135     
136     size_t arrayLength() const
137     {
138         return fastBitVectorArrayLength(numBits());
139     }
140     
141     void resize(size_t numBits)
142     {
143         if (numBits == m_numBits)
144             return;
145         
146         // Use fastCalloc instead of fastRealloc because we expect the common
147         // use case for this method to be initializing the size of the bitvector.
148         
149         size_t newLength = fastBitVectorArrayLength(numBits);
150         uint32_t* newArray = static_cast<uint32_t*>(fastCalloc(newLength, sizeof(uint32_t)));
151         memcpy(newArray, m_words, arrayLength() * sizeof(uint32_t));
152         if (m_words)
153             fastFree(m_words);
154         m_words = newArray;
155         m_numBits = numBits;
156     }
157     
158     uint32_t word(size_t index) const
159     {
160         ASSERT_WITH_SECURITY_IMPLICATION(index < arrayLength());
161         return m_words[index];
162     }
163     
164     uint32_t& word(size_t index)
165     {
166         ASSERT_WITH_SECURITY_IMPLICATION(index < arrayLength());
167         return m_words[index];
168     }
169     
170     const uint32_t* words() const { return m_words; }
171     uint32_t* words() { return m_words; }
172
173 private:
174     uint32_t* m_words { nullptr };
175     size_t m_numBits { 0 };
176 };
177
178 template<typename Left, typename Right>
179 class FastBitVectorAndWords {
180 public:
181     typedef FastBitVectorAndWords ViewType;
182     
183     FastBitVectorAndWords(const Left& left, const Right& right)
184         : m_left(left)
185         , m_right(right)
186     {
187         ASSERT_WITH_SECURITY_IMPLICATION(m_left.numBits() == m_right.numBits());
188     }
189     
190     FastBitVectorAndWords view() const { return *this; }
191     
192     size_t numBits() const
193     {
194         return m_left.numBits();
195     }
196     
197     uint32_t word(size_t index) const
198     {
199         return m_left.word(index) & m_right.word(index);
200     }
201     
202 private:
203     Left m_left;
204     Right m_right;
205 };
206     
207 template<typename Left, typename Right>
208 class FastBitVectorOrWords {
209 public:
210     typedef FastBitVectorOrWords ViewType;
211     
212     FastBitVectorOrWords(const Left& left, const Right& right)
213         : m_left(left)
214         , m_right(right)
215     {
216         ASSERT_WITH_SECURITY_IMPLICATION(m_left.numBits() == m_right.numBits());
217     }
218     
219     FastBitVectorOrWords view() const { return *this; }
220     
221     size_t numBits() const
222     {
223         return m_left.numBits();
224     }
225     
226     uint32_t word(size_t index) const
227     {
228         return m_left.word(index) | m_right.word(index);
229     }
230     
231 private:
232     Left m_left;
233     Right m_right;
234 };
235     
236 template<typename View>
237 class FastBitVectorNotWords {
238 public:
239     typedef FastBitVectorNotWords ViewType;
240     
241     FastBitVectorNotWords(const View& view)
242         : m_view(view)
243     {
244     }
245     
246     FastBitVectorNotWords view() const { return *this; }
247     
248     size_t numBits() const
249     {
250         return m_view.numBits();
251     }
252     
253     uint32_t word(size_t index) const
254     {
255         return ~m_view.word(index);
256     }
257     
258 private:
259     View m_view;
260 };
261     
262 class FastBitVector;
263
264 template<typename Words>
265 class FastBitVectorImpl {
266 public:
267     FastBitVectorImpl()
268         : m_words()
269     {
270     }
271     
272     FastBitVectorImpl(const Words& words)
273         : m_words(words)
274     {
275     }
276     
277     FastBitVectorImpl(Words&& words)
278         : m_words(WTFMove(words))
279     {
280     }
281
282     size_t numBits() const { return m_words.numBits(); }
283     size_t size() const { return numBits(); }
284     
285     size_t arrayLength() const { return fastBitVectorArrayLength(numBits()); }
286     
287     template<typename Other>
288     bool operator==(const Other& other) const
289     {
290         if (numBits() != other.numBits())
291             return false;
292         for (size_t i = arrayLength(); i--;) {
293             if (m_words.word(i) != other.m_words.word(i))
294                 return false;
295         }
296         return true;
297     }
298     
299     template<typename Other>
300     bool operator!=(const Other& other) const
301     {
302         return !(*this == other);
303     }
304     
305     bool at(size_t index) const
306     {
307         return atImpl(index);
308     }
309     
310     bool operator[](size_t index) const
311     {
312         return atImpl(index);
313     }
314     
315     size_t bitCount() const
316     {
317         size_t result = 0;
318         for (size_t index = arrayLength(); index--;)
319             result += WTF::bitCount(m_words.word(index));
320         return result;
321     }
322     
323     template<typename OtherWords>
324     FastBitVectorImpl<FastBitVectorAndWords<typename Words::ViewType, typename OtherWords::ViewType>> operator&(const FastBitVectorImpl<OtherWords>& other) const
325     {
326         return FastBitVectorImpl<FastBitVectorAndWords<typename Words::ViewType, typename OtherWords::ViewType>>(FastBitVectorAndWords<typename Words::ViewType, typename OtherWords::ViewType>(m_words.view(), other.m_words.view()));
327     }
328     
329     template<typename OtherWords>
330     FastBitVectorImpl<FastBitVectorOrWords<typename Words::ViewType, typename OtherWords::ViewType>> operator|(const FastBitVectorImpl<OtherWords>& other) const
331     {
332         return FastBitVectorImpl<FastBitVectorOrWords<typename Words::ViewType, typename OtherWords::ViewType>>(FastBitVectorOrWords<typename Words::ViewType, typename OtherWords::ViewType>(m_words.view(), other.m_words.view()));
333     }
334     
335     FastBitVectorImpl<FastBitVectorNotWords<typename Words::ViewType>> operator~() const
336     {
337         return FastBitVectorImpl<FastBitVectorNotWords<typename Words::ViewType>>(FastBitVectorNotWords<typename Words::ViewType>(m_words.view()));
338     }
339     
340     template<typename Func>
341     ALWAYS_INLINE void forEachSetBit(const Func& func) const
342     {
343         size_t n = m_words.arrayLength();
344         for (size_t i = 0; i < n; ++i) {
345             uint32_t word = m_words.word(i);
346             size_t j = i * 32;
347             while (word) {
348                 if (word & 1)
349                     func(j);
350                 word >>= 1;
351                 j++;
352             }
353         }
354     }
355     
356     template<typename Func>
357     ALWAYS_INLINE void forEachClearBit(const Func& func) const
358     {
359         (~*this).forEachSetBit(func);
360     }
361     
362     template<typename Func>
363     void forEachBit(bool value, const Func& func) const
364     {
365         if (value)
366             forEachSetBit(func);
367         else
368             forEachClearBit(func);
369     }
370     
371     // Starts looking for bits at the index you pass. If that index contains the value you want,
372     // then it will return that index. Returns numBits when we get to the end. For example, you
373     // can write a loop to iterate over all set bits like this:
374     //
375     // for (size_t i = 0; i < bits.numBits(); i = bits.findBit(i + 1, true))
376     //     ...
377     ALWAYS_INLINE size_t findBit(size_t startIndex, bool value) const
378     {
379         // If value is true, this produces 0. If value is false, this produces UINT_MAX. It's
380         // written this way so that it performs well regardless of whether value is a constant.
381         uint32_t skipValue = -(static_cast<uint32_t>(value) ^ 1);
382         
383         size_t numWords = m_words.arrayLength();
384         
385         size_t wordIndex = startIndex / 32;
386         size_t startIndexInWord = startIndex - wordIndex * 32;
387         
388         while (wordIndex < numWords) {
389             uint32_t word = m_words.word(wordIndex);
390             if (word != skipValue) {
391                 size_t index = startIndexInWord;
392                 if (findBitInWord(word, index, 32, value))
393                     return wordIndex * 32 + index;
394             }
395             
396             wordIndex++;
397             startIndexInWord = 0;
398         }
399         
400         return numBits();
401     }
402     
403     ALWAYS_INLINE size_t findSetBit(size_t index) const
404     {
405         return findBit(index, true);
406     }
407     
408     ALWAYS_INLINE size_t findClearBit(size_t index) const
409     {
410         return findBit(index, false);
411     }
412     
413     void dump(PrintStream& out) const
414     {
415         for (size_t i = 0; i < numBits(); ++i)
416             out.print((*this)[i] ? "1" : "-");
417     }
418     
419 private:
420     // You'd think that we could remove this friend if we used protected, but you'd be wrong,
421     // because templates.
422     friend class FastBitVector;
423     
424     bool atImpl(size_t index) const
425     {
426         ASSERT_WITH_SECURITY_IMPLICATION(index < numBits());
427         return !!(m_words.word(index >> 5) & (1 << (index & 31)));
428     }
429     
430     Words m_words;
431 };
432
433 class FastBitVector : public FastBitVectorImpl<FastBitVectorWordOwner> {
434 public:
435     FastBitVector() { }
436     
437     FastBitVector(const FastBitVector&) = default;
438     FastBitVector& operator=(const FastBitVector&) = default;
439     
440     template<typename OtherWords>
441     FastBitVector(const FastBitVectorImpl<OtherWords>& other)
442     {
443         *this = other;
444     }
445     
446     template<typename OtherWords>
447     FastBitVector& operator=(const FastBitVectorImpl<OtherWords>& other)
448     {
449         if (UNLIKELY(numBits() != other.numBits()))
450             resize(other.numBits());
451         
452         for (unsigned i = arrayLength(); i--;)
453             m_words.word(i) = other.m_words.word(i);
454         return *this;
455     }
456     
457     void resize(size_t numBits)
458     {
459         m_words.resize(numBits);
460     }
461     
462     void setAll()
463     {
464         m_words.setAll();
465     }
466     
467     void clearAll()
468     {
469         m_words.clearAll();
470     }
471
472     template<typename OtherWords>
473     bool setAndCheck(const FastBitVectorImpl<OtherWords>& other)
474     {
475         bool changed = false;
476         ASSERT_WITH_SECURITY_IMPLICATION(numBits() == other.numBits());
477         for (unsigned i = arrayLength(); i--;) {
478             changed |= m_words.word(i) != other.m_words.word(i);
479             m_words.word(i) = other.m_words.word(i);
480         }
481         return changed;
482     }
483     
484     template<typename OtherWords>
485     FastBitVector& operator|=(const FastBitVectorImpl<OtherWords>& other)
486     {
487         ASSERT_WITH_SECURITY_IMPLICATION(numBits() == other.numBits());
488         for (unsigned i = arrayLength(); i--;)
489             m_words.word(i) |= other.m_words.word(i);
490         return *this;
491     }
492     
493     template<typename OtherWords>
494     FastBitVector& operator&=(const FastBitVectorImpl<OtherWords>& other)
495     {
496         ASSERT_WITH_SECURITY_IMPLICATION(numBits() == other.numBits());
497         for (unsigned i = arrayLength(); i--;)
498             m_words.word(i) &= other.m_words.word(i);
499         return *this;
500     }
501     
502     bool at(size_t index) const
503     {
504         return atImpl(index);
505     }
506     
507     bool operator[](size_t index) const
508     {
509         return atImpl(index);
510     }
511     
512     class BitReference {
513     public:
514         BitReference() { }
515         
516         BitReference(uint32_t* word, uint32_t mask)
517             : m_word(word)
518             , m_mask(mask)
519         {
520         }
521         
522         explicit operator bool() const
523         {
524             return !!(*m_word & m_mask);
525         }
526         
527         BitReference& operator=(bool value)
528         {
529             if (value)
530                 *m_word |= m_mask;
531             else
532                 *m_word &= ~m_mask;
533             return *this;
534         }
535         
536     private:
537         uint32_t* m_word { nullptr };
538         uint32_t m_mask { 0 };
539     };
540     
541     BitReference at(size_t index)
542     {
543         ASSERT_WITH_SECURITY_IMPLICATION(index < numBits());
544         return BitReference(&m_words.word(index >> 5), 1 << (index & 31));
545     }
546     
547     BitReference operator[](size_t index)
548     {
549         return at(index);
550     }
551 };
552
553 } // namespace WTF
554
555 using WTF::FastBitVector;