Coverage for model_workflow/tools/get_bonds.py: 71%

170 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-23 10:54 +0000

1from model_workflow.tools.get_pdb_frames import get_pdb_frames 

2from model_workflow.utils.auxiliar import load_json, warn, MISSING_TOPOLOGY 

3from model_workflow.utils.auxiliar import MISSING_BONDS, JSON_SERIALIZABLE_MISSING_BONDS 

4from model_workflow.utils.constants import STANDARD_TOPOLOGY_FILENAME 

5from model_workflow.utils.vmd_spells import get_covalent_bonds 

6from model_workflow.utils.gmx_spells import get_tpr_bonds as get_tpr_bonds_gromacs 

7from model_workflow.utils.gmx_spells import get_tpr_atom_count 

8from model_workflow.utils.type_hints import * 

9import pytraj as pt 

10from MDAnalysis.topology.TPRParser import TPRParser 

11from collections import Counter 

12 

13# Set some atoms which are to be skipped from bonding tests given their "fake" nature 

14def get_excluded_atoms_selection (structure : 'Structure', pbc_selection : 'Selection') -> 'Selection': 

15 # Get a selection of ion atoms which are not in PBC 

16 # These ions are usually "tweaked" to be bonded to another atom although there is no real covalent bond 

17 # They are not taken in count when testing coherent bonds or looking for the reference frame 

18 non_pbc_ions_selection = structure.select_ions() - pbc_selection 

19 # We also exclude coarse grain atoms since their bonds will never be found by a distance/radius guess 

20 excluded_atoms_selection = non_pbc_ions_selection + structure.select_cg() 

21 return excluded_atoms_selection 

22 

23# Check if two sets of bonds match perfectly 

24def do_bonds_match ( 

25 bonds_1 : List[ List[int] ], 

26 bonds_2 : List[ List[int] ], 

27 # A selection of atoms whose bonds are not evaluated 

28 excluded_atoms_selection : 'Selection', 

29 # Set verbose as true to show which are the atoms preventing the match 

30 verbose : bool = False, 

31 # The rest of inputs are just for logs and debug 

32 atoms : Optional[ List['Atom'] ] = None, 

33 counter_list : Optional[ List[int] ] = None 

34) -> bool: 

35 # If the number of atoms in both lists is not matching then there is something very wrong 

36 if len(bonds_1) != len(bonds_2): 

37 raise ValueError(f'The number of atoms is not matching in both bond lists ({len(bonds_1)} and {len(bonds_2)})') 

38 # Find ion atom indices 

39 excluded_atom_indices = set(excluded_atoms_selection.atom_indices) 

40 # For each atom, check bonds to match perfectly 

41 # Order is not important 

42 for atom_index, (atom_bonds_1, atom_bonds_2) in enumerate(zip(bonds_1, bonds_2)): 

43 # Skip ion bonds 

44 if atom_index in excluded_atom_indices: 

45 continue 

46 atom_bonds_set_1 = set(atom_bonds_1) - excluded_atom_indices 

47 atom_bonds_set_2 = set(atom_bonds_2) - excluded_atom_indices 

48 # Check atom bonds to match 

49 if len(atom_bonds_set_1) != len(atom_bonds_set_2) or any(bond not in atom_bonds_set_2 for bond in atom_bonds_set_1): 

50 if verbose: 

51 if atoms: 

52 mismatch_atom_label = atoms[atom_index].label 

53 print(f' Mismatch in atom {mismatch_atom_label}:') 

54 it_is_atom_labels = ', '.join([ atoms[index].label for index in atom_bonds_set_1 ]) 

55 print(f' It is bonded to atoms {it_is_atom_labels}') 

56 it_should_be_atom_labels = ', '.join([ atoms[index].label for index in atom_bonds_set_2 ]) 

57 print(f' It should be bonded to atoms {it_should_be_atom_labels}') 

58 else: 

59 print(f' Mismatch in atom with index {atom_index}:') 

60 it_is_atom_indices = ','.join([ str(index) for index in atom_bonds_set_1 ]) 

61 print(f' It is bonded to atoms with indices {it_is_atom_indices}') 

62 it_should_be_atom_indices = ','.join([ str(index) for index in atom_bonds_set_2 ]) 

