Coverage for model_workflow/tools/check_inputs.py: 57%
143 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
1from model_workflow.utils.auxiliar import InputError, warn, CaptureOutput, load_json, MISSING_TOPOLOGY
2from model_workflow.utils.constants import STANDARD_TOPOLOGY_FILENAME, GROMACS_EXECUTABLE
3from model_workflow.utils.pyt_spells import find_first_corrupted_frame
4from model_workflow.utils.gmx_spells import mine_system_atoms_count
5from model_workflow.utils.vmd_spells import vmd_to_pdb
6from model_workflow.utils.structures import Structure
7from model_workflow.utils.file import File
9from model_workflow.tools.guess_and_filter import guess_and_filter_topology
11from re import match, search
12from typing import *
13from subprocess import run, PIPE, Popen
14from scipy.io import netcdf_file
15import mdtraj as mdt
16import pytraj as pyt
18# Set some known message errors
19NETCDF_DTYPE_ERROR = 'When changing to a larger dtype, its size must be a divisor of the total size in bytes of the last axis of the array.'
20MDTRAJ_ATOM_MISMATCH_ERROR = r'xyz must be shape \(Any, ([0-9]*), 3\). You supplied \(1, ([0-9]*), 3\)'
21PYTRAJ_XTC_ATOM_MISMATCH_ERROR = r'Error: # atoms in XTC file \(([0-9]*)\) does not match # atoms in (topology|parm) [\w.-]* \(([0-9]*)\)'
22GROMACS_ATOM_MISMATCH_ERROR = r'is larger than the number of atoms in the\ntrajectory file \(([0-9]*)\). There is a mismatch in the contents'
24# List supported formats
25TOPOLOGY_SUPPORTED_FORMATS = { 'tpr', 'top', 'prmtop', 'psf' }
26TRAJECTORY_SUPPORTED_FORMATS = { 'xtc', 'trr', 'nc', 'dcd', 'crd', 'pdb', 'rst7' }
27STRUCTURE_SUPPORTED_FORMATS = { *TOPOLOGY_SUPPORTED_FORMATS, 'pdb', 'gro' }
28GROMACS_TRAJECTORY_SUPPORTED_FORMATS = { 'xtc', 'trr'}
30# Auxiliar PDB file which may be generated to load non supported restart files
31AUXILIAR_PDB_BILE = '.auxiliar.pdb'
33# Set excpetions for fixes applied from here
34PREFILTERED_TOPOLOGY_EXCEPTION = Exception('Prefiltered topology')
36# Check input files coherence and intergrity
37# If there is any problem then raise an input error
38# Some exceptional problems may be fixed from here
39# In this cases both the exception and the modified file are return in a final dict
40def check_inputs (
41 input_structure_file : 'File',
42 input_trajectory_files : List['File'],
43 input_topology_file : Union['File', Exception]) -> dict:
45 # Set the exceptions dict to be returned at the end
46 exceptions = {}
48 # Get a sample trajectory file and then check its format
49 # All input trajectory files must have the same format
50 trajectory_sample = input_trajectory_files[0]
52 # Check input files are supported by the workflow
53 if input_topology_file != MISSING_TOPOLOGY and input_topology_file.filename != STANDARD_TOPOLOGY_FILENAME and input_topology_file.format not in TOPOLOGY_SUPPORTED_FORMATS:
54 if input_topology_file.format in { 'pdb', 'gro' }:
55 raise InputError('A structure file is not supported as topology anymore. If there is no topology then use the argument "-top no"')
56 raise InputError(f'Topology {input_topology_file.path} has a not supported format. Try one of these: {", ".join(TOPOLOGY_SUPPORTED_FORMATS)}')
57 if trajectory_sample.format not in TRAJECTORY_SUPPORTED_FORMATS:
58 raise InputError(f'Trajectory {trajectory_sample.path} has a not supported format. Try one of these: {", ".join(TRAJECTORY_SUPPORTED_FORMATS)}')
59 if input_structure_file.format not in STRUCTURE_SUPPORTED_FORMATS:
60 raise InputError(f'Structure {input_structure_file.path} has a not supported format. Try one of these: {", ".join(STRUCTURE_SUPPORTED_FORMATS)}')
62 # Make sure the trajectory file is not corrupted
64 # Check if reading the trajectory raises the following error
65 # ValueError: When changing to a larger dtype, its size must be a divisor of the total size in bytes of the last axis of the array.
66 # This error may happen with NetCDF files and it is a bit shady
67 # Some tools may be able to read the first frames of the corrupted file: VMD and pytraj
68 # Some other tools will instantly fail to read it: MDtraj and MDAnalysis
69 if trajectory_sample.format == 'nc':
70 try:
71 # Iterate trajectory files
72 for trajectory_file in input_trajectory_files:
73 # This does not read the whole trajectory
74 netcdf_file(trajectory_file.path, 'r')
75 except Exception as error:
76 # If the error message matches with a known error then report the problem
77 error_message = str(error)
78 if error_message == NETCDF_DTYPE_ERROR:
79 warn(f'Corrupted trajectory file {trajectory_file.path}')
80 pytraj_input_topology = input_topology_file if input_topology_file != MISSING_TOPOLOGY else input_structure_file
81 first_corrupted_frame = find_first_corrupted_frame(pytraj_input_topology.path, trajectory_file.path)
82 print(f' However some tools may be able to read the first {first_corrupted_frame} frames: VMD and PyTraj')
83 raise InputError('Corrupted input trajectory file')
84 # If we do not know the error then raise it as is
85 else:
86 raise error
88 # Set a function to get atoms from a structure alone
89 def get_structure_atoms (structure_file : 'File') -> int:
90 # Get the number of atoms in the input structure
91 structure = Structure.from_file(structure_file.path)
92 return structure.atom_count
94 # Set a function to get atoms from structure and trajectory together
95 def get_structure_and_trajectory_atoms (structure_file : 'File', trajectory_file : 'File') -> Tuple[int, int]:
96 # Note that declaring the iterator will not fail even when there is a mismatch
97 trajectory = mdt.iterload(trajectory_file.path, top=structure_file.path, chunk=1)
98 # We must consume the generator first value to make the error raise
99 frame = next(trajectory)
100 # Now obtain the number of atoms from the frame we just read
101 trajectory_atom_count = frame.n_atoms
102 # And still, it may happen that the topology has more atoms than the trajectory but it loads
103 # MDtraj may silently load as many coordinates as possible and discard the rest of atoms in topology
104 # This behaviour has been observed with a gromacs .top topology and a PDB used as trajectory
105 # Two double check the match, load the topology alone with PyTraj
106 topology = pyt.load_topology(structure_file.path)
107 structure_atom_count = topology.n_atoms
108 return structure_atom_count, trajectory_atom_count
110 # Set a function to get atoms from topology and trajectory together
111 def get_topology_and_trajectory_atoms (topology_file : 'File', trajectory_file : 'File') -> Tuple[int, int]:
112 # To do so rely on different tools depending on the topology format
113 # If there is no topology file then just compare strucutre and trajectory an exit
114 if topology_file == MISSING_TOPOLOGY:
115 # We do not have a topology atom count to return
116 # Without a valid topology we can not count trajectory atoms either
117 return None, None
118 # If it is our standard topology then simply count the atom names
119 # Get trajectory atoms using the structure instead
120 if topology_file.filename == STANDARD_TOPOLOGY_FILENAME:
121 # Parse the json and count atoms
122 parsed_topology = load_json(topology_file.path)
123 topology_atom_count = len(parsed_topology['atom_names'])
124 # Without a valid topology we can not count trajectory atoms
125 return topology_atom_count, None
126 # For a TPR use Gromacs, which is its native tool
127 if topology_file.format == 'tpr':
128 # Make sure the trajectory is compatible with gromacs
129 if trajectory_file.format not in GROMACS_TRAJECTORY_SUPPORTED_FORMATS:
130 raise InputError('Why loading a TPR topology with a non-gromacs trajectory?')
131 # Run Gromacs just to generate a structure using all atoms in the topology and coordinates in the first frame
132 # If atoms do not match then we will see a specific error
133 output_sample_file = File('.sample.gro')
134 p = Popen([ "echo", "System" ], stdout=PIPE)
135 process = run([
136 GROMACS_EXECUTABLE,
137 "trjconv",
138 "-s",
139 topology_file.path,
140 "-f",
141 trajectory_file.path,
142 '-o',
143 output_sample_file.path,
144 "-dump",
145 "0",
146 '-quiet'
147 ], stdin=p.stdout, stdout=PIPE, stderr=PIPE)
148 logs = process.stdout.decode()
149 p.stdout.close()
150 # Always get error logs and mine topology atoms
151 # Note that this logs include the output selection request from Gromacs
152 # This log should be always there, even if there was a mismatch and then Gromacs failed
153 error_logs = process.stderr.decode()
154 topology_atom_count = mine_system_atoms_count(error_logs)
155 # If the output does not exist at this point it means something went wrong with gromacs
156 if not output_sample_file.exists:
157 # Check if we know the error
158 error_match = search(GROMACS_ATOM_MISMATCH_ERROR, error_logs)
159 if error_match:
160 # Get the trajectory atom count
161 trajectory_atom_count = int(error_match[1])
162 return topology_atom_count, trajectory_atom_count
163 # Otherwise just print the whole error logs and stop here anyway
164 print(logs)
165 print(error_logs)
166 raise SystemExit('Something went wrong with GROMACS during the checking')
167 # If we had an output then it means both topology and trajectory match in the number of atoms
168 # Cleanup the file we just created and proceed
169 output_sample_file.remove()
170 # If there was no problem then it means the trajectory atom count matches the topology atom count
171 trajectory_atom_count = topology_atom_count
172 return topology_atom_count, trajectory_atom_count
173 # For .top files we use PyTraj since MDtraj can not handle it
174 if topology_file.format == 'top':
175 # Note that calling ierload will print a error log when atoms do not match but will not raise a proper error
176 # To capture the error log we must throw this command wrapped in a stdout redirect
177 trajectory = None
178 with CaptureOutput('stderr') as output:
179 trajectory = pyt.iterload(trajectory_file.path, top=topology_file.path)
180 logs = output.captured_text
181 error_match = match(PYTRAJ_XTC_ATOM_MISMATCH_ERROR, logs)
182 if error_match:
183 topology_atom_count = int(error_match[3])
184 trajectory_atom_count = int(error_match[1])
185 # Now obtain the number of atoms from the frame we just read
186 else:
187 topology_atom_count = trajectory_atom_count = trajectory.n_atoms
188 return topology_atom_count, trajectory_atom_count
189 # At this point the topology should be supported by MDtraj
190 # However, f the trajectory is a restart file MDtraj will not be able to read it
191 # Make the conversion here, since restart files are single-frame trajectories this should be fast
192 use_auxiliar_pdb = False
193 if trajectory_file.format == 'rst7':
194 # Generate the auxiliar PDB file
195 vmd_to_pdb(topology_file.path, trajectory_file.path, AUXILIAR_PDB_BILE)
196 use_auxiliar_pdb = True
197 # For any other format use MDtraj
198 try:
199 # Note that declaring the iterator will not fail even when there is a mismatch
200 trajectory_path = AUXILIAR_PDB_BILE if use_auxiliar_pdb else trajectory_file.path
201 trajectory = mdt.iterload(trajectory_path, top=topology_file.path, chunk=1)
202 # We must consume the generator first value to make the error raise
203 frame = next(trajectory)
204 # Now obtain the number of atoms from the frame we just read
205 trajectory_atom_count = frame.n_atoms
206 # And still, it may happen that the topology has more atoms than the trajectory but it loads
207 # MDtraj may silently load as many coordinates as possible and discard the rest of atoms in topology
208 # This behaviour has been observed with a gromacs .top topology and a PDB used as trajectory
209 # Two double check the match, load the topology alone with PyTraj
210 topology = pyt.load_topology(topology_file.path)
211 topology_atom_count = topology.n_atoms
212 return topology_atom_count, trajectory_atom_count
213 except Exception as error:
214 # If the error message matches with a known error then report the problem
215 error_message = str(error)
216 error_match = match(MDTRAJ_ATOM_MISMATCH_ERROR, error_message)
217 if error_match:
218 topology_atom_count = int(error_match[1])
219 trajectory_atom_count = int(error_match[2])
220 return topology_atom_count, trajectory_atom_count
221 # If we do not know the error then raise it as is
222 else:
223 raise error
225 # Get topology and trajectory atom counts
226 topology_atom_count, trajectory_atom_count = get_topology_and_trajectory_atoms(input_topology_file, trajectory_sample)
228 # If we have the trajectory atom count then it means we had a valid topology
229 if trajectory_atom_count != None:
231 # Make sure their atom counts match
232 if topology_atom_count != trajectory_atom_count:
233 warn('Mismatch in the number of atoms between input files:\n' +
234 f' Topology "{input_topology_file.path}" -> {topology_atom_count} atoms\n' +
235 f' Trajectory "{trajectory_sample.path}" -> {trajectory_atom_count} atoms')
236 if topology_atom_count < trajectory_atom_count:
237 raise InputError('Trajectory has more atoms than topology, there is no way to fix this.')
238 # If the topology has more atoms than the trajectory however we may attempt to guess
239 # If we guess which atoms are the ones in the trajectory then we can filter the topology
240 else:
241 prefiltered_topology_filepath = f'{input_topology_file.basepath}/prefiltered.{input_topology_file.format}'
242 prefiltered_topology_file = File(prefiltered_topology_filepath)
243 guessed = guess_and_filter_topology(
244 input_topology_file,
245 prefiltered_topology_file,
246 trajectory_atom_count)
247 if guessed: exceptions[PREFILTERED_TOPOLOGY_EXCEPTION] = prefiltered_topology_file
248 else: raise InputError('Could not guess topology atom selection to match trajectory atoms count')
250 # If the topology file is already the structure file then there is no need to check it
251 if input_structure_file == input_topology_file:
252 print(f'Topology and trajectory files match in number of atoms: {trajectory_atom_count}')
253 return exceptions
255 # If the counts match then also get the structure atom count and compare
256 structure_atom_count = get_structure_atoms(input_structure_file)
258 # Make sure it matches the topology and trajectory atom count
259 if topology_atom_count != structure_atom_count:
260 raise InputError('Mismatch in the structure input file number of atoms:\n'+
261 f' Topology and trajectory -> {topology_atom_count} atoms\n' +
262 f' Structure "{input_structure_file.path}" -> {structure_atom_count} atoms')
264 # If we reached this point then it means everything is matching
265 print(f'All input files match in number of atoms: {trajectory_atom_count}')
266 return exceptions
268 # Otherwise it means we had not a valid topology file
269 # We must use the structure to find trajectory atoms
270 structure_atom_count, trajectory_atom_count = get_structure_and_trajectory_atoms(input_structure_file, trajectory_sample)
272 # Make sure their atom counts match
273 if structure_atom_count != trajectory_atom_count:
274 raise InputError('Mismatch in the number of atoms between input files:\n' +
275 f' Structure "{input_structure_file.path}" -> {structure_atom_count} atoms\n' +
276 f' Trajectory "{trajectory_sample.path}" -> {trajectory_atom_count} atoms')
278 # If we have a number of topology atoms then make sure it matches the structure and trajectory atoms
279 # This may happen if the topology is our standard topology file instead of a valid topology
280 if topology_atom_count != None and topology_atom_count != trajectory_atom_count:
281 raise InputError('Mismatch in the number of atoms between input files:\n' +
282 f' Structure and trajectory -> {trajectory_atom_count} atoms\n' +
283 f' Topology "{input_topology_file.path}" -> {topology_atom_count} atoms')
285 # If we made it this far it means all checkings are good
286 print(f'Input files match in number of atoms: {trajectory_atom_count}')
287 return exceptions