Models

Base Class

class lagrangebench.models.base.BaseModel(name: str | None = None)[source]

Base model class. All models must inherit from this class.

abstract __call__(sample: Tuple[Dict[str, Array], Array]) Dict[str, Array][source]

Forward pass.

We specify the dimensions of the inputs and outputs using the number of nodes N, the number of edges E, number of historic velocities K (=input_seq_length - 1), and the dimensionality of the feature vectors dim.

Parameters:

sample

Tuple with feature dictionary and particle type. Possible features

  • ”abs_pos” (N, K+1, dim), absolute positions

  • ”vel_hist” (N, K*dim), historical velocity sequence

  • ”vel_mag” (N,), velocity magnitudes

  • ”bound” (N, 2*dim), distance to boundaries

  • ”force” (N, dim), external force field

  • ”rel_disp” (E, dim), relative displacement vectors

  • ”rel_dist” (E, 1), relative distances, i.e. magnitude of displacements

  • ”senders” (E), sender indices

  • ”receivers” (E), receiver indices

Returns:

Dict with model output. The keys must be at least one of the following:

  • ”acc” (N, dim), (normalized) acceleration

  • ”vel” (N, dim), (normalized) velocity

  • ”pos” (N, dim), (absolute) next position

GNS

Graph Network-based Simulator. GNS model and feature transform.

class lagrangebench.models.gns.GNS(particle_dimension: int, latent_size: int, blocks_per_step: int, num_mp_steps: int, particle_type_embedding_size: int, num_particle_types: int = NodeType.SIZE)[source]

Graph Network-based Simulator by Sanchez-Gonzalez et al..

GNS is the simples graph neural network applied to particle dynamics. It is built on the usual Graph Network architecture, with an encoder, a processor, and a decoder.

\[\begin{split}\begin{align} \mathbf{m}_{ij}^{(t+1)} &= \phi \left( \mathbf{m}_{ij}^{(t)}, \mathbf{h}_i^{(t)}, \mathbf{h}_j^{(t)} \right) \\ \mathbf{h}_i^{(t+1)} &= \psi \left( \mathbf{h}_i^{(t)}, \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}^{(t+1)} \right) \\ \end{align}\end{split}\]
__init__(particle_dimension: int, latent_size: int, blocks_per_step: int, num_mp_steps: int, particle_type_embedding_size: int, num_particle_types: int = NodeType.SIZE)[source]

Initialize the model.

Parameters:
  • particle_dimension – Space dimensionality (e.g. 2 or 3).

  • latent_size – Size of the latent representations.

  • blocks_per_step – Number of MLP layers per block.

  • num_mp_steps – Number of message passing steps.

  • particle_type_embedding_size – Size of the particle type embedding.

  • num_particle_types – Max number of particle types.

__call__(sample: Tuple[Dict[str, Array], Array]) Dict[str, Array][source]

Forward pass.

We specify the dimensions of the inputs and outputs using the number of nodes N, the number of edges E, number of historic velocities K (=input_seq_length - 1), and the dimensionality of the feature vectors dim.

Parameters:

sample

Tuple with feature dictionary and particle type. Possible features

  • ”abs_pos” (N, K+1, dim), absolute positions

  • ”vel_hist” (N, K*dim), historical velocity sequence

  • ”vel_mag” (N,), velocity magnitudes

  • ”bound” (N, 2*dim), distance to boundaries

  • ”force” (N, dim), external force field

  • ”rel_disp” (E, dim), relative displacement vectors

  • ”rel_dist” (E, 1), relative distances, i.e. magnitude of displacements

  • ”senders” (E), sender indices

  • ”receivers” (E), receiver indices

Returns:

Dict with model output. The keys must be at least one of the following:

  • ”acc” (N, dim), (normalized) acceleration

  • ”vel” (N, dim), (normalized) velocity

  • ”pos” (N, dim), (absolute) next position

SEGNN

Steerable E(3) equivariant GNN from Brandstetter et al.. SEGNN model, layers and feature transform.

Original implementation: https://github.com/RobDHess/Steerable-E3-GNN

Standalone implementation + validation: https://github.com/gerkone/segnn-jax

