diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index e335ec680..1298991b3 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -31,13 +31,30 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED # OF THE POSSIBILITY OF SUCH DAMAGE. - +import ctypes import os import pytest import tempfile import numpy -from pyop2 import op2, mpi -from pyop2.caching import DEFAULT_CACHE, memory_and_disk_cache, clear_memory_cache +from itertools import chain +from textwrap import dedent +from pyop2 import op2 +from pyop2.caching import ( + DEFAULT_CACHE, + disk_only_cache, + memory_cache, + memory_and_disk_cache, + clear_memory_cache +) +from pyop2.compilation import load +from pyop2.mpi import ( + MPI, + COMM_WORLD, + COMM_SELF, + comm_cache_keyval, + internal_comm, + temp_internal_comm +) def _seed(): @@ -75,7 +92,7 @@ def dindset2(indset): @pytest.fixture def g(): - return op2.Global(1, 0, numpy.uint32, "g", comm=mpi.COMM_WORLD) + return op2.Global(1, 0, numpy.uint32, "g", comm=COMM_WORLD) @pytest.fixture @@ -286,11 +303,11 @@ class TestGeneratedCodeCache: @property def cache(self): - int_comm = mpi.internal_comm(mpi.COMM_WORLD, self) - _cache_collection = int_comm.Get_attr(mpi.comm_cache_keyval) + int_comm = internal_comm(COMM_WORLD, self) + _cache_collection = int_comm.Get_attr(comm_cache_keyval) if _cache_collection is None: _cache_collection = {default_cache_name: DEFAULT_CACHE()} - int_comm.Set_attr(mpi.comm_cache_keyval, _cache_collection) + int_comm.Set_attr(comm_cache_keyval, _cache_collection) return _cache_collection[default_cache_name] @pytest.fixture @@ -455,7 +472,7 @@ def test_change_dat_dtype_matters(self, iterset, diterset): assert len(self.cache) == 2 def test_change_global_dtype_matters(self, iterset, diterset): - g = op2.Global(1, 0, dtype=numpy.uint32, comm=mpi.COMM_WORLD) + g = op2.Global(1, 0, dtype=numpy.uint32, comm=COMM_WORLD) self.cache.clear() assert len(self.cache) == 0 @@ -465,7 +482,7 @@ def test_change_global_dtype_matters(self, iterset, diterset): assert len(self.cache) == 1 - g = op2.Global(1, 0, dtype=numpy.float64, comm=mpi.COMM_WORLD) + g = op2.Global(1, 0, dtype=numpy.float64, comm=COMM_WORLD) op2.par_loop(k, iterset, g(op2.INC)) assert len(self.cache) == 2 @@ -541,9 +558,9 @@ def myfunc(arg, comm): def comm(self): """This fixture provides a temporary comm so that each test gets it's own communicator and that caches are cleaned on free.""" - temporary_comm = mpi.COMM_WORLD.Dup() + temporary_comm = COMM_WORLD.Dup() temporary_comm.name = "pytest temp COMM_WORLD" - with mpi.temp_internal_comm(temporary_comm) as comm: + with temp_internal_comm(temporary_comm) as comm: yield comm temporary_comm.Free() @@ -557,7 +574,7 @@ def test_decorator_in_memory_cache_reuses_results(self, cachedir, comm): )(self.myfunc) obj1 = decorated_func("input1", comm=comm) - mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] + mem_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name] assert len(mem_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 @@ -571,22 +588,22 @@ def test_decorator_uses_different_in_memory_caches_on_different_comms(self, cach cachedir=cachedir.name )(self.myfunc) - temporary_comm = mpi.COMM_SELF.Dup() + temporary_comm = COMM_SELF.Dup() temporary_comm.name = "pytest temp COMM_SELF" - with mpi.temp_internal_comm(temporary_comm) as comm_self: + with temp_internal_comm(temporary_comm) as comm_self: comm_self_func = memory_and_disk_cache( cachedir=cachedir.name )(self.myfunc) # obj1 should be cached on the COMM_WORLD cache obj1 = comm_world_func("input1", comm=comm) - comm_world_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] + comm_world_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name] assert len(comm_world_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 # obj2 should be cached on the COMM_SELF cache obj2 = comm_self_func("input1", comm=comm_self) - comm_self_cache = comm_self.Get_attr(mpi.comm_cache_keyval)[default_cache_name] + comm_self_cache = comm_self.Get_attr(comm_cache_keyval)[default_cache_name] assert obj1 == obj2 and obj1 is not obj2 assert len(comm_world_cache) == 1 assert len(comm_self_cache) == 1 @@ -600,7 +617,7 @@ def test_decorator_disk_cache_reuses_results(self, cachedir, comm): obj1 = decorated_func("input1", comm=comm) clear_memory_cache(comm) obj2 = decorated_func("input1", comm=comm) - mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] + mem_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name] assert obj1 == obj2 and obj1 is not obj2 assert len(mem_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 @@ -610,11 +627,182 @@ def test_decorator_cache_misses(self, cachedir, comm): obj1 = decorated_func("input1", comm=comm) obj2 = decorated_func("input2", comm=comm) - mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] + mem_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name] assert obj1 != obj2 assert len(mem_cache) == 2 assert len(os.listdir(cachedir.name)) == 2 +# Test updated caching functionality +class StateIncrement: + """Simple class for keeping track of the number of times executed + """ + def __init__(self): + self._count = 0 + + def __call__(self): + self._count += 1 + return self._count + + @property + def value(self): + return self._count + + +def twople(x): + return (x, )*2 + + +def threeple(x): + return (x, )*3 + + +def n_comms(n): + return [MPI.COMM_WORLD]*n + + +def n_ops(n): + return [MPI.SUM]*n + + +# decorator = parallel_memory_only_cache, parallel_memory_only_cache_no_broadcast, disk_only_cached +def function_factory(state, decorator, f, **kwargs): + def custom_function(x, comm=COMM_WORLD): + state() + return f(x) + + return decorator(**kwargs)(custom_function) + + +@pytest.fixture +def state(): + return StateIncrement() + + +@pytest.mark.parametrize("decorator, uncached_function", [ + (memory_cache, twople), + (memory_cache, n_comms), + (memory_and_disk_cache, twople), + (disk_only_cache, twople) +]) +def test_function_args_twice_caches(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + first = cached_function(2, comm=COMM_WORLD) + assert first == uncached_function(2) + assert state.value == 1 + second = cached_function(2, comm=COMM_WORLD) + assert second == uncached_function(2) + if request.node.callspec.params["decorator"] is not disk_only_cache: + assert second is first + assert state.value == 1 + + clear_memory_cache(COMM_WORLD) + + +@pytest.mark.parametrize("decorator, uncached_function", [ + (memory_cache, twople), + (memory_cache, n_comms), + (memory_and_disk_cache, twople), + (disk_only_cache, twople) +]) +def test_function_args_different(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + first = cached_function(2, comm=COMM_WORLD) + assert first == uncached_function(2) + assert state.value == 1 + second = cached_function(3, comm=COMM_WORLD) + assert second == uncached_function(3) + assert state.value == 2 + + clear_memory_cache(COMM_WORLD) + + +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize("decorator, uncached_function", [ + (memory_cache, twople), + (memory_cache, n_comms), + (memory_and_disk_cache, twople), + (disk_only_cache, twople) +]) +def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: + # In parallel different ranks can get different tempdirs, we just want one + tmpdir = COMM_WORLD.bcast(tmpdir, root=0) + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + + for ii in range(10): + color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED + comm12 = COMM_WORLD.Split(color=color) + if COMM_WORLD.rank < 2: + _ = cached_function(2, comm=comm12) + comm12.Free() + + color = 0 if COMM_WORLD.rank > 0 else MPI.UNDEFINED + comm23 = COMM_WORLD.Split(color=color) + if COMM_WORLD.rank > 0: + _ = cached_function(2, comm=comm23) + comm23.Free() + + clear_memory_cache(COMM_WORLD) + + +# pyop2/compilation.py uses a custom cache which we test here +@pytest.mark.parallel(nprocs=2) +def test_writing_large_so(): + # This test exercises the compilation caching when handling larger files + if COMM_WORLD.rank == 0: + preamble = dedent("""\ + #include \n + void big(double *result){ + """) + variables = (f"v{next(tempfile._get_candidate_names())}" for _ in range(128*1024)) + lines = (f" double {v} = {hash(v)/1000000000};\n *result += {v};\n" for v in variables) + program = "\n".join(chain.from_iterable(((preamble, ), lines, ("}\n", )))) + with open("big.c", "w") as fh: + fh.write(program) + + COMM_WORLD.Barrier() + with open("big.c", "r") as fh: + program = fh.read() + + if COMM_WORLD.rank == 1: + os.remove("big.c") + + fn = load(program, "c", "big", argtypes=(ctypes.c_voidp,), comm=COMM_WORLD) + assert fn is not None + + +@pytest.mark.parallel(nprocs=2) +def test_two_comms_compile_the_same_code(): + new_comm = COMM_WORLD.Split(color=COMM_WORLD.rank) + new_comm.name = "test_two_comms" + code = dedent("""\ + #include \n + void noop(){ + printf("Do nothing!\\n"); + } + """) + + fn = load(code, "c", "noop", argtypes=(), comm=COMM_WORLD) + assert fn is not None + + if __name__ == '__main__': pytest.main(os.path.abspath(__file__)) diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py deleted file mode 100644 index 1d9424b05..000000000 --- a/test/unit/test_updated_caching.py +++ /dev/null @@ -1,185 +0,0 @@ -import ctypes -import pytest -import os -import tempfile -from itertools import chain -from textwrap import dedent - -from pyop2.caching import ( - disk_only_cache, - memory_cache, - memory_and_disk_cache, - clear_memory_cache -) -from pyop2.compilation import load -from pyop2.mpi import MPI, COMM_WORLD - - -class StateIncrement: - """Simple class for keeping track of the number of times executed - """ - def __init__(self): - self._count = 0 - - def __call__(self): - self._count += 1 - return self._count - - @property - def value(self): - return self._count - - -def twople(x): - return (x, )*2 - - -def threeple(x): - return (x, )*3 - - -def n_comms(n): - return [MPI.COMM_WORLD]*n - - -def n_ops(n): - return [MPI.SUM]*n - - -# decorator = parallel_memory_only_cache, parallel_memory_only_cache_no_broadcast, disk_only_cached -def function_factory(state, decorator, f, **kwargs): - def custom_function(x, comm=COMM_WORLD): - state() - return f(x) - - return decorator(**kwargs)(custom_function) - - -@pytest.fixture -def state(): - return StateIncrement() - - -@pytest.mark.parametrize("decorator, uncached_function", [ - (memory_cache, twople), - (memory_cache, n_comms), - (memory_and_disk_cache, twople), - (disk_only_cache, twople) -]) -def test_function_args_twice_caches(request, state, decorator, uncached_function, tmpdir): - if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: - kwargs = {"cachedir": tmpdir} - else: - kwargs = {} - - cached_function = function_factory(state, decorator, uncached_function, **kwargs) - assert state.value == 0 - first = cached_function(2, comm=COMM_WORLD) - assert first == uncached_function(2) - assert state.value == 1 - second = cached_function(2, comm=COMM_WORLD) - assert second == uncached_function(2) - if request.node.callspec.params["decorator"] is not disk_only_cache: - assert second is first - assert state.value == 1 - - clear_memory_cache(COMM_WORLD) - - -@pytest.mark.parametrize("decorator, uncached_function", [ - (memory_cache, twople), - (memory_cache, n_comms), - (memory_and_disk_cache, twople), - (disk_only_cache, twople) -]) -def test_function_args_different(request, state, decorator, uncached_function, tmpdir): - if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: - kwargs = {"cachedir": tmpdir} - else: - kwargs = {} - - cached_function = function_factory(state, decorator, uncached_function, **kwargs) - assert state.value == 0 - first = cached_function(2, comm=COMM_WORLD) - assert first == uncached_function(2) - assert state.value == 1 - second = cached_function(3, comm=COMM_WORLD) - assert second == uncached_function(3) - assert state.value == 2 - - clear_memory_cache(COMM_WORLD) - - -@pytest.mark.parallel(nprocs=3) -@pytest.mark.parametrize("decorator, uncached_function", [ - (memory_cache, twople), - (memory_cache, n_comms), - (memory_and_disk_cache, twople), - (disk_only_cache, twople) -]) -def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir): - if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: - # In parallel different ranks can get different tempdirs, we just want one - tmpdir = COMM_WORLD.bcast(tmpdir, root=0) - kwargs = {"cachedir": tmpdir} - else: - kwargs = {} - - cached_function = function_factory(state, decorator, uncached_function, **kwargs) - assert state.value == 0 - - for ii in range(10): - color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED - comm12 = COMM_WORLD.Split(color=color) - if COMM_WORLD.rank < 2: - _ = cached_function(2, comm=comm12) - comm12.Free() - - color = 0 if COMM_WORLD.rank > 0 else MPI.UNDEFINED - comm23 = COMM_WORLD.Split(color=color) - if COMM_WORLD.rank > 0: - _ = cached_function(2, comm=comm23) - comm23.Free() - - clear_memory_cache(COMM_WORLD) - - -# pyop2/compilation.py uses a custom cache which we test here -@pytest.mark.parallel(nprocs=2) -def test_writing_large_so(): - # This test exercises the compilation caching when handling larger files - if COMM_WORLD.rank == 0: - preamble = dedent("""\ - #include \n - void big(double *result){ - """) - variables = (f"v{next(tempfile._get_candidate_names())}" for _ in range(128*1024)) - lines = (f" double {v} = {hash(v)/1000000000};\n *result += {v};\n" for v in variables) - program = "\n".join(chain.from_iterable(((preamble, ), lines, ("}\n", )))) - with open("big.c", "w") as fh: - fh.write(program) - - COMM_WORLD.Barrier() - with open("big.c", "r") as fh: - program = fh.read() - - if COMM_WORLD.rank == 1: - os.remove("big.c") - - fn = load(program, "c", "big", argtypes=(ctypes.c_voidp,), comm=COMM_WORLD) - assert fn is not None - - -@pytest.mark.parallel(nprocs=2) -def test_two_comms_compile_the_same_code(): - new_comm = COMM_WORLD.Split(color=COMM_WORLD.rank) - new_comm.name = "test_two_comms" - code = dedent("""\ - #include \n - void noop(){ - printf("Do nothing!\\n"); - } - """) - - fn = load(code, "c", "noop", argtypes=(), comm=COMM_WORLD) - assert fn is not None