import inspect
from abc import ABC, abstractmethod
from typing import Tuple, Union
import torch
import torch.nn as nn
from meddlr.modeling.blocks.conv_blocks import SimpleConvBlockNd
__all__ = [
"ResBlockNd",
"ResBlock2d",
"ResBlock3d",
"ConcatBlockNd",
"ConcatBlock2d",
"ConcatBlock3d",
]
class _SimpleFuseBlockNd(nn.Module, ABC):
"""Series of :class:`SimpleConvBlockNd` with residual connection."""
# Assumes order is the last argument in SimpleConvBlockNd
_DEFAULT_CONV_BLOCK_ORDER = inspect.getfullargspec(SimpleConvBlockNd).defaults[-1]
def __init__(
self,
in_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
n_blocks: int,
dimension: int,
connect_before: str = None,
**kwargs,
):
"""
Args:
in_channels (int): Number of channels in the input.
kernel_size (`int(s)`): Convolution kernel size.
n_blocks (int): Number of conv blocks.
dimension (int): Integer specifying the dimension of convolution.
connect_before (str, optional): Layer to add residual connection before in conv block.
For example, if `n_blocks=1`, conv block `order=("conv", "batchnorm", "relu")`,
and `connect_before="relu"`, residual block will look like below. If `None`,
residual connection will be made after full conv block.
x -> Conv -> BatchNorm -> + -> ReLU
| ^
| |
--------------------------
kwargs: `SimpleConvBlockNd` arguments. `in_channels` and `out_channels` required.
"""
super().__init__()
self.n_blocks = n_blocks
# Determine order for connecting before.
order = kwargs.pop("order", self._DEFAULT_CONV_BLOCK_ORDER)
if connect_before:
if connect_before not in order:
raise ValueError(f"Layer {connect_before} not in conv block `order` ({order})")
layer_idx = order.index(connect_before)
if layer_idx == 0:
raise ValueError(
f"Layer {connect_before} occurs first in conv block `order` ({order}). "
f"Reduce n_block by 1, set `connect_before=None`, and add `SimpleConvBlockNd` "
f"after this residual block."
)
standard_order = [order] * (self.n_blocks - 1)
split_order = [order[:layer_idx], order[layer_idx:]]
conv_orders = standard_order + split_order
else:
conv_orders = [order] * self.n_blocks
self.connect_before = connect_before
# Build conv blocks
self.blocks = nn.ModuleDict(
{
f"block_{i+1}": SimpleConvBlockNd(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
dimension=dimension,
order=order,
**kwargs,
)
for i, order in enumerate(conv_orders)
}
)
assert (self.connect_before and len(self.blocks) == self.n_blocks + 1) or (
not self.connect_before and len(self.blocks) == self.n_blocks
)
def forward(self, x):
out = x
for i in range(self.n_blocks):
out = self.blocks[f"block_{i+1}"](out)
out = self.fuse(x, out)
# Handle any remaining layers when connect_before is specified
if self.connect_before:
out = self.blocks[f"block_{self.n_blocks + 1}"](out)
return out
@abstractmethod
def fuse(self, x, y):
"""Fuse two tensors"""
pass
[docs]class ResBlockNd(_SimpleFuseBlockNd):
"""Residual block.
This block adds a residual connection to the the :cls:`SimpleConvBlockNd` block.
The order of the layers follows the same order used by :cls:`SimpleConvBlockNd`
and can be manually configured using the ``order`` argument.
Args:
in_channels (int): Number of channels in the input.
kernel_size (int(s)): Convolution kernel size.
n_blocks (int): Number of conv blocks.
dimension (int): Integer specifying the dimension of convolution.
connect_before (str, optional): Layer to add residual connection before in conv block.
For example, if `n_blocks=1`, conv block `order=("conv", "batchnorm", "relu")`,
and `connect_before="relu"`, residual block will look like below. If `None`,
residual connection will be made after full conv block.
x -> Conv -> BatchNorm -> + -> ReLU
| ^
| |
--------------------------
kwargs: `SimpleConvBlockNd` arguments. `in_channels` and `out_channels` required.
"""
[docs] def __init__(
self,
in_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
n_blocks: int,
dimension: int,
connect_before: str = None,
**kwargs,
):
super().__init__(in_channels, kernel_size, n_blocks, dimension, connect_before, **kwargs)
def fuse(self, x, y):
return x + y
[docs]class ResBlock2d(ResBlockNd):
"""2D implementation of :cls:`ResBlockNd`."""
[docs] def __init__(
self,
in_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
n_blocks: int,
connect_before: str = None,
**kwargs,
):
super().__init__(in_channels, kernel_size, n_blocks, 2, connect_before, **kwargs)
[docs]class ResBlock3d(ResBlockNd):
"""3D implementation of :class:`ResBlockNd`."""
[docs] def __init__(
self,
in_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
n_blocks: int,
connect_before: str = None,
**kwargs,
):
super().__init__(in_channels, kernel_size, n_blocks, 3, connect_before, **kwargs)
[docs]class ConcatBlockNd(_SimpleFuseBlockNd):
"""
Args:
in_channels (int): Number of channels in the input.
kernel_size (`int(s)`): Convolution kernel size.
n_blocks (int): Number of conv blocks.
dimension (int): Integer specifying the dimension of convolution.
connect_before (str, optional): Layer to add residual connection before in conv block.
For example, if `n_blocks=1`, conv block `order=("conv", "batchnorm", "relu")`,
and `connect_before="relu"`, residual block will look like below. If `None`,
residual connection will be made after full conv block.
x -> Conv -> BatchNorm -> [concat] -> ReLU
| ^
| |
------------------------------
kwargs: `SimpleConvBlockNd` arguments. `in_channels` and `out_channels` required.
"""
[docs] def __init__(
self,
in_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
n_blocks: int,
dimension: int,
connect_before: str = None,
**kwargs,
):
super().__init__(in_channels, kernel_size, n_blocks, dimension, connect_before, **kwargs)
def fuse(self, x, y):
return torch.cat([x, y], dim=1)
[docs]class ConcatBlock2d(ConcatBlockNd):
"""2D implementation of :class:`ConcatBlockNd`."""
[docs] def __init__(
self,
in_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
n_blocks: int,
connect_before: str = None,
**kwargs,
):
super().__init__(in_channels, kernel_size, n_blocks, 2, connect_before, **kwargs)
[docs]class ConcatBlock3d(ConcatBlockNd):
"""3D implementation of :class:`ConcatBlockNd`."""
[docs] def __init__(
self,
in_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
n_blocks: int,
connect_before: str = None,
**kwargs,
):
super().__init__(in_channels, kernel_size, n_blocks, 3, connect_before, **kwargs)