Skip to content

Commit

Permalink
Consolidate caching tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Oct 7, 2024
1 parent 5706ddd commit 81d0f0a
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 203 deletions.
224 changes: 206 additions & 18 deletions test/unit/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,30 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
# OF THE POSSIBILITY OF SUCH DAMAGE.


import ctypes
import os
import pytest
import tempfile
import numpy
from pyop2 import op2, mpi
from pyop2.caching import DEFAULT_CACHE, memory_and_disk_cache, clear_memory_cache
from itertools import chain
from textwrap import dedent
from pyop2 import op2
from pyop2.caching import (
DEFAULT_CACHE,
disk_only_cache,
memory_cache,
memory_and_disk_cache,
clear_memory_cache
)
from pyop2.compilation import load
from pyop2.mpi import (
MPI,
COMM_WORLD,
COMM_SELF,
comm_cache_keyval,
internal_comm,
temp_internal_comm
)


def _seed():
Expand Down Expand Up @@ -75,7 +92,7 @@ def dindset2(indset):

@pytest.fixture
def g():
return op2.Global(1, 0, numpy.uint32, "g", comm=mpi.COMM_WORLD)
return op2.Global(1, 0, numpy.uint32, "g", comm=COMM_WORLD)


@pytest.fixture
Expand Down Expand Up @@ -286,11 +303,11 @@ class TestGeneratedCodeCache:

@property
def cache(self):
int_comm = mpi.internal_comm(mpi.COMM_WORLD, self)
_cache_collection = int_comm.Get_attr(mpi.comm_cache_keyval)
int_comm = internal_comm(COMM_WORLD, self)
_cache_collection = int_comm.Get_attr(comm_cache_keyval)
if _cache_collection is None:
_cache_collection = {default_cache_name: DEFAULT_CACHE()}
int_comm.Set_attr(mpi.comm_cache_keyval, _cache_collection)
int_comm.Set_attr(comm_cache_keyval, _cache_collection)
return _cache_collection[default_cache_name]

@pytest.fixture
Expand Down Expand Up @@ -455,7 +472,7 @@ def test_change_dat_dtype_matters(self, iterset, diterset):
assert len(self.cache) == 2

def test_change_global_dtype_matters(self, iterset, diterset):
g = op2.Global(1, 0, dtype=numpy.uint32, comm=mpi.COMM_WORLD)
g = op2.Global(1, 0, dtype=numpy.uint32, comm=COMM_WORLD)
self.cache.clear()
assert len(self.cache) == 0

Expand All @@ -465,7 +482,7 @@ def test_change_global_dtype_matters(self, iterset, diterset):

assert len(self.cache) == 1

g = op2.Global(1, 0, dtype=numpy.float64, comm=mpi.COMM_WORLD)
g = op2.Global(1, 0, dtype=numpy.float64, comm=COMM_WORLD)
op2.par_loop(k, iterset, g(op2.INC))

assert len(self.cache) == 2
Expand Down Expand Up @@ -541,9 +558,9 @@ def myfunc(arg, comm):
def comm(self):
"""This fixture provides a temporary comm so that each test gets it's own
communicator and that caches are cleaned on free."""
temporary_comm = mpi.COMM_WORLD.Dup()
temporary_comm = COMM_WORLD.Dup()
temporary_comm.name = "pytest temp COMM_WORLD"
with mpi.temp_internal_comm(temporary_comm) as comm:
with temp_internal_comm(temporary_comm) as comm:
yield comm
temporary_comm.Free()

Expand All @@ -557,7 +574,7 @@ def test_decorator_in_memory_cache_reuses_results(self, cachedir, comm):
)(self.myfunc)

obj1 = decorated_func("input1", comm=comm)
mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name]
mem_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name]
assert len(mem_cache) == 1
assert len(os.listdir(cachedir.name)) == 1

Expand All @@ -571,22 +588,22 @@ def test_decorator_uses_different_in_memory_caches_on_different_comms(self, cach
cachedir=cachedir.name
)(self.myfunc)

temporary_comm = mpi.COMM_SELF.Dup()
temporary_comm = COMM_SELF.Dup()
temporary_comm.name = "pytest temp COMM_SELF"
with mpi.temp_internal_comm(temporary_comm) as comm_self:
with temp_internal_comm(temporary_comm) as comm_self:
comm_self_func = memory_and_disk_cache(
cachedir=cachedir.name
)(self.myfunc)

# obj1 should be cached on the COMM_WORLD cache
obj1 = comm_world_func("input1", comm=comm)
comm_world_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name]
comm_world_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name]
assert len(comm_world_cache) == 1
assert len(os.listdir(cachedir.name)) == 1

# obj2 should be cached on the COMM_SELF cache
obj2 = comm_self_func("input1", comm=comm_self)
comm_self_cache = comm_self.Get_attr(mpi.comm_cache_keyval)[default_cache_name]
comm_self_cache = comm_self.Get_attr(comm_cache_keyval)[default_cache_name]
assert obj1 == obj2 and obj1 is not obj2
assert len(comm_world_cache) == 1
assert len(comm_self_cache) == 1
Expand All @@ -600,7 +617,7 @@ def test_decorator_disk_cache_reuses_results(self, cachedir, comm):
obj1 = decorated_func("input1", comm=comm)
clear_memory_cache(comm)
obj2 = decorated_func("input1", comm=comm)
mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name]
mem_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name]
assert obj1 == obj2 and obj1 is not obj2
assert len(mem_cache) == 1
assert len(os.listdir(cachedir.name)) == 1
Expand All @@ -610,11 +627,182 @@ def test_decorator_cache_misses(self, cachedir, comm):

