Coverage for mddb_workflow/tools/get_bonds.py: 75%

175 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 15:48 +0000

1from mddb_workflow.tools.get_pdb_frames import get_pdb_frames 

2from mddb_workflow.utils.auxiliar import load_json, warn, MISSING_TOPOLOGY 

3from mddb_workflow.utils.auxiliar import MISSING_BONDS, JSON_SERIALIZABLE_MISSING_BONDS 

4from mddb_workflow.utils.constants import STANDARD_TOPOLOGY_FILENAME 

5from mddb_workflow.utils.vmd_spells import get_covalent_bonds 

6from mddb_workflow.utils.gmx_spells import get_tpr_bonds as get_tpr_bonds_gromacs 

7from mddb_workflow.utils.gmx_spells import get_tpr_atom_count 

8from mddb_workflow.utils.mda_spells import get_tpr_bonds_mdanalysis 

9from mddb_workflow.utils.type_hints import * 

10import pytraj as pt 

11from collections import Counter 

12 

13FAILED_BOND_MINING_EXCEPTION = Exception('Failed to mine bonds') 

14 

15def get_excluded_atoms_selection (structure : 'Structure', pbc_selection : 'Selection') -> 'Selection': 

16 """ Set some atoms which are to be skipped from bonding tests given their "fake" nature. """ 

17 # Get a selection of ion atoms which are not in PBC 

18 # These ions are usually "tweaked" to be bonded to another atom although there is no real covalent bond 

19 # They are not taken in count when testing coherent bonds or looking for the reference frame 

20 non_pbc_ions_selection = structure.select_ions() - pbc_selection 

21 # We also exclude coarse grain atoms since their bonds will never be found by a distance/radius guess 

22 excluded_atoms_selection = non_pbc_ions_selection + structure.select_cg() 

23 return excluded_atoms_selection 

24 

25def do_bonds_match ( 

26 bonds_1 : list[ list[int] ], 

27 bonds_2 : list[ list[int] ], 

28 # A selection of atoms whose bonds are not evaluated 

29 excluded_atoms_selection : 'Selection', 

30 # Set verbose as true to show which are the atoms preventing the match 

31 verbose : bool = False, 

32 # The rest of inputs are just for logs and debug 

33 atoms : Optional[ list['Atom'] ] = None, 

34 counter_list : Optional[ list[int] ] = None 

35) -> bool: 

36 """ Check if two sets of bonds match perfectly. """ 

37 # If the number of atoms in both lists is not matching then there is something very wrong 

38 if len(bonds_1) != len(bonds_2): 

39 raise ValueError(f'The number of atoms is not matching in both bond lists ({len(bonds_1)} and {len(bonds_2)})') 

40 # Find ion atom indices 

41 excluded_atom_indices = set(excluded_atoms_selection.atom_indices) 

42 # For each atom, check bonds to match perfectly 

43 # Order is not important 

44 for atom_index, (atom_bonds_1, atom_bonds_2) in enumerate(zip(bonds_1, bonds_2)): 

45 # Skip ion bonds 

46 if atom_index in excluded_atom_indices: 

47 continue 

48 atom_bonds_set_1 = set(atom_bonds_1) - excluded_atom_indices 

49 atom_bonds_set_2 = set(atom_bonds_2) - excluded_atom_indices 

50 # Check atom bonds to match 

51 if len(atom_bonds_set_1) != len(atom_bonds_set_2) or any(bond not in atom_bonds_set_2 for bond in atom_bonds_set_1): 

52 if verbose: 

53 if atoms: 

54 mismatch_atom_label = atoms[atom_index].label 

55 print(f' Mismatch in atom {mismatch_atom_label}:') 

56 it_is_atom_labels = ', '.join([ atoms[index].label for index in atom_bonds_set_1 ]) 

57 print(f' It is bonded to atoms {it_is_atom_labels}') 

58 it_should_be_atom_labels = ', '.join([ atoms[index].label for index in atom_bonds_set_2 ]) 

59 print(f' It should be bonded to atoms {it_should_be_atom_labels}') 

60 else: 

61 print(f' Mismatch in atom with index {atom_index}:') 

62 it_is_atom_indices = ','.join([ str(index) for index in atom_bonds_set_1 ]) 

