JAX Parallel Alignment#
Note
XLA_FLAGS must be set before any JAX import so that
len(jax.devices()) equals the desired number of virtual CPU devices.
Place the following lines at the very top of your script:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
os.environ['JAX_PLATFORMS'] = 'cpu'
This module provides multi-device volumetric alignment using
jax.shard_map(). It is the backend called by
shepherd_score.container.MoleculePairBatch.align_with_vol() when
use_shmap=True.
Overview#
optimize_ROCS_overlay_jax_vol_shmap() distributes a flat batch of
molecule pairs across virtual CPU devices without any Python-level
multiprocessing. Key properties:
Accepts flat
(total, ...)arrays wheretotalis the number of pairs padded to a multiple oflen(jax.devices()). Do not pre-reshape to(n_devices, B, ...).Internally wraps
_per_pair_optimize_vol_mask_scan(fromshepherd_score.alignment._jax) invmap, then distributes viashard_mapwithPartitionSpec('i')on the leading axis.The compiled function is cached in
_shmap_vol_cachekeyed by(max_num_steps, n_devices)so the XLA kernel is compiled only once per unique(steps, device-count)combination.Uses
lax.scanfor a fixed number of steps (no convergence-based early stopping). This enables full ahead-of-time compilation; typical speedup is ~2.8× on 4 CPU cores compared to sequential JAX alignment.
Usage via the high-level API#
The recommended entry point is
align_with_vol():
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
os.environ['JAX_PLATFORMS'] = 'cpu'
from shepherd_score.container import MoleculePairBatch
batch = MoleculePairBatch(pairs)
# Default: single pass (n_buckets=1)
scores, aligned = batch.align_with_vol(num_workers=4, use_shmap=True)
# Bucketed: useful for >10k pairs with diverse molecule sizes
scores, aligned = batch.align_with_vol(num_workers=4, use_shmap=True, n_buckets=8)
See MoleculePairBatch for details on bucketing and masking strategy.
Direct usage#
Pre-compute self-overlaps and SE(3) initialisations outside the function (they are invariant to the optimisation loop), then call:
from shepherd_score.alignment._jax_parallel import optimize_ROCS_overlay_jax_vol_shmap
aligned_pts, se3_transform, scores = optimize_ROCS_overlay_jax_vol_shmap(
ref_batch, # (total, N, 3)
fit_batch, # (total, M, 3)
mask_ref_batch, # (total, N)
mask_fit_batch, # (total, M)
VAA_batch, # (total,) pre-computed ref self-overlaps
VBB_batch, # (total,) pre-computed fit self-overlaps
se3_init_batch, # (total, R, 7) pre-initialised SE(3) params
alpha=0.81,
lr=0.1,
max_num_steps=200,
)
API#
Parallel (multi-device) JAX alignment via jax.shard_map.
Each public optimize_*_jax_*_shmap function accepts flat (total, ...)
arrays (where total is the number of pairs padded to a multiple of
len(jax.devices())) and returns flat (total, ...) results. The
leading axis is automatically distributed across devices by
jax.shard_map with PartitionSpec('i'); do not pre-reshape to
(n_devices, B, ...).
XLA_FLAGS=--xla_force_host_platform_device_count=N must be set before
any JAX import so that len(jax.devices()) == N.
- shepherd_score.alignment._jax_parallel.optimize_ROCS_overlay_jax_vol_shmap(ref_batch, fit_batch, mask_ref_batch, mask_fit_batch, VAA_batch, VBB_batch, se3_init_batch, alpha, lr, max_num_steps)[source]#
Volumetric alignment via
shard_map+vmapacross virtual CPU devices.All
*_batcharrays use a flat leading axis of sizetotal(i.e. the number of pairs padded to a multiple oflen(jax.devices())). Unlikepmap,shard_mapautomatically distributes the flat leading axis across devices; do not pre-reshape to(n_devices, B, ...). Pre-compute self-overlapsVAA/VBBand SE(3) initialisations outside this function (they are invariant to the optimisation loop).XLA_FLAGS=--xla_force_host_platform_device_count=Nmust be set before JAX is first imported so thatlen(jax.devices()) == N.- Parameters:
ref_batch ((total, N, 3) padded reference positions)
fit_batch ((total, M, 3) padded fit positions)
mask_ref_batch ((total, N))
mask_fit_batch ((total, M))
VAA_batch ((total,) pre-computed ref self-overlaps)
VBB_batch ((total,) pre-computed fit self-overlaps)
se3_init_batch ((total, R, 7) pre-initialised SE(3) params)
alpha (float)
lr (float)
max_num_steps (int (Python int; determines compiled kernel))
- Returns:
aligned_pts ((total, M, 3))
se3_transform ((total, 4, 4))
scores ((total,))
- shepherd_score.alignment._jax_parallel.optimize_ROCS_esp_overlay_jax_vol_esp_shmap(ref_batch, fit_batch, ref_charges_batch, fit_charges_batch, mask_ref_batch, mask_fit_batch, VAA_batch, VBB_batch, se3_init_batch, alpha, lam, lr, max_num_steps)[source]#
Masked volumetric ESP alignment via
shard_map+vmapacross virtual CPU devices.- Parameters:
ref_batch ((total, N, 3) padded reference positions)
fit_batch ((total, M, 3) padded fit positions)
ref_charges_batch ((total, N, 1) padded reference charges (column-shaped))
fit_charges_batch ((total, M, 1) padded fit charges)
mask_ref_batch ((total, N))
mask_fit_batch ((total, M))
VAA_batch ((total,) pre-computed ref ESP self-overlaps)
VBB_batch ((total,) pre-computed fit ESP self-overlaps)
se3_init_batch ((total, R, 7) pre-initialised SE(3) params)
alpha (float)
lam (float)
lr (float)
max_num_steps (int (Python int; determines compiled kernel))
- Returns:
aligned_pts ((total, M, 3))
se3_transform ((total, 4, 4))
scores ((total,))
- shepherd_score.alignment._jax_parallel.optimize_ROCS_overlay_jax_surf_shmap(ref_batch, fit_batch, VAA_batch, VBB_batch, se3_init_batch, alpha, lr, max_num_steps)[source]#
Non-masked surface alignment via
shard_map+vmapacross virtual CPU devices.Surface arrays are uniform size across all pairs so no masking is needed.
- Parameters:
ref_batch ((total, N, 3) stacked reference surface positions)
fit_batch ((total, M, 3) stacked fit surface positions)
VAA_batch ((total,) pre-computed ref self-overlaps)
VBB_batch ((total,) pre-computed fit self-overlaps)
se3_init_batch ((total, R, 7) pre-initialised SE(3) params)
alpha (float)
lr (float)
max_num_steps (int (Python int; determines compiled kernel))
- Returns:
aligned_pts ((total, M, 3))
se3_transform ((total, 4, 4))
scores ((total,))
- shepherd_score.alignment._jax_parallel.optimize_ROCS_esp_overlay_jax_surf_esp_shmap(ref_batch, fit_batch, ref_charges_batch, fit_charges_batch, VAA_batch, VBB_batch, se3_init_batch, alpha, lam, lr, max_num_steps)[source]#
Non-masked surface ESP alignment via
shard_map+vmapacross virtual CPU devices.- Parameters:
ref_batch ((total, N, 3) stacked reference surface positions)
fit_batch ((total, M, 3) stacked fit surface positions)
ref_charges_batch ((total, N) stacked reference ESP values)
fit_charges_batch ((total, M) stacked fit ESP values)
VAA_batch ((total,) pre-computed ref ESP self-overlaps)
VBB_batch ((total,) pre-computed fit ESP self-overlaps)
se3_init_batch ((total, R, 7) pre-initialised SE(3) params)
alpha (float)
lam (float)
lr (float)
max_num_steps (int)
- Returns:
aligned_pts ((total, M, 3))
se3_transform ((total, 4, 4))
scores ((total,))
- shepherd_score.alignment._jax_parallel.optimize_pharm_overlay_jax_pharm_shmap(ref_pharms_batch, fit_pharms_batch, ref_anchors_batch, fit_anchors_batch, ref_vectors_batch, fit_vectors_batch, mask_ref_batch, mask_fit_batch, ref_self_batch, fit_self_batch, se3_init_batch, similarity, extended_points, only_extended, lr, max_num_steps)[source]#
Masked pharmacophore alignment via
shard_map+vmapacross virtual CPU devices.- Parameters:
ref_pharms_batch ((total, N) padded reference pharmacophore type indices (int32))
fit_pharms_batch ((total, M) padded fit pharmacophore type indices)
ref_anchors_batch ((total, N, 3) padded reference anchor positions)
fit_anchors_batch ((total, M, 3) padded fit anchor positions)
ref_vectors_batch ((total, N, 3) padded reference direction vectors)
fit_vectors_batch ((total, M, 3) padded fit direction vectors)
mask_ref_batch ((total, N) binary masks)
mask_fit_batch ((total, M))
ref_self_batch ((total,) pre-computed ref self-overlaps)
fit_self_batch ((total,) pre-computed fit self-overlaps)
se3_init_batch ((total, R, 7) pre-initialised SE(3) params)
similarity (str ('tanimoto', 'tversky_ref', 'tversky_fit'))
extended_points (bool)
only_extended (bool)
lr (float)
max_num_steps (int)
- Returns:
aligned_anchors ((total, M, 3))
aligned_vectors ((total, M, 3))
se3_transform ((total, 4, 4))
scores ((total,))