Coverage for mddb_workflow/utils/pyt_spells.py: 57%
70 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 15:48 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 15:48 +0000
1import pytraj as pyt
2import math
3from packaging.version import Version
5from mddb_workflow.utils.auxiliar import InputError
6from mddb_workflow.utils.file import File
7from mddb_workflow.utils.selections import Selection
8from mddb_workflow.utils.type_hints import *
9from mddb_workflow.tools.get_reduced_trajectory import calculate_frame_step
10# Set pytraj supported formats
11pytraj_supported_structure_formats = {'prmtop', 'pdb', 'mol2', 'psf', 'cif', 'sdf'}
12pytraj_supported_trajectory_formats = {'xtc', 'trr', 'crd', 'mdcrd', 'nc', 'dcd'}
14# Get the whole trajectory as a generator
15def get_pytraj_trajectory (
16 input_topology_filename : str,
17 input_trajectory_filename : str,
18 atom_selection : Optional['Selection'] = None):
20 # Topology is mandatory to setup the pytraj trajectory
21 if not input_topology_filename:
22 raise SystemExit('Missing topology file to setup PyTraj trajectory')
24 # Set the pytraj trayectory and get the number of frames
25 # NEVER FORGET: The pytraj iterload does not accept a mask, but we can strip atoms later
26 pyt_trajectory = pyt.iterload(input_trajectory_filename, input_topology_filename)
28 # WARNING: This extra line prevents the error "Segment violation (core dumped)" in some pdbs
29 # This happens with some random pdbs which pytraj considers to have 0 Mols
30 # More info: https://github.com/Amber-MD/cpptraj/pull/820
31 # DANI: Esto es útil en pytraj <= 2.0.5 pero hace fallar el código a partir de pytraj 2.0.6
32 if Version(pyt.__version__) <= Version('2.0.5'):
33 pyt_trajectory.top.start_new_mol()
35 # Filter away atoms which ARE NOT in the atom selection
36 if atom_selection:
37 # Get atom indices for all atoms but the ones in the atom selection
38 topology = pyt_trajectory.top
39 all_atoms = set(range(topology.n_atoms))
40 keep_atoms = set(atom_selection.atom_indices)
41 strip_atoms = all_atoms - keep_atoms
42 if len(strip_atoms) > 0:
43 # Convert the strip atom indices to a pytraj mask string and strip the iterator
44 mask = Selection(strip_atoms).to_pytraj()
45 pyt_trajectory = pyt_trajectory.strip(mask)
47 return pyt_trajectory
49# Get the reduced trajectory
50# WARNING: The final number of frames may be the specifided or less
51def get_reduced_pytraj_trajectory (
52 input_topology_filename : str,
53 input_trajectory_filename : str,
54 snapshots : int,
55 reduced_trajectory_frames_limit : int):
57 # Set the pytraj trayectory and get the number of frames
58 pt_trajectory = get_pytraj_trajectory(input_topology_filename, input_trajectory_filename)
59 # WARNING: Do not read pt_trajectory.n_frames to get the number of snapshots or you will read the whole trajectory
60 # WARNING: This may be a lot of time for a huge trajectory. Use the snapshots input instead
62 # Set a reduced trajectory used for heavy analyses
64 # If the current trajectory has already less or the same frames than the limit
65 # Then do nothing and use it also as reduced
66 if snapshots <= reduced_trajectory_frames_limit:
67 frame_step = 1
68 return pt_trajectory, frame_step, snapshots
70 frame_step, reduced_frame_count = calculate_frame_step(snapshots, reduced_trajectory_frames_limit)
71 reduced_pt_trajectory = pt_trajectory[0:snapshots:frame_step]
72 return reduced_pt_trajectory, frame_step, reduced_frame_count
74# LORE: This was tried also with mdtraj's iterload but pytraj was way faster
75def get_frames_count (
76 structure_file : 'File',
77 trajectory_file : 'File') -> int:
78 """Get the trajectory frames count."""
80 print('-> Counting number of frames')
82 if not trajectory_file.exists:
83 raise InputError('Missing trajectroy file when counting frames: ' + trajectory_file.path)
85 if not structure_file.exists:
86 raise InputError('Missing topology file when counting frames: ' + structure_file.path)
88 # Load the trajectory from pytraj
89 pyt_trajectory = pyt.iterload(
90 trajectory_file.path,
91 structure_file.path)
93 # Return the frames number
94 frames = pyt_trajectory.n_frames
95 print(f' Frames: {frames}')
97 # If 0 frames were counted then there is something wrong with the file
98 if frames == 0:
99 raise InputError('Something went wrong when reading the trajectory')
101 return frames
102# Set function supported formats
103get_frames_count.format_sets = [
104 {
105 'inputs': {
106 'input_structure_filename': pytraj_supported_structure_formats,
107 'input_trajectory_filename': pytraj_supported_trajectory_formats
108 }
109 }
110]
112# Filter topology atoms
113# DANI: Note that a PRMTOP file is not a structure but a topology
114# DANI: However it is important that the argument is called 'structure' for the format finder
115def filter_topology (
116 input_structure_file : str,
117 output_structure_file : str,
118 input_selection : 'Selection'
119):
120 # Generate a pytraj mask with the desired selection
121 mask = input_selection.to_pytraj()
123 # Load the topology
124 topology = pyt.load_topology(input_structure_file.path)
126 # Apply the filter mask
127 filtered_topology = topology[mask]
129 # Write the filtered topology to disk
130 filtered_topology.save(output_structure_file.path)
132 # Check the output file exists at this point
133 # If not then it means something went wrong with gromacs
134 if not output_structure_file.exists:
135 raise SystemExit('Something went wrong with PyTraj')
138filter_topology.format_sets = [
139 {
140 'inputs': {
141 'input_structure_file': pytraj_supported_structure_formats,
142 },
143 'outputs': {
144 'output_structure_file': pytraj_supported_structure_formats
145 }
146 }
147]
149# Given a corrupted NetCDF file, whose first frames may be read by pytraj, find the first corrupted frame number
150def find_first_corrupted_frame (input_topology_filepath, input_trajectory_filepath) -> int:
151 # Iterload the trajectory to pytraj
152 trajectory = get_pytraj_trajectory(input_topology_filepath, input_trajectory_filepath)
153 # Iterate frames until we find one frame whose last atom coordinates are all zeros
154 frame_iterator = iter(trajectory.iterframe())
155 expected_frames = trajectory.n_frames
156 for f, frame in enumerate(frame_iterator, 1):
157 print(f'Reading frame {f}/{expected_frames}', end='\r')
158 # Make sure there are actual coordinates here
159 # If there is any problem we may have frames with coordinates full of zeros
160 last_atom_coordinates = frame.xyz[-1]
161 if not last_atom_coordinates.any():
162 return f
163 return None
165# This process is carried by pytraj, since the Gromacs average may be displaced
166def get_average_structure (structure_file : 'File', trajectory_file : 'File', output_filepath : str):
167 """Get an average structure from a trajectory."""
168 # Iterload the trajectory to pytraj
169 pytraj_trajectory = get_pytraj_trajectory(structure_file.path, trajectory_file.path)
171 # Create a new frame with the average positions
172 # WARNING: Do not pass the argument 'autoimage=True'
173 # WARNING: Autoimage makes some trajectories get displaced the same as in Gromacs
174 average_frame = pyt.mean_structure(pytraj_trajectory())
176 # In order to export it, first create an empty trajectory only with the topology
177 # Then add the average frame and write it to 'xtc' format
178 average = pyt.Trajectory(top=pytraj_trajectory.top)
179 average.append(average_frame)
180 pyt.write_traj(output_filepath, average, overwrite=True)