Alignment Utilities#

SE3 Transformations (PyTorch)#

Functions used for SE(3) transformations. (Torch implementation). Has support for operations with batches.

Namely, converting quaternions to rotation matrices, getting an SE(3) transform from SE(3) parameters, and applying the SE(3) transformation on a set of points.

Credit to Lewis J. Martin as this was adapted from ljmartin/align and PyTorch’s implementations.

shepherd_score.alignment.utils.se3.quaternions_to_rotation_matrix(quaternions)[source]#

Converts quaternion to a rotation matrix. Supports batched and non-batched inputs. Adapted from PyTorch3D: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#quaternion_to_matrix

Parameters:

quaternions (torch.Tensor (batch, 4) or (4,)) – Quaternion parameters in (r,i,j,k) order. Accepts single set of parameters or a batched set.

Returns:

rotation_matrix – Rotation matrix converted from quaternion in batched or single instance form.

Return type:

torch.Tensor (batch, 3, 3) or (3,3)

shepherd_score.alignment.utils.se3.get_SE3_transform(se3_params)[source]#

Constructs an SE(3) transformtion matrix from parameters. Supports batched and non-batched inputs.

Parameters:

se3_params (torch.Tensor (batch, 7) or (7,)) – Parameters for SE(3) transformation. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).

Returns:

se3_matrix – se3_params converted to a 4x4 SE(3) transformation matrix.

Return type:

torch.Tensor (batch, 4, 4) or (4, 4)

shepherd_score.alignment.utils.se3.apply_SE3_transform(points, SE3_transform)[source]#

Takes a point cloud and transforms it according to the provided SE3 transformation matrix. Supports batched and non-batched inputs.

Parameters:
  • points (torch.Tensor (batch, N, 3) or (N, 3)) – Set of coordinates representing a point cloud.

  • SE3_transform (torch.Tensor (batch, 4, 4) or (4, 4)) – SE(3) transformation matrix. If ‘points’ argument is batched, this one should be too.

Returns:

transformed_points – Set of coordinates transformed by the corresponding SE(3) transformation.

Return type:

torch.Tensor (batch, N, 3) or (N, 3)

shepherd_score.alignment.utils.se3.apply_SO3_transform(points, SE3_transform)[source]#

Takes a point cloud and ONLY ROTATES it according to the provided SE3 transformation matrix. Supports batched and non-batched inputs.

Parameters:
  • points (torch.Tensor (batch, N, 3) or (N, 3)) – Set of coordinates representing a point cloud.

  • SE3_transform (torch.Tensor (batch, 4, 4) or (4, 4)) – SE(3) transformation matrix. If ‘points’ argument is batched, this one should be too.

Returns:

rotated_points – Set of coordinates rotated by the rotation component of the SE(3) transformation.

Return type:

torch.Tensor (batch, N, 3) or (N, 3)

SE3 Transformations (NumPy)#

Functions used for SE(3) transformations. (NumPy implementation).

Namely, converting quaternions to rotation matrices, getting an SE(3) transform from SE(3) parameters, and applying the SE(3) transformation on a set of points.

Credit to Lewis J. Martin as this was adapted from ljmartin/align and PyTorch’s implementations.

shepherd_score.alignment.utils.se3_np.quaternions_to_rotation_matrix_np(quaternions)[source]#

Converts quaternion to a rotation matrix. Adapted from PyTorch3D: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#quaternion_to_matrix

Parameters:

quaternions (np.ndarray (4,)) – Quaternion parameters in (r,i,j,k) order. set.

Returns:

rotation_matrix – Rotation matrix converted from quaternion.

Return type:

np.ndarray (3,3)

shepherd_score.alignment.utils.se3_np.get_SE3_transform_np(se3_params)[source]#

Constructs an SE(3) transformtion matrix from parameters. NumPy implementation

Parameters:

se3_params (np.ndarray (7,)) – Parameters for SE(3) transformation. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).

Returns:

se3_matrix – se3_params converted to a 4x4 SE(3) transformation matrix.

Return type:

np.ndarray (4, 4)

shepherd_score.alignment.utils.se3_np.apply_SE3_transform_np(points, SE3_transform)[source]#

