Source code for meddlr.modeling.meta_arch.denoising

import itertools

import torch
import torchvision.utils as tv_utils
from torch import nn

import meddlr.ops.complex as cplx
from meddlr.config.config import configurable
from meddlr.data.transforms.noise import NoiseModel
from meddlr.forward.mri import SenseModel
from meddlr.modeling.meta_arch.build import META_ARCH_REGISTRY, build_model
from meddlr.utils.events import get_event_storage
from meddlr.utils.general import move_to_device

__all__ = ["DenoisingModel"]


[docs]@META_ARCH_REGISTRY.register() class DenoisingModel(nn.Module): """A denoising trainer."""
[docs] @configurable def __init__( self, model, noiser: NoiseModel, use_fully_sampled_target=False, use_fully_sampled_target_eval: bool = None, vis_period: int = -1, ): """ Args: model (nn.Module): The base model. noiser (NoiseModel): The additive noise model. use_fully_sampled_target (bool, optional): If ``True``, use fully sampled images as the target for denoising during training. use_fully_sampled_target_eval (bool, optional): If ``True``, use fully sampled images as the target for denoising during evaluation (includes validation). If ``None``, defaults to ``use_fully_sampled_target``. """ super().__init__() self.model = model self.noiser = noiser # Default validation to use_fully_sampled_target_eval. if use_fully_sampled_target_eval is None: use_fully_sampled_target_eval = use_fully_sampled_target self.use_fully_sampled_target = use_fully_sampled_target self.use_fully_sampled_target_eval = use_fully_sampled_target_eval # Visualization done by this model if hasattr(self.model, "vis_period") and vis_period > 0: self.model.vis_period = -1 self.vis_period = vis_period
def augment(self, kspace): """Noise augmentation module. TODO: Perform the augmentation here. """ kspace = kspace.detach().clone() return self.noiser(kspace, clone=False) def visualize_training(self, kspace, zfs, targets, preds): """A function used to visualize reconstructions. Args: targets: NxHxWx2 tensors of target images. preds: NxHxWx2 tensors of predictions. """ storage = get_event_storage() with torch.no_grad(): if cplx.is_complex(kspace): kspace = torch.view_as_real(kspace) kspace = kspace[0, ..., 0, :].unsqueeze(0).cpu() # calc mask for first coil only targets = targets[0, ...].unsqueeze(0).cpu() preds = preds[0, ...].unsqueeze(0).cpu() zfs = zfs[0, ...].unsqueeze(0).cpu() all_images = torch.cat([zfs, preds, targets], dim=2) imgs_to_write = { "phases": cplx.angle(all_images), "images": cplx.abs(all_images), "errors": cplx.abs(preds - targets), "masks": cplx.get_mask(kspace), } for name, data in imgs_to_write.items(): data = data.squeeze(-1).unsqueeze(1) data = tv_utils.make_grid(data, nrow=1, padding=1, normalize=True, scale_each=True) storage.put_image("train/{}".format(name), data.numpy(), data_format="CHW") def forward(self, inputs, return_pp=False, vis_training=False): """ TODO: condense into list of dataset dicts. Args: inputs: Standard meddlr module input dictionary * "kspace": Kspace. If fully sampled, and want to simulate undersampled kspace, provide "mask" argument. * "maps": Sensitivity maps * "target" (optional): Target image (typically fully sampled). * "mask" (optional): Undersampling mask to apply. * "signal_model" (optional): The signal model. If provided, "maps" will not be used to estimate the signal model. Use with caution. return_pp (bool, optional): If `True`, return post-processing parameters "mean", "std", and "norm" if included in the input. vis_training (bool, optional): If `True`, force visualize training on this pass. Can only be `True` if model is in training mode. Returns: Dict: A standard meddlr output dict * "pred": The reconstructed image * "target" (optional): The target image. Added if provided in the input. * "mean"/"std"/"norm" (optional): Pre-processing parameters. Added if provided in the input. * "zf_image": The zero-filled image. Added when model is in eval mode. """ if vis_training and not self.training: raise ValueError("vis_training is only applicable in training mode.") device = next(self.parameters()).device inputs = move_to_device(inputs, device) if self.training and self.vis_period > 0: storage = get_event_storage() if storage.iter % self.vis_period == 0: vis_training = True if not any(k in inputs for k in ["supervised", "unsupervised"]): inputs = {"supervised": inputs} use_fully_sampled_target = (self.training and self.use_fully_sampled_target) or ( not self.training and self.use_fully_sampled_target_eval ) if "supervised" in inputs: sup_inputs = inputs["supervised"] if use_fully_sampled_target: img = sup_inputs["target"] A = SenseModel(sup_inputs["maps"]) sup_inputs["kspace"] = self.augment(A(img, adjoint=False)) else: kspace = sup_inputs["kspace"] A = SenseModel(sup_inputs["maps"], weights=cplx.get_mask(kspace)) sup_inputs["target"] = A(kspace, adjoint=True).detach() sup_inputs["kspace"] = self.augment(kspace) if "unsupervised" in inputs: unsup_inputs = inputs["unsupervised"] kspace = unsup_inputs["kspace"] A = SenseModel(unsup_inputs["maps"], weights=cplx.get_mask(kspace)) unsup_inputs["target"] = A(kspace, adjoint=True).detach() unsup_inputs["kspace"] = self.augment(kspace) keys = [set(v.keys()) for v in inputs.values()] keys = keys[0].intersection(*keys) inputs = {k: [inputs[field][k] for field in inputs.keys()] for k in keys} inputs = { k: torch.cat(inputs[k], dim=0) if isinstance(inputs[k][0], torch.Tensor) else itertools.chain(inputs[k]) for k in inputs.keys() } output_dict = self.model(inputs, return_pp=True, vis_training=vis_training) return output_dict @classmethod def from_config(cls, cfg): device = torch.device(cfg.MODEL.DEVICE) model_cfg = cfg.clone() model_cfg.defrost() model_cfg.MODEL.META_ARCHITECTURE = cfg.MODEL.DENOISING.META_ARCHITECTURE model_cfg.freeze() model = build_model(model_cfg) noise_cfg = cfg.clone() noise_cfg.defrost() noise_cfg.MODEL.CONSISTENCY.AUG.NOISE.STD_DEV = cfg.MODEL.DENOISING.NOISE.STD_DEV noise_cfg.freeze() noiser = NoiseModel.from_cfg(noise_cfg, device=device) use_fully_sampled_target = cfg.MODEL.DENOISING.NOISE.USE_FULLY_SAMPLED_TARGET use_fully_sampled_target_eval = cfg.MODEL.DENOISING.NOISE.USE_FULLY_SAMPLED_TARGET_EVAL return { "model": model, "noiser": noiser, "use_fully_sampled_target": use_fully_sampled_target, "use_fully_sampled_target_eval": use_fully_sampled_target_eval, }