Utils and Defaults

Utils

General utils and config structures.

class lagrangebench.utils.NodeType(value)[source]

Particle types.

lagrangebench.utils.get_kinematic_mask(particle_type)[source]

Return a boolean mask, set to true for all kinematic (obstacle) particles.

lagrangebench.utils.broadcast_to_batch(sample, batch_size: int)[source]

Broadcast a pytree to a batched one with first dimension batch_size.

lagrangebench.utils.broadcast_from_batch(batch, index: int)[source]

Broadcast a batched pytree to the sample index out of the batch.

lagrangebench.utils.save_pytree(ckp_dir: str, pytree_obj, name) None[source]

Save a pytree to a directory.

lagrangebench.utils.save_haiku(ckp_dir: str, params, state, opt_state, metadata_ckp) None[source]

Save params, state and optimizer state to ckp_dir.

Additionally it tracks and saves the best model to ckp_dir/best.

See: https://github.com/deepmind/dm-haiku/issues/18

lagrangebench.utils.load_pytree(model_dir: str, name)[source]

Load a pytree from a directory.

lagrangebench.utils.load_haiku(model_dir: str)[source]

Load params, state, optimizer state and last training step from model_dir.

See: https://github.com/deepmind/dm-haiku/issues/18

lagrangebench.utils.get_num_params(params)[source]

Get the number of parameters in a Haiku model.

lagrangebench.utils.set_seed(seed: int) Tuple[Array, Callable, Generator][source]

Set seeds for jax, random and torch.

Defaults

Default lagrangebench configs.

lagrangebench.defaults.set_defaults(cfg: DictConfig = {'config': None, 'load_ckp': None, 'mode': 'all', 'seed': 0, 'dtype': 'float64', 'gpu': None, 'xla_mem_fraction': None, 'dataset': {'src': None, 'name': None}, 'model': {'name': None, 'input_seq_length': 6, 'num_mp_steps': 10, 'num_mlp_layers': 2, 'latent_dim': 128, 'magnitude_features': False, 'isotropic_norm': False, 'lmax_attributes': 1, 'lmax_hidden': 1, 'segnn_norm': 'none', 'velocity_aggregate': 'avg'}, 'train': {'batch_size': 1, 'step_max': 500000, 'num_workers': 4, 'noise_std': 0.0003, 'optimizer': {'lr_start': 0.0001, 'lr_final': 1e-06, 'lr_decay_rate': 0.1, 'lr_decay_steps': 100000.0}, 'pushforward': {'steps': [-1, 20000, 300000, 400000], 'unrolls': [0, 1, 2, 3], 'probs': [18, 2, 1, 1]}, 'loss_weight': {'acc': 1.0, 'vel': 0.0, 'pos': 0.0}}, 'eval': {'n_rollout_steps': 20, 'test': False, 'rollout_dir': None, 'train': {'n_trajs': 50, 'metrics_stride': 10, 'batch_size': 1, 'metrics': ['mse'], 'out_type': 'none'}, 'infer': {'n_trajs': -1, 'metrics_stride': 1, 'batch_size': 2, 'metrics': ['mse', 'e_kin', 'sinkhorn'], 'out_type': 'pkl', 'n_extrap_steps': 0}}, 'logging': {'log_steps': 1000, 'eval_steps': 10000, 'wandb': False, 'wandb_project': None, 'wandb_entity': 'lagrangebench', 'ckp_dir': 'ckp', 'run_name': None}, 'neighbors': {'backend': 'jaxmd_vmap', 'multiplier': 1.25}}) DictConfig[source]

Set default lagrangebench configs.

lagrangebench.defaults.check_cfg(cfg: DictConfig)[source]

Check if the configs are valid.