"""
Modified PaiNN implementation for general vectorial inputs and outputs
`Schütt et al. <https://proceedings.mlr.press/v139/schutt21a.html>`_.
PaiNN model, layers and feature transform.
Original implementation: https://github.com/atomistic-machine-learning/schnetpack
Standalone implementation + validation: https://github.com/gerkone/painn-jax
"""
from typing import Callable, Dict, NamedTuple, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import jraph
from lagrangebench.utils import NodeType
from .utils import LinearXav
[docs]
class NodeFeatures(NamedTuple):
"""Simple container for PaiNN scalar and vectorial node features."""
s: jnp.ndarray = None
v: jnp.ndarray = None
ReadoutFn = Callable[[jraph.GraphsTuple], Tuple[jnp.ndarray, jnp.ndarray]]
ReadoutBuilderFn = Callable[..., ReadoutFn]
[docs]
class GatedEquivariantBlock(hk.Module):
"""Gated equivariant block (restricted to vectorial features).
.. image:: https://i.imgur.com/EMlg2Qi.png
"""
[docs]
def __init__(
self,
hidden_size: int,
scalar_out_channels: int,
vector_out_channels: int,
activation: Callable = jax.nn.silu,
scalar_activation: Callable = None,
eps: float = 1e-8,
name: str = "gated_equivariant_block",
):
"""Initialize the layer.
Args:
hidden_size: Number of hidden channels.
scalar_out_channels: Number of scalar output channels.
vector_out_channels: Number of vector output channels.
activation: Gate activation function.
scalar_activation: Activation function for the scalar output.
eps: Constant added in norm to prevent derivation instabilities.
name: Name of the module.
"""
super().__init__(name)
assert scalar_out_channels > 0 and vector_out_channels > 0
self._scalar_out_channels = scalar_out_channels
self._vector_out_channels = vector_out_channels
self._eps = eps
self.vector_mix_net = LinearXav(
2 * vector_out_channels,
with_bias=False,
name="vector_mix_net",
)
self.gate_block = hk.Sequential(
[
LinearXav(hidden_size),
activation,
LinearXav(scalar_out_channels + vector_out_channels),
],
name="scalar_gate_net",
)
self.scalar_activation = scalar_activation
[docs]
def __call__(
self, s: jnp.ndarray, v: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]:
v_l, v_r = jnp.split(self.vector_mix_net(v), 2, axis=-1)
v_r_norm = jnp.sqrt(jnp.sum(v_r**2, axis=-2) + self._eps)
gating_scalars = jnp.concatenate([s, v_r_norm], axis=-1)
s, _, v_gate = jnp.split(
self.gate_block(gating_scalars),
[self._scalar_out_channels, self._vector_out_channels],
axis=-1,
)
# scale the vectors by the gating scalars
v = v_l * v_gate[:, jnp.newaxis]
if self.scalar_activation:
s = self.scalar_activation(s)
return s, v
[docs]
def gaussian_rbf(
n_rbf: int,
cutoff: float,
start: float = 0.0,
centered: bool = False,
trainable: bool = False,
) -> Callable[[jnp.ndarray], Callable]:
r"""Gaussian radial basis functions.
Args:
n_rbf: total number of Gaussian functions, :math:`N_g`.
cutoff: center of last Gaussian function, :math:`\mu_{N_g}`
start: center of first Gaussian function, :math:`\mu_0`.
trainable: If True, widths and offset of Gaussian functions learnable.
"""
if centered:
widths = jnp.linspace(start, cutoff, n_rbf)
offset = jnp.zeros_like(widths)
else:
offset = jnp.linspace(start, cutoff, n_rbf)
width = jnp.abs(cutoff - start) / n_rbf * jnp.ones_like(offset)
if trainable:
widths = hk.get_parameter(
"widths", width.shape, width.dtype, init=lambda *_: width
)
offsets = hk.get_parameter(
"offset", offset.shape, offset.dtype, init=lambda *_: offset
)
else:
hk.set_state("widths", jnp.array([width]))
hk.set_state("offsets", jnp.array([offset]))
widths = hk.get_state("widths")
offsets = hk.get_state("offsets")
def _rbf(x: jnp.ndarray) -> jnp.ndarray:
coeff = -0.5 / jnp.power(widths, 2)
diff = x[..., jnp.newaxis] - offsets
return jnp.exp(coeff * jnp.power(diff, 2))
return _rbf
[docs]
def cosine_cutoff(cutoff: float) -> Callable[[jnp.ndarray], Callable]:
r"""Behler-style cosine cutoff.
.. math::
f(r) = \begin{cases}
0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right]
& r < r_\text{cutoff} \\
0 & r \geqslant r_\text{cutoff} \\
\end{cases}
Args:
cutoff (float): cutoff radius.
"""
hk.set_state("cutoff", cutoff)
cutoff = hk.get_state("cutoff")
def _cutoff(x: jnp.ndarray) -> jnp.ndarray:
# Compute values of cutoff function
cuts = 0.5 * (jnp.cos(x * jnp.pi / cutoff) + 1.0)
# Remove contributions beyond the cutoff radius
mask = jnp.array(x < cutoff, dtype=jnp.float32)
return cuts * mask
return _cutoff
[docs]
def PaiNNReadout(
hidden_size: int,
out_channels: int = 1,
activation: Callable = jax.nn.silu,
blocks: int = 2,
eps: float = 1e-8,
) -> ReadoutFn:
"""
PaiNN readout block.
Args:
hidden_size: Number of hidden channels.
scalar_out_channels: Number of scalar/vector output channels.
activation: Activation function.
blocks: Number of readout blocks.
Returns:
Configured readout function.
"""
def _readout(graph: jraph.GraphsTuple) -> Tuple[jnp.ndarray, jnp.ndarray]:
s, v = graph.nodes
s = jnp.squeeze(s)
for i in range(blocks - 1):
ith_hidden_size = hidden_size // 2 ** (i + 1)
s, v = GatedEquivariantBlock(
hidden_size=ith_hidden_size * 2,
scalar_out_channels=ith_hidden_size,
vector_out_channels=ith_hidden_size,
activation=activation,
eps=eps,
name=f"readout_block_{i}",
)(s, v)
s, v = GatedEquivariantBlock(
hidden_size=ith_hidden_size,
scalar_out_channels=out_channels,
vector_out_channels=out_channels,
activation=activation,
eps=eps,
name="readout_block_out",
)(s, v)
return jnp.squeeze(s), jnp.squeeze(v)
return _readout
[docs]
class PaiNNLayer(hk.Module):
"""PaiNN interaction block."""
[docs]
def __init__(
self,
hidden_size: int,
layer_num: int,
activation: Callable = jax.nn.silu,
blocks: int = 2,
aggregate_fn: Callable = jraph.segment_sum,
eps: float = 1e-8,
):
"""
Initialize the PaiNN layer, made up of an interaction block and a mixing block.
Args:
hidden_size: Number of node features.
activation: Activation function.
layer_num: Numbering of the layer.
blocks: Number of layers in the context networks.
aggregate_fn: Function to aggregate the neighbors.
eps: Constant added in norm to prevent derivation instabilities.
"""
super().__init__(f"layer_{layer_num}")
self._hidden_size = hidden_size
self._eps = eps
self._aggregate_fn = aggregate_fn
# inter-particle context net
self.interaction_block = hk.Sequential(
[LinearXav(hidden_size), activation] * (blocks - 1)
+ [LinearXav(3 * hidden_size)],
name="interaction_block",
)
# intra-particle context net
self.mixing_block = hk.Sequential(
[LinearXav(hidden_size), activation] * (blocks - 1)
+ [LinearXav(3 * hidden_size)],
name="mixing_block",
)
# vector channel mix
self.vector_mixing_block = LinearXav(
2 * hidden_size,
with_bias=False,
name="vector_mixing_block",
)
def _message(
self,
s: jnp.ndarray,
v: jnp.ndarray,
dir_ij: jnp.ndarray,
Wij: jnp.ndarray,
senders: jnp.ndarray,
receivers: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Message/interaction. Inter-particle.
Args:
s (jnp.ndarray): Input scalar features.
v (jnp.ndarray): Input vector features.
dir_ij (jnp.ndarray): Direction of the edge.
Wij (jnp.ndarray): Filter.
senders (jnp.ndarray): Index of the sender node.
receivers (jnp.ndarray): Index of the receiver node.
Returns:
Aggregated messages after interaction.
"""
x = self.interaction_block(s)
xj = x[receivers]
vj = v[receivers]
ds, dv1, dv2 = jnp.split(Wij * xj, 3, axis=-1)
n_nodes = tree.tree_leaves(s)[0].shape[0]
dv = dv1 * dir_ij[..., jnp.newaxis] + dv2 * vj
# aggregate scalars and vectors
ds = self._aggregate_fn(ds, senders, n_nodes)
dv = self._aggregate_fn(dv, senders, n_nodes)
s = s + jnp.clip(ds, -1e2, 1e2)
v = v + jnp.clip(dv, -1e2, 1e2)
return s, v
def _update(
self, s: jnp.ndarray, v: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Update/mixing. Intra-particle.
Args:
s (jnp.ndarray): Input scalar features.
v (jnp.ndarray): Input vector features.
Returns:
Node features after update.
"""
v_l, v_r = jnp.split(self.vector_mixing_block(v), 2, axis=-1)
v_norm = jnp.sqrt(jnp.sum(v_r**2, axis=-2, keepdims=True) + self._eps)
ts = jnp.concatenate([s, v_norm], axis=-1)
ds, dv, dsv = jnp.split(self.mixing_block(ts), 3, axis=-1)
dv = v_l * dv
dsv = dsv * jnp.sum(v_r * v_l, axis=1, keepdims=True)
s = s + jnp.clip(ds + dsv, -1e2, 1e2)
v = v + jnp.clip(dv, -1e2, 1e2)
return s, v
[docs]
def __call__(
self,
graph: jraph.GraphsTuple,
Wij: jnp.ndarray,
):
"""Compute interaction output.
Args:
graph (jraph.GraphsTuple): Input graph.
Wij (jnp.ndarray): Filter.
Returns:
atom features after interaction
"""
s, v = graph.nodes
s, v = self._message(s, v, graph.edges, Wij, graph.senders, graph.receivers)
s, v = self._update(s, v)
return graph._replace(nodes=NodeFeatures(s=s, v=v))
[docs]
class PaiNN(hk.Module):
r"""Polarizable interaction Neural Network by
`Schütt et al. <https://proceedings.mlr.press/v139/schutt21a.html>`_.
In order to accomodate general inputs/outputs, this PaiNN is different from the
original in a few ways; the main change is that inputs vectors are not initialized
to 0 anymore but to the time average of velocity.
.. image:: https://i.imgur.com/NxZ2rPi.png
"""
[docs]
def __init__(
self,
hidden_size: int,
output_size: int,
num_mp_steps: int,
radial_basis_fn: Callable,
cutoff_fn: Callable,
n_vels: int,
homogeneous_particles: bool = True,
activation: Callable = jax.nn.silu,
shared_interactions: bool = False,
shared_filters: bool = False,
eps: float = 1e-8,
):
"""Initialize the model.
Args:
hidden_size: Determines the size of each embedding vector.
output_size: Number of output features.
num_mp_steps: Number of interaction blocks.
radial_basis_fn: Expands inter-particle distances in a basis set.
cutoff_fn: Cutoff function.
n_vels: Number of historical velocities.
homogeneous_particles: If all particles are of homogeneous type.
activation: Activation function.
shared_interactions: If True, share the weights across interaction blocks.
shared_filters: If True, share the weights across filter networks.
eps: Constant added in norm to prevent derivation instabilities.
"""
super().__init__("painn")
assert radial_basis_fn is not None, "A radial_basis_fn must be provided"
self._n_vels = n_vels
self._homogeneous_particles = homogeneous_particles
self._hidden_size = hidden_size
self._num_mp_steps = num_mp_steps
self._eps = eps
self._shared_filters = shared_filters
self._shared_interactions = shared_interactions
self.radial_basis_fn = radial_basis_fn
self.cutoff_fn = cutoff_fn
self.scalar_emb = LinearXav(self._hidden_size, name="scalar_embedding")
# mix vector channels (only used if vector features are present in input)
self.vector_emb = LinearXav(
self._hidden_size, with_bias=False, name="vector_embedding"
)
if shared_filters:
self.filter_net = LinearXav(3 * self._hidden_size, name="filter_net")
else:
self.filter_net = LinearXav(
self._num_mp_steps * 3 * self._hidden_size, name="filter_net"
)
if self._shared_interactions:
self.layers = [
PaiNNLayer(self._hidden_size, 0, activation, eps=eps)
] * self._num_mp_steps
else:
self.layers = [
PaiNNLayer(self._hidden_size, i, activation, eps=eps)
for i in range(self._num_mp_steps)
]
self._readout = PaiNNReadout(self._hidden_size, out_channels=output_size)
def _embed(self, graph: jraph.GraphsTuple) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Embed the input nodes."""
# embeds scalar features
s = jnp.asarray(graph.nodes.s, dtype=jnp.float32)
if len(s.shape) == 1:
s = s[:, jnp.newaxis]
s = self.scalar_emb(s)[:, jnp.newaxis]
# embeds vector features
v = self.vector_emb(graph.nodes.v)
return graph._replace(nodes=NodeFeatures(s=s, v=v))
def _get_filters(self, norm_ij: jnp.ndarray) -> jnp.ndarray:
r"""Compute the rotationally invariant filters :math:`W_s`.
.. math::
W_s = MLP(RBF(\|\vector{r}_{ij}\|)) * f_{cut}(\|\vector{r}_{ij}\|)
"""
phi_ij = self.radial_basis_fn(norm_ij)
if self.cutoff_fn is not None:
norm_ij = self.cutoff_fn(norm_ij)
# compute filters
filters = self.filter_net(phi_ij) * norm_ij[:, jnp.newaxis]
# split into layer-wise filters
if self._shared_filters:
filter_list = [filters] * self._num_mp_steps
else:
filter_list = jnp.split(filters, self._num_mp_steps, axis=-1)
return filter_list
def _transform(
self, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray
) -> jraph.GraphsTuple:
n_nodes = particle_type.shape[0]
# node features
node_scalars = []
node_vectors = []
traj = jnp.reshape(features["vel_hist"], (n_nodes, self._n_vels, -1))
node_vectors.append(traj.transpose(0, 2, 1))
if "force" in features:
node_vectors.append(features["force"][..., jnp.newaxis])
if "bound" in features:
bounds = jnp.reshape(features["bound"], (n_nodes, 2, -1))
node_vectors.append(bounds.transpose(0, 2, 1))
# velocity magnitudes as node feature
node_scalars.append(features["vel_mag"])
if not self._homogeneous_particles:
particles = jax.nn.one_hot(particle_type, NodeType.SIZE)
node_scalars.append(particles)
node_scalars = jnp.concatenate(node_scalars, axis=-1)
node_vectors = jnp.concatenate(node_vectors, axis=-1)
return jraph.GraphsTuple(
nodes=NodeFeatures(s=node_scalars, v=node_vectors),
edges=features["rel_disp"],
senders=features["senders"],
receivers=features["receivers"],
n_node=jnp.array([n_nodes]),
n_edge=jnp.array([len(features["senders"])]),
globals=None,
)
[docs]
def __call__(
self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray]
) -> Dict[str, jnp.ndarray]:
graph = self._transform(*sample)
# compute atom and pair features
norm_ij = jnp.sqrt(jnp.sum(graph.edges**2, axis=1, keepdims=True) + self._eps)
# edge directions
dir_ij = graph.edges / (norm_ij + self._eps)
graph = graph._replace(edges=dir_ij)
# compute filters (r_ij track in message block from the paper)
filter_list = self._get_filters(norm_ij)
# embeds node scalar features (and vector, if present)
graph = self._embed(graph)
# message passing
for n, layer in enumerate(self.layers):
graph = layer(graph, filter_list[n])
_, v = self._readout(graph)
return {"acc": v}