Fine-tune a Custom Pre-Trained ModelΒΆ

πŸ•ΆοΈ

ZhiJian is an unifying and rapidly deployable toolbox for pre-trained model reuse.

OverviewΒΆ

In the following example, we show how ZhiJian:

  • Construct a custom MLP

  • Tune with supervision on a cutom dataset

  • Infer to evaluate the performance

The figure below shows the three stages of our example. To run the following code, please click [Open In Colab].

../_images/tutorials_get_started_mlp.png

Construct Custom ModelΒΆ

We fisrt begin with a three-layer Multilayer Perceptron (MLP).

../_images/tutorials_mlp.png

Fig. 2 Custom Multilayer Perceptron (MLP) ArchitectureΒΆ

Although a multi-layer perceptron is not a good image learner, we can quickly get started with it. For other custom networks, we can also make similar designs and modifications by analogy.

  • Run the code block below to customize the model:

import torch.nn as nn

class MLP(nn.Module):
    """
    MLP Class
    ==============

    Multilayer Perceptron (MLP) model for image (224x224) classification tasks.

    Args:
        args (object): Custom arguments or configurations.
        num_classes (int): Number of output classes.
    """
    def __init__(self, args, num_classes):
        super(MLP, self).__init__()
        self.args = args
        self.image_size = 224
        self.fc1 = nn.Linear(self.image_size * self.image_size * 3, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, x):
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output logits.
        """
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)
        x = nn.ReLU()(x)
        x = self.fc3(x)
        return x
  • Next, run the code block below to configure the GPU and the model:

    model = MLP(args, DATASET2NUM_CLASSES[args.dataset.replace('VTAB.','')])
    model = ModelWrapper(model)
    model_args = dict2args({'hidden_size': 512})
    
  • Now, run the code block below to prepare the trainer with passing in the parameter model:

    trainer = prepare_trainer(
        args,
        model=model,
        model_args=model_args,
        device=device,
        ...
    )
    
    trainer.fit()
    trainer.test()
    

Prepare Custom DatasetΒΆ

  • Configure without dataset configuration and organize the custom dataset in the following structure:

    • within the your/dataset/dir directory

    • create a separate folder for each category

    • store all the data corresponding to each category within its respective folder

      /your/dataset/directory
      β”œβ”€β”€ train
      β”‚   β”œβ”€β”€ class_1
      β”‚   β”‚   β”œβ”€β”€ train_class_1_img_1.jpg
      β”‚   β”‚   β”œβ”€β”€ train_class_1_img_2.jpg
      β”‚   β”‚   β”œβ”€β”€ train_class_1_img_3.jpg
      β”‚   β”‚   └── ...
      β”‚   β”œβ”€β”€ class_2
      β”‚   β”‚   β”œβ”€β”€ train_class_2_img_1.jpg
      β”‚   β”‚   └── ...
      β”‚   β”œβ”€β”€ class_3
      β”‚   β”‚   └── ...
      β”‚   β”œβ”€β”€ class_4
      β”‚   β”‚   └── ...
      β”‚   β”œβ”€β”€ class_5
      β”‚   β”‚   └── ...
      └── test
          β”œβ”€β”€ class_1
          β”‚   β”œβ”€β”€ test_class_1_img_1.jpg
          β”‚   β”œβ”€β”€ test_class_1_img_2.jpg
          β”‚   β”œβ”€β”€ test_class_1_img_3.jpg
          β”‚   └── ...
          β”œβ”€β”€ class_2
          β”‚   β”œβ”€β”€ test_class_2_img_1.jpg
          β”‚   └── ...
          β”œβ”€β”€ class_3
          β”‚   └── ...
          β”œβ”€β”€ class_4
          β”‚   └── ...
          └── class_5
              └── ...
      
  • Set up the custom dataset:

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    ])
    
    train_dataset = ImageFolder(root='/your/dataset/directory/train', transform=train_transform)
    val_dataset = ImageFolder(root='/your/dataset/directory/test', transform=val_transform)
    
  • Implement the corresponding loader:

    train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=True,
            shuffle=True
        )
    val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=True,
            shuffle=False
        )
    num_classes = len(train_dataset.classes)