"""
Utilities for doing complex-valued operations.
"""
import numpy as np
import torch
from meddlr.utils.deprecated import deprecated
from meddlr.utils.env import supports_cplx_tensor
__all__ = [
"is_complex",
"is_complex_as_real",
"conj",
"mul",
"abs",
"angle",
"real",
"imag",
"from_polar",
"channels_first",
"channels_last",
"get_mask",
"matmul",
"power_method",
"svd",
"to_numpy",
"to_tensor",
"rss",
"center_crop",
]
[docs]def is_complex(x):
"""Returns if ``x`` is a complex-tensor.
This function is a wrapper around torch.is_complex() for PyTorch<1.7.
torch < 1.7 does not have the ``torch.is_complex`` directive, so
we can't call it for older PyTorch versions.
Args:
x (torch.Tensor): A tensor.
Returns:
bool: ``True`` if complex tensors are supported and the tensor is complex.
"""
return supports_cplx_tensor() and torch.is_complex(x)
[docs]def is_complex_as_real(x):
"""
Returns ``True`` if the tensor follows the real-view
convention for complex numbers.
The real-view of a complex tensor has the shape [..., 2].
Args:
x (torch.Tensor): A tensor.
Returns:
bool: ``True`` if the tensor follows the real-view convention
for complex numbers.
Note:
We recommend using complex tensors instead of the real-view
convention. This function cannot interpret if the last dimension
has a size of ``2`` because it is the real-imaginary channel or
for some other reason.
"""
return not is_complex(x) and x.size(-1) == 2
[docs]def conj(x):
"""
Computes the complex conjugate of complex-valued input tensor (x).
``conj(a + ib)`` = :math:`\\bar{a + ib} = a - ib`
Args:
x (torch.Tensor): A tensor.
Returns:
torch.Tensor: The conjugate.
"""
assert is_complex_as_real(x) or is_complex(x)
if is_complex(x):
return x.conj()
else:
real = x[..., 0]
imag = x[..., 1]
return torch.stack((real, -1.0 * imag), dim=-1)
[docs]def mul(x, y):
"""
Multiplies two complex-valued tensors x and y.
:math:`z = (a + ib) * (c + id) = (ac - bd) + i(ad + bc)`
Args:
x (torch.Tensor): A tensor.
y (torch.Tensor): A tensor.
Returns:
torch.Tensor: The matrix multiplication.
"""
# assert x.size(-1) == 2
# assert y.size(-1) == 2
#
# a = x[..., 0]
# b = x[..., 1]
# c = y[..., 0]
# d = y[..., 1]
#
# real = a * c - b * d
# imag = a * d + b * c
# return torch.stack((real, imag), dim=-1)
assert is_complex_as_real(x) or is_complex(x)
assert is_complex_as_real(y) or is_complex(y)
if is_complex(x):
return x * y
else:
# note: using select() makes sure that another copy is not made.
# real = a*c - b*d
real = x.select(-1, 0) * y.select(-1, 0) # a*c
real -= x.select(-1, 1) * y.select(-1, 1) # b*d
# imag = a*d + b*c
imag = x.select(-1, 0) * y.select(-1, 1) # a*d
imag += x.select(-1, 1) * y.select(-1, 0) # b*c
return torch.stack((real, imag), dim=-1)
[docs]def abs(x):
"""
Computes the absolute value (magnitude) of a complex-valued input tensor (x).
Args:
x (torch.Tensor): A tensor.
Returns:
torch.Tensor: The magnitude tensor.
"""
assert is_complex_as_real(x) or is_complex(x)
if is_complex(x):
return x.abs()
else:
return (x**2).sum(dim=-1).sqrt()
[docs]def angle(x, eps=1e-11):
"""
Computes the phase of a complex-valued input tensor (x).
Args:
x (torch.Tensor): A tensor.
Returns:
torch.Tensor: The angle tensor.
"""
assert is_complex_as_real(x) or is_complex(x)
if is_complex(x):
return x.angle()
else:
return torch.atan(x[..., 1] / (x[..., 0] + eps))
[docs]def real(x):
"""
Gets real component of complex tensor.
"""
assert is_complex_as_real(x) or is_complex(x)
if is_complex(x):
return x.real
else:
return x[..., 0]
[docs]def imag(x):
"""
Gets imaginary component of complex tensor.
"""
assert is_complex_as_real(x) or is_complex(x)
if is_complex(x):
return x.imag
else:
return x[..., 1]
[docs]def from_polar(magnitude, phase, return_cplx: bool = None):
"""
Computes real and imaginary values from polar representation.
"""
if return_cplx and not supports_cplx_tensor():
raise RuntimeError(f"torch {torch.__version__} does not support complex tensors")
if supports_cplx_tensor():
out = torch.polar(magnitude, phase)
if return_cplx is False:
out = torch.view_as_real(out)
return out
else:
real = magnitude * torch.cos(phase)
imag = magnitude * torch.sin(phase)
return torch.stack((real, imag), dim=-1)
polar = from_polar
[docs]def channels_first(x: torch.Tensor):
"""Permute complex-valued ``x`` to channels-first convention.
For complex values, there are two potential conventions:
1. ``x`` is complex-valued: ``(B,...,C)`` -> ``(B, C, ...)``.
2. The real and imaginary components are stored in the last dimension.
``(B,...,C,2)`` -> ``(B, C, ..., 2)``.
Args:
x (torch.Tensor): A complex-valued tensor of shape ``(B,...,C)``
or a real-valued tensor of shape ``(B,...,C,2)``.
Returns:
torch.Tensor: A channels-first tensor. If ``x`` is complex,
this will also be complex. If ``x`` is the real-view of
a complex tensor, this will also be the real view.
"""
assert is_complex_as_real(x) or is_complex(x)
if is_complex(x):
return x.permute((0, x.ndim - 1) + tuple(range(1, x.ndim - 1)))
else:
return x.permute((0, x.ndim - 2) + tuple(range(1, x.ndim - 2)) + (x.ndim - 1,))
@deprecated(
reason="Renamed to channels_first",
vremove="v0.1.0",
replacement="meddlr.ops.complex.channels_first",
)
def channel_first(x: torch.Tensor):
"""Deprecated alias for :func:`channels_first`."""
return channels_first(x)
[docs]def channels_last(x: torch.Tensor):
"""Permute complex-valued ``x`` to channels-last convention.
Args:
x (torch.Tensor): A tensor of shape [B,C,H,W,...] or [B,C,H,W,...,2].
Returns:
torch.Tensor: A tensor of shape [B,H,W,...,C] or [B,H,W,...,C,2].
"""
assert is_complex_as_real(x) or is_complex(x)
if is_complex(x):
return x.permute((0,) + tuple(range(2, x.ndim) + (1,)))
else:
order = (0,) + tuple(range(2, x.ndim - 2)) + (1, x.ndim - 1)
return x.permute(order)
[docs]def get_mask(x, eps=1e-11, coil_dim=None):
"""Returns a binary mask for where ``x`` is nonzero with ``eps`` tolerance.
- 0, if both real and imaginary components are zero.
- 1, if either real and imaginary components are non-zero.
Args:
x (torch.Tensor): A complex-valued tensor.
eps (float): Tolerance for zer0-value.
coil_dim (int): The coil dimension.
When this is provided, if a pixel is non-zero for any coil,
we assume that pixel was acquired. This is useful when
a coil ``i`` has zero signal but the location was actually
acquired.
Returns:
torch.Tensor: A binary mask of shape ``x.shape``.
"""
unsqueeze = True
if is_complex(x):
unsqueeze = False
x = torch.view_as_real(x)
assert x.size(-1) == 2
absx = abs(x)
loc = absx > eps # squashes last dimension
if coil_dim is not None:
loc = loc.any(coil_dim, keepdims=True)
mask = torch.where(loc, torch.ones_like(absx), torch.zeros_like(absx))
if unsqueeze:
mask = mask.unsqueeze(-1)
return mask
[docs]def matmul(X, Y):
"""
Computes complex-valued matrix product of X and Y.
"""
assert is_complex_as_real(X) or is_complex(X)
assert is_complex_as_real(Y) or is_complex(Y)
if is_complex(X):
return torch.matmul(X, Y)
else:
A = X[..., 0]
B = X[..., 1]
C = Y[..., 0]
D = Y[..., 1]
real = torch.matmul(A, C) - torch.matmul(B, D)
imag = torch.matmul(A, D) + torch.matmul(B, C)
return torch.stack((real, imag), dim=-1)
[docs]def power_method(X, num_iter=10, eps=1e-6):
"""
Iteratively computes first singular value of X using power method.
"""
if is_complex_as_real(X) or is_complex(X):
X = torch.view_as_real(X)
assert X.size(-1) == 2
# get data dimensions
batch_size, m, n, _ = X.shape
XhX = matmul(conj(X).permute(0, 2, 1, 3), X)
# initialize random eigenvector
if XhX.is_cuda:
v = torch.cuda.FloatTensor(batch_size, n, 1, 2).uniform_()
else:
v = torch.FloatTensor(batch_size, n, 1, 2).uniform_()
# v = torch.rand(batch_size, n, 1, 2).to(X.device) # slow way
for _i in range(num_iter):
v = matmul(XhX, v)
eigenvals = (abs(v) ** 2).sum(1).sqrt()
v = v / (eigenvals.reshape(batch_size, 1, 1, 1) + eps)
return eigenvals.reshape(batch_size)
[docs]def svd(X, compute_uv=True):
"""
Computes singular value decomposition of batch of complex-valued matrices.
Args:
matrix (torch.Tensor): batch of complex-valued 2D matrices
[batch, m, n, 2]
Returns:
U, S, V (tuple)
"""
if is_complex_as_real(X) or is_complex(X):
X = torch.view_as_real(X)
assert X.size(-1) == 2
# Get data dimensions
batch_size, m, n, _ = X.shape
# Allocate block-wise matrix
# (otherwise, need to allocate new arrays three times)
if X.is_cuda:
Xb = torch.cuda.FloatTensor(batch_size, 2 * m, 2 * n).fill_(0)
else:
Xb = torch.FloatTensor(batch_size, 2 * m, 2 * n).fill_(0)
# Construct real-valued block matrix
# Xb = [X.real, X.imag; -X.imag, X.real]
Xb[:, :m, :n] = X[..., 0]
Xb[:, :m, n:] = X[..., 1]
Xb[:, m:, :n] = -X[..., 1]
Xb[:, m:, n:] = X[..., 0]
# Perform real-valued SVD
U, S, V = torch.svd(Xb, compute_uv=compute_uv)
# Slice U, S, V appropriately
S = S[:, ::2]
U = torch.stack((U[:, :m, ::2], -U[:, m:, ::2]), dim=3)
V = torch.stack((V[:, :n, ::2], -V[:, n:, ::2]), dim=3)
return U, S, V
[docs]def to_numpy(x: torch.Tensor):
"""
Convert real-valued PyTorch tensor to complex-valued numpy array.
"""
assert is_complex_as_real(x) or is_complex(x)
if is_complex(x):
return x.clone().numpy() # previously returned copy
else:
x = x.numpy()
return x[..., 0] + 1j * x[..., 1]
[docs]def to_tensor(x: np.ndarray):
"""
Convert complex-valued numpy array to real-valued PyTorch tensor.
"""
if not supports_cplx_tensor():
x = np.stack((x.real, x.imag), axis=-1)
return torch.from_numpy(x)
root_sum_of_squares = rss
[docs]def center_crop(x: torch.Tensor, shape, channels_last: bool = False):
"""
Apply a center crop to the input image or batch of complex images.
Args:
data (torch.Tensor): The complex input tensor to be center cropped.
shape (int, int): The output shape. The shape should be smaller than the
corresponding dimensions of data.
channels_last (bool, optional): If ``True``, crop dimensions ``range(1, 1+len(shape))``.
If ``False``, apply to last non-real/imaginary channel dimensions.
Returns:
torch.Tensor: The center cropped image.
"""
if channels_last:
dims = range(1, 1 + len(shape))
elif not is_complex(x) and is_complex_as_real(x):
dims = range(-1 - len(shape), -1)
else:
dims = range(-len(shape), 0)
x_shape = tuple(x.shape[d] for d in dims)
assert all(0 < shape[idx] <= x_shape[idx] for idx in range(len(shape)))
sl = [slice(None) for _ in range(x.ndim)]
for d, shp, x_shp in zip(dims, shape, x_shape):
start = (x_shp - shp) // 2
end = start + shp
sl[d] = slice(start, end)
return x[sl]
def bdot(x: torch.Tensor, y: torch.Tensor, keepdim: bool = False) -> torch.Tensor:
"""Batch dot product (inner product) of two complex-valued tensors.
Args:
x: The first input tensor.
y: The second input tensor.
Returns:
torch.Tensor: The batch inner product :math:`<x, y>_i = sum(conj(x_i) * y_i)`.
Note:
To avoid ambiguity, use torch.complex tensors to represent complex values.
"""
dim = tuple(range(1, x.ndim))
return torch.sum((x.conj() * y), dim=dim, keepdim=keepdim)