Source code for lagrangebench.data.data

"""Dataset modules for loading HDF5 simulation trajectories."""

import bisect
import importlib
import json
import os
import os.path as osp
import re
import warnings
import zipfile
from typing import Optional

import h5py
import jax.numpy as jnp
import numpy as np
import wget
from torch.utils.data import Dataset

from lagrangebench.utils import NodeType

ZENODO_PREFIX = "https://zenodo.org/records/10491868/files/"
URLS = {
    "tgv2d": f"{ZENODO_PREFIX}2D_TGV_2500_10kevery100.zip",
    "rpf2d": f"{ZENODO_PREFIX}2D_RPF_3200_20kevery100.zip",
    "ldc2d": f"{ZENODO_PREFIX}2D_LDC_2708_10kevery100.zip",
    "dam2d": f"{ZENODO_PREFIX}2D_DAM_5740_20kevery100.zip",
    "tgv3d": f"{ZENODO_PREFIX}3D_TGV_8000_10kevery100.zip",
    "rpf3d": f"{ZENODO_PREFIX}3D_RPF_8000_10kevery100.zip",
    "ldc3d": f"{ZENODO_PREFIX}3D_LDC_8160_10kevery100.zip",
}


[docs] class H5Dataset(Dataset): """Dataset for loading HDF5 simulation trajectories. Reference on parallel loading of h5 samples see: https://github.com/pytorch/pytorch/issues/11929 Implementation inspired by: https://github.com/Open-Catalyst-Project/ocp/blob/main/ocpmodels/datasets/lmdb_dataset.py """
[docs] def __init__( self, split: str, dataset_path: str, name: Optional[str] = None, input_seq_length: int = 6, extra_seq_length: int = 0, nl_backend: str = "jaxmd_vmap", ): """Initialize the dataset. If the dataset is not present, it is downloaded. Args: split: "train", "valid", or "test" dataset_path: Path to the dataset. Download will start automatically if dataset_path does not exist. name: Name of the dataset. If None, it is inferred from the path. input_seq_length: Length of the input sequence. The number of historic velocities is input_seq_length - 1. And during training, the returned number of past positions is input_seq_length + 1, to compute target acceleration. extra_seq_length: During training, this is the maximum number of pushforward unroll steps. During validation/testing, this specifies the largest N-step MSE loss we are interested in, e.g. for best model checkpointing. nl_backend: Which backend to use for the neighbor list """ dataset_path = osp.normpath(dataset_path) # remove potential trailing slash if name is None: self.name = get_dataset_name_from_path(dataset_path) else: self.name = name if not osp.exists(dataset_path): dataset_path = self.download(self.name, dataset_path) assert split in ["train", "valid", "test"] assert ( input_seq_length > 1 ), "To compute at least one past velocity, input_seq_length must be >= 2." self.dataset_path = dataset_path self.file_path = osp.join(dataset_path, split + ".h5") self.input_seq_length = input_seq_length self.nl_backend = nl_backend force_fn_path = osp.join(dataset_path, "force.py") if osp.exists(force_fn_path): # load force_fn if `force.py` exists in dataset_path spec = importlib.util.spec_from_file_location("force_module", force_fn_path) force_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(force_module) self.external_force_fn = force_module.force_fn else: if self.name in ["dam2d", "rpf2d", "rpf3d"]: raise FileNotFoundError( f"External force function not found in {dataset_path}. " "Download the latest LagrangeBench dataset from Zenodo." ) self.external_force_fn = None # load dataset metadata with open(osp.join(dataset_path, "metadata.json"), "r") as f: self.metadata = json.loads(f.read()) self.db_hdf5 = None with h5py.File(self.file_path, "r") as f: self.traj_keys = list(f.keys()) # (num_steps, num_particles, dim) = f["00000/position"].shape self.sequence_length = f["00000/position"].shape[0] if split == "train": # During training, the first input_seq_length steps can only be used as # input, and the last one to compute the target acceleration. If we use # pushforward, then we need to provide extra_seq_length more steps # from the end of a trajectory. Thus, the number of training samples per # trajectory becomes: self.subseq_length = input_seq_length + 1 + extra_seq_length samples_per_traj = self.sequence_length - self.subseq_length + 1 keylens = jnp.array([samples_per_traj for _ in range(len(self.traj_keys))]) self._keylen_cumulative = jnp.cumsum(keylens).tolist() self.num_samples = sum(keylens) self.getter = self.get_window else: assert ( extra_seq_length > 0 ), "extra_seq_length must be > 0 for validation and testing." # Compute the number of splits per validation trajectory. If the length of # each trajectory is 1000, we want to compute a 20-step MSE, and # intput_seq_length=6, then we should split the trajectory into # _split_valid_traj_into_n = 1000 // (20 + 6) chunks. self.subseq_length = input_seq_length + extra_seq_length self._split_valid_traj_into_n = self.sequence_length // self.subseq_length self.num_samples = self._split_valid_traj_into_n * len(self.traj_keys) self.getter = self.get_trajectory assert self.sequence_length >= self.subseq_length, ( f"# steps in dataset trajectory ({self.sequence_length}) must be >= " f"subsequence length ({self.subseq_length}). Reduce either " f"input_seq_length or extra_seq_length/max pushforward steps." )
[docs] def download(self, name: str, path: str) -> str: """Download the dataset. Args: name: Name of the dataset path: Destination path to the downloaded dataset """ assert name in URLS, f"Dataset {name} not available." url = URLS[name] # path could be e.g. "./data/2D_TGV_2500_10kevery100/" # remove trailing slash if present and get the root of the datasets path = path[:-1] if path.endswith("/") else path path_root = osp.split(path)[0] # e.g. # "./data" # download the dataset as a zip file, e.g. "./data/2D_TGV_2500_10kevery100.zip" os.makedirs(path_root, exist_ok=True) filename = wget.download(url, out=path_root) print(f"\nDataset {name} downloaded to {filename}") # unzip the dataset and then remove the zip file zipfile.ZipFile(filename, "r").extractall(path_root) os.remove(filename) return path
def _open_hdf5(self) -> h5py.File: if self.db_hdf5 is None: return h5py.File(self.file_path, "r") else: return self.db_hdf5 def _matscipy_pad(self, pos_input, particle_type): padding_size = self.metadata["num_particles_max"] - pos_input.shape[0] pos_input = np.pad( pos_input, ((0, padding_size), (0, 0), (0, 0)), mode="constant", constant_values=0.0, ) particle_type = np.pad( particle_type, (0, padding_size), mode="constant", constant_values=NodeType.PAD_VALUE, ) return pos_input, particle_type
[docs] def get_trajectory(self, idx: int): """Get a (full) trajectory and index idx.""" # open the database file self.db_hdf5 = self._open_hdf5() if self._split_valid_traj_into_n > 1: traj_idx = idx // self._split_valid_traj_into_n slice_from = (idx % self._split_valid_traj_into_n) * self.subseq_length slice_to = slice_from + self.subseq_length else: traj_idx = idx slice_from = 0 slice_to = self.sequence_length # get a pointer to the trajectory. That is not yet the real trajectory. traj = self.db_hdf5[f"{self.traj_keys[traj_idx]}"] # get a pointer to the positions of the traj. Still nothing in memory. traj_pos = traj["position"] # load and transpose the trajectory pos_input = traj_pos[slice_from:slice_to].transpose((1, 0, 2)) particle_type = traj["particle_type"][:] if self.nl_backend == "matscipy": pos_input, particle_type = self._matscipy_pad(pos_input, particle_type) return pos_input, particle_type
[docs] def get_window(self, idx: int): """Get a window of the trajectory and index idx.""" # figure out which trajectory this should be indexed from. traj_idx = bisect.bisect(self._keylen_cumulative, idx) # extract index of element within that trajectory. el_idx = idx if traj_idx != 0: el_idx = idx - self._keylen_cumulative[traj_idx - 1] assert el_idx >= 0 # open the database file self.db_hdf5 = self._open_hdf5() # get a pointer to the trajectory. That is not yet the real trajectory. traj = self.db_hdf5[f"{self.traj_keys[traj_idx]}"] # get a pointer to the positions of the traj. Still nothing in memory. traj_pos = traj["position"] # load only a slice of the positions. Now, this is an array in memory. pos_input_and_target = traj_pos[el_idx : el_idx + self.subseq_length] pos_input_and_target = pos_input_and_target.transpose((1, 0, 2)) particle_type = traj["particle_type"][:] if self.nl_backend == "matscipy": pos_input_and_target, particle_type = self._matscipy_pad( pos_input_and_target, particle_type ) return pos_input_and_target, particle_type
[docs] def __getitem__(self, idx: int): """ Get a sequence of positions (of size windows) from the dataset at index idx. Returns: Array of shape (num_particles_max, input_seq_length + 1, dim). Along axis=1 the position sequence (length input_seq_length) and the last position to compute the target acceleration. """ return self.getter(idx)
def __len__(self): return self.num_samples
[docs] def get_dataset_name_from_path(path: str) -> str: """Infer the dataset name from the provided path. Variant 1: If the dataset directory contains {2|3}D_{ABC}, then the name is inferred as {abc2d|abc3d}. These names are based on the lagrangebench dataset directories: {2D|3D}_{TGV|RPF|LDC|DAM}_{num_particles_max}_{num_steps}every{sampling_rate} The shorter dataset names then become one of the following: {tgv2d|tgv3d|rpf2d|rpf3d|ldc2d|ldc3d|dam2d} Variant 2: If the condition {2|3}D_{ABC} is not met, the name is the dataset directory """ dir = osp.basename(osp.normpath(path)) name = re.search(r"(?:2D|3D)_[A-Z]{3}", dir) if name is not None: # lagrangebench convention used name = name.group(0) name = f"{name.split('_')[1]}{name.split('_')[0]}".lower() else: warnings.warn( f"Dataset directory {dir} does not follow the lagrangebench convention. " "Valid name formats: {2D|3D}_{TGV|RPF|LDC|DAM}. Alternatively, you can " "specify the dataset name explicitly." ) name = dir return name
[docs] class TGV2D(H5Dataset): """Taylor-Green Vortex 2D dataset. 2.5K particles."""
[docs] def __init__( self, split: str, dataset_path: str = "datasets/2D_TGV_2500_10kevery100", input_seq_length: int = 6, extra_seq_length: int = 0, nl_backend: str = "jaxmd_vmap", ): super().__init__( split, dataset_path, name="tgv2d", input_seq_length=input_seq_length, extra_seq_length=extra_seq_length, nl_backend=nl_backend, )
[docs] class TGV3D(H5Dataset): """Taylor-Green Vortex 3D dataset. 8K particles."""
[docs] def __init__( self, split: str, dataset_path: str = "datasets/3D_TGV_8000_10kevery100", input_seq_length: int = 6, extra_seq_length: int = 0, nl_backend: str = "jaxmd_vmap", ): super().__init__( split, dataset_path, name="tgv3d", input_seq_length=input_seq_length, extra_seq_length=extra_seq_length, nl_backend=nl_backend, )
[docs] class RPF2D(H5Dataset): """Reverse Poiseuille Flow 2D dataset. 3.2K particles."""
[docs] def __init__( self, split: str, dataset_path: str = "datasets/2D_RPF_3200_20kevery100", input_seq_length: int = 6, extra_seq_length: int = 0, nl_backend: str = "jaxmd_vmap", ): super().__init__( split, dataset_path, name="rpf2d", input_seq_length=input_seq_length, extra_seq_length=extra_seq_length, nl_backend=nl_backend, )
[docs] class RPF3D(H5Dataset): """Reverse Poiseuille Flow 3D dataset. 8K particles."""
[docs] def __init__( self, split: str, dataset_path: str = "datasets/3D_RPF_8000_10kevery100", input_seq_length: int = 6, extra_seq_length: int = 0, nl_backend: str = "jaxmd_vmap", ): super().__init__( split, dataset_path, name="rpf3d", input_seq_length=input_seq_length, extra_seq_length=extra_seq_length, nl_backend=nl_backend, )
[docs] class LDC2D(H5Dataset): """Lid-Driven Cabity 2D dataset. 2.5K particles."""
[docs] def __init__( self, split: str, dataset_path: str = "datasets/2D_LDC_2500_10kevery100", input_seq_length: int = 6, extra_seq_length: int = 0, nl_backend: str = "jaxmd_vmap", ): super().__init__( split, dataset_path, name="ldc2d", input_seq_length=input_seq_length, extra_seq_length=extra_seq_length, nl_backend=nl_backend, )
[docs] class LDC3D(H5Dataset): """Lid-Driven Cabity 3D dataset. 8.2K particles."""
[docs] def __init__( self, split: str, dataset_path: str = "datasets/3D_LDC_8160_10kevery100", input_seq_length: int = 6, extra_seq_length: int = 0, nl_backend: str = "jaxmd_vmap", ): super().__init__( split, dataset_path, name="ldc3d", input_seq_length=input_seq_length, extra_seq_length=extra_seq_length, nl_backend=nl_backend, )
[docs] class DAM2D(H5Dataset): """Dam break 2D dataset. 5.7K particles."""
[docs] def __init__( self, split: str, dataset_path: str = "datasets/2D_DB_5740_20kevery100", input_seq_length: int = 6, extra_seq_length: int = 0, nl_backend: str = "jaxmd_vmap", ): super().__init__( split, dataset_path, name="dam2d", input_seq_length=input_seq_length, extra_seq_length=extra_seq_length, nl_backend=nl_backend, )