63 print(f' It should be bonded to atoms with indices {it_should_be_atom_indices}') 

64 # Save for failure analysis 

65 if counter_list is not None: 

66 counter_list.append((atom_index,tuple(atom_bonds_set_1),tuple(atom_bonds_set_2))) 

67 return False 

68 return True 

69 

70# Get covalent bonds using VMD along different frames 

71# This way we avoid having false positives because 2 atoms are very close in one frame by accident 

72# This way we avoid having false negatives because 2 atoms are very far in one frame by accident 

73def get_most_stable_bonds ( 

74 structure_filepath : str, 

75 trajectory_filepath : str, 

76 snapshots : int, 

77 frames_limit : int = 10 

78) -> List[ List[int] ]: 

79 

80 # Get each frame in pdb format to run VMD 

81 print('Finding most stable bonds') 

82 frames, step, count = get_pdb_frames(structure_filepath, trajectory_filepath, 

83 snapshots, frames_limit, pbar_bool=True) 

84 

85 # Track bonds along frames 

86 frame_bonds = [] 

87 

88 # Iterate over the different frames 

89 for current_frame_pdb in frames: 

90 

91 # Find the covalent bonds for the current frame 

92 current_frame_bonds = get_covalent_bonds(current_frame_pdb) 

93 frame_bonds.append(current_frame_bonds) 

94 

95 # Then keep those bonds which are respected in the majority of frames 

96 # Usually wrongs bonds (both false positives and negatives) are formed only one frame 

97 # It should not happend that a bond is formed around half of times 

98 majority_cut = count / 2 

99 atom_count = len(frame_bonds[0]) 

100 most_stable_bonds = [] 

101 for atom in range(atom_count): 

102 total_bonds = [] 

103 # Accumulate all bonds 

104 for frame in frame_bonds: 

105 atom_bonds = frame[atom] 

106 total_bonds += atom_bonds 

107 # Keep only those bonds with more occurrences than half the number of frames 

108 unique_bonds = set(total_bonds) 

109 atom_most_stable_bonds = [] 

110 for bond in unique_bonds: 

111 occurrences = total_bonds.count(bond) 

112 if occurrences > majority_cut: 

113 atom_most_stable_bonds.append(bond) 

114 # Add current atom safe bonds to the total 

115 most_stable_bonds.append(atom_most_stable_bonds) 

116 

117 return most_stable_bonds 

118 

119 

120def get_bonds_canonical_frame ( 

121 structure_file : 'File', 

122 trajectory_file : 'File', 

123 snapshots : int, 

124 reference_bonds : List[ List[int] ], 

125 structure : 'Structure', 

126 pbc_selection : 'Selection', 

127 patience : int = 100, # Limit of frames to check before we surrender 

128 verbose : bool = False, 

129) -> Optional[int]: 

130 """Return a canonical frame number where all bonds are exactly as they should. 

131 This is the frame used when representing the MD.""" 

132 # Set some atoms which are to be skipped from these test given their "fake" nature 

133 excluded_atoms_selection = get_excluded_atoms_selection(structure, pbc_selection) 

134 

135 # If all atoms are to be excluded then set the first frame as the reference frame and stop here 

136 if len(excluded_atoms_selection) == len(structure.atoms): return 0 

137 

138 # Now that we have the reference bonds, we must find a frame where bonds are exactly the canonical ones 

139 # IMPORTANT: Note that we do not set a frames limit here, so all frames will be read and the step will be 1 

140 frames, step, count = get_pdb_frames(structure_file.path, trajectory_file.path, snapshots,patience=patience) 

141 if step != 1: raise ValueError('If we are skipping frames then the code below will silently return a wrong reference frame') 

142 print(f'Searching reference bonds canonical frame. Only first {min(patience, count)} frames will be checked.') 

143 # We check all frames but we stop as soon as we find a match 

144 reference_bonds_frame = None 

145 counter_list = [] 

146 for frame_number, frame_pdb in enumerate(frames): 

147 # Get the actual frame number 

148 bonds = get_covalent_bonds(frame_pdb) 

149 if do_bonds_match(bonds, reference_bonds, excluded_atoms_selection, counter_list=counter_list, verbose=verbose): 

