Source code for meddlr.ops.fft

from typing import Sequence

import numpy as np
import torch

from meddlr.ops import complex as cplx
from meddlr.ops.utils import roll
from meddlr.utils import env

if env.pt_version() >= [1, 6]:
    import torch.fft

__all__ = [
    "fftnc",
    "ifftnc",
    "fftc",
    "ifftc",
    "fft2c",
    "ifft2c",
    "fft3c",
    "ifft3c",
    "fftshift",
    "ifftshift",
]


[docs]def fftnc(input: torch.Tensor, dim=None, norm="ortho", is_real: bool = None) -> torch.Tensor: """Apply nD centered fast fourier transform. This function is a backwards-compatible wrapper for centered :meth:`torch.fft.fftn`. It supports backwards compatibility with ``torch.fft``as implemented in torch<1.7. Args: input (torch.Tensor): A tensor (typically complex). dim (Tuple[int]): Dimensions to be transformed. norm (str | bool, optional): The normalization method. Defaults to ``'ortho'``. For torch<1.7, only ``'ortho'`` is supported. is_real (bool, optional): If ``True``, ``input`` is a real-valued tensor. If ``None`` or ``False`` and ``input.shape[-1] == 2``, ``input`` is a real-view of a complex tensor. Returns: torch.Tensor Note: Real-valued tensors are not supported with ``torch<1.7``. """ return _fft_template(input, kind="fft", dim=dim, norm=norm, is_real=is_real)
[docs]def ifftnc(input: torch.Tensor, dim=None, norm="ortho", is_real: bool = None) -> torch.Tensor: """Apply nD centered inverse fast fourier transform. This supports backwards compatibility with ``torch.fft`` as implemented in torch<1.7. Args: input (torch.Tensor): A tensor (typically complex). dim (Tuple[int]): Dimensions to be transformed. norm (str | bool, optional): The normalization method. Defaults to ``'ortho'``. For torch<1.7, only ``'ortho'`` is supported. is_real (bool, optional): If ``True``, ``input`` is a real-valued tensor. If ``None`` or ``False`` and ``input.shape[-1] == 2``, ``input`` is a real-view of a complex tensor. Returns: torch.Tensor Note: Real-valued tensors are not supported with ``torch<1.7``. """ return _fft_template(input, kind="ifft", dim=dim, norm=norm, is_real=is_real)
[docs]def fftc(input: torch.Tensor, norm: str = "ortho", is_real: bool = None, channels_last=False): """Apply 1D centered Fast Fourier Transform (FFT). Args: input (torch.Tensor): A tensor. norm (str | bool, optional): The normalization method. Defaults to ``'ortho'``. For torch<1.7, only ``'ortho'`` is supported. is_real (bool, optional): If ``True``, ``input`` is a real-valued tensor. If ``None`` or ``False`` and ``input.shape[-1] == 2``, ``input`` is a real-view of a complex tensor. channels_last (bool, optional): If ``True``, apply to first non-batch dimensions. If ``False``, apply to last dimension. Returns: torch.Tensor: The 1D centered FFT of the input. """ dim = _get_fft_dims(input, 1, is_real=is_real, channels_last=channels_last) return fftnc(input, dim=dim, norm=norm, is_real=is_real)
[docs]def ifftc(input: torch.Tensor, channels_last=False, norm: str = "ortho", is_real: bool = None): """Apply centered 2 dimensional Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data with the last dimension containing real and imaginary components. dims (2-tuple): Containing spatial dimension indices. Returns: torch.Tensor: The FFT of the input. """ dim = _get_fft_dims(input, 1, is_real=is_real, channels_last=channels_last) return ifftnc(input, dim=dim, norm=norm, is_real=is_real)
[docs]def fft2c(input: torch.Tensor, channels_last=False, norm: str = "ortho", is_real: bool = None): """Apply centered 2 dimensional Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data with the last dimension containing real and imaginary components. dims (2-tuple): Containing spatial dimension indices. Returns: torch.Tensor: The FFT of the input. """ dim = _get_fft_dims(input, 2, is_real=is_real, channels_last=channels_last) return fftnc(input, dim=dim, norm=norm, is_real=is_real)
[docs]def ifft2c(input, channels_last=False, norm: str = "ortho", is_real: bool = None): """Apply centered 2-dimensional Inverse Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data with the last dimension containing real and imaginary components. dims (2-tuple): Containing spatial dimension indices. Returns: torch.Tensor: The IFFT of the input. """ dim = _get_fft_dims(input, 2, is_real=is_real, channels_last=channels_last) return ifftnc(input, dim=dim, norm=norm, is_real=is_real)
[docs]def fft3c(input: torch.Tensor, channels_last=False, norm: str = "ortho", is_real: bool = None): """Apply centered 2 dimensional Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data with the last dimension containing real and imaginary components. dims (2-tuple): Containing spatial dimension indices. Returns: torch.Tensor: The FFT of the input. """ dim = _get_fft_dims(input, 3, is_real=is_real, channels_last=channels_last) return fftnc(input, dim=dim, norm=norm, is_real=is_real)
[docs]def ifft3c(input, channels_last=False, norm: str = "ortho", is_real: bool = None): """Apply centered 2-dimensional Inverse Fast Fourier Transform. Args: data (torch.Tensor): Complex valued input data with the last dimension containing real and imaginary components. dims (2-tuple): Containing spatial dimension indices. Returns: torch.Tensor: The IFFT of the input. """ dim = _get_fft_dims(input, 3, is_real=is_real, channels_last=channels_last) return ifftnc(input, dim=dim, norm=norm, is_real=is_real)
[docs]def fftshift(x, dim=None): """ Similar to np.fft.fftshift but applies to PyTorch Tensors """ if dim is None: dim = tuple(range(x.dim())) shift = [dim // 2 for dim in x.shape] elif isinstance(dim, int): shift = x.shape[dim] // 2 else: shift = [x.shape[i] // 2 for i in dim] return roll(x, shift, dim)
[docs]def ifftshift(x, dim=None): """ Similar to np.fft.ifftshift but applies to PyTorch Tensors """ if dim is None: dim = tuple(range(x.dim())) shift = [(dim + 1) // 2 for dim in x.shape] elif isinstance(dim, int): shift = (x.shape[dim] + 1) // 2 else: shift = [(x.shape[i] + 1) // 2 for i in dim] return roll(x, shift, dim)
def _fft_template( data: torch.Tensor, kind, dim=None, norm="ortho", is_real: bool = None, centered: bool = True ) -> torch.Tensor: """Template for fft operations. Args: data (torch.Tensor): A tensor. kind (str): Either ``'fft'`` or ``'ifft'``. dim (int(s), optional): The dimension(s) along which to apply the operation. Defaults to all dimensions. norm (str, optional): The normalization method. Defaults to ``'ortho'``. is_real (bool, optional): If ``True``, ``input`` is treated like a real-valued tensor. If not specified, this is ``True`` only if ``data`` is not complex and data is not inferred to be a real view of a complex tensor (i.e. ``data.shape[-1] != 2``). centered (bool, optional): If ``True``, apply centered FFT. Defaults to ``True``. Returns: torch.Tensor: The FFT (or IFFT) of the input. """ if isinstance(dim, int): dim = (dim,) if norm is True: norm = "ortho" assert kind in ("fft", "ifft") if is_real is None: is_real = not (cplx.is_complex_as_real(data) or cplx.is_complex(data)) ndim = data.ndim if not env.supports_cplx_tensor(): # Defaults to torch.fft method. assert norm in ("ortho", False) # norm not supported norm = norm == "ortho" assert not is_real # real tensors not supported assert cplx.is_complex_as_real(data) dim = tuple(sorted(_to_positive_index(dim))) if ndim - 1 in dim: raise ValueError("Cannot take fft along the real/imaginary channel.") if len(set(dim)) != len(dim): raise ValueError(f"Expected unique dimensions, got {dim}.") signal_ndim = len(dim) if signal_ndim > 3: raise ValueError(f"Number of dimensions must be <=3, got {len(dim)}.") # Reorder dims (if necessary). last_dims = dim + (ndim - 1,) permute = last_dims != tuple(range(ndim - signal_ndim - 1, ndim)) if permute: order = tuple(i for i in range(ndim)) + last_dims data = data.permute(order) shift_dims = tuple([-2 - i for i in range(len(dim))][::-1]) if kind == "fft": if centered: data = ifftshift(data, dim=shift_dims) data = torch.Tensor.fft(data, signal_ndim, normalized=norm) if centered: data = fftshift(data, shift_dims) elif kind == "ifft": if centered: data = ifftshift(data, dim=shift_dims) data = torch.ifft(data, signal_ndim, normalized=norm) if centered: data = fftshift(data, shift_dims) else: raise ValueError(f"Unknown `kind={kind}`") # Reorder dims (if necessary). if permute: reorder = tuple(np.argsort(order)) data = data.permute(reorder) return data is_real_view = not is_real and cplx.is_complex_as_real(data) if is_real_view: # Make dimensions positive relative to input dimensions. dim = _to_positive_index(dim, ndim=data.ndim) data = torch.view_as_complex(data) if kind == "fft": if centered: data = ifftshift(data, dim=dim) data = torch.fft.fftn(data, dim=dim, norm=norm) if centered: data = fftshift(data, dim=dim) elif kind == "ifft": if centered: data = ifftshift(data, dim=dim) data = torch.fft.ifftn(data, dim=dim, norm=norm) if centered: data = fftshift(data, dim=dim) else: raise ValueError(f"Unknown `kind={kind}`") if is_real_view: data = torch.view_as_real(data) return data def _get_fft_dims(x, signal_ndim, is_real, channels_last): if channels_last: return tuple(range(1, 1 + signal_ndim)) if not is_real and not cplx.is_complex(x) and cplx.is_complex_as_real(x): return tuple(range(-1 - signal_ndim, -1)) else: return tuple(range(-signal_ndim, 0)) def _to_positive_index(idxs: Sequence[int], ndim: int): return tuple(ndim + i if i < 0 else i for i in idxs)