Takes a point cloud and transforms it according to the provided SE3 transformation matrix. NumPy implementation.

Parameters:
  • points (np.ndarray (N, 3)) – Set of coordinates representing a point cloud.

  • SE3_transform (np.ndarray (4, 4)) – SE(3) transformation matrix.

Returns:

transformed_points – Set of coordinates transformed by the corresponding SE(3) transformation.

Return type:

np.ndarray (N, 3)

shepherd_score.alignment.utils.se3_np.apply_SO3_transform_np(points, SE3_transform)[source]#

Takes a point cloud and ONLY ROTATES it according to the provided SE3 transformation matrix. Supports batched and non-batched inputs.

Parameters:
  • points (np.array (N, 3)) – Set of coordinates representing a point cloud.

  • SE3_transform ((4, 4)) – SE(3) transformation matrix. If ‘points’ argument is batched, this one should be too.

Returns:

rotated_points – Set of coordinates rotated by the rotation component of the SE(3) transformation.

Return type:

torch.Tensor (batch, N, 3) or (N, 3)

SE3 Transformations (JAX)#

Functions used for SE(3) transformations. (Jax implementation).

Namely, converting quaternions to rotation matrices, getting an SE(3) transform from SE(3) parameters, and applying the SE(3) transformation on a set of points.

Credit to Lewis J. Martin as this was adapted from ljmartin/align and PyTorch’s implementations.

shepherd_score.alignment.utils.se3_jax.quaternions_to_rotation_matrix_jax(quaternions)[source]#

Converts quaternion to a rotation matrix. Jax implementation Adapted from PyTorch3D: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html#quaternion_to_matrix

Parameters:

quaternions (Array (4,)) – Quaternion parameters in (r,i,j,k) order. set.

Returns:

rotation_matrix – Rotation matrix converted from quaternion.

Return type:

Array (3,3)

shepherd_score.alignment.utils.se3_jax.get_SE3_transform_jax(se3_params)[source]#

Constructs an SE(3) transformtion matrix from parameters. Jax implementation

Parameters:

se3_params (Array (7,)) – Parameters for SE(3) transformation. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).

Returns:

se3_matrix – se3_params converted to a 4x4 SE(3) transformation matrix.

Return type:

Array (4, 4)

shepherd_score.alignment.utils.se3_jax.apply_SE3_transform_jax(points, SE3_transform)[source]#

Takes a point cloud and transforms it according to the provided SE3 transformation matrix. Jax implementation.

Parameters:
  • points (Array (N, 3)) – Set of coordinates representing a point cloud.

  • SE3_transform (Array (4, 4)) – SE(3) transformation matrix.

Returns:

transformed_points – Set of coordinates transformed by the corresponding SE(3) transformation.

Return type:

Array (N, 3)

PCA Utilities (PyTorch)#

Torch implementations for principal component alignment (pca). Written to handle batching.

IT IS RECOMMENDED TO USE THE NUMPY VERSION. Using the numpy version of quaternions_for_principal_component_alignment is faster (~2.5ms vs ~5ms). Further, the parallel GPU version is slightly slower (50us) than the serial, CPU version.

Credit to Lewis J. Martin as this was adapted from ljmartin/align

shepherd_score.alignment.utils.pca.compute_moment_of_inertia(points)[source]#

Computes the moment of inertia of a set of points. A = x^2 + y^2 + z^2 B = X^T X

Parameters:

points (torch.Tensor)

Return type:

torch.Tensor

shepherd_score.alignment.utils.pca.compute_principal_moments_of_interia(points)[source]#

Compute the principal moments of inertia of a set of points.

Parameters:

points (torch.Tensor)

Return type:

torch.Tensor

shepherd_score.alignment.utils.pca.angle_between_vecs(v1, v2)[source]#

Compute the angle in radians between two vectors (already normalized).

shepherd_score.alignment.utils.pca.rotation_axis(v1, v2)[source]#

Calculate the vector about which to order to rotate a to align with b (cross product).

shepherd_score.alignment.utils.pca.quaternion_from_axis_angle(axis, angle)[source]#

Create a Quaternion from a rotation axis and an angle in radians.

