Source code for shepherd_score.alignment.utils.pca_np

"""
Numpy Implementations for principal component alignment (pca).
Using the numpy version of quaternions_for_principal_component_alignment is faster than
the torch implementation (~2.5ms vs ~5ms).

Credit to Lewis J. Martin as this was adapted from
https://github.com/ljmartin/align/blob/main/0.2%20aligning%20principal%20moments%20of%20inertia.ipynb
"""
import numpy as np
from shepherd_score.alignment.utils.se3_np import get_SE3_transform_np, apply_SE3_transform_np

[docs] def compute_moment_of_inertia_np(points: np.ndarray) -> np.ndarray: """ Computes the moment of inertia tensor for a set of points. Numpy implementation. A = x^2 + y^2 + z^2 B = X^T X """ points = points - np.mean(points, axis=0) A = np.sum(points**2) B = points.T @ points eye = np.eye(3) return (A * eye - B) / points.shape[0]
[docs] def compute_principal_moments_of_interia_np(points: np.ndarray) -> np.ndarray: """ Calculate principal moment of inertia. Numpy implementation. """ momint = compute_moment_of_inertia_np(points) eigvals, eigvecs = np.linalg.eigh(momint) indices = np.argsort(-eigvals) #sorting it returns the 'long' axis in index 0. # Return transposed which is more intuitive format return eigvecs[:, indices].T
[docs] def angle_between_vecs_np(v1: np.ndarray, v2: np.ndarray) -> np.ndarray: """ Compute the angle in radians between two vectors. Numpy implementation. """ v1_u = v1 / np.linalg.norm(v1) v2_u = v2 / np.linalg.norm(v2) return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
[docs] def rotation_axis_np(v1: np.ndarray, v2: np.ndarray) -> np.ndarray: """ Calculate the vector about which to order to rotate `a` to align with `b` (cross product). Numpy implementation. """ if np.allclose(v1, v2): return np.array([1, 0, 0]) v3 = np.cross(v1, v2) return v3 / np.linalg.norm(v3)
[docs] def quaternion_from_axis_angle_np(axis: np.ndarray, angle: np.ndarray) -> np.ndarray: """ Create a Quaternion from a rotation axis and an angle in radians. Numpy implementation. Parameters ---------- axis : np.ndarray (3,) Axis to rotate about. angle: np.ndarray (1,) Angle in radians. Returns ------- quaternion : np.ndarray (4,) """ mag_sq = np.dot(axis, axis) if mag_sq == 0.0: raise ZeroDivisionError("Provided rotation axis has no length") # Ensure axis is in unit vector form if (abs(1.0 - mag_sq) > 1e-12): axis = axis / np.sqrt(mag_sq) theta = angle / 2.0 r = np.cos(theta) i = axis * np.sin(theta) return np.array([r, i[0], i[1], i[2]])
[docs] def quaternion_mult_np(p: np.ndarray, q: np.ndarray) -> np.ndarray: """ Multiplication of quaternions p and q. Numpy implementation. Reference: https://academicflight.com/articles/kinematics/rotation-formalisms/quaternions/ General use case: The consecutive rotations of q_1 then q_2 is equivalent to q_3 = q_2*q_1. (order matters) Parameters ---------- p : np.ndarray The first quaternion with shape (4,). q : np.ndarray The second quaternion with shape (4,). Returns ------- np.ndarray The product of the two quaternions with shape (4,). """ mat1 = np.array([[p[0], -p[1], -p[2], -p[3]], [p[1], p[0], -p[3], p[2]], [p[2], p[3], p[0], -p[1]], [p[3], -p[2], p[1], p[0]]]) pq = mat1 @ q return pq
[docs] def quaternions_for_principal_component_alignment_np(ref_points: np.ndarray, fit_points: np.ndarray ) -> np.ndarray: """ Computes the 4 quaternions required for alignment of the fit mol along the principal components of the reference mol. NumPy implementation. The computed quaternions assumes that fit_points will be rotated after being centered at COM. Parameters ---------- ref_points : np.ndarray (N, 3) Set of reference points that fit_points will be aligned to. fit_points : np.ndarray (M, 3) Set of points that will be aligned to ref_points. Returns ------- quaternions : np.ndarray (4, 4) Set of four quaternions corresponding to the alignment of fit_points to ref_points in the four possible principal component combinations. """ pmi_ref = compute_principal_moments_of_interia_np(ref_points) quaternions = np.zeros((4,4)) for q_index in range(4): if q_index == 1: # flip orientation of longest axis pmi_ref[0] = -pmi_ref[0] elif q_index == 2: # unflip orientation of longest axis pmi_ref[0] = -pmi_ref[0] # flip orientation of 2nd longest axis pmi_ref[1] = -pmi_ref[1] elif q_index == 3: # flip orientation of both axes pmi_ref[0] = -pmi_ref[0] quat_order = np.zeros((2,4)) # Initially center to COM fit_points_adjust = fit_points - np.mean(fit_points, axis=0) for ax_idx in range(2): pmi_fit = compute_principal_moments_of_interia_np(fit_points_adjust) # Angle between principal axis of fit mol and referencne mol angle = angle_between_vecs_np(pmi_fit[ax_idx], pmi_ref[ax_idx]) # Axis that we are rotating about ax = rotation_axis_np(pmi_fit[ax_idx], pmi_ref[ax_idx]) # Quaternion quat_order[ax_idx] = quaternion_from_axis_angle_np(ax, angle) # get SE(3) transformation matrix se3_params = np.concatenate((quat_order[ax_idx], np.zeros(3))) # get transformed matrix fit_points_adjust = apply_SE3_transform_np(fit_points_adjust, get_SE3_transform_np(se3_params)) quaternions[q_index] = quaternion_mult_np(quat_order[1], quat_order[0]) return quaternions