Coverage for mddb_workflow/utils/pyt_spells.py: 57%

70 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 15:48 +0000

1import pytraj as pyt 

2import math 

3from packaging.version import Version 

4 

5from mddb_workflow.utils.auxiliar import InputError 

6from mddb_workflow.utils.file import File 

7from mddb_workflow.utils.selections import Selection 

8from mddb_workflow.utils.type_hints import * 

9from mddb_workflow.tools.get_reduced_trajectory import calculate_frame_step 

10# Set pytraj supported formats 

11pytraj_supported_structure_formats = {'prmtop', 'pdb', 'mol2', 'psf', 'cif', 'sdf'} 

12pytraj_supported_trajectory_formats = {'xtc', 'trr', 'crd', 'mdcrd', 'nc', 'dcd'} 

13 

14# Get the whole trajectory as a generator 

15def get_pytraj_trajectory ( 

16 input_topology_filename : str, 

17 input_trajectory_filename : str, 

18 atom_selection : Optional['Selection'] = None): 

19 

20 # Topology is mandatory to setup the pytraj trajectory 

21 if not input_topology_filename: 

22 raise SystemExit('Missing topology file to setup PyTraj trajectory') 

23 

24 # Set the pytraj trayectory and get the number of frames 

25 # NEVER FORGET: The pytraj iterload does not accept a mask, but we can strip atoms later 

26 pyt_trajectory = pyt.iterload(input_trajectory_filename, input_topology_filename) 

27 

28 # WARNING: This extra line prevents the error "Segment violation (core dumped)" in some pdbs 

29 # This happens with some random pdbs which pytraj considers to have 0 Mols 

30 # More info: https://github.com/Amber-MD/cpptraj/pull/820 

31 # DANI: Esto es útil en pytraj <= 2.0.5 pero hace fallar el código a partir de pytraj 2.0.6 

32 if Version(pyt.__version__) <= Version('2.0.5'): 

33 pyt_trajectory.top.start_new_mol() 

34 

35 # Filter away atoms which ARE NOT in the atom selection 

36 if atom_selection: 

37 # Get atom indices for all atoms but the ones in the atom selection 

38 topology = pyt_trajectory.top 

39 all_atoms = set(range(topology.n_atoms)) 

40 keep_atoms = set(atom_selection.atom_indices) 

41 strip_atoms = all_atoms - keep_atoms 

42 if len(strip_atoms) > 0: 

43 # Convert the strip atom indices to a pytraj mask string and strip the iterator 

44 mask = Selection(strip_atoms).to_pytraj() 

45 pyt_trajectory = pyt_trajectory.strip(mask) 

46 

47 return pyt_trajectory 

48 

49# Get the reduced trajectory 

50# WARNING: The final number of frames may be the specifided or less 

51def get_reduced_pytraj_trajectory ( 

52 input_topology_filename : str, 

53 input_trajectory_filename : str, 

54 snapshots : int, 

55 reduced_trajectory_frames_limit : int): 

56 

57 # Set the pytraj trayectory and get the number of frames 

58 pt_trajectory = get_pytraj_trajectory(input_topology_filename, input_trajectory_filename) 

59 # WARNING: Do not read pt_trajectory.n_frames to get the number of snapshots or you will read the whole trajectory 

60 # WARNING: This may be a lot of time for a huge trajectory. Use the snapshots input instead 

61 

62 # Set a reduced trajectory used for heavy analyses 

63 

64 # If the current trajectory has already less or the same frames than the limit 

65 # Then do nothing and use it also as reduced 

66 if snapshots <= reduced_trajectory_frames_limit: 

67 frame_step = 1 

68 return pt_trajectory, frame_step, snapshots 

69 

70 frame_step, reduced_frame_count = calculate_frame_step(snapshots, reduced_trajectory_frames_limit) 

71 reduced_pt_trajectory = pt_trajectory[0:snapshots:frame_step] 

72 return reduced_pt_trajectory, frame_step, reduced_frame_count 

73 

74# LORE: This was tried also with mdtraj's iterload but pytraj was way faster 

75def get_frames_count ( 

76 structure_file : 'File', 

77 trajectory_file : 'File') -> int: 

78 """Get the trajectory frames count.""" 

