Fashion MNIST classification using custom PyTorch Convolution Neural Network (CNN)

6 minute read

Hi, in today’s post we are going to look at image classification using a simple PyTorch architecture. We’re going to use the Fashion-MNIST data, which is a famous benchmarking dataset. Below is a brief summary of the Fashion-MNIST.

Fashion-MNIST is a dataset of Zalando’s article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. Zalando intends Fashion-MNIST to serve as a direct drop-in replacement for the original MNIST dataset for benchmarking machine learning algorithms. It shares the same image size and structure of training and testing splits.

The original MNIST dataset contains a lot of handwritten digits. Members of the AI/ML/Data Science community love this dataset and use it as a benchmark to validate their algorithms. In fact, MNIST is often the first dataset researchers try. “If it doesn’t work on MNIST, it won’t work at all”, they said. “Well, if it does work on MNIST, it may still fail on others.”

The Fashion-MNIST can be cloned from the following GitHub repo and please don’t forget to check out the Fashion-MNIST paper for more detials.

Each of the labels in the data will correspond to either one of the following classes;

<!DOCTYPE html>

Fashion-MNIST labels

class label
T-shirt/top 0
Trouser 1
Pullover 2
Dress 3
Coat 4
Sandal 5
Shirt 6
Sneaker 7
Bag 8
Ankle boot 9
# import some dependencies
import torchvision
import torch
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

import torch.optim as optim
import time
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(linewidth=120)
# import data
train_set = torchvision.datasets.FashionMNIST(root="./", download=True, 
                                              train=True,
                                              transform=transforms.Compose([transforms.ToTensor()]))

test_set = torchvision.datasets.FashionMNIST(root="./", download=True, 
                                              train=False,
                                              transform=transforms.Compose([transforms.ToTensor()]))
data_loader = torch.utils.data.DataLoader(train_set, batch_size=10, shuffle=True)
sample = next(iter(data_loader))

imgs, lbls = sample

visualize some samples

# create a grid 
plt.figure(figsize=(15,10))
grid = torchvision.utils.make_grid(nrow=20, tensor=imgs)
print(f"image tensor: {imgs.shape}")
print(f"class labels: {lbls}")
plt.imshow(np.transpose(grid, axes=(1,2,0)), cmap='gray');
image tensor: torch.Size([10, 1, 28, 28])
class labels: tensor([8, 0, 9, 9, 7, 5, 9, 5, 2, 6])

png

As we can see from the above, the images are grayscale 28 by 28 images. Without further ado, lets define a simple network that will learn to map the inputs (images) to the correct class (label/target). To do this we will be using a classic way to program a neural network (i.e, using object oriented programming or OOP), train it over a few epochs (iterations) and inspect the results using a confusion matrix. Before we do this, lets define some functions we will be using..

# define some helper functions
def get_item(preds, labels):
    """function that returns the accuracy of our architecture"""
    return preds.argmax(dim=1).eq(labels).sum().item()

@torch.no_grad() # turn off gradients during inference for memory effieciency
def get_all_preds(network, dataloader):
    """function to return the number of correct predictions across data set"""
    all_preds = torch.tensor([])
    model = network
    for batch in dataloader:
        images, labels = batch
        preds = model(images) # get preds
        all_preds = torch.cat((all_preds, preds), dim=0) # join along existing axis
        
    return all_preds


def plot_confusion_matrix(cm,
                          target_names,
                          title='Confusion matrix',
                          cmap=None,
                          normalize=True):
    """
    given a sklearn confusion matrix (cm), make a nice plot

    Arguments
    ---------
    cm:           confusion matrix from sklearn.metrics.confusion_matrix

    target_names: given classification classes such as [0, 1, 2]
                  the class names, for example: ['high', 'medium', 'low']

    title:        the text to display at the top of the matrix

    cmap:         the gradient of the values displayed from matplotlib.pyplot.cm
                  see http://matplotlib.org/examples/color/colormaps_reference.html
                  plt.get_cmap('jet') or plt.cm.Blues

    normalize:    If False, plot the raw numbers
                  If True, plot the proportions

    Usage
    -----
    plot_confusion_matrix(cm           = cm,                  # confusion matrix created by
                                                              # sklearn.metrics.confusion_matrix
                          normalize    = True,                # show proportions
                          target_names = y_labels_vals,       # list of names of the classes
                          title        = best_estimator_name) # title of graph

    Citiation
    ---------
    http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

    """
    import matplotlib.pyplot as plt
    import numpy as np
    import itertools

    accuracy = np.trace(cm) / np.sum(cm).astype('float')
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(15, 10))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    plt.show()

