Coverage for mddb_workflow / tools / check_inputs.py: 58%

157 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 18:45 +0000

1from mddb_workflow.utils.auxiliar import InputError, ToolError 

2from mddb_workflow.utils.auxiliar import warn, CaptureOutput, load_json, MISSING_TOPOLOGY 

3from mddb_workflow.utils.auxiliar import is_standard_topology 

4from mddb_workflow.utils.pyt_spells import find_first_corrupted_frame 

5from mddb_workflow.utils.gmx_spells import run_gromacs, mine_system_atoms_count, get_atom_count 

6from mddb_workflow.utils.vmd_spells import vmd_to_pdb 

7from mddb_workflow.utils.structures import Structure 

8from mddb_workflow.utils.file import File 

9 

10from mddb_workflow.tools.guess_and_filter import guess_and_filter_topology 

11 

12import re 

13from typing import * 

14from scipy.io import netcdf_file 

15import mdtraj as mdt 

16import pytraj as pyt 

17 

18# Set some known message errors 

19NETCDF_DTYPE_ERROR = 'When changing to a larger dtype, its size must be a divisor of the total size in bytes of the last axis of the array.' 

20MDTRAJ_ATOM_MISMATCH_ERROR = r'xyz must be shape \(Any, ([0-9]*), 3\). You supplied \(1, ([0-9]*), 3\)' 

21MDTRAJ_INSERTION_CODES_ERROR = r'^Could not convert residue number \[[0-9]*[a-zA-Z]\]$' 

22PYTRAJ_XTC_ATOM_MISMATCH_ERROR = r'Error: # atoms in XTC file \(([0-9]*)\) does not match # atoms in (topology|parm) [\w.-]* \(([0-9]*)\)' 

23GROMACS_ATOM_MISMATCH_ERROR = r'is larger than the number of atoms in the\ntrajectory file \(([0-9]*)\). There is a mismatch in the contents' 

24GROMACS_ATOM_COUNT_CHECK = r'# Atoms ([0-9]*)' 

25 

26# List supported formats 

27TOPOLOGY_SUPPORTED_FORMATS = {'tpr', 'top', 'prmtop', 'psf'} 

28TRAJECTORY_SUPPORTED_FORMATS = {'xtc', 'trr', 'nc', 'dcd', 'crd', 'pdb', 'rst7'} 

29STRUCTURE_SUPPORTED_FORMATS = {*TOPOLOGY_SUPPORTED_FORMATS, 'pdb', 'gro'} 

30GROMACS_TRAJECTORY_SUPPORTED_FORMATS = {'xtc', 'trr'} 

31 

32# Auxiliar PDB file which may be generated to load non supported restart files 

33AUXILIAR_PDB_FILE = '.auxiliar.pdb' 

34 

35# Set exceptions for fixes applied from here 

36PREFILTERED_TOPOLOGY_EXCEPTION = Exception('Prefiltered topology') 

37 

38 

39def check_inputs( 

40 input_structure_file: 'File', 

41 input_trajectory_files: list['File'], 

42 input_topology_file: Union['File', Exception]) -> dict: 

43 """Check input files coherence and integrity. 

44 If there is any problem then raises an input error. 

45 Some exceptional problems may be fixed from here. 

46 In these cases, both the exception and the modified file are returned in a final dict. 

47 """ 

48 # Set the exceptions dict to be returned at the end 

49 exceptions = {} 

50 

51 # Get a sample trajectory file and then check its format 

52 # All input trajectory files must have the same format 

53 trajectory_sample = input_trajectory_files[0] 

54 

55 # Check input files are supported by the workflow 

56 if input_topology_file != MISSING_TOPOLOGY and not is_standard_topology(input_topology_file) and input_topology_file.format not in TOPOLOGY_SUPPORTED_FORMATS: 

57 if input_topology_file.format in {'pdb', 'gro'}: 

58 raise InputError('A structure file is not supported as topology anymore. If there is no topology then use the argument "-top no"') 

59 raise InputError(f'Topology {input_topology_file.path} has a not supported format. Try one of these: {", ".join(TOPOLOGY_SUPPORTED_FORMATS)}') 

60 if trajectory_sample.format not in TRAJECTORY_SUPPORTED_FORMATS: 

61 raise InputError(f'Trajectory {trajectory_sample.path} has a not supported format. Try one of these: {", ".join(TRAJECTORY_SUPPORTED_FORMATS)}') 

62 if input_structure_file.format not in STRUCTURE_SUPPORTED_FORMATS: 

63 raise InputError(f'Structure {input_structure_file.path} has a not supported format. Try one of these: {", ".join(STRUCTURE_SUPPORTED_FORMATS)}') 

64 

65 # Make sure the trajectory file is not corrupted 

66 

67 # Check if reading the trajectory raises the following error 