class lagrangebench.models.segnn.O3TensorProduct(output_irreps: ~e3nn_jax._src.irreps.Irreps, *, biases: bool = True, name: str | None = None, init_fn: ~typing.Callable = <function uniform_init>, gradient_normalization: str | float = 'element', path_normalization: str | float = 'element')[source]

O(3) equivariant linear parametrized tensor product layer.

Applies a linear (parametrized) tensor product of representations to the input(s).

\[\begin{align} tp(x, y) := \mathbf{x} \otimes_{CG}^{\mathcal{W}} \mathbf{y} \end{align}\]

where \(\mathcal{W}\) are learnable parameters.

Uses tensor_product + Linear instead of FullyConnectedTensorProduct. From e3nn 0.19.2 (https://github.com/e3nn/e3nn-jax/releases/tag/0.19.2), this is as fast as FullyConnectedTensorProduct.

__init__(output_irreps: ~e3nn_jax._src.irreps.Irreps, *, biases: bool = True, name: str | None = None, init_fn: ~typing.Callable = <function uniform_init>, gradient_normalization: str | float = 'element', path_normalization: str | float = 'element')[source]

Initialize the tensor product.

Parameters:
  • output_irreps – Output representation

  • biases – If set ot true will add biases

  • name – Name of the linear layer params

  • init_fn – Weight initialization function. Default is uniform.

  • gradient_normalization – Gradient normalization method. Default is “element”.

  • NOTE – gradient_normalization=”element” is the default in torch and haiku.

  • path_normalization – Path normalization method. Default is “element”

__call__(x: IrrepsArray, y: IrrepsArray | None = None) IrrepsArray[source]

Applies an O(3) equivariant linear parametrized tensor product layer.

Parameters:
  • x (IrrepsArray) – Left tensor

  • y (IrrepsArray) – Right tensor. If None it defaults to np.ones.

Returns:

The output to the weighted tensor product (IrrepsArray).

lagrangebench.models.segnn.O3TensorProductGate(output_irreps: Irreps, *, biases: bool = True, scalar_activation: Callable | None = None, gate_activation: Callable | None = None, name: str | None = None, init_fn: Callable | None = None) Callable[source]

Non-linear (gated) O(3) equivariant linear tensor product layer.

It applies a linear tensor product of representations to the input(s) and then a gated nonlinearity by Weiler et al..

The input representation is lifted to have gating scalars.

Parameters:
  • output_irreps – Output representation

  • biases – Add biases

  • scalar_activation – Activation function for scalars

  • gate_activation – Activation function for higher order

  • name – Name of the linear layer params

Returns:

Function that applies the gated tensor product layer.

lagrangebench.models.segnn.O3Embedding(embed_irreps: Irreps, embed_edges: bool = True) Callable[source]

Linear steerable embedding.

Embeds the graph nodes in the representation space :param embed_irreps:.

Parameters:
  • embed_irreps – Output representation

  • embed_edges – If true also embed edges/message passing features

Returns:

Function to embed graph nodes (and optionally edges)

lagrangebench.models.segnn.O3Decoder(latent_irreps: Irreps, output_irreps: Irreps, n_blocks: int = 1)[source]

Steerable decoder.

Parameters:
  • latent_irreps – Representation from the previous block

  • output_irreps – Output representation

  • n_blocks – Number of tensor product blocks in the decoder

Returns:

Decoded latent feature space to output space.

class lagrangebench.models.segnn.SEGNNLayer(output_irreps: ~e3nn_jax._src.irreps.Irreps, layer_idx: int, n_blocks: int = 2, norm: str | None = None, aggregate_fn: ~typing.Callable | None = <function segment_sum>)[source]

Steerable E(3) equivariant layer.

Applies a message passing step (GN) with equivariant message and update functions.

__init__(output_irreps: ~e3nn_jax._src.irreps.Irreps, layer_idx: int, n_blocks: int = 2, norm: str | None = None, aggregate_fn: ~typing.Callable | None = <function segment_sum>)[source]

Initialize the layer.

Parameters:
  • output_irreps – Layer output representation

  • layer_idx – Numbering of the layer

  • n_blocks – Number of tensor product n_blocks in the layer

  • norm – Normalization type. Either be None, ‘instance’ or ‘batch’

  • aggregate_fn – Message aggregation function. Defaults to sum.

__call__(st_graph: SteerableGraphsTuple) SteerableGraphsTuple[source]

Perform a message passing step.

Parameters:

st_graph – Input graph

Returns:

The updated graph

lagrangebench.models.segnn.weight_balanced_irreps(scalar_units: int, irreps_right: Irreps, lmax: int | None = None) Irreps[source]

Determine left irreps so that the tensor product with irreps_right has at least scalar_units weights.

Parameters:
  • scalar_units – Number of weights

  • irreps_right – Right irreps

  • lmax – Maximum L of the left irreps

Returns:

Left irreps

class lagrangebench.models.segnn.SEGNN(node_features_irreps: Irreps, edge_features_irreps: Irreps, scalar_units: int, lmax_hidden: int, lmax_attributes: int, output_irreps: Irreps, num_mp_steps: int, n_vels: int, velocity_aggregate: str = 'avg', homogeneous_particles: bool = True, norm: str | None = None, blocks_per_step: int = 2, embed_msg_features: bool = False)[source]

Steerable E(3) equivariant network by Brandstetter et al..

SEGNNs are E(3)-equivariant graph neural networks based around tensor products of representations. By design, SEGNNs allow for flexible scalar/vectorial inputs and outputs for both edges and attributes. The message passing is modified as follows:

\[\begin{split}\begin{align} \mathbf{m}_{ij} &= \textit{M}_{\mathbf{\hat{a}}_{ij}}\left( \mathbf{f}_i, \mathbf{f}_j, \| x_i - x_j \|^2 \right), \\ \mathbf{f}^{\prime}_i &= \textit{U}_{\mathbf{\hat{a}}_i}\left( \mathbf{f}_i, \sum_{j\in\mathcal{N}(i)} \mathbf{m}_{ij} \right) \end{align}\end{split}\]

\(\mathbf{\hat{a}}_{ij}\) and \(\mathbf{\hat{a}}_{i}\) are edge and node attributes and the operators \(\textit{M}_{\mathbf{\hat{a}}_{ij}}\) and \(\textit{U}_{\mathbf{\hat{a}}_{i}}\) are defined as a tensor product of representations \(\otimes_{CG}\) between the input and the attribues:

\[\begin{align} \textit{U}_{\mathbf{\hat{a}}_i} := \mathcal{W^n}_{\mathbf{\hat{a}}_i} (\dots \sigma ( \mathcal{W^0}_{\mathbf{\hat{a}}_i} \mathbf{f} ) ) \quad \text{with} \quad \mathcal{W}_{\mathbf{\hat{a}}_i} \mathbf{f} := \mathbf{f} \otimes_{CG}^{\mathcal{W}} \mathbf{\mathbf{\hat{a}}} \end{align}\]

where :math`mathbf{f} = [mathbf{f}_i,sum_{jinmathcal{N}(i)} mathbf{m}_{ij}]` are node features concatenated to the aggregated messages, \(\sigma\) is a gated non-linearity and \(\mathcal{W}\) are the tensor product parameters. \(\textit{M}_{\mathbf{\hat{a}}_{ij}}\) is similarly defined, but with the nonlinearity on the last layer, with edge attributes \(\mathbf{\hat{a}}_{ij}\) and :math`mathbf{f} = [ mathbf{f}_i, mathbf{f}_j, |x_i - x_j|^2 ]`

__init__(node_features_irreps: Irreps, edge_features_irreps: Irreps, scalar_units: int, lmax_hidden: int, lmax_attributes: int, output_irreps: Irreps, num_mp_steps: int, n_vels: int, velocity_aggregate: str = 'avg', homogeneous_particles: bool = True, norm: str | None = None, blocks_per_step: int = 2, embed_msg_features: bool = False)[source]

Initialize the network.

Parameters:
  • node_features_irreps – Irreps of the node features.

  • edge_features_irreps – Irreps of the additional message passing features.

  • scalar_units – Hidden units (lower bound). Actual number depends on lmax.

  • lmax_hidden – Maximum L of the hidden layer representations.

  • lmax_attributes – Maximum L of the attributes.

  • output_irreps – Output representation.

  • num_mp_steps – Number of message passing layers

  • n_vels – Number of velocities in the history.

  • velocity_aggregate – Velocity sequence aggregation method.

  • homogeneous_particles – If all particles are of homogeneous type.

  • norm – Normalization type. Either None, ‘instance’ or ‘batch’

  • blocks_per_step – Number of tensor product blocks in each message passing

  • embed_msg_features – Set to true to also embed edges/message passing features

__call__(sample: Tuple[Dict[str, Array], Array]) Dict[str, Array][source]

Forward pass.

We specify the dimensions of the inputs and outputs using the number of nodes N, the number of edges E, number of historic velocities K (=input_seq_length - 1), and the dimensionality of the feature vectors dim.

Parameters:

sample

Tuple with feature dictionary and particle type. Possible features

  • ”abs_pos” (N, K+1, dim), absolute positions

  • ”vel_hist” (N, K*dim), historical velocity sequence

  • ”vel_mag” (N,), velocity magnitudes

  • ”bound” (N, 2*dim), distance to boundaries

  • ”force” (N, dim), external force field

  • ”rel_disp” (E, dim), relative displacement vectors

  • ”rel_dist” (E, 1), relative distances, i.e. magnitude of displacements

  • ”senders” (E), sender indices

  • ”receivers” (E), receiver indices

Returns:

Dict with model output. The keys must be at least one of the following:

  • ”acc” (N, dim), (normalized) acceleration

  • ”vel” (N, dim), (normalized) velocity

  • ”pos” (N, dim), (absolute) next position

EGNN

E(n) equivariant GNN from Garcia Satorras et al.. EGNN model, layers and feature transform.

Original implementation: https://github.com/vgsatorras/egnn

Standalone implementation + validation: https://github.com/gerkone/egnn-jax

class lagrangebench.models.egnn.EGNNLayer(layer_num: int, hidden_size: int, output_size: int, displacement_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], shift_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], blocks: int = 1, act_fn: ~typing.Callable = <PjitFunction of <function silu>>, pos_aggregate_fn: ~typing.Callable | None = <function segment_sum>, msg_aggregate_fn: ~typing.Callable | None = <function segment_sum>, residual: bool = True, attention: bool = False, normalize: bool = False, tanh: bool = False, dt: float = 0.001, eps: float = 1e-08)[source]

