Source code for shepherd_score.alignment._jax_parallel

"""
Parallel (multi-device) JAX alignment via jax.shard_map.

Each public ``optimize_*_jax_*_shmap`` function accepts flat ``(total, ...)``
arrays (where ``total`` is the number of pairs padded to a multiple of
``len(jax.devices())``) and returns flat ``(total, ...)`` results.  The
leading axis is automatically distributed across devices by
``jax.shard_map`` with ``PartitionSpec('i')``; do **not** pre-reshape to
``(n_devices, B, ...)``.

``XLA_FLAGS=--xla_force_host_platform_device_count=N`` must be set **before
any JAX import** so that ``len(jax.devices()) == N``.
"""
import numpy as np

import jax
from jax import jit, vmap
from jax.sharding import Mesh, PartitionSpec as P

from shepherd_score.alignment._jax import (
    _per_pair_optimize_vol_mask_scan,
    _per_pair_optimize_vol_esp_mask_scan,
    _per_pair_optimize_surf_scan,
    _per_pair_optimize_surf_esp_scan,
    _per_pair_optimize_pharm_mask_scan_factory,
)

# ---------------------------------------------------------------------------
# Volumetric alignment
# ---------------------------------------------------------------------------

# Cache: (max_num_steps, n_devices) -> jit-compiled shard_map function
_shmap_vol_cache: dict = {}


def _make_shmap_vol_fn(max_num_steps: int, mesh: Mesh):
    """Build a ``jit(shard_map(vmap(...)))`` volumetric alignment function.

    Closes over ``max_num_steps`` so ``lax.scan`` sees a static Python int.
    Wraps the shard_map in ``jit`` so XLA compiles it once (shard_map alone
    does not trigger compilation the way pmap does).
    The result is stored in ``_shmap_vol_cache`` and reused for matching calls.
    """
    def _per_shard(ref_b, fit_b, mask_r_b, mask_f_b, se3_b, alpha, VAA_b, VBB_b, lr):
        """Processes one device's shard: shape (B, N/M, 3) etc."""
        def per_pair(ref, fit, mr, mf, s, VAA, VBB):
            return _per_pair_optimize_vol_mask_scan(
                ref, fit, mr, mf, s, alpha, VAA, VBB, lr, max_num_steps
            )
        return vmap(per_pair)(ref_b, fit_b, mask_r_b, mask_f_b, se3_b, VAA_b, VBB_b)

    return jit(jax.shard_map(
        _per_shard,
        mesh=mesh,
        in_specs=(P('i'), P('i'), P('i'), P('i'), P('i'), P(), P('i'), P('i'), P()),
        out_specs=(P('i'), P('i'), P('i')),
        check_vma=False,
    ))


