Skip to content

Commit

Permalink
new changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sc0rp10n-py committed Jul 22, 2023
1 parent d83d78a commit 1432715
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
env/
Binary file added 3 Layered NN task by Pragyan Yadav.docx
Binary file not shown.
26 changes: 26 additions & 0 deletions graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np
import matplotlib.pyplot as plt

# epoch
epoch = np.array([0, 2, 5, 7, 10])

# LR = 0.001
avg_accu1 = np.array([0.2682, 0.9494, 0.9514, 0.9624, 0.968])

# LR = 0.009
avg_accu2 = np.array([0.2812, 0.8972, 0.906, 0.8938, 0.9092])

# LR = 1
avg_accu3 = np.array([0.1008, 0.1038, 0.0958, 0.1068, 0.0976])

# No of Epoch vs Avg Accuracy
plt.plot(epoch, avg_accu1, marker = 'x', label="LR=0.001")
plt.plot(epoch, avg_accu2, marker = 'x', label="LR=0.009")
plt.plot(epoch, avg_accu3, marker = 'x', label="LR=1")

plt.legend()
plt.title("No of Epoch vs Avg Accuracy")
plt.xlabel("No of Epoch")
plt.ylabel("Avg Accuracy")

plt.show()
30 changes: 21 additions & 9 deletions main.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "ab9744fb",
"metadata": {},
"outputs": [],
Expand All @@ -17,15 +17,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "74c30637",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.8.1+cu111\n"
"2.0.1+cpu\n"
]
}
],
Expand All @@ -35,7 +35,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "9349edc3",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -205,7 +205,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 5,
"id": "5175db9f",
"metadata": {},
"outputs": [],
Expand All @@ -232,10 +232,22 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"id": "c7060f5a",
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "NameError",
"evalue": "name 'train_loader' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[39mfor\u001b[39;00m epoch \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(epoch):\n\u001b[1;32m 7\u001b[0m net\u001b[39m.\u001b[39mtrain()\n\u001b[0;32m----> 9\u001b[0m \u001b[39mfor\u001b[39;00m data \u001b[39min\u001b[39;00m train_loader:\n\u001b[1;32m 10\u001b[0m x, y \u001b[39m=\u001b[39m data \u001b[39m# x=features, y=targets\u001b[39;00m\n\u001b[1;32m 11\u001b[0m optimiser\u001b[39m.\u001b[39mzero_grad() \u001b[39m# set gradient to 0 before each loss calc\u001b[39;00m\n",
"\u001b[0;31mNameError\u001b[0m: name 'train_loader' is not defined"
]
}
],
"source": [
"# loss function\n",
"cross_en_loss = nn.CrossEntropyLoss()\n",
Expand Down Expand Up @@ -291,7 +303,7 @@
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAANx0lEQVR4nO3df6zV9X3H8ddLflnxR6AWvAU2W8RO12R0u+qM22Jjx5CYgax2ZVnDFiPtWrM267KZLpluWRqytroua+1QGNS1dmatEyvrSkgz49o4rw4FhgpzgAgBLVmBTfn53h/3a3OL93zO5XzPL3g/H8nJOef7Pt/zfecLr/s953zO93wcEQJw9jun1w0A6A7CDiRB2IEkCDuQBGEHkhjfzY1N9KQ4V5O7uUkglTf0vzoaRzxarVbYbc+X9EVJ4yTdHxHLS48/V5N1jW+os0kABU/Ghoa1ll/G2x4n6UuSbpR0paQltq9s9fkAdFad9+xXS9oeES9FxFFJ35C0sD1tAWi3OmGfIenlEfd3V8t+gu1ltodsDx3TkRqbA1BHnbCP9iHAW757GxErImIwIgYnaFKNzQGoo07Yd0uaNeL+TEl76rUDoFPqhP0pSXNsv8v2REkflrS2PW0BaLeWh94i4rjt2yX9i4aH3lZFxJa2dQagrWqNs0fEOknr2tQLgA7i67JAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJNHVKZuBbtr559c2rD229HPFdW//9duK9ZPPbm2pp17iyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDOjrPWB258pmHtPJfXPXHexGK9yep9qVbYbe+QdEjSCUnHI2KwHU0BaL92HNnfHxGvteF5AHQQ79mBJOqGPSR91/bTtpeN9gDby2wP2R46piM1NwegVXVfxl8XEXtsT5O03vbzEfH4yAdExApJKyTpQk+NmtsD0KJaR/aI2FNd75f0sKSr29EUgPZrOey2J9u+4M3bkuZJ2tyuxgC0V52X8dMlPWz7zef5ekR8py1dAWPw+qLyC8nlA3/dsPb1Q5cX1/UPnm2pp37Wctgj4iVJP9fGXgB0EENvQBKEHUiCsANJEHYgCcIOJMEprjhjvTq3/N93kic0rP3DK+UTNMdrV0s99TOO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPs6FvnvPdnivUv/vZ9LT/3a4/OLNYvYZwdwJmKsANJEHYgCcIOJEHYgSQIO5AEYQeSYJz9DHD4lmuK9YvWP9+wduJ/ftTudrpm7/unFuvzzjtWrK8+ONCwNuPvG+8zaXha4rMNR3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9j7w4t9eVaw/f9PfFOs/+6+3NazN/q2NrbTUFeOumFOsf/zj/1Ssn4iTxfo9X/lgw9olP/x+cd2zUdMju+1Vtvfb3jxi2VTb621vq66ndLZNAHWN5WX8aknzT1l2h6QNETFH0obqPoA+1jTsEfG4pAOnLF4oaU11e42kRe1tC0C7tfoB3fSI2CtJ1fW0Rg+0vcz2kO2hYzrS4uYA1NXxT+MjYkVEDEbE4ARN6vTmADTQatj32R6QpOp6f/taAtAJrYZ9raSl1e2lkh5pTzsAOqXpOLvtByVdL+li27sl3SlpuaSHbN8qaZekWzrZZL/z+PJufPGvfqFY337TvcX6/hPlzzre/s9vK9b71Y4PvqNY/90LXy7WVx+cUay/c+WmhrXyCP3ZqWnYI2JJg9INbe4FQAfxdVkgCcIOJEHYgSQIO5AEYQeS4BTXMSoNr714T3lo7YWbv9zs2YvVBcv/qFif9kD/nq559NcGG9YWLX6i1nP/xb/dVKxffmio1vOfbTiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLOP0aHFjceLX1j8pVrPfdnajxXr71n5dLEetbbeWftue6Nh7c+m/Udx3YcON/y1M0nSFct/WKyfjdMu18GRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJy9cs7cK4v1r37u843X1XnFdTe8Xp4J5/Lf+/divZ/H0V9feHWx/uhVdzesNdtvn139m8X6zO39ex5/P+LIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM5eeWN6ecz3p8Y3nhb5ZJOR8G1HLynWX/y78u/ON/POdY3/GSce7OxZ3Vf9afm32Uv77Te231hcd9bdZ+55/P2o6ZHd9irb+21vHrHsLtuv2N5YXRZ0tk0AdY3lZfxqSfNHWX5PRMytLuva2xaAdmsa9oh4XNKBLvQCoIPqfEB3u+3nqpf5Uxo9yPYy20O2h47pSI3NAaij1bDfK2m2pLmS9kr6QqMHRsSKiBiMiMEJKp8QAqBzWgp7ROyLiBMRcVLSfZLKpz4B6LmWwm57YMTdmyVtbvRYAP2h6Ti77QclXS/pYtu7Jd0p6XrbczU81LlD0kc712J3jP+/8nj04ZONP284/5zy25NlF+0o1j82775ivdk4vuaVy3Wc02Tu+Ga9/ehk49+N3/GPs4vrTj/C+ert1DTsEbFklMUrO9ALgA7i67JAEoQdSIKwA0kQdiAJwg4k4YjunSh4oafGNb6ha9trpzduavy9oV2LT3axk9N0tPz3fMrGccX6st9fW6zfetGuYn3+1psb1sZ/oLwuTt+TsUEH48Co46Uc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCX5KeozO/XbjaZUv/3YXG+myd//h/lrrv/bozIa1S8Q4ezdxZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnT+6/P3ttsT7vvI3F+pajR4v1GY/tbVjr7GTSOBVHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH25C67dmexfiLKv4m/+IE/KNYv3f6D0+4JndH0yG57lu3v2d5qe4vtT1bLp9peb3tbdT2l8+0CaNVYXsYfl/TpiLhC0i9K+oTtKyXdIWlDRMyRtKG6D6BPNQ17ROyNiGeq24ckbZU0Q9JCSWuqh62RtKhDPQJog9P6gM72pZLeJ+lJSdMjYq80/AdB0rQG6yyzPWR76JiO1GwXQKvGHHbb50v6pqRPRcTBsa4XESsiYjAiBidoUis9AmiDMYXd9gQNB/1rEfGtavE+2wNVfUBSvZ8hBdBRTYfebFvSSklbI+LuEaW1kpZKWl5dP9KRDlHLuPdcVqx/ZfbqYv3OV8unwF52/+5i/Xixim4ayzj7dZI+ImmT7Y3Vss9oOOQP2b5V0i5Jt3SkQwBt0TTsEfGEpFEnd5d0Q3vbAdApfF0WSIKwA0kQdiAJwg4kQdiBJDjF9Sy340Ojfov5xwbGva1Yf+z+Xy7Wp+38/mn3hN7gyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDOfpa7YGfUWn/gO42nXJaYdvlMwpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwRL1x2NNxoafGNeYHaYFOeTI26GAcGPXXoDmyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASTcNue5bt79neanuL7U9Wy++y/YrtjdVlQefbBdCqsfx4xXFJn46IZ2xfIOlp2+ur2j0R8fnOtQegXcYyP/teSXur24dsb5U0o9ONAWiv03rPbvtSSe+T9GS16Hbbz9leZXtKg3WW2R6yPXRMR+p1C6BlYw677fMlfVPSpyLioKR7Jc2WNFfDR/4vjLZeRKyIiMGIGJygSfU7BtCSMYXd9gQNB/1rEfEtSYqIfRFxIiJOSrpP0tWdaxNAXWP5NN6SVkraGhF3j1g+MOJhN0va3P72ALTLWD6Nv07SRyRtsr2xWvYZSUtsz5UUknZI+mgH+gPQJmP5NP4JSaOdH7uu/e0A6BS+QQckQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiiq1M2235V0s4Riy6W9FrXGjg9/dpbv/Yl0Vur2tnbT0fEO0YrdDXsb9m4PRQRgz1roKBfe+vXviR6a1W3euNlPJAEYQeS6HXYV/R4+yX92lu/9iXRW6u60ltP37MD6J5eH9kBdAlhB5LoSdhtz7f9gu3ttu/oRQ+N2N5he1M1DfVQj3tZZXu/7c0jlk21vd72tup61Dn2etRbX0zjXZhmvKf7rtfTn3f9PbvtcZJelPSrknZLekrSkoj4z6420oDtHZIGI6LnX8Cw/SuSDkv6akS8t1r2l5IORMTy6g/llIj44z7p7S5Jh3s9jXc1W9HAyGnGJS2S9Dvq4b4r9PUhdWG/9eLIfrWk7RHxUkQclfQNSQt70Effi4jHJR04ZfFCSWuq22s0/J+l6xr01hciYm9EPFPdPiTpzWnGe7rvCn11RS/CPkPSyyPu71Z/zfcekr5r+2nby3rdzCimR8Reafg/j6RpPe7nVE2n8e6mU6YZ75t918r053X1IuyjTSXVT+N/10XEz0u6UdInqperGJsxTePdLaNMM94XWp3+vK5ehH23pFkj7s+UtKcHfYwqIvZU1/slPaz+m4p635sz6FbX+3vcz4/10zTeo00zrj7Yd72c/rwXYX9K0hzb77I9UdKHJa3tQR9vYXty9cGJbE+WNE/9NxX1WklLq9tLJT3Sw15+Qr9M491omnH1eN/1fPrziOj6RdICDX8i/1+S/qQXPTTo692Snq0uW3rdm6QHNfyy7piGXxHdKuntkjZI2lZdT+2j3h6QtEnScxoO1kCPevslDb81fE7SxuqyoNf7rtBXV/YbX5cFkuAbdEAShB1IgrADSRB2IAnCDiRB2IEkCDuQxP8DmoYLeg2e4iYAAAAASUVORK5CYII=\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAANx0lEQVR4nO3df6zV9X3H8ddLflnxR6AWvAU2W8RO12R0u+qM22Jjx5CYgax2ZVnDFiPtWrM267KZLpluWRqytroua+1QGNS1dmatEyvrSkgz49o4rw4FhgpzgAgBLVmBTfn53h/3a3OL93zO5XzPL3g/H8nJOef7Pt/zfecLr/s953zO93wcEQJw9jun1w0A6A7CDiRB2IEkCDuQBGEHkhjfzY1N9KQ4V5O7uUkglTf0vzoaRzxarVbYbc+X9EVJ4yTdHxHLS48/V5N1jW+os0kABU/Ghoa1ll/G2x4n6UuSbpR0paQltq9s9fkAdFad9+xXS9oeES9FxFFJ35C0sD1tAWi3OmGfIenlEfd3V8t+gu1ltodsDx3TkRqbA1BHnbCP9iHAW757GxErImIwIgYnaFKNzQGoo07Yd0uaNeL+TEl76rUDoFPqhP0pSXNsv8v2REkflrS2PW0BaLeWh94i4rjt2yX9i4aH3lZFxJa2dQagrWqNs0fEOknr2tQLgA7i67JAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJNHVKZuBbtr559c2rD229HPFdW//9duK9ZPPbm2pp17iyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDOjrPWB258pmHtPJfXPXHexGK9yep9qVbYbe+QdEjSCUnHI2KwHU0BaL92HNnfHxGvteF5AHQQ79mBJOqGPSR91/bTtpeN9gDby2wP2R46piM1NwegVXVfxl8XEXtsT5O03vbzEfH4yAdExApJKyTpQk+NmtsD0KJaR/aI2FNd75f0sKSr29EUgPZrOey2J9u+4M3bkuZJ2tyuxgC0V52X8dMlPWz7zef5ekR8py1dAWPw+qLyC8nlA3/dsPb1Q5cX1/UPnm2pp37Wctgj4iVJP9fGXgB0EENvQBKEHUiCsANJEHYgCcIOJMEprjhjvTq3/N93kic0rP3DK+UTNMdrV0s99TOO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPs6FvnvPdnivUv/vZ9LT/3a4/OLNYvYZwdwJmKsANJEHYgCcIOJEHYgSQIO5AEYQeSYJz9DHD4lmuK9YvWP9+wduJ/ftTudrpm7/unFuvzzjtWrK8+ONCwNuPvG+8zaXha4rMNR3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9j7w4t9eVaw/f9PfFOs/+6+3NazN/q2NrbTUFeOumFOsf/zj/1Ssn4iTxfo9X/lgw9olP/x+cd2zUdMju+1Vtvfb3jxi2VTb621vq66ndLZNAHWN5WX8aknzT1l2h6QNETFH0obqPoA+1jTsEfG4pAOnLF4oaU11e42kRe1tC0C7tfoB3fSI2CtJ1fW0Rg+0vcz2kO2hYzrS4uYA1NXxT+MjYkVEDEbE4ARN6vTmADTQatj32R6QpOp6f/taAtAJrYZ9raSl1e2lkh5pTzsAOqXpOLvtByVdL+li27sl3SlpuaSHbN8qaZekWzrZZL/z+PJufPGvfqFY337TvcX6/hPlzzre/s9vK9b71Y4PvqNY/90LXy7WVx+cUay/c+WmhrXyCP3ZqWnYI2JJg9INbe4FQAfxdVkgCcIOJEHYgSQIO5AEYQeS4BTXMSoNr714T3lo7YWbv9zs2YvVBcv/qFif9kD/nq559NcGG9YWLX6i1nP/xb/dVKxffmio1vOfbTiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLOP0aHFjceLX1j8pVrPfdnajxXr71n5dLEetbbeWftue6Nh7c+m/Udx3YcON/y1M0nSFct/WKyfjdMu18GRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJy9cs7cK4v1r37u843X1XnFdTe8Xp4J5/Lf+/divZ/H0V9feHWx/uhVdzesNdtvn139m8X6zO39ex5/P+LIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM5eeWN6ecz3p8Y3nhb5ZJOR8G1HLynWX/y78u/ON/POdY3/GSce7OxZ3Vf9afm32Uv77Te231hcd9bdZ+55/P2o6ZHd9irb+21vHrHsLtuv2N5YXRZ0tk0AdY3lZfxqSfNHWX5PRMytLuva2xaAdmsa9oh4XNKBLvQCoIPqfEB3u+3nqpf5Uxo9yPYy20O2h47pSI3NAaij1bDfK2m2pLmS9kr6QqMHRsSKiBiMiMEJKp8QAqBzWgp7ROyLiBMRcVLSfZLKpz4B6LmWwm57YMTdmyVtbvRYAP2h6Ti77QclXS/pYtu7Jd0p6XrbczU81LlD0kc712J3jP+/8nj04ZONP284/5zy25NlF+0o1j82775ivdk4vuaVy3Wc02Tu+Ga9/ehk49+N3/GPs4vrTj/C+ert1DTsEbFklMUrO9ALgA7i67JAEoQdSIKwA0kQdiAJwg4k4YjunSh4oafGNb6ha9trpzduavy9oV2LT3axk9N0tPz3fMrGccX6st9fW6zfetGuYn3+1psb1sZ/oLwuTt+TsUEH48Co46Uc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCX5KeozO/XbjaZUv/3YXG+myd//h/lrrv/bozIa1S8Q4ezdxZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnT+6/P3ttsT7vvI3F+pajR4v1GY/tbVjr7GTSOBVHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH25C67dmexfiLKv4m/+IE/KNYv3f6D0+4JndH0yG57lu3v2d5qe4vtT1bLp9peb3tbdT2l8+0CaNVYXsYfl/TpiLhC0i9K+oTtKyXdIWlDRMyRtKG6D6BPNQ17ROyNiGeq24ckbZU0Q9JCSWuqh62RtKhDPQJog9P6gM72pZLeJ+lJSdMjYq80/AdB0rQG6yyzPWR76JiO1GwXQKvGHHbb50v6pqRPRcTBsa4XESsiYjAiBidoUis9AmiDMYXd9gQNB/1rEfGtavE+2wNVfUBSvZ8hBdBRTYfebFvSSklbI+LuEaW1kpZKWl5dP9KRDlHLuPdcVqx/ZfbqYv3OV8unwF52/+5i/Xixim4ayzj7dZI+ImmT7Y3Vss9oOOQP2b5V0i5Jt3SkQwBt0TTsEfGEpFEnd5d0Q3vbAdApfF0WSIKwA0kQdiAJwg4kQdiBJDjF9Sy340Ojfov5xwbGva1Yf+z+Xy7Wp+38/mn3hN7gyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDOfpa7YGfUWn/gO42nXJaYdvlMwpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwRL1x2NNxoafGNeYHaYFOeTI26GAcGPXXoDmyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASTcNue5bt79neanuL7U9Wy++y/YrtjdVlQefbBdCqsfx4xXFJn46IZ2xfIOlp2+ur2j0R8fnOtQegXcYyP/teSXur24dsb5U0o9ONAWiv03rPbvtSSe+T9GS16Hbbz9leZXtKg3WW2R6yPXRMR+p1C6BlYw677fMlfVPSpyLioKR7Jc2WNFfDR/4vjLZeRKyIiMGIGJygSfU7BtCSMYXd9gQNB/1rEfEtSYqIfRFxIiJOSrpP0tWdaxNAXWP5NN6SVkraGhF3j1g+MOJhN0va3P72ALTLWD6Nv07SRyRtsr2xWvYZSUtsz5UUknZI+mgH+gPQJmP5NP4JSaOdH7uu/e0A6BS+QQckQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiiq1M2235V0s4Riy6W9FrXGjg9/dpbv/Yl0Vur2tnbT0fEO0YrdDXsb9m4PRQRgz1roKBfe+vXviR6a1W3euNlPJAEYQeS6HXYV/R4+yX92lu/9iXRW6u60ltP37MD6J5eH9kBdAlhB5LoSdhtz7f9gu3ttu/oRQ+N2N5he1M1DfVQj3tZZXu/7c0jlk21vd72tup61Dn2etRbX0zjXZhmvKf7rtfTn3f9PbvtcZJelPSrknZLekrSkoj4z6420oDtHZIGI6LnX8Cw/SuSDkv6akS8t1r2l5IORMTy6g/llIj44z7p7S5Jh3s9jXc1W9HAyGnGJS2S9Dvq4b4r9PUhdWG/9eLIfrWk7RHxUkQclfQNSQt70Effi4jHJR04ZfFCSWuq22s0/J+l6xr01hciYm9EPFPdPiTpzWnGe7rvCn11RS/CPkPSyyPu71Z/zfcekr5r+2nby3rdzCimR8Reafg/j6RpPe7nVE2n8e6mU6YZ75t918r053X1IuyjTSXVT+N/10XEz0u6UdInqperGJsxTePdLaNMM94XWp3+vK5ehH23pFkj7s+UtKcHfYwqIvZU1/slPaz+m4p635sz6FbX+3vcz4/10zTeo00zrj7Yd72c/rwXYX9K0hzb77I9UdKHJa3tQR9vYXty9cGJbE+WNE/9NxX1WklLq9tLJT3Sw15+Qr9M491omnH1eN/1fPrziOj6RdICDX8i/1+S/qQXPTTo692Snq0uW3rdm6QHNfyy7piGXxHdKuntkjZI2lZdT+2j3h6QtEnScxoO1kCPevslDb81fE7SxuqyoNf7rtBXV/YbX5cFkuAbdEAShB1IgrADSRB2IAnCDiRB2IEkCDuQxP8DmoYLeg2e4iYAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
Expand Down Expand Up @@ -341,7 +353,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.8.10"
}
},
"nbformat": 4,
Expand Down
76 changes: 76 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
import torch as t
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt

