Source code for meddlr.metrics.collection

from typing import Any, Dict, List, Optional, Sequence, Set, Union

import numpy as np
import pandas as pd
import tabulate
from torchmetrics.collections import MetricCollection as _MetricCollection

from meddlr.metrics.metric import Metric

__all__ = ["MetricCollection"]


[docs]class MetricCollection(_MetricCollection): """The class that manages multiple metrics."""
[docs] def __init__( self, metrics: Union[Metric, Sequence[Metric], Dict[str, Metric]], *additional_metrics: Metric, prefix: Optional[str] = None, postfix: Optional[str] = None ) -> None: super().__init__(metrics, *additional_metrics, prefix=prefix, postfix=postfix) self._is_data_stale = False
def scans(self) -> List[str]: return list(self._scan_data.keys()) def scan_summary(self, scan_id, delimiter: str = ", ") -> str: """Get summary of results for a scan. Args: scan_id: Scan id for which to summarize results. delimiter (`str`, optional): Delimiter between different metrics. Returns: str: A summary of metrics for the scan. Values are averaged across all categories. """ scan_data = self._scan_data[scan_id] avg_data = scan_data.mean(axis=1) strs = ["{}: {:0.3f}".format(n, avg_data[n]) for n in avg_data.index.tolist()] return delimiter.join(strs) def to_pandas(self, sync_dist: bool = True) -> pd.DataFrame: frames = [] metric: Metric for name, metric in self.items(): df: pd.DataFrame = metric.to_pandas(sync_dist=sync_dist) df["Metric"] = name frames.append(df) return pd.concat(frames, ignore_index=True) def to_dict(self, group_by="Metric", sync_dist: bool = True) -> Dict[str, Any]: df = self.to_pandas(sync_dist=sync_dist) df = df.melt(id_vars=["Metric", "id"], var_name="category", value_name="value") if len(np.unique(df["category"])) > 1: df["Metric"] = df["Metric"] + "/" + df["category"] df = df.drop(columns="category") values = df.groupby(by=group_by).mean(numeric_only=True) return values.to_dict()["value"] def summary(self, sync_dist: bool = True) -> str: """Get summary of results overall scans. Returns: str: Tabulated summary. Rows=metrics. Columns=classes. """ df = self.to_pandas(sync_dist=sync_dist) if "id" in df: df = df.drop(columns="id") df = df.groupby(by="Metric") mean = df.mean().applymap(lambda x: "{:0.3f}".format(x)) std = df.std().applymap(lambda x: "{:0.3f}".format(x)) df = mean + " (" + std + ")" return tabulate.tabulate(df, headers=df.columns) + "\n" def ids(self, sync_dist=True) -> Set[str]: _ids = set() metric: Metric for _, metric in self.items(): _ids |= set(metric.to_pandas(sync_dist=sync_dist)["id"].to_numpy()) return _ids