Coverage for mddb_workflow / analyses / clusters.py: 95%
129 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 18:45 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 18:45 +0000
1from os.path import exists
2import numpy as np
3import mdtraj as mdt
5from mddb_workflow.utils.auxiliar import round_to_thousandths, save_json, otherwise
6from mddb_workflow.utils.auxiliar import numerate_filename, get_analysis_name
7from mddb_workflow.utils.auxiliar import reprint, delete_previous_log
8from mddb_workflow.utils.constants import OUTPUT_CLUSTERS_FILENAME, OUTPUT_CLUSTER_SCREENSHOT_FILENAMES
9from mddb_workflow.utils.file import File
10from mddb_workflow.tools.get_screenshot import get_screenshot
11from mddb_workflow.tools.get_reduced_trajectory import get_reduced_trajectory
12from mddb_workflow.utils.type_hints import *
15def clusters_analysis(
16 structure_file: 'File',
17 trajectory_file: 'File',
18 interactions: list,
19 structure: 'Structure',
20 snapshots: int,
21 pbc_selection: 'Selection',
22 output_directory: str,
23 # Set the maximum number of frames
24 frames_limit: int = 1000,
25 # Set the number of steps between the maximum and minimum RMSD so set how many cutoff are tried and how far they are
26 n_steps: int = 100,
27 # Set the final amount of desired clusters
28 desired_n_clusters: int = 20,
29 # Set the atom selection for the overall clustering
30 overall_selection: str = "name CA or name C5'",
31):
32 """Run the cluster analysis."""
33 # The cluster analysis is run for the overall structure and then once more for every interaction
34 # We must set the atom selection of every run in atom indices, for MDtraj
35 runs = []
37 # Start with the overall selection
38 parsed_overall_selection = structure.select(overall_selection)
39 # If the default selection is empty then use all heavy atoms
40 if not parsed_overall_selection:
41 parsed_overall_selection = structure.select_heavy_atoms()
42 # Substract PBC atoms
43 # Note that this is essential since sudden jumps across boundaries would eclipse the actual clusters
44 parsed_overall_selection -= pbc_selection
45 if parsed_overall_selection:
46 runs.append({
47 'name': 'Overall',
48 'selection': parsed_overall_selection
49 })
51 # Now setup the interaction runs
52 for interaction in interactions:
53 # Get the interface selection
54 interface_residue_indices = interaction['interface_residue_indices_1'] \
55 + interaction['interface_residue_indices_2']
56 interface_selection = structure.select_residue_indices(interface_residue_indices)
57 heavy_atoms_selection = structure.select_heavy_atoms()
58 # Keep only heavy atoms for the distance calculation
59 final_selection = interface_selection & heavy_atoms_selection
60 # Substract PBC atoms
61 final_selection -= pbc_selection
62 if final_selection:
63 runs.append({
64 'name': interaction['name'],
65 'selection': final_selection
66 })
68 # If there are no runs at this point then stop here
69 # This may happen when the whole system is in PCB
70 if len(runs) == 0:
71 print(' No clusters to analyze')
72 return
74 # If trajectory frames number is bigger than the limit we create a reduced trajectory
75 reduced_trajectory_filepath, step, frames = get_reduced_trajectory(
76 structure_file,
77 trajectory_file,
78 snapshots,
79 frames_limit,
80 )
82 # Set output filepaths
83 output_analysis_filepath = f'{output_directory}/{OUTPUT_CLUSTERS_FILENAME}'
84 output_screenshot_filepath = f'{output_directory}/{OUTPUT_CLUSTER_SCREENSHOT_FILENAMES}'
86 # Load the whole trajectory
87 traj = mdt.load(reduced_trajectory_filepath, top=structure_file.path)
89 # Set the target number of clusters
90 # This should be the desired number of clusters unless there are less frames than that
91 target_n_clusters = min([desired_n_clusters, snapshots])
93 # Copy the structure to further mutate its coordinates without affecting the original
94 auxiliar_structure = structure.copy()
96 # Set the final analysis which is actually a summary to find every run
97 output_summary = []
99 # Now iterate over the different runs
100 for r, run in enumerate(runs):
101 # Set the output analysis filename from the input template
102 # e.g. replica_1/mda.clusters_*.json -> replica_1/mda.clusters_01.json
103 numbered_output_analysis_filepath = numerate_filename(output_analysis_filepath, r)
104 # Get the run name
105 name = run['name']
106 # Add the root of the output analysis filename to the run data
107 analysis_name = get_analysis_name(numbered_output_analysis_filepath)
108 # Add this run to the final summary
109 output_summary.append({
110 'name': name,
111 'analysis': analysis_name
112 })
113 # If the output file already exists then skip this iteration
114 if exists(numbered_output_analysis_filepath):
115 continue
117 print(f'Calculating distances for {name} -> {analysis_name}')
118 # Get the run selection atom indices
119 atom_indices = run['selection'].atom_indices
121 # Calculate the RMSD matrix
122 distance_matrix = np.empty((traj.n_frames, traj.n_frames))
123 for i in range(traj.n_frames):
124 print(f' Frame {i+1} out of {traj.n_frames}', end='\r')
125 # Calculate the RMSD between every frame in the trajectory and the frame 'i'
126 distance_matrix[i] = mdt.rmsd(traj, traj, i, atom_indices=atom_indices)
128 # Get the maximum RMSD value in the whole matrix
129 maximum_rmsd = np.max(distance_matrix)
130 # Get the minimum RMSD value in the whole matrix
131 # Discard 0s from frames against themselves
132 minimum_rmsd = np.min(distance_matrix[distance_matrix != 0])
134 # Set the difference between the minimum and maximum to determine the cutoffs step
135 rmsd_difference = maximum_rmsd - minimum_rmsd
136 rmsd_step = rmsd_difference / n_steps
137 # Since we are then rounding to thousands we make sure a cutoff won't be repeated twice
138 if rmsd_step < 0.001: rmsd_step = 0.001
140 # Set the initial RMSD cutoff
141 cutoff = round_to_thousandths(minimum_rmsd + rmsd_difference / 2)
142 # Keep a register of already tried cutoffs so we do not repeat
143 already_tried_cutoffs = set()
145 # Adjust the RMSD cutoff until we get the desired amount of clusters
146 # Note that final clusters will be ordered by the time they appear
147 clusters = None
148 n_clusters = 0
149 while n_clusters != target_n_clusters:
150 # Find clusters
151 print(f' Trying with cutoff {cutoff}', end='')
152 clusters = clustering(distance_matrix, cutoff)
153 n_clusters = len(clusters)
154 print(f' -> Found {n_clusters} clusters')
155 # Update the cutoff
156 already_tried_cutoffs.add(cutoff)
157 if n_clusters > target_n_clusters:
158 cutoff = round_to_thousandths(cutoff + rmsd_step)
159 if n_clusters < target_n_clusters:
160 cutoff = round_to_thousandths(cutoff - rmsd_step)
161 # If we already tried the updated cutoff then we are close enough to the desired number of clusters
162 if cutoff in already_tried_cutoffs:
163 break
164 # Erase previous log and write in the same line
165 delete_previous_log()
167 # Count the number of frames per cluster
168 cluster_lengths = [len(cluster) for cluster in clusters]
170 # Resort clusters in a "cluster per frame" structure
171 frame_clusters = np.empty(traj.n_frames, dtype=int)
172 for c, cluster in enumerate(clusters):
173 for frame in cluster:
174 frame_clusters[frame] = c
176 # Count the transitions between clusters
177 transitions = []
179 # Iterate over the different frames
180 previous_cluster = frame_clusters[0]
181 for cluster in frame_clusters[1:]:
182 # If this is the same cluster then there is no transition here
183 if previous_cluster == cluster:
184 continue
185 # Otherwise save the transition
186 transition = previous_cluster, cluster
187 transitions.append(transition)
188 previous_cluster = cluster
190 print(f' Found {len(transitions)} transitions')
192 # Count every different transition
193 transition_counts = {}
194 for transition in transitions:
195 current_count = transition_counts.get(transition, 0)
196 transition_counts[transition] = current_count + 1
198 # Now for every cluster find the most representative frame (i.e. the one with less RMSD distance to its neighbours)
199 # Then make a screenshot for this specific frame
200 representative_frames = []
201 # Save the screenshot parameters so we can keep images coherent between clusters
202 screenshot_parameters = None
203 print(' Generating cluster screenshots') # This will be reprinted
204 for c, cluster in enumerate(clusters):
205 most_representative_frame = None
206 min_distance = float('inf') # Positive infinity
207 for frame, neighbour_frames in otherwise(cluster):
208 # Calculate the sum of all rmsd distances
209 total_distance = 0
210 for neighbour_frame in neighbour_frames:
211 total_distance += distance_matrix[frame][neighbour_frame]
212 # If the distance is inferior the current minimum then set this frame as the most representative
213 if total_distance < min_distance:
214 most_representative_frame = frame
215 min_distance = total_distance
216 # Save the most representative frame in the list
217 representative_frames.append(most_representative_frame)
218 # Once we have the most representative frame we take a screenshot
219 # This screenshots will be then uploaded to the database as well
220 # Generate a pdb with coordinates from the most representative frame
221 mdt_frame = traj[most_representative_frame]
222 coordinates = mdt_frame.xyz[0] * 10 # We multiply by to restor Ångstroms
223 # WARNING: a PDB generated by MDtraj may have problems thus leading to artifacts in the screenshot
224 # WARNING: to avoid this we add the coordinates to the structure
225 # coordinates.save(AUXILIAR_PDB_FILENAME)
226 auxiliar_structure.set_new_coordinates(coordinates)
227 # Set the screenshot filename from the input template
228 screenshot_filepath = output_screenshot_filepath.replace('*', str(r).zfill(2)).replace('??', str(c).zfill(2))
229 screenshot_file = File(screenshot_filepath)
230 # Generate the screenshot
231 reprint(f' Generating cluster screenshot {c+1}/{n_clusters}')
232 screenshot_parameters = get_screenshot(auxiliar_structure, screenshot_file,
233 parameters=screenshot_parameters)
235 # Set the output clusters which include all frames in the cluster and the main or more representative frame
236 output_clusters = []
237 for frames, most_representative_frame in zip(clusters, representative_frames):
238 output_clusters.append({'frames': frames, 'main': most_representative_frame * step})
240 # Set the output transitions in a hashable and json parseable way
241 output_transitions = []
242 for transition, count in transition_counts.items():
243 # Set frames as regular ints to make them json serializable
244 output_transitions.append({'from': int(transition[0]), 'to': int(transition[1]), 'count': count})
246 # Set the output analysis
247 output_analysis = {
248 'name': name,
249 'cutoff': cutoff,
250 'clusters': output_clusters,
251 'transitions': output_transitions,
252 'step': step,
253 'version': '0.1.0',
254 }
256 # The output filename must be different for every run to avoid overwritting previous results
257 # However the filename is not important regarding the database since this analysis is found by its 'run'
258 save_json(output_analysis, numbered_output_analysis_filepath)
260 # Save the final summary
261 save_json(output_summary, output_analysis_filepath)
264# Set a function to cluster frames in a RMSD matrix given a RMSD cutoff
265# https://github.com/boneta/RMSD-Clustering/blob/master/rmsd_clustering/clustering.py
266def clustering(rmsd_matrix: np.ndarray, cutoff: float) -> list:
267 clusters = []
268 for i in range(rmsd_matrix.shape[0]):
269 for cluster in clusters:
270 if all(rmsd_matrix[i, j] < cutoff for j in cluster):
271 cluster.append(i)
272 break
273 else:
274 clusters.append([i])
275 return clusters