Source code for meddlr.modeling.meta_arch.cs_model

"""Compressed Sensing (2D).

This file contains an implementation of the Compressed Sensing
framework by Lustig, et al. using the Python Package sigpy.
See the tutorial for details.

Tutorial:
    https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.app.L1WaveletRecon.html#sigpy.mri.app.L1WaveletRecon

Reference:
    Lustig, M., Donoho, D., & Pauly, J. M. (2007). Sparse MRI: The application of compressed
    sensing for rapid MR imaging. Magnetic Resonance in Medicine, 58(6), 1082-1195.
"""
import numpy as np
import sigpy as sp
import sigpy.mri as mr
import torch
from torch import nn

import meddlr.ops.complex as cplx
from meddlr.config.config import configurable
from meddlr.forward.mri import SenseModel
from meddlr.utils.general import move_to_device

from .build import META_ARCH_REGISTRY

try:
    import cupy as cp

    _CUPY_AVAILABLE = True
except ImportError:
    cp = None
    _CUPY_AVAILABLE = False

__all__ = ["CSModel"]


[docs]@META_ARCH_REGISTRY.register() class CSModel(nn.Module): """Compressed sensing reconstruction with l1 wavelet regularization. This class is a PyTorch wrapper around the SigPy's L1WaveletRecon class. On each forward pass, each example is reconstructed using :math:`\ell_1` wavelet-regularized compressed sensing. If the model should run on a GPU, `cupy` must be installed. Note: Gradients are not supported. Attributes: device (torch.Device | str): Device to use for execution. l1_reg (float): :math:`\ell_1` regularization parameter. max_iter (int): Maximum number of iterations. num_emaps (int): Number of sensitivity maps. """
[docs] @configurable def __init__(self, reg: float, max_iter: int, device="cpu", num_emaps: int = 1): """ Args: reg (float): The regularization strength. max_iter (int): Maximum number of iterations. device (str | torch.device, optional): The device to execute on. num_emaps (int, optional): Number of estimated sensitivity maps. Currently only ``1`` is supported. """ super().__init__() if device != torch.device("cpu") and not _CUPY_AVAILABLE: raise ModuleNotFoundError( f"Requested device {device}, but cupy not installed. " f"Install cupy>=9.0 following instructions at " f"https://docs.cupy.dev/en/stable/install.html" ) self.device = device # Extract network parameters self.l1_reg = reg self.max_iter = max_iter # Data dimensions self.num_emaps = num_emaps if self.num_emaps != 1: raise ValueError("CSModel currently only supports one sensitivity map.")
def forward(self, inputs, return_pp=False, vis_training=False): """ TODO: condense into list of dataset dicts. Args: inputs: Standard ss_recon module input dictionary * "kspace": Kspace. If fully sampled, and want to simulate undersampled kspace, provide "mask" argument. * "maps": Sensitivity maps * "target" (optional): Target image (typically fully sampled). * "mask" (optional): Undersampling mask to apply. * "signal_model" (optional): The signal model. If provided, "maps" will not be used to estimate the signal model. Use with caution. return_pp (bool, optional): If `True`, return post-processing parameters "mean", "std", and "norm" if included in the input. vis_training (bool, optional): If `True`, force visualize training on this pass. Can only be `True` if model is in training mode. Returns: Dict: A standard ss_recon output dict * "pred": The reconstructed image * "target" (optional): The target image. Added if provided in the input. * "mean"/"std"/"norm" (optional): Pre-processing parameters. Added if provided in the input. * "zf_image": The zero-filled image. Added when model is in eval mode. """ if inputs["kspace"].shape[0] != 1: raise ValueError("Only batch size == 1 is supported in compressed sensing") # Need to fetch device at runtime for proper data transfer. # device = self.resnets[0].final_layer.weight.device device = self.device inputs = move_to_device(inputs, device) kspace = inputs["kspace"] target = inputs.get("target", None) mask = inputs.get("mask", None) A = inputs.get("signal_model", None) maps = inputs["maps"] num_maps_dim = -2 if cplx.is_complex_as_real(maps) else -1 if self.num_emaps != maps.size()[num_maps_dim]: raise ValueError("Incorrect number of ESPIRiT maps! Re-prep data...") if mask is None: mask = cplx.get_mask(kspace) kspace *= mask # Declare signal model. if A is None: A = SenseModel(maps, weights=mask) zf_image = A(kspace, adjoint=True) # Channel-first - (#coils, ky, kz) # TODO: Generalize to 3D kspace = kspace[0].permute((2, 0, 1)) maps = maps.squeeze(num_maps_dim)[0].permute((2, 0, 1)) mask = mask[0].permute((2, 0, 1)) xp = np if device == torch.device("cpu") else cp kspace = xp.asarray(kspace) maps = xp.asarray(maps) mask = xp.asarray(mask) image = mr.app.L1WaveletRecon( kspace, maps, self.l1_reg, weights=mask, max_iter=self.max_iter, device=sp.get_device(kspace), ).run() image = torch.as_tensor(image, device=device) image = image.unsqueeze(0).unsqueeze(-1) output_dict = {"pred": image, "target": target} # N x Y x Z x 1 x 2 # N x Y x Z x 1 x 2 if return_pp: output_dict.update({k: inputs[k] for k in ["mean", "std", "norm"]}) if not self.training: output_dict["zf_image"] = zf_image return output_dict @classmethod def from_config(cls, cfg): return { "reg": cfg.MODEL.CS.REGULARIZATION, "max_iter": cfg.MODEL.CS.MAX_ITER, "device": cfg.MODEL.DEVICE, "num_emaps": cfg.MODEL.UNROLLED.NUM_EMAPS, }