Coverage for model_workflow/tools/filter_atoms.py: 38%

133 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-23 10:54 +0000

1from os import remove 

2from os.path import exists 

3from subprocess import run, PIPE, Popen 

4import json 

5 

6import pytraj as pt 

7 

8from model_workflow.utils.constants import STANDARD_TOPOLOGY_FILENAME, RAW_CHARGES_FILENAME 

9from model_workflow.utils.structures import Structure 

10from model_workflow.utils.auxiliar import save_json, MISSING_TOPOLOGY 

11from model_workflow.utils.gmx_spells import get_tpr_atom_count, tpr_filter, xtc_filter, pdb_filter 

12from model_workflow.utils.type_hints import * 

13from model_workflow.tools.get_charges import get_raw_charges 

14 

15# Set the gromacs indices filename 

16index_filename = 'filter.ndx' 

17# Set the name for the group name in gromacs ndx file 

18filter_group_name = "not_water_or_counter_ions" 

19 

20# Filter atoms of all input topologies by remvoing atoms and ions 

21# As an exception, some water and ions may be not removed if specified 

22# At the end, all topologies must match in atoms count 

23def filter_atoms ( 

24 input_structure_file : 'File', 

25 input_trajectory_file : 'File', 

26 input_topology_file : Union['File', Exception], 

27 output_structure_file : 'File', 

28 output_trajectory_file : 'File', 

29 output_topology_file : Union['File', Exception], 

30 # Reference structure used to parse the actual selection 

31 reference_structure : 'Structure', 

32 # Filter selection may be a custom selection or true 

33 # If true then we run a default filtering of water and counter ions 

34 filter_selection : Union[bool, str], 

35 filter_selection_syntax : str = 'vmd', 

36): 

37 

38 # Handle missing filter selection 

39 if not filter_selection: 

40 return 

41 

42 # Parse the selection to be filtered 

43 # WARNING: Note that the structure is not corrected at this point and there may be limitations 

44 parsed_filter_selection = None 

45 if filter_selection == True: parsed_filter_selection = reference_structure.select_water_and_counter_ions() 

46 else: parsed_filter_selection = reference_structure.select(filter_selection, syntax=filter_selection_syntax) 

47 

48 # Invert the parsed selection to get the atoms to remain 

49 keep_selection = reference_structure.invert_selection(parsed_filter_selection) 

50 

51 # Set the pytraj mask to filter the desired atoms from the structure 

52 filter_mask = keep_selection.to_pytraj() 

53 

54 # Load the structure and trajectory 

55 trajectory = pt.iterload(input_trajectory_file.path, input_structure_file.path) 

56 pt_topology = trajectory.topology 

57 atoms_count = pt_topology.n_atoms 

58 

59 # Set the filtered structure 

60 filtered_pt_topology = pt_topology[filter_mask] 

61 filtered_structure_atoms_count = filtered_pt_topology.n_atoms 

62 

63 # Check if both the normal and the filtered topologies have the same number of atoms 

64 # If not, filter the whole trajectory and overwrite both topologies and trajectory 

65 print(f'Total number of atoms: {atoms_count}') 

66 print(f'Filtered number of atoms: {filtered_structure_atoms_count}') 

67 if filtered_structure_atoms_count < atoms_count: 

68 print('Filtering structure and trajectory...') 

69 # Filter both structure and trajectory using Gromacs, since is more efficient than pytraj 

70 # Set up an index file with all atom indices manually 

71 # As long as indices are for atoms and not residues there should never be any incompatibility 

72 keep_selection.to_ndx_file(output_filepath = index_filename) 

73 # Filter the trajectory 

74 xtc_filter(input_structure_file.path, 

75 input_trajectory_file.path, 

76 output_trajectory_file.path, 

77 index_filename, 

78 filter_group_name) 

79 # Filter the structure 

80 pdb_filter( 

81 input_structure_file.path, 

82 output_structure_file.path, 

83 index_filename, 

84 filter_group_name) 

85 

86 # Filter topology according to the file format 

87 if input_topology_file != MISSING_TOPOLOGY and input_topology_file.exists: 

88 print('Filtering topology...') 

89 # Pytraj supported formats 

90 if input_topology_file.is_pytraj_supported(): 

91 # Load the topology and count its atoms 

92 pt_topology = pt.load_topology(filename=input_topology_file.path) 

93 topology_atoms_count = pt_topology.n_atoms 

94 print(f'Topology atoms count: {topology_atoms_count}') 

95 # Filter the desired atoms using the mask and then count them 

96 filtered_pt_topology = pt_topology[filter_mask] 

97 filtered_topology_atoms_count = filtered_pt_topology.n_atoms 