[docs] def optimize_ROCS_overlay_jax_vol_shmap( ref_batch, fit_batch, mask_ref_batch, mask_fit_batch, VAA_batch, VBB_batch, se3_init_batch, alpha: float, lr: float, max_num_steps: int, ): """Volumetric alignment via ``shard_map`` + ``vmap`` across virtual CPU devices. All ``*_batch`` arrays use a flat leading axis of size ``total`` (i.e. the number of pairs padded to a multiple of ``len(jax.devices())``). Unlike ``pmap``, ``shard_map`` automatically distributes the flat leading axis across devices; do **not** pre-reshape to ``(n_devices, B, ...)``. Pre-compute self-overlaps ``VAA``/``VBB`` and SE(3) initialisations outside this function (they are invariant to the optimisation loop). ``XLA_FLAGS=--xla_force_host_platform_device_count=N`` must be set **before JAX is first imported** so that ``len(jax.devices()) == N``. Parameters ---------- ref_batch : (total, N, 3) padded reference positions fit_batch : (total, M, 3) padded fit positions mask_ref_batch : (total, N) mask_fit_batch : (total, M) VAA_batch : (total,) pre-computed ref self-overlaps VBB_batch : (total,) pre-computed fit self-overlaps se3_init_batch : (total, R, 7) pre-initialised SE(3) params alpha : float lr : float max_num_steps : int (Python int; determines compiled kernel) Returns ------- aligned_pts : (total, M, 3) se3_transform : (total, 4, 4) scores : (total,) """ devices = jax.devices() mesh = Mesh(np.array(devices), axis_names=('i',)) cache_key = (max_num_steps, len(devices)) if cache_key not in _shmap_vol_cache: _shmap_vol_cache[cache_key] = _make_shmap_vol_fn(max_num_steps, mesh) fn = _shmap_vol_cache[cache_key] return fn( ref_batch, fit_batch, mask_ref_batch, mask_fit_batch, se3_init_batch, alpha, VAA_batch, VBB_batch, lr, )
# --------------------------------------------------------------------------- # Masked volumetric ESP alignment # --------------------------------------------------------------------------- _shmap_vol_esp_cache: dict = {} def _make_shmap_vol_esp_fn(max_num_steps: int, mesh: Mesh): """Build a ``jit(shard_map(vmap(...)))`` masked volumetric ESP alignment function.""" def _per_shard(ref_b, fit_b, ref_ch_b, fit_ch_b, mask_r_b, mask_f_b, se3_b, alpha, lam, VAA_b, VBB_b, lr): """Processes one device's shard: shape (B, N/M, ...) etc.""" def per_pair(ref, fit, ref_ch, fit_ch, mr, mf, s, VAA, VBB): return _per_pair_optimize_vol_esp_mask_scan( ref, fit, ref_ch, fit_ch, mr, mf, s, alpha, lam, VAA, VBB, lr, max_num_steps ) return vmap(per_pair)(ref_b, fit_b, ref_ch_b, fit_ch_b, mask_r_b, mask_f_b, se3_b, VAA_b, VBB_b) return jit(jax.shard_map( _per_shard, mesh=mesh, in_specs=(P('i'), P('i'), P('i'), P('i'), P('i'), P('i'), P('i'), P(), P(), P('i'), P('i'), P()), out_specs=(P('i'), P('i'), P('i')), check_vma=False, ))
[docs] def optimize_ROCS_esp_overlay_jax_vol_esp_shmap( ref_batch, fit_batch, ref_charges_batch, fit_charges_batch, mask_ref_batch, mask_fit_batch, VAA_batch, VBB_batch, se3_init_batch, alpha: float, lam: float, lr: float, max_num_steps: int, ): """Masked volumetric ESP alignment via ``shard_map`` + ``vmap`` across virtual CPU devices. Parameters ---------- ref_batch : (total, N, 3) padded reference positions fit_batch : (total, M, 3) padded fit positions ref_charges_batch : (total, N, 1) padded reference charges (column-shaped) fit_charges_batch : (total, M, 1) padded fit charges mask_ref_batch : (total, N) mask_fit_batch : (total, M) VAA_batch : (total,) pre-computed ref ESP self-overlaps VBB_batch : (total,) pre-computed fit ESP self-overlaps se3_init_batch : (total, R, 7) pre-initialised SE(3) params alpha, lam, lr : float max_num_steps : int (Python int; determines compiled kernel) Returns ------- aligned_pts : (total, M, 3) se3_transform : (total, 4, 4) scores : (total,) """ devices = jax.devices() mesh = Mesh(np.array(devices), axis_names=('i',)) cache_key = (max_num_steps, len(devices)) if cache_key not in _shmap_vol_esp_cache: _shmap_vol_esp_cache[cache_key] = _make_shmap_vol_esp_fn(max_num_steps, mesh) fn = _shmap_vol_esp_cache[cache_key] return fn( ref_batch, fit_batch, ref_charges_batch, fit_charges_batch, mask_ref_batch, mask_fit_batch, se3_init_batch, alpha, lam, VAA_batch, VBB_batch, lr, )
# --------------------------------------------------------------------------- # Non-masked surface alignment # --------------------------------------------------------------------------- _shmap_surf_cache: dict = {} def _make_shmap_surf_fn(max_num_steps: int, mesh: Mesh): """Build a ``jit(shard_map(vmap(...)))`` non-masked surface alignment function.""" def _per_shard(ref_b, fit_b, se3_b, alpha, VAA_b, VBB_b, lr): """Processes one device's shard.""" def per_pair(ref, fit, s, VAA, VBB): return _per_pair_optimize_surf_scan( ref, fit, s, alpha, VAA, VBB, lr, max_num_steps ) return vmap(per_pair)(ref_b, fit_b, se3_b, VAA_b, VBB_b) return jit(jax.shard_map( _per_shard, mesh=mesh, in_specs=(P('i'), P('i'), P('i'), P(), P('i'), P('i'), P()), out_specs=(P('i'), P('i'), P('i')), check_vma=False, ))
[docs] def optimize_ROCS_overlay_jax_surf_shmap( ref_batch, fit_batch, VAA_batch, VBB_batch, se3_init_batch, alpha: float, lr: float, max_num_steps: int, ): """Non-masked surface alignment via ``shard_map`` + ``vmap`` across virtual CPU devices. Surface arrays are uniform size across all pairs so no masking is needed. Parameters ---------- ref_batch : (total, N, 3) stacked reference surface positions fit_batch : (total, M, 3) stacked fit surface positions VAA_batch : (total,) pre-computed ref self-overlaps VBB_batch : (total,) pre-computed fit self-overlaps se3_init_batch : (total, R, 7) pre-initialised SE(3) params alpha, lr : float max_num_steps : int (Python int; determines compiled kernel) Returns ------- aligned_pts : (total, M, 3) se3_transform : (total, 4, 4) scores : (total,) """ devices = jax.devices() mesh = Mesh(np.array(devices), axis_names=('i',)) cache_key = (max_num_steps, len(devices)) if cache_key not in _shmap_surf_cache: _shmap_surf_cache[cache_key] = _make_shmap_surf_fn(max_num_steps, mesh) fn = _shmap_surf_cache[cache_key] return fn(ref_batch, fit_batch, se3_init_batch, alpha, VAA_batch, VBB_batch, lr)
# --------------------------------------------------------------------------- # Non-masked surface ESP alignment # --------------------------------------------------------------------------- _shmap_surf_esp_cache: dict = {} def _make_shmap_surf_esp_fn(max_num_steps: int, mesh: Mesh): """Build a ``jit(shard_map(vmap(...)))`` non-masked surface ESP alignment function.""" def _per_shard(ref_b, fit_b, ref_ch_b, fit_ch_b, se3_b, alpha, lam, VAA_b, VBB_b, lr): """Processes one device's shard.""" def per_pair(ref, fit, ref_ch, fit_ch, s, VAA, VBB): return _per_pair_optimize_surf_esp_scan( ref, fit, ref_ch, fit_ch, s, alpha, lam, VAA, VBB, lr, max_num_steps ) return vmap(per_pair)(ref_b, fit_b, ref_ch_b, fit_ch_b, se3_b, VAA_b, VBB_b) return jit(jax.shard_map( _per_shard, mesh=mesh, in_specs=(P('i'), P('i'), P('i'), P('i'), P('i'), P(), P(), P('i'), P('i'), P()), out_specs=(P('i'), P('i'), P('i')), check_vma=False, ))
[docs] def optimize_ROCS_esp_overlay_jax_surf_esp_shmap( ref_batch, fit_batch, ref_charges_batch, fit_charges_batch, VAA_batch, VBB_batch, se3_init_batch, alpha: float, lam: float, lr: float, max_num_steps: int, ): """Non-masked surface ESP alignment via ``shard_map`` + ``vmap`` across virtual CPU devices. Parameters ---------- ref_batch : (total, N, 3) stacked reference surface positions fit_batch : (total, M, 3) stacked fit surface positions ref_charges_batch : (total, N) stacked reference ESP values fit_charges_batch : (total, M) stacked fit ESP values VAA_batch : (total,) pre-computed ref ESP self-overlaps VBB_batch : (total,) pre-computed fit ESP self-overlaps se3_init_batch : (total, R, 7) pre-initialised SE(3) params alpha, lam, lr : float max_num_steps : int Returns ------- aligned_pts : (total, M, 3) se3_transform : (total, 4, 4) scores : (total,) """ devices = jax.devices() mesh = Mesh(np.array(devices), axis_names=('i',)) cache_key = (max_num_steps, len(devices)) if cache_key not in _shmap_surf_esp_cache: _shmap_surf_esp_cache[cache_key] = _make_shmap_surf_esp_fn(max_num_steps, mesh) fn = _shmap_surf_esp_cache[cache_key] return fn( ref_batch, fit_batch, ref_charges_batch, fit_charges_batch, se3_init_batch, alpha, lam, VAA_batch, VBB_batch, lr, )
# --------------------------------------------------------------------------- # Masked pharmacophore alignment # --------------------------------------------------------------------------- _shmap_pharm_cache: dict = {} def _make_shmap_pharm_fn(max_num_steps: int, mesh: Mesh, similarity: str, extended_points: bool, only_extended: bool): """Build a ``jit(shard_map(vmap(...)))`` masked pharmacophore alignment function.""" _per_pair_fn = _per_pair_optimize_pharm_mask_scan_factory( similarity, extended_points, only_extended ) def _per_shard(ref_pharms_b, fit_pharms_b, ref_ancs_b, fit_ancs_b, ref_vecs_b, fit_vecs_b, mask_r_b, mask_f_b, se3_b, ref_self_b, fit_self_b, lr): """Processes one device's shard.""" def per_pair(rp, fp, ra, fa, rv, fv, mr, mf, s, rss, fss): return _per_pair_fn( rp, fp, ra, fa, rv, fv, mr, mf, s, rss, fss, lr, max_num_steps ) return vmap(per_pair)( ref_pharms_b, fit_pharms_b, ref_ancs_b, fit_ancs_b, ref_vecs_b, fit_vecs_b, mask_r_b, mask_f_b, se3_b, ref_self_b, fit_self_b, ) return jit(jax.shard_map( _per_shard, mesh=mesh, in_specs=(P('i'), P('i'), P('i'), P('i'), P('i'), P('i'), P('i'), P('i'), P('i'), P('i'), P('i'), P()), out_specs=(P('i'), P('i'), P('i'), P('i')), check_vma=False, ))
[docs] def optimize_pharm_overlay_jax_pharm_shmap( ref_pharms_batch, fit_pharms_batch, ref_anchors_batch, fit_anchors_batch, ref_vectors_batch, fit_vectors_batch, mask_ref_batch, mask_fit_batch, ref_self_batch, fit_self_batch, se3_init_batch, similarity: str, extended_points: bool, only_extended: bool, lr: float, max_num_steps: int, ): """Masked pharmacophore alignment via ``shard_map`` + ``vmap`` across virtual CPU devices. Parameters ---------- ref_pharms_batch : (total, N) padded reference pharmacophore type indices (int32) fit_pharms_batch : (total, M) padded fit pharmacophore type indices ref_anchors_batch : (total, N, 3) padded reference anchor positions fit_anchors_batch : (total, M, 3) padded fit anchor positions ref_vectors_batch : (total, N, 3) padded reference direction vectors fit_vectors_batch : (total, M, 3) padded fit direction vectors mask_ref_batch : (total, N) binary masks mask_fit_batch : (total, M) ref_self_batch : (total,) pre-computed ref self-overlaps fit_self_batch : (total,) pre-computed fit self-overlaps se3_init_batch : (total, R, 7) pre-initialised SE(3) params similarity : str ('tanimoto', 'tversky_ref', 'tversky_fit') extended_points : bool only_extended : bool lr : float max_num_steps : int Returns ------- aligned_anchors : (total, M, 3) aligned_vectors : (total, M, 3) se3_transform : (total, 4, 4) scores : (total,) """ devices = jax.devices() mesh = Mesh(np.array(devices), axis_names=('i',)) cache_key = (similarity, extended_points, only_extended, max_num_steps, len(devices)) if cache_key not in _shmap_pharm_cache: _shmap_pharm_cache[cache_key] = _make_shmap_pharm_fn( max_num_steps, mesh, similarity, extended_points, only_extended ) fn = _shmap_pharm_cache[cache_key] return fn( ref_pharms_batch, fit_pharms_batch, ref_anchors_batch, fit_anchors_batch, ref_vectors_batch, fit_vectors_batch, mask_ref_batch, mask_fit_batch, se3_init_batch, ref_self_batch, fit_self_batch, lr, )