63 print(f' It is bonded to atoms with indices {it_is_atom_indices}') 

64 it_should_be_atom_indices = ','.join([ str(index) for index in atom_bonds_set_2 ]) 

65 print(f' It should be bonded to atoms with indices {it_should_be_atom_indices}') 

66 # Save for failure analysis 

67 if counter_list is not None: 

68 counter_list.append((atom_index,tuple(atom_bonds_set_1),tuple(atom_bonds_set_2))) 

69 return False 

70 return True 

71 

72def get_most_stable_bonds ( 

73 structure_filepath : str, 

74 trajectory_filepath : str, 

75 snapshots : int, 

76 frames_limit : int = 10 

77) -> list[ list[int] ]: 

78 """ Get covalent bonds using VMD along different frames. 

79 This way we avoid having false positives because 2 atoms are very close in one frame by accident. 

80 This way we avoid having false negatives because 2 atoms are very far in one frame by accident. """ 

81 

82 # Get each frame in pdb format to run VMD 

83 print('Finding most stable bonds') 

84 frames, step, count = get_pdb_frames(structure_filepath, trajectory_filepath, 

85 snapshots, frames_limit, pbar_bool=True) 

86 

87 # Track bonds along frames 

88 frame_bonds = [] 

89 

90 # Iterate over the different frames 

91 for current_frame_pdb in frames: 

92 

93 # Find the covalent bonds for the current frame 

94 current_frame_bonds = get_covalent_bonds(current_frame_pdb) 

95 frame_bonds.append(current_frame_bonds) 

96 

97 # Then keep those bonds which are respected in the majority of frames 

98 # Usually wrongs bonds (both false positives and negatives) are formed only one frame 

99 # It should not happend that a bond is formed around half of times 

100 majority_cut = count / 2 

101 atom_count = len(frame_bonds[0]) 

102 most_stable_bonds = [] 

103 for atom in range(atom_count): 

104 total_bonds = [] 

105 # Accumulate all bonds 

106 for frame in frame_bonds: 

107 atom_bonds = frame[atom] 

108 total_bonds += atom_bonds 

109 # Keep only those bonds with more occurrences than half the number of frames 

110 unique_bonds = set(total_bonds) 

111 atom_most_stable_bonds = [] 

112 for bond in unique_bonds: 

113 occurrences = total_bonds.count(bond) 

114 if occurrences > majority_cut: 

115 atom_most_stable_bonds.append(bond) 

116 # Add current atom safe bonds to the total 

117 most_stable_bonds.append(atom_most_stable_bonds) 

118 

119 return most_stable_bonds 

120 

121def get_bonds_reference_frame ( 

122 structure_file : 'File', 

123 trajectory_file : 'File', 

124 snapshots : int, 

125 reference_bonds : list[ list[int] ], 

126 structure : 'Structure', 

127 pbc_selection : 'Selection', 

128 patience : int = 100, # Limit of frames to check before we surrender 

129 verbose : bool = False, 

130) -> Optional[int]: 

131 """ Return a reference frame number where all bonds are exactly as they should (by VMD standards). 

132 This is the frame used when representing the MD. """ 

133 # Set some atoms which are to be skipped from these test given their "fake" nature 

134 excluded_atoms_selection = get_excluded_atoms_selection(structure, pbc_selection) 

135 

136 # If all atoms are to be excluded then set the first frame as the reference frame and stop here 

137 if len(excluded_atoms_selection) == len(structure.atoms): return 0 

138 

139 # Now that we have the reference bonds, we must find a frame where bonds are exactly the reference ones 

140 # IMPORTANT: Note that we do not set a frames limit here, so all frames will be read and the step will be 1 

141 frames, step, count = get_pdb_frames(structure_file.path, trajectory_file.path, snapshots,patience=patience) 

142 if step != 1: raise ValueError('If we are skipping frames then the code below will silently return a wrong reference frame') 

143 print(f'Searching the reference frame for the bonds. Only first {min(patience, count)} frames will be checked.') 

144 # We check all frames but we stop as soon as we find a match 

145 bonds_reference_frame = None 

146 counter_list = [] 

147 for frame_number, frame_pdb in enumerate(frames): 

148 # Get the actual frame number 

