Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed race conditioning in indices_boundary_masker #73

Merged
merged 5 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import xlb
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.helper import create_nse_fields, initialize_eq
from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import (
FullwayBounceBackBC,
Expand Down Expand Up @@ -48,15 +48,12 @@ def _setup(self, omega):
self.setup_stepper(omega)

def define_boundary_indices(self):
inlet = self.grid.boundingBoxIndices["left"]
outlet = self.grid.boundingBoxIndices["right"]
walls = [
self.grid.boundingBoxIndices["bottom"][i]
+ self.grid.boundingBoxIndices["top"][i]
+ self.grid.boundingBoxIndices["front"][i]
+ self.grid.boundingBoxIndices["back"][i]
for i in range(self.velocity_set.d)
]
box = self.grid.bounding_box_indices()
box_noedge = self.grid.bounding_box_indices(remove_edges=True)
inlet = box_noedge["left"]
outlet = box_noedge["right"]
walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)]
walls = np.unique(np.array(walls), axis=-1).tolist()

sphere_radius = self.grid_shape[1] // 12
x = np.arange(self.grid_shape[0])
Expand All @@ -79,13 +76,12 @@ def setup_boundary_conditions(self):
# bc_outlet = DoNothingBC(indices=outlet)
bc_outlet = ExtrapolationOutflowBC(indices=outlet)
bc_sphere = HalfwayBounceBackBC(indices=sphere)

self.boundary_conditions = [bc_left, bc_outlet, bc_sphere, bc_walls]
# Note: it is important to add bc_walls to be after bc_outlet/bc_inlet because
# of the corner nodes. This way the corners are treated as wall and not inlet/outlet.
# TODO: how to ensure about this behind in the src code?
self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere]

def setup_boundary_masker(self):
# check boundary condition list for duplicate indices before creating bc mask
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)

indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
Expand All @@ -105,6 +101,8 @@ def run(self, num_steps, post_process_interval=100):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if i == 0:
self.check_boundary_mask()
hsalehipour marked this conversation as resolved.
Show resolved Hide resolved
if i % post_process_interval == 0 or i == num_steps - 1:
self.post_process(i)
end_time = time.time()
Expand Down Expand Up @@ -134,6 +132,23 @@ def post_process(self, i):

# save_fields_vtk(fields, timestep=i)
save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i)
return

def check_boundary_mask(self):
# Write the results. We'll use JAX backend for the post-processing
if not isinstance(self.f_0, jnp.ndarray):
bmask = wp.to_jax(self.bc_mask)[0]
else:
bmask = self.bc_mask[0]

# save_fields_vtk(fields, timestep=i)
save_image(bmask[0, :, :], prefix="00_left")
hsalehipour marked this conversation as resolved.
Show resolved Hide resolved
save_image(bmask[self.grid_shape[0] - 1, :, :], prefix="00_right")
save_image(bmask[:, :, self.grid_shape[2] - 1], prefix="00_top")
save_image(bmask[:, :, 0], prefix="00_bottom")
hsalehipour marked this conversation as resolved.
Show resolved Hide resolved
save_image(bmask[:, 0, :], prefix="00_front")
save_image(bmask[:, self.grid_shape[1] - 1, :], prefix="00_back")
save_image(bmask[:, self.grid_shape[1] // 2, :], prefix="00_middle")


if __name__ == "__main__":
Expand Down
19 changes: 11 additions & 8 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import xlb
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.helper import create_nse_fields, initialize_eq
from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps
from xlb.operator.boundary_masker import IndicesBoundaryMasker
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import HalfwayBounceBackBC, EquilibriumBC
from xlb.operator.macroscopic import Macroscopic
from xlb.utils import save_fields_vtk, save_image
import xlb.velocity_set
import warp as wp
import jax.numpy as jnp
import xlb.velocity_set
import numpy as np


class LidDrivenCavity2D:
Expand Down Expand Up @@ -39,20 +40,22 @@ def _setup(self, omega):
self.setup_stepper(omega)

def define_boundary_indices(self):
lid = self.grid.boundingBoxIndices["top"]
walls = [
self.grid.boundingBoxIndices["bottom"][i] + self.grid.boundingBoxIndices["left"][i] + self.grid.boundingBoxIndices["right"][i]
for i in range(self.velocity_set.d)
]
box = self.grid.bounding_box_indices()
box_noedge = self.grid.bounding_box_indices(remove_edges=True)
lid = box_noedge["top"]
hsalehipour marked this conversation as resolved.
Show resolved Hide resolved
walls = [box["bottom"][i] + box["left"][i] + box["right"][i] for i in range(self.velocity_set.d)]
walls = np.unique(np.array(walls), axis=-1).tolist()
return lid, walls

def setup_boundary_conditions(self):
lid, walls = self.define_boundary_indices()
bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid)
bc_walls = HalfwayBounceBackBC(indices=walls)
self.boundary_conditions = [bc_top, bc_walls]
self.boundary_conditions = [bc_walls, bc_top]

def setup_boundary_masker(self):
# check boundary condition list for duplicate indices before creating bc mask
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)
indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
Expand Down
4 changes: 2 additions & 2 deletions examples/cfd/turbulent_channel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def _setup(self):

