From 1f53d0a63b45742c29d7ba132531afb1c313679d Mon Sep 17 00:00:00 2001 From: mahsanghani Date: Wed, 11 Sep 2024 13:42:43 -0400 Subject: [PATCH 1/4] save checkpoint as .weights.h5 --- site/en/tutorials/distribute/keras.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/site/en/tutorials/distribute/keras.ipynb b/site/en/tutorials/distribute/keras.ipynb index 1b6311a822c..2405a0fc0ef 100644 --- a/site/en/tutorials/distribute/keras.ipynb +++ b/site/en/tutorials/distribute/keras.ipynb @@ -363,7 +363,7 @@ "# Define the checkpoint directory to store the checkpoints.\n", "checkpoint_dir = './training_checkpoints'\n", "# Define the name of the checkpoint files.\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")" + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}.weights.h5\")" ] }, { @@ -396,7 +396,7 @@ "# Define a callback for printing the learning rate at the end of each epoch.\n", "class PrintLR(tf.keras.callbacks.Callback):\n", " def on_epoch_end(self, epoch, logs=None):\n", - " print('\\nLearning rate for epoch {} is {}'.format( epoch + 1, model.optimizer.lr.numpy()))" + " print('\\nLearning rate for epoch {} is {}'.format(epoch + 1, model.optimizer.lr.numpy()))" ] }, { From f1ba57adda5d88704bae2e34263c1de54dfeca67 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 11 Sep 2024 15:20:29 -0700 Subject: [PATCH 2/4] Fix lr -> learning_rate --- site/en/tutorials/distribute/keras.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/site/en/tutorials/distribute/keras.ipynb b/site/en/tutorials/distribute/keras.ipynb index 2405a0fc0ef..7a598ed4a84 100644 --- a/site/en/tutorials/distribute/keras.ipynb +++ b/site/en/tutorials/distribute/keras.ipynb @@ -396,7 +396,7 @@ "# Define a callback for printing the learning rate at the end of each epoch.\n", "class PrintLR(tf.keras.callbacks.Callback):\n", " def on_epoch_end(self, epoch, logs=None):\n", - " print('\\nLearning rate for epoch {} is {}'.format(epoch + 1, model.optimizer.lr.numpy()))" + " print('\\nLearning rate for epoch {} is {}'.format(epoch + 1, model.optimizer.learning_rate.numpy()))" ] }, { From 774fbc855a7869877f915754e36e23546d70ab39 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 11 Sep 2024 15:23:12 -0700 Subject: [PATCH 3/4] Update site/en/tutorials/distribute/keras.ipynb Make names sortable. --- site/en/tutorials/distribute/keras.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/site/en/tutorials/distribute/keras.ipynb b/site/en/tutorials/distribute/keras.ipynb index 7a598ed4a84..2aa3f0a45fd 100644 --- a/site/en/tutorials/distribute/keras.ipynb +++ b/site/en/tutorials/distribute/keras.ipynb @@ -363,7 +363,7 @@ "# Define the checkpoint directory to store the checkpoints.\n", "checkpoint_dir = './training_checkpoints'\n", "# Define the name of the checkpoint files.\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}.weights.h5\")" + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch:04d}.weights.h5\")" ] }, { From 9555b66055f936ed5c2536353ebd70d6e39d86b5 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 11 Sep 2024 15:29:08 -0700 Subject: [PATCH 4/4] implement latest_checkpoint. --- site/en/tutorials/distribute/keras.ipynb | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/site/en/tutorials/distribute/keras.ipynb b/site/en/tutorials/distribute/keras.ipynb index 2aa3f0a45fd..b96656d4436 100644 --- a/site/en/tutorials/distribute/keras.ipynb +++ b/site/en/tutorials/distribute/keras.ipynb @@ -486,7 +486,10 @@ }, "outputs": [], "source": [ - "model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))\n", + "import pathlib\n", + "latest_checkpoint = sorted(pathlib.Path(checkpoint_dir).glob('*'))[-1]\n", + "\n", + "model.load_weights(latest_checkpoint)\n", "\n", "eval_loss, eval_acc = model.evaluate(eval_dataset)\n", "\n",