98 # If there is a difference in atom counts then write the filtered topology 

99 if filtered_topology_atoms_count < topology_atoms_count: 

100 # WARNING: If the output topology is a symlink it will try to overwrite the origin 

101 # Remove it to avoid overwriting input data 

102 if output_topology_file.is_symlink(): output_topology_file.remove() 

103 # Now write the filtered topology 

104 pt.write_parm( 

105 filename=output_topology_file.path, 

106 top=filtered_pt_topology, 

107 format=input_topology_file.get_pytraj_parm_format(), 

108 overwrite=True 

109 ) 

110 # Gromacs format format 

111 elif input_topology_file.format == 'tpr': 

112 # Get the input tpr atom count 

113 topology_atoms_count = get_tpr_atom_count(input_topology_file.path) 

114 print(f'Topology atoms count: {topology_atoms_count}') 

115 # If the number of atoms is greater than expected then filter the tpr file 

116 if topology_atoms_count > filtered_structure_atoms_count: 

117 if not exists(index_filename): 

118 # In order to filter the tpr we need the filter.ndx file 

119 # This must be generated from a pytraj supported topology that matches the number of atoms in the tpr file 

120 raise ValueError('Topology atoms number does not match the structure atoms number and tpr files can not be filtered alone') 

121 tpr_filter( 

122 input_topology_file.path, 

123 output_topology_file.path, 

124 index_filename, 

125 filter_group_name) 

126 # Get the output tpr atom count 

127 filtered_topology_atoms_count = get_tpr_atom_count(output_topology_file.path) 

128 else: 

129 filtered_topology_atoms_count = topology_atoms_count 

130 # Standard topology 

131 elif input_topology_file.filename == STANDARD_TOPOLOGY_FILENAME: 

132 standard_topology = None 

133 with open(input_topology_file.path, 'r') as file: 

134 standard_topology = json.load(file) 

135 topology_atoms_count = len(standard_topology['atom_names']) 

136 print(f'Topology atoms count: {topology_atoms_count}') 

137 # Make it match since there is no problem when these 2 do not match 

138 filtered_topology_atoms_count = filtered_structure_atoms_count 

139 # If the number of charges does not match the number of atoms then filter the topology 

140 if topology_atoms_count != filtered_structure_atoms_count: 

141 standard_topology_filter(input_topology_file, reference_structure, parsed_filter_selection, output_topology_file) 

142 # Raw charges 

143 elif input_topology_file.filename == RAW_CHARGES_FILENAME: 

144 charges = get_raw_charges(input_topology_file.path) 

145 # Nothing to do here. It better matches by defualt or we have a problem 

146 filtered_topology_atoms_count = len(charges) 

147 print(f'Topology atoms count: {filtered_topology_atoms_count}') 

148 else: 

149 raise ValueError(f'Topology file ({input_topology_file.filename}) is in a non supported format') 

150 

151 print(f'Filtered topology atoms: {filtered_topology_atoms_count}') 

152 

153 # Both filtered structure and topology must have the same number of atoms 

154 if filtered_structure_atoms_count != filtered_topology_atoms_count: 

155 print(f'Filtered structure atoms: {filtered_structure_atoms_count}') 

156 raise ValueError('Filtered atom counts in topology and charges does not match') 

157 

158 # Remove the index file in case it was created 

159 if exists(index_filename): 

160 remove(index_filename) 

161 

162 # Check if any of the output files does not exist 

163 # If so, then it means there was nothing to filter 

164 # However the output file is expected, so me make symlink 

165 if not output_structure_file.exists: 

166 output_structure_file.set_symlink_to(input_structure_file) 

167 if not output_trajectory_file.exists: 

168 output_trajectory_file.set_symlink_to(input_trajectory_file) 

169 if output_topology_file != MISSING_TOPOLOGY and not output_topology_file.exists: 

170 output_topology_file.set_symlink_to(input_topology_file) 

171 

172# Set a function to filter the standard topology file 

173# WARNING: This function has not been checked in depth to work properly 

174def standard_topology_filter ( 

175 input_topology_file : 'File', 

176 reference_structure : 'Structure', 

177 parsed_filter_selection : 'Selection', 

178 output_topology_file : 'File'): 

179 

180 # Load the topology 

181 topology = None 

182 with open(input_topology_file.path, 'r') as file: 

183 topology = json.load(file) 

184 

185 # Get filtered atom, residues and chain indices 

186 atom_indices = parsed_filter_selection.atom_indices 

187 residue_indices = reference_structure.get_selection_residue_indices(parsed_filter_selection) 

188 chain_indices = reference_structure.get_selection_chain_indices(parsed_filter_selection) 

