import inspect
import itertools
from typing import Any, Dict, Sequence
import pandas as pd
import torch
from torchmetrics.metric import Metric as _Metric
from torchmetrics.utilities import reduce
from meddlr.utils import comm
__all__ = ["Metric"]
[docs]class Metric(_Metric):
"""Interface for new metrics.
A metric should be implemented as a callable with explicitly defined
arguments. In other words, metrics should not have `**kwargs` or `**args`
options in the `__call__` method.
While not explicitly constrained to the return type, metrics typically
return float value(s). The number of values returned corresponds to the
number of categories.
This class is opinionated in that it computes metrics for each (example, channel)
pair. This means that outputs of ``compute`` are not scalars, but rather tensors
of shape ``(B, C)``. Note, this opinion may be relaxed in the future.
* metrics should have different name() for different functionality.
* `category_dim` duck type if metric can process multiple categories at
once.
To compute metrics:
.. code-block:: python
metric = Metric()
results = metric(...)
"""
[docs] def __init__(
self,
channel_names: Sequence[str] = None,
units: str = None,
reduction="none",
compute_on_step: bool = False,
dist_sync_on_step: bool = False,
process_group: bool = None,
dist_sync_fn: bool = None,
):
self.units = units
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.reduction = reduction
self.channel_names = channel_names
self._update_kwargs_aliases = {}
# Identifiers for the examples that are seen.
self.add_state("ids", default=[], dist_reduce_fx=lambda x: itertools.chain(*x))
self.add_state("values", default=[], dist_reduce_fx="cat")
def func(self, preds, targets, *args, **kwargs) -> torch.Tensor:
"""Computes metrics for each element in the batch.
Returns:
torch.Tensor: A torch Tensor with first dimension being
batch dimension (``Bx...``).
"""
raise NotImplementedError
def update(self, preds, targets, *args, ids=None, **kwargs):
assert preds.shape == targets.shape
values: torch.Tensor = self.func(preds, targets, *args, **kwargs)
self.values.append(values)
self._add_ids(ids=ids, num_samples=len(values))
def _generate_ids(self, num_samples):
id_start = sum(len(x) for x in self.values)
rank = comm.get_rank()
ids = [f"{rank}-{id_start + idx}" for idx in range(num_samples)]
return ids
def _add_ids(self, ids, num_samples):
if ids is None:
ids = self._generate_ids(num_samples)
self.ids.extend(ids)
def compute(self, reduction=None):
if reduction is None:
reduction = self.reduction
return reduce(torch.cat(self.values), reduction)
def to_pandas(self, sync_dist: bool = True) -> pd.DataFrame:
return pd.DataFrame.from_dict(self.to_dict(sync_dist=sync_dist, device="cpu"))
def to_dict(self, sync_dist: bool = True, device=None):
if sync_dist:
with self.sync_context():
data = self._to_dict(device=device)
else:
data = self._to_dict(device=device)
return data
def _to_dict(self, device=None):
if not self.values:
return {"id": self.ids}
values = torch.cat(self.values) if isinstance(self.values, list) else self.values
if device is not None:
values = values.to(device)
channel_names = (
self.channel_names
if self.channel_names
else [f"channel_{idx}" for idx in range(values.shape[1])]
)
data = {"id": self.ids}
data.update({name: values[:, idx] for idx, name in enumerate(channel_names)})
return data
def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
"""filter kwargs such that they match the update signature of the metric"""
if self._update_kwargs_aliases:
filtered_kwargs = {k: v for k, v in kwargs.items()}
aliases = {}
for alias in self._update_kwargs_aliases:
if alias not in kwargs or self._update_kwargs_aliases[alias] in aliases:
continue
aliases[self._update_kwargs_aliases[alias]] = kwargs.pop(alias)
filtered_kwargs.update(aliases)
else:
filtered_kwargs = kwargs
# Use filtering from torch 0.6.0 where kwargs are preserved and passed along.
filtered_kwargs = _filter_kwargs(self._update_signature, **filtered_kwargs)
return filtered_kwargs
def register_update_aliases(self, **kwargs):
"""Register aliases for keyword arguments when calling update."""
# filter all parameters based on update signature except those of
# type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs)
_params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
_sign_params = self._update_signature.parameters
supported_kwargs = tuple(
k for k in _sign_params.keys() if _sign_params[k].kind not in _params
)
unsupported_kwargs = [v for v in kwargs.values() if v not in supported_kwargs]
if len(unsupported_kwargs) > 0:
raise ValueError(
f"Found unsupported kwargs '{unsupported_kwargs}'. "
f"Supported keyword arguments include:{supported_kwargs}"
)
aliases = {k: v for k, v in kwargs.items()}
self._update_kwargs_aliases.update(aliases)
def name(self):
return type(self).__name__
def display_name(self):
"""Name to use for pretty printing and display purposes."""
name = self.name()
return "{} ({})".format(name, self.units) if self.units else name
def _filter_kwargs(sig, **kwargs: Any) -> Dict[str, Any]:
# filter all parameters based on update signature except those of
# type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs)
_params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
_sign_params = sig.parameters
filtered_kwargs = {
k: v
for k, v in kwargs.items()
if (k in _sign_params.keys() and _sign_params[k].kind not in _params)
}
# if no kwargs filtered, return al kwargs as default
if not filtered_kwargs:
filtered_kwargs = kwargs
return filtered_kwargs