Skip to content

Commit

Permalink
cleaned up comments
Browse files Browse the repository at this point in the history
  • Loading branch information
won-bae committed Aug 9, 2024
1 parent c271c3f commit 35fba41
Show file tree
Hide file tree
Showing 6 changed files with 0 additions and 31 deletions.
10 changes: 0 additions & 10 deletions src/models/cls_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.



import torch
from lightning import LightningModule
from torchmetrics import MinMetric, MaxMetric, MeanMetric
Expand All @@ -17,7 +16,6 @@
from src.utils.metrics import MeanMetricWithCount



class CLSLitModule(LightningModule):
"""Example of LightningModule for MNIST classification.
Expand Down Expand Up @@ -85,13 +83,6 @@ def __init__(
self.dataset = None
self.tuning = False

self.indices_to_check = [26291, 31336, 36921, 41564, 49741]
self.idx0_train_loss = MeanMetric()
self.idx1_train_loss = MeanMetric()
self.idx2_train_loss = MeanMetric()
self.idx3_train_loss = MeanMetric()
self.idx4_train_loss = MeanMetric()

def on_train_start(self):
# by default lightning executes validation step sanity checks before training starts,
# so it's worth to make sure validation metrics don't store results from these checks
Expand Down Expand Up @@ -214,7 +205,6 @@ def configure_optimizers(self):
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
"""
if self.tuning:
#assert self.dataset is not None
train_parameters = []
for name, params in self.net.named_parameters():
if name.startswith('fc.'):
Expand Down
2 changes: 0 additions & 2 deletions src/models/components/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def forward(
losses = self.common_step(output_dict, input_dict)
if self.training:
adjusted_loss = (torch.mean(losses) - self.flood_level).abs() + self.flood_level
#import IPython; IPython.embed()
else:
adjusted_loss = torch.mean(losses)
return {constants.LOSS: adjusted_loss, constants.LOSSES: losses}
Expand Down Expand Up @@ -241,7 +240,6 @@ def forward(

losses = self.common_step(output_dict, input_dict)
if self.training:
#probs = torch.softmax(output_dict[constants.AUX_LOGITS], dim=1)
aux_eval_losses = self.aux_step(
output_dict, input_dict, 'aux')

Expand Down
5 changes: 0 additions & 5 deletions src/models/tpp/prob_dists.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,6 @@ def forward(self, histories, log_times, masks):
masks_without_last[torch.arange(batch_size), last_event_idx] = 0
event_ll = (log_probs * masks_without_last).sum((1, 2)) - log_times.sum((1,2)) # (B,)

#if constants.LOG_MEAN in shared_data and constants.LOG_STD in shared_data:
# log_mean = shared_data[constants.LOG_MEAN]
# log_std = shared_data[constants.LOG_STD]
#else:
# log_mean, log_std = 0.0, 1.0
log_mean, log_std = 0.0, 1.0

# compute predictions
Expand Down
9 changes: 0 additions & 9 deletions src/models/tpp/thp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@


import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from src import constants
from src.models.tpp.thp.layers import EncoderLayer, CrossAttnLayer
Expand Down Expand Up @@ -153,13 +151,6 @@ def __init__(
nn.init.xavier_uniform_(self.query_proj.weight)
nn.init.xavier_uniform_(self.key_proj.weight)

#self.attention = ScaledDotProductAttention(
# temperature=d_k ** 0.5, attn_dropout=dropout)

#self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
#self.dropout = nn.Dropout(dropout)


def temporal_enc(self, time, non_pad_mask):
"""
Input: batch*seq_len.
Expand Down
4 changes: 0 additions & 4 deletions src/models/tpp/tpp_aux_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def __init__(self, name, activation, num_classes, d_model=256, d_inner=1024,
num_latent=num_latent, vi_method=vi_method, num_z_samples=num_z_samples, compute_acc=compute_acc)

aux_loss_path = aux_logit_path.replace('_logits', '_losses')
#aux_pred_path = aux_logit_path.replace('_logits', '_preds')
aux_mu_path = aux_logit_path.replace('_logits', '_mus')
aux_sigma_path = aux_logit_path.replace('_logits', '_sigmas')
aux_log_weight_path = aux_logit_path.replace('_logits', '_log_weights')
Expand All @@ -141,9 +140,6 @@ def __init__(self, name, activation, num_classes, d_model=256, d_inner=1024,
with open(aux_loss_path, "rb") as f:
self.loss_dict = pickle.load(f)

#with open(aux_pred_path, "rb") as f:
# self.pred_dict = pickle.load(f)

with open(aux_mu_path, "rb") as f:
self.mu_dict = pickle.load(f)

Expand Down
1 change: 0 additions & 1 deletion src/models/tpp_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def on_validation_epoch_end(self):
self.log("val/rmse_with_nll_best", self.val_rmse_with_nll_best.compute(), sync_dist=True, prog_bar=True)
self.log("val/acc_with_nll_best", self.val_acc_with_nll_best.compute(), sync_dist=True, prog_bar=True)

#if val_rmse < prev_val_rmse_best:
if val_rmse == val_rmse_best:
self.val_nll_with_rmse_best.reset()
self.val_nll_with_rmse_best(val_nll)
Expand Down

0 comments on commit 35fba41

Please sign in to comment.