Source code for meddlr.modeling.layers.build

import copy
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from fvcore.common.registry import Registry
from torch import nn

CUSTOM_LAYERS_REGISTRY = Registry("CUSTOM_LAYERS")
CUSTOM_LAYERS_REGISTRY.__doc__ = """
Registry for custom layers.

Use this registry to identify if the layer is not provided by default in torch.nn.
"""

_LAYER_SHORTCUTS = {
    "bn": "batchnorm",
    "dropout1d": "dropout",
}
_PT_LAYERS_LOWERCASE = {
    name.lower(): layer
    for name, layer in nn.__dict__.items()
    if isinstance(layer, type) and issubclass(layer, nn.Module)
}

__all__ = ["get_layer_type", "get_layer_kind"]


[docs]def get_layer_type(name: str, dimension: int = None) -> type: """Returns the layer type based on the name and, in some cases, the dimension. This function searches both the default PyTorch layers and the custom layers registered in :obj:`CUSTOM_LAYERS_REGISTRY`. Args: name (str): Name of the layer. dimension (int, optional): Dimension of the layer. Note, not all layers require this argument (e.g. ReLU), but it may be safe to pass it regardless. Returns: type: Layer type. Raises: ValueError: If the layer type is not found with the name/dimension pair. """ in_name = name name = name.lower() # Handle some layer names earlier because of naming conventions. if name == "dropout" and dimension: name += f"{dimension}d" if name in _LAYER_SHORTCUTS: name = _LAYER_SHORTCUTS[name] custom_layer_names = {x.lower(): x for x in CUSTOM_LAYERS_REGISTRY._obj_map} if name in CUSTOM_LAYERS_REGISTRY: return CUSTOM_LAYERS_REGISTRY.get(name) if name in custom_layer_names: return CUSTOM_LAYERS_REGISTRY.get(custom_layer_names[name]) if name in _PT_LAYERS_LOWERCASE: return _PT_LAYERS_LOWERCASE.get(name) if any(x in name for x in ["norm", "conv"]): if not dimension: raise ValueError(f"{in_name} requires dimension") if dimension: name += f"{dimension}d" if name in CUSTOM_LAYERS_REGISTRY: return CUSTOM_LAYERS_REGISTRY.get(name) if name in custom_layer_names: return CUSTOM_LAYERS_REGISTRY.get(custom_layer_names[name]) if name in _PT_LAYERS_LOWERCASE: return _PT_LAYERS_LOWERCASE.get(name) raise ValueError(f"No layer found for '{in_name}'")
[docs]def get_layer_kind(layer_type: Union[type, str]) -> str: """Returns the layer kind based on the layer type. The layer kind is a string that describes the kind of layer: - "conv": Convolutional layer. - "norm": Normalization layer. - "act": Activation layer. - "dropout": Dropout layer. - "unknown": Unknown layer. This delineation is useful for building models that order layers based on their kind. For example, a model with layers ``conv->norm->act->dropout`` would need to know the kind of the different types of the layers to organize them appropriately. TODO: Add support for pooling layers. Args: layer_type (Union[type, str]): Layer type. Returns: str: Layer kind. """ if isinstance(layer_type, str): layer_type = get_layer_type(layer_type, dimension=2) assert issubclass(layer_type, nn.Module) kinds_in_name = ["norm", "conv"] name = layer_type.__name__ lower_name = name.lower() for kind in kinds_in_name: if kind in lower_name: return kind if hasattr(nn.modules.activation, name): return "act" if "dropout" in lower_name: return "dropout" return "unknown"
_LayerInfoInitKwargsDict = Dict[str, Any] _LayerInfoInitKwargsFlatSequence = Union[List[Any], Tuple[Any]] _NestedInnerArgs = Tuple[str, Any] _LayerInfoInitKwargsNestedSequence = Union[List[_NestedInnerArgs], Tuple[_NestedInnerArgs, ...]] _LayerInfoInitKwargsType = Union[ _LayerInfoInitKwargsDict, _LayerInfoInitKwargsFlatSequence, _LayerInfoInitKwargsNestedSequence ] # LayerInfo type schema with raw Python types (str, int, float, tuple, list, etc.). LayerInfoRawType = Union[ str, Dict[str, _LayerInfoInitKwargsType], Tuple[str, _LayerInfoInitKwargsType] ] @dataclass class LayerInfo: """Dataclass for managing layer information.""" name: str # name of the layer dimension: Optional[int] = None # The dimension of the layer. init_kwargs: Dict[str, Any] = field(default_factory=dict) # keyword args to initialize layer @property def kind(self) -> str: """The layer kind.""" return get_layer_kind(self.ltype) @property def ltype(self) -> type: """The layer type.""" return get_layer_type(self.name, dimension=self.dimension) @classmethod def format(cls, layer_info: LayerInfoRawType) -> "LayerInfo": """Formats layer information from Python raw types to LayerInfo object. Args: layer_info: Returns: LayerInfo """ if isinstance(layer_info, str): return LayerInfo(name=layer_info, init_kwargs={}) if isinstance(layer_info, Dict): if len(layer_info) != 1: raise ValueError( "Dictionary format for LayerInfo can only have one key-value pair - " "e.g. {'dropout': {'p': 0.5}}. " f"Got {layer_info}" ) name, init_kwargs = list(layer_info.items())[0] elif isinstance(layer_info, (Tuple, List)): if len(layer_info) != 2: raise ValueError( "Sequence format for LayerInfo should be formatted as (name, init_kwargs) - " "e.g. ('dropout': {'p': 0.5}), ('dropout', ('p', 0.5). " f"Got {layer_info}" ) name, init_kwargs = layer_info[0], layer_info[1] else: raise ValueError(f"Unsupported layer info format: {type(layer_info)}") # Format init kwargs. init_kwargs_err_message = ( "Unknown init_kwargs format. init_kwargs must follow one of these formats: " "\n\t- dict: {key1: value1, key2: value2}" "\n\t- flat sequence: [key1, value1, key2, value2, ...]" "\n\t- nested sequence: [(key1, value1), (key2, value2), ...]\n" f"Got {init_kwargs}" ) if isinstance(init_kwargs, (Tuple, List)): if all(isinstance(x, (Tuple, List)) and len(x) == 2 for x in init_kwargs): # Nested sequence - i.e. [(key1, value1), (key2, value2), ...] init_kwargs = {x[0]: x[1] for x in init_kwargs} else: if len(init_kwargs) % 2 != 0: raise ValueError(init_kwargs_err_message) # Flat sequence - i.e. [key1, value1, key2, value2, ...] init_kwargs = dict(zip(init_kwargs[::2], init_kwargs[1::2])) if not isinstance(init_kwargs, Dict): raise ValueError(init_kwargs_err_message) return LayerInfo(name=name, init_kwargs=init_kwargs) def build(self, *args, **kwargs) -> nn.Module: """Builds the layer. Args: *args: Positional arguments for building the module. **kwargs: Keyword arguments to pass to the layer's ``__init__``. Note, these will override the init_kwargs of the layer if there are conflicts. """ return self.ltype(*args, **{**self.init_kwargs, **kwargs}) def build_layer_info_from_seq( layers_info: Sequence[Union[LayerInfoRawType, LayerInfo]], dimension: Optional[int] = None ) -> List[LayerInfo]: """Builds a sequence of layer info objects from a sequence of layer info types. Args: layers_info: Sequence of layer info types. Returns: List[LayerInfo]: list of LayerInfo objects corresponding to each element in layers_info. """ out = [ copy.deepcopy(x) if isinstance(x, LayerInfo) else LayerInfo.format(x) for x in layers_info ] if dimension is not None: # TODO (arjundd): We probably want some extra logic here where we don't # the dimension if it already exists in the name (e.g. conv1d). # Currently, if the name contains the dimension and a dimension is passed # (e.g. conv1d, 2), then the dimension value is ignored. for x in out: x.dimension = dimension return out