Coverage for mddb_workflow / analyses / rmsd_check.py: 27%

226 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 18:45 +0000

1import mdtraj as mdt 

2import pytraj as pt 

3import numpy as np 

4import math 

5import sys 

6 

7# Visual output tools 

8from tqdm import tqdm 

9import plotext as plt 

10 

11from mddb_workflow.utils.auxiliar import delete_previous_log, reprint, TestFailure, warn 

12from mddb_workflow.utils.auxiliar import round_to_hundredths 

13from mddb_workflow.utils.constants import TRAJECTORY_INTEGRITY_FLAG, RED_HEADER, COLOR_END 

14from mddb_workflow.utils.pyt_spells import get_pytraj_trajectory 

15from mddb_workflow.utils.type_hints import * 

16 

17# Check if the output is going to a terminal or not 

18is_terminal = sys.stdout.isatty() 

19 

20# LORE 

21# This test was originaly intended to use a RMSD jump cutoff based on number of atoms and timestep 

22# However, after a deep study, it was observed that simulations with similar features may show very different RMSD jumps 

23# For this reason now we comptue RMSD jumps along the whole trajectory and check that the biggest jump is not an outlier 

24# The outlier is defined according to how many times the standard deviation far from the mean is a value 

25 

26# Look for sudden raises of RMSd values from one frame to another 

27# To do so, we check the RMSD of every frame using its previous frame as reference 

28# This was allowed thanks to people from MDtraj https://github.com/mdtraj/mdtraj/issues/1966 

29def check_trajectory_integrity ( 

30 input_structure_filename : str, 

31 input_trajectory_filename : str, 

32 structure : 'Structure', 

33 pbc_selection : 'Selection', 

34 mercy : list[str], 

35 trust: list[str], 

36 register : 'Register', 

37 #time_length : float, 

38 check_selection : str, 

39 # DANI: He visto saltos 'correctos' pasar de 6 

40 # DANI: He visto saltos 'incorrectos' no bajar de 10 

41 standard_deviations_cutoff : float, 

42 snapshots : int) -> bool: 

43 

44 # Skip the test if we trust 

45 if TRAJECTORY_INTEGRITY_FLAG in trust: 

46 return True 

47 

48 # Skip the test if it is already passed according to the register 

49 if register.tests.get(TRAJECTORY_INTEGRITY_FLAG, None): 

50 return True 

51 

52 # If the trajectory has only 1 or 2 frames then there is no test to do 

53 if snapshots < 3: 

54 register.update_test(TRAJECTORY_INTEGRITY_FLAG, True) 

55 return True 

56 

57 # Remove old warnings 

58 register.remove_warnings(TRAJECTORY_INTEGRITY_FLAG) 

59 

60 # Parse the selection in VMD selection syntax 

61 parsed_selection = structure.select(check_selection, syntax='vmd') 

62 

63 # If there is nothing to check then warn the user and stop here 

64 if not parsed_selection: 

65 raise Exception('WARNING: There are not atoms to be analyzed for the RMSD analysis') 

66 

67 # Discard PBC residues from the selection to be checked 

68 parsed_selection -= pbc_selection 

69 

70 # If there is nothing to check then warn the user and stop here 

71 if not parsed_selection: 

72 warn('There are no atoms to be analyzed for the RMSD checking after PBC substraction') 

73 register.update_test(TRAJECTORY_INTEGRITY_FLAG, 'na') 

74 return True 

75 

76 # Get fragments out of the parsed selection 

77 # Fragments will be analyzed independently 

78 # A sudden jump of a small fragment may cause a small RMSD perturbation when 

79 # it is part of a large fragment which is not jumping 

80 # Jumps or artifacts of partial fragments along boundaries are rare imaging problems, althoug possible 

81 # For this reason splitting the test in fragments is a good meassure 

82 # Although big fragments could be splitted in even smaller parts in the future 

83 # In the other hand a very small fragment may overcome the cutoff with small RMSD jumps, so beware 

84 # If the structure is missing bonds then there are no fragments to select 

