Source code for meddlr.metrics.sem_seg

from typing import Sequence

import torch

import meddlr.metrics.functional as mF
from meddlr.metrics.metric import Metric

__all__ = ["DSC", "CV", "VOE", "ASSD"]


[docs]class DSC(Metric): """Dice score coefficient. Attributes: channel_names (Sequence[str]): Category names corresponding to the channels. """ is_differentiable = True higher_is_better = True
[docs] def __init__( self, channel_names: Sequence[str] = None, reduction="none", compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: bool = None, dist_sync_fn: bool = None, ): super().__init__( channel_names=channel_names, units="", reduction=reduction, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, )
def func(self, preds, targets) -> torch.Tensor: return mF.dice_score(y_pred=preds, y_true=targets)
Dice = DSC
[docs]class CV(Metric): """Coefficient of variation. Attributes: channel_names (Sequence[str]): Category names corresponding to the channels. """ is_differentiable = True higher_is_better = False
[docs] def __init__( self, channel_names: Sequence[str] = None, reduction="none", compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: bool = None, dist_sync_fn: bool = None, ): super().__init__( channel_names=channel_names, units="", reduction=reduction, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, )
def func(self, preds, targets) -> torch.Tensor: return mF.coefficient_variation(y_pred=preds, y_true=targets)
[docs]class VOE(Metric): """Volumetric overlap error. Attributes: channel_names (Sequence[str]): Category names corresponding to the channels. """ is_differentiable = True higher_is_better = False
[docs] def __init__( self, channel_names: Sequence[str] = None, reduction="none", compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: bool = None, dist_sync_fn: bool = None, ): super().__init__( channel_names=channel_names, units="", reduction=reduction, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, )
def func(self, preds, targets) -> torch.Tensor: return mF.volumetric_overlap_error(y_pred=preds, y_true=targets)
[docs]class ASSD(Metric): """Average symmetric surface distance. Attributes: connectivity (int): The neighbourhood/connectivity considered when determining the surface of the binary objects. channel_names (Sequence[str]): Category names corresponding to the channels. Note: This metric is not differentiable. """ is_differentiable = False higher_is_better = False
[docs] def __init__( self, connectivity: int = 1, channel_names: Sequence[str] = None, reduction="none", compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: bool = None, dist_sync_fn: bool = None, ): """ Args: connectivity (int): The neighbourhood/connectivity considered when determining the surface of the binary objects. If in doubt, leave it as it is. channel_names (Sequence[str]): Category names corresponding to the channels. """ super().__init__( channel_names=channel_names, units="", reduction=reduction, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) self.connectivity = connectivity
def func(self, preds, targets, spacing=None) -> torch.Tensor: return mF.assd( y_pred=preds, y_true=targets, spacing=spacing, connectivity=self.connectivity ) def update(self, preds, targets, spacing=None, ids=None): return super().update(preds, targets, spacing=spacing, ids=ids)