Source code for jax_sph.integrator

"""Integrator schemes."""

from typing import Callable, Dict

from jax_sph.utils import Tag


[docs] def si_euler( tvf: float, model: Callable, shift_fn: Callable, bc_fn: Callable, nw_fn: Callable ): """Semi-implicit Euler integrator including Transport Velocity. The integrator advances the state of the system following the steps: 1. Update u and v with dudt and dvdt 2. Update position r with velocity v 3. Update neighbor list 4. Compute accelerations - SPH model call 5. Impose boundary conditions as defined by the case. """ def advance(dt: float, state: Dict, neighbors): """Call to integrator.""" # 1. Twice 1/2dt integration of u and v state["u"] += 1.0 * dt * state["dudt"] state["v"] = state["u"] + tvf * 0.5 * dt * state["dvdt"] # 2. Integrate position with velocity v state["r"] = shift_fn(state["r"], 1.0 * dt * state["v"]) # recompute wall normals if needed if nw_fn is not None: state["nw"] = nw_fn(state["r"]) # 3. Update neighbor list # The displacment and shift function from JAX MD are used for computing the # relative particle distance and moving particles across the priodic domain # # dr = displacement_fn(r1, r2) = r1 - r2 # respecting PBC, i.e. returns |dr_i| < box_size / 2 # # r = shift_fn(r, dr) = r + dr # respecting PBC, i.e. new r in [0, box_size] num_particles = (state["tag"] != Tag.PAD_VALUE).sum() neighbors = neighbors.update(state["r"], num_particles=num_particles) # 4. Compute accelerations state = model(state, neighbors) # 5. Impose boundary conditions on dummy particles (if applicable) state = bc_fn(state) return state, neighbors return advance