Coverage for mddb_workflow / tools / get_inchi_keys.py: 62%

138 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 18:45 +0000

1import MDAnalysis 

2import multiprocessing 

3from dataclasses import dataclass, field 

4from mddb_workflow.tools.get_ligands import pubchem_standardization 

5from mddb_workflow.utils.structures import Structure 

6from mddb_workflow.utils.auxiliar import warn, save_json 

7from mddb_workflow.utils.type_hints import * 

8from rdkit import Chem 

9from rdkit.Chem.MolStandardize import rdMolStandardize 

10from rdkit.Chem.rdDetermineBonds import DetermineBondOrders 

11from MDAnalysis.converters.RDKitInferring import MDAnalysisInferrer 

12 

13 

14@dataclass 

15class InChIKeyData: 

16 """Data structure for InChI key information. 

17 

18 Attributes: 

19 inchi (str): The InChI string for the residue. 

20 resindices (list[int]): List of all residue indices with this InChI key. 

21 fragments (list[list[int]]): Lists of residue indices that are connected as a single group. 

22 resnames (set[str]): Set of residue names associated with this InChI key. 

23 molname (str): Representative molecule name for this InChI key. 

24 moltype (Literal['residue', 'fragment']): Type of the molecule. 

25 classification (set): Set of residue classifications for this InChI key. 

26 frag_len (int): Length of the fragments. 1 if no fragments are present. 

27 references (dict): Additional database references related to this InChI key. 

28 

29 """ 

30 inchi: str 

31 resindices: list[int] = field(default_factory=list) 

32 fragments: list[list[int]] = field(default_factory=list) 

33 resnames: set[str] = field(default_factory=set) 

34 molname: str = '' 

35 moltype: Literal['residue', 'fragment'] = 'residue' 

36 classification: set = field(default_factory=set) 

37 frag_len: int = 1 

38 references: dict = field(default_factory=dict) 

39 

40 @classmethod 

41 def load_cache(cls, cache_dict: dict[dict | None]) -> dict[str, 'InChIKeyData']: 

42 """Load data from cache by converting dictionaries to InChIKeyData objects.""" 

43 if len(cache_dict) > 0 and isinstance(next(iter(cache_dict.values())), dict): 

44 inchikeys = {} 

45 for inchikey, data in cache_dict.items(): 

46 inchidata = cls(inchi=data['inchi']) 

47 for key, value in data.items(): 

48 setattr(inchidata, key, value) 

49 inchikeys[inchikey] = inchidata 

50 return inchikeys 

51 else: 

52 return cache_dict 

53 

54 

55def is_ferroheme(mda_atoms: 'MDAnalysis.AtomGroup') -> bool: 

56 """Check if the given MDAnalysis AtomGroup corresponds to a ferroheme molecule.""" 

57 # Create a copy of the universe to remove the bonds safely 

58 resatoms = MDAnalysis.Merge(mda_atoms) 

59 resatoms.delete_bonds(resatoms.select_atoms('element Fe').bonds) 

60 # Convert to basic RDKit molecule (no bond order nor formal charges) 

61 mol = resatoms.select_atoms('not element Fe').convert_to.rdkit(inferrer=None) 

62 # Add hydrogen to first two nitrogens in exchange 

63 # for the removed Fe bonds. Which N to does not matter 

64 # as we standardize later 

65 mol_editable = Chem.RWMol(mol) 

66 n_atoms = [at for at in mol_editable.GetAtoms() if at.GetSymbol() == 'N'] 

67 for n_atom in n_atoms[:2]: 

68 # Add H atom bonding to N 

69 h_idx = mol_editable.AddAtom(Chem.Atom('H')) 

70 mol_editable.AddBond(n_atom.GetIdx(), h_idx, Chem.BondType.SINGLE) 

71 # Convert back to regular molecule 

72 mol_with_h = mol_editable.GetMol() 

73 # rdDepictor.Compute2DCoords(mol_with_h) # Optional: compute 2D coordinates for visualization 

74 # Use MDAnalysisInferrer to get formal charge of the molecule 

75 mol_mda = MDAnalysisInferrer()(mol) 

76 formal_charge = Chem.GetFormalCharge(mol_mda) 

77 DetermineBondOrders(mol_with_h, charge=formal_charge, maxIterations=1000) 

78 unch_mol = rdMolStandardize.ChargeParent(mol_with_h) 

79 inchi = Chem.MolToInchi(unch_mol) 

80 standar_cid = pubchem_standardization(inchi) 

81 return standar_cid[0]['pubchem'] == '4971' # CID for ferroheme without Fe 

82 

83 

84def residue_to_inchi(task: tuple['MDAnalysis.AtomGroup', int]) -> tuple[str, str, int]: 

85 """Process a single residue to get its InChI key and related information.""" 

86 resatoms, resindices = task 

87 # Convert to RDKIT and get InChI data 

88 res_RD = resatoms.convert_to.rdkit(force=True) 

