SCAN, Lets Classify images without Labels !
How can we classify images without labels ?
Image classification is the task in which the machine tries to map an image to a specific label (class / target). This approach is termed supervised learning where the input (images) and the annotations (labels) are used to train the model.
But what to do when the labels are not preset ?.
Here unsupervised learning comes into play. As the name suggests in this task the models are trained with supervision; meaning without any training labels the models try to segment different images based on its characteristics eg. all images having a cat in it will have different features compared to the ones having dogs (assuming images contain only single objects).
Kmeans , DBSCAN are some famous algorithms we hear when learning about unsupervised learning. But there is a big problem with them . Most of these algorithms were designed for low dimension data 2d may be 3d not more than that and when used to classify 32 x 32 dim image data this algorithms fault.
To combat this curse of dimensionality we use feature extracting methods like pca, auto-encoders, hidden states on a CNN ect. to reduce the dimension to a vector which is fed through the kmeans algorithm.
With no supervision the clustering model is dependent on the this extracted features (embedding). And most of this feature-extracting methods produce fair results cant be compared to state of the nets in supervised models.
So now what? Will the unsupervised models always loose to supervised approach. In the near feature will we a spend 100 of hours annotating data
NO ! SCAN to rescue!
OK, I might have overreacted there a bit but the new study in this field in mind-boggling. Well we are not at power with the supervised models but unsupervised learning image net moment will come.
SO whats SCAN ?
SCAN stands for Semantic Clustering by Adopting Nearest Neighbors
If you are new to unsupervised learning or have only trained a CNN this might be a bit different. First of scan is not a end- to-end approach . now I know end-to-end it the way to go but believe me we need to be patient.
So now lets get our hands dirty
The SCAN method is divided into 3 parts.
- The Pretext part
- Scan Model
- Self-Labeling.
The Pretext task.
The Idea of pretext task comes from Representational learning where feature representation is learned solely from the images. This is done my minimizing the distance between a image and its augmentation. Since the neural net used has a limited capacity it tries to remember only the important characteristics which help classify that image.
Lets say Our task is to just classify cat in an image take the above image for example here we are not interested in the mug, the bad sheet or anything else but the neural net doesn’t know this and we don’t have labels to tell it. One way to solve this problem is by removing the object of less interest by augmentation the image. Like random cropping will be useful in this case, the blurring it, introducing some jitter, horizontal flip etc are all thing we can try.
So now that we have a solution lets implement it. In the SCAN paper they have propose a pretext task to minimize the distance between the real image and its augmented versions and since our net has limited capacity it can only remember the important features. Do note that the type of augmentation is used can be specific for certain images and needs to be selected carefully.
Output of this stage will be a vector of dimension D (default = 128) configured by the user.
If the pretext model is pretty good then it will produce a high similarity between images belonging to the same call and in this case we can use the embedding produced by the model as a input to a clustering algorithm like k-means. And get meaningful results. In the paper they have used this approach (pretext + k-means) which has already performed better.
SCAN.
Most of the unsupervised approaches tend to have two phases one to reduce the high dimension inputs to low dimensions while maintaining important features followed by a clustering algorithm which tries to cluster the embedded images into C clusters. where C is defined my the user.
We never get a probability distribution like the softmax layer, And we only have the centrists for labels.The Scan model outputs a probability distribution like a supervised model. but how does it do this ?
The Scan model is build on top of the pretest model which inherits its weights.Aside from the weights the pretext model also outputs K nearest neighbors for each image and saves it in a .np file (numpy array) of shape N x (K+1) where N is the size of the dataset and K is the nearest neighbours plus the original image.
Lets look at the loss function used proposed by the authors.
The first term is maximize the dot product of the input image and its K neighbors and the second term is categorical loss over C classes, here C is the number of classes.
Self-Labeling :
This step if a fine tuning step in which the same scan model is used and we try to produce pseudo labels. If the output predictions is above a certain threshold (is set by the user before hand) we are quite confident that the image falls in the correct cluster. (Over here the we dont know before hand which cluster a image should belong to but if say image one is classified to cluster 1 then for step in the training it needs to fall in the same cluster).
Then we create a Y label array note that a value in only inserted into Y only if the threshold condition is met and a cross-entropy loss on the obtained labels Hence the term self-labeling.
The the model is fine-tuned in self-supervised manner and the idea is quite simple since will have well classified samples these samples then help its wrongly classified nearest neighbors gradually becoming more certain, adding more samples to Y as we train the model.
So Lets brief about the 3 a stages quickly before diving deep into the code.
- Pretext : Learn important feature representation and ignore stuff which doesn't matter. (remember the cat example).
- SCAN : Refine the embedding by maximizing the dot product and output class probability.
3. Self-Labeling: Generate pseudo labels and and fine tuning the scan algorithm.
Here is the entire algorithm from the scan paper.
OK, now lets have quick dive into the code. I didn't want this article to be just theory. In the SCAN paper the authors have used three datasets cifar10,sat10,Imagenet.
For this tutorial I be using the cifar-10 dataset. The goal of the article will be fact check the SCAN paper and try to achieve similar accuracies.
It receives about 87% of the cifar-10 dataset.
Step 1 : Cloning the Repo.
Use the following command
!git clone https://github.com/wvangansbeke/Unsupervised-Classification.git
Note the repo. is in pytorch.
Step 2: Installing dependencies:
!pip install faiss-gpu # For efficient nearest neighbors search!pip install pyyaml easydict # For using config files!pip install termcolor
Step 3: Exploring the repo in colab.
Main Files used for training.
- simclar.py : Used to train the pretext model.
- scan.py : Used to train the scan model.
- selflabel : used to trin the final model (selflabel)
- The eval.py is used for evaluating the scan and selflabel model.
- moco.py is used for training the ImageNet dataset.
Important Folders and their Content.:
- Config: This folder contains config files for pretext, scan and selflabel models.
- data : This folder contains the datasat .py files.
- losses : This folder contains the loss for the pretext and scan model.
- models : different resnet model architectures (for cifar-10) will be using the resnet-18 model.
- utils: utility functions for getting datasets,dataloaders, transforms etc.
Folders to Modify:
If you want to train on custom dataset. The these are the folder you will be changing.
- Add a custom dataset class in the data dir.
- Add the configs.yml files in the pretext,scan,self-label files.
- Finally in the Utils dir we need to modify the common_utils.py file by adding the name
Step 4: Pretext Model training
!python simclr.py — config_env configs/env.yml — config_exp configs/pretext/simclr_cifar10.yml
Output:
The pretext model takes a lot of time to train for the default 500 epochs it can take upto 2 days on a google colab hence I reduced the epochs to 200.
After training 200 epochs this is the output.
Step 5: Pretext + Kmeans (optional)
This is another method discussed in the Scan paper which has good results.
The output vectors of the pretext models can be fed to a cluster model like k-means.
Lets load the pretext model from the last epochs.
#evaluating the models performanceimport torch
import numpy as np
from models.models import ContrastiveMode
from models.resnet_cifar import resnet18pretrain_path = "cifar-10/pretext/checkpoint.pth.tar"
state_dict=torch.load(pretrain_path,map_location=torch.device('cpu'))backbone = resnet18()
model = ContrastiveModel(backbone,head='mlp',features_dim=128)
model.load_state_dict(state_dict['model'])
model.eval()
Lets load the Cifar-10 datsaet using the get_train_dataset and since we are validating our model will use the transforms from the validation dataset using the get_val_transforms
But before we do this we need the config parameters eg ( type of transforms and thier values eg. crop_size etc). Thats why we first get the config file from using the create_config method which returns a dict. It takes two arguments the env.yml (contains the path to root dir), simclr_cifar10.yml this file contains all the parameters in the to train pretext model can be found in the configs/ pretext/ dir.
All the functions (get_train_dataset,create_config ect) can be found in the uitls dir.
#getting the configfrom utils.common_config import get_train_transformations,get_train_dataset,get_train_dataloader,get_val_transformations
from utils.config import create_config
from torchvision.transforms import Compose,ToTensor
from torch.utils.data import DataLoaderp = create_config('configs/env.yml','configs/pretext/simclr_cifar10.yml')#for evaluating will be only using the get_val_transformstrain_transformations = get_val_transformations(p)train_dataset = get_train_dataset(p, train_transformations,split='train')train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=4,batch_size=512)
Printing the Classes
The cifar dataset has 10 classes which are quite distinct from each other hence a good dataset to evaluate the models generalizing capability.
Loading the images in an array the kmeans model and the storing the ground truths in a list to validate the output of kmeans
# from the train set lets get a embedding for each image
from tqdm import tqdmbatch_size = 512
true_labels = []
images = np.zeros(((len(train_dataloader)-1) * batch_size,128))model = model.cuda()#get 500 images from train
for i,data in tqdm(enumerate(train_dataloader)): try: if i < len(train_dataloader)-1: true_labels.extend(data['target']) images[i * batch_size : (i+1) * batch_size] = model(data['image'].cuda()).cpu().detach().numpy() except ValueError: np.append(images,model(data['image'].cuda()).cpu().detach().
numpy())
Training the kmeans model.
#train kmeans
from sklearn.cluster import KMeanskModel = KMeans(n_clusters = 10,n_jobs=-1,random_state = 101,max_iter = 1000,init = 'random',n_init=100)# Training the modelkModel.fit(images)
Now we there is one last step left that’s creating mapping the k-means labels (cluster) assignments to the ground truth labels.
Why do we need to map the predicted labels ?
Kmeans model is an unsupervied method were the models has got no idea of what the labels are. Eg. the ground truth for class cat map be 4 in the cifar-dataset. But the K-means model will assign random centroids of this class and may label it. 5 insead of 4. Hence the need for mapping.
Mapping
Here n_samples can be thought of as a hyperparameters. To get the labels for a particluar class we are going to pass the image from that class to the kmenas model. The number of images of pass per class is controlled using the `n_samples` parameter.
Here I have considered 1000 images per class as our dataset is quite large, The more samples we have the better results we have
After running the below cell for each class will be getting a list of images of size n_samples
eg. for class cats : [2,2,2,2..1,3,4,2,2,2,2…2]
`Then we just take the mode.` and the class cats will be assigned the most frequent label that is 2 in our example.
Then we create a mapping between the predicted_centroids and ground_truths.
classes = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']print(f"classes : {classes}")n_samples = 1000
kmeans_labels_dict = {}
for data in train_dataset:if classes == []: print("embedding for all classes found now stopping") breakclassName = data['meta']['class_name']if className in classes: embedding = model(data['image'].unsqueeze(0).cuda()).cpu().detach().numpy() kmeans_label = kModel.predict(embedding)
if kmeans_labels_dict.get(className,None) is None: kmeans_labels_dict[className] = [kmeans_label[0]]
else:
kmeans_labels_dict[className].append(kmeans_label[0])# Here you can increase if len(kmeans_labels_dict[className]) == n_samples: print(f'For :: {className} {kmeans_labels_dict[className]}')
classes.remove(className)
Performing the Map operation :
To determine the predicted label (assigned centroid) for a image we are going to take mode of the 1000 samples we have stored in the kmeans_labels_dict dict.
import statisticsmapping = {}classes = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']for i in kmeans_labels_dict: print(i , statistics.mode(kmeans_labels_dict[i]))mapping[statistics.mode(kmeans_labels_dict[i])] = classes.index(i)
Note that 2 class have been assigned the same cluster this is to be expected as we trained the pretext-model for only 200 epcochs.
A simple solution here is to increase the number of Samples or just perform the mannual assignments.
Manually assigning the labels .
from collections import Counterfor c_ in ['airplane','ship']:print(f'for {c_} most occuring top 3 classes are : {Counter(kmeans_labels_dict[c_]).most_common(3)}')mapping[7] = classes.index('airplane')mapping
Mapping and predicting:
#mapping
# get the predicted labelspreds = kModel.labels_
print("mapped for preds")mapped_pred_labels = []for pred in preds: mapped_pred_labels.append(mapping[pred])
from sklearn.metrics import accuracy_scoreaccuracy_score(mapped_pred_labels,true_labels) # 52 %
The accuracy we get is just 52% from the claimed 65% from scan paper. But that to be accepted since we didnt train the pretext-model for full 500 epochs.
Step 6: Training the SCAN Model.
!python scan.py --config_env configs/env.yml --config_exp configs/scan/scan_cifar10.yml
Step 7 : Validating the SCAN Model
a. Load the model from the cifar-10/scan dir
from models.models import ClusteringModel
from models.resnet_cifar import resnet18state_dict = torch.load('cifar-10/scan/model.pth.tar')backbone = resnet18()
model = ClusteringModel(backbone, 10, 1)
model.load_state_dict(state_dict['model'])
model = model.cuda()#set the eval model
model.eval()
b. Creating a mapping and saving it in a python dict.
data_file = {'gt':[],'pred':[],'className':[]}# iterate throught the dataset and get prediction for each image and store the true labels.for data in tqdm(train_dataloader): # save the label
data_file['gt'].extend(data['target'].numpy()) #save the class name
data_file['className'].extend(data['meta']['class_name']) #save the pred. label
data_file['pred'].extend(torch.argmax(model(data['image'].cuda())[0],dim = 1).cpu().detach().numpy())#performing the mapping.df['mapped_preds'] = df['pred'].map(lambda x : mapping[x])
c. Compute the accuracy:
accuracy_score(df['gt'].values,df['mapped_preds'].values) # 76 %
Wow ! evaluating the scan model we get a accuracy of 76% that a +26% increase.
Step 7: Training the Self-Labelling Model.
!python selflabel.py --config_env configs/env.yml --config_exp configs/selflabel/selflabel_cifar10.yml
And the output of this stage is 86% thats quite good. An unsupervied model getting an accuracy of 86% is quite good.
Confusion Matrix:
Conclusion:
This is quite a huge accomplishment for the unsupervised domain but we are still not there yet. A supervised model can achieve accuracies above 95% quite easily on such datasets and also takes a few epochs to do so. But this study shows that there is potential for more improvement in this domain.
Here is image Comparing different algorithms from the scan paper.
Acknowledgement:
Authors : Van Gansbeke, Wouter and Vandenhende, Simon and Georgoulis, Stamatios and Proesmans, Marc and Van Gool, Luc
paper : Scan: Learning to classify images without labels
github repo : https://github.com/wvangansbeke/Unsupervised-Classification