Models
Base Class
- class lagrangebench.models.base.BaseModel(name: str | None = None)[source]
Base model class. All models must inherit from this class.
- abstract __call__(sample: Tuple[Dict[str, Array], Array]) Dict[str, Array][source]
Forward pass.
We specify the dimensions of the inputs and outputs using the number of nodes N, the number of edges E, number of historic velocities K (=input_seq_length - 1), and the dimensionality of the feature vectors dim.
- Parameters:
sample –
Tuple with feature dictionary and particle type. Possible features
”abs_pos” (N, K+1, dim), absolute positions
”vel_hist” (N, K*dim), historical velocity sequence
”vel_mag” (N,), velocity magnitudes
”bound” (N, 2*dim), distance to boundaries
”force” (N, dim), external force field
”rel_disp” (E, dim), relative displacement vectors
”rel_dist” (E, 1), relative distances, i.e. magnitude of displacements
”senders” (E), sender indices
”receivers” (E), receiver indices
- Returns:
Dict with model output. The keys must be at least one of the following:
”acc” (N, dim), (normalized) acceleration
”vel” (N, dim), (normalized) velocity
”pos” (N, dim), (absolute) next position
GNS
Graph Network-based Simulator. GNS model and feature transform.
- class lagrangebench.models.gns.GNS(particle_dimension: int, latent_size: int, blocks_per_step: int, num_mp_steps: int, particle_type_embedding_size: int, num_particle_types: int = NodeType.SIZE)[source]
Graph Network-based Simulator by Sanchez-Gonzalez et al..
GNS is the simples graph neural network applied to particle dynamics. It is built on the usual Graph Network architecture, with an encoder, a processor, and a decoder.
\[\begin{split}\begin{align} \mathbf{m}_{ij}^{(t+1)} &= \phi \left( \mathbf{m}_{ij}^{(t)}, \mathbf{h}_i^{(t)}, \mathbf{h}_j^{(t)} \right) \\ \mathbf{h}_i^{(t+1)} &= \psi \left( \mathbf{h}_i^{(t)}, \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}^{(t+1)} \right) \\ \end{align}\end{split}\]- __init__(particle_dimension: int, latent_size: int, blocks_per_step: int, num_mp_steps: int, particle_type_embedding_size: int, num_particle_types: int = NodeType.SIZE)[source]
Initialize the model.
- Parameters:
particle_dimension – Space dimensionality (e.g. 2 or 3).
latent_size – Size of the latent representations.
blocks_per_step – Number of MLP layers per block.
num_mp_steps – Number of message passing steps.
particle_type_embedding_size – Size of the particle type embedding.
num_particle_types – Max number of particle types.
- __call__(sample: Tuple[Dict[str, Array], Array]) Dict[str, Array][source]
Forward pass.
We specify the dimensions of the inputs and outputs using the number of nodes N, the number of edges E, number of historic velocities K (=input_seq_length - 1), and the dimensionality of the feature vectors dim.
- Parameters:
sample –
Tuple with feature dictionary and particle type. Possible features
”abs_pos” (N, K+1, dim), absolute positions
”vel_hist” (N, K*dim), historical velocity sequence
”vel_mag” (N,), velocity magnitudes
”bound” (N, 2*dim), distance to boundaries
”force” (N, dim), external force field
”rel_disp” (E, dim), relative displacement vectors
”rel_dist” (E, 1), relative distances, i.e. magnitude of displacements
”senders” (E), sender indices
”receivers” (E), receiver indices
- Returns:
Dict with model output. The keys must be at least one of the following:
”acc” (N, dim), (normalized) acceleration
”vel” (N, dim), (normalized) velocity
”pos” (N, dim), (absolute) next position
SEGNN
Steerable E(3) equivariant GNN from Brandstetter et al.. SEGNN model, layers and feature transform.
Original implementation: https://github.com/RobDHess/Steerable-E3-GNN
Standalone implementation + validation: https://github.com/gerkone/segnn-jax
- class lagrangebench.models.segnn.O3TensorProduct(output_irreps: ~e3nn_jax._src.irreps.Irreps, *, biases: bool = True, name: str | None = None, init_fn: ~typing.Callable = <function uniform_init>, gradient_normalization: str | float = 'element', path_normalization: str | float = 'element')[source]
O(3) equivariant linear parametrized tensor product layer.
Applies a linear (parametrized) tensor product of representations to the input(s).
\[\begin{align} tp(x, y) := \mathbf{x} \otimes_{CG}^{\mathcal{W}} \mathbf{y} \end{align}\]where \(\mathcal{W}\) are learnable parameters.
Uses
tensor_product+Linearinstead of FullyConnectedTensorProduct. From e3nn 0.19.2 (https://github.com/e3nn/e3nn-jax/releases/tag/0.19.2), this is as fast as FullyConnectedTensorProduct.- __init__(output_irreps: ~e3nn_jax._src.irreps.Irreps, *, biases: bool = True, name: str | None = None, init_fn: ~typing.Callable = <function uniform_init>, gradient_normalization: str | float = 'element', path_normalization: str | float = 'element')[source]
Initialize the tensor product.
- Parameters:
output_irreps – Output representation
biases – If set ot true will add biases
name – Name of the linear layer params
init_fn – Weight initialization function. Default is uniform.
gradient_normalization – Gradient normalization method. Default is “element”.
NOTE – gradient_normalization=”element” is the default in torch and haiku.
path_normalization – Path normalization method. Default is “element”
- __call__(x: IrrepsArray, y: IrrepsArray | None = None) IrrepsArray[source]
Applies an O(3) equivariant linear parametrized tensor product layer.
- Parameters:
x (IrrepsArray) – Left tensor
y (IrrepsArray) – Right tensor. If None it defaults to np.ones.
- Returns:
The output to the weighted tensor product (IrrepsArray).
- lagrangebench.models.segnn.O3TensorProductGate(output_irreps: Irreps, *, biases: bool = True, scalar_activation: Callable | None = None, gate_activation: Callable | None = None, name: str | None = None, init_fn: Callable | None = None) Callable[source]
Non-linear (gated) O(3) equivariant linear tensor product layer.
It applies a linear tensor product of representations to the input(s) and then a gated nonlinearity by Weiler et al..
The input representation is lifted to have gating scalars.
- Parameters:
output_irreps – Output representation
biases – Add biases
scalar_activation – Activation function for scalars
gate_activation – Activation function for higher order
name – Name of the linear layer params
- Returns:
Function that applies the gated tensor product layer.
- lagrangebench.models.segnn.O3Embedding(embed_irreps: Irreps, embed_edges: bool = True) Callable[source]
Linear steerable embedding.
Embeds the graph nodes in the representation space :param embed_irreps:.
- Parameters:
embed_irreps – Output representation
embed_edges – If true also embed edges/message passing features
- Returns:
Function to embed graph nodes (and optionally edges)
- lagrangebench.models.segnn.O3Decoder(latent_irreps: Irreps, output_irreps: Irreps, n_blocks: int = 1)[source]
Steerable decoder.
- Parameters:
latent_irreps – Representation from the previous block
output_irreps – Output representation
n_blocks – Number of tensor product blocks in the decoder
- Returns:
Decoded latent feature space to output space.
- class lagrangebench.models.segnn.SEGNNLayer(output_irreps: ~e3nn_jax._src.irreps.Irreps, layer_idx: int, n_blocks: int = 2, norm: str | None = None, aggregate_fn: ~typing.Callable | None = <function segment_sum>)[source]
Steerable E(3) equivariant layer.
Applies a message passing step (GN) with equivariant message and update functions.
- __init__(output_irreps: ~e3nn_jax._src.irreps.Irreps, layer_idx: int, n_blocks: int = 2, norm: str | None = None, aggregate_fn: ~typing.Callable | None = <function segment_sum>)[source]
Initialize the layer.
- Parameters:
output_irreps – Layer output representation
layer_idx – Numbering of the layer
n_blocks – Number of tensor product n_blocks in the layer
norm – Normalization type. Either be None, ‘instance’ or ‘batch’
aggregate_fn – Message aggregation function. Defaults to sum.
- __call__(st_graph: SteerableGraphsTuple) SteerableGraphsTuple[source]
Perform a message passing step.
- Parameters:
st_graph – Input graph
- Returns:
The updated graph
- lagrangebench.models.segnn.weight_balanced_irreps(scalar_units: int, irreps_right: Irreps, lmax: int | None = None) Irreps[source]
Determine left irreps so that the tensor product with irreps_right has at least scalar_units weights.
- Parameters:
scalar_units – Number of weights
irreps_right – Right irreps
lmax – Maximum L of the left irreps
- Returns:
Left irreps
- class lagrangebench.models.segnn.SEGNN(node_features_irreps: Irreps, edge_features_irreps: Irreps, scalar_units: int, lmax_hidden: int, lmax_attributes: int, output_irreps: Irreps, num_mp_steps: int, n_vels: int, velocity_aggregate: str = 'avg', homogeneous_particles: bool = True, norm: str | None = None, blocks_per_step: int = 2, embed_msg_features: bool = False)[source]
Steerable E(3) equivariant network by Brandstetter et al..
SEGNNs are E(3)-equivariant graph neural networks based around tensor products of representations. By design, SEGNNs allow for flexible scalar/vectorial inputs and outputs for both edges and attributes. The message passing is modified as follows:
\[\begin{split}\begin{align} \mathbf{m}_{ij} &= \textit{M}_{\mathbf{\hat{a}}_{ij}}\left( \mathbf{f}_i, \mathbf{f}_j, \| x_i - x_j \|^2 \right), \\ \mathbf{f}^{\prime}_i &= \textit{U}_{\mathbf{\hat{a}}_i}\left( \mathbf{f}_i, \sum_{j\in\mathcal{N}(i)} \mathbf{m}_{ij} \right) \end{align}\end{split}\]\(\mathbf{\hat{a}}_{ij}\) and \(\mathbf{\hat{a}}_{i}\) are edge and node attributes and the operators \(\textit{M}_{\mathbf{\hat{a}}_{ij}}\) and \(\textit{U}_{\mathbf{\hat{a}}_{i}}\) are defined as a tensor product of representations \(\otimes_{CG}\) between the input and the attribues:
\[\begin{align} \textit{U}_{\mathbf{\hat{a}}_i} := \mathcal{W^n}_{\mathbf{\hat{a}}_i} (\dots \sigma ( \mathcal{W^0}_{\mathbf{\hat{a}}_i} \mathbf{f} ) ) \quad \text{with} \quad \mathcal{W}_{\mathbf{\hat{a}}_i} \mathbf{f} := \mathbf{f} \otimes_{CG}^{\mathcal{W}} \mathbf{\mathbf{\hat{a}}} \end{align}\]where :math`mathbf{f} = [mathbf{f}_i,sum_{jinmathcal{N}(i)} mathbf{m}_{ij}]` are node features concatenated to the aggregated messages, \(\sigma\) is a gated non-linearity and \(\mathcal{W}\) are the tensor product parameters. \(\textit{M}_{\mathbf{\hat{a}}_{ij}}\) is similarly defined, but with the nonlinearity on the last layer, with edge attributes \(\mathbf{\hat{a}}_{ij}\) and :math`mathbf{f} = [ mathbf{f}_i, mathbf{f}_j, |x_i - x_j|^2 ]`
- __init__(node_features_irreps: Irreps, edge_features_irreps: Irreps, scalar_units: int, lmax_hidden: int, lmax_attributes: int, output_irreps: Irreps, num_mp_steps: int, n_vels: int, velocity_aggregate: str = 'avg', homogeneous_particles: bool = True, norm: str | None = None, blocks_per_step: int = 2, embed_msg_features: bool = False)[source]
Initialize the network.
- Parameters:
node_features_irreps – Irreps of the node features.
edge_features_irreps – Irreps of the additional message passing features.
scalar_units – Hidden units (lower bound). Actual number depends on lmax.
lmax_hidden – Maximum L of the hidden layer representations.
lmax_attributes – Maximum L of the attributes.
output_irreps – Output representation.
num_mp_steps – Number of message passing layers
n_vels – Number of velocities in the history.
velocity_aggregate – Velocity sequence aggregation method.
homogeneous_particles – If all particles are of homogeneous type.
norm – Normalization type. Either None, ‘instance’ or ‘batch’
blocks_per_step – Number of tensor product blocks in each message passing
embed_msg_features – Set to true to also embed edges/message passing features
- __call__(sample: Tuple[Dict[str, Array], Array]) Dict[str, Array][source]
Forward pass.
We specify the dimensions of the inputs and outputs using the number of nodes N, the number of edges E, number of historic velocities K (=input_seq_length - 1), and the dimensionality of the feature vectors dim.
- Parameters:
sample –
Tuple with feature dictionary and particle type. Possible features
”abs_pos” (N, K+1, dim), absolute positions
”vel_hist” (N, K*dim), historical velocity sequence
”vel_mag” (N,), velocity magnitudes
”bound” (N, 2*dim), distance to boundaries
”force” (N, dim), external force field
”rel_disp” (E, dim), relative displacement vectors
”rel_dist” (E, 1), relative distances, i.e. magnitude of displacements
”senders” (E), sender indices
”receivers” (E), receiver indices
- Returns:
Dict with model output. The keys must be at least one of the following:
”acc” (N, dim), (normalized) acceleration
”vel” (N, dim), (normalized) velocity
”pos” (N, dim), (absolute) next position
EGNN
E(n) equivariant GNN from Garcia Satorras et al.. EGNN model, layers and feature transform.
Original implementation: https://github.com/vgsatorras/egnn
Standalone implementation + validation: https://github.com/gerkone/egnn-jax
- class lagrangebench.models.egnn.EGNNLayer(layer_num: int, hidden_size: int, output_size: int, displacement_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], shift_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], blocks: int = 1, act_fn: ~typing.Callable = <PjitFunction of <function silu>>, pos_aggregate_fn: ~typing.Callable | None = <function segment_sum>, msg_aggregate_fn: ~typing.Callable | None = <function segment_sum>, residual: bool = True, attention: bool = False, normalize: bool = False, tanh: bool = False, dt: float = 0.001, eps: float = 1e-08)[source]
E(n)-equivariant EGNN layer.
Applies a message passing step where the positions are corrected with the velocities and a learnable correction term \(\psi_x(\mathbf{h}_i^{(t+1)})\):
- __init__(layer_num: int, hidden_size: int, output_size: int, displacement_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], shift_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], blocks: int = 1, act_fn: ~typing.Callable = <PjitFunction of <function silu>>, pos_aggregate_fn: ~typing.Callable | None = <function segment_sum>, msg_aggregate_fn: ~typing.Callable | None = <function segment_sum>, residual: bool = True, attention: bool = False, normalize: bool = False, tanh: bool = False, dt: float = 0.001, eps: float = 1e-08)[source]
Initialize the layer.
- Parameters:
layer_num – layer number
hidden_size – hidden size
output_size – output size
displacement_fn – Displacement function for the acceleration computation.
shift_fn – Shift function for updating positions
blocks – number of blocks in the node and edge MLPs
act_fn – activation function
pos_aggregate_fn – position aggregation function
msg_aggregate_fn – message aggregation function
residual – whether to use residual connections
attention – whether to use attention
normalize – whether to normalize the coordinates
tanh – whether to use tanh in the position update
dt – position update step size
eps – small number to avoid division by zero
- __call__(graph: GraphsTuple, pos: Array, vel: Array, edge_attribute: Array | None = None, node_attribute: Array | None = None) Tuple[GraphsTuple, Array][source]
Apply EGNN layer.
- Parameters:
graph – Graph from previous step
pos – Node position, updated separately
vel – Node velocity
edge_attribute – Edge attribute (optional)
node_attribute – Node attribute (optional)
- Returns:
Updated graph, node position
- class lagrangebench.models.egnn.EGNN(hidden_size: int, output_size: int, dt: float, n_vels: int, displacement_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], shift_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], normalization_stats: ~typing.Dict[str, ~jax.Array] | None = None, act_fn: ~typing.Callable = <PjitFunction of <function silu>>, num_mp_steps: int = 4, homogeneous_particles: bool = True, residual: bool = True, attention: bool = False, normalize: bool = False, tanh: bool = False)[source]
E(n) Graph Neural Network by Garcia Satorras et al..
EGNN doesn’t require expensive higher-order representations in intermediate layers; instead it relies on separate scalar and vector channels, which are treated differently by EGNN layers. In this setup, EGNN is similar to a learnable numerical integrator:
\[\begin{split}\begin{align} \mathbf{m}_{ij}^{(t+1)} &= \phi_e \left( \mathbf{m}_{ij}^{(t)}, \mathbf{h}_i^{(t)}, \mathbf{h}_j^{(t)}, ||\mathbf{x}_i^{(t)} - \mathbf{x}_j^{(t)}||^2 \right) \\ \mathbf{\hat{m}}_{ij}^{(t+1)} &= (\mathbf{x}_i^{(t)} - \mathbf{x}_j^{(t)}) \phi_x(\mathbf{m}_{ij}^{(t+1)}) \end{align}\end{split}\]And the node update with the integrator
\[\begin{split}\begin{align} \mathbf{h}_i^{(t+1)} &= \psi_h \left( \mathbf{h}_i^{(t)}, \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}^{(t+1)} \right) \\ \mathbf{x}_i^{(t+1)} &= \mathbf{x}_i^{(t)} + \mathbf{\hat{m}}_{ij}^{(t+1)} \psi_x(\mathbf{h}_i^{(t+1)}) \end{align}\end{split}\]where \(\mathbf{m}_{ij}\) and \(\mathbf{\hat{m}}_{ij}\) are the scalar and vector messages respectively, and \(\mathbf{x}_{i}\) are the positions.
This implementation differs from the original in two places:
because our datasets can have periodic boundary conditions, we use shift and displacement functions that take care of it when operations on positions are done.
we apply a simple integrator after the last layer to get the acceleration.
- __init__(hidden_size: int, output_size: int, dt: float, n_vels: int, displacement_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], shift_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], normalization_stats: ~typing.Dict[str, ~jax.Array] | None = None, act_fn: ~typing.Callable = <PjitFunction of <function silu>>, num_mp_steps: int = 4, homogeneous_particles: bool = True, residual: bool = True, attention: bool = False, normalize: bool = False, tanh: bool = False)[source]
Initialize the network.
- Parameters:
hidden_size – Number of hidden features.
output_size – Number of features for ‘h’ at the output.
dt – Time step for position and velocity integration. Used to rescale the initialization of the correction MLP.
n_vels – Number of velocities in the history.
displacement_fn – Displacement function for the acceleration computation.
shift_fn – Shift function for updating positions.
normalization_stats – Normalization statistics for the input data.
act_fn – Non-linearity.
num_mp_steps – Number of layer for the EGNN
homogeneous_particles – If all particles are of homogeneous type.
residual – Whether to use residual connections.
attention – Whether to use attention or not.
normalize – Normalizes the coordinates messages such that:
x^{l+1}_i = x^{l}_i + \sum(x_i - x_j)\phi_x(m_{ij})\|x_i - x_j\|It may help in the stability or generalization. Not used in the paper.tanh – Sets a tanh activation function at the output of
\phi_x(m_{ij}). It bounds the output of\phi_x(m_{ij})which definitely improves in stability but it may decrease in accuracy. Not used in the paper.
- __call__(sample: Tuple[Dict[str, Array], Array]) Dict[str, Array][source]
Forward pass.
We specify the dimensions of the inputs and outputs using the number of nodes N, the number of edges E, number of historic velocities K (=input_seq_length - 1), and the dimensionality of the feature vectors dim.
- Parameters:
sample –
Tuple with feature dictionary and particle type. Possible features
”abs_pos” (N, K+1, dim), absolute positions
”vel_hist” (N, K*dim), historical velocity sequence
”vel_mag” (N,), velocity magnitudes
”bound” (N, 2*dim), distance to boundaries
”force” (N, dim), external force field
”rel_disp” (E, dim), relative displacement vectors
”rel_dist” (E, 1), relative distances, i.e. magnitude of displacements
”senders” (E), sender indices
”receivers” (E), receiver indices
- Returns:
Dict with model output. The keys must be at least one of the following:
”acc” (N, dim), (normalized) acceleration
”vel” (N, dim), (normalized) velocity
”pos” (N, dim), (absolute) next position
PaiNN
Modified PaiNN implementation for general vectorial inputs and outputs Schütt et al.. PaiNN model, layers and feature transform.
Original implementation: https://github.com/atomistic-machine-learning/schnetpack
Standalone implementation + validation: https://github.com/gerkone/painn-jax
- class lagrangebench.models.painn.NodeFeatures(s: Array | None = None, v: Array | None = None)[source]
Simple container for PaiNN scalar and vectorial node features.
- __getnewargs__()
Return self as a plain tuple. Used by copy and pickle.
- static __new__(_cls, s: Array | None = None, v: Array | None = None)
Create new instance of NodeFeatures(s, v)
- class lagrangebench.models.painn.GatedEquivariantBlock(hidden_size: int, scalar_out_channels: int, vector_out_channels: int, activation: ~typing.Callable = <PjitFunction of <function silu>>, scalar_activation: ~typing.Callable = None, eps: float = 1e-08, name: str = 'gated_equivariant_block')[source]
Gated equivariant block (restricted to vectorial features).
- __init__(hidden_size: int, scalar_out_channels: int, vector_out_channels: int, activation: ~typing.Callable = <PjitFunction of <function silu>>, scalar_activation: ~typing.Callable | None = None, eps: float = 1e-08, name: str = 'gated_equivariant_block')[source]
Initialize the layer.
- Parameters:
hidden_size – Number of hidden channels.
scalar_out_channels – Number of scalar output channels.
vector_out_channels – Number of vector output channels.
activation – Gate activation function.
scalar_activation – Activation function for the scalar output.
eps – Constant added in norm to prevent derivation instabilities.
name – Name of the module.
- lagrangebench.models.painn.gaussian_rbf(n_rbf: int, cutoff: float, start: float = 0.0, centered: bool = False, trainable: bool = False) Callable[[Array], Callable][source]
Gaussian radial basis functions.
- Parameters:
n_rbf – total number of Gaussian functions, \(N_g\).
cutoff – center of last Gaussian function, \(\mu_{N_g}\)
start – center of first Gaussian function, \(\mu_0\).
trainable – If True, widths and offset of Gaussian functions learnable.
- lagrangebench.models.painn.cosine_cutoff(cutoff: float) Callable[[Array], Callable][source]
Behler-style cosine cutoff.
\[\begin{split}f(r) = \begin{cases} 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right] & r < r_\text{cutoff} \\ 0 & r \geqslant r_\text{cutoff} \\ \end{cases}\end{split}\]- Parameters:
cutoff (float) – cutoff radius.
- lagrangebench.models.painn.PaiNNReadout(hidden_size: int, out_channels: int = 1, activation: ~typing.Callable = <PjitFunction of <function silu>>, blocks: int = 2, eps: float = 1e-08) Callable[[GraphsTuple], Tuple[Array, Array]][source]
PaiNN readout block.
- Parameters:
hidden_size – Number of hidden channels.
scalar_out_channels – Number of scalar/vector output channels.
activation – Activation function.
blocks – Number of readout blocks.
- Returns:
Configured readout function.
- class lagrangebench.models.painn.PaiNNLayer(hidden_size: int, layer_num: int, activation: ~typing.Callable = <PjitFunction of <function silu>>, blocks: int = 2, aggregate_fn: ~typing.Callable = <function segment_sum>, eps: float = 1e-08)[source]
PaiNN interaction block.
- __init__(hidden_size: int, layer_num: int, activation: ~typing.Callable = <PjitFunction of <function silu>>, blocks: int = 2, aggregate_fn: ~typing.Callable = <function segment_sum>, eps: float = 1e-08)[source]
Initialize the PaiNN layer, made up of an interaction block and a mixing block.
- Parameters:
hidden_size – Number of node features.
activation – Activation function.
layer_num – Numbering of the layer.
blocks – Number of layers in the context networks.
aggregate_fn – Function to aggregate the neighbors.
eps – Constant added in norm to prevent derivation instabilities.
- class lagrangebench.models.painn.PaiNN(hidden_size: int, output_size: int, num_mp_steps: int, radial_basis_fn: ~typing.Callable, cutoff_fn: ~typing.Callable, n_vels: int, homogeneous_particles: bool = True, activation: ~typing.Callable = <PjitFunction of <function silu>>, shared_interactions: bool = False, shared_filters: bool = False, eps: float = 1e-08)[source]
Polarizable interaction Neural Network by Schütt et al..
In order to accomodate general inputs/outputs, this PaiNN is different from the original in a few ways; the main change is that inputs vectors are not initialized to 0 anymore but to the time average of velocity.
- __init__(hidden_size: int, output_size: int, num_mp_steps: int, radial_basis_fn: ~typing.Callable, cutoff_fn: ~typing.Callable, n_vels: int, homogeneous_particles: bool = True, activation: ~typing.Callable = <PjitFunction of <function silu>>, shared_interactions: bool = False, shared_filters: bool = False, eps: float = 1e-08)[source]
Initialize the model.
- Parameters:
hidden_size – Determines the size of each embedding vector.
output_size – Number of output features.
num_mp_steps – Number of interaction blocks.
radial_basis_fn – Expands inter-particle distances in a basis set.
cutoff_fn – Cutoff function.
n_vels – Number of historical velocities.
homogeneous_particles – If all particles are of homogeneous type.
activation – Activation function.
shared_interactions – If True, share the weights across interaction blocks.
shared_filters – If True, share the weights across filter networks.
eps – Constant added in norm to prevent derivation instabilities.
Linear
Simple baseline linear model.
- class lagrangebench.models.linear.Linear(dim_out)[source]
Model defining linear relation between input nodes and targets.
\(\mathbf{a}_i = \mathbf{W} \mathbf{x}_i\) where \(\mathbf{a}_i\) are the output accelerations, \(\mathbf{W}\) is a learnable weight matrix and \(\mathbf{x}_i\) are input features.
- __call__(sample: Tuple[Dict[str, Array], ndarray]) Dict[str, Array][source]
Forward pass.
We specify the dimensions of the inputs and outputs using the number of nodes N, the number of edges E, number of historic velocities K (=input_seq_length - 1), and the dimensionality of the feature vectors dim.
- Parameters:
sample –
Tuple with feature dictionary and particle type. Possible features
”abs_pos” (N, K+1, dim), absolute positions
”vel_hist” (N, K*dim), historical velocity sequence
”vel_mag” (N,), velocity magnitudes
”bound” (N, 2*dim), distance to boundaries
”force” (N, dim), external force field
”rel_disp” (E, dim), relative displacement vectors
”rel_dist” (E, 1), relative distances, i.e. magnitude of displacements
”senders” (E), sender indices
”receivers” (E), receiver indices
- Returns:
Dict with model output. The keys must be at least one of the following:
”acc” (N, dim), (normalized) acceleration
”vel” (N, dim), (normalized) velocity
”pos” (N, dim), (absolute) next position
Utils
- class lagrangebench.models.utils.LinearXav(output_size: int, with_bias: bool = True, w_init: Callable[[Sequence[int], Any], Array] | None = None, b_init: Callable[[Sequence[int], Any], Array] | None = None, name: str | None = None)[source]
Linear layer with Xavier init. Avoid distracting ‘w_init’ everywhere.
- __init__(output_size: int, with_bias: bool = True, w_init: Callable[[Sequence[int], Any], Array] | None = None, b_init: Callable[[Sequence[int], Any], Array] | None = None, name: str | None = None)[source]
Constructs the Linear module.
- Parameters:
output_size – Output dimensionality.
with_bias – Whether to add a bias to the output.
w_init – Optional initializer for weights. By default, uses random values from truncated normal, with stddev
1 / sqrt(fan_in). See https://arxiv.org/abs/1502.03167v3.b_init – Optional initializer for bias. By default, zero.
name – Name of the module.
- class lagrangebench.models.utils.MLPXav(output_sizes: ~typing.Iterable[int], with_bias: bool = True, w_init: ~typing.Callable[[~collections.abc.Sequence[int], ~typing.Any], ~jax.Array] | None = None, b_init: ~typing.Callable[[~collections.abc.Sequence[int], ~typing.Any], ~jax.Array] | None = None, activation: ~typing.Callable = <PjitFunction of <function silu>>, activate_final: bool = False, name: str | None = None)[source]
MLP layer with Xavier init. Avoid distracting ‘w_init’ everywhere.
- __init__(output_sizes: ~typing.Iterable[int], with_bias: bool = True, w_init: ~typing.Callable[[~collections.abc.Sequence[int], ~typing.Any], ~jax.Array] | None = None, b_init: ~typing.Callable[[~collections.abc.Sequence[int], ~typing.Any], ~jax.Array] | None = None, activation: ~typing.Callable = <PjitFunction of <function silu>>, activate_final: bool = False, name: str | None = None)[source]
Constructs an MLP.
- Parameters:
output_sizes – Sequence of layer sizes.
w_init – Initializer for
Linearweights.b_init – Initializer for
Linearbias. Must beNoneifwith_bias=False.with_bias – Whether or not to apply a bias in each layer.
activation – Activation function to apply between
Linearlayers. Defaults to ReLU.activate_final – Whether or not to activate the final layer of the MLP.
name – Optional name for this module.
- Raises:
ValueError – If
with_biasisFalseandb_initis notNone.
- class lagrangebench.models.utils.SteerableGraphsTuple(graph: GraphsTuple, node_attributes: IrrepsArray | None = None, edge_attributes: IrrepsArray | None = None, additional_message_features: IrrepsArray | None = None)[source]
Pack (steerable) node and edge attributes with jraph.GraphsTuple.
- graph
jraph.GraphsTuple, graph structure
- Type:
jraph._src.graph.GraphsTuple
- node_attributes
(N, irreps.dim), node attributes \(\mathbf{\hat{a}}_i\)
- Type:
e3nn_jax._src.irreps_array.IrrepsArray | None
- edge_attributes
(E, irreps.dim), edge attributes \(\mathbf{\hat{a}}_{ij}\)
- Type:
e3nn_jax._src.irreps_array.IrrepsArray | None
- additional_message_features
(E, edge_dim), optional message features
- Type:
e3nn_jax._src.irreps_array.IrrepsArray | None
- lagrangebench.models.utils.node_irreps(metadata: Dict, input_seq_length: int, has_external_force: bool, has_magnitudes: bool, has_homogeneous_particles: bool) str[source]
Compute input node irreps based on which features are available.