obj1 = decorated_func("input1", comm=comm)
obj2 = decorated_func("input2", comm=comm)
mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name]
mem_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name]
assert obj1 != obj2
assert len(mem_cache) == 2
assert len(os.listdir(cachedir.name)) == 2


# Test updated caching functionality
class StateIncrement:
"""Simple class for keeping track of the number of times executed
"""
def __init__(self):
self._count = 0

def __call__(self):
self._count += 1
return self._count

@property
def value(self):
return self._count


def twople(x):
return (x, )*2


def threeple(x):
return (x, )*3


def n_comms(n):
return [MPI.COMM_WORLD]*n


def n_ops(n):
return [MPI.SUM]*n


# decorator = parallel_memory_only_cache, parallel_memory_only_cache_no_broadcast, disk_only_cached
def function_factory(state, decorator, f, **kwargs):
def custom_function(x, comm=COMM_WORLD):
state()
return f(x)

return decorator(**kwargs)(custom_function)


@pytest.fixture
def state():
return StateIncrement()


@pytest.mark.parametrize("decorator, uncached_function", [
(memory_cache, twople),
(memory_cache, n_comms),
(memory_and_disk_cache, twople),
(disk_only_cache, twople)
])
def test_function_args_twice_caches(request, state, decorator, uncached_function, tmpdir):
if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}:
kwargs = {"cachedir": tmpdir}
else:
kwargs = {}

cached_function = function_factory(state, decorator, uncached_function, **kwargs)
assert state.value == 0
first = cached_function(2, comm=COMM_WORLD)
assert first == uncached_function(2)
assert state.value == 1
second = cached_function(2, comm=COMM_WORLD)
assert second == uncached_function(2)
if request.node.callspec.params["decorator"] is not disk_only_cache:
assert second is first
assert state.value == 1

clear_memory_cache(COMM_WORLD)


@pytest.mark.parametrize("decorator, uncached_function", [
(memory_cache, twople),
(memory_cache, n_comms),
(memory_and_disk_cache, twople),
(disk_only_cache, twople)
])
def test_function_args_different(request, state, decorator, uncached_function, tmpdir):
if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}:
kwargs = {"cachedir": tmpdir}
else:
kwargs = {}

cached_function = function_factory(state, decorator, uncached_function, **kwargs)
assert state.value == 0
first = cached_function(2, comm=COMM_WORLD)
assert first == uncached_function(2)
assert state.value == 1
second = cached_function(3, comm=COMM_WORLD)
assert second == uncached_function(3)
assert state.value == 2

clear_memory_cache(COMM_WORLD)


@pytest.mark.parallel(nprocs=3)
@pytest.mark.parametrize("decorator, uncached_function", [
(memory_cache, twople),
(memory_cache, n_comms),
(memory_and_disk_cache, twople),
(disk_only_cache, twople)
])
def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir):
if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}:
# In parallel different ranks can get different tempdirs, we just want one
tmpdir = COMM_WORLD.bcast(tmpdir, root=0)
kwargs = {"cachedir": tmpdir}
else:
kwargs = {}

cached_function = function_factory(state, decorator, uncached_function, **kwargs)
assert state.value == 0

for ii in range(10):
color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED
comm12 = COMM_WORLD.Split(color=color)
if COMM_WORLD.rank < 2:
_ = cached_function(2, comm=comm12)
comm12.Free()

color = 0 if COMM_WORLD.rank > 0 else MPI.UNDEFINED
comm23 = COMM_WORLD.Split(color=color)
if COMM_WORLD.rank > 0:
_ = cached_function(2, comm=comm23)
comm23.Free()

clear_memory_cache(COMM_WORLD)


# pyop2/compilation.py uses a custom cache which we test here
@pytest.mark.parallel(nprocs=2)
def test_writing_large_so():
# This test exercises the compilation caching when handling larger files
if COMM_WORLD.rank == 0:
preamble = dedent("""\
#include <stdio.h>\n
void big(double *result){
""")
variables = (f"v{next(tempfile._get_candidate_names())}" for _ in range(128*1024))
lines = (f" double {v} = {hash(v)/1000000000};\n *result += {v};\n" for v in variables)
program = "\n".join(chain.from_iterable(((preamble, ), lines, ("}\n", ))))
with open("big.c", "w") as fh:
fh.write(program)

COMM_WORLD.Barrier()
with open("big.c", "r") as fh:
program = fh.read()

if COMM_WORLD.rank == 1:
os.remove("big.c")

fn = load(program, "c", "big", argtypes=(ctypes.c_voidp,), comm=COMM_WORLD)
assert fn is not None


@pytest.mark.parallel(nprocs=2)
def test_two_comms_compile_the_same_code():
new_comm = COMM_WORLD.Split(color=COMM_WORLD.rank)
new_comm.name = "test_two_comms"
code = dedent("""\
#include <stdio.h>\n
void noop(){
printf("Do nothing!\\n");
}
""")

fn = load(code, "c", "noop", argtypes=(), comm=COMM_WORLD)
assert fn is not None


if __name__ == '__main__':
pytest.main(os.path.abspath(__file__))
Loading

0 comments on commit 81d0f0a

Please sign in to comment.