SSIBench: Benchmarking Self-Supervised Learning Methods for Accelerated MRI Reconstruction

SSIBench is a modular benchmark for learning to solve imaging inverse problems without ground truth, applied to accelerated MRI reconstruction.

Andrew Wang, Steven McDonagh, Mike Davies

arXiv Code Benchmark Colab


Skip to…

  1. Overview
  2. How to…
    1. …use the benchmark
    2. …contribute a method
    3. …use a custom dataset, model, forward operator/acquisition strategy, metric
  3. Live leaderboard
  4. Training script step-by-step
  5. Dataset preparation instructions

Overview

SSIBench is a modular benchmark for learning to solve imaging inverse problems without ground truth, applied to accelerated MRI reconstruction. We contribute:

  1. A comprehensive review of state-of-the-art self-supervised feedforward methods for inverse problems;
  2. Well-documented implementations of all benchmarked methods in the open-source DeepInverse library, and a modular benchmark site enabling ML researchers to evaluate new methods or on custom setups and datasets;
  3. Benchmarking experiments on MRI, on a standardised setup across multiple realistic, general scenarios;
  4. A new method, multi-operator equivariant imaging (MO-EI).

How to…

How to use the benchmark

First setup your environment:

  1. Create a python environment:
python -m venv venv
source venv/Scripts/activate
  1. Clone the benchmark repo:
git clone https://github.com/Andrewwango/ssibench.git
  1. Install DeepInverse
pip install deepinv   # Stable
pip install git+https://github.com/deepinv/deepinv.git#egg=deepinv   # Nightly
  1. Prepare your fastMRI data using the below instructions.

Then run train.py for your chosen loss, where --loss is the loss function (mc, ei etc.), and --physics is the physics (see train.py for options):

python train.py --loss ... --physics ...

To evaluate, use the same script train.py with 0 epochs and loading a checkpoint. We provide one pretrained model for quick eval (download here):

python train.py --epochs 0 --ckpt "demo_mo-ei.pth.tar"

Notation: in our benchmark, we compare the loss functions \(\mathcal{L}(\ldots)\), while keeping constant the model \(f_\theta\), forward operator physics \(A\), and data \(y\).

How to contribute a method

  1. Add the code for your loss in the format:
class YourOwnLoss(deepinv.loss.Loss):
    def forward(
        self, 
        x_net: torch.Tensor,    # Reconstruction i.e. model output
        y: torch.Tensor,        # Measurement data e.g. k-space in MRI
        model: deepinv.models.Reconstructor, # Reconstruction model $f_\theta$
        physics: deepinv.physics.Physics,    # Forward operator physics $A$
        x: torch.Tensor = None, # Ground truth, must be unused!
        **kwargs
    ):
        loss_calc = ...
        return loss_calc
  1. Add your loss function as an option in train.py (hint: search “Add your custom loss here!”)
  2. Benchmark your method by running train.py (hint: “How to use the benchmark”).
  3. Submit your results by editing the live leaderboard.
  4. Open a GitHub pull request to contribute your loss! (hint: see example here; hint: how to open a PR in GitHub)

How to use a custom dataset

Our modular benchmark lets you easily train and evaluate the benchmarked methods on your own setup.

  1. The custom dataset should have the form (see DeepInverse docs for details):
class YourOwnDataset(torch.utils.data.Dataset):
    def __getitem__(self, idx: int):
        ...
        # y = measurement data
        # params = dict of physics data-dependent parameters, e.g. acceleration mask in MRI
        return     x,     y, params # If ground truth x provided for evaluation
        return torch.nan, y, params # If ground truth does not exist
  1. Replace dataset = ... in train.py with your own, then train/evaluate using the script as in How to use the benchmark.

How to use a custom model

  1. The custom model should have the form (see DeepInverse guide for details):
class YourOwnModel(deepinv.models.Reconstructor):
    def forward(
        self, 
        y: torch.Tensor,
        physics: deepinv.physics.Physics,
        **kwargs
    ):
        x_net = ...
        return x_net
  1. Replace model = ... in train.py with your own, then train/evaluate using the script as in How to use the benchmark.