89 if 'Fe' in set(resatoms.atoms.elements) and len(resatoms.atoms) > 1: 

90 # Metallo proteins are not well handled by RDKit/InChI so we hardcode 

91 # the InChI key for ferroheme that is the only case we have found 

92 if is_ferroheme(resatoms): 

93 inchikey = 'KABFMIBPWCXCRK-UHFFFAOYSA-L' 

94 inchi = 'InChI=1S/C34H34N4O4.Fe/c1-7-21-17(3)25-13-26-19(5)23(9-11-33(39)40)31(37-26)16-32-24(10-12-34(41)42)20(6)28(38-32)15-30-22(8-2)18(4)27(36-30)14-29(21)35-25;/h7-8,13-16H,1-2,9-12H2,3-6H3,(H4,35,36,37,38,39,40,41,42);/q;+2/p-2' 

95 else: 

96 raise NotImplementedError('Non-ferroheme residues with Fe are not supported.') 

97 else: 

98 formal_charge = (int(resatoms.atoms.charges.sum().round()) 

99 if hasattr(resatoms.atoms, 'charges') 

100 else None) 

101 # For charged residues, DetermineBondOrders is better as we 

102 # can set the formal charge on the molecule. Step needed for A026E 

103 # If MDAnalysisInferrer infers the correct formal charge, we skip this step 

104 if formal_charge and Chem.GetFormalCharge(res_RD) != formal_charge: 

105 # Try/except because DetermineBondOrders can fail 

106 try: 

107 res_RD_copy = Chem.Mol(res_RD) 

108 DetermineBondOrders(res_RD_copy, charge=formal_charge) 

109 res_RD = res_RD_copy 

110 except Exception: 

111 pass 

112 # Calculate InChI key and string 

113 inchikey = Chem.MolToInchiKey(res_RD) 

114 # rdinchi.MolToInchi so it doesnt print the warnings 

115 inchi, retcode, message, logs, aux = Chem.rdinchi.MolToInchi(res_RD) 

116 return (inchikey, inchi, resindices) 

117 

118 

119def generate_inchikeys( 

120 universe: 'MDAnalysis.Universe', 

121 structure: 'Structure', 

122) -> dict[str, InChIKeyData]: 

123 """Generate a dictionary mapping InChI keys to residue information for non-standard residues. 

124 

125 This function uses MDAnalysis to parse the input structure and topology files and identifies 

126 residues that are not classified as 'ion', 'solvent', 'nucleic', or 'protein'. For each 

127 identified residue, it converts the structure to RDKit format to obtain the InChI key 

128 and InChI string. The resulting data is stored in dictionaries to map InChI keys to residue 

129 details and residue names to InChI keys. PDB coordinates are necesary to distinguish stereoisomers. 

130 

131 Args: 

132 universe (Universe): The MDAnalysis Universe object containing the structure and topology. 

133 structure (Structure): The Structure object containing residues. 

134 

135 Returns: 

136 dict: A dictionary mapping InChI keys to InChIKeyData objects. 

137 

138 Notes: 

139 The function also performs consistency checks, warning if multiple residue names 

140 map to the same InChI key or if multiple InChI keys map to the same residue name, 

141 which can indicate mismatched residue definitions or stereoisomers. 

142 

143 """ 

144 try: 

145 universe.universe.atoms.charges 

146 except Exception: 

147 warn('Topology file does not have charges, InChI keys may be unreliable.') 

148 

149 # 1) Prepare residue data for parallel processing 

150 # First group residues that are bonded together 

151 tasks = [] 

152 residues = structure.residues 

153 

154 # Fragment = residues that are bonded together 

155 fragments = universe.atoms.fragments 

156 for i, fragment in enumerate(fragments): 

157 resindices = fragment.residues.resindices.tolist() 

158 

159 # Continue to the next fragment if any of its 

160 # residues are of a disallowed classification 

161 classes = {residues[resindex].classification for resindex in resindices} 

162 if (classes.intersection({'solvent'}) or 

163 (classes.intersection({'dna', 'rna', 'protein'}) and len(resindices) > 1)): 

164 # print(f'Skipping fragment {i} with classes {classes} and residues {resindices}') 

165 continue 

166 

167 # Select residues atoms with MDAnalysis 

168 resatoms = universe.residues[resindices].atoms 

169 if 'Cg' in resatoms.types: 

170 # Skip coarse grain residues 

171 continue 

172 # If you pass a residue selection to a parallel worker, you a passing a whole MDAnalysis 

173 # universe, slowing the process down because you have to pickle the object 

174 # To avoid this we create 

175 resatoms = MDAnalysis.Merge(resatoms).universe.atoms 

176 # Convert to RDKit and get InChI data 

177 tasks.append((resatoms, resindices)) 

178 

179 results = [] 

180 # Execute tasks in parallel 

181 with multiprocessing.Pool() as pool: 

182 results = pool.map(residue_to_inchi, tasks) 

183 

184 # 2) Process results and build dictionaries 

