MoleculePairBatch#
MoleculePairBatch accepts a list of
MoleculePair objects and aligns them
efficiently by padding all atom/pharmacophore arrays to a common maximum
length. Because every call shares the same padded array shape, JAX’s XLA
compiler produces a single compiled kernel that is reused for every
pair in the batch — avoiding the per-pair recompilation overhead that
occurs when array shapes differ.
Note
While shard_map is recommended for all cases, it requires
jax>=0.9.0 and thus python>=3.11. An alternative is to use
multiprocessing with 'spawn' context by setting
use_shmap=False. However, this is known to NOT work on Linux HPC
environments and has only been tested on M-series Macs.
Masking strategy#
_pad_arrays() pads each coordinate (or charge) array to max_len and
produces a binary float32 mask (1.0 = real atom, 0.0 = padding).
An outer-product pair mask
pair_mask = mask_fit[:, None] * mask_ref[None, :]
zeroes out all padding-atom contributions to the Gaussian overlap sum, self-overlaps, and gradients. Since padded shapes are fixed across all pairs, the compiled kernel is reused without recompilation.
Parallel volumetric alignment via jax.shard_map#
Note
XLA_FLAGS must be set before any JAX import. Place the
environment variable assignments at the very top of your script, before
any import jax or from shepherd_score ... statements.
align_with_vol() supports
a use_shmap=True path that distributes pairs across virtual CPU devices
using optimize_ROCS_overlay_jax_vol_shmap()
instead of Python multiprocessing:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
os.environ['JAX_PLATFORMS'] = 'cpu'
from shepherd_score.container import MoleculePairBatch
batch = MoleculePairBatch(pairs)
# Sequential (default)
scores, aligned = batch.align_with_vol(num_workers=1)
# Parallel via shard_map
scores, aligned = batch.align_with_vol(num_workers=4, use_shmap=True)
Why shard_map instead of multiprocessing on HPC?
multiprocessingwith'spawn'context can be unreliable on Linux HPC (JAX initialisation in subprocesses, resource limits).shard_mapdistributes work across virtual CPU devices within a single process without forking/spawning.
Bucketing for heterogeneous molecule sets#
By default (n_buckets=1) all pairs are padded to the global atom-count
maximum and processed in a single shard_map call (one JIT compilation,
lowest overhead). For large heterogeneous sets use n_buckets > 1, which
sorts pairs by (max(ref, fit), min(ref, fit)) via np.lexsort and
processes each bucket with its own local padding maximum — reducing wasted
computation at the cost of multiple sequential shard_map calls:
# Default: single pass, one compilation
scores, aligned = batch.align_with_vol(num_workers=4, use_shmap=True)
# Bucketed: useful for datasets with diverse molecule sizes
scores, aligned = batch.align_with_vol(num_workers=4, use_shmap=True, num_buckets=4)
Note
use_shmap=True uses lax.scan (fixed steps, no convergence-based
early stopping). max_num_steps=200 is the default.
Available batch alignment methods#
For the low-level parallel kernel see JAX Parallel Alignment. For the full scoring and alignment theory see Representations, Scoring, and Alignment.
Class reference#
- class shepherd_score.container._batch.MoleculePairBatch(pairs)[source]#
Bases:
objectBatch of MoleculePair objects for fast sequential JAX alignment.
Pads all atom coordinate arrays to common max shapes so JAX’s XLA compiler reuses the same compiled function for every pair, avoiding recompilation. This modifies each MoleculePair in-place (stores results on the pair).
This is currently optimized for CPU. A GPU-optimized version would benefit from optimizing batches of pairs and using a GPU-optimized alignment.
- Parameters:
pairs (List[MoleculePair])
- align_with_vol(no_H=True, num_repeats=50, trans_init=False, lr=0.1, max_num_steps=200, num_workers=1, use_shmap=True, num_buckets=1, verbose=False)[source]#
Align all pairs using padded masked volumetric similarity via JAX.
Because all padded arrays have the same shape, JAX’s XLA compiler reuses one compiled kernel for every pair — no recompilation overhead.
When
num_workers > 1the pairs are split into size-sorted chunks and processed in parallel. It is recommended to useuse_shmap=Trueinstead ofmultiprocessingfor this setting.Results are stored in-place on each MoleculePair: -
pair.transform_vol_noH/pair.sim_aligned_vol_noH(whenno_H=True) -pair.transform_vol/pair.sim_aligned_vol(whenno_H=False)- Parameters:
no_H (bool) – Whether to exclude hydrogens. Default is True.
num_repeats (int) – Number of SE(3) initializations per pair. Default is 50.
trans_init (bool) – If True, initialize translations to each ref atom position. Default is False.
lr (float) – Optimizer learning rate. Default is 0.1.
max_num_steps (int) – Maximum optimization steps. Default is 200.
num_workers (int) – Number of parallel workers.
1(default) runs sequentially in-process. Whenuse_shmap=True(the default), this value is informational; actual parallelism equalslen(jax.devices()), which is set byXLA_FLAGSbefore JAX is first imported. Whenuse_shmap=Falseusemultiprocessingwith a'spawn'start method.use_shmap (bool) – If
Trueandnum_workers > 1, usejax.shard_map+vmapto parallelise across virtual CPU devices in a single process. RequiresXLA_FLAGS=--xla_force_host_platform_device_count=Nto be set before any JAX import. Useslax.scan(fixed steps, no early stopping) instead of thewhile_loop-based sequential path. Required on Linux HPC if num_workers > 1 wheremultiprocessingspawn can be unreliable with JAX. Default isTrue.num_buckets (int) –
1(default) pads all pairs to the global atom-count maximum — lowest overhead for typical use. Values > 1 sort pairs by(max(ref,fit), min(ref,fit))and process each bucket separately with reduced per-bucket padding, which can be beneficial for large heterogeneous molecule sets.verbose (bool) – Print scores per pair. Default is False.
- Returns:
scores (np.ndarray) – Scores for each pair. Shape: (N,).
aligned_list (list of np.ndarray) – Aligned fit atom coordinates (unpadded) for each pair.
- Return type:
- align_with_vol_esp(lam, no_H=True, num_repeats=50, trans_init=False, lr=0.1, max_num_steps=200, num_workers=1, use_shmap=True, num_buckets=1, verbose=False)[source]#
Align all pairs using padded masked volumetric ESP similarity via JAX.
Because all padded arrays have the same shape, JAX’s XLA compiler reuses one compiled kernel for every pair — no recompilation overhead.
When
num_workers > 1the pairs are split into size-sorted chunks and processed in parallel. It is recommended to useuse_shmap=Trueinstead ofmultiprocessingfor this setting.Results are stored in-place on each MoleculePair: -
pair.transform_vol_esp_noH/pair.sim_aligned_vol_esp_noH(whenno_H=True) -pair.transform_vol_esp/pair.sim_aligned_vol_esp(whenno_H=False)- Parameters:
lam (float) – Partial charge weighting parameter. Typically 0.1 for volumetric.
no_H (bool) – Whether to exclude hydrogens. Default is True.
num_repeats (int) – Number of SE(3) initializations per pair. Default is 50.
trans_init (bool) – If True, initialize translations to each ref atom position. Default is False.
lr (float) – Optimizer learning rate. Default is 0.1.
max_num_steps (int) – Maximum optimization steps. Default is 200.
num_workers (int) – Number of parallel worker processes.
1(default) runs sequentially in-process. Values greater thanlen(self.pairs)are clamped tolen(self.pairs).use_shmap (bool) – If
Trueandnum_workers > 1, usejax.shard_map+vmapto parallelise across virtual CPU devices in a single process. RequiresXLA_FLAGS=--xla_force_host_platform_device_count=Nto be set before any JAX import. Useslax.scan(fixed steps, no early stopping) instead of thewhile_loop-based sequential path. Required on Linux HPC if num_workers > 1 wheremultiprocessingspawn can be unreliable with JAX. Default isTrue.num_buckets (int) –
1(default) pads all pairs to the global atom-count maximum — lowest overhead for typical use. Values > 1 sort pairs by(max(ref,fit), min(ref,fit))and process each bucket separately with reduced per-bucket padding, which can be beneficial for large heterogeneous molecule sets.verbose (bool) – Print scores per pair. Default is False.
- Returns:
scores (np.ndarray) – Scores for each pair. Shape: (N,).
aligned_list (list of np.ndarray) – Aligned fit atom coordinates (unpadded) for each pair.
- Return type:
- align_with_surf(alpha, num_repeats=50, trans_init=False, lr=0.1, max_num_steps=200, use_jax=True, use_analytical=True, num_workers=1, use_shmap=False, verbose=False)[source]#
Align all pairs using surface similarity.
Surface arrays are the same size across all pairs so no padding or size-sorting is needed. It is not recommended to use multiprocessing due to this reason.
Results are stored in-place on each MoleculePair: -
pair.transform_surfandpair.sim_aligned_surf- Parameters:
alpha (float) – Gaussian width parameter for overlap.
num_repeats (int) – Number of SE(3) initializations per pair. Default is 50.
trans_init (bool) – Apply translation initialization for alignment. Default is False.
lr (float) – Optimizer learning rate. Default is 0.1.
max_num_steps (int) – Maximum optimization steps. Default is 200.
use_jax (bool) – Whether to use JAX backend. Default is True.
use_analytical (bool) – Whether to use analytical gradients (PyTorch only). Default is True.
num_workers (int) – Number of parallel worker processes.
1(default) runs sequentially in-process. Values greater thanlen(self.pairs)are clamped tolen(self.pairs).use_shmap (bool) – Whether to use JAX shard_map for parallel alignment. Default is False. Performance is better when use_shmap is False on cpu.
verbose (bool) – Print scores per pair. Default is False.
- Returns:
scores (np.ndarray) – Scores for each pair. Shape: (N,).
aligned_list (list of np.ndarray) – Aligned fit surface coordinates for each pair.
- Return type:
- align_with_esp(alpha, lam=0.3, num_repeats=50, trans_init=False, lr=0.1, max_num_steps=200, use_jax=True, use_analytical=True, num_workers=1, use_shmap=False, verbose=False)[source]#
Align all pairs using ESP+surface similarity.
Surface arrays are the same size across all pairs so no padding or size-sorting is needed. It is not recommended to use multiprocessing due to this reason.
Results are stored in-place on each MoleculePair: -
pair.transform_espandpair.sim_aligned_esp- Parameters:
alpha (float) – Gaussian width parameter for overlap.
lam (float) – Weighting factor for ESP scoring. Scaled internally. Default is 0.3.
num_repeats (int) – Number of SE(3) initializations per pair. Default is 50.
trans_init (bool) – Apply translation initialization for alignment. Default is False.
lr (float) – Optimizer learning rate. Default is 0.1.
max_num_steps (int) – Maximum optimization steps. Default is 200.
use_jax (bool) – Whether to use JAX backend. Default is True.
use_analytical (bool) – Whether to use analytical gradients (PyTorch only). Default is True.
num_workers (int) – Number of parallel worker processes.
1(default) runs sequentially in-process. Values greater thanlen(self.pairs)are clamped tolen(self.pairs).use_shmap (bool) – Whether to use JAX shard_map for parallel alignment. Default is False. Performance is better when use_shmap is False on cpu.
verbose (bool) – Print scores per pair. Default is False.
- Returns:
scores (np.ndarray) – Scores for each pair. Shape: (N,).
aligned_list (list of np.ndarray) – Aligned fit surface coordinates for each pair.
- Return type:
- align_with_pharm(similarity='tanimoto', extended_points=False, only_extended=False, num_repeats=50, trans_init=False, lr=0.1, max_num_steps=200, num_workers=1, use_shmap=True, num_buckets=1, verbose=False)[source]#
Align all pairs using padded masked pharmacophore similarity via JAX.
Because all padded arrays have the same shape, JAX’s XLA compiler reuses one compiled kernel for every pair — no recompilation overhead.
When
num_workers > 1the pairs are split into size-sorted chunks and processed in parallel. It is recommended to useuse_shmap=Trueinstead ofmultiprocessingfor this setting.Results are stored in-place on each MoleculePair: -
pair.transform_pharmandpair.sim_aligned_pharm- Parameters:
similarity (str) – One of
'tanimoto','tversky','tversky_ref','tversky_fit'.extended_points (bool) – Score HBA/HBD with extended-point Gaussians.
only_extended (bool) – When
extended_pointsis True, ignore anchor overlaps.num_repeats (int) – Number of SE(3) initializations per pair.
trans_init (bool) – If True, initialize translations to each ref pharmacophore anchor.
lr (float) – Optimizer learning rate.
max_num_steps (int) – Maximum optimization steps.
num_workers (int) – Number of parallel worker processes.
1(default) runs sequentially in-process. Values greater thanlen(self.pairs)are clamped tolen(self.pairs).use_shmap (bool) – If
Trueandnum_workers > 1, usejax.shard_map+vmapto parallelise across virtual CPU devices in a single process. RequiresXLA_FLAGS=--xla_force_host_platform_device_count=Nto be set before any JAX import. Useslax.scan(fixed steps, no early stopping) instead of thewhile_loop-based sequential path. Required on Linux HPC if num_workers > 1 wheremultiprocessingspawn can be unreliable with JAX. Default isTrue.num_buckets (int) –
1(default) pads all pairs to the global pharmacophore-count maximum — lowest overhead for typical use. Values > 1 sort pairs by(max(ref,fit), min(ref,fit))and process each bucket separately with reduced per-bucket padding, which can be beneficial for large heterogeneous molecule sets.verbose (bool) – Print scores per pair.
- Returns:
scores (np.ndarray) – Shape: (N,).
aligned_anchors_list (list of np.ndarray) – Aligned fit pharmacophore anchors (unpadded) for each pair.
aligned_vectors_list (list of np.ndarray) – Aligned fit pharmacophore vectors (unpadded) for each pair.
- Return type: