Offline Linear Eval

Now that you know how to pretrain a model, let’s go through the procedure to perform offline linear evaluation.

As for pretraining, we start by importing the required packages:

import torch
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from torchvision.models import resnet18

from solo.methods.linear import LinearModel  # imports the linear eval class
from solo.utils.classification_dataloader import prepare_data

There are tons of parameters that need to be set and, fortunately, main_linear.py takes care of this for us. If we want to be able to specify the arguments from the command line, we can simply call the function parse_args_linear in solo.args.setup. However, in this tutorial, we will simply define all the needed parameters to perform linear evaluation:

# basic parameters for offline linear evaluation
# some parameters for extra functionally are missing, but don't mind this for now.
kwargs = {
    "num_classes": 10,
    "cifar": True,
    "max_epochs": 100,
    "optimizer": "sgd",
    "precision": 16,
    "lars": False,
    "lr": 0.1,
    "exclude_bias_n_norm_lars": False,
    "gpus": "0",
    "weight_decay": 0,
    "extra_optimizer_args": {"momentum": 0.9},
    "scheduler": "step",
    "lr_decay_steps": [60, 80],
    "batch_size": 128,
    "num_workers": 4,
    "pretrained_feature_extractor": "path/to/pretrained/feature/extractor"
}

Apart from the hyperparameters, we also need to load the pretrained model:

# create the backbone network
# the first convolutional and maxpooling layers of the ResNet backbone
# are adjusted to handle lower resolution images (32x32 instead of 224x224).
backbone = resnet18()
backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)
backbone.maxpool = nn.Identity()
backbone.fc = nn.Identity()

# load pretrained feature extractor
state = torch.load(kwargs["pretrained_feature_extractor"])["state_dict"]
for k in list(state.keys()):
    if "backbone" in k:
        state[k.replace("backbone.", "")] = state[k]
    del state[k]
backbone.load_state_dict(state, strict=False)

model = LinearModel(backbone, **kwargs)

Now, let’s create the data loaders. Unlike when we are doing pretraining, this time we will not use multiple augmentations:

train_loader, val_loader = prepare_data(
    "cifar10",
    data_dir="./",
    train_dir=None,
    val_dir=None,
    batch_size=base_kwargs["batch_size"],
    num_workers=base_kwargs["num_workers"],
)

Lastly, we just need to define some extra utilities for Pytorch Lightning to automatically log some stuff for us and then we can just create our lightning Trainer:

wandb_logger = WandbLogger(
    name="linear-cifar10",  # name of the experiment
    project="self-supervised",  # name of the wandb project
    entity=None,
    offline=False,
)
wandb_logger.watch(model, log="gradients", log_freq=100)

callbacks = []

# automatically log our learning rate
lr_monitor = LearningRateMonitor(logging_interval="epoch")
callbacks.append(lr_monitor)

# checkpointer can automatically log your parameters,
# but we need to wrap them in a Namespace object
from argparse import Namespace
args = Namespace(**kwargs)
# saves the checkout after every epoch
ckpt = Checkpointer(
    args,
    logdir="checkpoints/linear",
    frequency=1,
)
callbacks.append(ckpt)

trainer = Trainer.from_argparse_args(
    args,
    logger=wandb_logger if args.wandb else None,
    callbacks=callbacks,
    plugins=DDPPlugin(find_unused_parameters=True),
    checkpoint_callback=False,
    terminate_on_nan=True,
)

trainer.fit(model, train_loader, val_loader)

And that’s it, we basically replicated a small version of main_linear.py. Of course, we can accomplish the same thing by simply running the following script:

python3 ../../main_linear.py \
    --dataset cifar10 \
    --backbone resnet18 \
    --data_dir ./ \
    --max_epochs 100 \
    --gpus 0 \
    --sync_batchnorm \
    --precision 16 \
    --optimizer sgd \
    --scheduler step \
    --lr 0.1 \
    --lr_decay_steps 60 80 \
    --weight_decay 0 \
    --batch_size 128 \
    --num_workers 4 \
    --name general-linear-eval \
    --pretrained_feature_extractor path/to/pretrained/feature/extractor \
    --project self-supervised \
    --wandb

Now you are fully able to use solo-learn and you can make your research ideas become reality!