diff --git a/playhouse/_sqlite_ext.pyx b/playhouse/_sqlite_ext.pyx index 7fa3949e0..3e6408016 100644 --- a/playhouse/_sqlite_ext.pyx +++ b/playhouse/_sqlite_ext.pyx @@ -1099,7 +1099,7 @@ seeds[:] = [0, 1337, 37, 0xabcd, 0xdead, 0xface, 97, 0xed11, 0xcad9, 0x827b] cdef bf_t *bf_create(size_t size): cdef bf_t *bf = calloc(1, sizeof(bf_t)) bf.size = size - bf.bits = malloc(size) + bf.bits = calloc(1, size) return bf @cython.cdivision(True) @@ -1152,6 +1152,9 @@ cdef class BloomFilter(object): if self.bf: bf_free(self.bf) + def __len__(self): + return self.bf.size + def add(self, *keys): cdef bytes bkey @@ -1171,6 +1174,19 @@ cdef class BloomFilter(object): # embedded NULL bytes. return buf + @classmethod + def from_buffer(cls, data): + cdef: + char *buf + Py_ssize_t buflen + BloomFilter bloom + + PyBytes_AsStringAndSize(data, &buf, &buflen) + + bloom = BloomFilter(buflen) + memcpy(bloom.bf.bits, buf, buflen) + return bloom + @classmethod def calculate_size(cls, double n, double p): cdef double m = ceil((n * log(p)) / log(1.0 / (pow(2.0, log(2.0))))) diff --git a/tests/cysqlite.py b/tests/cysqlite.py index 5c361a103..42d19280c 100644 --- a/tests/cysqlite.py +++ b/tests/cysqlite.py @@ -376,9 +376,11 @@ def test_bloomfilter(self): class TestBloomFilter(BaseTestCase): + n = 1024 + def setUp(self): super(TestBloomFilter, self).setUp() - self.bf = BloomFilter(1024) + self.bf = BloomFilter(self.n) def test_bloomfilter(self): keys = ('charlie', 'huey', 'mickey', 'zaizee', 'nuggie', 'foo', 'bar', @@ -392,6 +394,33 @@ def test_bloomfilter(self): self.assertFalse(key + '-y' in self.bf) self.assertFalse(key + ' ' in self.bf) + def test_bloomfilter_buffer(self): + self.assertEqual(len(self.bf), self.n) + + # Buffer is all zeroes when uninitialized. + buf = self.bf.to_buffer() + self.assertEqual(len(buf), self.n) + self.assertEqual(buf, b'\x00' * self.n) + + keys = ('alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta') + self.bf.add(*keys) + + for key in keys: + self.assertTrue(key in self.bf) + self.assertFalse(key + '-x' in self.bf) + + # Convert to buffer and then populate a 2nd bloom-filter. + buf = self.bf.to_buffer() + new_bf = BloomFilter.from_buffer(buf) + for key in keys: + self.assertTrue(key in new_bf) + self.assertFalse(key + '-x' in new_bf) + + # Ensure that the two underlying bloom-filter buffers are equal. + self.assertEqual(len(new_bf), self.n) + new_buf = new_bf.to_buffer() + self.assertEqual(buf, new_buf) + class DataTypes(TableFunction): columns = ('key', 'value')