Skip to content

Commit

Permalink
Converted examples to functional. Made compute_backend name consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdiataei committed Jan 16, 2025
1 parent d4a92bc commit 6a15554
Show file tree
Hide file tree
Showing 30 changed files with 732 additions and 589 deletions.
266 changes: 132 additions & 134 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,95 +16,61 @@
import jax.numpy as jnp
import time

# -------------------------- Simulation Setup --------------------------

omega = 1.6
grid_shape = (512 // 2, 128 // 2, 128 // 2)
compute_backend = ComputeBackend.WARP
precision_policy = PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend)
u_max = 0.04
num_steps = 10000
post_process_interval = 1000

# Initialize XLB
xlb.init(
velocity_set=velocity_set,
default_backend=compute_backend,
default_precision_policy=precision_policy,
)

class FlowOverSphere:
def __init__(self, omega, grid_shape, velocity_set, backend, precision_policy):
# initialize backend
xlb.init(
velocity_set=velocity_set,
default_backend=backend,
default_precision_policy=precision_policy,
)

self.grid_shape = grid_shape
self.velocity_set = velocity_set
self.backend = backend
self.precision_policy = precision_policy
self.omega = omega

self.boundary_conditions = []
self.u_max = 0.04

# Create grid using factory
self.grid = grid_factory(grid_shape, compute_backend=backend)

# Setup the simulation BC and stepper
self._setup()

def _setup(self):
self.setup_boundary_conditions()
self.setup_stepper()

def define_boundary_indices(self):
box = self.grid.bounding_box_indices()
box_no_edge = self.grid.bounding_box_indices(remove_edges=True)
inlet = box_no_edge["left"]
outlet = box_no_edge["right"]
walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(self.velocity_set.d)]
walls = np.unique(np.array(walls), axis=-1).tolist()

sphere_radius = self.grid_shape[1] // 12
x = np.arange(self.grid_shape[0])
y = np.arange(self.grid_shape[1])
z = np.arange(self.grid_shape[2])
X, Y, Z = np.meshgrid(x, y, z, indexing="ij")
indices = np.where(
(X - self.grid_shape[0] // 6) ** 2 + (Y - self.grid_shape[1] // 2) ** 2 + (Z - self.grid_shape[2] // 2) ** 2 < sphere_radius**2
)
sphere = [tuple(indices[i]) for i in range(self.velocity_set.d)]

return inlet, outlet, walls, sphere

def setup_boundary_conditions(self):
inlet, outlet, walls, sphere = self.define_boundary_indices()
bc_left = RegularizedBC("velocity", profile=self.bc_profile(), indices=inlet)
# bc_left = RegularizedBC("velocity", prescribed_value=(self.u_max, 0.0, 0.0), indices=inlet)
bc_walls = FullwayBounceBackBC(indices=walls)
bc_outlet = ExtrapolationOutflowBC(indices=outlet)
bc_sphere = HalfwayBounceBackBC(indices=sphere)
self.boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere]

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(
grid=self.grid,
boundary_conditions=self.boundary_conditions,
collision_type="BGK",
)
self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.stepper.prepare_fields()

def bc_profile(self):
u_max = self.u_max # u_max = 0.04
# Get the grid dimensions for the y and z directions
H_y = float(self.grid_shape[1] - 1) # Height in y direction
H_z = float(self.grid_shape[2] - 1) # Height in z direction
# Create Grid
grid = grid_factory(grid_shape, compute_backend=compute_backend)

@wp.func
def bc_profile_warp(index: wp.vec3i):
# Poiseuille flow profile: parabolic velocity distribution
y = wp.float32(index[1])
z = wp.float32(index[2])

# Calculate normalized distance from center
y_center = y - (H_y / 2.0)
z_center = z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0
# Define Boundary Indices
def define_boundary_indices():
box = grid.bounding_box_indices()
box_no_edge = grid.bounding_box_indices(remove_edges=True)
inlet = box_no_edge["left"]
outlet = box_no_edge["right"]
walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(velocity_set.d)]
walls = np.unique(np.array(walls), axis=-1).tolist()

