Source code for meddlr.metrics.image

"""Metric utilities.

All metrics take a reference scan (ref) and an output/reconstructed scan (x).
Both should be tensors with the last dimension equal to 2 (real/imaginary
channels).
"""
from typing import Sequence

import numpy as np
import scipy as scp
import torch
from skimage.metrics import structural_similarity

import meddlr.metrics.functional as mF
from meddlr.metrics.metric import Metric
from meddlr.ops import complex as cplx
from meddlr.utils.deprecated import deprecated

# Mapping from str to complex function name.
_IM_TYPES_TO_FUNCS = {
    "magnitude": cplx.abs,
    "mag": cplx.abs,
    "abs": cplx.abs,
    "phase": cplx.angle,
    "angle": cplx.angle,
    "real": cplx.real,
    "imag": cplx.imag,
}

__all__ = ["PSNR", "MAE", "MSE", "NRMSE", "RMSE", "SSIM"]


[docs]class PSNR(Metric): """Peak signal-to-noise ratio with complex-valued support. :math:`PSNR = 20 * log_{10}(\\frac{max(|x_{gt}|)}{||x_{pred} - x_{gt}||_2})` This implementation supports complex tensors. ``im_type`` controls how the complex tensor should be processed: - ``'magnitude'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to magnitude images. - ``'phase'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to phase images. - ``'real'``: Real components of :math:`x_{pred}` and :math:`x_{gt}` are used. - ``'imag'``: Imaginary components of :math:`x_{pred}` and :math:`x_{gt}` are used. Attributes: im_type (str): The type of the complex image to compute the metric on. This only applies to complex tensors. channel_names (Sequence[str]): The names of the channels in the input. """ is_differentiable = True higher_is_better = True
[docs] def __init__( self, im_type: str = None, channel_names: Sequence[str] = None, reduction="none", compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: bool = None, dist_sync_fn: bool = None, ): super().__init__( channel_names=channel_names, units="dB", reduction=reduction, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) self.im_type = im_type
def func(self, preds, targets) -> torch.Tensor: return mF.psnr(preds, targets, im_type=self.im_type)
class MAE(Metric): """Mean absolute error with complex-valued support. :math:`MAE = \\frac{1}{N} \sum_{i=1}^{N} |x_{pred} - x_{gt}|`. This implementation supports complex tensors. ``im_type`` controls how the complex tensor should be processed: - ``'magnitude'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to magnitude images. - ``'phase'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to phase images. - ``'real'``: Real components of :math:`x_{pred}` and :math:`x_{gt}` are used. - ``'imag'``: Imaginary components of :math:`x_{pred}` and :math:`x_{gt}` are used. Attributes: im_type (str): The type of the complex image to compute the metric on. This only applies to complex tensors. channel_names (Sequence[str]): The names of the channels in the input. """ is_differentiable = True higher_is_better = False def __init__( self, im_type: str = None, channel_names: Sequence[str] = None, reduction="none", compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: bool = None, dist_sync_fn: bool = None, ): super().__init__( channel_names=channel_names, units="", reduction=reduction, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) self.im_type = im_type def func(self, preds, targets) -> torch.Tensor: return mF.mae(preds, targets, im_type=self.im_type)
[docs]class MSE(Metric): """Mean squared error with complex-valued support. :math:`MSE = ||x_{pred} - x_{gt}||_2^2`. This implementation supports complex tensors. ``im_type`` controls how the complex tensor should be processed: - ``'magnitude'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to magnitude images. - ``'phase'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to phase images. - ``'real'``: Real components of :math:`x_{pred}` and :math:`x_{gt}` are used. - ``'imag'``: Imaginary components of :math:`x_{pred}` and :math:`x_{gt}` are used. Attributes: im_type (str): The type of the complex image to compute the metric on. This only applies to complex tensors. channel_names (Sequence[str]): The names of the channels in the input. """ is_differentiable = True higher_is_better = False
[docs] def __init__( self, im_type: str = None, channel_names: Sequence[str] = None, reduction="none", compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: bool = None, dist_sync_fn: bool = None, ): super().__init__( channel_names=channel_names, units="", reduction=reduction, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) self.im_type = im_type
def func(self, preds, targets) -> torch.Tensor: return mF.mse(preds, targets, im_type=self.im_type)
[docs]class NRMSE(Metric): """Normalized root-mean-squared error with complex-valued support. :math:`NRMSE = \\frac{||x_{pred} - x_{gt}||_2}{||x_{gt}||_2}`. This implementation supports complex tensors. ``im_type`` controls how the complex tensor should be processed: - ``'magnitude'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to magnitude images. - ``'phase'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to phase images. - ``'real'``: Real components of :math:`x_{pred}` and :math:`x_{gt}` are used. - ``'imag'``: Imaginary components of :math:`x_{pred}` and :math:`x_{gt}` are used. Attributes: im_type (str): The type of the complex image to compute the metric on. This only applies to complex tensors. channel_names (Sequence[str]): The names of the channels in the input. """ is_differentiable = True higher_is_better = False
[docs] def __init__( self, im_type: str = None, channel_names: Sequence[str] = None, reduction="none", compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: bool = None, dist_sync_fn: bool = None, ): super().__init__( channel_names=channel_names, units="", reduction=reduction, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) self.im_type = im_type
def func(self, preds, targets) -> torch.Tensor: return mF.nrmse(preds, targets, im_type=self.im_type)
nRMSE = NRMSE
[docs]class RMSE(Metric): """Root-mean-squared error with complex-valued support. :math:`RMSE = ||x_{pred} - x_{gt}||_2`. This implementation supports complex tensors. ``im_type`` controls how the complex tensor should be processed: - ``'magnitude'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to magnitude images. - ``'phase'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to phase images. - ``'real'``: Real components of :math:`x_{pred}` and :math:`x_{gt}` are used. - ``'imag'``: Imaginary components of :math:`x_{pred}` and :math:`x_{gt}` are used. Attributes: im_type (str): The type of the complex image to compute the metric on. This only applies to complex tensors. channel_names (Sequence[str]): The names of the channels in the input. """ is_differentiable = True higher_is_better = False
[docs] def __init__( self, im_type: str = None, channel_names: Sequence[str] = None, reduction="none", compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: bool = None, dist_sync_fn: bool = None, ): super().__init__( channel_names=channel_names, units="", reduction=reduction, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) self.im_type = im_type
def func(self, preds, targets) -> torch.Tensor: return mF.rmse(preds, targets, im_type=self.im_type)
[docs]class SSIM(Metric): """Structural similarity index measure with complex-valued support. This implementation of pSNR supports complex tensors. ``im_type`` controls how the complex tensor should be processed: - ``'magnitude'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to magnitude images. - ``'phase'``: :math:`x_{pred}` and :math:`x_{gt}` are converted to phase images. - ``'real'``: Real components of :math:`x_{pred}` and :math:`x_{gt}` are used. - ``'imag'``: Imaginary components of :math:`x_{pred}` and :math:`x_{gt}` are used. Attributes: method (str): The method to use for computing the SSIM. Defaults to ``'wang'``. im_type (str): The type of the complex image to compute the metric on. This only applies to complex tensors. channel_names (Sequence[str]): The names of the channels in the input. """ is_differentiable = True higher_is_better = True
[docs] def __init__( self, method: str = "wang", im_type: str = "magnitude", channel_names: Sequence[str] = None, reduction="none", compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: bool = None, dist_sync_fn: bool = None, ): super().__init__( channel_names=channel_names, units="", reduction=reduction, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) self.method = method self.im_type = im_type
def func(self, preds, targets) -> torch.Tensor: return mF.ssim( preds, targets, method=self.method, im_type=self.im_type, )
@deprecated(vremove="0.1.0", replacement="metrics.functional.mse") def compute_mse(ref: torch.Tensor, x: torch.Tensor, is_batch=False, magnitude=False): if cplx.is_complex(ref): ref = torch.view_as_real(ref) x = torch.view_as_real(x) assert ref.shape[-1] == 2 assert x.shape[-1] == 2 if magnitude: squared_err = torch.abs(cplx.abs(x) - cplx.abs(ref)) ** 2 else: squared_err = cplx.abs(x - ref) ** 2 shape = (x.shape[0], -1) if is_batch else -1 return torch.mean(squared_err.view(shape), dim=-1) def compute_l2(ref: torch.Tensor, x: torch.Tensor, is_batch=False, magnitude=False): """ Args: ref (torch.Tensor): The target. Shape (...)x2 x (torch.Tensor): The prediction. Same shape as `ref`. """ if cplx.is_complex(ref): ref = torch.view_as_real(ref) x = torch.view_as_real(x) assert ref.shape[-1] == 2 assert x.shape[-1] == 2 return torch.sqrt(compute_mse(ref, x, is_batch=is_batch, magnitude=magnitude)) @deprecated(vremove="0.1.0", replacement="metrics.functional.psnr") def compute_psnr(ref: torch.Tensor, x: torch.Tensor, is_batch=False, magnitude=False): """Compute peak to signal to noise ratio of magnitude image. Args: ref (torch.Tensor): The target. Shape (...)x2 x (torch.Tensor): The prediction. Same shape as `ref`. Returns: Tensor: Scalar in db """ if cplx.is_complex(ref): ref = torch.view_as_real(ref) x = torch.view_as_real(x) assert ref.shape[-1] == 2 assert x.shape[-1] == 2 assert not is_batch, "is_batch not supported" l2 = compute_l2(ref, x, magnitude=magnitude, is_batch=False) # shape = (x.shape[0], -1) if is_batch else -1 return 20 * torch.log10(cplx.abs(ref).max() / l2) @deprecated(vremove="0.1.0", replacement="metrics.functional.nrmse") def compute_nrmse(ref, x, is_batch=False, magnitude=False): """Compute normalized root mean square error. The norm of reference is used to normalize the metric. Args: ref (torch.Tensor): The target. Shape (...)x2 x (torch.Tensor): The prediction. Same shape as `ref`. """ if cplx.is_complex(ref): ref = torch.view_as_real(ref) x = torch.view_as_real(x) assert ref.shape[-1] == 2 assert x.shape[-1] == 2 rmse = compute_l2(ref, x, is_batch=is_batch, magnitude=magnitude) shape = (x.shape[0], -1) if is_batch else -1 norm = torch.sqrt(torch.mean((cplx.abs(ref) ** 2).view(shape), dim=-1)) return rmse / norm @deprecated(vremove="0.1.0", replacement="metrics.functional.ssim") def compute_ssim( ref: torch.Tensor, x: torch.Tensor, multichannel: bool = False, data_range=None, **kwargs, ): """Compute structural similarity index metric. Does not preserve autograd. Based on implementation of Wang et. al. [1]_ The image is first converted to magnitude image and normalized before the metric is computed. Args: ref (torch.Tensor): The target. Shape (...)x2 x (torch.Tensor): The prediction. Same shape as `ref`. multichannel (bool, optional): If `True`, computes ssim for real and imaginary channels separately and then averages the two. data_range(float, optional): The data range of the input image (distance between minimum and maximum possible values). By default, this is estimated from the image data-type. References: .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image quality assessment: From error visibility to structural similarity. IEEE Transactions on Image Processing, 13, 600-612. https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, :DOI:`10.1109/TIP.2003.819861` """ gaussian_weights = kwargs.pop("gaussian_weights", True) sigma = kwargs.pop("sigma", 1.5) use_sample_covariance = kwargs.pop("use_sample_covariance", False) if cplx.is_complex(ref): ref = torch.view_as_real(ref) x = torch.view_as_real(x) assert ref.shape[-1] == 2 assert x.shape[-1] == 2 if not multichannel: ref = cplx.abs(ref) x = cplx.abs(x) if not x.is_contiguous(): x = x.contiguous() if not ref.is_contiguous(): ref = ref.contiguous() x = x.squeeze().numpy() ref = ref.squeeze().numpy() if data_range in ("range", "ref-range"): data_range = ref.max() - ref.min() elif data_range in ("ref-maxval", "maxval"): data_range = ref.max() elif data_range == "x-range": data_range = x.max() - x.min() elif data_range == "x-maxval": data_range = x.max() return structural_similarity( ref, x, data_range=data_range, gaussian_weights=gaussian_weights, sigma=sigma, use_sample_covariance=use_sample_covariance, multichannel=multichannel, **kwargs, ) def compute_vifp_mscale( ref: torch.Tensor, x: torch.Tensor, sigma_nsq: float = 2.0, eps: float = 1e-10, im_type: str = None, ): # pragma: no-cover """Compute visual information fidelity (VIF) metric. This code is adapted from https://github.com/aizvorski/video-quality/blob/master/vifp.py Args: ref (torch.Tensor): The reference image. This can be complex. x (torch.Tensor): The target image. This can be complex. sigma_nsq (float, optional): The visual noise parameter. This may need to be fine-tuned over the dataset of interest. eps (float, optional): The threshold below which data is considered to be 0. im_type (str, optional): The image type to compute metric on. Either ``'magnitude'`` (default) to compute metric on magnitude images or ``'phase'`` to compute metric on phase/angle images. Returns: float: The metric value. Note: ``im_type`` is only valid if input is complex. """ if cplx.is_complex(ref) or cplx.is_complex_as_real(ref): ref = _IM_TYPES_TO_FUNCS[im_type](ref) if cplx.is_complex(x) or cplx.is_complex_as_real(x): x = _IM_TYPES_TO_FUNCS[im_type](x) ref = np.squeeze(ref.numpy()) x = np.squeeze(x.numpy()) scale_val = 255.0 / ref.max() ref *= scale_val x *= scale_val num = 0.0 den = 0.0 for scale in range(1, 5): N = 2 ** (4 - scale + 1) + 1 sd = N / 5.0 if scale > 1: sl = tuple(slice(None, None, 2) if dim > 1 else slice(None) for dim in ref.shape) ref = scp.ndimage.gaussian_filter(ref, sd) x = scp.ndimage.gaussian_filter(x, sd) ref = ref[sl] x = x[sl] mu1 = scp.ndimage.gaussian_filter(ref, sd) mu2 = scp.ndimage.gaussian_filter(x, sd) mu1_sq = mu1 * mu1 mu2_sq = mu2 * mu2 mu1_mu2 = mu1 * mu2 sigma1_sq = scp.ndimage.gaussian_filter(ref * ref, sd) - mu1_sq sigma2_sq = scp.ndimage.gaussian_filter(x * x, sd) - mu2_sq sigma12 = scp.ndimage.gaussian_filter(ref * x, sd) - mu1_mu2 sigma1_sq[sigma1_sq < 0] = 0 sigma2_sq[sigma2_sq < 0] = 0 g = sigma12 / (sigma1_sq + eps) sv_sq = sigma2_sq - g * sigma12 g[sigma1_sq < eps] = 0 sv_sq[sigma1_sq < eps] = sigma2_sq[sigma1_sq < eps] sigma1_sq[sigma1_sq < eps] = 0 g[sigma2_sq < eps] = 0 sv_sq[sigma2_sq < eps] = 0 sv_sq[g < 0] = sigma2_sq[g < 0] g[g < 0] = 0 sv_sq[sv_sq <= eps] = eps num += np.sum(np.log10(1 + g * g * sigma1_sq / (sv_sq + sigma_nsq))) den += np.sum(np.log10(1 + sigma1_sq / sigma_nsq)) vifp = num / den return vifp compute_rmse = compute_l2