Source code for lagrangebench.models.segnn

"""
Steerable E(3) equivariant GNN from
`Brandstetter et al. <https://arxiv.org/abs/2110.02905>`_.
SEGNN model, layers and feature transform.

Original implementation: https://github.com/RobDHess/Steerable-E3-GNN

Standalone implementation + validation: https://github.com/gerkone/segnn-jax
"""


import warnings
from math import prod
from typing import Any, Callable, Dict, Optional, Tuple, Union

import e3nn_jax as e3nn
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
from e3nn_jax import Irreps, IrrepsArray
from jax.tree_util import Partial, tree_map

from lagrangebench.utils import NodeType

from .base import BaseModel
from .utils import SteerableGraphsTuple, features_2d_to_3d


def uniform_init(
    name: str,
    path_shape: Tuple[int, ...],
    weight_std: float,
    dtype: jnp.dtype = jnp.float32,
) -> jnp.ndarray:
    return hk.get_parameter(
        name,
        shape=path_shape,
        dtype=dtype,
        init=hk.initializers.RandomUniform(minval=-weight_std, maxval=weight_std),
    )


[docs] class O3TensorProduct(hk.Module): r""" O(3) equivariant linear parametrized tensor product layer. Applies a linear (parametrized) tensor product of representations to the input(s). .. math:: \begin{align} tp(x, y) := \mathbf{x} \otimes_{CG}^{\mathcal{W}} \mathbf{y} \end{align} where :math:`\mathcal{W}` are learnable parameters. Uses :code:`tensor_product` + :code:`Linear` instead of FullyConnectedTensorProduct. From e3nn 0.19.2 (https://github.com/e3nn/e3nn-jax/releases/tag/0.19.2), this is as fast as FullyConnectedTensorProduct. """
[docs] def __init__( self, output_irreps: e3nn.Irreps, *, biases: bool = True, name: Optional[str] = None, init_fn: Callable = uniform_init, gradient_normalization: Union[str, float] = "element", path_normalization: Union[str, float] = "element", ): """Initialize the tensor product. Args: output_irreps: Output representation biases: If set ot true will add biases name: Name of the linear layer params init_fn: Weight initialization function. Default is uniform. gradient_normalization: Gradient normalization method. Default is "element". NOTE: gradient_normalization="element" is the default in torch and haiku. path_normalization: Path normalization method. Default is "element" """ super().__init__(name=name) if not isinstance(output_irreps, e3nn.Irreps): output_irreps = e3nn.Irreps(output_irreps) self.output_irreps = output_irreps self._linear = e3nn.haiku.Linear( self.output_irreps, get_parameter=init_fn, biases=(biases and "0e" in self.output_irreps), gradient_normalization=gradient_normalization, path_normalization=path_normalization, )
def _check_input( self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None ) -> Tuple[e3nn.IrrepsArray, e3nn.IrrepsArray]: if not y: y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype)) if x.irreps.lmax == 0 and y.irreps.lmax == 0 and self.output_irreps.lmax > 0: warnings.warn( f"The specified output irreps ({self.output_irreps}) are not scalars " "but both operands are. This can have undesired behaviour (NaN). Try " "redistributing them into scalars or choose higher orders." ) miss = self.output_irreps.filter(drop=e3nn.tensor_product(x.irreps, y.irreps)) if len(miss) > 0: warnings.warn(f"Output irreps: '{miss}' are unreachable and were ignored.") return x, y
[docs] def __call__( self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None ) -> e3nn.IrrepsArray: """Applies an O(3) equivariant linear parametrized tensor product layer. Args: x (IrrepsArray): Left tensor y (IrrepsArray): Right tensor. If None it defaults to np.ones. Returns: The output to the weighted tensor product (IrrepsArray). """ x, y = self._check_input(x, y) # tensor product + linear tp = self._linear(e3nn.tensor_product(x, y)) return tp
[docs] def O3TensorProductGate( output_irreps: e3nn.Irreps, *, biases: bool = True, scalar_activation: Optional[Callable] = None, gate_activation: Optional[Callable] = None, name: Optional[str] = None, init_fn: Optional[Callable] = None, ) -> Callable: r"""Non-linear (gated) O(3) equivariant linear tensor product layer. It applies a linear tensor product of representations to the input(s) and then a gated nonlinearity by `Weiler et al. <https://arxiv.org/abs/1807.02547>`_. The input representation is lifted to have gating scalars. Args: output_irreps: Output representation biases: Add biases scalar_activation: Activation function for scalars gate_activation: Activation function for higher order name: Name of the linear layer params Returns: Function that applies the gated tensor product layer. """ if not isinstance(output_irreps, e3nn.Irreps): output_irreps = e3nn.Irreps(output_irreps) # lift output with gating scalars gate_irreps = e3nn.Irreps( f"{output_irreps.num_irreps - output_irreps.count('0e')}x0e" ) tensor_product = O3TensorProduct( (gate_irreps + output_irreps).regroup(), biases=biases, name=name, init_fn=init_fn, ) if not scalar_activation: scalar_activation = jax.nn.silu if not gate_activation: gate_activation = jax.nn.sigmoid def _gated_tensor_product( x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None, **kwargs ) -> e3nn.IrrepsArray: tp = tensor_product(x, y, **kwargs) return e3nn.gate(tp, even_act=scalar_activation, odd_gate_act=gate_activation) return _gated_tensor_product
[docs] def O3Embedding(embed_irreps: Irreps, embed_edges: bool = True) -> Callable: """Linear steerable embedding. Embeds the graph nodes in the representation space :param embed_irreps:. Args: embed_irreps: Output representation embed_edges: If true also embed edges/message passing features Returns: Function to embed graph nodes (and optionally edges) """ def _embedding( st_graph: SteerableGraphsTuple, ) -> SteerableGraphsTuple: graph = st_graph.graph nodes = O3TensorProduct( embed_irreps, name="embedding_nodes", )(graph.nodes, st_graph.node_attributes) st_graph = st_graph._replace(graph=graph._replace(nodes=nodes)) # NOTE edge embedding is not in the original paper but can get good results if embed_edges: additional_message_features = O3TensorProduct( embed_irreps, name="embedding_msg_features" ) (st_graph.additional_message_features, st_graph.edge_attributes) st_graph = st_graph._replace( additional_message_features=additional_message_features ) return st_graph return _embedding
[docs] def O3Decoder( latent_irreps: Irreps, output_irreps: Irreps, n_blocks: int = 1, ): """Steerable decoder. Args: latent_irreps: Representation from the previous block output_irreps: Output representation n_blocks: Number of tensor product blocks in the decoder Returns: Decoded latent feature space to output space. """ def _decoder(st_graph: SteerableGraphsTuple): nodes = st_graph.graph.nodes for i in range(n_blocks): nodes = O3TensorProductGate(latent_irreps, name=f"readout_{i}")( nodes, st_graph.node_attributes ) return O3TensorProduct(output_irreps, name="output")( nodes, st_graph.node_attributes ) return _decoder
[docs] class SEGNNLayer(hk.Module): """ Steerable E(3) equivariant layer. Applies a message passing step (GN) with equivariant message and update functions. """
[docs] def __init__( self, output_irreps: Irreps, layer_idx: int, n_blocks: int = 2, norm: Optional[str] = None, aggregate_fn: Optional[Callable] = jraph.segment_sum, ): """ Initialize the layer. Args: output_irreps: Layer output representation layer_idx: Numbering of the layer n_blocks: Number of tensor product n_blocks in the layer norm: Normalization type. Either be None, 'instance' or 'batch' aggregate_fn: Message aggregation function. Defaults to sum. """ super().__init__(f"layer_{layer_idx}") assert norm in ["batch", "instance", "none", None], f"Unknown norm '{norm}'" self._output_irreps = output_irreps self._n_blocks = n_blocks self._norm = norm self._aggregate_fn = aggregate_fn
def _message( self, edge_attribute: IrrepsArray, additional_message_features: IrrepsArray, edge_features: Any, incoming: IrrepsArray, outgoing: IrrepsArray, globals_: Any, ) -> IrrepsArray: """Steerable equivariant message function.""" _ = globals_ _ = edge_features # create messages msg = e3nn.concatenate([incoming, outgoing], axis=-1) if additional_message_features is not None: msg = e3nn.concatenate([msg, additional_message_features], axis=-1) # message mlp (phi_m in the paper) steered by edge attributeibutes for i in range(self._n_blocks): msg = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")( msg, edge_attribute ) # NOTE: original implementation only applied batch norm to messages if self._norm == "batch": msg = e3nn.haiku.BatchNorm(irreps=self._output_irreps)(msg) return msg def _update( self, node_attribute: IrrepsArray, nodes: IrrepsArray, senders: Any, msg: IrrepsArray, globals_: Any, ) -> IrrepsArray: """Steerable equivariant update function.""" _ = senders _ = globals_ x = e3nn.concatenate([nodes, msg], axis=-1) # update mlp (phi_f in the paper) steered by node attributeibutes for i in range(self._n_blocks - 1): x = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")( x, node_attribute ) # last update layer without activation update = O3TensorProduct(self._output_irreps, name=f"tp_{self._n_blocks - 1}")( x, node_attribute ) # residual connection nodes += update # message norm if self._norm in ["batch", "instance"]: nodes = e3nn.haiku.BatchNorm( irreps=self._output_irreps, instance=(self._norm == "instance"), )(nodes) return nodes
[docs] def __call__(self, st_graph: SteerableGraphsTuple) -> SteerableGraphsTuple: """Perform a message passing step. Args: st_graph: Input graph Returns: The updated graph """ # NOTE node_attributes, edge_attributes and additional_message_features # are never updated within the message passing layers return st_graph._replace( graph=jraph.GraphNetwork( update_node_fn=Partial(self._update, st_graph.node_attributes), update_edge_fn=Partial( self._message, st_graph.edge_attributes, st_graph.additional_message_features, ), aggregate_edges_for_nodes_fn=self._aggregate_fn, )(st_graph.graph) )
[docs] def weight_balanced_irreps( scalar_units: int, irreps_right: Irreps, lmax: int = None ) -> Irreps: """ Determine left irreps so that the tensor product with irreps_right has at least scalar_units weights. Args: scalar_units: Number of weights irreps_right: Right irreps lmax: Maximum L of the left irreps Returns: Left irreps """ # irrep order if lmax is None: lmax = irreps_right.lmax # linear layer with squdare weight matrix linear_weights = scalar_units**2 # raise hidden features until enough weigths n = 0 while True: n += 1 irreps_left = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() # number of paths tp_weights = sum( prod([irreps_left[i_1].mul ** 2, irreps_right[i_2].mul]) for i_1, (_, ir_1) in enumerate(irreps_left) for i_2, (_, ir_2) in enumerate(irreps_right) for _, (_, ir_out) in enumerate(irreps_left) if ir_out in ir_1 * ir_2 ) if tp_weights >= linear_weights: break return Irreps(irreps_left)
[docs] class SEGNN(BaseModel): r""" Steerable E(3) equivariant network by `Brandstetter et al. <https://arxiv.org/abs/2110.02905>`_. SEGNNs are E(3)-equivariant graph neural networks based around tensor products of representations. By design, SEGNNs allow for flexible scalar/vectorial inputs and outputs for both edges and attributes. The message passing is modified as follows: .. math:: \begin{align} \mathbf{m}_{ij} &= \textit{M}_{\mathbf{\hat{a}}_{ij}}\left( \mathbf{f}_i, \mathbf{f}_j, \| x_i - x_j \|^2 \right), \\ \mathbf{f}^{\prime}_i &= \textit{U}_{\mathbf{\hat{a}}_i}\left( \mathbf{f}_i, \sum_{j\in\mathcal{N}(i)} \mathbf{m}_{ij} \right) \end{align} :math:`\mathbf{\hat{a}}_{ij}` and :math:`\mathbf{\hat{a}}_{i}` are edge and node attributes and the operators :math:`\textit{M}_{\mathbf{\hat{a}}_{ij}}` and :math:`\textit{U}_{\mathbf{\hat{a}}_{i}}` are defined as a tensor product of representations :math:`\otimes_{CG}` between the input and the attribues: .. math:: \begin{align} \textit{U}_{\mathbf{\hat{a}}_i} := \mathcal{W^n}_{\mathbf{\hat{a}}_i} (\dots \sigma ( \mathcal{W^0}_{\mathbf{\hat{a}}_i} \mathbf{f} ) ) \quad \text{with} \quad \mathcal{W}_{\mathbf{\hat{a}}_i} \mathbf{f} := \mathbf{f} \otimes_{CG}^{\mathcal{W}} \mathbf{\mathbf{\hat{a}}} \end{align} where :math`\mathbf{f} = \[\mathbf{f}_i,\sum_{j\in\mathcal{N}(i)} \mathbf{m}_{ij}\]` are node features concatenated to the aggregated messages, :math:`\sigma` is a gated non-linearity and :math:`\mathcal{W}` are the tensor product parameters. :math:`\textit{M}_{\mathbf{\hat{a}}_{ij}}` is similarly defined, but with the nonlinearity on the last layer, with edge attributes :math:`\mathbf{\hat{a}}_{ij}` and :math`\mathbf{f} = \[ \mathbf{f}_i, \mathbf{f}_j, \|x_i - x_j\|^2 \]` """
[docs] def __init__( self, node_features_irreps: Irreps, edge_features_irreps: Irreps, scalar_units: int, lmax_hidden: int, lmax_attributes: int, output_irreps: Irreps, num_mp_steps: int, n_vels: int, velocity_aggregate: str = "avg", homogeneous_particles: bool = True, norm: Optional[str] = None, blocks_per_step: int = 2, embed_msg_features: bool = False, ): """ Initialize the network. Args: node_features_irreps: Irreps of the node features. edge_features_irreps: Irreps of the additional message passing features. scalar_units: Hidden units (lower bound). Actual number depends on lmax. lmax_hidden: Maximum L of the hidden layer representations. lmax_attributes: Maximum L of the attributes. output_irreps: Output representation. num_mp_steps: Number of message passing layers n_vels: Number of velocities in the history. velocity_aggregate: Velocity sequence aggregation method. homogeneous_particles: If all particles are of homogeneous type. norm: Normalization type. Either None, 'instance' or 'batch' blocks_per_step: Number of tensor product blocks in each message passing embed_msg_features: Set to true to also embed edges/message passing features """ super().__init__() # network self._attribute_irreps = Irreps.spherical_harmonics(lmax_attributes) self._hidden_irreps = weight_balanced_irreps( scalar_units, self._attribute_irreps, lmax_hidden ) self._output_irreps = output_irreps self._num_mp_steps = num_mp_steps self._embed_msg_features = embed_msg_features self._norm = norm self._blocks_per_step = blocks_per_step self._embedding = O3Embedding( self._hidden_irreps, embed_edges=self._embed_msg_features, ) self._decoder = O3Decoder( latent_irreps=self._hidden_irreps, output_irreps=output_irreps, n_blocks=self._blocks_per_step, ) # transform assert velocity_aggregate in [ "avg", "last", ], "Invalid velocity aggregate. Must be one of 'avg', 'sum' or 'last'." self._node_features_irreps = node_features_irreps self._edge_features_irreps = edge_features_irreps self._velocity_aggregate = velocity_aggregate self._n_vels = n_vels self._homogeneous_particles = homogeneous_particles
def _transform( self, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray ) -> Tuple[SteerableGraphsTuple, int]: """Convert physical features to SteerableGraphsTuple for segnn.""" dim = features["vel_hist"].shape[1] // self._n_vels assert ( dim == 3 or dim == 2 ), "The velocity history should be of shape (n_nodes, n_vels * 3)." n_nodes = features["vel_hist"].shape[0] features["vel_hist"] = features["vel_hist"].reshape(n_nodes, self._n_vels, dim) if dim == 2: # add zeros for z component for E(3) equivariance features = features_2d_to_3d(features) if self._n_vels == 1: vel = jnp.squeeze(features["vel_hist"]) else: if self._velocity_aggregate == "avg": vel = jnp.mean(features["vel_hist"], 1) if self._velocity_aggregate == "last": vel = features["vel_hist"][:, -1, :] rel_pos = features["rel_disp"] edge_attributes = e3nn.spherical_harmonics( self._attribute_irreps, rel_pos, normalize=True, normalization="integral" ) vel_embedding = e3nn.spherical_harmonics( self._attribute_irreps, vel, normalize=True, normalization="integral" ) # scatter edge attributes to nodes (density) scattered_edges = tree_map( lambda e: jraph.segment_mean(e, features["receivers"], n_nodes), edge_attributes, ) # node attributes as velocities + edge "density". Scalar default to 1.0 node_attributes = e3nn.IrrepsArray( vel_embedding.irreps, (vel_embedding + scattered_edges).array.at[:, 0].set(1.0), ) node_features = [features["vel_hist"].reshape(n_nodes, self._n_vels * 3)] node_features += [ features[k] for k in ["vel_mag", "bound", "force"] if k in features ] node_features = jnp.concatenate(node_features, axis=-1) if not self._homogeneous_particles: particles = jax.nn.one_hot(particle_type, NodeType.SIZE) node_features = jnp.concatenate([node_features, particles], axis=-1) edge_features = [features[k] for k in ["rel_disp", "rel_dist"] if k in features] edge_features = jnp.concatenate(edge_features, axis=-1) feature_graph = jraph.GraphsTuple( nodes=IrrepsArray(self._node_features_irreps, node_features), edges=None, senders=features["senders"], receivers=features["receivers"], n_node=jnp.array([n_nodes]), n_edge=jnp.array([len(features["senders"])]), globals=None, ) st_graph = SteerableGraphsTuple( graph=feature_graph, node_attributes=node_attributes, edge_attributes=edge_attributes, additional_message_features=IrrepsArray( self._edge_features_irreps, edge_features ), ) return st_graph, dim def _postprocess(self, nodes: IrrepsArray, dim: int) -> Dict[str, jnp.ndarray]: acc = jnp.squeeze(nodes.array) if dim == 2: acc = acc[:, :2] return {"acc": acc}
[docs] def __call__( self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray] ) -> Dict[str, jnp.ndarray]: # feature transformation st_graph, dim = self._transform(*sample) # node (and edge) embedding st_graph = self._embedding(st_graph) # message passing for n in range(self._num_mp_steps): st_graph = SEGNNLayer( self._hidden_irreps, n, n_blocks=self._blocks_per_step, norm=self._norm )(st_graph) # readout nodes = self._decoder(st_graph) out = self._postprocess(nodes, dim) return out