Coverage for model_workflow/analyses/clusters.py: 96%
126 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
1from os.path import exists
3import numpy as np
5import mdtraj as mdt
7from model_workflow.utils.auxiliar import round_to_thousandths, save_json, otherwise
8from model_workflow.utils.auxiliar import numerate_filename, get_analysis_name
9from model_workflow.utils.auxiliar import reprint, delete_previous_log
10from model_workflow.utils.constants import OUTPUT_CLUSTERS_FILENAME, OUTPUT_CLUSTER_SCREENSHOT_FILENAMES
11from model_workflow.tools.get_screenshot import get_screenshot
12from model_workflow.tools.get_reduced_trajectory import get_reduced_trajectory
13from model_workflow.utils.type_hints import *
15def clusters_analysis (
16 structure_file : 'File',
17 trajectory_file : 'File',
18 interactions : list,
19 structure : 'Structure',
20 snapshots : int,
21 pbc_selection : 'Selection',
22 output_directory : str,
23 # Set the maximum number of frames
24 frames_limit : int = 1000,
25 # Set the number of steps between the maximum and minimum RMSD so set how many cutoff are tried and how far they are
26 n_steps : int = 100,
27 # Set the final amount of desired clusters
28 desired_n_clusters : int = 20,
29 # Set the atom selection for the overall clustering
30 overall_selection : str = "name CA or name C5'",
31):
32 """Run the cluster analysis."""
33 # The cluster analysis is run for the overall structure and then once more for every interaction
34 # We must set the atom selection of every run in atom indices, for MDtraj
35 runs = []
37 # Start with the overall selection
38 parsed_overall_selection = structure.select(overall_selection)
39 # If the default selection is empty then use all heavy atoms
40 if not parsed_overall_selection:
41 parsed_overall_selection = structure.select_heavy_atoms()
42 # Substract PBC atoms
43 # Note that this is essential since sudden jumps across boundaries would eclipse the actual clusters
44 parsed_overall_selection -= pbc_selection
45 if parsed_overall_selection:
46 runs.append({
47 'name': 'Overall',
48 'selection': parsed_overall_selection
49 })
51 # Now setup the interaction runs
52 for interaction in interactions:
53 # Get the interface selection
54 interface_residue_indices = interaction['interface_residue_indices_1'] \
55 + interaction['interface_residue_indices_2']
56 interface_selection = structure.select_residue_indices(interface_residue_indices)
57 heavy_atoms_selection = structure.select_heavy_atoms()
58 # Keep only heavy atoms for the distance calculation
59 final_selection = interface_selection & heavy_atoms_selection
60 # Substract PBC atoms
61 final_selection -= pbc_selection
62 if final_selection:
63 runs.append({
64 'name': interaction['name'],
65 'selection': final_selection
66 })
68 # If there are no runs at this point then stop here
69 # This may happen when the whole system is in PCB
70 if len(runs) == 0:
71 print(' No clusters to analyze')
72 return
74 # If trajectory frames number is bigger than the limit we create a reduced trajectory
75 reduced_trajectory_filepath, step, frames = get_reduced_trajectory(
76 structure_file,
77 trajectory_file,
78 snapshots,
79 frames_limit,
80 )
82 # Set output filepaths
83 output_analysis_filepath = f'{output_directory}/{OUTPUT_CLUSTERS_FILENAME}'
84 output_screenshot_filepath = f'{output_directory}/{OUTPUT_CLUSTER_SCREENSHOT_FILENAMES}'
86 # Load the whole trajectory
87 traj = mdt.load(reduced_trajectory_filepath, top=structure_file.path)
89 # Set the target number of clusters
90 # This should be the desired number of clusters unless there are less frames than that
91 target_n_clusters = min([ desired_n_clusters, snapshots ])
93 # Copy the structure to further mutate its coordinates without affecting the original
94 auxiliar_structure = structure.copy()
96 # Set the final analysis which is actually a summary to find every run
97 output_summary = []
99 # Now iterate over the different runs
100 for r, run in enumerate(runs):
101 # Set the output analysis filename from the input template
102 # e.g. replica_1/mda.clusters_*.json -> replica_1/mda.clusters_01.json
103 numbered_output_analysis_filepath = numerate_filename(output_analysis_filepath, r)
104 # Get the run name
105 name = run['name']
106 # Add the root of the output analysis filename to the run data
107 analysis_name = get_analysis_name(numbered_output_analysis_filepath)
108 # Add this run to the final summary
109 output_summary.append({
110 'name': name,
111 'analysis': analysis_name
112 })
113 # If the output file already exists then skip this iteration
114 if exists(numbered_output_analysis_filepath):
115 continue
117 print(f'Calculating distances for {name} -> {analysis_name}')
118 # Get the run selection atom indices
119 atom_indices = run['selection'].atom_indices
121 # Calculate the RMSD matrix
122 distance_matrix = np.empty((traj.n_frames, traj.n_frames))
123 for i in range(traj.n_frames):
124 print(f' Frame {i+1} out of {traj.n_frames}', end='\r')
125 # Calculate the RMSD between every frame in the trajectory and the frame 'i'
126 distance_matrix[i] = mdt.rmsd(traj, traj, i, atom_indices=atom_indices)
128 # Get the maximum RMSD value in the whole matrix
129 maximum_rmsd = np.max(distance_matrix)
130 # Get the minimum RMSD value in the whole matrix
131 # Discard 0s from frames against themselves
132 minimum_rmsd = np.min(distance_matrix[distance_matrix != 0])
134 # Set the difference between the minimum and maximum to determine the cutoffs step
135 rmsd_difference = maximum_rmsd - minimum_rmsd
136 rmsd_step = rmsd_difference / n_steps
138 # Set the initial RMSD cutoff
139 cutoff = round_to_thousandths(minimum_rmsd + rmsd_difference / 2)
140 # Keep a register of already tried cutoffs so we do not repeat
141 already_tried_cutoffs = set()
143 # Adjust the RMSD cutoff until we get the desired amount of clusters
144 # Note that final clusters will be ordered by the time they appear
145 clusters = None
146 n_clusters = 0
147 while n_clusters != target_n_clusters:
148 # Find clusters
149 print(f' Trying with cutoff {cutoff}', end='')
150 clusters = clustering(distance_matrix, cutoff)
151 n_clusters = len(clusters)
152 print(f' -> Found {n_clusters} clusters')
153 # Update the cutoff
154 already_tried_cutoffs.add(cutoff)
155 if n_clusters > target_n_clusters:
156 cutoff = round_to_thousandths(cutoff + rmsd_step)
157 if n_clusters < target_n_clusters:
158 cutoff = round_to_thousandths(cutoff - rmsd_step)
159 # If we already tried the updated cutoff then we are close enough to the desired number of clusters
160 if cutoff in already_tried_cutoffs:
161 break
162 # Erase previous log and write in the same line
163 delete_previous_log()
165 # Count the number of frames per cluster
166 cluster_lengths = [ len(cluster) for cluster in clusters ]
168 # Resort clusters in a "cluster per frame" structure
169 frame_clusters = np.empty(traj.n_frames, dtype=int)
170 for c, cluster in enumerate(clusters):
171 for frame in cluster:
172 frame_clusters[frame] = c
174 # Count the transitions between clusters
175 transitions = []
177 # Iterate over the different frames
178 previous_cluster = frame_clusters[0]
179 for cluster in frame_clusters[1:]:
180 # If this is the same cluster then there is no transition here
181 if previous_cluster == cluster:
182 continue
183 # Otherwise save the transition
184 transition = previous_cluster, cluster
185 transitions.append(transition)
186 previous_cluster = cluster
188 print(f' Found {len(transitions)} transitions')
190 # Count every different transition
191 transition_counts = {}
192 for transition in transitions:
193 current_count = transition_counts.get(transition, 0)
194 transition_counts[transition] = current_count + 1
196 # Now for every cluster find the most representative frame (i.e. the one with less RMSD distance to its neighbours)
197 # Then make a screenshot for this specific frame
198 representative_frames = []
199 # Save the screenshot parameters so we can keep images coherent between clusters
200 screenshot_parameters = None
201 print(' Generating cluster screenshots') # This will be reprinted
202 for c, cluster in enumerate(clusters):
203 most_representative_frame = None
204 min_distance = float('inf') # Positive infinity
205 for frame, neighbour_frames in otherwise(cluster):
206 # Calculate the sum of all rmsd distances
207 total_distance = 0
208 for neighbour_frame in neighbour_frames:
209 total_distance += distance_matrix[frame][neighbour_frame]
210 # If the distance is inferior the current minimum then set this frame as the most representative
211 if total_distance < min_distance:
212 most_representative_frame = frame
213 min_distance = total_distance
214 # Save the most representative frame in the list
215 representative_frames.append(most_representative_frame)
216 # Once we have the most representative frame we take a screenshot
217 # This screenshots will be then uploaded to the database as well
218 # Generate a pdb with coordinates from the most representative frame
219 mdt_frame = traj[most_representative_frame]
220 coordinates = mdt_frame.xyz[0] * 10 # We multiply by to restor Ångstroms
221 # WARNING: a PDB generated by MDtraj may have problems thus leading to artifacts in the screenshot
222 # WARNING: to avoid this we add the coordinates to the structure
223 # coordinates.save(AUXILIAR_PDB_FILENAME)
224 auxiliar_structure.set_new_coordinates(coordinates)
225 # Set the screenshot filename from the input template
226 screenshot_filename = output_screenshot_filepath.replace('*', str(r).zfill(2)).replace('??', str(c).zfill(2))
227 # Generate the screenshot
228 reprint(f' Generating cluster screenshot {c+1}/{n_clusters}')
229 screenshot_parameters = get_screenshot(auxiliar_structure, screenshot_filename,
230 parameters=screenshot_parameters)
232 # Set the output clusters which include all frames in the cluster and the main or more representative frame
233 output_clusters = []
234 for frames, most_representative_frame in zip(clusters, representative_frames):
235 output_clusters.append({ 'frames': frames, 'main': most_representative_frame * step })
237 # Set the output transitions in a hashable and json parseable way
238 output_transitions = []
239 for transition, count in transition_counts.items():
240 # Set frames as regular ints to make them json serializable
241 output_transitions.append({ 'from': int(transition[0]), 'to': int(transition[1]), 'count': count })
243 # Set the output analysis
244 output_analysis = {
245 'name': name,
246 'cutoff': cutoff,
247 'clusters': output_clusters,
248 'transitions': output_transitions,
249 'step': step,
250 'version': '0.1.0',
251 }
253 # The output filename must be different for every run to avoid overwritting previous results
254 # However the filename is not important regarding the database since this analysis is found by its 'run'
255 save_json(output_analysis, numbered_output_analysis_filepath)
257 # Save the final summary
258 save_json(output_summary, output_analysis_filepath)
260# Set a function to cluster frames in a RMSD matrix given a RMSD cutoff
261# https://github.com/boneta/RMSD-Clustering/blob/master/rmsd_clustering/clustering.py
262def clustering (rmsd_matrix : np.ndarray, cutoff : float) -> list:
263 clusters = []
264 for i in range(rmsd_matrix.shape[0]):
265 for cluster in clusters:
266 if all(rmsd_matrix[i,j] < cutoff for j in cluster):
267 cluster.append(i)
268 break
269 else:
270 clusters.append([i])
271 return clusters