"""
E(n) equivariant GNN from `Garcia Satorras et al. <https://arxiv.org/abs/2102.09844>`_.
EGNN model, layers and feature transform.
Original implementation: https://github.com/vgsatorras/egnn
Standalone implementation + validation: https://github.com/gerkone/egnn-jax
"""
from typing import Any, Callable, Dict, Optional, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
from jax.tree_util import Partial
from jax_sph.jax_md import space
from lagrangebench.utils import NodeType
from .base import BaseModel
from .utils import LinearXav, MLPXav
[docs]
class EGNNLayer(hk.Module):
r"""E(n)-equivariant EGNN layer.
Applies a message passing step where the positions are corrected with the velocities
and a learnable correction term :math:`\psi_x(\mathbf{h}_i^{(t+1)})`:
"""
[docs]
def __init__(
self,
layer_num: int,
hidden_size: int,
output_size: int,
displacement_fn: space.DisplacementFn,
shift_fn: space.ShiftFn,
blocks: int = 1,
act_fn: Callable = jax.nn.silu,
pos_aggregate_fn: Optional[Callable] = jraph.segment_sum,
msg_aggregate_fn: Optional[Callable] = jraph.segment_sum,
residual: bool = True,
attention: bool = False,
normalize: bool = False,
tanh: bool = False,
dt: float = 0.001,
eps: float = 1e-8,
):
"""Initialize the layer.
Args:
layer_num: layer number
hidden_size: hidden size
output_size: output size
displacement_fn: Displacement function for the acceleration computation.
shift_fn: Shift function for updating positions
blocks: number of blocks in the node and edge MLPs
act_fn: activation function
pos_aggregate_fn: position aggregation function
msg_aggregate_fn: message aggregation function
residual: whether to use residual connections
attention: whether to use attention
normalize: whether to normalize the coordinates
tanh: whether to use tanh in the position update
dt: position update step size
eps: small number to avoid division by zero
"""
super().__init__(f"layer_{layer_num}")
self._displacement_fn = displacement_fn
self._shift_fn = shift_fn
self.pos_aggregate_fn = pos_aggregate_fn
self.msg_aggregate_fn = msg_aggregate_fn
self._residual = residual
self._normalize = normalize
self._eps = eps
# message network
self._edge_mlp = MLPXav(
[hidden_size] * blocks + [hidden_size],
activation=act_fn,
activate_final=True,
)
# update network
self._node_mlp = MLPXav(
[hidden_size] * blocks + [output_size],
activation=act_fn,
activate_final=False,
)
# position update network
net = [LinearXav(hidden_size)] * blocks
# NOTE: from https://github.com/vgsatorras/egnn/blob/main/models/gcl.py#L254
net += [
act_fn,
LinearXav(1, with_bias=False, w_init=hk.initializers.UniformScaling(dt)),
]
if tanh:
net.append(jax.nn.tanh)
self._pos_correction_mlp = hk.Sequential(net)
# velocity integrator network
net = [LinearXav(hidden_size)] * blocks
net += [
act_fn,
LinearXav(1, with_bias=False, w_init=hk.initializers.UniformScaling(dt)),
]
self._vel_correction_mlp = hk.Sequential(net)
# attention
self._attention_mlp = None
if attention:
self._attention_mlp = hk.Sequential(
[LinearXav(hidden_size), jax.nn.sigmoid]
)
def _pos_update(
self,
pos: jnp.ndarray,
graph: jraph.GraphsTuple,
coord_diff: jnp.ndarray,
) -> jnp.ndarray:
trans = coord_diff * self._pos_correction_mlp(graph.edges)
return self.pos_aggregate_fn(trans, graph.senders, num_segments=pos.shape[0])
def _message(
self,
radial: jnp.ndarray,
edge_attribute: jnp.ndarray,
edge_features: Any,
incoming: jnp.ndarray,
outgoing: jnp.ndarray,
globals_: Any,
) -> jnp.ndarray:
_ = edge_features
_ = globals_
msg = jnp.concatenate([incoming, outgoing, radial], axis=-1)
if edge_attribute is not None:
msg = jnp.concatenate([msg, edge_attribute], axis=-1)
msg = self._edge_mlp(msg)
if self._attention_mlp:
att = self._attention_mlp(msg)
msg = msg * att
return msg
def _update(
self,
node_attribute: jnp.ndarray,
nodes: jnp.ndarray,
senders: Any,
msg: jnp.ndarray,
globals_: Any,
) -> jnp.ndarray:
_ = senders
_ = globals_
x = jnp.concatenate([nodes, msg], axis=-1)
if node_attribute is not None:
x = jnp.concatenate([x, node_attribute], axis=-1)
x = self._node_mlp(x)
if self._residual:
x = nodes + x
return x
def _coord2radial(
self, graph: jraph.GraphsTuple, coord: jnp.array
) -> Tuple[jnp.array, jnp.array]:
coord_diff = self._displacement_fn(coord[graph.senders], coord[graph.receivers])
radial = jnp.sum(coord_diff**2, 1)[:, jnp.newaxis]
if self._normalize:
norm = jnp.sqrt(radial)
coord_diff = coord_diff / (norm + self._eps)
return radial, coord_diff
[docs]
def __call__(
self,
graph: jraph.GraphsTuple,
pos: jnp.ndarray,
vel: jnp.ndarray,
edge_attribute: Optional[jnp.ndarray] = None,
node_attribute: Optional[jnp.ndarray] = None,
) -> Tuple[jraph.GraphsTuple, jnp.ndarray]:
"""
Apply EGNN layer.
Args:
graph: Graph from previous step
pos: Node position, updated separately
vel: Node velocity
edge_attribute: Edge attribute (optional)
node_attribute: Node attribute (optional)
Returns:
Updated graph, node position
"""
radial, coord_diff = self._coord2radial(graph, pos)
graph = jraph.GraphNetwork(
update_edge_fn=Partial(self._message, radial, edge_attribute),
update_node_fn=Partial(self._update, node_attribute),
aggregate_edges_for_nodes_fn=self.msg_aggregate_fn,
)(graph)
# update position
pos = self._shift_fn(pos, self._pos_update(pos, graph, coord_diff))
# integrate velocity
pos = self._shift_fn(pos, self._vel_correction_mlp(graph.nodes) * vel)
return graph, pos
[docs]
class EGNN(BaseModel):
r"""
E(n) Graph Neural Network by
`Garcia Satorras et al. <https://arxiv.org/abs/2102.09844>`_.
EGNN doesn't require expensive higher-order representations in intermediate layers;
instead it relies on separate scalar and vector channels, which are treated
differently by EGNN layers. In this setup, EGNN is similar to a learnable numerical
integrator:
.. math::
\begin{align}
\mathbf{m}_{ij}^{(t+1)} &= \phi_e \left(
\mathbf{m}_{ij}^{(t)}, \mathbf{h}_i^{(t)},
\mathbf{h}_j^{(t)}, ||\mathbf{x}_i^{(t)} - \mathbf{x}_j^{(t)}||^2
\right) \\
\mathbf{\hat{m}}_{ij}^{(t+1)} &=
(\mathbf{x}_i^{(t)} - \mathbf{x}_j^{(t)}) \phi_x(\mathbf{m}_{ij}^{(t+1)})
\end{align}
And the node update with the integrator
.. math::
\begin{align}
\mathbf{h}_i^{(t+1)} &= \psi_h \left(
\mathbf{h}_i^{(t)}, \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}^{(t+1)}
\right) \\
\mathbf{x}_i^{(t+1)} &= \mathbf{x}_i^{(t)}
+ \mathbf{\hat{m}}_{ij}^{(t+1)} \psi_x(\mathbf{h}_i^{(t+1)})
\end{align}
where :math:`\mathbf{m}_{ij}` and :math:`\mathbf{\hat{m}}_{ij}` are the scalar and
vector messages respectively, and :math:`\mathbf{x}_{i}` are the positions.
This implementation differs from the original in two places:
- because our datasets can have periodic boundary conditions, we use shift and
displacement functions that take care of it when operations on positions are done.
- we apply a simple integrator after the last layer to get the acceleration.
"""
[docs]
def __init__(
self,
hidden_size: int,
output_size: int,
dt: float,
n_vels: int,
displacement_fn: space.DisplacementFn,
shift_fn: space.ShiftFn,
normalization_stats: Optional[Dict[str, jnp.ndarray]] = None,
act_fn: Callable = jax.nn.silu,
num_mp_steps: int = 4,
homogeneous_particles: bool = True,
residual: bool = True,
attention: bool = False,
normalize: bool = False,
tanh: bool = False,
):
r"""
Initialize the network.
Args:
hidden_size: Number of hidden features.
output_size: Number of features for 'h' at the output.
dt: Time step for position and velocity integration. Used to rescale the
initialization of the correction MLP.
n_vels: Number of velocities in the history.
displacement_fn: Displacement function for the acceleration computation.
shift_fn: Shift function for updating positions.
normalization_stats: Normalization statistics for the input data.
act_fn: Non-linearity.
num_mp_steps: Number of layer for the EGNN
homogeneous_particles: If all particles are of homogeneous type.
residual: Whether to use residual connections.
attention: Whether to use attention or not.
normalize: Normalizes the coordinates messages such that:
``x^{l+1}_i = x^{l}_i + \sum(x_i - x_j)\phi_x(m_{ij})\|x_i - x_j\|``
It may help in the stability or generalization. Not used in the paper.
tanh: Sets a tanh activation function at the output of ``\phi_x(m_{ij})``.
It bounds the output of ``\phi_x(m_{ij})`` which definitely improves in
stability but it may decrease in accuracy. Not used in the paper.
"""
super().__init__()
# network
self._hidden_size = hidden_size
self._output_size = output_size
self._act_fn = act_fn
self._num_mp_steps = num_mp_steps
self._residual = residual
self._attention = attention
self._normalize = normalize
self._tanh = tanh
# integrator
self._dt = dt / self._num_mp_steps
self._displacement_fn = displacement_fn
self._shift_fn = shift_fn
if normalization_stats is None:
normalization_stats = {
"velocity": {"mean": 0.0, "std": 1.0},
"acceleration": {"mean": 0.0, "std": 1.0},
}
self._vel_stats = normalization_stats["velocity"]
self._acc_stats = normalization_stats["acceleration"]
# transform
self._n_vels = n_vels
self._homogeneous_particles = homogeneous_particles
def _transform(
self, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray
) -> Tuple[jraph.GraphsTuple, Dict[str, jnp.ndarray]]:
props = {}
n_nodes = features["vel_hist"].shape[0]
props["vel"] = jnp.reshape(features["vel_hist"], (n_nodes, self._n_vels, -1))
# most recent position
props["pos"] = features["abs_pos"][:, -1]
# relative distances between particles
props["edge_attr"] = features["rel_dist"]
# force magnitude as node attributes
props["node_attr"] = None
if "force" in features:
props["node_attr"] = jnp.sqrt(
jnp.sum(features["force"] ** 2, axis=-1, keepdims=True)
)
# velocity magnitudes as node features
node_features = jnp.concatenate(
[
jnp.sqrt(jnp.sum(props["vel"][:, i, :] ** 2, axis=-1, keepdims=True))
for i in range(self._n_vels)
],
axis=-1,
)
if not self._homogeneous_particles:
particles = jax.nn.one_hot(particle_type, NodeType.SIZE)
node_features = jnp.concatenate([node_features, particles], axis=-1)
graph = jraph.GraphsTuple(
nodes=node_features,
edges=None,
senders=features["senders"],
receivers=features["receivers"],
n_node=jnp.array([n_nodes]),
n_edge=jnp.array([len(features["senders"])]),
globals=None,
)
return graph, props
def _postprocess(
self, next_pos: jnp.ndarray, props: Dict[str, jnp.ndarray]
) -> Dict[str, jnp.ndarray]:
prev_vel = props["vel"][:, -1, :]
prev_pos = props["pos"]
# first order finite difference
next_vel = self._displacement_fn(next_pos, prev_pos)
acc = next_vel - prev_vel
return {"pos": next_pos, "vel": next_vel, "acc": acc}
[docs]
def __call__(
self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray]
) -> Dict[str, jnp.ndarray]:
graph, props = self._transform(*sample)
# input node embedding
h = LinearXav(self._hidden_size, name="scalar_emb")(graph.nodes)
graph = graph._replace(nodes=h)
prev_vel = props["vel"][:, -1, :]
# egnn works with unnormalized velocities
prev_vel = prev_vel * self._vel_stats["std"] + self._vel_stats["mean"]
# message passing
next_pos = props["pos"].copy()
for n in range(self._num_mp_steps):
graph, next_pos = EGNNLayer(
layer_num=n,
hidden_size=self._hidden_size,
output_size=self._hidden_size,
displacement_fn=self._displacement_fn,
shift_fn=self._shift_fn,
act_fn=self._act_fn,
residual=self._residual,
attention=self._attention,
normalize=self._normalize,
dt=self._dt,
tanh=self._tanh,
)(graph, next_pos, prev_vel, props["edge_attr"], props["node_attr"])
# position finite differencing to get acceleration
out = self._postprocess(next_pos, props)
return out