Advanced: Extended Structure

🛠️

ZhiJian is an unifying and rapidly deployable toolbox for pre-trained model reuse.

Overview

In the following example, we show how ZhiJian:

  • Customize your own pre-trained model for new ideas of structure

  • Tailor and integrate any add-in extra module within the vast pre-trained model with lightning speed

../_images/tutorials_addin_overview.png

This chapter may involve more advanced configuration.

Introduce the Custom Model

Let’s begin with a three-layer Multilayer Perceptron (MLP).

  • Run the code block below to customize the model:

import torch.nn as nn

class MLP(nn.Module):
    """
    MLP Class
    ==============

    Multilayer Perceptron (MLP) model for image (224x224) classification tasks.

    Args:
        args (object): Custom arguments or configurations.
        num_classes (int): Number of output classes.
    """
    def __init__(self, args, num_classes):
        super(MLP, self).__init__()
        self.args = args
        self.image_size = 224
        self.fc1 = nn.Linear(self.image_size * self.image_size * 3, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, x):
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output logits.
        """
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)
        x = nn.ReLU()(x)
        x = self.fc3(x)
        return x
../_images/tutorials_mlp.png

Fig. 3 Custom Multilayer Perceptron (MLP) Architecture

Now, expand models from a moment of inspiration, do as you please.

We will customize and modify the network structure through a few lines of code from ZhiJian. This additional structures are also implemented based on the PyTorch framework, and inherit the base class AddinBase, which integrates some basic methods for data access.

  • In the following paragraphs, we introduce the components of the extended structure, they are:

    • 1. Main forward function.

    • 2. Entry points to guide inputs

    • 3. Configuration syntax for entry point.

Design Additional Add-in Modules

  • Run the code block below to customize add-in modules and entry points for the model.

class MLPAddin(AddinBase):
    """
    MLPAddin Class
    ==============

    Multilayer Perceptron (MLP) add-in.

    Args:
        config (object): Custom configuration or arguments.
        model_config (object): Configuration specific to the model.
    """
    def __init__(self, config, model_config):
        super(MLPAddin, self).__init__()

        self.config = config
        self.embed_dim = model_config.hidden_size

        self.reduction_dim = 16

        self.fc1 = nn.Linear(self.embed_dim, self.reduction_dim)
        if config.mlp_addin_output_size is not None:
            self.fc2 = nn.Linear(self.reduction_dim, config.mlp_addin_output_size)
        else:
            self.fc2 = nn.Linear(self.reduction_dim, self.embed_dim)

    def forward(self, x):
        """
        Forward pass of the MLP add-in.

        Args:
            x (tensor): Input tensor.

        Returns:
            tensor: Output tensor after passing through the MLP add-in.
        """
        identity = x
        out = self.fc1(identity)
        out = nn.ReLU()(out)
        out = self.fc2(out)

        return out

    def adapt_input(self, module, inputs):
        """
        Hook function to adapt the input data before it enters the module.

        Args:
            module (nn.Module): The module being hooked.
            inputs (tuple): (Inputs before the module,).

        Returns:
            tensor: Adapted input tensor after passing through the MLP add-in.
        """
        x = inputs[0]
        return self.forward(x)

    def adapt_output(self, module, inputs, outputs):
        """
        Hook function to adapt the output data after it leaves the module.

        Args:
            module (nn.Module): The module being hooked.
            inputs (tuple): (Inputs before the module,).
            outputs (tensor): Outputs after the module.

        Returns:
            tensor: Adapted output tensor after passing through the MLP add-in.
        """
        return self.forward(outputs)

    def adapt_across_input(self, module, inputs):
        """
        Hook function to adapt the data across the modules.

        Args:
            module (nn.Module): The module being hooked.
            inputs (tuple): (Inputs before the module,).

        Returns:
            tensor: Adapted input tensor after adding the MLP add-in output to the subsequent module.
        """
        x = inputs[0]
        x = x + self.forward(self.inputs_cache)
        return x

    def adapt_across_output(self, module, inputs, outputs):
        """
        Hook function to adapt the data across the modules.

        Args:
            module (nn.Module): The module being hooked.
            inputs (tuple): (Inputs before the module,).
            outputs (tensor): Outputs after the module.

        Returns:
            tensor: Adapted input tensor after adding the MLP add-in output to the previous module.
        """
        outputs = outputs + self.forward(self.inputs_cache)
        return outputs

Main forward function

In the extended auxiliary structure MLPAddin mentioned above, we add a low-rank bottleneck (consisting of two linear layers, with a reduced dimension in the middle) inspired by efficient parameter methods like Adapter or LoRA.

We define and implement this as in the __init__ and forward functions. The data goes through this structure and executes via the forward function.

../_images/tutorials_addin_structure.png

Fig. 4 Additional Auxiliary Structure Example

Entry points to guide inputs

As shown above, the hook methods starting with adapt_ are our entry points (functions) to guide the input data. They serve as hooks to attach the extended modules to the base model.

They are roughly divided into two categories:

  • guide data input before the modules

  • direct data output after the modules

These are generally closely associated with the forward function, and the data enters extended structures through these entry points. We will further explain their roles in the following Configuration Syntax section.

Config Syntax of Entry Points

We aim to customize our model by inter-layer insertion and cross-layer concatenation of the auxiliary structures at different positions within the base model (such as the custom MLP mentioned earlier). When configuring the insertion or concatenation positions, ZhiJian provides a minimalistic one-line configuration syntax.

The syntax for configuring add-in module into the base model is as follows. We will start with one or two examples and gradually understand the meaning of each configuration part.

  • Inter-layer Insertion:

    >>> (MLPAddin.adapt_input): ...->{inout1}(fc2)->...
    
    ../_images/tutorials_mlp_addin_1.png

    Fig. 5 Additional Add-in Structure - Inter-layer Insertion 1

    >>> (MLPAddin.adapt_input): ...->(fc2){inout1}->...
    
    ../_images/tutorials_mlp_addin_2.png

    Fig. 6 Additional Add-in Structure - Inter-layer Insertion 2

  • Cross-layer Concatenation:

    >>> (MLPAddin.adapt_across_input): ...->(fc1){in1}->...->{out1}(fc3)->...
    
    ../_images/tutorials_mlp_addin_3.png

    Fig. 7 Additional Add-in Structure - Cross-layer Concatenation

Base Module: ->(fc1)

Consider a base model implemented based on the PyTorch framework, where the representation of each layer and module in the model is straightforward:

  • As shown in the figure, the print command can output the defined names of the model structure:

    print(model)
    
  • The structure of some classic backbone can be represented as follows

    • MLP:

      >>> input->(fc1)->(fc2)->(fc3)->output
      
    • ViT block[i]`:

      >>> input->...->(block[i].norm1)->
            (block[i].attn.qkv)->(block[i].attn.attn_drop)->(block[i].attn.proj)->(block[i].attn.proj_drop)->
              (block[i].ls1)->(block[i].drop_path1)->
                (block[i].norm2)->
                  (block[i].mlp.fc1)->(block[i].mlp.act)->(block[i].mlp.drop1)->(block[i].mlp.fc2)->(block[i].mlp.drop2)->
                    (block[i].ls2)->(block[i].drop_path2)->...->output
      

Default Module: ...

In the configuration syntax of ZhiJian, the ... can be used to represent the default layer or module.

  • For example, when we only focus on the (fc2) module in MLP and the (block[i].mlp.fc2) module in ViT:

    • MLP:

      >>> ...->(fc2)->...
      
    • ViT:

      >>> ...->(block[i].mlp.fc2)->...
      

Insertion & Concatenation Function: ():

Considering the custom auxiliary structure MLPAddin mentioned above, the functions starting with adapt_ will serve as the processing center that insert and concatenate into the base model.

  • There are primarily two types of parameter passing methods:

    def adapt_input(self, module, inputs):
        """
        Args:
            module (nn.Module): The module being hooked.
            inputs (tuple): (Inputs before the module,).
        """
        ...
    
    def adapt_output(self, module, inputs, outputs):
        """
        Args:
            module (nn.Module): The module being hooked.
            inputs (tuple): (Inputs before the module,).
            outputs (tensor): Outputs after the module.
        """
        ...
    

    where

    • adapt_input(self, module, inputs) is generally set before the module and is called before the data enters the module to process inputs and truncate the input.

    • adapt_output(self, module, inputs, outputs) is generally set before the module and is called before the data enters the module to process outputs and truncate the output.

These functions will be “hooked” into the base model in the main method of configuring the module, serving as key connectors between the base model and the auxiliary structure.

Insertion & Concatenation Point: {}

Consider an independent extended auxiliary structure (such as the MLPAddin mentioned above), its insertion or concatenation points with the base network must consist of “Data Input” and “Data Output” where:

  • “Data Input” refers to the network features input into the extended auxiliary structure.

  • “Data Output” refers to the adapted features output from the auxiliary structure back to the base network.

Next, let’s use some configuration examples of MLP to illustrate the syntax and functionality of ZhiJian for module integration:

Inter-layer Insertion: inout

  • As shown in the above Fig. 5, the configuration expression is:

    >>> (MLPAddin.adapt_input): ...->{inout1}(fc2)->...
    

    where

    • {inout1} refers to the position which gets the base model features (or output, at any layer or module).

      It denotes the “Data Input” and “Data Output”. The configuration can be {inoutx}, where x represents the xth integration point. For example, {inout1} represents the first integration point.

    • In the example above, this inter-layer insertion configuration truncates the features of the input fc2 module, passes them through, and then return to the fc2 module. At this point, the original fc2 features no longer enter.

Cross-layer Concatenation in, out

  • As shown in the above Fig. 7, the configuration expression is:

    >>> (MLPAddin.adapt_across_input): ...->(fc1){in1}->...->{out1}(fc3)->...`
    

    where

    • {in1}: represents the integration point where the base network features (or output, at any layer or module) enter the additional add-in structure.

      It denotes the “Data Input”. The configuration can be {inx}, where x represents the xth integration point. For example, {in1} represents the first integration point.

    • {out1}: represent the integration points where the features processed by the additional add-in structure are returned to the base network.

      It denotes the “Data Output”. The configuration can be {outx}, where x represents the xth integration point. For example, {out1} represents the first integration point.

    • This cross-layer concatenation configuration extracts the features of the fc1 module’s output, passes them into the auxiliary structure, and then returns them to the base network before the fc3 module in the form of residual addition.

  • For a better prompt, let’s create a tool function that guides the input first:

    def select_from_input(prompt_for_select, valid_selections):
        selections2print = '\n\t'.join([f'[{idx + 1}] {i}' for idx, i in enumerate(valid_selections)])
        while True:
            selected = input(f"Please input a {prompt_for_select}, type 'help' to show the options: ")
    
            if selected == 'help':
                print(f"Available {prompt_for_select}(s):\n\t{selections2print}")
            elif selected.isdigit() and int(selected) >= 1 and int(selected) <= len(valid_selections):
                selected = valid_selections[int(selected) - 1]
                break
            elif selected in valid_selections:
                break
            else:
                print("Sorry, input not support.")
                print(f"Available {prompt_for_select}(s):\n\t{selections2print}")
    
        return selected
    
    available_example_config_blitzs = {
        'Insert between `fc1` and `fc2` layer (performed before `fc2`)': "(MLPAddin.adapt_input): ...->{inout1}(fc2)->...",
        'Insert between `fc1` and `fc2` layer (performed after `fc1`)': "(MLPAddin.adapt_output): ...->(fc1){inout1}->...",
        'Splice across `fc2` layer (performed before `fc2` and `fc3`)': "(MLPAddin.adapt_across_input): ...->{in1}(fc2)->{out1}(fc3)->...",
        'Splice across `fc2` layer (performed after `fc1` and before `fc3`)': "(MLPAddin.adapt_across_input): ...->(fc1){in1}->...->{in2}(fc3)->...",
        'Splice across `fc2` layer (performed before and after `fc2`)': "(MLPAddin.adapt_across_output): ...->{in1}(fc2){in2}->...",
        'Splice across `fc2` layer (performed after `fc1` and `fc2`)': "(MLPAddin.adapt_across_output): ...->(fc1){in1}->(fc2){in2}->...",
    }
    
    config_blitz = select_from_input('add-in structure', available_example_config_blitzs.keys()) # user input about model
    
    $ Available dataset(s):
          [1] VTAB-1k.CIFAR-100
          [2] VTAB-1k.CLEVR-Count
          [3] VTAB-1k.CLEVR-Distance
          [4] VTAB-1k.Caltech101
          [5] VTAB-1k.DTD
          [6] VTAB-1k.Diabetic-Retinopathy
          [7] VTAB-1k.Dmlab
          [8] VTAB-1k.EuroSAT
          [9] VTAB-1k.KITTI
          [10] VTAB-1k.Oxford-Flowers-102
          [11] VTAB-1k.Oxford-IIIT-Pet
          [12] VTAB-1k.PatchCamelyon
          [13] VTAB-1k.RESISC45
          [14] VTAB-1k.SUN397
          [15] VTAB-1k.SVHN
          [16] VTAB-1k.dSprites-Location
          [17] VTAB-1k.dSprites-Orientation
          [18] VTAB-1k.smallNORB-Azimuth
          [19] VTAB-1k.smallNORB-Elevation
      Your selection: VTAB-1k.CIFAR-100
      Your dataset directory: /data/zhangyk/data/zhijian
    
  • Next, we will configure the parameters and proceed with model training and testing:

    args = get_args(
        model='timm.vit_base_patch16_224_in21k',    # backbone network
        config_blitz=config_blitz,                  # addin blitz configuration
        dataset='VTAB.cifar',                       # dataset
        dataset_dir='your/dataset/directory',       # dataset directory
        training_mode='finetune',                   # training mode
        optimizer='adam',                           # optimizer
        lr=1e-2,                                    # learning rate
        wd=1e-5,                                    # weight decay
        verbose=True                                # control the verbosity of the output
    )
    pprint(vars(args))
    
    $ {'aa': None,
       'addins': [{'hook': [['get_pre', 'pre'], ['adapt_across_output', 'post']],
                   'location': [['fc2'], ['fc2']],
                   'name': 'MLPAddin'}],
       'amp': False,
       'amp_dtype': 'float16',
       'amp_impl': 'native',
       'aot_autograd': False,
       'aug_repeats': 0,
       'aug_splits': 0,
       'batch_size': 64,
       'bce_loss': False,
       ...
       'warmup_epochs': 5,
       'warmup_lr': 1e-05,
       'warmup_prefix': False,
       'wd': 5e-05,
       'weight_decay': 2e-05,
       'worker_seeding': 'all'}
    
  • Run the code block below to configure the GPU and the model (excluding additional auxiliary structures):

    assert torch.cuda.is_available()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    torch.cuda.set_device(int(args.gpu))
    
    model = MLP(args, DATASET2NUM_CLASSES[args.dataset.replace('VTAB.','')])
    model = ModelWrapper(model)
    model_args = dict2args({'hidden_size': 512})
    
  • Run the code block below to configure additional auxiliary structures:

    args.mlp_addin_output_size = 256
    addins, fixed_params = prepare_addins(args, model_args, addin_classes=[MLPAddin])
    
    prepare_hook(args.addins, addins, model, 'addin')
    prepare_gradient(args.reuse_keys, model)
    device = prepare_cuda(model)
    
  • Run the code block below to configure the dataset, optimizer, loss function, and other settings:

    train_loader, val_loader, num_classes = prepare_vision_dataloader(args, model_args)
    
    optimizer = optim.Adam(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.wd
    )
    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        args.max_epoch,
        eta_min=args.eta_min
    )
    criterion = nn.CrossEntropyLoss()
    
  • Run the code block below to prepare the trainer object and start training and testing:

    trainer = prepare_trainer(
        args,
        model=model,
        model_args=model_args,
        device=device,
        train_loader=train_loader,
        val_loader=val_loader,
        num_classes=num_classes,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        criterion=criterion
    )
    
    trainer.fit()
    trainer.test()
    
    $ Log level set to: INFO
      Log files are recorded in: your/log/directory/0718-19-52-36-748
      Trainable/total parameters of the model: 0.03M / 38.64M (0.08843%)
    
            Epoch   GPU Mem.       Time       Loss         LR
              1/5     0.589G     0.1355      4.602      0.001: 100%|██████████| 16.0/16.0 [00:01<00:00, 12.9batch/s]
    
            Epoch   GPU Mem.       Time      Acc@1      Acc@5
              1/5     0.629G    0.03114      1.871      7.932: 100%|██████████| 157/157 [00:05<00:00, 30.9batch/s]
      ***   Best results: [Acc@1: 1.8710191082802548], [Acc@5: 7.931926751592357]
    
            Epoch   GPU Mem.       Time       Loss         LR
              2/5     0.784G     0.1016      4.538 0.00090451: 100%|██████████| 16.0/16.0 [00:00<00:00, 19.4batch/s]
    
            Epoch   GPU Mem.       Time      Acc@1      Acc@5
              2/5     0.784G    0.02669      2.498      9.504: 100%|██████████| 157/157 [00:04<00:00, 35.9batch/s]
      ***   Best results: [Acc@1: 2.4980095541401273], [Acc@5: 9.504378980891719]
    
            Epoch   GPU Mem.       Time       Loss         LR
              3/5     0.784G    0.09631      4.488 0.00065451: 100%|██████████| 16.0/16.0 [00:00<00:00, 20.6batch/s]
    
            Epoch   GPU Mem.       Time      Acc@1      Acc@5
              3/5     0.784G    0.02688      2.379      10.16: 100%|██████████| 157/157 [00:04<00:00, 36.0batch/s]
      ***   Best results: [Acc@1: 2.3785828025477707], [Acc@5: 10.161226114649681]
    
            Epoch   GPU Mem.       Time       Loss         LR
              4/5     0.784G    0.09126       4.45 0.00034549: 100%|██████████| 16.0/16.0 [00:00<00:00, 20.2batch/s]
    
            Epoch   GPU Mem.       Time      Acc@1      Acc@5
              4/5     0.784G    0.02644      2.468      10.29: 100%|██████████| 157/157 [00:04<00:00, 36.2batch/s]
      ***   Best results: [Acc@1: 2.468152866242038], [Acc@5: 10.290605095541402]
    
            Epoch   GPU Mem.       Time       Loss         LR
              5/5     0.784G     0.0936      4.431 9.5492e-05: 100%|██████████| 16.0/16.0 [00:00<00:00, 20.5batch/s]
    
            Epoch   GPU Mem.       Time      Acc@1      Acc@5
              5/5     0.784G    0.02706      2.558      10.43: 100%|██████████| 157/157 [00:04<00:00, 35.8batch/s]
      ***   Best results: [Acc@1: 2.557722929936306], [Acc@5: 10.429936305732484]
    
            Epoch   GPU Mem.       Time      Acc@1      Acc@5
              1/5     0.784G    0.02667      2.558      10.43: 100%|██████████| 157/157 [00:04<00:00, 36.0batch/s]
      ***   Best results: [Acc@1: 2.557722929936306], [Acc@5: 10.429936305732484]