def define_boundary_indices(self):
# top and bottom sides of the channel are no-slip and the other directions are periodic
boundingBoxIndices = self.grid.bounding_box_indices(remove_edges=True)
walls = [boundingBoxIndices["bottom"][i] + boundingBoxIndices["top"][i] for i in range(self.velocity_set.d)]
box = self.grid.bounding_box_indices(remove_edges=True)
walls = [box["bottom"][i] + box["top"][i] for i in range(self.velocity_set.d)]
return walls

def setup_boundary_conditions(self):
Expand Down
22 changes: 11 additions & 11 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.helper import create_nse_fields, initialize_eq
from xlb.helper import create_nse_fields, initialize_eq, check_bc_overlaps
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import (
FullwayBounceBackBC,
Expand Down Expand Up @@ -67,15 +67,12 @@ def voxelize_stl(self, stl_filename, length_lbm_unit):
return mesh_matrix, pitch

def define_boundary_indices(self):
inlet = self.grid.boundingBoxIndices["left"]
outlet = self.grid.boundingBoxIndices["right"]
walls = [
self.grid.boundingBoxIndices["bottom"][i]
+ self.grid.boundingBoxIndices["top"][i]
+ self.grid.boundingBoxIndices["front"][i]
+ self.grid.boundingBoxIndices["back"][i]
for i in range(self.velocity_set.d)
]
box = self.grid.bounding_box_indices()
box_noedge = self.grid.bounding_box_indices(remove_edges=True)
inlet = box_noedge["left"]
outlet = box_noedge["right"]
walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)]
walls = np.unique(np.array(walls), axis=-1).tolist()

# Load the mesh
stl_filename = "examples/cfd/stl-files/DrivAer-Notchback.stl"
Expand Down Expand Up @@ -104,9 +101,12 @@ def setup_boundary_conditions(self):
# bc_car = HalfwayBounceBackBC(mesh_vertices=car)
bc_car = GradsApproximationBC(mesh_vertices=car)
# bc_car = FullwayBounceBackBC(mesh_vertices=car)
self.boundary_conditions = [bc_left, bc_do_nothing, bc_walls, bc_car]
self.boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car]

def setup_boundary_masker(self):
# check boundary condition list for duplicate indices before creating bc mask
check_bc_overlaps(self.boundary_conditions, self.velocity_set.d, self.backend)
hsalehipour marked this conversation as resolved.
Show resolved Hide resolved

