-
Notifications
You must be signed in to change notification settings - Fork 0
chapter 4.1 The Vanilla LSTM
A simple LSTM conguration is the Vanilla LSTM. It is named Vanilla in this book to differentiate it from deeper LSTMs and the suite of more elaborate congurations. It is the LSTM architecture dened in the original 1997 LSTM paper and the architecture that will give good results on most small sequence prediction problems. The Vanilla LSTM is dened as:
- Input layer.
- Fully connected LSTM hidden layer.
- Fully connected output layer.
In Keras, a Vanilla LSTM is dened below, with ellipsis for the specic conguration of the number of neurons in each layer.
This is the default or standard LSTM referenced in much work and discussion on LSTMs in deep learning and a good starting point when getting started with LSTMs on your sequence prediction problem. The Vanilla LSTM has the following 5 attractive properties, most of which were demonstrated in the original paper: Sequence classication conditional on multiple distributed input time steps.
- Memory of precise input observations over thousands of time steps.
- Sequence prediction as a function of prior time steps.
- Robust to the insertion of random time steps on the input sequence.
- Robust to the placement of signal data on the input sequence.
Next, we will define a simple sequence prediction problem that we can later use to demonstrate the Vanilla LSTM.
The echo sequence prediction problem is a contrived problem for demonstrating the memory capability of the Vanilla LSTM. The task is that, given a sequence of random integers as input, to output the value of a random integer at a specific time input step that is not specified to the model.
For example, given the input sequence of random integers [5, 3, 2] and the chosen time step was the second value, then the expected output is 3. Technically, this is a sequence classification problem; it is formulated as a many-to-one prediction problem, where there are multiple input time steps and one output time step at the end of the sequence.
This problem was carefully chosen to demonstrate the memory capability of the Vanilla LSTM. Further, we will manually perform some of the elements of the model life-cycle such as fitting and evaluating the model to get a deeper feeling for what is happening under the covers. Next, we will develop code to generate examples of this problem. This involves the following steps:
- Generate Random Sequences.
- One Hot Encode Sequences.
- Worked Example
- Reshape Sequences.
We can generate random integers in Python using the randint() function that takes two parameters indicating the range of integers from which to draw values. In this lesson, we will define the problem as having integer values between 0 and 99 with 100 unique values.
We can put this in a function called generate sequence() that will generate a sequence of random integers of the desired length. This function is listed below.
Once we have generated sequences of random integers, we need to transform them into a format that is suitable for training an LSTM network. One option would be to rescale the integer to the range [0,1]. This would work and would require that the problem be phrased as regression. We are interested in predicting the right number, not a number close to the expected value.
This means we would prefer to frame the problem as classication rather than regression, where the expected output is a class and there are 100 possible class values. In this case, we can use a one hot encoding of the integer values where each value is represented by a 100 element binary vector that is all 0 values except the index of the integer, which is marked 1.
The function below called one hot encode() denes how to iterate over a sequence of integers and create a binary vector representation for each and returns the result as a 2-dimensional array.
We also need to decode the encoded values so that we can make use of the predictions; in this case, to just review them. The one hot encoding can be inverted by using the argmax() NumPy function that returns the index of the value in the vector with the largest value. The function below, named one hot decode(), will decode an encoded sequence and can be used to later decode predictions from our network.
We can tie all of this together. Below is the complete code listing for generating a sequence of 25 random integers and encoding each integer as a binary vector. Running the example rst prints the list of 25 random integers, followed by a truncated view of the binary representations of all integers in the sequence, one vector per line, then the decoded sequence again. You may get dierent results as dierent random integers are generated each time the code is run.
The final step is to reshape the one hot encoded sequences into a format that can be used as input to the LSTM. This involves reshaping the encoded sequence to have n time steps and k features, where n is the number of integers in the generated sequence and k is the set of possible integers at each time step (e.g. 100)
A sequence can then be reshaped into a three-dimensional matrix of samples, time steps, and features, or for a single sequence of 25 integers [1, 25, 100]. As follows The output for the sequence is simply the encoded integer at a specic pre-dened location. This location must remain consistent for all examples generated for one model, so that the model can learn. For example, we can use the 2nd time step as the output of a sequence with 25 time steps by taking the encoded value directly from the encoded sequence
We can put this and the above generation and encoding steps together into a new function called generate example() that generates a sequence, encodes it, and returns the input (X) and output (y) components for training an LSTM.
We can put all of this together and test the generation of one example ready for tting or evaluating an LSTM as follows:
Running the code generates one encoded sequence and prints out the shape of the input and output components of the sequence for the LSTM.
Now that we know how to prepare and represent random sequences of integers, we can look at using LSTMs to learn them.
We will start o by dening and compiling the model. To keep the model small and ensure it is t in a reasonable time, we will greatly simplify the problem by reducing the sequence length to 5 integers and the number of features to 10 (e.g. 0-9). The model must specify the expected dimensionality of the input data. In this case, in terms of time steps (5) and features (10). We will use a single hidden layer LSTM with 25 memory units, chosen with a little trial and error. The output layer is a fully connected layer (Dense) with 10 neurons for the 10 possible integers that may be output. A softmax activation function is used on the output layer to allow the network to learn and output the distribution over the possible output values.
The network will use the log loss function while training, suitable for multiclass classication problems, and the ecient Adam optimization algorithm. The accuracy metric will be reported each training epoch to give an idea of the skill of the model in addition to the loss.
Running the example denes and compiles the model, then prints a summary of the model structure. Printing a summary of the model structure is a good practice in general to conrm the model was dened and compiled as you intended.
We can now t the model on example sequences. The code we developed for the echo sequence prediction problem generates random sequences. We could generate a large number of example sequences and pass them to the model's fit() function. The dataset would be loaded into memory, training would be fast, and we could experiment with varied number of epochs vs dataset size and number of batches.
A simpler approach is to manage the training process manually where one training sample is generated and used to update the model and any internal state is cleared. The number of epochs is the number of iterations of generating samples and essentially the batch size is 1 sample. Below is an example of tting the model for 10,000 epochs found with a little trial and error.
Fitting the model will report the log loss and accuracy for each pattern. Here, accuracy is either 0 or 1 (0% or 100%) because we are making sequence classication prediction on one sample and reporting the result.
Once the model is t, we can estimate the skill of the model when classifying new random sequences. We can do this by simply making predictions on 100 randomly generated sequences and counting the number of correct predictions made.
As with fitting the model, we could generate a large number of examples, concatenate them together, and use the evaluate() function to evaluate the model. In this case, we will make the predictions manually and count up the number of correct outcomes. We can do this in a loop that generates a sample, makes a prediction, and increments a counter if the prediction was correct.
Evaluating the model reports the estimated skill of the model as 100%.
Finally, we can use the t model to make predictions on new randomly generated sequences. For this problem, this is much the same as the case of evaluating the model. Because this is more of a user-facing activity, we can decode the whole sequence, expected output, and prediction and print them on the screen.
Running the example will print the decoded randomly generated sequence, expected outcome, and (hopefully) a prediction that meets the expected value. Your specic results will vary.
Don't panic if the model gets it wrong. LSTMs are stochastic and it is possible that a single run of the model may converge on a solution that does not completely learn the problem. If this happens to you, try running the example a few more times.
This section lists the complete working example for your reference.
This section provides some resources for further reading.
- Long Short-Term Memory, 1997.
- Learning to Forget: Continual Prediction with LSTM, 1999.
Do you want to dive deeper into the Vanilla LSTM? This section lists some challenging extensions to this lesson.
- Update the example to use a longer sequence length and still achieve 100% accuracy.
- Update the example to use a larger number of features and still achieve 100% accuracy.
- Update the example to use the SGD optimization algorithm and tune the learning rate and momentum.
- Update the example to prepare a large dataset of examples to t the model and explore different batch sizes.
- Vary the time step index of the sequence output and training epochs to see if there is a relationship between the index and how hard the problem is to learn. Post your extensions online and share the link with me. I'd love to see what you come up with!
In this lesson, you discovered how to develop a Vanilla or standard LSTM. Specically, you learned:
- The architecture of the Vanilla LSTM for sequence prediction and its general capabilities.
- How to dene and implement the echo sequence prediction problem.
- How to develop a Vanilla LSTM to learn and make accurate predictions on the echo sequence prediction problem.