185 inchikeys: dict[str, InChIKeyData] = {} # To see if different name for same residue 

186 name_2_key = {} # To see if different residues under same name 

187 for (inchikey, inchi, resindices) in results: 

188 # Get or create the entry for this InChI key 

189 data = inchikeys.setdefault(inchikey, InChIKeyData(inchi=inchi)) 

190 

191 # Add residue index to the list 

192 data.resindices.extend(resindices) 

193 # Add residue name to the list. For multi residues we join the names 

194 resnames = '-'.join(sorted([residues[index].name for index in resindices])) 

195 data.resnames.add(resnames) 

196 # Add residue class to the list 

197 if len(resindices) > 1: 

198 classes = tuple(set([residues[index].classification for index in resindices])) 

199 data.classification.add(classes) 

200 # Glucolipids saved the groups of residues the form a 'fragment' to solve a 

201 # problem with FATSLiM later (ex: A01IR, A01J5) 

202 data.fragments.append(list(map(int, resindices))) 

203 else: 

204 data.classification.add(residues[resindices[0]].classification) 

205 

206 # Incorrect residue name, stereoisomers, loss of atoms... 

207 name_2_key.setdefault(resnames, []).append(inchikey) 

208 

209 # 3) Check data coherence 

210 for inchikey, data in inchikeys.items(): 

211 # Check if there are multiple names for the same InChI key 

212 if len(data.resnames) > 1: 

213 warn(f'Same residue with different names:\n {inchikey} -> {str(data.resnames)}') 

214 data.molname = list(data.resnames)[0] # Just pick one name 

215 # Check if there are multiple classifications for the same InChI key 

216 if len(data.classification) > 1: 

217 warn('Same residue with different classifications:\n' 

218 f'{inchikey} + -> {str(data.classification)} for names {str(data.resnames)}') 

219 # Check if there are multiple fragments length for the same InChI key 

220 if len(data.fragments) == 0: 

221 data.frag_len = 1 

222 else: 

223 data.moltype = 'fragment' 

224 frag_lens = set([len(fragment) for fragment in data.fragments]) 

225 assert len(frag_lens) == 1, \ 

226 f'Fragments of different lengths for InChI key {inchikey}: {str(frag_lens)}' 

227 data.frag_len = frag_lens.pop() 

228 

229 # Check if there are multiple InChI keys for the same name 

230 for name, keys in name_2_key.items(): 

231 keys = list(set(keys)) 

232 if len(keys) < 2: continue 

233 counts = {} 

234 # Count the number of fragments 

235 for key in keys: 

236 # If there are not fragments, we use the number of residues 

237 counts[key] = (inchikeys[key].frag_len 

238 if inchikeys[key].frag_len > 1 

239 else len(inchikeys[key].resindices)) 

240 # Format the counts for printing 

241 key_counts = '\n'.join([f'\t{k}: {c: >4}' for k, c in counts.items()]) 

242 warn(f'The fragment {name} has more than one InChi key:\n' 

243 f'{key_counts}') 

244 

245 return inchikeys 

246 

247 

248def generate_inchi_references( 

249 inchikeys: dict[str, 'InChIKeyData'], 

250 lipid_references: dict[str, dict], 

251 ligand_references: dict[str, dict], 

252 output_file: 'File', 

253) -> list[dict]: 

254 """Generate InChI references for the database.""" 

255 inchikey_references = [] 

256 inchikey_map = [] 

257 for inchikey, res_data in inchikeys.items(): 

258 # If there is force ligands, the inchikey may have changed 

259 ref_inchikey = ligand_references.get(inchikey, {}).get('inchikey', inchikey) 

260 ref_inchi = ligand_references.get(inchikey, {}).get('inchi', res_data.inchi) 

261 inchikey_references.append({ 

262 'inchikey': ref_inchikey, 

263 'inchi': ref_inchi, 

264 'swisslipids': lipid_references.get(inchikey, {}).get('swisslipids', {}), 

265 'lipidmaps': lipid_references.get(inchikey, {}).get('lipidmaps', {}), 

266 'pubchem': ligand_references.get(inchikey, {}), 

267 }) 

268 # Sort dictionary entries for consistency when uploading to database 

269 for k, v in inchikey_references[-1].items(): 

270 if type(v) is dict: 

271 inchikey_references[-1][k] = dict(sorted(v.items())) 

272 

273 # Get residue indices from ligand forced selections if available 

274 resindices = ligand_references.get(inchikey, {}).get('resindices', list(map(int, res_data.resindices))) 

275 inchikey_map.append({ 

276 'inchikey': ref_inchikey, 

277 'name': list(res_data.resnames)[0], # For rmsds 

278 # 'inchi': ref_inchi, 

279 # 'fragments': res_data.fragments, 

280 'residue_indices': resindices, 

281 'is_lipid': inchikey in lipid_references, 

282 'match': { 

283 'ref': {'inchikey': inchikey} 

284 } 

285 }) 

286 save_json(inchikey_references, output_file.path) 

287 return inchikey_map