Analytical Gradients#

PyTorch analytical gradient implementations for shape, ESP, and pharmacophore alignment. These replace autograd with hand-derived gradients for ~2–3.5x speedup.

Analytical gradients package for shape and pharmacophore alignment scoring.

Re-exports all PyTorch implementations from _torch.py for backwards compatibility. JAX implementations are in _jax.py (optional dependency).

shepherd_score.score.analytical_gradients.rotation_matrix_jacobians_quat(q)[source]#

Compute the four 3x3 Jacobians dR/dq_k for k in {w, x, y, z}.

Assumes q is a unit quaternion (or batch of unit quaternions).

Parameters:

q (torch.Tensor (4,) or (B, 4)) – Unit quaternion(s) in (w, x, y, z) order.

Returns:

dR_dqw, dR_dqx, dR_dqy, dR_dqz

Return type:

each torch.Tensor (3,3) or (B,3,3)

shepherd_score.score.analytical_gradients.project_grad_R_to_quaternion(G, q)[source]#

Project gradient w.r.t. rotation matrix R onto quaternion parameters.

dL/dq_k = Tr(G^T @ dR/dq_k) = sum_{ij} G_ij * (dR/dq_k)_ij

Parameters:
  • G (torch.Tensor (3,3) or (B,3,3)) – Gradient w.r.t. rotation matrix.

  • q (torch.Tensor (4,) or (B,4)) – Unit quaternion.

Returns:

grad_q

Return type:

torch.Tensor (4,) or (B,4)

shepherd_score.score.analytical_gradients.build_lookup_tables_cached(device_str, dtype_str)[source]#
Parameters:
  • device_str (str)

  • dtype_str (str)

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

shepherd_score.score.analytical_gradients.build_lookup_tables(device, dtype)[source]#

Build constant lookup tables for all P_TYPES. categories: 0=_NONDIRECTIONAL, 1=_DIRECTIONAL, 2=_AROMATIC, 3=Dummy

Parameters:
  • device (torch.device)

  • dtype (torch.dtype)

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

shepherd_score.score.analytical_gradients.compute_overlap_and_grad_pharm(R, t, ref_pharms, fit_pharms, ref_anchors, fit_anchors_orig, ref_vectors, fit_vectors_orig, extended_points=False, only_extended=False)[source]#

Compute overlap O_AB and gradients fully vectorized across all types.

When extended_points=True, directional types (Acceptor, Donor, Halogen, cat==1) are scored using a plain Gaussian overlap at the extended point positions (anchor + normalized vector) instead of cosine weighting. If only_extended=True the anchor overlap for those types is skipped entirely; otherwise both anchor (w=1) and extended overlaps are included. Non-directional (cat==0) and Aromatic (cat==2) types are unaffected.

Parameters:
  • R (torch.Tensor)

  • t (torch.Tensor)

  • ref_pharms (torch.Tensor)

  • fit_pharms (torch.Tensor)

  • ref_anchors (torch.Tensor)

  • fit_anchors_orig (torch.Tensor)

  • ref_vectors (torch.Tensor)

  • fit_vectors_orig (torch.Tensor)

  • extended_points (bool)

  • only_extended (bool)

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

shepherd_score.score.analytical_gradients.compute_self_overlaps_pharm(ptype_1, ptype_2, anchors_1, anchors_2, vectors_1, vectors_2, extended_points=False, only_extended=False)[source]#

Compute self-overlaps VAA and VBB vectorially.

Parameters:
  • ptype_1 (torch.Tensor)

  • ptype_2 (torch.Tensor)

  • anchors_1 (torch.Tensor)

  • anchors_2 (torch.Tensor)

  • vectors_1 (torch.Tensor)

  • vectors_2 (torch.Tensor)

  • extended_points (bool)

  • only_extended (bool)

Return type:

Tuple[torch.Tensor, torch.Tensor]

