Coverage for mddb_workflow / tools / filter_atoms.py: 37%

132 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 18:45 +0000

1from os import remove 

2from os.path import exists 

3import json 

4 

5import pytraj as pt 

6 

7from mddb_workflow.utils.constants import RAW_CHARGES_FILENAME 

8from mddb_workflow.utils.structures import Structure 

9from mddb_workflow.utils.auxiliar import save_json, MISSING_TOPOLOGY, is_standard_topology 

10from mddb_workflow.utils.gmx_spells import get_tpr_atom_count, tpr_filter, xtc_filter, pdb_filter 

11from mddb_workflow.utils.type_hints import * 

12from mddb_workflow.tools.get_charges import get_raw_charges 

13 

14# Set the gromacs indices filename 

15index_filename = 'filter.ndx' 

16# Set the name for the group name in gromacs ndx file 

17filter_group_name = "not_water_or_counter_ions" 

18 

19def filter_atoms ( 

20 input_structure_file : 'File', 

21 input_trajectory_file : 'File', 

22 input_topology_file : Union['File', Exception], 

23 output_structure_file : 'File', 

24 output_trajectory_file : 'File', 

25 output_topology_file : Union['File', Exception], 

26 # Reference structure used to parse the actual selection 

27 reference_structure : 'Structure', 

28 # Filter selection may be a custom selection or true 

29 # If true then we run a default filtering of water and counter ions 

30 filter_selection : bool | str, 

31 filter_selection_syntax : str = 'vmd', 

32): 

33 """ Filter atoms of all input topologies by removing atoms and ions. 

34 As an exception, some water and ions may be not removed if specified. 

35 At the end, all topologies must match in atoms count. """ 

36 

37 # Handle missing filter selection 

38 if not filter_selection: 

39 return 

40 

41 # Parse the selection to be filtered 

42 # WARNING: Note that the structure is not corrected at this point and there may be limitations 

43 parsed_filter_selection = None 

44 if filter_selection == True: parsed_filter_selection = reference_structure.select_water_and_counter_ions() 

45 else: parsed_filter_selection = reference_structure.select(filter_selection, syntax=filter_selection_syntax) 

46 

47 # Invert the parsed selection to get the atoms to remain 

48 keep_selection = reference_structure.invert_selection(parsed_filter_selection) 

49 

50 # Set the pytraj mask to filter the desired atoms from the structure 

51 filter_mask = keep_selection.to_pytraj() 

52 

53 # Load the structure and trajectory 

54 trajectory = pt.iterload(input_trajectory_file.path, input_structure_file.path) 

55 pt_topology = trajectory.topology 

56 atoms_count = pt_topology.n_atoms 

57 

58 # Set the filtered structure 

59 filtered_pt_topology = pt_topology[filter_mask] 

60 filtered_structure_atoms_count = filtered_pt_topology.n_atoms 

61 

62 # Check if both the normal and the filtered topologies have the same number of atoms 

63 # If not, filter the whole trajectory and overwrite both topologies and trajectory 

64 print(f'Total number of atoms: {atoms_count}') 

65 print(f'Filtered number of atoms: {filtered_structure_atoms_count}') 

66 if filtered_structure_atoms_count < atoms_count: 

67 print('Filtering structure and trajectory...') 

68 # Filter both structure and trajectory using Gromacs, since is more efficient than pytraj 

69 # Set up an index file with all atom indices manually 

70 # As long as indices are for atoms and not residues there should never be any incompatibility 

71 keep_selection.to_ndx_file(output_filepath = index_filename) 

72 # Filter the trajectory 

73 xtc_filter(input_structure_file.path, 

74 input_trajectory_file.path, 

75 output_trajectory_file.path, 

76 index_filename, 

77 filter_group_name) 

78 # Filter the structure 

79 pdb_filter( 

80 input_structure_file.path, 

81 output_structure_file.path, 

82 index_filename, 

83 filter_group_name) 

84 

85 # Filter topology according to the file format 

86 if input_topology_file != MISSING_TOPOLOGY and input_topology_file.exists: 

87 print('Filtering topology...') 

88 # Pytraj supported formats 

89 if input_topology_file.is_pytraj_supported(): 

90 # Load the topology and count its atoms 

91 pt_topology = pt.load_topology(filename=input_topology_file.path) 

92 topology_atoms_count = pt_topology.n_atoms 

93 print(f'Topology atoms count: {topology_atoms_count}') 

94 # Filter the desired atoms using the mask and then count them 

95 filtered_pt_topology = pt_topology[filter_mask] 

96 filtered_topology_atoms_count = filtered_pt_topology.n_atoms 

