Coverage for model_workflow/tools/check_inputs.py: 57%

143 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-23 10:54 +0000

1from model_workflow.utils.auxiliar import InputError, warn, CaptureOutput, load_json, MISSING_TOPOLOGY 

2from model_workflow.utils.constants import STANDARD_TOPOLOGY_FILENAME, GROMACS_EXECUTABLE 

3from model_workflow.utils.pyt_spells import find_first_corrupted_frame 

4from model_workflow.utils.gmx_spells import mine_system_atoms_count 

5from model_workflow.utils.vmd_spells import vmd_to_pdb 

6from model_workflow.utils.structures import Structure 

7from model_workflow.utils.file import File 

8 

9from model_workflow.tools.guess_and_filter import guess_and_filter_topology 

10 

11from re import match, search 

12from typing import * 

13from subprocess import run, PIPE, Popen 

14from scipy.io import netcdf_file 

15import mdtraj as mdt 

16import pytraj as pyt 

17 

18# Set some known message errors 

19NETCDF_DTYPE_ERROR = 'When changing to a larger dtype, its size must be a divisor of the total size in bytes of the last axis of the array.' 

20MDTRAJ_ATOM_MISMATCH_ERROR = r'xyz must be shape \(Any, ([0-9]*), 3\). You supplied \(1, ([0-9]*), 3\)' 

21PYTRAJ_XTC_ATOM_MISMATCH_ERROR = r'Error: # atoms in XTC file \(([0-9]*)\) does not match # atoms in (topology|parm) [\w.-]* \(([0-9]*)\)' 

22GROMACS_ATOM_MISMATCH_ERROR = r'is larger than the number of atoms in the\ntrajectory file \(([0-9]*)\). There is a mismatch in the contents' 

23 

24# List supported formats 

25TOPOLOGY_SUPPORTED_FORMATS = { 'tpr', 'top', 'prmtop', 'psf' } 

26TRAJECTORY_SUPPORTED_FORMATS = { 'xtc', 'trr', 'nc', 'dcd', 'crd', 'pdb', 'rst7' } 

27STRUCTURE_SUPPORTED_FORMATS = { *TOPOLOGY_SUPPORTED_FORMATS, 'pdb', 'gro' } 

28GROMACS_TRAJECTORY_SUPPORTED_FORMATS = { 'xtc', 'trr'} 

29 

30# Auxiliar PDB file which may be generated to load non supported restart files 

31AUXILIAR_PDB_BILE = '.auxiliar.pdb' 

32 

33# Set excpetions for fixes applied from here 

34PREFILTERED_TOPOLOGY_EXCEPTION = Exception('Prefiltered topology') 

35 

36# Check input files coherence and intergrity 

37# If there is any problem then raise an input error 

38# Some exceptional problems may be fixed from here 

39# In this cases both the exception and the modified file are return in a final dict 

40def check_inputs ( 

41 input_structure_file : 'File', 

42 input_trajectory_files : List['File'], 

43 input_topology_file : Union['File', Exception]) -> dict: 

44 

45 # Set the exceptions dict to be returned at the end 

46 exceptions = {} 

47 

48 # Get a sample trajectory file and then check its format 

49 # All input trajectory files must have the same format 

50 trajectory_sample = input_trajectory_files[0] 

51 

52 # Check input files are supported by the workflow 

53 if input_topology_file != MISSING_TOPOLOGY and input_topology_file.filename != STANDARD_TOPOLOGY_FILENAME and input_topology_file.format not in TOPOLOGY_SUPPORTED_FORMATS: 

54 if input_topology_file.format in { 'pdb', 'gro' }: 

55 raise InputError('A structure file is not supported as topology anymore. If there is no topology then use the argument "-top no"') 

56 raise InputError(f'Topology {input_topology_file.path} has a not supported format. Try one of these: {", ".join(TOPOLOGY_SUPPORTED_FORMATS)}') 

57 if trajectory_sample.format not in TRAJECTORY_SUPPORTED_FORMATS: 

58 raise InputError(f'Trajectory {trajectory_sample.path} has a not supported format. Try one of these: {", ".join(TRAJECTORY_SUPPORTED_FORMATS)}') 

59 if input_structure_file.format not in STRUCTURE_SUPPORTED_FORMATS: 

