From 52889cf6f7e2bb40547395c231ecd4d1fe048760 Mon Sep 17 00:00:00 2001 From: Matthew Schultz Date: Sun, 23 Jun 2024 21:30:23 -0400 Subject: [PATCH 01/11] Updating 8 year old tutorial to include DataLoader, splitting into train and test sets as well as simplifying content --- .../char_rnn_classification_tutorial.py | 593 ++++++++++-------- 1 file changed, 324 insertions(+), 269 deletions(-) diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index 8451f07b82..b20760af44 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -3,6 +3,7 @@ NLP From Scratch: Classifying Names with a Character-Level RNN ************************************************************** **Author**: `Sean Robertson `_ +**Updated**: `Matthew Schultz `_ We will be building and training a basic character-level Recurrent Neural Network (RNN) to classify words. This tutorial, along with two other @@ -70,61 +71,47 @@ line, mostly romanized (but we still need to convert from Unicode to ASCII). -We'll end up with a dictionary of lists of names per language, -``{language: [names ...]}``. The generic variables "category" and "line" -(for language and name in our case) are used for later extensibility. -""" -from io import open -import glob -import os - -def findFiles(path): return glob.glob(path) - -print(findFiles('data/names/*.txt')) - -import unicodedata -import string +The first thing we need to define is our data items. In this case, we will create a class called NameData +which will have an __init__ function to specify the input fields and some helper functions. Our first +helper function will be __str__ to convert objects to strings for easy printing -all_letters = string.ascii_letters + " .,;'" -n_letters = len(all_letters) -# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427 -def unicodeToAscii(s): - return ''.join( - c for c in unicodedata.normalize('NFD', s) - if unicodedata.category(c) != 'Mn' - and c in all_letters - ) +There are two key pieces of this that we will flesh out over the course of this tutorial. First is the basic data +object which a label and some text. In this instance, label = the country of origin and text = the name. -print(unicodeToAscii('Ślusàrski')) +However, our data has some issues that we will need to clean up. First off, we need to convert unicode to plain ASCII to +limit the RNN input layers. This is accomplished by converting unicode strings to ASCII and allowing a samll set of allowed characters (allowed_characters) +""" -# Build the category_lines dictionary, a list of names per language -category_lines = {} -all_categories = [] +import torch +import string +import unicodedata -# Read a file and split into lines -def readLines(filename): - lines = open(filename, encoding='utf-8').read().strip().split('\n') - return [unicodeToAscii(line) for line in lines] +class NameData: + allowed_characters = string.ascii_letters + " .,;'" + n_letters = len(allowed_characters) -for filename in findFiles('data/names/*.txt'): - category = os.path.splitext(os.path.basename(filename))[0] - all_categories.append(category) - lines = readLines(filename) - category_lines[category] = lines -n_categories = len(all_categories) + def __init__(self, label, text): + self.label = label + self.text = NameData.unicodeToAscii(text) + + def __str__(self): + return f"label={self.label}, text={self.text}" + # Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427 + def unicodeToAscii(s): + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + and c in NameData.allowed_characters + ) -###################################################################### -# Now we have ``category_lines``, a dictionary mapping each category -# (language) to a list of lines (names). We also kept track of -# ``all_categories`` (just a list of languages) and ``n_categories`` for -# later reference. +######################### +#Now we can use that class to create a singe piece of data. # -print(category_lines['Italian'][:5]) - +print (f"{NameData(label='Polish', text='Ślusàrski')}") ###################################################################### # Turning Names into Tensors @@ -143,30 +130,122 @@ def readLines(filename): # That extra 1 dimension is because PyTorch assumes everything is in # batches - we're just using a batch size of 1 here. # +# For this, you'll need to add a couple of capabilities to our NameData object. + +import torch +import string +import unicodedata + +class NameData: + allowed_characters = string.ascii_letters + " .,;'" + n_letters = len(allowed_characters) + + + def __init__(self, label, text): + self.label = label + self.text = NameData.unicodeToAscii(text) + self.tensor = NameData.lineToTensor(self.text) + + def __str__(self): + return f"label={self.label}, text={self.text}\ntensor = {self.tensor}" + + # Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427 + def unicodeToAscii(s): + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + and c in NameData.allowed_characters + ) + + # Find letter index from all_letters, e.g. "a" = 0 + def letterToIndex(letter): + return NameData.allowed_characters.find(letter) + + # Turn a line into a , + # or an array of one-hot letter vectors + def lineToTensor(line): + tensor = torch.zeros(len(line), 1, NameData.n_letters) + for li, letter in enumerate(line): + tensor[li][0][NameData.letterToIndex(letter)] = 1 + return tensor + +######################### +#Here are some examples of how to use the NameData object + +print (f"{NameData(label='none', text='a')}") +print (f"{NameData(label='Korean', text='Ahn')}") + +######################### +#Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach +#for other RNN tasks with text. +# +#Next, we need to combine all our examples into a dataset so we can train, text and validate our models. For this, +#we will use the `Dataset and DataLoader ` classes +#to hold our dataset. Each Dataset needs to implement three functions: __init__, __len__, and __getitem__. + +from io import open +import glob +import os +import unicodedata +import string +import time import torch +from torch.utils.data import Dataset + +class NamesDataset(Dataset): + + def __init__(self, data_dir): + self.data_dir = data_dir #for provenance of the dataset + self.load_time = time.localtime #for provenance of the dataset + labels_set = set() #set of all classes -# Find letter index from all_letters, e.g. "a" = 0 -def letterToIndex(letter): - return all_letters.find(letter) + self.data = [] -# Just for demonstration, turn a letter into a <1 x n_letters> Tensor -def letterToTensor(letter): - tensor = torch.zeros(1, n_letters) - tensor[0][letterToIndex(letter)] = 1 - return tensor + #read all the txt files in the specified directory + text_files = glob.glob(os.path.join(data_dir, '*.txt')) + for filename in text_files: + label = os.path.splitext(os.path.basename(filename))[0] + labels_set.add(label) + lines = NamesDataset.readLines(filename) + for name in lines: + self.data.append(NameData(label=label, text=name)) -# Turn a line into a , -# or an array of one-hot letter vectors -def lineToTensor(line): - tensor = torch.zeros(len(line), 1, n_letters) - for li, letter in enumerate(line): - tensor[li][0][letterToIndex(letter)] = 1 - return tensor + self.labels = list(labels_set) -print(letterToTensor('J')) + def __len__(self): + return len(self.data) -print(lineToTensor('Jones').size()) + def __getitem__(self, idx): + data_item = self.data[idx] + label_tensor = torch.tensor([self.labels.index(data_item.label)], dtype=torch.long) + return label_tensor, data_item.tensor, data_item.label, data_item.text + + # Read a file and split into lines + def readLines(filename): + lines = open(filename, encoding='utf-8').read().strip().split('\n') + return lines + + +######################### +#Here are some examples of how to use the NamesDataset object + + +alldata = NamesDataset("data/names") +print(f"loaded {len(alldata)} items of data") +print(f"example = {alldata[0]}") + +######################### +#Using the dataset object allows us to easily split the data into train and test sets. Here we create na 80/20 +#split but the torch.utils.data has more useful utilities. + +train_set, test_set = torch.utils.data.random_split(alldata, [.8, .2]) + +print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}") + +######################### +#Now we have a basic dataset containing 20074 examples where each example is a pairing of label and name. We have also +#split the datset into training and testing so we can validate the model that we build. ###################################################################### @@ -181,111 +260,112 @@ def lineToTensor(line): # # This RNN module implements a "vanilla RNN" an is just 3 linear layers # which operate on an input and hidden state, with a ``LogSoftmax`` layer -# after the output. +# after the output.s # import torch.nn as nn import torch.nn.functional as F class RNN(nn.Module): - def __init__(self, input_size, hidden_size, output_size): + def __init__(self, input_size, hidden_size, output_labels): super(RNN, self).__init__() self.hidden_size = hidden_size + self.output_labels = output_labels self.i2h = nn.Linear(input_size, hidden_size) self.h2h = nn.Linear(hidden_size, hidden_size) - self.h2o = nn.Linear(hidden_size, output_size) + self.h2o = nn.Linear(hidden_size, len(output_labels)) self.softmax = nn.LogSoftmax(dim=1) + def initHidden(self): + return torch.zeros(1, self.hidden_size) + def forward(self, input, hidden): hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) output = self.h2o(hidden) output = self.softmax(output) return output, hidden - def initHidden(self): - return torch.zeros(1, self.hidden_size) +########################### +#We can then create a RNN with 128 hidden nodes and given our datasets -n_hidden = 128 -rnn = RNN(n_letters, n_hidden, n_categories) +n_hidden = 128 +rnn = RNN(NameData.n_letters, n_hidden, alldata.labels) +print(rnn) ###################################################################### -# To run a step of this network we need to pass an input (in our case, the -# Tensor for the current letter) and a previous hidden state (which we -# initialize as zeros at first). We'll get back the output (probability of -# each language) and a next hidden state (which we keep for the next -# step). -# +# To run a step of this network we need to pass a single character input +# and a hidden state (which we initialize as zeros at first). We'll get to +# multi-character names during training -input = letterToTensor('A') +input = NameData(label='none', text='A').tensor hidden = torch.zeros(1, n_hidden) - -output, next_hidden = rnn(input, hidden) - +output, next_hidden = rnn(input[0], hidden) +print(output) ###################################################################### -# For the sake of efficiency we don't want to be creating a new Tensor for -# every step, so we will use ``lineToTensor`` instead of -# ``letterToTensor`` and use slices. This could be further optimized by -# precomputing batches of Tensors. -# +# Scoring Multi-character names +# -------------------- +# Multi-character names require just a little bit more effort which is +# keeping track of the hidden output and passing it back into the RNN. +# You can see this defined in the function forward_multi() -input = lineToTensor('Albert') -hidden = torch.zeros(1, n_hidden) +import torch.nn as nn +import torch.nn.functional as F -output, next_hidden = rnn(input[0], hidden) -print(output) +class RNN(nn.Module): + def __init__(self, input_size, hidden_size, output_labels): + super(RNN, self).__init__() + self.hidden_size = hidden_size + self.output_labels = output_labels -###################################################################### -# As you can see the output is a ``<1 x n_categories>`` Tensor, where -# every item is the likelihood of that category (higher is more likely). -# + self.i2h = nn.Linear(input_size, hidden_size) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.h2o = nn.Linear(hidden_size, len(output_labels)) + self.softmax = nn.LogSoftmax(dim=1) + def initHidden(self): + return torch.zeros(1, self.hidden_size) -###################################################################### -# -# Training -# ======== -# Preparing for Training -# ---------------------- -# -# Before going into training we should make a few helper functions. The -# first is to interpret the output of the network, which we know to be a -# likelihood of each category. We can use ``Tensor.topk`` to get the index -# of the greatest value: -# + def forward(self, input, hidden): + hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) + output = self.h2o(hidden) + output = self.softmax(output) + return output, hidden + + def forward_multi(self, line_tensor): + hidden = rnn.initHidden() -def categoryFromOutput(output): - top_n, top_i = output.topk(1) - category_i = top_i[0].item() - return all_categories[category_i], category_i + for i in range(line_tensor.size()[0]): + output, hidden = rnn.forward(line_tensor[i], hidden) -print(categoryFromOutput(output)) + return output, hidden + def categoryFromOutput(self, output): + top_n, top_i = output.topk(1) + category_i = top_i[0].item() + return self.output_labels[category_i], category_i -###################################################################### -# We will also want a quick way to get a training example (a name and its -# language): -# +########################### +#Now we can score the output for names! -import random -def randomChoice(l): - return l[random.randint(0, len(l) - 1)] +n_hidden = 128 +hidden = torch.zeros(1, n_hidden) +rnn = RNN(NameData.n_letters, n_hidden, alldata.labels) -def randomTrainingExample(): - category = randomChoice(all_categories) - line = randomChoice(category_lines[category]) - category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long) - line_tensor = lineToTensor(line) - return category, line, category_tensor, line_tensor +input = NameData(label='none', text='Albert').tensor +output, next_hidden = rnn.forward_multi(input) +print(output) +print(rnn.categoryFromOutput(output)) -for i in range(10): - category, line, category_tensor, line_tensor = randomTrainingExample() - print('category =', category, '/ line =', line) +###################################################################### +# +# Training +# ======== ###################################################################### @@ -294,16 +374,9 @@ def randomTrainingExample(): # # Now all it takes to train this network is show it a bunch of examples, # have it make guesses, and tell it if it's wrong. -# -# For the loss function ``nn.NLLLoss`` is appropriate, since the last -# layer of the RNN is ``nn.LogSoftmax``. -# - -criterion = nn.NLLLoss() - - -###################################################################### -# Each loop of training will: +# +# We start by defining a function learn_single() which learns from a single +# piece of input data. # # - Create input and target tensors # - Create a zeroed initial hidden state @@ -315,73 +388,115 @@ def randomTrainingExample(): # - Back-propagate # - Return the output and loss # +# We also define a learn_batch() function which trains on a given dataset -learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn -def train(category_tensor, line_tensor): - hidden = rnn.initHidden() +import torch.nn as nn +import torch.nn.functional as F +import random - rnn.zero_grad() +class RNN(nn.Module): + def __init__(self, input_size, hidden_size, output_labels, criterion = nn.NLLLoss()): + super(RNN, self).__init__() - for i in range(line_tensor.size()[0]): - output, hidden = rnn(line_tensor[i], hidden) + self.hidden_size = hidden_size + self.output_labels = output_labels - loss = criterion(output, category_tensor) - loss.backward() + self.i2h = nn.Linear(input_size, hidden_size) + self.h2h = nn.Linear(hidden_size, hidden_size) + self.h2o = nn.Linear(hidden_size, len(output_labels)) + self.softmax = nn.LogSoftmax(dim=1) - # Add parameters' gradients to their values, multiplied by learning rate - for p in rnn.parameters(): - p.data.add_(p.grad.data, alpha=-learning_rate) + self.criterion = criterion - return output, loss.item() + def initHidden(self): + return torch.zeros(1, self.hidden_size) + def forward(self, input, hidden): + hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) + output = self.h2o(hidden) + output = self.softmax(output) + return output, hidden + + def forward_multi(self, line_tensor): + hidden = self.initHidden() -###################################################################### -# Now we just have to run that with a bunch of examples. Since the -# ``train`` function returns both the output and loss we can print its -# guesses and also keep track of loss for plotting. Since there are 1000s -# of examples we print only every ``print_every`` examples, and take an -# average of the loss. -# + for i in range(line_tensor.size()[0]): + output, hidden = self.forward(line_tensor[i], hidden) -import time -import math + return output, hidden -n_iters = 100000 -print_every = 5000 -plot_every = 1000 + def categoryFromOutput(self, output): + top_n, top_i = output.topk(1) + category_i = top_i[0].item() + return self.output_labels[category_i], category_i + + def learn_single(self, label_tensor, line_tensor, learning_rate = 0.005): + #Train the RNN for one example with a learning rate that defaults to 0.005. + + + rnn.zero_grad() + output, hidden = self.forward_multi(line_tensor) + + loss = self.criterion(output, label_tensor) + loss.backward() + + # Add parameters' gradients to their values, multiplied by learning rate + for p in self.parameters(): + p.data.add_(p.grad.data, alpha=-learning_rate) + + return output, loss.item() + + def learn_batch(self, training_data, n_iters = 1000, report_every = 100): + """ + Learn on a batch of training_data for a specified number of iterations and reporting thresholds + """ + + # Keep track of losses for plotting + current_loss = 0 + all_losses = [] + start = time.time() + print(f"training data = {training_data}") + print(f"size = {len(training_data)}") + for iter in range(1, n_iters + 1): + rand_idx = random.randint(0,len(training_data)-1) + (label_tensor, text_tensor, label, text) = training_data[rand_idx] -# Keep track of losses for plotting -current_loss = 0 -all_losses = [] + output, loss = self.learn_single(label_tensor, text_tensor) + current_loss += loss -def timeSince(since): - now = time.time() - s = now - since - m = math.floor(s / 60) - s -= m * 60 - return '%dm %ds' % (m, s) + # Print ``iter`` number, loss, name and guess + if iter % report_every == 0: + all_losses.append(current_loss / report_every) + print(f"{iter} ({iter / n_iters:.0%}): \t iteration loss = {all_losses[-1]}") + current_loss = 0 + + return all_losses -start = time.time() +########################### +#We can test this with one of our examples and see the output vector, loss and guess of a class from a random network. +# +#Here is a single input example -for iter in range(1, n_iters + 1): - category, line, category_tensor, line_tensor = randomTrainingExample() - output, loss = train(category_tensor, line_tensor) - current_loss += loss +n_hidden = 128 +hidden = torch.zeros(1, n_hidden) +rnn = RNN(NameData.n_letters, n_hidden, alldata.labels) - # Print ``iter`` number, loss, name and guess - if iter % print_every == 0: - guess, guess_i = categoryFromOutput(output) - correct = '✓' if guess == category else '✗ (%s)' % category - print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct)) +(label_tensor, text_tensor, label, text) = train_set[0] +print(f"training on name = {text} with label = {label}") +(output, loss) = rnn.learn_single(label_tensor, text_tensor) + +print("LogSoftmax outputs (highest score is predicted class") +for i in range(len(output[0])): + print (f"\t{i}. {alldata.labels[i]} => {output[0][i]}") + +########################### +#We can also train on our training data set by randomly selecting examples - # Add current loss avg to list of losses - if iter % plot_every == 0: - all_losses.append(current_loss / plot_every) - current_loss = 0 +all_losses = rnn.learn_batch(train_set, n_iters=200000, report_every=10000) ###################################################################### # Plotting the Results @@ -396,7 +511,7 @@ def timeSince(since): plt.figure() plt.plot(all_losses) - +plt.show() ###################################################################### # Evaluating the Results @@ -409,47 +524,39 @@ def timeSince(since): # ``evaluate()``, which is the same as ``train()`` minus the backprop. # -# Keep track of correct guesses in a confusion matrix -confusion = torch.zeros(n_categories, n_categories) -n_confusion = 10000 - -# Just return an output given a line -def evaluate(line_tensor): - hidden = rnn.initHidden() - - for i in range(line_tensor.size()[0]): - output, hidden = rnn(line_tensor[i], hidden) +def evaluate(rnn, testing_data): + confusion = torch.zeros(len(rnn.output_labels), len(rnn.output_labels)) - return output + with torch.no_grad(): # do not record the gradiants during eval phase + for i in range(len(testing_data)): + (label_tensor, text_tensor, label, text) = testing_data[i] + (output, hidden) = rnn.forward_multi(text_tensor) + guess, guess_i = rnn.categoryFromOutput(output) + category_i = rnn.output_labels.index(label) + confusion[category_i][guess_i] += 1 -# Go through a bunch of examples and record which are correctly guessed -for i in range(n_confusion): - category, line, category_tensor, line_tensor = randomTrainingExample() - output = evaluate(line_tensor) - guess, guess_i = categoryFromOutput(output) - category_i = all_categories.index(category) - confusion[category_i][guess_i] += 1 + # Normalize by dividing every row by its sum + for i in range(len(rnn.output_labels)): + confusion[i] = confusion[i] / confusion[i].sum() -# Normalize by dividing every row by its sum -for i in range(n_categories): - confusion[i] = confusion[i] / confusion[i].sum() + # Set up plot + fig = plt.figure() + ax = fig.add_subplot(111) + cax = ax.matshow(confusion.numpy()) + fig.colorbar(cax) -# Set up plot -fig = plt.figure() -ax = fig.add_subplot(111) -cax = ax.matshow(confusion.numpy()) -fig.colorbar(cax) + # Set up axes + ax.set_xticklabels([''] + rnn.output_labels, rotation=90) + ax.set_yticklabels([''] + rnn.output_labels) -# Set up axes -ax.set_xticklabels([''] + all_categories, rotation=90) -ax.set_yticklabels([''] + all_categories) + # Force label at every tick + ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) + ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) -# Force label at every tick -ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) -ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) + # sphinx_gallery_thumbnail_number = 2 + plt.show() -# sphinx_gallery_thumbnail_number = 2 -plt.show() +evaluate(rnn, test_set) ###################################################################### @@ -460,58 +567,6 @@ def evaluate(line_tensor): # -###################################################################### -# Running on User Input -# --------------------- -# - -def predict(input_line, n_predictions=3): - print('\n> %s' % input_line) - with torch.no_grad(): - output = evaluate(lineToTensor(input_line)) - - # Get top N categories - topv, topi = output.topk(n_predictions, 1, True) - predictions = [] - - for i in range(n_predictions): - value = topv[0][i].item() - category_index = topi[0][i].item() - print('(%.2f) %s' % (value, all_categories[category_index])) - predictions.append([value, all_categories[category_index]]) - -predict('Dovesky') -predict('Jackson') -predict('Satoshi') - - -###################################################################### -# The final versions of the scripts `in the Practical PyTorch -# repo `__ -# split the above code into a few files: -# -# - ``data.py`` (loads files) -# - ``model.py`` (defines the RNN) -# - ``train.py`` (runs training) -# - ``predict.py`` (runs ``predict()`` with command line arguments) -# - ``server.py`` (serve prediction as a JSON API with ``bottle.py``) -# -# Run ``train.py`` to train and save the network. -# -# Run ``predict.py`` with a name to view predictions: -# -# .. code-block:: sh -# -# $ python predict.py Hazaki -# (-0.42) Japanese -# (-1.39) Polish -# (-3.51) Czech -# -# Run ``server.py`` and visit http://localhost:5533/Yourname to get JSON -# output of predictions. -# - - ###################################################################### # Exercises # ========= @@ -528,4 +583,4 @@ def predict(input_line, n_predictions=3): # - Add more linear layers # - Try the ``nn.LSTM`` and ``nn.GRU`` layers # - Combine multiple of these RNNs as a higher level network -# +# \ No newline at end of file From 126f7ad4cfb07958b5df0f1ef85ed732857a03de Mon Sep 17 00:00:00 2001 From: Matthew Schultz Date: Mon, 24 Jun 2024 11:50:33 -0400 Subject: [PATCH 02/11] use label instead of category for class and remove old cmd line code-block --- .../char_rnn_classification_tutorial.py | 35 ++++++------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index b20760af44..a928df7a05 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -23,19 +23,6 @@ of origin, and predict which language a name is from based on the spelling: -.. code-block:: sh - - $ python predict.py Hinton - (-0.47) Scottish - (-1.52) English - (-3.57) Irish - - $ python predict.py Schmidhuber - (-0.19) German - (-2.48) Czech - (-2.68) Dutch - - Recommended Preparation ======================= @@ -344,10 +331,10 @@ def forward_multi(self, line_tensor): return output, hidden - def categoryFromOutput(self, output): + def label_from_output(self, output): top_n, top_i = output.topk(1) - category_i = top_i[0].item() - return self.output_labels[category_i], category_i + label_i = top_i[0].item() + return self.output_labels[label_i], label_i ########################### #Now we can score the output for names! @@ -360,7 +347,7 @@ def categoryFromOutput(self, output): input = NameData(label='none', text='Albert').tensor output, next_hidden = rnn.forward_multi(input) print(output) -print(rnn.categoryFromOutput(output)) +print(rnn.label_from_output(output)) ###################################################################### # @@ -426,10 +413,10 @@ def forward_multi(self, line_tensor): return output, hidden - def categoryFromOutput(self, output): + def label_from_output(self, output): top_n, top_i = output.topk(1) - category_i = top_i[0].item() - return self.output_labels[category_i], category_i + label_i = top_i[0].item() + return self.output_labels[label_i], label_i def learn_single(self, label_tensor, line_tensor, learning_rate = 0.005): #Train the RNN for one example with a learning rate that defaults to 0.005. @@ -531,9 +518,9 @@ def evaluate(rnn, testing_data): for i in range(len(testing_data)): (label_tensor, text_tensor, label, text) = testing_data[i] (output, hidden) = rnn.forward_multi(text_tensor) - guess, guess_i = rnn.categoryFromOutput(output) - category_i = rnn.output_labels.index(label) - confusion[category_i][guess_i] += 1 + guess, guess_i = rnn.label_from_output(output) + label_i = rnn.output_labels.index(label) + confusion[label_i][guess_i] += 1 # Normalize by dividing every row by its sum for i in range(len(rnn.output_labels)): @@ -571,7 +558,7 @@ def evaluate(rnn, testing_data): # Exercises # ========= # -# - Try with a different dataset of line -> category, for example: +# - Try with a different dataset of line -> label, for example: # # - Any word -> language # - First name -> gender From b4db18d5d7a3719ceaebcd8e7fdbb7ac03b696d9 Mon Sep 17 00:00:00 2001 From: Matthew Schultz Date: Thu, 27 Jun 2024 21:04:09 -0400 Subject: [PATCH 03/11] Simplify RNN class (e.g. one forward function), adding minibatches + optimizer --- .../char_rnn_classification_tutorial.py | 169 +++++++----------- 1 file changed, 68 insertions(+), 101 deletions(-) diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index a928df7a05..37d6b33059 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -194,7 +194,7 @@ def __init__(self, data_dir): for filename in text_files: label = os.path.splitext(os.path.basename(filename))[0] labels_set.add(label) - lines = NamesDataset.readLines(filename) + lines = open(filename, encoding='utf-8').read().strip().split('\n') for name in lines: self.data.append(NameData(label=label, text=name)) @@ -208,11 +208,6 @@ def __getitem__(self, idx): label_tensor = torch.tensor([self.labels.index(data_item.label)], dtype=torch.long) return label_tensor, data_item.tensor, data_item.label, data_item.text - # Read a file and split into lines - def readLines(filename): - lines = open(filename, encoding='utf-8').read().strip().split('\n') - return lines - ######################### #Here are some examples of how to use the NamesDataset object @@ -265,9 +260,6 @@ def __init__(self, input_size, hidden_size, output_labels): self.h2o = nn.Linear(hidden_size, len(output_labels)) self.softmax = nn.LogSoftmax(dim=1) - def initHidden(self): - return torch.zeros(1, self.hidden_size) - def forward(self, input, hidden): hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) output = self.h2o(hidden) @@ -285,11 +277,10 @@ def forward(self, input, hidden): ###################################################################### # To run a step of this network we need to pass a single character input # and a hidden state (which we initialize as zeros at first). We'll get to -# multi-character names during training +# multi-character names next input = NameData(label='none', text='A').tensor -hidden = torch.zeros(1, n_hidden) -output, next_hidden = rnn(input[0], hidden) +output, next_hidden = rnn(input[0], torch.zeros(1, n_hidden)) print(output) ###################################################################### @@ -297,7 +288,7 @@ def forward(self, input, hidden): # -------------------- # Multi-character names require just a little bit more effort which is # keeping track of the hidden output and passing it back into the RNN. -# You can see this defined in the function forward_multi() +# You can see this updated work defined in the function forward() import torch.nn as nn import torch.nn.functional as F @@ -313,39 +304,34 @@ def __init__(self, input_size, hidden_size, output_labels): self.h2h = nn.Linear(hidden_size, hidden_size) self.h2o = nn.Linear(hidden_size, len(output_labels)) self.softmax = nn.LogSoftmax(dim=1) - - def initHidden(self): - return torch.zeros(1, self.hidden_size) - - def forward(self, input, hidden): - hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) - output = self.h2o(hidden) - output = self.softmax(output) - return output, hidden - def forward_multi(self, line_tensor): - hidden = rnn.initHidden() + def forward(self, line_tensor): + hidden = torch.zeros(1, rnn.hidden_size) + output = torch.zeros(1, len(self.output_labels)) for i in range(line_tensor.size()[0]): - output, hidden = rnn.forward(line_tensor[i], hidden) + input = line_tensor[i] + hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) + output = self.h2o(hidden) + output = self.softmax(output) - return output, hidden + return output def label_from_output(self, output): top_n, top_i = output.topk(1) label_i = top_i[0].item() return self.output_labels[label_i], label_i + ########################### #Now we can score the output for names! n_hidden = 128 -hidden = torch.zeros(1, n_hidden) rnn = RNN(NameData.n_letters, n_hidden, alldata.labels) input = NameData(label='none', text='Albert').tensor -output, next_hidden = rnn.forward_multi(input) +output = rnn(input) #this is equivalent to output = rnn.forward(input) print(output) print(rnn.label_from_output(output)) @@ -375,15 +361,15 @@ def label_from_output(self, output): # - Back-propagate # - Return the output and loss # -# We also define a learn_batch() function which trains on a given dataset - +# We also define a learn() function which trains on a given dataset with minibatches import torch.nn as nn import torch.nn.functional as F import random +import numpy as np class RNN(nn.Module): - def __init__(self, input_size, hidden_size, output_labels, criterion = nn.NLLLoss()): + def __init__(self, input_size, hidden_size, output_labels): super(RNN, self).__init__() self.hidden_size = hidden_size @@ -393,97 +379,76 @@ def __init__(self, input_size, hidden_size, output_labels, criterion = nn.NLLLos self.h2h = nn.Linear(hidden_size, hidden_size) self.h2o = nn.Linear(hidden_size, len(output_labels)) self.softmax = nn.LogSoftmax(dim=1) - - self.criterion = criterion - - def initHidden(self): - return torch.zeros(1, self.hidden_size) - - def forward(self, input, hidden): - hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) - output = self.h2o(hidden) - output = self.softmax(output) - return output, hidden - def forward_multi(self, line_tensor): - hidden = self.initHidden() + def forward(self, line_tensor): + hidden = torch.zeros(1, rnn.hidden_size) + output = torch.zeros(1, len(self.output_labels)) for i in range(line_tensor.size()[0]): - output, hidden = self.forward(line_tensor[i], hidden) + input = line_tensor[i] + hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) + output = self.h2o(hidden) + output = self.softmax(output) - return output, hidden + return output def label_from_output(self, output): top_n, top_i = output.topk(1) label_i = top_i[0].item() - return self.output_labels[label_i], label_i + return self.output_labels[label_i], label_i - def learn_single(self, label_tensor, line_tensor, learning_rate = 0.005): - #Train the RNN for one example with a learning rate that defaults to 0.005. - - - rnn.zero_grad() - output, hidden = self.forward_multi(line_tensor) - - loss = self.criterion(output, label_tensor) - loss.backward() - - # Add parameters' gradients to their values, multiplied by learning rate - for p in self.parameters(): - p.data.add_(p.grad.data, alpha=-learning_rate) - - return output, loss.item() - - def learn_batch(self, training_data, n_iters = 1000, report_every = 100): + def learn(self, training_data, n_epoch = 1000, n_batch_size = 64, report_every = 50, learning_rate = 0.005, criterion = nn.NLLLoss()): """ Learn on a batch of training_data for a specified number of iterations and reporting thresholds """ - # Keep track of losses for plotting current_loss = 0 all_losses = [] + self.train() + optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate) start = time.time() - print(f"training data = {training_data}") - print(f"size = {len(training_data)}") - - for iter in range(1, n_iters + 1): - rand_idx = random.randint(0,len(training_data)-1) - (label_tensor, text_tensor, label, text) = training_data[rand_idx] - - output, loss = self.learn_single(label_tensor, text_tensor) - current_loss += loss - - # Print ``iter`` number, loss, name and guess + print(f"training on data set with n = {len(training_data)}") + + for iter in range(1, n_epoch + 1): + self.zero_grad() # clear the gradients + + # create some minibatches + # we cannot use dataloaders because each of our names is a different length + batches = list(range(len(training_data))) + random.shuffle(batches) + batches = np.array_split(batches, len(batches) //n_batch_size ) + + for idx, batch in enumerate(batches): + batch_loss = 0 + for i in batch: #for each example in this batch + (label_tensor, text_tensor, label, text) = training_data[i] + output = self.forward(text_tensor) + loss = criterion(output, label_tensor) + batch_loss += loss + + # optimize parameters + batch_loss.backward() + nn.utils.clip_grad_norm_(self.parameters(), 3) + optimizer.step() + optimizer.zero_grad() + + current_loss += batch_loss.item() / len(batch) + + all_losses.append(current_loss / len(batches) ) if iter % report_every == 0: - all_losses.append(current_loss / report_every) - print(f"{iter} ({iter / n_iters:.0%}): \t iteration loss = {all_losses[-1]}") - current_loss = 0 + print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}") + current_loss = 0 return all_losses -########################### -#We can test this with one of our examples and see the output vector, loss and guess of a class from a random network. -# -#Here is a single input example +########################################################################## +# We can now train a dataset with mini batches for a specified number of epochs n_hidden = 128 hidden = torch.zeros(1, n_hidden) rnn = RNN(NameData.n_letters, n_hidden, alldata.labels) - -(label_tensor, text_tensor, label, text) = train_set[0] -print(f"training on name = {text} with label = {label}") -(output, loss) = rnn.learn_single(label_tensor, text_tensor) - -print("LogSoftmax outputs (highest score is predicted class") -for i in range(len(output[0])): - print (f"\t{i}. {alldata.labels[i]} => {output[0][i]}") - -########################### -#We can also train on our training data set by randomly selecting examples - - -all_losses = rnn.learn_batch(train_set, n_iters=200000, report_every=10000) +all_losses = rnn.learn(train_set) ###################################################################### # Plotting the Results @@ -513,11 +478,12 @@ def learn_batch(self, training_data, n_iters = 1000, report_every = 100): def evaluate(rnn, testing_data): confusion = torch.zeros(len(rnn.output_labels), len(rnn.output_labels)) - - with torch.no_grad(): # do not record the gradiants during eval phase + + rnn.eval() #set to eval mode + with torch.no_grad(): # do not record the gradiants during eval phase for i in range(len(testing_data)): (label_tensor, text_tensor, label, text) = testing_data[i] - (output, hidden) = rnn.forward_multi(text_tensor) + output = rnn.forward(text_tensor) guess, guess_i = rnn.label_from_output(output) label_i = rnn.output_labels.index(label) confusion[label_i][guess_i] += 1 @@ -541,7 +507,8 @@ def evaluate(rnn, testing_data): ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) # sphinx_gallery_thumbnail_number = 2 - plt.show() + plt.show() + evaluate(rnn, test_set) From 6d08a08405e335488a5b2c4e3eebf485c8ca746c Mon Sep 17 00:00:00 2001 From: Matthew Schultz Date: Wed, 10 Jul 2024 11:57:06 -0400 Subject: [PATCH 04/11] fixing spelling errors, slight change to # of iterations to generate a better confusion matrix --- en-wordlist.txt | 6 ++++++ .../char_rnn_classification_tutorial.py | 12 ++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/en-wordlist.txt b/en-wordlist.txt index b52d8374d3..3b9ecbb44d 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -147,6 +147,8 @@ Minifier MobileNet ModelABC Mypy +NameData +NamesDataset NAS NCCL NCHW @@ -359,6 +361,7 @@ enum eq equalities et +eval evaluateInput extensibility fastai @@ -513,6 +516,7 @@ resnet restride rewinded rgb +rnn rollout rollouts romanized @@ -580,12 +584,14 @@ traceback tradeoff tradeoffs triton +txt uint umap uncomment uncommented underflowing unfused +unicode unimodal unnormalized unoptimized diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index 37d6b33059..3153ee88f2 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -66,8 +66,8 @@ There are two key pieces of this that we will flesh out over the course of this tutorial. First is the basic data object which a label and some text. In this instance, label = the country of origin and text = the name. -However, our data has some issues that we will need to clean up. First off, we need to convert unicode to plain ASCII to -limit the RNN input layers. This is accomplished by converting unicode strings to ASCII and allowing a samll set of allowed characters (allowed_characters) +However, our data has some issues that we will need to clean up. First off, we need to convert Unicode to plain ASCII to +limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing a small set of allowed characters (allowed_characters) """ import torch @@ -218,7 +218,7 @@ def __getitem__(self, idx): print(f"example = {alldata[0]}") ######################### -#Using the dataset object allows us to easily split the data into train and test sets. Here we create na 80/20 +#Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20 #split but the torch.utils.data has more useful utilities. train_set, test_set = torch.utils.data.random_split(alldata, [.8, .2]) @@ -227,7 +227,7 @@ def __getitem__(self, idx): ######################### #Now we have a basic dataset containing 20074 examples where each example is a pairing of label and name. We have also -#split the datset into training and testing so we can validate the model that we build. +#split the dataset into training and testing so we can validate the model that we build. ###################################################################### @@ -397,7 +397,7 @@ def label_from_output(self, output): label_i = top_i[0].item() return self.output_labels[label_i], label_i - def learn(self, training_data, n_epoch = 1000, n_batch_size = 64, report_every = 50, learning_rate = 0.005, criterion = nn.NLLLoss()): + def learn(self, training_data, n_epoch = 250, n_batch_size = 64, report_every = 50, learning_rate = 0.005, criterion = nn.NLLLoss()): """ Learn on a batch of training_data for a specified number of iterations and reporting thresholds """ @@ -480,7 +480,7 @@ def evaluate(rnn, testing_data): confusion = torch.zeros(len(rnn.output_labels), len(rnn.output_labels)) rnn.eval() #set to eval mode - with torch.no_grad(): # do not record the gradiants during eval phase + with torch.no_grad(): # do not record the gradients during eval phase for i in range(len(testing_data)): (label_tensor, text_tensor, label, text) = testing_data[i] output = rnn.forward(text_tensor) From 80804ae23810216949a4db2fc6ece0b966057d69 Mon Sep 17 00:00:00 2001 From: Matthew Schultz Date: Thu, 11 Jul 2024 18:11:06 -0400 Subject: [PATCH 05/11] decreasing training time by 97% (72s on CPU) by tuning hyper parameters, adding device config for CI steps, cleaning up documentatation --- en-wordlist.txt | 1 + .../char_rnn_classification_tutorial.py | 99 +++++++++++-------- 2 files changed, 59 insertions(+), 41 deletions(-) diff --git a/en-wordlist.txt b/en-wordlist.txt index 3b9ecbb44d..f4b5d6d3bc 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -146,6 +146,7 @@ MaskRCNN Minifier MobileNet ModelABC +MPS Mypy NameData NamesDataset diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index 3153ee88f2..f0a9af0530 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -44,33 +44,47 @@ Networks `__ is about LSTMs specifically but also informative about RNNs in general +""" -Preparing the Data -================== - -.. note:: - Download the data from - `here `_ - and extract it to the current directory. - -Included in the ``data/names`` directory are 18 text files named as -``[Language].txt``. Each file contains a bunch of names, one name per -line, mostly romanized (but we still need to convert from Unicode to -ASCII). +###################################################################### +# Preparing Torch +# ========================== +# +# Set up torch to default to the right device use GPU acceleration depending on your hardware (CPU or CUDA). +# -The first thing we need to define is our data items. In this case, we will create a class called NameData -which will have an __init__ function to specify the input fields and some helper functions. Our first -helper function will be __str__ to convert objects to strings for easy printing +import torch +# Check if CUDA is available +device = torch.device('cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') -There are two key pieces of this that we will flesh out over the course of this tutorial. First is the basic data -object which a label and some text. In this instance, label = the country of origin and text = the name. +torch.set_default_device(device) +print(f"Using device = {torch.get_default_device()}") -However, our data has some issues that we will need to clean up. First off, we need to convert Unicode to plain ASCII to -limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing a small set of allowed characters (allowed_characters) -""" +###################################################################### +# Preparing the Data +# ================== +# +# Download the data from `here `__ +# and extract it to the current directory. +# +# Included in the ``data/names`` directory are 18 text files named as +# ``[Language].txt``. Each file contains a bunch of names, one name per +# line, mostly romanized (but we still need to convert from Unicode to +# ASCII). +# +# The first thing we need to define is our data items. In this case, we will create a class called NameData +# which will have an __init__ function to specify the input fields and some helper functions. Our first +# helper function will be __str__ to convert objects to strings for easy printing +# +# There are two key pieces of this that we will flesh out over the course of this tutorial. First is the basic data +# object which a label and some text. In this instance, label = the country of origin and text = the name. +# +# However, our data has some issues that we will need to clean up. First off, we need to convert Unicode to plain ASCII to +# limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing a small set of allowed characters (allowed_characters) -import torch import string import unicodedata @@ -102,7 +116,7 @@ def unicodeToAscii(s): ###################################################################### # Turning Names into Tensors -# -------------------------- +# ========================== # # Now that we have all the names organized, we need to turn them into # Tensors to make any use of them. @@ -119,7 +133,6 @@ def unicodeToAscii(s): # # For this, you'll need to add a couple of capabilities to our NameData object. -import torch import string import unicodedata @@ -157,18 +170,18 @@ def lineToTensor(line): return tensor ######################### -#Here are some examples of how to use the NameData object +# Here are some examples of how to use the NameData object print (f"{NameData(label='none', text='a')}") print (f"{NameData(label='Korean', text='Ahn')}") ######################### -#Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach -#for other RNN tasks with text. +# Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach +# for other RNN tasks with text. # -#Next, we need to combine all our examples into a dataset so we can train, text and validate our models. For this, -#we will use the `Dataset and DataLoader ` classes -#to hold our dataset. Each Dataset needs to implement three functions: __init__, __len__, and __getitem__. +# Next, we need to combine all our examples into a dataset so we can train, text and validate our models. For this, +# we will use the `Dataset and DataLoader ` classes +# to hold our dataset. Each Dataset needs to implement three functions: __init__, __len__, and __getitem__. from io import open import glob @@ -219,9 +232,10 @@ def __getitem__(self, idx): ######################### #Using the dataset object allows us to easily split the data into train and test sets. Here we create a 80/20 -#split but the torch.utils.data has more useful utilities. +#split but the torch.utils.data has more useful utilities. Here we specify a generator since we need to use the +#same device as torch defaults to above. -train_set, test_set = torch.utils.data.random_split(alldata, [.8, .2]) +train_set, test_set = torch.utils.data.random_split(alldata, [.8, .2], generator=torch.Generator(device=device).manual_seed(1)) print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}") @@ -448,7 +462,10 @@ def learn(self, training_data, n_epoch = 250, n_batch_size = 64, report_every = n_hidden = 128 hidden = torch.zeros(1, n_hidden) rnn = RNN(NameData.n_letters, n_hidden, alldata.labels) -all_losses = rnn.learn(train_set) +start = time.time() +all_losses = rnn.learn(train_set, n_epoch=10, learning_rate=0.2, report_every=1) +end = time.time() +print(f"training took {end-start}s") ###################################################################### # Plotting the Results @@ -495,7 +512,7 @@ def evaluate(rnn, testing_data): # Set up plot fig = plt.figure() ax = fig.add_subplot(111) - cax = ax.matshow(confusion.numpy()) + cax = ax.matshow(confusion.cpu().numpy()) #numpy uses cpu here so we need to use a cpu version fig.colorbar(cax) # Set up axes @@ -525,16 +542,16 @@ def evaluate(rnn, testing_data): # Exercises # ========= # -# - Try with a different dataset of line -> label, for example: -# -# - Any word -> language -# - First name -> gender -# - Character name -> writer -# - Page title -> blog or subreddit -# # - Get better results with a bigger and/or better shaped network # +# - Vary the hyperparameters to improve performance (e.g. 250 epochs, batch size, learning rate ) # - Add more linear layers # - Try the ``nn.LSTM`` and ``nn.GRU`` layers # - Combine multiple of these RNNs as a higher level network -# \ No newline at end of file +# +# - Try with a different dataset of line -> label, for example: +# +# - Any word -> language +# - First name -> gender +# - Character name -> writer +# - Page title -> blog or subreddit \ No newline at end of file From 10cfcaace41a883ffe910700b838053864197c3f Mon Sep 17 00:00:00 2001 From: Matthew Schultz Date: Mon, 15 Jul 2024 14:10:38 -0400 Subject: [PATCH 06/11] removing updated by --- intermediate_source/char_rnn_classification_tutorial.py | 1 - 1 file changed, 1 deletion(-) diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index f0a9af0530..121fb9b603 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -3,7 +3,6 @@ NLP From Scratch: Classifying Names with a Character-Level RNN ************************************************************** **Author**: `Sean Robertson `_ -**Updated**: `Matthew Schultz `_ We will be building and training a basic character-level Recurrent Neural Network (RNN) to classify words. This tutorial, along with two other From cf3de687a83b8a1601e8cb3c5db530f746363e55 Mon Sep 17 00:00:00 2001 From: Matthew Schultz Date: Fri, 13 Sep 2024 14:39:57 -0400 Subject: [PATCH 07/11] based on Joel's review of Sept 9th: removing NameData object, combining all RNN definition into one, moving RNN.learn() to separate train() --- .../char_rnn_classification_tutorial.py | 338 ++++++------------ 1 file changed, 117 insertions(+), 221 deletions(-) diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index 121fb9b603..1fde54fb00 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -74,44 +74,28 @@ # line, mostly romanized (but we still need to convert from Unicode to # ASCII). # -# The first thing we need to define is our data items. In this case, we will create a class called NameData -# which will have an __init__ function to specify the input fields and some helper functions. Our first -# helper function will be __str__ to convert objects to strings for easy printing -# -# There are two key pieces of this that we will flesh out over the course of this tutorial. First is the basic data -# object which a label and some text. In this instance, label = the country of origin and text = the name. -# -# However, our data has some issues that we will need to clean up. First off, we need to convert Unicode to plain ASCII to +# The first thing we need to define and clean our data. First off, we need to convert Unicode to plain ASCII to # limit the RNN input layers. This is accomplished by converting Unicode strings to ASCII and allowing a small set of allowed characters (allowed_characters) import string import unicodedata -class NameData: - allowed_characters = string.ascii_letters + " .,;'" - n_letters = len(allowed_characters) - +allowed_characters = string.ascii_letters + " .,;'" +n_letters = len(allowed_characters) - def __init__(self, label, text): - self.label = label - self.text = NameData.unicodeToAscii(text) - - def __str__(self): - return f"label={self.label}, text={self.text}" - - # Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427 - def unicodeToAscii(s): - return ''.join( - c for c in unicodedata.normalize('NFD', s) - if unicodedata.category(c) != 'Mn' - and c in NameData.allowed_characters - ) +# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427 +def unicodeToAscii(s): + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + and c in allowed_characters + ) ######################### -#Now we can use that class to create a singe piece of data. +# Here's an example of converting a unicode alphabet name to plain ASCII. This simplifies the input layer # -print (f"{NameData(label='Polish', text='Ślusàrski')}") +print (f"converting 'Ślusàrski' to {unicodeToAscii('Ślusàrski')}") ###################################################################### # Turning Names into Tensors @@ -129,50 +113,24 @@ def unicodeToAscii(s): # # That extra 1 dimension is because PyTorch assumes everything is in # batches - we're just using a batch size of 1 here. -# -# For this, you'll need to add a couple of capabilities to our NameData object. -import string -import unicodedata - -class NameData: - allowed_characters = string.ascii_letters + " .,;'" - n_letters = len(allowed_characters) +# Find letter index from all_letters, e.g. "a" = 0 +def letterToIndex(letter): + return allowed_characters.find(letter) - - def __init__(self, label, text): - self.label = label - self.text = NameData.unicodeToAscii(text) - self.tensor = NameData.lineToTensor(self.text) - - def __str__(self): - return f"label={self.label}, text={self.text}\ntensor = {self.tensor}" - - # Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427 - def unicodeToAscii(s): - return ''.join( - c for c in unicodedata.normalize('NFD', s) - if unicodedata.category(c) != 'Mn' - and c in NameData.allowed_characters - ) - - # Find letter index from all_letters, e.g. "a" = 0 - def letterToIndex(letter): - return NameData.allowed_characters.find(letter) - - # Turn a line into a , - # or an array of one-hot letter vectors - def lineToTensor(line): - tensor = torch.zeros(len(line), 1, NameData.n_letters) - for li, letter in enumerate(line): - tensor[li][0][NameData.letterToIndex(letter)] = 1 - return tensor +# Turn a line into a , +# or an array of one-hot letter vectors +def lineToTensor(line): + tensor = torch.zeros(len(line), 1, n_letters) + for li, letter in enumerate(line): + tensor[li][0][letterToIndex(letter)] = 1 + return tensor ######################### -# Here are some examples of how to use the NameData object +# Here are some examples of how to use lineToTensor() for a single and multiple character string. -print (f"{NameData(label='none', text='a')}") -print (f"{NameData(label='Korean', text='Ahn')}") +print (f"The letter 'a' becomes {lineToTensor('a')}") #notice that the first position in the tensor = 1 +print (f"The name 'Ahn' becomes {lineToTensor('Ahn')}") #notice 'A' sets the 27th index to 1 ######################### # Congratulations, you have built the foundational tensor objects for this learning task! You can use a similar approach @@ -181,12 +139,9 @@ def lineToTensor(line): # Next, we need to combine all our examples into a dataset so we can train, text and validate our models. For this, # we will use the `Dataset and DataLoader ` classes # to hold our dataset. Each Dataset needs to implement three functions: __init__, __len__, and __getitem__. - from io import open import glob import os -import unicodedata -import string import time import torch @@ -200,6 +155,9 @@ def __init__(self, data_dir): labels_set = set() #set of all classes self.data = [] + self.data_tensors = [] + self.labels = [] + self.labels_tensors = [] #read all the txt files in the specified directory text_files = glob.glob(os.path.join(data_dir, '*.txt')) @@ -208,22 +166,30 @@ def __init__(self, data_dir): labels_set.add(label) lines = open(filename, encoding='utf-8').read().strip().split('\n') for name in lines: - self.data.append(NameData(label=label, text=name)) + self.data.append(name) + self.data_tensors.append(lineToTensor(name)) + self.labels.append(label) - self.labels = list(labels_set) + #Cache the tensor representation of the labels + self.labels_uniq = list(labels_set) + for idx in range(len(self.labels)): + temp_tensor = torch.tensor([self.labels_uniq.index(self.labels[idx])], dtype=torch.long) + self.labels_tensors.append(temp_tensor) def __len__(self): return len(self.data) def __getitem__(self, idx): data_item = self.data[idx] - label_tensor = torch.tensor([self.labels.index(data_item.label)], dtype=torch.long) - return label_tensor, data_item.tensor, data_item.label, data_item.text - + data_label = self.labels[idx] + data_tensor = self.data_tensors[idx] + label_tensor = self.labels_tensors[idx] + + return label_tensor, data_tensor, data_label, data_item -######################### -#Here are some examples of how to use the NamesDataset object +######################### +#Here we can load our example data into the NamesDataset alldata = NamesDataset("data/names") print(f"loaded {len(alldata)} items of data") @@ -257,51 +223,8 @@ def __getitem__(self, idx): # which operate on an input and hidden state, with a ``LogSoftmax`` layer # after the output.s # - -import torch.nn as nn -import torch.nn.functional as F - -class RNN(nn.Module): - def __init__(self, input_size, hidden_size, output_labels): - super(RNN, self).__init__() - - self.hidden_size = hidden_size - self.output_labels = output_labels - - self.i2h = nn.Linear(input_size, hidden_size) - self.h2h = nn.Linear(hidden_size, hidden_size) - self.h2o = nn.Linear(hidden_size, len(output_labels)) - self.softmax = nn.LogSoftmax(dim=1) - - def forward(self, input, hidden): - hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) - output = self.h2o(hidden) - output = self.softmax(output) - return output, hidden - -########################### -#We can then create a RNN with 128 hidden nodes and given our datasets - - -n_hidden = 128 -rnn = RNN(NameData.n_letters, n_hidden, alldata.labels) -print(rnn) - -###################################################################### -# To run a step of this network we need to pass a single character input -# and a hidden state (which we initialize as zeros at first). We'll get to -# multi-character names next - -input = NameData(label='none', text='A').tensor -output, next_hidden = rnn(input[0], torch.zeros(1, n_hidden)) -print(output) - -###################################################################### -# Scoring Multi-character names -# -------------------- -# Multi-character names require just a little bit more effort which is -# keeping track of the hidden output and passing it back into the RNN. -# You can see this updated work defined in the function forward() +# forward() loops through each of the characters in the given tensor, computes each +# layer and then passes the hidden layer onto to the next iteration. import torch.nn as nn import torch.nn.functional as F @@ -330,23 +253,27 @@ def forward(self, line_tensor): return output - def label_from_output(self, output): - top_n, top_i = output.topk(1) - label_i = top_i[0].item() - return self.output_labels[label_i], label_i - ########################### -#Now we can score the output for names! - +#We can then create a RNN with 128 hidden nodes and given our datasets n_hidden = 128 -rnn = RNN(NameData.n_letters, n_hidden, alldata.labels) +rnn = RNN(n_letters, n_hidden, alldata.labels_uniq) +print(rnn) + +###################################################################### +# We can then pass our Tensor to the runn to get a predicted output and +# use a helper function, label_from_output, to get a text label for the class. + +def label_from_output(output, output_labels): + top_n, top_i = output.topk(1) + label_i = top_i[0].item() + return output_labels[label_i], label_i -input = NameData(label='none', text='Albert').tensor +input = lineToTensor('Albert') output = rnn(input) #this is equivalent to output = rnn.forward(input) print(output) -print(rnn.label_from_output(output)) +print(label_from_output(output, alldata.labels_uniq)) ###################################################################### # @@ -374,95 +301,61 @@ def label_from_output(self, output): # - Back-propagate # - Return the output and loss # -# We also define a learn() function which trains on a given dataset with minibatches +# We do this by defining a learn() function which trains on a given dataset with minibatches -import torch.nn as nn -import torch.nn.functional as F import random import numpy as np -class RNN(nn.Module): - def __init__(self, input_size, hidden_size, output_labels): - super(RNN, self).__init__() - - self.hidden_size = hidden_size - self.output_labels = output_labels - - self.i2h = nn.Linear(input_size, hidden_size) - self.h2h = nn.Linear(hidden_size, hidden_size) - self.h2o = nn.Linear(hidden_size, len(output_labels)) - self.softmax = nn.LogSoftmax(dim=1) - - def forward(self, line_tensor): - hidden = torch.zeros(1, rnn.hidden_size) - output = torch.zeros(1, len(self.output_labels)) - - for i in range(line_tensor.size()[0]): - input = line_tensor[i] - hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) - output = self.h2o(hidden) - output = self.softmax(output) - - return output - - def label_from_output(self, output): - top_n, top_i = output.topk(1) - label_i = top_i[0].item() - return self.output_labels[label_i], label_i - - def learn(self, training_data, n_epoch = 250, n_batch_size = 64, report_every = 50, learning_rate = 0.005, criterion = nn.NLLLoss()): - """ - Learn on a batch of training_data for a specified number of iterations and reporting thresholds - """ - # Keep track of losses for plotting - current_loss = 0 - all_losses = [] - self.train() - optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate) - - start = time.time() - print(f"training on data set with n = {len(training_data)}") - - for iter in range(1, n_epoch + 1): - self.zero_grad() # clear the gradients - - # create some minibatches - # we cannot use dataloaders because each of our names is a different length - batches = list(range(len(training_data))) - random.shuffle(batches) - batches = np.array_split(batches, len(batches) //n_batch_size ) - - for idx, batch in enumerate(batches): - batch_loss = 0 - for i in batch: #for each example in this batch - (label_tensor, text_tensor, label, text) = training_data[i] - output = self.forward(text_tensor) - loss = criterion(output, label_tensor) - batch_loss += loss - - # optimize parameters - batch_loss.backward() - nn.utils.clip_grad_norm_(self.parameters(), 3) - optimizer.step() - optimizer.zero_grad() - - current_loss += batch_loss.item() / len(batch) - - all_losses.append(current_loss / len(batches) ) - if iter % report_every == 0: - print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}") - current_loss = 0 +def train(rnn, training_data, n_epoch = 250, n_batch_size = 64, report_every = 50, learning_rate = 0.005, criterion = nn.NLLLoss()): + """ + Learn on a batch of training_data for a specified number of iterations and reporting thresholds + """ + # Keep track of losses for plotting + current_loss = 0 + all_losses = [] + rnn.train() + optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate) + + start = time.time() + print(f"training on data set with n = {len(training_data)}") + + for iter in range(1, n_epoch + 1): + rnn.zero_grad() # clear the gradients + + # create some minibatches + # we cannot use dataloaders because each of our names is a different length + batches = list(range(len(training_data))) + random.shuffle(batches) + batches = np.array_split(batches, len(batches) //n_batch_size ) + + for idx, batch in enumerate(batches): + batch_loss = 0 + for i in batch: #for each example in this batch + (label_tensor, text_tensor, label, text) = training_data[i] + output = rnn.forward(text_tensor) + loss = criterion(output, label_tensor) + batch_loss += loss + + # optimize parameters + batch_loss.backward() + nn.utils.clip_grad_norm_(rnn.parameters(), 3) + optimizer.step() + optimizer.zero_grad() + + current_loss += batch_loss.item() / len(batch) - return all_losses + all_losses.append(current_loss / len(batches) ) + if iter % report_every == 0: + print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}") + current_loss = 0 + + return all_losses ########################################################################## # We can now train a dataset with mini batches for a specified number of epochs -n_hidden = 128 -hidden = torch.zeros(1, n_hidden) -rnn = RNN(NameData.n_letters, n_hidden, alldata.labels) start = time.time() -all_losses = rnn.learn(train_set, n_epoch=10, learning_rate=0.2, report_every=1) +all_losses = train(rnn, train_set, n_epoch=10, learning_rate=0.2, report_every=1) end = time.time() print(f"training took {end-start}s") @@ -492,21 +385,23 @@ def learn(self, training_data, n_epoch = 250, n_batch_size = 64, report_every = # ``evaluate()``, which is the same as ``train()`` minus the backprop. # -def evaluate(rnn, testing_data): - confusion = torch.zeros(len(rnn.output_labels), len(rnn.output_labels)) +def evaluate(rnn, testing_data, classes): + confusion = torch.zeros(len(classes), len(classes)) rnn.eval() #set to eval mode with torch.no_grad(): # do not record the gradients during eval phase for i in range(len(testing_data)): (label_tensor, text_tensor, label, text) = testing_data[i] - output = rnn.forward(text_tensor) - guess, guess_i = rnn.label_from_output(output) - label_i = rnn.output_labels.index(label) + output = rnn(text_tensor) + guess, guess_i = label_from_output(output, classes) + label_i = classes.index(label) confusion[label_i][guess_i] += 1 # Normalize by dividing every row by its sum - for i in range(len(rnn.output_labels)): - confusion[i] = confusion[i] / confusion[i].sum() + for i in range(len(classes)): + denom = confusion[i].sum() + if denom > 0: + confusion[i] = confusion[i] / denom # Set up plot fig = plt.figure() @@ -515,8 +410,8 @@ def evaluate(rnn, testing_data): fig.colorbar(cax) # Set up axes - ax.set_xticklabels([''] + rnn.output_labels, rotation=90) - ax.set_yticklabels([''] + rnn.output_labels) + ax.set_xticklabels([''] + classes, rotation=90) + ax.set_yticklabels([''] + classes) # Force label at every tick ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) @@ -526,8 +421,9 @@ def evaluate(rnn, testing_data): plt.show() -evaluate(rnn, test_set) +evaluate(rnn, test_set, classes=alldata.labels_uniq) + ###################################################################### # You can pick out bright spots off the main axis that show which From 3d2bb5721d28ea27738b4172431b468b14f80c1a Mon Sep 17 00:00:00 2001 From: Matthew Schultz Date: Fri, 13 Sep 2024 15:04:30 -0400 Subject: [PATCH 08/11] fixing training description to focus on the single train() function rather than building it up --- .../char_rnn_classification_tutorial.py | 20 +++++-------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index 1fde54fb00..de267d2761 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -287,26 +287,16 @@ def label_from_output(output, output_labels): # # Now all it takes to train this network is show it a bunch of examples, # have it make guesses, and tell it if it's wrong. -# -# We start by defining a function learn_single() which learns from a single -# piece of input data. -# -# - Create input and target tensors -# - Create a zeroed initial hidden state -# - Read each letter in and -# -# - Keep hidden state for next letter -# -# - Compare final output to target -# - Back-propagate -# - Return the output and loss # -# We do this by defining a learn() function which trains on a given dataset with minibatches +# We do this by defining a train() function which trains on a given dataset with minibatches. RNNs +# train similar to other networks so for completeness we include a batched training method here. +# The loop (for i in batch) computes the losses for each of the items in the batch before adjusting the +# weights. This is repeated until the number of epochs is reached. import random import numpy as np -def train(rnn, training_data, n_epoch = 250, n_batch_size = 64, report_every = 50, learning_rate = 0.005, criterion = nn.NLLLoss()): +def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50, learning_rate = 0.2, criterion = nn.NLLLoss()): """ Learn on a batch of training_data for a specified number of iterations and reporting thresholds """ From e08b4e378022c3b7c3ad726e7f5bc1b46fd2125b Mon Sep 17 00:00:00 2001 From: Matthew Schultz Date: Mon, 16 Sep 2024 21:33:37 -0400 Subject: [PATCH 09/11] updating tutorial to use nn.rnn in composition --- en-wordlist.txt | 4 ++ .../char_rnn_classification_tutorial.py | 49 ++++++++----------- 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/en-wordlist.txt b/en-wordlist.txt index cc687c3926..580439c06e 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -33,6 +33,7 @@ Captum Captum's CartPole Cayley +CharRNN Chatbots Chen Colab @@ -421,12 +422,14 @@ jpg json judgements jupyter +kernals keypoint kwargs labelled latencies learnable learnings +lineToTensor linearities loadFilename logits @@ -460,6 +463,7 @@ namespace natively ndarrays nightlies +nn num numericalize numpy diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index de267d2761..2796c8b0c2 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -218,51 +218,42 @@ def __getitem__(self, idx): # held hidden state and gradients which are now entirely handled by the # graph itself. This means you can implement a RNN in a very "pure" way, # as regular feed-forward layers. +# +# This CharRNN class implements an RNN with three components. +# First, we use the `nn.RNN implemnentation `__ +# , next we define a layer that maps the RNN hidden layers to our output and finally we apply a softmax. Using nn.RNN +# leads to a significant improvement in performance (e.g. cuDNN-accelerated kernals) versus implementing +# each layer as a nn.Linear. It also simplifies the implementation in forward(). # -# This RNN module implements a "vanilla RNN" an is just 3 linear layers -# which operate on an input and hidden state, with a ``LogSoftmax`` layer -# after the output.s -# -# forward() loops through each of the characters in the given tensor, computes each -# layer and then passes the hidden layer onto to the next iteration. import torch.nn as nn import torch.nn.functional as F -class RNN(nn.Module): - def __init__(self, input_size, hidden_size, output_labels): - super(RNN, self).__init__() +class CharRNN(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(CharRNN, self).__init__() - self.hidden_size = hidden_size - self.output_labels = output_labels - - self.i2h = nn.Linear(input_size, hidden_size) - self.h2h = nn.Linear(hidden_size, hidden_size) - self.h2o = nn.Linear(hidden_size, len(output_labels)) + self.rnn = nn.RNN(input_size, hidden_size) + self.h2o = nn.Linear(hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=1) def forward(self, line_tensor): - hidden = torch.zeros(1, rnn.hidden_size) - output = torch.zeros(1, len(self.output_labels)) - - for i in range(line_tensor.size()[0]): - input = line_tensor[i] - hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) - output = self.h2o(hidden) - output = self.softmax(output) + rnn_out, hidden = self.rnn(line_tensor) + output = self.h2o(hidden[0]) + output = self.softmax(output) return output ########################### -#We can then create a RNN with 128 hidden nodes and given our datasets +#We can then create a RNN with 57 input nodes, 128 hidden nodes and 18 outputs. n_hidden = 128 -rnn = RNN(n_letters, n_hidden, alldata.labels_uniq) +rnn = CharRNN(n_letters, n_hidden, len(alldata.labels_uniq)) print(rnn) ###################################################################### -# We can then pass our Tensor to the runn to get a predicted output and +# We can then pass our Tensor to the RNN to get a predicted output and # use a helper function, label_from_output, to get a text label for the class. def label_from_output(output, output_labels): @@ -345,7 +336,7 @@ def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50 # We can now train a dataset with mini batches for a specified number of epochs start = time.time() -all_losses = train(rnn, train_set, n_epoch=10, learning_rate=0.2, report_every=1) +all_losses = train(rnn, train_set, n_epoch=13, learning_rate=0.2, report_every=1) end = time.time() print(f"training took {end-start}s") @@ -429,9 +420,9 @@ def evaluate(rnn, testing_data, classes): # # - Get better results with a bigger and/or better shaped network # -# - Vary the hyperparameters to improve performance (e.g. 250 epochs, batch size, learning rate ) -# - Add more linear layers +# - Vary the hyperparameters to improve performance (e.g. change epochs, batch size, learning rate ) # - Try the ``nn.LSTM`` and ``nn.GRU`` layers +# - Change the size of the layers (e.g. fewer or more hidden nodes, additional linear layers) # - Combine multiple of these RNNs as a higher level network # # - Try with a different dataset of line -> label, for example: From 8258deb75c3dd26a9881be2e2acc7a28a4e2878d Mon Sep 17 00:00:00 2001 From: Matthew Schultz Date: Mon, 23 Sep 2024 20:32:32 -0400 Subject: [PATCH 10/11] Update intermediate_source/char_rnn_classification_tutorial.py Co-authored-by: Joel Schlosser <75754324+jbschlosser@users.noreply.github.com> --- intermediate_source/char_rnn_classification_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index 2796c8b0c2..95c36baeeb 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -220,7 +220,7 @@ def __getitem__(self, idx): # as regular feed-forward layers. # # This CharRNN class implements an RNN with three components. -# First, we use the `nn.RNN implemnentation `__ +# First, we use the `nn.RNN implementation `__ # , next we define a layer that maps the RNN hidden layers to our output and finally we apply a softmax. Using nn.RNN # leads to a significant improvement in performance (e.g. cuDNN-accelerated kernals) versus implementing # each layer as a nn.Linear. It also simplifies the implementation in forward(). From 3f11bc128f253996bb6d4b051a5b84806afc9b2f Mon Sep 17 00:00:00 2001 From: Matthew Schultz Date: Tue, 24 Sep 2024 21:38:21 -0400 Subject: [PATCH 11/11] tuning the results to show more of a diagonal on confusion matrix.. Changed epochs, training rate, more of split to training data --- intermediate_source/char_rnn_classification_tutorial.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/intermediate_source/char_rnn_classification_tutorial.py b/intermediate_source/char_rnn_classification_tutorial.py index 95c36baeeb..f62124dbb0 100644 --- a/intermediate_source/char_rnn_classification_tutorial.py +++ b/intermediate_source/char_rnn_classification_tutorial.py @@ -200,7 +200,7 @@ def __getitem__(self, idx): #split but the torch.utils.data has more useful utilities. Here we specify a generator since we need to use the #same device as torch defaults to above. -train_set, test_set = torch.utils.data.random_split(alldata, [.8, .2], generator=torch.Generator(device=device).manual_seed(1)) +train_set, test_set = torch.utils.data.random_split(alldata, [.85, .15], generator=torch.Generator(device=device).manual_seed(2024)) print(f"train examples = {len(train_set)}, validation examples = {len(test_set)}") @@ -336,7 +336,7 @@ def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50 # We can now train a dataset with mini batches for a specified number of epochs start = time.time() -all_losses = train(rnn, train_set, n_epoch=13, learning_rate=0.2, report_every=1) +all_losses = train(rnn, train_set, n_epoch=55, learning_rate=0.15, report_every=5) end = time.time() print(f"training took {end-start}s")