indices_boundary_masker = IndicesBoundaryMasker(
velocity_set=self.velocity_set,
precision_policy=self.precision_policy,
Expand Down
13 changes: 4 additions & 9 deletions examples/performance/mlups_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,10 @@ def create_grid_and_fields(cube_edge):


def define_boundary_indices(grid):
lid = grid.boundingBoxIndices["top"]
walls = [
grid.boundingBoxIndices["bottom"][i]
+ grid.boundingBoxIndices["left"][i]
+ grid.boundingBoxIndices["right"][i]
+ grid.boundingBoxIndices["front"][i]
+ grid.boundingBoxIndices["back"][i]
for i in range(len(grid.shape))
]
box = grid.bounding_box_indices()
box_noedge = grid.bounding_box_indices(remove_edges=True)
lid = box_noedge["top"]
hsalehipour marked this conversation as resolved.
Show resolved Hide resolved
walls = [box["bottom"][i] + box["left"][i] + box["right"][i] + box["front"][i] + box["back"][i] for i in range(len(grid.shape))]
return lid, walls


Expand Down
1 change: 0 additions & 1 deletion xlb/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self, shape: Tuple[int, ...], compute_backend: ComputeBackend):
self.shape = shape
self.dim = len(shape)
self.compute_backend = compute_backend
self.boundingBoxIndices = self.bounding_box_indices()
self._initialize_backend()

@abstractmethod
Expand Down
1 change: 1 addition & 0 deletions xlb/helper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from xlb.helper.nse_solver import create_nse_fields as create_nse_fields
from xlb.helper.initializers import initialize_eq as initialize_eq
from xlb.helper.check_boundary_overlaps import check_bc_overlaps as check_bc_overlaps
hsalehipour marked this conversation as resolved.
Show resolved Hide resolved
22 changes: 22 additions & 0 deletions xlb/helper/check_boundary_overlaps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
from xlb.compute_backend import ComputeBackend


def check_bc_overlaps(bclist, dim, backend):
index_list = [[] for _ in range(dim)]
for bc in bclist:
if bc.indices is None:
continue
# Detect duplicates within bc.indices
index_arr = np.unique(bc.indices, axis=-1)
if index_arr.shape[-1] != len(bc.indices[0]):
if backend == ComputeBackend.WARP:
raise ValueError(f"Boundary condition {bc.__class__.__name__} has duplicate indices!")
for d in range(dim):
index_list[d] += bc.indices[d]

# Detect duplicates within bclist
index_arr = np.unique(index_list, axis=-1)
if index_arr.shape[-1] != len(index_list[0]):
if backend == ComputeBackend.WARP:
raise ValueError("Boundary condition list containes duplicate indices!")
1 change: 0 additions & 1 deletion xlb/operator/boundary_condition/bc_grads_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
indices=None,
mesh_vertices=None,
):

# TODO: the input velocity must be suitably stored elesewhere when mesh is moving.
self.u = (0, 0, 0)

Expand Down
1 change: 1 addition & 0 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from xlb import DefaultConfig
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry


# Enum for implementation step
class ImplementationStep(Enum):
COLLISION = auto()
Expand Down
37 changes: 11 additions & 26 deletions xlb/operator/boundary_masker/indices_boundary_masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
start_index = (0,) * dim

domain_shape = bc_mask[0].shape
for bc in bclist:
for bc in reversed(bclist):
hsalehipour marked this conversation as resolved.
Show resolved Hide resolved
assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!"
assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!"
id_number = bc.id
Expand Down Expand Up @@ -103,6 +103,11 @@ def _construct_warp(self):
_c = self.velocity_set.c
_q = wp.constant(self.velocity_set.q)

@wp.func
def check_index_bounds(index: wp.vec3i, shape: wp.vec3i):
is_in_bounds = index[0] >= 0 and index[0] < shape[0] and index[1] >= 0 and index[1] < shape[1] and index[2] >= 0 and index[2] < shape[2]
return is_in_bounds

