Source code for lagrangebench.evaluate.metrics

"""Metrics for evaluation end testing."""

import warnings
from collections import defaultdict
from functools import partial
from typing import Callable, Dict, List, Optional

import jax
import jax.numpy as jnp
import numpy as np
from ott.geometry.geometry import Geometry
from ott.tools.sinkhorn_divergence import sinkhorn_divergence

MetricsDict = Dict[str, Dict[str, jnp.ndarray]]


[docs] class MetricsComputer: """ Metrics between predicted and target rollouts. Currently implemented: * MSE, mean squared error * MAE, mean absolute error * Sinkhorn distance, measures the similarity of two particle distributions * Kinetic energy, physical quantity of interest """ METRICS = ["mse", "mae", "sinkhorn", "e_kin"]
[docs] def __init__( self, active_metrics: List, dist_fn: Callable, metadata: Dict, input_seq_length: int, stride: int = 10, loss_ranges: Optional[List] = None, ot_backend: str = "ott", ): """Init the metric computer. Args: active_metrics: List of metrics to compute. dist_fn: Distance function. metadata: Metadata of the dataset. loss_ranges: List of horizon lengths to compute the loss for. input_seq_length: Length of the input sequence. stride: Rollout subsample frequency for e_kin and sinkhorn. ot_backend: Backend for sinkhorn computation. "ott" or "pot". """ if active_metrics is None: active_metrics = [] assert all([hasattr(self, metric) for metric in active_metrics]) assert ot_backend in ["ott", "pot"] self._active_metrics = active_metrics self._dist_fn = dist_fn self._dist_vmap = jax.vmap(dist_fn, in_axes=(0, 0)) self._dist_dvmap = jax.vmap(self._dist_vmap, in_axes=(0, 0)) if loss_ranges is None: loss_ranges = [1, 5, 10, 20, 50, 100] self._loss_ranges = loss_ranges self._input_seq_length = input_seq_length self._stride = stride self._metadata = metadata self.ot_backend = ot_backend
[docs] def __call__( self, pred_rollout: jnp.ndarray, target_rollout: jnp.ndarray ) -> MetricsDict: """Compute the metrics between two rollouts. Args: pred_rollout: Predicted rollout. target_rollout: Target rollout. Returns: Dictionary of metrics. """ # both rollouts of shape (traj_len - t_window, n_nodes, dim) target_rollout = jnp.asarray(target_rollout, dtype=pred_rollout.dtype) metrics = {} with warnings.catch_warnings(): warnings.simplefilter("ignore") for metric_name in self._active_metrics: metric_fn = getattr(self, metric_name) if metric_name in ["mse", "mae"]: # full rollout loss metrics[metric_name] = jax.vmap(metric_fn)( pred_rollout, target_rollout ) # shorter horizon losses for i in self._loss_ranges: if i < metrics[metric_name].shape[0]: metrics[f"{metric_name}{i}"] = metrics[metric_name][:i] elif metric_name in ["e_kin"]: dt = self._metadata["dt"] * self._metadata["write_every"] dx = self._metadata["dx"] dim = self._metadata["dim"] metric_dvmap = jax.vmap(jax.vmap(metric_fn)) # Ekin of predicted rollout velocity_rollout = self._dist_dvmap( pred_rollout[1 :: self._stride], pred_rollout[0 : -1 : self._stride], ) e_kin_pred = metric_dvmap(velocity_rollout / dt).sum(1) e_kin_pred = e_kin_pred * dx**dim # Ekin of target rollout velocity_rollout = self._dist_dvmap( target_rollout[1 :: self._stride], target_rollout[0 : -1 : self._stride], ) e_kin_target = metric_dvmap(velocity_rollout / dt).sum(1) e_kin_target = e_kin_target * dx**dim metrics[metric_name] = { "predicted": e_kin_pred, "target": e_kin_target, "mse": ((e_kin_pred - e_kin_target) ** 2).mean(), } elif metric_name == "sinkhorn": # vmapping over distance matrix blows up metrics[metric_name] = jax.lax.scan( lambda _, x: (None, self.sinkhorn(*x)), None, ( pred_rollout[0 :: self._stride], target_rollout[0 :: self._stride], ), )[1] return metrics
[docs] @partial(jax.jit, static_argnums=(0,)) def mse(self, pred: jnp.ndarray, target: jnp.ndarray) -> float: """Compute the mean squared error between two rollouts.""" return (self._dist_vmap(pred, target) ** 2).mean()
[docs] @partial(jax.jit, static_argnums=(0,)) def mae(self, pred: jnp.ndarray, target: jnp.ndarray) -> float: """Compute the mean absolute error between two rollouts.""" return (jnp.abs(self._dist_vmap(pred, target))).mean()
[docs] @partial(jax.jit, static_argnums=(0,)) def sinkhorn(self, pred: jnp.ndarray, target: jnp.ndarray) -> float: """Compute the sinkhorn distance between two rollouts.""" if self.ot_backend == "ott": return self._sinkhorn_ott(pred, target) else: return self._sinkhorn_pot(pred, target)
[docs] @partial(jax.jit, static_argnums=(0,)) def e_kin(self, frame: jnp.ndarray) -> float: """Compute the kinetic energy of a frame.""" return jnp.sum(frame**2) # * dx ** 3
def _sinkhorn_ott(self, pred: jnp.ndarray, target: jnp.ndarray) -> float: # pairwise distances as cost loss_matrix_xy = self._distance_matrix(pred, target) loss_matrix_xx = self._distance_matrix(pred, pred) loss_matrix_yy = self._distance_matrix(target, target) return sinkhorn_divergence( Geometry, loss_matrix_xy, loss_matrix_xx, loss_matrix_yy, # uniform weights a=jnp.ones((pred.shape[0],)) / pred.shape[0], b=jnp.ones((target.shape[0],)) / target.shape[0], sinkhorn_kwargs={"threshold": 1e-4}, ).divergence def _sinkhorn_pot(self, pred: jnp.ndarray, target: jnp.ndarray): """Jax-compatible POT implementation of Sinkorn.""" # equivalent to empirical_sinkhorn_divergence with custom distance computation sinkhorn_ab = self._custom_empirical_sinkorn_pot(pred, target) sinkhorn_a = self._custom_empirical_sinkorn_pot(pred, pred) sinkhorn_b = self._custom_empirical_sinkorn_pot(target, target) return jnp.asarray( jnp.clip(sinkhorn_ab - 0.5 * (sinkhorn_a + sinkhorn_b), 0), dtype=jnp.float32, ) def _custom_empirical_sinkorn_pot(self, pred: jnp.ndarray, target: jnp.ndarray): from ot.bregman import sinkhorn2 # weights are uniform a, b = ( jnp.ones((pred.shape[0],)) / pred.shape[0], jnp.ones((target.shape[0],)) / target.shape[0], ) loss_matrix = self._distance_matrix(pred, target) shape = jax.ShapeDtypeStruct((), dtype=jnp.float32) # hack to avoid CpuCallback attribute error def sinkhorn2_(a, b, loss_matrix): return jnp.array( sinkhorn2(a, b, loss_matrix, reg=0.1, numItermax=500, stopThr=1e-05), dtype=jnp.float32, ) return jax.pure_callback( sinkhorn2_, shape, a, b, loss_matrix, ) def _distance_matrix( self, x: jnp.ndarray, y: jnp.ndarray, squared=True ) -> jnp.ndarray: """Euclidean distance matrix (pairwise).""" def dist(a, b): return jnp.sum(self._dist_fn(a, b) ** 2) if not squared: def dist(a, b): return jnp.sqrt(dist(a, b)) return jnp.array( jax.vmap(lambda a: jax.vmap(lambda b: dist(a, b))(y))(x), dtype=jnp.float32 )
[docs] def averaged_metrics(eval_metrics: MetricsDict) -> Dict[str, float]: """Averages the metrics over the rollouts.""" # create a dictionary with the same keys as the metrics, but empty list as values trajectory_averages = defaultdict(list) for rollout in eval_metrics.values(): for k, v in rollout.items(): if k == "e_kin": v = v["mse"] if k in ["mse", "mae"]: k = "loss" trajectory_averages[k].append(jnp.mean(v).item()) # mean and std values accross rollouts small_metrics = {} for k, v in trajectory_averages.items(): small_metrics[f"val/{k}"] = float(np.mean(v)) for k, v in trajectory_averages.items(): small_metrics[f"val/std{k}"] = float(np.std(v)) return small_metrics