-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
31 changed files
with
325,075 additions
and
257 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,92 +1,45 @@ | ||
# Project Template | ||
|
||
This template combines three libraries to give you some basic training infrastructure: | ||
|
||
- [seml](https://github.com/TUM-DAML/seml/) to load configuration files and run jobs | ||
|
||
|
||
## Installation (Quick Guide) | ||
We highly recommend using [`uv`](https://docs.astral.sh/uv/) for reproducible project management: | ||
```bash | ||
curl -LsSf https://astral.sh/uv/install.sh | sh | ||
``` | ||
To setup the right environment and activate it use | ||
```sh | ||
uv sync | ||
source .venv/bin/activate | ||
``` | ||
*Optionally*: Install pre-commit hooks via | ||
```sh | ||
pre-commit install | ||
``` | ||
When executing commands with `seml` make sure to always first activate your virtual environment or use `uv run seml`. Do not use `uvx seml` as `uvx` will create a temporary virtual environments where your packages are not installed. | ||
|
||
## Developement | ||
|
||
**Project management** | ||
|
||
For project management, we recommend [`uv`](https://docs.astral.sh/uv/). Please read the docs carefully. Here are the most important commands | ||
* To add a package to your project use: `uv add <package>`, e.g., `uv add jax[cuda12]`. | ||
* To update your environment: `uv sync`. | ||
* To run a script without explicitly activating the environment, use `uv run main.py`. | ||
* Activate your environment: `source .venv/bin/activate` | ||
|
||
`uv` will create a lock file that exactly describes your current environment. Make sure to commit it. To recreate this environment, use `uv sync --locked`. This lock file enables the exact reproducibility of your current environment. | ||
|
||
**IDE** | ||
|
||
We recommend [VS Code](https://code.visualstudio.com) for development. Select the conda environment you created earlier as your default python interpreter. *Optionally*, use static typecheckers and linters like [ruff](https://github.com/astral-sh/ruff). | ||
|
||
**Sacred** | ||
|
||
`seml` is based on [Sacred](https://sacred.readthedocs.io/en/stable/index.html). Familiarize yourself with the rough concept behind this framework. Importantly, understand how [experiments](https://sacred.readthedocs.io/en/stable/experiment.html) work and how they can be [configured](https://sacred.readthedocs.io/en/stable/experiment.html#configuration) using config overrides and `named configs`. | ||
|
||
**MongoDB** | ||
|
||
`seml` will log your experiments on our local `MongoDB` server after you set it up according to the [installation guide]((https://github.com/TUM-DAML/seml/)). Familiarize yourself with the core functionality of `seml` experiments from the example configurations. | ||
|
||
|
||
**Pytest** | ||
|
||
During development you may want to test several functionalities. We recommend using [`pytest`](https://docs.pytest.org/en/8.0.x/) for this. To run your tests simply call | ||
```sh | ||
pytest | ||
``` | ||
|
||
|
||
## Running experiments locally | ||
|
||
To start a training locally, call `main.py` with the your settings, for example | ||
|
||
```sh | ||
./main.py with config/data/small.yaml config/model/big.yaml | ||
``` | ||
|
||
You can use this for debugging, e.g. in an interactive slurm session or on your own machine. | ||
|
||
## Running experiments on the cluster | ||
|
||
Use `seml` to run experiments on the cluster. Pick a collection name, e.g. `example_experiment`. Each experiment should be referred to with an configuration file in `experiments/`. Use the `seml.description` field to keep track of your experiments. Add experiments using: | ||
|
||
# Neural Pfaffians: Solving Many Many-Electron Schrödinger Equations | ||
|
||
data:image/s3,"s3://crabby-images/967c5/967c58fe159acad71e0238873ec57884d2b20c3a" alt="Title" | ||
|
||
Reference implementation of Neural Pfaffians from <be> | ||
|
||
<b>[Neural Pfaffians: Solving Many Many-Electron Schrödinger Equations](https://arxiv.org/abs/2405.14762)</b><br> | ||
by Nicholas Gao, Stephan Günnemann<br/> | ||
published as Oral at NeurIPS 2024. | ||
|
||
## Installation | ||
1. Install [`uv`](https://docs.astral.sh/uv/): | ||
```bash | ||
curl -LsSf https://astral.sh/uv/install.sh | sh | ||
``` | ||
2. Create a virtual environment and install dependencies | ||
```sh | ||
uv sync | ||
source .venv/bin/activate | ||
``` | ||
|
||
## Running the code | ||
We encourage the use of `seml` to manage all experiments, but we also supply commands to run the experiments directly. | ||
With `seml`: | ||
```bash | ||
seml {your-collection-name} add config/seml/grid.yaml | ||
seml n2_ablation add configs/seml/train_n2.yaml start | ||
``` | ||
|
||
Run them on the cluster using: | ||
|
||
Without `seml`: | ||
```bash | ||
seml {your-collection-name} start | ||
neural_pfaffian with configs/systems/n2.yaml | ||
``` | ||
|
||
You can monitor the experiment using: | ||
## Contact | ||
Please contact [[email protected]](mailto:[email protected]) if you have any questions. | ||
|
||
```bash | ||
seml {your-collection-name} status | ||
## Cite | ||
Please cite our paper if you use our method or code in your own works: | ||
``` | ||
|
||
More advanced usage of seml can be found in the [documentation](https://github.com/TUM-DAML/seml/tree/master/examples). | ||
|
||
|
||
## Analyzing results | ||
|
||
You can analyze the results by inspecting output files your code generates or values you log in the MongoDB. For reference, see `notebooks/visualize_results.ipynb`. | ||
@inproceedings{gao_pfaffian_2024, | ||
title = {Neural Pfaffians: Solving Many Many-Electron Schr\"odinger Equations}, | ||
author = {Gao, Nicholas and G{\"u}nnemann, Stephan}, | ||
booktitle = {Neural Information Processing Systems (NeurIPS)}, | ||
year = {2024} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
seml: | ||
executable: src/neural_pfaffian/main.py | ||
output_dir: ~/slurm-output | ||
project_root_dir: ../../ | ||
|
||
slurm: | ||
- sbatch_options: | ||
gres: gpu:1 | ||
cpus-per-task: 8 | ||
partition: gpu_h100 | ||
qos: interactive | ||
time: 0-12:00:00 | ||
|
||
fixed: | ||
+systems: config/systems/n2.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
systems: | ||
molecules: | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 1.60151 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 1.70828 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 1.81505 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 1.92181 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.02858 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.13535 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.24212 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.34889 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.45565 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.56242 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.66919 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.77595 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.88272 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.98949 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.09626 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.20302 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.30979 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.41656 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.52333 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.63009 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.73686 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.84363 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.95040 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.05716 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.16393 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.27070 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.37747 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.48423 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.59100 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.69777 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.80454 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.91130 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 5.01807 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 5.12484 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 5.23161 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 5.33837 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 5.87221 }] | ||
- ["diatomic", { charge1: 7, charge2: 7, distance: 6.40605 }] | ||
num_walker_per_mol: 110 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
[project] | ||
name = "neural_pfaffian" | ||
version = "0.0.1" | ||
authors = [{ name = "gaoni", email = "[email protected]" }] | ||
authors = [{ name = "Nicholas Gao", email = "[email protected]" }] | ||
requires-python = ">= 3.11" | ||
dependencies = [ | ||
"einops>=0.8.0", | ||
|
@@ -31,6 +31,9 @@ dependencies = [ | |
] | ||
license = { text = "MIT" } | ||
|
||
[project.scripts] | ||
neural_pfaffian = "neural_pfaffian.__main__:cli_main" | ||
|
||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .main import cli_main | ||
|
||
if __name__ == '__main__': | ||
cli_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import functools | ||
from typing import Protocol | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from flax.struct import PyTreeNode, field | ||
from jaxtyping import Array, Float | ||
|
||
from neural_pfaffian.utils import Modules | ||
from neural_pfaffian.utils.jax_utils import pgather, pmean | ||
|
||
LocalEnergies = Float[Array, ' batch_size n_mols'] | ||
LocalEnergiesPerMol = Float[Array, ' batch_size n_mols'] | ||
|
||
|
||
class Clipping(Protocol): | ||
def __call__(self, local_energies: LocalEnergiesPerMol) -> LocalEnergiesPerMol: ... | ||
|
||
|
||
class NoneClipping(Clipping, PyTreeNode): | ||
def __call__(self, local_energies: LocalEnergiesPerMol) -> LocalEnergiesPerMol: | ||
return local_energies | ||
|
||
|
||
class MeanClipping(Clipping, PyTreeNode): | ||
max_deviation: float = field(pytree_node=False) | ||
|
||
@functools.partial(jax.vmap, in_axes=-1, out_axes=-1) | ||
def __call__(self, local_energies: LocalEnergies) -> LocalEnergies: | ||
center = pmean(jnp.mean(local_energies)) | ||
dev = pmean(jnp.abs(local_energies - center).mean()) | ||
max_dev = self.max_deviation * dev | ||
return jnp.clip(local_energies, center - max_dev, center + max_dev) | ||
|
||
|
||
class MedianClipping(Clipping, PyTreeNode): | ||
max_deviation: float = field(pytree_node=False) | ||
|
||
@functools.partial(jax.vmap, in_axes=-1, out_axes=-1) | ||
def __call__(self, local_energies: LocalEnergies) -> LocalEnergies: | ||
full_e = pgather(local_energies, axis=0, tiled=True) | ||
center = jnp.median(full_e) | ||
dev = pmean(jnp.abs(local_energies - center).mean()) | ||
max_dev = self.max_deviation * dev | ||
return jnp.clip(local_energies, center - max_dev, center + max_dev) | ||
|
||
|
||
CLIPPINGS = Modules[Clipping]( | ||
{ | ||
cls.__name__.lower().replace('clipping', ''): cls | ||
for cls in [NoneClipping, MeanClipping, MedianClipping] | ||
} | ||
) |
Oops, something went wrong.