Coverage for model_workflow/tools/get_bonds.py: 71%
170 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
1from model_workflow.tools.get_pdb_frames import get_pdb_frames
2from model_workflow.utils.auxiliar import load_json, warn, MISSING_TOPOLOGY
3from model_workflow.utils.auxiliar import MISSING_BONDS, JSON_SERIALIZABLE_MISSING_BONDS
4from model_workflow.utils.constants import STANDARD_TOPOLOGY_FILENAME
5from model_workflow.utils.vmd_spells import get_covalent_bonds
6from model_workflow.utils.gmx_spells import get_tpr_bonds as get_tpr_bonds_gromacs
7from model_workflow.utils.gmx_spells import get_tpr_atom_count
8from model_workflow.utils.type_hints import *
9import pytraj as pt
10from MDAnalysis.topology.TPRParser import TPRParser
11from collections import Counter
13# Set some atoms which are to be skipped from bonding tests given their "fake" nature
14def get_excluded_atoms_selection (structure : 'Structure', pbc_selection : 'Selection') -> 'Selection':
15 # Get a selection of ion atoms which are not in PBC
16 # These ions are usually "tweaked" to be bonded to another atom although there is no real covalent bond
17 # They are not taken in count when testing coherent bonds or looking for the reference frame
18 non_pbc_ions_selection = structure.select_ions() - pbc_selection
19 # We also exclude coarse grain atoms since their bonds will never be found by a distance/radius guess
20 excluded_atoms_selection = non_pbc_ions_selection + structure.select_cg()
21 return excluded_atoms_selection
23# Check if two sets of bonds match perfectly
24def do_bonds_match (
25 bonds_1 : List[ List[int] ],
26 bonds_2 : List[ List[int] ],
27 # A selection of atoms whose bonds are not evaluated
28 excluded_atoms_selection : 'Selection',
29 # Set verbose as true to show which are the atoms preventing the match
30 verbose : bool = False,
31 # The rest of inputs are just for logs and debug
32 atoms : Optional[ List['Atom'] ] = None,
33 counter_list : Optional[ List[int] ] = None
34) -> bool:
35 # If the number of atoms in both lists is not matching then there is something very wrong
36 if len(bonds_1) != len(bonds_2):
37 raise ValueError(f'The number of atoms is not matching in both bond lists ({len(bonds_1)} and {len(bonds_2)})')
38 # Find ion atom indices
39 excluded_atom_indices = set(excluded_atoms_selection.atom_indices)
40 # For each atom, check bonds to match perfectly
41 # Order is not important
42 for atom_index, (atom_bonds_1, atom_bonds_2) in enumerate(zip(bonds_1, bonds_2)):
43 # Skip ion bonds
44 if atom_index in excluded_atom_indices:
45 continue
46 atom_bonds_set_1 = set(atom_bonds_1) - excluded_atom_indices
47 atom_bonds_set_2 = set(atom_bonds_2) - excluded_atom_indices
48 # Check atom bonds to match
49 if len(atom_bonds_set_1) != len(atom_bonds_set_2) or any(bond not in atom_bonds_set_2 for bond in atom_bonds_set_1):
50 if verbose:
51 if atoms:
52 mismatch_atom_label = atoms[atom_index].label
53 print(f' Mismatch in atom {mismatch_atom_label}:')
54 it_is_atom_labels = ', '.join([ atoms[index].label for index in atom_bonds_set_1 ])
55 print(f' It is bonded to atoms {it_is_atom_labels}')
56 it_should_be_atom_labels = ', '.join([ atoms[index].label for index in atom_bonds_set_2 ])
57 print(f' It should be bonded to atoms {it_should_be_atom_labels}')
58 else:
59 print(f' Mismatch in atom with index {atom_index}:')
60 it_is_atom_indices = ','.join([ str(index) for index in atom_bonds_set_1 ])
61 print(f' It is bonded to atoms with indices {it_is_atom_indices}')
62 it_should_be_atom_indices = ','.join([ str(index) for index in atom_bonds_set_2 ])
63 print(f' It should be bonded to atoms with indices {it_should_be_atom_indices}')
64 # Save for failure analysis
65 if counter_list is not None:
66 counter_list.append((atom_index,tuple(atom_bonds_set_1),tuple(atom_bonds_set_2)))
67 return False
68 return True
70# Get covalent bonds using VMD along different frames
71# This way we avoid having false positives because 2 atoms are very close in one frame by accident
72# This way we avoid having false negatives because 2 atoms are very far in one frame by accident
73def get_most_stable_bonds (
74 structure_filepath : str,
75 trajectory_filepath : str,
76 snapshots : int,
77 frames_limit : int = 10
78) -> List[ List[int] ]:
80 # Get each frame in pdb format to run VMD
81 print('Finding most stable bonds')
82 frames, step, count = get_pdb_frames(structure_filepath, trajectory_filepath,
83 snapshots, frames_limit, pbar_bool=True)
85 # Track bonds along frames
86 frame_bonds = []
88 # Iterate over the different frames
89 for current_frame_pdb in frames:
91 # Find the covalent bonds for the current frame
92 current_frame_bonds = get_covalent_bonds(current_frame_pdb)
93 frame_bonds.append(current_frame_bonds)
95 # Then keep those bonds which are respected in the majority of frames
96 # Usually wrongs bonds (both false positives and negatives) are formed only one frame
97 # It should not happend that a bond is formed around half of times
98 majority_cut = count / 2
99 atom_count = len(frame_bonds[0])
100 most_stable_bonds = []
101 for atom in range(atom_count):
102 total_bonds = []
103 # Accumulate all bonds
104 for frame in frame_bonds:
105 atom_bonds = frame[atom]
106 total_bonds += atom_bonds
107 # Keep only those bonds with more occurrences than half the number of frames
108 unique_bonds = set(total_bonds)
109 atom_most_stable_bonds = []
110 for bond in unique_bonds:
111 occurrences = total_bonds.count(bond)
112 if occurrences > majority_cut:
113 atom_most_stable_bonds.append(bond)
114 # Add current atom safe bonds to the total
115 most_stable_bonds.append(atom_most_stable_bonds)
117 return most_stable_bonds
120def get_bonds_canonical_frame (
121 structure_file : 'File',
122 trajectory_file : 'File',
123 snapshots : int,
124 reference_bonds : List[ List[int] ],
125 structure : 'Structure',
126 pbc_selection : 'Selection',
127 patience : int = 100, # Limit of frames to check before we surrender
128 verbose : bool = False,
129) -> Optional[int]:
130 """Return a canonical frame number where all bonds are exactly as they should.
131 This is the frame used when representing the MD."""
132 # Set some atoms which are to be skipped from these test given their "fake" nature
133 excluded_atoms_selection = get_excluded_atoms_selection(structure, pbc_selection)
135 # If all atoms are to be excluded then set the first frame as the reference frame and stop here
136 if len(excluded_atoms_selection) == len(structure.atoms): return 0
138 # Now that we have the reference bonds, we must find a frame where bonds are exactly the canonical ones
139 # IMPORTANT: Note that we do not set a frames limit here, so all frames will be read and the step will be 1
140 frames, step, count = get_pdb_frames(structure_file.path, trajectory_file.path, snapshots,patience=patience)
141 if step != 1: raise ValueError('If we are skipping frames then the code below will silently return a wrong reference frame')
142 print(f'Searching reference bonds canonical frame. Only first {min(patience, count)} frames will be checked.')
143 # We check all frames but we stop as soon as we find a match
144 reference_bonds_frame = None
145 counter_list = []
146 for frame_number, frame_pdb in enumerate(frames):
147 # Get the actual frame number
148 bonds = get_covalent_bonds(frame_pdb)
149 if do_bonds_match(bonds, reference_bonds, excluded_atoms_selection, counter_list=counter_list, verbose=verbose):
150 reference_bonds_frame = frame_number
151 break
152 frames.close()
153 # If no frame has the canonical bonds then we return None
154 if reference_bonds_frame == None:
155 # Print the first clashes
156 print(' First clash stats:')
157 headers = ['Count', 'Atom', 'Is bonding with', 'Should bond with']
158 count = Counter(counter_list).most_common(10)
159 # Calculate column widths
160 table_data = []
161 for (at, bond, should), n in count:
162 table_data.append([n, at, bond, should])
163 col_widths = [max(len(str(item)) for item in col) for col in zip(*table_data, headers)]
164 # Format rows
165 def format_row(row):
166 return " | ".join(f"{str(item):>{col_widths[i]}}" for i, item in enumerate(row))
167 # Print table
168 print(format_row(headers))
169 print("-+-".join('-' * width for width in col_widths))
170 for row in table_data:
171 print(format_row(row))
172 return None
173 print(f' Got it -> Frame {reference_bonds_frame + 1}')
175 return reference_bonds_frame
177# Extract bonds from a source file and format them per atom
178def mine_topology_bonds (bonds_source_file : Union['File', Exception]) -> List[ List[int] ]:
179 # If there is no topology then return no bonds at all
180 if bonds_source_file == MISSING_TOPOLOGY or not bonds_source_file.exists:
181 return None
182 print('Mining atom bonds from topology file')
183 # If we have the standard topology then get bonds from it
184 if bonds_source_file.filename == STANDARD_TOPOLOGY_FILENAME:
185 print(f' Bonds in the "{bonds_source_file.filename}" file will be used')
186 standard_topology = load_json(bonds_source_file.path)
187 standard_atom_bonds = standard_topology.get('atom_bonds', None)
188 # Convert missing bonds flags
189 # These come from coarse grain (CG) simulations with no topology
190 atom_bonds = []
191 for bonds in standard_atom_bonds:
192 if bonds == JSON_SERIALIZABLE_MISSING_BONDS:
193 atom_bonds.append(MISSING_BONDS)
194 continue
195 atom_bonds.append(bonds)
196 if atom_bonds: return atom_bonds
197 print(' There were no bonds in the topology file. Is this an old file?')
198 # In some ocasions, bonds may come inside a topology which can be parsed through pytraj
199 elif bonds_source_file.is_pytraj_supported():
200 print(f' Bonds will be mined from "{bonds_source_file.path}"')
201 pt_topology = pt.load_topology(filename=bonds_source_file.path)
202 raw_bonds = [ bonds.indices for bonds in pt_topology.bonds ]
203 # Sort bonds
204 atom_count = pt_topology.n_atoms
205 atom_bonds = sort_bonds(raw_bonds, atom_count)
206 # If there is any bonding data then return atom bonds
207 if any(len(bonds) > 0 for bonds in atom_bonds): return atom_bonds
208 # If we have a TPR then use our own tool
209 elif bonds_source_file.format == 'tpr':
210 print(f' Bonds will be mined from TPR file "{bonds_source_file.path}"')
211 raw_bonds = get_tpr_bonds(bonds_source_file.path)
212 # Sort bonds
213 atom_count = get_tpr_atom_count(bonds_source_file.path)
214 atom_bonds = sort_bonds(raw_bonds, atom_count)
215 # If there is any bonding data then return atom bonds
216 if any(len(bonds) > 0 for bonds in atom_bonds): return atom_bonds
217 # If we failed to mine bonds then return None and they will be guessed further
218 print (' Failed to mine bonds -> They will be guessed from atom distances and radius')
219 return None
221# Get TPR bonds
222# Try 2 different methods and hope 1 of them works
223def get_tpr_bonds (tpr_filepath : str) -> List[ Tuple[int, int] ]:
224 try:
225 bonds = get_tpr_bonds_gromacs(tpr_filepath)
226 except:
227 print(' Our tool failed to extract bonds. Using MDAnalysis extraction...')
228 bonds = get_tpr_bonds_mdanalysis(tpr_filepath)
229 return bonds
231# Get TPR bonds using MDAnalysis
232# WARNING: Sometimes this function takes additional constrains as actual bonds
233# DANI: si miras los topology.bonds.values estos enlaces falsos van al final
234# DANI: Lo veo porque los índices están en orden ascendente y veulven a empezar
235# DANI: He pedido ayuda aquí https://github.com/MDAnalysis/mdanalysis/pull/463
236def get_tpr_bonds_mdanalysis (tpr_filepath : str) -> List[ Tuple[int, int] ]:
237 parser = TPRParser(tpr_filepath)
238 topology = parser.parse()
239 bonds = list(topology.bonds.values)
240 return bonds
242# Sort bonds according to our format: a list with the bonded atom indices for each atom
243# Source data is the usual format to store bonds: a list of tuples with every pair of bonded atoms
244def sort_bonds (source_bonds : List[ Tuple[int, int] ], atom_count : int) -> List[ List[int] ]:
245 # Set a list of lists with an empty list for every atom
246 atom_bonds = [ [] for i in range(atom_count) ]
247 for bond in source_bonds:
248 a,b = bond
249 # Make sure atom indices are regular integers so they are JSON serializables
250 atom_bonds[a].append(int(b))
251 atom_bonds[b].append(int(a))
252 return atom_bonds
254# Get safe bonds
255# First try to mine bonds from a topology files
256# If the mining fails then search for the most stable bonds
257# If we turst in stable bonds then simply return the structure bonds
258def find_safe_bonds (
259 topology_file : Union['File', Exception],
260 structure_file : 'File',
261 trajectory_file : 'File',
262 must_check_stable_bonds : bool,
263 snapshots : int,
264 structure : 'Structure',
265 # Optional file with bonds sorted according a new atom order
266 resorted_bonds_file : Optional['File'] = None
267) -> List[List[int]]:
268 """Find reference safe bonds in the system."""
269 # If we have a resorted file then use it
270 # Note that this is very excepcional
271 if resorted_bonds_file != None and resorted_bonds_file.exists:
272 warn('Using resorted safe bonds')
273 return load_json(resorted_bonds_file.path)
274 # Try to get bonds from the topology before guessing
275 safe_bonds = mine_topology_bonds(topology_file)
276 if safe_bonds:
277 return safe_bonds
278 # Get a selection including coarse grain atoms in the structure
279 cg_selection = structure.select_cg()
280 # If all bonds are in coarse grain the set all bonds "wrong" already
281 if len(cg_selection) == structure.atom_count:
282 safe_bonds = [ MISSING_BONDS for atom in range(structure.atom_count) ]
283 return safe_bonds
284 # If failed to mine topology bonds then guess stable bonds
285 print('Bonds will be guessed by atom distances and radius')
286 # Find stable bonds if necessary
287 if must_check_stable_bonds:
288 # Using the trajectory, find the most stable bonds
289 print('Checking bonds along trajectory to determine which are stable')
290 safe_bonds = get_most_stable_bonds(structure_file.path, trajectory_file.path, snapshots)
291 discard_coarse_grain_bonds(safe_bonds, cg_selection)
292 return safe_bonds
293 # If we trust stable bonds then simply use structure bonds
294 print('Default structure bonds will be used since they have been marked as trusted')
295 safe_bonds = structure.bonds
296 discard_coarse_grain_bonds(safe_bonds, cg_selection)
297 return safe_bonds
299# Given a list of bonds, discard the ones in the coarse grain selection
300# Note that the input list will be mutated
301def discard_coarse_grain_bonds (bonds : list, cg_selection : 'Selection'):
302 # For every atom in CG, replace its bonds with a class which will raise and error when read
303 # Thus we make sure using these wrong bonds anywhere further will result in failure
304 for atom_index in cg_selection.atom_indices:
305 bonds[atom_index] = MISSING_BONDS