JAX Alignment#
Sequential JAX-based alignment functions. For multi-device parallel
alignment via jax.shard_map see JAX Parallel Alignment.
Alignment implementation in Jax.
- shepherd_score.alignment._jax.apply_SO3_transform_jax(vectors, se3_matrix)[source]#
Apply SO(3) transformation (rotation) to a set of vectors.
- Parameters:
vectors (jax.Array)
se3_matrix (jax.Array)
- Return type:
jax.Array
- shepherd_score.alignment._jax.vmap_apply_SO3_transform_jax(vectors, se3_matrix)#
Apply SO(3) transformation (rotation) to a set of vectors.
- Parameters:
vectors (jax.Array)
se3_matrix (jax.Array)
- Return type:
jax.Array
- shepherd_score.alignment._jax.batched_obj_ROCS_overlay_helper(se3_params, ref_points, fit_points, alpha)#
Objective function to optimize ROCS overlay. Jax implementation.
- Parameters:
se3_params (Array (7,)) – Parameters for SE(3) transformation. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points (Array (N,3)) – Reference points.
fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
alpha (float) – Gaussian width parameter used in scoring function.
- Return type:
Tanimoto overlap score
- shepherd_score.alignment._jax.vmap_score_ROCS_overlay_with_avoid_jax(ref_points, fit_points, alpha, fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)#
See _objective_ROCS_overlay_with_avoid_jax.
- shepherd_score.alignment._jax.batched_obj_ROCS_overlay_with_avoid_helper(se3_params, ref_points, fit_points, alpha, fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)#
Objective function to optimize ROCS overlay while avoiding certain points. Jax implementation.
- Parameters:
se3_params (Array (7,)) – Parameters for SE(3) transformation. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points (Array (N,3)) – Reference points.
fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
alpha (float) – Gaussian width parameter used in scoring function.
fit_points_for_avoid (Array (M,3)) – Set of points to apply SE(3) transformations to then compare to avoid_points
avoid_points (Array (K,3)) – Penalize overlap with these points
avoid_min_dist (float) – Minimum distance with no penalization between fit_points_for_avoid and avoid_points.
avoid_weight (float) – Weight for the avoid_points term in the scoring function.
- Returns:
score in range [-avoid_weight, 1] where higher is better. 1 is complete fit and ref overlap
with no avoid overlap, -avoid_weight is no fit and ref overlap and complete fit avoid overlap.
- Return type:
jax.Array
- shepherd_score.alignment._jax.objective_ROCS_overlay_jax(se3_params, ref_points, fit_points, alpha)[source]#
Objective function to optimize ROCS overlay. Jax implementation.
- Parameters:
se3_params (Array (batch, 7)) – Parameters for SE(3) transformation. Expects batch. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points (Array (N,3)) – Reference points. (NOT batched since it assumes the same reference points).
fit_points (Array (batch, M,3)) – Expects batch. Set of points to apply SE(3) transformations to maximize shape similarity with ref_points. If you want to optimize to the same fit_points, with a batch of different se3_params, try use jnp.tile(fit_points, (batch, 1, 1)).
alpha (float) – Gaussian width parameter used in scoring function.
- Returns:
loss – 1 - Tanimoto score
- Return type:
Array (1,)
- shepherd_score.alignment._jax.objective_ROCS_overlay_with_avoid_jax(se3_params, ref_points, fit_points, alpha, fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)[source]#
Objective function to optimize ROCS overlay. Includes points where overlap is a negative. Jax implementation.
- Parameters:
se3_params (Array (batch, 7)) – Parameters for SE(3) transformation. Expects batch. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points (Array (N,3)) – Reference points. (NOT batched since it assumes the same reference points).
fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
alpha (float) – Gaussian width parameter used in scoring function.
fit_points_for_avoid (Array (M,3)) – Set of points to apply SE(3) transformations to then compare to avoid_points
avoid_points (Array (K,3)) – Penalize overlap with these points
avoid_points – Penalize overlap with these points
avoid_min_dist (float) – Minimum distance with no penalization between fit_points_for_avoid and avoid_points.
avoid_weight (float) – Weight for the avoid_points term in the scoring function.
- Returns:
loss – 1 - (Tanimoto score of fit/ref) + avoid_weight * (max pairwise overlap of fit/avoid)
- Return type:
Array (1,)
- shepherd_score.alignment._jax.batched_obj_ROCS_overlay_precomputed(se3_params, ref_points, fit_points, alpha, VAA, VBB)#
Single-instance ROCS objective using precomputed self-overlaps.
- shepherd_score.alignment._jax.objective_ROCS_overlay_precomputed_jax(se3_params, ref, fit, alpha, VAA, VBB)[source]#
- shepherd_score.alignment._jax.batched_obj_ROCS_overlay_precomputed_mask(se3_params, ref_points, fit_points, mask_ref, mask_fit, alpha, VAA, VBB)#
Single-instance masked ROCS objective using precomputed self-overlaps.
- shepherd_score.alignment._jax.objective_ROCS_overlay_precomputed_jax_mask(se3_params, ref, fit, mask_ref, mask_fit, alpha, VAA, VBB)[source]#
- shepherd_score.alignment._jax.batched_obj_ROCS_esp_overlay_precomputed(se3_params, ref_points, fit_points, ref_charges, fit_charges, alpha, lam, VAA, VBB)#
Single-instance non-masked ROCS ESP objective using precomputed self-overlaps.
- shepherd_score.alignment._jax.objective_ROCS_esp_overlay_precomputed_jax(se3_params, ref, fit, ref_charges, fit_charges, alpha, lam, VAA, VBB)[source]#
- shepherd_score.alignment._jax.batched_obj_ROCS_esp_overlay_precomputed_mask(se3_params, ref_points, fit_points, ref_charges, fit_charges, mask_ref, mask_fit, alpha, lam, VAA, VBB)#
Single-instance masked ROCS ESP objective using precomputed self-overlaps.
- shepherd_score.alignment._jax.objective_ROCS_esp_overlay_precomputed_jax_mask(se3_params, ref, fit, ref_charges, fit_charges, mask_ref, mask_fit, alpha, lam, VAA, VBB)[source]#
- shepherd_score.alignment._jax.optimize_ROCS_esp_overlay_jax_mask(ref_points, fit_points, ref_charges, fit_charges, mask_ref, mask_fit, alpha, lam, *, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#
Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and maximizing masked electrostatic-weighted Gaussian overlap score.
Identical to
optimize_ROCS_esp_overlay_jaxbut accepts binary mask arrays so that padded (zero) entries are excluded from the overlap computation. Padding all arrays to a common maximum shape and passing masks allows JAX’s XLA compiler to reuse a single compiled function across all pairs in a batch, avoiding recompilation overhead.- Parameters:
ref_points (Array (N, 3)) – Reference points (may include zero-padding beyond mask_ref).
fit_points (Array (M, 3)) – Fit points (may include zero-padding beyond mask_fit).
ref_charges (Array (N,)) – Charges for reference points.
fit_charges (Array (M,)) – Charges for fit points.
mask_ref (Array (N,)) – Binary mask: 1 for real atoms, 0 for padding.
mask_fit (Array (M,)) – Binary mask: 1 for real atoms, 0 for padding.
alpha (float) – Gaussian width parameter.
lam (float) – Charge weighting parameter.
num_repeats (int (default=50))
trans_centers (array (P, 3) (default=None))
lr (float (default=0.1))
max_num_steps (int (default=200))
verbose (bool (default=False))
- Returns:
aligned_points : Array (M, 3) SE3_transform : Array (4, 4) score : Array (1,)
- Return type:
- shepherd_score.alignment._jax.batched_obj_ROCS_overlay_with_avoid_precomputed(se3_params, ref_points, fit_points, alpha, VAA, VBB, fit_points_for_avoid, avoid_points, avoid_min_dist, avoid_weight)#
- shepherd_score.alignment._jax.objective_ROCS_overlay_with_avoid_precomputed_jax(se3_params, ref, fit, alpha, VAA, VBB, fit_for_avoid, avoid, min_dist, weight)[source]#
- shepherd_score.alignment._jax.optimize_ROCS_overlay_jax(ref_points, fit_points, alpha, *, fit_points_for_avoid=None, avoid_points=None, avoid_min_dist=2.0, avoid_weight=1.0, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#
Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and maximizing gaussian overlap score.
If num_repeats is 1, the initial guess for alignment is an identity rotation and aligned COMs. If num_repeats is 5 or greater, four initial guesses are aligned using principal components.
- Parameters:
ref_points (Array (N,3)) – Reference points.
fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
alpha (float) – Gaussian width parameter used in scoring function.
fit_points_for_avoid (Array (M,3)) – Set of points to apply SE(3) transformations to then compare to avoid_points
avoid_points (Array (K,3) (default=None)) – If not None, these are points that are used in an additional term in the objective function to penalize overlap with these points.
avoid_min_dist (float (default=2.0)) – Minimum distance with no penalization between fit_points_for_avoid and avoid_points.
avoid_weight (float (default=1.0)) – Weight for the avoid_points term in the scoring function.
num_repeats (int (default=50)) – Number of different random initializations of SE(3) transformation parameters.
trans_centers (array (P, 3) (default=None)) – Locations to translate fit_points’ center of mass as an initial guesses for optimization. At each translation center, 10 rotations are also sampled. So the number of initializations scales as (# translation centers * 10 + 5) where 5 is from the identity and 4 PCA with aligned COM’s. If None, then num_repeats rotations are done with aligned COM’s.
lr (float (default=0.1)) – Learning rate or step-size for optimization
max_num_steps (int (default=200)) – Maximum number of steps to optimize over.
verbose (bool (False)) – Print statements about initial and final similarity scores. Further, it will print scores during optimization at very 100 steps.
- Returns:
- aligned_pointsArray (M,3)
The transformed point cloud for fit_points using the optimized SE(3) transformation for alignment with ref_points.
- SE3_transformArray (4,4)
Optimized SE(3) transformation matrix used to obtain aligned_points from fit_points.
- scoreArray (1,)
Tanimoto shape similarity score for the optimal transformation.
- Return type:
- shepherd_score.alignment._jax.optimize_ROCS_overlay_jax_mask(ref_points, fit_points, mask_ref, mask_fit, alpha, *, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#
Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and maximizing masked Gaussian overlap score.
Identical to
optimize_ROCS_overlay_jaxbut accepts binary mask arrays so that padded (zero) entries are excluded from the overlap computation. Padding all arrays to a common maximum shape and passing masks allows JAX’s XLA compiler to reuse a single compiled function across all pairs in a batch, avoiding recompilation overhead.- Parameters:
ref_points (Array (N, 3)) – Reference points (may include zero-padding beyond mask_ref).
fit_points (Array (M, 3)) – Fit points (may include zero-padding beyond mask_fit).
mask_ref (Array (N,)) – Binary mask: 1 for real atoms, 0 for padding.
mask_fit (Array (M,)) – Binary mask: 1 for real atoms, 0 for padding.
alpha (float) – Gaussian width parameter.
num_repeats (int (default=50))
trans_centers (array (P, 3) (default=None))
lr (float (default=0.1))
max_num_steps (int (default=200))
verbose (bool (default=False))
- Returns:
aligned_points : Array (M, 3) SE3_transform : Array (4, 4) score : Array (1,)
- Return type:
- shepherd_score.alignment._jax.batched_obj_ROCS_esp_overlay_helper(se3_params, ref_points, fit_points, ref_charges, fit_charges, alpha, lam)#
Objective function to optimize ROCS esp overlay. Jax implementation.
- Parameters:
se3_params (Array (7,)) – Parameters for SE(3) transformation. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points (Array (N,3)) – Reference points.
fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
alpha (float) – Gaussian width parameter used in scoring function.
ref_charges (jax.Array)
fit_charges (jax.Array)
lam (float)
- Returns:
loss – 1 - Tanimoto score
- Return type:
Array (1,)
- shepherd_score.alignment._jax.objective_ROCS_esp_overlay_jax(se3_params, ref_points, fit_points, ref_charges, fit_charges, alpha, lam)[source]#
Objective function to optimize ROCS esp overlay.
- Parameters:
se3_params (Array (batch, 7)) – Parameters for SE(3) transformation. The first 4 values in the last dimension are quaternions of form (r,i,j,k) and the last 3 values of the last dimension are the translations in (x,y,z).
ref_points (Array (N,3)) – Reference points.
fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
ref_charges (Array (N,)) – Electric potential at the corresponding ref_points coordinates.
fit_charges (Array (M,)) – Electric potential at the corresponding fit_points coordinates
alpha (float) – Gaussian width parameter used in scoring function.
lam (float) – Scaling term for charges used in the exponential kernel of the ESP scoring function.
- Returns:
loss – 1 - mean(ESP Tanimoto score).
- Return type:
Array (1,)
- shepherd_score.alignment._jax.optimize_ROCS_esp_overlay_jax(ref_points, fit_points, ref_charges, fit_charges, alpha, lam, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#
Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and maximizing electrostatic-weighted gaussian overlap score.
- Parameters:
ref_points (Array (N,3)) – Reference points.
fit_points (Array (M,3)) – Set of points to apply SE(3) transformations to maximize shape similarity with ref_points.
ref_charges (Array (batch, N) or (N,)) – Electric potential at the corresponding ref_points coordinates.
fit_charges (Array (batch, N) or (N,)) – Electric potential at the corresponding fit_points coordinates
alpha (float) – Gaussian width parameter used in scoring function.
lam (float) – Scaling term for charges used in the exponential kernel of the ESP scoring function.
num_repeats (int (default=50)) – Number of different random initializations of SE(3) transformation parameters.
trans_centers (array (P, 3) (default=None)) – Locations to translate fit_points’ center of mass as an initial guesses for optimization. At each translation center, 10 rotations are also sampled. So the number of initializations scales as (# translation centers * 10 + 5) where 5 is from the identity and 4 PCA with aligned COM’s. If None, then num_repeats rotations are done with aligned COM’s.
lr (float (default=0.1)) – Learning rate or step-size for optimization
max_num_steps (int (default=200)) – Maximum number of steps to optimize over.
verbose (bool (False)) – Print statements about initial and final similarity scores. Further, it will print scores during optimization at very 100 steps.
- Returns:
- aligned_pointsArray (M,3)
The transformed point cloud for fit_points using the optimized SE(3) transformation for alignment with ref_points.
- SE3_transformArray (4,4)
Optimized SE(3) transformation matrix used to obtain aligned_points from fit_points.
- scoreArray (1,)
Tanimoto shape+ESP similarity score for the optimal transformation.
- Return type:
- shepherd_score.alignment._jax.batched_obj_esp_combo_score_helper(se3_params, ref_centers_w_H, fit_centers_w_H, ref_centers, fit_centers, ref_points, fit_points, ref_partial_charges, fit_partial_charges, ref_surf_esp, fit_surf_esp, ref_radii, fit_radii, alpha, lam, probe_radii=1.0, esp_weight=0.5)#
Helper function to apply se3_param transformations to all fit related coordinates. Compute the score for that transformation.
- Return type:
jax.Array
- shepherd_score.alignment._jax.objective_esp_combo_score_overlay_jax(se3_params, ref_centers_w_H, fit_centers_w_H, ref_centers, fit_centers, ref_points, fit_points, ref_partial_charges, fit_partial_charges, ref_surf_esp, fit_surf_esp, ref_radii, fit_radii, alpha, lam, probe_radii=1.0, esp_weight=0.5)[source]#
Computes the esp combo score in batch, takes the mean and convert to a loss.
- Return type:
jax.Array
- shepherd_score.alignment._jax.optimize_esp_combo_score_overlay_jax(ref_centers_w_H, fit_centers_w_H, ref_centers, fit_centers, ref_points, fit_points, ref_partial_charges, fit_partial_charges, ref_surf_esp, fit_surf_esp, ref_radii, fit_radii, alpha, lam, probe_radius=1.0, esp_weight=0.5, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#
Optimize alignment of fit_points with respect to ref_points using SE(3) transformations and maximizing ShaEP score.
- Parameters:
ref_centers_w_H (Array (N + n_H, 3)) – Coordinates of atom centers INCLUDING hydrogens of reference molecule. Used for computing electrostatic potential. Same for fit_centers_w_H except (M + m_H, 3).
ref_centers (Array (N, 3) or (n_surf, 3)) – Coordinates of points for reference molecule used to compute shape similarity. Use atom centers for volumentric similarity. Use surface centers for surface similarity. Same for fit_centers except (M, 3) or (m_surf, 3).
ref_points (Array (n_surf, 3)) – Coordinates of surface points for referencemolecule. Same for fit_points except (m_surf, 3).
ref_partial_charges (Array (N + n_H,)) – Partial charges corresponding to the atoms in ref_centers_w_H. Same for fit_partial_charges except (M + m_H,).
ref_surf_esp (Array (n_surf,)) – The electrostatic potential calculated at each surface point (ref_points). Same for fit_surf_esp except (m_surf,)
ref_radii (Array (N + n_H,)) – vdW radii corresponding to the atoms in centers_w_H_1 (angstroms) Same for fit_radii except (M + m_H,)
alpha (float) – Gaussian width parameter used in shape similarity scoring function.
lam (float (default = 0.001)) – Electrostatic potential weighting parameter (smaller = higher weight). 0.001 was chosen as default based empirical observations of the distribution of scores generated by _esp_comparison before summation.
probe_radius (float (default = 1.0)) – Surface points found within vdW radii + probe radius will be masked out. Surface generation uses a probe radius of 1.2 (radius of hydrogen) so we use a slightly lower radius for be more tolerant.
esp_weight (float (default = 0.5)) – Weight to be placed on electrostatic similarity with respect to shape similarity. 0 = only shape similarity 1 = only electrostatic similarity
num_repeats (int (default=50)) – Number of different random initializations of SE(3) transformation parameters.
trans_centers (array (P, 3) (default=None)) – Locations to translate fit_points’ center of mass as an initial guesses for optimization. At each translation center, 10 rotations are also sampled. So the number of initializations scales as (# translation centers * 10 + 5) where 5 is from the identity and 4 PCA with aligned COM’s. If None, then num_repeats rotations are done with aligned COM’s.
lr (float (default=0.1)) – Learning rate or step-size for optimization
max_num_steps (int (default=200)) – Maximum number of steps to optimize over.
verbose (bool (False)) – Print statements about initial and final similarity scores. Further, it will print scores during optimization at very 100 steps.
fit_centers_w_H (jax.Array | ndarray)
fit_centers (jax.Array | ndarray)
fit_points (jax.Array | ndarray)
fit_surf_esp (jax.Array | ndarray)
- Returns:
- aligned_pointsArray (M,3)
The transformed point cloud for fit_points using the optimized SE(3) transformation for alignment with ref_points.
- SE3_transformArray (4,4)
Optimized SE(3) transformation matrix used to obtain aligned_points from fit_points.
- scoreArray (1,)
ShaEP similarity score for the optimal transformation.
- Return type:
- shepherd_score.alignment._jax.batched_obj_pharm_overlay_helper(se3_params, ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, similarity='tanimoto', extended_points=False, only_extended=False)#
Objective function to optimize pharmacophore overlay for a single instance.
- Parameters:
- Return type:
jax.Array
- shepherd_score.alignment._jax.objective_pharm_overlay_jax(se3_params, ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, similarity='tanimoto', extended_points=False, only_extended=False)[source]#
Objective function to optimize pharmacophore overlay. Batched.
- Parameters:
- Return type:
jax.Array
- shepherd_score.alignment._jax.optimize_pharm_overlay_jax(ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, similarity='tanimoto', extended_points=False, only_extended=False, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#
Optimize alignment of fit_anchors with respect to ref_anchors using SE(3) transformations and maximizing pharmacophore overlap score. JAX implementation.
- Parameters:
ref_pharms (jax.Array)
fit_pharms (jax.Array)
ref_anchors (jax.Array)
fit_anchors (jax.Array)
ref_vectors (jax.Array)
fit_vectors (jax.Array)
similarity (Literal['tanimoto', 'tversky', 'tversky_ref', 'tversky_fit'])
extended_points (bool)
only_extended (bool)
num_repeats (int)
trans_centers (jax.Array | ndarray | None)
lr (float)
max_num_steps (int)
verbose (bool)
- Return type:
Tuple[jax.Array, jax.Array, jax.Array, jax.Array]
- shepherd_score.alignment._jax.optimize_pharm_overlay_jax_vectorized(ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, similarity='tanimoto', extended_points=False, only_extended=False, num_repeats=50, trans_centers=None, lr=0.1, max_num_steps=200, verbose=False)[source]#
Optimize pharmacophore overlay using the fully-vectorized scoring function using jax.
- Parameters:
ref_pharms (jax.Array)
fit_pharms (jax.Array)
ref_anchors (jax.Array)
fit_anchors (jax.Array)
ref_vectors (jax.Array)
fit_vectors (jax.Array)
similarity (str)
extended_points (bool)
only_extended (bool)
num_repeats (int)
trans_centers (jax.Array | ndarray | None)
lr (float)
max_num_steps (int)
verbose (bool)
- Return type:
Tuple[jax.Array, jax.Array, jax.Array, jax.Array]
- shepherd_score.alignment._jax.optimize_pharm_overlay_jax_vectorized_mask(ref_pharms, fit_pharms, ref_anchors, fit_anchors, ref_vectors, fit_vectors, mask_ref, mask_fit, similarity='tanimoto', extended_points=False, only_extended=False, num_repeats=50, trans_centers=None, init_ref_anchors=None, init_fit_anchors=None, lr=0.1, max_num_steps=200, verbose=False)[source]#
Optimize pharmacophore overlay with padded/masked arrays via JAX.
Uses
get_overlap_pharm_jax_vectorized_maskso that padding entries never contribute to the overlap. Acceptsinit_ref_anchors/init_fit_anchors(original unpadded arrays) for PCA/COM-based SE(3) initialization to avoid zero-padding bias.- Parameters:
ref_pharms (jax.Array)
fit_pharms (jax.Array)
ref_anchors (jax.Array)
fit_anchors (jax.Array)
ref_vectors (jax.Array)
fit_vectors (jax.Array)
mask_ref (jax.Array)
mask_fit (jax.Array)
similarity (str)
extended_points (bool)
only_extended (bool)
num_repeats (int)
trans_centers (jax.Array | ndarray | None)
init_ref_anchors (ndarray | None)
init_fit_anchors (ndarray | None)
lr (float)
max_num_steps (int)
verbose (bool)
- Return type:
Tuple[jax.Array, jax.Array, jax.Array, jax.Array]