150 reference_bonds_frame = frame_number 

151 break 

152 frames.close() 

153 # If no frame has the canonical bonds then we return None 

154 if reference_bonds_frame == None: 

155 # Print the first clashes 

156 print(' First clash stats:') 

157 headers = ['Count', 'Atom', 'Is bonding with', 'Should bond with'] 

158 count = Counter(counter_list).most_common(10) 

159 # Calculate column widths 

160 table_data = [] 

161 for (at, bond, should), n in count: 

162 table_data.append([n, at, bond, should]) 

163 col_widths = [max(len(str(item)) for item in col) for col in zip(*table_data, headers)] 

164 # Format rows 

165 def format_row(row): 

166 return " | ".join(f"{str(item):>{col_widths[i]}}" for i, item in enumerate(row)) 

167 # Print table 

168 print(format_row(headers)) 

169 print("-+-".join('-' * width for width in col_widths)) 

170 for row in table_data: 

171 print(format_row(row)) 

172 return None 

173 print(f' Got it -> Frame {reference_bonds_frame + 1}') 

174 

175 return reference_bonds_frame 

176 

177# Extract bonds from a source file and format them per atom 

178def mine_topology_bonds (bonds_source_file : Union['File', Exception]) -> List[ List[int] ]: 

179 # If there is no topology then return no bonds at all 

180 if bonds_source_file == MISSING_TOPOLOGY or not bonds_source_file.exists: 

181 return None 

182 print('Mining atom bonds from topology file') 

183 # If we have the standard topology then get bonds from it 

184 if bonds_source_file.filename == STANDARD_TOPOLOGY_FILENAME: 

185 print(f' Bonds in the "{bonds_source_file.filename}" file will be used') 

186 standard_topology = load_json(bonds_source_file.path) 

187 standard_atom_bonds = standard_topology.get('atom_bonds', None) 

188 # Convert missing bonds flags 

189 # These come from coarse grain (CG) simulations with no topology 

190 atom_bonds = [] 

191 for bonds in standard_atom_bonds: 

192 if bonds == JSON_SERIALIZABLE_MISSING_BONDS: 

193 atom_bonds.append(MISSING_BONDS) 

194 continue 

195 atom_bonds.append(bonds) 

196 if atom_bonds: return atom_bonds 

197 print(' There were no bonds in the topology file. Is this an old file?') 

198 # In some ocasions, bonds may come inside a topology which can be parsed through pytraj 

199 elif bonds_source_file.is_pytraj_supported(): 

200 print(f' Bonds will be mined from "{bonds_source_file.path}"') 

201 pt_topology = pt.load_topology(filename=bonds_source_file.path) 

202 raw_bonds = [ bonds.indices for bonds in pt_topology.bonds ] 

203 # Sort bonds 

204 atom_count = pt_topology.n_atoms 

205 atom_bonds = sort_bonds(raw_bonds, atom_count) 

206 # If there is any bonding data then return atom bonds 

207 if any(len(bonds) > 0 for bonds in atom_bonds): return atom_bonds 

208 # If we have a TPR then use our own tool 

209 elif bonds_source_file.format == 'tpr': 

210 print(f' Bonds will be mined from TPR file "{bonds_source_file.path}"') 

211 raw_bonds = get_tpr_bonds(bonds_source_file.path) 

212 # Sort bonds 

213 atom_count = get_tpr_atom_count(bonds_source_file.path) 

214 atom_bonds = sort_bonds(raw_bonds, atom_count) 

215 # If there is any bonding data then return atom bonds 

216 if any(len(bonds) > 0 for bonds in atom_bonds): return atom_bonds 

217 # If we failed to mine bonds then return None and they will be guessed further 

218 print (' Failed to mine bonds -> They will be guessed from atom distances and radius') 

219 return None 

220 

221# Get TPR bonds 

222# Try 2 different methods and hope 1 of them works 

223def get_tpr_bonds (tpr_filepath : str) -> List[ Tuple[int, int] ]: 

224 try: 

225 bonds = get_tpr_bonds_gromacs(tpr_filepath) 

226 except: 

227 print(' Our tool failed to extract bonds. Using MDAnalysis extraction...') 

