Gaussian Overlap (Shape Similarity)#
PyTorch Implementation#
Gaussian volume overlap scoring functions – Shape-only (i.e., not color)
Batched and non-batched functionalities
Reference math: https://doi.org/10.1002/(SICI)1096-987X(19961115)17:14<1653::AID-JCC7>3.0.CO;2-K https://doi.org/10.1021/j100011a016
- shepherd_score.score.gaussian_overlap.VAB_2nd_order(centers_1, centers_2, alpha)[source]#
2nd order volume overlap of AB. Torch implementation supporting single instances, matched batches, and broadcasting scenarios. The function relies on torch.cdist for calculating squared distances, which handles necessary broadcasting for batch dimensions efficiently.
R2_cdist will have a shape like (Batch, N_c1, N_c2) or (N_c1, N_c2) depending on the input shapes. torch.cdist handles the broadcasting of batch dimensions. For example: - c1=(N,3), c2=(M,3) -> cdist_out=(N,M) - c1=(B,N,3), c2=(B,M,3) -> cdist_out=(B,N,M) - c1=(N,3), c2=(B,M,3) -> cdist_out=(B,N,M) (c1 broadcasted) - c1=(1,N,3), c2=(B,M,3) -> cdist_out=(B,N,M) (c1 broadcasted)
- Parameters:
centers_1 (torch.Tensor)
centers_2 (torch.Tensor)
alpha (float)
- Return type:
torch.Tensor
- shepherd_score.score.gaussian_overlap.shape_tanimoto(centers_1, centers_2, alpha)[source]#
Compute Tanimoto shape similarity
- Parameters:
centers_1 (torch.Tensor)
centers_2 (torch.Tensor)
alpha (float)
- Return type:
torch.Tensor
- shepherd_score.score.gaussian_overlap.get_overlap(centers_1, centers_2, alpha=0.81)[source]#
Volumetric shape similarity with tunable “alpha” Gaussian width parameter. Handles single instances, matched batches, and broadcasting scenarios (e.g., centers_1=(N,3) or (1,N,3) and centers_2=(B,M,3)). PyTorch implementation.
- Parameters:
centers_1 (Union[torch.Tensor, np.ndarray] (batch, N, 3) or (N, 3)) – Coordinates of each point of the first point cloud. Can be (N,3) for a single instance, (B,N,3) for a batch, or (1,N,3) for a single instance to be broadcast against a batch in centers_2.
centers_2 (Union[torch.Tensor, np.ndarray] (batch, M, 3) or (M, 3)) – Coordinates of each point of the second point cloud. Can be (M,3) for a single instance, (B,M,3) for a batch, or (1,M,3) for a single instance to be broadcast against a batch in centers_1.
alpha (float (default=0.81)) – Gaussian width parameter. Lower value corresponds to wider Gaussian (longer tail).
- Returns:
torch.Tensor – The Tanimoto similarity score. Returns a scalar if both inputs are single instances, or a 1D tensor of shape (batch,) if at least one input is batched.
- Return type:
(batch,) or scalar
- shepherd_score.score.gaussian_overlap.get_max_overlap(centers_1, centers_2, alpha=0.81)[source]#
Maximum overlap volume among any pair of centers (always in [0, 1] range).
- shepherd_score.score.gaussian_overlap.get_linear_hard_sphere_overlap(centers_1, centers_2, min_dist)[source]#
Compute linear hard sphere overlap .
See get_linear_hard_sphere_overlap_np for details.
- shepherd_score.score.gaussian_overlap.VAB_2nd_order_mask(centers_1, centers_2, alpha, mask_1, mask_2)[source]#
2nd order volume overlap of AB with masking. Torch implementation supporting single instances, matched batches, and broadcasting. Masks are applied to the interaction terms.
- Parameters:
centers_1 (torch.Tensor (N,3) or (B,N,3) or (1,N,3)) – Coordinates for the first set of points.
centers_2 (torch.Tensor (M,3) or (B,M,3) or (1,M,3)) – Coordinates for the second set of points.
alpha (float) – Gaussian width parameter.
mask_1 (torch.Tensor (N,) or (B,N) or (1,N)) – Mask for centers_1. Boolean or float (0/1).
mask_2 (torch.Tensor (M,) or (B,M) or (1,M)) – Mask for centers_2. Boolean or float (0/1).
- Returns:
Scalar or (B,) tensor of overlap scores.
- Return type:
torch.Tensor
- shepherd_score.score.gaussian_overlap.VAB_2nd_order_mask_batch(cdist_21, alpha, mask_1, mask_2)[source]#
2nd order volume overlap of AB (batched) with masking, using precomputed cdist. Assumes inputs cdist_21, mask_1, mask_2 are already batched and broadcast-compatible.
- Parameters:
cdist_21 (torch.Tensor (B,M,N)) – Precomputed squared Euclidean distances: (torch.cdist(centers_2, centers_1)**2.0). Note the order: cdist(c2, c1) gives (B, M, N) which is R_21^2. If cdist(c1,c2).permute(0,2,1) was used, it’s also (B,M,N).
alpha (float) – Gaussian width parameter.
mask_1 (torch.Tensor (B,N) or (1,N)) – Mask for the first set of points (corresponding to N in cdist_21).
mask_2 (torch.Tensor (B,M) or (1,M)) – Mask for the second set of points (corresponding to M in cdist_21).
- Returns:
torch.Tensor – Batched Tanimoto similarity scores.
- Return type:
(B,)
- shepherd_score.score.gaussian_overlap.VAB_2nd_order_cosine(centers_1, centers_2, vectors_1, vectors_2, alpha, allow_antiparallel)[source]#
2nd order volume overlap of AB weighted by cosine similarity. Torch implementation supporting single instances, matched batches, and broadcasting.
- Parameters:
- Returns:
Scalar or (B,) tensor of overlap scores.
- Return type:
torch.Tensor
- shepherd_score.score.gaussian_overlap.VAB_2nd_order_cosine_mask(centers_1, centers_2, vectors_1, vectors_2, alpha, allow_antiparallel, mask_1, mask_2)[source]#
2nd order volume overlap of AB weighted by cosine similarity, with masking. Torch implementation supporting single instances, matched batches, and broadcasting.
- Parameters:
centers_1 (torch.Tensor (N,3) or (B,N,3) or (1,N,3))
centers_2 (torch.Tensor (M,3) or (B,M,3) or (1,M,3))
vectors_1 (torch.Tensor (N,3) or (B,N,3) or (1,N,3))
vectors_2 (torch.Tensor (M,3) or (B,M,3) or (1,M,3))
alpha (float)
allow_antiparallel (bool)
mask_1 (torch.Tensor (N,) or (B,N) or (1,N))
mask_2 (torch.Tensor (M,) or (B,M) or (1,M))
- Returns:
Scalar or (B,) tensor of overlap scores.
- Return type:
torch.Tensor
- shepherd_score.score.gaussian_overlap.VAB_2nd_order_cosine_mask_batch(cdist_21, vmm_21, alpha, allow_antiparallel, mask_1, mask_2)[source]#
2nd order volume overlap of AB (batched) weighted by cosine similarity, with masking, using precomputed cdist and vector dot products (vmm). Assumes inputs cdist_21, vmm_21, mask_1, mask_2 are already batched and broadcast-compatible.
- Parameters:
cdist_21 (torch.Tensor (B,M,N)) – Precomputed squared Euclidean distances, e.g., (torch.cdist(centers_2, centers_1)**2.0).
vmm_21 (torch.Tensor (B,M,N)) – Precomputed dot products of normalized vectors, e.g., torch.matmul(vectors_2, vectors_1.permute(0,2,1)). This corresponds to cosine similarities if vectors were normalized.
alpha (float) – Gaussian width parameter.
allow_antiparallel (bool) – If true, absolute cosine similarity is used.
mask_1 (torch.Tensor (B,N) or (1,N)) – Mask for the first set of points/vectors (N dimension).
mask_2 (torch.Tensor (B,M) or (1,M)) – Mask for the second set of points/vectors (M dimension).
- Returns:
torch.Tensor – Batched Tanimoto similarity scores.
- Return type:
(B,)
- shepherd_score.score.gaussian_overlap.VAB_2nd_order_batched(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)[source]#
Calculate the 2nd order volume overlap of AB – batched functionality
- Parameters:
centers_1 ((torch.Tensor) (batch_size, num_atoms_1, 3)) – Coordinates of atoms in molecule 1
centers_2 ((torch.Tensor) (batch_size, num_atoms_2, 3)) – Coordinates of atoms in molecule 2
alphas_1 ((torch.Tensor) (batch_size, num_atoms_1)) – Alpha values for atoms in molecule 1
alphas_2 ((torch.Tensor) (batch_size, num_atoms_2)) – Alpha values for atoms in molecule 2
prefactors_1 ((torch.Tensor) (batch_size, num_atoms_1)) – Prefactor values for atoms in molecule 1
prefactors_2 ((torch.Tensor) (batch_size, num_atoms_2)) – Prefactor values for atoms in molecule 2
- Returns:
Representing the 2nd order volume overlap of AB for each batch
- Return type:
torch.Tensor (batch_size,)
- shepherd_score.score.gaussian_overlap.shape_tanimoto_batched(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)[source]#
Calculate the Tanimoto shape similarity between two batches of molecules.
- Parameters:
centers_1 (torch.Tensor)
centers_2 (torch.Tensor)
alphas_1 (torch.Tensor)
alphas_2 (torch.Tensor)
prefactors_1 (torch.Tensor)
prefactors_2 (torch.Tensor)
- Return type:
torch.Tensor
- shepherd_score.score.gaussian_overlap.get_overlap_batch(centers_1, centers_2, prefactor=0.8, alpha=0.81)[source]#
Computes the gaussian overlap for a batch of centers.
- shepherd_score.score.gaussian_overlap.VAB_2nd_order_full(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)[source]#
2nd order volume overlap of AB
- Return type:
torch.Tensor
- shepherd_score.score.gaussian_overlap.shape_tanimoto_full(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)[source]#
Compute Tanimoto shape similarity
- Return type:
torch.Tensor
NumPy Implementation#
Gaussian volume overlap scoring functions – Shape-only (i.e., not color) NUMPY VERSION
Single instance functionality only.
Reference math: https://doi.org/10.1002/(SICI)1096-987X(19961115)17:14<1653::AID-JCC7>3.0.CO;2-K https://doi.org/10.1021/j100011a016
- shepherd_score.score.gaussian_overlap_np.VAB_2nd_order_np(centers_1, centers_2, alpha)[source]#
2nd order volume overlap of AB
- Return type:
- shepherd_score.score.gaussian_overlap_np.shape_tanimoto_np(centers_1, centers_2, alpha)[source]#
Compute Tanimoto shape similarity
- Return type:
- shepherd_score.score.gaussian_overlap_np.get_overlap_np(centers_1, centers_2, alpha=0.81)[source]#
NumPy implementation of shape similarity via gaussian overlaps (single instance)
- shepherd_score.score.gaussian_overlap_np.get_max_overlap_np(centers_1, centers_2, alpha=0.81)[source]#
Maximum overlap volume among any pair of centers (always in [0, 1] range).
- shepherd_score.score.gaussian_overlap_np.get_linear_hard_sphere_overlap_np(centers_1, centers_2, min_dist)[source]#
Compute linear hard sphere overlap.
This function is linear based on the distance between centers For distance d d > min_dist: 0 0 < d < min_dist: linear from 0 to 1 d == 0: 1
- Returns:
np.ndarray shape (1,) with the sum of hard sphere overlaps between)
JAX Implementation#
Gaussian volume overlap scoring functions – Shape-only (i.e., not color) JAX VERSION (~ 6x faster than numpy)
Batched and non-batched functionalities
Reference math: https://doi.org/10.1002/(SICI)1096-987X(19961115)17:14<1653::AID-JCC7>3.0.CO;2-K https://doi.org/10.1021/j100011a016
- shepherd_score.score.gaussian_overlap_jax.jax_cdist(X_1, X_2)#
Jax implementation pairwise euclidian distances.
- Parameters:
X_1 (Array (N, P))
X_2 (Array (M, P))
- Returns:
Distance matrix between X_1 and X_2.
- Return type:
Array (N, M)
- shepherd_score.score.gaussian_overlap_jax.jax_sq_cdist(X_1, X_2)#
Jax implementation pairwise SQUARED euclidian distances.
- Parameters:
X_1 (Array (N, P))
X_2 (Array (M, P))
- Returns:
Distance matrix between X_1 and X_2, squared.
- Return type:
Array (N, M)
- shepherd_score.score.gaussian_overlap_jax.VAB_2nd_order_jax(centers_1, centers_2, alpha)[source]#
2nd order volume overlap of AB
- Parameters:
centers_1 (jax.Array)
centers_2 (jax.Array)
alpha (float)
- Return type:
jax.Array
- shepherd_score.score.gaussian_overlap_jax.shape_tanimoto_jax(centers_1, centers_2, alpha)[source]#
Compute Tanimoto shape similarity
- Parameters:
centers_1 (jax.Array)
centers_2 (jax.Array)
alpha (float)
- Return type:
jax.Array
- shepherd_score.score.gaussian_overlap_jax.get_overlap_jax(centers_1, centers_2, alpha=0.81)#
Compute ROCS Gaussian volume overlap using jitted jax function.
- Parameters:
centers_1 (jax.Array)
centers_2 (jax.Array)
alpha (float)
- Return type:
jax.Array
- shepherd_score.score.gaussian_overlap_jax.get_max_overlap_jax(centers_1, centers_2, alpha)#
Maximum overlap volume among any pair of centers (always in [0, 1] range).
- Parameters:
centers_1 (jax.Array)
centers_2 (jax.Array)
alpha (float)
- Return type:
jax.Array
- shepherd_score.score.gaussian_overlap_jax.get_linear_hard_sphere_overlap_jax(centers_1, centers_2, min_dist)#
Compute linear hard sphere overlap .
See get_linear_hard_sphere_overlap_np for details.
- Parameters:
centers_1 (jax.Array)
centers_2 (jax.Array)
min_dist (float)
- Return type:
jax.Array
- shepherd_score.score.gaussian_overlap_jax.VAB_2nd_order_jax_mask(centers_1, centers_2, mask_1, mask_2, alpha)[source]#
2nd order volume overlap of AB
- Parameters:
centers_1 (jax.Array)
centers_2 (jax.Array)
mask_1 (jax.Array)
mask_2 (jax.Array)
alpha (float)
- Return type:
jax.Array
- shepherd_score.score.gaussian_overlap_jax.shape_tanimoto_jax_mask(centers_1, centers_2, mask_1, mask_2, alpha)[source]#
Compute Tanimoto shape similarity
- Parameters:
centers_1 (jax.Array)
centers_2 (jax.Array)
mask_1 (jax.Array)
mask_2 (jax.Array)
alpha (float)
- Return type:
jax.Array
- shepherd_score.score.gaussian_overlap_jax.get_overlap_jax_mask(centers_1, centers_2, mask_1, mask_2, alpha=0.81)#
Compute ROCS Gaussian volume overlap using jitted jax function.
- Parameters:
centers_1 (jax.Array)
centers_2 (jax.Array)
mask_1 (jax.Array)
mask_2 (jax.Array)
alpha (float)
- Return type:
jax.Array
- shepherd_score.score.gaussian_overlap_jax.VAB_2nd_order_cosine_jax(centers_1, centers_2, vectors_1, vectors_2, alpha, allow_antiparallel)#
2nd order volume overlap of AB weighted by cosine similarity (JAX version) - implementation part.
- shepherd_score.score.gaussian_overlap_jax.VAB_2nd_order_cosine_jax_mask(centers_1, centers_2, vectors_1, vectors_2, mask_1, mask_2, alpha, allow_antiparallel)#
2nd order volume overlap of AB weighted by cosine similarity (JAX version) - implementation part. Vectors are assumed to be normalized.