Skip to content

Commit

Permalink
demo
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Isaacson committed Jul 29, 2024
1 parent aa1ccd1 commit 1d17283
Show file tree
Hide file tree
Showing 6 changed files with 1,143 additions and 5 deletions.
40 changes: 40 additions & 0 deletions src/beignet/_soft_sphere_potential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch import Tensor


def soft_sphere_potential(input: Tensor,
sigma: Tensor = 1,
epsilon: Tensor = 1,
alpha: Tensor = 2,
**unused_kwargs) -> Tensor:
r"""
Finite ranged repulsive interaction between soft spheres.
Parameters
----------
input : Tensor
A tensor of shape `[n, m]` of pairwise distances between particles.
sigma : Tensor, optional
Particle diameter. Should either be a floating point scalar or a tensor
whose shape is `[n, m]`. Default is 1.
epsilon : Tensor, optional
Interaction energy scale. Should either be a floating point scalar or a tensor
whose shape is `[n, m]`. Default is 1.
alpha : Tensor, optional
Exponent specifying interaction stiffness. Should either be a floating point scalar
or a tensor whose shape is `[n, m]`. Default is 2.
unused_kwargs : dict, optional
Allows extra data (e.g. time) to be passed to the energy.
Returns
-------
Tensor
Matrix of energies whose shape is `[n, m]`.
"""
input = input / sigma
fn = lambda dr: epsilon / alpha * (1.0 - dr) ** alpha

if isinstance(alpha, int) or issubclass(type(alpha.dtype), torch.int):
return torch.where(input < 1.0, fn(input), torch.tensor(0.0, dtype=input.dtype))

return torch.where(input < 1.0, fn(input), torch.tensor(0.0, dtype=input.dtype))
Binary file added src/beignet/examples/models/sand_castle.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 0 additions & 3 deletions src/beignet/func/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,6 @@ def neighbor_list_mask(neighbor: _NeighborList, mask_self: bool = False) -> Tens
"""
if is_neighbor_list_sparse(neighbor.format):
mask = neighbor.indexes[0] < len(neighbor.reference_positions)
torch.set_printoptions(profile="full")
if mask_self:
mask = mask & (neighbor.indexes[0] != neighbor.indexes[1])

Expand Down Expand Up @@ -1229,8 +1228,6 @@ def fn(
> buffer_size
)

print(f"cell_capacity: {buffer_size}")

return _CellList(
exceeded_maximum_size=exceeded_maximum_size,
indexes=indexes,
Expand Down
Loading

0 comments on commit 1d17283

Please sign in to comment.