Source code for meddlr.ops.categorical

import numpy as np
import torch
import torch.nn.functional as F

__all__ = [
    "one_hot_to_categorical",
    "categorical_to_one_hot",
    "logits_to_prob",
    "pred_to_categorical",
]


[docs]def one_hot_to_categorical(x, channel_dim: int = 1, background=False): """Converts one-hot encoded predictions to categorical predictions. Args: x (torch.Tensor | np.ndarray): One-hot encoded predictions. channel_dim (int, optional): Channel dimension. Defaults to ``1`` (i.e. ``(B,C,...)``). background (bool, optional): If ``True``, assumes index 0 in the channel dimension is the background. Returns: torch.Tensor | np.ndarray: Categorical array or tensor. If ``background=False``, the output will be 1-indexed such that ``0`` corresponds to the background. """ is_ndarray = isinstance(x, np.ndarray) if is_ndarray: x = torch.as_tensor(x) if background is not None and background is not False: out = torch.argmax(x, channel_dim) else: out = torch.argmax(x.type(torch.long), dim=channel_dim) + 1 out = torch.where(x.sum(channel_dim) == 0, torch.tensor([0], device=x.device), out) if is_ndarray: out = out.numpy() return out
[docs]def categorical_to_one_hot(x, channel_dim: int = 1, background=0, num_categories=None, dtype=None): """Converts categorical predictions to one-hot encoded predictions. Args: x (torch.Tensor | np.ndarray): Categorical array or tensor. channel_dim (int, optional): Channel dimension for output tensor. background (int | NoneType, optional): The numerical label of the background category. If ``None``, assumes that the background is a class that should be one-hot encoded. num_categories (int, optional): Number of categories (excluding background). Defaults to the ``max(x) + 1``. dtype (type, optional): Data type of the output. Defaults to boolean (``torch.bool`` or ``np.bool``). Returns: torch.Tensor | np.ndarray: One-hot encoded predictions. """ is_ndarray = isinstance(x, np.ndarray) if is_ndarray: x = torch.from_numpy(x) if num_categories is None: num_categories = torch.max(x).type(torch.long).cpu().item() num_categories += 1 shape = x.shape out_shape = (num_categories,) + shape if dtype is None: dtype = torch.bool default_value = True if dtype == torch.bool else 1 if x.dtype != torch.long: x = x.type(torch.long) out = torch.zeros(out_shape, dtype=dtype, device=x.device) out.scatter_(0, x.reshape((1,) + x.shape), default_value) if background is not None: out = torch.cat([out[0:background], out[background + 1 :]], dim=0) if channel_dim != 0: if channel_dim < 0: channel_dim = out.ndim + channel_dim order = (channel_dim,) + tuple(d for d in range(out.ndim) if d != channel_dim) out = out.permute(tuple(np.argsort(order))) out = out.contiguous() if is_ndarray: out = out.numpy() return out
[docs]def logits_to_prob(logits, activation, channel_dim: int = 1): """Converts logits to probabilities. Args: logits (torch.Tensor | np.ndarray): The logits. activation (str): Activation to use. One of ``'sigmoid'``, ``'softmax'``. channel_dim (int, optional): The channel dimension. Returns: torch.Tensor | np.ndarray: The probabilities. """ is_ndarray = isinstance(logits, np.ndarray) if is_ndarray: logits = torch.from_numpy(logits) if activation == "sigmoid": out = torch.sigmoid(logits) elif activation == "softmax": out = F.softmax(logits, dim=channel_dim) else: raise ValueError(f"activation '{activation}' not supported'") if is_ndarray: out = out.numpy() return out
[docs]def pred_to_categorical(pred_or_logits, activation, channel_dim: int = 1, threshold: float = 0.5): """Converts one-hot encoded predictions or logits to category. Args: pred_or_logits: One-hot encoded predictions or logits. Shape BxCx... activation (str): Activation to use. Either ``'sigmoid'`` or ``'softmax'`` if ``pred_or_logits`` are logits. If `None` or '', assumes that `pred` does not need to be passed through activation function. If 'softmax', should include a background class. include_background (bool): If `True`, the first slice of class dimension (``C``) will not be dropped. """ if activation not in [None, "", "sigmoid", "softmax"]: raise ValueError(f"activation '{activation}' not supported'") pred = pred_or_logits is_ndarray = isinstance(pred, np.ndarray) if is_ndarray: pred = torch.from_numpy(pred) if activation == "sigmoid": # TODO: Validate this case. out = one_hot_to_categorical(torch.sigmoid(pred) > threshold, channel_dim=channel_dim) elif activation == "softmax": out = torch.argmax(pred, dim=channel_dim) elif activation in (None, ""): # if not activation specified, assume it is one-hot encoded. out = one_hot_to_categorical(pred, channel_dim=channel_dim) else: raise ValueError(f"activation '{activation}' not supported'") if is_ndarray: out = out.numpy() return out