Coverage for mddb_workflow / analyses / rmsd_check.py: 27%
226 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 18:45 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 18:45 +0000
1import mdtraj as mdt
2import pytraj as pt
3import numpy as np
4import math
5import sys
7# Visual output tools
8from tqdm import tqdm
9import plotext as plt
11from mddb_workflow.utils.auxiliar import delete_previous_log, reprint, TestFailure, warn
12from mddb_workflow.utils.auxiliar import round_to_hundredths
13from mddb_workflow.utils.constants import TRAJECTORY_INTEGRITY_FLAG, RED_HEADER, COLOR_END
14from mddb_workflow.utils.pyt_spells import get_pytraj_trajectory
15from mddb_workflow.utils.type_hints import *
17# Check if the output is going to a terminal or not
18is_terminal = sys.stdout.isatty()
20# LORE
21# This test was originaly intended to use a RMSD jump cutoff based on number of atoms and timestep
22# However, after a deep study, it was observed that simulations with similar features may show very different RMSD jumps
23# For this reason now we comptue RMSD jumps along the whole trajectory and check that the biggest jump is not an outlier
24# The outlier is defined according to how many times the standard deviation far from the mean is a value
26# Look for sudden raises of RMSd values from one frame to another
27# To do so, we check the RMSD of every frame using its previous frame as reference
28# This was allowed thanks to people from MDtraj https://github.com/mdtraj/mdtraj/issues/1966
29def check_trajectory_integrity (
30 input_structure_filename : str,
31 input_trajectory_filename : str,
32 structure : 'Structure',
33 pbc_selection : 'Selection',
34 mercy : list[str],
35 trust: list[str],
36 register : 'Register',
37 #time_length : float,
38 check_selection : str,
39 # DANI: He visto saltos 'correctos' pasar de 6
40 # DANI: He visto saltos 'incorrectos' no bajar de 10
41 standard_deviations_cutoff : float,
42 snapshots : int) -> bool:
44 # Skip the test if we trust
45 if TRAJECTORY_INTEGRITY_FLAG in trust:
46 return True
48 # Skip the test if it is already passed according to the register
49 if register.tests.get(TRAJECTORY_INTEGRITY_FLAG, None):
50 return True
52 # If the trajectory has only 1 or 2 frames then there is no test to do
53 if snapshots < 3:
54 register.update_test(TRAJECTORY_INTEGRITY_FLAG, True)
55 return True
57 # Remove old warnings
58 register.remove_warnings(TRAJECTORY_INTEGRITY_FLAG)
60 # Parse the selection in VMD selection syntax
61 parsed_selection = structure.select(check_selection, syntax='vmd')
63 # If there is nothing to check then warn the user and stop here
64 if not parsed_selection:
65 raise Exception('WARNING: There are not atoms to be analyzed for the RMSD analysis')
67 # Discard PBC residues from the selection to be checked
68 parsed_selection -= pbc_selection
70 # If there is nothing to check then warn the user and stop here
71 if not parsed_selection:
72 warn('There are no atoms to be analyzed for the RMSD checking after PBC substraction')
73 register.update_test(TRAJECTORY_INTEGRITY_FLAG, 'na')
74 return True
76 # Get fragments out of the parsed selection
77 # Fragments will be analyzed independently
78 # A sudden jump of a small fragment may cause a small RMSD perturbation when
79 # it is part of a large fragment which is not jumping
80 # Jumps or artifacts of partial fragments along boundaries are rare imaging problems, althoug possible
81 # For this reason splitting the test in fragments is a good meassure
82 # Although big fragments could be splitted in even smaller parts in the future
83 # In the other hand a very small fragment may overcome the cutoff with small RMSD jumps, so beware
84 # If the structure is missing bonds then there are no fragments to select
85 # In this case we just split the structure in chains
86 if structure.is_missing_any_bonds():
87 fragments = [ chain.get_selection() for chain in structure.chains ]
88 else:
89 fragments = list(structure.find_fragments(parsed_selection))
91 print(f'Checking trajectory integrity ({len(fragments)} fragments)')
93 # Load the trajectory frame by frame
94 trajectory = mdt.iterload(input_trajectory_filename, top=input_structure_filename, chunk=1)
96 # Save the previous frame any time
97 previous_frame = next(trajectory)
99 # Save all RMSD jumps
100 fragment_rmsd_jumps = { fragment: [] for fragment in fragments }
102 # Initialize progress bar first
103 pbar = tqdm(trajectory, total=snapshots, desc=' Frame', unit='frame', initial=1)
105 # Iterate trajectory frames
106 for frame in pbar:
108 # Iterate over the different fragments
109 for fragment in fragment_rmsd_jumps:
110 # Calculate RMSD value between previous and current frame
111 # DANI: El centrado de MDtraj elimina el salto a través de las boundaries
112 # DANI: El precentered=True debería evitarlo, pero es ignorado si hay atom_indices
113 rmsd_value = mdt.rmsd(frame, previous_frame, atom_indices=fragment.atom_indices, superpose=False)[0]
114 fragment_rmsd_jumps[fragment].append(rmsd_value)
116 # Update the previous frame as the current one
117 previous_frame = frame
119 # Check all fragments and store their test result here before issuing the final report
120 fragment_reports = []
122 # Iterate over the different fragments
123 for fragment, rmsd_jumps in fragment_rmsd_jumps.items():
125 # Log the fragment name
126 fragment_name = structure.name_selection(fragment)
128 # Capture outliers
129 # If we capture more than 5 we stop searching
130 outliers_count = 0
131 max_z_score = 0
132 max_z_score_frame = 0
134 # Get the maximum RMSD value and check it is a reasonable deviation from the average values
135 # Otherwise, if it is an outlier, the test fails
136 mean_rmsd_jump = np.mean(rmsd_jumps)
137 stdv_rmsd_jump = np.std(rmsd_jumps)
139 # First frames may not be perfectly equilibrated and thus have stronger RMSD jumps
140 # For this reason we allow the first frames to bypass the check
141 # As soon as one frame is below the cutoff the bypass is finished for the following frames
142 # Count the number of bypassed frames for this fragment
143 bypassed_frames = 0
145 for i, rmsd_jump in enumerate(rmsd_jumps, 1):
146 z_score = abs( (rmsd_jump - mean_rmsd_jump) / stdv_rmsd_jump )
147 # Keep track of the maximum z score
148 if z_score > max_z_score:
149 max_z_score = z_score
150 max_z_score_frame = i
151 # If z score bypassed the limit then report it
152 if z_score > standard_deviations_cutoff:
153 # If there are as many bypassed frames as the index then it means no frame has passed the cutoff yet
154 if i - 1 == bypassed_frames:
155 bypassed_frames += 1
156 continue
157 # If we are no long bypassing frist frames then outliers mean the test has failed
158 # But don't kill the check just yet
159 if outliers_count < 4:
160 print(f' FAIL: Sudden RMSD jump in fragment "{fragment_name}" between frames {i} and {i+1} (z score = {round_to_hundredths(z_score)})')
161 if outliers_count == 4:
162 print(' etc...')
163 outliers_count += 1
165 # Save the report for this fragment
166 fragment_reports.append({
167 'name': fragment_name,
168 'score': max_z_score,
169 'frame': max_z_score_frame,
170 'outliers': outliers_count,
171 'bypass': bypassed_frames,
172 'jumps' : rmsd_jumps
173 })
175 # If there was any outliers or any bypassed frames then the test has failed
176 # LORE: Back in the day having bypassed frames was not a test failure, but now it is
177 any_outliers = any(report['outliers'] > 0 for report in fragment_reports)
178 any_bypassed_frames = any(report['bypass'] > 0 for report in fragment_reports)
180 # If the test has failed then display a full report
181 if any_outliers or any_bypassed_frames:
182 print('-- RMSD Check final report --')
183 for report in fragment_reports:
184 name = report['name']
185 max_z_score = round_to_hundredths(report['score'])
186 max_z_score_frame = report['frame']
187 outliers = report['outliers']
188 bypassed_frames = report['bypass']
189 passed = outliers == 0 and bypassed_frames == 0
190 # If the fragment has n0 problems
191 if passed:
192 print(f'Fragment "{name}" PASSED with a maximum z-score of {max_z_score}' + \
193 f' reported between frames {max_z_score_frame} and {max_z_score_frame + 1}')
194 continue
195 # If the fragment failed to pass the test
196 print(f'{RED_HEADER}Fragment "{name}" FAILED with a maximum z-score of {max_z_score}' + \
197 f' reported between frames {max_z_score_frame} and {max_z_score_frame + 1}{COLOR_END}')
198 if outliers > 0: print(f' Outliers: {outliers}')
199 if bypassed_frames > 0: print(f' Bypassed frames: {bypassed_frames}')
200 print('*The z-score of a value means how many times the standard deviation away it is from the average')
202 # If there were any outlier then the check has failed
203 if any_outliers > 0:
204 # Add a warning an return True since the test failed in case we have mercy
205 message = 'RMSD check has failed: there may be sudden jumps along the trajectory'
206 if TRAJECTORY_INTEGRITY_FLAG in mercy:
207 register.add_warning(TRAJECTORY_INTEGRITY_FLAG, message)
208 register.update_test(TRAJECTORY_INTEGRITY_FLAG, False)
209 return False
210 # Otherwise kill the process right away, after displaying a graph
211 max_z_score = max(report['score'] for report in fragment_reports)
212 worst_report = next(report for report in fragment_reports if report['score'] == max_z_score)
213 title = f'RMSD jumps along the trajectory in fragment "{report["name"]}"'
214 report_data = worst_report['jumps']
215 display_rmsd_jumps_graph(report_data, title)
216 raise TestFailure(message)
218 # Warn the user if we had bypassed frames
219 if any_bypassed_frames > 0:
220 # Set the error message
221 max_bypassed_frames = max(report['bypass'] for report in fragment_reports)
222 message = f'First {max_bypassed_frames} frames may be not equilibrated'
223 # Add a warning an return True since the test failed in case we have mercy
224 if TRAJECTORY_INTEGRITY_FLAG in mercy:
225 register.add_warning(TRAJECTORY_INTEGRITY_FLAG, message)
226 register.update_test(TRAJECTORY_INTEGRITY_FLAG, False)
227 return False
228 # Otherwise kill the process, after displaying a graph
229 max_z_score = max(report['score'] for report in fragment_reports)
230 worst_report = next(report for report in fragment_reports if report['score'] == max_z_score)
231 graph_frames_limit = max_bypassed_frames + 100
232 report_data = worst_report['jumps'][0:graph_frames_limit]
233 title = f'RMSD jumps along the trajectory in fragment "{report["name"]}" (first {graph_frames_limit} frames)'
234 display_rmsd_jumps_graph(report_data, title)
235 raise TestFailure(message)
237 print(' Test has passed successfully')
238 register.update_test(TRAJECTORY_INTEGRITY_FLAG, True)
239 return True
241# Display a graph to show the distribution of sudden jumps along the trajectory in the terminal itself
242def display_rmsd_jumps_graph (data : list, title : str):
243 if not is_terminal: return
244 # Display a graph to show the distribution of sudden jumps along the trajectory
245 plt.scatter(data)
246 plt.title(title)
247 plt.xlabel('Frame')
248 plt.ylabel('RMSD jump')
249 n_ticks = 5
250 n_jumps = len(data)
251 tickstep = math.ceil(n_jumps / n_ticks)
252 xticks = [ t+1 for t in range(0, n_jumps, tickstep) ]
253 xlabels = [ str(t) for t in xticks ]
254 plt.xticks(xticks, xlabels)
255 plt.show()
257# Compute every residue RMSD to check if there are sudden jumps along the trajectory
258# HARDCODE: This function is not fully implemented but enabled manually for specific cases
259def check_trajectory_integrity_per_residue (
260 input_structure_filename : str,
261 input_trajectory_filename : str,
262 structure : 'Structure',
263 pbc_selection : 'Selection',
264 mercy : list[str],
265 trust: list[str],
266 register : 'Register',
267 #time_length : float,
268 check_selection : str,
269 # DANI: He visto saltos 'correctos' pasar de 11
270 # DANI: He visto saltos 'incorrectos' no bajar de 14
271 standard_deviations_cutoff : float):
273 # HARDCODE: The default value does not work for a single residue
274 standard_deviations_cutoff = 12
276 # Skip the test if we trust
277 if TRAJECTORY_INTEGRITY_FLAG in trust:
278 return True
280 # Skip the test if it is already passed according to the register
281 if register.tests.get(TRAJECTORY_INTEGRITY_FLAG, None):
282 return True
284 # Remove old warnings
285 register.remove_warnings(TRAJECTORY_INTEGRITY_FLAG)
287 # Parse the selection in VMD selection syntax
288 parsed_selection = structure.select(check_selection, syntax='vmd')
290 # If there is nothing to check then warn the user and stop here
291 if not parsed_selection:
292 raise Exception('WARNING: There are not atoms to be analyzed for the RMSD analysis')
294 # Discard PBC residues from the selection to be checked
295 parsed_selection -= pbc_selection
297 # If there is nothing to check then warn the user and stop here
298 if not parsed_selection:
299 warn('There are no atoms to be analyzed for the RMSD checking after PBC substraction')
300 register.update_test(TRAJECTORY_INTEGRITY_FLAG, 'na')
301 return True
303 # We must filter out residues which only have 1 atom (e.g. ions)
304 # This is because sometimes pytraj does not return results for them and then the number of results and residues does not match
305 # More info: https://github.com/Amber-MD/pytraj/issues/1580
306 ion_atom_indices = []
307 for residue in structure.residues:
308 if len(residue.atom_indices) == 1:
309 ion_atom_indices += residue.atom_indices
310 ions_selection = structure.select_atom_indices(ion_atom_indices)
311 parsed_selection -= ions_selection
313 # Parse the selection to a pytraj mask
314 pytraj_selection = parsed_selection.to_pytraj()
316 # Calculate the residue indices of the overall structure remaining in the filtered trajectory
317 residue_indices = structure.get_selection_residue_indices(parsed_selection)
318 n_residues = len(residue_indices)
320 print('Checking trajectory integrity per residue')
322 # Parse the trajectory into pytraj and apply the mask
323 # NEVER FORGET: The pytraj iterload does not accept a mask, but we apply the mask later in the analysis
324 pt_trajectory = get_pytraj_trajectory(input_structure_filename, input_trajectory_filename, atom_selection = parsed_selection)
326 # Make sure the expected output number of residues to match with the pytraj results
327 # These numbers may not match when ions are included so we better check
328 # NEVER FORGET: The pytraj TrajectoryIterator is not an iterator
329 first_frame = pt_trajectory[0:1]
330 # DANI: When the 'resname' argument is missing it prints "Error: Range::SetRange(None): Range is -1 for None"
331 # DANI: However there is no problem and the analysis runs flawlessly
332 # DANI: For this reason we call this function with no resname and then we remove the log
333 data_sample = pt.rmsd_perres(first_frame, ref=first_frame, perres_mask=pytraj_selection)
334 # We remove the previous error log
335 delete_previous_log()
336 # We remove the first result, which is meant to be the whole rmsd and whose key is 'RMSD_00001'
337 del data_sample[0]
338 if n_residues != len(data_sample):
339 raise ValueError(f'Number of target residues ({n_residues}) does not match number of residues in data ({len(data_sample)})')
341 # Saving all RMSD jumps may take a lot of memory
342 # Instead we will store the sum of values and the maximum
343 # This way we can caluclate the average value at the end and check if the maximum is too far from it
344 rmsd_per_residue_per_frame = []
346 # Add an extra breakline before the first log
347 print()
349 # Iterate trajectory frames
350 previous_frame_trajectory = first_frame
351 frame_number = 1
352 for frame in pt_trajectory:
353 # Update the current frame log
354 reprint(f' Frame {frame_number}')
355 # Set a pytraj trajectory out of a pytraj frame
356 frame_trajectory = pt.Trajectory(top=pt_trajectory.top)
357 frame_trajectory.append(frame)
358 # Run the analysis in pytraj
359 # The result data is a custom pytraj class: pytraj.datasets.datasetlist.DatasetList
360 # This class has keys but its attributes can not be accessed through the key
361 # They must be accessed thorugh the index
362 # DANI: When the 'resname' argument is missing it prints "Error: Range::SetRange(None): Range is -1 for None"
363 # DANI: However there is no problem and the analysis runs flawlessly
364 # DANI: Adding resrage as a list/range was tried and did not work, only string works
365 # DANI: Adding a string resrange however strongly impacts the speed when this function is called repeatedly
366 # DANI: For this reason we call this function with no resname and then we remove the log
367 rmsd_per_residue = pt.rmsd_perres(frame_trajectory, ref=previous_frame_trajectory, mask=pytraj_selection)
368 # We remove the previous error log
369 delete_previous_log()
370 # We remove the first result, which is meant to be the whole rmsd and whose key is 'RMSD_00001'
371 del rmsd_per_residue[0]
372 # Check we have no NaNs
373 if np.isnan(rmsd_per_residue[0][0]):
374 raise ValueError(f'We are having NaNs at frame {frame_number}')
375 # Add last values to the list
376 rmsd_per_residue_per_frame.append(rmsd_per_residue)
377 # rmsd_per_residue_per_frame[:, frame]
378 # Now update data for every residue
379 # for index, residue_rmsd in enumerate(rmsd_per_residue):
380 # current_rmsd = residue_rmsd[0]
381 # # Get the current residue rmsd data
382 # total_rmsd = rmsd_per_residue_per_frame[index]
383 # total_rmsd['accumulated'] += current_rmsd
384 # total_rmsd['maximum'] = max(total_rmsd['maximum'], current_rmsd)
385 # Update previous coordinates
386 previous_frame_trajectory = frame_trajectory
387 # Update the frame_number
388 frame_number += 1
390 # If the trajectory has only 1 or 2 frames then there is no test to do
391 n_jumps = len(rmsd_per_residue_per_frame)
392 if n_jumps <= 1:
393 register.update_test(TRAJECTORY_INTEGRITY_FLAG, True)
394 return True
396 # Keep the overall maximum z score, and its residue and frame for the logs
397 overall_max_z_score = 0
398 overall_max_z_score_frame = None
399 overall_max_z_score_residue = None
401 # Keep the overall maximum bypassed frames number
402 overall_bypassed_frames = 0
404 # Keep the overall count of residues with outliers
405 overall_outliered_residues = 0
407 # Add an extra breakline before the next log
408 print()
410 # Now check there are not sudden jumps for each residue separattely
411 for residue_number in range(n_residues):
412 reprint(f' Residue {residue_number+1}')
413 # Get the rmsd jumps for each frame for this specific residue
414 rmsd_jumps = [ frame[residue_number] for frame in rmsd_per_residue_per_frame ]
416 # Get the maximum RMSD value and check it is a reasonable deviation from the average values
417 # Otherwise, if it is an outlier, the test fails
418 mean_rmsd_jump = np.mean(rmsd_jumps)
419 stdv_rmsd_jump = np.std(rmsd_jumps)
421 # First frames may not be perfectly equilibrated and thus have stronger RMSD jumps
422 # For this reason we allow the first frames to bypass the check
423 # As soon as one frame is below the cutoff the bypass is finished for the following frames
424 # Count the number of bypassed frames and warn the user in case there are any
425 bypassed_frames = 0
427 # Capture outliers
428 # If we capture more than 5 we stop searching
429 outliers_count = 0
430 max_z_score = 0
431 max_z_score_frame = 0
432 for i, rmsd_jump in enumerate(rmsd_jumps, 1):
433 z_score = abs( (rmsd_jump - mean_rmsd_jump) / stdv_rmsd_jump )
434 # Keep track of the maixmum z score
435 if z_score > max_z_score:
436 max_z_score = z_score
437 max_z_score_frame = i
438 # If z score bypassed the limit then report it
439 if z_score > standard_deviations_cutoff:
440 # If there are as many bypassed frames as the index then it means no frame has passed the cutoff yet
441 if i - 1 == bypassed_frames:
442 bypassed_frames += 1
443 continue
444 # Otherwise we consider this as an outlier and thus the test has failed
445 # However we keep checking just to find and report the highest outlier
446 outliers_count += 1
448 # Update the overall bypassed frames if we overcomed it
449 if overall_bypassed_frames < bypassed_frames:
450 overall_bypassed_frames = bypassed_frames
452 # Update the overall max z score if we overcome it
453 if max_z_score > overall_max_z_score:
454 overall_max_z_score = max_z_score
455 overall_max_z_score_frame = max_z_score_frame
456 overall_max_z_score_residue = residue_number
458 # If there were any outlier then add one to the overall count
459 if outliers_count > 0:
460 overall_outliered_residues += 1
462 # Always print the overall maximum z score and its frames and residue
463 overall_max_z_score_residue_label = pt_trajectory.top.residue(overall_max_z_score_residue)
464 print(f' Maximum z score {overall_max_z_score} reported for residue {overall_max_z_score_residue_label} between frames {overall_max_z_score_frame} and {overall_max_z_score_frame + 1}')
466 # If there were any outlier then the check has failed
467 if overall_outliered_residues > 0:
468 # Add a warning an return True since the test failed in case we have mercy
469 message = 'RMSD check has failed: there may be sudden jumps along the trajectory'
470 if TRAJECTORY_INTEGRITY_FLAG in mercy:
471 register.add_warning(TRAJECTORY_INTEGRITY_FLAG, message)
472 register.update_test(TRAJECTORY_INTEGRITY_FLAG, False)
473 return False
474 # Otherwise kill the process right away
475 raise TestFailure(message)
477 # Warn the user if we had bypassed frames
478 if overall_bypassed_frames > 0:
479 register.add_warning(TRAJECTORY_INTEGRITY_FLAG, f'First {overall_bypassed_frames} frames may be not equilibrated')
481 print(' Test has passed successfully')
482 register.update_test(TRAJECTORY_INTEGRITY_FLAG, True)
483 return True