print(t.__version__)

os.environ['KMP_DUPLICATE_LIB_OK']='True'

# the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = t.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = t.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# len(mnist_trainset)
# len(mnist_testset)

# the model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.linear1 = nn.Linear(28*28, 100) # input layer
self.linear2 = nn.Linear(100, 50) # hidden layer
self.final = nn.Linear(50, 10) # output layer
self.relu = nn.ReLU() # piecewise linear function

# convert + flatten
def forward(self, img):
x = img.view(-1, 28*28) # reshape the image for the model
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.final(x)
return x

net = Net()

# loss function
cross_en_loss = nn.CrossEntropyLoss()
optimiser = t.optim.Adam(net.parameters(), lr=1) # e-1
epoch = 10

for epoch in range(epoch):
net.train()

for data in train_loader:
x, y = data # x=features, y=targets
optimiser.zero_grad() # set gradient to 0 before each loss calc
output = net(x.view(-1, 28*28)) # pass in reshaped batch
loss = cross_en_loss(output, y) # cal and grab the loss value
loss.backward() # apply loss back through the network's parameters
optimiser.step() # optimise weights to account for loss and gradients


# evaluating our dataset
correct = 0
total = 0
with t.no_grad():
for data in test_loader:
x, y = data
output = net(x.view(-1, 784))
for idx, i in enumerate(output):
if t.argmax(i) == y[idx]:
correct += 1
total += 1
print(f"accuracy: {round(correct/total, 3)}")

# visualization
plt.imshow(x[3].view(28, 28))
plt.show()
print(t.argmax(net(x[3].view(-1, 784))[0]))
28 changes: 28 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
certifi==2022.12.7
charset-normalizer==2.1.1
contourpy==1.1.0
cycler==0.11.0
filelock==3.9.0
fonttools==4.41.1
idna==3.4
importlib-resources==6.0.0
Jinja2==3.1.2
kiwisolver==1.4.4
MarkupSafe==2.1.2
matplotlib==3.7.2
mpmath==1.2.1
networkx==3.0
numpy==1.24.1
packaging==23.1
Pillow==9.3.0
pyparsing==3.0.9
python-dateutil==2.8.2
requests==2.28.1
six==1.16.0
sympy==1.11.1
torch==2.0.1+cpu
torchaudio==2.0.2+cpu
torchvision==0.15.2+cpu
typing-extensions==4.4.0
urllib3==1.26.13
zipp==3.16.2

0 comments on commit 1432715

Please sign in to comment.