meddlr.metrics.RMSE#

class meddlr.metrics.RMSE(im_type: Optional[str] = None, channel_names: Optional[Sequence[str]] = None, reduction='none', compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: Optional[bool] = None, dist_sync_fn: Optional[bool] = None)[source]#

Root-mean-squared error with complex-valued support.

\(RMSE = ||x_{pred} - x_{gt}||_2\).

This implementation supports complex tensors. im_type controls how the complex tensor should be processed:

  • 'magnitude': \(x_{pred}\) and \(x_{gt}\) are converted to magnitude images.

  • 'phase': \(x_{pred}\) and \(x_{gt}\) are converted to phase images.

  • 'real': Real components of \(x_{pred}\) and \(x_{gt}\) are used.

  • 'imag': Imaginary components of \(x_{pred}\) and \(x_{gt}\) are used.

Variables
  • im_type (str) – The type of the complex image to compute the metric on. This only applies to complex tensors.

  • channel_names (Sequence[str]) – The names of the channels in the input.

__init__(im_type: Optional[str] = None, channel_names: Optional[Sequence[str]] = None, reduction='none', compute_on_step: bool = False, dist_sync_on_step: bool = False, process_group: Optional[bool] = None, dist_sync_fn: Optional[bool] = None)[source]#

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Methods

__init__([im_type, channel_names, ...])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

add_module(name, module)

Adds a child module to the current module.

add_state(name, default[, dist_reduce_fx, ...])

Adds metric state variable.

apply(fn)

Applies fn recursively to every submodule (as returned by .children()) as well as self.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Returns an iterator over module buffers.

children()

Returns an iterator over immediate children modules.

clone()

Make a copy of the metric.

compute([reduction])

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

cpu()

Moves all model parameters and buffers to the CPU.

cuda([device])

Moves all model parameters and buffers to the GPU.

display_name()

Name to use for pretty printing and display purposes.

double()

Method override default and prevent dtype casting.

eval()

Sets the module in evaluation mode.

extra_repr()

Set the extra representation of the module

float()

Method override default and prevent dtype casting.

forward(*args, **kwargs)

forward serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state.

func(preds, targets)

Computes metrics for each element in the batch.

get_buffer(target)

Returns the buffer given by target if it exists, otherwise throws an error.

get_extra_state()

Returns any extra state to include in the module's state_dict.

get_parameter(target)

Returns the parameter given by target if it exists, otherwise throws an error.

get_submodule(target)

Returns the submodule given by target if it exists, otherwise throws an error.

half()

Method override default and prevent dtype casting.

ipu([device])

Moves all model parameters and buffers to the IPU.

load_state_dict(state_dict[, strict])

Copies parameters and buffers from state_dict into this module and its descendants.

modules()

Returns an iterator over all modules in the network.

name()

named_buffers([prefix, recurse, ...])

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Returns an iterator over module parameters.

persistent([mode])

Method for post-init to change if metric states should be saved to its state_dict.

register_backward_hook(hook)

Registers a backward hook on the module.

register_buffer(name, tensor[, persistent])

Adds a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Registers a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Registers a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Registers a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Registers a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Registers a post hook to be run after module's load_state_dict is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Adds a parameter to the module.

register_state_dict_pre_hook(hook)

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self.

register_update_aliases(**kwargs)

Register aliases for keyword arguments when calling update.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

reset()

This method automatically resets the metric state variables to their default value.

set_dtype(dst_type)

Special version of type for transferring all metric states to specific dtype :param dst_type: the desired type :type dst_type: type or string

set_extra_state(state)

This function is called from load_state_dict() to handle any extra state found within the state_dict.

share_memory()

See torch.Tensor.share_memory_()

state_dict([destination, prefix, keep_vars])

Returns a dictionary containing references to the whole state of the module.

sync([dist_sync_fn, process_group, ...])

Sync function for manually controlling when metrics states should be synced across processes.

sync_context([dist_sync_fn, process_group, ...])

Context manager to synchronize the states between processes when running in a distributed setting and restore the local cache states after yielding.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

to_dict([sync_dist, device])

to_empty(*, device)

Moves the parameters and buffers to the specified device without copying storage.

to_pandas([sync_dist])

train([mode])

Sets the module in training mode.

type(dst_type)

Method override default and prevent dtype casting.

unsync([should_unsync])

Unsync function for manually controlling when metrics states should be reverted back to their local states.

update(preds, targets, *args[, ids])

Override this method to update the state variables of your metric class.

xpu([device])

Moves all model parameters and buffers to the XPU.

zero_grad([set_to_none])

Sets gradients of all model parameters to zero.

Attributes

T_destination

alias of TypeVar('T_destination', bound=Dict[str, Any])

call_super_init

device

Return the device of the metric.

dump_patches

full_state_update

higher_is_better

is_differentiable