"""
Molecule class to hold molecule geometries and extract interaction profiles.
MoleculePair class facilitates alignment with interaction profiles.
"""
from typing import Union, List, Optional, Tuple
from copy import deepcopy
import sys
import numpy as np
import rdkit.Chem as Chem
from rdkit.Geometry.rdGeometry import Point3D
import torch
from shepherd_score.score.constants import COULOMB_SCALING, LAM_SCALING, ALPHA # noqa: F401
from shepherd_score.generate_point_cloud import get_atom_coords, get_atomic_vdw_radii, get_molecular_surface, get_molecular_surface_const_density
from shepherd_score.score.gaussian_overlap_np import get_overlap_np
from shepherd_score.score.gaussian_overlap import get_overlap
from shepherd_score.score.electrostatic_scoring import get_overlap_esp
from shepherd_score.score.electrostatic_scoring_np import get_overlap_esp_np
from shepherd_score.pharm_utils.pharmacophore import get_pharmacophores
from shepherd_score.score.pharmacophore_scoring_np import get_overlap_pharm_np
from shepherd_score.score.pharmacophore_scoring import _SIM_TYPE, get_overlap_pharm
from shepherd_score.alignment import optimize_ROCS_overlay, optimize_ROCS_overlay_analytical, optimize_ROCS_esp_overlay, optimize_ROCS_esp_overlay_analytical, optimize_esp_combo_score_overlay
from shepherd_score.alignment import optimize_pharm_overlay, optimize_pharm_overlay_analytical
from shepherd_score.alignment.utils.se3_np import apply_SE3_transform_np, apply_SO3_transform_np
[docs]
def update_mol_coordinates(mol: Chem.Mol, coordinates: Union[List, np.ndarray]) -> Chem.Mol:
"""
Updates the coordinates of a 3D RDKit mol object with a new set of coordinates
Parameters
----------
mol : Chem.Mol
RDKit mol object with 3D coordinates to be replaced
coordinates : Union[List, np.ndarray]
List/array of new [x,y,z] coordinates
Returns
-------
mol_new : Chem.Mol
deep-copied RDKit mol object with updated 3D coordinates
"""
mol_new = deepcopy(mol)
conf = mol_new.GetConformer()
for i in range(mol_new.GetNumAtoms()):
x,y,z = coordinates[i]
conf.SetAtomPosition(i, Point3D(x,y,z))
return mol_new
[docs]
class Molecule:
"""
Molecule contains ways to hold/generate molecule geometries
"""
[docs]
def __init__(self,
mol: Chem.rdchem.Mol,
num_surf_points: Optional[int] = None,
density: Optional[float] = None,
probe_radius: Optional[float] = None,
surface_points: Optional[np.ndarray] = None,
partial_charges : Optional[np.ndarray] = None,
electrostatics: Optional[np.ndarray] = None,
pharm_multi_vector: Optional[bool] = None,
pharm_types: Optional[np.ndarray] = None,
pharm_ancs: Optional[np.ndarray] = None,
pharm_vecs: Optional[np.ndarray] = None
):
"""
Molecule constructor to extract interaction profiles.
If `partial_charges` are not provided, they will be generated using MMFF94 which may
result in subpar performance.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
num_surf_points : Optional[int]
Number of surface points to sample.
If ``None``, the surface point cloud is not generated. More efficient if only doing volumentric.
density : Optional[float]
Density of points to sample on molecular surface.
If ``None``, the surface point cloud is not generated. More efficient if only doing volumentric.
If both ``num_surf_points`` and ``density`` are not ``None``, ``num_surf_points`` supercedes ``density``.
surface_points : Optional[np.ndarray]
Surface points if they were previously generated. Shape: (M,3).
probe_radius : Optional[float]
The radius of a probe atom to act as a "solvent accessible surface".
Default is 1.2 if ``None`` is passed.
partial_charges : Optional[np.ndarray]
Partial charges for each atom. Shape: (N,).
If ``None`` is passed and ESP surface is generated, it will default to MMFF94 partial charges.
electrostatics : Optional[np.ndarray]
Electrostatic potential if they were previously generated. Shape: (M,).
pharm_multi_vector : Optional[bool]
If ``None``, don't generate pharmacophores, else generate
pharmacophores with/without (``True``/``False``) multi-vectors.
pharm_types : Optional[np.ndarray]
Types of pharmacophores. Shape: (P,).
pharm_ancs : Optional[np.ndarray]
Anchor positions of pharmacophores. Shape: (P,3).
pharm_vecs : Optional[np.ndarray]
Unit vectors relative to anchor positions of pharmacophores. Shape: (P,3).
"""
self.mol = mol
self.atom_pos = Chem.RemoveHs(mol).GetConformer().GetPositions()
if surface_points is None:
self.num_surf_points = num_surf_points
else:
self.num_surf_points = len(surface_points)
self.density = density
if isinstance(partial_charges, list):
partial_charges = np.array(partial_charges)
if isinstance(partial_charges, np.ndarray):
self.partial_charges = partial_charges
else:
self.partial_charges = self.get_partial_charges()
self.radii = get_atomic_vdw_radii(mol)
if surface_points is None:
self.probe_radius = probe_radius if probe_radius is not None else 1.2
if isinstance(num_surf_points, int):
self.surf_pos = self.get_pc()
elif isinstance(density, float):
self.surf_pos = self.get_pc(use_density=True)
else: # if None then don't generate a point cloud
self.surf_pos = None
self.surf_esp = None
else:
self.surf_pos = surface_points
self.probe_radius = probe_radius if probe_radius is not None else 1.2
if self.surf_pos is not None and self.partial_charges is not None:
if not isinstance(electrostatics, np.ndarray):
self.surf_esp = self.get_electrostatic_potential()
else:
self.surf_esp = electrostatics
# Indices for atoms that aren't hydrogens
self._nonH_atoms_idx = np.array([a.GetIdx() for a in self.mol.GetAtoms() if a.GetAtomicNum() != 1])
self.pharm_multi_vector = pharm_multi_vector
if isinstance(pharm_types, np.ndarray) and isinstance(pharm_ancs, np.ndarray) and isinstance(pharm_vecs, np.ndarray):
self.pharm_types, self.pharm_ancs, self.pharm_vecs = pharm_types, pharm_ancs, pharm_vecs
else:
self.pharm_types, self.pharm_ancs, self.pharm_vecs = None, None, None
if self.pharm_multi_vector is not None:
self.get_pharmacophore(
multi_vector=self.pharm_multi_vector,
exclude=[],
check_access=False,
scale=1.
)
[docs]
def get_partial_charges(self) -> np.ndarray:
"""
Get the partial charges on each atom using MMFF.
"""
molec_props = Chem.AllChem.MMFFGetMoleculeProperties(self.mol)
charges = np.array([molec_props.GetMMFFPartialCharge(i) for i, _ in enumerate(self.mol.GetAtoms())])
return charges.astype(np.float32)
[docs]
def get_pc(self, use_density=False) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Gets the point cloud positions.
"""
self.mol, centers = get_atom_coords(self.mol, MMFF_optimize=False)
if use_density:
positions = get_molecular_surface_const_density(centers,
self.radii,
self.density,
probe_radius=self.probe_radius,
num_samples_per_atom=25)
else:
positions = get_molecular_surface(centers,
self.radii,
num_points=self.num_surf_points,
probe_radius=self.probe_radius,
num_samples_per_atom = 25)
return positions.astype(np.float32)
[docs]
def get_electrostatic_potential(self) -> np.ndarray:
"""
Get the electrostatic potential at each surface point.
"""
centers = self.mol.GetConformer().GetPositions()
distances = np.linalg.norm(self.surf_pos[:, np.newaxis] - centers, axis=2)
# Calculate the potentials
E_pot = np.dot(self.partial_charges, 1 / distances.T) * COULOMB_SCALING
# Ensure that invalid distances (where distance is 0) are handled
E_pot[np.isinf(E_pot)] = 0
return E_pot.astype(np.float32)
[docs]
def center_to(self, xyz_means: np.ndarray) -> None:
"""
If you want to center the molecule with respect to a certain coordinate frame.
"""
self.atom_pos -= xyz_means
trans = np.eye(4)
trans[:3,3] = -xyz_means
Chem.rdMolTransforms.TransformConformer(self.mol.GetConformer(), trans)
if self.surf_pos is not None:
self.surf_pos -= xyz_means
if self.pharm_ancs is not None:
self.pharm_ancs -= xyz_means
[docs]
def get_pharmacophore(self,
multi_vector: bool = True,
exclude: List[int] = [],
check_access: bool = False,
scale: float = 1):
""" Get the pharmacophores of the molecule. """
self.pharm_types, self.pharm_ancs, self.pharm_vecs = get_pharmacophores(
self.mol,
multi_vector=multi_vector,
exclude=exclude,
check_access=check_access,
scale=scale
)
[docs]
class MoleculePair:
""" Pair of Molecule objects to facilitate alignment. """
[docs]
def __init__(self,
ref_mol: Union[Chem.rdchem.Mol, Molecule],
fit_mol: Union[Chem.rdchem.Mol, Molecule],
num_surf_points: Optional[int] = None,
density: Optional[float] = None,
do_center: bool = False,
device = -1):
"""
A pair of molecules. A refence molecule and a fit molecule that can be aligned to the fit.
There are a number of alignments that can be done:
- Volumetric (with and without hydrogens)
- Volumetric with partial charge weighting (with and without hydrogens)
- Surface
- Surface with electrostatic potential weighting
- ShaEP scoring (esp-combo)
- Pharmacophores (with various settings for using extended points rather than vectors)
Similarly, you can score with surface, Surf+ESP, and pharmacophores
Parameters
----------
ref_mol : Union[rdkit.Chem.rdchem.Mol, container.Molecule]
Reference molecule.
If a RDKit Mol object is provided, it will be converted to a Molecule
object. If a Molecule object is given, it will NOT regenerate the surface.
fit_mol : Union[rdkit.Chem.rdchem.Mol, container.Molecule]
Molecule to fit to the reference.
If a RDKit Mol object is provided, it will be converted to a Molecule
object. If a Molecule object is given, it will NOT regenerate the surface.
num_surf_points : Optional[int] (default = None)
Number of surface points to sample if rdkit Mol objects are given.
MUST provide a value for surface or ESP alignment.
density : Optional[float] (default = None)
Density of points to sample if rdkit Mol objects are given.
An integer intput for num_surf_points supercedes the density call.
do_center : bool (default = False)
THIS IS CRUCIAL
Whether to initially align molecule centers together. For global optimizations, set to
True. For scoring of current alignment or local alignment set to False.
device : pytorch Device (default = -1)
Device to use if you want to align with PyTorch downstream.
Default places alignment computation on CPU.
"""
# Generate surfaces if not a Molecule object
if not isinstance(ref_mol, Chem.rdchem.Mol):
self.ref_molec = ref_mol
else:
self.ref_molec = Molecule(ref_mol, num_surf_points=num_surf_points, density=density)
if not isinstance(fit_mol, Chem.rdchem.Mol):
self.fit_molec = fit_mol
else:
self.fit_molec = Molecule(fit_mol, num_surf_points=num_surf_points, density=density)
self.num_surf_points = num_surf_points
self.density = density
if density is not None and num_surf_points is None:
self.num_surf_points = True
if not isinstance(device, torch.device):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.device = device
# Center to origin
if do_center:
self.ref_molec.center_to(self.ref_molec.atom_pos.mean(0))
self.fit_molec.center_to(self.fit_molec.atom_pos.mean(0))
self.transform_vol = np.eye(4)
self.sim_aligned_vol = None
self.transform_vol_noH = np.eye(4)
self.sim_aligned_vol_noH = None
self.transform_surf = np.eye(4)
self.sim_aligned_surf = None
self.transform_esp = np.eye(4)
self.sim_aligned_esp = None
self.transform_vol_esp = np.eye(4)
self.sim_aligned_vol_esp = None
self.transform_vol_esp_noH = np.eye(4)
self.sim_aligned_vol_esp_noH = None
self.transform_esp_combo = np.eye(4)
self.sim_aligned_esp_combo = None
self.transform_pharm = np.eye(4)
self.sim_aligned_pharm = None
[docs]
def align_with_vol(self,
no_H: bool = True,
num_repeats: int = 50,
trans_init: bool = False,
lr: float = 0.1,
max_num_steps: int = 200,
use_jax: bool = False,
use_analytical: bool = True,
verbose: bool = False) -> np.ndarray:
"""
Align fit_molec to ref_molec using volumetric similarity.
Optimally aligned score found in ``self.sim_aligned_vol`` and the optimal SE(3)
transformation is at ``self.transform_vol``. If ``no_H`` is ``True``, append '_noH' to them.
Parameters
----------
no_H : bool
Whether to not include hydrogens in volumetric similarity. Default is ``True``.
num_repeats : int, optional
Number of different random initializations of SO(3) transformation parameters. Default is 50.
trans_init : bool, optional
Apply translation initializiation for alignment. ``fit_molec``'s center of mass (COM) is translated to
each ``ref_molec``'s atoms, with 10 rotations for each translation. So the
number of initializations scales as (# translation centers * 10 + 5) where 5 is from
the identity and 4 PCA with aligned COMs. If ``None``, then ``num_repeats``
rotations are done with aligned COMs.
lr : float, optional
Learning rate or step-size for optimization. Default is 0.1.
max_num_steps : int, optional
Maximum number of steps to optimize over. Default is 200.
use_jax : bool, optional
Whether to use Jax instead of PyTorch. Default is ``False``.
use_analytical : bool, optional
Whether to use analytical gradients instead of PyTorch autograd. Ignored if
``use_jax=True``. Default is ``True``.
verbose : bool, optional
Print initial and final similarity scores with scores every 100 steps. Default is ``False``.
Returns
-------
aligned_fit_points : np.ndarray
Coordinates of transformed atoms. Shape: (N, 3).
"""
if no_H:
ref_atom_pos = self.ref_molec.atom_pos
fit_atom_pos = self.fit_molec.atom_pos
else:
ref_atom_pos = self.ref_molec.mol.GetConformer().GetPositions()
# ref_atom_pos -= ref_atom_pos.mean(0)
fit_atom_pos = self.fit_molec.mol.GetConformer().GetPositions()
# fit_atom_pos -= fit_atom_pos.mean(0)
if use_jax: # Use Jax optimization implementation
if 'jax' not in sys.modules or 'jax.numpy' not in sys.modules:
try:
import jax.numpy as jnp
except ImportError:
raise ImportError('jax.numpy and torch is required for this function. Install Jax or just use Torch.')
import jax.numpy as jnp
from shepherd_score.alignment_jax import optimize_ROCS_overlay_jax
aligned_fit_points, se3_transform, score = optimize_ROCS_overlay_jax(
ref_points=jnp.array(ref_atom_pos),
fit_points=jnp.array(fit_atom_pos),
alpha=0.81,
num_repeats=num_repeats,
trans_centers = self.ref_molec.atom_pos if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
se3_transform = np.array(se3_transform)
score = np.array(score)
aligned_fit_points = np.array(aligned_fit_points)
else:
# PyTorch
_vol_fn = optimize_ROCS_overlay_analytical if use_analytical else optimize_ROCS_overlay
aligned_fit_points, se3_transform, score = _vol_fn(
ref_points=torch.from_numpy(ref_atom_pos).to(torch.float32).to(self.device),
fit_points=torch.from_numpy(fit_atom_pos).to(torch.float32).to(self.device),
alpha=0.81,
num_repeats=num_repeats,
trans_centers = torch.from_numpy(self.ref_molec.atom_pos).to(torch.float32).to(self.device) if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
se3_transform = se3_transform.numpy()
score = score.numpy()
aligned_fit_points = aligned_fit_points.numpy()
if no_H:
self.transform_vol_noH = se3_transform
self.sim_aligned_vol_noH = score
else:
self.transform_vol = se3_transform
self.sim_aligned_vol = score
return aligned_fit_points
[docs]
def align_with_vol_esp(self,
lam: float,
no_H: bool = True,
num_repeats: int = 50,
trans_init: bool = False,
lr: float = 0.1,
max_num_steps: int = 200,
use_jax: bool = False,
use_analytical: bool = True,
verbose: bool = False) -> np.ndarray:
"""
Align fit_molec to ref_molec using volume similarity weighted by partial charge
Toggle ``no_H`` parameter for scoring with or without hydrogens.
Typically ``lam=0.1`` is used.
Optimally aligned score found in ``self.sim_aligned_vol_esp`` and the optimal SE(3)
transformation is at ``self.transform_vol_esp``. If ``no_H`` is ``True``, append '_noH' to them.
Parameters
----------
lam : float
Partial charge weighting parameter.
no_H : bool
Whether to not include hydrogens in volumetric similarity. Default is ``True``.
num_repeats : int, optional
Number of different random initializations of SO(3) transformation parameters.
Default is 50.
trans_init : bool, optional
Apply translation initializiation for alignment. ``fit_molec``'s center of mass
(COM) is translated to each ``ref_molec``'s atoms, with 10 rotations for each translation.
So the number of initializations scales as (# translation centers * 10 + 5) where 5 is
from the identity and 4 PCA with aligned COMs. If ``None``, then ``num_repeats``
rotations are done with aligned COMs. Default is ``False``.
lr : float, optional
Learning rate or step-size for optimization. Default is 0.1.
max_num_steps : int, optional
Maximum number of steps to optimize over. Default is 200.
use_jax : bool, optional
Whether to use Jax instead of PyTorch. Default is ``False``.
verbose : bool, optional
Print initial and final similarity scores with scores every 100 steps.
Default is ``False``.
Returns
-------
aligned_fit_points : np.ndarray
Coordinates of transformed atoms. Shape: (N, 3).
"""
if no_H:
ref_mol_partial_charges = self.ref_molec.partial_charges[self.ref_molec._nonH_atoms_idx]
fit_mol_partial_charges = self.fit_molec.partial_charges[self.fit_molec._nonH_atoms_idx]
ref_mol_pos = self.ref_molec.atom_pos
fit_mol_pos = self.fit_molec.atom_pos
else:
ref_mol_partial_charges = self.ref_molec.partial_charges
fit_mol_partial_charges = self.fit_molec.partial_charges
ref_mol_pos = self.ref_molec.mol.GetConformer().GetPositions()
# ref_mol_pos -= ref_mol_pos.mean(0) # move COM to origin
fit_mol_pos = self.fit_molec.mol.GetConformer().GetPositions()
# fit_mol_pos -= fit_mol_pos.mean(0)
if use_jax: # Use Jax optimization implementation
if 'jax' not in sys.modules or 'jax.numpy' not in sys.modules:
try:
import jax.numpy as jnp
except ImportError:
raise ImportError('jax.numpy and torch is required for this function. Install Jax or just use Torch.')
import jax.numpy as jnp
from shepherd_score.alignment_jax import optimize_ROCS_esp_overlay_jax
aligned_fit_points, se3_transform, score = optimize_ROCS_esp_overlay_jax(
ref_points=jnp.array(ref_mol_pos),
fit_points=jnp.array(fit_mol_pos),
ref_charges=jnp.array(ref_mol_partial_charges),
fit_charges=jnp.array(fit_mol_partial_charges),
alpha=0.81,
lam=lam,
num_repeats=num_repeats,
trans_centers = self.ref_molec.atom_pos if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
se3_transform = np.array(se3_transform)
score = np.array(score)
aligned_fit_points = np.array(aligned_fit_points)
else: # Use Torch implementation
_esp_fn = optimize_ROCS_esp_overlay_analytical if use_analytical else optimize_ROCS_esp_overlay
aligned_fit_points, se3_transform, score = _esp_fn(
ref_points=torch.from_numpy(ref_mol_pos).to(torch.float32).to(self.device),
fit_points=torch.from_numpy(fit_mol_pos).to(torch.float32).to(self.device),
ref_charges=torch.from_numpy(ref_mol_partial_charges).to(torch.float32).to(self.device),
fit_charges=torch.from_numpy(fit_mol_partial_charges).to(torch.float32).to(self.device),
alpha=0.81,
lam=lam,
num_repeats=num_repeats,
trans_centers = torch.from_numpy(self.ref_molec.atom_pos).to(torch.float32).to(self.device) if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
se3_transform = se3_transform.numpy()
score = score.numpy()
aligned_fit_points = aligned_fit_points.numpy()
if no_H:
self.transform_vol_esp_noH = se3_transform
self.sim_aligned_vol_esp_noH = score
else:
self.transform_vol_esp = se3_transform
self.sim_aligned_vol_esp = score
return aligned_fit_points
[docs]
def align_with_surf(self,
alpha: float,
num_repeats: int = 50,
trans_init: bool = False,
lr: float = 0.1,
max_num_steps: int = 200,
use_jax: bool = False,
use_analytical: bool = True,
verbose: bool = False) -> np.ndarray:
"""
Align fit_molec to ref_molec using surface similarity.
Optimally aligned score found in ``self.sim_aligned_surf`` and the optimal SE(3)
transformation is at ``self.transform_surf``.
Parameters
----------
alpha : float
Gaussian width parameter for overlap.
num_repeats : int, optional
Number of different random initializations of SO(3) transformation parameters.
Default is 50.
trans_init : bool, optional
Apply translation initializiation for alignment. ``fit_molec``'s center of mass
(COM) is translated to each ``ref_molec``'s atoms, with 10 rotations for each
translation. So the number of initializations scales as
(# translation centers * 10 + 5) where 5 is from the identity and 4 PCA with
aligned COMs. If ``None``, then ``num_repeats`` rotations are done with aligned COMs.
Default is ``False``.
lr : float, optional
Learning rate or step-size for optimization. Default is 0.1.
max_num_steps : int, optional
Maximum number of steps to optimize over. Default is 200.
use_jax : bool, optional
Whether to use Jax instead of PyTorch. Default is ``False``.
use_analytical : bool, optional
Whether to use analytical gradients instead of PyTorch autograd. Ignored if
``use_jax=True``. Default is ``True``.
verbose : bool, optional
Print initial and final similarity scores with scores every 100 steps. Default is ``False``.
Returns
-------
aligned_fit_points : np.ndarray
Coordinates of transformed atoms. Shape: (N, 3).
"""
if self.num_surf_points is None:
raise ValueError('The Molecule objects were initialized with no surface points so this method cannot be used.')
if use_jax: # Use Jax optimization implementation
if 'jax' not in sys.modules or 'jax.numpy' not in sys.modules:
try:
import jax.numpy as jnp
except ImportError:
raise ImportError('jax.numpy and torch is required for this function. Install Jax or just use Torch.')
import jax.numpy as jnp
from shepherd_score.alignment_jax import optimize_ROCS_overlay_jax
aligned_fit_points, se3_transform, score = optimize_ROCS_overlay_jax(
ref_points=jnp.array(self.ref_molec.surf_pos),
fit_points=jnp.array(self.fit_molec.surf_pos),
alpha=alpha,
num_repeats=num_repeats,
trans_centers = self.ref_molec.atom_pos if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
self.transform_surf = np.array(se3_transform)
self.sim_aligned_surf = np.array(score)
return np.array(aligned_fit_points)
else:
# Torch
_surf_fn = optimize_ROCS_overlay_analytical if use_analytical else optimize_ROCS_overlay
aligned_fit_points, se3_transform, score = _surf_fn(
ref_points=torch.from_numpy(self.ref_molec.surf_pos).to(torch.float32).to(self.device),
fit_points=torch.from_numpy(self.fit_molec.surf_pos).to(torch.float32).to(self.device),
alpha=alpha,
num_repeats=num_repeats,
trans_centers = torch.from_numpy(self.ref_molec.atom_pos).to(torch.float32).to(self.device) if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
self.transform_surf = se3_transform.numpy()
self.sim_aligned_surf = score.numpy()
return aligned_fit_points.numpy()
[docs]
def align_with_esp(self,
alpha: float,
lam: float = 0.3,
num_repeats: int = 50,
trans_init: bool = False,
lr: float = 0.1,
max_num_steps: int = 200,
use_jax: bool = False,
use_analytical: bool = True,
verbose: bool = False) -> np.ndarray:
"""
Align fit_molec to ref_molec using ESP+surface similarity.
``lam`` is scaled by ``(1e4/(4*55.263*np.pi))**2`` for correct units.
Typically, ``lam=0.3`` is used and is scaled internally.
Optimally aligned score found in ``self.sim_aligned_esp`` and the optimal SE(3)
transformation is at ``self.transform_esp``.
Parameters
----------
alpha : float
Gaussian width parameter for overlap.
lam : float, optional
Weighting factor for ESP scoring. Scaled internally. Default is 0.3.
num_repeats : int, optional
Number of different random initializations of SO(3) transformation parameters.
Default is 50.
trans_init : bool, optional
Apply translation initializiation for alignment. ``fit_molec``'s COM is translated to
each ``ref_molecs``'s atoms, with 10 rotations for each translation. So the
number of initializations scales as (# translation centers * 10 + 5) where 5 is from
the identity and 4 PCA with aligned COM's. If None, then num_repeats rotations are done
with aligned COM's. Default is ``False``.
lr : float, optional
Learning rate or step-size for optimization. Default is 0.1.
max_num_steps : int, optional
Maximum number of steps to optimize over. Default is 200.
use_jax : bool, optional
Whether to use Jax instead of PyTorch. Default is ``False``.
verbose : bool, optional
Print initial and final similarity scores with scores every 100 steps.
Default is ``False``.
Returns
-------
aligned_fit_points : np.ndarray
Coordinates of transformed atoms. Shape: (N, 3).
"""
lam_scaled = LAM_SCALING * lam
if self.num_surf_points is None:
raise ValueError('The Molecule objects were initialized with no surface points so this method cannot be used.')
if use_jax: # Use Jax optimization implementation
if 'jax' not in sys.modules or 'jax.numpy' not in sys.modules:
try:
import jax.numpy as jnp
except ImportError:
raise ImportError('jax.numpy and torch is required for this function. Install Jax or just use Torch.')
import jax.numpy as jnp
from shepherd_score.alignment_jax import optimize_ROCS_esp_overlay_jax
aligned_fit_points, se3_transform, score = optimize_ROCS_esp_overlay_jax(
ref_points=jnp.array(self.ref_molec.surf_pos),
fit_points=jnp.array(self.fit_molec.surf_pos),
ref_charges=jnp.array(self.ref_molec.surf_esp),
fit_charges=jnp.array(self.fit_molec.surf_esp),
alpha=alpha,
lam=lam_scaled,
num_repeats=num_repeats,
trans_centers = self.ref_molec.atom_pos if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
self.transform_esp = np.array(se3_transform)
self.sim_aligned_esp = np.array(score)
return np.array(aligned_fit_points)
else: # Use Torch implementation
_esp_fn = optimize_ROCS_esp_overlay_analytical if use_analytical else optimize_ROCS_esp_overlay
aligned_fit_points, se3_transform, score = _esp_fn(
ref_points=torch.from_numpy(self.ref_molec.surf_pos).to(torch.float32).to(self.device),
fit_points=torch.from_numpy(self.fit_molec.surf_pos).to(torch.float32).to(self.device),
ref_charges=torch.from_numpy(self.ref_molec.surf_esp).to(torch.float32).to(self.device),
fit_charges=torch.from_numpy(self.fit_molec.surf_esp).to(torch.float32).to(self.device),
alpha=alpha,
lam=lam_scaled,
num_repeats=num_repeats,
trans_centers = torch.from_numpy(self.ref_molec.atom_pos).to(torch.float32).to(self.device) if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
self.transform_esp = se3_transform.numpy()
self.sim_aligned_esp = score.numpy()
return aligned_fit_points.numpy()
[docs]
def align_with_esp_combo(self,
alpha: float,
lam: float = 0.001,
probe_radius: float = 1.0,
esp_weight: float = 0.5,
num_repeats: int = 50,
trans_init: bool = False,
lr: float = 0.1,
max_num_steps: int = 200,
use_jax: bool = False,
verbose: bool = False):
"""
Align using ShaEP similarity score.
If alpha is 0.81, then it automatically uses volumetric shape similarity.
Otherwise, it uses surface shape similarity.
Optimally aligned score found in ``self.sim_aligned_esp_combo`` and the optimal SE(3)
transformation is at ``self.transform_esp_combo``.
Parameters
----------
alpha : float
Gaussian width parameter for overlap.
lam : float, optional
ESP weighting parameter. Default is 0.001.
probe_radius : float, optional
Surface points found within vdW radii + probe radius will be masked out.
Surface generation uses a probe radius of 1.2 by default (radius of hydrogen)
so we use a slightly lower radius for be more tolerant. Default is 1.0.
esp_weight : float, optional
How much to weight shape vs esp_combo similarity ([0,1]). Default is 0.5.
num_repeats : int, optional
Number of different random initializations of SO(3) transformation parameters. Default is 50.
trans_init : bool, optional
Apply translation initializiation for alignment. ``fit_molec``'s COM is translated
to each ``ref_molecs``'s atoms, with 10 rotations for each translation. So the
number of initializations scales as (# translation centers * 10 + 5) where 5 is
from the identity and 4 PCA with aligned COM's. If ``None``, then ``num_repeats``
rotations are done with aligned COM's. Default is ``False``.
lr : float, optional
Learning rate or step-size for optimization. Default is 0.1.
max_num_steps : int, optional
Maximum number of steps to optimize over. Default is 200.
use_jax : bool, optional
Whether to use Jax instead of PyTorch. Default is ``False``.
verbose : bool, optional
Print initial and final similarity scores with scores every 100 steps.
Default is ``False``.
Returns
-------
aligned_fit_points : np.ndarray (N, 3)
Coordinates of transformed atoms. Shape: (N, 3).
"""
if self.num_surf_points is None:
raise ValueError('The Molecule objects were initialized with no surface points so this method cannot be used.')
if use_jax: # Use Jax optimization implementation
if 'jax' not in sys.modules or 'jax.numpy' not in sys.modules:
try:
import jax.numpy as jnp
except ImportError:
raise ImportError('jax.numpy and torch is required for this function. Install Jax or just use Torch.')
import jax.numpy as jnp
from shepherd_score.alignment_jax import optimize_esp_combo_score_overlay_jax
aligned_fit_points, se3_transform, score = optimize_esp_combo_score_overlay_jax(
ref_centers_w_H=jnp.array(self.ref_molec.mol.GetConformer().GetPositions()),
fit_centers_w_H=jnp.array(self.fit_molec.mol.GetConformer().GetPositions()),
ref_centers=jnp.array(self.ref_molec.atom_pos) if alpha == 0.81 else jnp.array(self.ref_molec.surf_pos),
fit_centers=jnp.array(self.fit_molec.atom_pos) if alpha == 0.81 else jnp.array(self.fit_molec.surf_pos),
ref_points=jnp.array(self.ref_molec.surf_pos),
fit_points=jnp.array(self.fit_molec.surf_pos),
ref_partial_charges=jnp.array(self.ref_molec.partial_charges),
fit_partial_charges=jnp.array(self.fit_molec.partial_charges),
ref_surf_esp=jnp.array(self.ref_molec.surf_esp),
fit_surf_esp=jnp.array(self.fit_molec.surf_esp),
ref_radii=jnp.array(self.ref_molec.radii),
fit_radii=jnp.array(self.fit_molec.radii),
alpha=alpha,
lam=lam,
probe_radius=probe_radius,
esp_weight=esp_weight,
num_repeats=num_repeats,
trans_centers = self.ref_molec.atom_pos if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
self.transform_esp_combo = np.array(se3_transform)
self.sim_aligned_esp_combo = np.array(score)
return np.array(aligned_fit_points)
else:
if alpha == 0.81:
ref_centers = torch.from_numpy(self.ref_molec.atom_pos).to(torch.float32).to(self.device)
fit_centers = torch.from_numpy(self.fit_molec.atom_pos).to(torch.float32).to(self.device)
else:
ref_centers = torch.from_numpy(self.ref_molec.surf_pos).to(torch.float32).to(self.device)
fit_centers = torch.from_numpy(self.fit_molec.surf_pos).to(torch.float32).to(self.device)
aligned_fit_points, se3_transform, score = optimize_esp_combo_score_overlay(
ref_centers_w_H=torch.from_numpy(self.ref_molec.mol.GetConformer().GetPositions()).to(torch.float32).to(self.device),
fit_centers_w_H=torch.from_numpy(self.fit_molec.mol.GetConformer().GetPositions()).to(torch.float32).to(self.device),
ref_centers=ref_centers,
fit_centers=fit_centers,
ref_points=torch.from_numpy(self.ref_molec.surf_pos).to(torch.float32).to(self.device),
fit_points=torch.from_numpy(self.fit_molec.surf_pos).to(torch.float32).to(self.device),
ref_partial_charges=torch.from_numpy(self.ref_molec.partial_charges).to(torch.float32).to(self.device),
fit_partial_charges=torch.from_numpy(self.fit_molec.partial_charges).to(torch.float32).to(self.device),
ref_surf_esp=torch.from_numpy(self.ref_molec.surf_esp).to(torch.float32).to(self.device),
fit_surf_esp=torch.from_numpy(self.fit_molec.surf_esp).to(torch.float32).to(self.device),
ref_radii=torch.from_numpy(self.ref_molec.radii).to(torch.float32).to(self.device),
fit_radii=torch.from_numpy(self.fit_molec.radii).to(torch.float32).to(self.device),
alpha=alpha,
lam=lam,
probe_radius=probe_radius,
esp_weight=esp_weight,
num_repeats=num_repeats,
trans_centers = torch.from_numpy(self.ref_molec.atom_pos).to(torch.float32).to(self.device) if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
self.transform_esp_combo = se3_transform.numpy()
self.sim_aligned_esp_combo = score.numpy()
return aligned_fit_points.numpy()
[docs]
def align_with_pharm(self,
similarity: _SIM_TYPE = 'tanimoto',
extended_points: bool = False,
only_extended: bool = False,
num_repeats: int = 50,
trans_init: bool = False,
lr: float = 0.1,
max_num_steps: int = 200,
use_jax: bool = False,
verbose: bool = False,
use_vectorized: bool = True,
use_analytical: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Align fit_molec to ref_molec using pharmacophore similarity.
Optimally aligned score found in ``self.sim_aligned_pharm`` and the optimal SE(3)
transformation is at ``self.transform_pharm``.
Parameters
----------
similarity : str from ('tanimoto', 'tversky', 'tversky_ref', 'tversky_fit')
Specifies what similarity function to use. Options are:
'tanimoto' -- symmetric scoring function
'tversky' -- asymmetric -> Uses OpenEye's formulation 95% normalization by molec 1
'tversky_ref' -- asymmetric -> Uses Pharao's formulation 100% normalization by molec 1.
'tversky_fit' -- asymmetric -> Uses Pharao's formulation 100% normalization by molec 2.
extended_points : bool, optional
Whether to score HBA/HBD with gaussian overlaps of extended points. Default is ``False``.
only_extended : bool, optional
When ``extended_points`` is ``True``, decide whether to only score the extended points
(ignore anchor overlaps). Default is ``False``.
num_repeats : int, optional
Number of different random initializations of SO(3) transformation parameters.
Default is 50.
trans_init : bool, optional
Apply translation initializiation for alignment. ``fit_molec``'s COM is translated to
each ``ref_molecs``'s pharmacophore, with 10 rotations for each translation. So the
number of initializations scales as (# translation centers * 10 + 5) where 5 is from
the identity and 4 PCA with aligned COM's. If ``None``, then ``num_repeats`` rotations
are done with aligned COM's. Default is ``False``.
lr : float, optional
Learning rate or step-size for optimization. Default is 0.1.
max_num_steps : int, optional
Maximum number of steps to optimize over. Default is 200.
use_jax : bool, optional
Whether to use Jax instead of PyTorch. Default is ``False``.
verbose : bool, optional
Print initial and final similarity scores with scores every 100 steps.
Default is ``False``.
use_vectorized : bool, optional
Whether to use the vectorized version of the pharmacophore scoring function.
This is only relevant if ``use_jax=True``.
Default is ``True``.
use_analytical : bool, optional
Whether to use the analytical version of the pharmacophore scoring function.
Currently only implemented for PyTorch.
Default is ``True``.
Returns
-------
tuple
aligned_fit_anchors : np.ndarray
Aligned coordinates of pharmacophores positions. Shape: (P, 3).
aligned_fit_vectors : np.ndarray
Aligned coordinates of pharmacophore vectors. Shape: (P, 3).
"""
if use_jax:
if 'jax' not in sys.modules or 'jax.numpy' not in sys.modules:
try:
import jax.numpy as jnp
except ImportError:
raise ImportError('jax.numpy and torch is required for this function. Install Jax or just use Torch.')
import jax.numpy as jnp
from shepherd_score.alignment_jax import optimize_pharm_overlay_jax, optimize_pharm_overlay_jax_vectorized
_pharm_fn = optimize_pharm_overlay_jax_vectorized if use_vectorized else optimize_pharm_overlay_jax
aligned_fit_anchors, aligned_fit_vectors, se3_transform, score = _pharm_fn(
ref_pharms=jnp.array(self.ref_molec.pharm_types),
fit_pharms=jnp.array(self.fit_molec.pharm_types),
ref_anchors=jnp.array(self.ref_molec.pharm_ancs),
fit_anchors=jnp.array(self.fit_molec.pharm_ancs),
ref_vectors=jnp.array(self.ref_molec.pharm_vecs),
fit_vectors=jnp.array(self.fit_molec.pharm_vecs),
similarity=similarity,
extended_points=extended_points,
only_extended=only_extended,
num_repeats=num_repeats,
trans_centers=self.ref_molec.pharm_ancs if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
self.transform_pharm = np.array(se3_transform)
self.sim_aligned_pharm = np.array(score)
return np.array(aligned_fit_anchors), np.array(aligned_fit_vectors)
# PyTorch
_pharm_fn = optimize_pharm_overlay_analytical if use_analytical else optimize_pharm_overlay
aligned_fit_anchors, aligned_fit_vectors, se3_transform, score = _pharm_fn(
ref_pharms=torch.from_numpy(self.ref_molec.pharm_types).to(torch.float32).to(self.device),
fit_pharms=torch.from_numpy(self.fit_molec.pharm_types).to(torch.float32).to(self.device),
ref_anchors=torch.from_numpy(self.ref_molec.pharm_ancs).to(torch.float32).to(self.device),
fit_anchors=torch.from_numpy(self.fit_molec.pharm_ancs).to(torch.float32).to(self.device),
ref_vectors=torch.from_numpy(self.ref_molec.pharm_vecs).to(torch.float32).to(self.device),
fit_vectors=torch.from_numpy(self.fit_molec.pharm_vecs).to(torch.float32).to(self.device),
similarity=similarity,
extended_points=extended_points,
only_extended=only_extended,
num_repeats=num_repeats,
trans_centers=torch.from_numpy(self.ref_molec.pharm_ancs).to(torch.float32).to(self.device) if trans_init else None,
lr=lr,
max_num_steps=max_num_steps,
verbose=verbose
)
self.transform_pharm = se3_transform.numpy()
self.sim_aligned_pharm = score.numpy()
return aligned_fit_anchors.numpy(), aligned_fit_vectors.numpy()
[docs]
def score_with_surf(self,
alpha: float,
use: str = 'np'
) -> np.ndarray:
"""
Score fit_molec to ref_molec using surface similarity given current alignment.
By default it uses the numpy implementation.
Parameters
----------
alpha : float
Gaussian width parameter for overlap.
use : str, optional
Specifies what implementation to use. Options are:
- 'np' or 'numpy' (numpy implementation)
- 'jax' or 'jnp' (Jax implementation)
- 'torch' or 'pytorch' (PyTorch implementation)
Default is 'np'.
Returns
-------
score : np.ndarray
Similarity score. Shape: (1,).
"""
use = use.lower()
accepted_keys = ('jax', 'jnp', 'torch', 'pytorch', 'np', 'numpy')
if use not in accepted_keys:
raise ValueError(f"`use` must be in {accepted_keys}. Instead {use} was passed.")
if self.num_surf_points is None:
raise ValueError('The Molecule objects were initialized with no surface points so this method cannot be used.')
if use == 'jax' or use == 'jnp': # Use Jax optimization implementation
if 'jax' not in sys.modules or 'jax.numpy' not in sys.modules:
try:
import jax.numpy as jnp
except ImportError:
raise ImportError('jax.numpy and torch is required for this function. Install Jax or just use Torch.')
import jax.numpy as jnp
from shepherd_score.score.gaussian_overlap_jax import get_overlap_jax
score = get_overlap_jax(
centers_1=jnp.array(self.ref_molec.surf_pos),
centers_2=jnp.array(self.fit_molec.surf_pos),
alpha=alpha,
)
return np.array(score)
elif use == 'torch' or use == 'pytorch':
# Torch
score = get_overlap(
centers_1=torch.from_numpy(self.ref_molec.surf_pos).to(torch.float32).to(self.device),
centers_2=torch.from_numpy(self.fit_molec.surf_pos).to(torch.float32).to(self.device),
alpha=alpha,
)
return score.cpu().numpy()
elif use == 'np' or use == 'numpy':
score = get_overlap_np(
centers_1=self.ref_molec.surf_pos,
centers_2=self.fit_molec.surf_pos,
alpha=alpha,
)
return score
[docs]
def score_with_esp(self,
alpha: float,
lam: float = 0.3,
use: str = 'np'
) -> np.ndarray:
"""
Score fit_molec to ref_molec using ESP+surface similarity given current alignment.
``lam`` is scaled by ``(1e4/(4*55.263*np.pi))**2`` for correct units.
Typically ``lam = 0.3`` is used and is scaled internally.
By default it uses the numpy implementation.
Parameters
----------
alpha : float
Gaussian width parameter for overlap.
lam : float, optional
Weighting factor for ESP scoring. Default is 0.3.
use : str, optional
Specifies what implementation to use. Options are:
- 'np' or 'numpy' (numpy implementation)
- 'jax' or 'jnp' (Jax implementation)
- 'torch' or 'pytorch' (PyTorch implementation)
Default is 'np'.
Returns
-------
score : np.ndarray
Similarity score. Shape: (1,).
"""
lam_scaled = LAM_SCALING * lam
use = use.lower()
accepted_keys = ('jax', 'jnp', 'torch', 'pytorch', 'np', 'numpy')
if use not in accepted_keys:
raise ValueError(f"`use` must be in {accepted_keys}. Instead {use} was passed.")
if self.num_surf_points is None:
raise ValueError('The Molecule objects were initialized with no surface points so this method cannot be used.')
if use in ('jax', 'jnp'): # Use Jax implementation
if 'jax' not in sys.modules or 'jax.numpy' not in sys.modules:
try:
import jax.numpy as jnp
except ImportError:
raise ImportError('jax.numpy and torch is required for this function. Install Jax or just use Torch.')
import jax.numpy as jnp
from shepherd_score.score.electrostatic_scoring_jax import get_overlap_esp_jax
score = get_overlap_esp_jax(
centers_1=jnp.array(self.ref_molec.surf_pos),
centers_2=jnp.array(self.fit_molec.surf_pos),
charges_1=jnp.array(self.ref_molec.surf_esp),
charges_2=jnp.array(self.fit_molec.surf_esp),
alpha=alpha,
lam=lam_scaled,
)
return np.array(score)
elif use in ('torch', 'pytorch'): # Use Torch implementation
score = get_overlap_esp(
centers_1=torch.from_numpy(self.ref_molec.surf_pos).to(torch.float32).to(self.device),
centers_2=torch.from_numpy(self.fit_molec.surf_pos).to(torch.float32).to(self.device),
charges_1=torch.from_numpy(self.ref_molec.surf_esp).to(torch.float32).to(self.device),
charges_2=torch.from_numpy(self.fit_molec.surf_esp).to(torch.float32).to(self.device),
alpha=alpha,
lam=lam_scaled,
)
return score.cpu().numpy()
elif use in ('np', 'numpy'):
score = get_overlap_esp_np(
centers_1=self.ref_molec.surf_pos,
centers_2=self.fit_molec.surf_pos,
charges_1=self.ref_molec.surf_esp,
charges_2=self.fit_molec.surf_esp,
alpha=alpha,
lam=lam_scaled,
)
return score
[docs]
def score_with_pharm(self,
similarity: _SIM_TYPE = 'tanimoto',
extended_points: bool = False,
only_extended: bool = False,
use: str = 'np'
) -> np.ndarray:
"""
Score fit_molec to ref_molec using pharmacophore similarity given current alignment.
By default it uses the numpy implementation.
Parameters
----------
similarity : str from ('tanimoto', 'tversky', 'tversky_ref', 'tversky_fit')
Specifies what similarity function to use. Options are:
'tanimoto' -- symmetric scoring function
'tversky' -- asymmetric -> Uses OpenEye's formulation 95% normalization by molec 1
'tversky_ref' -- asymmetric -> Uses Pharao's formulation 100% normalization by molec 1.
'tversky_fit' -- asymmetric -> Uses Pharao's formulation 100% normalization by molec 2.
extended_points : bool, optional
Whether to score HBA/HBD with gaussian overlaps of extended points.
Default is ``False``.
only_extended : bool, optional
When ``extended_points`` is ``True``, decide whether to only score the extended
points (ignore anchor overlaps). Default is ``False``.
use : str, optional
Specifies what implementation to use. Options are:
- 'np' or 'numpy' (numpy implementation)
- 'jax' or 'jnp' (Jax implementation)
- 'torch' or 'pytorch' (PyTorch implementation)
Default is 'np'.
Returns
-------
score : np.ndarray
Similarity score. Shape: (1,).
"""
use = use.lower()
accepted_keys = ('jax', 'jnp', 'torch', 'pytorch', 'np', 'numpy')
if use not in accepted_keys:
raise ValueError(f"`use` must be in {accepted_keys}. Instead {use} was passed.")
elif use in ('torch', 'pytorch'):
# PyTorch
score = get_overlap_pharm(
ptype_1=torch.from_numpy(self.ref_molec.pharm_types).to(torch.float32).to(self.device),
ptype_2=torch.from_numpy(self.fit_molec.pharm_types).to(torch.float32).to(self.device),
anchors_1=torch.from_numpy(self.ref_molec.pharm_ancs).to(torch.float32).to(self.device),
anchors_2=torch.from_numpy(self.fit_molec.pharm_ancs).to(torch.float32).to(self.device),
vectors_1=torch.from_numpy(self.ref_molec.pharm_vecs).to(torch.float32).to(self.device),
vectors_2=torch.from_numpy(self.fit_molec.pharm_vecs).to(torch.float32).to(self.device),
similarity=similarity,
extended_points=extended_points,
only_extended=only_extended
)
return score.cpu().numpy()
elif use in ('np', 'numpy'):
score = get_overlap_pharm_np(
ptype_1=self.ref_molec.pharm_types,
ptype_2=self.fit_molec.pharm_types,
anchors_1=self.ref_molec.pharm_ancs,
anchors_2=self.fit_molec.pharm_ancs,
vectors_1=self.ref_molec.pharm_vecs,
vectors_2=self.fit_molec.pharm_vecs,
similarity=similarity,
extended_points=extended_points,
only_extended=only_extended
)
return score
elif use in ('jax', 'jnp'):
if 'jax' not in sys.modules or 'jax.numpy' not in sys.modules:
try:
import jax.numpy as jnp
except ImportError:
raise ImportError('jax.numpy and torch is required for this function. Install Jax or just use Torch.')
import jax.numpy as jnp
from shepherd_score.score.pharmacophore_scoring_jax import get_overlap_pharm_jax
score = get_overlap_pharm_jax(
ptype_1=jnp.array(self.ref_molec.pharm_types),
ptype_2=jnp.array(self.fit_molec.pharm_types),
anchors_1=jnp.array(self.ref_molec.pharm_ancs),
anchors_2=jnp.array(self.fit_molec.pharm_ancs),
vectors_1=jnp.array(self.ref_molec.pharm_vecs),
vectors_2=jnp.array(self.fit_molec.pharm_vecs),
similarity=similarity,
extended_points=extended_points,
only_extended=only_extended
)
return np.array(score)