# define network

class Network(nn.Module): # extend nn.Module class of nn
    def __init__(self):
        super().__init__() # super class constructor
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5,5))
        self.batchN1 = nn.BatchNorm2d(num_features=6)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=(5,5))
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.batchN2 = nn.BatchNorm1d(num_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
        
        
    def forward(self, t): # implements the forward method (flow of tensors)
        
        # hidden conv layer 
        t = self.conv1(t)
        t = F.max_pool2d(input=t, kernel_size=2, stride=2)
        t = F.relu(t)
        t = self.batchN1(t)
        
        # hidden conv layer
        t = self.conv2(t)
        t = F.max_pool2d(input=t, kernel_size=2, stride=2)
        t = F.relu(t)
        
        # flatten
        t = t.reshape(-1, 12*4*4)
        t = self.fc1(t)
        t = F.relu(t)
        t = self.batchN2(t)
        t = self.fc2(t)
        t = F.relu(t)
        
        # output
        t = self.out(t)
        
        return t        
cnn_model = Network() # init model
print(cnn_model) # print model structure
Network(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (batchN1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(6, 12, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=192, out_features=120, bias=True)
  (batchN2): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=120, out_features=60, bias=True)
  (out): Linear(in_features=60, out_features=10, bias=True)
)
# let's also normalize the data for faster convergence

# import data
mean = 0.2859;  std = 0.3530 # calculated using standization from the MNIST itself which we skip in this blog
train_set = torchvision.datasets.FashionMNIST(root="./", download=True,
                                              transform=transforms.Compose([transforms.ToTensor(),
                                                                            transforms.Normalize(mean, std)
                                                                           ]))
data_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True, num_workers=1)
optimizer = optim.Adam(lr=0.01, params=cnn_model.parameters())
# def train loop

for epoch in range(5):
    start_time = time.time()
    total_correct = 0
    total_loss = 0
    for batch in data_loader:
        imgs, lbls = batch
        preds = cnn_model(imgs) # get preds
        loss = F.cross_entropy(preds, lbls) # compute loss
        optimizer.zero_grad() # zero grads
        loss.backward() # calculates gradients 
        optimizer.step() # update the weights
        
        total_loss += loss.item()
        total_correct += get_item(preds, lbls)
        accuracy = total_correct/len(train_set)
    end_time = time.time() - start_time    
    print("Epoch no.",epoch+1 ,"|accuracy: ", round(accuracy, 3),"%", "|total_loss: ", total_loss, "| epoch_duration: ", round(end_time,2),"sec")
Epoch no. 1 |accuracy:  0.829 % |total_loss:  276.2480258792639 | epoch_duration:  80.73 sec
Epoch no. 2 |accuracy:  0.871 % |total_loss:  206.40078330039978 | epoch_duration:  70.3 sec
Epoch no. 3 |accuracy:  0.883 % |total_loss:  190.10711652040482 | epoch_duration:  73.01 sec
Epoch no. 4 |accuracy:  0.891 % |total_loss:  175.60668615996838 | epoch_duration:  83.73 sec
Epoch no. 5 |accuracy:  0.899 % |total_loss:  165.3020654693246 | epoch_duration:  111.66 sec
  • training is complete, it’s time to inspect how well our algorithm performed using the confusion matrix!

train confusion matrix

# get all preds
pred_data_loader = torch.utils.data.DataLoader(batch_size=10000, dataset=train_set, num_workers=1)
all_preds= get_all_preds(network=cnn_model, dataloader=pred_data_loader) 
plot_confusion_matrix(cm=confusion_matrix(y_true=train_set.targets, y_pred=all_preds.argmax(1)), target_names=train_set.classes, normalize=False)

png

test confusion matrix (out of sample performance)

# get all preds
test_pred_data_loader = torch.utils.data.DataLoader(batch_size=10000, dataset=test_set, num_workers=1)
all_preds_test = get_all_preds(network=cnn_model, dataloader=test_pred_data_loader) 
plot_confusion_matrix(cm=confusion_matrix(y_true=train_set.targets, y_pred=all_preds_test.argmax(1)), target_names=test_set.classes, normalize=False)

png

Conclusion

  • model was slightly overfit
  • train accuracy was about 90% while the test accuracy was 78%
  • We could train for more epochs, however, the BN-CNN performed well but results could be improved with transfer learning, which will be the future of this work.

Thanks to the following for providing insights!

  • support community at deeplizard
  • Zilando for making Fashion-MNIST open access for testing code such as this!