60 raise InputError(f'Structure {input_structure_file.path} has a not supported format. Try one of these: {", ".join(STRUCTURE_SUPPORTED_FORMATS)}') 

61 

62 # Make sure the trajectory file is not corrupted 

63 

64 # Check if reading the trajectory raises the following error 

65 # ValueError: When changing to a larger dtype, its size must be a divisor of the total size in bytes of the last axis of the array. 

66 # This error may happen with NetCDF files and it is a bit shady 

67 # Some tools may be able to read the first frames of the corrupted file: VMD and pytraj 

68 # Some other tools will instantly fail to read it: MDtraj and MDAnalysis 

69 if trajectory_sample.format == 'nc': 

70 try: 

71 # Iterate trajectory files 

72 for trajectory_file in input_trajectory_files: 

73 # This does not read the whole trajectory 

74 netcdf_file(trajectory_file.path, 'r') 

75 except Exception as error: 

76 # If the error message matches with a known error then report the problem 

77 error_message = str(error) 

78 if error_message == NETCDF_DTYPE_ERROR: 

79 warn(f'Corrupted trajectory file {trajectory_file.path}') 

80 pytraj_input_topology = input_topology_file if input_topology_file != MISSING_TOPOLOGY else input_structure_file 

81 first_corrupted_frame = find_first_corrupted_frame(pytraj_input_topology.path, trajectory_file.path) 

82 print(f' However some tools may be able to read the first {first_corrupted_frame} frames: VMD and PyTraj') 

83 raise InputError('Corrupted input trajectory file') 

84 # If we do not know the error then raise it as is 

85 else: 

86 raise error 

87 

88 # Set a function to get atoms from a structure alone 

89 def get_structure_atoms (structure_file : 'File') -> int: 

90 # Get the number of atoms in the input structure 

91 structure = Structure.from_file(structure_file.path) 

92 return structure.atom_count 

93 

94 # Set a function to get atoms from structure and trajectory together 

95 def get_structure_and_trajectory_atoms (structure_file : 'File', trajectory_file : 'File') -> Tuple[int, int]: 

96 # Note that declaring the iterator will not fail even when there is a mismatch 

97 trajectory = mdt.iterload(trajectory_file.path, top=structure_file.path, chunk=1) 

98 # We must consume the generator first value to make the error raise 

99 frame = next(trajectory) 

100 # Now obtain the number of atoms from the frame we just read 

101 trajectory_atom_count = frame.n_atoms 

102 # And still, it may happen that the topology has more atoms than the trajectory but it loads 

103 # MDtraj may silently load as many coordinates as possible and discard the rest of atoms in topology 

104 # This behaviour has been observed with a gromacs .top topology and a PDB used as trajectory 

105 # Two double check the match, load the topology alone with PyTraj 

106 topology = pyt.load_topology(structure_file.path) 

107 structure_atom_count = topology.n_atoms 

108 return structure_atom_count, trajectory_atom_count 

109 

110 # Set a function to get atoms from topology and trajectory together 

111 def get_topology_and_trajectory_atoms (topology_file : 'File', trajectory_file : 'File') -> Tuple[int, int]: 

112 # To do so rely on different tools depending on the topology format 

113 # If there is no topology file then just compare strucutre and trajectory an exit 

114 if topology_file == MISSING_TOPOLOGY: 

115 # We do not have a topology atom count to return 

116 # Without a valid topology we can not count trajectory atoms either 

117 return None, None 

118 # If it is our standard topology then simply count the atom names 

119 # Get trajectory atoms using the structure instead 

120 if topology_file.filename == STANDARD_TOPOLOGY_FILENAME: 

121 # Parse the json and count atoms 

122 parsed_topology = load_json(topology_file.path) 

123 topology_atom_count = len(parsed_topology['atom_names']) 

124 # Without a valid topology we can not count trajectory atoms 

125 return topology_atom_count, None 

126 # For a TPR use Gromacs, which is its native tool 

127 if topology_file.format == 'tpr': 

128 # Make sure the trajectory is compatible with gromacs 

129 if trajectory_file.format not in GROMACS_TRAJECTORY_SUPPORTED_FORMATS: 

130 raise InputError('Why loading a TPR topology with a non-gromacs trajectory?') 

