Gaussian Overlap (Shape Similarity)#

PyTorch Implementation#

Gaussian volume overlap scoring functions – Shape-only (i.e., not color)

Batched and non-batched functionalities

Reference math: https://doi.org/10.1002/(SICI)1096-987X(19961115)17:14<1653::AID-JCC7>3.0.CO;2-K https://doi.org/10.1021/j100011a016

shepherd_score.score.gaussian_overlap.VAB_2nd_order(centers_1, centers_2, alpha)[source]#

2nd order volume overlap of AB. Torch implementation supporting single instances, matched batches, and broadcasting scenarios. The function relies on torch.cdist for calculating squared distances, which handles necessary broadcasting for batch dimensions efficiently.

R2_cdist will have a shape like (Batch, N_c1, N_c2) or (N_c1, N_c2) depending on the input shapes. torch.cdist handles the broadcasting of batch dimensions. For example: - c1=(N,3), c2=(M,3) -> cdist_out=(N,M) - c1=(B,N,3), c2=(B,M,3) -> cdist_out=(B,N,M) - c1=(N,3), c2=(B,M,3) -> cdist_out=(B,N,M) (c1 broadcasted) - c1=(1,N,3), c2=(B,M,3) -> cdist_out=(B,N,M) (c1 broadcasted)

Parameters:
  • centers_1 (torch.Tensor)

  • centers_2 (torch.Tensor)

  • alpha (float)

Return type:

torch.Tensor

shepherd_score.score.gaussian_overlap.shape_tanimoto(centers_1, centers_2, alpha)[source]#

Compute Tanimoto shape similarity

Parameters:
  • centers_1 (torch.Tensor)

  • centers_2 (torch.Tensor)

  • alpha (float)

Return type:

torch.Tensor

shepherd_score.score.gaussian_overlap.get_overlap(centers_1, centers_2, alpha=0.81)[source]#

Volumetric shape similarity with tunable “alpha” Gaussian width parameter. Handles single instances, matched batches, and broadcasting scenarios (e.g., centers_1=(N,3) or (1,N,3) and centers_2=(B,M,3)). PyTorch implementation.

Parameters:
  • centers_1 (Union[torch.Tensor, np.ndarray] (batch, N, 3) or (N, 3)) – Coordinates of each point of the first point cloud. Can be (N,3) for a single instance, (B,N,3) for a batch, or (1,N,3) for a single instance to be broadcast against a batch in centers_2.

  • centers_2 (Union[torch.Tensor, np.ndarray] (batch, M, 3) or (M, 3)) – Coordinates of each point of the second point cloud. Can be (M,3) for a single instance, (B,M,3) for a batch, or (1,M,3) for a single instance to be broadcast against a batch in centers_1.

  • alpha (float (default=0.81)) – Gaussian width parameter. Lower value corresponds to wider Gaussian (longer tail).

Returns:

torch.Tensor – The Tanimoto similarity score. Returns a scalar if both inputs are single instances, or a 1D tensor of shape (batch,) if at least one input is batched.

Return type:

(batch,) or scalar

shepherd_score.score.gaussian_overlap.get_max_overlap(centers_1, centers_2, alpha=0.81)[source]#

Maximum overlap volume among any pair of centers (always in [0, 1] range).

Parameters:
Return type:

torch.Tensor

shepherd_score.score.gaussian_overlap.get_linear_hard_sphere_overlap(centers_1, centers_2, min_dist)[source]#

Compute linear hard sphere overlap .

See get_linear_hard_sphere_overlap_np for details.

Parameters:
Return type:

torch.Tensor

shepherd_score.score.gaussian_overlap.VAB_2nd_order_mask(centers_1, centers_2, alpha, mask_1, mask_2)[source]#

2nd order volume overlap of AB with masking. Torch implementation supporting single instances, matched batches, and broadcasting. Masks are applied to the interaction terms.

Parameters:
  • centers_1 (torch.Tensor (N,3) or (B,N,3) or (1,N,3)) – Coordinates for the first set of points.

  • centers_2 (torch.Tensor (M,3) or (B,M,3) or (1,M,3)) – Coordinates for the second set of points.

  • alpha (float) – Gaussian width parameter.

  • mask_1 (torch.Tensor (N,) or (B,N) or (1,N)) – Mask for centers_1. Boolean or float (0/1).

  • mask_2 (torch.Tensor (M,) or (B,M) or (1,M)) – Mask for centers_2. Boolean or float (0/1).

Returns:

Scalar or (B,) tensor of overlap scores.

Return type:

torch.Tensor

shepherd_score.score.gaussian_overlap.VAB_2nd_order_mask_batch(cdist_21, alpha, mask_1, mask_2)[source]#

