Skip to content

Commit

Permalink
Fix accuracy metric for logit output.
Browse files Browse the repository at this point in the history
Fixes: tensorflow/tensorflow#41413
PiperOrigin-RevId: 555773892
  • Loading branch information
MarkDaoust authored and copybara-github committed Aug 11, 2023
1 parent 1637c45 commit 51a06aa
Showing 1 changed file with 41 additions and 20 deletions.
61 changes: 41 additions & 20 deletions site/en/tutorials/images/transfer_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,9 @@
"source": [
"### Important note about BatchNormalization layers\n",
"\n",
"Many models contain `tf.keras.layers.BatchNormalization` layers. This layer is a special case and precautions should be taken in the context of fine-tuning, as shown later in this tutorial. \n",
"Many models contain `tf.keras.layers.BatchNormalization` layers. This layer is a special case and precautions should be taken in the context of fine-tuning, as shown later in this tutorial.\n",
"\n",
"When you set `layer.trainable = False`, the `BatchNormalization` layer will run in inference mode, and will not update its mean and variance statistics. \n",
"When you set `layer.trainable = False`, the `BatchNormalization` layer will run in inference mode, and will not update its mean and variance statistics.\n",
"\n",
"When you unfreeze a model that contains BatchNormalization layers in order to do fine-tuning, you should keep the BatchNormalization layers in inference mode by passing `training = False` when calling the base model. Otherwise, the updates applied to the non-trainable weights will destroy what the model has learned.\n",
"\n",
Expand Down Expand Up @@ -617,60 +617,71 @@
"model = tf.keras.Model(inputs, outputs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I8ARiyMFsgbH"
},
"outputs": [],
"source": [
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "g0ylJXE_kRLi"
"id": "lxOcmVr0ydFZ"
},
"source": [
"### Compile the model\n",
"\n",
"Compile the model before training it. Since there are two classes, use the `tf.keras.losses.BinaryCrossentropy` loss with `from_logits=True` since the model provides a linear output."
"The 8+ million parameters in MobileNet are frozen, but there are 1.2 thousand _trainable_ parameters in the Dense layer. These are divided between two `tf.Variable` objects, the weights and biases."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RpR8HdyMhukJ"
"id": "krvBumovycVA"
},
"outputs": [],
"source": [
"base_learning_rate = 0.0001\n",
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),\n",
" loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])"
"len(model.trainable_variables)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I8ARiyMFsgbH"
"id": "jeGk93R2ahav"
},
"outputs": [],
"source": [
"model.summary()"
"tf.keras.utils.plot_model(model, show_shapes=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lxOcmVr0ydFZ"
"id": "g0ylJXE_kRLi"
},
"source": [
"The 2.5 million parameters in MobileNet are frozen, but there are 1.2 thousand _trainable_ parameters in the Dense layer. These are divided between two `tf.Variable` objects, the weights and biases."
"### Compile the model\n",
"\n",
"Compile the model before training it. Since there are two classes, use the `tf.keras.losses.BinaryCrossentropy` loss with `from_logits=True` since the model provides a linear output."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "krvBumovycVA"
"id": "RpR8HdyMhukJ"
},
"outputs": [],
"source": [
"len(model.trainable_variables)"
"base_learning_rate = 0.0001\n",
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),\n",
" loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
" metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0, name='accuracy')])"
]
},
{
Expand All @@ -681,7 +692,7 @@
"source": [
"### Train the model\n",
"\n",
"After training for 10 epochs, you should see ~94% accuracy on the validation set.\n"
"After training for 10 epochs, you should see ~96% accuracy on the validation set.\n"
]
},
{
Expand Down Expand Up @@ -863,7 +874,7 @@
"source": [
"model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
" optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate/10),\n",
" metrics=['accuracy'])"
" metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0, name='accuracy')])"
]
},
{
Expand Down Expand Up @@ -1070,13 +1081,23 @@
"\n",
"To learn more, visit the [Transfer learning guide](https://www.tensorflow.org/guide/keras/transfer_learning).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uKIByL01da8c"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "transfer_learning.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
Expand Down

0 comments on commit 51a06aa

Please sign in to comment.