68 # ValueError: When changing to a larger dtype, its size must be a divisor of the total size in bytes of the last axis of the array. 

69 # This error may happen with NetCDF files and it is a bit shady 

70 # Some tools may be able to read the first frames of the corrupted file: VMD and pytraj 

71 # Some other tools will instantly fail to read it: MDtraj and MDAnalysis 

72 if trajectory_sample.format == 'nc': 

73 try: 

74 # Iterate trajectory files 

75 for trajectory_file in input_trajectory_files: 

76 # This does not read the whole trajectory 

77 netcdf_file(trajectory_file.path, 'r') 

78 except Exception as error: 

79 # If the error message matches with a known error then report the problem 

80 error_message = str(error) 

81 if error_message == NETCDF_DTYPE_ERROR: 

82 warn(f'Corrupted trajectory file {trajectory_file.path}') 

83 pytraj_input_topology = input_topology_file if input_topology_file != MISSING_TOPOLOGY else input_structure_file 

84 first_corrupted_frame = find_first_corrupted_frame(pytraj_input_topology.path, trajectory_file.path) 

85 print(f' However some tools may be able to read the first {first_corrupted_frame} frames: VMD and PyTraj') 

86 raise InputError('Corrupted input trajectory file') 

87 # If we do not know the error then raise it as is 

88 else: 

89 raise error 

90 

91 # Get topology and trajectory atom counts 

92 topology_atom_count, trajectory_atom_count = get_topology_and_trajectory_atoms(input_topology_file, trajectory_sample) 

93 

94 # If we have the trajectory atom count then it means we had a valid topology 

95 if trajectory_atom_count is not None: 

96 

97 # Make sure their atom counts match 

98 if topology_atom_count != trajectory_atom_count: 

99 warn('Mismatch in the number of atoms between input files:\n' + 

100 f' Topology "{input_topology_file.path}" -> {topology_atom_count} atoms\n' + 

101 f' Trajectory "{trajectory_sample.path}" -> {trajectory_atom_count} atoms') 

102 if topology_atom_count < trajectory_atom_count: 

103 raise InputError('Trajectory has more atoms than topology, there is no way to fix this.') 

104 # If the topology has more atoms than the trajectory however we may attempt to guess 

105 # If we guess which atoms are the ones in the trajectory then we can filter the topology 

106 else: 

107 prefiltered_topology_filepath = f'{input_topology_file.basepath}/prefiltered.{input_topology_file.format}' 

108 prefiltered_topology_file = File(prefiltered_topology_filepath) 

109 guessed = guess_and_filter_topology( 

110 input_topology_file, 

111 prefiltered_topology_file, 

112 trajectory_atom_count) 

113 if guessed: exceptions[PREFILTERED_TOPOLOGY_EXCEPTION] = prefiltered_topology_file 

114 else: raise InputError('Could not guess topology atom selection to match trajectory atoms count') 

115 

116 # If the topology file is already the structure file then there is no need to check it 

117 if input_structure_file == input_topology_file: 

118 print(f'Topology and trajectory files match in number of atoms: {trajectory_atom_count}') 

119 return exceptions 

120 

121 # If the counts match then also get the structure atom count and compare 

122 structure_atom_count = get_structure_atoms(input_structure_file) 

123 

124 # Make sure it matches the topology and trajectory atom count 

125 if topology_atom_count != structure_atom_count: 

126 raise InputError('Mismatch in the structure input file number of atoms:\n'+ 

127 f' Topology and trajectory -> {topology_atom_count} atoms\n' + 

128 f' Structure "{input_structure_file.path}" -> {structure_atom_count} atoms') 

129 

130 # If we reached this point then it means everything is matching 

131 print(f'All input files match in number of atoms: {trajectory_atom_count}') 

132 return exceptions 

133 

134 # Otherwise it means we had not a valid topology file 

135 # We must use the structure to find trajectory atoms 

136 structure_atom_count, trajectory_atom_count = get_structure_and_trajectory_atoms(input_structure_file, trajectory_sample) 

137 

138 # Make sure their atom counts match 

139 if structure_atom_count != trajectory_atom_count: 

140 raise InputError('Mismatch in the number of atoms between input files:\n' + 

141 f' Structure "{input_structure_file.path}" -> {structure_atom_count} atoms\n' + 

142 f' Trajectory "{trajectory_sample.path}" -> {trajectory_atom_count} atoms') 

143 

144 # If we have a number of topology atoms then make sure it matches the structure and trajectory atoms 

145 # This may happen if the topology is our standard topology file instead of a valid topology 

146 if topology_atom_count is not None and topology_atom_count != trajectory_atom_count: 

147 raise InputError('Mismatch in the number of atoms between input files:\n' + 

148 f' Structure and trajectory -> {trajectory_atom_count} atoms\n' + 

149 f' Topology "{input_topology_file.path}" -> {topology_atom_count} atoms') 

