"""
Steerable E(3) equivariant GNN from
`Brandstetter et al. <https://arxiv.org/abs/2110.02905>`_.
SEGNN model, layers and feature transform.
Original implementation: https://github.com/RobDHess/Steerable-E3-GNN
Standalone implementation + validation: https://github.com/gerkone/segnn-jax
"""
import warnings
from math import prod
from typing import Any, Callable, Dict, Optional, Tuple, Union
import e3nn_jax as e3nn
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
from e3nn_jax import Irreps, IrrepsArray
from jax.tree_util import Partial, tree_map
from lagrangebench.utils import NodeType
from .base import BaseModel
from .utils import SteerableGraphsTuple, features_2d_to_3d
def uniform_init(
name: str,
path_shape: Tuple[int, ...],
weight_std: float,
dtype: jnp.dtype = jnp.float32,
) -> jnp.ndarray:
return hk.get_parameter(
name,
shape=path_shape,
dtype=dtype,
init=hk.initializers.RandomUniform(minval=-weight_std, maxval=weight_std),
)
[docs]
class O3TensorProduct(hk.Module):
r"""
O(3) equivariant linear parametrized tensor product layer.
Applies a linear (parametrized) tensor product of representations to the input(s).
.. math::
\begin{align}
tp(x, y) := \mathbf{x} \otimes_{CG}^{\mathcal{W}} \mathbf{y}
\end{align}
where :math:`\mathcal{W}` are learnable parameters.
Uses :code:`tensor_product` + :code:`Linear` instead of FullyConnectedTensorProduct.
From e3nn 0.19.2 (https://github.com/e3nn/e3nn-jax/releases/tag/0.19.2), this is
as fast as FullyConnectedTensorProduct.
"""
[docs]
def __init__(
self,
output_irreps: e3nn.Irreps,
*,
biases: bool = True,
name: Optional[str] = None,
init_fn: Callable = uniform_init,
gradient_normalization: Union[str, float] = "element",
path_normalization: Union[str, float] = "element",
):
"""Initialize the tensor product.
Args:
output_irreps: Output representation
biases: If set ot true will add biases
name: Name of the linear layer params
init_fn: Weight initialization function. Default is uniform.
gradient_normalization: Gradient normalization method. Default is "element".
NOTE: gradient_normalization="element" is the default in torch and haiku.
path_normalization: Path normalization method. Default is "element"
"""
super().__init__(name=name)
if not isinstance(output_irreps, e3nn.Irreps):
output_irreps = e3nn.Irreps(output_irreps)
self.output_irreps = output_irreps
self._linear = e3nn.haiku.Linear(
self.output_irreps,
get_parameter=init_fn,
biases=(biases and "0e" in self.output_irreps),
gradient_normalization=gradient_normalization,
path_normalization=path_normalization,
)
def _check_input(
self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None
) -> Tuple[e3nn.IrrepsArray, e3nn.IrrepsArray]:
if not y:
y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype))
if x.irreps.lmax == 0 and y.irreps.lmax == 0 and self.output_irreps.lmax > 0:
warnings.warn(
f"The specified output irreps ({self.output_irreps}) are not scalars "
"but both operands are. This can have undesired behaviour (NaN). Try "
"redistributing them into scalars or choose higher orders."
)
miss = self.output_irreps.filter(drop=e3nn.tensor_product(x.irreps, y.irreps))
if len(miss) > 0:
warnings.warn(f"Output irreps: '{miss}' are unreachable and were ignored.")
return x, y
[docs]
def __call__(
self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None
) -> e3nn.IrrepsArray:
"""Applies an O(3) equivariant linear parametrized tensor product layer.
Args:
x (IrrepsArray): Left tensor
y (IrrepsArray): Right tensor. If None it defaults to np.ones.
Returns:
The output to the weighted tensor product (IrrepsArray).
"""
x, y = self._check_input(x, y)
# tensor product + linear
tp = self._linear(e3nn.tensor_product(x, y))
return tp
[docs]
def O3TensorProductGate(
output_irreps: e3nn.Irreps,
*,
biases: bool = True,
scalar_activation: Optional[Callable] = None,
gate_activation: Optional[Callable] = None,
name: Optional[str] = None,
init_fn: Optional[Callable] = None,
) -> Callable:
r"""Non-linear (gated) O(3) equivariant linear tensor product layer.
It applies a linear tensor product of representations to the input(s) and then
a gated nonlinearity by `Weiler et al. <https://arxiv.org/abs/1807.02547>`_.
The input representation is lifted to have gating scalars.
Args:
output_irreps: Output representation
biases: Add biases
scalar_activation: Activation function for scalars
gate_activation: Activation function for higher order
name: Name of the linear layer params
Returns:
Function that applies the gated tensor product layer.
"""
if not isinstance(output_irreps, e3nn.Irreps):
output_irreps = e3nn.Irreps(output_irreps)
# lift output with gating scalars
gate_irreps = e3nn.Irreps(
f"{output_irreps.num_irreps - output_irreps.count('0e')}x0e"
)
tensor_product = O3TensorProduct(
(gate_irreps + output_irreps).regroup(),
biases=biases,
name=name,
init_fn=init_fn,
)
if not scalar_activation:
scalar_activation = jax.nn.silu
if not gate_activation:
gate_activation = jax.nn.sigmoid
def _gated_tensor_product(
x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None, **kwargs
) -> e3nn.IrrepsArray:
tp = tensor_product(x, y, **kwargs)
return e3nn.gate(tp, even_act=scalar_activation, odd_gate_act=gate_activation)
return _gated_tensor_product
[docs]
def O3Embedding(embed_irreps: Irreps, embed_edges: bool = True) -> Callable:
"""Linear steerable embedding.
Embeds the graph nodes in the representation space :param embed_irreps:.
Args:
embed_irreps: Output representation
embed_edges: If true also embed edges/message passing features
Returns:
Function to embed graph nodes (and optionally edges)
"""
def _embedding(
st_graph: SteerableGraphsTuple,
) -> SteerableGraphsTuple:
graph = st_graph.graph
nodes = O3TensorProduct(
embed_irreps,
name="embedding_nodes",
)(graph.nodes, st_graph.node_attributes)
st_graph = st_graph._replace(graph=graph._replace(nodes=nodes))
# NOTE edge embedding is not in the original paper but can get good results
if embed_edges:
additional_message_features = O3TensorProduct(
embed_irreps, name="embedding_msg_features"
)
(st_graph.additional_message_features, st_graph.edge_attributes)
st_graph = st_graph._replace(
additional_message_features=additional_message_features
)
return st_graph
return _embedding
[docs]
def O3Decoder(
latent_irreps: Irreps,
output_irreps: Irreps,
n_blocks: int = 1,
):
"""Steerable decoder.
Args:
latent_irreps: Representation from the previous block
output_irreps: Output representation
n_blocks: Number of tensor product blocks in the decoder
Returns:
Decoded latent feature space to output space.
"""
def _decoder(st_graph: SteerableGraphsTuple):
nodes = st_graph.graph.nodes
for i in range(n_blocks):
nodes = O3TensorProductGate(latent_irreps, name=f"readout_{i}")(
nodes, st_graph.node_attributes
)
return O3TensorProduct(output_irreps, name="output")(
nodes, st_graph.node_attributes
)
return _decoder
[docs]
class SEGNNLayer(hk.Module):
"""
Steerable E(3) equivariant layer.
Applies a message passing step (GN) with equivariant message and update functions.
"""
[docs]
def __init__(
self,
output_irreps: Irreps,
layer_idx: int,
n_blocks: int = 2,
norm: Optional[str] = None,
aggregate_fn: Optional[Callable] = jraph.segment_sum,
):
"""
Initialize the layer.
Args:
output_irreps: Layer output representation
layer_idx: Numbering of the layer
n_blocks: Number of tensor product n_blocks in the layer
norm: Normalization type. Either be None, 'instance' or 'batch'
aggregate_fn: Message aggregation function. Defaults to sum.
"""
super().__init__(f"layer_{layer_idx}")
assert norm in ["batch", "instance", "none", None], f"Unknown norm '{norm}'"
self._output_irreps = output_irreps
self._n_blocks = n_blocks
self._norm = norm
self._aggregate_fn = aggregate_fn
def _message(
self,
edge_attribute: IrrepsArray,
additional_message_features: IrrepsArray,
edge_features: Any,
incoming: IrrepsArray,
outgoing: IrrepsArray,
globals_: Any,
) -> IrrepsArray:
"""Steerable equivariant message function."""
_ = globals_
_ = edge_features
# create messages
msg = e3nn.concatenate([incoming, outgoing], axis=-1)
if additional_message_features is not None:
msg = e3nn.concatenate([msg, additional_message_features], axis=-1)
# message mlp (phi_m in the paper) steered by edge attributeibutes
for i in range(self._n_blocks):
msg = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")(
msg, edge_attribute
)
# NOTE: original implementation only applied batch norm to messages
if self._norm == "batch":
msg = e3nn.haiku.BatchNorm(irreps=self._output_irreps)(msg)
return msg
def _update(
self,
node_attribute: IrrepsArray,
nodes: IrrepsArray,
senders: Any,
msg: IrrepsArray,
globals_: Any,
) -> IrrepsArray:
"""Steerable equivariant update function."""
_ = senders
_ = globals_
x = e3nn.concatenate([nodes, msg], axis=-1)
# update mlp (phi_f in the paper) steered by node attributeibutes
for i in range(self._n_blocks - 1):
x = O3TensorProductGate(self._output_irreps, name=f"tp_{i}")(
x, node_attribute
)
# last update layer without activation
update = O3TensorProduct(self._output_irreps, name=f"tp_{self._n_blocks - 1}")(
x, node_attribute
)
# residual connection
nodes += update
# message norm
if self._norm in ["batch", "instance"]:
nodes = e3nn.haiku.BatchNorm(
irreps=self._output_irreps,
instance=(self._norm == "instance"),
)(nodes)
return nodes
[docs]
def __call__(self, st_graph: SteerableGraphsTuple) -> SteerableGraphsTuple:
"""Perform a message passing step.
Args:
st_graph: Input graph
Returns:
The updated graph
"""
# NOTE node_attributes, edge_attributes and additional_message_features
# are never updated within the message passing layers
return st_graph._replace(
graph=jraph.GraphNetwork(
update_node_fn=Partial(self._update, st_graph.node_attributes),
update_edge_fn=Partial(
self._message,
st_graph.edge_attributes,
st_graph.additional_message_features,
),
aggregate_edges_for_nodes_fn=self._aggregate_fn,
)(st_graph.graph)
)
[docs]
def weight_balanced_irreps(
scalar_units: int, irreps_right: Irreps, lmax: int = None
) -> Irreps:
"""
Determine left irreps so that the tensor product with irreps_right has at least
scalar_units weights.
Args:
scalar_units: Number of weights
irreps_right: Right irreps
lmax: Maximum L of the left irreps
Returns:
Left irreps
"""
# irrep order
if lmax is None:
lmax = irreps_right.lmax
# linear layer with squdare weight matrix
linear_weights = scalar_units**2
# raise hidden features until enough weigths
n = 0
while True:
n += 1
irreps_left = (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify()
# number of paths
tp_weights = sum(
prod([irreps_left[i_1].mul ** 2, irreps_right[i_2].mul])
for i_1, (_, ir_1) in enumerate(irreps_left)
for i_2, (_, ir_2) in enumerate(irreps_right)
for _, (_, ir_out) in enumerate(irreps_left)
if ir_out in ir_1 * ir_2
)
if tp_weights >= linear_weights:
break
return Irreps(irreps_left)
[docs]
class SEGNN(BaseModel):
r"""
Steerable E(3) equivariant network by
`Brandstetter et al. <https://arxiv.org/abs/2110.02905>`_.
SEGNNs are E(3)-equivariant graph neural networks based around tensor products of
representations. By design, SEGNNs allow for flexible scalar/vectorial inputs and
outputs for both edges and attributes. The message passing is modified as follows:
.. math::
\begin{align}
\mathbf{m}_{ij} &= \textit{M}_{\mathbf{\hat{a}}_{ij}}\left(
\mathbf{f}_i, \mathbf{f}_j, \| x_i - x_j \|^2 \right), \\
\mathbf{f}^{\prime}_i &= \textit{U}_{\mathbf{\hat{a}}_i}\left(
\mathbf{f}_i, \sum_{j\in\mathcal{N}(i)} \mathbf{m}_{ij} \right)
\end{align}
:math:`\mathbf{\hat{a}}_{ij}` and :math:`\mathbf{\hat{a}}_{i}` are edge and
node attributes and the operators :math:`\textit{M}_{\mathbf{\hat{a}}_{ij}}` and
:math:`\textit{U}_{\mathbf{\hat{a}}_{i}}` are defined as a tensor product of
representations :math:`\otimes_{CG}` between the input and the attribues:
.. math::
\begin{align}
\textit{U}_{\mathbf{\hat{a}}_i} :=
\mathcal{W^n}_{\mathbf{\hat{a}}_i}
(\dots \sigma ( \mathcal{W^0}_{\mathbf{\hat{a}}_i} \mathbf{f} ) )
\quad \text{with} \quad
\mathcal{W}_{\mathbf{\hat{a}}_i} \mathbf{f} :=
\mathbf{f} \otimes_{CG}^{\mathcal{W}} \mathbf{\mathbf{\hat{a}}}
\end{align}
where :math`\mathbf{f} = \[\mathbf{f}_i,\sum_{j\in\mathcal{N}(i)} \mathbf{m}_{ij}\]`
are node features concatenated to the aggregated messages, :math:`\sigma` is a gated
non-linearity and :math:`\mathcal{W}` are the tensor product parameters.
:math:`\textit{M}_{\mathbf{\hat{a}}_{ij}}` is similarly defined, but with the
nonlinearity on the last layer, with edge attributes :math:`\mathbf{\hat{a}}_{ij}`
and :math`\mathbf{f} = \[ \mathbf{f}_i, \mathbf{f}_j, \|x_i - x_j\|^2 \]`
"""
[docs]
def __init__(
self,
node_features_irreps: Irreps,
edge_features_irreps: Irreps,
scalar_units: int,
lmax_hidden: int,
lmax_attributes: int,
output_irreps: Irreps,
num_mp_steps: int,
n_vels: int,
velocity_aggregate: str = "avg",
homogeneous_particles: bool = True,
norm: Optional[str] = None,
blocks_per_step: int = 2,
embed_msg_features: bool = False,
):
"""
Initialize the network.
Args:
node_features_irreps: Irreps of the node features.
edge_features_irreps: Irreps of the additional message passing features.
scalar_units: Hidden units (lower bound). Actual number depends on lmax.
lmax_hidden: Maximum L of the hidden layer representations.
lmax_attributes: Maximum L of the attributes.
output_irreps: Output representation.
num_mp_steps: Number of message passing layers
n_vels: Number of velocities in the history.
velocity_aggregate: Velocity sequence aggregation method.
homogeneous_particles: If all particles are of homogeneous type.
norm: Normalization type. Either None, 'instance' or 'batch'
blocks_per_step: Number of tensor product blocks in each message passing
embed_msg_features: Set to true to also embed edges/message passing features
"""
super().__init__()
# network
self._attribute_irreps = Irreps.spherical_harmonics(lmax_attributes)
self._hidden_irreps = weight_balanced_irreps(
scalar_units, self._attribute_irreps, lmax_hidden
)
self._output_irreps = output_irreps
self._num_mp_steps = num_mp_steps
self._embed_msg_features = embed_msg_features
self._norm = norm
self._blocks_per_step = blocks_per_step
self._embedding = O3Embedding(
self._hidden_irreps,
embed_edges=self._embed_msg_features,
)
self._decoder = O3Decoder(
latent_irreps=self._hidden_irreps,
output_irreps=output_irreps,
n_blocks=self._blocks_per_step,
)
# transform
assert velocity_aggregate in [
"avg",
"last",
], "Invalid velocity aggregate. Must be one of 'avg', 'sum' or 'last'."
self._node_features_irreps = node_features_irreps
self._edge_features_irreps = edge_features_irreps
self._velocity_aggregate = velocity_aggregate
self._n_vels = n_vels
self._homogeneous_particles = homogeneous_particles
def _transform(
self, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray
) -> Tuple[SteerableGraphsTuple, int]:
"""Convert physical features to SteerableGraphsTuple for segnn."""
dim = features["vel_hist"].shape[1] // self._n_vels
assert (
dim == 3 or dim == 2
), "The velocity history should be of shape (n_nodes, n_vels * 3)."
n_nodes = features["vel_hist"].shape[0]
features["vel_hist"] = features["vel_hist"].reshape(n_nodes, self._n_vels, dim)
if dim == 2:
# add zeros for z component for E(3) equivariance
features = features_2d_to_3d(features)
if self._n_vels == 1:
vel = jnp.squeeze(features["vel_hist"])
else:
if self._velocity_aggregate == "avg":
vel = jnp.mean(features["vel_hist"], 1)
if self._velocity_aggregate == "last":
vel = features["vel_hist"][:, -1, :]
rel_pos = features["rel_disp"]
edge_attributes = e3nn.spherical_harmonics(
self._attribute_irreps, rel_pos, normalize=True, normalization="integral"
)
vel_embedding = e3nn.spherical_harmonics(
self._attribute_irreps, vel, normalize=True, normalization="integral"
)
# scatter edge attributes to nodes (density)
scattered_edges = tree_map(
lambda e: jraph.segment_mean(e, features["receivers"], n_nodes),
edge_attributes,
)
# node attributes as velocities + edge "density". Scalar default to 1.0
node_attributes = e3nn.IrrepsArray(
vel_embedding.irreps,
(vel_embedding + scattered_edges).array.at[:, 0].set(1.0),
)
node_features = [features["vel_hist"].reshape(n_nodes, self._n_vels * 3)]
node_features += [
features[k] for k in ["vel_mag", "bound", "force"] if k in features
]
node_features = jnp.concatenate(node_features, axis=-1)
if not self._homogeneous_particles:
particles = jax.nn.one_hot(particle_type, NodeType.SIZE)
node_features = jnp.concatenate([node_features, particles], axis=-1)
edge_features = [features[k] for k in ["rel_disp", "rel_dist"] if k in features]
edge_features = jnp.concatenate(edge_features, axis=-1)
feature_graph = jraph.GraphsTuple(
nodes=IrrepsArray(self._node_features_irreps, node_features),
edges=None,
senders=features["senders"],
receivers=features["receivers"],
n_node=jnp.array([n_nodes]),
n_edge=jnp.array([len(features["senders"])]),
globals=None,
)
st_graph = SteerableGraphsTuple(
graph=feature_graph,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
additional_message_features=IrrepsArray(
self._edge_features_irreps, edge_features
),
)
return st_graph, dim
def _postprocess(self, nodes: IrrepsArray, dim: int) -> Dict[str, jnp.ndarray]:
acc = jnp.squeeze(nodes.array)
if dim == 2:
acc = acc[:, :2]
return {"acc": acc}
[docs]
def __call__(
self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray]
) -> Dict[str, jnp.ndarray]:
# feature transformation
st_graph, dim = self._transform(*sample)
# node (and edge) embedding
st_graph = self._embedding(st_graph)
# message passing
for n in range(self._num_mp_steps):
st_graph = SEGNNLayer(
self._hidden_irreps, n, n_blocks=self._blocks_per_step, norm=self._norm
)(st_graph)
# readout
nodes = self._decoder(st_graph)
out = self._postprocess(nodes, dim)
return out