Source code for meddlr.modeling.meta_arch.vortex

import logging

import torch
import torchvision.utils as tv_utils
from torch import nn

from meddlr.config.config import configurable
from meddlr.modeling.meta_arch.build import META_ARCH_REGISTRY, build_model
from meddlr.ops import complex as cplx
from meddlr.transforms.builtin.mri import MRIReconAugmentor
from meddlr.utils.events import get_event_storage
from meddlr.utils.general import move_to_device


[docs]@META_ARCH_REGISTRY.register() class VortexModel(nn.Module): """VORTEX model. This is the generalized model implementation for augmentation-based consistency. It differs from :class:`N2RModel` and :class:`M2RModel` in some ways: 1. **Generalizable augmentor**: :class:`MRIReconAugmentor` is used to perform augmentations. 2. **Faster augmentations:** Augmentations are performed on the operating device (e.g. GPU) with large, but reproducible seeds. 3. **Spatial augmentations**: Consistency with spatial augmentations are also supported. These augmentation are also used to transform the target image. Reference: A Desai, B Gunel, B Ozturkler, et al. VORTEX: Physics-Driven Data Augmentations Using Consistency Training for Robust Accelerated MRI Reconstruction. https://arxiv.org/abs/2111.02549. """ _version = 1 _aliases = ["A2RModel"]
[docs] @configurable def __init__( self, model: nn.Module, augmentor: MRIReconAugmentor, use_supervised_consistency: bool = False, vis_period: int = -1, ): """ Args: model (nn.Module): The base model. augmentor (MRIReconAugmentor): The augmentation module. use_supervised_consistency (bool, optional): If ``True``, use consistency with supervised examples too. vis_period (int, optional): The period over which to visualize images. If ``<=0``, it is ignored. Note if the ``model`` has a ``vis_period`` attribute, it will be overridden so that this class handles visualization. """ super().__init__() self.model = model self.augmentor = augmentor self.use_base_grad = False # Keep gradient for base images in transform. self.use_supervised_consistency = use_supervised_consistency # Visualization done by this model if hasattr(model, "vis_period") and vis_period > 0: self.model.vis_period = -1 self.vis_period = vis_period
def augment(self, inputs, pred_base): inputs = move_to_device(inputs, device="cuda") pred_base = move_to_device(pred_base, device="cuda") kspace, maps = inputs["kspace"].clone(), inputs["maps"].clone() out, _, _ = self.augmentor(kspace, maps, pred_base, mask=True) inputs = { k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in inputs.items() if k not in ("kspace", "maps") } inputs["kspace"] = out["kspace"] inputs["maps"] = out["maps"] aug_pred_base = out["target"] return inputs, aug_pred_base def visualize_aug_training(self, kspace, kspace_aug, preds, preds_base, target=None): """Visualize training of augmented data. Args: kspace: The base kspace. kspace_aug: The augmented kspace. preds: Reconstruction of augmented kspace. Shape: NxHxWx2. preds_base: Reconstruction of base kspace. Shape: NxHxWx2. """ storage = get_event_storage() with torch.no_grad(): # calc mask for first coil only if cplx.is_complex(kspace): kspace = torch.view_as_real(kspace) kspace = kspace.cpu()[0, ..., 0, :].unsqueeze(0) if cplx.is_complex(kspace_aug): kspace_aug = torch.view_as_real(kspace_aug) kspace_aug = kspace_aug.cpu()[0, ..., 0, :].unsqueeze(0) preds = preds.cpu()[0, ...].unsqueeze(0) preds_base = preds_base.cpu()[0, ...].unsqueeze(0) all_images = [preds, preds_base] errors = [cplx.abs(preds_base - preds)] if target is not None: target = target.cpu()[0, ...].unsqueeze(0) all_images.append(target) errors.append(cplx.abs(target - preds)) all_images = torch.cat(all_images, dim=2) all_kspace = torch.cat([kspace, kspace_aug], dim=2) errors = torch.cat(errors, dim=2) imgs_to_write = { "phases": cplx.angle(all_images), "images": cplx.abs(all_images), "errors": errors, "masks": cplx.get_mask(kspace), "kspace": cplx.abs(all_kspace), } for name, data in imgs_to_write.items(): data = data.squeeze(-1).unsqueeze(1) data = tv_utils.make_grid(data, nrow=1, padding=1, normalize=True, scale_each=True) storage.put_image("train_aug/{}".format(name), data.numpy(), data_format="CHW") def forward(self, inputs): if not self.training: assert ( "unsupervised" not in inputs ), "unsupervised inputs should not be provided in eval mode" inputs = inputs.get("supervised", inputs) return self.model(inputs) vis_training = False if self.training and self.vis_period > 0: storage = get_event_storage() if storage.iter % self.vis_period == 0: vis_training = True inputs_supervised = inputs.get("supervised", None) inputs_unsupervised = inputs.get("unsupervised", None) if inputs_supervised is None and inputs_unsupervised is None: raise ValueError("Examples not formatted in the proper way") output_dict = {} # Recon if inputs_supervised is not None: output_dict["recon"] = self.model( inputs_supervised, return_pp=True, vis_training=vis_training ) # Consistency. # kspace_aug = kspace + U \sigma \mathcal{N} # Loss = L(f(Ti(Te(kspace)), \theta), Te(f(kspace, \theta))) inputs_consistency = [] if inputs_unsupervised is not None: inputs_consistency.append(inputs_unsupervised) if self.use_supervised_consistency and inputs_supervised is not None: inputs_consistency.append({k: v for k, v in inputs_supervised.items() if k != "target"}) if len(inputs_consistency) > 0: if len(inputs_consistency) > 1: inputs_consistency = { k: torch.cat([x[k] for x in inputs_consistency], dim=0) for k in inputs_consistency[0].keys() } else: inputs_consistency = inputs_consistency[0] with torch.no_grad(): pred_base = self.model(inputs_consistency) # Target only used for visualization purposes not for loss. target = inputs_unsupervised.get("target", None) pred_base = pred_base["pred"] inputs_consistency_aug, pred_base = self.augment(inputs_consistency, pred_base) pred_aug = self.model(inputs_consistency_aug, return_pp=True) if "target" in pred_aug: del pred_aug["target"] pred_aug["target"] = pred_base.detach() output_dict["consistency"] = pred_aug if vis_training: self.visualize_aug_training( inputs_consistency["kspace"], inputs_consistency_aug["kspace"], pred_aug["pred"], pred_base, target=target, ) return output_dict @classmethod def from_config(cls, cfg): _logger = logging.getLogger(__name__) model_cfg = cfg.clone() model_cfg.defrost() model_cfg.MODEL.META_ARCHITECTURE = cfg.MODEL.A2R.META_ARCHITECTURE model_cfg.freeze() model = build_model(model_cfg) augmentor = MRIReconAugmentor.from_cfg( cfg, aug_kind="consistency", device=cfg.MODEL.DEVICE, seed=cfg.SEED ) _logger.info("Built augmentor:\n{}".format(str(augmentor.tfms_or_gens))) return { "model": model, "augmentor": augmentor, "use_supervised_consistency": cfg.MODEL.A2R.USE_SUPERVISED_CONSISTENCY, "vis_period": cfg.VIS_PERIOD, }