Coverage for model_workflow/utils/pyt_spells.py: 61%

70 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-23 10:54 +0000

1import pytraj as pyt 

2import math 

3from packaging.version import Version 

4 

5from model_workflow.utils.auxiliar import InputError 

6from model_workflow.utils.file import File 

7from model_workflow.utils.selections import Selection 

8from model_workflow.utils.type_hints import * 

9 

10# Set pytraj supported formats 

11pytraj_supported_structure_formats = {'prmtop', 'pdb', 'mol2', 'psf', 'cif', 'sdf'} 

12pytraj_supported_trajectory_formats = {'xtc', 'trr', 'crd', 'mdcrd', 'nc', 'dcd'} 

13 

14# Get the whole trajectory as a generator 

15def get_pytraj_trajectory ( 

16 input_topology_filename : str, 

17 input_trajectory_filename : str, 

18 atom_selection : Optional['Selection'] = None): 

19 

20 # Topology is mandatory to setup the pytraj trajectory 

21 if not input_topology_filename: 

22 raise SystemExit('Missing topology file to setup PyTraj trajectory') 

23 

24 # Set the pytraj trayectory and get the number of frames 

25 # NEVER FORGET: The pytraj iterload does not accept a mask, but we can strip atomos later 

26 pyt_trajectory = pyt.iterload(input_trajectory_filename, input_topology_filename) 

27 

28 # WARNING: This extra line prevents the error "Segment violation (core dumped)" in some pdbs 

29 # This happens with some random pdbs which pytraj considers to have 0 Mols 

30 # More info: https://github.com/Amber-MD/cpptraj/pull/820 

31 # DANI: Esto es útil en pytraj <= 2.0.5 pero hace fallar el código a partir de pytraj 2.0.6 

32 if Version(pyt.__version__) <= Version('2.0.5'): 

33 pyt_trajectory.top.start_new_mol() 

34 

35 # Filter away atoms which ARE NOT in the atom selection 

36 if atom_selection: 

37 # Get atom indices for all atoms but the ones in the atom selection 

38 topology = pyt_trajectory.top 

39 all_atoms = set(range(topology.n_atoms)) 

40 keep_atoms = set(atom_selection.atom_indices) 

41 strip_atoms = all_atoms - keep_atoms 

42 if len(strip_atoms) > 0: 

43 # Convert the strip atom indices to a pytraj mask string and strip the iterator 

44 mask = Selection(strip_atoms).to_pytraj() 

45 pyt_trajectory = pyt_trajectory.strip(mask) 

46 

47 return pyt_trajectory 

48 

49# Get the reduced trajectory 

50# WARNING: The final number of frames may be the specifided or less 

51def get_reduced_pytraj_trajectory ( 

52 input_topology_filename : str, 

53 input_trajectory_filename : str, 

54 snapshots : int, 

55 reduced_trajectory_frames_limit : int): 

56 

57 # Set the pytraj trayectory and get the number of frames 

58 pt_trajectory = get_pytraj_trajectory(input_topology_filename, input_trajectory_filename) 

59 # WARNING: Do not read pt_trajectory.n_frames to get the number of snapshots or you will read the whole trajectory 

60 # WARNING: This may be a lot of time for a huge trajectory. Use the snapshots input instead 

61 

62 # Set a reduced trajectory used for heavy analyses 

63 

64 # If the current trajectory has already less or the same frames than the limit 

65 # Then do nothing and use it also as reduced 

66 if snapshots <= reduced_trajectory_frames_limit: 

67 frame_step = 1 

68 return pt_trajectory, frame_step, snapshots 

69 

70 # Otherwise, create a reduced trajectory with as much frames as specified above 

71 # These frames are picked along the trajectory 

72 # Calculate the step between frames in the reduced trajectory to match the final number of frames 

73 # WARNING: Since the step must be an integer the thorical step must be rounded 

74 # This means the specified final number of frames may not be accomplished, but it is okey 

75 # WARNING: Since the step is rounded with the math.ceil function it will always be rounded up 

76 # This means the final number of frames will be the specified or less 

77 # CRITICAL WARNING: 

78 # This formula is exactly the same that the client uses to request stepped frames to the API 

79 # This means that the client and the workflow are coordinated and these formulas must not change 

80 # If you decide to change this formula (in both workflow and client)... 

81 # You will have to run again all the database analyses with reduced trajectories 

82 frame_step = math.ceil(snapshots / reduced_trajectory_frames_limit) 

83 reduced_pt_trajectory = pt_trajectory[0:snapshots:frame_step] 

