Train

Trainer

Training utils and functions.

class lagrangebench.train.trainer.Trainer(model: TransformedWithState, case, data_train: H5Dataset, data_valid: H5Dataset, cfg_train: Dict | DictConfig = {'batch_size': 1, 'step_max': 500000, 'num_workers': 4, 'noise_std': 0.0003, 'optimizer': {'lr_start': 0.0001, 'lr_final': 1e-06, 'lr_decay_rate': 0.1, 'lr_decay_steps': 100000.0}, 'pushforward': {'steps': [-1, 20000, 300000, 400000], 'unrolls': [0, 1, 2, 3], 'probs': [18, 2, 1, 1]}, 'loss_weight': {'acc': 1.0, 'vel': 0.0, 'pos': 0.0}}, cfg_eval: Dict | DictConfig = {'n_rollout_steps': 20, 'test': False, 'rollout_dir': None, 'train': {'n_trajs': 50, 'metrics_stride': 10, 'batch_size': 1, 'metrics': ['mse'], 'out_type': 'none'}, 'infer': {'n_trajs': -1, 'metrics_stride': 1, 'batch_size': 2, 'metrics': ['mse', 'e_kin', 'sinkhorn'], 'out_type': 'pkl', 'n_extrap_steps': 0}}, cfg_logging: Dict | DictConfig = {'log_steps': 1000, 'eval_steps': 10000, 'wandb': False, 'wandb_project': None, 'wandb_entity': 'lagrangebench', 'ckp_dir': 'ckp', 'run_name': None}, input_seq_length: int = 6, seed: int = 0)[source]

Trainer class.

Given a model, case setup, training and validation datasets this class automates training and evaluation.

  1. Initializes (or restarts a checkpoint) model, optimizer and loss function.

  2. Trains the model on data_train, using the given pushforward and noise tricks.

  3. Evaluates the model on data_valid on the specified metrics.

train(step_max: int = 500000, params: Mapping[str, Mapping[str, Array]] | None = None, state: Mapping[str, Mapping[str, Array]] | None = None, opt_state: Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree] | None = None, store_ckp: str | None = None, load_ckp: str | None = None, wandb_config: Dict | None = None) Tuple[Mapping[str, Mapping[str, Array]], Mapping[str, Mapping[str, Array]], Array | ndarray | bool_ | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]][source]

Training loop.

Trains and evals the model on the given case and dataset, and saves the model checkpoints and best models.

Parameters:
  • step_max – Maximum number of training steps.

  • params – Optional model parameters. If provided, training continues from it.

  • state – Optional model state.

  • opt_state – Optional optimizer state.

  • store_ckp – Checkpoints destination. Without it params aren’t saved.

  • load_ckp – Initial checkpoint directory. If provided resumes training.

  • wandb_config – Optional configuration to be logged on wandb.

Returns:

Tuple containing the final model parameters, state and optimizer state.

Strategies

Training tricks and strategies, currently: random-walk noise and push forward.

lagrangebench.train.strats.add_gns_noise(key: Array, pos_input: Array, particle_type: Array, input_seq_length: int, noise_std: float, shift_fn: Callable[[Any, Any], Any]) Tuple[Array, Array][source]

GNS-like random walk noise injection as described by Sanchez-Gonzalez et al..

Applies random-walk noise to the input positions and adjusts the targets accordingly to keep the trajectory consistent. It works by drawing independent samples from \(\mathcal{N^{(t)}}(0, \sigma_v^{(t)})\) for each input state. Noise is accummulated as a random walk and added to the velocity seqence. Each \(\sigma_v^{(t)}\) is set so that the last step of the random walk has \(\sigma_v^{(input\_seq\_length)}=noise\_std\). Based on the noised velocities, positions are adjusted such that \(\dot{p}^{t_k} = p^{t_k} − p^{t_{k−1}}\).

Parameters:
  • key – Random key.

  • pos_input – Clean input positions. Shape: (num_particles_max, input_seq_length + pushforward[“unrolls”][-1] + 1, dim)

  • particle_type – Particle type vector. Shape: (num_particles_max,)

  • input_seq_length – Input sequence length, as in the configs.

  • noise_std – Noise standard deviation at the last sequence step.

  • shift_fn – Shift function.

lagrangebench.train.strats.push_forward_sample_steps(key, step, pushforward)[source]

Sample the number of unroll steps based on the current training step and the specified pushforward configuration.

Parameters:
  • key – Random key

  • step – Current training step

  • pushforward – Pushforward configuration

lagrangebench.train.strats.push_forward_build(model_apply, case)[source]

Build the push forward function, introduced by Brandstetter et al..

Pushforward works by adding a stability “pushforward” loss term, in the form of an adversarial style loss.

\[L_{pf} = \mathbb{E}_k \mathbb{E}_{u^{k+1} | u^k} \mathbb{E}_{\epsilon} \left[ \mathcal{L}(f(u^k + \epsilon), u^{k-1}) \right]\]

where \(\epsilon\) is \(u^k + \epsilon = f(u^{k−1})\), i.e. the 2-step unroll of the solver \(f\) (from step \(k-1\) to \(k\)). The total loss is then \(L_{total}=\mathcal{L}(f(u^k), u^{k-1}) + L_{pf}\). Similarly, for \(S > 2\) pushforward steps, \(L_{pf}\) is extended to \(u^{k-S} \dots u^{k-1}\) with cumulated \(\epsilon\) perturbations.

In practice, this is implemented by unrolling the solver for two steps, but only running gradients through the last unroll step.

Parameters:
  • model_apply – Model apply function

  • case – Case setup function