Train
Trainer
Training utils and functions.
- class lagrangebench.train.trainer.Trainer(model: TransformedWithState, case, data_train: H5Dataset, data_valid: H5Dataset, cfg_train: Dict | DictConfig = {'batch_size': 1, 'step_max': 500000, 'num_workers': 4, 'noise_std': 0.0003, 'optimizer': {'lr_start': 0.0001, 'lr_final': 1e-06, 'lr_decay_rate': 0.1, 'lr_decay_steps': 100000.0}, 'pushforward': {'steps': [-1, 20000, 300000, 400000], 'unrolls': [0, 1, 2, 3], 'probs': [18, 2, 1, 1]}, 'loss_weight': {'acc': 1.0, 'vel': 0.0, 'pos': 0.0}}, cfg_eval: Dict | DictConfig = {'n_rollout_steps': 20, 'test': False, 'rollout_dir': None, 'train': {'n_trajs': 50, 'metrics_stride': 10, 'batch_size': 1, 'metrics': ['mse'], 'out_type': 'none'}, 'infer': {'n_trajs': -1, 'metrics_stride': 1, 'batch_size': 2, 'metrics': ['mse', 'e_kin', 'sinkhorn'], 'out_type': 'pkl', 'n_extrap_steps': 0}}, cfg_logging: Dict | DictConfig = {'log_steps': 1000, 'eval_steps': 10000, 'wandb': False, 'wandb_project': None, 'wandb_entity': 'lagrangebench', 'ckp_dir': 'ckp', 'run_name': None}, input_seq_length: int = 6, seed: int = 0)[source]
Trainer class.
Given a model, case setup, training and validation datasets this class automates training and evaluation.
Initializes (or restarts a checkpoint) model, optimizer and loss function.
Trains the model on data_train, using the given pushforward and noise tricks.
Evaluates the model on data_valid on the specified metrics.
- train(step_max: int = 500000, params: Mapping[str, Mapping[str, Array]] | None = None, state: Mapping[str, Mapping[str, Array]] | None = None, opt_state: Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree] | None = None, store_ckp: str | None = None, load_ckp: str | None = None, wandb_config: Dict | None = None) Tuple[Mapping[str, Mapping[str, Array]], Mapping[str, Mapping[str, Array]], Array | ndarray | bool | number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]][source]
Training loop.
Trains and evals the model on the given case and dataset, and saves the model checkpoints and best models.
- Parameters:
step_max – Maximum number of training steps.
params – Optional model parameters. If provided, training continues from it.
state – Optional model state.
opt_state – Optional optimizer state.
store_ckp – Checkpoints destination. Without it params aren’t saved.
load_ckp – Initial checkpoint directory. If provided resumes training.
wandb_config – Optional configuration to be logged on wandb.
- Returns:
Tuple containing the final model parameters, state and optimizer state.
Strategies
Training tricks and strategies, currently: random-walk noise and push forward.
- lagrangebench.train.strats.add_gns_noise(key: Array, pos_input: Array, particle_type: Array, input_seq_length: int, noise_std: float, shift_fn: Callable[[Any, Any], Any]) Tuple[Array, Array][source]
GNS-like random walk noise injection as described by Sanchez-Gonzalez et al..
Applies random-walk noise to the input positions and adjusts the targets accordingly to keep the trajectory consistent. It works by drawing independent samples from \(\mathcal{N^{(t)}}(0, \sigma_v^{(t)})\) for each input state. Noise is accummulated as a random walk and added to the velocity seqence. Each \(\sigma_v^{(t)}\) is set so that the last step of the random walk has \(\sigma_v^{(input\_seq\_length)}=noise\_std\). Based on the noised velocities, positions are adjusted such that \(\dot{p}^{t_k} = p^{t_k} − p^{t_{k−1}}\).
- Parameters:
key – Random key.
pos_input – Clean input positions. Shape: (num_particles_max, input_seq_length + pushforward[“unrolls”][-1] + 1, dim)
particle_type – Particle type vector. Shape: (num_particles_max,)
input_seq_length – Input sequence length, as in the configs.
noise_std – Noise standard deviation at the last sequence step.
shift_fn – Shift function.
- lagrangebench.train.strats.push_forward_sample_steps(key, step, pushforward)[source]
Sample the number of unroll steps based on the current training step and the specified pushforward configuration.
- Parameters:
key – Random key
step – Current training step
pushforward – Pushforward configuration
- lagrangebench.train.strats.push_forward_build(model_apply, case)[source]
Build the push forward function, introduced by Brandstetter et al..
Pushforward works by adding a stability “pushforward” loss term, in the form of an adversarial style loss.
\[L_{pf} = \mathbb{E}_k \mathbb{E}_{u^{k+1} | u^k} \mathbb{E}_{\epsilon} \left[ \mathcal{L}(f(u^k + \epsilon), u^{k-1}) \right]\]where \(\epsilon\) is \(u^k + \epsilon = f(u^{k−1})\), i.e. the 2-step unroll of the solver \(f\) (from step \(k-1\) to \(k\)). The total loss is then \(L_{total}=\mathcal{L}(f(u^k), u^{k-1}) + L_{pf}\). Similarly, for \(S > 2\) pushforward steps, \(L_{pf}\) is extended to \(u^{k-S} \dots u^{k-1}\) with cumulated \(\epsilon\) perturbations.
In practice, this is implemented by unrolling the solver for two steps, but only running gradients through the last unroll step.
- Parameters:
model_apply – Model apply function
case – Case setup function