84 reduced_frame_count = math.ceil(snapshots / frame_step) 

85 return reduced_pt_trajectory, frame_step, reduced_frame_count 

86 

87# LORE: This was tried also with mdtraj's iterload but pytraj was way faster 

88def get_frames_count ( 

89 structure_file : 'File', 

90 trajectory_file : 'File') -> int: 

91 """Get the trajectory frames count.""" 

92 

93 print('-> Counting number of frames') 

94 

95 if not trajectory_file.exists: 

96 raise InputError('Missing trajectroy file when counting frames: ' + trajectory_file.path) 

97 

98 if not structure_file.exists: 

99 raise InputError('Missing topology file when counting frames: ' + structure_file.path) 

100 

101 # Load the trajectory from pytraj 

102 pyt_trajectory = pyt.iterload( 

103 trajectory_file.path, 

104 structure_file.path) 

105 

106 # Return the frames number 

107 frames = pyt_trajectory.n_frames 

108 print(f' Frames: {frames}') 

109 

110 # If 0 frames were counted then there is something wrong with the file 

111 if frames == 0: 

112 raise InputError('Something went wrong when reading the trajectory') 

113 

114 return frames 

115# Set function supported formats 

116get_frames_count.format_sets = [ 

117 { 

118 'inputs': { 

119 'input_structure_filename': pytraj_supported_structure_formats, 

120 'input_trajectory_filename': pytraj_supported_trajectory_formats 

121 } 

122 } 

123] 

124 

125# Filter topology atoms 

126# DANI: Note that a PRMTOP file is not a structure but a topology 

127# DANI: However it is important that the argument is called 'structure' for the format finder 

128def filter_topology ( 

129 input_structure_file : str, 

130 output_structure_file : str, 

131 input_selection : 'Selection' 

132): 

133 # Generate a pytraj mask with the desired selection 

134 mask = input_selection.to_pytraj() 

135 

136 # Load the topology 

137 topology = pyt.load_topology(input_structure_file.path) 

138 

139 # Apply the filter mask 

140 filtered_topology = topology[mask] 

141 

142 # Write the filtered topology to disk 

143 filtered_topology.save(output_structure_file.path) 

144 

145 # Check the output file exists at this point 

146 # If not then it means something went wrong with gromacs 

147 if not output_structure_file.exists: 

148 raise SystemExit('Something went wrong with PyTraj') 

149 

150 

151filter_topology.format_sets = [ 

152 { 

153 'inputs': { 

154 'input_structure_file': pytraj_supported_structure_formats, 

155 }, 

156 'outputs': { 

157 'output_structure_file': pytraj_supported_structure_formats 

158 } 

159 } 

160] 

161 

162# Given a corrupted NetCDF file, whose first frames may be read by pytraj, find the first corrupted frame number 

163def find_first_corrupted_frame (input_topology_filepath, input_trajectory_filepath) -> int: 

164 # Iterload the trajectory to pytraj 

165 trajectory = get_pytraj_trajectory(input_topology_filepath, input_trajectory_filepath) 

166 # Iterate frames until we find one frame whose last atom coordinates are all zeros 

167 frame_iterator = iter(trajectory.iterframe()) 

168 expected_frames = trajectory.n_frames 

169 for f, frame in enumerate(frame_iterator, 1): 

170 print(f'Reading frame {f}/{expected_frames}', end='\r') 

171 # Make sure there are actual coordinates here 

172 # If there is any problem we may have frames with coordinates full of zeros 

173 last_atom_coordinates = frame.xyz[-1] 

174 if not last_atom_coordinates.any(): 

175 return f 

176 return None 

177 

178# This process is carried by pytraj, since the Gromacs average may be displaced 

179def get_average_structure (structure_file : 'File', trajectory_file : 'File', output_filepath : str): 

180 """Get an average structure from a trajectory.""" 

181 # Iterload the trajectory to pytraj 

182 pytraj_trajectory = get_pytraj_trajectory(structure_file.path, trajectory_file.path) 

183 

184 # Create a new frame with the average positions 

185 # WARNING: Do not pass the argument 'autoimage=True' 

186 # WARNING: Autoimage makes some trajectories get displaced the same as in Gromacs 

187 average_frame = pyt.mean_structure(pytraj_trajectory()) 

188 

189 # In order to export it, first create an empty trajectory only with the topology 

190 # Then add the average frame and write it to 'xtc' format 

191 average = pyt.Trajectory(top=pytraj_trajectory.top) 

192 average.append(average_frame) 

193 pyt.write_traj(output_filepath, average, overwrite=True)