Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding example CNN to the set of models #4

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions example_config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
[general]
use_gpu = true

# Destination of log messages
# 'stderr' and 'stdout' specify the console.
log_destination = "stderr"
# A path name specifies a file e.g.
# log = "fibad_log.txt"

# Lowest log level to emit.
# As you go down the list, fibad will become more verbose in the log.
#
# log_level = "critical" # Only emit the most severe of errors
# log_level = "error" # Emit all errors
# log_level = "warning" # Emit warnings and all errors
log_level = "info" # Emit informational messages, warnings and all errors
# log_level = "debug" # Very verbose, emit all log messages.

[download]
sw = "22asec"
sh = "22asec"
filter = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"]
type = "coadd"
rerun = "pdr3_wide"
username = "mtauraso@local"
password = "cCw+nX53lmNLHMy+JbizpH/dl4t7sxljiNm6a7k1"
max_connections = 2
fits_file = "../hscplay/temp.fits"
cutout_dir = "../hscplay/cutouts/"
offset = 0
num_sources = 500

# These control the downloader's HTTP requests and retries
# `retry_wait` How long to wait before retrying a failed HTTP request in seconds. Default 30s
retry_wait = 30
# `retries` How many times to retry a failed HTTP request before moving on to the next one. Default 3 times
retries = 3
# `timepout` How long should we wait to get a full HTTP response from the server. Default 3600s (1hr)
timeout = 3600
# `chunksize` How many sky location rectangles should we request in a single request. Default is 990
chunksize = 990

[model]
# name = "ExampleCNN"
# name = "ExampleAutoencoder"

# An example of requesting an external model class
# external_class = "user_package.submodule.ExternalModel"
external_cls = "kbmod_ml.models.cnn.CNN"

weights_filepath = "example_model.pth"
epochs = 10

[data_loader]
# Name of data loader to use
name = "CifarDataLoader"
# name = "HSCDataLoader"

# An example of requesting an external data loader class
# external_class = "user_package.submodule.ExternalDataLoader"

# Directory path where the data is stored
path = "/Users/drew/code/fibad/data/cifar-10-batches-py"
# path = "/Users/drew/code/fibad/data/hsc-samples"

# Default PyTorch DataLoader parameters
batch_size = 10
shuffle = true
num_workers = 10

[predict]
batch_size = 32
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ classifiers = [
dynamic = ["version"]
requires-python = ">=3.9"
dependencies = [
"torch", # PyTorch
# "fibad", when it is available on PyPI
]

[project.urls]
Expand Down
74 changes: 74 additions & 0 deletions src/kbmod_ml/models/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# ruff: noqa: D101, D102

# This example model is taken from the PyTorch CIFAR10 tutorial:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F # noqa N812
import torch.optim as optim
from fibad.models.model_registry import fibad_model

logger = logging.getLogger(__name__)


@fibad_model
class CNN(nn.Module):
def __init__(self, model_config, shape):
logger.info("This is an external model, not in FIBAD!!!")
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

self.config = model_config

# Optimizer and criterion could be set directly, i.e. `self.optimizer = optim.SGD(...)`
# but we define them as methods as a way to allow for more flexibility in the future.
self.optimizer = self._optimizer()
self.criterion = self._criterion()

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

def train_step(self, batch):
"""This function contains the logic for a single training step. i.e. the
contents of the inner loop of a ML training process.

Parameters
----------
batch : tuple
A tuple containing the inputs and labels for the current batch.

Returns
-------
Current loss value
The loss value for the current batch.
"""
inputs, labels = batch

self.optimizer.zero_grad()
outputs = self(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
return {"loss": loss.item()}

def _criterion(self):
return nn.CrossEntropyLoss()

def _optimizer(self):
return optim.SGD(self.parameters(), lr=0.001, momentum=0.9)

def save(self):
torch.save(self.state_dict(), self.config.get("weights_filepath", "example_cnn.pth"))
Loading