Coverage for model_workflow/utils/pyt_spells.py: 61%
70 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
1import pytraj as pyt
2import math
3from packaging.version import Version
5from model_workflow.utils.auxiliar import InputError
6from model_workflow.utils.file import File
7from model_workflow.utils.selections import Selection
8from model_workflow.utils.type_hints import *
10# Set pytraj supported formats
11pytraj_supported_structure_formats = {'prmtop', 'pdb', 'mol2', 'psf', 'cif', 'sdf'}
12pytraj_supported_trajectory_formats = {'xtc', 'trr', 'crd', 'mdcrd', 'nc', 'dcd'}
14# Get the whole trajectory as a generator
15def get_pytraj_trajectory (
16 input_topology_filename : str,
17 input_trajectory_filename : str,
18 atom_selection : Optional['Selection'] = None):
20 # Topology is mandatory to setup the pytraj trajectory
21 if not input_topology_filename:
22 raise SystemExit('Missing topology file to setup PyTraj trajectory')
24 # Set the pytraj trayectory and get the number of frames
25 # NEVER FORGET: The pytraj iterload does not accept a mask, but we can strip atomos later
26 pyt_trajectory = pyt.iterload(input_trajectory_filename, input_topology_filename)
28 # WARNING: This extra line prevents the error "Segment violation (core dumped)" in some pdbs
29 # This happens with some random pdbs which pytraj considers to have 0 Mols
30 # More info: https://github.com/Amber-MD/cpptraj/pull/820
31 # DANI: Esto es útil en pytraj <= 2.0.5 pero hace fallar el código a partir de pytraj 2.0.6
32 if Version(pyt.__version__) <= Version('2.0.5'):
33 pyt_trajectory.top.start_new_mol()
35 # Filter away atoms which ARE NOT in the atom selection
36 if atom_selection:
37 # Get atom indices for all atoms but the ones in the atom selection
38 topology = pyt_trajectory.top
39 all_atoms = set(range(topology.n_atoms))
40 keep_atoms = set(atom_selection.atom_indices)
41 strip_atoms = all_atoms - keep_atoms
42 if len(strip_atoms) > 0:
43 # Convert the strip atom indices to a pytraj mask string and strip the iterator
44 mask = Selection(strip_atoms).to_pytraj()
45 pyt_trajectory = pyt_trajectory.strip(mask)
47 return pyt_trajectory
49# Get the reduced trajectory
50# WARNING: The final number of frames may be the specifided or less
51def get_reduced_pytraj_trajectory (
52 input_topology_filename : str,
53 input_trajectory_filename : str,
54 snapshots : int,
55 reduced_trajectory_frames_limit : int):
57 # Set the pytraj trayectory and get the number of frames
58 pt_trajectory = get_pytraj_trajectory(input_topology_filename, input_trajectory_filename)
59 # WARNING: Do not read pt_trajectory.n_frames to get the number of snapshots or you will read the whole trajectory
60 # WARNING: This may be a lot of time for a huge trajectory. Use the snapshots input instead
62 # Set a reduced trajectory used for heavy analyses
64 # If the current trajectory has already less or the same frames than the limit
65 # Then do nothing and use it also as reduced
66 if snapshots <= reduced_trajectory_frames_limit:
67 frame_step = 1
68 return pt_trajectory, frame_step, snapshots
70 # Otherwise, create a reduced trajectory with as much frames as specified above
71 # These frames are picked along the trajectory
72 # Calculate the step between frames in the reduced trajectory to match the final number of frames
73 # WARNING: Since the step must be an integer the thorical step must be rounded
74 # This means the specified final number of frames may not be accomplished, but it is okey
75 # WARNING: Since the step is rounded with the math.ceil function it will always be rounded up
76 # This means the final number of frames will be the specified or less
77 # CRITICAL WARNING:
78 # This formula is exactly the same that the client uses to request stepped frames to the API
79 # This means that the client and the workflow are coordinated and these formulas must not change
80 # If you decide to change this formula (in both workflow and client)...
81 # You will have to run again all the database analyses with reduced trajectories
82 frame_step = math.ceil(snapshots / reduced_trajectory_frames_limit)
83 reduced_pt_trajectory = pt_trajectory[0:snapshots:frame_step]
84 reduced_frame_count = math.ceil(snapshots / frame_step)
85 return reduced_pt_trajectory, frame_step, reduced_frame_count
87# LORE: This was tried also with mdtraj's iterload but pytraj was way faster
88def get_frames_count (
89 structure_file : 'File',
90 trajectory_file : 'File') -> int:
91 """Get the trajectory frames count."""
93 print('-> Counting number of frames')
95 if not trajectory_file.exists:
96 raise InputError('Missing trajectroy file when counting frames: ' + trajectory_file.path)
98 if not structure_file.exists:
99 raise InputError('Missing topology file when counting frames: ' + structure_file.path)
101 # Load the trajectory from pytraj
102 pyt_trajectory = pyt.iterload(
103 trajectory_file.path,
104 structure_file.path)
106 # Return the frames number
107 frames = pyt_trajectory.n_frames
108 print(f' Frames: {frames}')
110 # If 0 frames were counted then there is something wrong with the file
111 if frames == 0:
112 raise InputError('Something went wrong when reading the trajectory')
114 return frames
115# Set function supported formats
116get_frames_count.format_sets = [
117 {
118 'inputs': {
119 'input_structure_filename': pytraj_supported_structure_formats,
120 'input_trajectory_filename': pytraj_supported_trajectory_formats
121 }
122 }
123]
125# Filter topology atoms
126# DANI: Note that a PRMTOP file is not a structure but a topology
127# DANI: However it is important that the argument is called 'structure' for the format finder
128def filter_topology (
129 input_structure_file : str,
130 output_structure_file : str,
131 input_selection : 'Selection'
132):
133 # Generate a pytraj mask with the desired selection
134 mask = input_selection.to_pytraj()
136 # Load the topology
137 topology = pyt.load_topology(input_structure_file.path)
139 # Apply the filter mask
140 filtered_topology = topology[mask]
142 # Write the filtered topology to disk
143 filtered_topology.save(output_structure_file.path)
145 # Check the output file exists at this point
146 # If not then it means something went wrong with gromacs
147 if not output_structure_file.exists:
148 raise SystemExit('Something went wrong with PyTraj')
151filter_topology.format_sets = [
152 {
153 'inputs': {
154 'input_structure_file': pytraj_supported_structure_formats,
155 },
156 'outputs': {
157 'output_structure_file': pytraj_supported_structure_formats
158 }
159 }
160]
162# Given a corrupted NetCDF file, whose first frames may be read by pytraj, find the first corrupted frame number
163def find_first_corrupted_frame (input_topology_filepath, input_trajectory_filepath) -> int:
164 # Iterload the trajectory to pytraj
165 trajectory = get_pytraj_trajectory(input_topology_filepath, input_trajectory_filepath)
166 # Iterate frames until we find one frame whose last atom coordinates are all zeros
167 frame_iterator = iter(trajectory.iterframe())
168 expected_frames = trajectory.n_frames
169 for f, frame in enumerate(frame_iterator, 1):
170 print(f'Reading frame {f}/{expected_frames}', end='\r')
171 # Make sure there are actual coordinates here
172 # If there is any problem we may have frames with coordinates full of zeros
173 last_atom_coordinates = frame.xyz[-1]
174 if not last_atom_coordinates.any():
175 return f
176 return None
178# This process is carried by pytraj, since the Gromacs average may be displaced
179def get_average_structure (structure_file : 'File', trajectory_file : 'File', output_filepath : str):
180 """Get an average structure from a trajectory."""
181 # Iterload the trajectory to pytraj
182 pytraj_trajectory = get_pytraj_trajectory(structure_file.path, trajectory_file.path)
184 # Create a new frame with the average positions
185 # WARNING: Do not pass the argument 'autoimage=True'
186 # WARNING: Autoimage makes some trajectories get displaced the same as in Gromacs
187 average_frame = pyt.mean_structure(pytraj_trajectory())
189 # In order to export it, first create an empty trajectory only with the topology
190 # Then add the average frame and write it to 'xtc' format
191 average = pyt.Trajectory(top=pytraj_trajectory.top)
192 average.append(average_frame)
193 pyt.write_traj(output_filepath, average, overwrite=True)