Skip to content

Commit

Permalink
add configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
n-gao committed Nov 7, 2024
1 parent 0f1ea6f commit 9f98a7a
Show file tree
Hide file tree
Showing 31 changed files with 325,075 additions and 257 deletions.
123 changes: 38 additions & 85 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,92 +1,45 @@
# Project Template

This template combines three libraries to give you some basic training infrastructure:

- [seml](https://github.com/TUM-DAML/seml/) to load configuration files and run jobs


## Installation (Quick Guide)
We highly recommend using [`uv`](https://docs.astral.sh/uv/) for reproducible project management:
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
To setup the right environment and activate it use
```sh
uv sync
source .venv/bin/activate
```
*Optionally*: Install pre-commit hooks via
```sh
pre-commit install
```
When executing commands with `seml` make sure to always first activate your virtual environment or use `uv run seml`. Do not use `uvx seml` as `uvx` will create a temporary virtual environments where your packages are not installed.

## Developement

**Project management**

For project management, we recommend [`uv`](https://docs.astral.sh/uv/). Please read the docs carefully. Here are the most important commands
* To add a package to your project use: `uv add <package>`, e.g., `uv add jax[cuda12]`.
* To update your environment: `uv sync`.
* To run a script without explicitly activating the environment, use `uv run main.py`.
* Activate your environment: `source .venv/bin/activate`

`uv` will create a lock file that exactly describes your current environment. Make sure to commit it. To recreate this environment, use `uv sync --locked`. This lock file enables the exact reproducibility of your current environment.

**IDE**

We recommend [VS Code](https://code.visualstudio.com) for development. Select the conda environment you created earlier as your default python interpreter. *Optionally*, use static typecheckers and linters like [ruff](https://github.com/astral-sh/ruff).

**Sacred**

`seml` is based on [Sacred](https://sacred.readthedocs.io/en/stable/index.html). Familiarize yourself with the rough concept behind this framework. Importantly, understand how [experiments](https://sacred.readthedocs.io/en/stable/experiment.html) work and how they can be [configured](https://sacred.readthedocs.io/en/stable/experiment.html#configuration) using config overrides and `named configs`.

**MongoDB**

`seml` will log your experiments on our local `MongoDB` server after you set it up according to the [installation guide]((https://github.com/TUM-DAML/seml/)). Familiarize yourself with the core functionality of `seml` experiments from the example configurations.


**Pytest**

During development you may want to test several functionalities. We recommend using [`pytest`](https://docs.pytest.org/en/8.0.x/) for this. To run your tests simply call
```sh
pytest
```


## Running experiments locally

To start a training locally, call `main.py` with the your settings, for example

```sh
./main.py with config/data/small.yaml config/model/big.yaml
```

You can use this for debugging, e.g. in an interactive slurm session or on your own machine.

## Running experiments on the cluster

Use `seml` to run experiments on the cluster. Pick a collection name, e.g. `example_experiment`. Each experiment should be referred to with an configuration file in `experiments/`. Use the `seml.description` field to keep track of your experiments. Add experiments using:

# Neural Pfaffians: Solving Many Many-Electron Schrödinger Equations

![Title](figures/title.png)

Reference implementation of Neural Pfaffians from <be>

<b>[Neural Pfaffians: Solving Many Many-Electron Schrödinger Equations](https://arxiv.org/abs/2405.14762)</b><br>
by Nicholas Gao, Stephan Günnemann<br/>
published as Oral at NeurIPS 2024.

## Installation
1. Install [`uv`](https://docs.astral.sh/uv/):
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
2. Create a virtual environment and install dependencies
```sh
uv sync
source .venv/bin/activate
```

## Running the code
We encourage the use of `seml` to manage all experiments, but we also supply commands to run the experiments directly.
With `seml`:
```bash
seml {your-collection-name} add config/seml/grid.yaml
seml n2_ablation add configs/seml/train_n2.yaml start
```

Run them on the cluster using:

Without `seml`:
```bash
seml {your-collection-name} start
neural_pfaffian with configs/systems/n2.yaml
```

You can monitor the experiment using:
## Contact
Please contact [[email protected]](mailto:[email protected]) if you have any questions.

```bash
seml {your-collection-name} status
## Cite
Please cite our paper if you use our method or code in your own works:
```

More advanced usage of seml can be found in the [documentation](https://github.com/TUM-DAML/seml/tree/master/examples).


## Analyzing results

You can analyze the results by inspecting output files your code generates or values you log in the MongoDB. For reference, see `notebooks/visualize_results.ipynb`.
@inproceedings{gao_pfaffian_2024,
title = {Neural Pfaffians: Solving Many Many-Electron Schr\"odinger Equations},
author = {Gao, Nicholas and G{\"u}nnemann, Stephan},
booktitle = {Neural Information Processing Systems (NeurIPS)},
year = {2024}
}
```
15 changes: 15 additions & 0 deletions config/seml/train_n2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
seml:
executable: src/neural_pfaffian/main.py
output_dir: ~/slurm-output
project_root_dir: ../../

slurm:
- sbatch_options:
gres: gpu:1
cpus-per-task: 8
partition: gpu_h100
qos: interactive
time: 0-12:00:00

fixed:
+systems: config/systems/n2.yaml
41 changes: 41 additions & 0 deletions config/systems/n2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
systems:
molecules:
- ["diatomic", { charge1: 7, charge2: 7, distance: 1.60151 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 1.70828 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 1.81505 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 1.92181 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.02858 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.13535 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.24212 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.34889 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.45565 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.56242 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.66919 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.77595 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.88272 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 2.98949 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.09626 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.20302 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.30979 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.41656 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.52333 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.63009 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.73686 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.84363 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 3.95040 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.05716 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.16393 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.27070 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.37747 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.48423 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.59100 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.69777 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.80454 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 4.91130 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 5.01807 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 5.12484 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 5.23161 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 5.33837 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 5.87221 }]
- ["diatomic", { charge1: 7, charge2: 7, distance: 6.40605 }]
num_walker_per_mol: 110
Binary file added figures/title.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 0 additions & 27 deletions main.py

This file was deleted.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "neural_pfaffian"
version = "0.0.1"
authors = [{ name = "gaoni", email = "[email protected]" }]
authors = [{ name = "Nicholas Gao", email = "[email protected]" }]
requires-python = ">= 3.11"
dependencies = [
"einops>=0.8.0",
Expand Down Expand Up @@ -31,6 +31,9 @@ dependencies = [
]
license = { text = "MIT" }

[project.scripts]
neural_pfaffian = "neural_pfaffian.__main__:cli_main"

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand Down
4 changes: 4 additions & 0 deletions src/neural_pfaffian/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .main import cli_main

if __name__ == '__main__':
cli_main()
53 changes: 53 additions & 0 deletions src/neural_pfaffian/clipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import functools
from typing import Protocol

import jax
import jax.numpy as jnp
from flax.struct import PyTreeNode, field
from jaxtyping import Array, Float

from neural_pfaffian.utils import Modules
from neural_pfaffian.utils.jax_utils import pgather, pmean

LocalEnergies = Float[Array, ' batch_size n_mols']
LocalEnergiesPerMol = Float[Array, ' batch_size n_mols']


class Clipping(Protocol):
def __call__(self, local_energies: LocalEnergiesPerMol) -> LocalEnergiesPerMol: ...


class NoneClipping(Clipping, PyTreeNode):
def __call__(self, local_energies: LocalEnergiesPerMol) -> LocalEnergiesPerMol:
return local_energies


class MeanClipping(Clipping, PyTreeNode):
max_deviation: float = field(pytree_node=False)

@functools.partial(jax.vmap, in_axes=-1, out_axes=-1)
def __call__(self, local_energies: LocalEnergies) -> LocalEnergies:
center = pmean(jnp.mean(local_energies))
dev = pmean(jnp.abs(local_energies - center).mean())
max_dev = self.max_deviation * dev
return jnp.clip(local_energies, center - max_dev, center + max_dev)


class MedianClipping(Clipping, PyTreeNode):
max_deviation: float = field(pytree_node=False)

@functools.partial(jax.vmap, in_axes=-1, out_axes=-1)
def __call__(self, local_energies: LocalEnergies) -> LocalEnergies:
full_e = pgather(local_energies, axis=0, tiled=True)
center = jnp.median(full_e)
dev = pmean(jnp.abs(local_energies - center).mean())
max_dev = self.max_deviation * dev
return jnp.clip(local_energies, center - max_dev, center + max_dev)


CLIPPINGS = Modules[Clipping](
{
cls.__name__.lower().replace('clipping', ''): cls
for cls in [NoneClipping, MeanClipping, MedianClipping]
}
)
Loading

0 comments on commit 9f98a7a

Please sign in to comment.