Skip to content

Commit

Permalink
RMSprop by default; More flexible training; device bugfix; more hyper…
Browse files Browse the repository at this point in the history
…params in yml; renaming to weight_decay (#36)

* Add optimizer_of_name function for dynamic optimizer creation
* save more hparams in yml
* adding cli to netam for concat_csvs 😂
* device bugfix for SHM model
* RMSprop by default!
* using UNIX time for walltime, not hours
* changing l2_regularization_coeff to weight_decay
* making branch length optimization optional
  • Loading branch information
matsen authored Jun 14, 2024
1 parent 29859d9 commit 9c1352c
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 28 deletions.
2 changes: 1 addition & 1 deletion data/cnn_joi_sml-shmoof_small.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ model_hyperparameters:
kmer_length: 3
serialization_version: 0
training_hyperparameters:
l2_regularization_coeff: 1.0e-06
weight_decay: 1.0e-06
learning_rate: 0.1
min_learning_rate: 1.0e-06
35 changes: 35 additions & 0 deletions netam/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import fire
import pandas as pd


def concatenate_csvs(
input_csvs_str: str,
output_csv: str,
is_tsv: bool = False,
record_path: bool = False,
):
"""
This function concatenates multiple CSV or TSV files into one CSV file.
Args:
input_csvs: A string of paths to the input CSV or TSV files separated by commas.
output_csv: Path to the output CSV file.
is_tsv: A boolean flag that determines whether the input files are TSV.
record_path: A boolean flag that adds a column recording the path of the input_csv.
"""
input_csvs = input_csvs_str.split(",")
dfs = []

for csv in input_csvs:
df = pd.read_csv(csv, delimiter="\t" if is_tsv else ",")
if record_path:
df["input_file_path"] = csv
dfs.append(df)

result_df = pd.concat(dfs, ignore_index=True)

result_df.to_csv(output_csv, index=False)


def main():
fire.Fire()
18 changes: 18 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,24 @@ def stack_heterogeneous(tensors, pad_value=0.0):
return torch.stack(padded_tensors)


def optimizer_of_name(optimizer_name, model_parameters, **kwargs):
"""
Build a torch.optim optimizer from a string name and model parameters.
Use a SGD optimizer with momentum if the optimizer_name is "SGDMomentum".
"""
if optimizer_name == "SGDMomentum":
optimizer_name = "SGD"
kwargs["momentum"] = 0.9
try:
optimizer_class = getattr(optim, optimizer_name)
return optimizer_class(model_parameters, **kwargs)
except AttributeError:
raise ValueError(
f"Optimizer '{optimizer_name}' is not recognized in torch.optim"
)


def find_least_used_cuda_gpu():
"""
Find the least used CUDA GPU on the system using nvidia-smi.
Expand Down
8 changes: 6 additions & 2 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,11 @@ def to_crepe(self):
training_hyperparameters = {
key: self.__dict__[key]
for key in [
"optimizer_name",
"batch_size",
"learning_rate",
"min_learning_rate",
"weight_decay",
]
}
encoder = framework.PlaceholderEncoder()
Expand All @@ -444,7 +448,7 @@ def burrito_of_model(
batch_size=1024,
learning_rate=0.1,
min_learning_rate=1e-4,
l2_regularization_coeff=1e-6,
weight_decay=1e-6,
):
model.to(device)
burrito = DNSMBurrito(
Expand All @@ -454,6 +458,6 @@ def burrito_of_model(
batch_size=batch_size,
learning_rate=learning_rate,
min_learning_rate=min_learning_rate,
l2_regularization_coeff=l2_regularization_coeff,
weight_decay=weight_decay,
)
return burrito
65 changes: 42 additions & 23 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
generate_kmers,
kmer_to_index_of,
nt_mask_tensor_of,
optimizer_of_name,
BASES,
BASES_AND_N_TO_INDEX,
BIG,
Expand Down Expand Up @@ -322,7 +323,7 @@ def load_crepe(prefix, device=None):
model.eval()

crepe_instance = Crepe(encoder, model, config["training_hyperparameters"])
if device:
if device is not None:
crepe_instance.to(device)

return crepe_instance
Expand Down Expand Up @@ -371,10 +372,11 @@ def __init__(
train_dataset,
val_dataset,
model,
optimizer_name="RMSprop",
batch_size=1024,
learning_rate=0.1,
min_learning_rate=1e-4,
l2_regularization_coeff=1e-6,
weight_decay=1e-6,
name="",
):
"""
Expand All @@ -383,15 +385,16 @@ def __init__(
"""
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.batch_size = batch_size
if train_dataset is not None:
self.writer = SummaryWriter(log_dir=f"./_logs/{name}")
self.writer.add_text("model_name", model.__class__.__name__)
self.writer.add_text("model_hyperparameters", str(model.hyperparameters))
self.model = model
self.optimizer_name = optimizer_name
self.batch_size = batch_size
self.learning_rate = learning_rate
self.min_learning_rate = min_learning_rate
self.l2_regularization_coeff = l2_regularization_coeff
self.weight_decay = weight_decay
self.name = name
self.reset_optimization()
self.bce_loss = nn.BCELoss()
Expand All @@ -417,20 +420,22 @@ def reset_optimization(self, learning_rate=None):
"""Reset the optimizer and scheduler."""
if learning_rate is None:
learning_rate = self.learning_rate
self.optimizer = torch.optim.AdamW(

self.optimizer = optimizer_of_name(
self.optimizer_name,
self.model.parameters(),
lr=learning_rate,
weight_decay=self.l2_regularization_coeff,
weight_decay=self.weight_decay,
)
self.scheduler = ReduceLROnPlateau(
self.optimizer, mode="min", factor=0.5, patience=10
)

def execution_hours(self):
def execution_time(self):
"""
Return time in hours (rounded to 3 decimal places) since the Burrito was created.
Return time since the Burrito was created.
"""
return round((time() - self.start_time) / 3600, 3)
return time() - self.start_time

def multi_train(self, epochs, max_tries=3):
"""
Expand All @@ -451,7 +456,7 @@ def multi_train(self, epochs, max_tries=3):
return train_history

def write_loss(self, loss_name, loss, step):
self.writer.add_scalar(loss_name, loss, step, walltime=self.execution_hours())
self.writer.add_scalar(loss_name, loss, step, walltime=self.execution_time())

def write_cuda_memory_info(self):
megabyte_scaling_factor = 1 / 1024**2
Expand Down Expand Up @@ -690,11 +695,16 @@ def mark_branch_lengths_optimized(self, cycle):
"branch length optimization",
cycle,
self.global_epoch,
walltime=self.execution_hours(),
walltime=self.execution_time(),
)

def joint_train(
self, epochs=100, cycle_count=2, training_method="full", out_prefix=None
self,
epochs=100,
cycle_count=2,
training_method="full",
out_prefix=None,
optimize_bl_first_cycle=True,
):
"""
Do joint optimization of model and branch lengths.
Expand All @@ -703,6 +713,10 @@ def joint_train(
If training_method is "yun", then we use Yun's approximation to the branch lengths.
If training_method is "fixed", then we fix the branch lengths and only optimize the model.
We give an option to optimize the branch lengths in the first cycle (by
default we do). But, this can be useful to turn off e.g. if we've loaded
in some preoptimized branch lengths.
We reset the optimization after each cycle, and we use a learning rate
schedule that uses a weighted geometric mean of the current learning
rate and the initial learning rate that progressively moves towards
Expand All @@ -717,10 +731,13 @@ def joint_train(
else:
raise ValueError(f"Unknown training method {training_method}")
loss_history_l = []
optimize_branch_lengths()
if optimize_bl_first_cycle:
optimize_branch_lengths()
self.mark_branch_lengths_optimized(0)
for cycle in range(cycle_count):
print(f"### Beginning cycle {cycle + 1}/{cycle_count}")
print(
f"### Beginning cycle {cycle + 1}/{cycle_count} using optimizer {self.optimizer_name}"
)
self.mark_branch_lengths_optimized(cycle + 1)
current_lr = self.optimizer.param_groups[0]["lr"]
# set new_lr to be the geometric mean of current_lr and the
Expand Down Expand Up @@ -752,21 +769,23 @@ def __init__(
train_dataset,
val_dataset,
model,
optimizer_name="RMSprop",
batch_size=1024,
learning_rate=0.1,
min_learning_rate=1e-4,
l2_regularization_coeff=1e-6,
weight_decay=1e-6,
name="",
):
super().__init__(
train_dataset,
val_dataset,
model,
batch_size,
learning_rate,
min_learning_rate,
l2_regularization_coeff,
name,
optimizer_name=optimizer_name,
batch_size=batch_size,
learning_rate=learning_rate,
min_learning_rate=min_learning_rate,
weight_decay=weight_decay,
name=name,
)

def loss_of_batch(self, batch):
Expand Down Expand Up @@ -820,7 +839,7 @@ def to_crepe(self):
for key in [
"learning_rate",
"min_learning_rate",
"l2_regularization_coeff",
"weight_decay",
]
}
encoder = KmerSequenceEncoder(
Expand Down Expand Up @@ -958,10 +977,10 @@ def find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
def write_loss(self, loss_name, loss, step):
rate_loss, csp_loss = loss.unbind()
self.writer.add_scalar(
"Rate " + loss_name, rate_loss.item(), step, walltime=self.execution_hours()
"Rate " + loss_name, rate_loss.item(), step, walltime=self.execution_time()
)
self.writer.add_scalar(
"CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_hours()
"CSP " + loss_name, csp_loss.item(), step, walltime=self.execution_time()
)


Expand Down
4 changes: 2 additions & 2 deletions netam/hyper_burrito.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def burrito_of_model(
batch_size=1024,
learning_rate=0.1,
min_learning_rate=1e-4,
l2_regularization_coeff=1e-6,
weight_decay=1e-6,
):
burrito = SHMBurrito(
self.train_dataset,
Expand All @@ -190,6 +190,6 @@ def burrito_of_model(
batch_size=batch_size,
learning_rate=learning_rate,
min_learning_rate=min_learning_rate,
l2_regularization_coeff=l2_regularization_coeff,
weight_decay=weight_decay,
)
return burrito

0 comments on commit 9c1352c

Please sign in to comment.