2nd order volume overlap of AB (batched) with masking, using precomputed cdist. Assumes inputs cdist_21, mask_1, mask_2 are already batched and broadcast-compatible.

Parameters:
  • cdist_21 (torch.Tensor (B,M,N)) – Precomputed squared Euclidean distances: (torch.cdist(centers_2, centers_1)**2.0). Note the order: cdist(c2, c1) gives (B, M, N) which is R_21^2. If cdist(c1,c2).permute(0,2,1) was used, it’s also (B,M,N).

  • alpha (float) – Gaussian width parameter.

  • mask_1 (torch.Tensor (B,N) or (1,N)) – Mask for the first set of points (corresponding to N in cdist_21).

  • mask_2 (torch.Tensor (B,M) or (1,M)) – Mask for the second set of points (corresponding to M in cdist_21).

Returns:

torch.Tensor – Batched Tanimoto similarity scores.

Return type:

(B,)

shepherd_score.score.gaussian_overlap.VAB_2nd_order_cosine(centers_1, centers_2, vectors_1, vectors_2, alpha, allow_antiparallel)[source]#

2nd order volume overlap of AB weighted by cosine similarity. Torch implementation supporting single instances, matched batches, and broadcasting.

Parameters:
  • centers_1 (torch.Tensor (N,3) or (B,N,3) or (1,N,3))

  • centers_2 (torch.Tensor (M,3) or (B,M,3) or (1,M,3))

  • vectors_1 (torch.Tensor (N,3) or (B,N,3) or (1,N,3))

  • vectors_2 (torch.Tensor (M,3) or (B,M,3) or (1,M,3))

  • alpha (float)

  • allow_antiparallel (bool)

Returns:

Scalar or (B,) tensor of overlap scores.

Return type:

torch.Tensor

shepherd_score.score.gaussian_overlap.VAB_2nd_order_cosine_mask(centers_1, centers_2, vectors_1, vectors_2, alpha, allow_antiparallel, mask_1, mask_2)[source]#

2nd order volume overlap of AB weighted by cosine similarity, with masking. Torch implementation supporting single instances, matched batches, and broadcasting.

Parameters:
  • centers_1 (torch.Tensor (N,3) or (B,N,3) or (1,N,3))

  • centers_2 (torch.Tensor (M,3) or (B,M,3) or (1,M,3))

  • vectors_1 (torch.Tensor (N,3) or (B,N,3) or (1,N,3))

  • vectors_2 (torch.Tensor (M,3) or (B,M,3) or (1,M,3))

  • alpha (float)

  • allow_antiparallel (bool)

  • mask_1 (torch.Tensor (N,) or (B,N) or (1,N))

  • mask_2 (torch.Tensor (M,) or (B,M) or (1,M))

Returns:

Scalar or (B,) tensor of overlap scores.

Return type:

torch.Tensor

shepherd_score.score.gaussian_overlap.VAB_2nd_order_cosine_mask_batch(cdist_21, vmm_21, alpha, allow_antiparallel, mask_1, mask_2)[source]#

2nd order volume overlap of AB (batched) weighted by cosine similarity, with masking, using precomputed cdist and vector dot products (vmm). Assumes inputs cdist_21, vmm_21, mask_1, mask_2 are already batched and broadcast-compatible.

Parameters:
  • cdist_21 (torch.Tensor (B,M,N)) – Precomputed squared Euclidean distances, e.g., (torch.cdist(centers_2, centers_1)**2.0).

  • vmm_21 (torch.Tensor (B,M,N)) – Precomputed dot products of normalized vectors, e.g., torch.matmul(vectors_2, vectors_1.permute(0,2,1)). This corresponds to cosine similarities if vectors were normalized.

  • alpha (float) – Gaussian width parameter.

  • allow_antiparallel (bool) – If true, absolute cosine similarity is used.

  • mask_1 (torch.Tensor (B,N) or (1,N)) – Mask for the first set of points/vectors (N dimension).

  • mask_2 (torch.Tensor (B,M) or (1,M)) – Mask for the second set of points/vectors (M dimension).

Returns:

torch.Tensor – Batched Tanimoto similarity scores.

Return type:

(B,)

shepherd_score.score.gaussian_overlap.VAB_2nd_order_batched(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)[source]#

Calculate the 2nd order volume overlap of AB – batched functionality

