This proct is about to implement a pytorch code to desing a neural network in order to do classification based on the dataset FasionMNIST.
The first step then is to download the training and test sets of the dataset using the module datasets in pytorch.
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets , transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5,0.5)])
#transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x/255)])
train_set = datasets.FashionMNIST('.\FasionMNIST' , download = True , train = True, transform=transform)
test_set = datasets.FashionMNIST('.\FashioMNIST' , download = True , train = False, transform = transforms.ToTensor())
Also, to avoid unwanted overfittings I have divided the train set into train and validation as below:
from torch.utils.data import random_split
# calculate size of train and validation sets
train_size = 51000
valid_size = 9000
p_train_set, valid_set = random_split(train_set, [train_size, valid_size])
Alright, now I have three sets, training validation and test. During the traiing phase of the algorithm we will use the training and validation sets and finally we apply the resulting model on the test set to verify how well int can classify untrained date.
But before that, let's take a look at the type of fata that we have. Using the fuctio iter in python I make the train set iterable and select the first elemet of the iterable obect using next function.
it = iter(train_set)
sample , label = next(it)
Note that since the elements of the training set consist of tuples of images and class labels, I have unpacked the extracted element into two valiables sample and label. Using the attribute shape we see the size of the sample. The first number in the list indicates the number of channels of the image, and since here we have a gray scale image there is only one channel. Also the size of the image is 28$*$28 (28 pixels in row and 28 in column)
sample.shape
torch.Size([1, 28, 28])
Using the imshow of pyplot module we can visualize the image. As you can see, it seems to a kid of boot.
import matplotlib.pyplot as plt
plt.imshow(sample[0],cmap = 'gray')
<matplotlib.image.AxesImage at 0x13b00ad60>
So to have an idea bout the existig classes in the dataset, I have hard coded them in below.
class_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
In fact , the labels in the dataset are categorical data from 0 to 9 each one idicateng one of the classes. Numer 0 refers to "T_shirt/top", number one to "Trouser" and so on.
For instance, the lable in the example above indicates the number 9 which is refer to the last class "Ankle Boot".
label
9
class_names[label]
'Ankle boot'
To ake the learning process faster, I am going to use the batch ;earning. So I use the DataLoader module in torch.utils.data to divide the data ito batches of size 32. It is possible to choose different numbers but 32 is the one that seems ok in our case. Also to improve performance of learning (avoid overfit and underfit), I put the shuffle for my traning data true, which means that after each epoch of training phase, the construction of batches will be differet in the training data.
from torch.utils.data import DataLoader
train_loader = DataLoader(p_train_set , batch_size=32 , shuffle = True)
valid_loader = DataLoader(valid_set , batch_size=32 , shuffle = False)
test_loader = DataLoader(test_set , batch_size=32 , shuffle = False)
Ok, now it is time to define the model. I am going to use a sequential neural network with three layers. Sincethe images are of size 28$*$28, the nput size would be 28$*$28. For the first hidden layer, I am using 300 neurons and for the second one I a going to use 100 neurons and since we have 10 classes, the last layer consists of 10 neurons (for classes one to nine). These parameters are as well as the learning rate are given as below.
input_size = 28*28
hidden_layer_1 = 300
hidden_layer_2 = 100
output_size = 10
learning_rate = 0.01
The following snippet illustrates the structure of the model that I costruct. In fact, I and used the Relu activation function for each layer of the neural network.
class NeuralNet(nn.Module):
def __init__(self, input_size , hidden_layer_1 , hidden_layer_2 , output_size):
super().__init__()
self.l1 = nn.Linear(input_size , hidden_layer_1)
self.relu1 = nn.ReLU()
self.l2 = nn.Linear(hidden_layer_1 , hidden_layer_2)
self.relu2 = nn.ReLU()
self.l3 = nn.Linear(hidden_layer_2 , output_size)
self.dropout = nn.Dropout(0.5)
def forward(self , X):
#out = self.dropout1(X)
out = self.l1(X)
out = self.relu1(out)
out = self.dropout(out)
out = self.l2(out)
out = self.dropout(out)
out = self.relu2(out)
out = self.l3(out)
return out
model = NeuralNet(input_size , hidden_layer_1 , hidden_layer_2 , output_size)
By typing the name of the model, we can somehow see the summary of the model structure.
model
NeuralNet( (l1): Linear(in_features=784, out_features=300, bias=True) (relu1): ReLU() (l2): Linear(in_features=300, out_features=100, bias=True) (relu2): ReLU() (l3): Linear(in_features=100, out_features=10, bias=True) (dropout): Dropout(p=0.5, inplace=False) )
The other things that should be specified before starting the training phase, is the loss function and the optimier. Due to the nature of our problem, which is multiple classification, I have chosen the cross entropy loss function, which is available in torch.nn. Also to minimize the loss function, I have used the stochastic gradient decsent algoritm as optimizer.
from torch.optim import lr_scheduler
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
#step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size = 10 , gamma = 0.09)
Now we have all the tools to start training. I will do the training using 30 epochs, the number of steps is equal to number of training batches and the total nuber of steps for validation is equal to number of validation batches. Also I have used a dictionary to save some of the metrics like loss and accuracy during the training phase. These metrics can be used to visualize the performance of algorithm.
from collections import defaultdict
num_epochs = 50
total_steps = len(train_loader)
total_steps_valid = len(valid_loader)
dic = defaultdict(list)
for epoch in range(num_epochs):
n_currect_train = 0
n_currect_valid = 0
running_loss_train = 0
running_loss_val = 0
model.train()
for i , (images, labels) in enumerate(train_loader):
images = images.reshape(-1, 28*28)
#images = images / 255.0
outputs = model(images)
loss_train = criterion(outputs , labels)
optimizer.zero_grad()
loss_train.backward()
optimizer.step()
running_loss_train += loss_train
_, predictions = torch.max(outputs, 1)
n_currect_train += (predictions == labels).sum().item()
# wandb.log({'Train Loss': train_loss/train_total})
loss_train = running_loss_train / len(train_loader)
print(f"epoch:{epoch}, training loss = {loss_train} , train_accuracy = {n_currect_train/51000}")
#step_lr_scheduler.step()
model.eval()
with torch.no_grad():
for i,(images , labels) in enumerate(valid_loader):
images = images.reshape(-1, 28*28)
#images = images / 255.0
outputs = model(images)
loss_val = criterion(outputs, labels)
running_loss_val += loss_val
_,predictions = torch.max(outputs, 1)
n_currect_valid += (predictions == labels).sum().item()
loss_val = running_loss_val / len(valid_loader)
print(f"epoch:{epoch}, validation loss = {loss_val}, val accuracy = {n_currect_valid/9000}")
dic['train loss'].append(loss_train.item())
dic['train accuracy'].append(n_currect_train/51000)
dic['val loss'].append(loss_val.item())
dic['val accuracy'].append(n_currect_valid/9000)
epoch:0, training loss = 1.0660334825515747 , train_accuracy = 0.6154901960784314 epoch:0, validation loss = 0.6104346513748169, val accuracy = 0.7731111111111111 epoch:1, training loss = 0.6538720726966858 , train_accuracy = 0.7615098039215686 epoch:1, validation loss = 0.5079036355018616, val accuracy = 0.8117777777777778 epoch:2, training loss = 0.5738135576248169 , train_accuracy = 0.7968823529411765 epoch:2, validation loss = 0.4667207896709442, val accuracy = 0.8312222222222222 epoch:3, training loss = 0.5280290246009827 , train_accuracy = 0.8152941176470588 epoch:3, validation loss = 0.43352800607681274, val accuracy = 0.8434444444444444 epoch:4, training loss = 0.4966854155063629 , train_accuracy = 0.8253333333333334 epoch:4, validation loss = 0.4204733073711395, val accuracy = 0.8435555555555555 epoch:5, training loss = 0.47657376527786255 , train_accuracy = 0.8328039215686275 epoch:5, validation loss = 0.4025634527206421, val accuracy = 0.8498888888888889 epoch:6, training loss = 0.4597938358783722 , train_accuracy = 0.8381176470588235 epoch:6, validation loss = 0.3917181193828583, val accuracy = 0.8591111111111112 epoch:7, training loss = 0.4467167258262634 , train_accuracy = 0.8416274509803922 epoch:7, validation loss = 0.38517993688583374, val accuracy = 0.8605555555555555 epoch:8, training loss = 0.43612566590309143 , train_accuracy = 0.8472352941176471 epoch:8, validation loss = 0.3760550618171692, val accuracy = 0.8643333333333333 epoch:9, training loss = 0.426359087228775 , train_accuracy = 0.8518039215686275 epoch:9, validation loss = 0.36813223361968994, val accuracy = 0.8665555555555555 epoch:10, training loss = 0.41397595405578613 , train_accuracy = 0.8561764705882353 epoch:10, validation loss = 0.3633224368095398, val accuracy = 0.8671111111111112 epoch:11, training loss = 0.40791866183280945 , train_accuracy = 0.8572156862745098 epoch:11, validation loss = 0.3499942421913147, val accuracy = 0.8713333333333333 epoch:12, training loss = 0.39882996678352356 , train_accuracy = 0.8595098039215686 epoch:12, validation loss = 0.3481449484825134, val accuracy = 0.8745555555555555 epoch:13, training loss = 0.39534270763397217 , train_accuracy = 0.86 epoch:13, validation loss = 0.3464592397212982, val accuracy = 0.8743333333333333 epoch:14, training loss = 0.38584789633750916 , train_accuracy = 0.8659607843137255 epoch:14, validation loss = 0.33864426612854004, val accuracy = 0.8785555555555555 epoch:15, training loss = 0.3831351697444916 , train_accuracy = 0.8653529411764705 epoch:15, validation loss = 0.3343909978866577, val accuracy = 0.8784444444444445 epoch:16, training loss = 0.376904159784317 , train_accuracy = 0.8678431372549019 epoch:16, validation loss = 0.3324114978313446, val accuracy = 0.8783333333333333 epoch:17, training loss = 0.3708369731903076 , train_accuracy = 0.8697058823529412 epoch:17, validation loss = 0.3275023400783539, val accuracy = 0.8822222222222222 epoch:18, training loss = 0.3655582666397095 , train_accuracy = 0.8713529411764706 epoch:18, validation loss = 0.3342456817626953, val accuracy = 0.8774444444444445 epoch:19, training loss = 0.3593081533908844 , train_accuracy = 0.8744901960784314 epoch:19, validation loss = 0.3202292025089264, val accuracy = 0.8853333333333333 epoch:20, training loss = 0.35814061760902405 , train_accuracy = 0.8734509803921569 epoch:20, validation loss = 0.32529181241989136, val accuracy = 0.8815555555555555 epoch:21, training loss = 0.35250774025917053 , train_accuracy = 0.8760392156862745 epoch:21, validation loss = 0.3176078498363495, val accuracy = 0.8848888888888888 epoch:22, training loss = 0.34929654002189636 , train_accuracy = 0.8778039215686274 epoch:22, validation loss = 0.32624009251594543, val accuracy = 0.8808888888888889 epoch:23, training loss = 0.34252166748046875 , train_accuracy = 0.8787647058823529 epoch:23, validation loss = 0.3188793957233429, val accuracy = 0.8842222222222222 epoch:24, training loss = 0.33902886509895325 , train_accuracy = 0.8797058823529412 epoch:24, validation loss = 0.3146404027938843, val accuracy = 0.8834444444444445 epoch:25, training loss = 0.33586081862449646 , train_accuracy = 0.8798235294117647 epoch:25, validation loss = 0.30975472927093506, val accuracy = 0.8896666666666667 epoch:26, training loss = 0.33639952540397644 , train_accuracy = 0.8812941176470588 epoch:26, validation loss = 0.3153079152107239, val accuracy = 0.8846666666666667 epoch:27, training loss = 0.331190288066864 , train_accuracy = 0.8832745098039215 epoch:27, validation loss = 0.30871644616127014, val accuracy = 0.8871111111111111 epoch:28, training loss = 0.32833436131477356 , train_accuracy = 0.882921568627451 epoch:28, validation loss = 0.30797871947288513, val accuracy = 0.8881111111111111 epoch:29, training loss = 0.3246174454689026 , train_accuracy = 0.8855686274509804 epoch:29, validation loss = 0.30637386441230774, val accuracy = 0.8881111111111111 epoch:30, training loss = 0.32185670733451843 , train_accuracy = 0.8864705882352941 epoch:30, validation loss = 0.30469581484794617, val accuracy = 0.889 epoch:31, training loss = 0.3173470199108124 , train_accuracy = 0.8873333333333333 epoch:31, validation loss = 0.3029055893421173, val accuracy = 0.8908888888888888 epoch:32, training loss = 0.3143216669559479 , train_accuracy = 0.8880392156862745 epoch:32, validation loss = 0.30490925908088684, val accuracy = 0.8917777777777778 epoch:33, training loss = 0.31426751613616943 , train_accuracy = 0.8880392156862745 epoch:33, validation loss = 0.30317750573158264, val accuracy = 0.8901111111111111 epoch:34, training loss = 0.3085930645465851 , train_accuracy = 0.8910392156862745 epoch:34, validation loss = 0.30012795329093933, val accuracy = 0.8907777777777778 epoch:35, training loss = 0.3064310848712921 , train_accuracy = 0.889235294117647 epoch:35, validation loss = 0.2974194884300232, val accuracy = 0.8906666666666667 epoch:36, training loss = 0.30692628026008606 , train_accuracy = 0.8924313725490196 epoch:36, validation loss = 0.30590003728866577, val accuracy = 0.8867777777777778 epoch:37, training loss = 0.30249127745628357 , train_accuracy = 0.8928039215686274 epoch:37, validation loss = 0.2971547245979309, val accuracy = 0.8922222222222222 epoch:38, training loss = 0.30319660902023315 , train_accuracy = 0.8923725490196078 epoch:38, validation loss = 0.2934527099132538, val accuracy = 0.8924444444444445 epoch:39, training loss = 0.2976021468639374 , train_accuracy = 0.8929411764705882 epoch:39, validation loss = 0.2983241081237793, val accuracy = 0.8922222222222222 epoch:40, training loss = 0.2955777049064636 , train_accuracy = 0.8964509803921569 epoch:40, validation loss = 0.2918039560317993, val accuracy = 0.8938888888888888 epoch:41, training loss = 0.2939228117465973 , train_accuracy = 0.895235294117647 epoch:41, validation loss = 0.2914087772369385, val accuracy = 0.8947777777777778 epoch:42, training loss = 0.29221951961517334 , train_accuracy = 0.8957058823529411 epoch:42, validation loss = 0.29193246364593506, val accuracy = 0.8925555555555555 epoch:43, training loss = 0.28882843255996704 , train_accuracy = 0.897235294117647 epoch:43, validation loss = 0.2935861349105835, val accuracy = 0.8948888888888888 epoch:44, training loss = 0.28703486919403076 , train_accuracy = 0.8982156862745098 epoch:44, validation loss = 0.29051968455314636, val accuracy = 0.8945555555555555 epoch:45, training loss = 0.2829912006855011 , train_accuracy = 0.8988823529411765 epoch:45, validation loss = 0.2926543056964874, val accuracy = 0.8954444444444445 epoch:46, training loss = 0.2805190086364746 , train_accuracy = 0.899 epoch:46, validation loss = 0.2953588366508484, val accuracy = 0.8933333333333333 epoch:47, training loss = 0.2838083505630493 , train_accuracy = 0.897921568627451 epoch:47, validation loss = 0.2867686152458191, val accuracy = 0.8954444444444445 epoch:48, training loss = 0.2785106897354126 , train_accuracy = 0.899607843137255 epoch:48, validation loss = 0.28962841629981995, val accuracy = 0.896 epoch:49, training loss = 0.277021586894989 , train_accuracy = 0.9010196078431373 epoch:49, validation loss = 0.29809436202049255, val accuracy = 0.8936666666666667
Using the information registered in the dictionary, now I can see the performance of the model using loss value and accuracy for both the training set and validation set.
import pandas as pd
pd.DataFrame(dic).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1) # set the vertical range to [0-1] plt.show()
(0.0, 1.0)
with torch.no_grad():
n_samples , n_correct = 0,0
for (images , labels) in test_loader:
images = images.reshape(-1,28*28)
outputs = model(images)
_, predictions = torch.max(outputs, 1)
n_samples +=labels.shape[0]
n_correct +=(labels ==predictions).sum().item()
print(n_correct/n_samples)
0.783