Utils and Defaults

Utils

General jax-sph utils.

class jax_sph.utils.Tag(value)[source]

Particle types.

jax_sph.utils.pos_init_cartesian_2d(box_size: array, dx: float)[source]

Create a grid of particles in 2D.

Particles are at the center of the corresponding Cartesian grid cells. Example: if box_size=np.array([1, 1]) and dx=0.1, then the first particle will be at position [0.05, 0.05].

jax_sph.utils.pos_init_cartesian_3d(box_size: array, dx: float)[source]

Create a grid of particles in 3D.

jax_sph.utils.pos_box_2d(fluid_box: array, dx: float, n_walls: int = 3)[source]

Create an empty box of particles in 2D.

fluid_box is an array of the form: [L, H] The box is of size (L + n_walls * dx) x (H + n_walls * dx). The inner part of the box starts at (n_walls * dx, n_walls * dx).

jax_sph.utils.pos_box_3d(fluid_box: array, dx: float, n_walls: int = 3, z_periodic: bool = True)[source]

Create an z-periodic empty box of particles in 3D.

fluid_box is an array of the form: [L, H, D] The box is of size (L + n_walls * dx) x (H + n_walls * dx) x D. The inner part of the box starts at (n_walls * dx, n_walls * dx). z_periodic states whether the box is periodic in z-direction.

jax_sph.utils.get_noise_masked(shape: tuple, mask: array, key: PRNGKey, std: float)[source]

Generate Gaussian noise with std where mask is True.

jax_sph.utils.get_ekin(state: Dict, dx: float)[source]

Compute the kinetic energy of the fluid from state[“v”].

jax_sph.utils.get_array_stats(state: Dict, var: str = 'u', operation='max')[source]

Extract the min, max, or mean of state[“var”].

For vectorial quantities, use the Euclidean norm.

Parameters:
  • state – Simulation state dictionary.

  • var – Variable to extract, i.e. dict key.

  • operation – One of “min”, “max”, “mean”.

jax_sph.utils.get_stats(state: Dict, props: list, dx: float)[source]

Extract values from state for printing.

jax_sph.utils.compute_nws_scipy(r, tag, dx, n_walls, offset_vec, wall_part_fn)[source]

Computes the normal vectors of all wall boundaries. Jit-able pure_callback.

jax_sph.utils.compute_nws_jax_wrapper(state0: Dict, dx: float, n_walls: int, offset_vec: Array, box_size: Array, pbc: Array, cfg_nl: DictConfig, displacement_fn: Callable, wall_part_fn: Callable)[source]

Compute wall normal vectors from wall to fluid. Jit-able JAX implementation.

For the particles from r_walls, find the closest particle from layer and compute the normal vector from each r_walls particle.

class jax_sph.utils.Logger(dt, dx, print_props, sequence_length)[source]

Logger for printing stats to stdout.

jax_sph.utils.sph_interpolator(cfg: DictConfig, src_path: str, prop_type: str = 'vector')[source]

Interpolate properties from a state to arbitrary coordinates, e.g. a line.

Parameters:
  • cfg – Simulation arguments.

  • src_path – used only for instantiating the neighbors object.

  • prop_type – Whether the target will be of vectorial or scalar type.

Returns:

Interpolation function.

Return type:

Callable

Input/Output

Input-output utilities.

jax_sph.io_state.io_setup(cfg: DictConfig)[source]

Setup the output directory and write the arguments to a .txt file.

jax_sph.io_state.write_h5(data_dict: Dict, path: str)[source]

Write a dict of numpy or jax arrays to a .h5 file.

jax_sph.io_state.write_vtk(data_dict: Dict, path: str)[source]

Store a .vtk file for ParaView.

jax_sph.io_state.write_state(step: int, state: Dict, dir: str, cfg: DictConfig)[source]

Write state to .h5 or .vtk file while simulation is running.

jax_sph.io_state.read_h5(file_name: str, array_type: str = 'jax')[source]

Read an .h5 file and return a dict of numpy or jax arrays.

jax_sph.io_state.write_vtks_from_h5s(dir_path: str, keep_h5: bool = True)[source]

Transform a set of .h5 files to .vtk files.

Parameters:
  • path – path to directory with .h5 files

  • keep_h5 – Whether to keep or delete the original .h5 files.

Defaults

Default jax-sph configs.

jax_sph.defaults.set_defaults(cfg: DictConfig = {'config': None, 'seed': 123, 'no_jit': False, 'gpu': 0, 'dtype': 'float64', 'xla_mem_fraction': 0.75, 'case': {'source': None, 'mode': 'sim', 'dim': 3, 'dx': 0.05, 'state0_path': None, 'state0_keys': ['r'], 'r0_type': 'cartesian', 'r0_noise_factor': 0.0, 'g_ext_magnitude': 0.0, 'viscosity': 0.01, 'u_ref': 1.0, 'c_ref_factor': 10.0, 'rho_ref': 1.0, 'T_ref': 1.0, 'kappa_ref': 0.0, 'Cp_ref': 0.0, 'special': {}}, 'solver': {'name': 'SPH', 'tvf': 0.0, 'cfl': 0.25, 'density_evolution': False, 'density_renormalize': False, 'dt': None, 't_end': 0.2, 'artificial_alpha': 0.0, 'free_slip': False, 'eta_limiter': 3, 'kappa': 0, 'n_walls': 3, 'heat_conduction': False, 'is_bc_trick': False}, 'kernel': {'name': 'QSK', 'h_factor': 1.0}, 'eos': {'name': 'Tait', 'gamma': 1.0, 'p_bg_factor': 0.0}, 'nl': {'backend': 'jaxmd_vmap', 'num_partitions': 1}, 'io': {'write_type': [], 'write_every': 1, 'data_path': './', 'print_props': ['Ekin', 'u_max']}}) DictConfig[source]

Set default lagrangebench configs.

jax_sph.defaults.check_cfg(cfg: DictConfig) None[source]

Check if the configs are valid.