from abc import ABC, abstractmethod
from typing import Dict, Tuple
import haiku as hk
import jax.numpy as jnp
[docs]
class BaseModel(hk.Module, ABC):
"""Base model class. All models must inherit from this class."""
[docs]
@abstractmethod
def __call__(
self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray]
) -> Dict[str, jnp.ndarray]:
"""Forward pass.
We specify the dimensions of the inputs and outputs using the number of nodes N,
the number of edges E, number of historic velocities K (=input_seq_length - 1),
and the dimensionality of the feature vectors dim.
Args:
sample: Tuple with feature dictionary and particle type. Possible features
- "abs_pos" (N, K+1, dim), absolute positions
- "vel_hist" (N, K*dim), historical velocity sequence
- "vel_mag" (N,), velocity magnitudes
- "bound" (N, 2*dim), distance to boundaries
- "force" (N, dim), external force field
- "rel_disp" (E, dim), relative displacement vectors
- "rel_dist" (E, 1), relative distances, i.e. magnitude of displacements
- "senders" (E), sender indices
- "receivers" (E), receiver indices
Returns:
Dict with model output.
The keys must be at least one of the following:
- "acc" (N, dim), (normalized) acceleration
- "vel" (N, dim), (normalized) velocity
- "pos" (N, dim), (absolute) next position
"""
raise NotImplementedError