85 # In this case we just split the structure in chains 

86 if structure.is_missing_any_bonds(): 

87 fragments = [ chain.get_selection() for chain in structure.chains ] 

88 else: 

89 fragments = list(structure.find_fragments(parsed_selection)) 

90 

91 print(f'Checking trajectory integrity ({len(fragments)} fragments)') 

92 

93 # Load the trajectory frame by frame 

94 trajectory = mdt.iterload(input_trajectory_filename, top=input_structure_filename, chunk=1) 

95 

96 # Save the previous frame any time 

97 previous_frame = next(trajectory) 

98 

99 # Save all RMSD jumps 

100 fragment_rmsd_jumps = { fragment: [] for fragment in fragments } 

101 

102 # Initialize progress bar first 

103 pbar = tqdm(trajectory, total=snapshots, desc=' Frame', unit='frame', initial=1) 

104 

105 # Iterate trajectory frames 

106 for frame in pbar: 

107 

108 # Iterate over the different fragments 

109 for fragment in fragment_rmsd_jumps: 

110 # Calculate RMSD value between previous and current frame 

111 # DANI: El centrado de MDtraj elimina el salto a través de las boundaries 

112 # DANI: El precentered=True debería evitarlo, pero es ignorado si hay atom_indices 

113 rmsd_value = mdt.rmsd(frame, previous_frame, atom_indices=fragment.atom_indices, superpose=False)[0] 

114 fragment_rmsd_jumps[fragment].append(rmsd_value) 

115 

116 # Update the previous frame as the current one 

117 previous_frame = frame 

118 

119 # Check all fragments and store their test result here before issuing the final report 

120 fragment_reports = [] 

121 

122 # Iterate over the different fragments 

123 for fragment, rmsd_jumps in fragment_rmsd_jumps.items(): 

124 

125 # Log the fragment name 

126 fragment_name = structure.name_selection(fragment) 

127 

128 # Capture outliers 

129 # If we capture more than 5 we stop searching 

130 outliers_count = 0 

131 max_z_score = 0 

132 max_z_score_frame = 0 

133 

134 # Get the maximum RMSD value and check it is a reasonable deviation from the average values 

135 # Otherwise, if it is an outlier, the test fails 

136 mean_rmsd_jump = np.mean(rmsd_jumps) 

137 stdv_rmsd_jump = np.std(rmsd_jumps) 

138 

139 # First frames may not be perfectly equilibrated and thus have stronger RMSD jumps 

140 # For this reason we allow the first frames to bypass the check 

141 # As soon as one frame is below the cutoff the bypass is finished for the following frames 

142 # Count the number of bypassed frames for this fragment 

143 bypassed_frames = 0 

144 

145 for i, rmsd_jump in enumerate(rmsd_jumps, 1): 

146 z_score = abs( (rmsd_jump - mean_rmsd_jump) / stdv_rmsd_jump ) 

147 # Keep track of the maximum z score 

148 if z_score > max_z_score: 

149 max_z_score = z_score 

150 max_z_score_frame = i 

151 # If z score bypassed the limit then report it 

152 if z_score > standard_deviations_cutoff: 

153 # If there are as many bypassed frames as the index then it means no frame has passed the cutoff yet 

154 if i - 1 == bypassed_frames: 

155 bypassed_frames += 1 

156 continue 

157 # If we are no long bypassing frist frames then outliers mean the test has failed 

158 # But don't kill the check just yet 

159 if outliers_count < 4: 

160 print(f' FAIL: Sudden RMSD jump in fragment "{fragment_name}" between frames {i} and {i+1} (z score = {round_to_hundredths(z_score)})') 

161 if outliers_count == 4: 

162 print(' etc...') 

163 outliers_count += 1 

164 

165 # Save the report for this fragment 

166 fragment_reports.append({ 

167 'name': fragment_name, 

168 'score': max_z_score, 

169 'frame': max_z_score_frame, 

170 'outliers': outliers_count, 

171 'bypass': bypassed_frames, 

172 'jumps' : rmsd_jumps 

173 }) 

