MoleculePairBatch#

MoleculePairBatch accepts a list of MoleculePair objects and aligns them efficiently by padding all atom/pharmacophore arrays to a common maximum length. Because every call shares the same padded array shape, JAX’s XLA compiler produces a single compiled kernel that is reused for every pair in the batch — avoiding the per-pair recompilation overhead that occurs when array shapes differ.

Note

While shard_map is recommended for all cases, it requires jax>=0.9.0 and thus python>=3.11. An alternative is to use multiprocessing with 'spawn' context by setting use_shmap=False. However, this is known to NOT work on Linux HPC environments and has only been tested on M-series Macs.

Masking strategy#

_pad_arrays() pads each coordinate (or charge) array to max_len and produces a binary float32 mask (1.0 = real atom, 0.0 = padding). An outer-product pair mask

pair_mask = mask_fit[:, None] * mask_ref[None, :]

zeroes out all padding-atom contributions to the Gaussian overlap sum, self-overlaps, and gradients. Since padded shapes are fixed across all pairs, the compiled kernel is reused without recompilation.

Parallel volumetric alignment via jax.shard_map#

Note

XLA_FLAGS must be set before any JAX import. Place the environment variable assignments at the very top of your script, before any import jax or from shepherd_score ... statements.

align_with_vol() supports a use_shmap=True path that distributes pairs across virtual CPU devices using optimize_ROCS_overlay_jax_vol_shmap() instead of Python multiprocessing:

import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
os.environ['JAX_PLATFORMS'] = 'cpu'

from shepherd_score.container import MoleculePairBatch

batch = MoleculePairBatch(pairs)

# Sequential (default)
scores, aligned = batch.align_with_vol(num_workers=1)

# Parallel via shard_map
scores, aligned = batch.align_with_vol(num_workers=4, use_shmap=True)

Why shard_map instead of multiprocessing on HPC?

  • multiprocessing with 'spawn' context can be unreliable on Linux HPC (JAX initialisation in subprocesses, resource limits).

  • shard_map distributes work across virtual CPU devices within a single process without forking/spawning.

Bucketing for heterogeneous molecule sets#

By default (n_buckets=1) all pairs are padded to the global atom-count maximum and processed in a single shard_map call (one JIT compilation, lowest overhead). For large heterogeneous sets use n_buckets > 1, which sorts pairs by (max(ref, fit), min(ref, fit)) via np.lexsort and processes each bucket with its own local padding maximum — reducing wasted computation at the cost of multiple sequential shard_map calls:

# Default: single pass, one compilation
scores, aligned = batch.align_with_vol(num_workers=4, use_shmap=True)

# Bucketed: useful for datasets with diverse molecule sizes
scores, aligned = batch.align_with_vol(num_workers=4, use_shmap=True, num_buckets=4)

Note

use_shmap=True uses lax.scan (fixed steps, no convergence-based early stopping). max_num_steps=200 is the default.

Available batch alignment methods#

For the low-level parallel kernel see JAX Parallel Alignment. For the full scoring and alignment theory see Representations, Scoring, and Alignment.

Class reference#

class shepherd_score.container._batch.MoleculePairBatch(pairs)[source]#

Bases: object

Batch of MoleculePair objects for fast sequential JAX alignment.

Pads all atom coordinate arrays to common max shapes so JAX’s XLA compiler reuses the same compiled function for every pair, avoiding recompilation. This modifies each MoleculePair in-place (stores results on the pair).

This is currently optimized for CPU. A GPU-optimized version would benefit from optimizing batches of pairs and using a GPU-optimized alignment.

Parameters:

pairs (List[MoleculePair])

align_with_vol(no_H=True, num_repeats=50, trans_init=False, lr=0.1, max_num_steps=200, num_workers=1, use_shmap=True, num_buckets=1, verbose=False)[source]#

Align all pairs using padded masked volumetric similarity via JAX.

Because all padded arrays have the same shape, JAX’s XLA compiler reuses one compiled kernel for every pair — no recompilation overhead.

When num_workers > 1 the pairs are split into size-sorted chunks and processed in parallel. It is recommended to use use_shmap=True instead of multiprocessing for this setting.

Results are stored in-place on each MoleculePair: - pair.transform_vol_noH / pair.sim_aligned_vol_noH (when no_H=True) - pair.transform_vol / pair.sim_aligned_vol (when no_H=False)