Parameters:
  • centers_1 ((torch.Tensor) (batch_size, num_atoms_1, 3)) – Coordinates of atoms in molecule 1

  • centers_2 ((torch.Tensor) (batch_size, num_atoms_2, 3)) – Coordinates of atoms in molecule 2

  • alphas_1 ((torch.Tensor) (batch_size, num_atoms_1)) – Alpha values for atoms in molecule 1

  • alphas_2 ((torch.Tensor) (batch_size, num_atoms_2)) – Alpha values for atoms in molecule 2

  • prefactors_1 ((torch.Tensor) (batch_size, num_atoms_1)) – Prefactor values for atoms in molecule 1

  • prefactors_2 ((torch.Tensor) (batch_size, num_atoms_2)) – Prefactor values for atoms in molecule 2

Returns:

Representing the 2nd order volume overlap of AB for each batch

Return type:

torch.Tensor (batch_size,)

shepherd_score.score.gaussian_overlap.shape_tanimoto_batched(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)[source]#

Calculate the Tanimoto shape similarity between two batches of molecules.

Parameters:
  • centers_1 (torch.Tensor)

  • centers_2 (torch.Tensor)

  • alphas_1 (torch.Tensor)

  • alphas_2 (torch.Tensor)

  • prefactors_1 (torch.Tensor)

  • prefactors_2 (torch.Tensor)

Return type:

torch.Tensor

shepherd_score.score.gaussian_overlap.get_overlap_batch(centers_1, centers_2, prefactor=0.8, alpha=0.81)[source]#

Computes the gaussian overlap for a batch of centers.

Parameters:
  • centers_1 (torch.Tensor)

  • centers_2 (torch.Tensor)

  • prefactor (float)

  • alpha (float)

Return type:

torch.Tensor

shepherd_score.score.gaussian_overlap.VAB_2nd_order_full(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)[source]#

2nd order volume overlap of AB

Return type:

torch.Tensor

shepherd_score.score.gaussian_overlap.shape_tanimoto_full(centers_1, centers_2, alphas_1, alphas_2, prefactors_1, prefactors_2)[source]#

Compute Tanimoto shape similarity

Return type:

torch.Tensor

shepherd_score.score.gaussian_overlap.get_overlap_full(centers_1, centers_2, prefactor=0.8, alpha=0.81)[source]#

Computes the gaussian overlap for a batch of centers with custom prefactor and alpha values.

Parameters:
  • centers_1 (torch.Tensor)

  • centers_2 (torch.Tensor)

  • prefactor (float)

  • alpha (float)

Return type:

torch.Tensor

NumPy Implementation#

Gaussian volume overlap scoring functions – Shape-only (i.e., not color) NUMPY VERSION

Single instance functionality only.

Reference math: https://doi.org/10.1002/(SICI)1096-987X(19961115)17:14<1653::AID-JCC7>3.0.CO;2-K https://doi.org/10.1021/j100011a016

shepherd_score.score.gaussian_overlap_np.VAB_2nd_order_np(centers_1, centers_2, alpha)[source]#

2nd order volume overlap of AB

Return type:

ndarray

shepherd_score.score.gaussian_overlap_np.shape_tanimoto_np(centers_1, centers_2, alpha)[source]#

Compute Tanimoto shape similarity

Return type:

ndarray

shepherd_score.score.gaussian_overlap_np.get_overlap_np(centers_1, centers_2, alpha=0.81)[source]#

NumPy implementation of shape similarity via gaussian overlaps (single instance)

Parameters:
Return type:

ndarray

shepherd_score.score.gaussian_overlap_np.get_max_overlap_np(centers_1, centers_2, alpha=0.81)[source]#

Maximum overlap volume among any pair of centers (always in [0, 1] range).

Parameters:
Return type:

ndarray

shepherd_score.score.gaussian_overlap_np.get_linear_hard_sphere_overlap_np(centers_1, centers_2, min_dist)[source]#

Compute linear hard sphere overlap.

This function is linear based on the distance between centers For distance d d > min_dist: 0 0 < d < min_dist: linear from 0 to 1 d == 0: 1

Returns:

np.ndarray shape (1,) with the sum of hard sphere overlaps between)

Parameters:
Return type:

ndarray

shepherd_score.score.gaussian_overlap_np.VAB_2nd_order_cosine_np(centers_1, centers_2, vectors_1, vectors_2, alpha, allow_antiparallel)[source]#

2nd order volume overlap of AB weighted by cosine similarity. NumPy implementation with single instance functionality.

Parameters:
Return type:

ndarray

JAX Implementation#

Gaussian volume overlap scoring functions – Shape-only (i.e., not color) JAX VERSION (~ 6x faster than numpy)

Batched and non-batched functionalities

Reference math: https://doi.org/10.1002/(SICI)1096-987X(19961115)17:14<1653::AID-JCC7>3.0.CO;2-K https://doi.org/10.1021/j100011a016

shepherd_score.score.gaussian_overlap_jax.jax_cdist(X_1, X_2)#

Jax implementation pairwise euclidian distances.

