diff --git a/.gitignore b/.gitignore index a4fdd01ebc..975a88cd25 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,6 @@ pplot_out/ # docker docker-bake.override.json + +# benchmark +.benchmarks/ diff --git a/src/aiida/storage/sqlite_zip/orm.py b/src/aiida/storage/sqlite_zip/orm.py index 0f51c12534..26902e1e5f 100644 --- a/src/aiida/storage/sqlite_zip/orm.py +++ b/src/aiida/storage/sqlite_zip/orm.py @@ -19,7 +19,7 @@ from sqlalchemy import JSON, case, func, select from sqlalchemy.orm.util import AliasedClass -from sqlalchemy.sql import ColumnElement +from sqlalchemy.sql import ColumnElement, null from aiida.common.lang import type_check from aiida.storage.psql_dos.orm import authinfos, comments, computers, entities, groups, logs, nodes, users, utils @@ -285,8 +285,21 @@ def _cast_json_type(comparator: JSON.Comparator, value: Any) -> Tuple[ColumnElem return case((type_filter, casted_entity.ilike(value, escape='\\')), else_=False) if operator == 'contains': - # to-do, see: https://github.com/sqlalchemy/sqlalchemy/discussions/7836 - raise NotImplementedError('The operator `contains` is not implemented for SQLite-based storage plugins.') + # If the operator is 'contains', we must mirror the behavior of the PostgreSQL + # backend, which returns NULL if `attr_key` doesn't exist. To achieve this, + # an additional CASE statement is added to directly return NULL in such cases. + # + # Instead of using `database_entity`, which would be interpreted as a 'null' + # string in SQL, this approach ensures a proper NULL value is returned when + # `attr_key` doesn't exist. + # + # Original implementation: + # return func.json_contains(database_entity, json.dumps(value)) + + return case( + (func.json_extract(column, '$.' + '.'.join(attr_key)).is_(null()), null()), + else_=func.json_contains(database_entity, json.dumps(value)), + ) if operator == 'has_key': return ( diff --git a/src/aiida/storage/sqlite_zip/utils.py b/src/aiida/storage/sqlite_zip/utils.py index 2438c18fcb..a641a0195f 100644 --- a/src/aiida/storage/sqlite_zip/utils.py +++ b/src/aiida/storage/sqlite_zip/utils.py @@ -48,12 +48,44 @@ def sqlite_case_sensitive_like(dbapi_connection, _): cursor.close() +def _contains(lhs: Union[dict, list], rhs: Union[dict, list]): + if isinstance(lhs, dict) and isinstance(rhs, dict): + for key in rhs: + if key not in lhs or not _contains(lhs[key], rhs[key]): + return False + return True + + elif isinstance(lhs, list) and isinstance(rhs, list): + for item in rhs: + if not any(_contains(e, item) for e in lhs): + return False + return True + else: + return lhs == rhs + + +def _json_contains(lhs: Union[str, bytes, bytearray, dict, list], rhs: Union[str, bytes, bytearray, dict, list]): + try: + if isinstance(lhs, (str, bytes, bytearray)): + lhs = json.loads(lhs) + if isinstance(rhs, (str, bytes, bytearray)): + rhs = json.loads(rhs) + except json.JSONDecodeError: + return 0 + return int(_contains(lhs, rhs)) + + +def register_json_contains(dbapi_connection, _): + dbapi_connection.create_function('json_contains', 2, _json_contains) + + def create_sqla_engine(path: Union[str, Path], *, enforce_foreign_keys: bool = True, **kwargs) -> Engine: """Create a new engine instance.""" engine = create_engine(f'sqlite:///{path}', json_serializer=json.dumps, json_deserializer=json.loads, **kwargs) event.listen(engine, 'connect', sqlite_case_sensitive_like) if enforce_foreign_keys: event.listen(engine, 'connect', sqlite_enforce_foreign_keys) + event.listen(engine, 'connect', register_json_contains) return engine diff --git a/tests/benchmark/test_json_contains.py b/tests/benchmark/test_json_contains.py new file mode 100644 index 0000000000..3ec2393b17 --- /dev/null +++ b/tests/benchmark/test_json_contains.py @@ -0,0 +1,138 @@ +import random +import string + +import pytest + +from aiida import orm +from aiida.orm.querybuilder import QueryBuilder + +GROUP_NAME = 'json-contains' + + +COMPLEX_JSON_DEPTH_RANGE = [2**i for i in range(4)] +COMPLEX_JSON_BREADTH_RANGE = [2**i for i in range(4)] +LARGE_TABLE_SIZE_RANGE = [2**i for i in range(1, 11)] + + +def gen_json(depth: int, breadth: int): + def gen_str(n: int, with_digits: bool = True): + population = string.ascii_letters + if with_digits: + population += string.digits + return ''.join(random.choices(population, k=n)) + + if depth == 0: # random primitive value + # real numbers are not included as their equivalence is tricky + return random.choice( + [ + random.randint(-114, 514), # integers + gen_str(6), # strings + random.choice([True, False]), # booleans + None, # nulls + ] + ) + + else: + gen_dict = random.choice([True, False]) + data = [gen_json(depth - 1, breadth) for _ in range(breadth)] + if gen_dict: + keys = set() + while len(keys) < breadth: + keys.add(gen_str(6, False)) + data = dict(zip(list(keys), data)) + return data + + +def extract_component(data, p: float = -1): + if random.random() < p: + return data + + if isinstance(data, dict) and data: + key = random.choice(list(data.keys())) + return {key: extract_component(data[key])} + elif isinstance(data, list) and data: + element = random.choice(data) + return [extract_component(element)] + else: + return data + + +@pytest.mark.benchmark(group=GROUP_NAME) +@pytest.mark.parametrize('depth', [1, 2, 4, 8]) +@pytest.mark.parametrize('breadth', [1, 2, 4]) +@pytest.mark.usefixtures('aiida_profile_clean') +def test_deep_json(benchmark, depth, breadth): + lhs = gen_json(depth, breadth) + rhs = extract_component(lhs, p=1.0 / depth) + assert 0 == len(QueryBuilder().append(orm.Dict).all()) + + orm.Dict( + { + 'id': f'{depth}-{breadth}', + 'data': lhs, + } + ).store() + qb = QueryBuilder().append( + orm.Dict, + filters={ + 'attributes.data': {'contains': rhs}, + }, + project=['attributes.id'], + ) + qb.all() + result = benchmark(qb.all) + assert len(result) == 1 + + +@pytest.mark.benchmark(group=GROUP_NAME) +@pytest.mark.parametrize('depth', [2]) +@pytest.mark.parametrize('breadth', [1, 10, 100]) +@pytest.mark.usefixtures('aiida_profile_clean') +def test_wide_json(benchmark, depth, breadth): + lhs = gen_json(depth, breadth) + rhs = extract_component(lhs, p=1.0 / depth) + assert 0 == len(QueryBuilder().append(orm.Dict).all()) + + orm.Dict( + { + 'id': f'{depth}-{breadth}', + 'data': lhs, + } + ).store() + qb = QueryBuilder().append( + orm.Dict, + filters={ + 'attributes.data': {'contains': rhs}, + }, + project=['attributes.id'], + ) + qb.all() + result = benchmark(qb.all) + assert len(result) == 1 + + +@pytest.mark.benchmark(group=GROUP_NAME) +@pytest.mark.parametrize('num_entries', LARGE_TABLE_SIZE_RANGE) +@pytest.mark.usefixtures('aiida_profile_clean') +def test_large_table(benchmark, num_entries): + data = gen_json(2, 10) + rhs = extract_component(data) + assert 0 == len(QueryBuilder().append(orm.Dict).all()) + + for i in range(num_entries): + orm.Dict( + { + 'id': f'N={num_entries}, i={i}', + 'data': data, + } + ).store() + qb = QueryBuilder().append( + orm.Dict, + filters={ + 'attributes.data': {'contains': rhs}, + }, + project=['attributes.id'], + ) + qb.all() + result = benchmark(qb.all) + assert len(result) == num_entries diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index 6d5d7cefc4..5dcf2c2e58 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -1711,6 +1711,14 @@ def test_statistics_default_class(self, aiida_localhost): class TestJsonFilters: + @staticmethod + def assert_match(data, filters, is_match): + orm.Dict(data).store() + qb = orm.QueryBuilder().append(orm.Dict, filters=filters) + assert qb.count() in {0, 1} + found = qb.count() == 1 + assert found == is_match + @pytest.mark.parametrize( 'data,filters,is_match', ( @@ -1740,26 +1748,43 @@ class TestJsonFilters: ({'arr': [1, '2', None]}, {'attributes.arr': {'!contains': []}}, False), ({'arr': [1, '2', None]}, {'attributes.arr': {'!contains': [114514]}}, True), ({'arr': [1, '2', None]}, {'attributes.arr': {'!contains': [1, 114514]}}, True), - # TODO: these pass, but why? are these behaviors expected? - # non-exist `attr_key`s - ({'foo': []}, {'attributes.arr': {'contains': []}}, False), - ({'foo': []}, {'attributes.arr': {'!contains': []}}, False), + # when attr_key does not exist, `contains`` returns `NULL` + ({'arr': [1, '2', None]}, {'attributes.x': {'!contains': []}}, False), + ({'arr': [1, '2', None]}, {'attributes.x': {'contains': []}}, False), ), ids=json.dumps, ) @pytest.mark.usefixtures('aiida_profile_clean') - @pytest.mark.requires_psql def test_json_filters_contains_arrays(self, data, filters, is_match): """Test QueryBuilder filter `contains` for JSON array fields""" - orm.Dict(data).store() - qb = orm.QueryBuilder().append(orm.Dict, filters=filters) - assert qb.count() in {0, 1} - found = qb.count() == 1 - assert found == is_match + self.assert_match(data, filters, is_match) @pytest.mark.parametrize( 'data,filters,is_match', ( + # when attr_key does not exist, `contains`` returns `NULL` + ( + { + 'dict': { + 'k1': 1, + 'k2': '2', + 'k3': None, + } + }, + {'attributes.foobar': {'!contains': {}}}, + False, + ), + ( + { + 'dict': { + 'k1': 1, + 'k2': '2', + 'k3': None, + } + }, + {'attributes.foobar': {'contains': {}}}, + False, + ), # contains different types of values ( { @@ -1806,6 +1831,50 @@ def test_json_filters_contains_arrays(self, data, filters, is_match): {'attributes.dict': {'contains': {}}}, True, ), + # nested dicts + ( + {'dict': {'k1': {'k2': {'kx': 1, 'k3': 'secret'}, 'kxx': None}, 'kxxx': 'vxxx'}}, + {'attributes.dict': {'contains': {'k1': {'k2': {'k3': 'secret'}}}}}, + True, + ), + ( + { + 'dict': { + 'k1': [ + 0, + 1, + { + 'k2': [ + '0', + { + 'kkk': 'vvv', + 'k3': 'secret', + }, + '2', + ] + }, + 3, + ], + 'kkk': 'vvv', + } + }, + { + 'attributes.dict': { + 'contains': { + 'k1': [ + { + 'k2': [ + { + 'k3': 'secret', + } + ] + } + ] + } + } + }, + True, + ), # doesn't contain non-exist entries ( { @@ -1852,19 +1921,194 @@ def test_json_filters_contains_arrays(self, data, filters, is_match): {'attributes.dict': {'!contains': {}}}, False, ), - # TODO: these pass, but why? are these behaviors expected? - # non-exist `attr_key`s - ({'map': {}}, {'attributes.dict': {'contains': {}}}, False), - ({'map': {}}, {'attributes.dict': {'!contains': {}}}, False), ), ids=json.dumps, ) @pytest.mark.usefixtures('aiida_profile_clean') - @pytest.mark.requires_psql def test_json_filters_contains_object(self, data, filters, is_match): """Test QueryBuilder filter `contains` for JSON object fields""" - orm.Dict(data).store() - qb = orm.QueryBuilder().append(orm.Dict, filters=filters) - assert qb.count() in {0, 1} - found = qb.count() == 1 - assert found == is_match + self.assert_match(data, filters, is_match) + + @pytest.mark.parametrize( + 'data,filters,is_match', + ( + ({'dict': {'k1': 1, 'k2': '2', 'k3': None}}, {'attributes.dict': {'has_key': 'k1'}}, True), + ({'dict': {'k1': 1, 'k2': '2', 'k3': None}}, {'attributes.dict': {'has_key': 'k2'}}, True), + ({'dict': {'k1': 1, 'k2': '2', 'k3': None}}, {'attributes.dict': {'has_key': 'k3'}}, True), + ({'dict': {'k1': 1, 'k2': '2', 'k3': None}}, {'attributes.dict': {'!has_key': 'k1'}}, False), + ({'dict': {'k1': 1, 'k2': '2', 'k3': None}}, {'attributes.dict': {'!has_key': 'k2'}}, False), + ({'dict': {'k1': 1, 'k2': '2', 'k3': None}}, {'attributes.dict': {'!has_key': 'k3'}}, False), + ({'dict': {'k1': 1, 'k2': '2', 'k3': None}}, {'attributes.dict': {'has_key': 'non-exist'}}, False), + ({'dict': {'k1': 1, 'k2': '2', 'k3': None}}, {'attributes.dict': {'!has_key': 'non-exist'}}, True), + ({'dict': 0xFA15ED1C7}, {'attributes.dict': {'has_key': 'dict'}}, False), + ({'dict': 0xFA15ED1C7}, {'attributes.dict': {'!has_key': 'dict'}}, True), + ), + ) + @pytest.mark.usefixtures('aiida_profile_clean') + def test_json_filters_has_key(self, data, filters, is_match): + self.assert_match(data, filters, is_match) + + @pytest.mark.parametrize( + 'filters,matches', + ( + # type match + ({'attributes.text': {'of_type': 'string'}}, 1), + ({'attributes.integer': {'of_type': 'number'}}, 1), + ({'attributes.float': {'of_type': 'number'}}, 1), + ({'attributes.true': {'of_type': 'boolean'}}, 1), + ({'attributes.false': {'of_type': 'boolean'}}, 1), + ({'attributes.null': {'of_type': 'null'}}, 2), + ({'attributes.list': {'of_type': 'array'}}, 1), + ({'attributes.dict': {'of_type': 'object'}}, 1), + # equality match + ({'attributes.text': {'==': 'abcXYZ'}}, 1), + ({'attributes.integer': {'==': 1}}, 1), + ({'attributes.float': {'==': 1.1}}, 1), + ({'attributes.true': {'==': True}}, 1), + ({'attributes.false': {'==': False}}, 1), + ({'attributes.list': {'==': [1, 2]}}, 1), + ({'attributes.list2': {'==': ['a', 'b']}}, 1), + ({'attributes.dict': {'==': {'key-1': 1, 'key-none': None}}}, 1), + # equality non-match + ({'attributes.text': {'==': 'lmn'}}, 0), + ({'attributes.integer': {'==': 2}}, 0), + ({'attributes.float': {'==': 2.2}}, 0), + ({'attributes.true': {'==': False}}, 0), + ({'attributes.false': {'==': True}}, 0), + ({'attributes.list': {'==': [1, 3]}}, 0), + # text regexes + ({'attributes.text': {'like': 'abcXYZ'}}, 1), + ({'attributes.text': {'like': 'abcxyz'}}, 0), + ({'attributes.text': {'ilike': 'abcxyz'}}, 1), + ({'attributes.text': {'like': 'abc%'}}, 1), + ({'attributes.text': {'like': 'abc_YZ'}}, 1), + ( + { + 'attributes.text2': { + 'like': 'abc\\_XYZ' # Literal match + } + }, + 1, + ), + ({'attributes.text2': {'like': 'abc_XYZ'}}, 2), + # integer comparisons + ({'attributes.float': {'<': 1}}, 0), + ({'attributes.float': {'<': 2}}, 1), + ({'attributes.float': {'>': 2}}, 0), + ({'attributes.float': {'>': 0}}, 1), + ({'attributes.integer': {'<': 1}}, 0), + ({'attributes.integer': {'<': 2}}, 1), + ({'attributes.integer': {'>': 2}}, 0), + ({'attributes.integer': {'>': 0}}, 1), + # float comparisons + ({'attributes.float': {'<': 0.99}}, 0), + ({'attributes.float': {'<': 2.01}}, 1), + ({'attributes.float': {'>': 2.01}}, 0), + ({'attributes.float': {'>': 0.01}}, 1), + ({'attributes.integer': {'<': 0.99}}, 0), + ({'attributes.integer': {'<': 2.01}}, 1), + ({'attributes.integer': {'>': 2.01}}, 0), + ({'attributes.integer': {'>': 0.01}}, 1), + # array operators + ({'attributes.list': {'of_length': 0}}, 0), + ({'attributes.list': {'of_length': 2}}, 1), + ({'attributes.list': {'longer': 3}}, 0), + ({'attributes.list': {'longer': 1}}, 1), + ({'attributes.list': {'shorter': 1}}, 0), + ({'attributes.list': {'shorter': 3}}, 1), + # in operator + ({'attributes.text': {'in': ['x', 'y', 'z']}}, 0), + ({'attributes.text': {'in': ['x', 'y', 'abcXYZ']}}, 1), + ({'attributes.integer': {'in': [5, 6, 7]}}, 0), + ({'attributes.integer': {'in': [1, 2, 3]}}, 1), + ), + ids=json.dumps, + ) + @pytest.mark.usefixtures('aiida_profile_clean') + def test_json_filters(self, filters, matches): + """Test QueryBuilder filtering for JSON fields.""" + orm.Dict( + { + 'text': 'abcXYZ', + 'text2': 'abc_XYZ', + 'integer': 1, + 'float': 1.1, + 'true': True, + 'false': False, + 'null': None, + 'list': [1, 2], + 'list2': ['a', 'b'], + 'dict': { + 'key-1': 1, + 'key-none': None, + }, + }, + ).store() + orm.Dict({'text2': 'abcxXYZ'}).store() + + qbuilder = orm.QueryBuilder() + qbuilder.append(orm.Dict, filters=filters) + assert qbuilder.count() == matches + + @pytest.mark.parametrize( + 'filters,matches', + ( + ({'label': {'like': 'abc_XYZ'}}, 2), + ({'label': {'like': 'abc\\_XYZ'}}, 1), + ({'label': {'like': 'abcxXYZ'}}, 1), + ({'label': {'like': 'abc%XYZ'}}, 2), + ), + ids=json.dumps, + ) + @pytest.mark.usefixtures('aiida_profile_clean') + def test_column_filters(self, filters, matches): + """Test querying directly those stored in the columns""" + dict1 = orm.Dict( + { + 'text2': 'abc_XYZ', + } + ).store() + dict2 = orm.Dict({'text2': 'abcxXYZ'}).store() + dict1.label = 'abc_XYZ' + dict2.label = 'abcxXYZ' + qbuilder = orm.QueryBuilder() + qbuilder.append(orm.Dict, filters=filters) + assert qbuilder.count() == matches + + @pytest.mark.parametrize( + 'key,cast_type', + ( + ('text', 't'), + ('integer', 'i'), + ('float', 'f'), + ), + ) + @pytest.mark.usefixtures('aiida_profile_clean') + def test_json_order_by(self, key, cast_type): + """Test QueryBuilder ordering by JSON field keys.""" + dict1 = orm.Dict( + { + 'text': 'b', + 'integer': 2, + 'float': 2.2, + } + ).store() + dict2 = orm.Dict( + { + 'text': 'a', + 'integer': 1, + 'float': 1.1, + } + ).store() + dict3 = orm.Dict( + { + 'text': 'c', + 'integer': 3, + 'float': 3.3, + } + ).store() + qbuilder = orm.QueryBuilder() + qbuilder.append(orm.Dict, tag='dict', project=['id']).order_by( + {'dict': {f'attributes.{key}': {'order': 'asc', 'cast': cast_type}}} + ) + assert qbuilder.all(flat=True) == [dict2.pk, dict1.pk, dict3.pk] diff --git a/tests/storage/sqlite/test_orm.py b/tests/storage/sqlite/test_orm.py deleted file mode 100644 index 8a03ce034f..0000000000 --- a/tests/storage/sqlite/test_orm.py +++ /dev/null @@ -1,201 +0,0 @@ -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Test for the ORM implementation.""" - -import json - -import pytest - -from aiida.orm import Dict, QueryBuilder -from aiida.storage.sqlite_temp import SqliteTempBackend - - -@pytest.mark.parametrize( - 'filters,matches', - ( - # type match - ({'attributes.text': {'of_type': 'string'}}, 1), - ({'attributes.integer': {'of_type': 'number'}}, 1), - ({'attributes.float': {'of_type': 'number'}}, 1), - ({'attributes.true': {'of_type': 'boolean'}}, 1), - ({'attributes.false': {'of_type': 'boolean'}}, 1), - ({'attributes.null': {'of_type': 'null'}}, 3), - ({'attributes.list': {'of_type': 'array'}}, 1), - ({'attributes.dict': {'of_type': 'object'}}, 1), - # equality match - ({'attributes.text': {'==': 'abcXYZ'}}, 1), - ({'attributes.integer': {'==': 1}}, 1), - ({'attributes.float': {'==': 1.1}}, 1), - ({'attributes.true': {'==': True}}, 1), - ({'attributes.false': {'==': False}}, 1), - ({'attributes.list': {'==': [1, 2]}}, 1), - ({'attributes.list2': {'==': ['a', 'b']}}, 1), - ({'attributes.dict': {'==': {'key-1': 1, 'key-none': None}}}, 1), - # equality non-match - ({'attributes.text': {'==': 'lmn'}}, 0), - ({'attributes.integer': {'==': 2}}, 0), - ({'attributes.float': {'==': 2.2}}, 0), - ({'attributes.true': {'==': False}}, 0), - ({'attributes.false': {'==': True}}, 0), - ({'attributes.list': {'==': [1, 3]}}, 0), - # text regexes - ({'attributes.text': {'like': 'abcXYZ'}}, 1), - ({'attributes.text': {'like': 'abcxyz'}}, 0), - ({'attributes.text': {'ilike': 'abcxyz'}}, 1), - ({'attributes.text': {'like': 'abc%'}}, 1), - ({'attributes.text': {'like': 'abc_YZ'}}, 1), - ( - { - 'attributes.text2': { - 'like': 'abc\\_XYZ' # Literal match - } - }, - 1, - ), - ({'attributes.text2': {'like': 'abc_XYZ'}}, 2), - # integer comparisons - ({'attributes.float': {'<': 1}}, 0), - ({'attributes.float': {'<': 2}}, 1), - ({'attributes.float': {'>': 2}}, 0), - ({'attributes.float': {'>': 0}}, 1), - ({'attributes.integer': {'<': 1}}, 0), - ({'attributes.integer': {'<': 2}}, 1), - ({'attributes.integer': {'>': 2}}, 0), - ({'attributes.integer': {'>': 0}}, 1), - # float comparisons - ({'attributes.float': {'<': 0.99}}, 0), - ({'attributes.float': {'<': 2.01}}, 1), - ({'attributes.float': {'>': 2.01}}, 0), - ({'attributes.float': {'>': 0.01}}, 1), - ({'attributes.integer': {'<': 0.99}}, 0), - ({'attributes.integer': {'<': 2.01}}, 1), - ({'attributes.integer': {'>': 2.01}}, 0), - ({'attributes.integer': {'>': 0.01}}, 1), - # array operators - ({'attributes.list': {'of_length': 0}}, 0), - ({'attributes.list': {'of_length': 2}}, 1), - ({'attributes.list': {'longer': 3}}, 0), - ({'attributes.list': {'longer': 1}}, 1), - ({'attributes.list': {'shorter': 1}}, 0), - ({'attributes.list': {'shorter': 3}}, 1), - # in operator - ({'attributes.text': {'in': ['x', 'y', 'z']}}, 0), - ({'attributes.text': {'in': ['x', 'y', 'abcXYZ']}}, 1), - ({'attributes.integer': {'in': [5, 6, 7]}}, 0), - ({'attributes.integer': {'in': [1, 2, 3]}}, 1), - # object operators - ({'attributes.dict': {'has_key': 'non-exist'}}, 0), - ({'attributes.dict': {'!has_key': 'non-exist'}}, 3), - ({'attributes.dict': {'has_key': 'key-1'}}, 1), - ({'attributes.dict': {'has_key': 'key-none'}}, 1), - ({'attributes.dict': {'!has_key': 'key-none'}}, 2), - ), - ids=json.dumps, -) -def test_qb_json_filters(filters, matches): - """Test QueryBuilder filtering for JSON fields.""" - profile = SqliteTempBackend.create_profile(debug=False) - backend = SqliteTempBackend(profile) - Dict( - { - 'text': 'abcXYZ', - 'text2': 'abc_XYZ', - 'integer': 1, - 'float': 1.1, - 'true': True, - 'false': False, - 'null': None, - 'list': [1, 2], - 'list2': ['a', 'b'], - 'dict': { - 'key-1': 1, - 'key-none': None, - }, - }, - backend=backend, - ).store() - Dict({'text2': 'abcxXYZ'}, backend=backend).store() - - # a false dict, added to test `has_key`'s behavior when key is not of json type - Dict({'dict': 0xFA15ED1C7}, backend=backend).store() - - qbuilder = QueryBuilder(backend=backend) - qbuilder.append(Dict, filters=filters) - assert qbuilder.count() == matches - - -@pytest.mark.parametrize( - 'filters,matches', - ( - ({'label': {'like': 'abc_XYZ'}}, 2), - ({'label': {'like': 'abc\\_XYZ'}}, 1), - ({'label': {'like': 'abcxXYZ'}}, 1), - ({'label': {'like': 'abc%XYZ'}}, 2), - ), - ids=json.dumps, -) -def test_qb_column_filters(filters, matches): - """Test querying directly those stored in the columns""" - profile = SqliteTempBackend.create_profile(debug=False) - backend = SqliteTempBackend(profile) - dict1 = Dict( - { - 'text2': 'abc_XYZ', - }, - backend=backend, - ).store() - dict2 = Dict({'text2': 'abcxXYZ'}, backend=backend).store() - dict1.label = 'abc_XYZ' - dict2.label = 'abcxXYZ' - qbuilder = QueryBuilder(backend=backend) - qbuilder.append(Dict, filters=filters) - assert qbuilder.count() == matches - - -@pytest.mark.parametrize( - 'key,cast_type', - ( - ('text', 't'), - ('integer', 'i'), - ('float', 'f'), - ), -) -def test_qb_json_order_by(key, cast_type): - """Test QueryBuilder ordering by JSON field keys.""" - profile = SqliteTempBackend.create_profile(debug=False) - backend = SqliteTempBackend(profile) - dict1 = Dict( - { - 'text': 'b', - 'integer': 2, - 'float': 2.2, - }, - backend=backend, - ).store() - dict2 = Dict( - { - 'text': 'a', - 'integer': 1, - 'float': 1.1, - }, - backend=backend, - ).store() - dict3 = Dict( - { - 'text': 'c', - 'integer': 3, - 'float': 3.3, - }, - backend=backend, - ).store() - qbuilder = QueryBuilder(backend=backend) - qbuilder.append(Dict, tag='dict', project=['id']).order_by( - {'dict': {f'attributes.{key}': {'order': 'asc', 'cast': cast_type}}} - ) - assert qbuilder.all(flat=True) == [dict2.pk, dict1.pk, dict3.pk] diff --git a/tests/storage/sqlite_zip/test_utils.py b/tests/storage/sqlite_zip/test_utils.py new file mode 100644 index 0000000000..5069b4612a --- /dev/null +++ b/tests/storage/sqlite_zip/test_utils.py @@ -0,0 +1,131 @@ +import json + +import pytest + +from aiida.storage.sqlite_zip.utils import _contains, _json_contains + + +class TestCustomFunction: + @pytest.mark.parametrize( + 'lhs,rhs,is_match', + ( + # contains different types of element + ([1, '2', None], [1], True), + ([1, '2', None], ['2'], True), + ([1, '2', None], [None], True), + # contains multiple elements of various types + ([1, '2', None], [1, None], True), + # contains non-exist elements + ([1, '2', None], [114514], False), + # contains empty set + ([1, '2', None], [], True), + ([], [], True), + # nested arrays + ([[1, 0], [0, 2]], [[1, 0]], True), + ([[2, 3], [0, 1], []], [[1, 0]], True), + ([[2, 3], [1]], [[4]], False), + ([[1, 0], [0, 2]], [[3]], False), + ([[1, 0], [0, 2]], [3], False), + ([[1, 0], [0, 2]], [[2]], True), + ([[1, 0], [0, 2]], [2], False), + ([[1, 0], [0, 2], 3], [[3]], False), + ([[1, 0], [0, 2], 3], [3], True), + # contains different types of values + ( + { + 'k1': 1, + 'k2': '2', + 'k3': None, + }, + {'k1': 1}, + True, + ), + ( + { + 'k1': 1, + 'k2': '2', + 'k3': None, + }, + {'k1': 1, 'k2': '2'}, + True, + ), + ( + { + 'k1': 1, + 'k2': '2', + 'k3': None, + }, + {'k3': None}, + True, + ), + # contains empty set + ( + { + 'k1': 1, + 'k2': '2', + 'k3': None, + }, + {}, + True, + ), + # nested dicts + ( + {'k1': {'k2': {'kx': 1, 'k3': 'secret'}, 'kxx': None}, 'kxxx': 'vxxx'}, + {'k1': {'k2': {'k3': 'secret'}}}, + True, + ), + ( + { + 'k1': [ + 0, + 1, + { + 'k2': [ + '0', + { + 'kkk': 'vvv', + 'k3': 'secret', + }, + '2', + ] + }, + 3, + ], + 'kkk': 'vvv', + }, + { + 'k1': [ + { + 'k2': [ + { + 'k3': 'secret', + } + ] + } + ] + }, + True, + ), + # doesn't contain non-exist entries + ( + { + 'k1': 1, + 'k2': '2', + 'k3': None, + }, + {'k1': 1, 'k': 'v'}, + False, + ), + ), + ids=json.dumps, + ) + @pytest.mark.usefixtures('aiida_profile_clean') + def test_json_contains(self, lhs, rhs, is_match): + """Test QueryBuilder filter `contains` for JSON array fields""" + lhs_json = json.dumps(lhs) + rhs_json = json.dumps(rhs) + assert is_match == _contains(lhs, rhs) + assert is_match == _json_contains(lhs, rhs) + assert is_match == _json_contains(lhs_json, rhs) + assert is_match == _json_contains(lhs, rhs_json) + assert is_match == _json_contains(lhs_json, rhs_json)