131 # Run Gromacs just to generate a structure using all atoms in the topology and coordinates in the first frame 

132 # If atoms do not match then we will see a specific error 

133 output_sample_file = File('.sample.gro') 

134 p = Popen([ "echo", "System" ], stdout=PIPE) 

135 process = run([ 

136 GROMACS_EXECUTABLE, 

137 "trjconv", 

138 "-s", 

139 topology_file.path, 

140 "-f", 

141 trajectory_file.path, 

142 '-o', 

143 output_sample_file.path, 

144 "-dump", 

145 "0", 

146 '-quiet' 

147 ], stdin=p.stdout, stdout=PIPE, stderr=PIPE) 

148 logs = process.stdout.decode() 

149 p.stdout.close() 

150 # Always get error logs and mine topology atoms 

151 # Note that this logs include the output selection request from Gromacs 

152 # This log should be always there, even if there was a mismatch and then Gromacs failed 

153 error_logs = process.stderr.decode() 

154 topology_atom_count = mine_system_atoms_count(error_logs) 

155 # If the output does not exist at this point it means something went wrong with gromacs 

156 if not output_sample_file.exists: 

157 # Check if we know the error 

158 error_match = search(GROMACS_ATOM_MISMATCH_ERROR, error_logs) 

159 if error_match: 

160 # Get the trajectory atom count 

161 trajectory_atom_count = int(error_match[1]) 

162 return topology_atom_count, trajectory_atom_count 

163 # Otherwise just print the whole error logs and stop here anyway 

164 print(logs) 

165 print(error_logs) 

166 raise SystemExit('Something went wrong with GROMACS during the checking') 

167 # If we had an output then it means both topology and trajectory match in the number of atoms 

168 # Cleanup the file we just created and proceed 

169 output_sample_file.remove() 

170 # If there was no problem then it means the trajectory atom count matches the topology atom count 

171 trajectory_atom_count = topology_atom_count 

172 return topology_atom_count, trajectory_atom_count 

173 # For .top files we use PyTraj since MDtraj can not handle it 

174 if topology_file.format == 'top': 

175 # Note that calling ierload will print a error log when atoms do not match but will not raise a proper error 

176 # To capture the error log we must throw this command wrapped in a stdout redirect 

177 trajectory = None 

178 with CaptureOutput('stderr') as output: 

179 trajectory = pyt.iterload(trajectory_file.path, top=topology_file.path) 

180 logs = output.captured_text 

181 error_match = match(PYTRAJ_XTC_ATOM_MISMATCH_ERROR, logs) 

182 if error_match: 

183 topology_atom_count = int(error_match[3]) 

184 trajectory_atom_count = int(error_match[1]) 

185 # Now obtain the number of atoms from the frame we just read 

186 else: 

187 topology_atom_count = trajectory_atom_count = trajectory.n_atoms 

188 return topology_atom_count, trajectory_atom_count 

189 # At this point the topology should be supported by MDtraj 

190 # However, f the trajectory is a restart file MDtraj will not be able to read it 

191 # Make the conversion here, since restart files are single-frame trajectories this should be fast 

192 use_auxiliar_pdb = False 

193 if trajectory_file.format == 'rst7': 

194 # Generate the auxiliar PDB file 

195 vmd_to_pdb(topology_file.path, trajectory_file.path, AUXILIAR_PDB_BILE) 

196 use_auxiliar_pdb = True 

197 # For any other format use MDtraj 

198 try: 

199 # Note that declaring the iterator will not fail even when there is a mismatch 

200 trajectory_path = AUXILIAR_PDB_BILE if use_auxiliar_pdb else trajectory_file.path 

201 trajectory = mdt.iterload(trajectory_path, top=topology_file.path, chunk=1) 

202 # We must consume the generator first value to make the error raise 

203 frame = next(trajectory) 

204 # Now obtain the number of atoms from the frame we just read 

205 trajectory_atom_count = frame.n_atoms 

206 # And still, it may happen that the topology has more atoms than the trajectory but it loads 

207 # MDtraj may silently load as many coordinates as possible and discard the rest of atoms in topology 

208 # This behaviour has been observed with a gromacs .top topology and a PDB used as trajectory 

209 # Two double check the match, load the topology alone with PyTraj 

210 topology = pyt.load_topology(topology_file.path) 

