Coverage for mddb_workflow/tools/get_bonds.py: 75%
175 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 15:48 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 15:48 +0000
1from mddb_workflow.tools.get_pdb_frames import get_pdb_frames
2from mddb_workflow.utils.auxiliar import load_json, warn, MISSING_TOPOLOGY
3from mddb_workflow.utils.auxiliar import MISSING_BONDS, JSON_SERIALIZABLE_MISSING_BONDS
4from mddb_workflow.utils.constants import STANDARD_TOPOLOGY_FILENAME
5from mddb_workflow.utils.vmd_spells import get_covalent_bonds
6from mddb_workflow.utils.gmx_spells import get_tpr_bonds as get_tpr_bonds_gromacs
7from mddb_workflow.utils.gmx_spells import get_tpr_atom_count
8from mddb_workflow.utils.mda_spells import get_tpr_bonds_mdanalysis
9from mddb_workflow.utils.type_hints import *
10import pytraj as pt
11from collections import Counter
13FAILED_BOND_MINING_EXCEPTION = Exception('Failed to mine bonds')
15def get_excluded_atoms_selection (structure : 'Structure', pbc_selection : 'Selection') -> 'Selection':
16 """ Set some atoms which are to be skipped from bonding tests given their "fake" nature. """
17 # Get a selection of ion atoms which are not in PBC
18 # These ions are usually "tweaked" to be bonded to another atom although there is no real covalent bond
19 # They are not taken in count when testing coherent bonds or looking for the reference frame
20 non_pbc_ions_selection = structure.select_ions() - pbc_selection
21 # We also exclude coarse grain atoms since their bonds will never be found by a distance/radius guess
22 excluded_atoms_selection = non_pbc_ions_selection + structure.select_cg()
23 return excluded_atoms_selection
25def do_bonds_match (
26 bonds_1 : list[ list[int] ],
27 bonds_2 : list[ list[int] ],
28 # A selection of atoms whose bonds are not evaluated
29 excluded_atoms_selection : 'Selection',
30 # Set verbose as true to show which are the atoms preventing the match
31 verbose : bool = False,
32 # The rest of inputs are just for logs and debug
33 atoms : Optional[ list['Atom'] ] = None,
34 counter_list : Optional[ list[int] ] = None
35) -> bool:
36 """ Check if two sets of bonds match perfectly. """
37 # If the number of atoms in both lists is not matching then there is something very wrong
38 if len(bonds_1) != len(bonds_2):
39 raise ValueError(f'The number of atoms is not matching in both bond lists ({len(bonds_1)} and {len(bonds_2)})')
40 # Find ion atom indices
41 excluded_atom_indices = set(excluded_atoms_selection.atom_indices)
42 # For each atom, check bonds to match perfectly
43 # Order is not important
44 for atom_index, (atom_bonds_1, atom_bonds_2) in enumerate(zip(bonds_1, bonds_2)):
45 # Skip ion bonds
46 if atom_index in excluded_atom_indices:
47 continue
48 atom_bonds_set_1 = set(atom_bonds_1) - excluded_atom_indices
49 atom_bonds_set_2 = set(atom_bonds_2) - excluded_atom_indices
50 # Check atom bonds to match
51 if len(atom_bonds_set_1) != len(atom_bonds_set_2) or any(bond not in atom_bonds_set_2 for bond in atom_bonds_set_1):
52 if verbose:
53 if atoms:
54 mismatch_atom_label = atoms[atom_index].label
55 print(f' Mismatch in atom {mismatch_atom_label}:')
56 it_is_atom_labels = ', '.join([ atoms[index].label for index in atom_bonds_set_1 ])
57 print(f' It is bonded to atoms {it_is_atom_labels}')
58 it_should_be_atom_labels = ', '.join([ atoms[index].label for index in atom_bonds_set_2 ])
59 print(f' It should be bonded to atoms {it_should_be_atom_labels}')
60 else:
61 print(f' Mismatch in atom with index {atom_index}:')
62 it_is_atom_indices = ','.join([ str(index) for index in atom_bonds_set_1 ])
63 print(f' It is bonded to atoms with indices {it_is_atom_indices}')
64 it_should_be_atom_indices = ','.join([ str(index) for index in atom_bonds_set_2 ])
65 print(f' It should be bonded to atoms with indices {it_should_be_atom_indices}')
66 # Save for failure analysis
67 if counter_list is not None:
68 counter_list.append((atom_index,tuple(atom_bonds_set_1),tuple(atom_bonds_set_2)))
69 return False
70 return True
72def get_most_stable_bonds (
73 structure_filepath : str,
74 trajectory_filepath : str,
75 snapshots : int,
76 frames_limit : int = 10
77) -> list[ list[int] ]:
78 """ Get covalent bonds using VMD along different frames.
79 This way we avoid having false positives because 2 atoms are very close in one frame by accident.
80 This way we avoid having false negatives because 2 atoms are very far in one frame by accident. """
82 # Get each frame in pdb format to run VMD
83 print('Finding most stable bonds')
84 frames, step, count = get_pdb_frames(structure_filepath, trajectory_filepath,
85 snapshots, frames_limit, pbar_bool=True)
87 # Track bonds along frames
88 frame_bonds = []
90 # Iterate over the different frames
91 for current_frame_pdb in frames:
93 # Find the covalent bonds for the current frame
94 current_frame_bonds = get_covalent_bonds(current_frame_pdb)
95 frame_bonds.append(current_frame_bonds)
97 # Then keep those bonds which are respected in the majority of frames
98 # Usually wrongs bonds (both false positives and negatives) are formed only one frame
99 # It should not happend that a bond is formed around half of times
100 majority_cut = count / 2
101 atom_count = len(frame_bonds[0])
102 most_stable_bonds = []
103 for atom in range(atom_count):
104 total_bonds = []
105 # Accumulate all bonds
106 for frame in frame_bonds:
107 atom_bonds = frame[atom]
108 total_bonds += atom_bonds
109 # Keep only those bonds with more occurrences than half the number of frames
110 unique_bonds = set(total_bonds)
111 atom_most_stable_bonds = []
112 for bond in unique_bonds:
113 occurrences = total_bonds.count(bond)
114 if occurrences > majority_cut:
115 atom_most_stable_bonds.append(bond)
116 # Add current atom safe bonds to the total
117 most_stable_bonds.append(atom_most_stable_bonds)
119 return most_stable_bonds
121def get_bonds_reference_frame (
122 structure_file : 'File',
123 trajectory_file : 'File',
124 snapshots : int,
125 reference_bonds : list[ list[int] ],
126 structure : 'Structure',
127 pbc_selection : 'Selection',
128 patience : int = 100, # Limit of frames to check before we surrender
129 verbose : bool = False,
130) -> Optional[int]:
131 """ Return a reference frame number where all bonds are exactly as they should (by VMD standards).
132 This is the frame used when representing the MD. """
133 # Set some atoms which are to be skipped from these test given their "fake" nature
134 excluded_atoms_selection = get_excluded_atoms_selection(structure, pbc_selection)
136 # If all atoms are to be excluded then set the first frame as the reference frame and stop here
137 if len(excluded_atoms_selection) == len(structure.atoms): return 0
139 # Now that we have the reference bonds, we must find a frame where bonds are exactly the reference ones
140 # IMPORTANT: Note that we do not set a frames limit here, so all frames will be read and the step will be 1
141 frames, step, count = get_pdb_frames(structure_file.path, trajectory_file.path, snapshots,patience=patience)
142 if step != 1: raise ValueError('If we are skipping frames then the code below will silently return a wrong reference frame')
143 print(f'Searching the reference frame for the bonds. Only first {min(patience, count)} frames will be checked.')
144 # We check all frames but we stop as soon as we find a match
145 bonds_reference_frame = None
146 counter_list = []
147 for frame_number, frame_pdb in enumerate(frames):
148 # Get the actual frame number
149 bonds = get_covalent_bonds(frame_pdb)
150 if do_bonds_match(bonds, reference_bonds, excluded_atoms_selection, counter_list=counter_list, verbose=verbose):
151 bonds_reference_frame = frame_number
152 break
153 frames.close()
154 # If no frame has the reference bonds then we return None
155 if bonds_reference_frame == None:
156 # Print the first clashes table
157 print(' First clash stats:')
158 headers = ['Count', 'Atom', 'Is bonding with', 'Should bond with']
159 count = Counter(counter_list).most_common(10)
160 # Calculate column widths
161 table_data = []
162 for (at, bond, should), n in count:
163 table_data.append([n, at, bond, should])
164 col_widths = [max(len(str(item)) for item in col) for col in zip(*table_data, headers)]
165 # Format rows
166 def format_row(row):
167 return " | ".join(f"{str(item):>{col_widths[i]}}" for i, item in enumerate(row))
168 # Print table
169 print(format_row(headers))
170 print("-+-".join('-' * width for width in col_widths))
171 for row in table_data:
172 print(format_row(row))
173 return None
174 print(f' Got it -> Frame {bonds_reference_frame + 1}')
176 return bonds_reference_frame
178def mine_topology_bonds (bonds_source_file : Union['File', Exception]) -> list[ list[int] ]:
179 """ Extract bonds from a source file and format them per atom. """
180 # If there is no topology then return no bonds at all
181 if bonds_source_file == MISSING_TOPOLOGY or not bonds_source_file.exists:
182 return None
183 print('Mining atom bonds from topology file')
184 # If we have the standard topology then get bonds from it
185 if bonds_source_file.filename == STANDARD_TOPOLOGY_FILENAME:
186 print(f' Bonds in the "{bonds_source_file.filename}" file will be used')
187 standard_topology = load_json(bonds_source_file.path)
188 standard_atom_bonds = standard_topology.get('atom_bonds', None)
189 # Convert missing bonds flags
190 # These come from coarse grain (CG) simulations with no topology
191 atom_bonds = []
192 for bonds in standard_atom_bonds:
193 if bonds == JSON_SERIALIZABLE_MISSING_BONDS:
194 atom_bonds.append(MISSING_BONDS)
195 continue
196 atom_bonds.append(bonds)
197 if atom_bonds: return atom_bonds
198 print(' There were no bonds in the topology file. Is this an old file?')
199 # In some ocasions, bonds may come inside a topology which can be parsed through pytraj
200 elif bonds_source_file.is_pytraj_supported():
201 print(f' Bonds will be mined from "{bonds_source_file.path}"')
202 pt_topology = pt.load_topology(filename=bonds_source_file.path)
203 raw_bonds = [ bonds.indices for bonds in pt_topology.bonds ]
204 # Sort bonds
205 atom_count = pt_topology.n_atoms
206 atom_bonds = sort_bonds(raw_bonds, atom_count)
207 # If there is any bonding data then return atom bonds
208 if any(len(bonds) > 0 for bonds in atom_bonds): return atom_bonds
209 # If we have a TPR then use our own tool
210 elif bonds_source_file.format == 'tpr':
211 print(f' Bonds will be mined from TPR file "{bonds_source_file.path}"')
212 raw_bonds = get_tpr_bonds(bonds_source_file.path)
213 if raw_bonds != FAILED_BOND_MINING_EXCEPTION:
214 # Sort bonds
215 atom_count = get_tpr_atom_count(bonds_source_file.path)
216 atom_bonds = sort_bonds(raw_bonds, atom_count)
217 return atom_bonds
218 # If we failed to mine bonds then return None and they will be guessed further
219 print (' Failed to mine bonds -> They will be guessed from atom distances and radius')
220 return None
222def get_tpr_bonds (tpr_filepath : str) -> list[ tuple[int, int] ]:
223 """ Get TPR bonds.
224 Try 2 different methods and hope 1 of them works. """
225 try:
226 bonds = get_tpr_bonds_gromacs(tpr_filepath)
227 except:
228 print(' Our tool failed to extract bonds. Using MDAnalysis extraction...')
229 try:
230 bonds = get_tpr_bonds_mdanalysis(tpr_filepath)
231 # This happens when the filter selection is what we want to exclude, and not to keep
232 if len(bonds) == 0:
233 warn('Bonds were mined successfully but it looks like there are no bonds. Is it all ions or CG solvent?')
234 except Exception as err:
235 print(f' MDAnalysis failed to extract bonds: {err}. Relying on guess...')
236 bonds = FAILED_BOND_MINING_EXCEPTION
237 return bonds
239def sort_bonds (source_bonds : list[ tuple[int, int] ], atom_count : int) -> list[ list[int] ]:
240 """ Sort bonds according to our format: a list with the bonded atom indices for each atom.
241 Source data is the usual format to store bonds: a list of tuples with every pair of bonded atoms. """
242 # Set a list of lists with an empty list for every atom
243 atom_bonds = [ [] for i in range(atom_count) ]
244 for bond in source_bonds:
245 a,b = bond
246 # Make sure atom indices are regular integers so they are JSON serializables
247 atom_bonds[a].append(int(b))
248 atom_bonds[b].append(int(a))
249 return atom_bonds
251def find_safe_bonds (
252 topology_file : Union['File', Exception],
253 structure_file : 'File',
254 trajectory_file : 'File',
255 must_check_stable_bonds : bool,
256 snapshots : int,
257 structure : 'Structure',
258 guess_bonds : bool = False,
259 # Optional file with bonds sorted according a new atom order
260 resorted_bonds_file : Optional['File'] = None
261) -> list[list[int]]:
262 """ Find reference safe bonds in the system.
263 First try to mine bonds from a topology files.
264 If the mining fails then search for the most stable bonds.
265 If we trust in stable bonds then simply return the structure bonds. """
266 # If we have a resorted file then use it
267 # Note that this is very excepcional
268 if resorted_bonds_file != None and resorted_bonds_file.exists:
269 warn('Using resorted safe bonds')
270 return load_json(resorted_bonds_file.path)
271 # Get a selection including coarse grain atoms in the structure
272 cg_selection = structure.select_cg()
273 # Try to get bonds from the topology before guessing
274 if guess_bonds:
275 print('Bonds were forced to be guessed, so bonds in the topology file will be ignored')
276 else:
277 safe_bonds = mine_topology_bonds(topology_file)
278 if safe_bonds:
279 return safe_bonds
280 # If all bonds are in coarse grain then set all bonds "wrong" already
281 if len(cg_selection) == structure.atom_count:
282 safe_bonds = [ MISSING_BONDS for atom in range(structure.atom_count) ]
283 return safe_bonds
284 # If failed to mine topology bonds then guess stable bonds
285 print('Bonds will be guessed by atom distances and radii along different frames in the trajectory')
286 # Find stable bonds if necessary
287 if must_check_stable_bonds:
288 # Using the trajectory, find the most stable bonds
289 print('Checking bonds along trajectory to determine which are stable')
290 safe_bonds = get_most_stable_bonds(structure_file.path, trajectory_file.path, snapshots)
291 discard_coarse_grain_bonds(safe_bonds, cg_selection)
292 return safe_bonds
293 # If we trust stable bonds then simply use structure bonds
294 print('Default structure bonds will be used since they have been marked as trusted')
295 safe_bonds = structure.bonds
296 discard_coarse_grain_bonds(safe_bonds, cg_selection)
297 return safe_bonds
299def discard_coarse_grain_bonds (bonds : list, cg_selection : 'Selection'):
300 """ Given a list of bonds, discard the ones in the coarse grain selection.
301 Note that the input list will be mutated. """
302 # For every atom in CG, replace its bonds with a class which will raise and error when read
303 # Thus we make sure using these wrong bonds anywhere further will result in failure
304 for atom_index in cg_selection.atom_indices:
305 bonds[atom_index] = MISSING_BONDS