-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d83d78a
commit 1432715
Showing
6 changed files
with
152 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
env/ |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
# epoch | ||
epoch = np.array([0, 2, 5, 7, 10]) | ||
|
||
# LR = 0.001 | ||
avg_accu1 = np.array([0.2682, 0.9494, 0.9514, 0.9624, 0.968]) | ||
|
||
# LR = 0.009 | ||
avg_accu2 = np.array([0.2812, 0.8972, 0.906, 0.8938, 0.9092]) | ||
|
||
# LR = 1 | ||
avg_accu3 = np.array([0.1008, 0.1038, 0.0958, 0.1068, 0.0976]) | ||
|
||
# No of Epoch vs Avg Accuracy | ||
plt.plot(epoch, avg_accu1, marker = 'x', label="LR=0.001") | ||
plt.plot(epoch, avg_accu2, marker = 'x', label="LR=0.009") | ||
plt.plot(epoch, avg_accu3, marker = 'x', label="LR=1") | ||
|
||
plt.legend() | ||
plt.title("No of Epoch vs Avg Accuracy") | ||
plt.xlabel("No of Epoch") | ||
plt.ylabel("Avg Accuracy") | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import os | ||
import torch as t | ||
import torchvision.datasets as datasets | ||
import torchvision.transforms as transforms | ||
import torch.nn as nn | ||
import matplotlib.pyplot as plt | ||
|
||
print(t.__version__) | ||
|
||
os.environ['KMP_DUPLICATE_LIB_OK']='True' | ||
|
||
# the data | ||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]) | ||
|
||
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) | ||
train_loader = t.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True) | ||
|
||
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) | ||
test_loader = t.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True) | ||
|
||
# len(mnist_trainset) | ||
# len(mnist_testset) | ||
|
||
# the model | ||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.linear1 = nn.Linear(28*28, 100) # input layer | ||
self.linear2 = nn.Linear(100, 50) # hidden layer | ||
self.final = nn.Linear(50, 10) # output layer | ||
self.relu = nn.ReLU() # piecewise linear function | ||
|
||
# convert + flatten | ||
def forward(self, img): | ||
x = img.view(-1, 28*28) # reshape the image for the model | ||
x = self.relu(self.linear1(x)) | ||
x = self.relu(self.linear2(x)) | ||
x = self.final(x) | ||
return x | ||
|
||
net = Net() | ||
|
||
# loss function | ||
cross_en_loss = nn.CrossEntropyLoss() | ||
optimiser = t.optim.Adam(net.parameters(), lr=1) # e-1 | ||
epoch = 10 | ||
|
||
for epoch in range(epoch): | ||
net.train() | ||
|
||
for data in train_loader: | ||
x, y = data # x=features, y=targets | ||
optimiser.zero_grad() # set gradient to 0 before each loss calc | ||
output = net(x.view(-1, 28*28)) # pass in reshaped batch | ||
loss = cross_en_loss(output, y) # cal and grab the loss value | ||
loss.backward() # apply loss back through the network's parameters | ||
optimiser.step() # optimise weights to account for loss and gradients | ||
|
||
|
||
# evaluating our dataset | ||
correct = 0 | ||
total = 0 | ||
with t.no_grad(): | ||
for data in test_loader: | ||
x, y = data | ||
output = net(x.view(-1, 784)) | ||
for idx, i in enumerate(output): | ||
if t.argmax(i) == y[idx]: | ||
correct += 1 | ||
total += 1 | ||
print(f"accuracy: {round(correct/total, 3)}") | ||
|
||
# visualization | ||
plt.imshow(x[3].view(28, 28)) | ||
plt.show() | ||
print(t.argmax(net(x[3].view(-1, 784))[0])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
certifi==2022.12.7 | ||
charset-normalizer==2.1.1 | ||
contourpy==1.1.0 | ||
cycler==0.11.0 | ||
filelock==3.9.0 | ||
fonttools==4.41.1 | ||
idna==3.4 | ||
importlib-resources==6.0.0 | ||
Jinja2==3.1.2 | ||
kiwisolver==1.4.4 | ||
MarkupSafe==2.1.2 | ||
matplotlib==3.7.2 | ||
mpmath==1.2.1 | ||
networkx==3.0 | ||
numpy==1.24.1 | ||
packaging==23.1 | ||
Pillow==9.3.0 | ||
pyparsing==3.0.9 | ||
python-dateutil==2.8.2 | ||
requests==2.28.1 | ||
six==1.16.0 | ||
sympy==1.11.1 | ||
torch==2.0.1+cpu | ||
torchaudio==2.0.2+cpu | ||
torchvision==0.15.2+cpu | ||
typing-extensions==4.4.0 | ||
urllib3==1.26.13 | ||
zipp==3.16.2 |