97 # If there is a difference in atom counts then write the filtered topology 

98 if filtered_topology_atoms_count < topology_atoms_count: 

99 # WARNING: If the output topology is a symlink it will try to overwrite the origin 

100 # Remove it to avoid overwriting input data 

101 if output_topology_file.is_symlink(): output_topology_file.remove() 

102 # Now write the filtered topology 

103 pt.write_parm( 

104 filename=output_topology_file.path, 

105 top=filtered_pt_topology, 

106 format=input_topology_file.get_pytraj_parm_format(), 

107 overwrite=True 

108 ) 

109 # Gromacs format format 

110 elif input_topology_file.format == 'tpr': 

111 # Get the input tpr atom count 

112 topology_atoms_count = get_tpr_atom_count(input_topology_file.path) 

113 print(f'Topology atoms count: {topology_atoms_count}') 

114 # If the number of atoms is greater than expected then filter the tpr file 

115 if topology_atoms_count > filtered_structure_atoms_count: 

116 if not exists(index_filename): 

117 # In order to filter the tpr we need the filter.ndx file 

118 # This must be generated from a pytraj supported topology that matches the number of atoms in the tpr file 

119 raise ValueError('Topology atoms number does not match the structure atoms number and tpr files can not be filtered alone') 

120 tpr_filter( 

121 input_topology_file.path, 

122 output_topology_file.path, 

123 index_filename, 

124 filter_group_name) 

125 # Get the output tpr atom count 

126 filtered_topology_atoms_count = get_tpr_atom_count(output_topology_file.path) 

127 else: 

128 filtered_topology_atoms_count = topology_atoms_count 

129 # Standard topology 

130 elif is_standard_topology(input_topology_file): 

131 standard_topology = None 

132 with open(input_topology_file.path, 'r') as file: 

133 standard_topology = json.load(file) 

134 topology_atoms_count = len(standard_topology['atom_names']) 

135 print(f'Topology atoms count: {topology_atoms_count}') 

136 # Make it match since there is no problem when these 2 do not match 

137 filtered_topology_atoms_count = filtered_structure_atoms_count 

138 # If the number of charges does not match the number of atoms then filter the topology 

139 if topology_atoms_count != filtered_structure_atoms_count: 

140 standard_topology_filter(input_topology_file, reference_structure, parsed_filter_selection, output_topology_file) 

141 # Raw charges 

142 elif input_topology_file.filename == RAW_CHARGES_FILENAME: 

143 charges = get_raw_charges(input_topology_file.path) 

144 # Nothing to do here. It better matches by defualt or we have a problem 

145 filtered_topology_atoms_count = len(charges) 

146 print(f'Topology atoms count: {filtered_topology_atoms_count}') 

147 else: 

148 raise ValueError(f'Topology file ({input_topology_file.filename}) is in a non supported format') 

149 

150 print(f'Filtered topology atoms: {filtered_topology_atoms_count}') 

151 

152 # Both filtered structure and topology must have the same number of atoms 

153 if filtered_structure_atoms_count != filtered_topology_atoms_count: 

154 print(f'Filtered structure atoms: {filtered_structure_atoms_count}') 

155 raise ValueError('Filtered atom counts in topology and charges does not match') 

156 

157 # Remove the index file in case it was created 

158 if exists(index_filename): 

159 remove(index_filename) 

160 

161 # Check if any of the output files does not exist 

162 # If so, then it means there was nothing to filter 

163 # However the output file is expected, so me make symlink 

164 if not output_structure_file.exists: 

165 output_structure_file.set_symlink_to(input_structure_file) 

166 if not output_trajectory_file.exists: 

167 output_trajectory_file.set_symlink_to(input_trajectory_file) 

168 if output_topology_file != MISSING_TOPOLOGY and not output_topology_file.exists: 

169 output_topology_file.set_symlink_to(input_topology_file) 

170 

171# Set a function to filter the standard topology file 

172# WARNING: This function has not been checked in depth to work properly 

173def standard_topology_filter ( 

174 input_topology_file : 'File', 

175 reference_structure : 'Structure', 

176 parsed_filter_selection : 'Selection', 

177 output_topology_file : 'File'): 

178 

179 # Load the topology 

180 topology = None 

181 with open(input_topology_file.path, 'r') as file: 

182 topology = json.load(file) 

183 

184 # Get filtered atom, residues and chain indices 

185 atom_indices = parsed_filter_selection.atom_indices 

186 residue_indices = reference_structure.get_selection_residue_indices(parsed_filter_selection) 

187 chain_indices = reference_structure.get_selection_chain_indices(parsed_filter_selection) 

188 

189 # Set backmapping 

