Source code for shepherd_score.container._batch

"""MoleculePairBatch: batch of MoleculePair objects for fast sequential JAX alignment."""
from importlib.metadata import version as _pkg_version
from typing import List, Tuple

import numpy as np

from shepherd_score.container._core import MoleculePair
from shepherd_score.container._batch_utils import (
    _pad_arrays,
    _dispatch_parallel,
    _align_vol_shmap,
    _align_vol_esp_shmap,
    _align_surf_shmap,
    _align_esp_shmap,
    _align_pharm_shmap,
    _align_vol_worker,
    _align_vol_esp_worker,
    _align_surf_worker,
    _align_esp_worker,
    _align_pharm_worker,
)


def _compute_bucket_splits(sizes_a, sizes_b, num_buckets):
    """Sort pairs by (max(a,b), min(a,b)) and split into buckets.

    Parameters
    ----------
    sizes_a, sizes_b : array-like of int
        Per-pair sizes (e.g. atom counts) for the two molecules.
    num_buckets : int
        Number of buckets.  ``<= 1`` returns a single bucket with all
        pairs in their original order (no sorting).

    Returns
    -------
    list of list of int
        Each inner list is a bucket of global pair indices.
    """
    n = len(sizes_a)
    if num_buckets <= 1:
        return [list(range(n))]
    sizes_a = np.asarray(sizes_a)
    sizes_b = np.asarray(sizes_b)
    sort_keys = np.array([np.minimum(sizes_a, sizes_b),
                           np.maximum(sizes_a, sizes_b)])
    sorted_order = np.lexsort(sort_keys)
    num_buckets_actual = min(num_buckets, n)
    return [
        arr.tolist()
        for arr in np.array_split(sorted_order, num_buckets_actual)
        if len(arr) > 0
    ]


