Source code for meddlr.modeling.meta_arch.ssdu

import warnings

import torch
import torchvision.utils as tv_utils
from torch import nn

import meddlr.ops as oF
from meddlr.config.config import configurable
from meddlr.forward.mri import SenseModel
from meddlr.modeling.meta_arch.build import META_ARCH_REGISTRY, build_model
from meddlr.ops import complex as cplx
from meddlr.transforms.base.mask import KspaceMaskTransform
from meddlr.transforms.builtin.mri import MRIReconAugmentor
from meddlr.transforms.gen.mask import RandomKspaceMask
from meddlr.utils.events import get_event_storage
from meddlr.utils.general import move_to_device, nested_apply


[docs]@META_ARCH_REGISTRY.register() class SSDUModel(nn.Module): """Self-supervised learning via data undersampling. This model is the relaxed form of the SSDU model that can be used to train with both supervised and unsupervised data. The mask used to acquire the data (:math:`\Omega`) is partitioned into train mask for the zero-filled image (:math:`\Theta`) and a mask for the loss (:math:`\Lambda`). Reference: B Yaman, SAH Hosseini, S Moeller, et al. Self-supervised learning of physics-guided reconstruction neural networks without fully sampled reference data. https://onlinelibrary-wiley-com.stanford.idm.oclc.org/doi/full/10.1002/mrm.28378 """ _version = 1
[docs] @configurable def __init__( self, model: nn.Module, masker: RandomKspaceMask, augmentor: MRIReconAugmentor = None, postprocessor: str = None, vis_period: int = None, ): """ Args: model (nn.Module): The base model. masker (NoiseModel): The masking model. augmentor: An augmentation model that can be used postprocessor: The postprocessing to perform on the image. """ super().__init__() self.model = model self.masker = masker self.augmentor = augmentor # Visualization done by this model if hasattr(self.model, "vis_period"): if vis_period is not None: self.model.vis_period = vis_period else: vis_period = self.model.vis_period self.model.vis_period = -1 self.vis_period = vis_period self.postprocessor = postprocessor
def augment(self, inputs): """Noise augmentation module for the consistency branch. Args: inputs (Dict[str, Any]): The input dictionary. It must contain a key ``'kspace'``, which traditionally corresponds to the undersampled kspace when performing augmentation for consistency. For supervised examples, this can correspond to either the retrospectively undersampled k-space or the fully-sampled kspace. Returns: Dict[str, Any]: The input dictionary with the kspace polluted with additive masked complex Gaussian noise. """ masker = self.masker kspace = inputs["kspace"].clone() mask = inputs.get("mask", None) if mask is None: mask = cplx.get_mask(kspace) edge_mask = inputs["edge_mask"] tfm: KspaceMaskTransform = masker.get_transform(kspace) train_mask = tfm.generate_mask(kspace, channels_last=True) loss_mask = mask - train_mask # The loss mask should be a subset of the original mask. # TODO (arjundd): See if we can remove this check for speed reasons. is_loss_mask_valid = torch.all(loss_mask >= 0) if not is_loss_mask_valid: idx = torch.where(loss_mask < 0) print("keys", inputs.keys()) raise ValueError( "Train mask is not a subset of the original mask.\n" f"Invalid indices: {idx}\n" f"Mask: {mask[idx]}\n" f"Mask (coils): {mask[idx[:-1]]}\n" f"Train mask: {train_mask[idx[:-1]]}\n" f"Loss mask: {loss_mask[idx]}\n" ) assert is_loss_mask_valid # Pad the train mask so that all unacquired kspace points # are included in the train_mask. train_mask = (train_mask.type(torch.bool) | edge_mask.type(torch.bool)).type(torch.float32) inputs = { k: nested_apply(v, lambda _v: _v.clone()) for k, v in inputs.items() if k != "kspace" } inputs["kspace"] = train_mask * kspace inputs["mask"] = train_mask if self.augmentor is not None: out, _, _ = self.augmentor(kspace=inputs["kspace"], maps=inputs["maps"]) inputs["kspace"] = out["kspace"] return inputs, mask[..., 0:1], train_mask, loss_mask[..., 0:1] @torch.no_grad() def visualize(self, images_dict): for name, images in images_dict.items(): storage = get_event_storage() if isinstance(images, (tuple, list)): images = torch.stack(images, dim=0) if cplx.is_complex_as_real(images) or cplx.is_complex(images): images = { f"{name}-phase": cplx.angle(images), f"{name}-mag": cplx.abs(images), } else: images = {name: images} for name, data in images.items(): if data.shape[-1] == 1: data = data.squeeze(-1) data = data.unsqueeze(1) data = tv_utils.make_grid( data, nrow=len(data), padding=1, normalize=True, scale_each=True ) storage.put_image("ssdu/{}".format(name), data.cpu().numpy(), data_format="CHW") def forward(self, inputs): if not self.training: assert ( "unsupervised" not in inputs ), "unsupervised inputs should not be provided in eval mode" inputs = inputs.get("supervised", inputs) mask = cplx.get_mask(inputs["kspace"]) # The mask should be the union of the edge mask and the sampled data mask. # https://github.com/byaman14/SSDU # If the edge mask is not passed in, we assume that we do not want to get # the edge mask. if "edge_mask" not in inputs: edge_mask = torch.tensor(0, device=mask.device, dtype=mask.dtype) warnings.warn("Edge mask not found in `inputs`. Assuming no edge mask.") else: edge_mask = inputs["edge_mask"] dc_mask = (mask + edge_mask).bool().to(mask.dtype) inputs["mask"] = dc_mask # inputs["postprocessing_mask"] = dc_mask - mask return self.model(inputs) storage = get_event_storage() vis_training = self.training and self.vis_period > 0 and storage.iter % self.vis_period == 0 # Put supervised and unsupervised scans in a single tensor. sup = inputs.get("supervised", {}) unsup = inputs.get("unsupervised", {}) # TODO: Make the cat operation recursive. sup = {k: v for k, v in sup.items() if k not in ["metrics", "_profiler", "stats"]} unsup = {k: v for k, v in unsup.items() if k not in ["metrics", "_profiler", "stats"]} if sup or unsup: inputs = { k: torch.cat([sup.get(k, torch.tensor([])), unsup.get(k, torch.tensor([]))]) for k in sup.keys() | unsup.keys() } assert all(k in inputs for k in ["kspace"]) device = next(self.model.parameters()).device inputs = move_to_device(inputs, device=device, non_blocking=True) kspace = inputs["kspace"] inputs_aug, orig_mask, train_mask, loss_mask = self.augment(inputs) outputs = self.model(inputs_aug, vis_training=vis_training and len(sup) > 0) # Get the signal model reconstructed images. # TODO: Make it possible to use these are the target instead of multi-coil images. pred_img = outputs["pred"] target_img, zf_image = outputs.get("target", None), outputs.get("zf_image", None) # Use signal model (SENSE) to get weighted kspace. A = SenseModel(maps=inputs_aug["maps"]) # no weights - we do not want to mask the data. loss_pred_kspace = loss_mask * A(outputs["pred"], adjoint=False) loss_kspace = loss_mask * kspace # TODO: Refactor post processing to be general to all reconstruction networks. postprocessing_mask = None if self.postprocessor == "hard_dc_edge": postprocessing_mask = inputs["edge_mask"] elif self.postprocessor == "hard_dc_all": postprocessing_mask = train_mask # Do this for differentiability reasons. if postprocessing_mask is not None: loss_pred_kspace = ( 1 - postprocessing_mask ) * loss_pred_kspace + postprocessing_mask * inputs["kspace"] # A hacky way to prepare the predictions and target for the loss. # This may result in inaccurate training metrics outside of the loss. # TODO (arjundd): Fix this. # Shape: B x H x W x #coils outputs["pred"] = oF.ifft2c(loss_pred_kspace, channels_last=True) outputs["target"] = oF.ifft2c(loss_kspace, channels_last=True) # Visualize. if self.training and self.vis_period > 0: with torch.no_grad(): storage = get_event_storage() if storage.iter % self.vis_period == 0: A = SenseModel(maps=inputs["maps"][0:1], weights=train_mask[0:1]) base_image = A(kspace[0:1], adjoint=True) self.visualize( { "masks": [orig_mask[0], train_mask[0], loss_mask[0]], "kspace": [ kspace[0, ..., 0:1], inputs_aug["kspace"][0, ..., 0:1], loss_pred_kspace[0, ..., 0:1], loss_kspace[0, ..., 0:1], ], "images": [ x[0] for x in [base_image, zf_image, pred_img, target_img] if x is not None ], } ) return outputs @classmethod def from_config(cls, cfg): model_cfg = cfg.clone() model_cfg.defrost() model_cfg.MODEL.META_ARCHITECTURE = cfg.MODEL.SSDU.META_ARCHITECTURE model_cfg.freeze() model = build_model(model_cfg) # Train/loss mask splitter. params = cfg.MODEL.SSDU.MASKER.PARAMS masker = RandomKspaceMask(**params) masker.to(cfg.MODEL.DEVICE) init_kwargs = {"model": model, "masker": masker} # Build augmentor. aug_cfg = cfg.MODEL.SSDU.AUGMENTOR if aug_cfg.TRANSFORMS: augmentor = MRIReconAugmentor.from_cfg(aug_cfg, aug_kind=None, seed=cfg.SEED) init_kwargs["augmentor"] = augmentor # Build postprocessor. postprocessor = cfg.MODEL.SSDU.POSTPROCESSOR.NAME or None init_kwargs["postprocessor"] = postprocessor return init_kwargs