import deepinv as dinv
import torch
SSIBench: Benchmarking Self-Supervised Learning Methods for Accelerated MRI Reconstruction
SSIBench is a modular benchmark for learning to solve imaging inverse problems without ground truth, applied to accelerated MRI reconstruction.
Andrew Wang, Steven McDonagh, Mike Davies
Skip to…
Overview
SSIBench is a modular benchmark for learning to solve imaging inverse problems without ground truth, applied to accelerated MRI reconstruction. We contribute:
- A comprehensive review of state-of-the-art self-supervised feedforward methods for inverse problems;
- Well-documented implementations of all benchmarked methods in the open-source DeepInverse library, and a modular benchmark site enabling ML researchers to evaluate new methods or on custom setups and datasets;
- Benchmarking experiments on MRI, on a standardised setup across multiple realistic, general scenarios;
- A new method, multi-operator equivariant imaging (MO-EI).
How to…
How to use the benchmark
First setup your environment:
- Create a python environment:
python -m venv venv
source venv/Scripts/activate
- Clone the benchmark repo:
git clone https://github.com/Andrewwango/ssibench.git
- Install DeepInverse
pip install deepinv # Stable
pip install git+https://github.com/deepinv/deepinv.git#egg=deepinv # Nightly
- Prepare your fastMRI data using the below instructions.
Then run train.py
for your chosen loss, where --loss
is the loss function (mc
, ei
etc.), and --physics
is the physics (see train.py
for options):
python train.py --loss ... --physics ...
To evaluate, use the same script train.py
with 0 epochs and loading a checkpoint. We provide one pretrained model for quick eval (download here):
python train.py --epochs 0 --ckpt "demo_mo-ei.pth.tar"
Notation: in our benchmark, we compare the loss
functions \(\mathcal{L}(\ldots)\), while keeping constant the model
\(f_\theta\), forward operator physics
\(A\), and data \(y\).
How to contribute a method
- Add the code for your loss in the format:
class YourOwnLoss(deepinv.loss.Loss):
def forward(
self,
# Reconstruction i.e. model output
x_net: torch.Tensor, # Measurement data e.g. k-space in MRI
y: torch.Tensor, # Reconstruction model $f_\theta$
model: deepinv.models.Reconstructor, # Forward operator physics $A$
physics: deepinv.physics.Physics, = None, # Ground truth, must be unused!
x: torch.Tensor **kwargs
):= ...
loss_calc return loss_calc
- Add your loss function as an option in
train.py
(hint: search “Add your custom loss here!”) - Benchmark your method by running
train.py
(hint: “How to use the benchmark”). - Submit your results by editing the live leaderboard.
- Open a GitHub pull request to contribute your loss! (hint: see example here; hint: how to open a PR in GitHub)
How to use a custom dataset
Our modular benchmark lets you easily train and evaluate the benchmarked methods on your own setup.
- The custom dataset should have the form (see DeepInverse docs for details):
class YourOwnDataset(torch.utils.data.Dataset):
def __getitem__(self, idx: int):
...# y = measurement data
# params = dict of physics data-dependent parameters, e.g. acceleration mask in MRI
return x, y, params # If ground truth x provided for evaluation
return torch.nan, y, params # If ground truth does not exist
- Replace
dataset = ...
intrain.py
with your own, then train/evaluate using the script as in How to use the benchmark.
How to use a custom model
- The custom model should have the form (see DeepInverse guide for details):
class YourOwnModel(deepinv.models.Reconstructor):
def forward(
self,
y: torch.Tensor,
physics: deepinv.physics.Physics,**kwargs
):= ...
x_net return x_net
- Replace
model = ...
intrain.py
with your own, then train/evaluate using the script as in How to use the benchmark.
How to use a custom forward operator/acquisition strategy
- To use an alternative physics, you can use a different off-the-shelf DeepInverse physics or a custom one of the form (see DeepInverse guide on creating custom physics):
class YourOwnPhysics(deepinv.physics.Physics):
def A(self, x: torch.Tensor, **kwargs):
= ...
y return y
def A_adjoint(self, y: torch.Tensor, **kwargs):
= ...
x_hat return x_hat
- Replace
physics = ...
train.py
with your own, then train/evaluate using the script as in How to use the benchmark.
How to use a custom metric
- The custom metric should have the form (see DeepInverse docs for details):
class YourOwnMetric(deepinv.loss.metric.Metric):
def metric(
self,
# Reconstruction i.e. model output
x_net: torch.Tensor, # Ground-truth for evaluation
x: torch.Tensor,
):return ...
- Replace
metrics = ...
intrain.py
with your own, then train/evaluate using the script as in How to use the benchmark.
Live leaderboard
We provide a live leaderboard for each experimental scenario described in the paper. Got a new method? Contribute it to the leaderboard!
Scenario 1 (single-coil)
# | Loss | PSNR | SSIM |
---|---|---|---|
1 | UAIR | 14.00 | .3715 |
2 | Adversarial | 18.52 | .4732 |
3 | MC | 27.66 | .7861 |
4 | Zero-filled | 27.67 | .7862 |
5 | VORTEX | 27.75 | .7898 |
6 | SSDU | 27.98 | .7485 |
7 | Noise2Inverse | 28.42 | .7853 |
8 | Weighted-SSDU | 29.93 | .8355 |
9 | EI | 30.26 | .8523 |
10 | MOI | 30.29 | .8651 |
11 | MOC-SSDU | 30.42 | .8198 |
12 | SSDU-Consistency | 30.81 | .8495 |
13 | MO-EI | 32.14 | .8846 |
14 | (Supervised) | 33.15 | .9032 |
Scenario 2 (noisy)
# | Loss | PSNR | SSIM |
---|---|---|---|
1 | Zero-filled | 24.34 | .4428 |
2 | (Non-robust) Weighted-SSDU | 25.91 | .5477 |
3 | (Non-robust) MO-EI | 26.12 | .6002 |
4 | ENSURE | 26.29 | .5856 |
5 | Robust-SSDU | 27.42 | .6159 |
6 | Noise2Recon-SSDU | 27.84 | .7661 |
7 | DDSSL | 28.25 | .7836 |
8 | Robust-EI | 29.07 | .8227 |
9 | Robust-MO-EI | 29.72 | .8409 |
10 | (Supervised) | 30.19 | .8411 |
Scenario 3 (single-operator)
# | Loss | PSNR | SSIM |
---|---|---|---|
1 | UAIR | 18.44 | .5388 |
2 | SSDU | 21.89 | .6288 |
3 | Noise2Inverse | 24.63 | .6559 |
4 | Adversarial | 26.53 | .7013 |
5 | MOC-SSDU | 27.85 | .7717 |
6 | Zero-filled | 28.02 | .7900 |
7 | MC | 28.02 | .7900 |
8 | VORTEX | 28.07 | .7916 |
9 | Weighted-SSDU | 30.14 | .8454 |
10 | SSDU-Consistency | 31.05 | .8614 |
11 | MO-EI | 31.11 | .8713 |
12 | MOI | 31.60 | .8789 |
13 | EI | 31.99 | .8806 |
14 | (Supervised) | 34.03 | .9040 |
Scenario 4 (multi-coil)
# | Loss | PSNR | SSIM |
---|---|---|---|
1 | UAIR | 15.26 | .3453 |
2 | Adversarial | 17.47 | .6464 |
3 | VORTEX | 23.59 | .5846 |
4 | Zero-filled | 27.82 | .7988 |
5 | MC | 28.96 | .8271 |
6 | Noise2Inverse | 30.93 | .8589 |
7 | MOI | 31.37 | .8810 |
8 | SSDU | 31.47 | .8705 |
9 | MO-EI | 31.56 | .8836 |
10 | EI | 31.66 | .8769 |
11 | MOC-SSDU | 31.80 | .8761 |
12 | SSDU-Consistency | 32.30 | .8949 |
13 | Weighted-SSDU | 33.03 | .8991 |
14 | (Supervised) | 33.89 | .9147 |
Training script step-by-step
The training script makes extensive use of modular training framework provided by DeepInverse.
Define training parameters:
= dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
device 0)
torch.manual_seed(0)
torch.cuda.manual_seed(= torch.Generator(device=device).manual_seed(0)
rng = torch.Generator(device="cpu").manual_seed(0)
rng_cpu = 6
acceleration = 4
batch_size = 1e-3
lr = (320, 320)
img_size
class args: # Command line args from train.py
= "mri"
physics = 0
epochs = "mc"
loss = None ckpt
Define MRI physics \(A\) and mask generator \(M\) according to scenario
= dinv.physics.generator.GaussianMaskGenerator(img_size=img_size, acceleration=acceleration, rng=rng, device=device)
physics_generator = dinv.physics.MRI(img_size=img_size, device=device)
physics
match args.physics:
case "noisy":
= 0.1
sigma = dinv.physics.GaussianNoise(sigma, rng=rng)
physics.noise_model case "multicoil":
= dinv.physics.MultiCoilMRI(img_size=img_size, coil_maps=4, device=device)
physics case "single":
**physics_generator.step()) physics.update(
Define model \(f_\theta\)
= dinv.models.UNet(2, 2, scales=4, batch_norm=False)
denoiser = dinv.models.MoDL(denoiser=denoiser, num_iter=3).to(device) model
Define dataset
= dinv.datasets.SimpleFastMRISliceDataset("data", file_name="fastmri_brain_singlecoil.pt")
dataset = torch.utils.data.random_split(dataset, (0.8, 0.2), generator=rng_cpu) train_dataset, test_dataset
Simulate and save random measurements
= dinv.datasets.generate_dataset(
dataset_path =train_dataset,
train_dataset=test_dataset,
test_dataset=physics,
physics=physics_generator if args.physics != "single" else None,
physics_generator=True,
save_physics_generator_params=False,
overwrite_existing=device,
device="data",
save_dir=1,
batch_size="dataset_" + args.physics
dataset_filename
)
= dinv.datasets.HDF5Dataset(dataset_path, split="train", load_physics_generator_params=True)
train_dataset = dinv.datasets.HDF5Dataset(dataset_path, split="test", load_physics_generator_params=True)
test_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=rng_cpu)
train_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) test_dataloader
Define loss function (see train.py
for all options)
match args.loss:
case "mc":
= dinv.loss.MCLoss()
loss
case "...":
# Add your custom loss here!
pass
Define metrics
= [
metrics =True),
dinv.metric.PSNR(complex_abs=True)
dinv.metric.SSIM(complex_abs ]
Define trainer
= dinv.Trainer(
trainer = model,
model = physics,
physics = torch.optim.Adam(model.parameters(), lr=lr),
optimizer = train_dataloader,
train_dataloader = test_dataloader,
eval_dataloader = args.epochs,
epochs = loss,
losses = metrics,
metrics = device,
device =args.ckpt,
ckpt_pretrained )
Define additional adversarial trainer (if needed)
if args.loss in ("uair", "adversarial"):
= dinv.training.AdversarialTrainer(
trainer = model,
model = physics,
physics = dinv.training.AdversarialOptimizer(
optimizer =lr),
torch.optim.Adam(model.parameters(), lr=lr)
torch.optim.Adam(discrim.parameters(), lr
),= train_dataloader,
train_dataloader = test_dataloader,
eval_dataloader = args.epochs,
epochs = loss,
losses = metrics,
metrics = device,
device =args.ckpt,
ckpt_pretrained
)
= discrim
trainer.D = loss_d trainer.losses_d
Train or evaluate!
trainer.train()
print(trainer.test(test_dataloader))
Dataset preparation instructions
To prepare the fastMRI dataset fastmri_brain_singlecoil.pt
used in train.py
for the benchmark experiments, we make use of the fastMRI wrapper in DeepInverse.
- Download fastMRI brain dataset batch 0:
brain_multicoil_train_batch_0
(~98.5 GB) - Generate an efficient dataset of the middle slices (note that this is deterministic, and the random masks & noise are simulated in
train.py
):
= deepinv.datasets.FastMRISliceDataset(
dataset "/path/to/fastmri/brain/multicoil_train",
="middle"
slice_index
)
"data/fastmri_brain_singlecoil.pt") dataset.save_simple_dataset(