shepherd_score.score.analytical_gradients.apply_tanimoto_chain_rule(O_AB, U, grad_R, grad_t)[source]#
Parameters:
  • O_AB (torch.Tensor)

  • U (torch.Tensor)

  • grad_R (torch.Tensor)

  • grad_t (torch.Tensor)

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

shepherd_score.score.analytical_gradients.apply_tversky_chain_rule(O_AB, D, grad_R, grad_t)[source]#
Parameters:
  • O_AB (torch.Tensor)

  • D (torch.Tensor)

  • grad_R (torch.Tensor)

  • grad_t (torch.Tensor)

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

shepherd_score.score.analytical_gradients.compute_overlap_and_grad_shape(R, t, ref_points, fit_points_orig, alpha, pair_weights=None)[source]#

Compute shape overlap O_AB and gradients w.r.t. rotation matrix R and translation t.

Shape scoring uses uniform weight w=1 and a single alpha, so there is no weight gradient term (unlike pharmacophore scoring).

Parameters:
  • R (torch.Tensor (3,3) or (B,3,3))

  • t (torch.Tensor (3,) or (B,3))

  • ref_points (torch.Tensor (N,3) or (B,N,3))

  • fit_points_orig (torch.Tensor (M,3) or (B,M,3))

  • alpha (float)

  • pair_weights (torch.Tensor (M,N) or (B,M,N) or None) – Optional per-pair multiplicative weights. If None, uniform weight 1 is used (standard shape scoring). For ESP scoring, pass the charge-based weights exp(-||v_a - v_b||^2 / lam), which are SE(3)-invariant.

Returns:

  • O_AB (torch.Tensor scalar or (B,))

  • grad_R (torch.Tensor (3,3) or (B,3,3))

  • grad_t (torch.Tensor (3,) or (B,3))

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

shepherd_score.score.analytical_gradients.compute_self_overlaps_shape(ref_points, fit_points, alpha)[source]#

Compute self-overlaps VAA and VBB for shape scoring. Invariant to SE(3).

Parameters:
  • ref_points (torch.Tensor (N,3))

  • fit_points (torch.Tensor (M,3))

  • alpha (float)

Returns:

VAA, VBB

Return type:

torch.Tensor scalars

shepherd_score.score.analytical_gradients.compute_self_overlaps_esp(ref_points, fit_points, ref_charges, fit_charges, alpha, lam)[source]#

Compute ESP self-overlaps VAA and VBB. Invariant to SE(3).

Parameters:
  • ref_points (torch.Tensor (N,3))

  • fit_points (torch.Tensor (M,3))

  • ref_charges (torch.Tensor (N,))

  • fit_charges (torch.Tensor (M,))

  • alpha (float)

  • lam (float)

Returns:

VAA, VBB

Return type:

torch.Tensor scalars

shepherd_score.score.analytical_gradients.compute_analytical_grad_se3_esp(se3_params, ref_points, fit_points, ref_charges, fit_charges, alpha, lam, VAA_total, VBB_total)[source]#

Compute loss (1 - Tanimoto ESP similarity) and its gradient w.r.t. SE(3) parameters using analytical gradients.

The ESP pair weights exp(-||v_a - v_b||^2 / lam) are SE(3)-invariant (charges are fixed to their points), so they are precomputed once and treated as constants.

Parameters:
  • se3_params (torch.Tensor (7,) or (B,7)) – [q_w, q_x, q_y, q_z, t_x, t_y, t_z]

  • ref_points (torch.Tensor (N,3) or (B,N,3))

  • fit_points (torch.Tensor (M,3) or (B,M,3))

  • ref_charges (torch.Tensor (N,))

  • fit_charges (torch.Tensor (M,))

  • alpha (float)

  • lam (float)

  • VAA_total (torch.Tensor scalar)

  • VBB_total (torch.Tensor scalar)

Returns:

  • loss (torch.Tensor scalar)

  • grad_se3 (torch.Tensor (7,) or (B,7))

Return type:

Tuple[torch.Tensor, torch.Tensor]

