JAX Parallel Alignment#

Note

XLA_FLAGS must be set before any JAX import so that len(jax.devices()) equals the desired number of virtual CPU devices. Place the following lines at the very top of your script:

import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
os.environ['JAX_PLATFORMS'] = 'cpu'

This module provides multi-device volumetric alignment using jax.shard_map(). It is the backend called by shepherd_score.container.MoleculePairBatch.align_with_vol() when use_shmap=True.

Overview#

optimize_ROCS_overlay_jax_vol_shmap() distributes a flat batch of molecule pairs across virtual CPU devices without any Python-level multiprocessing. Key properties:

  • Accepts flat (total, ...) arrays where total is the number of pairs padded to a multiple of len(jax.devices()). Do not pre-reshape to (n_devices, B, ...).

  • Internally wraps _per_pair_optimize_vol_mask_scan (from shepherd_score.alignment._jax) in vmap, then distributes via shard_map with PartitionSpec('i') on the leading axis.

  • The compiled function is cached in _shmap_vol_cache keyed by (max_num_steps, n_devices) so the XLA kernel is compiled only once per unique (steps, device-count) combination.

  • Uses lax.scan for a fixed number of steps (no convergence-based early stopping). This enables full ahead-of-time compilation; typical speedup is ~2.8× on 4 CPU cores compared to sequential JAX alignment.

Usage via the high-level API#

The recommended entry point is align_with_vol():

import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
os.environ['JAX_PLATFORMS'] = 'cpu'

from shepherd_score.container import MoleculePairBatch

batch = MoleculePairBatch(pairs)

# Default: single pass (n_buckets=1)
scores, aligned = batch.align_with_vol(num_workers=4, use_shmap=True)

# Bucketed: useful for >10k pairs with diverse molecule sizes
scores, aligned = batch.align_with_vol(num_workers=4, use_shmap=True, n_buckets=8)

See MoleculePairBatch for details on bucketing and masking strategy.

Direct usage#

Pre-compute self-overlaps and SE(3) initialisations outside the function (they are invariant to the optimisation loop), then call:

from shepherd_score.alignment._jax_parallel import optimize_ROCS_overlay_jax_vol_shmap

aligned_pts, se3_transform, scores = optimize_ROCS_overlay_jax_vol_shmap(
    ref_batch,       # (total, N, 3)
    fit_batch,       # (total, M, 3)
    mask_ref_batch,  # (total, N)
    mask_fit_batch,  # (total, M)
    VAA_batch,       # (total,)  pre-computed ref self-overlaps
    VBB_batch,       # (total,)  pre-computed fit self-overlaps
    se3_init_batch,  # (total, R, 7)  pre-initialised SE(3) params
    alpha=0.81,
    lr=0.1,
    max_num_steps=200,
)

API#

Parallel (multi-device) JAX alignment via jax.shard_map.

Each public optimize_*_jax_*_shmap function accepts flat (total, ...) arrays (where total is the number of pairs padded to a multiple of len(jax.devices())) and returns flat (total, ...) results. The leading axis is automatically distributed across devices by jax.shard_map with PartitionSpec('i'); do not pre-reshape to (n_devices, B, ...).

XLA_FLAGS=--xla_force_host_platform_device_count=N must be set before any JAX import so that len(jax.devices()) == N.

shepherd_score.alignment._jax_parallel.optimize_ROCS_overlay_jax_vol_shmap(ref_batch, fit_batch, mask_ref_batch, mask_fit_batch, VAA_batch, VBB_batch, se3_init_batch, alpha, lr, max_num_steps)[source]#

Volumetric alignment via shard_map + vmap across virtual CPU devices.

All *_batch arrays use a flat leading axis of size total (i.e. the number of pairs padded to a multiple of len(jax.devices())). Unlike pmap, shard_map automatically distributes the flat leading axis across devices; do not pre-reshape to (n_devices, B, ...). Pre-compute self-overlaps VAA/VBB and SE(3) initialisations outside this function (they are invariant to the optimisation loop).

XLA_FLAGS=--xla_force_host_platform_device_count=N must be set before JAX is first imported so that len(jax.devices()) == N.

Parameters:
  • ref_batch ((total, N, 3) padded reference positions)

  • fit_batch ((total, M, 3) padded fit positions)

  • mask_ref_batch ((total, N))

  • mask_fit_batch ((total, M))

  • VAA_batch ((total,) pre-computed ref self-overlaps)

  • VBB_batch ((total,) pre-computed fit self-overlaps)

  • se3_init_batch ((total, R, 7) pre-initialised SE(3) params)

  • alpha (float)

  • lr (float)

  • max_num_steps (int (Python int; determines compiled kernel))

Returns:

  • aligned_pts ((total, M, 3))

  • se3_transform ((total, 4, 4))

  • scores ((total,))