150 

151 # If we made it this far it means all checkings are good 

152 print(f'Input files match in number of atoms: {trajectory_atom_count}') 

153 return exceptions 

154 

155 

156def get_topology_and_trajectory_atoms_pytraj(topology_file: 'File', trajectory_file: 'File') -> tuple[int, int]: 

157 """Get atoms from topology and trajectory together using pytraj. 

158 This is an altermative method used when MDtraj can not handle it. 

159 """ 

160 # Note that calling iterload will print a error log when atoms do not match but will not raise a proper error 

161 # To capture the error log we must throw this command wrapped in a stdout redirect 

162 trajectory = None 

163 with CaptureOutput('stderr') as output: 

164 trajectory = pyt.iterload(trajectory_file.path, top=topology_file.path) 

165 logs = output.captured_text 

166 error_match = re.match(PYTRAJ_XTC_ATOM_MISMATCH_ERROR, logs) 

167 if error_match: 

168 topology_atom_count = int(error_match[3]) 

169 trajectory_atom_count = int(error_match[1]) 

170 # Now obtain the number of atoms from the frame we just read 

171 else: 

172 topology_atom_count = trajectory_atom_count = trajectory.n_atoms 

173 return topology_atom_count, trajectory_atom_count 

174 

175 

176def get_topology_and_trajectory_atoms(topology_file: 'File', trajectory_file: 'File') -> tuple[int, int]: 

177 """Get atoms from topology and trajectory together.""" 

178 # To do so rely on different tools depending on the topology format 

179 # If there is no topology file then just compare strucutre and trajectory an exit 

180 if topology_file == MISSING_TOPOLOGY: 

181 # We do not have a topology atom count to return 

182 # Without a valid topology we can not count trajectory atoms either 

183 return None, None 

184 # If it is our standard topology then simply count the atom names 

185 # Get trajectory atoms using the structure instead 

186 if is_standard_topology(topology_file): 

187 # Parse the json and count atoms 

188 parsed_topology = load_json(topology_file.path) 

189 topology_atom_count = len(parsed_topology['atom_names']) 

190 # Without a valid topology we can not count trajectory atoms 

191 return topology_atom_count, None 

192 # For a TPR use Gromacs, which is its native tool 

193 if topology_file.format == 'tpr': 

194 # Make sure the trajectory is compatible with gromacs 

195 if trajectory_file.format not in GROMACS_TRAJECTORY_SUPPORTED_FORMATS: 

196 raise InputError('Why loading a TPR topology with a non-gromacs trajectory?') 

197 # Run Gromacs just to generate a structure using all atoms in the topology and coordinates in the first frame 

198 # If trajectory atoms are fewer than topology atoms then we will see a specific error 

199 output_sample_gro_file = File('.sample.gro') 

200 output_logs, error_logs = run_gromacs(f'trjconv -s {topology_file.path} \ 

201 -f {trajectory_file.path} -o {output_sample_gro_file.path} -dump 0', 

202 user_input='System', expected_output_filepath=None) 

203 # Always get error logs and mine topology atoms 

204 # Note that these logs include the output selection request from Gromacs 

205 # This log should be always there, even if there was a mismatch and then Gromacs failed 

206 topology_atom_count = mine_system_atoms_count(error_logs) 

207 # If the output does not exist at this point it means something went wrong with gromacs 

208 if not output_sample_gro_file.exists: 

209 # Check if we know the error 

210 error_match = re.search(GROMACS_ATOM_MISMATCH_ERROR, error_logs) 

211 if error_match: 

212 # Get the trajectory atom count 

213 trajectory_atom_count = int(error_match[1]) 

214 return topology_atom_count, trajectory_atom_count 

215 # Otherwise just print the whole error logs and stop here anyway 

216 print(output_logs) 

217 print(error_logs) 

218 raise ToolError('Something went wrong with GROMACS during the checking') 

219 # If we had an output then it means both topology and trajectory match in the number of atoms 

220 # Cleanup the file we just created and proceed 

221 output_sample_gro_file.remove() 

222 # Now make sure trajectory atoms are not more than topology atoms 

223 # Easiest way to print trajectory atoms is using gmx check 

224 # However if we feed this command with the whole trajectory it will read it all 

225 # To prevent this we must create a single frame before 

226 output_sample_xtc_file = File('.sample.xtc') 

227 # Note that we do NOT pass the -s argument here 

228 # Otherwise the structure/topology would eclipse the actual number of atoms in the trajectory 

229 run_gromacs(f'trjconv -f {trajectory_file.path} -o {output_sample_xtc_file.path} -dump 0', 

230 user_input='System', expected_output_filepath=output_sample_xtc_file.path) 

231 # Now read the number of atoms 

