JAX Alignment#

Sequential JAX-based alignment functions. For multi-device parallel alignment via jax.shard_map see JAX Parallel Alignment.

Alignment implementation in Jax.

shepherd_score.alignment._jax.apply_SO3_transform_jax(vectors, se3_matrix)[source]#

Apply SO(3) transformation (rotation) to a set of vectors.

Parameters:
  • vectors (jax.Array)

  • se3_matrix (jax.Array)

Return type:

jax.Array

shepherd_score.alignment._jax.vmap_apply_SO3_transform_jax(vectors, se3_matrix)#

Apply SO(3) transformation (rotation) to a set of vectors.

Parameters:
  • vectors (jax.Array)

  • se3_matrix (jax.Array)

Return type:

jax.Array

shepherd_score.alignment._jax.batched_obj_ROCS_overlay_helper(se3_params, ref_points, fit_points, alpha)#

Objective function to optimize ROCS overlay. Jax implementation.

Parameters:
  • se3_params (Array (7,)) – Parameters for SE(3) transformation. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).

  • ref_points (Array (N,3)) – Reference points.

  • fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.

  • alpha (float) – Gaussian width parameter used in scoring function.

Return type:

Tanimoto overlap score

shepherd_score.alignment._jax.vmap_score_ROCS_overlay_with_avoid_jax(ref_points, fit_points, alpha, fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)#

See _objective_ROCS_overlay_with_avoid_jax.

Parameters:
  • ref_points (jax.Array)

  • fit_points (jax.Array)

  • alpha (float)

  • fit_points_for_avoid (jax.Array)

  • avoid_points (jax.Array)

  • avoid_min_dist (float)

  • avoid_weight (float)

Return type:

jax.Array

shepherd_score.alignment._jax.batched_obj_ROCS_overlay_with_avoid_helper(se3_params, ref_points, fit_points, alpha, fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)#

Objective function to optimize ROCS overlay while avoiding certain points. Jax implementation.

Parameters:
  • se3_params (Array (7,)) – Parameters for SE(3) transformation. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).

  • ref_points (Array (N,3)) – Reference points.

  • fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.

  • alpha (float) – Gaussian width parameter used in scoring function.

  • fit_points_for_avoid (Array (M,3)) – Set of points to apply SE(3) transformations to then compare to avoid_points

  • avoid_points (Array (K,3)) – Penalize overlap with these points

  • avoid_min_dist (float) – Minimum distance with no penalization between fit_points_for_avoid and avoid_points.

  • avoid_weight (float) – Weight for the avoid_points term in the scoring function.

Returns:

  • score in range [-avoid_weight, 1] where higher is better. 1 is complete fit and ref overlap

  • with no avoid overlap, -avoid_weight is no fit and ref overlap and complete fit avoid overlap.

Return type:

jax.Array

shepherd_score.alignment._jax.objective_ROCS_overlay_jax(se3_params, ref_points, fit_points, alpha)[source]#

Objective function to optimize ROCS overlay. Jax implementation.

Parameters:
  • se3_params (Array (batch, 7)) – Parameters for SE(3) transformation. Expects batch. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).

  • ref_points (Array (N,3)) – Reference points. (NOT batched since it assumes the same reference points).

  • fit_points (Array (batch, M,3)) – Expects batch. Set of points to apply SE(3) transformations to maximize shape similarity with ref_points. If you want to optimize to the same fit_points, with a batch of different se3_params, try use jnp.tile(fit_points, (batch, 1, 1)).

  • alpha (float) – Gaussian width parameter used in scoring function.

Returns:

loss – 1 - Tanimoto score

Return type:

Array (1,)

shepherd_score.alignment._jax.objective_ROCS_overlay_with_avoid_jax(se3_params, ref_points, fit_points, alpha, fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)[source]#

Objective function to optimize ROCS overlay. Includes points where overlap is a negative. Jax implementation.

