Skip to content

Commit

Permalink
Merge pull request #61 from ADACS-Australia/main
Browse files Browse the repository at this point in the history
ADACS delivery of MAP 2024A
  • Loading branch information
anaismoller authored Oct 28, 2024
2 parents 8bf5a64 + 756744f commit 822b3f2
Show file tree
Hide file tree
Showing 65 changed files with 5,337 additions and 3,522 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ jobs:
# Uncomment if you need mpi
# - name: Set-up MPI
# uses: mpi4py/setup-mpi@v1

- name: Set-up Cuda Toolkit
run: sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc

# - name: Set-up Cuda Toolkit
# run: |
# sudo apt-get update
# sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc

- name: Set-up Poetry
uses: snok/install-poetry@v1
Expand Down
45 changes: 33 additions & 12 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ jobs:

run_tests:
runs-on: ubuntu-latest
strategy:
matrix:
include:
- conda_env: supernnova
conda_env_file: env/conda_env.yml
- conda_env: supernnova-cuda
conda_env_file: env/conda_gpu_env.yml

steps:

Expand All @@ -34,50 +41,64 @@ jobs:
# uses: mpi4py/setup-mpi@v1

- name: Set-up Cuda Toolkit
run: sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc
run: |
sudo apt-get update
sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc
- name: Set-up Poetry
uses: snok/install-poetry@v1
- name: Setup miniconda
uses: conda-incubator/setup-miniconda@v3
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
activate-environment: ${{matrix.conda_env}}
environment-file: ${{matrix.conda_env_file}}

- name: Set-up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'poetry'
- name: Verify Environment
shell: bash -l {0}
run: |
conda info
conda list
- name: Verify cuda support
if: ${{ matrix.conda_env == 'supernnova-cuda' }}
shell: bash -l {0}
run: python env/verify_cuda_support.py

# Configure project
- name: Set project version
run: poetry version $(git describe --tags --abbrev=0)
shell: bash -l {0}
run: poetry version $(git describe --tags --match "v[0-9]*" --abbrev=0)

# Install the project (we need some of the tools installed here for liniting etc)
- name: Install the project
shell: bash -l {0}
run: poetry install --no-interaction --extras "docs dev"

# Enforce code formating standards
- name: Enforce linting
shell: bash -l {0}
run: poetry run ruff .

- name: Enforce formating
shell: bash -l {0}
run: poetry run black .

# Make sure the Poetry project is properly maintained
- name: Enforce maintainance of Poetry project
shell: bash -l {0}
run: |
poetry check
poetry lock --check
# Run tests
- name: Generate build
shell: bash -l {0}
run: poetry build

- name: Code tests
shell: bash -l {0}
run: poetry run pytest

- name: Documentation build test
shell: bash -l {0}
run: |
cd docs
poetry run make html
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ supernnova/**/.ipynb_checkpoints
science_modules/**/*.pyc
rfsn/**/*.pyc
charnokmoss/**/*.pyc
*.csv
# *.csv
*.h5
.cache
/data
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repos:
hooks:
- id: check-merge-conflict
- id: check-yaml
exclude: archive/|tmp/
- id: check-added-large-files
- id: no-commit-to-branch
args: ['--branch', 'main'] # Commits to main only allowed VIA PR for this project
Expand Down
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
USERNAME = $(USER)
USER_ID = $(shell id -u)
USER_GID = $(shell id -g)

%:
@DOCKER_BUILDKIT=1 docker build -f env/Dockerfile --build-arg 'TARGET=$*' --build-arg 'USERNAME=$(USERNAME)' --build-arg 'USER_ID=${USER_ID}' --build-arg 'USER_GID=${USER_GID}' -t rnn-$* .
258 changes: 258 additions & 0 deletions cli/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import sys
import re
import os.path as osp
import numpy as np
from pathlib import Path
from supernnova import conf
from supernnova.utils import logging_utils as lu

# Notice: This script serves as the entry point of the program. By design, only one action can be executed at a time.
# To optimize performance, heavy imports are deferred until their corresponding actions are invoked.

ALL_ACTIONS = ("make_data", "train_rnn", "validate_rnn", "show", "performance")

help_msg = """
Available commands:
make_data create dataset for ML training
train_rnn train RNN model
validate_rnn validate RNN model
show vitualize different types of plot
performance get method performance and paper plots
Type snn <command> --help for usage help on a specific command.
For example, snn make_data --help will list all data creation options.
"""


def print_usage():
print("Usage: {} <command> <options> <arguments>".format(osp.basename(sys.argv[0])))
print(help_msg)


def get_action():
"""Pop first argument, check it is a valid action."""
if len(sys.argv) <= 1:
print_usage()
sys.exit(1)
if sys.argv[1] not in ALL_ACTIONS:
print_usage()
sys.exit(1)

return sys.argv.pop(1)


def get_plot_lcs(settings):
from supernnova.visualization import early_prediction

if settings.model_files is None:
early_prediction.make_early_prediction(settings, nb_lcs=100, do_gifs=False)

elif settings.model_files:
for model_file in settings.model_files:
model_settings = conf.get_settings_from_dump(settings, model_file)
early_prediction.make_early_prediction(
model_settings, nb_lcs=100, do_gifs=False
)


def make_data_action(settings):
from supernnova.data import make_dataset
from supernnova.visualization import visualize

