"""
Helper functions to convert different data types.
"""
from typing import Union, Tuple, List
from pathlib import Path
import numpy as np
import rdkit
from rdkit import Chem
import rdkit.Chem.rdDetermineBonds
import pandas as pd
[docs]
def write_xyz_file(atomic_numbers: np.ndarray,
positions: np.ndarray,
path_to_file: Union[str, None] = None
) -> str:
"""
Writes an xyz file of an atomistic structure, given np.ndarray of atomic numbers and coordinates.
Arguments
---------
atomic_numbers : np.ndarray of shape (N,) containing atomic numbers
positions : np.ndarray of shape (N,3) containing atomic coordinates
path_to_file : str specifying file path -- e.g. path_to_file = 'examples/molecule.xyz'. If None, then no output file is written.
Returns
-------
str : xyz block
"""
N = atomic_numbers.shape[0]
xyz = ''
xyz += f'{N}\n\n'
for i in range(0,N):
a = int(atomic_numbers[i])
p = positions[i]
xyz+= f'{rdkit.Chem.Atom(a).GetSymbol()} {p[0]:>15.8f} {p[1]:>15.8f} {p[2]:>15.8f}\n'
xyz+= '\n'
if path_to_file is not None:
with open(f'{path_to_file}', 'w') as f:
f.write(xyz)
return xyz
[docs]
def write_xyz_file_with_dummy(
atomic_numbers: np.ndarray,
positions: np.ndarray,
path_to_file: Union[str, None] = None
) -> Tuple[str, Union[np.ndarray, None]]:
"""
Writes an xyz file of an atomistic structure, given np.ndarray of atomic numbers and coordinates.
Accounts for the presence of dummy atoms.
Arguments
---------
atomic_numbers : np.ndarray of shape (N,) containing atomic numbers
positions : np.ndarray of shape (N,3) containing atomic coordinates
path_to_file : str specifying file path -- e.g. path_to_file = 'examples/molecule.xyz'. If None, then no output file is written.
Returns
-------
Tuple
xyz : str : xyz block
dummy_atom_pos : np.ndarray : positions of dummy atoms
"""
real_atom_inds = np.where(atomic_numbers != 0)[0]
N = len(real_atom_inds)
xyz = ''
xyz += f'{N}\n\n'
for i in real_atom_inds:
a = int(atomic_numbers[i])
p = positions[i]
xyz+= f'{rdkit.Chem.Atom(a).GetSymbol()} {p[0]:>15.8f} {p[1]:>15.8f} {p[2]:>15.8f}\n'
xyz+= '\n'
if path_to_file is not None:
with open(f'{path_to_file}', 'w') as f:
f.write(xyz)
dummy_atom_pos = None
if len(atomic_numbers) - len(real_atom_inds) > 0:
dummy_atom_pos = positions[np.where(atomic_numbers == 0)[0]]
return xyz, dummy_atom_pos
[docs]
def get_xyz_content(atomic_numbers: np.ndarray,
positions: np.ndarray
) -> str:
"""
Get the xyz block of an atomistic structure.
"""
xyz = write_xyz_file(atomic_numbers, positions, path_to_file=None)
return xyz
[docs]
def get_xyz_content_with_dummy(
atomic_numbers: np.ndarray,
positions: np.ndarray
) -> Tuple[str, Union[np.ndarray, None]]:
"""
Get the xyz block of an atomistic structure and remove dummy atoms from the xyz block.
Arguments
---------
atomic_numbers : np.ndarray of shape (N,) containing atomic numbers
positions : np.ndarray of shape (N,3) containing atomic coordinates
Returns
-------
Tuple
xyz : str : xyz block (without dummy atoms)
dummy_atom_pos : np.ndarray : positions of dummy atoms
"""
xyz, dummy_atom_pos = write_xyz_file_with_dummy(atomic_numbers, positions, path_to_file=None)
return xyz, dummy_atom_pos
[docs]
def get_mol_from_atom_pos(atoms: np.ndarray,
positions: np.ndarray
) -> Tuple[Union[Chem.Mol, None], int, str]:
"""
Try to get a RDKit mol object from atom and coordinate arrays.
Arguments
---------
atoms : np.ndarray (N,) of atomic numbers of the generated molecule or (N,M) one-hot
encoding.
positions : np.ndarray (N,3) of coordinates for the generated molecule's atoms.
Returns
-------
Tuple
mol : Chem.Mol or None
charge : int overall charge of molecule
xyz_block : str
"""
if len(atoms.shape) == 2:
atomic_nums = np.argmin(np.abs(atoms - 1.0), axis = -1)
else:
atomic_nums = atoms
xyz_block = write_xyz_file(atomic_nums, positions)
for charge in [0, 1, -1, 2, -2]:
try:
mol = extract_mol_from_xyz_block(xyz_block=xyz_block, charge=charge)
except Exception:
mol = None
if mol is not None:
break
else:
charge = 0
return mol, charge, xyz_block
[docs]
def get_smiles_from_atom_pos(atoms: np.ndarray,
positions: np.ndarray
) -> Union[str, None]:
"""
Try to get a SMILES string from atom and coordinate arrays.
Arguments
---------
atoms : np.ndarray (N,) of atomic numbers of the generated molecule or (N,M) one-hot
encoding.
positions : np.ndarray (N,3) of coordinates for the generated molecule's atoms.
Returns
-------
SMILES str or None
"""
mol, _, _ = get_mol_from_atom_pos(atoms=atoms, positions=positions)
smiles = None
if mol is not None:
smiles = Chem.MolToSmiles(Chem.RemoveHs(mol))
return smiles
[docs]
def load_npz_to_df(npz_path: Union[Path, str],
file_id: bool
) -> pd.DataFrame:
"""
Function to load a single npz file and return a dataframe with expanded zero-dimensional arrays.
This works specifically for files generated by ConditionalEvalPipeline.
"""
data = np.load(npz_path, allow_pickle=True)
df_dict = {}
# Find the first non-zero dimensional array length (assumed to be N_i)
length = None
for key, arr in data.items():
if arr.ndim == 1 and len(arr) < 50: # Non-zero dimensional array
length = len(arr)
break
# Ensure we have a valid length for the file
if length is None:
raise ValueError(f"No 1D array found in {npz_path}")
# Fill in the dictionary with arrays
for key, arr in data.items():
if key in ('ref_surf_resampling_scores', 'ref_surf_esp_resampling_scores', 'ref_mol_morgan_fp'):
continue
if arr.ndim == 0: # Zero-dimensional array
df_dict[key] = np.repeat(arr, length) # Repeat value to match length N_i
elif arr.ndim == 1 and len(arr) == length: # 1D arrays with length N_i
df_dict[key] = arr
else:
raise ValueError(f"Inconsistent array length for {key} in {npz_path}")
if file_id is not None:
df_dict['file_id'] = np.repeat(file_id, length)
return pd.DataFrame(df_dict)
[docs]
def collate_npz_files(npz_files: List[Union[str, Path]],
include_file_id: bool
) -> pd.DataFrame:
"""
Function to collate all npz files into a single dataframe.
Arguments
---------
npz_files : list of file paths
include_file_id : bool Whether to include a column called "file_id" that groups together
rows that came from the same file.
Returns
-------
pd.DataFrame : rows are each sample, columns are each property, and it repeats any 0d arrays.
"""
dfs = []
for i, npz_file in enumerate(npz_files):
if include_file_id:
df = load_npz_to_df(npz_file, file_id=i)
else:
df = load_npz_to_df(npz_file, file_id=None)
dfs.append(df)
# Concatenate all dataframes
return pd.concat(dfs, ignore_index=True)