232 output_logs, error_logs = run_gromacs(f'check -f {output_sample_xtc_file.path}') 

233 search_results = re.search(GROMACS_ATOM_COUNT_CHECK, error_logs) 

234 if not search_results: 

235 print(error_logs) 

236 raise RuntimeError('Something went wrong when reading trajectory atoms') 

237 # Get the trajectory atom count 

238 trajectory_atom_count = int(search_results[1]) 

239 # Cleanup the file we just created and proceed 

240 output_sample_xtc_file.remove() 

241 return topology_atom_count, trajectory_atom_count 

242 # For .top files we use PyTraj since MDtraj can not handle it 

243 if topology_file.format == 'top': 

244 return get_topology_and_trajectory_atoms_pytraj(topology_file, trajectory_file) 

245 # At this point the topology should be supported by MDtraj 

246 # However, f the trajectory is a restart file MDtraj will not be able to read it 

247 # Make the conversion here, since restart files are single-frame trajectories this should be fast 

248 use_auxiliar_pdb = False 

249 if trajectory_file.format == 'rst7': 

250 # Generate the auxiliar PDB file 

251 vmd_to_pdb(topology_file.path, trajectory_file.path, AUXILIAR_PDB_FILE) 

252 use_auxiliar_pdb = True 

253 # For any other format use MDtraj 

254 try: 

255 # Note that declaring the iterator will not fail even when there is a mismatch 

256 trajectory_path = AUXILIAR_PDB_FILE if use_auxiliar_pdb else trajectory_file.path 

257 trajectory = mdt.iterload(trajectory_path, top=topology_file.path, chunk=1) 

258 # We must consume the generator first value to make the error raise 

259 frame = next(trajectory) 

260 # Now obtain the number of atoms from the frame we just read 

261 trajectory_atom_count = frame.n_atoms 

262 # And still, it may happen that the topology has more atoms than the trajectory but it loads 

263 # MDtraj may silently load as many coordinates as possible and discard the rest of atoms in topology 

264 # This behaviour has been observed with a gromacs .top topology and a PDB used as trajectory 

265 # Two double check the match, load the topology alone with PyTraj 

266 topology = pyt.load_topology(topology_file.path) 

267 topology_atom_count = topology.n_atoms 

268 return topology_atom_count, trajectory_atom_count 

269 except Exception as error: 

270 # If the error message matches with a known error then report the problem 

271 error_message = str(error) 

272 error_match = re.match(MDTRAJ_ATOM_MISMATCH_ERROR, error_message) 

273 if error_match: 

274 topology_atom_count = int(error_match[1]) 

275 trajectory_atom_count = int(error_match[2]) 

276 return topology_atom_count, trajectory_atom_count 

277 error_match = re.match(MDTRAJ_INSERTION_CODES_ERROR, error_message) 

278 if error_match: 

279 warn('The input topology has insertion codes.\n'+ \ 

280 ' Some tools may crash when reading the topology (MDtraj).\n'+ \ 

281 ' Some tools may ignore insertion codes when reading the topology (MDAnlysis, PyTraj, VMD).') 

282 # Use other tool to read the topology 

283 # Other tools could ignore the inserion codes 

284 # However this is not a problem here, where we only care bout the number of atoms 

285 return get_topology_and_trajectory_atoms_pytraj(topology_file, trajectory_file) 

286 # If we do not know the error then raise it as is 

287 raise error 

288 

289 

290def get_structure_atoms(structure_file: 'File') -> int: 

291 """Get atoms from a structure alone.""" 

292 # If this is not a Structure supported file then use an alternative function 

293 if structure_file.format == 'gro': 

294 return get_atom_count(structure_file) 

295 # Get the number of atoms in the input structure 

296 structure = Structure.from_file(structure_file.path) 

297 return structure.atom_count 

298 

299 

300def get_structure_and_trajectory_atoms(structure_file: 'File', trajectory_file: 'File') -> tuple[int, int]: 

301 """Get atoms from structure and trajectory together.""" 

302 # Note that declaring the iterator will not fail even when there is a mismatch 

303 trajectory = mdt.iterload(trajectory_file.path, top=structure_file.path, chunk=1) 

304 # We must consume the generator first value to make the error raise 

305 frame = next(trajectory) 

306 # Now obtain the number of atoms from the frame we just read 

307 trajectory_atom_count = frame.n_atoms 

308 # And still, it may happen that the topology has more atoms than the trajectory but it loads 

309 # MDtraj may silently load as many coordinates as possible and discard the rest of atoms in topology 

310 # This behaviour has been observed with a gromacs .top topology and a PDB used as trajectory 

311 # Two double check the match, load the topology alone with PyTraj 

312 topology = pyt.load_topology(structure_file.path) 

313 structure_atom_count = topology.n_atoms 

314 return structure_atom_count, trajectory_atom_count