Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching improvements #3989

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8d413cc
Caching improvements
connorjward Jan 23, 2025
5ed1a91
always spit out log files
connorjward Jan 23, 2025
8deaead
Use a timeout method that preserves more information
connorjward Jan 24, 2025
4d07082
more cleaning
connorjward Jan 24, 2025
8631b8c
fixupgs
connorjward Jan 24, 2025
d297a7e
fixup
connorjward Jan 24, 2025
13168c0
Try and avoid race conditions
connorjward Jan 28, 2025
fb57af4
Use strict SPMD behaviour to try and track these down. Also change be…
connorjward Jan 28, 2025
84a8ad7
Refactor parallel_cache decorator
connorjward Jan 28, 2025
b8235f8
Apply suggestions from code review
connorjward Jan 28, 2025
6a869ae
experimenting
connorjward Jan 28, 2025
37ae507
fixup
connorjward Jan 28, 2025
845b01c
debugging
connorjward Jan 29, 2025
ed4ef2b
-s to build
connorjward Jan 29, 2025
318cfd4
more print
connorjward Jan 29, 2025
4c5bbe4
improvements, hopefully fixed?
connorjward Jan 29, 2025
2dd2da6
avoid race conditions, does this fix things?
connorjward Jan 29, 2025
1aed273
Merge branch 'master' into connorjward/more-cache-fixes
connorjward Jan 29, 2025
ef5d03b
Merge remote-tracking branch 'origin/master' into connorjward/more-ca…
connorjward Jan 29, 2025
5c1b9f5
Add extra SPMD_STRICT check
connorjward Jan 30, 2025
1da4d5c
Is is not fixed?
connorjward Jan 30, 2025
81e8cb4
Fix bad hashing
connorjward Jan 31, 2025
22054ed
Point to FIAT branch
connorjward Jan 31, 2025
6592cdf
linting
connorjward Jan 31, 2025
2507977
Fix hashing for interface kwarg
connorjward Feb 6, 2025
e914120
Fix caching for patch
connorjward Feb 7, 2025
60c9c0e
Apply suggestions from code review
connorjward Feb 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
COMPLEX: ${{ matrix.complex }}
RDMAV_FORK_SAFE: 1
EXTRA_PYTEST_ARGS: --splitting-algorithm least_duration --timeout=1800 --timeout-method=thread -o faulthandler_timeout=1860 tests/firedrake
PYOP2_SPMD_STRICT: 1
steps:
- uses: actions/checkout@v4

Expand Down Expand Up @@ -95,7 +96,7 @@ jobs:
sudo apt update
sudo apt -y install parallel
. ../firedrake_venv/bin/activate
python "$(which firedrake-clean)"
firedrake-clean
python -m pip install pytest-timeout ipympl pytest-split pytest-xdist
python -m pip list

Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/pip-mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
if: ${{ (github.ref == 'refs/heads/master') || contains(github.event.pull_request.labels.*.name, 'macOS') }}
env:
FIREDRAKE_CI_TESTS: 1
PYOP2_SPMD_STRICT: 1
OMP_NUM_THREADS: 1
OPENBLAS_NUM_THREADS: 1
steps:
Expand Down Expand Up @@ -98,8 +99,9 @@ jobs:
- name: Run Firedrake smoke tests
run: |
source pip_venv/bin/activate
firedrake-clean
cd pip_venv/src/firedrake
make check CHECK_PYTEST_ARGS="--timeout 60"
make check CHECK_PYTEST_ARGS="--timeout 60 --timeout-method=thread"
timeout-minutes: 10

- name: Cleanup (post)
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
# PETSC_DIR, HDF5_DIR and MPICH_DIR are set inside the docker image
FIREDRAKE_CI_TESTS: 1
PYOP2_CI_TESTS: 1
PYOP2_SPMD_STRICT: 1
PETSC_ARCH: ${{ matrix.petsc_arch }}
OMP_NUM_THREADS: 1
OPENBLAS_NUM_THREADS: 1
Expand Down Expand Up @@ -82,8 +83,9 @@ jobs:
- name: Run Firedrake smoke tests
run: |
source pip_venv/bin/activate
firedrake-clean
cd pip_venv/src/firedrake
make check CHECK_PYTEST_ARGS="--timeout 60"
make check CHECK_PYTEST_ARGS="--timeout 60 --timeout-method=thread"
timeout-minutes: 10

- name: Publish Test Report
Expand Down
5 changes: 4 additions & 1 deletion firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ class RestrictedFunctionSpace(FunctionSpace):
def __init__(self, function_space, boundary_set=frozenset(), name=None):
label = ""
boundary_set_ = []
for boundary_domain in boundary_set:
for boundary_domain in sorted(boundary_set, key=str):
connorjward marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(boundary_domain, str):
boundary_set_.append(boundary_domain)
else:
Expand All @@ -885,6 +885,9 @@ def __init__(self, function_space, boundary_set=frozenset(), name=None):
bd, = as_tuple(boundary_domain)
boundary_set_.append(bd)
boundary_set = boundary_set_