149 bonds = get_covalent_bonds(frame_pdb) 

150 if do_bonds_match(bonds, reference_bonds, excluded_atoms_selection, counter_list=counter_list, verbose=verbose): 

151 bonds_reference_frame = frame_number 

152 break 

153 frames.close() 

154 # If no frame has the reference bonds then we return None 

155 if bonds_reference_frame == None: 

156 # Print the first clashes table 

157 print(' First clash stats:') 

158 headers = ['Count', 'Atom', 'Is bonding with', 'Should bond with'] 

159 count = Counter(counter_list).most_common(10) 

160 # Calculate column widths 

161 table_data = [] 

162 for (at, bond, should), n in count: 

163 table_data.append([n, at, bond, should]) 

164 col_widths = [max(len(str(item)) for item in col) for col in zip(*table_data, headers)] 

165 # Format rows 

166 def format_row(row): 

167 return " | ".join(f"{str(item):>{col_widths[i]}}" for i, item in enumerate(row)) 

168 # Print table 

169 print(format_row(headers)) 

170 print("-+-".join('-' * width for width in col_widths)) 

171 for row in table_data: 

172 print(format_row(row)) 

173 return None 

174 print(f' Got it -> Frame {bonds_reference_frame + 1}') 

175 

176 return bonds_reference_frame 

177 

178def mine_topology_bonds (bonds_source_file : Union['File', Exception]) -> list[ list[int] ]: 

179 """ Extract bonds from a source file and format them per atom. """ 

180 # If there is no topology then return no bonds at all 

181 if bonds_source_file == MISSING_TOPOLOGY or not bonds_source_file.exists: 

182 return None 

183 print('Mining atom bonds from topology file') 

184 # If we have the standard topology then get bonds from it 

185 if bonds_source_file.filename == STANDARD_TOPOLOGY_FILENAME: 

186 print(f' Bonds in the "{bonds_source_file.filename}" file will be used') 

187 standard_topology = load_json(bonds_source_file.path) 

188 standard_atom_bonds = standard_topology.get('atom_bonds', None) 

189 # Convert missing bonds flags 

190 # These come from coarse grain (CG) simulations with no topology 

191 atom_bonds = [] 

192 for bonds in standard_atom_bonds: 

193 if bonds == JSON_SERIALIZABLE_MISSING_BONDS: 

194 atom_bonds.append(MISSING_BONDS) 

195 continue 

196 atom_bonds.append(bonds) 

197 if atom_bonds: return atom_bonds 

198 print(' There were no bonds in the topology file. Is this an old file?') 

199 # In some ocasions, bonds may come inside a topology which can be parsed through pytraj 

200 elif bonds_source_file.is_pytraj_supported(): 

201 print(f' Bonds will be mined from "{bonds_source_file.path}"') 

202 pt_topology = pt.load_topology(filename=bonds_source_file.path) 

203 raw_bonds = [ bonds.indices for bonds in pt_topology.bonds ] 

204 # Sort bonds 

205 atom_count = pt_topology.n_atoms 

206 atom_bonds = sort_bonds(raw_bonds, atom_count) 

207 # If there is any bonding data then return atom bonds 

208 if any(len(bonds) > 0 for bonds in atom_bonds): return atom_bonds 

209 # If we have a TPR then use our own tool 

210 elif bonds_source_file.format == 'tpr': 

211 print(f' Bonds will be mined from TPR file "{bonds_source_file.path}"') 

212 raw_bonds = get_tpr_bonds(bonds_source_file.path) 

213 if raw_bonds != FAILED_BOND_MINING_EXCEPTION: 

214 # Sort bonds 

215 atom_count = get_tpr_atom_count(bonds_source_file.path) 

216 atom_bonds = sort_bonds(raw_bonds, atom_count) 

217 return atom_bonds 

218 # If we failed to mine bonds then return None and they will be guessed further 

219 print (' Failed to mine bonds -> They will be guessed from atom distances and radius') 

220 return None 

221 

222def get_tpr_bonds (tpr_filepath : str) -> list[ tuple[int, int] ]: 

223 """ Get TPR bonds. 

224 Try 2 different methods and hope 1 of them works. """ 

225 try: 

226 bonds = get_tpr_bonds_gromacs(tpr_filepath) 