E(n)-equivariant EGNN layer.

Applies a message passing step where the positions are corrected with the velocities and a learnable correction term \(\psi_x(\mathbf{h}_i^{(t+1)})\):

__init__(layer_num: int, hidden_size: int, output_size: int, displacement_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], shift_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], blocks: int = 1, act_fn: ~typing.Callable = <PjitFunction of <function silu>>, pos_aggregate_fn: ~typing.Callable | None = <function segment_sum>, msg_aggregate_fn: ~typing.Callable | None = <function segment_sum>, residual: bool = True, attention: bool = False, normalize: bool = False, tanh: bool = False, dt: float = 0.001, eps: float = 1e-08)[source]

Initialize the layer.

Parameters:
  • layer_num – layer number

  • hidden_size – hidden size

  • output_size – output size

  • displacement_fn – Displacement function for the acceleration computation.

  • shift_fn – Shift function for updating positions

  • blocks – number of blocks in the node and edge MLPs

  • act_fn – activation function

  • pos_aggregate_fn – position aggregation function

  • msg_aggregate_fn – message aggregation function

  • residual – whether to use residual connections

  • attention – whether to use attention

  • normalize – whether to normalize the coordinates

  • tanh – whether to use tanh in the position update

  • dt – position update step size

  • eps – small number to avoid division by zero

