Adding New Methods (Momentum Version)

As of now, you should be familiar with how to implement new methods in solo-learn. If not, please read this tutorial: Adding New Methods. This tutorial will help you creating methods that use a momentum backbone. Let’s now suppose we wanted to implement NNBYOL (similar to NNSiam but with momentum backbone). As always, the fist thing to do is to navigate to solo/methods and create a python file for our new method (e.g.: nnbyol.py):

import argparse
from typing import Any, Dict, List, Sequence, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from solo.losses.byol import byol_loss_func
from solo.methods.base import BaseMomentumMethod
from solo.utils.momentum import initialize_momentum_params
from solo.utils.misc import gather


class NNBYOL(BaseMomentumMethod):
    def __init__(
        self,
        proj_output_dim: int,
        proj_hidden_dim: int,
        pred_hidden_dim: int,
        queue_size: int,
        **kwargs,
    ):
        """Implements NNBYOL (https://arxiv.org/abs/2104.14548).

        Args:
            proj_output_dim (int): number of dimensions of projected features.
            proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
            pred_hidden_dim (int): number of neurons of the hidden layers of the predictor.
            queue_size (int): number of samples to keep in the queue.
        """

        super().__init__(**kwargs)

        self.queue_size = queue_size

        # projector
        self.projector = nn.Sequential(
            nn.Linear(self.features_dim, proj_hidden_dim),
            nn.BatchNorm1d(proj_hidden_dim),
            nn.ReLU(),
            nn.Linear(proj_hidden_dim, proj_output_dim),
        )

        # momentum projector
        self.momentum_projector = nn.Sequential(
            nn.Linear(self.features_dim, proj_hidden_dim),
            nn.BatchNorm1d(proj_hidden_dim),
            nn.ReLU(),
            nn.Linear(proj_hidden_dim, proj_output_dim),
        )
        initialize_momentum_params(self.projector, self.momentum_projector)

        # predictor
        self.predictor = nn.Sequential(
            nn.Linear(proj_output_dim, pred_hidden_dim),
            nn.BatchNorm1d(pred_hidden_dim),
            nn.ReLU(),
            nn.Linear(pred_hidden_dim, proj_output_dim),
        )

        # queue
        self.register_buffer("queue", torch.randn(self.queue_size, proj_output_dim))
        self.register_buffer("queue_y", -torch.ones(self.queue_size, dtype=torch.long))
        self.queue = F.normalize(self.queue, dim=1)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @property
    def learnable_params(self) -> List[dict]:
        """Adds projector and predictor parameters to the parent's learnable parameters.

        Returns:
            List[dict]: list of learnable parameters.
        """

        extra_learnable_params = [
            {"params": self.projector.parameters()},
            {"params": self.predictor.parameters()},
        ]
        return super().learnable_params + extra_learnable_params

Note that here we are inheriting from BaseMomentumMethod which already implements most of the complexity for momentum-based models. Apart from this, and similarly to NNSiam, NNBYOL has a projector, a predictor and a queue. However, NNBYOL also has a momentum backbone and a momentum projector that need to be updated at every step. The library already implements this behavior for the momentum backbone. To achieve the same for the momentum projector, the only thing that you need to do is overriding the momentum_pairs property of the parent:

@property
def momentum_pairs(self) -> List[Tuple[Any, Any]]:
    """Adds (projector, momentum_projector) to the parent's momentum pairs.

    Returns:
        List[Tuple[Any, Any]]: list of momentum pairs.
    """

    extra_momentum_pairs = [(self.projector, self.momentum_projector)]
    return super().momentum_pairs + extra_momentum_pairs

You can just use the momentum encoder in your training step:

def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
    """Training step for BYOL reusing BaseMethod training step.

    Args:
        batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
            [X] is a list of size num_crops containing batches of images.
        batch_idx (int): index of the batch.

    Returns:
        torch.Tensor: total loss composed of BYOL and classification loss.
    """

    targets = batch[-1]

    out = super().training_step(batch, batch_idx)
    class_loss = out["loss"]
    feats1, feats2 = out["feats"]
    momentum_feats1, momentum_feats2 = out["momentum_feats"]

    z1 = self.projector(feats1)
    z2 = self.projector(feats2)

    p1 = self.predictor(z1)
    p2 = self.predictor(z2)

    # forward momentum backbone
    with torch.no_grad():
        z1_momentum = self.momentum_projector(momentum_feats1)
        z2_momentum = self.momentum_projector(momentum_feats2)

    z1_momentum = F.normalize(z1_momentum, dim=-1)
    z2_momentum = F.normalize(z2_momentum, dim=-1)

    # find nn
    idx1, nn1_momentum = self.find_nn(z1_momentum)
    _, nn2_momentum = self.find_nn(z2_momentum)

    # ------- negative cosine similarity loss -------
    neg_cos_sim = byol_loss_func(p1, nn2_momentum) + byol_loss_func(p2, nn1_momentum)

    # compute nn accuracy
    b = targets.size(0)
    nn_acc = (targets == self.queue_y[idx1]).sum() / b

    # dequeue and enqueue
    self.dequeue_and_enqueue(z1_momentum, targets)

    # calculate std of features
    z1_std = F.normalize(z1, dim=-1).std(dim=0).mean()
    z2_std = F.normalize(z2, dim=-1).std(dim=0).mean()
    z_std = (z1_std + z2_std) / 2

    metrics = {
        "train_neg_cos_sim": neg_cos_sim,
        "train_z_std": z_std,
        "train_nn_acc": nn_acc,
    }
    self.log_dict(metrics, on_epoch=True, sync_dist=True)

    return neg_cos_sim + class_loss

For the rest of the code for NNBYOL, please check out the implementation in solo/methods/nnbyol.py.