# Parabolic profile: u = u_max * (1 - r²)
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), length=1)
sphere_radius = grid_shape[1] // 12
x = np.arange(grid_shape[0])
y = np.arange(grid_shape[1])
z = np.arange(grid_shape[2])
X, Y, Z = np.meshgrid(x, y, z, indexing="ij")
indices = np.where((X - grid_shape[0] // 6) ** 2 + (Y - grid_shape[1] // 2) ** 2 + (Z - grid_shape[2] // 2) ** 2 < sphere_radius**2)
sphere = [tuple(indices[i]) for i in range(velocity_set.d)]

return inlet, outlet, walls, sphere


inlet, outlet, walls, sphere = define_boundary_indices()


# Define Boundary Conditions
def bc_profile():
H_y = float(grid_shape[1] - 1) # Height in y direction
H_z = float(grid_shape[2] - 1) # Height in z direction

if compute_backend == ComputeBackend.JAX:

def bc_profile_jax():
y = jnp.arange(self.grid_shape[1])
z = jnp.arange(self.grid_shape[2])
y = jnp.arange(grid_shape[1])
z = jnp.arange(grid_shape[2])
Y, Z = jnp.meshgrid(y, z, indexing="ij")

# Calculate normalized distance from center
Expand All @@ -119,56 +85,88 @@ def bc_profile_jax():

return jnp.stack([u_x, u_y, u_z])

if self.backend == ComputeBackend.JAX:
return bc_profile_jax
elif self.backend == ComputeBackend.WARP:
return bc_profile_warp
return bc_profile_jax

elif compute_backend == ComputeBackend.WARP:

@wp.func
def bc_profile_warp(index: wp.vec3i):
# Poiseuille flow profile: parabolic velocity distribution
y = wp.float32(index[1])
z = wp.float32(index[2])

# Calculate normalized distance from center
y_center = y - (H_y / 2.0)
z_center = z - (H_z / 2.0)
r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0

# Parabolic profile: u = u_max * (1 - r²)
return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), length=1)

return bc_profile_warp


# Initialize Boundary Conditions
bc_left = RegularizedBC("velocity", profile=bc_profile(), indices=inlet)
# Alternatively, use a prescribed velocity profile
# bc_left = RegularizedBC("velocity", prescribed_value=(u_max, 0.0, 0.0), indices=inlet)
bc_walls = FullwayBounceBackBC(indices=walls)
bc_outlet = ExtrapolationOutflowBC(indices=outlet)
bc_sphere = HalfwayBounceBackBC(indices=sphere)
boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere]

# Setup Stepper
stepper = IncompressibleNavierStokesStepper(
grid=grid,
boundary_conditions=boundary_conditions,
collision_type="BGK",
)
f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields()

# Define Macroscopic Calculation
macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=precision_policy,
velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX),
)


# Post-Processing Function
def post_process(step, f_current):
# Convert to JAX array if necessary
if not isinstance(f_current, jnp.ndarray):
f_current = wp.to_jax(f_current)

rho, u = macro(f_current)

# Remove boundary cells
u = u[:, 1:-1, 1:-1, 1:-1]
rho = rho[:, 1:-1, 1:-1, 1:-1]
u_magnitude = jnp.sqrt(u[0] ** 2 + u[1] ** 2 + u[2] ** 2)

fields = {
"u_magnitude": u_magnitude,
"u_x": u[0],
"u_y": u[1],
"u_z": u[2],
"rho": rho[0],
}