__call__(graph: GraphsTuple, pos: Array, vel: Array, edge_attribute: Array | None = None, node_attribute: Array | None = None) Tuple[GraphsTuple, Array][source]

Apply EGNN layer.

Parameters:
  • graph – Graph from previous step

  • pos – Node position, updated separately

  • vel – Node velocity

  • edge_attribute – Edge attribute (optional)

  • node_attribute – Node attribute (optional)

Returns:

Updated graph, node position

class lagrangebench.models.egnn.EGNN(hidden_size: int, output_size: int, dt: float, n_vels: int, displacement_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], shift_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], normalization_stats: ~typing.Dict[str, ~jax.Array] | None = None, act_fn: ~typing.Callable = <PjitFunction of <function silu>>, num_mp_steps: int = 4, homogeneous_particles: bool = True, residual: bool = True, attention: bool = False, normalize: bool = False, tanh: bool = False)[source]

E(n) Graph Neural Network by Garcia Satorras et al..

EGNN doesn’t require expensive higher-order representations in intermediate layers; instead it relies on separate scalar and vector channels, which are treated differently by EGNN layers. In this setup, EGNN is similar to a learnable numerical integrator:

\[\begin{split}\begin{align} \mathbf{m}_{ij}^{(t+1)} &= \phi_e \left( \mathbf{m}_{ij}^{(t)}, \mathbf{h}_i^{(t)}, \mathbf{h}_j^{(t)}, ||\mathbf{x}_i^{(t)} - \mathbf{x}_j^{(t)}||^2 \right) \\ \mathbf{\hat{m}}_{ij}^{(t+1)} &= (\mathbf{x}_i^{(t)} - \mathbf{x}_j^{(t)}) \phi_x(\mathbf{m}_{ij}^{(t+1)}) \end{align}\end{split}\]

