Source code for lagrangebench.case_setup.case

"""Case setup functions."""

import warnings
from typing import Callable, Dict, Optional, Tuple, Union

import jax.numpy as jnp
from jax import Array, jit, lax, vmap
from jax_sph.jax_md import space
from jax_sph.jax_md.dataclasses import dataclass, static_field
from jax_sph.jax_md.partition import NeighborList, NeighborListFormat, neighbor_list
from omegaconf import DictConfig, OmegaConf

from lagrangebench.data.utils import get_dataset_stats
from lagrangebench.defaults import defaults
from lagrangebench.train.strats import add_gns_noise

from .features import FeatureDict, TargetDict, physical_feature_builder

TrainCaseOut = Tuple[Array, FeatureDict, TargetDict, NeighborList]
EvalCaseOut = Tuple[FeatureDict, NeighborList]
SampleIn = Tuple[jnp.ndarray, jnp.ndarray]

AllocateFn = Callable[[Array, SampleIn, float, int], TrainCaseOut]
AllocateEvalFn = Callable[[SampleIn], EvalCaseOut]

PreprocessFn = Callable[[Array, SampleIn, float, NeighborList, int], TrainCaseOut]
PreprocessEvalFn = Callable[[SampleIn, NeighborList], EvalCaseOut]

IntegrateFn = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]


