diff --git a/playhouse/sqlite_ext.py b/playhouse/sqlite_ext.py index dbaf5ee2f..6ecabde49 100644 --- a/playhouse/sqlite_ext.py +++ b/playhouse/sqlite_ext.py @@ -90,7 +90,7 @@ def __getitem__(self, idx): item = '[%s]' % idx else: item = '.%s' % idx - return JSONPath(self._field, self._path + (item,)) + return type(self)(self._field, self._path + (item,)) def append(self, value, as_json=None): if as_json or isinstance(value, (list, dict)): @@ -133,10 +133,41 @@ def __sql__(self, ctx): return ctx.sql(fn.json_extract(self._field, self.path) if self._path else self._field) +class JSONBPath(JSONPath): + def append(self, value, as_json=None): + if as_json or isinstance(value, (list, dict)): + value = fn.jsonb(self._field._json_dumps(value)) + return fn.jsonb_set(self._field, self['#'].path, value) + + def _json_operation(self, func, value, as_json=None): + if as_json or isinstance(value, (list, dict)): + value = fn.jsonb(self._field._json_dumps(value)) + return func(self._field, self.path, value) + + def insert(self, value, as_json=None): + return self._json_operation(fn.jsonb_insert, value, as_json) + + def set(self, value, as_json=None): + return self._json_operation(fn.jsonb_set, value, as_json) + + def replace(self, value, as_json=None): + return self._json_operation(fn.jsonb_replace, value, as_json) + + def update(self, value): + return self.set(fn.jsonb_patch(self, self._field._json_dumps(value))) + + def remove(self): + return fn.jsonb_remove(self._field, self.path) + + def __sql__(self, ctx): + return ctx.sql(fn.jsonb_extract(self._field, self.path) + if self._path else self._field) + class JSONField(TextField): field_type = 'JSON' unpack = False + Path = JSONPath def __init__(self, json_dumps=None, json_loads=None, **kwargs): self._json_dumps = json_dumps or json.dumps @@ -171,7 +202,7 @@ def inner(self, rhs): __hash__ = Field.__hash__ def __getitem__(self, item): - return JSONPath(self)[item] + return self.Path(self)[item] def extract(self, *paths): paths = [Value(p, converter=False) for p in paths] @@ -182,23 +213,23 @@ def extract_text(self, path): return Expression(self, '->>', Value(path, converter=False)) def append(self, value, as_json=None): - return JSONPath(self).append(value, as_json) + return self.Path(self).append(value, as_json) def insert(self, value, as_json=None): - return JSONPath(self).insert(value, as_json) + return self.Path(self).insert(value, as_json) def set(self, value, as_json=None): - return JSONPath(self).set(value, as_json) + return self.Path(self).set(value, as_json) def replace(self, value, as_json=None): - return JSONPath(self).replace(value, as_json) + return self.Path(self).replace(value, as_json) def update(self, data): - return JSONPath(self).update(data) + return self.Path(self).update(data) def remove(self, *paths): if not paths: - return JSONPath(self).remove() + return self.Path(self).remove() return fn.json_remove(self, *paths) def json_type(self): @@ -229,6 +260,26 @@ def tree(self): return fn.json_tree(self) +class JSONBField(JSONField): + field_type = 'JSONB' + Path = JSONBPath + + def db_value(self, value): + if value is not None: + if not isinstance(value, Node): + value = fn.jsonb(self._json_dumps(value)) + return value + + def extract(self, *paths): + paths = [Value(p, converter=False) for p in paths] + return fn.jsonb_extract(self, *paths) + + def remove(self, *paths): + if not paths: + return self.Path(self).remove() + return fn.jsonb_remove(self, *paths) + + class SearchField(Field): def __init__(self, unindexed=False, column_name=None, **k): if k: diff --git a/tests/sqlite.py b/tests/sqlite.py index 6a5e5968b..191ec5375 100644 --- a/tests/sqlite.py +++ b/tests/sqlite.py @@ -25,6 +25,7 @@ from .sqlite_helpers import json_installed from .sqlite_helpers import json_patch_installed from .sqlite_helpers import json_text_installed +from .sqlite_helpers import jsonb_installed database = SqliteExtDatabase(':memory:', c_extensions=False, timeout=100) @@ -125,6 +126,10 @@ class KeyData(TestModel): key = TextField() data = JSONField() +class JBData(TestModel): + key = TextField() + data = JSONBField() + class Values(TestModel): klass = IntegerField() @@ -532,9 +537,11 @@ class TestJSONFieldFunctions(ModelTestCase): ('d', {'x1': {'y1': 'z1', 'y2': 'z2'}}), ('e', {'l1': [0, 1, 2], 'l2': [1, [3, 3], 7]}), ] + M = KeyData def setUp(self): super(TestJSONFieldFunctions, self).setUp() + KeyData = self.M with self.database.atomic(): for key, data in self.test_data: KeyData.create(key=key, data=data) @@ -545,9 +552,11 @@ def assertRows(self, where, expected): self.assertEqual([kd.key for kd in self.Q.where(where)], expected) def assertData(self, key, expected): + KeyData = self.M self.assertEqual(KeyData.get(KeyData.key == key).data, expected) def test_json_group_functions(self): + KeyData = self.M with self.database.atomic(): KeyData.delete().execute() for i in range(10): @@ -597,6 +606,7 @@ def test_json_group_functions(self): self.assertEqual(query.scalar(), {'k0': 0, 'k1': 1, 'k2': 2, 'k3': 3}) def test_extract(self): + KeyData = self.M self.assertRows((KeyData.data['k1'] == 'v1'), ['a', 'c']) self.assertRows((KeyData.data['k2'] == 'v2'), ['b', 'c']) self.assertRows((KeyData.data['x1']['y1'] == 'z1'), ['a', 'd']) @@ -605,6 +615,7 @@ def test_extract(self): @skip_unless(json_text_installed()) def test_extract_text_json(self): + KeyData = self.M D = KeyData.data self.assertRows((D.extract('$.k1') == 'v1'), ['a', 'c']) self.assertRows((D.extract_text('$.k1') == 'v1'), ['a', 'c']) @@ -618,6 +629,7 @@ def test_extract_text_json(self): self.assertRows((D.extract_json('x1') == '{"y1":"z1"}'), ['a']) def test_extract_multiple(self): + KeyData = self.M query = KeyData.select( KeyData.key, KeyData.data.extract('$.k1', '$.k2').alias('keys')) @@ -629,6 +641,7 @@ def test_extract_multiple(self): ('e', [None, None])]) def test_insert(self): + KeyData = self.M # Existing values are not overwritten. query = KeyData.update(data=KeyData.data['k1'].insert('v1-x')) self.assertEqual(query.execute(), 5) @@ -641,6 +654,7 @@ def test_insert(self): 'l2': [1, [3, 3], 7]}) def test_insert_json(self): + KeyData = self.M set_json = KeyData.data['k1'].insert([0]) query = KeyData.update(data=set_json) self.assertEqual(query.execute(), 5) @@ -653,6 +667,7 @@ def test_insert_json(self): 'l2': [1, [3, 3], 7]}) def test_replace(self): + KeyData = self.M # Only existing values are overwritten. query = KeyData.update(data=KeyData.data['k1'].replace('v1-x')) self.assertEqual(query.execute(), 5) @@ -664,6 +679,7 @@ def test_replace(self): self.assertData('e', {'l1': [0, 1, 2], 'l2': [1, [3, 3], 7]}) def test_replace_json(self): + KeyData = self.M set_json = KeyData.data['k1'].replace([0]) query = KeyData.update(data=set_json) self.assertEqual(query.execute(), 5) @@ -675,6 +691,7 @@ def test_replace_json(self): self.assertData('e', {'l1': [0, 1, 2], 'l2': [1, [3, 3], 7]}) def test_set(self): + KeyData = self.M query = (KeyData .update({KeyData.data: KeyData.data['k1'].set('v1-x')}) .where(KeyData.data['k1'] == 'v1')) @@ -684,6 +701,7 @@ def test_set(self): self.assertData('a', {'k1': 'v1-x', 'x1': {'y1': 'z1'}}) def test_set_json(self): + KeyData = self.M set_json = KeyData.data['x1'].set({'y1': 'z1-x', 'y3': 'z3'}) query = (KeyData .update({KeyData.data: set_json}) @@ -695,6 +713,7 @@ def test_set_json(self): self.assertData('d', {'x1': {'y1': 'z1-x', 'y3': 'z3'}}) def test_append(self): + KeyData = self.M for value in ('ix', [], ['c1'], ['c1', 'c2'], {}, {'k1': 'v1'}, {'k1': 'v1', 'k2': 'v2'}, None, 1): KeyData.delete().execute() @@ -710,7 +729,9 @@ def test_append(self): .where(KeyData.key.startswith('a'))) self.assertEqual(query.execute(), 3) - query = KeyData.select().where(KeyData.key.startswith('a')) + query = (KeyData + .select(KeyData.key, fn.json(KeyData.data)) + .where(KeyData.key.startswith('a'))) self.assertEqual(sorted((row.key, row.data) for row in query), [('a0', [value]), ('a1', ['i1', value]), ('a2', ['i1', 'i2', value])]) @@ -720,7 +741,9 @@ def test_append(self): .where(KeyData.key.startswith('n'))) self.assertEqual(query.execute(), 3) - query = KeyData.select().where(KeyData.key.startswith('n')) + query = (KeyData + .select(KeyData.key, fn.json(KeyData.data)) + .where(KeyData.key.startswith('n'))) self.assertEqual(sorted((row.key, row.data) for row in query), [('n0', {'arr': [value]}), ('n1', {'arr': ['i1', value]}), @@ -728,6 +751,7 @@ def test_append(self): @skip_unless(json_patch_installed()) def test_update(self): + KeyData = self.M merged = KeyData.data.update({'x1': {'y1': 'z1-x', 'y3': 'z3'}}) query = (KeyData .update({KeyData.data: merged}) @@ -740,6 +764,7 @@ def test_update(self): @skip_unless(json_patch_installed()) def test_update_with_removal(self): + KeyData = self.M m = KeyData.data.update({'k1': None, 'x1': {'y1': None, 'y3': 'z3'}}) query = KeyData.update(data=m).where(KeyData.data['x1']['y1'] == 'z1') self.assertEqual(query.execute(), 2) @@ -750,6 +775,7 @@ def test_update_with_removal(self): @skip_unless(json_patch_installed()) def test_update_nested(self): + KeyData = self.M merged = KeyData.data['x1'].update({'y1': 'z1-x', 'y3': 'z3'}) query = (KeyData .update(data=merged) @@ -762,6 +788,7 @@ def test_update_nested(self): @skip_unless(json_patch_installed()) def test_updated_nested_with_removal(self): + KeyData = self.M merged = KeyData.data['x1'].update({'o1': 'p1', 'y1': None}) nrows = (KeyData .update(data=merged) @@ -772,6 +799,7 @@ def test_updated_nested_with_removal(self): self.assertData('d', {'x1': {'o1': 'p1', 'y2': 'z2'}}) def test_remove(self): + KeyData = self.M query = (KeyData .update(data=KeyData.data['k1'].remove()) .where(KeyData.data['k1'] == 'v1')) @@ -787,14 +815,16 @@ def test_remove(self): self.assertData('e', {'l1': [0, 1, 2], 'l2': [1, [3], 7]}) def test_simple_update(self): + KeyData = self.M nrows = (KeyData .update(data={'foo': 'bar'}) .where(KeyData.key.in_(['a', 'b'])) .execute()) - for k in self.Q.where(KeyData.key.in_(['a', 'b'])): - self.assertEqual(k.data, {'foo': 'bar'}) + self.assertData('a', {'foo': 'bar'}) + self.assertData('b', {'foo': 'bar'}) def test_children(self): + KeyData = self.M children = KeyData.data.children().alias('children') query = (KeyData .select(KeyData.key, children.c.fullkey.alias('fullkey')) @@ -809,6 +839,7 @@ def test_children(self): ('e', '$.l1'), ('e', '$.l2')]) def test_tree(self): + KeyData = self.M tree = KeyData.data.tree().alias('tree') query = (KeyData .select(tree.c.fullkey.alias('fullkey')) @@ -823,6 +854,29 @@ def test_tree(self): '$.x1.y2']) +@skip_unless(jsonb_installed(), 'requires sqlite jsonb support') +class TestJSONBFieldFunctions(TestJSONFieldFunctions): + requires = [JBData] + M = JBData + + def assertData(self, key, expected): + q = JBData.select(fn.json(JBData.data)).where(JBData.key == key) + self.assertEqual(q.get().data, expected) + + def test_extract_multiple(self): + # We need to override this, otherwise we end up with jsonb returned. + expr = fn.json(JBData.data.extract('$.k1', '$.k2')) + query = JBData.select( + JBData.key, + expr.python_value(json.loads).alias('keys')) + self.assertEqual(sorted((k.key, k.keys) for k in query), [ + ('a', ['v1', None]), + ('b', [None, 'v2']), + ('c', ['v1', 'v2']), + ('d', [None, None]), + ('e', [None, None])]) + + class TestSqliteExtensions(BaseTestCase): def test_virtual_model(self): class Test(VirtualModel): diff --git a/tests/sqlite_helpers.py b/tests/sqlite_helpers.py index 659ce2eb6..92ea4c1d8 100644 --- a/tests/sqlite_helpers.py +++ b/tests/sqlite_helpers.py @@ -4,7 +4,6 @@ def json_installed(): if sqlite3.sqlite_version_info < (3, 9, 0): return False - # Test in-memory DB to determine if the FTS5 extension is installed. tmp_db = sqlite3.connect(':memory:') try: tmp_db.execute('select json(?)', (1337,)) @@ -22,6 +21,9 @@ def json_patch_installed(): def json_text_installed(): return sqlite3.sqlite_version_info >= (3, 38, 0) +def jsonb_installed(): + return sqlite3.sqlite_version_info >= (3, 45, 0) + def compile_option(p): if not hasattr(compile_option, '_pragma_cache'):