Parameters:
  • se3_params (Array (batch, 7)) – Parameters for SE(3) transformation. Expects batch. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).

  • ref_points (Array (N,3)) – Reference points. (NOT batched since it assumes the same reference points).

  • fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.

  • alpha (float) – Gaussian width parameter used in scoring function.

  • fit_points_for_avoid (Array (M,3)) – Set of points to apply SE(3) transformations to then compare to avoid_points

  • avoid_points (Array (K,3)) – Penalize overlap with these points

  • avoid_points – Penalize overlap with these points

  • avoid_min_dist (float) – Minimum distance with no penalization between fit_points_for_avoid and avoid_points.

  • avoid_weight (float) – Weight for the avoid_points term in the scoring function.

Returns:

loss – 1 - (Tanimoto score of fit/ref) + avoid_weight * (max pairwise overlap of fit/avoid)

Return type:

Array (1,)

shepherd_score.alignment._jax.batched_obj_ROCS_overlay_precomputed(se3_params, ref_points, fit_points, alpha, VAA, VBB)#

Single-instance ROCS objective using precomputed self-overlaps.

shepherd_score.alignment._jax.objective_ROCS_overlay_precomputed_jax(se3_params, ref, fit, alpha, VAA, VBB)[source]#
shepherd_score.alignment._jax.batched_obj_ROCS_overlay_precomputed_mask(se3_params, ref_points, fit_points, mask_ref, mask_fit, alpha, VAA, VBB)#

Single-instance masked ROCS objective using precomputed self-overlaps.

shepherd_score.alignment._jax.objective_ROCS_overlay_precomputed_jax_mask(se3_params, ref, fit, mask_ref, mask_fit, alpha, VAA, VBB)[source]#
shepherd_score.alignment._jax.batched_obj_ROCS_esp_overlay_precomputed(se3_params, ref_points, fit_points, ref_charges, fit_charges, alpha, lam, VAA, VBB)#

Single-instance non-masked ROCS ESP objective using precomputed self-overlaps.

shepherd_score.alignment._jax.objective_ROCS_esp_overlay_precomputed_jax(se3_params, ref, fit, ref_charges, fit_charges, alpha, lam, VAA, VBB)[source]#
shepherd_score.alignment._jax.batched_obj_ROCS_esp_overlay_precomputed_mask(se3_params, ref_points, fit_points, ref_charges, fit_charges, mask_ref, mask_fit, alpha, lam, VAA, VBB)#

Single-instance masked ROCS ESP objective using precomputed self-overlaps.