Parameters:
  • no_H (bool) – Whether to exclude hydrogens. Default is True.

  • num_repeats (int) – Number of SE(3) initializations per pair. Default is 50.

  • trans_init (bool) – If True, initialize translations to each ref atom position. Default is False.

  • lr (float) – Optimizer learning rate. Default is 0.1.

  • max_num_steps (int) – Maximum optimization steps. Default is 200.

  • num_workers (int) – Number of parallel workers. 1 (default) runs sequentially in-process. When use_shmap=True (the default), this value is informational; actual parallelism equals len(jax.devices()), which is set by XLA_FLAGS before JAX is first imported. When use_shmap=False use multiprocessing with a 'spawn' start method.

  • use_shmap (bool) – If True and num_workers > 1, use jax.shard_map + vmap to parallelise across virtual CPU devices in a single process. Requires XLA_FLAGS=--xla_force_host_platform_device_count=N to be set before any JAX import. Uses lax.scan (fixed steps, no early stopping) instead of the while_loop-based sequential path. Required on Linux HPC if num_workers > 1 where multiprocessing spawn can be unreliable with JAX. Default is True.

  • num_buckets (int) – 1 (default) pads all pairs to the global atom-count maximum — lowest overhead for typical use. Values > 1 sort pairs by (max(ref,fit), min(ref,fit)) and process each bucket separately with reduced per-bucket padding, which can be beneficial for large heterogeneous molecule sets.

  • verbose (bool) – Print scores per pair. Default is False.

Returns:

  • scores (np.ndarray) – Scores for each pair. Shape: (N,).

  • aligned_list (list of np.ndarray) – Aligned fit atom coordinates (unpadded) for each pair.

Return type:

Tuple[ndarray, List[ndarray]]

align_with_vol_esp(lam, no_H=True, num_repeats=50, trans_init=False, lr=0.1, max_num_steps=200, num_workers=1, use_shmap=True, num_buckets=1, verbose=False)[source]#

Align all pairs using padded masked volumetric ESP similarity via JAX.

Because all padded arrays have the same shape, JAX’s XLA compiler reuses one compiled kernel for every pair — no recompilation overhead.

When num_workers > 1 the pairs are split into size-sorted chunks and processed in parallel. It is recommended to use use_shmap=True instead of multiprocessing for this setting.

Results are stored in-place on each MoleculePair: - pair.transform_vol_esp_noH / pair.sim_aligned_vol_esp_noH (when no_H=True) - pair.transform_vol_esp / pair.sim_aligned_vol_esp (when no_H=False)

Parameters:
  • lam (float) – Partial charge weighting parameter. Typically 0.1 for volumetric.

  • no_H (bool) – Whether to exclude hydrogens. Default is True.

  • num_repeats (int) – Number of SE(3) initializations per pair. Default is 50.

  • trans_init (bool) – If True, initialize translations to each ref atom position. Default is False.

  • lr (float) – Optimizer learning rate. Default is 0.1.

  • max_num_steps (int) – Maximum optimization steps. Default is 200.

  • num_workers (int) – Number of parallel worker processes. 1 (default) runs sequentially in-process. Values greater than len(self.pairs) are clamped to len(self.pairs).

  • use_shmap (bool) – If True and num_workers > 1, use jax.shard_map + vmap to parallelise across virtual CPU devices in a single process. Requires XLA_FLAGS=--xla_force_host_platform_device_count=N to be set before any JAX import. Uses lax.scan (fixed steps, no early stopping) instead of the while_loop-based sequential path. Required on Linux HPC if num_workers > 1 where multiprocessing spawn can be unreliable with JAX. Default is True.

  • num_buckets (int) – 1 (default) pads all pairs to the global atom-count maximum — lowest overhead for typical use. Values > 1 sort pairs by (max(ref,fit), min(ref,fit)) and process each bucket separately with reduced per-bucket padding, which can be beneficial for large heterogeneous molecule sets.

  • verbose (bool) – Print scores per pair. Default is False.

Returns:

  • scores (np.ndarray) – Scores for each pair. Shape: (N,).

  • aligned_list (list of np.ndarray) – Aligned fit atom coordinates (unpadded) for each pair.

Return type:

Tuple[ndarray, List[ndarray]]

align_with_surf(alpha, num_repeats=50, trans_init=False, lr=0.1, max_num_steps=200, use_jax=True, use_analytical=True, num_workers=1, use_shmap=False, verbose=False)[source]#

Align all pairs using surface similarity.

Surface arrays are the same size across all pairs so no padding or size-sorting is needed. It is not recommended to use multiprocessing due to this reason.

Results are stored in-place on each MoleculePair: - pair.transform_surf and pair.sim_aligned_surf

Parameters:
  • alpha (float) – Gaussian width parameter for overlap.

  • num_repeats (int) – Number of SE(3) initializations per pair. Default is 50.

  • trans_init (bool) – Apply translation initialization for alignment. Default is False.

  • lr (float) – Optimizer learning rate. Default is 0.1.

  • max_num_steps (int) – Maximum optimization steps. Default is 200.

  • use_jax (bool) – Whether to use JAX backend. Default is True.

  • use_analytical (bool) – Whether to use analytical gradients (PyTorch only). Default is True.

  • num_workers (int) – Number of parallel worker processes. 1 (default) runs sequentially in-process. Values greater than len(self.pairs) are clamped to len(self.pairs).

  • use_shmap (bool) – Whether to use JAX shard_map for parallel alignment. Default is False. Performance is better when use_shmap is False on cpu.

  • verbose (bool) – Print scores per pair. Default is False.