174 

175 # If there was any outliers or any bypassed frames then the test has failed 

176 # LORE: Back in the day having bypassed frames was not a test failure, but now it is 

177 any_outliers = any(report['outliers'] > 0 for report in fragment_reports) 

178 any_bypassed_frames = any(report['bypass'] > 0 for report in fragment_reports) 

179 

180 # If the test has failed then display a full report 

181 if any_outliers or any_bypassed_frames: 

182 print('-- RMSD Check final report --') 

183 for report in fragment_reports: 

184 name = report['name'] 

185 max_z_score = round_to_hundredths(report['score']) 

186 max_z_score_frame = report['frame'] 

187 outliers = report['outliers'] 

188 bypassed_frames = report['bypass'] 

189 passed = outliers == 0 and bypassed_frames == 0 

190 # If the fragment has n0 problems 

191 if passed: 

192 print(f'Fragment "{name}" PASSED with a maximum z-score of {max_z_score}' + \ 

193 f' reported between frames {max_z_score_frame} and {max_z_score_frame + 1}') 

194 continue 

195 # If the fragment failed to pass the test 

196 print(f'{RED_HEADER}Fragment "{name}" FAILED with a maximum z-score of {max_z_score}' + \ 

197 f' reported between frames {max_z_score_frame} and {max_z_score_frame + 1}{COLOR_END}') 

198 if outliers > 0: print(f' Outliers: {outliers}') 

199 if bypassed_frames > 0: print(f' Bypassed frames: {bypassed_frames}') 

200 print('*The z-score of a value means how many times the standard deviation away it is from the average') 

201 

202 # If there were any outlier then the check has failed 

203 if any_outliers > 0: 

204 # Add a warning an return True since the test failed in case we have mercy 

205 message = 'RMSD check has failed: there may be sudden jumps along the trajectory' 

206 if TRAJECTORY_INTEGRITY_FLAG in mercy: 

207 register.add_warning(TRAJECTORY_INTEGRITY_FLAG, message) 

208 register.update_test(TRAJECTORY_INTEGRITY_FLAG, False) 

209 return False 

210 # Otherwise kill the process right away, after displaying a graph 

211 max_z_score = max(report['score'] for report in fragment_reports) 

212 worst_report = next(report for report in fragment_reports if report['score'] == max_z_score) 

213 title = f'RMSD jumps along the trajectory in fragment "{report["name"]}"' 

214 report_data = worst_report['jumps'] 

215 display_rmsd_jumps_graph(report_data, title) 

216 raise TestFailure(message) 

217 

218 # Warn the user if we had bypassed frames 

219 if any_bypassed_frames > 0: 

220 # Set the error message 

221 max_bypassed_frames = max(report['bypass'] for report in fragment_reports) 

222 message = f'First {max_bypassed_frames} frames may be not equilibrated' 

223 # Add a warning an return True since the test failed in case we have mercy 

224 if TRAJECTORY_INTEGRITY_FLAG in mercy: 

225 register.add_warning(TRAJECTORY_INTEGRITY_FLAG, message) 

226 register.update_test(TRAJECTORY_INTEGRITY_FLAG, False) 

227 return False 

228 # Otherwise kill the process, after displaying a graph 

229 max_z_score = max(report['score'] for report in fragment_reports) 

230 worst_report = next(report for report in fragment_reports if report['score'] == max_z_score) 

231 graph_frames_limit = max_bypassed_frames + 100 

232 report_data = worst_report['jumps'][0:graph_frames_limit] 

233 title = f'RMSD jumps along the trajectory in fragment "{report["name"]}" (first {graph_frames_limit} frames)' 

234 display_rmsd_jumps_graph(report_data, title) 

235 raise TestFailure(message) 

236 

237 print(' Test has passed successfully') 

238 register.update_test(TRAJECTORY_INTEGRITY_FLAG, True) 

239 return True 

240 

241# Display a graph to show the distribution of sudden jumps along the trajectory in the terminal itself 

242def display_rmsd_jumps_graph (data : list, title : str): 

