Coverage for model_workflow/tools/filter_atoms.py: 38%
133 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
1from os import remove
2from os.path import exists
3from subprocess import run, PIPE, Popen
4import json
6import pytraj as pt
8from model_workflow.utils.constants import STANDARD_TOPOLOGY_FILENAME, RAW_CHARGES_FILENAME
9from model_workflow.utils.structures import Structure
10from model_workflow.utils.auxiliar import save_json, MISSING_TOPOLOGY
11from model_workflow.utils.gmx_spells import get_tpr_atom_count, tpr_filter, xtc_filter, pdb_filter
12from model_workflow.utils.type_hints import *
13from model_workflow.tools.get_charges import get_raw_charges
15# Set the gromacs indices filename
16index_filename = 'filter.ndx'
17# Set the name for the group name in gromacs ndx file
18filter_group_name = "not_water_or_counter_ions"
20# Filter atoms of all input topologies by remvoing atoms and ions
21# As an exception, some water and ions may be not removed if specified
22# At the end, all topologies must match in atoms count
23def filter_atoms (
24 input_structure_file : 'File',
25 input_trajectory_file : 'File',
26 input_topology_file : Union['File', Exception],
27 output_structure_file : 'File',
28 output_trajectory_file : 'File',
29 output_topology_file : Union['File', Exception],
30 # Reference structure used to parse the actual selection
31 reference_structure : 'Structure',
32 # Filter selection may be a custom selection or true
33 # If true then we run a default filtering of water and counter ions
34 filter_selection : Union[bool, str],
35 filter_selection_syntax : str = 'vmd',
36):
38 # Handle missing filter selection
39 if not filter_selection:
40 return
42 # Parse the selection to be filtered
43 # WARNING: Note that the structure is not corrected at this point and there may be limitations
44 parsed_filter_selection = None
45 if filter_selection == True: parsed_filter_selection = reference_structure.select_water_and_counter_ions()
46 else: parsed_filter_selection = reference_structure.select(filter_selection, syntax=filter_selection_syntax)
48 # Invert the parsed selection to get the atoms to remain
49 keep_selection = reference_structure.invert_selection(parsed_filter_selection)
51 # Set the pytraj mask to filter the desired atoms from the structure
52 filter_mask = keep_selection.to_pytraj()
54 # Load the structure and trajectory
55 trajectory = pt.iterload(input_trajectory_file.path, input_structure_file.path)
56 pt_topology = trajectory.topology
57 atoms_count = pt_topology.n_atoms
59 # Set the filtered structure
60 filtered_pt_topology = pt_topology[filter_mask]
61 filtered_structure_atoms_count = filtered_pt_topology.n_atoms
63 # Check if both the normal and the filtered topologies have the same number of atoms
64 # If not, filter the whole trajectory and overwrite both topologies and trajectory
65 print(f'Total number of atoms: {atoms_count}')
66 print(f'Filtered number of atoms: {filtered_structure_atoms_count}')
67 if filtered_structure_atoms_count < atoms_count:
68 print('Filtering structure and trajectory...')
69 # Filter both structure and trajectory using Gromacs, since is more efficient than pytraj
70 # Set up an index file with all atom indices manually
71 # As long as indices are for atoms and not residues there should never be any incompatibility
72 keep_selection.to_ndx_file(output_filepath = index_filename)
73 # Filter the trajectory
74 xtc_filter(input_structure_file.path,
75 input_trajectory_file.path,
76 output_trajectory_file.path,
77 index_filename,
78 filter_group_name)
79 # Filter the structure
80 pdb_filter(
81 input_structure_file.path,
82 output_structure_file.path,
83 index_filename,
84 filter_group_name)
86 # Filter topology according to the file format
87 if input_topology_file != MISSING_TOPOLOGY and input_topology_file.exists:
88 print('Filtering topology...')
89 # Pytraj supported formats
90 if input_topology_file.is_pytraj_supported():
91 # Load the topology and count its atoms
92 pt_topology = pt.load_topology(filename=input_topology_file.path)
93 topology_atoms_count = pt_topology.n_atoms
94 print(f'Topology atoms count: {topology_atoms_count}')
95 # Filter the desired atoms using the mask and then count them
96 filtered_pt_topology = pt_topology[filter_mask]
97 filtered_topology_atoms_count = filtered_pt_topology.n_atoms
98 # If there is a difference in atom counts then write the filtered topology
99 if filtered_topology_atoms_count < topology_atoms_count:
100 # WARNING: If the output topology is a symlink it will try to overwrite the origin
101 # Remove it to avoid overwriting input data
102 if output_topology_file.is_symlink(): output_topology_file.remove()
103 # Now write the filtered topology
104 pt.write_parm(
105 filename=output_topology_file.path,
106 top=filtered_pt_topology,
107 format=input_topology_file.get_pytraj_parm_format(),
108 overwrite=True
109 )
110 # Gromacs format format
111 elif input_topology_file.format == 'tpr':
112 # Get the input tpr atom count
113 topology_atoms_count = get_tpr_atom_count(input_topology_file.path)
114 print(f'Topology atoms count: {topology_atoms_count}')
115 # If the number of atoms is greater than expected then filter the tpr file
116 if topology_atoms_count > filtered_structure_atoms_count:
117 if not exists(index_filename):
118 # In order to filter the tpr we need the filter.ndx file
119 # This must be generated from a pytraj supported topology that matches the number of atoms in the tpr file
120 raise ValueError('Topology atoms number does not match the structure atoms number and tpr files can not be filtered alone')
121 tpr_filter(
122 input_topology_file.path,
123 output_topology_file.path,
124 index_filename,
125 filter_group_name)
126 # Get the output tpr atom count
127 filtered_topology_atoms_count = get_tpr_atom_count(output_topology_file.path)
128 else:
129 filtered_topology_atoms_count = topology_atoms_count
130 # Standard topology
131 elif input_topology_file.filename == STANDARD_TOPOLOGY_FILENAME:
132 standard_topology = None
133 with open(input_topology_file.path, 'r') as file:
134 standard_topology = json.load(file)
135 topology_atoms_count = len(standard_topology['atom_names'])
136 print(f'Topology atoms count: {topology_atoms_count}')
137 # Make it match since there is no problem when these 2 do not match
138 filtered_topology_atoms_count = filtered_structure_atoms_count
139 # If the number of charges does not match the number of atoms then filter the topology
140 if topology_atoms_count != filtered_structure_atoms_count:
141 standard_topology_filter(input_topology_file, reference_structure, parsed_filter_selection, output_topology_file)
142 # Raw charges
143 elif input_topology_file.filename == RAW_CHARGES_FILENAME:
144 charges = get_raw_charges(input_topology_file.path)
145 # Nothing to do here. It better matches by defualt or we have a problem
146 filtered_topology_atoms_count = len(charges)
147 print(f'Topology atoms count: {filtered_topology_atoms_count}')
148 else:
149 raise ValueError(f'Topology file ({input_topology_file.filename}) is in a non supported format')
151 print(f'Filtered topology atoms: {filtered_topology_atoms_count}')
153 # Both filtered structure and topology must have the same number of atoms
154 if filtered_structure_atoms_count != filtered_topology_atoms_count:
155 print(f'Filtered structure atoms: {filtered_structure_atoms_count}')
156 raise ValueError('Filtered atom counts in topology and charges does not match')
158 # Remove the index file in case it was created
159 if exists(index_filename):
160 remove(index_filename)
162 # Check if any of the output files does not exist
163 # If so, then it means there was nothing to filter
164 # However the output file is expected, so me make symlink
165 if not output_structure_file.exists:
166 output_structure_file.set_symlink_to(input_structure_file)
167 if not output_trajectory_file.exists:
168 output_trajectory_file.set_symlink_to(input_trajectory_file)
169 if output_topology_file != MISSING_TOPOLOGY and not output_topology_file.exists:
170 output_topology_file.set_symlink_to(input_topology_file)
172# Set a function to filter the standard topology file
173# WARNING: This function has not been checked in depth to work properly
174def standard_topology_filter (
175 input_topology_file : 'File',
176 reference_structure : 'Structure',
177 parsed_filter_selection : 'Selection',
178 output_topology_file : 'File'):
180 # Load the topology
181 topology = None
182 with open(input_topology_file.path, 'r') as file:
183 topology = json.load(file)
185 # Get filtered atom, residues and chain indices
186 atom_indices = parsed_filter_selection.atom_indices
187 residue_indices = reference_structure.get_selection_residue_indices(parsed_filter_selection)
188 chain_indices = reference_structure.get_selection_chain_indices(parsed_filter_selection)
190 # Set backmapping
191 atom_backmapping = { old_index: new_index for new_index, old_index in enumerate(atom_indices) }
192 residue_backmapping = { old_index: new_index for new_index, old_index in enumerate(residue_indices) }
193 chain_backmapping = { old_index: new_index for new_index, old_index in enumerate(chain_indices) }
195 # Set a function to get substract specific values of a list given by its indices
196 def filter_by_indices (values : list, indices : List[int]) -> list:
197 if values == None:
198 return None
199 return [ values[i] for i in indices ]
201 # Filter atomwise fields
202 atom_names = filter_by_indices(topology['atom_names'], atom_indices)
203 atom_elements = filter_by_indices(topology['atom_elements'], atom_indices)
204 atom_charges = filter_by_indices(topology['atom_charges'], atom_indices)
206 # Handle atom fields which require backmapping
207 old_atom_residue_indices = filter_by_indices(topology['atom_residue_indices'], atom_indices)
208 atom_residue_indices = [ residue_backmapping[index] for index in old_atom_residue_indices ]
209 atom_bonds = None
210 raw_atom_bonds = topology.get('atom_bonds', None)
211 if raw_atom_bonds:
212 old_atom_bonds = filter_by_indices(raw_atom_bonds, atom_indices)
213 atom_bonds = [ [ atom_backmapping(bond) for bond in bonds ] for bonds in old_atom_bonds ]
215 # Filter residuewise fields
216 residue_names = filter_by_indices(topology['residue_names'], residue_indices)
217 residue_numbers = filter_by_indices(topology['residue_numbers'], residue_indices)
219 # Handle residue fields which require backmapping
220 residue_indices_set = set(residue_indices)
221 old_residue_chain_indices = filter_by_indices(topology['residue_chain_indices'], residue_indices)
222 residue_chain_indices = [ chain_backmapping[index] for index in old_residue_chain_indices ]
223 # Handle icodes if they exist
224 residue_icodes = None
225 raw_residue_icodes = topology['residue_icodes']
226 if raw_residue_icodes:
227 # Filter icodes
228 old_residue_icodes = { index: icode for index, icode in raw_residue_icodes.items if index in residue_indices_set }
229 # Backmap icodes
230 residue_icodes = { residue_backmapping(index): icode for index, icode in old_residue_icodes.items() }
231 # Handle PBC residues
232 pbc_residues = [ residue_backmapping(index) for index in topology.get('pbc_residues', []) if index in residue_indices_set ]
234 # Handle chainwise fields
235 chain_names = filter_by_indices(topology['chain_names'], chain_indices)
237 # Handle references
238 references = None
239 reference_types = None
240 residue_reference_indices = None
241 residue_reference_numbers = None
242 # If they exist
243 raw_references = topology['references']
244 if raw_references:
245 residue_reference_numbers = filter_by_indices(topology['residue_reference_numbers'], residue_indices)
246 old_residue_reference_indices = filter_by_indices(topology['residue_reference_indices'], residue_indices)
247 reference_indices = [ index for index in set(old_residue_reference_indices) if index != None ]
248 references_backmapping = { old_index: new_index for new_index, old_index in enumerate(reference_indices) }
249 references = [ raw_references[old_index] for old_index in references_backmapping.keys() ]
250 raw_reference_types = topology.get('reference_types', None)
251 if raw_reference_types:
252 reference_types = [ raw_reference_types[old_index] for old_index in references_backmapping.keys() ]
253 references_backmapping[None] = None # To residues with no reference
254 residue_reference_indices = [ references_backmapping[index] for index in old_residue_reference_indices ]
256 # Set the filtered topology
257 output_topology = {
258 'atom_names': atom_names,
259 'atom_elements': atom_elements,
260 'atom_charges': atom_charges,
261 'atom_residue_indices': atom_residue_indices,
262 'atom_bonds': atom_bonds,
263 'residue_names': residue_names,
264 'residue_numbers': residue_numbers,
265 'residue_icodes': residue_icodes,
266 'residue_chain_indices': residue_chain_indices,
267 'chain_names': chain_names,
268 'references': references,
269 'reference_types': reference_types,
270 'residue_reference_indices': residue_reference_indices,
271 'residue_reference_numbers': residue_reference_numbers,
272 'pbc_residues': pbc_residues,
273 }
274 # Wrtie the new topology
275 save_json(output_topology, output_topology_file.path)