79 

80 print('-> Counting number of frames') 

81 

82 if not trajectory_file.exists: 

83 raise InputError('Missing trajectroy file when counting frames: ' + trajectory_file.path) 

84 

85 if not structure_file.exists: 

86 raise InputError('Missing topology file when counting frames: ' + structure_file.path) 

87 

88 # Load the trajectory from pytraj 

89 pyt_trajectory = pyt.iterload( 

90 trajectory_file.path, 

91 structure_file.path) 

92 

93 # Return the frames number 

94 frames = pyt_trajectory.n_frames 

95 print(f' Frames: {frames}') 

96 

97 # If 0 frames were counted then there is something wrong with the file 

98 if frames == 0: 

99 raise InputError('Something went wrong when reading the trajectory') 

100 

101 return frames 

102# Set function supported formats 

103get_frames_count.format_sets = [ 

104 { 

105 'inputs': { 

106 'input_structure_filename': pytraj_supported_structure_formats, 

107 'input_trajectory_filename': pytraj_supported_trajectory_formats 

108 } 

109 } 

110] 

111 

112# Filter topology atoms 

113# DANI: Note that a PRMTOP file is not a structure but a topology 

114# DANI: However it is important that the argument is called 'structure' for the format finder 

115def filter_topology ( 

116 input_structure_file : str, 

117 output_structure_file : str, 

118 input_selection : 'Selection' 

119): 

120 # Generate a pytraj mask with the desired selection 

121 mask = input_selection.to_pytraj() 

122 

123 # Load the topology 

124 topology = pyt.load_topology(input_structure_file.path) 

125 

126 # Apply the filter mask 

127 filtered_topology = topology[mask] 

128 

129 # Write the filtered topology to disk 

130 filtered_topology.save(output_structure_file.path) 

131 

132 # Check the output file exists at this point 

133 # If not then it means something went wrong with gromacs 

134 if not output_structure_file.exists: 

135 raise SystemExit('Something went wrong with PyTraj') 

136 

137 

138filter_topology.format_sets = [ 

139 { 

140 'inputs': { 

141 'input_structure_file': pytraj_supported_structure_formats, 

142 }, 

143 'outputs': { 

144 'output_structure_file': pytraj_supported_structure_formats 

145 } 

146 } 

147] 

148 

149# Given a corrupted NetCDF file, whose first frames may be read by pytraj, find the first corrupted frame number 

150def find_first_corrupted_frame (input_topology_filepath, input_trajectory_filepath) -> int: 

151 # Iterload the trajectory to pytraj 

152 trajectory = get_pytraj_trajectory(input_topology_filepath, input_trajectory_filepath) 

153 # Iterate frames until we find one frame whose last atom coordinates are all zeros 

154 frame_iterator = iter(trajectory.iterframe()) 

155 expected_frames = trajectory.n_frames 

156 for f, frame in enumerate(frame_iterator, 1): 

157 print(f'Reading frame {f}/{expected_frames}', end='\r') 

158 # Make sure there are actual coordinates here 

159 # If there is any problem we may have frames with coordinates full of zeros 

160 last_atom_coordinates = frame.xyz[-1] 

161 if not last_atom_coordinates.any(): 

162 return f 

163 return None 

164 

165# This process is carried by pytraj, since the Gromacs average may be displaced 

166def get_average_structure (structure_file : 'File', trajectory_file : 'File', output_filepath : str): 

167 """Get an average structure from a trajectory.""" 

168 # Iterload the trajectory to pytraj 

169 pytraj_trajectory = get_pytraj_trajectory(structure_file.path, trajectory_file.path) 

170 

171 # Create a new frame with the average positions 

172 # WARNING: Do not pass the argument 'autoimage=True' 

173 # WARNING: Autoimage makes some trajectories get displaced the same as in Gromacs 

174 average_frame = pyt.mean_structure(pytraj_trajectory()) 

175 

176 # In order to export it, first create an empty trajectory only with the topology 

177 # Then add the average frame and write it to 'xtc' format 

178 average = pyt.Trajectory(top=pytraj_trajectory.top) 

179 average.append(average_frame) 

180 pyt.write_traj(output_filepath, average, overwrite=True)