Parameters:
  • axis (torch.Tensor (3,)) – Axis to rotate about.

  • angle (torch.Tensor (1,)) – Angle in radians.

Returns:

quaternion

Return type:

torch.Tensor (4,)

shepherd_score.alignment.utils.pca.quaternion_mult(p, q)[source]#

Multiplication of quaternions p and q.

Reference: https://academicflight.com/articles/kinematics/rotation-formalisms/quaternions/

General use case: The consecutive rotations of q_1 then q_2 is equivalent to q_3 = q_2*q_1. (order matters)

Parameters:
  • p (torch.Tensor) – The first quaternion with shape (4,) or (batch, 4).

  • q (torch.Tensor) – The second quaternion with shape (4,) or (batch, 4).

Returns:

The product of the two quaternions with shape (4,) or (batch, 4).

Return type:

torch.Tensor

shepherd_score.alignment.utils.pca.quaternions_for_principal_component_alignment(ref_points, fit_points)[source]#

Computes the 4 quaternions required for alignment of the fit mol along the principal components of the reference mol.

The computed quaternions assumes that fit_points will be rotated after being centered at COM.

Parameters:
  • ref_points (torch.Tensor)

  • fit_points (torch.Tensor)

Return type:

torch.Tensor

PCA Utilities (NumPy)#

Numpy Implementations for principal component alignment (pca). Using the numpy version of quaternions_for_principal_component_alignment is faster than the torch implementation (~2.5ms vs ~5ms).

Credit to Lewis J. Martin as this was adapted from ljmartin/align

shepherd_score.alignment.utils.pca_np.compute_moment_of_inertia_np(points)[source]#

Computes the moment of inertia tensor for a set of points. Numpy implementation. A = x^2 + y^2 + z^2 B = X^T X

Parameters:

points (ndarray)

Return type:

ndarray

shepherd_score.alignment.utils.pca_np.compute_principal_moments_of_interia_np(points)[source]#

Calculate principal moment of inertia. Numpy implementation.

Parameters:

points (ndarray)

Return type:

ndarray

shepherd_score.alignment.utils.pca_np.angle_between_vecs_np(v1, v2)[source]#

Compute the angle in radians between two vectors. Numpy implementation.

Parameters:
Return type:

ndarray

shepherd_score.alignment.utils.pca_np.rotation_axis_np(v1, v2)[source]#

Calculate the vector about which to order to rotate a to align with b (cross product). Numpy implementation.

Parameters:
Return type:

ndarray

shepherd_score.alignment.utils.pca_np.quaternion_from_axis_angle_np(axis, angle)[source]#

Create a Quaternion from a rotation axis and an angle in radians. Numpy implementation.

Parameters:
  • axis (np.ndarray (3,)) – Axis to rotate about.

  • angle (np.ndarray (1,)) – Angle in radians.

Returns:

quaternion

Return type:

np.ndarray (4,)

shepherd_score.alignment.utils.pca_np.quaternion_mult_np(p, q)[source]#

Multiplication of quaternions p and q. Numpy implementation.

Reference: https://academicflight.com/articles/kinematics/rotation-formalisms/quaternions/

General use case: The consecutive rotations of q_1 then q_2 is equivalent to q_3 = q_2*q_1. (order matters)

Parameters:
  • p (np.ndarray) – The first quaternion with shape (4,).

  • q (np.ndarray) – The second quaternion with shape (4,).

Returns:

The product of the two quaternions with shape (4,).

Return type:

np.ndarray

shepherd_score.alignment.utils.pca_np.quaternions_for_principal_component_alignment_np(ref_points, fit_points)[source]#

Computes the 4 quaternions required for alignment of the fit mol along the principal components of the reference mol. NumPy implementation.

The computed quaternions assumes that fit_points will be rotated after being centered at COM.

Parameters:
  • ref_points (np.ndarray (N, 3)) – Set of reference points that fit_points will be aligned to.

  • fit_points (np.ndarray (M, 3)) – Set of points that will be aligned to ref_points.

Returns:

quaternions – Set of four quaternions corresponding to the alignment of fit_points to ref_points in the four possible principal component combinations.