189 

190 # Set backmapping 

191 atom_backmapping = { old_index: new_index for new_index, old_index in enumerate(atom_indices) } 

192 residue_backmapping = { old_index: new_index for new_index, old_index in enumerate(residue_indices) } 

193 chain_backmapping = { old_index: new_index for new_index, old_index in enumerate(chain_indices) } 

194 

195 # Set a function to get substract specific values of a list given by its indices 

196 def filter_by_indices (values : list, indices : List[int]) -> list: 

197 if values == None: 

198 return None 

199 return [ values[i] for i in indices ] 

200 

201 # Filter atomwise fields 

202 atom_names = filter_by_indices(topology['atom_names'], atom_indices) 

203 atom_elements = filter_by_indices(topology['atom_elements'], atom_indices) 

204 atom_charges = filter_by_indices(topology['atom_charges'], atom_indices) 

205 

206 # Handle atom fields which require backmapping 

207 old_atom_residue_indices = filter_by_indices(topology['atom_residue_indices'], atom_indices) 

208 atom_residue_indices = [ residue_backmapping[index] for index in old_atom_residue_indices ] 

209 atom_bonds = None 

210 raw_atom_bonds = topology.get('atom_bonds', None) 

211 if raw_atom_bonds: 

212 old_atom_bonds = filter_by_indices(raw_atom_bonds, atom_indices) 

213 atom_bonds = [ [ atom_backmapping(bond) for bond in bonds ] for bonds in old_atom_bonds ] 

214 

215 # Filter residuewise fields 

216 residue_names = filter_by_indices(topology['residue_names'], residue_indices) 

217 residue_numbers = filter_by_indices(topology['residue_numbers'], residue_indices) 

218 

219 # Handle residue fields which require backmapping 

220 residue_indices_set = set(residue_indices) 

221 old_residue_chain_indices = filter_by_indices(topology['residue_chain_indices'], residue_indices) 

222 residue_chain_indices = [ chain_backmapping[index] for index in old_residue_chain_indices ] 

223 # Handle icodes if they exist 

224 residue_icodes = None 

225 raw_residue_icodes = topology['residue_icodes'] 

226 if raw_residue_icodes: 

227 # Filter icodes 

228 old_residue_icodes = { index: icode for index, icode in raw_residue_icodes.items if index in residue_indices_set } 

229 # Backmap icodes 

230 residue_icodes = { residue_backmapping(index): icode for index, icode in old_residue_icodes.items() } 

231 # Handle PBC residues 

232 pbc_residues = [ residue_backmapping(index) for index in topology.get('pbc_residues', []) if index in residue_indices_set ] 

233 

234 # Handle chainwise fields 

235 chain_names = filter_by_indices(topology['chain_names'], chain_indices) 

236 

237 # Handle references 

238 references = None 

239 reference_types = None 

240 residue_reference_indices = None 

241 residue_reference_numbers = None 

242 # If they exist 

243 raw_references = topology['references'] 

244 if raw_references: 

245 residue_reference_numbers = filter_by_indices(topology['residue_reference_numbers'], residue_indices) 

246 old_residue_reference_indices = filter_by_indices(topology['residue_reference_indices'], residue_indices) 

247 reference_indices = [ index for index in set(old_residue_reference_indices) if index != None ] 

248 references_backmapping = { old_index: new_index for new_index, old_index in enumerate(reference_indices) } 

249 references = [ raw_references[old_index] for old_index in references_backmapping.keys() ] 

250 raw_reference_types = topology.get('reference_types', None) 

251 if raw_reference_types: 

252 reference_types = [ raw_reference_types[old_index] for old_index in references_backmapping.keys() ] 

253 references_backmapping[None] = None # To residues with no reference 

254 residue_reference_indices = [ references_backmapping[index] for index in old_residue_reference_indices ] 

255 

256 # Set the filtered topology 

257 output_topology = { 

258 'atom_names': atom_names, 

259 'atom_elements': atom_elements, 

260 'atom_charges': atom_charges, 

261 'atom_residue_indices': atom_residue_indices, 

262 'atom_bonds': atom_bonds, 

263 'residue_names': residue_names, 

264 'residue_numbers': residue_numbers, 

265 'residue_icodes': residue_icodes, 

266 'residue_chain_indices': residue_chain_indices, 

267 'chain_names': chain_names, 

268 'references': references, 

269 'reference_types': reference_types, 

270 'residue_reference_indices': residue_reference_indices, 

271 'residue_reference_numbers': residue_reference_numbers, 

272 'pbc_residues': pbc_residues, 

273 } 

274 # Wrtie the new topology 

275 save_json(output_topology, output_topology_file.path)