243 if not is_terminal: return 

244 # Display a graph to show the distribution of sudden jumps along the trajectory 

245 plt.scatter(data) 

246 plt.title(title) 

247 plt.xlabel('Frame') 

248 plt.ylabel('RMSD jump') 

249 n_ticks = 5 

250 n_jumps = len(data) 

251 tickstep = math.ceil(n_jumps / n_ticks) 

252 xticks = [ t+1 for t in range(0, n_jumps, tickstep) ] 

253 xlabels = [ str(t) for t in xticks ] 

254 plt.xticks(xticks, xlabels) 

255 plt.show() 

256 

257# Compute every residue RMSD to check if there are sudden jumps along the trajectory 

258# HARDCODE: This function is not fully implemented but enabled manually for specific cases 

259def check_trajectory_integrity_per_residue ( 

260 input_structure_filename : str, 

261 input_trajectory_filename : str, 

262 structure : 'Structure', 

263 pbc_selection : 'Selection', 

264 mercy : list[str], 

265 trust: list[str], 

266 register : 'Register', 

267 #time_length : float, 

268 check_selection : str, 

269 # DANI: He visto saltos 'correctos' pasar de 11 

270 # DANI: He visto saltos 'incorrectos' no bajar de 14 

271 standard_deviations_cutoff : float): 

272 

273 # HARDCODE: The default value does not work for a single residue 

274 standard_deviations_cutoff = 12 

275 

276 # Skip the test if we trust 

277 if TRAJECTORY_INTEGRITY_FLAG in trust: 

278 return True 

279 

280 # Skip the test if it is already passed according to the register 

281 if register.tests.get(TRAJECTORY_INTEGRITY_FLAG, None): 

282 return True 

283 

284 # Remove old warnings 

285 register.remove_warnings(TRAJECTORY_INTEGRITY_FLAG) 

286 

287 # Parse the selection in VMD selection syntax 

288 parsed_selection = structure.select(check_selection, syntax='vmd') 

289 

290 # If there is nothing to check then warn the user and stop here 

291 if not parsed_selection: 

292 raise Exception('WARNING: There are not atoms to be analyzed for the RMSD analysis') 

293 

294 # Discard PBC residues from the selection to be checked 

295 parsed_selection -= pbc_selection 

296 

297 # If there is nothing to check then warn the user and stop here 

298 if not parsed_selection: 

299 warn('There are no atoms to be analyzed for the RMSD checking after PBC substraction') 

300 register.update_test(TRAJECTORY_INTEGRITY_FLAG, 'na') 

301 return True 

302 

303 # We must filter out residues which only have 1 atom (e.g. ions) 

304 # This is because sometimes pytraj does not return results for them and then the number of results and residues does not match 

305 # More info: https://github.com/Amber-MD/pytraj/issues/1580 

306 ion_atom_indices = [] 

307 for residue in structure.residues: 

308 if len(residue.atom_indices) == 1: 

309 ion_atom_indices += residue.atom_indices 

310 ions_selection = structure.select_atom_indices(ion_atom_indices) 

311 parsed_selection -= ions_selection 

312 

313 # Parse the selection to a pytraj mask 

314 pytraj_selection = parsed_selection.to_pytraj() 

315 

316 # Calculate the residue indices of the overall structure remaining in the filtered trajectory 

317 residue_indices = structure.get_selection_residue_indices(parsed_selection) 

318 n_residues = len(residue_indices) 

319 

320 print('Checking trajectory integrity per residue') 

321 

322 # Parse the trajectory into pytraj and apply the mask 

323 # NEVER FORGET: The pytraj iterload does not accept a mask, but we apply the mask later in the analysis 

324 pt_trajectory = get_pytraj_trajectory(input_structure_filename, input_trajectory_filename, atom_selection = parsed_selection) 

325 

326 # Make sure the expected output number of residues to match with the pytraj results 

327 # These numbers may not match when ions are included so we better check 

328 # NEVER FORGET: The pytraj TrajectoryIterator is not an iterator 

329 first_frame = pt_trajectory[0:1] 

