Source code for meddlr.modeling.layers.gauss

from typing import Sequence, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from meddlr.modeling.layers.build import CUSTOM_LAYERS_REGISTRY

__all__ = ["gaussian", "get_gaussian_kernel", "GaussianBlur"]


def gaussian(window_size, sigma, normalize=True):
    def gauss_fcn(x):
        center = window_size // 2 if window_size % 2 == 1 else window_size // 2 - 0.5
        return -((x - center) ** 2) / float(2 * sigma**2)

    gauss = torch.stack([torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)])
    if normalize:
        gauss = gauss / gauss.sum()
    return gauss


def get_gaussian_kernel(
    kernel_size: Union[int, Sequence[int]],
    sigma: Union[float, Sequence[float]],
    normalize: bool = True,
) -> torch.Tensor:
    """Function that returns Gaussian filter coefficients.

    Args:
        kernel_size (int(s)): filter size. It should be positive.
        sigma (float(s)): gaussian standard deviation.
        normalize (bool, optional): If `True`, kernel will be normalized. i.e. `kernel.sum() == 1)

    Returns:
        Tensor: nD tensor with gaussian filter coefficients. Shape :math:`(\text{kernel_size})`

    Examples::

        >>> medsegpy.layers.get_gaussian_kernel(3, 2.5)
        tensor([0.3243, 0.3513, 0.3243])

        >>> medsegpy.layers.get_gaussian_kernel((3, 3), (1.5, 1.5))
        tensor([[0.0947, 0.1183, 0.0947],
                [0.1183, 0.1478, 0.1183],
                [0.0947, 0.1183, 0.0947]])
    """
    kernel_size_seq = (kernel_size,) if not isinstance(kernel_size, Sequence) else kernel_size
    sigma_seq = (sigma,) if not isinstance(sigma, Sequence) else sigma
    if not isinstance(kernel_size, (int, Tuple)) or any(k <= 0 for k in kernel_size_seq):
        raise TypeError(
            "kernel_size must be a (sequence of) odd positive integer. "
            "Got {}".format(kernel_size)
        )
    if len(kernel_size_seq) != len(sigma_seq):
        raise ValueError(
            "kernel_size and sigma must have same number of elements. "
            "Got kernel_size={}, sigma={}".format(kernel_size, sigma)
        )
    assert len(kernel_size_seq) <= 26
    kernels_1d = tuple(
        gaussian(ksize, sigma, normalize) for ksize, sigma in zip(kernel_size_seq, sigma_seq)
    )
    elems = tuple(chr(ord("a") + i) for i in range(len(kernels_1d)))
    equation = "{}->{}".format(",".join(elems), "".join(elems))
    return torch.einsum(equation, *kernels_1d)


[docs]@CUSTOM_LAYERS_REGISTRY.register() class GaussianBlur(nn.Module): r"""Creates an operator that blurs a tensor using a Gaussian filter. The operator smooths the given tensor with a gaussian kernel by convolving it to each channel. It suports batched operation. Arguments: kernel_size (Tuple[int, int]): the size of the kernel. Dimensions are `([D, H], W)` sigma (Tuple[float, float]): the standard deviation of the kernel. Dimensions are `([D, H], W)` Returns: Tensor: the blurred tensor. Shape :math:`(B, C, [..., H], W)` Shape: - Input: :math:`(B, C, H, W)` - Output: :math:`(B, C, H, W)` Examples: >>> input = torch.rand(2, 4, 5, 5) >>> gauss = GaussianBlur((3, 3), (1.5, 1.5)) >>> output = gauss(input) # 2x4x5x5 """
[docs] def __init__( self, kernel_size: Union[int, Tuple[int, ...]], sigma: Union[float, Tuple[float, ...]] ): super().__init__() kernel_size = ( (kernel_size,) if not isinstance(kernel_size, Sequence) else tuple(kernel_size) ) if any(k % 2 == 0 for k in kernel_size): raise ValueError("kernel_size must be odd and positive. Got {}".format(kernel_size)) sigma = (sigma,) if not isinstance(sigma, Sequence) else tuple(sigma) self.kernel_size: Tuple[int, ...] = kernel_size self.sigma: Tuple[float, ...] = sigma self._padding: Tuple[int, ...] = self.compute_zero_padding(kernel_size) self.kernel: torch.Tensor = nn.Parameter( get_gaussian_kernel(kernel_size, sigma, normalize=True), requires_grad=False, ) self.spatial_dim = len(kernel_size) self.conv = [F.conv1d, F.conv2d, F.conv3d][self.spatial_dim - 1]
@staticmethod def compute_zero_padding(kernel_size: Tuple[int, ...]) -> Tuple[int, ...]: """Computes zero padding tuple.""" computed = [(k - 1) // 2 for k in kernel_size] return tuple(computed) def forward(self, x: torch.Tensor): if not torch.is_tensor(x): raise TypeError("Input x type is not a torch.Tensor. Got {}".format(type(x))) c = x.shape[1] spatial_dim = self.spatial_dim tmp_kernel: torch.Tensor = self.kernel kernel: torch.Tensor = tmp_kernel.repeat([c, 1] + [1] * spatial_dim) # TODO: explore solution when using jit.trace since it raises a warning # because the shape is converted to a tensor instead to a int. # convolve tensor with gaussian kernel return self.conv(x, kernel, padding=self._padding, stride=1, groups=c)