# NOTE: boundary_set must be deterministically ordered here to ensure
# consistency between ranks
connorjward marked this conversation as resolved.
Show resolved Hide resolved
for boundary_domain in boundary_set:
label += str(boundary_domain)
label += "_"
Expand Down
2 changes: 1 addition & 1 deletion firedrake/slate/slac/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _compile_expression_comm(*args, **kwargs):

@memory_and_disk_cache(
hashkey=_compile_expression_hashkey,
comm_fetcher=_compile_expression_comm,
comm_getter=_compile_expression_comm,
cachedir=tsfc_interface._cachedir
)
@PETSc.Log.EventDecorator()
Expand Down
34 changes: 18 additions & 16 deletions firedrake/tsfc_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from os import path, environ, getuid, makedirs
import tempfile
import collections
import cachetools

Check failure on line 10 in firedrake/tsfc_interface.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/tsfc_interface.py:10:1: F401 'cachetools' imported but unused

import ufl
import finat.ufl
Expand Down Expand Up @@ -54,19 +54,23 @@


def tsfc_compile_form_hashkey(form, prefix, parameters, interface, diagonal):
# Drop prefix as it's only used for naming
return default_parallel_hashkey(form.signature(), prefix, parameters, interface, diagonal)
return default_parallel_hashkey(
form.signature(),
prefix,
utils.tuplify(parameters),
type(interface).__name__,
connorjward marked this conversation as resolved.
Show resolved Hide resolved
diagonal,
)


def _compile_form_comm(*args, **kwargs):
# args[0] is a form
return args[0].ufl_domains()[0].comm
def _compile_form_comm(form, *args, **kwargs):
return form.ufl_domains()[0].comm


# Decorate the original tsfc.compile_form with a cache
tsfc_compile_form = memory_and_disk_cache(
hashkey=tsfc_compile_form_hashkey,
comm_fetcher=_compile_form_comm,
comm_getter=_compile_form_comm,
cachedir=_cachedir
)(original_tsfc_compile_form)

Expand Down Expand Up @@ -133,23 +137,21 @@
SplitKernel = collections.namedtuple("SplitKernel", ["indices", "kinfo"])


def _compile_form_hashkey(*args, **kwargs):
# form, name, parameters, split, diagonal
parameters = kwargs.pop("parameters", None)
key = cachetools.keys.hashkey(
args[0].signature(),
*args[1:],
def _compile_form_hashkey(form, name, parameters=None, split=True, interface=None, diagonal=False):
return (
form.signature(),
name,
utils.tuplify(parameters),
**kwargs
split,
type(interface).__name__,
diagonal,
)
kwargs.setdefault("parameters", parameters)
return key


@PETSc.Log.EventDecorator()
@memory_and_disk_cache(
hashkey=_compile_form_hashkey,
comm_fetcher=_compile_form_comm,
comm_getter=_compile_form_comm,
cachedir=_cachedir
)
@PETSc.Log.EventDecorator()
Expand Down
191 changes: 81 additions & 110 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from pyop2.mpi import (
MPI, COMM_WORLD, comm_cache_keyval, temp_internal_comm
)
import pytools
from petsc4py import PETSc


Expand Down Expand Up @@ -365,7 +366,7 @@ def write(self, filehandle, value):
pickle.dump(value, filehandle)


def default_comm_fetcher(*args, **kwargs):
def default_comm_getter(*args, **kwargs):
""" A sensible default comm fetcher for use with `parallel_cache`.
"""
comms = filter(
Expand Down Expand Up @@ -440,119 +441,89 @@ class DEFAULT_CACHE(dict):
DictLikeDiskAccess = instrument(DictLikeDiskAccess)


if configuration["spmd_strict"]:
def parallel_cache(
hashkey=default_parallel_hashkey,
comm_fetcher=default_comm_fetcher,
cache_factory=lambda: DEFAULT_CACHE(),
):
"""Parallel cache decorator (SPMD strict-enabled).
"""
def decorator(func):
@PETSc.Log.EventDecorator("PyOP2 Cache Wrapper")
@wraps(func)
def wrapper(*args, **kwargs):
""" Extract the key and then try the memory cache before falling back
on calling the function and populating the cache. SPMD strict ensures
that all ranks cache hit or miss to ensure that the function evaluation
always occurs in parallel.
"""
k = hashkey(*args, **kwargs)
key = _as_hexdigest(*k), func.__qualname__
# Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits
with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm:
# Fetch the per-comm cache_collection or set it up if not present
# A collection is required since different types of cache can be set up on the same comm
cache_collection = comm.Get_attr(comm_cache_keyval)
if cache_collection is None:
cache_collection = {}
comm.Set_attr(comm_cache_keyval, cache_collection)
# If this kind of cache is already present on the
# cache_collection, get it, otherwise create it
local_cache = cache_collection.setdefault(
(cf := cache_factory()).__class__.__name__,
cf
)
local_cache = cache_collection[cf.__class__.__name__]

# If this is a new cache or function add it to the list of known caches
if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]:
# When a comm is freed we do not hold a reference to the cache.
# We attach a finalizer that extracts the stats before the cache
# is deleted.
_KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache))