How to use a custom forward operator/acquisition strategy

  1. To use an alternative physics, you can use a different off-the-shelf DeepInverse physics or a custom one of the form (see DeepInverse guide on creating custom physics):
class YourOwnPhysics(deepinv.physics.Physics):
    def A(self, x: torch.Tensor, **kwargs):
        y = ...
        return y
    
    def A_adjoint(self, y: torch.Tensor, **kwargs):
        x_hat = ...
        return x_hat
  1. Replace physics = ... train.py with your own, then train/evaluate using the script as in How to use the benchmark.

How to use a custom metric

  1. The custom metric should have the form (see DeepInverse docs for details):
class YourOwnMetric(deepinv.loss.metric.Metric):
    def metric(
        self, 
        x_net: torch.Tensor, # Reconstruction i.e. model output
        x: torch.Tensor,     # Ground-truth for evaluation
    ):
        return ...
  1. Replace metrics = ... in train.py with your own, then train/evaluate using the script as in How to use the benchmark.

Live leaderboard

We provide a live leaderboard for each experimental scenario described in the paper. Got a new method? Contribute it to the leaderboard!

Scenario 1 (single-coil)
# Loss PSNR SSIM
1 UAIR 14.00 .3715
2 Adversarial 18.52 .4732
3 MC 27.66 .7861
4 Zero-filled 27.67 .7862
5 VORTEX 27.75 .7898
6 SSDU 27.98 .7485
7 Noise2Inverse 28.42 .7853
8 Weighted-SSDU 29.93 .8355
9 EI 30.26 .8523
10 MOI 30.29 .8651
11 MOC-SSDU 30.42 .8198
12 SSDU-Consistency 30.81 .8495
13 MO-EI 32.14 .8846
14 (Supervised) 33.15 .9032

Scenario 2 (noisy)
# Loss PSNR SSIM
1 Zero-filled 24.34 .4428
2 (Non-robust) Weighted-SSDU 25.91 .5477
3 (Non-robust) MO-EI 26.12 .6002
4 ENSURE 26.29 .5856
5 Robust-SSDU 27.42 .6159
6 Noise2Recon-SSDU 27.84 .7661
7 DDSSL 28.25 .7836
8 Robust-EI 29.07 .8227
9 Robust-MO-EI 29.72 .8409
10 (Supervised) 30.19 .8411

Scenario 3 (single-operator)
# Loss PSNR SSIM
1 UAIR 18.44 .5388
2 SSDU 21.89 .6288
3 Noise2Inverse 24.63 .6559
4 Adversarial 26.53 .7013
5 MOC-SSDU 27.85 .7717
6 Zero-filled 28.02 .7900
7 MC 28.02 .7900
8 VORTEX 28.07 .7916
9 Weighted-SSDU 30.14 .8454
10 SSDU-Consistency 31.05 .8614
11 MO-EI 31.11 .8713
12 MOI 31.60 .8789
13 EI 31.99 .8806
14 (Supervised) 34.03 .9040

Scenario 4 (multi-coil)
# Loss PSNR SSIM
1 UAIR 15.26 .3453
2 Adversarial 17.47 .6464
3 VORTEX 23.59 .5846
4 Zero-filled 27.82 .7988
5 MC 28.96 .8271
6 Noise2Inverse 30.93 .8589
7 MOI 31.37 .8810
8 SSDU 31.47 .8705
9 MO-EI 31.56 .8836
10 EI 31.66 .8769
11 MOC-SSDU 31.80 .8761
12 SSDU-Consistency 32.30 .8949
13 Weighted-SSDU 33.03 .8991
14 (Supervised) 33.89 .9147


Training script step-by-step

The training script makes extensive use of modular training framework provided by DeepInverse.

import deepinv as dinv
import torch

Define training parameters:

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)
torch.cuda.manual_seed(0)
rng = torch.Generator(device=device).manual_seed(0)
rng_cpu = torch.Generator(device="cpu").manual_seed(0)
acceleration = 6
batch_size = 4
lr = 1e-3
img_size = (320, 320)

