Skip to content

Commit

Permalink
implement checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
brettinanl committed Apr 13, 2023
1 parent e790951 commit fa6641a
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions examples/ADRP/adrp_baseline_keras2.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,6 @@ def run(params):
seed = args.rng_seed
candle.set_seed(seed)

# Construct extension to save model
# ext = adrp.extension_from_parameters(params, ".keras")
# params['save_path'] = './'+params['base_name']+'/'
# candle.verify_path(params["save_path"])

# prefix = "{}{}".format(params["save_path"], ext)
prefix = "{}".format(params["save_path"])
logfile = params["logfile"] if params["logfile"] else prefix + "TEST.log"
candle.set_up_logger(logfile, adrp.logger, params["verbose"])
Expand All @@ -259,7 +253,6 @@ def run(params):
# Get default parameters for initialization and optimizer functions
keras_defaults = candle.keras_default_config()

##
X_train, Y_train, X_test, Y_test, PS, count_array = adrp.load_data(params, seed)

print("X_train shape:", X_train.shape)
Expand Down Expand Up @@ -342,12 +335,20 @@ def run(params):

# set up a bunch of callbacks to do work during model training..

checkpointer = ModelCheckpoint(
filepath=params["save_path"] + "agg_adrp.autosave.model.h5",
verbose=1,
save_weights_only=False,
save_best_only=True,
)
#checkpointer = ModelCheckpoint(
# filepath=params["save_path"] + "agg_adrp.autosave.model.h5",
# verbose=1,
# save_weights_only=False,
# save_best_only=True,
#)
initial_epoch = 0
ckpt = candle.CandleCkptKeras(params, verbose=True)
ckpt.set_model(model)
J = ckpt.restart(model)
if J is not None:
initial_epoch = J["epoch"]
print("restarting from ckpt: initial_epoch: %i" % initial_epoch)

csv_logger = CSVLogger(params["save_path"] + "agg_adrp.training.log")

# min_lr = params['learning_rate']*params['reduce_ratio']
Expand Down Expand Up @@ -456,8 +457,9 @@ def run(params):
verbose=1,
sample_weight=train_weight,
validation_data=(X_test, Y_test, test_weight),
callbacks=[checkpointer, timeout_monitor, csv_logger, reduce_lr, early_stop],
callbacks=[ckpt, timeout_monitor, csv_logger, reduce_lr, early_stop],
)
ckpt.report_final()

print("Reloading saved best model")
model.load_weights(params["save_path"] + "agg_adrp.autosave.model.h5")
Expand Down

0 comments on commit fa6641a

Please sign in to comment.