Source code for jax_sph.partition

"""Neighbors search backends."""

from functools import partial
from typing import Optional

import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
import numpy as onp
from jax import jit

from jax_sph.jax_md import space
from jax_sph.jax_md.partition import (
    MaskFn,
    NeighborFn,
    NeighborList,
    NeighborListFns,
    NeighborListFormat,
    PartitionError,
    PartitionErrorCode,
    _displacement_or_metric_to_metric_sq,
    _neighboring_cells,
    cell_list,
    is_format_valid,
    is_sparse,
    shift_array,
)
from jax_sph.jax_md.partition import neighbor_list as vmap_neighbor_list

PEC = PartitionErrorCode


[docs] def get_particle_cells(idx, cl_capacity, N): """ Given a cell list idx of shape (nx, ny, nz, cell_capacity), we first enumerate each cell and then return a list of shape (N,) containing the number of the cell each particle belongs to. """ # containes particle indices in each cell (num_cells, cell_capacity) idx = idx.reshape(-1, cl_capacity) # (num_cells, cell_capacity) of # [[0,0,...0],[1,1,...1],...,[num_cells-1,num_cells-1,...num_cells-1] list_cells = jnp.broadcast_to(jnp.arange(idx.shape[0])[:, None], idx.shape) idx = jnp.reshape(idx, (-1,)) # flatten list_cells = jnp.reshape(list_cells, (-1,)) # flatten ordering = jnp.argsort(idx) # each particle is only once in the cell list particle_cells = list_cells[ordering][:N] return particle_cells
def _scan_neighbor_list( displacement_or_metric: space.DisplacementOrMetricFn, box: space.Box, r_cutoff: float, dr_threshold: float = 0.0, capacity_multiplier: float = 1.25, disable_cell_list: bool = False, mask_self: bool = True, custom_mask_function: Optional[MaskFn] = None, fractional_coordinates: bool = False, format: NeighborListFormat = NeighborListFormat.Sparse, num_partitions: int = 8, **static_kwargs, ) -> NeighborFn: """Modified JAX-MD neighbor list function that uses `lax.scan` to compute the distance between particles to save memory. Original: https://github.com/jax-md/jax-md/blob/main/jax_md/partition.py Returns a function that builds a list neighbors for collections of points. Neighbor lists must balance the need to be jit compatible with the fact that under a jit the maximum number of neighbors cannot change (owing to static shape requirements). To deal with this, our `neighbor_list` returns a `NeighborListFns` object that contains two functions: 1) `neighbor_fn.allocate` create a new neighbor list and 2) `neighbor_fn.update` updates an existing neighbor list. Neighbor lists themselves additionally have a convenience `update` member function. Note that allocation of a new neighbor list cannot be jit compiled since it uses the positions to infer the maximum number of neighbors (along with additional space specified by the `capacity_multiplier`). Updating the neighbor list can be jit compiled; if the neighbor list capacity is not sufficient to store all the neighbors, the `did_buffer_overflow` bit will be set to `True` and a new neighbor list will need to be reallocated. Here is a typical example of a simulation loop with neighbor lists: .. code-block:: python init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3) exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3) nbrs = neighbor_fn.allocate(R) state = init_fn(random.PRNGKey(0), R, neighbor_idx=nbrs.idx) def body_fn(i, state): state, nbrs = state nbrs = nbrs.update(state.position) state = apply_fn(state, neighbor_idx=nbrs.idx) return state, nbrs step = 0 for _ in range(20): new_state, nbrs = lax.fori_loop(0, 100, body_fn, (state, nbrs)) if nbrs.did_buffer_overflow: nbrs = neighbor_fn.allocate(state.position) else: state = new_state step += 1 Args: displacement: A function `d(R_a, R_b)` that computes the displacement between pairs of points. box: Either a float specifying the size of the box or an array of shape `[spatial_dim]` specifying the box size in each spatial dimension. r_cutoff: A scalar specifying the neighborhood radius. dr_threshold: A scalar specifying the maximum distance particles can move before rebuilding the neighbor list. capacity_multiplier: A floating point scalar specifying the fractional increase in maximum neighborhood occupancy we allocate compared with the maximum in the example positions. disable_cell_list: An optional boolean. If set to `True` then the neighbor list is constructed using only distances. This can be useful for debugging but should generally be left as `False`. mask_self: An optional boolean. Determines whether points can consider themselves to be their own neighbors. custom_mask_function: An optional function. Takes the neighbor array and masks selected elements. Note: The input array to the function is `(n_particles, m)` where the index of particle 1 is in index in the first dimension of the array, the index of particle 2 is given by the value in the array fractional_coordinates: An optional boolean. Specifies whether positions will be supplied in fractional coordinates in the unit cube, :math:`[0, 1]^d`. If this is set to True then the `box_size` will be set to `1.0` and the cell size used in the cell list will be set to `cutoff / box_size`. format: The format of the neighbor list; see the :meth:`NeighborListFormat` enum for details about the different choices for formats. Defaults to `Dense`. **static_kwargs: kwargs that get threaded through the calculation of example positions. Returns: A NeighborListFns object that contains a method to allocate a new neighbor list and a method to update an existing neighbor list. """ assert disable_cell_list is False, "Works only with a cell list" assert not fractional_coordinates, "Works only with real coordinates" assert format == NeighborListFormat.Sparse, "Works only with sparse neighbor list" assert custom_mask_function is None, "Custom masking not implemented" is_format_valid(format) box = lax.stop_gradient(box) r_cutoff = lax.stop_gradient(r_cutoff) dr_threshold = lax.stop_gradient(dr_threshold) box = jnp.float32(box) cutoff = r_cutoff + dr_threshold cutoff_sq = cutoff**2 threshold_sq = (dr_threshold / jnp.float32(2)) ** 2 metric_sq = _displacement_or_metric_to_metric_sq(displacement_or_metric) cell_size = cutoff assert jnp.all(cell_size < box / 3.0), "Don't use scan with very few cells" def neighbor_list_fn( position: jnp.ndarray, neighbors: Optional[NeighborList] = None, extra_capacity: int = 0, **kwargs, ) -> NeighborList: def neighbor_fn(position_and_error, max_occupancy=None): position, err = position_and_error N, dim = position.shape cl_fn = None cl = None cell_size = None if neighbors is None: # cl.shape = (nx, ny, nz, cell_capacity, dim) cell_size = cutoff cl_fn = cell_list(box, cell_size, capacity_multiplier) cl = cl_fn.allocate(position, extra_capacity=extra_capacity) else: cell_size = neighbors.cell_size cl_fn = neighbors.cell_list_fn if cl_fn is not None: cl = cl_fn.update(position, neighbors.cell_list_capacity) err = err.update(PEC.CELL_LIST_OVERFLOW, cl.did_buffer_overflow) cl_capacity = cl.cell_capacity idx = cl.id_buffer cell_idx = [idx] # shape: (nx, ny, nz, cell_capacity, 1) for dindex in _neighboring_cells(dim): if onp.all(dindex == 0): continue cell_idx += [shift_array(idx, dindex)] cell_idx = jnp.concatenate(cell_idx, axis=-2) cell_idx = jnp.reshape(cell_idx, (-1, cell_idx.shape[-2])) num_cells, considered_neighbors = cell_idx.shape particle_cells = get_particle_cells(idx, cl_capacity, N) d = partial(metric_sq, **kwargs) d = space.map_bond(d) # number of particles per partition N_sub # np.ceil used to pad last partition with < num_partitions entries N_sub = int(np.ceil(N / num_partitions)) num_pad = N_sub * num_partitions - N particle_cells = jnp.pad( particle_cells, ( 0, num_pad, ), constant_values=-1, ) if dim == 2: # the area of a circle with r=1/3 is 0.34907 volumetric_factor = 0.34907 elif dim == 3: # the volume of a sphere with r=1/3 is 0.15514 volumetric_factor = 0.15514 num_edges_sub = int( N_sub * considered_neighbors * volumetric_factor * capacity_multiplier ) def scan_body(carry, input): """Compute neighbors over a subset of particles The largest object here is of size (N_sub*considered_neighbors), where considered_neighbors in 3D is 27 * cell_capacity. """ occupancy = carry slice_from = input _entries = lax.dynamic_slice(particle_cells, (slice_from,), (N_sub,)) _idx = cell_idx[_entries] if mask_self: particle_idx = slice_from + jnp.arange(N_sub) _idx = jnp.where(_idx == particle_idx[:, None], N, _idx) if num_pad > 0: _idx = jnp.where(_entries[:, None] != -1, _idx, N) sender_idx = ( jnp.broadcast_to( jnp.arange(N_sub, dtype="int32")[:, None], _idx.shape ) + slice_from ) if num_pad > 0: sender_idx = jnp.clip(sender_idx, a_max=N) sender_idx = jnp.reshape(sender_idx, (-1,)) receiver_idx = jnp.reshape(_idx, (-1,)) dR = d(position[sender_idx], position[receiver_idx]) mask = (dR < cutoff_sq) & (receiver_idx < N) out_idx = N * jnp.ones(receiver_idx.shape, jnp.int32) cumsum = jnp.cumsum(mask) index = jnp.where(mask, cumsum - 1, considered_neighbors * N - 1) receiver_idx = out_idx.at[index].set(receiver_idx) sender_idx = out_idx.at[index].set(sender_idx) occupancy += cumsum[-1] carry = occupancy y = jnp.stack( (receiver_idx[:num_edges_sub], sender_idx[:num_edges_sub]) ) overflow = cumsum[-1] > num_edges_sub return carry, (y, overflow) carry = jnp.array(0) xs = jnp.array([i * N_sub for i in range(num_partitions)]) occupancy, (idx, overflows) = lax.scan( scan_body, carry, xs, length=num_partitions ) err = err.update(PEC.CELL_LIST_OVERFLOW, overflows.sum()) idx = idx.transpose(1, 2, 0).reshape(2, -1) # sort to enable pruning later ordering = jnp.argsort(idx[1]) idx = idx[:, ordering] if max_occupancy is None: _extra_capacity = N * extra_capacity max_occupancy = int(occupancy * capacity_multiplier + _extra_capacity) if max_occupancy > idx.shape[-1]: max_occupancy = idx.shape[-1] if not is_sparse(format): capacity_limit = N - 1 if mask_self else N elif format is NeighborListFormat.Sparse: capacity_limit = N * (N - 1) if mask_self else N**2 else: capacity_limit = N * (N - 1) // 2 if max_occupancy > capacity_limit: max_occupancy = capacity_limit idx = idx[:, :max_occupancy] update_fn = neighbor_list_fn if neighbors is None else neighbors.update_fn return NeighborList( idx, position, err.update(PEC.NEIGHBOR_LIST_OVERFLOW, occupancy > max_occupancy), cl_capacity, max_occupancy, format, cell_size, cl_fn, update_fn, ) # pytype: disable=wrong-arg-count nbrs = neighbors if nbrs is None: return neighbor_fn((position, PartitionError(jnp.zeros((), jnp.uint8)))) neighbor_fn = partial(neighbor_fn, max_occupancy=nbrs.max_occupancy) d = partial(metric_sq, **kwargs) d = jax.vmap(d) return lax.cond( jnp.any(d(position, nbrs.reference_position) > threshold_sq), (position, nbrs.error), neighbor_fn, nbrs, lambda x: x, ) def allocate_fn( position: jnp.ndarray, extra_capacity: int = 0, **kwargs ) -> NeighborList: return neighbor_list_fn(position, extra_capacity=extra_capacity, **kwargs) @jit def update_fn( position: jnp.ndarray, neighbors: NeighborList, **kwargs ) -> NeighborList: return neighbor_list_fn(position, neighbors, **kwargs) return NeighborListFns(allocate_fn, update_fn) # pytype: disable=wrong-arg-count def _matscipy_neighbor_list( displacement_or_metric: space.DisplacementOrMetricFn, box_size: space.Box, r_cutoff: float, dr_threshold: float = 0.0, capacity_multiplier: float = 1.25, disable_cell_list: bool = False, mask_self: bool = True, custom_mask_function: Optional[MaskFn] = None, fractional_coordinates: bool = False, format: NeighborListFormat = NeighborListFormat.Dense, **static_kwargs, ) -> NeighborFn: pbc = static_kwargs["pbc"] num_particles_max = static_kwargs["num_particles_max"] from matscipy.neighbours import neighbour_list as matscipy_nl assert box_size.ndim == 1 and (len(box_size) in [2, 3]) if box_size.shape == (2,): box_size = np.pad(box_size, (0, 1), mode="constant", constant_values=1.0) if box_size.shape != (3, 3): box_size = np.diag(box_size) if len(pbc) == 2: pbc = np.pad(pbc, (0, 1), mode="constant", constant_values=False) else: pbc = np.asarray(pbc, dtype=bool) dtype_idx = jnp.arange(0).dtype # just to get the correct dtype def matscipy_wrapper(position, idx_shape, num_particles): position = position[:num_particles] if position.shape[1] == 2: position = np.pad( position, ((0, 0), (0, 1)), mode="constant", constant_values=0.5 ) edge_list = matscipy_nl( "ij", cutoff=r_cutoff, positions=position, cell=box_size, pbc=pbc ) edge_list = np.asarray(edge_list, dtype=dtype_idx) if not mask_self: # add self connection, which matscipy does not do self_connect = np.arange(num_particles, dtype=dtype_idx) self_connect = np.array([self_connect, self_connect]) edge_list = np.concatenate((self_connect, edge_list), axis=-1) if edge_list.shape[1] > idx_shape[1]: # overflow true case idx_new = np.asarray(edge_list[:, : idx_shape[1]]) buffer_overflow = np.array(True) else: idx_new = np.ones(idx_shape, dtype=dtype_idx) * num_particles_max idx_new[:, : edge_list.shape[1]] = edge_list buffer_overflow = np.array(False) return idx_new, buffer_overflow @jax.jit def update_fn( position: jnp.ndarray, neighbors: NeighborList, **kwargs ) -> NeighborList: num_particles = kwargs["num_particles"] shape_edgelist = jax.ShapeDtypeStruct( neighbors.idx.shape, dtype=neighbors.idx.dtype ) shape_overflow = jax.ShapeDtypeStruct((), dtype=bool) shape_out = (shape_edgelist, shape_overflow) idx, buffer_overflow = jax.pure_callback( matscipy_wrapper, shape_out, position, neighbors.idx.shape, num_particles ) return NeighborList( idx, position, neighbors.error.update(PEC.NEIGHBOR_LIST_OVERFLOW, buffer_overflow), None, None, None, None, None, update_fn, ) def allocate_fn( position: jnp.ndarray, extra_capacity: int = 0, **kwargs ) -> NeighborList: num_particles = kwargs["num_particles"] position = position[:num_particles] if position.shape[1] == 2: position = np.pad( position, ((0, 0), (0, 1)), mode="constant", constant_values=0.5 ) edge_list = matscipy_nl( "ij", cutoff=r_cutoff, positions=position, cell=box_size, pbc=pbc ) edge_list = jnp.asarray(edge_list, dtype=dtype_idx) if not mask_self: # add self connection, which matscipy does not do self_connect = jnp.arange(num_particles, dtype=dtype_idx) self_connect = jnp.array([self_connect, self_connect]) edge_list = jnp.concatenate((self_connect, edge_list), axis=-1) # in case this is a (2,M) pair list, we pad with N and capacity_multiplier factor = capacity_multiplier * num_particles_max / num_particles res = num_particles * jnp.ones( (2, round(edge_list.shape[1] * factor + extra_capacity)), dtype_idx, ) res = res.at[:, : edge_list.shape[1]].set(edge_list) return NeighborList( res, position, PartitionError(jnp.zeros((), jnp.uint8)), None, None, None, None, None, update_fn, ) return NeighborListFns(allocate_fn, update_fn) BACKENDS = { "jaxmd_vmap": vmap_neighbor_list, "jaxmd_scan": _scan_neighbor_list, "matscipy": _matscipy_neighbor_list, }
[docs] def neighbor_list( displacement_or_metric: space.DisplacementOrMetricFn, box_size: space.Box, r_cutoff: float, backend: str = "jaxmd_vmap", dr_threshold: float = 0.0, capacity_multiplier: float = 1.25, disable_cell_list: bool = False, mask_self: bool = True, custom_mask_function: Optional[MaskFn] = None, fractional_coordinates: bool = False, format: NeighborListFormat = NeighborListFormat.Sparse, num_particles_max: int = None, num_partitions: int = 1, pbc: jnp.ndarray = None, ) -> NeighborFn: """Neighbor lists wrapper. Its arguments are mainly based on the jax-md ones. Args: displacement: A function `d(R_a, R_b)` that computes the displacement between pairs of points. box_size: Either a float specifying the size of the box or an array of shape `[spatial_dim]` specifying the box size in each spatial dimension. r_cutoff: A scalar specifying the neighborhood radius. dr_threshold: A scalar specifying the maximum distance particles can move before rebuilding the neighbor list. backend: The backend to use. Can be one of: 1) ``jaxmd_vmap`` - the default jax-md neighbor list which vectorizes the computations. 2) ``jaxmd_scan`` - a modified jax-md neighbor list which serializes the search into ``num_partitions`` chunks to improve the memory efficiency. 3) ``matscipy`` - a jit-able implementation with the matscipy neighbor list backend, which runs on CPU and takes variable number of particles smaller or equal to ``num_particles``. capacity_multiplier: A floating point scalar specifying the fractional increase in maximum neighborhood occupancy we allocate compared with the maximum in the example positions. disable_cell_list: An optional boolean. If set to `True` then the neighbor list is constructed using only distances. This can be useful for debugging but should generally be left as `False`. mask_self: An optional boolean. Determines whether points can consider themselves to be their own neighbors. custom_mask_function: An optional function. Takes the neighbor array and masks selected elements. Note: The input array to the function is `(n_particles, m)` where the index of particle 1 is in index in the first dimension of the array, the index of particle 2 is given by the value in the array fractional_coordinates: An optional boolean. Specifies whether positions will be supplied in fractional coordinates in the unit cube, :math:`[0, 1]^d`. If this is set to True then the `box_size` will be set to `1.0` and the cell size used in the cell list will be set to `cutoff / box_size`. format: The format of the neighbor list; see the :meth:`NeighborListFormat` enum for details about the different choices for formats. Defaults to `Dense`. num_particles_max: only used with the ``matscipy`` backend. Based on the largest particles system in a dataset. num_partitions: only used with the ``jaxmd_scan`` backend pbc: only used with the ``matscipy`` backend. Defines the boundary conditions for each dimension individually. Can have shape (2,) or (3,). **static_kwargs: kwargs that get threaded through the calculation of example positions. Returns: A NeighborListFns object that contains a method to allocate a new neighbor list and a method to update an existing neighbor list. """ assert backend in BACKENDS, f"Unknown backend {backend}" return BACKENDS[backend]( displacement_or_metric, box_size, r_cutoff, dr_threshold, capacity_multiplier, disable_cell_list, mask_self, custom_mask_function, fractional_coordinates, format, num_particles_max=num_particles_max, num_partitions=num_partitions, pbc=pbc, )
if __name__ == "__main__": # edge_list = matscipy_nl( # "ij", # cutoff=0.45, # positions=np.array([[0.1, 0.1, 0.1], [0.5, 0.1, 0.1], [0.5, 0.5, 0.1]]), # cell=np.array([[1.0, 0.0, 0.0],[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), # pbc=np.array([True, True, True]) # ) # box_size = np.array([1.0, 1.0]) # r = np.array([[0.1, 0.1], [0.5, 0.1]]) # displacement_fn, shift_fn = space.periodic(side=box_size) # r_cutoff = 0.45 # neighbor_fn = neighbor_list( # displacement_fn, # box_size, # r_cutoff=r_cutoff, # backend="jaxmd_vmap", # dr_threshold=r_cutoff * 0.25, # capacity_multiplier=1.25, # mask_self=False, # format=Sparse, # ) # neighbors = neighbor_fn.allocate(r) # neighbors = neighbors.update(r) # neighbors = neighbors.update(r) # a = np.array([1,2,3]) # f = lambda x: x**2 # def distance_fn(x, y): # return lax.scan(lambda _, x: (None, d(*x)), None, (x, y))[1] # def scan(f, init, xs, length=None): # if xs is None: # xs = [None] * length # carry = init # ys = [] # for x in xs: # carry, y = f(carry, x) # ys.append(y) # return carry, np.stack(ys) pass