Coverage for mddb_workflow / tools / filter_atoms.py: 37%
132 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 18:45 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 18:45 +0000
1from os import remove
2from os.path import exists
3import json
5import pytraj as pt
7from mddb_workflow.utils.constants import RAW_CHARGES_FILENAME
8from mddb_workflow.utils.structures import Structure
9from mddb_workflow.utils.auxiliar import save_json, MISSING_TOPOLOGY, is_standard_topology
10from mddb_workflow.utils.gmx_spells import get_tpr_atom_count, tpr_filter, xtc_filter, pdb_filter
11from mddb_workflow.utils.type_hints import *
12from mddb_workflow.tools.get_charges import get_raw_charges
14# Set the gromacs indices filename
15index_filename = 'filter.ndx'
16# Set the name for the group name in gromacs ndx file
17filter_group_name = "not_water_or_counter_ions"
19def filter_atoms (
20 input_structure_file : 'File',
21 input_trajectory_file : 'File',
22 input_topology_file : Union['File', Exception],
23 output_structure_file : 'File',
24 output_trajectory_file : 'File',
25 output_topology_file : Union['File', Exception],
26 # Reference structure used to parse the actual selection
27 reference_structure : 'Structure',
28 # Filter selection may be a custom selection or true
29 # If true then we run a default filtering of water and counter ions
30 filter_selection : bool | str,
31 filter_selection_syntax : str = 'vmd',
32):
33 """ Filter atoms of all input topologies by removing atoms and ions.
34 As an exception, some water and ions may be not removed if specified.
35 At the end, all topologies must match in atoms count. """
37 # Handle missing filter selection
38 if not filter_selection:
39 return
41 # Parse the selection to be filtered
42 # WARNING: Note that the structure is not corrected at this point and there may be limitations
43 parsed_filter_selection = None
44 if filter_selection == True: parsed_filter_selection = reference_structure.select_water_and_counter_ions()
45 else: parsed_filter_selection = reference_structure.select(filter_selection, syntax=filter_selection_syntax)
47 # Invert the parsed selection to get the atoms to remain
48 keep_selection = reference_structure.invert_selection(parsed_filter_selection)
50 # Set the pytraj mask to filter the desired atoms from the structure
51 filter_mask = keep_selection.to_pytraj()
53 # Load the structure and trajectory
54 trajectory = pt.iterload(input_trajectory_file.path, input_structure_file.path)
55 pt_topology = trajectory.topology
56 atoms_count = pt_topology.n_atoms
58 # Set the filtered structure
59 filtered_pt_topology = pt_topology[filter_mask]
60 filtered_structure_atoms_count = filtered_pt_topology.n_atoms
62 # Check if both the normal and the filtered topologies have the same number of atoms
63 # If not, filter the whole trajectory and overwrite both topologies and trajectory
64 print(f'Total number of atoms: {atoms_count}')
65 print(f'Filtered number of atoms: {filtered_structure_atoms_count}')
66 if filtered_structure_atoms_count < atoms_count:
67 print('Filtering structure and trajectory...')
68 # Filter both structure and trajectory using Gromacs, since is more efficient than pytraj
69 # Set up an index file with all atom indices manually
70 # As long as indices are for atoms and not residues there should never be any incompatibility
71 keep_selection.to_ndx_file(output_filepath = index_filename)
72 # Filter the trajectory
73 xtc_filter(input_structure_file.path,
74 input_trajectory_file.path,
75 output_trajectory_file.path,
76 index_filename,
77 filter_group_name)
78 # Filter the structure
79 pdb_filter(
80 input_structure_file.path,
81 output_structure_file.path,
82 index_filename,
83 filter_group_name)
85 # Filter topology according to the file format
86 if input_topology_file != MISSING_TOPOLOGY and input_topology_file.exists:
87 print('Filtering topology...')
88 # Pytraj supported formats
89 if input_topology_file.is_pytraj_supported():
90 # Load the topology and count its atoms
91 pt_topology = pt.load_topology(filename=input_topology_file.path)
92 topology_atoms_count = pt_topology.n_atoms
93 print(f'Topology atoms count: {topology_atoms_count}')
94 # Filter the desired atoms using the mask and then count them
95 filtered_pt_topology = pt_topology[filter_mask]
96 filtered_topology_atoms_count = filtered_pt_topology.n_atoms
97 # If there is a difference in atom counts then write the filtered topology
98 if filtered_topology_atoms_count < topology_atoms_count:
99 # WARNING: If the output topology is a symlink it will try to overwrite the origin
100 # Remove it to avoid overwriting input data
101 if output_topology_file.is_symlink(): output_topology_file.remove()
102 # Now write the filtered topology
103 pt.write_parm(
104 filename=output_topology_file.path,
105 top=filtered_pt_topology,
106 format=input_topology_file.get_pytraj_parm_format(),
107 overwrite=True
108 )
109 # Gromacs format format
110 elif input_topology_file.format == 'tpr':
111 # Get the input tpr atom count
112 topology_atoms_count = get_tpr_atom_count(input_topology_file.path)
113 print(f'Topology atoms count: {topology_atoms_count}')
114 # If the number of atoms is greater than expected then filter the tpr file
115 if topology_atoms_count > filtered_structure_atoms_count:
116 if not exists(index_filename):
117 # In order to filter the tpr we need the filter.ndx file
118 # This must be generated from a pytraj supported topology that matches the number of atoms in the tpr file
119 raise ValueError('Topology atoms number does not match the structure atoms number and tpr files can not be filtered alone')
120 tpr_filter(
121 input_topology_file.path,
122 output_topology_file.path,
123 index_filename,
124 filter_group_name)
125 # Get the output tpr atom count
126 filtered_topology_atoms_count = get_tpr_atom_count(output_topology_file.path)
127 else:
128 filtered_topology_atoms_count = topology_atoms_count
129 # Standard topology
130 elif is_standard_topology(input_topology_file):
131 standard_topology = None
132 with open(input_topology_file.path, 'r') as file:
133 standard_topology = json.load(file)
134 topology_atoms_count = len(standard_topology['atom_names'])
135 print(f'Topology atoms count: {topology_atoms_count}')
136 # Make it match since there is no problem when these 2 do not match
137 filtered_topology_atoms_count = filtered_structure_atoms_count
138 # If the number of charges does not match the number of atoms then filter the topology
139 if topology_atoms_count != filtered_structure_atoms_count:
140 standard_topology_filter(input_topology_file, reference_structure, parsed_filter_selection, output_topology_file)
141 # Raw charges
142 elif input_topology_file.filename == RAW_CHARGES_FILENAME:
143 charges = get_raw_charges(input_topology_file.path)
144 # Nothing to do here. It better matches by defualt or we have a problem
145 filtered_topology_atoms_count = len(charges)
146 print(f'Topology atoms count: {filtered_topology_atoms_count}')
147 else:
148 raise ValueError(f'Topology file ({input_topology_file.filename}) is in a non supported format')
150 print(f'Filtered topology atoms: {filtered_topology_atoms_count}')
152 # Both filtered structure and topology must have the same number of atoms
153 if filtered_structure_atoms_count != filtered_topology_atoms_count:
154 print(f'Filtered structure atoms: {filtered_structure_atoms_count}')
155 raise ValueError('Filtered atom counts in topology and charges does not match')
157 # Remove the index file in case it was created
158 if exists(index_filename):
159 remove(index_filename)
161 # Check if any of the output files does not exist
162 # If so, then it means there was nothing to filter
163 # However the output file is expected, so me make symlink
164 if not output_structure_file.exists:
165 output_structure_file.set_symlink_to(input_structure_file)
166 if not output_trajectory_file.exists:
167 output_trajectory_file.set_symlink_to(input_trajectory_file)
168 if output_topology_file != MISSING_TOPOLOGY and not output_topology_file.exists:
169 output_topology_file.set_symlink_to(input_topology_file)
171# Set a function to filter the standard topology file
172# WARNING: This function has not been checked in depth to work properly
173def standard_topology_filter (
174 input_topology_file : 'File',
175 reference_structure : 'Structure',
176 parsed_filter_selection : 'Selection',
177 output_topology_file : 'File'):
179 # Load the topology
180 topology = None
181 with open(input_topology_file.path, 'r') as file:
182 topology = json.load(file)
184 # Get filtered atom, residues and chain indices
185 atom_indices = parsed_filter_selection.atom_indices
186 residue_indices = reference_structure.get_selection_residue_indices(parsed_filter_selection)
187 chain_indices = reference_structure.get_selection_chain_indices(parsed_filter_selection)
189 # Set backmapping
190 atom_backmapping = { old_index: new_index for new_index, old_index in enumerate(atom_indices) }
191 residue_backmapping = { old_index: new_index for new_index, old_index in enumerate(residue_indices) }
192 chain_backmapping = { old_index: new_index for new_index, old_index in enumerate(chain_indices) }
194 # Set a function to get substract specific values of a list given by its indices
195 def filter_by_indices (values : list, indices : list[int]) -> list:
196 if values == None:
197 return None
198 return [ values[i] for i in indices ]
200 # Filter atomwise fields
201 atom_names = filter_by_indices(topology['atom_names'], atom_indices)
202 atom_elements = filter_by_indices(topology['atom_elements'], atom_indices)
203 atom_charges = filter_by_indices(topology['atom_charges'], atom_indices)
205 # Handle atom fields which require backmapping
206 old_atom_residue_indices = filter_by_indices(topology['atom_residue_indices'], atom_indices)
207 atom_residue_indices = [ residue_backmapping[index] for index in old_atom_residue_indices ]
208 atom_bonds = None
209 raw_atom_bonds = topology.get('atom_bonds', None)
210 if raw_atom_bonds:
211 old_atom_bonds = filter_by_indices(raw_atom_bonds, atom_indices)
212 atom_bonds = [ [ atom_backmapping(bond) for bond in bonds ] for bonds in old_atom_bonds ]
214 # Filter residuewise fields
215 residue_names = filter_by_indices(topology['residue_names'], residue_indices)
216 residue_numbers = filter_by_indices(topology['residue_numbers'], residue_indices)
218 # Handle residue fields which require backmapping
219 residue_indices_set = set(residue_indices)
220 old_residue_chain_indices = filter_by_indices(topology['residue_chain_indices'], residue_indices)
221 residue_chain_indices = [ chain_backmapping[index] for index in old_residue_chain_indices ]
222 # Handle icodes if they exist
223 residue_icodes = None
224 raw_residue_icodes = topology['residue_icodes']
225 if raw_residue_icodes:
226 # Filter icodes
227 old_residue_icodes = { index: icode for index, icode in raw_residue_icodes.items if index in residue_indices_set }
228 # Backmap icodes
229 residue_icodes = { residue_backmapping(index): icode for index, icode in old_residue_icodes.items() }
230 # Handle PBC residues
231 pbc_residues = [ residue_backmapping(index) for index in topology.get('pbc_residues', []) if index in residue_indices_set ]
233 # Handle chainwise fields
234 chain_names = filter_by_indices(topology['chain_names'], chain_indices)
236 # Handle references
237 references = None
238 reference_types = None
239 residue_reference_indices = None
240 residue_reference_numbers = None
241 # If they exist
242 raw_references = topology['references']
243 if raw_references:
244 residue_reference_numbers = filter_by_indices(topology['residue_reference_numbers'], residue_indices)
245 old_residue_reference_indices = filter_by_indices(topology['residue_reference_indices'], residue_indices)
246 reference_indices = [ index for index in set(old_residue_reference_indices) if index != None ]
247 references_backmapping = { old_index: new_index for new_index, old_index in enumerate(reference_indices) }
248 references = [ raw_references[old_index] for old_index in references_backmapping.keys() ]
249 raw_reference_types = topology.get('reference_types', None)
250 if raw_reference_types:
251 reference_types = [ raw_reference_types[old_index] for old_index in references_backmapping.keys() ]
252 references_backmapping[None] = None # To residues with no reference
253 residue_reference_indices = [ references_backmapping[index] for index in old_residue_reference_indices ]
255 # Set the filtered topology
256 output_topology = {
257 'atom_names': atom_names,
258 'atom_elements': atom_elements,
259 'atom_charges': atom_charges,
260 'atom_residue_indices': atom_residue_indices,
261 'atom_bonds': atom_bonds,
262 'residue_names': residue_names,
263 'residue_numbers': residue_numbers,
264 'residue_icodes': residue_icodes,
265 'residue_chain_indices': residue_chain_indices,
266 'chain_names': chain_names,
267 'references': references,
268 'reference_types': reference_types,
269 'residue_reference_indices': residue_reference_indices,
270 'residue_reference_numbers': residue_reference_numbers,
271 'pbc_residues': pbc_residues,
272 }
273 # Wrtie the new topology
274 save_json(output_topology, output_topology_file.path)