[docs] @dataclass class CaseSetupFn: """Dataclass that contains all functions required to setup the case and simulate. Attributes: allocate: AllocateFn, runs the preprocessing without having a NeighborList as input. preprocess: PreprocessFn, takes positions from the dataloader, computes velocities, adds random-walk noise if needed, then updates the neighbor list, and return the inputs to the neural network as well as the targets. allocate_eval: AllocateEvalFn, same as allocate, but without noise addition and without targets. preprocess_eval: PreprocessEvalFn, same as allocate_eval, but jit-able. integrate: IntegrateFn, semi-implicit Euler integrations step respecting all boundary conditions. displacement: space.DisplacementFn, displacement function aware of boundary conditions (periodic on non-periodic). normalization_stats: Dict, normalization statisticss for input velocities and output acceleration. """ allocate: AllocateFn = static_field() preprocess: PreprocessFn = static_field() allocate_eval: AllocateEvalFn = static_field() preprocess_eval: PreprocessEvalFn = static_field() integrate: IntegrateFn = static_field() displacement: space.DisplacementFn = static_field() normalization_stats: Dict = static_field()
[docs] def case_builder( box: Tuple[float, float, float], metadata: Dict, input_seq_length: int, cfg_neighbors: Union[Dict, DictConfig] = defaults.neighbors, cfg_model: Union[Dict, DictConfig] = defaults.model, noise_std: float = defaults.train.noise_std, external_force_fn: Optional[Callable] = None, dtype: jnp.dtype = defaults.dtype, ): """Set up a CaseSetupFn that contains every required function besides the model. Inspired by the `partition.neighbor_list` function in JAX-MD. The core functions are: * allocate, allocate memory for the neighbors list. * preprocess, update the neighbors list. * integrate, semi-implicit Euler respecting periodic boundary conditions. Args: box: Box xyz sizes of the system. metadata: Dataset metadata dictionary. input_seq_length: Length of the input sequence. cfg_neighbors: Configuration dictionary for the neighbor list. cfg_model: Configuration dictionary for the model / feature builder. noise_std: Noise standard deviation. external_force_fn: External force function. dtype: Data type. """ if isinstance(cfg_neighbors, Dict): cfg_neighbors = OmegaConf.create(cfg_neighbors) if isinstance(cfg_model, Dict): cfg_model = OmegaConf.create(cfg_model) # if one of the cfg_* arguments has a subset of the default configs, merge them cfg_neighbors = OmegaConf.merge(defaults.neighbors, cfg_neighbors) cfg_model = OmegaConf.merge(defaults.model, cfg_model) normalization_stats = get_dataset_stats( metadata, cfg_model.isotropic_norm, noise_std ) # apply PBC in all directions or not at all if jnp.array(metadata["periodic_boundary_conditions"]).any(): displacement_fn, shift_fn = space.periodic(side=jnp.array(box)) else: displacement_fn, shift_fn = space.free() displacement_fn_set = vmap(displacement_fn, in_axes=(0, 0)) if cfg_neighbors.multiplier < 1.25: warnings.warn( f"cfg_neighbors.multiplier={cfg_neighbors.multiplier} < 1.25 is very low. " "Be especially cautious if you batch training and/or inference as " "reallocation might be necessary based on different overflow conditions. " "See https://github.com/tumaer/lagrangebench/pull/20#discussion_r1443811262" ) neighbor_fn = neighbor_list( displacement_fn, jnp.array(box), backend=cfg_neighbors.backend, r_cutoff=metadata["default_connectivity_radius"], capacity_multiplier=cfg_neighbors.multiplier, mask_self=False, format=NeighborListFormat.Sparse, num_particles_max=metadata["num_particles_max"], pbc=metadata["periodic_boundary_conditions"], ) feature_transform = physical_feature_builder( bounds=metadata["bounds"], normalization_stats=normalization_stats, connectivity_radius=metadata["default_connectivity_radius"], displacement_fn=displacement_fn, pbc=metadata["periodic_boundary_conditions"], magnitude_features=cfg_model.magnitude_features, external_force_fn=external_force_fn, ) def _compute_target(pos_input: jnp.ndarray) -> TargetDict: # displacement(r1, r2) = r1-r2 # without PBC current_velocity = displacement_fn_set(pos_input[:, 1], pos_input[:, 0]) next_velocity = displacement_fn_set(pos_input[:, 2], pos_input[:, 1]) current_acceleration = next_velocity - current_velocity acc_stats = normalization_stats["acceleration"] normalized_acceleration = ( current_acceleration - acc_stats["mean"] ) / acc_stats["std"] vel_stats = normalization_stats["velocity"] normalized_velocity = (next_velocity - vel_stats["mean"]) / vel_stats["std"] return { "acc": normalized_acceleration, "vel": normalized_velocity, "pos": pos_input[:, -1], } def _preprocess( sample: Tuple[jnp.ndarray, jnp.ndarray], neighbors: Optional[NeighborList] = None, is_allocate: bool = False, mode: str = "train", **kwargs, # key, noise_std, unroll_steps ) -> Union[TrainCaseOut, EvalCaseOut]: pos_input = jnp.asarray(sample[0], dtype=dtype) particle_type = jnp.asarray(sample[1]) if mode == "train": key, noise_std = kwargs["key"], kwargs["noise_std"] unroll_steps = kwargs["unroll_steps"] if pos_input.shape[1] > 1: key, pos_input = add_gns_noise( key, pos_input, particle_type, input_seq_length, noise_std, shift_fn ) # allocate the neighbor list most_recent_position = pos_input[:, input_seq_length - 1] num_particles = (particle_type != -1).sum() if is_allocate: neighbors = neighbor_fn.allocate( most_recent_position, num_particles=num_particles ) else: neighbors = neighbors.update( most_recent_position, num_particles=num_particles ) # selected features features = feature_transform(pos_input[:, :input_seq_length], neighbors) if mode == "train": # compute target acceleration. Inverse of postprocessing step. # the "-2" is needed because we need the most recent position and one before slice_begin = (0, input_seq_length - 2 + unroll_steps, 0) slice_size = (pos_input.shape[0], 3, pos_input.shape[2]) target_dict = _compute_target( lax.dynamic_slice(pos_input, slice_begin, slice_size) ) return key, features, target_dict, neighbors if mode == "eval": return features, neighbors def allocate_fn(key, sample, noise_std=0.0, unroll_steps=0): return _preprocess( sample, key=key, noise_std=noise_std, unroll_steps=unroll_steps, is_allocate=True, ) @jit def preprocess_fn(key, sample, noise_std, neighbors, unroll_steps=0): return _preprocess( sample, neighbors, key=key, noise_std=noise_std, unroll_steps=unroll_steps ) def allocate_eval_fn(sample): return _preprocess(sample, is_allocate=True, mode="eval") @jit def preprocess_eval_fn(sample, neighbors): return _preprocess(sample, neighbors, mode="eval") @jit def integrate_fn(normalized_in, position_sequence): """Euler integrator to get position shift.""" assert any([key in normalized_in for key in ["pos", "vel", "acc"]]) if "pos" in normalized_in: # Zeroth euler step return normalized_in["pos"] else: most_recent_position = position_sequence[:, -1] if "vel" in normalized_in: # invert normalization velocity_stats = normalization_stats["velocity"] new_velocity = velocity_stats["mean"] + ( normalized_in["vel"] * velocity_stats["std"] ) elif "acc" in normalized_in: # invert normalization. acceleration_stats = normalization_stats["acceleration"] acceleration = acceleration_stats["mean"] + ( normalized_in["acc"] * acceleration_stats["std"] ) # Second Euler step most_recent_velocity = displacement_fn_set( most_recent_position, position_sequence[:, -2] ) new_velocity = most_recent_velocity + acceleration # * dt = 1 # First Euler step return shift_fn(most_recent_position, new_velocity) return CaseSetupFn( allocate_fn, preprocess_fn, allocate_eval_fn, preprocess_eval_fn, integrate_fn, displacement_fn, normalization_stats, )