Source code for jax_sph.solver

"""Weakly compressible SPH solver."""

from typing import Callable, Union

import jax.numpy as jnp
from jax import ops, vmap

from jax_sph.eos import RIEMANNEoS, TaitEoS
from jax_sph.jax_md import space
from jax_sph.kernel import (
    CubicKernel,
    GaussianKernel,
    QuinticKernel,
    SuperGaussianKernel,
    WendlandC2Kernel,
    WendlandC4Kernel,
    WendlandC6Kernel,
)
from jax_sph.utils import Tag, wall_tags

EPS = jnp.finfo(float).eps


[docs] def rho_evol_fn(rho, mass, u, grad_w_dist, i_s, j_s, dt, N, **kwargs): """Density evolution according to Adami et al. 2013.""" v_j_s = (mass / rho)[j_s] temp = v_j_s * ((u[i_s] - u[j_s]) * grad_w_dist).sum(axis=1) drhodt = rho * ops.segment_sum(temp, i_s, N) rho = rho + dt * drhodt return rho, drhodt
[docs] def rho_evol_riemann_fn_wrapper(kernel_fn, eos, c_ref): """Density evolution according to Zhang et al. 2017.""" def rho_evol_riemann_fn( e_s, rho_i, rho_j, m_j, u_i, u_j, p_i, p_j, r_ij, d_ij, wall_mask_j, n_w_j, g_ext_i, u_tilde_j, **kwargs, ): # Compute unit vector, above eq. (6), Zhang (2017) e_ij = e_s # Compute kernel gradient kernel_grad = kernel_fn.grad_w(d_ij) * (e_ij) # Compute average states eq. (6)/(12)/(13), Zhang (2017) u_L = jnp.where( jnp.isin(wall_mask_j, wall_tags), jnp.dot(u_i, -n_w_j), jnp.dot(u_i, -e_ij) ) p_L = p_i rho_L = rho_i # u_w from eq. (15), Yang (2020) u_R = jnp.where( jnp.isin(wall_mask_j, wall_tags), -u_L + 2 * jnp.dot(u_j, n_w_j), jnp.dot(u_j, -e_ij), ) p_R = jnp.where( jnp.isin(wall_mask_j, wall_tags), p_L + rho_L * jnp.dot(g_ext_i, -r_ij), p_j ) rho_R = jnp.where(jnp.isin(wall_mask_j, wall_tags), eos.rho_fn(p_R), rho_j) U_avg = (u_L + u_R) / 2 v_avg = (u_i + u_j) / 2 rho_avg = (rho_L + rho_R) / 2 # Compute Riemann states eq. (7) and below eq. (9), Zhang (2017) U_star = U_avg + 0.5 * (p_L - p_R) / (rho_avg * c_ref) v_star = U_star * (-e_ij) + (v_avg - U_avg * (-e_ij)) # Mass conservation with linear Riemann solver eq. (8), Zhang (2017) eq_8 = 2 * rho_i * m_j / rho_j * jnp.dot((u_i - v_star), kernel_grad) return eq_8 return rho_evol_riemann_fn
[docs] def rho_renorm_fn(rho, mass, i_s, j_s, w_dist, N): """Renormalization of density according to Zhang et al. 2017.""" nominator = ops.segment_sum(mass[j_s] * w_dist, i_s, N) rho_denominator = ops.segment_sum((mass / rho)[j_s] * w_dist, i_s, N) rho_denominator = jnp.where(rho_denominator > 1, 1, rho_denominator) rho = nominator / rho_denominator return rho
[docs] def rho_summation_fn(mass, i_s, w_dist, N): """Density summation.""" return mass * ops.segment_sum(w_dist, i_s, N)
[docs] def wall_phi_vec_wrapper(kernel_fn): """Compute the wall phi vector according to Zhang et al. 2017.""" def wall_phi_vec(rho_j, m_j, dr_ij, dist, tag_j, tag_i): # Compute unit vector, above eq. (6), Zhang (2017) e_ij_w = dr_ij / (dist + EPS) # Compute kernel gradient kernel_grad = kernel_fn.grad_w(dist) * (e_ij_w) # compute phi eq. (15), Zhang (2017) phi = -1.0 * m_j / rho_j * kernel_grad * tag_j * tag_i return phi return wall_phi_vec
[docs] def acceleration_tvf_fn_wrapper(kernel_fn): """Transport velocity formulation acceleration according to Adami et al. 2013.""" def acceleration_tvf_fn(r_ij, d_ij, rho_i, rho_j, m_i, m_j, p_bg_i): # compute the common prefactor `_c` _weighted_volume = ((m_i / rho_i) ** 2 + (m_j / rho_j) ** 2) / m_i _kernel_grad = kernel_fn.grad_w(d_ij) _c = _weighted_volume * _kernel_grad / (d_ij + EPS) # (Eq. 13) - or at least the acceleration term a_eq_13 = _c * 1.0 * p_bg_i * r_ij return a_eq_13 return acceleration_tvf_fn
[docs] def tvf_stress_fn(rho: float, u, v): """Transport velocity stress tensor. See 'A' under (Eq. 4) in Adami et al. 2013.""" return jnp.outer(rho * u, v - u)
[docs] def acceleration_standard_fn_wrapper(kernel_fn): """Standard SPH acceleration according to Adami et al. 2012.""" def acceleration_standard_fn( r_ij, d_ij, rho_i, rho_j, u_i, u_j, v_i, v_j, m_i, m_j, eta_i, eta_j, p_i, p_j, ): # (Eq. 6) - inter-particle-averaged shear viscosity (harmonic mean) eta_ij = 2 * eta_i * eta_j / (eta_i + eta_j + EPS) # (Eq. 7) - density-weighted pressure (weighted arithmetic mean) p_ij = (rho_j * p_i + rho_i * p_j) / (rho_i + rho_j) # compute the common prefactor `_c` _weighted_volume = ((m_i / rho_i) ** 2 + (m_j / rho_j) ** 2) / m_i _kernel_grad = kernel_fn.grad_w(d_ij) _c = _weighted_volume * _kernel_grad / (d_ij + EPS) # (Eq. 8): \boldsymbol{e}_{ij} is computed as r_ij/d_ij here. _A = (tvf_stress_fn(rho_i, u_i, v_i) + tvf_stress_fn(rho_j, u_j, v_j)) / 2 _u_ij = u_i - u_j a_eq_8 = _c * (-p_ij * r_ij + jnp.dot(_A, r_ij) + eta_ij * _u_ij) return a_eq_8 return acceleration_standard_fn
[docs] def acceleration_riemann_fn_wrapper(kernel_fn, eos, beta_fn, eta_limiter): """Riemann solver acceleration according to Zhang et al. 2017.""" def acceleration_fn_riemann( e_s, r_ij, d_ij, rho_i, rho_j, m_j, m_i, u_i, u_j, p_i, p_j, eta_i, eta_j, wall_mask_j, mask, n_w_j, g_ext_i, u_tilde_j, ): # Compute unit vector, above eq. (6), Zhang (2017) e_ij = e_s # Compute kernel gradient kernel_part_diff = kernel_fn.grad_w(d_ij) kernel_grad = kernel_part_diff * (e_ij) # Compute average states eq. (6)/(12)/(13), Zhang (2017) u_L = jnp.where( jnp.isin(wall_mask_j, wall_tags), jnp.dot(u_i, -n_w_j), jnp.dot(u_i, -e_ij) ) p_L = p_i rho_L = rho_i # u_w from eq. (15), Yang (2020) u_R = jnp.where( jnp.isin(wall_mask_j, wall_tags), -u_L + 2 * jnp.dot(u_j, n_w_j), jnp.dot(u_j, -e_ij), ) p_R = jnp.where( jnp.isin(wall_mask_j, wall_tags), p_L + rho_L * jnp.dot(g_ext_i, -r_ij), p_j ) rho_R = jnp.where(jnp.isin(wall_mask_j, wall_tags), eos.rho_fn(p_R), rho_j) P_avg = (p_L + p_R) / 2 rho_avg = (rho_L + rho_R) / 2 # Compute inter-particle-averaged shear viscosity (harmonic mean) # eq. (6), Adami (2013) eta_ij = 2 * eta_i * eta_j / (eta_i + eta_j + EPS) # Compute Riemann states eq. (7) and (10), Zhang (2017) P_star = P_avg + 0.5 * rho_avg * (u_L - u_R) * beta_fn(u_L, u_R, eta_limiter) # pressure term with linear Riemann solver eq. (9), Zhang (2017) eq_9 = -2 * m_j * (P_star / (rho_i * rho_j)) * kernel_grad # viscosity term eq. (6), Zhang (2019) u_d = 2 * u_j - u_tilde_j v_ij = jnp.where( jnp.isin(wall_mask_j, wall_tags), u_i - u_d, u_i - u_j, ) eq_6 = 2 * m_j * eta_ij / (rho_i * rho_j) * v_ij / (d_ij + EPS) eq_6 *= kernel_part_diff * mask # compute the prefactor `_c` _weighted_volume = ((m_i / rho_i) ** 2 + (m_j / rho_j) ** 2) / m_i _kernel_grad = kernel_fn.grad_w(d_ij) _c = _weighted_volume * _kernel_grad / (d_ij + EPS) _A = jnp.where( jnp.isin(wall_mask_j, wall_tags), (tvf_stress_fn(rho_i, u_i, u_i) + tvf_stress_fn(rho_j, u_d, u_d)) / 2, (tvf_stress_fn(rho_i, u_i, u_i) + tvf_stress_fn(rho_j, u_j, u_j)) / 2, ) a_eq_8 = _c * jnp.dot(_A, r_ij) return eq_9 + eq_6 + a_eq_8 return acceleration_fn_riemann
[docs] def artificial_viscosity_fn_wrapper(dx, artificial_alpha, u_ref=1.0): """Artificial viscosity according to Adami et al. 2012.""" h_ab = dx # if only artificial viscosity is used, then the following applies # nu = alpha * h * c_ab / 2 / (dim+2) # = 0.1 * 0.02 * 10*1 /2/4= 0.0025 # TODO: parse reference parameters from case setup c_ab = 10.0 * u_ref # c_ref def artificial_viscosity_fn( rho, mass, u, tag, i_s, j_s, dr_i_j, dist, grad_w_dist, N ): rho_ab = (rho[i_s] + rho[j_s]) / 2 numerator = mass[j_s] * artificial_alpha * h_ab * c_ab numerator = numerator * ((u[i_s] - u[j_s]) * dr_i_j).sum(axis=1) numerator = numerator[:, None] * grad_w_dist denominator = (rho_ab * (dist**2 + 0.01 * h_ab**2))[:, None] mask_fluid = tag == Tag.FLUID mask_fluid_edges = mask_fluid[j_s] * mask_fluid[i_s] res = mask_fluid_edges[:, None] * numerator / denominator dudt_artif = ops.segment_sum(res, i_s, N) return dudt_artif return artificial_viscosity_fn
[docs] def gwbc_fn_wrapper(is_free_slip, is_heat_conduction, eos): """Enforce wall boundary conditions by treating boundary particles in a special way. If solid walls -> apply BC tricks Update dummy particles before acceleration computation (if appl.). Steps for boundary particles: - sum pressure over fluid with sheparding - inverse EoS for density - sum velocity over fluid with sheparding and * (-1) - if free-slip: project velocity onto normal vector - subtract that from 2 * u_wall - keeps lid intact Based on: "A generalized wall boundary condition for smoothed particle hydrodynamics", Adami, Hu, Adams, 2012 """ def gwbc_fn(temperature, rho, tag, u, v, p, g_ext, i_s, j_s, w_dist, dr_i_j, nw, N): mask_bc = jnp.isin(tag, wall_tags) def no_slip_bc_fn(x): # for boundary particles, sum over fluid velocities x_wall_unnorm = ops.segment_sum(w_j_s_fluid[:, None] * x[j_s], i_s, N) # eq. 22 from "A Generalized Wall boundary condition for SPH", 2012 x_wall = x_wall_unnorm / (w_i_sum_wf[:, None] + EPS) # eq. 23 from same paper x = jnp.where(mask_bc[:, None], 2 * x - x_wall, x) return x def free_slip_bc_fn(x, wall_inner_normals): # # normal vectors pointing from fluid to wall # (1) implement via summing over fluid particles # wall_inner = ops.segment_sum(dr_i_j * mask_j_s_fluid[:, None], i_s, N) # # (2) implement using color gradient. Requires 2*rc thick wall # # wall_inner = - ops.segment_sum(dr_i_j*mask_j_s_wall[:, None], i_s, N) # normalization = jnp.sqrt((wall_inner**2).sum(axis=1, keepdims=True)) # wall_inner_normals = wall_inner / (normalization + EPS) # wall_inner_normals = jnp.where(mask_bc[:, None], wall_inner_normals, 0.0) # for boundary particles, sum over fluid velocities x_wall_unnorm = ops.segment_sum(w_j_s_fluid[:, None] * x[j_s], i_s, N) # eq. 22 from "A Generalized Wall boundary condition for SPH", 2012 x_wall = x_wall_unnorm / (w_i_sum_wf[:, None] + EPS) x_wall = wall_inner_normals * (x_wall * wall_inner_normals).sum( axis=1, keepdims=True ) # eq. 23 from same paper x = jnp.where(mask_bc[:, None], 2 * x - x_wall, x) return x # require operations with sender fluid and receiver wall/lid mask_j_s_fluid = jnp.where(tag[j_s] == Tag.FLUID, 1.0, 0.0) w_j_s_fluid = w_dist * mask_j_s_fluid # sheparding denominator w_i_sum_wf = ops.segment_sum(w_j_s_fluid, i_s, N) if is_free_slip: # free-slip boundary - ignore viscous interactions with wall u = free_slip_bc_fn(u, -nw) v = free_slip_bc_fn(v, -nw) else: # no-slip boundary condition u = no_slip_bc_fn(u) v = no_slip_bc_fn(v) # eq. 27 from "A Generalized Wall boundary condition for SPH", 2012 # fluid pressure term p_wall_unnorm = ops.segment_sum(w_j_s_fluid * p[j_s], i_s, N) # external fluid acceleration term rho_wf_sum = (rho[j_s] * w_j_s_fluid)[:, None] * dr_i_j rho_wf_sum = ops.segment_sum(rho_wf_sum, i_s, N) p_wall_ext = (g_ext * rho_wf_sum).sum(axis=1) # normalize with sheparding p_wall = (p_wall_unnorm + p_wall_ext) / (w_i_sum_wf + EPS) p = jnp.where(mask_bc, p_wall, p) rho = vmap(eos.rho_fn)(p) if is_heat_conduction: # wall particles without temperature boundary condition obtain the adjacent # fluid temperature t_wall_unnorm = ops.segment_sum(w_j_s_fluid * temperature[j_s], i_s, N) t_wall = t_wall_unnorm / (w_i_sum_wf + EPS) mask = jnp.isin(tag, jnp.array([Tag.SOLID_WALL, Tag.MOVING_WALL])) t_wall = jnp.where(mask, t_wall, temperature) temperature = t_wall return p, rho, u, v, temperature return gwbc_fn
[docs] def gwbc_fn_riemann_wrapper(is_free_slip, is_heat_conduction): """Riemann solver boundary condition for wall particles.""" if is_free_slip: def free_weight(fluid_mask_i, tag_i): return fluid_mask_i def riemann_velocities(u, w_dist, fluid_mask, i_s, j_s, N): return u else: def free_weight(fluid_mask_i, tag_i): return jnp.ones_like(tag_i) def riemann_velocities(u, w_dist, fluid_mask, i_s, j_s, N): w_dist_fluid = w_dist * fluid_mask[j_s] u_wall_nom = ops.segment_sum(w_dist_fluid[:, None] * u[j_s], i_s, N) u_wall_denom = ops.segment_sum(w_dist_fluid, i_s, N) u_tilde = u_wall_nom / (u_wall_denom[:, None] + EPS) return u_tilde if is_heat_conduction: def heat_bc(mask_j_s_fluid, w_dist, temperature, i_s, j_s, tag, N): w_j_s_fluid = w_dist * mask_j_s_fluid # sheparding denominator w_i_sum_wf = ops.segment_sum(w_j_s_fluid, i_s, N) t_wall_unnorm = ops.segment_sum(w_j_s_fluid * temperature[j_s], i_s, N) t_wall = t_wall_unnorm / (w_i_sum_wf + EPS) mask = jnp.isin(tag, jnp.array([Tag.SOLID_WALL, Tag.MOVING_WALL])) t_wall = jnp.where(mask, t_wall, temperature) temperature = t_wall return temperature else: def heat_bc(mask_j_s_fluid, w_dist, temperature, i_s, j_s, tag, N): return temperature return free_weight, riemann_velocities, heat_bc
[docs] def limiter_fn_wrapper(eta_limiter, c_ref): """if eta_limiter != -1, introduce dissipation limiter eq. (11), Zhang (2017).""" if eta_limiter == -1: def beta_fn(u_L, u_R, eta_limiter): return c_ref else: def beta_fn(u_L, u_R, eta_limiter): temp = eta_limiter * jnp.maximum(u_L - u_R, jnp.zeros_like(u_L)) beta = jnp.minimum(temp, jnp.full_like(temp, c_ref)) return beta return beta_fn
[docs] def temperature_derivative_wrapper(kernel_fn): """Temperature derivative according to Cleary 1998.""" def temperature_derivative( e_s, r_ij, d_ij, rho_i, rho_j, m_j, kappa_i, kappa_j, Cp_i, T_i, T_j ): e_ij = e_s _kernel_grad = kernel_fn.grad_w(d_ij) _kernel_grad_vector = _kernel_grad * e_ij _effective_kappa = (kappa_i * kappa_j) / (kappa_i + kappa_j) F_ab = jnp.dot(r_ij, _kernel_grad_vector) / ((d_ij * d_ij) + EPS) # scalar dTdt = (4 * m_j * _effective_kappa * (T_i - T_j) * F_ab) / ( Cp_i * rho_i * rho_j ) return dTdt return temperature_derivative
[docs] class WCSPH: """Weakly compressible SPH solver with transport velocity formulation.""" def __init__( self, displacement_fn: Callable, eos: Union[TaitEoS, RIEMANNEoS], g_ext_fn: Callable, dx: float, dim: int, dt: float, c_ref: float, eta_limiter: float = 3, solver: str = "SPH", kernel: str = "QSK", is_bc_trick: bool = False, is_rho_evol: bool = False, artificial_alpha: float = 0.0, is_free_slip: bool = False, is_rho_renorm: bool = False, is_heat_conduction: bool = False, ): self.displacement_fn = displacement_fn self.solver = solver self.g_ext_fn = g_ext_fn self.is_bc_trick = is_bc_trick self.is_rho_evol = is_rho_evol self.is_rho_renorm = is_rho_renorm self.dt = dt self.eos = eos self.artificial_alpha = artificial_alpha self.is_heat_conduction = is_heat_conduction _beta_fn = limiter_fn_wrapper(eta_limiter, c_ref) match kernel: case "CSK": self._kernel_fn = CubicKernel(h=dx, dim=dim) case "QSK": self._kernel_fn = QuinticKernel(h=dx, dim=dim) case "WC2K": self._kernel_fn = WendlandC2Kernel(h=1.3 * dx, dim=dim) case "WC4K": self._kernel_fn = WendlandC4Kernel(h=1.3 * dx, dim=dim) case "WC6K": self._kernel_fn = WendlandC6Kernel(h=1.3 * dx, dim=dim) case "GK": self._kernel_fn = GaussianKernel(h=dx, dim=dim) case "SGK": self._kernel_fn = SuperGaussianKernel(h=dx, dim=dim) self._gwbc_fn = gwbc_fn_wrapper(is_free_slip, is_heat_conduction, eos) ( self._free_weight, self._riemann_velocities, self._heat_bc, ) = gwbc_fn_riemann_wrapper(is_free_slip, is_heat_conduction) self._acceleration_tvf_fn = acceleration_tvf_fn_wrapper(self._kernel_fn) self._acceleration_riemann_fn = acceleration_riemann_fn_wrapper( self._kernel_fn, eos, _beta_fn, eta_limiter ) self._acceleration_fn = acceleration_standard_fn_wrapper(self._kernel_fn) self._artificial_viscosity_fn = artificial_viscosity_fn_wrapper( dx, artificial_alpha ) self._wall_phi_vec = wall_phi_vec_wrapper(self._kernel_fn) self._rho_evol_riemann_fn = rho_evol_riemann_fn_wrapper( self._kernel_fn, eos, c_ref ) self._temperature_derivative = temperature_derivative_wrapper(self._kernel_fn)
[docs] def forward_wrapper(self): """Wrapper of update step of SPH.""" def forward(state, neighbors): """Update step of SPH solver. Args: state (dict): Flow fields and particle properties. neighbors (_type_): Neighbors object. """ r, tag, mass, eta = state["r"], state["tag"], state["mass"], state["eta"] u, v, dudt, dvdt = state["u"], state["v"], state["dudt"], state["dvdt"] rho, drhodt, p = state["rho"], state["drhodt"], state["p"] nw, kappa, Cp = state["nw"], state["kappa"], state["Cp"] temperature, dTdt = state["T"], state["dTdt"] N = len(r) # precompute displacements `dr` and distances `dist` # the second vector is sorted i_s, j_s = neighbors.idx r_i_s, r_j_s = r[i_s], r[j_s] dr_i_j = vmap(self.displacement_fn)(r_i_s, r_j_s) dist = space.distance(dr_i_j) w_dist = vmap(self._kernel_fn.w)(dist) e_s = dr_i_j / (dist[:, None] + EPS) # currently only for density evolution and with artificial viscosity grad_w_dist_norm = vmap(self._kernel_fn.grad_w)(dist) grad_w_dist = grad_w_dist_norm[:, None] * e_s # external acceleration field g_ext = self.g_ext_fn(r) # e.g. np.array([[0, -1], [0, -1], ...]) # masks wall_mask = jnp.where(jnp.isin(tag, wall_tags), 1.0, 0.0) fluid_mask = jnp.where(tag == Tag.FLUID, 1.0, 0.0) ##### Riemann velocity BCs if self.is_bc_trick and (self.solver == "RIE"): u_tilde = self._riemann_velocities(u, w_dist, fluid_mask, i_s, j_s, N) else: u_tilde = u ##### Density summation or evolution # update evolution if self.is_rho_evol and (self.solver == "SPH"): rho, drhodt = rho_evol_fn( rho, mass, u, grad_w_dist, i_s, j_s, self.dt, N ) if self.is_rho_renorm: rho = rho_renorm_fn(rho, mass, i_s, j_s, w_dist, N) elif self.is_rho_evol and (self.solver == "RIE"): temp = vmap(self._rho_evol_riemann_fn)( e_s, rho[i_s], rho[j_s], mass[j_s], u[i_s], u[j_s], p[i_s], p[j_s], dr_i_j, dist, wall_mask[j_s], nw[j_s], g_ext[i_s], u_tilde[j_s], ) drhodt = ops.segment_sum(temp, i_s, N) * fluid_mask rho = rho + self.dt * drhodt if self.is_rho_renorm: rho = rho_renorm_fn(rho, mass, i_s, j_s, w_dist, N) else: rho_ = rho_summation_fn(mass, i_s, w_dist, N) rho = jnp.where(fluid_mask, rho_, rho) ##### Compute primitives # pressure, and background pressure p = vmap(self.eos.p_fn)(rho) background_pressure_tvf = vmap(self.eos.p_fn)(jnp.zeros_like(p)) ##### Apply BC trick if self.is_bc_trick and (self.solver == "SPH"): p, rho, u, v, temperature = self._gwbc_fn( temperature, rho, tag, u, v, p, g_ext, i_s, j_s, w_dist, dr_i_j, nw, N, ) elif self.is_bc_trick and (self.solver == "RIE"): mask = self._free_weight(fluid_mask[i_s], tag[i_s]) temperature = self._heat_bc( fluid_mask[j_s], w_dist, temperature, i_s, j_s, tag, N ) elif (not self.is_bc_trick) and (self.solver == "RIE"): mask = jnp.ones_like(tag[i_s]) ##### compute heat conduction if self.is_heat_conduction: # integrate the incomming temperature derivative temperature += self.dt * dTdt # compute temperature derivative for next step out = vmap(self._temperature_derivative)( e_s, dr_i_j, dist, rho[i_s], rho[j_s], mass[j_s], kappa[i_s], kappa[j_s], Cp[i_s], temperature[i_s], temperature[j_s], ) dTdt = ops.segment_sum(out, i_s, N) ##### Compute RHS if self.solver == "SPH": out = vmap(self._acceleration_fn)( dr_i_j, dist, rho[i_s], rho[j_s], u[i_s], u[j_s], v[i_s], v[j_s], mass[i_s], mass[j_s], eta[i_s], eta[j_s], p[i_s], p[j_s], ) elif self.solver == "RIE": out = vmap(self._acceleration_riemann_fn)( e_s, dr_i_j, dist, rho[i_s], rho[j_s], mass[j_s], mass[i_s], u[i_s], u[j_s], p[i_s], p[j_s], eta[i_s], eta[j_s], wall_mask[j_s], mask, nw[j_s], g_ext[i_s], u_tilde[j_s], ) dudt = ops.segment_sum(out, i_s, N) out_tv = vmap(self._acceleration_tvf_fn)( dr_i_j, dist, rho[i_s], rho[j_s], mass[i_s], mass[j_s], background_pressure_tvf[i_s], ) dvdt = ops.segment_sum(out_tv, i_s, N) ##### Additional things if self.artificial_alpha != 0.0: dudt += self._artificial_viscosity_fn( rho, mass, u, tag, i_s, j_s, dr_i_j, dist, grad_w_dist, N ) state = { "r": r, "tag": tag, "u": u, "v": v, "drhodt": drhodt, "dudt": dudt + g_ext, "dvdt": dvdt, "rho": rho, "p": p, "mass": mass, "eta": eta, "dTdt": dTdt, "T": temperature, "kappa": kappa, "Cp": Cp, "nw": nw, } return state return forward