Source code for jax_sph.case_setup

"""Simulation setup."""

import importlib
import os
import warnings
from abc import ABC, abstractmethod

import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap

from jax_sph.eos import RIEMANNEoS, TaitEoS
from jax_sph.io_state import read_h5
from jax_sph.jax_md import space
from jax_sph.utils import (
    Tag,
    compute_nws_jax_wrapper,
    compute_nws_scipy,
    get_noise_masked,
    pos_init_cartesian_2d,
    pos_init_cartesian_3d,
    wall_tags,
)

EPS = jnp.finfo(float).eps


[docs] class SimulationSetup(ABC): """Parent class for all simulation setups.""" def __init__(self, cfg): # If a value is changed in one of these, it also changes in the others self.cfg = cfg self.case = cfg.case self.special = cfg.case.special if cfg.case.r0_type == "relaxed": assert (cfg.case.state0_path is not None) and ("r" in cfg.case.state0_keys) # get the config file name, e.g. "cases/db.yaml" -> "db" cfg.case.name = os.path.splitext(os.path.basename(cfg.config))[0]
[docs] def initialize(self): """Initialize and return everything needed for the numerical setup. This function sets up all physical and numerical quantities including: - reference values and case specific parameters - integration time step - relaxation or simulation mode - equation of state - boundary conditions - state dictionary - function including external force and boundary conditions Returns: A tuple containing the following elements: - cfg (DictConfig): configuration arguments - box_size (ndarray): size of the simulation box starting at 0.0 - state (dict): dictionary containing all field values - g_ext_fn (Callable): external force function - bc_fn (Callable): boundary conditions function (e.g. velocity at walls) - nw_fn (Callable): jit-able wall normal funct. when moving walls, else None - eos (Callable): equation of state function - key (PRNGKey): random key for sampling - displacement_fn (Callable): displacement function for edge features - shift_fn (Callable): shift function for position updates """ cfg = self.cfg dx = cfg.case.dx dim = cfg.case.dim rho_ref = cfg.case.rho_ref viscosity = cfg.case.viscosity u_ref = cfg.case.u_ref cfl = cfg.solver.cfl key_prng = jax.random.PRNGKey(cfg.seed) # Primal: reference density, dynamic viscosity, and velocity # Derived: reference speed of sound, pressure c_ref = cfg.case.c_ref_factor * u_ref p_ref = rho_ref * c_ref**2 / cfg.eos.gamma # for free surface simulation p_background in the EoS has to be 0.0 p_bg = cfg.eos.p_bg_factor * p_ref # calculate volume and mass h = dx volume_ref = h**dim mass_ref = volume_ref * rho_ref # time integration step dt dt_convective = cfl * h / (c_ref + u_ref) dt_viscous = cfl * h**2 * rho_ref / (viscosity + EPS) dt_body_force = cfl * (h / (cfg.case.g_ext_magnitude + EPS)) ** 0.5 dt = np.amin([dt_convective, dt_viscous, dt_body_force]).item() # TODO: consider adaptive time step sizes print("dt_convective :", dt_convective) print("dt_viscous :", dt_viscous) print("dt_body_force :", dt_body_force) print("dt_max_allowed:", dt) if cfg.solver.dt is not None: if cfg.solver.dt > dt: warnings.warn("Explicit dt should comply with CFL.", UserWarning) dt = cfg.solver.dt print("dt_final :", dt) if cfg.case.mode == "rlx": # run a relaxation of randomly initialized state for 500 steps sequence_length = 5000 cfg.solver.t_end = dt * sequence_length # turn background pressure on for homogeneous particle distribution else: sequence_length = int(cfg.solver.t_end / dt) # Equation of state if cfg.solver.name == "RIE": eos = RIEMANNEoS(rho_ref, p_bg, u_ref) else: eos = TaitEoS(p_ref, rho_ref, p_bg, cfg.eos.gamma) # initialize box and positions of particles if dim == 2: box_size = self._box_size2D(cfg.solver.n_walls) r, tag = self._init_pos2D(box_size, dx, cfg.solver.n_walls) elif dim == 3: box_size = self._box_size3D(cfg.solver.n_walls) r, tag = self._init_pos3D(box_size, dx, cfg.solver.n_walls) displacement_fn, shift_fn = space.periodic(side=box_size) num_particles = len(r) print("Total number of particles = ", num_particles) # add noise to the fluid particles to break symmetry key, subkey = jax.random.split(key_prng) if cfg.case.r0_noise_factor != 0.0: noise_std = cfg.case.r0_noise_factor * dx noise = get_noise_masked(r.shape, tag == Tag.FLUID, subkey, std=noise_std) # PBC: move all particles to the box limits after noise addition r = shift_fn(r, noise) # initialize the velocity given the coordinates r with the noise if dim == 2: v = vmap(self._init_velocity2D)(r) elif dim == 3: v = vmap(self._init_velocity3D)(r) # initialize all other field values rho, mass, eta, temperature, kappa, Cp = self._set_field_properties( num_particles, mass_ref, cfg.case ) # whether to compute wall normals is_nw = ( cfg.solver.free_slip or cfg.solver.name == "RIE" ) and cfg.solver.is_bc_trick # calculate wall normals if necessary nw = self._compute_wall_normals("scipy")(r, tag) if is_nw else jnp.zeros_like(r) # initialize the state dictionary state = { "r": r, "tag": tag, "u": v, "v": v, "dudt": jnp.zeros_like(v), "dvdt": jnp.zeros_like(v), "drhodt": jnp.zeros_like(rho), "rho": rho, "p": eos.p_fn(rho), "mass": mass, "eta": eta, "dTdt": jnp.zeros_like(rho), "T": temperature, "kappa": kappa, "Cp": Cp, "nw": nw, } # overwrite the state dictionary with the provided one, only for the fluid if cfg.case.state0_path is not None: _state = read_h5(cfg.case.state0_path) for k in state: if k not in cfg.case.state0_keys: continue assert k in _state, ValueError(f"Key {k} not found in state0 file.") mask, _mask = state["tag"] == Tag.FLUID, _state["tag"] == Tag.FLUID assert state[k][mask].shape == _state[k][_mask].shape, ValueError( f"Shape mismatch for key {k} in state0 file." ) state[k] = state[k].at[mask].set(_state[k][_mask]) # the following arguments are needed for dataset generation cfg.case.c_ref, cfg.case.p_ref, cfg.case.p_bg = c_ref, p_ref, p_bg cfg.solver.dt, cfg.solver.sequence_length = dt, sequence_length cfg.case.num_particles_max = num_particles cfg.case.pbc = [True, True, True] # TODO: matscipy needs 3D cfg.case.bounds = np.array([np.zeros_like(box_size), box_size]).T.tolist() state = self._boundary_conditions_fn(state) g_ext_fn = self._external_acceleration_fn bc_fn = self._boundary_conditions_fn # whether to recompute the wall normals at every integration step is_nw_recompute = (tag == Tag.MOVING_WALL).any() and is_nw if is_nw_recompute: assert cfg.nl.backend != "matscipy", NotImplementedError( "Wall normals not yet implemented for matscipy neighbor list when " "working with moving boundaries. \nIf you work with moving boundaries, " "don't use one of: `nl.backend=matscipy` or `solver.free_slip=True` or " "`solver.name=RIE`." ) kwargs = {"disp_fn": displacement_fn, "box_size": box_size, "state0": state} nw_fn = self._compute_wall_normals("jax", **kwargs) if is_nw_recompute else None return ( cfg, box_size, state, g_ext_fn, bc_fn, nw_fn, eos, key, displacement_fn, shift_fn, )
@abstractmethod def _box_size2D(self, n_walls): pass @abstractmethod def _box_size3D(self, n_walls): pass def _init_pos2D(self, box_size, dx, n_walls): r = pos_init_cartesian_2d(box_size, dx) tag = jnp.full(len(r), Tag.FLUID, dtype=int) return r, tag def _init_pos3D(self, box_size, dx, n_walls): r = pos_init_cartesian_3d(box_size, dx) tag = jnp.full(len(r), Tag.FLUID, dtype=int) return r, tag @abstractmethod def _init_walls_2d(self): """Create all solid walls of a 2D case.""" pass @abstractmethod def _init_walls_3d(self): """Create all solid walls of a 3D case.""" pass @abstractmethod def _init_velocity2D(self, r): pass @abstractmethod def _init_velocity3D(self, r): pass @abstractmethod def _external_acceleration_fn(self, r): pass @abstractmethod def _boundary_conditions_fn(self, state): pass def _get_relaxed_r0(self, box_size, dx): assert hasattr(self, "_load_only_fluid"), AttributeError cfg = self.cfg name = "_".join([cfg.case.name, str(cfg.case.dim), str(dx), str(cfg.seed)]) init_path = os.path.join("data_relaxed", name + ".h5") if not os.path.isfile(init_path): message = ( f"python main.py config={cfg.config} mode=rlx seed={str(cfg.seed)} " f"case.dim={str(cfg.case.dim)} case.dx={str(cfg.case.dx)} " f"solver.name=SPH solver.tvf=1 solver.r0_noise_factor=0.25 " f"io.write_type=['h5'] io.data_path=data_relaxed/" ) raise FileNotFoundError(f"First execute this: \n{message}") state = read_h5(init_path) if self._load_only_fluid: return state["r"][state["tag"] == Tag.FLUID] else: return state["r"], state["tag"] def _set_field_properties(self, num_particles, mass_ref, case): rho = jnp.ones(num_particles) * case.rho_ref mass = jnp.ones(num_particles) * mass_ref eta = jnp.ones(num_particles) * case.viscosity temperature = jnp.ones(num_particles) * case.T_ref kappa = jnp.ones(num_particles) * case.kappa_ref Cp = jnp.ones(num_particles) * case.Cp_ref return rho, mass, eta, temperature, kappa, Cp def _set_default_rlx(self): """Set default values for relaxation case setup. These would only change if the domain is not full. """ self._box_size2D_rlx = self._box_size2D self._box_size3D_rlx = self._box_size3D self._init_pos2D_rlx = self._init_pos2D self._init_pos3D_rlx = self._init_pos3D def _compute_wall_normals(self, backend="scipy", **kwargs): if self.cfg.case.dim == 2: wall_part_fn = self._init_walls_2d elif self.cfg.case.dim == 3: wall_part_fn = self._init_walls_3d else: raise NotImplementedError("1D wall BCs not yet implemented") if backend == "scipy": # If one makes `tag` static (-> `self.tag`), this function can be jitted. # But it is significatly slower than `backend="jax"` due to `pure_callback`. def body(r, tag): return compute_nws_scipy( r, tag, self.cfg.case.dx, self.cfg.solver.n_walls, self.offset_vec, wall_part_fn, ) elif backend == "jax": # This implementation is used in the integrator when having moving walls. body = compute_nws_jax_wrapper( state0=kwargs["state0"], dx=self.cfg.case.dx, n_walls=self.cfg.solver.n_walls, offset_vec=self.offset_vec, box_size=kwargs["box_size"], pbc=self.cfg.case.pbc, cfg_nl=self.cfg.nl, displacement_fn=kwargs["disp_fn"], wall_part_fn=wall_part_fn, ) return body
[docs] def set_relaxation(Case, cfg): """Make a relaxation case from a SimulationSetup instance. Create a child class of a particular SimulationSetup instance and overwrite: - _init_pos{2|3}D - _box_size{2|3}D - _tag{2|3}D - _init_velocity{2|3}D - _external_acceleration_fn - _boundary_conditions_fn """ class Rlx(Case): """Relax particles in a box""" def __init__(self, cfg): super().__init__(cfg) # custom variables related only to this Simulation self.case.g_ext_magnitude = 0.0 # use the relaxation setup from the main case self._init_pos2D = self._init_pos2D_rlx self._init_pos3D = self._init_pos3D_rlx self._box_size2D = self._box_size2D_rlx self._box_size3D = self._box_size3D_rlx def _init_velocity2D(self, r): return jnp.zeros_like(r) def _init_velocity3D(self, r): return self._init_velocity2D(r) def _external_acceleration_fn(self, r): return jnp.zeros_like(r) def _boundary_conditions_fn(self, state): mask1 = jnp.isin(state["tag"], wall_tags)[:, None] state["u"] = jnp.where(mask1, 0.0, state["u"]) state["v"] = jnp.where(mask1, 0.0, state["v"]) state["dudt"] = jnp.where(mask1, 0.0, state["dudt"]) state["dvdt"] = jnp.where(mask1, 0.0, state["dvdt"]) return state return Rlx(cfg)
[docs] def load_case(case_root: str, case_py_file: str) -> SimulationSetup: """Load Case class from Python file. Args: case_root (str): Path to the case root directory, e.g. "cases/". case_py_file (str): Name of the Python case file, e.g. "db.py". """ file_path = os.path.join(case_root, case_py_file) spec = importlib.util.spec_from_file_location("case_module", file_path) case_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(case_module) # the class name has to be capital version of case_py_file without .py extension # e.g. "db.py" -> "DB" class_name = case_py_file.split(".")[0].upper() Case = getattr(case_module, class_name) return Case