And the node update with the integrator

\[\begin{split}\begin{align} \mathbf{h}_i^{(t+1)} &= \psi_h \left( \mathbf{h}_i^{(t)}, \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}^{(t+1)} \right) \\ \mathbf{x}_i^{(t+1)} &= \mathbf{x}_i^{(t)} + \mathbf{\hat{m}}_{ij}^{(t+1)} \psi_x(\mathbf{h}_i^{(t+1)}) \end{align}\end{split}\]

where \(\mathbf{m}_{ij}\) and \(\mathbf{\hat{m}}_{ij}\) are the scalar and vector messages respectively, and \(\mathbf{x}_{i}\) are the positions.

This implementation differs from the original in two places:

  • because our datasets can have periodic boundary conditions, we use shift and displacement functions that take care of it when operations on positions are done.

  • we apply a simple integrator after the last layer to get the acceleration.

__init__(hidden_size: int, output_size: int, dt: float, n_vels: int, displacement_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], shift_fn: ~typing.Callable[[~typing.Any, ~typing.Any], ~typing.Any], normalization_stats: ~typing.Dict[str, ~jax.Array] | None = None, act_fn: ~typing.Callable = <PjitFunction of <function silu>>, num_mp_steps: int = 4, homogeneous_particles: bool = True, residual: bool = True, attention: bool = False, normalize: bool = False, tanh: bool = False)[source]

Initialize the network.

Parameters:
  • hidden_size – Number of hidden features.

  • output_size – Number of features for ‘h’ at the output.

  • dt – Time step for position and velocity integration. Used to rescale the initialization of the correction MLP.

  • n_vels – Number of velocities in the history.

  • displacement_fn – Displacement function for the acceleration computation.

  • shift_fn – Shift function for updating positions.

  • normalization_stats – Normalization statistics for the input data.

  • act_fn – Non-linearity.

  • num_mp_steps – Number of layer for the EGNN

  • homogeneous_particles – If all particles are of homogeneous type.

  • residual – Whether to use residual connections.

  • attention – Whether to use attention or not.

  • normalize – Normalizes the coordinates messages such that: x^{l+1}_i = x^{l}_i + \sum(x_i - x_j)\phi_x(m_{ij})\|x_i - x_j\| It may help in the stability or generalization. Not used in the paper.

  • tanh – Sets a tanh activation function at the output of \phi_x(m_{ij}). It bounds the output of \phi_x(m_{ij}) which definitely improves in stability but it may decrease in accuracy. Not used in the paper.

__call__(sample: Tuple[Dict[str, Array], Array]) Dict[str, Array][source]

Forward pass.

We specify the dimensions of the inputs and outputs using the number of nodes N, the number of edges E, number of historic velocities K (=input_seq_length - 1), and the dimensionality of the feature vectors dim.

Parameters:

sample

