Case Setup

Case

Case setup functions.

class lagrangebench.case_setup.case.CaseSetupFn(allocate: Callable[[Array, Tuple[Array, Array], float, int], Tuple[Array, Dict[str, Array], Dict[str, Array], NeighborList]], preprocess: Callable[[Array, Tuple[Array, Array], float, NeighborList, int], Tuple[Array, Dict[str, Array], Dict[str, Array], NeighborList]], allocate_eval: Callable[[Tuple[Array, Array]], Tuple[Dict[str, Array], NeighborList]], preprocess_eval: Callable[[Tuple[Array, Array], NeighborList], Tuple[Dict[str, Array], NeighborList]], integrate: Callable[[Array, Array], Array], displacement: Callable[[Any, Any], Any], normalization_stats: Dict)[source]

Dataclass that contains all functions required to setup the case and simulate.

allocate

AllocateFn, runs the preprocessing without having a NeighborList as input.

Type:

Callable[[jax.Array, Tuple[jax.Array, jax.Array], float, int], Tuple[jax.Array, Dict[str, jax.Array], Dict[str, jax.Array], jax_sph.jax_md.partition.NeighborList]]

preprocess

PreprocessFn, takes positions from the dataloader, computes velocities, adds random-walk noise if needed, then updates the neighbor list, and return the inputs to the neural network as well as the targets.

Type:

Callable[[jax.Array, Tuple[jax.Array, jax.Array], float, jax_sph.jax_md.partition.NeighborList, int], Tuple[jax.Array, Dict[str, jax.Array], Dict[str, jax.Array], jax_sph.jax_md.partition.NeighborList]]

allocate_eval

AllocateEvalFn, same as allocate, but without noise addition and without targets.

Type:

Callable[[Tuple[jax.Array, jax.Array]], Tuple[Dict[str, jax.Array], jax_sph.jax_md.partition.NeighborList]]

preprocess_eval

PreprocessEvalFn, same as allocate_eval, but jit-able.

Type:

Callable[[Tuple[jax.Array, jax.Array], jax_sph.jax_md.partition.NeighborList], Tuple[Dict[str, jax.Array], jax_sph.jax_md.partition.NeighborList]]

integrate

IntegrateFn, semi-implicit Euler integrations step respecting all boundary conditions.

Type:

Callable[[jax.Array, jax.Array], jax.Array]

displacement

space.DisplacementFn, displacement function aware of boundary conditions (periodic on non-periodic).

Type:

Callable[[Any, Any], Any]

normalization_stats

Dict, normalization statisticss for input velocities and output acceleration.

Type:

Dict

__delattr__(name)

Implement delattr(self, name).

__eq__(other)

Return self==value.

__hash__()

Return hash(self).

__init__(allocate: Callable[[Array, Tuple[Array, Array], float, int], Tuple[Array, Dict[str, Array], Dict[str, Array], NeighborList]], preprocess: Callable[[Array, Tuple[Array, Array], float, NeighborList, int], Tuple[Array, Dict[str, Array], Dict[str, Array], NeighborList]], allocate_eval: Callable[[Tuple[Array, Array]], Tuple[Dict[str, Array], NeighborList]], preprocess_eval: Callable[[Tuple[Array, Array], NeighborList], Tuple[Dict[str, Array], NeighborList]], integrate: Callable[[Array, Array], Array], displacement: Callable[[Any, Any], Any], normalization_stats: Dict) None
__setattr__(name, value)

Implement setattr(self, name, value).

lagrangebench.case_setup.case.case_builder(box: Tuple[float, float, float], metadata: Dict, input_seq_length: int, cfg_neighbors: Dict | DictConfig = {'backend': 'jaxmd_vmap', 'multiplier': 1.25}, cfg_model: Dict | DictConfig = {'name': None, 'input_seq_length': 6, 'num_mp_steps': 10, 'num_mlp_layers': 2, 'latent_dim': 128, 'magnitude_features': False, 'isotropic_norm': False, 'lmax_attributes': 1, 'lmax_hidden': 1, 'segnn_norm': 'none', 'velocity_aggregate': 'avg'}, noise_std: float = 0.0003, external_force_fn: Callable | None = None, dtype: dtype = 'float64')[source]

Set up a CaseSetupFn that contains every required function besides the model.

Inspired by the partition.neighbor_list function in JAX-MD.

The core functions are:
  • allocate, allocate memory for the neighbors list.

  • preprocess, update the neighbors list.

  • integrate, semi-implicit Euler respecting periodic boundary conditions.

Parameters:
  • box – Box xyz sizes of the system.

  • metadata – Dataset metadata dictionary.

  • input_seq_length – Length of the input sequence.

  • cfg_neighbors – Configuration dictionary for the neighbor list.

  • cfg_model – Configuration dictionary for the model / feature builder.

  • noise_std – Noise standard deviation.

  • external_force_fn – External force function.

  • dtype – Data type.

Featurizer

Feature extraction utilities.

lagrangebench.case_setup.features.physical_feature_builder(bounds: list, normalization_stats: dict, connectivity_radius: float, displacement_fn: Callable, pbc: List[bool], magnitude_features: bool = False, external_force_fn: Callable | None = None) Callable[source]

Build a physical feature transform function.

Transform raw coordinates to
  • Absolute positions

  • Historical velocity sequence

  • Velocity magnitudes

  • Distance to boundaries

  • External force field

  • Relative displacement vectors and distances

Parameters:
  • bounds – Each sublist contains the lower and upper bound of a dimension.

  • normalization_stats – Dict containing mean and std of velocities and targets

  • connectivity_radius – Radius of the connectivity graph.

  • displacement_fn – Displacement function.

  • pbc – Wether to use periodic boundary conditions.

  • magnitude_features – Whether to include the magnitude of the velocity.

  • external_force_fn – Function that returns the external force field (optional).