shepherd_score.alignment._jax.objective_ROCS_esp_overlay_precomputed_jax_mask(se3_params, ref, fit, ref_charges, fit_charges, mask_ref, mask_fit, alpha, lam, VAA, VBB)[source]#
shepherd_score.alignment._jax.optimize_ROCS_esp_overlay_jax_mask(ref_points, fit_points, ref_charges, fit_charges, mask_ref, mask_fit, alpha, lam, *, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#

Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and maximizing masked electrostatic-weighted Gaussian overlap score.

Identical to optimize_ROCS_esp_overlay_jax but accepts binary mask arrays so that padded (zero) entries are excluded from the overlap computation. Padding all arrays to a common maximum shape and passing masks allows JAX’s XLA compiler to reuse a single compiled function across all pairs in a batch, avoiding recompilation overhead.

Parameters:
  • ref_points (Array (N, 3)) – Reference points (may include zero-padding beyond mask_ref).

  • fit_points (Array (M, 3)) – Fit points (may include zero-padding beyond mask_fit).

  • ref_charges (Array (N,)) – Charges for reference points.

  • fit_charges (Array (M,)) – Charges for fit points.

  • mask_ref (Array (N,)) – Binary mask: 1 for real atoms, 0 for padding.

  • mask_fit (Array (M,)) – Binary mask: 1 for real atoms, 0 for padding.

  • alpha (float) – Gaussian width parameter.

  • lam (float) – Charge weighting parameter.

  • num_repeats (int (default=50))

  • trans_centers (array (P, 3) (default=None))

  • lr (float (default=0.1))

  • max_num_steps (int (default=200))

  • verbose (bool (default=False))

Returns:

aligned_points : Array (M, 3) SE3_transform : Array (4, 4) score : Array (1,)

Return type:

tuple

shepherd_score.alignment._jax.batched_obj_ROCS_overlay_with_avoid_precomputed(se3_params, ref_points, fit_points, alpha, VAA, VBB, fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)#
shepherd_score.alignment._jax.objective_ROCS_overlay_with_avoid_precomputed_jax(se3_params, ref, fit, alpha, VAA, VBB, fit_for_avoid, avoid, min_dist, weight)[source]#
shepherd_score.alignment._jax.optimize_ROCS_overlay_jax(ref_points, fit_points, alpha, *, fit_points_for_avoid=None, avoid_points=None, avoid_min_dist=2.0, avoid_weight=1.0, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#

Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and maximizing gaussian overlap score.

If num_repeats is 1, the initial guess for alignment is an identity rotation and aligned COMs. If num_repeats is 5 or greater, four initial guesses are aligned using principal components.

Parameters:
  • ref_points (Array (N,3)) – Reference points.

  • fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.

  • alpha (float) – Gaussian width parameter used in scoring function.

  • fit_points_for_avoid (Array (M,3)) – Set of points to apply SE(3) transformations to then compare to avoid_points

  • avoid_points (Array (K,3) (default=None)) – If not None, these are points that are used in an additional term in the objective function to penalize overlap with these points.

  • avoid_min_dist (float (default=2.0)) – Minimum distance with no penalization between fit_points_for_avoid and avoid_points.

  • avoid_weight (float (default=1.0)) – Weight for the avoid_points term in the scoring function.

  • num_repeats (int (default=50)) – Number of different random initializations of SE(3) transformation parameters.

  • trans_centers (array (P, 3) (default=None)) – Locations to translate fit_points’ center of mass as an initial guesses for optimization. At each translation center, 10 rotations are also sampled. So the number of initializations scales as (# translation centers * 10 + 5) where 5 is from the identity and 4 PCA with aligned COM’s. If None, then num_repeats rotations are done with aligned COM’s.

  • lr (float (default=0.1)) – Learning rate or step-size for optimization

  • max_num_steps (int (default=200)) – Maximum number of steps to optimize over.

  • verbose (bool (False)) – Print statements about initial and final similarity scores. Further, it will print scores during optimization at very 100 steps.

Returns:

aligned_pointsArray (M,3)

The transformed point cloud for fit_points using the optimized SE(3) transformation for alignment with ref_points.

SE3_transformArray (4,4)

Optimized SE(3) transformation matrix used to obtain aligned_points from fit_points.

scoreArray (1,)

Tanimoto shape similarity score for the optimal transformation.

Return type:

tuple

shepherd_score.alignment._jax.optimize_ROCS_overlay_jax_mask(ref_points, fit_points, mask_ref, mask_fit, alpha, *, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#

Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and maximizing masked Gaussian overlap score.

Identical to optimize_ROCS_overlay_jax but accepts binary mask arrays so that padded (zero) entries are excluded from the overlap computation. Padding all arrays to a common maximum shape and passing masks allows JAX’s XLA compiler to reuse a single compiled function across all pairs in a batch, avoiding recompilation overhead.

Parameters:
  • ref_points (Array (N, 3)) – Reference points (may include zero-padding beyond mask_ref).

  • fit_points (Array (M, 3)) – Fit points (may include zero-padding beyond mask_fit).

  • mask_ref (Array (N,)) – Binary mask: 1 for real atoms, 0 for padding.

  • mask_fit (Array (M,)) – Binary mask: 1 for real atoms, 0 for padding.

  • alpha (float) – Gaussian width parameter.

  • num_repeats (int (default=50))

  • trans_centers (array (P, 3) (default=None))

  • lr (float (default=0.1))

  • max_num_steps (int (default=200))

  • verbose (bool (default=False))

Returns:

aligned_points : Array (M, 3) SE3_transform : Array (4, 4) score : Array (1,)

Return type:

tuple

shepherd_score.alignment._jax.batched_obj_ROCS_esp_overlay_helper(se3_params, ref_points, fit_points, ref_charges, fit_charges, alpha, lam)#

Objective function to optimize ROCS esp overlay. Jax implementation.

Parameters:
  • se3_params (Array (7,)) – Parameters for SE(3) transformation. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).

  • ref_points (Array (N,3)) – Reference points.

  • fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.

  • alpha (float) – Gaussian width parameter used in scoring function.

  • ref_charges (jax.Array)

  • fit_charges (jax.Array)

  • lam (float)

Returns:

loss – 1 - Tanimoto score

Return type:

Array (1,)

shepherd_score.alignment._jax.objective_ROCS_esp_overlay_jax(se3_params, ref_points, fit_points, ref_charges, fit_charges, alpha, lam)[source]#

Objective function to optimize ROCS esp overlay.

Parameters:
  • se3_params (Array (batch, 7)) – Parameters for SE(3) transformation. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).

  • ref_points (Array (N,3)) – Reference points.

  • fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.

  • ref_charges (Array (N,)) – Electric potential at the corresponding ref_points coordinates.

  • fit_charges (Array (M,)) – Electric potential at the corresponding fit_points coordinates

  • alpha (float) – Gaussian width parameter used in scoring function.

  • lam (float) – Scaling term for charges used in the exponential kernel of the ESP scoring function.

Returns:

loss – 1 - mean(ESP Tanimoto score).

Return type:

Array (1,)

shepherd_score.alignment._jax.optimize_ROCS_esp_overlay_jax(ref_points, fit_points, ref_charges, fit_charges, alpha, lam, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#

Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and maximizing electrostatic-weighted gaussian overlap score.

Parameters:
  • ref_points (Array (N,3)) – Reference points.

  • fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.

  • ref_charges (Array (batch, N) or (N,)) – Electric potential at the corresponding ref_points coordinates.

  • fit_charges (Array (batch, N) or (N,)) – Electric potential at the corresponding fit_points coordinates

  • alpha (float) – Gaussian width parameter used in scoring function.

  • lam (float) – Scaling term for charges used in the exponential kernel of the ESP scoring function.

  • num_repeats (int (default=50)) – Number of different random initializations of SE(3) transformation parameters.

  • trans_centers (array (P, 3) (default=None)) – Locations to translate fit_points’ center of mass as an initial guesses for optimization. At each translation center, 10 rotations are also sampled. So the number of initializations scales as (# translation centers * 10 + 5) where 5 is from the identity and 4 PCA with aligned COM’s. If None, then num_repeats rotations are done with aligned COM’s.

  • lr (float (default=0.1)) – Learning rate or step-size for optimization

  • max_num_steps (int (default=200)) – Maximum number of steps to optimize over.

  • verbose (bool (False)) – Print statements about initial and final similarity scores. Further, it will print scores during optimization at very 100 steps.

Returns:

aligned_pointsArray (M,3)

The transformed point cloud for fit_points using the optimized SE(3) transformation for alignment with ref_points.

SE3_transformArray (4,4)

Optimized SE(3) transformation matrix used to obtain aligned_points from fit_points.

scoreArray (1,)

Tanimoto shape+ESP similarity score for the optimal transformation.

Return type:

tuple

shepherd_score.alignment._jax.batched_obj_esp_combo_score_helper(se3_params, ref_centers_w_H, fit_centers_w_H, ref_centers, fit_centers, ref_points, fit_points, ref_partial_charges, fit_partial_charges, ref_surf_esp, fit_surf_esp, ref_radii, fit_radii, alpha, lam, probe_radii=1.0, esp_weight=0.5)#

Helper function to apply se3_param transformations to all fit related coordinates. Compute the score for that transformation.

Return type:

jax.Array

shepherd_score.alignment._jax.objective_esp_combo_score_overlay_jax(se3_params, ref_centers_w_H, fit_centers_w_H, ref_centers, fit_centers, ref_points, fit_points, ref_partial_charges, fit_partial_charges, ref_surf_esp, fit_surf_esp, ref_radii, fit_radii, alpha, lam, probe_radii=1.0, esp_weight=0.5)[source]#

Computes the esp combo score in batch, takes the mean and convert to a loss.

Return type:

jax.Array

shepherd_score.alignment._jax.convert_to_jnp_array(arr)[source]#
shepherd_score.alignment._jax.optimize_esp_combo_score_overlay_jax(ref_centers_w_H, fit_centers_w_H, ref_centers, fit_centers, ref_points, fit_points, ref_partial_charges, fit_partial_charges, ref_surf_esp, fit_surf_esp, ref_radii, fit_radii, alpha, lam, probe_radius=1.0, esp_weight=0.5, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#

Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and maximizing ShaEP score.

Parameters:
  • ref_centers_w_H (Array (N + n_H, 3)) – Coordinates of atom centers INCLUDING hydrogens of reference molecule. Used for computing electrostatic potential. Same for fit_centers_w_H except (M + m_H, 3).

  • ref_centers (Array (N, 3) or (n_surf, 3)) – Coordinates of points for reference molecule used to compute shape similarity. Use atom centers for volumentric similarity. Use surface centers for surface similarity. Same for fit_centers except (M, 3) or (m_surf, 3).

  • ref_points (Array (n_surf, 3)) – Coordinates of surface points for referencemolecule. Same for fit_points except (m_surf, 3).

  • ref_partial_charges (Array (N + n_H,)) – Partial charges corresponding to the atoms in ref_centers_w_H. Same for fit_partial_charges except (M + m_H,).

  • ref_surf_esp (Array (n_surf,)) – The electrostatic potential calculated at each surface point (ref_points). Same for fit_surf_esp except (m_surf,)

  • ref_radii (Array (N + n_H,)) – vdW radii corresponding to the atoms in centers_w_H_1 (angstroms) Same for fit_radii except (M + m_H,)

  • alpha (float) – Gaussian width parameter used in shape similarity scoring function.

  • lam (float (default = 0.001)) – Electrostatic potential weighting parameter (smaller = higher weight). 0.001 was chosen as default based empirical observations of the distribution of scores generated by _esp_comparison before summation.

  • probe_radius (float (default = 1.0)) – Surface points found within vdW radii + probe radius will be masked out. Surface generation uses a probe radius of 1.2 (radius of hydrogen) so we use a slightly lower radius for be more tolerant.

  • esp_weight (float (default = 0.5)) – Weight to be placed on electrostatic similarity with respect to shape similarity. 0 = only shape similarity 1 = only electrostatic similarity

  • num_repeats (int (default=50)) – Number of different random initializations of SE(3) transformation parameters.

  • trans_centers (array (P, 3) (default=None)) – Locations to translate fit_points’ center of mass as an initial guesses for optimization. At each translation center, 10 rotations are also sampled. So the number of initializations scales as (# translation centers * 10 + 5) where 5 is from the identity and 4 PCA with aligned COM’s. If None, then num_repeats rotations are done with aligned COM’s.

  • lr (float (default=0.1)) – Learning rate or step-size for optimization

  • max_num_steps (int (default=200)) – Maximum number of steps to optimize over.

  • verbose (bool (False)) – Print statements about initial and final similarity scores. Further, it will print scores during optimization at very 100 steps.

  • fit_centers_w_H (jax.Array | ndarray)

  • fit_centers (jax.Array | ndarray)

  • fit_points (jax.Array | ndarray)

  • fit_partial_charges (jax.Array | ndarray | List)

  • fit_surf_esp (jax.Array | ndarray)

  • fit_radii (jax.Array | ndarray | List)

Returns:

aligned_pointsArray (M,3)

The transformed point cloud for fit_points using the optimized SE(3) transformation for alignment with ref_points.

SE3_transformArray (4,4)

Optimized SE(3) transformation matrix used to obtain aligned_points from fit_points.

scoreArray (1,)

ShaEP similarity score for the optimal transformation.

Return type:

tuple

shepherd_score.alignment._jax.batched_obj_pharm_overlay_helper(se3_params, ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, similarity='tanimoto', extended_points=False, only_extended=False)#

Objective function to optimize pharmacophore overlay for a single instance.

Parameters:
  • se3_params (jax.Array)

  • ref_pharms (jax.Array)

  • fit_pharms (jax.Array)

  • ref_anchors (jax.Array)

  • fit_anchors (jax.Array)

  • ref_vectors (jax.Array)

  • fit_vectors (jax.Array)

  • similarity (Literal['tanimoto', 'tversky', 'tversky_ref', 'tversky_fit'])

  • extended_points (bool)

  • only_extended (bool)

Return type:

jax.Array

shepherd_score.alignment._jax.objective_pharm_overlay_jax(se3_params, ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, similarity='tanimoto', extended_points=False, only_extended=False)[source]#

Objective function to optimize pharmacophore overlay. Batched.

Parameters:
  • se3_params (jax.Array)

  • ref_pharms (jax.Array)

  • fit_pharms (jax.Array)

  • ref_anchors (jax.Array)

  • fit_anchors (jax.Array)

  • ref_vectors (jax.Array)

  • fit_vectors (jax.Array)

  • similarity (Literal['tanimoto', 'tversky', 'tversky_ref', 'tversky_fit'])

  • extended_points (bool)

  • only_extended (bool)

Return type:

jax.Array

shepherd_score.alignment._jax.optimize_pharm_overlay_jax(ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, similarity='tanimoto', extended_points=False, only_extended=False, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#

Optimize alignment of fit_anchors with respect to ref_anchors using SE(3) transformations and maximizing pharmacophore overlap score. JAX implementation.

Parameters:
  • ref_pharms (jax.Array)

  • fit_pharms (jax.Array)

  • ref_anchors (jax.Array)

  • fit_anchors (jax.Array)

  • ref_vectors (jax.Array)

  • fit_vectors (jax.Array)

  • similarity (Literal['tanimoto', 'tversky', 'tversky_ref', 'tversky_fit'])

  • extended_points (bool)

  • only_extended (bool)

  • num_repeats (int)

  • trans_centers (jax.Array | ndarray | None)

  • lr (float)

  • max_num_steps (int)

  • verbose (bool)

Return type:

Tuple[jax.Array, jax.Array, jax.Array, jax.Array]

shepherd_score.alignment._jax.optimize_pharm_overlay_jax_vectorized(ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, similarity='tanimoto', extended_points=False, only_extended=False, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#

Optimize pharmacophore overlay using the fully-vectorized scoring function using jax.

Parameters:
  • ref_pharms (jax.Array)

  • fit_pharms (jax.Array)

  • ref_anchors (jax.Array)

  • fit_anchors (jax.Array)

  • ref_vectors (jax.Array)

  • fit_vectors (jax.Array)

  • similarity (str)

  • extended_points (bool)

  • only_extended (bool)

  • num_repeats (int)

  • trans_centers (jax.Array | ndarray | None)

  • lr (float)

  • max_num_steps (int)

  • verbose (bool)

Return type:

Tuple[jax.Array, jax.Array, jax.Array, jax.Array]

shepherd_score.alignment._jax.optimize_pharm_overlay_jax_vectorized_mask(ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, mask_ref, mask_fit, similarity='tanimoto', extended_points=False, only_extended=False, num_repeats=50, trans_centers=None, init_ref_anchors=None, init_fit_anchors=None, lr=0.1, max_num_steps=200, verbose=False)[source]#

Optimize pharmacophore overlay with padded/masked arrays via JAX.

Uses get_overlap_pharm_jax_vectorized_mask so that padding entries never contribute to the overlap. Accepts init_ref_anchors / init_fit_anchors (original unpadded arrays) for PCA/COM-based SE(3) initialization to avoid zero-padding bias.

Parameters:
  • ref_pharms (jax.Array)

  • fit_pharms (jax.Array)

  • ref_anchors (jax.Array)

  • fit_anchors (jax.Array)

  • ref_vectors (jax.Array)

  • fit_vectors (jax.Array)

  • mask_ref (jax.Array)

  • mask_fit (jax.Array)

  • similarity (str)

  • extended_points (bool)

  • only_extended (bool)

  • num_repeats (int)

  • trans_centers (jax.Array | ndarray | None)

  • init_ref_anchors (ndarray | None)

  • init_fit_anchors (ndarray | None)

  • lr (float)

  • max_num_steps (int)

  • verbose (bool)

Return type:

Tuple[jax.Array, jax.Array, jax.Array, jax.Array]