diff --git a/docs/tutorials/text_generation.ipynb b/docs/tutorials/text_generation.ipynb index a427038df..fb2f6813f 100644 --- a/docs/tutorials/text_generation.ipynb +++ b/docs/tutorials/text_generation.ipynb @@ -160,7 +160,9 @@ "\n", "import numpy as np\n", "import os\n", - "import time" + "import time", + "!pip install tf-keras", + "import tf_keras as keras" ] }, { @@ -182,7 +184,7 @@ }, "outputs": [], "source": [ - "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')" + "path_to_file = keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')" ] }, { @@ -288,7 +290,7 @@ }, "outputs": [], "source": [ - "ids_from_chars = tf.keras.layers.StringLookup(\n", + "ids_from_chars = keras.layers.StringLookup(\n", " vocabulary=list(vocab), mask_token=None)" ] }, @@ -339,7 +341,7 @@ }, "outputs": [], "source": [ - "chars_from_ids = tf.keras.layers.StringLookup(\n", + "chars_from_ids = keras.layers.StringLookup(\n", " vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)" ] }, @@ -671,14 +673,14 @@ }, "outputs": [], "source": [ - "class MyModel(tf.keras.Model):\n", + "class MyModel(keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, rnn_units):\n", " super().__init__(self)\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - " self.gru = tf.keras.layers.GRU(rnn_units,\n", + " self.embedding = keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = keras.layers.GRU(rnn_units,\n", " return_sequences=True,\n", " return_state=True)\n", - " self.dense = tf.keras.layers.Dense(vocab_size)\n", + " self.dense = keras.layers.Dense(vocab_size)\n", "\n", " def call(self, inputs, states=None, return_state=False, training=False):\n", " x = inputs\n", @@ -974,7 +976,7 @@ "# Name of the checkpoint files\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n", "\n", - "checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n", + "checkpoint_callback = keras.callbacks.ModelCheckpoint(\n", " filepath=checkpoint_prefix,\n", " save_weights_only=True)" ] @@ -1058,7 +1060,7 @@ }, "outputs": [], "source": [ - "class OneStep(tf.keras.Model):\n", + "class OneStep(keras.Model):\n", " def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):\n", " super().__init__()\n", " self.temperature = temperature\n", @@ -1306,8 +1308,8 @@ }, "outputs": [], "source": [ - "model.compile(optimizer = tf.keras.optimizers.Adam(),\n", - " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))" + "model.compile(optimizer = keras.optimizers.Adam(),\n", + " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True))" ] }, { @@ -1345,7 +1347,7 @@ "for epoch in range(EPOCHS):\n", " start = time.time()\n", "\n", - " mean.reset_states()\n", + " mean.reset_state()\n", " for (batch_n, (inp, target)) in enumerate(dataset):\n", " logs = model.train_step([inp, target])\n", " mean.update_state(logs['loss'])\n",