Source code for meddlr.modeling.meta_arch.n2r

from typing import Dict, Optional

import torch
import torchvision.utils as tv_utils
from torch import nn

from meddlr.config.config import configurable
from meddlr.data.transforms.noise import NoiseModel
from meddlr.modeling.meta_arch.build import META_ARCH_REGISTRY, build_model
from meddlr.modeling.meta_arch.ssdu import SSDUModel
from meddlr.ops import complex as cplx
from meddlr.utils.events import get_event_storage
from meddlr.utils.general import nested_apply


[docs]@META_ARCH_REGISTRY.register() class N2RModel(nn.Module): """Noise2Recon model. Reference: AD Desai, BM Ozturkler, CM Sandino, et al. Noise2Recon: A Semi-Supervised Framework for Joint MRI Reconstruction and Denoising. ArXiv 2021. https://arxiv.org/abs/2110.00075 """ _version = 3
[docs] @configurable def __init__( self, model: nn.Module, noiser: NoiseModel, use_supervised_consistency: bool = False, vis_period: int = -1, ): """ Args: model: The base model. noiser: The additive noise module. use_supervised_consistency: Whether to apply noise-based consistency to supervised examples. vis_period (int, optional): The period over which to visualize images. If ``<=0``, it is ignored. Note if the ``model`` has a ``vis_period`` attribute, it will be overridden so that this class handles visualization. """ super().__init__() self.model = model # Visualization done by this model. # If sub-model is SSDU, we allow SSDU to log images # for the recon pipeline. if ( not isinstance(self.model, SSDUModel) and hasattr(self.model, "vis_period") and vis_period > 0 ): self.model.vis_period = -1 self.vis_period = vis_period # Whether to keep gradient for base images in transform. self.use_base_grad = False # Use supervised examples for consistency self.use_supervised_consistency = use_supervised_consistency self.noiser = noiser
def augment(self, inputs): """Noise augmentation module for the consistency branch. Args: inputs (Dict[str, Any]): The input dictionary. It must contain a key ``'kspace'``, which traditionally corresponds to the undersampled kspace when performing augmentation for consistency. Returns: Dict[str, Any]: The input dictionary with the kspace polluted with additive masked complex Gaussian noise. """ kspace = inputs["kspace"].clone() aug_kspace = self.noiser(kspace, clone=False) inputs = { k: nested_apply(v, lambda _v: _v.clone()) for k, v in inputs.items() if k != "kspace" } inputs["kspace"] = aug_kspace return inputs @torch.no_grad() def visualize_aug_training(self, kspace, kspace_aug, preds, preds_base, target=None): """Visualize training of augmented data. Args: kspace: The base kspace. kspace_aug: The augmented kspace. preds: Reconstruction of augmented kspace. Shape: NxHxWx2. preds_base: Reconstruction of base kspace. Shape: NxHxWx2. """ storage = get_event_storage() # calc mask for first coil only if cplx.is_complex(kspace): kspace = torch.view_as_real(kspace) kspace = kspace.cpu()[0, ..., 0, :].unsqueeze(0) if cplx.is_complex(kspace_aug): kspace_aug = torch.view_as_real(kspace_aug) kspace_aug = kspace_aug.cpu()[0, ..., 0, :].unsqueeze(0) preds = preds.cpu()[0, ...].unsqueeze(0) preds_base = preds_base.cpu()[0, ...].unsqueeze(0) all_images = [preds, preds_base] errors = [cplx.abs(preds_base - preds)] if target is not None: target = target.cpu()[0, ...].unsqueeze(0) all_images.append(target) errors.append(cplx.abs(target - preds)) all_images = torch.cat(all_images, dim=2) all_kspace = torch.cat([kspace, kspace_aug], dim=2) errors = torch.cat(errors, dim=2) imgs_to_write = { "phases": cplx.angle(all_images), "images": cplx.abs(all_images), "errors": errors, "masks": cplx.get_mask(kspace), "kspace": cplx.abs(all_kspace), } for name, data in imgs_to_write.items(): data = data.squeeze(-1).unsqueeze(1) data = tv_utils.make_grid(data, nrow=1, padding=1, normalize=True, scale_each=True) storage.put_image("train_aug/{}".format(name), data.numpy(), data_format="CHW") def _format_consistency_inputs( self, inputs_supervised: Optional[Dict[str, torch.Tensor]] = None, inputs_unsupervised: Optional[Dict[str, torch.Tensor]] = None, ) -> Dict[str, torch.Tensor]: """Generate base and augmented inputs to be used for consistency training. Args: inputs_supervised: A dict of inputs, their metadata, and their ground truth references. inputs_unsupervised: A dict of inputs and their metadata. Returns: Dict[str, Dict[str, Tensor]]: A dictionary of base inputs and augmented inputs: - 'base': Inputs to be used to generate the pseudo-label (i.e. target) for consistency optimization. - 'aug': Noise augmented inputs to use for consistency training. """ inputs_consistency = [] if inputs_unsupervised is not None: inputs_consistency.append(inputs_unsupervised) if self.use_supervised_consistency and inputs_supervised is not None: inputs_consistency.append({k: v for k, v in inputs_supervised.items() if k != "target"}) if len(inputs_consistency) == 0: return {} # No consistency training. if len(inputs_consistency) > 1: inputs_consistency = { k: torch.cat([x[k] for x in inputs_consistency], dim=0) for k in inputs_consistency[0].keys() } else: inputs_consistency = inputs_consistency[0] # Augment the inputs. inputs_consistency_aug = self.augment(inputs_consistency) return {"base": inputs_consistency, "aug": inputs_consistency_aug} def forward(self, inputs): if not self.training: assert ( "unsupervised" not in inputs ), "unsupervised inputs should not be provided in eval mode" inputs = inputs.get("supervised", inputs) return self.model(inputs) storage = get_event_storage() vis_training = self.training and self.vis_period > 0 and storage.iter % self.vis_period == 0 inputs_supervised = inputs.get("supervised", None) inputs_unsupervised = inputs.get("unsupervised", None) if inputs_supervised is None and inputs_unsupervised is None: raise ValueError("Examples not formatted in the proper way") # Whether to use self-supervised via data undersampling (SSDU) for reconstruction. is_ssdu_enabled = isinstance(self.model, SSDUModel) output_dict = {} # Reconstruction (supervised). if is_ssdu_enabled: output_dict["recon"] = self.model(inputs) elif inputs_supervised is not None: output_dict["recon"] = self.model( inputs_supervised, return_pp=True, vis_training=vis_training ) # Consistency (unsupervised). # kspace_aug = kspace + U \sigma \mathcal{N} # Loss = L(f(kspace_aug, \theta), f(kspace, \theta)) # If the model is an SSDU model, unpack it to use the internal model for consistency. model = self.model.model if is_ssdu_enabled else self.model consistency_inputs = self._format_consistency_inputs(inputs_supervised, inputs_unsupervised) if len(consistency_inputs) > 0: inputs_consistency = consistency_inputs["base"] inputs_consistency_aug = consistency_inputs["aug"] with torch.no_grad(): pred_base = model(inputs_consistency) # Target only used for visualization purposes not for loss. target = inputs_consistency.get("target", None) pred_base = pred_base["pred"] pred_aug: Dict[str, torch.Tensor] = model(inputs_consistency_aug, return_pp=True) pred_aug.pop("target", None) pred_aug["target"] = pred_base.detach() output_dict["consistency"] = pred_aug if vis_training: self.visualize_aug_training( inputs_consistency["kspace"], inputs_consistency_aug["kspace"], pred_aug["pred"], pred_base, target=target, ) return output_dict def load_state_dict(self, state_dict, strict=True): # pragma: no cover # TODO: Configure backwards compatibility if any(x.startswith("unrolled") for x in state_dict.keys()): raise ValueError( "`self.unrolled` was renamed to `self.model`. " "Backwards compatibility has not been configured." ) return super().load_state_dict(state_dict, strict) @classmethod def from_config(cls, cfg): model_cfg = cfg.clone() model_cfg.defrost() model_cfg.MODEL.META_ARCHITECTURE = cfg.MODEL.N2R.META_ARCHITECTURE model_cfg.freeze() model = build_model(model_cfg) noiser = NoiseModel.from_cfg(cfg) return { "model": model, "noiser": noiser, "use_supervised_consistency": cfg.MODEL.N2R.USE_SUPERVISED_CONSISTENCY, "vis_period": cfg.VIS_PERIOD, }