Source code for lagrangebench.models.utils

from typing import Callable, Dict, Iterable, NamedTuple, Optional

import e3nn_jax as e3nn
import haiku as hk
import jax
import jax.numpy as jnp
import jraph

from lagrangebench.utils import NodeType


[docs] class LinearXav(hk.Linear): """Linear layer with Xavier init. Avoid distracting 'w_init' everywhere."""
[docs] def __init__( self, output_size: int, with_bias: bool = True, w_init: Optional[hk.initializers.Initializer] = None, b_init: Optional[hk.initializers.Initializer] = None, name: Optional[str] = None, ): if w_init is None: w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform") super().__init__(output_size, with_bias, w_init, b_init, name)
[docs] class MLPXav(hk.nets.MLP): """MLP layer with Xavier init. Avoid distracting 'w_init' everywhere."""
[docs] def __init__( self, output_sizes: Iterable[int], with_bias: bool = True, w_init: Optional[hk.initializers.Initializer] = None, b_init: Optional[hk.initializers.Initializer] = None, activation: Callable = jax.nn.silu, activate_final: bool = False, name: Optional[str] = None, ): if w_init is None: w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform") if not with_bias: b_init = None super().__init__( output_sizes, w_init, b_init, with_bias, activation, activate_final, name, )
[docs] class SteerableGraphsTuple(NamedTuple): r""" Pack (steerable) node and edge attributes with jraph.GraphsTuple. Attributes: graph: jraph.GraphsTuple, graph structure node_attributes: (N, irreps.dim), node attributes :math:`\mathbf{\hat{a}}_i` edge_attributes: (E, irreps.dim), edge attributes :math:`\mathbf{\hat{a}}_{ij}` additional_message_features: (E, edge_dim), optional message features """ graph: jraph.GraphsTuple node_attributes: Optional[e3nn.IrrepsArray] = None edge_attributes: Optional[e3nn.IrrepsArray] = None # NOTE: additional_message_features is in a separate field otherwise it would get # updated by jraph.GraphNetwork. Actual graph edges are used only for the messages. additional_message_features: Optional[e3nn.IrrepsArray] = None
[docs] def node_irreps( metadata: Dict, input_seq_length: int, has_external_force: bool, has_magnitudes: bool, has_homogeneous_particles: bool, ) -> str: """Compute input node irreps based on which features are available.""" irreps = [] irreps.append(f"{input_seq_length - 1}x1o") if not any(metadata["periodic_boundary_conditions"]): irreps.append("2x1o") if has_external_force: irreps.append("1x1o") if has_magnitudes: irreps.append(f"{input_seq_length - 1}x0e") if not has_homogeneous_particles: irreps.append(f"{NodeType.SIZE}x0e") return e3nn.Irreps("+".join(irreps))
[docs] def build_mlp( latent_size, output_size, num_hidden_layers, is_layer_norm=True, **kwds: Dict ): """MLP generation helper using Haiku.""" assert num_hidden_layers >= 1 network = hk.nets.MLP( [latent_size] * (num_hidden_layers - 1) + [output_size], **kwds, activate_final=False, name="MLP", ) if is_layer_norm: l_norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) return hk.Sequential([network, l_norm]) else: return network
[docs] def features_2d_to_3d(features): """Add zeros in the z component of 2D features.""" n_nodes = features["vel_hist"].shape[0] n_edges = features["rel_disp"].shape[0] n_vels = features["vel_hist"].shape[1] features["vel_hist"] = jnp.concatenate( [features["vel_hist"], jnp.zeros((n_nodes, n_vels, 1))], -1 ) features["rel_disp"] = jnp.concatenate( [features["rel_disp"], jnp.zeros((n_edges, 1))], -1 ) if "bound" in features: features["bound"] = jnp.concatenate( [features["bound"], jnp.zeros((n_nodes, 1))], -1 ) if "force" in features: features["force"] = jnp.concatenate( [features["force"], jnp.zeros((n_nodes, 1))], -1 ) return features