Evaluate
Rollout
Evaluation and inference functions for generating rollouts.
- lagrangebench.evaluate.rollout.eval_rollout(model_apply: Callable, case, params: Mapping[str, Mapping[str, Array]], state: Mapping[str, Mapping[str, Array]], loader_eval: Iterable, neighbors: NeighborList, metrics_computer: MetricsComputer, n_rollout_steps: int, n_trajs: int, rollout_dir: str, out_type: str = 'none', n_extrap_steps: int = 0) Dict[str, Dict[str, Array]][source]
Compute the rollout and evaluate the metrics.
- Parameters:
model_apply – Model function.
case – CaseSetupFn class.
params – Haiku params.
state – Haiku state.
loader_eval – Evaluation data loader.
neighbors – Neighbor list.
metrics_computer – MetricsComputer with the desired metrics.
n_rollout_steps – Number of rollout steps.
n_trajs – Number of ground truth trajectories to evaluate.
rollout_dir – Parent directory path where to store the rollout and metrics dict.
out_type – Output type. Either “none”, “vtk” or “pkl”.
n_extrap_steps – Number of extrapolation steps (beyond the ground truth rollout).
- Returns:
Metrics per trajectory.
- lagrangebench.evaluate.rollout.infer(model: TransformedWithState, case, data_test: H5Dataset, params: Mapping[str, Mapping[str, Array]] | None = None, state: Mapping[str, Mapping[str, Array]] | None = None, load_ckp: str | None = None, cfg_eval_infer: Dict | DictConfig = {'n_trajs': -1, 'metrics_stride': 1, 'batch_size': 2, 'metrics': ['mse', 'e_kin', 'sinkhorn'], 'out_type': 'pkl', 'n_extrap_steps': 0}, rollout_dir: str | None = None, n_rollout_steps: int = 20, seed: int = 0)[source]
Infer on a dataset, compute metrics and optionally save rollout in out_type format.
- Parameters:
model – (Transformed) Haiku model.
case – Case setup class.
data_test – Test dataset.
params – Haiku params.
state – Haiku state.
load_ckp – Path to checkpoint directory.
rollout_dir – Path to rollout directory.
cfg_eval_infer – Evaluation configuration for inference mode.
n_rollout_steps – Number of rollout steps.
seed – Seed.
- Returns:
Metrics per trajectory.
- Return type:
eval_metrics
Metrics
Metrics for evaluation end testing.
- class lagrangebench.evaluate.metrics.MetricsComputer(active_metrics: List, dist_fn: Callable, metadata: Dict, input_seq_length: int, stride: int = 10, loss_ranges: List | None = None, ot_backend: str = 'ott')[source]
Metrics between predicted and target rollouts.
Currently implemented: * MSE, mean squared error * MAE, mean absolute error * Sinkhorn distance, measures the similarity of two particle distributions * Kinetic energy, physical quantity of interest
- __init__(active_metrics: List, dist_fn: Callable, metadata: Dict, input_seq_length: int, stride: int = 10, loss_ranges: List | None = None, ot_backend: str = 'ott')[source]
Init the metric computer.
- Parameters:
active_metrics – List of metrics to compute.
dist_fn – Distance function.
metadata – Metadata of the dataset.
loss_ranges – List of horizon lengths to compute the loss for.
input_seq_length – Length of the input sequence.
stride – Rollout subsample frequency for e_kin and sinkhorn.
ot_backend – Backend for sinkhorn computation. “ott” or “pot”.
- __call__(pred_rollout: Array, target_rollout: Array) Dict[str, Dict[str, Array]][source]
Compute the metrics between two rollouts.
- Parameters:
pred_rollout – Predicted rollout.
target_rollout – Target rollout.
- Returns:
Dictionary of metrics.
- mae(pred: Array, target: Array) float[source]
Compute the mean absolute error between two rollouts.
Utils
Utility functions for evaluation.
- lagrangebench.evaluate.utils.pkl2vtk(src_path, dst_path=None)[source]
Convert a rollout pickle file to a set of vtk files.
- Parameters:
src_path (str) – Source path to .pkl file.
dst_path (str, optoinal) – Destination directory path. Defaults to None. If None, then the vtk files are saved in the same directory as the pkl file.
Example
pkl2vtk(“rollout/test/rollout_0.pkl”, “rollout/test_vtk”) will create files rollout_0_0.vtk, rollout_0_1.vtk, etc. in the directory “rollout/test_vtk”