shepherd_score.alignment._jax_parallel.optimize_ROCS_esp_overlay_jax_vol_esp_shmap(ref_batch, fit_batch, ref_charges_batch, fit_charges_batch, mask_ref_batch, mask_fit_batch, VAA_batch, VBB_batch, se3_init_batch, alpha, lam, lr, max_num_steps)[source]#

Masked volumetric ESP alignment via shard_map + vmap across virtual CPU devices.

Parameters:
  • ref_batch ((total, N, 3) padded reference positions)

  • fit_batch ((total, M, 3) padded fit positions)

  • ref_charges_batch ((total, N, 1) padded reference charges (column-shaped))

  • fit_charges_batch ((total, M, 1) padded fit charges)

  • mask_ref_batch ((total, N))

  • mask_fit_batch ((total, M))

  • VAA_batch ((total,) pre-computed ref ESP self-overlaps)

  • VBB_batch ((total,) pre-computed fit ESP self-overlaps)

  • se3_init_batch ((total, R, 7) pre-initialised SE(3) params)

  • alpha (float)

  • lam (float)

  • lr (float)

  • max_num_steps (int (Python int; determines compiled kernel))

Returns:

  • aligned_pts ((total, M, 3))

  • se3_transform ((total, 4, 4))

  • scores ((total,))

shepherd_score.alignment._jax_parallel.optimize_ROCS_overlay_jax_surf_shmap(ref_batch, fit_batch, VAA_batch, VBB_batch, se3_init_batch, alpha, lr, max_num_steps)[source]#

Non-masked surface alignment via shard_map + vmap across virtual CPU devices.

Surface arrays are uniform size across all pairs so no masking is needed.

Parameters:
  • ref_batch ((total, N, 3) stacked reference surface positions)

  • fit_batch ((total, M, 3) stacked fit surface positions)

  • VAA_batch ((total,) pre-computed ref self-overlaps)

  • VBB_batch ((total,) pre-computed fit self-overlaps)

  • se3_init_batch ((total, R, 7) pre-initialised SE(3) params)

  • alpha (float)

  • lr (float)

  • max_num_steps (int (Python int; determines compiled kernel))

Returns:

  • aligned_pts ((total, M, 3))

  • se3_transform ((total, 4, 4))

  • scores ((total,))

shepherd_score.alignment._jax_parallel.optimize_ROCS_esp_overlay_jax_surf_esp_shmap(ref_batch, fit_batch, ref_charges_batch, fit_charges_batch, VAA_batch, VBB_batch, se3_init_batch, alpha, lam, lr, max_num_steps)[source]#

Non-masked surface ESP alignment via shard_map + vmap across virtual CPU devices.

Parameters:
  • ref_batch ((total, N, 3) stacked reference surface positions)

  • fit_batch ((total, M, 3) stacked fit surface positions)

  • ref_charges_batch ((total, N) stacked reference ESP values)

  • fit_charges_batch ((total, M) stacked fit ESP values)

  • VAA_batch ((total,) pre-computed ref ESP self-overlaps)

  • VBB_batch ((total,) pre-computed fit ESP self-overlaps)

  • se3_init_batch ((total, R, 7) pre-initialised SE(3) params)

  • alpha (float)

  • lam (float)

  • lr (float)

  • max_num_steps (int)

Returns:

  • aligned_pts ((total, M, 3))

  • se3_transform ((total, 4, 4))

  • scores ((total,))

shepherd_score.alignment._jax_parallel.optimize_pharm_overlay_jax_pharm_shmap(ref_pharms_batch, fit_pharms_batch, ref_anchors_batch, fit_anchors_batch, ref_vectors_batch, fit_vectors_batch, mask_ref_batch, mask_fit_batch, ref_self_batch, fit_self_batch, se3_init_batch, similarity, extended_points, only_extended, lr, max_num_steps)[source]#

Masked pharmacophore alignment via shard_map + vmap across virtual CPU devices.

Parameters:
  • ref_pharms_batch ((total, N) padded reference pharmacophore type indices (int32))

  • fit_pharms_batch ((total, M) padded fit pharmacophore type indices)

  • ref_anchors_batch ((total, N, 3) padded reference anchor positions)

  • fit_anchors_batch ((total, M, 3) padded fit anchor positions)

  • ref_vectors_batch ((total, N, 3) padded reference direction vectors)

  • fit_vectors_batch ((total, M, 3) padded fit direction vectors)

  • mask_ref_batch ((total, N) binary masks)

  • mask_fit_batch ((total, M))

  • ref_self_batch ((total,) pre-computed ref self-overlaps)

  • fit_self_batch ((total,) pre-computed fit self-overlaps)

  • se3_init_batch ((total, R, 7) pre-initialised SE(3) params)

  • similarity (str ('tanimoto', 'tversky_ref', 'tversky_fit'))

  • extended_points (bool)

  • only_extended (bool)

  • lr (float)

  • max_num_steps (int)

Returns:

  • aligned_anchors ((total, M, 3))

  • aligned_vectors ((total, M, 3))

  • se3_transform ((total, 4, 4))

  • scores ((total,))