Utils and Defaults
Utils
General utils and config structures.
- lagrangebench.utils.get_kinematic_mask(particle_type)[source]
Return a boolean mask, set to true for all kinematic (obstacle) particles.
- lagrangebench.utils.broadcast_to_batch(sample, batch_size: int)[source]
Broadcast a pytree to a batched one with first dimension batch_size.
- lagrangebench.utils.broadcast_from_batch(batch, index: int)[source]
Broadcast a batched pytree to the sample index out of the batch.
- lagrangebench.utils.save_pytree(ckp_dir: str, pytree_obj, name) None[source]
Save a pytree to a directory.
- lagrangebench.utils.save_haiku(ckp_dir: str, params, state, opt_state, metadata_ckp) None[source]
Save params, state and optimizer state to ckp_dir.
Additionally it tracks and saves the best model to ckp_dir/best.
Defaults
Default lagrangebench configs.
- lagrangebench.defaults.set_defaults(cfg: DictConfig = {'config': None, 'load_ckp': None, 'mode': 'all', 'seed': 0, 'dtype': 'float64', 'gpu': None, 'xla_mem_fraction': None, 'dataset': {'src': None, 'name': None}, 'model': {'name': None, 'input_seq_length': 6, 'num_mp_steps': 10, 'num_mlp_layers': 2, 'latent_dim': 128, 'magnitude_features': False, 'isotropic_norm': False, 'lmax_attributes': 1, 'lmax_hidden': 1, 'segnn_norm': 'none', 'velocity_aggregate': 'avg'}, 'train': {'batch_size': 1, 'step_max': 500000, 'num_workers': 4, 'noise_std': 0.0003, 'optimizer': {'lr_start': 0.0001, 'lr_final': 1e-06, 'lr_decay_rate': 0.1, 'lr_decay_steps': 100000.0}, 'pushforward': {'steps': [-1, 20000, 300000, 400000], 'unrolls': [0, 1, 2, 3], 'probs': [18, 2, 1, 1]}, 'loss_weight': {'acc': 1.0, 'vel': 0.0, 'pos': 0.0}}, 'eval': {'n_rollout_steps': 20, 'test': False, 'rollout_dir': None, 'train': {'n_trajs': 50, 'metrics_stride': 10, 'batch_size': 1, 'metrics': ['mse'], 'out_type': 'none'}, 'infer': {'n_trajs': -1, 'metrics_stride': 1, 'batch_size': 2, 'metrics': ['mse', 'e_kin', 'sinkhorn'], 'out_type': 'pkl', 'n_extrap_steps': 0}}, 'logging': {'log_steps': 1000, 'eval_steps': 10000, 'wandb': False, 'wandb_project': None, 'wandb_entity': 'lagrangebench', 'ckp_dir': 'ckp', 'run_name': None}, 'neighbors': {'backend': 'jaxmd_vmap', 'multiplier': 1.25}}) DictConfig[source]
Set default lagrangebench configs.