Evaluate

Rollout

Evaluation and inference functions for generating rollouts.

lagrangebench.evaluate.rollout.eval_rollout(model_apply: Callable, case, params: Mapping[str, Mapping[str, Array]], state: Mapping[str, Mapping[str, Array]], loader_eval: Iterable, neighbors: NeighborList, metrics_computer: MetricsComputer, n_rollout_steps: int, n_trajs: int, rollout_dir: str, out_type: str = 'none', n_extrap_steps: int = 0) Dict[str, Dict[str, Array]][source]

Compute the rollout and evaluate the metrics.

Parameters:
  • model_apply – Model function.

  • case – CaseSetupFn class.

  • params – Haiku params.

  • state – Haiku state.

  • loader_eval – Evaluation data loader.

  • neighbors – Neighbor list.

  • metrics_computer – MetricsComputer with the desired metrics.

  • n_rollout_steps – Number of rollout steps.

  • n_trajs – Number of ground truth trajectories to evaluate.

  • rollout_dir – Parent directory path where to store the rollout and metrics dict.

  • out_type – Output type. Either “none”, “vtk” or “pkl”.

  • n_extrap_steps – Number of extrapolation steps (beyond the ground truth rollout).

Returns:

Metrics per trajectory.

lagrangebench.evaluate.rollout.infer(model: TransformedWithState, case, data_test: H5Dataset, params: Mapping[str, Mapping[str, Array]] | None = None, state: Mapping[str, Mapping[str, Array]] | None = None, load_ckp: str | None = None, cfg_eval_infer: Dict | DictConfig = {'n_trajs': -1, 'metrics_stride': 1, 'batch_size': 2, 'metrics': ['mse', 'e_kin', 'sinkhorn'], 'out_type': 'pkl', 'n_extrap_steps': 0}, rollout_dir: str | None = None, n_rollout_steps: int = 20, seed: int = 0)[source]

Infer on a dataset, compute metrics and optionally save rollout in out_type format.

Parameters:
  • model – (Transformed) Haiku model.

  • case – Case setup class.

  • data_test – Test dataset.

  • params – Haiku params.

  • state – Haiku state.

  • load_ckp – Path to checkpoint directory.

  • rollout_dir – Path to rollout directory.

  • cfg_eval_infer – Evaluation configuration for inference mode.

  • n_rollout_steps – Number of rollout steps.

  • seed – Seed.

Returns:

Metrics per trajectory.

Return type:

eval_metrics

Metrics

Metrics for evaluation end testing.

class lagrangebench.evaluate.metrics.MetricsComputer(active_metrics: List, dist_fn: Callable, metadata: Dict, input_seq_length: int, stride: int = 10, loss_ranges: List | None = None, ot_backend: str = 'ott')[source]

Metrics between predicted and target rollouts.

Currently implemented: * MSE, mean squared error * MAE, mean absolute error * Sinkhorn distance, measures the similarity of two particle distributions * Kinetic energy, physical quantity of interest

__init__(active_metrics: List, dist_fn: Callable, metadata: Dict, input_seq_length: int, stride: int = 10, loss_ranges: List | None = None, ot_backend: str = 'ott')[source]

Init the metric computer.

Parameters:
  • active_metrics – List of metrics to compute.

  • dist_fn – Distance function.

  • metadata – Metadata of the dataset.

  • loss_ranges – List of horizon lengths to compute the loss for.

  • input_seq_length – Length of the input sequence.

  • stride – Rollout subsample frequency for e_kin and sinkhorn.

  • ot_backend – Backend for sinkhorn computation. “ott” or “pot”.

__call__(pred_rollout: Array, target_rollout: Array) Dict[str, Dict[str, Array]][source]

Compute the metrics between two rollouts.

Parameters:
  • pred_rollout – Predicted rollout.

  • target_rollout – Target rollout.

Returns:

Dictionary of metrics.

mse(pred: Array, target: Array) float[source]

Compute the mean squared error between two rollouts.

mae(pred: Array, target: Array) float[source]

Compute the mean absolute error between two rollouts.

sinkhorn(pred: Array, target: Array) float[source]

Compute the sinkhorn distance between two rollouts.

e_kin(frame: Array) float[source]

Compute the kinetic energy of a frame.

lagrangebench.evaluate.metrics.averaged_metrics(eval_metrics: Dict[str, Dict[str, Array]]) Dict[str, float][source]

Averages the metrics over the rollouts.

Utils

Utility functions for evaluation.

lagrangebench.evaluate.utils.write_vtk(data_dict, path)[source]

Store a .vtk file for ParaView.

lagrangebench.evaluate.utils.pkl2vtk(src_path, dst_path=None)[source]

Convert a rollout pickle file to a set of vtk files.

Parameters:
  • src_path (str) – Source path to .pkl file.

  • dst_path (str, optoinal) – Destination directory path. Defaults to None. If None, then the vtk files are saved in the same directory as the pkl file.

Example

pkl2vtk(“rollout/test/rollout_0.pkl”, “rollout/test_vtk”) will create files rollout_0_0.vtk, rollout_0_1.vtk, etc. in the directory “rollout/test_vtk”