211 topology_atom_count = topology.n_atoms 

212 return topology_atom_count, trajectory_atom_count 

213 except Exception as error: 

214 # If the error message matches with a known error then report the problem 

215 error_message = str(error) 

216 error_match = match(MDTRAJ_ATOM_MISMATCH_ERROR, error_message) 

217 if error_match: 

218 topology_atom_count = int(error_match[1]) 

219 trajectory_atom_count = int(error_match[2]) 

220 return topology_atom_count, trajectory_atom_count 

221 # If we do not know the error then raise it as is 

222 else: 

223 raise error 

224 

225 # Get topology and trajectory atom counts 

226 topology_atom_count, trajectory_atom_count = get_topology_and_trajectory_atoms(input_topology_file, trajectory_sample) 

227 

228 # If we have the trajectory atom count then it means we had a valid topology 

229 if trajectory_atom_count != None: 

230 

231 # Make sure their atom counts match 

232 if topology_atom_count != trajectory_atom_count: 

233 warn('Mismatch in the number of atoms between input files:\n' + 

234 f' Topology "{input_topology_file.path}" -> {topology_atom_count} atoms\n' + 

235 f' Trajectory "{trajectory_sample.path}" -> {trajectory_atom_count} atoms') 

236 if topology_atom_count < trajectory_atom_count: 

237 raise InputError('Trajectory has more atoms than topology, there is no way to fix this.') 

238 # If the topology has more atoms than the trajectory however we may attempt to guess 

239 # If we guess which atoms are the ones in the trajectory then we can filter the topology 

240 else: 

241 prefiltered_topology_filepath = f'{input_topology_file.basepath}/prefiltered.{input_topology_file.format}' 

242 prefiltered_topology_file = File(prefiltered_topology_filepath) 

243 guessed = guess_and_filter_topology( 

244 input_topology_file, 

245 prefiltered_topology_file, 

246 trajectory_atom_count) 

247 if guessed: exceptions[PREFILTERED_TOPOLOGY_EXCEPTION] = prefiltered_topology_file 

248 else: raise InputError('Could not guess topology atom selection to match trajectory atoms count') 

249 

250 # If the topology file is already the structure file then there is no need to check it 

251 if input_structure_file == input_topology_file: 

252 print(f'Topology and trajectory files match in number of atoms: {trajectory_atom_count}') 

253 return exceptions 

254 

255 # If the counts match then also get the structure atom count and compare 

256 structure_atom_count = get_structure_atoms(input_structure_file) 

257 

258 # Make sure it matches the topology and trajectory atom count 

259 if topology_atom_count != structure_atom_count: 

260 raise InputError('Mismatch in the structure input file number of atoms:\n'+ 

261 f' Topology and trajectory -> {topology_atom_count} atoms\n' + 

262 f' Structure "{input_structure_file.path}" -> {structure_atom_count} atoms') 

263 

264 # If we reached this point then it means everything is matching 

265 print(f'All input files match in number of atoms: {trajectory_atom_count}') 

266 return exceptions 

267 

268 # Otherwise it means we had not a valid topology file 

269 # We must use the structure to find trajectory atoms 

270 structure_atom_count, trajectory_atom_count = get_structure_and_trajectory_atoms(input_structure_file, trajectory_sample) 

271 

272 # Make sure their atom counts match 

273 if structure_atom_count != trajectory_atom_count: 

274 raise InputError('Mismatch in the number of atoms between input files:\n' + 

275 f' Structure "{input_structure_file.path}" -> {structure_atom_count} atoms\n' + 

276 f' Trajectory "{trajectory_sample.path}" -> {trajectory_atom_count} atoms') 

277 

278 # If we have a number of topology atoms then make sure it matches the structure and trajectory atoms 

279 # This may happen if the topology is our standard topology file instead of a valid topology 

280 if topology_atom_count != None and topology_atom_count != trajectory_atom_count: 

281 raise InputError('Mismatch in the number of atoms between input files:\n' + 

282 f' Structure and trajectory -> {trajectory_atom_count} atoms\n' + 

283 f' Topology "{input_topology_file.path}" -> {topology_atom_count} atoms') 

284 

285 # If we made it this far it means all checkings are good 

286 print(f'Input files match in number of atoms: {trajectory_atom_count}') 

287 return exceptions