Defaults

Listing 1 LagrangeBench default values
with open("lagrangebench/defaults.py", "r") as file:
    defaults_full = file.read()

# parse defaults: remove imports, only keep the set_defaults function

defaults_full = defaults_full.split("\n")

# remove imports
defaults_full = [line for line in defaults_full if not line.startswith("import")]
defaults_full = [line for line in defaults_full if len(line.replace(" ", "")) > 0]

# remove other functions
keep = False
defaults = []
for i, line in enumerate(defaults_full):
    if line.startswith("def"):
        if "set_defaults" in line:
            keep = True
        else:
            keep = False

    if keep:
        defaults.append(line)

# remove function declaration and return
defaults = defaults[2:-2]

# remove indent
defaults = [line[4:] for line in defaults]


print("\n".join(defaults))
  1### global and hardware-related configs
  2# configuration file. Either "config" or "load_ckp" must be specified.
  3# If "config" is specified, "load_ckp" is ignored.
  4cfg.config = None
  5# Load checkpointed model from this directory
  6cfg.load_ckp = None
  7# One of "train", "infer" or "all" (= both)
  8cfg.mode = "all"
  9# random seed
 10cfg.seed = 0
 11# data type for preprocessing. One of "float32" or "float64"
 12cfg.dtype = "float64"
 13# gpu device. -1 for CPU. Should be specified before importing the library.
 14cfg.gpu = None
 15# XLA memory fraction to be preallocated. The JAX default is 0.75.
 16# Should be specified before importing the library.
 17cfg.xla_mem_fraction = None
 18### dataset
 19cfg.dataset = OmegaConf.create({})
 20# path to data directory
 21cfg.dataset.src = None
 22# dataset name
 23cfg.dataset.name = None
 24### model
 25cfg.model = OmegaConf.create({})
 26# model architecture name. gns, segnn, egnn
 27cfg.model.name = None
 28# Length of the position input sequence
 29cfg.model.input_seq_length = 6
 30# Number of message passing steps
 31cfg.model.num_mp_steps = 10
 32# Number of MLP layers
 33cfg.model.num_mlp_layers = 2
 34# Hidden dimension
 35cfg.model.latent_dim = 128
 36# whether to include velocity magnitude features
 37cfg.model.magnitude_features = False
 38#  whether to normalize dimensions equally
 39cfg.model.isotropic_norm = False
 40# SEGNN only parameters
 41# steerable attributes level
 42cfg.model.lmax_attributes = 1
 43# Level of the hidden layer
 44cfg.model.lmax_hidden = 1
 45# SEGNN normalization. instance, batch, none
 46cfg.model.segnn_norm = "none"
 47# SEGNN velocity aggregation. avg or last
 48cfg.model.velocity_aggregate = "avg"
 49### training
 50cfg.train = OmegaConf.create({})
 51# batch size
 52cfg.train.batch_size = 1
 53# max number of training steps
 54cfg.train.step_max = 500_000
 55# number of workers for data loading
 56cfg.train.num_workers = 4
 57# standard deviation of the GNS-style noise
 58cfg.train.noise_std = 3.0e-4
 59# optimizer
 60cfg.train.optimizer = OmegaConf.create({})
 61# initial learning rate
 62cfg.train.optimizer.lr_start = 1.0e-4
 63# final learning rate (after exponential decay)
 64cfg.train.optimizer.lr_final = 1.0e-6
 65# learning rate decay rate
 66cfg.train.optimizer.lr_decay_rate = 0.1
 67# number of steps to decay learning rate
 68cfg.train.optimizer.lr_decay_steps = 1.0e5
 69# pushforward
 70cfg.train.pushforward = OmegaConf.create({})
 71# At which training step to introduce next unroll stage
 72cfg.train.pushforward.steps = [-1, 20000, 300000, 400000]
 73# For how many steps to unroll
 74cfg.train.pushforward.unrolls = [0, 1, 2, 3]
 75# Which probability ratio to keep between the unrolls
 76cfg.train.pushforward.probs = [18, 2, 1, 1]
 77# loss weights
 78cfg.train.loss_weight = OmegaConf.create({})
 79# weight for acceleration error
 80cfg.train.loss_weight.acc = 1.0
 81# weight for velocity error
 82cfg.train.loss_weight.vel = 0.0
 83# weight for position error
 84cfg.train.loss_weight.pos = 0.0
 85### evaluation
 86cfg.eval = OmegaConf.create({})
 87# number of eval rollout steps. -1 is full rollout
 88cfg.eval.n_rollout_steps = 20
 89# whether to use the test or valid split
 90cfg.eval.test = False
 91# rollouts directory
 92cfg.eval.rollout_dir = None
 93# configs for validation during training
 94cfg.eval.train = OmegaConf.create({})
 95# number of trajectories to evaluate
 96cfg.eval.train.n_trajs = 50
 97# stride for e_kin and sinkhorn
 98cfg.eval.train.metrics_stride = 10
 99# batch size
100cfg.eval.train.batch_size = 1
101# metrics to evaluate
102cfg.eval.train.metrics = ["mse"]
103# write validation rollouts. One of "none", "vtk", or "pkl"
104cfg.eval.train.out_type = "none"
105# configs for inference/testing
106cfg.eval.infer = OmegaConf.create({})
107# number of trajectories to evaluate during inference
108cfg.eval.infer.n_trajs = -1
109# stride for e_kin and sinkhorn
110cfg.eval.infer.metrics_stride = 1
111# batch size
112cfg.eval.infer.batch_size = 2
113# metrics for inference
114cfg.eval.infer.metrics = ["mse", "e_kin", "sinkhorn"]
115# write inference rollouts. One of "none", "vtk", or "pkl"
116cfg.eval.infer.out_type = "pkl"
117# number of extrapolation steps during inference
118cfg.eval.infer.n_extrap_steps = 0
119### logging
120cfg.logging = OmegaConf.create({})
121# number of steps between loggings
122cfg.logging.log_steps = 1000
123# number of steps between evaluations and checkpoints
124cfg.logging.eval_steps = 10000
125# wandb enable
126cfg.logging.wandb = False
127# wandb project name
128cfg.logging.wandb_project = None
129# wandb entity name
130cfg.logging.wandb_entity = "lagrangebench"
131# checkpoint directory
132cfg.logging.ckp_dir = "ckp"
133# name of training run
134cfg.logging.run_name = None
135### neighbor list
136cfg.neighbors = OmegaConf.create({})
137# backend for neighbor list computation
138cfg.neighbors.backend = "jaxmd_vmap"
139# multiplier for neighbor list capacity
140cfg.neighbors.multiplier = 1.25