Fine-tune a Custom Pre-Trained ModelΒΆ
πΆοΈZhiJian is an unifying and rapidly deployable toolbox for pre-trained model reuse.
OverviewΒΆ
In the following example, we show how ZhiJian:
Construct a custom MLP
Tune with supervision on a cutom dataset
Infer to evaluate the performance
The figure below shows the three stages of our example. To run the following code, please click [Open In Colab].
Construct Custom ModelΒΆ
We fisrt begin with a three-layer Multilayer Perceptron (MLP).
Although a multi-layer perceptron is not a good image learner, we can quickly get started with it. For other custom networks, we can also make similar designs and modifications by analogy.
Run the code block below to customize the model:
import torch.nn as nn
class MLP(nn.Module):
"""
MLP Class
==============
Multilayer Perceptron (MLP) model for image (224x224) classification tasks.
Args:
args (object): Custom arguments or configurations.
num_classes (int): Number of output classes.
"""
def __init__(self, args, num_classes):
super(MLP, self).__init__()
self.args = args
self.image_size = 224
self.fc1 = nn.Linear(self.image_size * self.image_size * 3, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, num_classes)
def forward(self, x):
"""
Forward pass of the model.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output logits.
"""
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = nn.ReLU()(x)
x = self.fc2(x)
x = nn.ReLU()(x)
x = self.fc3(x)
return x
Next, run the code block below to configure the GPU and the model:
model = MLP(args, DATASET2NUM_CLASSES[args.dataset.replace('VTAB.','')]) model = ModelWrapper(model) model_args = dict2args({'hidden_size': 512})
Now, run the code block below to prepare the
trainer
with passing in the parametermodel
:trainer = prepare_trainer( args, model=model, model_args=model_args, device=device, ... ) trainer.fit() trainer.test()
Prepare Custom DatasetΒΆ
Configure without dataset configuration and organize the custom dataset in the following structure:
within the
your/dataset/dir
directorycreate a separate folder for each category
store all the data corresponding to each category within its respective folder
/your/dataset/directory βββ train β βββ class_1 β β βββ train_class_1_img_1.jpg β β βββ train_class_1_img_2.jpg β β βββ train_class_1_img_3.jpg β β βββ ... β βββ class_2 β β βββ train_class_2_img_1.jpg β β βββ ... β βββ class_3 β β βββ ... β βββ class_4 β β βββ ... β βββ class_5 β β βββ ... βββ test βββ class_1 β βββ test_class_1_img_1.jpg β βββ test_class_1_img_2.jpg β βββ test_class_1_img_3.jpg β βββ ... βββ class_2 β βββ test_class_2_img_1.jpg β βββ ... βββ class_3 β βββ ... βββ class_4 β βββ ... βββ class_5 βββ ...
Set up the custom dataset:
train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) train_dataset = ImageFolder(root='/your/dataset/directory/train', transform=train_transform) val_dataset = ImageFolder(root='/your/dataset/directory/test', transform=val_transform)
Implement the corresponding loader:
train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, shuffle=True ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True, shuffle=False ) num_classes = len(train_dataset.classes)