330 # DANI: When the 'resname' argument is missing it prints "Error: Range::SetRange(None): Range is -1 for None" 

331 # DANI: However there is no problem and the analysis runs flawlessly 

332 # DANI: For this reason we call this function with no resname and then we remove the log 

333 data_sample = pt.rmsd_perres(first_frame, ref=first_frame, perres_mask=pytraj_selection) 

334 # We remove the previous error log 

335 delete_previous_log() 

336 # We remove the first result, which is meant to be the whole rmsd and whose key is 'RMSD_00001' 

337 del data_sample[0] 

338 if n_residues != len(data_sample): 

339 raise ValueError(f'Number of target residues ({n_residues}) does not match number of residues in data ({len(data_sample)})') 

340 

341 # Saving all RMSD jumps may take a lot of memory 

342 # Instead we will store the sum of values and the maximum 

343 # This way we can caluclate the average value at the end and check if the maximum is too far from it 

344 rmsd_per_residue_per_frame = [] 

345 

346 # Add an extra breakline before the first log 

347 print() 

348 

349 # Iterate trajectory frames 

350 previous_frame_trajectory = first_frame 

351 frame_number = 1 

352 for frame in pt_trajectory: 

353 # Update the current frame log 

354 reprint(f' Frame {frame_number}') 

355 # Set a pytraj trajectory out of a pytraj frame 

356 frame_trajectory = pt.Trajectory(top=pt_trajectory.top) 

357 frame_trajectory.append(frame) 

358 # Run the analysis in pytraj 

359 # The result data is a custom pytraj class: pytraj.datasets.datasetlist.DatasetList 

360 # This class has keys but its attributes can not be accessed through the key 

361 # They must be accessed thorugh the index 

362 # DANI: When the 'resname' argument is missing it prints "Error: Range::SetRange(None): Range is -1 for None" 

363 # DANI: However there is no problem and the analysis runs flawlessly 

364 # DANI: Adding resrage as a list/range was tried and did not work, only string works 

365 # DANI: Adding a string resrange however strongly impacts the speed when this function is called repeatedly 

366 # DANI: For this reason we call this function with no resname and then we remove the log 

367 rmsd_per_residue = pt.rmsd_perres(frame_trajectory, ref=previous_frame_trajectory, mask=pytraj_selection) 

368 # We remove the previous error log 

369 delete_previous_log() 

370 # We remove the first result, which is meant to be the whole rmsd and whose key is 'RMSD_00001' 

371 del rmsd_per_residue[0] 

372 # Check we have no NaNs 

373 if np.isnan(rmsd_per_residue[0][0]): 

374 raise ValueError(f'We are having NaNs at frame {frame_number}') 

375 # Add last values to the list 

376 rmsd_per_residue_per_frame.append(rmsd_per_residue) 

377 # rmsd_per_residue_per_frame[:, frame] 

378 # Now update data for every residue 

379 # for index, residue_rmsd in enumerate(rmsd_per_residue): 

380 # current_rmsd = residue_rmsd[0] 

381 # # Get the current residue rmsd data 

382 # total_rmsd = rmsd_per_residue_per_frame[index] 

383 # total_rmsd['accumulated'] += current_rmsd 

384 # total_rmsd['maximum'] = max(total_rmsd['maximum'], current_rmsd) 

385 # Update previous coordinates 

386 previous_frame_trajectory = frame_trajectory 

387 # Update the frame_number 

388 frame_number += 1 

389 

390 # If the trajectory has only 1 or 2 frames then there is no test to do 

391 n_jumps = len(rmsd_per_residue_per_frame) 

392 if n_jumps <= 1: 

393 register.update_test(TRAJECTORY_INTEGRITY_FLAG, True) 

394 return True 

395 

396 # Keep the overall maximum z score, and its residue and frame for the logs 

397 overall_max_z_score = 0 

398 overall_max_z_score_frame = None 

399 overall_max_z_score_residue = None 

400 

401 # Keep the overall maximum bypassed frames number 

402 overall_bypassed_frames = 0 

