Source code for lagrangebench.defaults

"""Default lagrangebench configs."""


from omegaconf import DictConfig, OmegaConf


[docs] def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig: """Set default lagrangebench configs.""" ### global and hardware-related configs # configuration file. Either "config" or "load_ckp" must be specified. # If "config" is specified, "load_ckp" is ignored. cfg.config = None # Load checkpointed model from this directory cfg.load_ckp = None # One of "train", "infer" or "all" (= both) cfg.mode = "all" # random seed cfg.seed = 0 # data type for preprocessing. One of "float32" or "float64" cfg.dtype = "float64" # gpu device. -1 for CPU. Should be specified before importing the library. cfg.gpu = None # XLA memory fraction to be preallocated. The JAX default is 0.75. # Should be specified before importing the library. cfg.xla_mem_fraction = None ### dataset cfg.dataset = OmegaConf.create({}) # path to data directory cfg.dataset.src = None # dataset name cfg.dataset.name = None ### model cfg.model = OmegaConf.create({}) # model architecture name. gns, segnn, egnn cfg.model.name = None # Length of the position input sequence cfg.model.input_seq_length = 6 # Number of message passing steps cfg.model.num_mp_steps = 10 # Number of MLP layers cfg.model.num_mlp_layers = 2 # Hidden dimension cfg.model.latent_dim = 128 # whether to include velocity magnitude features cfg.model.magnitude_features = False # whether to normalize dimensions equally cfg.model.isotropic_norm = False # SEGNN only parameters # steerable attributes level cfg.model.lmax_attributes = 1 # Level of the hidden layer cfg.model.lmax_hidden = 1 # SEGNN normalization. instance, batch, none cfg.model.segnn_norm = "none" # SEGNN velocity aggregation. avg or last cfg.model.velocity_aggregate = "avg" ### training cfg.train = OmegaConf.create({}) # batch size cfg.train.batch_size = 1 # max number of training steps cfg.train.step_max = 500_000 # number of workers for data loading cfg.train.num_workers = 4 # standard deviation of the GNS-style noise cfg.train.noise_std = 3.0e-4 # optimizer cfg.train.optimizer = OmegaConf.create({}) # initial learning rate cfg.train.optimizer.lr_start = 1.0e-4 # final learning rate (after exponential decay) cfg.train.optimizer.lr_final = 1.0e-6 # learning rate decay rate cfg.train.optimizer.lr_decay_rate = 0.1 # number of steps to decay learning rate cfg.train.optimizer.lr_decay_steps = 1.0e5 # pushforward cfg.train.pushforward = OmegaConf.create({}) # At which training step to introduce next unroll stage cfg.train.pushforward.steps = [-1, 20000, 300000, 400000] # For how many steps to unroll cfg.train.pushforward.unrolls = [0, 1, 2, 3] # Which probability ratio to keep between the unrolls cfg.train.pushforward.probs = [18, 2, 1, 1] # loss weights cfg.train.loss_weight = OmegaConf.create({}) # weight for acceleration error cfg.train.loss_weight.acc = 1.0 # weight for velocity error cfg.train.loss_weight.vel = 0.0 # weight for position error cfg.train.loss_weight.pos = 0.0 ### evaluation cfg.eval = OmegaConf.create({}) # number of eval rollout steps. -1 is full rollout cfg.eval.n_rollout_steps = 20 # whether to use the test or valid split cfg.eval.test = False # rollouts directory cfg.eval.rollout_dir = None # configs for validation during training cfg.eval.train = OmegaConf.create({}) # number of trajectories to evaluate cfg.eval.train.n_trajs = 50 # stride for e_kin and sinkhorn cfg.eval.train.metrics_stride = 10 # batch size cfg.eval.train.batch_size = 1 # metrics to evaluate cfg.eval.train.metrics = ["mse"] # write validation rollouts. One of "none", "vtk", or "pkl" cfg.eval.train.out_type = "none" # configs for inference/testing cfg.eval.infer = OmegaConf.create({}) # number of trajectories to evaluate during inference cfg.eval.infer.n_trajs = -1 # stride for e_kin and sinkhorn cfg.eval.infer.metrics_stride = 1 # batch size cfg.eval.infer.batch_size = 2 # metrics for inference cfg.eval.infer.metrics = ["mse", "e_kin", "sinkhorn"] # write inference rollouts. One of "none", "vtk", or "pkl" cfg.eval.infer.out_type = "pkl" # number of extrapolation steps during inference cfg.eval.infer.n_extrap_steps = 0 ### logging cfg.logging = OmegaConf.create({}) # number of steps between loggings cfg.logging.log_steps = 1000 # number of steps between evaluations and checkpoints cfg.logging.eval_steps = 10000 # wandb enable cfg.logging.wandb = False # wandb project name cfg.logging.wandb_project = None # wandb entity name cfg.logging.wandb_entity = "lagrangebench" # checkpoint directory cfg.logging.ckp_dir = "ckp" # name of training run cfg.logging.run_name = None ### neighbor list cfg.neighbors = OmegaConf.create({}) # backend for neighbor list computation cfg.neighbors.backend = "jaxmd_vmap" # multiplier for neighbor list capacity cfg.neighbors.multiplier = 1.25 return cfg
defaults = set_defaults()
[docs] def check_cfg(cfg: DictConfig): """Check if the configs are valid.""" assert cfg.mode in ["train", "infer", "all"] assert cfg.dtype in ["float32", "float64"] assert cfg.dataset.src is not None, "dataset.src must be specified." assert cfg.model.input_seq_length >= 2, "At least two positions for one past vel." pf = cfg.train.pushforward assert len(pf.steps) == len(pf.unrolls) == len(pf.probs) assert all([s >= 0 for s in pf.unrolls]), "All unrolls must be non-negative." assert all([s >= 0 for s in pf.probs]), "All probabilities must be non-negative." lwv = cfg.train.loss_weight.values() assert all([w >= 0 for w in lwv]), "All loss weights must be non-negative." assert sum(lwv) > 0, "At least one loss weight must be non-zero." assert cfg.eval.train.n_trajs >= -1 assert cfg.eval.infer.n_trajs >= -1 assert set(cfg.eval.train.metrics).issubset(["mse", "e_kin", "sinkhorn"]) assert set(cfg.eval.infer.metrics).issubset(["mse", "e_kin", "sinkhorn"]) assert cfg.eval.train.out_type in ["none", "vtk", "pkl"] assert cfg.eval.infer.out_type in ["none", "vtk", "pkl"]