Return type:

np.ndarray (4, 4)

PCA Utilities (JAX)#

Jax Implementations for principal component alignment (pca).

Credit to Lewis J. Martin as this was adapted from ljmartin/align

shepherd_score.alignment.utils.pca_jax.compute_moment_of_inertia_jax(points)[source]#

Computes the moment of inertia tensor for a set of points. Jax implementation. A = x^2 + y^2 + z^2 B = X^T X

Parameters:

points (jax.Array)

Return type:

jax.Array

shepherd_score.alignment.utils.pca_jax.compute_principal_moments_of_interia_jax(points)[source]#

Calculate principal moment of inertia. Jax implementation.

Parameters:

points (jax.Array)

Return type:

jax.Array

shepherd_score.alignment.utils.pca_jax.angle_between_vecs_jax(v1, v2)[source]#

Compute the angle in radians between two vectors. Jax implementation.

Parameters:
  • v1 (jax.Array)

  • v2 (jax.Array)

Return type:

jax.Array

shepherd_score.alignment.utils.pca_jax.vmap_angle_between_vecs_jax(v1, v2)#

Compute the angle in radians between two vectors. Jax implementation.

Parameters:
  • v1 (jax.Array)

  • v2 (jax.Array)

Return type:

jax.Array

shepherd_score.alignment.utils.pca_jax.rotation_axis_jax(v1, v2)[source]#

Calculate the vector about which to order to rotate a to align with b (cross product). Jax implementation.

Parameters:
  • v1 (jax.Array)

  • v2 (jax.Array)

Return type:

jax.Array

shepherd_score.alignment.utils.pca_jax.quaternion_from_axis_angle_jax(axis, angle)[source]#

Create a Quaternion from a rotation axis and an angle in radians. Jax implementation.

Parameters:
  • axis (Array (3,)) – Axis to rotate about.

  • angle (Array (1,)) – Angle in radians.

Returns:

quaternion

Return type:

Array (4,)

shepherd_score.alignment.utils.pca_jax.vmap_quaternion_from_axis_angle_jax(axis, angle)#

Create a Quaternion from a rotation axis and an angle in radians. Jax implementation.

Parameters:
  • axis (Array (3,)) – Axis to rotate about.

  • angle (Array (1,)) – Angle in radians.

Returns:

quaternion

Return type:

Array (4,)

shepherd_score.alignment.utils.pca_jax.quaternion_mult_jax(p, q)[source]#

Multiplication of quaternions p and q. Jax implementation.

Reference: https://academicflight.com/articles/kinematics/rotation-formalisms/quaternions/

General use case: The consecutive rotations of q_1 then q_2 is equivalent to q_3 = q_2*q_1. (order matters)

Parameters:
  • p (Array) – The first quaternion with shape (4,).

  • q (Array) – The second quaternion with shape (4,).

Returns:

The product of the two quaternions with shape (4,).

Return type:

Array

shepherd_score.alignment.utils.pca_jax.vmap_quaternion_mult_jax(p, q)#

Multiplication of quaternions p and q. Jax implementation.

Reference: https://academicflight.com/articles/kinematics/rotation-formalisms/quaternions/

General use case: The consecutive rotations of q_1 then q_2 is equivalent to q_3 = q_2*q_1. (order matters)

Parameters:
  • p (Array) – The first quaternion with shape (4,).

  • q (Array) – The second quaternion with shape (4,).

Returns:

The product of the two quaternions with shape (4,).

Return type:

Array

shepherd_score.alignment.utils.pca_jax.quaternions_for_principal_component_alignment_jax(ref_points, fit_points)[source]#

Computes the 4 quaternions required for alignment of the fit mol along the principal components of the reference mol. NumPy implementation.

The computed quaternions assumes that fit_points will be rotated after being centered at COM.

Parameters:
  • ref_points (Array (N, 3)) – Set of reference points that fit_points will be aligned to.

  • fit_points (Array (M, 3)) – Set of points that will be aligned to ref_points.

Returns:

quaternions – Set of four quaternions corresponding to the alignment of fit_points to ref_points in the four possible principal component combinations.

Return type:

Array (4, 4)