Returns:

  • scores (np.ndarray) – Scores for each pair. Shape: (N,).

  • aligned_list (list of np.ndarray) – Aligned fit surface coordinates for each pair.

Return type:

Tuple[ndarray, List[ndarray]]

align_with_esp(alpha, lam=0.3, num_repeats=50, trans_init=False, lr=0.1, max_num_steps=200, use_jax=True, use_analytical=True, num_workers=1, use_shmap=False, verbose=False)[source]#

Align all pairs using ESP+surface similarity.

Surface arrays are the same size across all pairs so no padding or size-sorting is needed. It is not recommended to use multiprocessing due to this reason.

Results are stored in-place on each MoleculePair: - pair.transform_esp and pair.sim_aligned_esp

Parameters:
  • alpha (float) – Gaussian width parameter for overlap.

  • lam (float) – Weighting factor for ESP scoring. Scaled internally. Default is 0.3.

  • num_repeats (int) – Number of SE(3) initializations per pair. Default is 50.

  • trans_init (bool) – Apply translation initialization for alignment. Default is False.

  • lr (float) – Optimizer learning rate. Default is 0.1.

  • max_num_steps (int) – Maximum optimization steps. Default is 200.

  • use_jax (bool) – Whether to use JAX backend. Default is True.

  • use_analytical (bool) – Whether to use analytical gradients (PyTorch only). Default is True.

  • num_workers (int) – Number of parallel worker processes. 1 (default) runs sequentially in-process. Values greater than len(self.pairs) are clamped to len(self.pairs).

  • use_shmap (bool) – Whether to use JAX shard_map for parallel alignment. Default is False. Performance is better when use_shmap is False on cpu.

  • verbose (bool) – Print scores per pair. Default is False.

Returns:

  • scores (np.ndarray) – Scores for each pair. Shape: (N,).

  • aligned_list (list of np.ndarray) – Aligned fit surface coordinates for each pair.

Return type:

Tuple[ndarray, List[ndarray]]

align_with_pharm(similarity='tanimoto', extended_points=False, only_extended=False, num_repeats=50, trans_init=False, lr=0.1, max_num_steps=200, num_workers=1, use_shmap=True, num_buckets=1, verbose=False)[source]#

Align all pairs using padded masked pharmacophore similarity via JAX.

Because all padded arrays have the same shape, JAX’s XLA compiler reuses one compiled kernel for every pair — no recompilation overhead.

When num_workers > 1 the pairs are split into size-sorted chunks and processed in parallel. It is recommended to use use_shmap=True instead of multiprocessing for this setting.

Results are stored in-place on each MoleculePair: - pair.transform_pharm and pair.sim_aligned_pharm

Parameters:
  • similarity (str) – One of 'tanimoto', 'tversky', 'tversky_ref', 'tversky_fit'.

  • extended_points (bool) – Score HBA/HBD with extended-point Gaussians.

  • only_extended (bool) – When extended_points is True, ignore anchor overlaps.

  • num_repeats (int) – Number of SE(3) initializations per pair.

  • trans_init (bool) – If True, initialize translations to each ref pharmacophore anchor.

  • lr (float) – Optimizer learning rate.

  • max_num_steps (int) – Maximum optimization steps.

  • num_workers (int) – Number of parallel worker processes. 1 (default) runs sequentially in-process. Values greater than len(self.pairs) are clamped to len(self.pairs).

  • use_shmap (bool) – If True and num_workers > 1, use jax.shard_map + vmap to parallelise across virtual CPU devices in a single process. Requires XLA_FLAGS=--xla_force_host_platform_device_count=N to be set before any JAX import. Uses lax.scan (fixed steps, no early stopping) instead of the while_loop-based sequential path. Required on Linux HPC if num_workers > 1 where multiprocessing spawn can be unreliable with JAX. Default is True.

  • num_buckets (int) – 1 (default) pads all pairs to the global pharmacophore-count maximum — lowest overhead for typical use. Values > 1 sort pairs by (max(ref,fit), min(ref,fit)) and process each bucket separately with reduced per-bucket padding, which can be beneficial for large heterogeneous molecule sets.

  • verbose (bool) – Print scores per pair.

Returns:

  • scores (np.ndarray) – Shape: (N,).

  • aligned_anchors_list (list of np.ndarray) – Aligned fit pharmacophore anchors (unpadded) for each pair.

  • aligned_vectors_list (list of np.ndarray) – Aligned fit pharmacophore vectors (unpadded) for each pair.

Return type:

Tuple[ndarray, List[ndarray], List[ndarray]]