# Save the u_magnitude slice at the mid y-plane
save_image(fields["u_magnitude"][:, grid_shape[1] // 2, :], timestep=step)
print(f"Post-processed step {step}: Saved u_magnitude slice at y={grid_shape[1] // 2}")


# -------------------------- Simulation Loop --------------------------

start_time = time.time()
for step in range(num_steps):
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, step)
f_0, f_1 = f_1, f_0 # Swap the buffers

def run(self, num_steps, post_process_interval=100):
if step % post_process_interval == 0 or step == num_steps - 1:
post_process(step, f_0)
end_time = time.time()
elapsed = end_time - start_time
print(f"Completed step {step}. Time elapsed for {post_process_interval} steps: {elapsed:.6f} seconds.")
start_time = time.time()
for i in range(num_steps):
self.f_0, self.f_1 = self.stepper(self.f_0, self.f_1, self.bc_mask, self.missing_mask, self.omega, i)
self.f_0, self.f_1 = self.f_1, self.f_0

if i % post_process_interval == 0 or i == num_steps - 1:
self.post_process(i)
end_time = time.time()
print(f"Completing {i} iterations. Time elapsed for 1000 LBM steps in {end_time - start_time:.6f} seconds.")
start_time = time.time()

def post_process(self, i):
# Write the results. We'll use JAX backend for the post-processing
if not isinstance(self.f_0, jnp.ndarray):
f_0 = wp.to_jax(self.f_0)
else:
f_0 = self.f_0

macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=self.precision_policy,
velocity_set=xlb.velocity_set.D3Q19(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
)
rho, u = macro(f_0)

# remove boundary cells
u = u[:, 1:-1, 1:-1, 1:-1]
rho = rho[:, 1:-1, 1:-1, 1:-1]
u_magnitude = (u[0] ** 2 + u[1] ** 2 + u[2] ** 2) ** 0.5

fields = {"u_magnitude": u_magnitude, "u_x": u[0], "u_y": u[1], "u_z": u[2], "rho": rho[0]}

# save_fields_vtk(fields, timestep=i)
save_image(fields["u_magnitude"][:, self.grid_shape[1] // 2, :], timestep=i)


if __name__ == "__main__":
# Running the simulation
grid_shape = (512 // 2, 128 // 2, 128 // 2)
backend = ComputeBackend.WARP
precision_policy = PrecisionPolicy.FP32FP32

velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy)
simulation.run(num_steps=10000, post_process_interval=1000)
22 changes: 11 additions & 11 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,24 @@


class LidDrivenCavity2D:
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy):
# initialize backend
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy):
# initialize compute_backend
xlb.init(
velocity_set=velocity_set,
default_backend=backend,
default_backend=compute_backend,
default_precision_policy=precision_policy,
)

self.grid_shape = grid_shape
self.velocity_set = velocity_set
self.backend = backend
self.compute_backend = compute_backend
self.precision_policy = precision_policy
self.omega = omega
self.boundary_conditions = []
self.prescribed_vel = prescribed_vel

# Create grid using factory
self.grid = grid_factory(grid_shape, compute_backend=backend)
self.grid = grid_factory(grid_shape, compute_backend=compute_backend)

# Setup the simulation BC and stepper
self._setup()
Expand Down Expand Up @@ -71,17 +71,17 @@ def run(self, num_steps, post_process_interval=100):
self.post_process(i)

def post_process(self, i):
# Write the results. We'll use JAX backend for the post-processing
# Write the results. We'll use JAX compute_backend for the post-processing
if not isinstance(self.f_0, jnp.ndarray):
# If the backend is warp, we need to drop the last dimension added by warp for 2D simulations
# If the compute_backend is warp, we need to drop the last dimension added by warp for 2D simulations
f_0 = wp.to_jax(self.f_0)[..., 0]
else:
f_0 = self.f_0

macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=self.precision_policy,
velocity_set=xlb.velocity_set.D2Q9(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
velocity_set=xlb.velocity_set.D2Q9(precision_policy=self.precision_policy, compute_backend=ComputeBackend.JAX),
)

rho, u = macro(f_0)
Expand All @@ -101,10 +101,10 @@ def post_process(self, i):
# Running the simulation
grid_size = 500
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.WARP
compute_backend = ComputeBackend.WARP
precision_policy = PrecisionPolicy.FP32FP32

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, compute_backend=compute_backend)

# Setting fluid viscosity and relaxation parameter.
Re = 200.0
Expand All @@ -113,5 +113,5 @@ def post_process(self, i):
visc = prescribed_vel * clength / Re
omega = 1.0 / (3.0 * visc + 0.5)

simulation = LidDrivenCavity2D(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
simulation = LidDrivenCavity2D(omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy)
simulation.run(num_steps=50000, post_process_interval=1000)
12 changes: 7 additions & 5 deletions examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


class LidDrivenCavity2D_distributed(LidDrivenCavity2D):
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy):
super().__init__(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
def __init__(self, omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy):
super().__init__(omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy)

def setup_stepper(self):
# Create the base stepper
Expand All @@ -30,10 +30,12 @@ def setup_stepper(self):
# Running the simulation
grid_size = 512
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet!
compute_backend = (
ComputeBackend.JAX
) # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet!
precision_policy = PrecisionPolicy.FP32FP32

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, compute_backend=compute_backend)

# Setting fluid viscosity and relaxation parameter.
Re = 200.0
Expand All @@ -42,5 +44,5 @@ def setup_stepper(self):
visc = prescribed_vel * clength / Re
omega = 1.0 / (3.0 * visc + 0.5)

simulation = LidDrivenCavity2D_distributed(omega, prescribed_vel, grid_shape, velocity_set, backend, precision_policy)
simulation = LidDrivenCavity2D_distributed(omega, prescribed_vel, grid_shape, velocity_set, compute_backend, precision_policy)
simulation.run(num_steps=50000, post_process_interval=1000)
Loading

0 comments on commit 6a15554

Please sign in to comment.