Parameters:
  • X_1 (Array (N, P))

  • X_2 (Array (M, P))

Returns:

Distance matrix between X_1 and X_2.

Return type:

Array (N, M)

shepherd_score.score.gaussian_overlap_jax.jax_sq_cdist(X_1, X_2)#

Jax implementation pairwise SQUARED euclidian distances.

Parameters:
  • X_1 (Array (N, P))

  • X_2 (Array (M, P))

Returns:

Distance matrix between X_1 and X_2, squared.

Return type:

Array (N, M)

shepherd_score.score.gaussian_overlap_jax.VAB_2nd_order_jax(centers_1, centers_2, alpha)[source]#

2nd order volume overlap of AB

Parameters:
  • centers_1 (jax.Array)

  • centers_2 (jax.Array)

  • alpha (float)

Return type:

jax.Array

shepherd_score.score.gaussian_overlap_jax.shape_tanimoto_jax(centers_1, centers_2, alpha)[source]#

Compute Tanimoto shape similarity

Parameters:
  • centers_1 (jax.Array)

  • centers_2 (jax.Array)

  • alpha (float)

Return type:

jax.Array

shepherd_score.score.gaussian_overlap_jax.get_overlap_jax(centers_1, centers_2, alpha=0.81)#

Compute ROCS Gaussian volume overlap using jitted jax function.

Parameters:
  • centers_1 (jax.Array)

  • centers_2 (jax.Array)

  • alpha (float)

Return type:

jax.Array

shepherd_score.score.gaussian_overlap_jax.get_max_overlap_jax(centers_1, centers_2, alpha)#

Maximum overlap volume among any pair of centers (always in [0, 1] range).

Parameters:
  • centers_1 (jax.Array)

  • centers_2 (jax.Array)

  • alpha (float)

Return type:

jax.Array

shepherd_score.score.gaussian_overlap_jax.get_linear_hard_sphere_overlap_jax(centers_1, centers_2, min_dist)#

Compute linear hard sphere overlap .

See get_linear_hard_sphere_overlap_np for details.

Parameters:
  • centers_1 (jax.Array)

  • centers_2 (jax.Array)

  • min_dist (float)

Return type:

jax.Array

shepherd_score.score.gaussian_overlap_jax.VAB_2nd_order_jax_mask(centers_1, centers_2, mask_1, mask_2, alpha)[source]#

2nd order volume overlap of AB

Parameters:
  • centers_1 (jax.Array)

  • centers_2 (jax.Array)

  • mask_1 (jax.Array)

  • mask_2 (jax.Array)

  • alpha (float)

Return type:

jax.Array

shepherd_score.score.gaussian_overlap_jax.shape_tanimoto_jax_mask(centers_1, centers_2, mask_1, mask_2, alpha)[source]#

Compute Tanimoto shape similarity

Parameters:
  • centers_1 (jax.Array)

  • centers_2 (jax.Array)

  • mask_1 (jax.Array)

  • mask_2 (jax.Array)

  • alpha (float)

Return type:

jax.Array

shepherd_score.score.gaussian_overlap_jax.get_overlap_jax_mask(centers_1, centers_2, mask_1, mask_2, alpha=0.81)#

Compute ROCS Gaussian volume overlap using jitted jax function.

Parameters:
  • centers_1 (jax.Array)

  • centers_2 (jax.Array)

  • mask_1 (jax.Array)

  • mask_2 (jax.Array)

  • alpha (float)

Return type:

jax.Array

shepherd_score.score.gaussian_overlap_jax.VAB_2nd_order_cosine_jax(centers_1, centers_2, vectors_1, vectors_2, alpha, allow_antiparallel)#

2nd order volume overlap of AB weighted by cosine similarity (JAX version) - implementation part.

Parameters:
  • centers_1 (jax.Array)

  • centers_2 (jax.Array)

  • vectors_1 (jax.Array)

  • vectors_2 (jax.Array)

  • alpha (float)

  • allow_antiparallel (bool)

Return type:

jax.Array

shepherd_score.score.gaussian_overlap_jax.VAB_2nd_order_cosine_jax_mask(centers_1, centers_2, vectors_1, vectors_2, mask_1, mask_2, alpha, allow_antiparallel)#

2nd order volume overlap of AB weighted by cosine similarity (JAX version) - implementation part. Vectors are assumed to be normalized.

Parameters:
  • centers_1 (jax.Array)

  • centers_2 (jax.Array)

  • vectors_1 (jax.Array)

  • vectors_2 (jax.Array)

  • mask_1 (jax.Array)

  • mask_2 (jax.Array)

  • alpha (float)

  • allow_antiparallel (bool)

Return type:

jax.Array