# Validate command-line arguments
# explore_lightcurves should be used with debug
if settings.explore_lightcurves and not settings.debug:
message = "--explore_lightcurves must be used with --debug"
raise ValueError(message)

# Build an HDF5 database
make_dataset.make_dataset(settings)
lu.print_blue("Finished constructing dataset")

# Visualize lightcurves
if settings.explore_lightcurves:
visualize.visualize(settings)


def train_rnn_action(settings):
from supernnova.training import train_rnn
from supernnova.validation import validate_rnn, metrics
from supernnova.visualization import early_prediction
from supernnova.paper import superNNova_plots as sp

# Validate command-line arguments
if settings.swag:
if settings.cyclic:
lu.print_yellow("SWAG is not available with training cyclic")
if settings.swag_start_epoch + 2 >= settings.nb_epoch:
message = "(swag_start_epoch +2) must be smaller than nb_epoch"
raise ValueError(message)

# Train
if settings.cyclic:
train_rnn.train_cyclic(settings)
else:
train_rnn.train(settings)

# Obtain predictions
prediction_file_list = validate_rnn.get_predictions(settings)

# Compute metrics
for prediction_file in prediction_file_list:
metrics.get_metrics_singlemodel(
settings, prediction_file=prediction_file, model_type="rnn"
)

# Plot some lightcurves
early_prediction.make_early_prediction(settings)

lu.print_blue("Finished rnn training, validating, testing and plotting lcs")

if settings.calibration:
sp.plot_calibration(settings, prediction_files=prediction_file_list)


def validate_rnn_action(settings):
from supernnova.validation import validate_rnn, metrics
from supernnova.visualization import prediction_distribution
from supernnova.paper import superNNova_plots as sp

if settings.model_files is None:
prediction_file_list = validate_rnn.get_predictions(settings)
# Compute metrics
for prediction_file in prediction_file_list:
metrics.get_metrics_singlemodel(
settings, prediction_file=prediction_file, model_type="rnn"
)
pf = prediction_file_list
else:
prediction_files = []
for model_file in settings.model_files:
# Restore model settings
model_settings = conf.get_settings_from_dump(settings, model_file)
if settings.num_inference_samples != model_settings.num_inference_samples:
model_settings.num_inference_samples = settings.num_inference_samples
# Get predictions
prediction_file_list = validate_rnn.get_predictions(
model_settings, model_file=model_file
)
prediction_files.extend(prediction_file_list)
# Compute metrics
for prediction_file in prediction_file_list:
metrics.get_metrics_singlemodel(
model_settings,
prediction_file=prediction_file,
model_type="rnn",
)
pf = prediction_files

if settings.plot_lcs:
get_plot_lcs(settings)

if settings.calibration:
sp.plot_calibration(settings, prediction_files=pf)

if settings.plot_prediction_distribution:
prediction_distribution.plot_prediction_distribution(settings)

if settings.speed:
validate_rnn.get_predictions_for_speed_benchmark(settings)


def show_action(settings):
from supernnova.visualization import prediction_distribution
from supernnova.paper import superNNova_plots as sp

# Validate command-line arguments
# prediction_files should be provided for calibration
if settings.calibration and not settings.prediction_files:
message = "--calibration must be used with --prediction_files"
raise ValueError(message)

if settings.plot_lcs:
get_plot_lcs(settings)

if settings.plot_prediction_distribution:
prediction_distribution.plot_prediction_distribution(settings)

if settings.calibration:
sp.plot_calibration(settings)


def performance_action(settings):
from supernnova.validation import validate_rnn, metrics

# Validate command-line arguments
# Need to provide prediction files when calculating metrics
if settings.metrics and not settings.prediction_files:
message = "--metrics should be used with --prediction_files"
raise ValueError(message)

if settings.metrics:
for prediction_file in settings.prediction_files:
# TODO: need to make sure only rnn model file is allowed in this step
model_type = "rf" if "randomforest" in prediction_file else "rnn"
metrics.get_metrics_singlemodel(
conf.get_settings_from_dump(settings, prediction_file),
prediction_file=prediction_file,
model_type=model_type,
)
lu.print_blue("Finished computing metrics")

# if settings.performance:
# from supernnova.utils import logging_utils
# metrics.aggregate_metrics(settings)
# lu.print_blue("Finished aggregating performance")
# # Stats and plots in paper
# st.SuperNNova_stats_and_plots(settings)
# lu.print_blue("Finished assembling paper performance")

# Speed benchmarks
if settings.speed:
validate_rnn.get_predictions_for_speed_benchmark(settings)

if settings.done_file:
with open(Path(settings.done_file), "w") as the_file:
the_file.write("SUCCESS\n")


def main():
if len(sys.argv) == 2 and sys.argv[1] == "--help":
print_usage()

else:
# Workaround for optparse limitation: insert -- before first negative
# number found.
negint = re.compile("-[0-9]+")
for n, arg in enumerate(sys.argv):
if negint.match(arg):
sys.argv.insert(n, "--")
break

actions = {
"make_data": make_data_action,
"train_rnn": train_rnn_action,
"validate_rnn": validate_rnn_action,
"show": show_action,
"performance": performance_action,
}

action = get_action()

# Get config parameters
settings = conf.get_settings(action)
# setting random seeds
np.random.seed(settings.seed)

import torch

torch.manual_seed(settings.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(settings.seed)

actions[action](settings)


if __name__ == "__main__":
main()
Loading

0 comments on commit 822b3f2

Please sign in to comment.