Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PytatoKeyBuilder, persistent_dict test #459

Merged
merged 33 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
38e4332
add PytatoKeyBuilder
matthiasdiener Sep 25, 2023
4dd3250
mypy fixes
matthiasdiener Sep 25, 2023
970e7bb
support TaggableCLArray, Subscript
matthiasdiener Sep 28, 2023
95dec09
CL Array, function
matthiasdiener Sep 28, 2023
2ac10ee
add prim.Variable
matthiasdiener Feb 5, 2024
62a13ae
fixes to ndarray, pymb expressions
matthiasdiener Feb 5, 2024
b8e04bf
flake8
matthiasdiener Feb 5, 2024
ad9aa28
improve test
matthiasdiener Feb 5, 2024
60d8e41
add full invocation test
matthiasdiener Feb 5, 2024
9d45e65
lint fixes
matthiasdiener Feb 5, 2024
08be380
add missing pymbolic expressions
matthiasdiener Feb 5, 2024
058f6f9
flake8
matthiasdiener Feb 6, 2024
352bab6
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Jun 13, 2024
0360e21
remove update_for_function (now handled directly by pytools)
matthiasdiener Jun 13, 2024
8e3277c
Merge remote-tracking branch 'refs/remotes/origin/PytatoKeyBuilder' i…
matthiasdiener Jun 13, 2024
b1aaa97
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Jul 3, 2024
16516ec
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Jul 3, 2024
454f273
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Jul 25, 2024
993dfe4
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Sep 7, 2024
82a5f25
lint
matthiasdiener Sep 7, 2024
dc53746
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Sep 9, 2024
97df5d7
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Sep 19, 2024
93abdc1
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Sep 27, 2024
b35d841
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Oct 9, 2024
70a6887
add typecheck, remove pymbolic handling
matthiasdiener Oct 9, 2024
4f9f95f
lint
matthiasdiener Oct 9, 2024
cffda36
pylint
matthiasdiener Oct 9, 2024
47a31d6
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Nov 7, 2024
57908b1
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Nov 7, 2024
adfe6b2
add arraycontext package for mypy
matthiasdiener Nov 8, 2024
20ed0d2
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Nov 15, 2024
29c1833
Merge branch 'main' into PytatoKeyBuilder
matthiasdiener Nov 20, 2024
8e9fe8a
Merge branch 'main' into PytatoKeyBuilder
inducer Dec 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
python-version: '3.x'
- name: "Main Script"
run: |
export EXTRA_INSTALL="git+https://github.com/inducer/arraycontext"
curl -L -O https://tiker.net/ci-support-v0
. ci-support-v0
build_py_project_in_conda_env
Expand All @@ -54,6 +55,7 @@ jobs:
python-version: '3.x'
- name: "Main Script"
run: |
export EXTRA_INSTALL="git+https://github.com/inducer/arraycontext"
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_conda_env
Expand Down
28 changes: 28 additions & 0 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any

from loopy.tools import LoopyKeyBuilder
from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method

Expand Down Expand Up @@ -565,4 +566,31 @@ def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int:

# }}}


# {{{ PytatoKeyBuilder

class PytatoKeyBuilder(LoopyKeyBuilder):
"""A custom :class:`pytools.persistent_dict.KeyBuilder` subclass
for objects within :mod:`pytato`.
"""
# The types below aren't immutable in general, but in the context of
# pytato, they are used as such.

def update_for_ndarray(self, key_hash: Any, key: Any) -> None:
import numpy as np
assert isinstance(key, np.ndarray)
self.rec(key_hash, key.data.tobytes())

def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None:
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
assert isinstance(key, TaggableCLArray)
self.rec(key_hash, key.get())

def update_for_Array(self, key_hash: Any, key: Any) -> None:
from pyopencl.array import Array
assert isinstance(key, Array)
self.rec(key_hash, key.get())

# }}}

# vim: fdm=marker
98 changes: 97 additions & 1 deletion test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,99 @@ def test_dot_visualizers():
# }}}


# {{{ Test PytatoKeyBuilder

def run_test_with_new_python_invocation(f, *args, extra_env_vars=None) -> None:
import os
if extra_env_vars is None:
extra_env_vars = {}

from base64 import b64encode
from pickle import dumps
from subprocess import check_call

env_vars = {
"INVOCATION_INFO": b64encode(dumps((f, args))).decode(),
}
env_vars.update(extra_env_vars)

my_env = os.environ.copy()
my_env.update(env_vars)

check_call([sys.executable, __file__], env=my_env)


def run_test_with_new_python_invocation_inner() -> None:
import os
from base64 import b64decode
from pickle import loads

f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode()))

f(*args)


def test_persistent_hashing_and_persistent_dict() -> None:
import shutil
import tempfile

from pytools.persistent_dict import ReadOnlyEntryError, WriteOncePersistentDict

from pytato.analysis import PytatoKeyBuilder

try:
tmpdir = tempfile.mkdtemp()

pkb = PytatoKeyBuilder()

pd = WriteOncePersistentDict("test_persistent_dict",
key_builder=pkb,
container_dir=tmpdir,
safe_sync=False)

for i in range(100):
rdagc = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=5, use_numpy=True)

dag = make_random_dag(rdagc)

# Make sure the PytatoKeyBuilder can handle 'dag'
pd[dag] = 42

# Make sure that the key stays the same within the same Python invocation
with pytest.raises(ReadOnlyEntryError):
pd[dag] = 42

# Make sure that the key stays the same across Python invocations
run_test_with_new_python_invocation(
_test_persistent_hashing_and_persistent_dict_stage2, tmpdir)
finally:
shutil.rmtree(tmpdir)


def _test_persistent_hashing_and_persistent_dict_stage2(tmpdir) -> None:
from pytools.persistent_dict import ReadOnlyEntryError, WriteOncePersistentDict

from pytato.analysis import PytatoKeyBuilder
pkb = PytatoKeyBuilder()

pd = WriteOncePersistentDict("test_persistent_dict",
key_builder=pkb,
container_dir=tmpdir,
safe_sync=False)

for i in range(100):
rdagc = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=5, use_numpy=True)

dag = make_random_dag(rdagc)

with pytest.raises(ReadOnlyEntryError):
pd[dag] = 42

# }}}


def test_numpy_type_promotion_with_pytato_arrays():
class NotReallyAnArray:
@property
Expand Down Expand Up @@ -1427,7 +1520,10 @@ def test_pickling_hash():


if __name__ == "__main__":
if len(sys.argv) > 1:
import os
if "INVOCATION_INFO" in os.environ:
run_test_with_new_python_invocation_inner()
elif len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
Expand Down
Loading