Analytical Gradients#
PyTorch analytical gradient implementations for shape, ESP, and pharmacophore alignment. These replace autograd with hand-derived gradients for ~2–3.5x speedup.
Analytical gradients package for shape and pharmacophore alignment scoring.
Re-exports all PyTorch implementations from _torch.py for backwards compatibility. JAX implementations are in _jax.py (optional dependency).
- shepherd_score.score.analytical_gradients.rotation_matrix_jacobians_quat(q)[source]#
Compute the four 3x3 Jacobians dR/dq_k for k in {w, x, y, z}.
Assumes q is a unit quaternion (or batch of unit quaternions).
- Parameters:
q (torch.Tensor (4,) or (B, 4)) – Unit quaternion(s) in (w, x, y, z) order.
- Returns:
dR_dqw, dR_dqx, dR_dqy, dR_dqz
- Return type:
each torch.Tensor (3,3) or (B,3,3)
- shepherd_score.score.analytical_gradients.project_grad_R_to_quaternion(G, q)[source]#
Project gradient w.r.t. rotation matrix R onto quaternion parameters.
dL/dq_k = Tr(G^T @ dR/dq_k) = sum_{ij} G_ij * (dR/dq_k)_ij
- Parameters:
G (torch.Tensor (3,3) or (B,3,3)) – Gradient w.r.t. rotation matrix.
q (torch.Tensor (4,) or (B,4)) – Unit quaternion.
- Returns:
grad_q
- Return type:
torch.Tensor (4,) or (B,4)
- shepherd_score.score.analytical_gradients.build_lookup_tables_cached(device_str, dtype_str)[source]#
- shepherd_score.score.analytical_gradients.build_lookup_tables(device, dtype)[source]#
Build constant lookup tables for all P_TYPES. categories: 0=_NONDIRECTIONAL, 1=_DIRECTIONAL, 2=_AROMATIC, 3=Dummy
- Parameters:
device (torch.device)
dtype (torch.dtype)
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- shepherd_score.score.analytical_gradients.compute_overlap_and_grad_pharm(R, t, ref_pharms, fit_pharms, ref_anchors, fit_anchors_orig, ref_vectors, fit_vectors_orig, extended_points=False, only_extended=False)[source]#
Compute overlap O_AB and gradients fully vectorized across all types.
When
extended_points=True, directional types (Acceptor, Donor, Halogen, cat==1) are scored using a plain Gaussian overlap at the extended point positions (anchor + normalized vector) instead of cosine weighting. Ifonly_extended=Truethe anchor overlap for those types is skipped entirely; otherwise both anchor (w=1) and extended overlaps are included. Non-directional (cat==0) and Aromatic (cat==2) types are unaffected.- Parameters:
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- shepherd_score.score.analytical_gradients.compute_self_overlaps_pharm(ptype_1, ptype_2, anchors_1, anchors_2, vectors_1, vectors_2, extended_points=False, only_extended=False)[source]#
Compute self-overlaps VAA and VBB vectorially.
- shepherd_score.score.analytical_gradients.apply_tanimoto_chain_rule(O_AB, U, grad_R, grad_t)[source]#
- Parameters:
O_AB (torch.Tensor)
U (torch.Tensor)
grad_R (torch.Tensor)
grad_t (torch.Tensor)
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- shepherd_score.score.analytical_gradients.apply_tversky_chain_rule(O_AB, D, grad_R, grad_t)[source]#
- Parameters:
O_AB (torch.Tensor)
D (torch.Tensor)
grad_R (torch.Tensor)
grad_t (torch.Tensor)
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- shepherd_score.score.analytical_gradients.compute_overlap_and_grad_shape(R, t, ref_points, fit_points_orig, alpha, pair_weights=None)[source]#
Compute shape overlap O_AB and gradients w.r.t. rotation matrix R and translation t.
Shape scoring uses uniform weight w=1 and a single alpha, so there is no weight gradient term (unlike pharmacophore scoring).
- Parameters:
R (torch.Tensor (3,3) or (B,3,3))
t (torch.Tensor (3,) or (B,3))
ref_points (torch.Tensor (N,3) or (B,N,3))
fit_points_orig (torch.Tensor (M,3) or (B,M,3))
alpha (float)
pair_weights (torch.Tensor (M,N) or (B,M,N) or None) – Optional per-pair multiplicative weights. If None, uniform weight 1 is used (standard shape scoring). For ESP scoring, pass the charge-based weights exp(-||v_a - v_b||^2 / lam), which are SE(3)-invariant.
- Returns:
O_AB (torch.Tensor scalar or (B,))
grad_R (torch.Tensor (3,3) or (B,3,3))
grad_t (torch.Tensor (3,) or (B,3))
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- shepherd_score.score.analytical_gradients.compute_self_overlaps_shape(ref_points, fit_points, alpha)[source]#
Compute self-overlaps VAA and VBB for shape scoring. Invariant to SE(3).
- Parameters:
ref_points (torch.Tensor (N,3))
fit_points (torch.Tensor (M,3))
alpha (float)
- Returns:
VAA, VBB
- Return type:
torch.Tensor scalars
- shepherd_score.score.analytical_gradients.compute_self_overlaps_esp(ref_points, fit_points, ref_charges, fit_charges, alpha, lam)[source]#
Compute ESP self-overlaps VAA and VBB. Invariant to SE(3).
- shepherd_score.score.analytical_gradients.compute_analytical_grad_se3_esp(se3_params, ref_points, fit_points, ref_charges, fit_charges, alpha, lam, VAA_total, VBB_total)[source]#
Compute loss (1 - Tanimoto ESP similarity) and its gradient w.r.t. SE(3) parameters using analytical gradients.
The ESP pair weights exp(-||v_a - v_b||^2 / lam) are SE(3)-invariant (charges are fixed to their points), so they are precomputed once and treated as constants.
- Parameters:
se3_params (torch.Tensor (7,) or (B,7)) – [q_w, q_x, q_y, q_z, t_x, t_y, t_z]
ref_points (torch.Tensor (N,3) or (B,N,3))
fit_points (torch.Tensor (M,3) or (B,M,3))
ref_charges (torch.Tensor (N,))
fit_charges (torch.Tensor (M,))
alpha (float)
lam (float)
VAA_total (torch.Tensor scalar)
VBB_total (torch.Tensor scalar)
- Returns:
loss (torch.Tensor scalar)
grad_se3 (torch.Tensor (7,) or (B,7))
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- shepherd_score.score.analytical_gradients.compute_analytical_grad_se3_shape(se3_params, ref_points, fit_points, alpha, VAA_total, VBB_total)[source]#
Compute loss (1 - Tanimoto shape similarity) and its gradient w.r.t. SE(3) parameters using analytical gradients.
- Parameters:
se3_params (torch.Tensor (7,) or (B,7)) – [q_w, q_x, q_y, q_z, t_x, t_y, t_z]
ref_points (torch.Tensor (N,3) or (B,N,3))
fit_points (torch.Tensor (M,3) or (B,M,3))
alpha (float)
VAA_total (torch.Tensor scalar) – Precomputed self-overlap for ref_points.
VBB_total (torch.Tensor scalar) – Precomputed self-overlap for fit_points.
- Returns:
loss (torch.Tensor scalar)
grad_se3 (torch.Tensor (7,) or (B,7))
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- shepherd_score.score.analytical_gradients.compute_avoid_and_grad(R, t, fit_pts_avoid_orig, avoid_points, min_dist)[source]#
Compute the linear hard-sphere avoid term and its gradients w.r.t. R and t.
A = sum_{a in fit_avoid, b in avoid} relu((min_dist - dist(P’_a, P_b)) / min_dist)
where P’_a = R @ P_a + t.
- Parameters:
R (torch.Tensor (3,3) or (B,3,3))
t (torch.Tensor (3,) or (B,3))
fit_pts_avoid_orig (torch.Tensor (M,3) or (B,M,3)) – Fit points to penalize (in original frame, before transformation).
avoid_points (torch.Tensor (K,3)) – Fixed reference points to avoid (never batched).
min_dist (float) – Distance threshold below which overlap is penalized.
- Returns:
A (torch.Tensor scalar or (B,))
grad_R (torch.Tensor (3,3) or (B,3,3))
grad_t (torch.Tensor (3,) or (B,3))
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- shepherd_score.score.analytical_gradients.compute_analytical_grad_se3_shape_with_avoid(se3_params, ref_points, fit_points, alpha, VAA_total, VBB_total, fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)[source]#
Compute loss and gradient for shape alignment with an avoid-points penalty.
loss = (1 - Tanimoto_shape) + avoid_weight * hard_sphere_overlap(fit_avoid, avoid)
- Parameters:
se3_params (torch.Tensor (7,) or (B,7))
ref_points (torch.Tensor (N,3) or (B,N,3))
fit_points (torch.Tensor (M,3) or (B,M,3))
alpha (float)
VAA_total (torch.Tensor scalar)
VBB_total (torch.Tensor scalar)
fit_points_for_avoid (torch.Tensor (M2,3) or (B,M2,3))
avoid_points (torch.Tensor (K,3))
avoid_min_dist (float)
avoid_weight (float)
- Returns:
loss (torch.Tensor scalar)
grad_se3 (torch.Tensor (7,) or (B,7))
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- shepherd_score.score.analytical_gradients.compute_analytical_grad_se3(se3_params, ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, VAA_total, VBB_total, similarity='tanimoto', sigma=0.5, extended_points=False, only_extended=False)[source]#
- Parameters:
se3_params (torch.Tensor)
ref_pharms (torch.Tensor)
fit_pharms (torch.Tensor)
ref_anchors (torch.Tensor)
fit_anchors (torch.Tensor)
ref_vectors (torch.Tensor)
fit_vectors (torch.Tensor)
VAA_total (torch.Tensor)
VBB_total (torch.Tensor)
similarity (str)
sigma (float)
extended_points (bool)
only_extended (bool)
- Return type:
Tuple[torch.Tensor, torch.Tensor]