Tuple with feature dictionary and particle type. Possible features

  • ”abs_pos” (N, K+1, dim), absolute positions

  • ”vel_hist” (N, K*dim), historical velocity sequence

  • ”vel_mag” (N,), velocity magnitudes

  • ”bound” (N, 2*dim), distance to boundaries

  • ”force” (N, dim), external force field

  • ”rel_disp” (E, dim), relative displacement vectors

  • ”rel_dist” (E, 1), relative distances, i.e. magnitude of displacements

  • ”senders” (E), sender indices

  • ”receivers” (E), receiver indices

Returns:

Dict with model output. The keys must be at least one of the following:

  • ”acc” (N, dim), (normalized) acceleration

  • ”vel” (N, dim), (normalized) velocity

  • ”pos” (N, dim), (absolute) next position

PaiNN

Modified PaiNN implementation for general vectorial inputs and outputs Schütt et al.. PaiNN model, layers and feature transform.

Original implementation: https://github.com/atomistic-machine-learning/schnetpack

Standalone implementation + validation: https://github.com/gerkone/painn-jax

class lagrangebench.models.painn.NodeFeatures(s: Array | None = None, v: Array | None = None)[source]

Simple container for PaiNN scalar and vectorial node features.

__getnewargs__()

Return self as a plain tuple. Used by copy and pickle.

static __new__(_cls, s: Array | None = None, v: Array | None = None)

Create new instance of NodeFeatures(s, v)

class lagrangebench.models.painn.GatedEquivariantBlock(hidden_size: int, scalar_out_channels: int, vector_out_channels: int, activation: ~typing.Callable = <PjitFunction of <function silu>>, scalar_activation: ~typing.Callable = None, eps: float = 1e-08, name: str = 'gated_equivariant_block')[source]

Gated equivariant block (restricted to vectorial features).

https://i.imgur.com/EMlg2Qi.png
__init__(hidden_size: int, scalar_out_channels: int, vector_out_channels: int, activation: ~typing.Callable = <PjitFunction of <function silu>>, scalar_activation: ~typing.Callable | None = None, eps: float = 1e-08, name: str = 'gated_equivariant_block')[source]

Initialize the layer.

Parameters:
  • hidden_size – Number of hidden channels.

  • scalar_out_channels – Number of scalar output channels.

  • vector_out_channels – Number of vector output channels.

  • activation – Gate activation function.

  • scalar_activation – Activation function for the scalar output.

  • eps – Constant added in norm to prevent derivation instabilities.

  • name – Name of the module.

__call__(s: Array, v: Array) Tuple[Array, Array][source]

Call self as a function.

lagrangebench.models.painn.gaussian_rbf(n_rbf: int, cutoff: float, start: float = 0.0, centered: bool = False, trainable: bool = False) Callable[[Array], Callable][source]

Gaussian radial basis functions.

Parameters:
  • n_rbf – total number of Gaussian functions, \(N_g\).

  • cutoff – center of last Gaussian function, \(\mu_{N_g}\)

  • start – center of first Gaussian function, \(\mu_0\).

  • trainable – If True, widths and offset of Gaussian functions learnable.

lagrangebench.models.painn.cosine_cutoff(cutoff: float) Callable[[Array], Callable][source]

Behler-style cosine cutoff.

\[\begin{split}f(r) = \begin{cases} 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right] & r < r_\text{cutoff} \\ 0 & r \geqslant r_\text{cutoff} \\ \end{cases}\end{split}\]
Parameters:

cutoff (float) – cutoff radius.

lagrangebench.models.painn.PaiNNReadout(hidden_size: int, out_channels: int = 1, activation: ~typing.Callable = <PjitFunction of <function silu>>, blocks: int = 2, eps: float = 1e-08) Callable[[GraphsTuple], Tuple[Array, Array]][source]

PaiNN readout block.

Parameters:
  • hidden_size – Number of hidden channels.

  • scalar_out_channels – Number of scalar/vector output channels.

  • activation – Activation function.

  • blocks – Number of readout blocks.

Returns:

Configured readout function.

class lagrangebench.models.painn.PaiNNLayer(hidden_size: int, layer_num: int, activation: ~typing.Callable = <PjitFunction of <function silu>>, blocks: int = 2, aggregate_fn: ~typing.Callable = <function segment_sum>, eps: float = 1e-08)[source]