228 bonds = get_tpr_bonds_mdanalysis(tpr_filepath) 

229 return bonds 

230 

231# Get TPR bonds using MDAnalysis 

232# WARNING: Sometimes this function takes additional constrains as actual bonds 

233# DANI: si miras los topology.bonds.values estos enlaces falsos van al final 

234# DANI: Lo veo porque los índices están en orden ascendente y veulven a empezar 

235# DANI: He pedido ayuda aquí https://github.com/MDAnalysis/mdanalysis/pull/463 

236def get_tpr_bonds_mdanalysis (tpr_filepath : str) -> List[ Tuple[int, int] ]: 

237 parser = TPRParser(tpr_filepath) 

238 topology = parser.parse() 

239 bonds = list(topology.bonds.values) 

240 return bonds 

241 

242# Sort bonds according to our format: a list with the bonded atom indices for each atom 

243# Source data is the usual format to store bonds: a list of tuples with every pair of bonded atoms 

244def sort_bonds (source_bonds : List[ Tuple[int, int] ], atom_count : int) -> List[ List[int] ]: 

245 # Set a list of lists with an empty list for every atom 

246 atom_bonds = [ [] for i in range(atom_count) ] 

247 for bond in source_bonds: 

248 a,b = bond 

249 # Make sure atom indices are regular integers so they are JSON serializables 

250 atom_bonds[a].append(int(b)) 

251 atom_bonds[b].append(int(a)) 

252 return atom_bonds 

253 

254# Get safe bonds 

255# First try to mine bonds from a topology files 

256# If the mining fails then search for the most stable bonds 

257# If we turst in stable bonds then simply return the structure bonds 

258def find_safe_bonds ( 

259 topology_file : Union['File', Exception], 

260 structure_file : 'File', 

261 trajectory_file : 'File', 

262 must_check_stable_bonds : bool, 

263 snapshots : int, 

264 structure : 'Structure', 

265 # Optional file with bonds sorted according a new atom order 

266 resorted_bonds_file : Optional['File'] = None 

267) -> List[List[int]]: 

268 """Find reference safe bonds in the system.""" 

269 # If we have a resorted file then use it 

270 # Note that this is very excepcional 

271 if resorted_bonds_file != None and resorted_bonds_file.exists: 

272 warn('Using resorted safe bonds') 

273 return load_json(resorted_bonds_file.path) 

274 # Try to get bonds from the topology before guessing 

275 safe_bonds = mine_topology_bonds(topology_file) 

276 if safe_bonds: 

277 return safe_bonds 

278 # Get a selection including coarse grain atoms in the structure 

279 cg_selection = structure.select_cg() 

280 # If all bonds are in coarse grain the set all bonds "wrong" already 

281 if len(cg_selection) == structure.atom_count: 

282 safe_bonds = [ MISSING_BONDS for atom in range(structure.atom_count) ] 

283 return safe_bonds 

284 # If failed to mine topology bonds then guess stable bonds 

285 print('Bonds will be guessed by atom distances and radius') 

286 # Find stable bonds if necessary 

287 if must_check_stable_bonds: 

288 # Using the trajectory, find the most stable bonds 

289 print('Checking bonds along trajectory to determine which are stable') 

290 safe_bonds = get_most_stable_bonds(structure_file.path, trajectory_file.path, snapshots) 

291 discard_coarse_grain_bonds(safe_bonds, cg_selection) 

292 return safe_bonds 

293 # If we trust stable bonds then simply use structure bonds 

294 print('Default structure bonds will be used since they have been marked as trusted') 

295 safe_bonds = structure.bonds 

296 discard_coarse_grain_bonds(safe_bonds, cg_selection) 

297 return safe_bonds 

298 

299# Given a list of bonds, discard the ones in the coarse grain selection 

300# Note that the input list will be mutated 

301def discard_coarse_grain_bonds (bonds : list, cg_selection : 'Selection'): 

302 # For every atom in CG, replace its bonds with a class which will raise and error when read 

303 # Thus we make sure using these wrong bonds anywhere further will result in failure 

304 for atom_index in cg_selection.atom_indices: 

305 bonds[atom_index] = MISSING_BONDS