Source code for jax_sph.io_state

"""Input-output utilities."""

import os
import time
from typing import Dict

import h5py
import jax.numpy as jnp
import numpy as np
import pyvista
from omegaconf import DictConfig, OmegaConf


[docs] def io_setup(cfg: DictConfig): """Setup the output directory and write the arguments to a .txt file.""" case, solver, dim = cfg.case.name.upper(), cfg.solver.name, cfg.case.dim dir = cfg.io.data_path if not dir.endswith("/"): dir += "/" if len(cfg.io.write_type) > 0 and cfg.case.mode == "sim": dir += str(dim) + "D_" + case + "_" + solver + "_" + str(cfg.seed) dir += "_" + time.strftime("%Y%m%d-%H%M%S") os.makedirs(dir, exist_ok=True) OmegaConf.save(cfg, os.path.join(dir, "config.yaml")) return dir
[docs] def write_h5(data_dict: Dict, path: str): """Write a dict of numpy or jax arrays to a .h5 file.""" hf = h5py.File(path, "w") for k, v in data_dict.items(): hf.create_dataset(k, data=np.array(v)) hf.close()
[docs] def write_vtk(data_dict: Dict, path: str): """Store a .vtk file for ParaView.""" data_pv = dict2pyvista(data_dict) data_pv.save(path)
[docs] def write_state(step: int, state: Dict, dir: str, cfg: DictConfig): """Write state to .h5 or .vtk file while simulation is running.""" step_max = cfg.solver.sequence_length write_normal = ( (cfg.case.mode == "sim") and ((step % cfg.io.write_every) == 0) and (step >= 0) ) write_relax = (cfg.case.mode == "rlx") and (step == (step_max - 1)) if write_normal or write_relax: digits = len(str(step_max)) step_str = str(step).zfill(digits) if cfg.case.mode == "rlx": name = [cfg.case.name, str(cfg.case.dim), str(cfg.case.dx), str(cfg.seed)] name = "_".join(name) # e.g. "tgv_3_0.02_42" else: name = "traj_" + step_str if "h5" in cfg.io.write_type: path = os.path.join(dir, name + ".h5") write_h5(state, path) if "vtk" in cfg.io.write_type: path = os.path.join(dir, name + ".vtk") write_vtk(state, path)
[docs] def read_h5(file_name: str, array_type: str = "jax"): """Read an .h5 file and return a dict of numpy or jax arrays.""" hf = h5py.File(file_name, "r") data_dict = {} for k, v in hf.items(): if array_type == "jax": data_dict[k] = jnp.array(v) elif array_type == "numpy": data_dict[k] = np.array(v) else: raise ValueError('array_type must be either "jax" or "numpy"') hf.close() return data_dict
[docs] def write_vtks_from_h5s(dir_path: str, keep_h5: bool = True): """Transform a set of .h5 files to .vtk files. Args: path: path to directory with .h5 files keep_h5: Whether to keep or delete the original .h5 files. """ files = os.listdir(dir_path) # only regard h5 files files = [f for f in files if (".h5" in f)] files = sorted(files) for i, filename in enumerate(files): file_path_h5 = os.path.join(dir_path, filename) file_path_vtk = file_path_h5[:-2] + "vtk" state = read_h5(file_path_h5) write_vtk(state, file_path_vtk) if not keep_h5: os.remove(file_path_h5)
def render_data_dict(data_dict): pl = pyvista.Plotter() pl.show_grid() pl.add_bounding_box() pl.add_axes() pl.camera_position = "xy" pl.show() def dict2pyvista(data_dict): # PyVista works only with 3D objects, thus we check whether the inputs # are 2D and then increase the degrees of freedom of the second dimension. # N is the number of points and dim the dimension r = np.asarray(data_dict["r"]) N, dim = r.shape # PyVista treats the position information differently than the rest if dim == 2: r = np.hstack([r, np.zeros((N, 1))]) data_pv = pyvista.PolyData(r) # copy all the other information also to pyvista, using plain numpy arrays for k, v in data_dict.items(): # skip r because we already considered it above if k == "r": continue # working in 3D or scalar features do not require special care if dim == 2 and v.ndim == 2: v = np.hstack([v, np.zeros((N, 1))]) data_pv[k] = np.asarray(v) return data_pv if __name__ == "__main__": # write_vtks_from_h5s("./3D_TGV_SPH_17_20221224-084051", False) pass