class args: # Command line args from train.py
    physics = "mri"
    epochs = 0
    loss = "mc"
    ckpt = None

Define MRI physics \(A\) and mask generator \(M\) according to scenario

physics_generator = dinv.physics.generator.GaussianMaskGenerator(img_size=img_size, acceleration=acceleration, rng=rng, device=device)
physics = dinv.physics.MRI(img_size=img_size, device=device)

match args.physics:
    case "noisy":
        sigma = 0.1
        physics.noise_model = dinv.physics.GaussianNoise(sigma, rng=rng)
    case "multicoil":
        physics = dinv.physics.MultiCoilMRI(img_size=img_size, coil_maps=4, device=device)
    case "single":
        physics.update(**physics_generator.step())

Define model \(f_\theta\)

denoiser = dinv.models.UNet(2, 2, scales=4, batch_norm=False)
model = dinv.models.MoDL(denoiser=denoiser, num_iter=3).to(device)

Define dataset

dataset = dinv.datasets.SimpleFastMRISliceDataset("data", file_name="fastmri_brain_singlecoil.pt")
train_dataset, test_dataset = torch.utils.data.random_split(dataset, (0.8, 0.2), generator=rng_cpu)

Simulate and save random measurements

dataset_path = dinv.datasets.generate_dataset(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    physics=physics,
    physics_generator=physics_generator if args.physics != "single" else None,
    save_physics_generator_params=True,
    overwrite_existing=False,
    device=device,
    save_dir="data",
    batch_size=1,
    dataset_filename="dataset_" + args.physics
)

train_dataset = dinv.datasets.HDF5Dataset(dataset_path, split="train", load_physics_generator_params=True)
test_dataset  = dinv.datasets.HDF5Dataset(dataset_path, split="test",  load_physics_generator_params=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=rng_cpu)
test_dataloader  = torch.utils.data.DataLoader(test_dataset,  batch_size=batch_size)

Define loss function (see train.py for all options)

match args.loss:
    case "mc":
        loss = dinv.loss.MCLoss()

    case "...":
        # Add your custom loss here!
        pass

Define metrics

metrics = [
    dinv.metric.PSNR(complex_abs=True),
    dinv.metric.SSIM(complex_abs=True)
]

Define trainer

trainer = dinv.Trainer(
    model = model,
    physics = physics,
    optimizer = torch.optim.Adam(model.parameters(), lr=lr),
    train_dataloader = train_dataloader,
    eval_dataloader = test_dataloader,
    epochs = args.epochs,
    losses = loss,
    metrics = metrics,
    device = device,
    ckpt_pretrained=args.ckpt,
)
Define additional adversarial trainer (if needed)
if args.loss in ("uair", "adversarial"):
    trainer = dinv.training.AdversarialTrainer(
        model = model,
        physics = physics,
        optimizer = dinv.training.AdversarialOptimizer(
            torch.optim.Adam(model.parameters(), lr=lr), 
            torch.optim.Adam(discrim.parameters(), lr=lr)
        ),
        train_dataloader = train_dataloader,
        eval_dataloader = test_dataloader,
        epochs = args.epochs,
        losses = loss,
        metrics = metrics,
        device = device,
        ckpt_pretrained=args.ckpt,
    )

    trainer.D = discrim
    trainer.losses_d = loss_d

Train or evaluate!

trainer.train()

print(trainer.test(test_dataloader))

Dataset preparation instructions

To prepare the fastMRI dataset fastmri_brain_singlecoil.pt used in train.py for the benchmark experiments, we make use of the fastMRI wrapper in DeepInverse.

  1. Download fastMRI brain dataset batch 0: brain_multicoil_train_batch_0 (~98.5 GB)
  2. Generate an efficient dataset of the middle slices (note that this is deterministic, and the random masks & noise are simulated in train.py):
dataset = deepinv.datasets.FastMRISliceDataset(
    "/path/to/fastmri/brain/multicoil_train", 
    slice_index="middle"
)

dataset.save_simple_dataset("data/fastmri_brain_singlecoil.pt")