shepherd_score.score.analytical_gradients.compute_analytical_grad_se3_shape(se3_params, ref_points, fit_points, alpha, VAA_total, VBB_total)[source]#

Compute loss (1 - Tanimoto shape similarity) and its gradient w.r.t. SE(3) parameters using analytical gradients.

Parameters:
  • se3_params (torch.Tensor (7,) or (B,7)) – [q_w, q_x, q_y, q_z, t_x, t_y, t_z]

  • ref_points (torch.Tensor (N,3) or (B,N,3))

  • fit_points (torch.Tensor (M,3) or (B,M,3))

  • alpha (float)

  • VAA_total (torch.Tensor scalar) – Precomputed self-overlap for ref_points.

  • VBB_total (torch.Tensor scalar) – Precomputed self-overlap for fit_points.

Returns:

  • loss (torch.Tensor scalar)

  • grad_se3 (torch.Tensor (7,) or (B,7))

Return type:

Tuple[torch.Tensor, torch.Tensor]

shepherd_score.score.analytical_gradients.compute_avoid_and_grad(R, t, fit_pts_avoid_orig, avoid_points, min_dist)[source]#

Compute the linear hard-sphere avoid term and its gradients w.r.t. R and t.

A = sum_{a in fit_avoid, b in avoid} relu((min_dist - dist(P’_a, P_b)) / min_dist)

where P’_a = R @ P_a + t.

Parameters:
  • R (torch.Tensor (3,3) or (B,3,3))

  • t (torch.Tensor (3,) or (B,3))

  • fit_pts_avoid_orig (torch.Tensor (M,3) or (B,M,3)) – Fit points to penalize (in original frame, before transformation).

  • avoid_points (torch.Tensor (K,3)) – Fixed reference points to avoid (never batched).

  • min_dist (float) – Distance threshold below which overlap is penalized.

Returns:

  • A (torch.Tensor scalar or (B,))

  • grad_R (torch.Tensor (3,3) or (B,3,3))

  • grad_t (torch.Tensor (3,) or (B,3))

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

shepherd_score.score.analytical_gradients.compute_analytical_grad_se3_shape_with_avoid(se3_params, ref_points, fit_points, alpha, VAA_total, VBB_total, fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)[source]#

Compute loss and gradient for shape alignment with an avoid-points penalty.

loss = (1 - Tanimoto_shape) + avoid_weight * hard_sphere_overlap(fit_avoid, avoid)

Parameters:
  • se3_params (torch.Tensor (7,) or (B,7))

  • ref_points (torch.Tensor (N,3) or (B,N,3))

  • fit_points (torch.Tensor (M,3) or (B,M,3))

  • alpha (float)

  • VAA_total (torch.Tensor scalar)

  • VBB_total (torch.Tensor scalar)

  • fit_points_for_avoid (torch.Tensor (M2,3) or (B,M2,3))

  • avoid_points (torch.Tensor (K,3))

  • avoid_min_dist (float)

  • avoid_weight (float)

Returns:

  • loss (torch.Tensor scalar)

  • grad_se3 (torch.Tensor (7,) or (B,7))

Return type:

Tuple[torch.Tensor, torch.Tensor]

shepherd_score.score.analytical_gradients.compute_analytical_grad_se3(se3_params, ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, VAA_total, VBB_total, similarity='tanimoto', sigma=0.5, extended_points=False, only_extended=False)[source]#
Parameters:
  • se3_params (torch.Tensor)

  • ref_pharms (torch.Tensor)

  • fit_pharms (torch.Tensor)

  • ref_anchors (torch.Tensor)

  • fit_anchors (torch.Tensor)

  • ref_vectors (torch.Tensor)

  • fit_vectors (torch.Tensor)

  • VAA_total (torch.Tensor)

  • VBB_total (torch.Tensor)

  • similarity (str)

  • sigma (float)

  • extended_points (bool)

  • only_extended (bool)

Return type:

Tuple[torch.Tensor, torch.Tensor]