403 

404 # Keep the overall count of residues with outliers 

405 overall_outliered_residues = 0 

406 

407 # Add an extra breakline before the next log 

408 print() 

409 

410 # Now check there are not sudden jumps for each residue separattely 

411 for residue_number in range(n_residues): 

412 reprint(f' Residue {residue_number+1}') 

413 # Get the rmsd jumps for each frame for this specific residue 

414 rmsd_jumps = [ frame[residue_number] for frame in rmsd_per_residue_per_frame ] 

415 

416 # Get the maximum RMSD value and check it is a reasonable deviation from the average values 

417 # Otherwise, if it is an outlier, the test fails 

418 mean_rmsd_jump = np.mean(rmsd_jumps) 

419 stdv_rmsd_jump = np.std(rmsd_jumps) 

420 

421 # First frames may not be perfectly equilibrated and thus have stronger RMSD jumps 

422 # For this reason we allow the first frames to bypass the check 

423 # As soon as one frame is below the cutoff the bypass is finished for the following frames 

424 # Count the number of bypassed frames and warn the user in case there are any 

425 bypassed_frames = 0 

426 

427 # Capture outliers 

428 # If we capture more than 5 we stop searching 

429 outliers_count = 0 

430 max_z_score = 0 

431 max_z_score_frame = 0 

432 for i, rmsd_jump in enumerate(rmsd_jumps, 1): 

433 z_score = abs( (rmsd_jump - mean_rmsd_jump) / stdv_rmsd_jump ) 

434 # Keep track of the maixmum z score 

435 if z_score > max_z_score: 

436 max_z_score = z_score 

437 max_z_score_frame = i 

438 # If z score bypassed the limit then report it 

439 if z_score > standard_deviations_cutoff: 

440 # If there are as many bypassed frames as the index then it means no frame has passed the cutoff yet 

441 if i - 1 == bypassed_frames: 

442 bypassed_frames += 1 

443 continue 

444 # Otherwise we consider this as an outlier and thus the test has failed 

445 # However we keep checking just to find and report the highest outlier 

446 outliers_count += 1 

447 

448 # Update the overall bypassed frames if we overcomed it 

449 if overall_bypassed_frames < bypassed_frames: 

450 overall_bypassed_frames = bypassed_frames 

451 

452 # Update the overall max z score if we overcome it 

453 if max_z_score > overall_max_z_score: 

454 overall_max_z_score = max_z_score 

455 overall_max_z_score_frame = max_z_score_frame 

456 overall_max_z_score_residue = residue_number 

457 

458 # If there were any outlier then add one to the overall count 

459 if outliers_count > 0: 

460 overall_outliered_residues += 1 

461 

462 # Always print the overall maximum z score and its frames and residue 

463 overall_max_z_score_residue_label = pt_trajectory.top.residue(overall_max_z_score_residue) 

464 print(f' Maximum z score {overall_max_z_score} reported for residue {overall_max_z_score_residue_label} between frames {overall_max_z_score_frame} and {overall_max_z_score_frame + 1}') 

465 

466 # If there were any outlier then the check has failed 

467 if overall_outliered_residues > 0: 

468 # Add a warning an return True since the test failed in case we have mercy 

469 message = 'RMSD check has failed: there may be sudden jumps along the trajectory' 

470 if TRAJECTORY_INTEGRITY_FLAG in mercy: 

471 register.add_warning(TRAJECTORY_INTEGRITY_FLAG, message) 

472 register.update_test(TRAJECTORY_INTEGRITY_FLAG, False) 

473 return False 

474 # Otherwise kill the process right away 

475 raise TestFailure(message) 

476 

477 # Warn the user if we had bypassed frames 

478 if overall_bypassed_frames > 0: 

479 register.add_warning(TRAJECTORY_INTEGRITY_FLAG, f'First {overall_bypassed_frames} frames may be not equilibrated') 

480 

481 print(' Test has passed successfully') 

482 register.update_test(TRAJECTORY_INTEGRITY_FLAG, True) 

483 return True