# Construct the warp 2D kernel
@wp.kernel
def kernel2d(
Expand Down Expand Up @@ -173,14 +178,8 @@ def kernel3d(
index[2] = indices[2, ii] - start_index[2]

# Check if index is in bounds
if (
index[0] >= 0
and index[0] < missing_mask.shape[1]
and index[1] >= 0
and index[1] < missing_mask.shape[2]
and index[2] >= 0
and index[2] < missing_mask.shape[3]
):
shape = wp.vec3i(missing_mask.shape[1], missing_mask.shape[2], missing_mask.shape[3])
if check_index_bounds(index, shape):
# Stream indices
for l in range(_q):
# Get the index of the streaming direction
Expand All @@ -195,27 +194,12 @@ def kernel3d(

# check if pull index is out of bound
# These directions will have missing information after streaming
if (
pull_index[0] < 0
or pull_index[0] >= missing_mask.shape[1]
or pull_index[1] < 0
or pull_index[1] >= missing_mask.shape[2]
or pull_index[2] < 0
or pull_index[2] >= missing_mask.shape[3]
):
if not check_index_bounds(pull_index, shape):
# Set the missing mask
missing_mask[l, index[0], index[1], index[2]] = True

# handling geometries in the interior of the computational domain
elif (
is_interior[ii]
and push_index[0] >= 0
and push_index[0] < missing_mask.shape[1]
and push_index[1] >= 0
and push_index[1] < missing_mask.shape[2]
and push_index[2] >= 0
and push_index[2] < missing_mask.shape[3]
):
elif check_index_bounds(pull_index, shape) and is_interior[ii]:
# Set the missing mask
missing_mask[l, push_index[0], push_index[1], push_index[2]] = True
bc_mask[0, push_index[0], push_index[1], push_index[2]] = id_number[ii]
Expand All @@ -241,6 +225,7 @@ def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
# We are done with bc.indices. Remove them from BC objects
bc.__dict__.pop("indices", None)

# convert to warp arrays
indices = wp.array2d(index_list, dtype=wp.int32)
id_number = wp.array1d(id_list, dtype=wp.uint8)
is_interior = wp.array1d(is_interior, dtype=wp.bool)
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/collision/kbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def decompose_shear_d2q9_jax(self, fneq):

def _construct_warp(self):
# Raise error if velocity set is not supported
if not isinstance(self.velocity_set, D3Q27):
if not (isinstance(self.velocity_set, D3Q27) or isinstance(self.velocity_set, D2Q9)):
raise NotImplementedError("Velocity set not supported for warp backend: {}".format(type(self.velocity_set)))

# Set local constants TODO: This is a hack and should be fixed with warp update
Expand All @@ -192,7 +192,7 @@ def _construct_warp(self):
def decompose_shear_d2q9(fneq: Any):
pi = self.momentum_flux.warp_functional(fneq)
N = pi[0] - pi[1]
s = wp.vec9(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
s = _f_vec()
s[3] = N
s[6] = N
s[2] = -N
Expand Down
8 changes: 5 additions & 3 deletions xlb/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def downsample_field(field, factor, method="bicubic"):
return jnp.stack(downsampled_components, axis=-1)


def save_image(fld, timestep, prefix=None):
def save_image(fld, timestep=None, prefix=None, **kwargs):
"""
Save an image of a field at a given timestep.

Expand Down Expand Up @@ -74,15 +74,17 @@ def save_image(fld, timestep, prefix=None):
else:
fname = prefix

fname = fname + "_" + str(timestep).zfill(4)
if timestep is not None:
fname = fname + "_" + str(timestep).zfill(4)

if len(fld.shape) > 3:
raise ValueError("The input field should be 2D!")
if len(fld.shape) == 3:
fld = np.sqrt(fld[0, ...] ** 2 + fld[0, ...] ** 2)

plt.clf()
plt.imsave(fname + ".png", fld.T, cmap=cm.nipy_spectral, origin="lower")
kwargs.pop("cmap", None)
hsalehipour marked this conversation as resolved.
Show resolved Hide resolved
plt.imsave(fname + ".png", fld.T, cmap=cm.nipy_spectral, origin="lower", **kwargs)


def save_fields_vtk(fields, timestep, output_dir=".", prefix="fields"):
Expand Down
Loading