190 atom_backmapping = { old_index: new_index for new_index, old_index in enumerate(atom_indices) } 

191 residue_backmapping = { old_index: new_index for new_index, old_index in enumerate(residue_indices) } 

192 chain_backmapping = { old_index: new_index for new_index, old_index in enumerate(chain_indices) } 

193 

194 # Set a function to get substract specific values of a list given by its indices 

195 def filter_by_indices (values : list, indices : list[int]) -> list: 

196 if values == None: 

197 return None 

198 return [ values[i] for i in indices ] 

199 

200 # Filter atomwise fields 

201 atom_names = filter_by_indices(topology['atom_names'], atom_indices) 

202 atom_elements = filter_by_indices(topology['atom_elements'], atom_indices) 

203 atom_charges = filter_by_indices(topology['atom_charges'], atom_indices) 

204 

205 # Handle atom fields which require backmapping 

206 old_atom_residue_indices = filter_by_indices(topology['atom_residue_indices'], atom_indices) 

207 atom_residue_indices = [ residue_backmapping[index] for index in old_atom_residue_indices ] 

208 atom_bonds = None 

209 raw_atom_bonds = topology.get('atom_bonds', None) 

210 if raw_atom_bonds: 

211 old_atom_bonds = filter_by_indices(raw_atom_bonds, atom_indices) 

212 atom_bonds = [ [ atom_backmapping(bond) for bond in bonds ] for bonds in old_atom_bonds ] 

213 

214 # Filter residuewise fields 

215 residue_names = filter_by_indices(topology['residue_names'], residue_indices) 

216 residue_numbers = filter_by_indices(topology['residue_numbers'], residue_indices) 

217 

218 # Handle residue fields which require backmapping 

219 residue_indices_set = set(residue_indices) 

220 old_residue_chain_indices = filter_by_indices(topology['residue_chain_indices'], residue_indices) 

221 residue_chain_indices = [ chain_backmapping[index] for index in old_residue_chain_indices ] 

222 # Handle icodes if they exist 

223 residue_icodes = None 

224 raw_residue_icodes = topology['residue_icodes'] 

225 if raw_residue_icodes: 

226 # Filter icodes 

227 old_residue_icodes = { index: icode for index, icode in raw_residue_icodes.items if index in residue_indices_set } 

228 # Backmap icodes 

229 residue_icodes = { residue_backmapping(index): icode for index, icode in old_residue_icodes.items() } 

230 # Handle PBC residues 

231 pbc_residues = [ residue_backmapping(index) for index in topology.get('pbc_residues', []) if index in residue_indices_set ] 

232 

233 # Handle chainwise fields 

234 chain_names = filter_by_indices(topology['chain_names'], chain_indices) 

235 

236 # Handle references 

237 references = None 

238 reference_types = None 

239 residue_reference_indices = None 

240 residue_reference_numbers = None 

241 # If they exist 

242 raw_references = topology['references'] 

243 if raw_references: 

244 residue_reference_numbers = filter_by_indices(topology['residue_reference_numbers'], residue_indices) 

245 old_residue_reference_indices = filter_by_indices(topology['residue_reference_indices'], residue_indices) 

246 reference_indices = [ index for index in set(old_residue_reference_indices) if index != None ] 

247 references_backmapping = { old_index: new_index for new_index, old_index in enumerate(reference_indices) } 

248 references = [ raw_references[old_index] for old_index in references_backmapping.keys() ] 

249 raw_reference_types = topology.get('reference_types', None) 

250 if raw_reference_types: 

251 reference_types = [ raw_reference_types[old_index] for old_index in references_backmapping.keys() ] 

252 references_backmapping[None] = None # To residues with no reference 

253 residue_reference_indices = [ references_backmapping[index] for index in old_residue_reference_indices ] 

254 

255 # Set the filtered topology 

256 output_topology = { 

257 'atom_names': atom_names, 

258 'atom_elements': atom_elements, 

259 'atom_charges': atom_charges, 

260 'atom_residue_indices': atom_residue_indices, 

261 'atom_bonds': atom_bonds, 

262 'residue_names': residue_names, 

263 'residue_numbers': residue_numbers, 

264 'residue_icodes': residue_icodes, 

265 'residue_chain_indices': residue_chain_indices, 

266 'chain_names': chain_names, 

267 'references': references, 

268 'reference_types': reference_types, 

269 'residue_reference_indices': residue_reference_indices, 

270 'residue_reference_numbers': residue_reference_numbers, 

271 'pbc_residues': pbc_residues, 

272 } 

273 # Wrtie the new topology 

274 save_json(output_topology, output_topology_file.path)