"""
Alignment implementation in Jax.
"""
from typing import Union, List, Tuple, Callable, Optional
from functools import lru_cache, partial
import numpy as np
import jax
import jax.numpy as jnp
from jax import vmap, jit, value_and_grad, Array, lax
import optax
import torch
from shepherd_score.score.gaussian_overlap_jax import get_overlap_jax, get_linear_hard_sphere_overlap_jax, VAB_2nd_order_jax, VAB_2nd_order_jax_mask, get_overlap_jax_mask
from shepherd_score.score.electrostatic_scoring_jax import get_overlap_esp_jax, VAB_2nd_order_esp_jax, esp_combo_score_jax, VAB_2nd_order_esp_jax_mask, get_overlap_esp_jax_mask
from shepherd_score.score.pharmacophore_scoring_jax import get_overlap_pharm_jax, get_overlap_pharm_jax_vectorized, get_overlap_pharm_jax_vectorized_mask, _SIM_TYPE
from shepherd_score.alignment.utils.pca_jax import quaternions_for_principal_component_alignment_jax, rotation_axis_jax, vmap_angle_between_vecs_jax, vmap_quaternion_from_axis_angle_jax
from shepherd_score.alignment.utils.se3_jax import get_SE3_transform_jax, apply_SE3_transform_jax
from shepherd_score.alignment import _initialize_se3_params, _initialize_se3_params_with_translations
vmap_get_overlap_jax = vmap(get_overlap_jax, (None, 0, None))
vmap_get_overlap_esp_jax = vmap(get_overlap_esp_jax, (None, 0, None, None, None, None))
vmap_esp_combo_score = vmap(esp_combo_score_jax, (None, 0,
None, 0,
None, 0,
None, None,
None, None,
None, None,
None, None, None, None))
vmap_apply_SE3_transform_jax = jit(vmap(apply_SE3_transform_jax, (None, 0)))
vmap_get_SE3_transform_jax = jit(vmap(get_SE3_transform_jax, 0))
vmap_apply_SO3_transform_jax = vmap(apply_SO3_transform_jax, (None, 0))
def _get_points_fibonacci_jax(num_samples: int) -> Array:
"""
Generate points on unit sphere using fibonacci approach. Jax implementation.
Adapted from Morfeus:
https://github.com/digital-chemistry-laboratory/morfeus/blob/main/morfeus/geometry.py
Parameters
----------
num_samples : int
Number of points to sample from the surface of a sphere
Returns
-------
Array (num_samples,3)
Coordinates of the sampled points.
"""
offset = 2.0 / num_samples
increment = jnp.pi * (3.0 - jnp.sqrt(5.0))
i = jnp.arange(num_samples)
y = ((i * offset) - 1) + (offset / 2)
r = jnp.sqrt(1 - jnp.square(y))
phi = jnp.mod((i + 1), num_samples) * increment
x = jnp.cos(phi) * r
z = jnp.sin(phi) * r
points = jnp.column_stack((x, y, z))
return points
def _objective_ROCS_overlay_jax(se3_params: Array,
ref_points: Array,
fit_points: Array,
alpha: float
) -> Array:
"""
Objective function to optimize ROCS overlay.
Jax implementation.
Parameters
----------
se3_params : Array (7,)
Parameters for SE(3) transformation.
The first 4 values in the last dimension are quaternions of form (r,i,j,k)
and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points : Array (N,3)
Reference points.
fit_points : Array (M,3)
Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
alpha : float
Gaussian width parameter used in scoring function.
Returns
-------
Tanimoto overlap score
"""
se3_matrix = get_SE3_transform_jax(se3_params)
fit_points = apply_SE3_transform_jax(fit_points, se3_matrix)
score = get_overlap_jax(ref_points, fit_points, alpha)
return score
batched_obj_ROCS_overlay_helper = vmap(_objective_ROCS_overlay_jax, (0, None, None, None))
def _score_ROCS_overlay_with_avoid_jax(ref_points: Array,
fit_points: Array,
alpha: float,
fit_points_for_avoid: Array,
avoid_points: Array,
avoid_min_dist: float,
avoid_weight: float) -> Array:
"""See _objective_ROCS_overlay_with_avoid_jax. """
score = get_overlap_jax(ref_points, fit_points, alpha)
avoid_score = get_linear_hard_sphere_overlap_jax(avoid_points, fit_points_for_avoid, avoid_min_dist)
return score - avoid_weight * avoid_score
# This parallels vmap_get_overlap_jax but for the case with avoid points.
vmap_score_ROCS_overlay_with_avoid_jax = vmap(_score_ROCS_overlay_with_avoid_jax, (None, 0, None, 0, None, None, None))
def _objective_ROCS_overlay_with_avoid_jax(se3_params: Array,
ref_points: Array,
fit_points: Array,
alpha: float,
fit_points_for_avoid: Array,
avoid_points: Array,
avoid_min_dist: float,
avoid_weight: float,
) -> Array:
"""
Objective function to optimize ROCS overlay while avoiding certain points.
Jax implementation.
Parameters
----------
se3_params : Array (7,)
Parameters for SE(3) transformation.
The first 4 values in the last dimension are quaternions of form (r,i,j,k)
and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points : Array (N,3)
Reference points.
fit_points : Array (M,3)
Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
alpha : float
Gaussian width parameter used in scoring function.
fit_points_for_avoid : Array (M,3)
Set of points to apply SE(3) transformations to then compare to avoid_points
avoid_points : Array (K,3)
Penalize overlap with these points
avoid_min_dist : float
Minimum distance with no penalization between fit_points_for_avoid and avoid_points.
avoid_weight : float
Weight for the avoid_points term in the scoring function.
Returns
-------
score in range [-avoid_weight, 1] where higher is better. 1 is complete fit and ref overlap
with no avoid overlap, -avoid_weight is no fit and ref overlap and complete fit avoid overlap.
"""
se3_matrix = get_SE3_transform_jax(se3_params)
fit_points = apply_SE3_transform_jax(fit_points, se3_matrix)
fit_points_for_avoid = apply_SE3_transform_jax(fit_points_for_avoid, se3_matrix)
return _score_ROCS_overlay_with_avoid_jax(ref_points, fit_points, alpha, fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)
batched_obj_ROCS_overlay_with_avoid_helper = vmap(_objective_ROCS_overlay_with_avoid_jax, (0, None, None, None, None, None, None, None))
[docs]
def objective_ROCS_overlay_jax(se3_params: Array,
ref_points: Array,
fit_points: Array,
alpha: float
) -> Array:
"""
Objective function to optimize ROCS overlay.
Jax implementation.
Parameters
----------
se3_params : Array (batch, 7)
Parameters for SE(3) transformation. Expects batch.
The first 4 values in the last dimension are quaternions of form (r,i,j,k)
and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points : Array (N,3)
Reference points. (NOT batched since it assumes the same reference points).
fit_points : Array (batch, M,3)
Expects batch.
Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
If you want to optimize to the same fit_points, with a batch of different
se3_params, try use jnp.tile(fit_points, (batch, 1, 1)).
alpha : float
Gaussian width parameter used in scoring function.
Returns
-------
loss : Array (1,)
1 - Tanimoto score
"""
scores = batched_obj_ROCS_overlay_helper(se3_params,
ref_points,
fit_points,
alpha)
return 1 - scores.mean()
[docs]
def objective_ROCS_overlay_with_avoid_jax(
se3_params: Array,
ref_points: Array,
fit_points: Array,
alpha: float,
fit_points_for_avoid: Array,
avoid_points: Array,
avoid_min_dist: float,
avoid_weight: float,
) -> Array:
"""
Objective function to optimize ROCS overlay.
Includes points where overlap is a negative.
Jax implementation.
Parameters
----------
se3_params : Array (batch, 7)
Parameters for SE(3) transformation. Expects batch.
The first 4 values in the last dimension are quaternions of form (r,i,j,k)
and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points : Array (N,3)
Reference points. (NOT batched since it assumes the same reference points).
fit_points : Array (M,3)
Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
alpha : float
Gaussian width parameter used in scoring function.
fit_points_for_avoid : Array (M,3)
Set of points to apply SE(3) transformations to then compare to avoid_points
avoid_points : Array (K,3)
Penalize overlap with these points
avoid_points : Array (K,3)
Penalize overlap with these points
avoid_min_dist : float
Minimum distance with no penalization between fit_points_for_avoid and avoid_points.
avoid_weight : float
Weight for the avoid_points term in the scoring function.
Returns
-------
loss : Array (1,)
1
- (Tanimoto score of fit/ref)
+ avoid_weight * (max pairwise overlap of fit/avoid)
"""
scores = batched_obj_ROCS_overlay_with_avoid_helper(
se3_params,
ref_points,
fit_points,
alpha,
fit_points_for_avoid,
avoid_points,
avoid_min_dist,
avoid_weight)
return 1 - scores.mean()
def _quats_from_fibo_jax(num_samples: int):
"""
Computes the quaternions corresponding to the a uniform distribution (deterministic) of
rotations. Does this by finding out the quaternions necessary to rotate a unit vector
to points sampled on a sphere from the golden spiral method or Fibonacci sphere surface
sampling.
Jax implementation.
Parameters
----------
num_samples : int
Number of rotations to generate.
Returns
-------
quaternions : Array (num_samples, 4)
quaternions corresponding to each rotation.
"""
fibo = _get_points_fibonacci_jax(num_samples)
unit_v = jnp.tile(jnp.array([1., 0., 0.]),
(num_samples, 1))
# quaternions = __quats_from_fibo_jax(unit_v, fibo)
angles = vmap_angle_between_vecs_jax(unit_v, fibo)
axes = rotation_axis_jax(unit_v, fibo)
quaternions = vmap_quaternion_from_axis_angle_jax(axes, angles)
return quaternions
def _get_45_fibo_jax() -> Array:
""" Precomputed values for se3_params_from_fibo(45).
Returns
-------
Array (45,4)
Corresponding quaternions for se3_params_from_fibo(45).
"""
return jnp.array([[ 0.6501596 , 0. , -0.10890594, -0.7519521 ],
[ 0.71811795, 0. , 0.24900949, -0.64984685],
[ 0.79960614, 0. , -0.22734107, -0.5558292 ],
[ 0.48607868, 0. , 0.09597147, -0.8686294 ],
[ 0.8678287 , 0. , 0.18554172, -0.46092048],
[ 0.6441806 , 0. , -0.49103084, -0.58644706],
[ 0.58135426, 0. , 0.53663224, -0.61159873],
[ 0.9219894 , 0. , -0.13865991, -0.36153716],
[ 0.37174237, 0. , -0.4017539 , -0.8368999 ],
[ 0.82034767, 0. , 0.4505742 , -0.3521542 ],
[ 0.7915699 , 0. , -0.5098301 , -0.3368833 ],
[ 0.35016882, 0. , 0.62455714, -0.69807595],
[ 0.9682232 , 0. , 0.0993299 , -0.22951545],
[ 0.48625368, 0. , -0.7709624 , -0.41130796],
[ 0.6632823 , 0. , 0.69872594, -0.26802734],
[ 0.92916685, 0. , -0.3295777 , -0.16741402],
[ 0.13607754, 0. , -0.1463197 , -0.97983336],
[ 0.9195395 , 0. , 0.37396038, -0.12083343],
[ 0.6908489 , 0. , -0.71145827, -0.12866619],
[ 0.427207 , 0. , 0.89058506, -0.15605238],
[ 0.9967814 , 0. , -0.06662399, -0.04458794],
[ 0.2999607 , 0. , -0.95107055, -0.07408379],
[ 0.78085893, -0. , 0.6247074 , 0. ],
[ 0.8650692 , 0. , -0.5009943 , 0.02568838],
[ 0.15980992, -0. , 0.9471624 , 0.2781082 ],
[ 0.9745988 , -0. , 0.21325576, 0.06840423],
[ 0.5568162 , 0. , -0.8151512 , 0.15963776],
[ 0.57879627, -0. , 0.79255456, 0.19196929],
[ 0.962584 , 0. , -0.23290652, 0.13851605],
[ 0.20126757, 0. , -0.60178804, 0.7728793 ],
[ 0.86761075, -0. , 0.45306247, 0.20490502],
[ 0.7600118 , 0. , -0.5942492 , 0.2631538 ],
[ 0.3819389 , -0. , 0.71805334, 0.5818266 ],
[ 0.96679044, -0. , 0.03724971, 0.2528412 ],
[ 0.46128264, 0. , -0.67306805, 0.57809824],
[ 0.7085131 , -0. , 0.575984 , 0.40773967],
[ 0.8734927 , 0. , -0.33189464, 0.3561691 ],
[ 0.35904366, -0. , 0.09578869, 0.92839223],
[ 0.8831887 , -0. , 0.24063599, 0.40258166],
[ 0.6643608 , 0. , -0.48505744, 0.56863344],
[ 0.58328235, -0. , 0.43531075, 0.6857742 ],
[ 0.8708025 , 0. , -0.08129112, 0.4848656 ],
[ 0.5442492 , 0. , -0.19216032, 0.81661934],
[ 0.74993277, -0. , 0.2244347 , 0.622278 ],
[ 0.77770305, 0. , 0. , 0.6286318 ]])
def _initialize_se3_params_jax(ref_points: Array,
fit_points: Array,
num_repeats: int = 50
) -> Array:
"""
Initialize SE(3) parameter guesses. Jax implementation.
SLOWER THAN TORCH.
First four values are the quaternion and the last three
are the translation.
All initial translations are to align fit_points COM with ref_points' COM.
The first set corresponds to no rotation.
The next four (if applicable) correspond to principal component alignment with ref_points.
All other transformations are rotations generated from Fibonacci sampling of points on a
sphere.
Parameters
----------
ref_points : Array (N,3)
Reference points.
fit_points : Array (M,3)
Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
num_repeats : int (default=50)
Number of different random initializations of SE(3) transformation parameters.
Returns
-------
se3_params : Array (num_repeats, 7)
Initial guesses for the SE(3) transformation parameters.
"""
# Initial guess for SE(3) parameters (quaternion followed by translation)
ref_points_com = ref_points.mean(0)
fit_points_com = fit_points.mean(0)
# Always do all principal components if num_repeats is greater than 1
if num_repeats < 5:
num_repeats = 5
# First guess keeps the original orientation but aligns the COMs
# Switch to just local optimization, no COM alignment
se3_params = jnp.zeros((num_repeats, 7))
se3_params = se3_params.at[0, :4].set(jnp.array([1.0, 0.0, 0.0, 0.0]))
# se3_params = se3_params.at[0, 4:].set(-fit_points_com + ref_points_com)
# Align the principal components for the next 4
pca_quats = quaternions_for_principal_component_alignment_jax(ref_points, fit_points)
se3_params = se3_params.at[1:5, :4].set(jnp.array(pca_quats)) # rotation component for centered points
SE3_rotation = vmap_get_SE3_transform_jax(se3_params.at[1:5].get()) # only rotation
# Rotate translation to COM in original coordinates
T = vmap_apply_SE3_transform_jax(fit_points_com, SE3_rotation).squeeze()
# Apply translation to center COMs by taking into account implicit translation done in PCA
se3_params = se3_params.at[1:5, 4:].set(- T + ref_points_com)
# Do random rotations
if num_repeats > 5:
if num_repeats == 50:
# Precomputed se3_params from fibonacci sampling of 45
se3_params = se3_params.at[5:, :4].set(_get_45_fibo_jax())
else:
se3_params = se3_params.at[5:, :4].set(_quats_from_fibo_jax(num_repeats - 5))
# Adjust translation to COM with the corresponding rotations
SE3_rotation = vmap_get_SE3_transform_jax(se3_params.at[5:].get()) # only rotation
T = vmap_apply_SE3_transform_jax(fit_points_com, SE3_rotation).squeeze()
# Apply translation to center COMs by taking into account implicit translation done with rotations
se3_params = se3_params.at[5:, 4:].set(- T + ref_points_com)
return se3_params
# TRIED TO REPLACE PYTORCH VERSION BUT NO REAL SPEEDUP
# def _quats_from_fibo_np(num_samples: int):
# """
# Computes the quaternions corresponding to the a uniform distribution (deterministic) of
# rotations. Does this by finding out the quaternions necessary to rotate a unit vector
# to points sampled on a sphere from the golden spiral method or Fibonacci sphere surface
# sampling.
# Parameters
# ----------
# num_samples : int
# Number of rotations to generate.
# Returns
# -------
# quaternions : torch.Tensor (num_samples, 4)
# quaternions corresponding to each rotation.
# """
# fibo = _get_points_fibonacci(num_samples)
# unit_v = np.array([1., 0., 0.])
# quaternions = np.zeros((num_samples, 4))
# for i in range(num_samples):
# angles = angle_between_vecs_np(unit_v, fibo[i])
# axes = rotation_axis_np(unit_v, fibo[i])
# quaternions[i] = quaternion_from_axis_angle_np(axes, angles)
# return quaternions
# def _initialize_se3_params_with_translations_np(ref_points: np.ndarray,
# fit_points: np.ndarray,
# trans_centers: np.ndarray,
# num_repeats_per_trans: int = 10
# ) -> np.ndarray:
# """
# Slower than Torch so use Torch version. Scales linearlly with num_repeats_per_trans.
# """
# # Initial guess for SE(3) parameters (quaternion followed by translation)
# ref_points_com = ref_points.mean(0)
# fit_points_com = fit_points.mean(0)
# num_repeats = num_repeats_per_trans * trans_centers.shape[0] + 5
# # First guess keeps the original orientation but aligns the COMs
# se3_params = np.zeros((num_repeats, 7))
# se3_params[0, :4] = np.array([1.0, 0.0, 0.0, 0.0])
# se3_params[0, 4:] = -fit_points_com + ref_points_com
# pca_quats = quaternions_for_principal_component_alignment_np(ref_points, fit_points)
# se3_params[1:5, :4] = pca_quats # rotation component for centered points
# fit_points_com = fit_points_com.reshape(1,-1)
# for i in range(1,5):
# SE3_rotation = get_SE3_transform_np(se3_params[i]) # only rotation
# # Rotate translation to COM in original coordinates
# T = apply_SE3_transform_np(fit_points_com, SE3_rotation).squeeze()
# # Apply translation to center COMs by taking into account implicit translation done in PCA
# se3_params[i, 4:] = - T + ref_points_com
# # Do random rotations
# quats = _quats_from_fibo_np(num_repeats_per_trans)
# quats = quats / np.linalg.norm(_quats_from_fibo_np(10), 2, 1, keepdims=True)
# se3_params[5:, :4] = np.tile(quats, (trans_centers.shape[0], 1))
# # Construct SE(3) transformation matrix for rotations
# SE3_rotation = np.eye(4)
# T = np.zeros((num_repeats_per_trans, 3))
# for i in range(num_repeats_per_trans):
# SE3_rotation[:3, :3] = quaternions_to_rotation_matrix_np(quats[i])
# # Adjust translation to COM with the corresponding rotations
# T[i] = apply_SE3_transform_np(fit_points_com, SE3_rotation)
# T = np.tile(T, (trans_centers.shape[0], 1))
# # translation to atoms
# trans_centers_rep = np.repeat(trans_centers, num_repeats_per_trans, 0)
# # Apply translation to center COMs by taking into account implicit translation done with rotations
# se3_params[5:, 4:] = - T + trans_centers_rep
# return se3_params
jit_val_grad_obj_ROCS = jit(value_and_grad(objective_ROCS_overlay_jax))
jit_val_grad_obj_ROCS_with_avoid = jit(value_and_grad(objective_ROCS_overlay_with_avoid_jax))
def _objective_ROCS_overlay_precomputed_jax(se3_params, ref_points, fit_points, alpha, VAA, VBB):
"""Single-instance ROCS objective using precomputed self-overlaps."""
se3_matrix = get_SE3_transform_jax(se3_params)
fit_t = apply_SE3_transform_jax(fit_points, se3_matrix)
VAB = VAB_2nd_order_jax(ref_points, fit_t, alpha)
return VAB / (VAA + VBB - VAB)
batched_obj_ROCS_overlay_precomputed = vmap(
_objective_ROCS_overlay_precomputed_jax, (0, None, None, None, None, None)
)
[docs]
def objective_ROCS_overlay_precomputed_jax(se3_params, ref, fit, alpha, VAA, VBB):
scores = batched_obj_ROCS_overlay_precomputed(se3_params, ref, fit, alpha, VAA, VBB)
return 1 - scores.mean()
jit_val_grad_obj_ROCS_precomputed = jit(value_and_grad(objective_ROCS_overlay_precomputed_jax))
def _objective_ROCS_overlay_precomputed_jax_mask(se3_params, ref_points, fit_points,
mask_ref, mask_fit, alpha, VAA, VBB):
"""Single-instance masked ROCS objective using precomputed self-overlaps."""
se3_matrix = get_SE3_transform_jax(se3_params)
fit_t = apply_SE3_transform_jax(fit_points, se3_matrix)
VAB = VAB_2nd_order_jax_mask(ref_points, fit_t, mask_ref, mask_fit, alpha)
return VAB / (VAA + VBB - VAB)
batched_obj_ROCS_overlay_precomputed_mask = vmap(
_objective_ROCS_overlay_precomputed_jax_mask,
(0, None, None, None, None, None, None, None)
)
[docs]
def objective_ROCS_overlay_precomputed_jax_mask(se3_params, ref, fit,
mask_ref, mask_fit, alpha, VAA, VBB):
scores = batched_obj_ROCS_overlay_precomputed_mask(
se3_params, ref, fit, mask_ref, mask_fit, alpha, VAA, VBB)
return 1 - scores.mean()
jit_val_grad_obj_ROCS_precomputed_mask = jit(
value_and_grad(objective_ROCS_overlay_precomputed_jax_mask)
)
vmap_get_overlap_jax_mask = vmap(get_overlap_jax_mask, (None, 0, None, None, None))
def _objective_ROCS_esp_overlay_precomputed_jax(se3_params, ref_points, fit_points,
ref_charges, fit_charges, alpha, lam, VAA, VBB):
"""Single-instance non-masked ROCS ESP objective using precomputed self-overlaps."""
se3_matrix = get_SE3_transform_jax(se3_params)
fit_t = apply_SE3_transform_jax(fit_points, se3_matrix)
VAB = VAB_2nd_order_esp_jax(ref_points, fit_t, ref_charges, fit_charges, alpha, lam)
return VAB / (VAA + VBB - VAB)
batched_obj_ROCS_esp_overlay_precomputed = vmap(
_objective_ROCS_esp_overlay_precomputed_jax,
(0, None, None, None, None, None, None, None, None)
)
[docs]
def objective_ROCS_esp_overlay_precomputed_jax(se3_params, ref, fit,
ref_charges, fit_charges, alpha, lam, VAA, VBB):
scores = batched_obj_ROCS_esp_overlay_precomputed(
se3_params, ref, fit, ref_charges, fit_charges, alpha, lam, VAA, VBB)
return 1 - scores.mean()
jit_val_grad_obj_ROCS_esp_precomputed = jit(
value_and_grad(objective_ROCS_esp_overlay_precomputed_jax)
)
def _objective_ROCS_esp_overlay_precomputed_jax_mask(se3_params, ref_points, fit_points,
ref_charges, fit_charges,
mask_ref, mask_fit, alpha, lam, VAA, VBB):
"""Single-instance masked ROCS ESP objective using precomputed self-overlaps."""
se3_matrix = get_SE3_transform_jax(se3_params)
fit_t = apply_SE3_transform_jax(fit_points, se3_matrix)
VAB = VAB_2nd_order_esp_jax_mask(ref_points, fit_t, ref_charges, fit_charges,
mask_ref, mask_fit, alpha, lam)
return VAB / (VAA + VBB - VAB)
batched_obj_ROCS_esp_overlay_precomputed_mask = vmap(
_objective_ROCS_esp_overlay_precomputed_jax_mask,
(0, None, None, None, None, None, None, None, None, None, None)
)
[docs]
def objective_ROCS_esp_overlay_precomputed_jax_mask(se3_params, ref, fit,
ref_charges, fit_charges,
mask_ref, mask_fit, alpha, lam, VAA, VBB):
scores = batched_obj_ROCS_esp_overlay_precomputed_mask(
se3_params, ref, fit, ref_charges, fit_charges, mask_ref, mask_fit, alpha, lam, VAA, VBB)
return 1 - scores.mean()
jit_val_grad_obj_ROCS_esp_precomputed_mask = jit(
value_and_grad(objective_ROCS_esp_overlay_precomputed_jax_mask)
)
vmap_get_overlap_esp_jax_mask = vmap(get_overlap_esp_jax_mask, (None, 0, None, None, None, None, None, None))
[docs]
def optimize_ROCS_esp_overlay_jax_mask(ref_points: Array,
fit_points: Array,
ref_charges: Array,
fit_charges: Array,
mask_ref: Array,
mask_fit: Array,
alpha: float,
lam: float,
*,
num_repeats: int = 50,
trans_centers: Union[Array, np.ndarray, None] = None,
lr: float = 0.1,
max_num_steps: int = 200,
verbose: bool = False
) -> Tuple[Array, Array, Array]:
"""
Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and
maximizing masked electrostatic-weighted Gaussian overlap score.
Identical to ``optimize_ROCS_esp_overlay_jax`` but accepts binary mask arrays so that padded
(zero) entries are excluded from the overlap computation. Padding all arrays to a common
maximum shape and passing masks allows JAX's XLA compiler to reuse a single compiled
function across all pairs in a batch, avoiding recompilation overhead.
Parameters
----------
ref_points : Array (N, 3)
Reference points (may include zero-padding beyond mask_ref).
fit_points : Array (M, 3)
Fit points (may include zero-padding beyond mask_fit).
ref_charges : Array (N,)
Charges for reference points.
fit_charges : Array (M,)
Charges for fit points.
mask_ref : Array (N,)
Binary mask: 1 for real atoms, 0 for padding.
mask_fit : Array (M,)
Binary mask: 1 for real atoms, 0 for padding.
alpha : float
Gaussian width parameter.
lam : float
Charge weighting parameter.
num_repeats : int (default=50)
trans_centers : array (P, 3) (default=None)
lr : float (default=0.1)
max_num_steps : int (default=200)
verbose : bool (default=False)
Returns
-------
tuple
aligned_points : Array (M, 3)
SE3_transform : Array (4, 4)
score : Array (1,)
"""
if trans_centers is None:
se3_params = _initialize_se3_params(ref_points=torch.Tensor(np.array(ref_points)),
fit_points=torch.Tensor(np.array(fit_points)),
num_repeats=num_repeats).detach()
if num_repeats == 1:
se3_params = se3_params.unsqueeze(0)
else:
se3_params = _initialize_se3_params_with_translations(
ref_points=torch.Tensor(np.array(ref_points)),
fit_points=torch.Tensor(np.array(fit_points)),
trans_centers=torch.Tensor(np.array(trans_centers)),
num_repeats_per_trans=10).detach()
if len(se3_params.shape) == 1:
se3_params = se3_params.unsqueeze(0)
se3_params = jnp.array(se3_params)
# Reshape charges to (-1, 1) for jax_sq_cdist inside VAB_2nd_order_esp_jax_mask
ref_charges_col = jnp.reshape(ref_charges, (-1, 1))
fit_charges_col = jnp.reshape(fit_charges, (-1, 1))
VAA = VAB_2nd_order_esp_jax_mask(ref_points, ref_points, ref_charges_col, ref_charges_col,
mask_ref, mask_ref, alpha, lam)
VBB = VAB_2nd_order_esp_jax_mask(fit_points, fit_points, fit_charges_col, fit_charges_col,
mask_fit, mask_fit, alpha, lam)
if verbose:
print(f'Initial score: {get_overlap_esp_jax_mask(ref_points, fit_points, ref_charges, fit_charges, mask_ref, mask_fit, alpha, lam):.3f}')
data_args = (ref_points, fit_points, ref_charges_col, fit_charges_col,
mask_ref, mask_fit, alpha, lam, VAA, VBB)
se3_opt = _generic_optimize_loop(
se3_params,
data_args,
jit_val_grad_obj_ROCS_esp_precomputed_mask,
lr,
max_num_steps
)
SE3_transform = vmap_get_SE3_transform_jax(se3_opt)
aligned_points = vmap_apply_SE3_transform_jax(fit_points, SE3_transform)
scores = vmap_get_overlap_esp_jax_mask(ref_points, aligned_points,
ref_charges, fit_charges,
mask_ref, mask_fit, alpha, lam)
if verbose:
print(f'Optimized score max: {scores.max():.3f} | mean: {scores.mean():.3f}')
best_idx = jnp.argmax(scores)
return (aligned_points.at[best_idx].get(),
SE3_transform.at[best_idx].get(),
scores.at[best_idx].get())
def _score_ROCS_overlay_with_avoid_precomputed_jax(ref_points, fit_points, alpha,
VAA, VBB,
fit_points_for_avoid, avoid_points,
avoid_min_dist, avoid_weight):
VAB = VAB_2nd_order_jax(ref_points, fit_points, alpha)
score = VAB / (VAA + VBB - VAB)
avoid_score = get_linear_hard_sphere_overlap_jax(avoid_points, fit_points_for_avoid, avoid_min_dist)
return score - avoid_weight * avoid_score
def _objective_ROCS_overlay_with_avoid_precomputed_jax(se3_params, ref_points, fit_points, alpha,
VAA, VBB,
fit_points_for_avoid, avoid_points,
avoid_min_dist, avoid_weight):
se3_matrix = get_SE3_transform_jax(se3_params)
fit_points_t = apply_SE3_transform_jax(fit_points, se3_matrix)
fit_points_for_avoid_t = apply_SE3_transform_jax(fit_points_for_avoid, se3_matrix)
return _score_ROCS_overlay_with_avoid_precomputed_jax(
ref_points, fit_points_t, alpha, VAA, VBB,
fit_points_for_avoid_t, avoid_points, avoid_min_dist, avoid_weight)
batched_obj_ROCS_overlay_with_avoid_precomputed = vmap(
_objective_ROCS_overlay_with_avoid_precomputed_jax,
(0, None, None, None, None, None, None, None, None, None)
)
[docs]
def objective_ROCS_overlay_with_avoid_precomputed_jax(se3_params, ref, fit, alpha,
VAA, VBB,
fit_for_avoid, avoid, min_dist, weight):
scores = batched_obj_ROCS_overlay_with_avoid_precomputed(
se3_params, ref, fit, alpha, VAA, VBB, fit_for_avoid, avoid, min_dist, weight)
return 1 - scores.mean()
jit_val_grad_obj_ROCS_with_avoid_precomputed = jit(
value_and_grad(objective_ROCS_overlay_with_avoid_precomputed_jax)
)
@partial(jax.jit, static_argnames=['val_and_grad_fn', 'max_num_steps'])
def _generic_optimize_loop(
se3_params: jnp.ndarray,
data_args: Tuple,
val_and_grad_fn: Callable,
lr: float,
max_num_steps: int
):
"""
Generic optimization loop.
Parameters
----------
se3_params: Array (num_repeats, 7)
Initial SE3 parameters.
data_args: Tuple
A tuple containing all array arguments (points, charges, etc.)
required by the val_and_grad_fn.
val_and_grad_fn: Callable
The JIT-compiled value_and_grad function.
Must accept (params, *data_args).
lr: float
Learning rate.
max_num_steps: int
Max steps.
Returns
-------
se3_params_opt: Array (num_repeats, 7)
Optimized SE3 parameters.
"""
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(se3_params)
# Loop state: (params, opt_state, last_loss, counter, step_index)
init_val = (se3_params, opt_state, jnp.array(1.0), 0, 0)
def early_stop_cond(val):
_, _, _, counter, step = val
return (step < max_num_steps) & (counter <= 10)
def optim_step(val):
params, state, last_loss, counter, step = val
loss, grads = val_and_grad_fn(params, *data_args)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
delta = jnp.abs(loss - last_loss)
new_counter = jnp.where(delta > 1e-5, 0, counter + 1)
return (new_params, new_state, loss, new_counter, step + 1)
final_val = lax.while_loop(early_stop_cond, optim_step, init_val)
se3_params_opt, _, _, _, _ = final_val
return se3_params_opt
[docs]
def optimize_ROCS_overlay_jax(ref_points: Array,
fit_points: Array,
alpha: float,
*,
fit_points_for_avoid: Optional[Array] = None,
avoid_points: Optional[Array] = None,
avoid_min_dist: float = 2.0,
avoid_weight: float = 1.0,
num_repeats: int = 50,
trans_centers: Union[Array, np.ndarray, None] = None,
lr: float = 0.1,
max_num_steps: int = 200,
verbose: bool = False
) -> Tuple[Array, Array, Array]:
"""
Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and
maximizing gaussian overlap score.
If num_repeats is 1, the initial guess for alignment is an identity rotation and aligned COMs.
If num_repeats is 5 or greater, four initial guesses are aligned using principal components.
Parameters
----------
ref_points : Array (N,3)
Reference points.
fit_points : Array (M,3)
Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
alpha : float
Gaussian width parameter used in scoring function.
fit_points_for_avoid : Array (M,3)
Set of points to apply SE(3) transformations to then compare to avoid_points
avoid_points : Array (K,3) (default=None)
If not None, these are points that are used in an additional term in the objective function
to penalize overlap with these points.
avoid_min_dist : float (default=2.0)
Minimum distance with no penalization between fit_points_for_avoid and avoid_points.
avoid_weight : float (default=1.0)
Weight for the avoid_points term in the scoring function.
num_repeats : int (default=50)
Number of different random initializations of SE(3) transformation parameters.
trans_centers : array (P, 3) (default=None)
Locations to translate fit_points' center of mass as an initial guesses for optimization.
At each translation center, 10 rotations are also sampled. So the number of initializations
scales as (# translation centers * 10 + 5) where 5 is from the identity and 4 PCA with
aligned COM's.
If None, then num_repeats rotations are done with aligned COM's.
lr : float (default=0.1)
Learning rate or step-size for optimization
max_num_steps : int (default=200)
Maximum number of steps to optimize over.
verbose : bool (False)
Print statements about initial and final similarity scores. Further, it will print scores
during optimization at very 100 steps.
Returns
-------
tuple
aligned_points : Array (M,3)
The transformed point cloud for fit_points using the optimized SE(3) transformation for
alignment with ref_points.
SE3_transform : Array (4,4)
Optimized SE(3) transformation matrix used to obtain aligned_points from fit_points.
score : Array (1,)
Tanimoto shape similarity score for the optimal transformation.
"""
# Initial guess for SE(3) parameters (quaternion followed by translation)
# FASTER USING TORCH
# se3_params = _initialize_se3_params_jax(ref_points=ref_points, fit_points=fit_points, num_repeats=num_repeats)
if trans_centers is None:
se3_params = _initialize_se3_params(ref_points=torch.Tensor(np.array(ref_points)),
fit_points=torch.Tensor(np.array(fit_points)),
num_repeats=num_repeats).detach()
if num_repeats == 1:
se3_params = se3_params.unsqueeze(0)
else:
se3_params = _initialize_se3_params_with_translations(
ref_points=torch.Tensor(np.array(ref_points)),
fit_points=torch.Tensor(np.array(fit_points)),
trans_centers=torch.Tensor(np.array(trans_centers)),
num_repeats_per_trans=10).detach()
if len(se3_params.shape) == 1:
se3_params = se3_params.unsqueeze(0)
se3_params = jnp.array(se3_params)
if fit_points_for_avoid is None:
fit_points_for_avoid = fit_points
if verbose:
print(f'Initial score: {get_overlap_jax(ref_points, fit_points, alpha):.3f}')
if avoid_points is None:
VAA = VAB_2nd_order_jax(ref_points, ref_points, alpha)
VBB = VAB_2nd_order_jax(fit_points, fit_points, alpha)
data_args = (ref_points, fit_points, alpha, VAA, VBB)
se3_opt = _generic_optimize_loop(
se3_params,
data_args,
jit_val_grad_obj_ROCS_precomputed,
lr,
max_num_steps
)
else:
VAA = VAB_2nd_order_jax(ref_points, ref_points, alpha)
VBB = VAB_2nd_order_jax(fit_points, fit_points, alpha)
data_args = (ref_points, fit_points, alpha, VAA, VBB,
fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)
se3_opt = _generic_optimize_loop(
se3_params,
data_args,
jit_val_grad_obj_ROCS_with_avoid_precomputed,
lr,
max_num_steps
)
SE3_transform = vmap_get_SE3_transform_jax(se3_opt)
aligned_points = vmap_apply_SE3_transform_jax(fit_points, SE3_transform)
if avoid_points is None:
scores = vmap_get_overlap_jax(ref_points, aligned_points, alpha)
else:
aligned_points_for_avoid = vmap_apply_SE3_transform_jax(fit_points_for_avoid, SE3_transform)
scores = vmap_score_ROCS_overlay_with_avoid_jax(ref_points, aligned_points, alpha, aligned_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)
if verbose:
print(f'Optimized score max: {scores.max():.3f} | mean: {scores.mean():.3f}')
best_idx = jnp.argmax(scores)
return (aligned_points.at[best_idx].get(),
SE3_transform.at[best_idx].get(),
scores.at[best_idx].get())
[docs]
def optimize_ROCS_overlay_jax_mask(ref_points: Array,
fit_points: Array,
mask_ref: Array,
mask_fit: Array,
alpha: float,
*,
num_repeats: int = 50,
trans_centers: Union[Array, np.ndarray, None] = None,
lr: float = 0.1,
max_num_steps: int = 200,
verbose: bool = False
) -> Tuple[Array, Array, Array]:
"""
Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and
maximizing masked Gaussian overlap score.
Identical to ``optimize_ROCS_overlay_jax`` but accepts binary mask arrays so that padded
(zero) entries are excluded from the overlap computation. Padding all arrays to a common
maximum shape and passing masks allows JAX's XLA compiler to reuse a single compiled
function across all pairs in a batch, avoiding recompilation overhead.
Parameters
----------
ref_points : Array (N, 3)
Reference points (may include zero-padding beyond mask_ref).
fit_points : Array (M, 3)
Fit points (may include zero-padding beyond mask_fit).
mask_ref : Array (N,)
Binary mask: 1 for real atoms, 0 for padding.
mask_fit : Array (M,)
Binary mask: 1 for real atoms, 0 for padding.
alpha : float
Gaussian width parameter.
num_repeats : int (default=50)
trans_centers : array (P, 3) (default=None)
lr : float (default=0.1)
max_num_steps : int (default=200)
verbose : bool (default=False)
Returns
-------
tuple
aligned_points : Array (M, 3)
SE3_transform : Array (4, 4)
score : Array (1,)
"""
if trans_centers is None:
se3_params = _initialize_se3_params(ref_points=torch.Tensor(np.array(ref_points)),
fit_points=torch.Tensor(np.array(fit_points)),
num_repeats=num_repeats).detach()
if num_repeats == 1:
se3_params = se3_params.unsqueeze(0)
else:
se3_params = _initialize_se3_params_with_translations(
ref_points=torch.Tensor(np.array(ref_points)),
fit_points=torch.Tensor(np.array(fit_points)),
trans_centers=torch.Tensor(np.array(trans_centers)),
num_repeats_per_trans=10).detach()
if len(se3_params.shape) == 1:
se3_params = se3_params.unsqueeze(0)
se3_params = jnp.array(se3_params)
VAA = VAB_2nd_order_jax_mask(ref_points, ref_points, mask_ref, mask_ref, alpha)
VBB = VAB_2nd_order_jax_mask(fit_points, fit_points, mask_fit, mask_fit, alpha)
if verbose:
print(f'Initial score: {get_overlap_jax_mask(ref_points, fit_points, mask_ref, mask_fit, alpha):.3f}')
data_args = (ref_points, fit_points, mask_ref, mask_fit, alpha, VAA, VBB)
se3_opt = _generic_optimize_loop(
se3_params,
data_args,
jit_val_grad_obj_ROCS_precomputed_mask,
lr,
max_num_steps
)
SE3_transform = vmap_get_SE3_transform_jax(se3_opt)
aligned_points = vmap_apply_SE3_transform_jax(fit_points, SE3_transform)
scores = vmap_get_overlap_jax_mask(ref_points, aligned_points, mask_ref, mask_fit, alpha)
if verbose:
print(f'Optimized score max: {scores.max():.3f} | mean: {scores.mean():.3f}')
best_idx = jnp.argmax(scores)
return (aligned_points.at[best_idx].get(),
SE3_transform.at[best_idx].get(),
scores.at[best_idx].get())
def _per_pair_optimize_vol_mask_scan(
ref_pts, fit_pts, mask_ref, mask_fit, se3_init,
alpha, VAA, VBB, lr, max_num_steps,
):
"""Per-pair volumetric optimization via lax.scan — vmappable, fixed steps.
Unlike the ``while_loop``-based ``_generic_optimize_loop``, this variant
uses ``lax.scan`` with a static step count so it can be safely vmapped and
pmapped across many pairs simultaneously. There is no early stopping;
``max_num_steps`` must be a Python int (static at trace time).
Parameters
----------
ref_pts : (N, 3) padded reference atom positions
fit_pts : (M, 3) padded fit atom positions
mask_ref : (N,) binary mask (1 = real atom, 0 = padding)
mask_fit : (M,) binary mask
se3_init : (R, 7) pre-initialised SE(3) parameters
alpha : scalar Gaussian width
VAA, VBB : scalar pre-computed self-overlaps
lr : scalar Adam learning rate
max_num_steps : Python int — compile-time constant
Returns
-------
aligned_pts : (M, 3)
se3_transform : (4, 4)
score : scalar
"""
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(se3_init)
def scan_step(carry, _):
params, state = carry
loss, grads = value_and_grad(
objective_ROCS_overlay_precomputed_jax_mask
)(params, ref_pts, fit_pts, mask_ref, mask_fit, alpha, VAA, VBB)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return (new_params, new_state), None
(se3_opt, _), _ = lax.scan(
scan_step, (se3_init, opt_state), None, length=max_num_steps
)
SE3_transforms = vmap_get_SE3_transform_jax(se3_opt)
aligned_pts_all = vmap_apply_SE3_transform_jax(fit_pts, SE3_transforms)
scores = vmap_get_overlap_jax_mask(
ref_pts, aligned_pts_all, mask_ref, mask_fit, alpha
)
best_idx = jnp.argmax(scores)
return aligned_pts_all[best_idx], SE3_transforms[best_idx], scores[best_idx]
def _per_pair_optimize_vol_esp_mask_scan(
ref_pts, fit_pts, ref_charges, fit_charges,
mask_ref, mask_fit, se3_init,
alpha, lam, VAA, VBB, lr, max_num_steps,
):
"""Per-pair masked volumetric ESP optimization via lax.scan — vmappable, fixed steps.
Parameters
----------
ref_pts : (N, 3) padded reference atom positions
fit_pts : (M, 3) padded fit atom positions
ref_charges : (N, 1) padded reference charges (column-shaped for jax_sq_cdist)
fit_charges : (M, 1) padded fit charges
mask_ref : (N,) binary mask (1 = real atom, 0 = padding)
mask_fit : (M,) binary mask
se3_init : (R, 7) pre-initialised SE(3) parameters
alpha : scalar Gaussian width
lam : scalar charge-weighting parameter
VAA, VBB : scalar pre-computed ESP self-overlaps
lr : scalar Adam learning rate
max_num_steps : Python int — compile-time constant
Returns
-------
aligned_pts : (M, 3)
se3_transform : (4, 4)
score : scalar
"""
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(se3_init)
def scan_step(carry, _):
params, state = carry
loss, grads = value_and_grad(
objective_ROCS_esp_overlay_precomputed_jax_mask
)(params, ref_pts, fit_pts, ref_charges, fit_charges, mask_ref, mask_fit, alpha, lam, VAA, VBB)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return (new_params, new_state), None
(se3_opt, _), _ = lax.scan(
scan_step, (se3_init, opt_state), None, length=max_num_steps
)
SE3_transforms = vmap_get_SE3_transform_jax(se3_opt)
aligned_pts_all = vmap_apply_SE3_transform_jax(fit_pts, SE3_transforms)
# squeeze charges to 1D for final scoring
scores = vmap_get_overlap_esp_jax_mask(
ref_pts, aligned_pts_all, ref_charges[..., 0], fit_charges[..., 0],
mask_ref, mask_fit, alpha, lam
)
best_idx = jnp.argmax(scores)
return aligned_pts_all[best_idx], SE3_transforms[best_idx], scores[best_idx]
def _per_pair_optimize_surf_scan(
ref_pts, fit_pts, se3_init,
alpha, VAA, VBB, lr, max_num_steps,
):
"""Per-pair non-masked surface optimization via lax.scan — vmappable, fixed steps.
Parameters
----------
ref_pts : (N, 3) reference surface points (uniform size — no padding)
fit_pts : (M, 3) fit surface points
se3_init : (R, 7) pre-initialised SE(3) parameters
alpha : scalar Gaussian width
VAA, VBB : scalar pre-computed self-overlaps
lr : scalar Adam learning rate
max_num_steps : Python int — compile-time constant
Returns
-------
aligned_pts : (M, 3)
se3_transform : (4, 4)
score : scalar
"""
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(se3_init)
def scan_step(carry, _):
params, state = carry
loss, grads = value_and_grad(
objective_ROCS_overlay_precomputed_jax
)(params, ref_pts, fit_pts, alpha, VAA, VBB)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return (new_params, new_state), None
(se3_opt, _), _ = lax.scan(
scan_step, (se3_init, opt_state), None, length=max_num_steps
)
SE3_transforms = vmap_get_SE3_transform_jax(se3_opt)
aligned_pts_all = vmap_apply_SE3_transform_jax(fit_pts, SE3_transforms)
scores = vmap_get_overlap_jax(ref_pts, aligned_pts_all, alpha)
best_idx = jnp.argmax(scores)
return aligned_pts_all[best_idx], SE3_transforms[best_idx], scores[best_idx]
def _per_pair_optimize_surf_esp_scan(
ref_pts, fit_pts, ref_charges, fit_charges, se3_init,
alpha, lam, VAA, VBB, lr, max_num_steps,
):
"""Per-pair non-masked surface ESP optimization via lax.scan — vmappable, fixed steps.
Parameters
----------
ref_pts : (N, 3) reference surface points (uniform size — no padding)
fit_pts : (M, 3) fit surface points
ref_charges : (N,) reference surface ESP values
fit_charges : (M,) fit surface ESP values
se3_init : (R, 7) pre-initialised SE(3) parameters
alpha : scalar Gaussian width
lam : scalar charge-weighting parameter
VAA, VBB : scalar pre-computed ESP self-overlaps
lr : scalar Adam learning rate
max_num_steps : Python int — compile-time constant
Returns
-------
aligned_pts : (M, 3)
se3_transform : (4, 4)
score : scalar
"""
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(se3_init)
def scan_step(carry, _):
params, state = carry
loss, grads = value_and_grad(
objective_ROCS_esp_overlay_precomputed_jax
)(params, ref_pts, fit_pts, ref_charges, fit_charges, alpha, lam, VAA, VBB)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return (new_params, new_state), None
(se3_opt, _), _ = lax.scan(
scan_step, (se3_init, opt_state), None, length=max_num_steps
)
SE3_transforms = vmap_get_SE3_transform_jax(se3_opt)
aligned_pts_all = vmap_apply_SE3_transform_jax(fit_pts, SE3_transforms)
scores = vmap_get_overlap_esp_jax(ref_pts, aligned_pts_all, ref_charges, fit_charges, alpha, lam)
best_idx = jnp.argmax(scores)
return aligned_pts_all[best_idx], SE3_transforms[best_idx], scores[best_idx]
@lru_cache(maxsize=8)
def _per_pair_optimize_pharm_mask_scan_factory(similarity: str,
extended_points: bool,
only_extended: bool):
"""Factory returning a lax.scan-based per-pair pharm optimizer with static args baked in.
Uses lru_cache so the inner function is only built once per unique
(similarity, extended_points, only_extended) combination — analogous to
``_make_jit_val_grad_pharm_vectorized_mask``.
Returns
-------
Callable with signature:
_per_pair(ref_pharms, fit_pharms, ref_anchors, fit_anchors,
ref_vectors, fit_vectors, mask_ref, mask_fit,
se3_init, ref_self_score, fit_self_score, lr, max_num_steps)
-> (aligned_ancs (M,3), aligned_vecs (M,3), se3_transform (4,4), score scalar)
"""
_overlap_fn = partial(get_overlap_pharm_jax_vectorized_mask,
extended_points=extended_points,
only_extended=only_extended)
_eps = 1e-6
def _loss(se3_params,
ref_pharms, fit_pharms,
ref_anchors, fit_anchors,
ref_vectors, fit_vectors,
mask_ref, mask_fit,
ref_self_score, fit_self_score):
return _loss_fn_pharm_vectorized_mask(
se3_params,
ref_pharms, fit_pharms,
ref_anchors, fit_anchors,
ref_vectors, fit_vectors,
mask_ref, mask_fit,
ref_self_score, fit_self_score,
similarity=similarity,
extended_points=extended_points,
only_extended=only_extended,
)
def _per_pair(ref_pharms, fit_pharms,
ref_anchors, fit_anchors,
ref_vectors, fit_vectors,
mask_ref, mask_fit,
se3_init, ref_self_score, fit_self_score,
lr, max_num_steps):
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(se3_init)
def scan_step(carry, _):
params, state = carry
loss, grads = value_and_grad(_loss)(
params,
ref_pharms, fit_pharms,
ref_anchors, fit_anchors,
ref_vectors, fit_vectors,
mask_ref, mask_fit,
ref_self_score, fit_self_score,
)
updates, new_state = optimizer.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return (new_params, new_state), None
(se3_opt, _), _ = lax.scan(
scan_step, (se3_init, opt_state), None, length=max_num_steps
)
SE3_transforms = vmap_get_SE3_transform_jax(se3_opt)
aligned_ancs_all = vmap_apply_SE3_transform_jax(fit_anchors, SE3_transforms)
aligned_vecs_all = vmap_apply_SO3_transform_jax(fit_vectors, SE3_transforms)
vab_scores = vmap(
_overlap_fn,
in_axes=(None, None, None, 0, None, 0, None, None)
)(ref_pharms, fit_pharms, ref_anchors, aligned_ancs_all,
ref_vectors, aligned_vecs_all, mask_ref, mask_fit)
if similarity == 'tanimoto':
final_scores = vab_scores / (ref_self_score + fit_self_score - vab_scores + _eps)
elif similarity == 'tversky_ref':
final_scores = vab_scores / (ref_self_score + _eps)
elif similarity == 'tversky_fit':
final_scores = vab_scores / (fit_self_score + _eps)
else:
final_scores = vab_scores
best_idx = jnp.argmax(final_scores)
return (aligned_ancs_all[best_idx], aligned_vecs_all[best_idx],
SE3_transforms[best_idx], final_scores[best_idx])
return _per_pair
def _objective_ROCS_esp_overlay_jax(se3_params: Array,
ref_points: Array,
fit_points: Array,
ref_charges: Array,
fit_charges: Array,
alpha: float,
lam: float
) -> Array:
"""
Objective function to optimize ROCS esp overlay.
Jax implementation.
Parameters
----------
se3_params : Array (7,)
Parameters for SE(3) transformation.
The first 4 values in the last dimension are quaternions of form (r,i,j,k)
and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points : Array (N,3)
Reference points.
fit_points : Array (M,3)
Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
alpha : float
Gaussian width parameter used in scoring function.
Returns
-------
loss : Array (1,)
1 - Tanimoto score
"""
se3_matrix = get_SE3_transform_jax(se3_params)
fit_points = apply_SE3_transform_jax(fit_points, se3_matrix)
score = get_overlap_esp_jax(ref_points, fit_points, ref_charges, fit_charges, alpha, lam)
return score
batched_obj_ROCS_esp_overlay_helper = vmap(_objective_ROCS_esp_overlay_jax, (0, None, None, None, None, None, None))
[docs]
def objective_ROCS_esp_overlay_jax(se3_params: Array,
ref_points: Array,
fit_points: Array,
ref_charges: Array,
fit_charges: Array,
alpha: float,
lam: float
) -> Array:
"""
Objective function to optimize ROCS esp overlay.
Parameters
----------
se3_params : Array (batch, 7)
Parameters for SE(3) transformation.
The first 4 values in the last dimension are quaternions of form (r,i,j,k)
and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points : Array (N,3)
Reference points.
fit_points : Array (M,3)
Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
ref_charges : Array (N,)
Electric potential at the corresponding ref_points coordinates.
fit_charges : Array (M,)
Electric potential at the corresponding fit_points coordinates
alpha : float
Gaussian width parameter used in scoring function.
lam : float
Scaling term for charges used in the exponential kernel of the ESP scoring function.
Returns
-------
loss : Array (1,)
1 - mean(ESP Tanimoto score).
"""
scores = batched_obj_ROCS_esp_overlay_helper(se3_params, ref_points, fit_points, ref_charges, fit_charges, alpha, lam)
return 1-scores.mean()
jit_val_grad_obj_ROCS_esp = jit(value_and_grad(objective_ROCS_esp_overlay_jax))
[docs]
def optimize_ROCS_esp_overlay_jax(ref_points: Array,
fit_points: Array,
ref_charges: Array,
fit_charges: Array,
alpha: float,
lam: float,
num_repeats: int = 50,
trans_centers: Union[Array, np.ndarray, None] = None,
lr: float = 0.1,
max_num_steps: int = 200,
verbose: bool = False) -> Tuple[Array]:
"""
Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and
maximizing electrostatic-weighted gaussian overlap score.
Parameters
----------
ref_points : Array (N,3)
Reference points.
fit_points : Array (M,3)
Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
ref_charges : Array (batch, N) or (N,)
Electric potential at the corresponding ref_points coordinates.
fit_charges : Array (batch, N) or (N,)
Electric potential at the corresponding fit_points coordinates
alpha : float
Gaussian width parameter used in scoring function.
lam : float
Scaling term for charges used in the exponential kernel of the ESP scoring function.
num_repeats : int (default=50)
Number of different random initializations of SE(3) transformation parameters.
trans_centers : array (P, 3) (default=None)
Locations to translate fit_points' center of mass as an initial guesses for optimization.
At each translation center, 10 rotations are also sampled. So the number of initializations
scales as (# translation centers * 10 + 5) where 5 is from the identity and 4 PCA with
aligned COM's.
If None, then num_repeats rotations are done with aligned COM's.
lr : float (default=0.1)
Learning rate or step-size for optimization
max_num_steps : int (default=200)
Maximum number of steps to optimize over.
verbose : bool (False)
Print statements about initial and final similarity scores. Further, it will print scores
during optimization at very 100 steps.
Returns
-------
tuple
aligned_points : Array (M,3)
The transformed point cloud for fit_points using the optimized SE(3) transformation for
alignment with ref_points.
SE3_transform : Array (4,4)
Optimized SE(3) transformation matrix used to obtain aligned_points from fit_points.
score : Array (1,)
Tanimoto shape+ESP similarity score for the optimal transformation.
"""
# Initial guess for SE(3) parameters (quaternion followed by translation)
# FASTER USING TORCH
# se3_params = _initialize_se3_params_jax(ref_points=ref_points, fit_points=fit_points, num_repeats=num_repeats)
if trans_centers is None:
se3_params = _initialize_se3_params(ref_points=torch.Tensor(np.array(ref_points)),
fit_points=torch.Tensor(np.array(fit_points)),
num_repeats=num_repeats).detach()
if num_repeats == 1:
se3_params = se3_params.unsqueeze(0)
else:
se3_params = _initialize_se3_params_with_translations(
ref_points=torch.Tensor(np.array(ref_points)),
fit_points=torch.Tensor(np.array(fit_points)),
trans_centers=torch.Tensor(np.array(trans_centers)),
num_repeats_per_trans=10).detach()
if len(se3_params.shape) == 1:
se3_params = se3_params.unsqueeze(0)
se3_params = jnp.array(se3_params)
data_args = (ref_points, fit_points, ref_charges, fit_charges, alpha, lam)
se3_opt = _generic_optimize_loop(
se3_params,
data_args,
jit_val_grad_obj_ROCS_esp,
lr,
max_num_steps
)
SE3_transform = vmap_get_SE3_transform_jax(se3_opt)
aligned_points = vmap_apply_SE3_transform_jax(fit_points, SE3_transform)
scores = vmap_get_overlap_esp_jax(ref_points, aligned_points, ref_charges, fit_charges, alpha, lam)
best_idx = jnp.argmax(scores)
return (aligned_points.at[best_idx].get(),
SE3_transform.at[best_idx].get(),
scores.at[best_idx].get())
def _objective_esp_combo_score_overlay_jax(se3_params,
ref_centers_w_H,
fit_centers_w_H,
ref_centers,
fit_centers,
ref_points,
fit_points,
ref_partial_charges,
fit_partial_charges,
ref_surf_esp,
fit_surf_esp,
ref_radii,
fit_radii,
alpha,
lam,
probe_radii=1.0,
esp_weight=0.5) -> Array:
"""
Helper function to apply se3_param transformations to all fit related coordinates.
Compute the score for that transformation.
"""
se3_matrix = get_SE3_transform_jax(se3_params)
fit_centers_w_H = apply_SE3_transform_jax(fit_centers_w_H, se3_matrix)
fit_centers = apply_SE3_transform_jax(fit_centers, se3_matrix)
fit_points = apply_SE3_transform_jax(fit_points, se3_matrix)
score = esp_combo_score_jax(ref_centers_w_H,
fit_centers_w_H,
ref_centers,
fit_centers,
ref_points,
fit_points,
ref_partial_charges,
fit_partial_charges,
ref_surf_esp,
fit_surf_esp,
ref_radii,
fit_radii,
alpha,
lam,
probe_radii,
esp_weight)
return score
batched_obj_esp_combo_score_helper = vmap(_objective_esp_combo_score_overlay_jax, (0,
None, None,
None, None,
None, None,
None, None,
None, None,
None, None,
None, None,
None, None))
[docs]
def objective_esp_combo_score_overlay_jax(se3_params,
ref_centers_w_H,
fit_centers_w_H,
ref_centers,
fit_centers,
ref_points,
fit_points,
ref_partial_charges,
fit_partial_charges,
ref_surf_esp,
fit_surf_esp,
ref_radii,
fit_radii,
alpha,
lam,
probe_radii=1.0,
esp_weight=0.5) -> Array:
"""
Computes the esp combo score in batch, takes the mean and convert to a loss.
"""
scores = batched_obj_esp_combo_score_helper(se3_params,
ref_centers_w_H,
fit_centers_w_H,
ref_centers,
fit_centers,
ref_points,
fit_points,
ref_partial_charges,
fit_partial_charges,
ref_surf_esp,
fit_surf_esp,
ref_radii,
fit_radii,
alpha,
lam,
probe_radii,
esp_weight)
return 1-scores.mean()
jit_val_grad_obj_esp_combo_score_overlay = jit(value_and_grad(objective_esp_combo_score_overlay_jax))
[docs]
def convert_to_jnp_array(arr):
if not isinstance(arr, Array):
arr = jnp.array(arr)
return arr
[docs]
def optimize_esp_combo_score_overlay_jax(ref_centers_w_H: Union[Array, np.ndarray],
fit_centers_w_H: Union[Array, np.ndarray],
ref_centers: Union[Array, np.ndarray],
fit_centers: Union[Array, np.ndarray],
ref_points: Union[Array, np.ndarray],
fit_points: Union[Array, np.ndarray],
ref_partial_charges: Union[Array, np.ndarray, List],
fit_partial_charges: Union[Array, np.ndarray, List],
ref_surf_esp: Union[Array, np.ndarray],
fit_surf_esp: Union[Array, np.ndarray],
ref_radii: Union[Array, np.ndarray, List],
fit_radii: Union[Array, np.ndarray, List],
alpha: float,
lam: float,
probe_radius: float = 1.0,
esp_weight: float = 0.5,
num_repeats: int = 50,
trans_centers: Union[Array, np.ndarray, None] = None,
lr: float = 0.1,
max_num_steps: int = 200,
verbose: bool = False) -> Tuple[Array]:
"""
Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and
maximizing ShaEP score.
Parameters
----------
ref_centers_w_H : Array (N + n_H, 3)
Coordinates of atom centers INCLUDING hydrogens of reference molecule.
Used for computing electrostatic potential.
Same for fit_centers_w_H except (M + m_H, 3).
ref_centers : Array (N, 3) or (n_surf, 3)
Coordinates of points for reference molecule used to compute shape similarity.
Use atom centers for volumentric similarity. Use surface centers for surface similarity.
Same for fit_centers except (M, 3) or (m_surf, 3).
ref_points : Array (n_surf, 3)
Coordinates of surface points for referencemolecule.
Same for fit_points except (m_surf, 3).
ref_partial_charges : Array (N + n_H,)
Partial charges corresponding to the atoms in ref_centers_w_H.
Same for fit_partial_charges except (M + m_H,).
ref_surf_esp : Array (n_surf,)
The electrostatic potential calculated at each surface point (ref_points).
Same for fit_surf_esp except (m_surf,)
ref_radii : Array (N + n_H,)
vdW radii corresponding to the atoms in centers_w_H_1 (angstroms)
Same for fit_radii except (M + m_H,)
alpha : float
Gaussian width parameter used in shape similarity scoring function.
lam : float (default = 0.001)
Electrostatic potential weighting parameter (smaller = higher weight).
0.001 was chosen as default based empirical observations of the distribution of scores
generated by _esp_comparison before summation.
probe_radius : float (default = 1.0)
Surface points found within vdW radii + probe radius will be masked out. Surface generation
uses a probe radius of 1.2 (radius of hydrogen) so we use a slightly lower radius for be
more tolerant.
esp_weight : float (default = 0.5)
Weight to be placed on electrostatic similarity with respect to shape similarity.
0 = only shape similarity
1 = only electrostatic similarity
num_repeats : int (default=50)
Number of different random initializations of SE(3) transformation parameters.
trans_centers : array (P, 3) (default=None)
Locations to translate fit_points' center of mass as an initial guesses for optimization.
At each translation center, 10 rotations are also sampled. So the number of initializations
scales as (# translation centers * 10 + 5) where 5 is from the identity and 4 PCA with
aligned COM's.
If None, then num_repeats rotations are done with aligned COM's.
lr : float (default=0.1)
Learning rate or step-size for optimization
max_num_steps : int (default=200)
Maximum number of steps to optimize over.
verbose : bool (False)
Print statements about initial and final similarity scores. Further, it will print scores
during optimization at very 100 steps.
Returns
-------
tuple
aligned_points : Array (M,3)
The transformed point cloud for fit_points using the optimized SE(3) transformation for
alignment with ref_points.
SE3_transform : Array (4,4)
Optimized SE(3) transformation matrix used to obtain aligned_points from fit_points.
score : Array (1,)
ShaEP similarity score for the optimal transformation.
"""
# Initial guess for SE(3) parameters (quaternion followed by translation)
# FASTER USING TORCH
# se3_params = _initialize_se3_params_jax(ref_points=ref_points, fit_points=fit_points, num_repeats=num_repeats)
if trans_centers is None:
se3_params = _initialize_se3_params(ref_points=torch.Tensor(np.array(ref_points)),
fit_points=torch.Tensor(np.array(fit_points)),
num_repeats=num_repeats).detach()
if num_repeats == 1:
se3_params = se3_params.unsqueeze(0)
else:
se3_params = _initialize_se3_params_with_translations(
ref_points=torch.Tensor(np.array(ref_points)),
fit_points=torch.Tensor(np.array(fit_points)),
trans_centers=torch.Tensor(np.array(trans_centers)),
num_repeats_per_trans=10).detach()
if len(se3_params.shape) == 1:
se3_params = se3_params.unsqueeze(0)
se3_params = jnp.array(se3_params)
ref_centers_w_H = convert_to_jnp_array(ref_centers_w_H)
fit_centers_w_H = convert_to_jnp_array(fit_centers_w_H)
ref_centers = convert_to_jnp_array(ref_centers)
fit_centers = convert_to_jnp_array(fit_centers)
ref_points = convert_to_jnp_array(ref_points)
fit_points = convert_to_jnp_array(fit_points)
ref_partial_charges = convert_to_jnp_array(ref_partial_charges)
fit_partial_charges = convert_to_jnp_array(fit_partial_charges)
ref_surf_esp = convert_to_jnp_array(ref_surf_esp)
fit_surf_esp = convert_to_jnp_array(fit_surf_esp)
ref_radii = convert_to_jnp_array(ref_radii)
fit_radii = convert_to_jnp_array(fit_radii)
# Pack arguments in the EXACT order expected by `objective_esp_combo_score_overlay_jax`
data_args = (
ref_centers_w_H, fit_centers_w_H,
ref_centers, fit_centers,
ref_points, fit_points,
ref_partial_charges, fit_partial_charges,
ref_surf_esp, fit_surf_esp,
ref_radii, fit_radii,
alpha, lam, probe_radius, esp_weight
)
if verbose:
# We can just call the single-item scorer or the batched helper here for the initial print
# Note: Using your existing helper logic for consistency
init_score = esp_combo_score_jax(ref_centers_w_H, fit_centers_w_H, ref_centers, fit_centers,
ref_points, fit_points, ref_partial_charges, fit_partial_charges,
ref_surf_esp, fit_surf_esp, ref_radii, fit_radii,
alpha, lam, probe_radius, esp_weight)
print(f'Initial ShaEP-inspired similarity score: {init_score:.3f}')
se3_opt = _generic_optimize_loop(
se3_params,
data_args,
jit_val_grad_obj_esp_combo_score_overlay,
lr,
max_num_steps
)
SE3_transform = vmap_get_SE3_transform_jax(se3_opt)
# Apply transformations to all relevant fit coordinate sets
aligned_points = vmap_apply_SE3_transform_jax(fit_points, SE3_transform)
aligned_centers_w_H = vmap_apply_SE3_transform_jax(fit_centers_w_H, SE3_transform)
aligned_centers = vmap_apply_SE3_transform_jax(fit_centers, SE3_transform)
# Recalculate scores for the final optimized positions
# (assuming vmap_esp_combo_score is defined in your scope, as implied by your snippet)
scores = vmap_esp_combo_score(
ref_centers_w_H, aligned_centers_w_H,
ref_centers, aligned_centers,
ref_points, aligned_points,
ref_partial_charges, fit_partial_charges,
ref_surf_esp, fit_surf_esp,
ref_radii, fit_radii,
alpha, lam, probe_radius, esp_weight
)
best_idx = jnp.argmax(scores)
if verbose:
print(f'Optimized ShaEP inspired similarity score -- max: {scores.max():.3f} | mean: {scores.mean():.3f}')
return (aligned_points.at[best_idx].get(),
SE3_transform.at[best_idx].get(),
scores.at[best_idx].get())
def _objective_pharm_overlay_jax(se3_params: Array,
ref_pharms: Array,
fit_pharms: Array,
ref_anchors: Array,
fit_anchors: Array,
ref_vectors: Array,
fit_vectors: Array,
similarity: _SIM_TYPE = 'tanimoto',
extended_points: bool = False,
only_extended: bool = False
) -> Array:
"""
Objective function to optimize pharmacophore overlay for a single instance.
"""
se3_matrix = get_SE3_transform_jax(se3_params)
fit_anchors_transformed = apply_SE3_transform_jax(fit_anchors, se3_matrix)
fit_vectors_transformed = apply_SO3_transform_jax(fit_vectors, se3_matrix)
score = get_overlap_pharm_jax(ptype_1=ref_pharms,
ptype_2=fit_pharms,
anchors_1=ref_anchors,
anchors_2=fit_anchors_transformed,
vectors_1=ref_vectors,
vectors_2=fit_vectors_transformed,
similarity=similarity,
extended_points=extended_points,
only_extended=only_extended)
return score
batched_obj_pharm_overlay_helper = vmap(_objective_pharm_overlay_jax, (0, None, None, None, None, None, None, None, None, None))
[docs]
def objective_pharm_overlay_jax(se3_params: Array,
ref_pharms: Array,
fit_pharms: Array,
ref_anchors: Array,
fit_anchors: Array,
ref_vectors: Array,
fit_vectors: Array,
similarity: _SIM_TYPE = 'tanimoto',
extended_points: bool = False,
only_extended: bool = False
) -> Array:
"""
Objective function to optimize pharmacophore overlay. Batched.
"""
scores = batched_obj_pharm_overlay_helper(se3_params,
ref_pharms,
fit_pharms,
ref_anchors,
fit_anchors,
ref_vectors,
fit_vectors,
similarity,
extended_points,
only_extended)
return 1 - scores.mean()
jit_val_grad_obj_pharm_overlay = jit(value_and_grad(objective_pharm_overlay_jax), static_argnames=('similarity', 'extended_points', 'only_extended'))
[docs]
def optimize_pharm_overlay_jax(ref_pharms: Array,
fit_pharms: Array,
ref_anchors: Array,
fit_anchors: Array,
ref_vectors: Array,
fit_vectors: Array,
similarity: _SIM_TYPE = 'tanimoto',
extended_points: bool = False,
only_extended: bool = False,
num_repeats: int = 50,
trans_centers: Union[Array, np.ndarray, None] = None,
lr: float = 0.1,
max_num_steps: int = 200,
verbose: bool = False
) -> Tuple[Array, Array, Array, Array]:
"""
Optimize alignment of fit_anchors with respect to ref_anchors using SE(3) transformations and
maximizing pharmacophore overlap score. JAX implementation.
"""
if trans_centers is None:
se3_params = _initialize_se3_params(ref_points=torch.Tensor(np.array(ref_anchors)),
fit_points=torch.Tensor(np.array(fit_anchors)),
num_repeats=num_repeats).detach()
if num_repeats == 1:
se3_params = se3_params.unsqueeze(0)
else:
se3_params = _initialize_se3_params_with_translations(
ref_points=torch.Tensor(np.array(ref_anchors)),
fit_points=torch.Tensor(np.array(fit_anchors)),
trans_centers=torch.Tensor(np.array(trans_centers)),
num_repeats_per_trans=10).detach()
if len(se3_params.shape) == 1:
se3_params.unsqueeze(0)
se3_params = jnp.array(se3_params)
current_num_repeats = se3_params.shape[0]
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(se3_params)
if verbose:
init_score = get_overlap_pharm_jax(ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, similarity, extended_points, only_extended)
print(f'Initial pharmacophore similarity score: {init_score:.3f}')
last_loss = 1
counter = 0
for step in range(max_num_steps):
loss, grads = jit_val_grad_obj_pharm_overlay(se3_params, ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, similarity, extended_points, only_extended)
updates, opt_state = optimizer.update(grads, opt_state, se3_params)
se3_params = optax.apply_updates(se3_params, updates)
if abs(loss - last_loss) > 1e-5:
counter = 0
else:
counter += 1
last_loss = loss
if counter > 10:
break
SE3_transform = vmap_get_SE3_transform_jax(se3_params)
aligned_anchors = vmap_apply_SE3_transform_jax(fit_anchors, SE3_transform)
aligned_vectors = vmap_apply_SO3_transform_jax(fit_vectors, SE3_transform)
scores = vmap(get_overlap_pharm_jax, (None, None, None, 0, None, 0, None, None, None))(
ref_pharms, fit_pharms, ref_anchors, aligned_anchors, ref_vectors, aligned_vectors, similarity, extended_points, only_extended
)
if current_num_repeats == 1:
if verbose:
print(f'Optimized pharmacophore similarity score: {scores.squeeze():.3f}')
best_alignment = aligned_anchors.squeeze()
best_aligned_vectors = aligned_vectors.squeeze()
best_transform = SE3_transform.squeeze()
best_score = scores.squeeze()
else:
if verbose:
print(f'Optimized pharmacophore similarity score -- max: {scores.max():.3f} | mean: {scores.mean():.3f} | min: {scores.min():.3f}')
best_idx = jnp.argmax(scores)
best_alignment = aligned_anchors[best_idx]
best_aligned_vectors = aligned_vectors[best_idx]
best_transform = SE3_transform[best_idx]
best_score = scores[best_idx]
return best_alignment, best_aligned_vectors, best_transform, best_score
def _loss_fn_pharm_vectorized(
se3_params,
ref_pharms, fit_pharms,
ref_anchors, fit_anchors,
ref_vectors, fit_vectors,
ref_self_score, fit_self_score,
similarity='tanimoto',
extended_points=False,
only_extended=False
):
"""
Vectorized batched loss for pharmacophore overlay optimization.
NOT jit-compiled directly so that a single stable jit(value_and_grad(...))
wrapper can be created at module level and reused across all calls.
Parameters
----------
se3_params : Array (num_repeats, 7)
ref_pharms, fit_pharms : Array (N,), (M,)
ref_anchors, fit_anchors : Array (N,3), (M,3)
ref_vectors, fit_vectors : Array (N,3), (M,3)
ref_self_score, fit_self_score : scalar Arrays
Pre-computed self-overlaps (VAA, VBB). Pass as module-level constants
for the duration of one molecule-pair optimization.
similarity : str
extended_points, only_extended : bool
"""
se3_matrices = vmap(get_SE3_transform_jax)(se3_params)
fit_anchors_transformed = vmap(apply_SE3_transform_jax, (None, 0))(fit_anchors, se3_matrices)
fit_vectors_transformed = vmap(apply_SO3_transform_jax, (None, 0))(fit_vectors, se3_matrices)
vab_scores = vmap(
partial(get_overlap_pharm_jax_vectorized,
extended_points=extended_points,
only_extended=only_extended),
in_axes=(None, None, None, 0, None, 0)
)(ref_pharms, fit_pharms, ref_anchors, fit_anchors_transformed,
ref_vectors, fit_vectors_transformed)
eps = 1e-6
if similarity == 'tanimoto':
scores = vab_scores / (ref_self_score + fit_self_score - vab_scores + eps)
elif similarity == 'tversky_ref':
scores = vab_scores / (ref_self_score + eps)
elif similarity == 'tversky_fit':
scores = vab_scores / (fit_self_score + eps)
else:
scores = vab_scores
return 1.0 - jnp.mean(scores)
jit_val_grad_pharm_vectorized = jit(
value_and_grad(_loss_fn_pharm_vectorized),
static_argnames=('similarity', 'extended_points', 'only_extended')
)
@lru_cache(maxsize=8)
def _make_jit_val_grad_pharm_vectorized(similarity: str,
extended_points: bool,
only_extended: bool):
"""
Return a JIT-compiled value_and_grad function with the three static
arguments baked in via closure.
lru_cache guarantees that repeated calls with the same
(similarity, extended_points, only_extended) return the *same Python
object*. _generic_optimize_loop takes val_and_grad_fn as a static argument,
so it recompiles whenever it sees a new object.
By caching here we ensure _generic_optimize_loop's JIT cache is hit on every call.
"""
def _loss(se3_params,
ref_pharms, fit_pharms,
ref_anchors, fit_anchors,
ref_vectors, fit_vectors,
ref_self_score, fit_self_score):
return _loss_fn_pharm_vectorized(
se3_params,
ref_pharms, fit_pharms,
ref_anchors, fit_anchors,
ref_vectors, fit_vectors,
ref_self_score, fit_self_score,
similarity=similarity,
extended_points=extended_points,
only_extended=only_extended,
)
return jit(value_and_grad(_loss))
[docs]
def optimize_pharm_overlay_jax_vectorized(
ref_pharms: Array,
fit_pharms: Array,
ref_anchors: Array,
fit_anchors: Array,
ref_vectors: Array,
fit_vectors: Array,
similarity: str = 'tanimoto',
extended_points: bool = False,
only_extended: bool = False,
num_repeats: int = 50,
trans_centers: Union[Array, np.ndarray, None] = None,
lr: float = 0.1,
max_num_steps: int = 200,
verbose: bool = False
) -> Tuple[Array, Array, Array, Array]:
"""
Optimize pharmacophore overlay using the fully-vectorized scoring function using jax.
"""
# 1. Initialization
if trans_centers is None:
se3_params = _initialize_se3_params(ref_points=torch.Tensor(np.array(ref_anchors)),
fit_points=torch.Tensor(np.array(fit_anchors)),
num_repeats=num_repeats).detach()
if num_repeats == 1:
se3_params = se3_params.unsqueeze(0)
else:
se3_params = _initialize_se3_params_with_translations(
ref_points=torch.Tensor(np.array(ref_anchors)),
fit_points=torch.Tensor(np.array(fit_anchors)),
trans_centers=torch.Tensor(np.array(trans_centers)),
num_repeats_per_trans=10).detach()
if len(se3_params.shape) == 1:
se3_params = se3_params.unsqueeze(0)
se3_params = jnp.array(se3_params)
current_num_repeats = se3_params.shape[0]
# 2. Pre-compute self-scores once (scalars, not differentiated)
ref_self_score = get_overlap_pharm_jax_vectorized(
ref_pharms, ref_pharms, ref_anchors, ref_anchors, ref_vectors, ref_vectors,
extended_points=extended_points, only_extended=only_extended
)
fit_self_score = get_overlap_pharm_jax_vectorized(
fit_pharms, fit_pharms, fit_anchors, fit_anchors, fit_vectors, fit_vectors,
extended_points=extended_points, only_extended=only_extended
)
if verbose:
init_score = get_overlap_pharm_jax(ref_pharms, fit_pharms, ref_anchors, fit_anchors,
ref_vectors, fit_vectors, similarity, extended_points, only_extended)
print(f'Initial pharmacophore similarity score: {init_score:.3f}')
# 3. Optimization with lax while loop
# _make_jit_val_grad_pharm_vectorized is lru_cache'd
# guaranteeing _generic_optimize_loop never recompiles across molecule pairs.
val_and_grad_fn = _make_jit_val_grad_pharm_vectorized(similarity, extended_points, only_extended)
data_args = (ref_pharms, fit_pharms, ref_anchors, fit_anchors,
ref_vectors, fit_vectors, ref_self_score, fit_self_score)
se3_params = _generic_optimize_loop(
se3_params=se3_params,
data_args=data_args,
val_and_grad_fn=val_and_grad_fn,
lr=lr,
max_num_steps=max_num_steps
)
# 4. Final output
SE3_transform = vmap_get_SE3_transform_jax(se3_params)
aligned_anchors = vmap_apply_SE3_transform_jax(fit_anchors, SE3_transform)
aligned_vectors = vmap_apply_SO3_transform_jax(fit_vectors, SE3_transform)
vab_scores = vmap(
partial(get_overlap_pharm_jax_vectorized,
extended_points=extended_points,
only_extended=only_extended),
in_axes=(None, None, None, 0, None, 0)
)(ref_pharms, fit_pharms, ref_anchors, aligned_anchors,
ref_vectors, aligned_vectors)
eps = 1e-6
if similarity == 'tanimoto':
final_scores = vab_scores / (ref_self_score + fit_self_score - vab_scores + eps)
elif similarity == 'tversky_ref':
final_scores = vab_scores / (ref_self_score + eps)
elif similarity == 'tversky_fit':
final_scores = vab_scores / (fit_self_score + eps)
else:
final_scores = vab_scores
best_idx = jnp.argmax(final_scores)
if verbose:
print(f'Optimized pharmacophore similarity score -- max: {final_scores.max():.3f} | mean: {final_scores.mean():.3f}')
if current_num_repeats == 1:
return (aligned_anchors.squeeze(), aligned_vectors.squeeze(),
SE3_transform.squeeze(), final_scores.squeeze())
return (aligned_anchors[best_idx], aligned_vectors[best_idx],
SE3_transform[best_idx], final_scores[best_idx])
def _loss_fn_pharm_vectorized_mask(
se3_params,
ref_pharms, fit_pharms,
ref_anchors, fit_anchors,
ref_vectors, fit_vectors,
mask_ref, mask_fit,
ref_self_score, fit_self_score,
similarity='tanimoto',
extended_points=False,
only_extended=False
):
"""
Batched loss for masked pharmacophore overlay optimization.
Parameters
----------
se3_params : Array (num_repeats, 7)
ref_pharms, fit_pharms : Array (N,), (M,) — padded type indices
ref_anchors, fit_anchors : Array (N,3), (M,3) — padded
ref_vectors, fit_vectors : Array (N,3), (M,3) — padded
mask_ref : Array (N,) — 1.0 for real, 0.0 for padding
mask_fit : Array (M,) — 1.0 for real, 0.0 for padding
ref_self_score, fit_self_score : scalar Arrays
similarity : str
extended_points, only_extended : bool
"""
se3_matrices = vmap(get_SE3_transform_jax)(se3_params)
fit_anchors_transformed = vmap(apply_SE3_transform_jax, (None, 0))(fit_anchors, se3_matrices)
fit_vectors_transformed = vmap(apply_SO3_transform_jax, (None, 0))(fit_vectors, se3_matrices)
# mask_ref / mask_fit are the same for all repeats — not vmapped
vab_scores = vmap(
partial(get_overlap_pharm_jax_vectorized_mask,
extended_points=extended_points,
only_extended=only_extended),
in_axes=(None, None, None, 0, None, 0, None, None)
)(ref_pharms, fit_pharms, ref_anchors, fit_anchors_transformed,
ref_vectors, fit_vectors_transformed, mask_ref, mask_fit)
eps = 1e-6
if similarity == 'tanimoto':
scores = vab_scores / (ref_self_score + fit_self_score - vab_scores + eps)
elif similarity == 'tversky_ref':
scores = vab_scores / (ref_self_score + eps)
elif similarity == 'tversky_fit':
scores = vab_scores / (fit_self_score + eps)
else:
scores = vab_scores
return 1.0 - jnp.mean(scores)
@lru_cache(maxsize=8)
def _make_jit_val_grad_pharm_vectorized_mask(similarity: str,
extended_points: bool,
only_extended: bool):
"""
Return a JIT-compiled value_and_grad function for masked pharmacophore
alignment, with static args baked in via closure and lru_cache'd so that
_generic_optimize_loop never recompiles across molecule pairs.
"""
def _loss(se3_params,
ref_pharms, fit_pharms,
ref_anchors, fit_anchors,
ref_vectors, fit_vectors,
mask_ref, mask_fit,
ref_self_score, fit_self_score):
return _loss_fn_pharm_vectorized_mask(
se3_params,
ref_pharms, fit_pharms,
ref_anchors, fit_anchors,
ref_vectors, fit_vectors,
mask_ref, mask_fit,
ref_self_score, fit_self_score,
similarity=similarity,
extended_points=extended_points,
only_extended=only_extended,
)
return jit(value_and_grad(_loss))
[docs]
def optimize_pharm_overlay_jax_vectorized_mask(
ref_pharms: Array,
fit_pharms: Array,
ref_anchors: Array,
fit_anchors: Array,
ref_vectors: Array,
fit_vectors: Array,
mask_ref: Array,
mask_fit: Array,
similarity: str = 'tanimoto',
extended_points: bool = False,
only_extended: bool = False,
num_repeats: int = 50,
trans_centers: Union[Array, np.ndarray, None] = None,
init_ref_anchors: Union[np.ndarray, None] = None,
init_fit_anchors: Union[np.ndarray, None] = None,
lr: float = 0.1,
max_num_steps: int = 200,
verbose: bool = False
) -> Tuple[Array, Array, Array, Array]:
"""
Optimize pharmacophore overlay with padded/masked arrays via JAX.
Uses ``get_overlap_pharm_jax_vectorized_mask`` so that padding entries never
contribute to the overlap. Accepts ``init_ref_anchors`` / ``init_fit_anchors``
(original unpadded arrays) for PCA/COM-based SE(3) initialization to avoid
zero-padding bias.
"""
# Use original unpadded anchors for SE3 initialization if provided
se3_ref = init_ref_anchors if init_ref_anchors is not None else np.array(ref_anchors)
se3_fit = init_fit_anchors if init_fit_anchors is not None else np.array(fit_anchors)
if trans_centers is None:
se3_params = _initialize_se3_params(
ref_points=torch.Tensor(se3_ref),
fit_points=torch.Tensor(se3_fit),
num_repeats=num_repeats
).detach()
if num_repeats == 1:
se3_params = se3_params.unsqueeze(0)
else:
se3_params = _initialize_se3_params_with_translations(
ref_points=torch.Tensor(se3_ref),
fit_points=torch.Tensor(se3_fit),
trans_centers=torch.Tensor(np.array(trans_centers)),
num_repeats_per_trans=10
).detach()
if len(se3_params.shape) == 1:
se3_params = se3_params.unsqueeze(0)
se3_params = jnp.array(se3_params)
current_num_repeats = se3_params.shape[0]
# Pre-compute self-overlaps once (scalars)
ref_self_score = get_overlap_pharm_jax_vectorized_mask(
ref_pharms, ref_pharms, ref_anchors, ref_anchors, ref_vectors, ref_vectors,
mask_ref, mask_ref,
extended_points=extended_points, only_extended=only_extended
)
fit_self_score = get_overlap_pharm_jax_vectorized_mask(
fit_pharms, fit_pharms, fit_anchors, fit_anchors, fit_vectors, fit_vectors,
mask_fit, mask_fit,
extended_points=extended_points, only_extended=only_extended
)
# Optimization
val_and_grad_fn = _make_jit_val_grad_pharm_vectorized_mask(similarity, extended_points, only_extended)
data_args = (ref_pharms, fit_pharms, ref_anchors, fit_anchors,
ref_vectors, fit_vectors, mask_ref, mask_fit,
ref_self_score, fit_self_score)
se3_params = _generic_optimize_loop(
se3_params=se3_params,
data_args=data_args,
val_and_grad_fn=val_and_grad_fn,
lr=lr,
max_num_steps=max_num_steps
)
# Final output
SE3_transform = vmap_get_SE3_transform_jax(se3_params)
aligned_anchors = vmap_apply_SE3_transform_jax(fit_anchors, SE3_transform)
aligned_vectors = vmap_apply_SO3_transform_jax(fit_vectors, SE3_transform)
vab_scores = vmap(
partial(get_overlap_pharm_jax_vectorized_mask,
extended_points=extended_points,
only_extended=only_extended),
in_axes=(None, None, None, 0, None, 0, None, None)
)(ref_pharms, fit_pharms, ref_anchors, aligned_anchors,
ref_vectors, aligned_vectors, mask_ref, mask_fit)
eps = 1e-6
if similarity == 'tanimoto':
final_scores = vab_scores / (ref_self_score + fit_self_score - vab_scores + eps)
elif similarity == 'tversky_ref':
final_scores = vab_scores / (ref_self_score + eps)
elif similarity == 'tversky_fit':
final_scores = vab_scores / (fit_self_score + eps)
else:
final_scores = vab_scores
best_idx = jnp.argmax(final_scores)
if verbose:
print(f'Optimized pharmacophore similarity score -- max: {final_scores.max():.3f} | mean: {final_scores.mean():.3f}')
if current_num_repeats == 1:
return (aligned_anchors.squeeze(), aligned_vectors.squeeze(),
SE3_transform.squeeze(), final_scores.squeeze())
return (aligned_anchors[best_idx], aligned_vectors[best_idx],
SE3_transform[best_idx], final_scores[best_idx])