Case Setup
Case
Case setup functions.
- class lagrangebench.case_setup.case.CaseSetupFn(allocate: Callable[[Array, Tuple[Array, Array], float, int], Tuple[Array, Dict[str, Array], Dict[str, Array], NeighborList]], preprocess: Callable[[Array, Tuple[Array, Array], float, NeighborList, int], Tuple[Array, Dict[str, Array], Dict[str, Array], NeighborList]], allocate_eval: Callable[[Tuple[Array, Array]], Tuple[Dict[str, Array], NeighborList]], preprocess_eval: Callable[[Tuple[Array, Array], NeighborList], Tuple[Dict[str, Array], NeighborList]], integrate: Callable[[Array, Array], Array], displacement: Callable[[Any, Any], Any], normalization_stats: Dict)[source]
Dataclass that contains all functions required to setup the case and simulate.
- allocate
AllocateFn, runs the preprocessing without having a NeighborList as input.
- Type:
Callable[[jax.Array, Tuple[jax.Array, jax.Array], float, int], Tuple[jax.Array, Dict[str, jax.Array], Dict[str, jax.Array], jax_sph.jax_md.partition.NeighborList]]
- preprocess
PreprocessFn, takes positions from the dataloader, computes velocities, adds random-walk noise if needed, then updates the neighbor list, and return the inputs to the neural network as well as the targets.
- Type:
Callable[[jax.Array, Tuple[jax.Array, jax.Array], float, jax_sph.jax_md.partition.NeighborList, int], Tuple[jax.Array, Dict[str, jax.Array], Dict[str, jax.Array], jax_sph.jax_md.partition.NeighborList]]
- allocate_eval
AllocateEvalFn, same as allocate, but without noise addition and without targets.
- Type:
Callable[[Tuple[jax.Array, jax.Array]], Tuple[Dict[str, jax.Array], jax_sph.jax_md.partition.NeighborList]]
- preprocess_eval
PreprocessEvalFn, same as allocate_eval, but jit-able.
- Type:
Callable[[Tuple[jax.Array, jax.Array], jax_sph.jax_md.partition.NeighborList], Tuple[Dict[str, jax.Array], jax_sph.jax_md.partition.NeighborList]]
- integrate
IntegrateFn, semi-implicit Euler integrations step respecting all boundary conditions.
- Type:
Callable[[jax.Array, jax.Array], jax.Array]
- displacement
space.DisplacementFn, displacement function aware of boundary conditions (periodic on non-periodic).
- Type:
Callable[[Any, Any], Any]
- normalization_stats
Dict, normalization statisticss for input velocities and output acceleration.
- Type:
Dict
- __delattr__(name)
Implement delattr(self, name).
- __eq__(other)
Return self==value.
- __hash__()
Return hash(self).
- __init__(allocate: Callable[[Array, Tuple[Array, Array], float, int], Tuple[Array, Dict[str, Array], Dict[str, Array], NeighborList]], preprocess: Callable[[Array, Tuple[Array, Array], float, NeighborList, int], Tuple[Array, Dict[str, Array], Dict[str, Array], NeighborList]], allocate_eval: Callable[[Tuple[Array, Array]], Tuple[Dict[str, Array], NeighborList]], preprocess_eval: Callable[[Tuple[Array, Array], NeighborList], Tuple[Dict[str, Array], NeighborList]], integrate: Callable[[Array, Array], Array], displacement: Callable[[Any, Any], Any], normalization_stats: Dict) None
- __setattr__(name, value)
Implement setattr(self, name, value).
- lagrangebench.case_setup.case.case_builder(box: Tuple[float, float, float], metadata: Dict, input_seq_length: int, cfg_neighbors: Dict | DictConfig = {'backend': 'jaxmd_vmap', 'multiplier': 1.25}, cfg_model: Dict | DictConfig = {'name': None, 'input_seq_length': 6, 'num_mp_steps': 10, 'num_mlp_layers': 2, 'latent_dim': 128, 'magnitude_features': False, 'isotropic_norm': False, 'lmax_attributes': 1, 'lmax_hidden': 1, 'segnn_norm': 'none', 'velocity_aggregate': 'avg'}, noise_std: float = 0.0003, external_force_fn: Callable | None = None, dtype: dtype = 'float64')[source]
Set up a CaseSetupFn that contains every required function besides the model.
Inspired by the partition.neighbor_list function in JAX-MD.
- The core functions are:
allocate, allocate memory for the neighbors list.
preprocess, update the neighbors list.
integrate, semi-implicit Euler respecting periodic boundary conditions.
- Parameters:
box – Box xyz sizes of the system.
metadata – Dataset metadata dictionary.
input_seq_length – Length of the input sequence.
cfg_neighbors – Configuration dictionary for the neighbor list.
cfg_model – Configuration dictionary for the model / feature builder.
noise_std – Noise standard deviation.
external_force_fn – External force function.
dtype – Data type.
Featurizer
Feature extraction utilities.
- lagrangebench.case_setup.features.physical_feature_builder(bounds: list, normalization_stats: dict, connectivity_radius: float, displacement_fn: Callable, pbc: List[bool], magnitude_features: bool = False, external_force_fn: Callable | None = None) Callable[source]
Build a physical feature transform function.
- Transform raw coordinates to
Absolute positions
Historical velocity sequence
Velocity magnitudes
Distance to boundaries
External force field
Relative displacement vectors and distances
- Parameters:
bounds – Each sublist contains the lower and upper bound of a dimension.
normalization_stats – Dict containing mean and std of velocities and targets
connectivity_radius – Radius of the connectivity graph.
displacement_fn – Displacement function.
pbc – Wether to use periodic boundary conditions.
magnitude_features – Whether to include the magnitude of the velocity.
external_force_fn – Function that returns the external force field (optional).