PaiNN interaction block.

__init__(hidden_size: int, layer_num: int, activation: ~typing.Callable = <PjitFunction of <function silu>>, blocks: int = 2, aggregate_fn: ~typing.Callable = <function segment_sum>, eps: float = 1e-08)[source]

Initialize the PaiNN layer, made up of an interaction block and a mixing block.

Parameters:
  • hidden_size – Number of node features.

  • activation – Activation function.

  • layer_num – Numbering of the layer.

  • blocks – Number of layers in the context networks.

  • aggregate_fn – Function to aggregate the neighbors.

  • eps – Constant added in norm to prevent derivation instabilities.

__call__(graph: GraphsTuple, Wij: Array)[source]

Compute interaction output.

Parameters:
  • graph (jraph.GraphsTuple) – Input graph.

  • Wij (jnp.ndarray) – Filter.

Returns:

atom features after interaction

class lagrangebench.models.painn.PaiNN(hidden_size: int, output_size: int, num_mp_steps: int, radial_basis_fn: ~typing.Callable, cutoff_fn: ~typing.Callable, n_vels: int, homogeneous_particles: bool = True, activation: ~typing.Callable = <PjitFunction of <function silu>>, shared_interactions: bool = False, shared_filters: bool = False, eps: float = 1e-08)[source]

Polarizable interaction Neural Network by Schütt et al..

In order to accomodate general inputs/outputs, this PaiNN is different from the original in a few ways; the main change is that inputs vectors are not initialized to 0 anymore but to the time average of velocity.

https://i.imgur.com/NxZ2rPi.png
__init__(hidden_size: int, output_size: int, num_mp_steps: int, radial_basis_fn: ~typing.Callable, cutoff_fn: ~typing.Callable, n_vels: int, homogeneous_particles: bool = True, activation: ~typing.Callable = <PjitFunction of <function silu>>, shared_interactions: bool = False, shared_filters: bool = False, eps: float = 1e-08)[source]

Initialize the model.

Parameters:
  • hidden_size – Determines the size of each embedding vector.

  • output_size – Number of output features.

  • num_mp_steps – Number of interaction blocks.

  • radial_basis_fn – Expands inter-particle distances in a basis set.

  • cutoff_fn – Cutoff function.

  • n_vels – Number of historical velocities.

  • homogeneous_particles – If all particles are of homogeneous type.

  • activation – Activation function.

  • shared_interactions – If True, share the weights across interaction blocks.

  • shared_filters – If True, share the weights across filter networks.

  • eps – Constant added in norm to prevent derivation instabilities.

__call__(sample: Tuple[Dict[str, Array], Array]) Dict[str, Array][source]

Call self as a function.

Linear

Simple baseline linear model.

class lagrangebench.models.linear.Linear(dim_out)[source]

Model defining linear relation between input nodes and targets.

\(\mathbf{a}_i = \mathbf{W} \mathbf{x}_i\) where \(\mathbf{a}_i\) are the output accelerations, \(\mathbf{W}\) is a learnable weight matrix and \(\mathbf{x}_i\) are input features.

__init__(dim_out)[source]

Initialize the model.

Parameters:

dim_out – Output dimensionality.

__call__(sample: Tuple[Dict[str, Array], ndarray]) Dict[str, Array][source]

Forward pass.

We specify the dimensions of the inputs and outputs using the number of nodes N, the number of edges E, number of historic velocities K (=input_seq_length - 1), and the dimensionality of the feature vectors dim.

Parameters:

sample

Tuple with feature dictionary and particle type. Possible features

  • ”abs_pos” (N, K+1, dim), absolute positions

  • ”vel_hist” (N, K*dim), historical velocity sequence

  • ”vel_mag” (N,), velocity magnitudes

  • ”bound” (N, 2*dim), distance to boundaries

  • ”force” (N, dim), external force field

  • ”rel_disp” (E, dim), relative displacement vectors

  • ”rel_dist” (E, 1), relative distances, i.e. magnitude of displacements

  • ”senders” (E), sender indices

  • ”receivers” (E), receiver indices

