Source code for meddlr.modeling.layers.conv

from typing import Tuple, Union

import torch.nn as nn
import torch.nn.functional as F

from meddlr.modeling.layers.build import CUSTOM_LAYERS_REGISTRY

__all__ = ["ConvWS2d", "ConvWS3d"]


[docs]@CUSTOM_LAYERS_REGISTRY.register() class ConvWS2d(nn.Conv2d): """Conv2d with Weight Standardization. Adapted from https://github.com/joe-siyuan-qiao/pytorch-classification/blob/e6355f829e85ac05a71b8889f4fff77b9ab95d0b/models/layers.py References: S. Qiao, et al. Micro-Batch Training with Batch-Channel Normalization and Weight Standardization, ArXiv 2020. https://arxiv.org/abs/1903.10520 """
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride=1, padding=0, dilation=1, groups=1, bias=True, eps=1e-5, ): super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias ) self.eps = eps
def forward(self, x): weight = self.weight weight_mean = weight.view(weight.size(0), -1).mean(dim=1).view(-1, 1, 1, 1) weight = weight - weight_mean std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + self.eps weight = weight / std.expand_as(weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs]@CUSTOM_LAYERS_REGISTRY.register() class ConvWS3d(nn.Conv3d): """Conv3d with Weight Standardization. Adapted from: https://github.com/joe-siyuan-qiao/pytorch-classification/blob/e6355f829e85ac05a71b8889f4fff77b9ab95d0b/models/layers.py References: S. Qiao, et al. Micro-Batch Training with Batch-Channel Normalization and Weight Standardization, ArXiv 2020. https://arxiv.org/abs/1903.10520 """
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride=1, padding=0, dilation=1, groups=1, bias=True, eps=1e-5, ): super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias ) self.eps = eps
def forward(self, x): weight = self.weight weight_mean = weight.view(weight.size(0), -1).mean(dim=1).view(-1, 1, 1, 1, 1) weight = weight - weight_mean std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1, 1) + self.eps weight = weight / std.expand_as(weight) return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)