227 except: 

228 print(' Our tool failed to extract bonds. Using MDAnalysis extraction...') 

229 try: 

230 bonds = get_tpr_bonds_mdanalysis(tpr_filepath) 

231 # This happens when the filter selection is what we want to exclude, and not to keep 

232 if len(bonds) == 0: 

233 warn('Bonds were mined successfully but it looks like there are no bonds. Is it all ions or CG solvent?') 

234 except Exception as err: 

235 print(f' MDAnalysis failed to extract bonds: {err}. Relying on guess...') 

236 bonds = FAILED_BOND_MINING_EXCEPTION 

237 return bonds 

238 

239def sort_bonds (source_bonds : list[ tuple[int, int] ], atom_count : int) -> list[ list[int] ]: 

240 """ Sort bonds according to our format: a list with the bonded atom indices for each atom. 

241 Source data is the usual format to store bonds: a list of tuples with every pair of bonded atoms. """ 

242 # Set a list of lists with an empty list for every atom 

243 atom_bonds = [ [] for i in range(atom_count) ] 

244 for bond in source_bonds: 

245 a,b = bond 

246 # Make sure atom indices are regular integers so they are JSON serializables 

247 atom_bonds[a].append(int(b)) 

248 atom_bonds[b].append(int(a)) 

249 return atom_bonds 

250 

251def find_safe_bonds ( 

252 topology_file : Union['File', Exception], 

253 structure_file : 'File', 

254 trajectory_file : 'File', 

255 must_check_stable_bonds : bool, 

256 snapshots : int, 

257 structure : 'Structure', 

258 guess_bonds : bool = False, 

259 # Optional file with bonds sorted according a new atom order 

260 resorted_bonds_file : Optional['File'] = None 

261) -> list[list[int]]: 

262 """ Find reference safe bonds in the system. 

263 First try to mine bonds from a topology files. 

264 If the mining fails then search for the most stable bonds. 

265 If we trust in stable bonds then simply return the structure bonds. """ 

266 # If we have a resorted file then use it 

267 # Note that this is very excepcional 

268 if resorted_bonds_file != None and resorted_bonds_file.exists: 

269 warn('Using resorted safe bonds') 

270 return load_json(resorted_bonds_file.path) 

271 # Get a selection including coarse grain atoms in the structure 

272 cg_selection = structure.select_cg() 

273 # Try to get bonds from the topology before guessing 

274 if guess_bonds: 

275 print('Bonds were forced to be guessed, so bonds in the topology file will be ignored') 

276 else: 

277 safe_bonds = mine_topology_bonds(topology_file) 

278 if safe_bonds: 

279 return safe_bonds 

280 # If all bonds are in coarse grain then set all bonds "wrong" already 

281 if len(cg_selection) == structure.atom_count: 

282 safe_bonds = [ MISSING_BONDS for atom in range(structure.atom_count) ] 

283 return safe_bonds 

284 # If failed to mine topology bonds then guess stable bonds 

285 print('Bonds will be guessed by atom distances and radii along different frames in the trajectory') 

286 # Find stable bonds if necessary 

287 if must_check_stable_bonds: 

288 # Using the trajectory, find the most stable bonds 

289 print('Checking bonds along trajectory to determine which are stable') 

290 safe_bonds = get_most_stable_bonds(structure_file.path, trajectory_file.path, snapshots) 

291 discard_coarse_grain_bonds(safe_bonds, cg_selection) 

292 return safe_bonds 

293 # If we trust stable bonds then simply use structure bonds 

294 print('Default structure bonds will be used since they have been marked as trusted') 

295 safe_bonds = structure.bonds 

296 discard_coarse_grain_bonds(safe_bonds, cg_selection) 

297 return safe_bonds 

298 

299def discard_coarse_grain_bonds (bonds : list, cg_selection : 'Selection'): 

300 """ Given a list of bonds, discard the ones in the coarse grain selection. 

301 Note that the input list will be mutated. """ 

302 # For every atom in CG, replace its bonds with a class which will raise and error when read 

303 # Thus we make sure using these wrong bonds anywhere further will result in failure 

304 for atom_index in cg_selection.atom_indices: 

305 bonds[atom_index] = MISSING_BONDS