# Grab value from all ranks cache and broadcast cache hit/miss
value = local_cache.get(key, CACHE_MISS)
debug_string = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: "
debug_string += f"key={k} in cache: {local_cache.__class__.__name__} cache "
if value is CACHE_MISS:
debug(debug_string + "miss")
cache_hit = False
else:
debug(debug_string + "hit")
cache_hit = True
all_present = comm.allgather(cache_hit)

# If not present in the cache of all ranks we force re-evaluation on all ranks
if not min(all_present):
value = CACHE_MISS
def parallel_cache(
ksagiyam marked this conversation as resolved.
Show resolved Hide resolved
hashkey=default_parallel_hashkey,
comm_getter=default_comm_getter,
cache_factory=lambda: DEFAULT_CACHE(),
bcast=False,
):
"""Parallel cache decorator.

Parameters
----------
hashkey :
Callable taking ``*args`` and ``**kwargs`` and returning a hash.
comm_getter :
Callable taking ``*args`` and ``**kwargs`` and returning the
appropriate communicator.
cache_factory :
Callable that will build a new cache (if one does not exist).
connorjward marked this conversation as resolved.
Show resolved Hide resolved
bcast :
If `True`, then generate the new cache value on one rank and broadcast
to the others. If `False` then values are generated on all ranks.
This option can only be `True` if the operation can be executed in
serial; else it will deadlock.

"""
def decorator(func):
@PETSc.Log.EventDecorator("pyop2: cache wrapper")
@wraps(func)
def wrapper(*args, **kwargs):
# Extract the key and then try the memory cache before falling back
# to calling the function and populating the cache.
k = hashkey(*args, **kwargs)
key = _as_hexdigest(*k), func.__qualname__

# Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits
with temp_internal_comm(comm_getter(*args, **kwargs)) as comm:
if configuration["spmd_strict"] and not pytools.is_single_valued(comm.allgather(key)):
raise ValueError("Cache keys differ between ranks")

# Fetch the per-comm cache_collection or set it up if not present
# A collection is required since different types of cache can be set up on the same comm
cache_collection = comm.Get_attr(comm_cache_keyval)
if cache_collection is None:
cache_collection = {}
comm.Set_attr(comm_cache_keyval, cache_collection)
# If this kind of cache is already present on the
# cache_collection, get it, otherwise create it
local_cache = cache_collection.setdefault(
(cf := cache_factory()).__class__.__name__,
cf
)
local_cache = cache_collection[cf.__class__.__name__]

# If this is a new cache or function add it to the list of known caches
if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]:
# When a comm is freed we do not hold a reference to the cache.
# We attach a finalizer that extracts the stats before the cache
# is deleted.
_KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache))

# Grab value from all ranks cache and broadcast cache hit/miss
value = local_cache.get(key, CACHE_MISS)
debug_string = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: "
debug_string += f"key={k} in cache: {local_cache.__class__.__name__} cache "
if value is CACHE_MISS:
debug(debug_string + "miss")
cache_hit = False
else:
debug(debug_string + "hit")
cache_hit = True

if configuration["spmd_strict"] and not pytools.is_single_valued(comm.allgather(cache_hit)):
raise ValueError("Cache hit on some ranks but missed on others")

if value is CACHE_MISS:
if bcast:
value = func(*args, **kwargs) if comm.rank == 0 else None
value = comm.bcast(value, root=0)
else:
value = func(*args, **kwargs)
return local_cache.setdefault(key, value)

return wrapper
return decorator
else:
def parallel_cache(
hashkey=default_parallel_hashkey,
comm_fetcher=default_comm_fetcher,
cache_factory=lambda: DEFAULT_CACHE(),
):
"""Parallel cache decorator.
"""
def decorator(func):
@PETSc.Log.EventDecorator("PyOP2 Cache Wrapper")
@wraps(func)
def wrapper(*args, **kwargs):
""" Extract the key and then try the memory cache before falling back
on calling the function and populating the cache.
"""
k = hashkey(*args, **kwargs)
key = _as_hexdigest(*k), func.__qualname__
# Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits
with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm:
# Fetch the per-comm cache_collection or set it up if not present
# A collection is required since different types of cache can be set up on the same comm
cache_collection = comm.Get_attr(comm_cache_keyval)
if cache_collection is None:
cache_collection = {}
comm.Set_attr(comm_cache_keyval, cache_collection)
# If this kind of cache is already present on the
# cache_collection, get it, otherwise create it
local_cache = cache_collection.setdefault(
(cf := cache_factory()).__class__.__name__,
cf
)
local_cache = cache_collection[cf.__class__.__name__]

# If this is a new cache or function add it to the list of known caches
if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]:
# When a comm is freed we do not hold a reference to the cache.
# We attach a finalizer that extracts the stats before the cache
# is deleted.
_KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache))

value = local_cache.get(key, CACHE_MISS)

if value is CACHE_MISS:
with PETSc.Log.Event("pyop2: handle cache miss"):
value = func(*args, **kwargs)
return local_cache.setdefault(key, value)

return wrapper
return decorator
return local_cache.setdefault(key, value)
return wrapper
return decorator


def clear_memory_cache(comm):
Expand Down
Loading
Loading