Returns:

Dict with model output. The keys must be at least one of the following:

  • ”acc” (N, dim), (normalized) acceleration

  • ”vel” (N, dim), (normalized) velocity

  • ”pos” (N, dim), (absolute) next position

Utils

class lagrangebench.models.utils.LinearXav(output_size: int, with_bias: bool = True, w_init: Callable[[Sequence[int], Any], Array] | None = None, b_init: Callable[[Sequence[int], Any], Array] | None = None, name: str | None = None)[source]

Linear layer with Xavier init. Avoid distracting ‘w_init’ everywhere.

__init__(output_size: int, with_bias: bool = True, w_init: Callable[[Sequence[int], Any], Array] | None = None, b_init: Callable[[Sequence[int], Any], Array] | None = None, name: str | None = None)[source]

Constructs the Linear module.

Parameters:
  • output_size – Output dimensionality.

  • with_bias – Whether to add a bias to the output.

  • w_init – Optional initializer for weights. By default, uses random values from truncated normal, with stddev 1 / sqrt(fan_in). See https://arxiv.org/abs/1502.03167v3.

  • b_init – Optional initializer for bias. By default, zero.

  • name – Name of the module.

class lagrangebench.models.utils.MLPXav(output_sizes: ~typing.Iterable[int], with_bias: bool = True, w_init: ~typing.Callable[[~collections.abc.Sequence[int], ~typing.Any], ~jax.Array] | None = None, b_init: ~typing.Callable[[~collections.abc.Sequence[int], ~typing.Any], ~jax.Array] | None = None, activation: ~typing.Callable = <PjitFunction of <function silu>>, activate_final: bool = False, name: str | None = None)[source]

MLP layer with Xavier init. Avoid distracting ‘w_init’ everywhere.

__init__(output_sizes: ~typing.Iterable[int], with_bias: bool = True, w_init: ~typing.Callable[[~collections.abc.Sequence[int], ~typing.Any], ~jax.Array] | None = None, b_init: ~typing.Callable[[~collections.abc.Sequence[int], ~typing.Any], ~jax.Array] | None = None, activation: ~typing.Callable = <PjitFunction of <function silu>>, activate_final: bool = False, name: str | None = None)[source]

Constructs an MLP.

Parameters:
  • output_sizes – Sequence of layer sizes.

  • w_init – Initializer for Linear weights.

  • b_init – Initializer for Linear bias. Must be None if with_bias=False.

  • with_bias – Whether or not to apply a bias in each layer.

  • activation – Activation function to apply between Linear layers. Defaults to ReLU.

  • activate_final – Whether or not to activate the final layer of the MLP.

  • name – Optional name for this module.

Raises:

ValueError – If with_bias is False and b_init is not None.

class lagrangebench.models.utils.SteerableGraphsTuple(graph: GraphsTuple, node_attributes: IrrepsArray | None = None, edge_attributes: IrrepsArray | None = None, additional_message_features: IrrepsArray | None = None)[source]

Pack (steerable) node and edge attributes with jraph.GraphsTuple.

graph

jraph.GraphsTuple, graph structure

Type:

jraph._src.graph.GraphsTuple

node_attributes

(N, irreps.dim), node attributes \(\mathbf{\hat{a}}_i\)

Type:

e3nn_jax._src.irreps_array.IrrepsArray | None

edge_attributes

(E, irreps.dim), edge attributes \(\mathbf{\hat{a}}_{ij}\)

Type:

e3nn_jax._src.irreps_array.IrrepsArray | None

additional_message_features

(E, edge_dim), optional message features

Type:

e3nn_jax._src.irreps_array.IrrepsArray | None

lagrangebench.models.utils.node_irreps(metadata: Dict, input_seq_length: int, has_external_force: bool, has_magnitudes: bool, has_homogeneous_particles: bool) str[source]

Compute input node irreps based on which features are available.

lagrangebench.models.utils.build_mlp(latent_size, output_size, num_hidden_layers, is_layer_norm=True, **kwds: Dict)[source]

MLP generation helper using Haiku.

lagrangebench.models.utils.features_2d_to_3d(features)[source]

Add zeros in the z component of 2D features.