LagrangeBench
What is LagrangeBench
?
LagrangeBench is a machine learning benchmarking suite for Lagrangian particle problems based on the JAX library. It provides:
Data loading and preprocessing utilities for particle data.
Three different neighbors search routines: (a) original JAX-MD implementation, (b) memory efficient version of the JAX-MD implementation, and (c) a wrapper around the matscipy implementation allowing to handle variable number of particles.
JAX reimplementation of established graph neural networks: GNS, SEGNN, EGNN, PaiNN.
Training strategies including random-walk additive noise and the pushforward trick.
Evaluation tools consisting of rollout generation and different error metrics: position MSE, kinetic energy MSE, and Sinkhorn distance for the particle distribution.
Note
For more details on LagrangeBench usage check out our tutorials.
Data loading and preprocessing
First, we create a dataset class based on torch.utils.data.Dataset
.
We then initialize a CaseSetupFn
object taking care of the neighbors search,
preprocessing, and time integration.
import lagrangebench
# Load data
data_train = lagrangebench.RPF2D("train")
data_valid = lagrangebench.RPF2D("valid", extra_seq_length=20)
data_test = lagrangebench.RPF2D("test", extra_seq_length=20)
# Case setup (preprocessing and graph building)
bounds = np.array(data_train.metadata["bounds"])
box = bounds[:, 1] - bounds[:, 0]
case = lagrangebench.case_builder(
box=box,
metadata=data_train.metadata,
input_seq_length=6,
)
Models
Initialize a GNS model.
import haiku as hk
def gns(x):
return lagrangebench.models.GNS(
particle_dimension=data_train.metadata["dim"],
latent_size=16,
blocks_per_step=2,
num_mp_steps=4,
particle_type_embedding_size=8,
)(x)
gns = hk.without_apply_rng(hk.transform_with_state(gns))
Training
The Trainer
provides a convenient way to train a model.
trainer = lagrangebench.Trainer(
model=gns,
case=case,
data_train=data_train,
data_valid=data_valid,
cfg_eval={"n_rollout_steps": 20, "train": {"metrics": ["mse"]}},
input_seq_length=6
)
# Train for 25000 steps
params, state, _ = trainer.train(step_max=25000)
Evaluation
When training is done, we can evaluate the model on the test set.
metrics = lagrangebench.infer(
gns,
case,
data_test,
params,
state,
cfg_eval_infer={"metrics": ["mse", "sinkhorn", "e_kin"]},
n_rollout_steps=20,
)