[docs] class MoleculePairBatch: """Batch of MoleculePair objects for fast sequential JAX alignment. Pads all atom coordinate arrays to common max shapes so JAX's XLA compiler reuses the same compiled function for every pair, avoiding recompilation. This modifies each MoleculePair in-place (stores results on the pair). This is currently optimized for CPU. A GPU-optimized version would benefit from optimizing batches of pairs and using a GPU-optimized alignment. """ def __init__(self, pairs: List[MoleculePair]): self.pairs = pairs def _pad_and_mask_vol(self, no_H: bool = True, include_charges: bool = False): """Extract, pad, and create masks for volumetric (and optionally ESP) alignment. Does NOT modify the pair objects. Returns padded arrays and masks. Parameters ---------- no_H : bool If True, use heavy-atom positions (atom_pos). If False, use all atoms. include_charges : bool If True, also extract and pad partial charge arrays. The returned tuple per entry gains two extra elements: ``(ref_pos_pad, fit_pos_pad, ref_ch_pad, fit_ch_pad, mask_ref, mask_fit, orig_ref, orig_fit)``. If False, each entry is ``(ref_padded, fit_padded, mask_ref, mask_fit, orig_ref, orig_fit)``. Returns ------- entries : list of tuples max_ref_len : int max_fit_len : int """ if no_H: ref_pos_arrays = [p.ref_molec.atom_pos for p in self.pairs] fit_pos_arrays = [p.fit_molec.atom_pos for p in self.pairs] if include_charges: ref_ch_arrays = [p.ref_molec.partial_charges[p.ref_molec._nonH_atoms_idx] for p in self.pairs] fit_ch_arrays = [p.fit_molec.partial_charges[p.fit_molec._nonH_atoms_idx] for p in self.pairs] else: ref_pos_arrays = [p.ref_molec.mol.GetConformer().GetPositions().astype(np.float32) for p in self.pairs] fit_pos_arrays = [p.fit_molec.mol.GetConformer().GetPositions().astype(np.float32) for p in self.pairs] if include_charges: ref_ch_arrays = [p.ref_molec.partial_charges for p in self.pairs] fit_ch_arrays = [p.fit_molec.partial_charges for p in self.pairs] ref_padded, masks_ref, orig_refs, max_ref_len = _pad_arrays(ref_pos_arrays) fit_padded, masks_fit, orig_fits, max_fit_len = _pad_arrays(fit_pos_arrays) if include_charges: ref_ch_padded, _, _, _ = _pad_arrays(ref_ch_arrays) fit_ch_padded, _, _, _ = _pad_arrays(fit_ch_arrays) entries = [ (rp, fp, rc, fc, mr, mf, ori_r, ori_f) for rp, fp, rc, fc, mr, mf, ori_r, ori_f in zip( ref_padded, fit_padded, ref_ch_padded, fit_ch_padded, masks_ref, masks_fit, orig_refs, orig_fits ) ] else: entries = [ (rp, fp, mr, mf, ori_r, ori_f) for rp, fp, mr, mf, ori_r, ori_f in zip( ref_padded, fit_padded, masks_ref, masks_fit, orig_refs, orig_fits ) ] return entries, max_ref_len, max_fit_len
[docs] def align_with_vol(self, no_H: bool = True, num_repeats: int = 50, trans_init: bool = False, lr: float = 0.1, max_num_steps: int = 200, num_workers: int = 1, use_shmap: bool = True, num_buckets: int = 1, verbose: bool = False, ) -> Tuple[np.ndarray, List[np.ndarray]]: """Align all pairs using padded masked volumetric similarity via JAX. Because all padded arrays have the same shape, JAX's XLA compiler reuses one compiled kernel for every pair — no recompilation overhead. When ``num_workers > 1`` the pairs are split into size-sorted chunks and processed in parallel. It is recommended to use ``use_shmap=True`` instead of ``multiprocessing`` for this setting. Results are stored in-place on each MoleculePair: - ``pair.transform_vol_noH`` / ``pair.sim_aligned_vol_noH`` (when ``no_H=True``) - ``pair.transform_vol`` / ``pair.sim_aligned_vol`` (when ``no_H=False``) Parameters ---------- no_H : bool Whether to exclude hydrogens. Default is True. num_repeats : int Number of SE(3) initializations per pair. Default is 50. trans_init : bool If True, initialize translations to each ref atom position. Default is False. lr : float Optimizer learning rate. Default is 0.1. max_num_steps : int Maximum optimization steps. Default is 200. num_workers : int Number of parallel workers. ``1`` (default) runs sequentially in-process. When ``use_shmap=True`` (the default), this value is informational; actual parallelism equals ``len(jax.devices())``, which is set by ``XLA_FLAGS`` **before** JAX is first imported. When ``use_shmap=False`` use ``multiprocessing`` with a ``'spawn'`` start method. use_shmap : bool If ``True`` and ``num_workers > 1``, use ``jax.shard_map`` + ``vmap`` to parallelise across virtual CPU devices in a single process. Requires ``XLA_FLAGS=--xla_force_host_platform_device_count=N`` to be set before any JAX import. Uses ``lax.scan`` (fixed steps, no early stopping) instead of the ``while_loop``-based sequential path. Required on Linux HPC if num_workers > 1 where ``multiprocessing`` spawn can be unreliable with JAX. Default is ``True``. num_buckets : int ``1`` (default) pads all pairs to the global atom-count maximum — lowest overhead for typical use. Values > 1 sort pairs by ``(max(ref,fit), min(ref,fit))`` and process each bucket separately with reduced per-bucket padding, which can be beneficial for large heterogeneous molecule sets. verbose : bool Print scores per pair. Default is False. Returns ------- scores : np.ndarray Scores for each pair. Shape: (N,). aligned_list : list of np.ndarray Aligned fit atom coordinates (unpadded) for each pair. """ # build raw (unpadded) position arrays for every pair raw_refs, raw_fits, trans_centers_list = [], [], [] for pair in self.pairs: if no_H: ref_pos = pair.ref_molec.atom_pos fit_pos = pair.fit_molec.atom_pos else: ref_pos = pair.ref_molec.mol.GetConformer().GetPositions().astype(np.float32) fit_pos = pair.fit_molec.mol.GetConformer().GetPositions().astype(np.float32) raw_refs.append(ref_pos) raw_fits.append(fit_pos) tc = None if trans_init: tc = ref_pos # already numpy; worker copies implicitly trans_centers_list.append(tc) n_pairs = len(self.pairs) scores = np.zeros(n_pairs) aligned_list = [None] * n_pairs if use_shmap and num_workers > 1: # shard_map path (single process, multi-device) _jax_ver = _pkg_version("jax") _jax_ver_tuple = tuple(int(x) for x in _jax_ver.split(".")[:2]) if _jax_ver_tuple < (0, 9): raise RuntimeError( f"use_shmap=True requires JAX >= 0.9.0, but found JAX {_jax_ver}. " "Either upgrade JAX (which requires Python >= 3.11) or set use_shmap=False." ) pair_data = list(zip(raw_refs, raw_fits, trans_centers_list)) results = _align_vol_shmap( pair_data, num_workers, num_repeats, lr, max_num_steps, verbose, num_buckets=num_buckets, ) for i, (score, se3_transform, aligned_pts) in enumerate(results): scores[i] = score aligned_list[i] = aligned_pts pair = self.pairs[i] if no_H: pair.transform_vol_noH = se3_transform pair.sim_aligned_vol_noH = score else: pair.transform_vol = se3_transform pair.sim_aligned_vol = score elif num_workers > 1: # multiprocessing path pair_data = list(zip(raw_refs, raw_fits, trans_centers_list)) ref_sizes = np.array([len(r) for r in raw_refs]) fit_sizes = np.array([len(f) for f in raw_fits]) # Primary key: max(ref, fit) — dominates padding; secondary: min. sort_keys = np.array([np.minimum(ref_sizes, fit_sizes), np.maximum(ref_sizes, fit_sizes)]) index_splits, chunk_results = _dispatch_parallel( pair_data, sort_keys, _align_vol_worker, num_workers, (num_repeats, lr, max_num_steps, verbose), ) for idx_list, chunk_result in zip(index_splits, chunk_results): for global_i, (score, se3_transform, aligned_pts) in zip(idx_list, chunk_result): scores[global_i] = score aligned_list[global_i] = aligned_pts pair = self.pairs[global_i] if no_H: pair.transform_vol_noH = se3_transform pair.sim_aligned_vol_noH = score else: pair.transform_vol = se3_transform pair.sim_aligned_vol = score else: # sequential try: import jax.numpy as jnp except ImportError as exc: raise ImportError( 'JAX is required for MoleculePairBatch.align_with_vol. ' 'Install it with: pip install "shepherd-score[jax]"' ) from exc from shepherd_score.alignment_jax import optimize_ROCS_overlay_jax_mask ref_sizes = np.array([len(r) for r in raw_refs]) fit_sizes = np.array([len(f) for f in raw_fits]) bucket_splits = _compute_bucket_splits(ref_sizes, fit_sizes, num_buckets) for bucket_idx_list in bucket_splits: bucket_refs = [raw_refs[i] for i in bucket_idx_list] bucket_fits = [raw_fits[i] for i in bucket_idx_list] ref_padded_b, masks_ref_b, _orig_refs_b, _ = _pad_arrays(bucket_refs) fit_padded_b, masks_fit_b, orig_fits_b, _ = _pad_arrays(bucket_fits) for local_j, global_i in enumerate(bucket_idx_list): pair = self.pairs[global_i] aligned_pts, se3_transform, score = optimize_ROCS_overlay_jax_mask( ref_points=jnp.array(ref_padded_b[local_j]), fit_points=jnp.array(fit_padded_b[local_j]), mask_ref=jnp.array(masks_ref_b[local_j]), mask_fit=jnp.array(masks_fit_b[local_j]), alpha=0.81, num_repeats=num_repeats, trans_centers=trans_centers_list[global_i], lr=lr, max_num_steps=max_num_steps, verbose=verbose, ) se3_transform = np.array(se3_transform) score = float(np.array(score)) aligned_pts = np.array(aligned_pts)[:orig_fits_b[local_j]] scores[global_i] = score if no_H: pair.transform_vol_noH = se3_transform pair.sim_aligned_vol_noH = score else: pair.transform_vol = se3_transform pair.sim_aligned_vol = score aligned_list[global_i] = aligned_pts return scores, aligned_list
[docs] def align_with_vol_esp(self, lam: float, no_H: bool = True, num_repeats: int = 50, trans_init: bool = False, lr: float = 0.1, max_num_steps: int = 200, num_workers: int = 1, use_shmap: bool = True, num_buckets: int = 1, verbose: bool = False, ) -> Tuple[np.ndarray, List[np.ndarray]]: """Align all pairs using padded masked volumetric ESP similarity via JAX. Because all padded arrays have the same shape, JAX's XLA compiler reuses one compiled kernel for every pair — no recompilation overhead. When ``num_workers > 1`` the pairs are split into size-sorted chunks and processed in parallel. It is recommended to use ``use_shmap=True`` instead of ``multiprocessing`` for this setting. Results are stored in-place on each MoleculePair: - ``pair.transform_vol_esp_noH`` / ``pair.sim_aligned_vol_esp_noH`` (when ``no_H=True``) - ``pair.transform_vol_esp`` / ``pair.sim_aligned_vol_esp`` (when ``no_H=False``) Parameters ---------- lam : float Partial charge weighting parameter. Typically 0.1 for volumetric. no_H : bool Whether to exclude hydrogens. Default is True. num_repeats : int Number of SE(3) initializations per pair. Default is 50. trans_init : bool If True, initialize translations to each ref atom position. Default is False. lr : float Optimizer learning rate. Default is 0.1. max_num_steps : int Maximum optimization steps. Default is 200. num_workers : int Number of parallel worker processes. ``1`` (default) runs sequentially in-process. Values greater than ``len(self.pairs)`` are clamped to ``len(self.pairs)``. use_shmap : bool If ``True`` and ``num_workers > 1``, use ``jax.shard_map`` + ``vmap`` to parallelise across virtual CPU devices in a single process. Requires ``XLA_FLAGS=--xla_force_host_platform_device_count=N`` to be set before any JAX import. Uses ``lax.scan`` (fixed steps, no early stopping) instead of the ``while_loop``-based sequential path. Required on Linux HPC if num_workers > 1 where ``multiprocessing`` spawn can be unreliable with JAX. Default is ``True``. num_buckets : int ``1`` (default) pads all pairs to the global atom-count maximum — lowest overhead for typical use. Values > 1 sort pairs by ``(max(ref,fit), min(ref,fit))`` and process each bucket separately with reduced per-bucket padding, which can be beneficial for large heterogeneous molecule sets. verbose : bool Print scores per pair. Default is False. Returns ------- scores : np.ndarray Scores for each pair. Shape: (N,). aligned_list : list of np.ndarray Aligned fit atom coordinates (unpadded) for each pair. """ # Build raw (unpadded) per-pair data tuples (plain numpy — picklable). raw_refs, raw_fits, raw_ref_ch, raw_fit_ch, trans_centers_list = [], [], [], [], [] for pair in self.pairs: if no_H: ref_pos = pair.ref_molec.atom_pos fit_pos = pair.fit_molec.atom_pos ref_ch = pair.ref_molec.partial_charges[pair.ref_molec._nonH_atoms_idx] fit_ch = pair.fit_molec.partial_charges[pair.fit_molec._nonH_atoms_idx] else: ref_pos = pair.ref_molec.mol.GetConformer().GetPositions().astype(np.float32) fit_pos = pair.fit_molec.mol.GetConformer().GetPositions().astype(np.float32) ref_ch = pair.ref_molec.partial_charges fit_ch = pair.fit_molec.partial_charges raw_refs.append(ref_pos) raw_fits.append(fit_pos) raw_ref_ch.append(ref_ch) raw_fit_ch.append(fit_ch) tc = ref_pos if trans_init else None trans_centers_list.append(tc) n_pairs = len(self.pairs) scores = np.zeros(n_pairs) aligned_list = [None] * n_pairs if use_shmap and num_workers > 1: # shard_map path _jax_ver = _pkg_version("jax") _jax_ver_tuple = tuple(int(x) for x in _jax_ver.split(".")[:2]) if _jax_ver_tuple < (0, 9): raise RuntimeError( f"use_shmap=True requires JAX >= 0.9.0, but found JAX {_jax_ver}. " "Either upgrade JAX (which requires Python >= 3.11) or set use_shmap=False." ) pair_data = list(zip(raw_refs, raw_fits, raw_ref_ch, raw_fit_ch, trans_centers_list)) results = _align_vol_esp_shmap( pair_data, num_workers, lam, num_repeats, lr, max_num_steps, verbose, num_buckets=num_buckets, ) for i, (score, se3_transform, aligned_pts) in enumerate(results): scores[i] = score aligned_list[i] = aligned_pts pair = self.pairs[i] if no_H: pair.transform_vol_esp_noH = se3_transform pair.sim_aligned_vol_esp_noH = score else: pair.transform_vol_esp = se3_transform pair.sim_aligned_vol_esp = score elif num_workers > 1: # parallel pair_data = list(zip(raw_refs, raw_fits, raw_ref_ch, raw_fit_ch, trans_centers_list)) ref_sizes = np.array([len(r) for r in raw_refs]) fit_sizes = np.array([len(f) for f in raw_fits]) sort_keys = np.array([np.minimum(ref_sizes, fit_sizes), np.maximum(ref_sizes, fit_sizes)]) index_splits, chunk_results = _dispatch_parallel( pair_data, sort_keys, _align_vol_esp_worker, num_workers, (lam, num_repeats, lr, max_num_steps, verbose), ) for idx_list, chunk_result in zip(index_splits, chunk_results): for global_i, (score, se3_transform, aligned_pts) in zip(idx_list, chunk_result): scores[global_i] = score aligned_list[global_i] = aligned_pts pair = self.pairs[global_i] if no_H: pair.transform_vol_esp_noH = se3_transform pair.sim_aligned_vol_esp_noH = score else: pair.transform_vol_esp = se3_transform pair.sim_aligned_vol_esp = score else: # sequential try: import jax.numpy as jnp except ImportError as exc: raise ImportError( 'JAX is required for MoleculePairBatch.align_with_vol_esp. ' 'Install it with: pip install "shepherd-score[jax]"' ) from exc from shepherd_score.alignment_jax import optimize_ROCS_esp_overlay_jax_mask ref_sizes = np.array([len(r) for r in raw_refs]) fit_sizes = np.array([len(f) for f in raw_fits]) bucket_splits = _compute_bucket_splits(ref_sizes, fit_sizes, num_buckets) for bucket_idx_list in bucket_splits: bucket_refs = [raw_refs[i] for i in bucket_idx_list] bucket_fits = [raw_fits[i] for i in bucket_idx_list] bucket_ref_ch = [raw_ref_ch[i] for i in bucket_idx_list] bucket_fit_ch = [raw_fit_ch[i] for i in bucket_idx_list] ref_padded_b, masks_ref_b, _orig_refs_b, _ = _pad_arrays(bucket_refs) fit_padded_b, masks_fit_b, orig_fits_b, _ = _pad_arrays(bucket_fits) ref_ch_padded_b, _, _, _ = _pad_arrays(bucket_ref_ch) fit_ch_padded_b, _, _, _ = _pad_arrays(bucket_fit_ch) for local_j, global_i in enumerate(bucket_idx_list): pair = self.pairs[global_i] aligned_pts, se3_transform, score = optimize_ROCS_esp_overlay_jax_mask( ref_points=jnp.array(ref_padded_b[local_j]), fit_points=jnp.array(fit_padded_b[local_j]), ref_charges=jnp.array(ref_ch_padded_b[local_j]), fit_charges=jnp.array(fit_ch_padded_b[local_j]), mask_ref=jnp.array(masks_ref_b[local_j]), mask_fit=jnp.array(masks_fit_b[local_j]), alpha=0.81, lam=lam, num_repeats=num_repeats, trans_centers=trans_centers_list[global_i], lr=lr, max_num_steps=max_num_steps, verbose=verbose, ) se3_transform = np.array(se3_transform) score = float(np.array(score)) aligned_pts = np.array(aligned_pts)[:orig_fits_b[local_j]] scores[global_i] = score if no_H: pair.transform_vol_esp_noH = se3_transform pair.sim_aligned_vol_esp_noH = score else: pair.transform_vol_esp = se3_transform pair.sim_aligned_vol_esp = score aligned_list[global_i] = aligned_pts return scores, aligned_list
def _delegate_alignment(self, method_name: str, score_attr: str, **kwargs): """Delegate alignment to each MoleculePair's method and collect results. Parameters ---------- method_name : str Name of the MoleculePair method to call (e.g. 'align_with_surf'). score_attr : str Name of the attribute on MoleculePair where the score is stored after alignment. **kwargs Forwarded to each pair's method. Returns ------- scores : np.ndarray Shape: (N,). aligned_list : list of np.ndarray Aligned fit coordinates for each pair. """ aligned_list = [] scores = np.zeros(len(self.pairs)) for i, pair in enumerate(self.pairs): aligned_pts = getattr(pair, method_name)(**kwargs) scores[i] = float(getattr(pair, score_attr)) aligned_list.append(aligned_pts) return scores, aligned_list
[docs] def align_with_surf(self, alpha: float, num_repeats: int = 50, trans_init: bool = False, lr: float = 0.1, max_num_steps: int = 200, use_jax: bool = True, use_analytical: bool = True, num_workers: int = 1, use_shmap: bool = False, verbose: bool = False, ) -> Tuple[np.ndarray, List[np.ndarray]]: """Align all pairs using surface similarity. Surface arrays are the same size across all pairs so no padding or size-sorting is needed. It is not recommended to use multiprocessing due to this reason. Results are stored in-place on each MoleculePair: - ``pair.transform_surf`` and ``pair.sim_aligned_surf`` Parameters ---------- alpha : float Gaussian width parameter for overlap. num_repeats : int Number of SE(3) initializations per pair. Default is 50. trans_init : bool Apply translation initialization for alignment. Default is False. lr : float Optimizer learning rate. Default is 0.1. max_num_steps : int Maximum optimization steps. Default is 200. use_jax : bool Whether to use JAX backend. Default is True. use_analytical : bool Whether to use analytical gradients (PyTorch only). Default is True. num_workers : int Number of parallel worker processes. ``1`` (default) runs sequentially in-process. Values greater than ``len(self.pairs)`` are clamped to ``len(self.pairs)``. use_shmap : bool Whether to use JAX shard_map for parallel alignment. Default is False. Performance is better when use_shmap is False on cpu. verbose : bool Print scores per pair. Default is False. Returns ------- scores : np.ndarray Scores for each pair. Shape: (N,). aligned_list : list of np.ndarray Aligned fit surface coordinates for each pair. """ n_pairs = len(self.pairs) pair_data = [ (pair.ref_molec.surf_pos, pair.fit_molec.surf_pos, pair.ref_molec.atom_pos if trans_init else None) for pair in self.pairs ] if use_shmap and num_workers > 1: # shard_map path _jax_ver = _pkg_version("jax") _jax_ver_tuple = tuple(int(x) for x in _jax_ver.split(".")[:2]) if _jax_ver_tuple < (0, 9): raise RuntimeError( f"use_shmap=True requires JAX >= 0.9.0, but found JAX {_jax_ver}. " "Either upgrade JAX (which requires Python >= 3.11) or set use_shmap=False." ) results = _align_surf_shmap( pair_data, num_workers, alpha, num_repeats, lr, max_num_steps, verbose, ) scores = np.zeros(n_pairs) aligned_list = [None] * n_pairs for i, (score, se3_transform, aligned_pts) in enumerate(results): scores[i] = score aligned_list[i] = aligned_pts pair = self.pairs[i] pair.transform_surf = se3_transform pair.sim_aligned_surf = score return scores, aligned_list elif num_workers > 1: # parallel index_splits, chunk_results = _dispatch_parallel( pair_data, None, _align_surf_worker, num_workers, (alpha, num_repeats, lr, max_num_steps, use_jax, use_analytical, verbose), ) scores = np.zeros(n_pairs) aligned_list = [None] * n_pairs for idx_list, chunk_result in zip(index_splits, chunk_results): for global_i, (score, se3_transform, aligned_pts) in zip(idx_list, chunk_result): scores[global_i] = score aligned_list[global_i] = aligned_pts pair = self.pairs[global_i] pair.transform_surf = se3_transform pair.sim_aligned_surf = score return scores, aligned_list # sequential return self._delegate_alignment( 'align_with_surf', 'sim_aligned_surf', alpha=alpha, num_repeats=num_repeats, trans_init=trans_init, lr=lr, max_num_steps=max_num_steps, use_jax=use_jax, use_analytical=use_analytical, verbose=verbose, )
[docs] def align_with_esp(self, alpha: float, lam: float = 0.3, num_repeats: int = 50, trans_init: bool = False, lr: float = 0.1, max_num_steps: int = 200, use_jax: bool = True, use_analytical: bool = True, num_workers: int = 1, use_shmap: bool = False, verbose: bool = False, ) -> Tuple[np.ndarray, List[np.ndarray]]: """Align all pairs using ESP+surface similarity. Surface arrays are the same size across all pairs so no padding or size-sorting is needed. It is not recommended to use multiprocessing due to this reason. Results are stored in-place on each MoleculePair: - ``pair.transform_esp`` and ``pair.sim_aligned_esp`` Parameters ---------- alpha : float Gaussian width parameter for overlap. lam : float Weighting factor for ESP scoring. Scaled internally. Default is 0.3. num_repeats : int Number of SE(3) initializations per pair. Default is 50. trans_init : bool Apply translation initialization for alignment. Default is False. lr : float Optimizer learning rate. Default is 0.1. max_num_steps : int Maximum optimization steps. Default is 200. use_jax : bool Whether to use JAX backend. Default is True. use_analytical : bool Whether to use analytical gradients (PyTorch only). Default is True. num_workers : int Number of parallel worker processes. ``1`` (default) runs sequentially in-process. Values greater than ``len(self.pairs)`` are clamped to ``len(self.pairs)``. use_shmap : bool Whether to use JAX shard_map for parallel alignment. Default is False. Performance is better when use_shmap is False on cpu. verbose : bool Print scores per pair. Default is False. Returns ------- scores : np.ndarray Scores for each pair. Shape: (N,). aligned_list : list of np.ndarray Aligned fit surface coordinates for each pair. """ from shepherd_score.score.constants import LAM_SCALING lam_scaled = float(LAM_SCALING * lam) n_pairs = len(self.pairs) pair_data = [ (pair.ref_molec.surf_pos, pair.fit_molec.surf_pos, pair.ref_molec.surf_esp, pair.fit_molec.surf_esp, pair.ref_molec.atom_pos if trans_init else None) for pair in self.pairs ] if use_shmap and num_workers > 1: # shard_map path _jax_ver = _pkg_version("jax") _jax_ver_tuple = tuple(int(x) for x in _jax_ver.split(".")[:2]) if _jax_ver_tuple < (0, 9): raise RuntimeError( f"use_shmap=True requires JAX >= 0.9.0, but found JAX {_jax_ver}. " "Either upgrade JAX (which requires Python >= 3.11) or set use_shmap=False." ) results = _align_esp_shmap( pair_data, num_workers, alpha, lam_scaled, num_repeats, lr, max_num_steps, verbose, ) scores = np.zeros(n_pairs) aligned_list = [None] * n_pairs for i, (score, se3_transform, aligned_pts) in enumerate(results): scores[i] = score aligned_list[i] = aligned_pts pair = self.pairs[i] pair.transform_esp = se3_transform pair.sim_aligned_esp = score return scores, aligned_list elif num_workers > 1: # parallel index_splits, chunk_results = _dispatch_parallel( pair_data, None, _align_esp_worker, num_workers, (alpha, lam_scaled, num_repeats, lr, max_num_steps, use_jax, use_analytical, verbose), ) scores = np.zeros(n_pairs) aligned_list = [None] * n_pairs for idx_list, chunk_result in zip(index_splits, chunk_results): for global_i, (score, se3_transform, aligned_pts) in zip(idx_list, chunk_result): scores[global_i] = score aligned_list[global_i] = aligned_pts pair = self.pairs[global_i] pair.transform_esp = se3_transform pair.sim_aligned_esp = score return scores, aligned_list # sequential return self._delegate_alignment( 'align_with_esp', 'sim_aligned_esp', alpha=alpha, lam=lam, num_repeats=num_repeats, trans_init=trans_init, lr=lr, max_num_steps=max_num_steps, use_jax=use_jax, use_analytical=use_analytical, verbose=verbose, )
def _pad_and_mask_pharm(self): """Extract, pad, and create masks for pharmacophore alignment. Validates that all pairs have pharmacophore data. Does NOT modify the pair objects. Returns padded arrays and masks. Returns ------- entries : list of tuples Each tuple is (ref_ptypes, fit_ptypes, ref_ancs_pad, fit_ancs_pad, ref_vecs_pad, fit_vecs_pad, mask_ref, mask_fit, orig_ref_len, orig_fit_len). max_ref_len : int max_fit_len : int """ for i, pair in enumerate(self.pairs): if (pair.ref_molec.pharm_types is None or pair.fit_molec.pharm_types is None): raise ValueError( f'Pair {i} is missing pharmacophore data. ' 'Create Molecule objects with pharm_multi_vector set to True or False.' ) DUMMY_TYPE = 8 # index of 'Dummy' in P_TYPES ref_types_list = [p.ref_molec.pharm_types for p in self.pairs] fit_types_list = [p.fit_molec.pharm_types for p in self.pairs] max_ref_len = max(t.shape[0] for t in ref_types_list) max_fit_len = max(t.shape[0] for t in fit_types_list) ref_ancs_padded, masks_ref, orig_refs, _ = _pad_arrays([p.ref_molec.pharm_ancs for p in self.pairs]) fit_ancs_padded, masks_fit, orig_fits, _ = _pad_arrays([p.fit_molec.pharm_ancs for p in self.pairs]) ref_vecs_padded, _, _, _ = _pad_arrays([p.ref_molec.pharm_vecs for p in self.pairs]) fit_vecs_padded, _, _, _ = _pad_arrays([p.fit_molec.pharm_vecs for p in self.pairs]) entries = [] for (ref_types, fit_types, ref_ancs_pad, fit_ancs_pad, ref_vecs_pad, fit_vecs_pad, mask_ref, mask_fit, orig_ref, orig_fit) in zip( ref_types_list, fit_types_list, ref_ancs_padded, fit_ancs_padded, ref_vecs_padded, fit_vecs_padded, masks_ref, masks_fit, orig_refs, orig_fits ): ref_types_pad = np.full(max_ref_len, DUMMY_TYPE, dtype=np.int32) ref_types_pad[:orig_ref] = ref_types fit_types_pad = np.full(max_fit_len, DUMMY_TYPE, dtype=np.int32) fit_types_pad[:orig_fit] = fit_types entries.append((ref_types_pad, fit_types_pad, ref_ancs_pad, fit_ancs_pad, ref_vecs_pad, fit_vecs_pad, mask_ref, mask_fit, orig_ref, orig_fit)) return entries, max_ref_len, max_fit_len
[docs] def align_with_pharm(self, similarity: str = 'tanimoto', extended_points: bool = False, only_extended: bool = False, num_repeats: int = 50, trans_init: bool = False, lr: float = 0.1, max_num_steps: int = 200, num_workers: int = 1, use_shmap: bool = True, num_buckets: int = 1, verbose: bool = False, ) -> Tuple[np.ndarray, List[np.ndarray], List[np.ndarray]]: """Align all pairs using padded masked pharmacophore similarity via JAX. Because all padded arrays have the same shape, JAX's XLA compiler reuses one compiled kernel for every pair — no recompilation overhead. When ``num_workers > 1`` the pairs are split into size-sorted chunks and processed in parallel. It is recommended to use ``use_shmap=True`` instead of ``multiprocessing`` for this setting. Results are stored in-place on each MoleculePair: - ``pair.transform_pharm`` and ``pair.sim_aligned_pharm`` Parameters ---------- similarity : str One of ``'tanimoto'``, ``'tversky'``, ``'tversky_ref'``, ``'tversky_fit'``. extended_points : bool Score HBA/HBD with extended-point Gaussians. only_extended : bool When ``extended_points`` is True, ignore anchor overlaps. num_repeats : int Number of SE(3) initializations per pair. trans_init : bool If True, initialize translations to each ref pharmacophore anchor. lr : float Optimizer learning rate. max_num_steps : int Maximum optimization steps. num_workers : int Number of parallel worker processes. ``1`` (default) runs sequentially in-process. Values greater than ``len(self.pairs)`` are clamped to ``len(self.pairs)``. use_shmap : bool If ``True`` and ``num_workers > 1``, use ``jax.shard_map`` + ``vmap`` to parallelise across virtual CPU devices in a single process. Requires ``XLA_FLAGS=--xla_force_host_platform_device_count=N`` to be set before any JAX import. Uses ``lax.scan`` (fixed steps, no early stopping) instead of the ``while_loop``-based sequential path. Required on Linux HPC if num_workers > 1 where ``multiprocessing`` spawn can be unreliable with JAX. Default is ``True``. num_buckets : int ``1`` (default) pads all pairs to the global pharmacophore-count maximum — lowest overhead for typical use. Values > 1 sort pairs by ``(max(ref,fit), min(ref,fit))`` and process each bucket separately with reduced per-bucket padding, which can be beneficial for large heterogeneous molecule sets. verbose : bool Print scores per pair. Returns ------- scores : np.ndarray Shape: (N,). aligned_anchors_list : list of np.ndarray Aligned fit pharmacophore anchors (unpadded) for each pair. aligned_vectors_list : list of np.ndarray Aligned fit pharmacophore vectors (unpadded) for each pair. """ # Validate pharmacophore data and collect raw arrays for all pairs. for idx, pair in enumerate(self.pairs): if (pair.ref_molec.pharm_types is None or pair.fit_molec.pharm_types is None): raise ValueError( f'Pair {idx} is missing pharmacophore data. ' 'Create Molecule objects with pharm_multi_vector set to True or False.' ) n_pairs = len(self.pairs) scores = np.zeros(n_pairs) aligned_anchors_list = [None] * n_pairs aligned_vectors_list = [None] * n_pairs # Build raw (unpadded) per-pair data tuples (plain numpy — picklable). pair_data = [] for pair in self.pairs: tc = pair.ref_molec.pharm_ancs if trans_init else None pair_data.append(( pair.ref_molec.pharm_types, pair.fit_molec.pharm_types, pair.ref_molec.pharm_ancs, pair.fit_molec.pharm_ancs, pair.ref_molec.pharm_vecs, pair.fit_molec.pharm_vecs, tc, pair.ref_molec.pharm_ancs, pair.fit_molec.pharm_ancs, )) if use_shmap and num_workers > 1: # shard_map path _jax_ver = _pkg_version("jax") _jax_ver_tuple = tuple(int(x) for x in _jax_ver.split(".")[:2]) if _jax_ver_tuple < (0, 9): raise RuntimeError( f"use_shmap=True requires JAX >= 0.9.0, but found JAX {_jax_ver}. " "Either upgrade JAX (which requires Python >= 3.11) or set use_shmap=False." ) results = _align_pharm_shmap( pair_data, num_workers, similarity, extended_points, only_extended, num_repeats, lr, max_num_steps, verbose, num_buckets=num_buckets, ) for i, (score, se3_transform, aligned_ancs, aligned_vecs) in enumerate(results): scores[i] = score aligned_anchors_list[i] = aligned_ancs aligned_vectors_list[i] = aligned_vecs pair = self.pairs[i] pair.transform_pharm = se3_transform pair.sim_aligned_pharm = score elif num_workers > 1: # parallel ref_sizes = np.array([len(d[2]) for d in pair_data]) # ref_ancs fit_sizes = np.array([len(d[3]) for d in pair_data]) # fit_ancs # Primary key: max(ref, fit) — dominates padding; secondary: min. sort_keys = np.array([np.minimum(ref_sizes, fit_sizes), np.maximum(ref_sizes, fit_sizes)]) index_splits, chunk_results = _dispatch_parallel( pair_data, sort_keys, _align_pharm_worker, num_workers, (similarity, extended_points, only_extended, num_repeats, lr, max_num_steps, verbose), ) for idx_list, chunk_result in zip(index_splits, chunk_results): for global_i, (score, se3_transform, aligned_ancs, aligned_vecs) in zip( idx_list, chunk_result ): scores[global_i] = score aligned_anchors_list[global_i] = aligned_ancs aligned_vectors_list[global_i] = aligned_vecs pair = self.pairs[global_i] pair.transform_pharm = se3_transform pair.sim_aligned_pharm = score else: # sequential try: import jax.numpy as jnp except ImportError as exc: raise ImportError( 'JAX is required for MoleculePairBatch.align_with_pharm. ' 'Install it with: pip install "shepherd-score[jax]"' ) from exc from shepherd_score.alignment_jax import optimize_pharm_overlay_jax_vectorized_mask DUMMY_TYPE = 8 # index of 'Dummy' in P_TYPES ref_types_list = [p.ref_molec.pharm_types for p in self.pairs] fit_types_list = [p.fit_molec.pharm_types for p in self.pairs] ref_ancs_list = [p.ref_molec.pharm_ancs for p in self.pairs] fit_ancs_list = [p.fit_molec.pharm_ancs for p in self.pairs] ref_vecs_list = [p.ref_molec.pharm_vecs for p in self.pairs] fit_vecs_list = [p.fit_molec.pharm_vecs for p in self.pairs] ref_sizes = np.array([len(a) for a in ref_ancs_list]) fit_sizes = np.array([len(a) for a in fit_ancs_list]) bucket_splits = _compute_bucket_splits(ref_sizes, fit_sizes, num_buckets) for bucket_idx_list in bucket_splits: bucket_ref_ancs = [ref_ancs_list[i] for i in bucket_idx_list] bucket_fit_ancs = [fit_ancs_list[i] for i in bucket_idx_list] bucket_ref_vecs = [ref_vecs_list[i] for i in bucket_idx_list] bucket_fit_vecs = [fit_vecs_list[i] for i in bucket_idx_list] ref_ancs_padded, masks_ref, orig_refs_b, max_ref_b = _pad_arrays(bucket_ref_ancs) fit_ancs_padded, masks_fit, orig_fits_b, max_fit_b = _pad_arrays(bucket_fit_ancs) ref_vecs_padded, _, _, _ = _pad_arrays(bucket_ref_vecs) fit_vecs_padded, _, _, _ = _pad_arrays(bucket_fit_vecs) for local_j, global_i in enumerate(bucket_idx_list): pair = self.pairs[global_i] orig_ref = orig_refs_b[local_j] orig_fit = orig_fits_b[local_j] ref_types_pad = np.full(max_ref_b, DUMMY_TYPE, dtype=np.int32) ref_types_pad[:orig_ref] = ref_types_list[global_i] fit_types_pad = np.full(max_fit_b, DUMMY_TYPE, dtype=np.int32) fit_types_pad[:orig_fit] = fit_types_list[global_i] trans_centers = pair.ref_molec.pharm_ancs if trans_init else None aligned_ancs, aligned_vecs, se3_transform, score = \ optimize_pharm_overlay_jax_vectorized_mask( ref_pharms=jnp.array(ref_types_pad), fit_pharms=jnp.array(fit_types_pad), ref_anchors=jnp.array(ref_ancs_padded[local_j]), fit_anchors=jnp.array(fit_ancs_padded[local_j]), ref_vectors=jnp.array(ref_vecs_padded[local_j]), fit_vectors=jnp.array(fit_vecs_padded[local_j]), mask_ref=jnp.array(masks_ref[local_j]), mask_fit=jnp.array(masks_fit[local_j]), similarity=similarity, extended_points=extended_points, only_extended=only_extended, num_repeats=num_repeats, trans_centers=trans_centers, init_ref_anchors=pair.ref_molec.pharm_ancs, init_fit_anchors=pair.fit_molec.pharm_ancs, lr=lr, max_num_steps=max_num_steps, verbose=verbose, ) se3_transform = np.array(se3_transform) score = float(np.array(score)) aligned_ancs = np.array(aligned_ancs)[:orig_fit] aligned_vecs = np.array(aligned_vecs)[:orig_fit] scores[global_i] = score pair.transform_pharm = se3_transform pair.sim_aligned_pharm = score aligned_anchors_list[global_i] = aligned_ancs aligned_vectors_list[global_i] = aligned_vecs return scores, aligned_anchors_list, aligned_vectors_list