"""Training tricks and strategies, currently: random-walk noise and push forward."""
from typing import Tuple
import jax
import jax.numpy as jnp
from jax_sph.jax_md.partition import space
from lagrangebench.utils import get_kinematic_mask
[docs]
def add_gns_noise(
key: jax.Array,
pos_input: jnp.ndarray,
particle_type: jnp.ndarray,
input_seq_length: int,
noise_std: float,
shift_fn: space.ShiftFn,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
r"""GNS-like random walk noise injection as described by
`Sanchez-Gonzalez et al. <https://arxiv.org/abs/2002.09405>`_.
Applies random-walk noise to the input positions and adjusts the targets accordingly
to keep the trajectory consistent. It works by drawing independent samples from
:math:`\mathcal{N^{(t)}}(0, \sigma_v^{(t)})` for each input state. Noise is
accummulated as a random walk and added to the velocity seqence.
Each :math:`\sigma_v^{(t)}` is set so that the last step of the random walk has
:math:`\sigma_v^{(input\_seq\_length)}=noise\_std`. Based on the noised velocities,
positions are adjusted such that :math:`\dot{p}^{t_k} = p^{t_k} â p^{t_{kâ1}}`.
Args:
key: Random key.
pos_input: Clean input positions. Shape:
(num_particles_max, input_seq_length + pushforward["unrolls"][-1] + 1, dim)
particle_type: Particle type vector. Shape: (num_particles_max,)
input_seq_length: Input sequence length, as in the configs.
noise_std: Noise standard deviation at the last sequence step.
shift_fn: Shift function.
"""
isl = input_seq_length
# random-walk noise in the velocity applied to the first input_seq_length positions
key, pos_input_noise = _get_random_walk_noise_for_pos_sequence(
key, pos_input[:, :input_seq_length], noise_std_last_step=noise_std
)
kinematic_mask = get_kinematic_mask(particle_type)
pos_input_noise = jnp.where(kinematic_mask[:, None, None], 0.0, pos_input_noise)
# adjust targets based on the noise from the last input position
n_potential_targets = pos_input[:, isl:].shape[1]
pos_target_noise = pos_input_noise[:, -1][:, None, :]
pos_target_noise = jnp.tile(pos_target_noise, (1, n_potential_targets, 1))
pos_input_noise = jnp.concatenate([pos_input_noise, pos_target_noise], axis=1)
shift_vmap = jax.vmap(shift_fn, in_axes=(0, 0))
shift_dvmap = jax.vmap(shift_vmap, in_axes=(0, 0))
pos_input_noisy = shift_dvmap(pos_input, pos_input_noise)
return key, pos_input_noisy
def _get_random_walk_noise_for_pos_sequence(
key, position_sequence, noise_std_last_step
):
key, subkey = jax.random.split(key)
velocity_sequence_shape = list(position_sequence.shape)
velocity_sequence_shape[1] -= 1
n_velocities = velocity_sequence_shape[1]
velocity_sequence_noise = jax.random.normal(
subkey, shape=tuple(velocity_sequence_shape)
)
velocity_sequence_noise *= noise_std_last_step / (n_velocities**0.5)
velocity_sequence_noise = jnp.cumsum(velocity_sequence_noise, axis=1)
position_sequence_noise = jnp.concatenate(
[
jnp.zeros_like(velocity_sequence_noise[:, 0:1]),
jnp.cumsum(velocity_sequence_noise, axis=1),
],
axis=1,
)
return key, position_sequence_noise
[docs]
def push_forward_sample_steps(key, step, pushforward):
"""Sample the number of unroll steps based on the current training step and the
specified pushforward configuration.
Args:
key: Random key
step: Current training step
pushforward: Pushforward configuration
"""
key, key_unroll = jax.random.split(key, 2)
# steps needs to be an ordered list
steps = jnp.array(pushforward.steps)
assert all(steps[i] <= steps[i + 1] for i in range(len(steps) - 1))
# until which index to sample from
idx = (step > steps).sum()
unroll_steps = jax.random.choice(
key_unroll,
a=jnp.array(pushforward.unrolls[:idx]),
p=jnp.array(pushforward.probs[:idx]),
)
return key, unroll_steps
[docs]
def push_forward_build(model_apply, case):
r"""Build the push forward function, introduced by
`Brandstetter et al. <https://arxiv.org/abs/2202.03376>`_.
Pushforward works by adding a stability "pushforward" loss term, in the form of an
adversarial style loss.
.. math::
L_{pf} = \mathbb{E}_k \mathbb{E}_{u^{k+1} | u^k}
\mathbb{E}_{\epsilon} \left[ \mathcal{L}(f(u^k + \epsilon), u^{k-1}) \right]
where :math:`\epsilon` is :math:`u^k + \epsilon = f(u^{kâ1})`, i.e. the 2-step
unroll of the solver :math:`f` (from step :math:`k-1` to :math:`k`).
The total loss is then :math:`L_{total}=\mathcal{L}(f(u^k), u^{k-1}) + L_{pf}`.
Similarly, for :math:`S > 2` pushforward steps, :math:`L_{pf}` is extended to
:math:`u^{k-S} \dots u^{k-1}` with cumulated :math:`\epsilon` perturbations.
In practice, this is implemented by unrolling the solver for two steps, but only
running gradients through the last unroll step.
Args:
model_apply: Model apply function
case: Case setup function
"""
@jax.jit
def push_forward_fn(features, current_pos, particle_type, neighbors, params, state):
"""Push forward function.
Args:
features: Input features
current_pos: Current position
particle_type: Particle type vector
neighbors: Neighbor list
params: Model parameters
state: Model state
"""
# no buffer overflow check here, since push forward acts on later epochs
pred, _ = model_apply(params, state, (features, particle_type))
next_pos = case.integrate(pred, current_pos)
current_pos = jnp.concatenate(
[current_pos[:, 1:], next_pos[:, None, :]], axis=1
)
features, neighbors = case.preprocess_eval(
(current_pos, particle_type), neighbors
)
return current_pos, neighbors, features
return push_forward_fn