Source code for shepherd_score.score.electrostatic_scoring

"""
Gaussian volume overlap scoring functions combined with continuous electrostatics
PYTORCH VERSION.

Batched and non-batched functionalities

Reference math:
https://doi.org/10.1002/(SICI)1096-987X(19961115)17:14<1653::AID-JCC7>3.0.CO;2-K
https://doi.org/10.1021/j100011a016
"""
import numpy as np
import torch
from shepherd_score.score.constants import COULOMB_SCALING, LAM_SCALING
from shepherd_score.score.gaussian_overlap import get_overlap


[docs] def VAB_2nd_order_esp(centers_1: torch.Tensor, centers_2: torch.Tensor, charges_1: torch.Tensor, charges_2: torch.Tensor, alpha: float, lam: float ) -> torch.Tensor: """ Torch implementation. Handles batching""" R2_sq = torch.cdist(centers_1, centers_2)**2.0 C2_sq = torch.cdist(charges_1, charges_2)**2.0 if R2_sq.dim() == 2: # Both centers_1 and centers_2 were single instances # R2_sq and C2_sq have shape (N, M) VAB_2nd_order = torch.sum(np.pi**(1.5) \ / ((2*alpha)**(1.5)) \ * torch.exp(-(alpha / 2) * R2_sq) \ * torch.exp(-C2_sq/lam)) elif R2_sq.dim() == 3: # At least one of centers_1 or centers_2 was batched # R2_sq and C2_sq have shape (B, N, M) or (N, B, M) if c1 was batched and c2 single (not typical for cdist) # More typically, if c1=(N,D) c2=(B,M,D) -> (B,N,M) # Or if c1=(B,N,D) c2=(M,D) -> (B,N,M) # Or if c1=(B,N,D) c2=(B,M,D) -> (B,N,M) # We permute to (B, M, N) to match the batched summation logic R2 = R2_sq.permute(0,2,1) C2 = C2_sq.permute(0,2,1) VAB_2nd_order = torch.sum(torch.sum(np.pi**(1.5) \ / ((2*alpha)**(1.5)) \ * torch.exp(-(alpha / 2) * R2) \ * torch.exp(-C2/lam), dim = 2), # Sum over N dim = 1) # Sum over M else: raise ValueError(f"Unexpected dimension for R2_sq: {R2_sq.dim()}. centers_1: {centers_1.shape}, centers_2: {centers_2.shape}") return VAB_2nd_order
[docs] def shape_tanimoto_esp(centers_1: torch.Tensor, centers_2: torch.Tensor, charges_1: torch.Tensor, charges_2: torch.Tensor, alpha: float, lam: float ) -> torch.Tensor: """ Compute Tanimoto shape similarity """ VAA = VAB_2nd_order_esp(centers_1, centers_1, charges_1, charges_1, alpha, lam) VBB = VAB_2nd_order_esp(centers_2, centers_2, charges_2, charges_2, alpha, lam) VAB = VAB_2nd_order_esp(centers_1, centers_2, charges_1, charges_2, alpha, lam) return VAB / (VAA + VBB - VAB)
[docs] def get_overlap_esp(centers_1: torch.Tensor, centers_2: torch.Tensor, charges_1: torch.Tensor, charges_2: torch.Tensor, alpha: float = 0.81, lam: float = 0.3*LAM_SCALING ) -> torch.Tensor: """ Torch implementation. Compute electrostatic similarity which weights Gaussian volume overlap by electrostatics. The Tanimoto score is used. Typically `lam=0.3*LAM_SCALING` is used for surface point clouds and `lam=0.1` for partial charge weighted volumetric overlap. Parameters ---------- centers_1 : torch.Tensor (batch, N, 3) or (N, 3) Coordinates for the sets of points representing molecule 1. centers_2 : torch.Tensor (batch, N, 3) or (N, 3) Coordinates for the sets of points representing molecule 2. charges_1 : torch.Tensor (batch, N) or (N,) Electrostatic energy for the sets of points representing molecule 1. charges_2 : torch.Tensor (batch, N) or (N,) Electrostatic energy for the sets of points representing molecule 2. alpha : float Parameter controlling the width of the Gaussians. lam : float Parameter controlling the influence of electrostatics. Returns ------- tanimoto_esp : torch.Tensor (batch, 1) or (1,) Tanimoto similarities of electrostatics. """ if isinstance(centers_1, np.ndarray): centers_1 = torch.Tensor(centers_1) if isinstance(centers_2, np.ndarray): centers_2 = torch.Tensor(centers_2) if isinstance(charges_1, np.ndarray): charges_1 = torch.Tensor(charges_1) if isinstance(charges_2, np.ndarray): charges_2 = torch.Tensor(charges_2) if len(charges_1.shape) == 1: charges_1 = charges_1.reshape((-1,1)) elif len(charges_1.shape) == 2: charges_1 = charges_1.unsqueeze(2) if len(charges_2.shape) == 1: charges_2 = charges_2.reshape((-1,1)) elif len(charges_2.shape) == 2: charges_2 = charges_2.unsqueeze(2) tanimoto_esp = shape_tanimoto_esp(centers_1, centers_2, charges_1, charges_2, alpha, lam) return tanimoto_esp
def _esp_comparison(points_1: torch.Tensor, centers_w_H_2: torch.Tensor, partial_charges_2: torch.Tensor, points_charges_1: torch.Tensor, radii_2: torch.Tensor, probe_radius: float = 1.0, lam: float = 0.001 ) -> torch.Tensor: """ Helper function for computing the electrostatic potential (ESP) component of ShaEP score. It computes the difference in ESP at surface/observer points of molecule 1 for the ESP values generated by molecule 1 and molecule 2. It masks out observer points if they are in molecule 2's volume defined by vdW+probe_radius. Expects single instance or batched. This will ONLY check the shape of the points_1 to deterimine if it is batched or not so errors in the shape of the other tensors may or may not be caught. Parameters ---------- points_1 : torch.Tensor (N_surf, 3) or (batch, N_surf, 3) Surface points of molecule 1 for which ESP's will be computed and compared. centers_w_H_2 : torch.Tensor (M + m_H, 3) or (batch, M + m_H, 3) Coordinates for atoms (including hydrogens) of molecule 2. Used in calculation of ESP at points_1 and masking out those within molecule 2's volume. partial_charges_2 : torch.Tensor (M + m_H,) or (batch, M + m_H,) Partial charges corresponding to centers_w_H_2. Used to calculate ESP. points_charges_1 : torch.Tensor (N_surf,) or (batch, N_surf,) Precalculated ESP's of molecule 1 corresponding to points_1. radii_2 : torch.Tensor (M + m_H,) or (batch, M + m_H,) Radii of each atom corresponding to centers_w_H_2. Used for masking operation. probe_radius : float (default = 1.0) Probe radius (default is 1 angstrom). Surfaces assumed to be generated with vdW radius and a probe radius of 1.2 angstroms (vdW radius of hydrogen). 1.0 used rather than 1.2 as a tolerance. lam : float (default = 0.001) Electrostatic potential weighting parameter (smaller = higher weight). 0.001 was chosen as default based empirical observations of the distribution of scores generated before the summation in this function. Returns ------- torch.Tensor (1,) or (batch, 1) Point to point ESP comparison. Scores range: [0, N_surf]. Score decreases for differences in ESP or due to masking of poorly aligned surface points. """ lam = LAM_SCALING * lam distances = torch.cdist(points_1, centers_w_H_2) # single instance path for points_1 if len(points_1.shape) == 2: if distances.dim() == 2: # centers_w_H_2 is also single (M+m_H, 3) mask = torch.where(torch.all(distances >= radii_2 + probe_radius, axis=1), 1., 0.) # (N_surf,) esp_at_surf_1 = torch.matmul(partial_charges_2, 1 / distances.T) * COULOMB_SCALING # (N_surf,) esp_diff_sq = torch.square(points_charges_1 - esp_at_surf_1) # (N_surf,) esp = torch.sum(mask * torch.exp(-esp_diff_sq/lam)) # scalar elif distances.dim() == 3: # centers_w_H_2 is batched (B, M+m_H, 3) # mask out molecule 1 surface points that are within molecule 2 (batched) mask = torch.where(torch.all(distances >= radii_2.unsqueeze(1) + probe_radius, axis=2), 1., 0.) # (B, N_surf) esp_at_surf_1 = torch.matmul(partial_charges_2.unsqueeze(1), 1 / distances.permute(0,2,1)) * COULOMB_SCALING # (B, 1, N_surf) esp_diff_sq = torch.square(points_charges_1.unsqueeze(0) - esp_at_surf_1.squeeze(1)) # (B, N_surf) after broadcasting and squeeze esp = torch.sum(mask * torch.exp(-esp_diff_sq/lam), axis=1) # (B,) else: raise ValueError(f"Distances tensor has unexpected dimensions {distances.dim()} when points_1 is single.") # batched path for points_1 elif len(points_1.shape) == 3: # points_1 is (B, N_surf, 3) # This case assumes centers_w_H_2 and other inputs are compatibly batched or single (will broadcast) # distances will be (B, N_surf, M+m_H) irrespective of centers_w_H_2 being (B,M,3) or (M,3) _radii_2 = radii_2 _partial_charges_2 = partial_charges_2 if centers_w_H_2.shape[0] != points_1.shape[0] and centers_w_H_2.dim() == 3: raise ValueError(f"centers_w_H_2 has unexpected shape {centers_w_H_2.shape} when points_1 is batched. points_1: {points_1.shape}.") if _radii_2.dim() == 1: _radii_2 = _radii_2.unsqueeze(0) if _partial_charges_2.dim() == 1: _partial_charges_2 = _partial_charges_2.unsqueeze(0) # mask out molecule 1 surface points that are within molecule 2 mask = torch.where(torch.all(distances >= _radii_2.unsqueeze(1) + probe_radius, axis=2), 1., 0.) # (B,N) # Calculate the potentials esp_at_surf_1 = torch.matmul(_partial_charges_2.unsqueeze(1), 1 / distances.permute(0,2,1)) * COULOMB_SCALING # (B,1,N) esp_diff_sq = torch.square(points_charges_1.unsqueeze(1) - esp_at_surf_1) # (B,1,N) esp = torch.sum(mask * torch.exp(-esp_diff_sq/lam).squeeze(1), axis=1) # (B,) else: raise ValueError(f"points_1 has unexpected shape {points_1.shape}") return esp
[docs] def esp_combo_score(centers_w_H_1: torch.Tensor, centers_w_H_2: torch.Tensor, centers_1: torch.Tensor, centers_2: torch.Tensor, points_1: torch.Tensor, points_2: torch.Tensor, partial_charges_1: torch.Tensor, partial_charges_2: torch.Tensor, point_charges_1: torch.Tensor, point_charges_2: torch.Tensor, radii_1: torch.Tensor, radii_2: torch.Tensor, alpha: float, lam: float=0.001, probe_radius: float=1.0, esp_weight: float=0.5 ) -> torch.Tensor: """ Computes a similarity score defined by ShaEP. It is a balanced score between electrostatics and shape similarity. Single instance or batch accepted (in the 0th dimension). This will ONLY check the shape of points_1 to deterimine if it is batched or not so errors in the shape of the other tensors may or may not be caught. Parameters ---------- centers_w_H_1 : torch.Tensor (N + n_H, 3) | (batch, N + n_H, 3) Coordinates of atom centers INCLUDING hydrogens of molecule 1. Used for computing electrostatic potential. Same for centers_w_H_2 except (M + m_H, 3). centers_1 : torch.Tensor (N, 3) or (n_surf, 3) | (batch, N, 3) or (batch, n_surf, 3) Coordinates of points for molecule 1 used to compute shape similarity. Use atom centers for volumentric similarity. Use surface centers for surface similarity. Same for centers except (M, 3) or (m_surf, 3). points_1 : torch.Tensor (n_surf, 3) | (batch, n_surf, 3) Coordinates of surface points for molecule 1. Same for points_2 except (m_surf, 3). partial_charges_1 : torch.Tensor (N + n_H,) | (batch, N + n_H,) Partial charges corresponding to the atoms in centers_w_H_1. Same for partial_charges_2 except (M + m_H,). point_charges_1 : torch.Tensor (n_surf,) | (batch, n_surf,) The electrostatic potential calculated at each surface point (points_1). Same for point_charges_1 except (m_surf,) radii_1 : torch.Tensor (N + n_H,) | (batch, N + n_H,) vdW radii corresponding to the atoms in centers_w_H_1 (angstroms) Same for radii_2 except (M + m_H,) alpha : float Gaussian width parameter used for shape similarity. lam : float (default = 0.001) Electrostatic potential weighting parameter (smaller = higher weight). 0.001 was chosen as default based empirical observations of the distribution of scores generated by _esp_comparison before summation. probe_radius : float (default = 1.0) Surface points found within vdW radii + probe radius will be masked out. Surface generation uses a probe radius of 1.2 (radius of hydrogen) so we use a slightly lower radius for be more tolerant. esp_weight : float (default = 0.5) Weight to be placed on electrostatic similarity with respect to shape similarity. 0 = only shape similarity 1 = only electrostatic similarity Returns ------- torch.Tensor (1,) or (batch, 1) Similarity score (range: [0, 1]). Higher is more similar. """ # Calculate the difference in ESP at the surface of molecule 1 # Expects hydrogens for the centers esp_1 = _esp_comparison(points_1, centers_w_H_2, partial_charges_2, point_charges_1, radii_2, probe_radius, lam) esp_2 = _esp_comparison(points_2, centers_w_H_1, partial_charges_1, point_charges_2, radii_1, probe_radius, lam) # Determine number of points for normalization, robust to single or batched inputs _num_points_1 = points_1.shape[1] if points_1.dim() == 3 else points_1.shape[0] _num_points_2 = points_2.shape[1] if points_2.dim() == 3 else points_2.shape[0] electrostatic_sim = (esp_1 + esp_2) / (_num_points_1 + _num_points_2) volumetric_sim = get_overlap(centers_1, centers_2, alpha) score = esp_weight*electrostatic_sim + (1-esp_weight)*volumetric_sim return score