Coverage for mddb_workflow / analyses / clusters.py: 95%

129 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 18:45 +0000

1from os.path import exists 

2import numpy as np 

3import mdtraj as mdt 

4 

5from mddb_workflow.utils.auxiliar import round_to_thousandths, save_json, otherwise 

6from mddb_workflow.utils.auxiliar import numerate_filename, get_analysis_name 

7from mddb_workflow.utils.auxiliar import reprint, delete_previous_log 

8from mddb_workflow.utils.constants import OUTPUT_CLUSTERS_FILENAME, OUTPUT_CLUSTER_SCREENSHOT_FILENAMES 

9from mddb_workflow.utils.file import File 

10from mddb_workflow.tools.get_screenshot import get_screenshot 

11from mddb_workflow.tools.get_reduced_trajectory import get_reduced_trajectory 

12from mddb_workflow.utils.type_hints import * 

13 

14 

15def clusters_analysis( 

16 structure_file: 'File', 

17 trajectory_file: 'File', 

18 interactions: list, 

19 structure: 'Structure', 

20 snapshots: int, 

21 pbc_selection: 'Selection', 

22 output_directory: str, 

23 # Set the maximum number of frames 

24 frames_limit: int = 1000, 

25 # Set the number of steps between the maximum and minimum RMSD so set how many cutoff are tried and how far they are 

26 n_steps: int = 100, 

27 # Set the final amount of desired clusters 

28 desired_n_clusters: int = 20, 

29 # Set the atom selection for the overall clustering 

30 overall_selection: str = "name CA or name C5'", 

31): 

32 """Run the cluster analysis.""" 

33 # The cluster analysis is run for the overall structure and then once more for every interaction 

34 # We must set the atom selection of every run in atom indices, for MDtraj 

35 runs = [] 

36 

37 # Start with the overall selection 

38 parsed_overall_selection = structure.select(overall_selection) 

39 # If the default selection is empty then use all heavy atoms 

40 if not parsed_overall_selection: 

41 parsed_overall_selection = structure.select_heavy_atoms() 

42 # Substract PBC atoms 

43 # Note that this is essential since sudden jumps across boundaries would eclipse the actual clusters 

44 parsed_overall_selection -= pbc_selection 

45 if parsed_overall_selection: 

46 runs.append({ 

47 'name': 'Overall', 

48 'selection': parsed_overall_selection 

49 }) 

50 

51 # Now setup the interaction runs 

52 for interaction in interactions: 

53 # Get the interface selection 

54 interface_residue_indices = interaction['interface_residue_indices_1'] \ 

55 + interaction['interface_residue_indices_2'] 

56 interface_selection = structure.select_residue_indices(interface_residue_indices) 

57 heavy_atoms_selection = structure.select_heavy_atoms() 

58 # Keep only heavy atoms for the distance calculation 

59 final_selection = interface_selection & heavy_atoms_selection 

60 # Substract PBC atoms 

61 final_selection -= pbc_selection 

62 if final_selection: 

63 runs.append({ 

64 'name': interaction['name'], 

65 'selection': final_selection 

66 }) 

67 

68 # If there are no runs at this point then stop here 

69 # This may happen when the whole system is in PCB 

70 if len(runs) == 0: 

71 print(' No clusters to analyze') 

72 return 

73 

74 # If trajectory frames number is bigger than the limit we create a reduced trajectory 

75 reduced_trajectory_filepath, step, frames = get_reduced_trajectory( 

76 structure_file, 

77 trajectory_file, 

78 snapshots, 

79 frames_limit, 

80 ) 

81 

82 # Set output filepaths 

83 output_analysis_filepath = f'{output_directory}/{OUTPUT_CLUSTERS_FILENAME}' 

84 output_screenshot_filepath = f'{output_directory}/{OUTPUT_CLUSTER_SCREENSHOT_FILENAMES}' 

85 

86 # Load the whole trajectory 

87 traj = mdt.load(reduced_trajectory_filepath, top=structure_file.path) 

88 

89 # Set the target number of clusters 

90 # This should be the desired number of clusters unless there are less frames than that 

91 target_n_clusters = min([desired_n_clusters, snapshots]) 

92 

93 # Copy the structure to further mutate its coordinates without affecting the original 

94 auxiliar_structure = structure.copy() 

95 

96 # Set the final analysis which is actually a summary to find every run 

97 output_summary = [] 

98 

99 # Now iterate over the different runs 

100 for r, run in enumerate(runs): 

101 # Set the output analysis filename from the input template 

102 # e.g. replica_1/mda.clusters_*.json -> replica_1/mda.clusters_01.json 

103 numbered_output_analysis_filepath = numerate_filename(output_analysis_filepath, r) 

104 # Get the run name 

105 name = run['name'] 

106 # Add the root of the output analysis filename to the run data 

107 analysis_name = get_analysis_name(numbered_output_analysis_filepath) 

108 # Add this run to the final summary 

109 output_summary.append({ 

110 'name': name, 

111 'analysis': analysis_name 

112 }) 

113 # If the output file already exists then skip this iteration 

114 if exists(numbered_output_analysis_filepath): 

115 continue 

116 

117 print(f'Calculating distances for {name} -> {analysis_name}') 

118 # Get the run selection atom indices 

119 atom_indices = run['selection'].atom_indices 

120 

121 # Calculate the RMSD matrix 

122 distance_matrix = np.empty((traj.n_frames, traj.n_frames)) 

123 for i in range(traj.n_frames): 

124 print(f' Frame {i+1} out of {traj.n_frames}', end='\r') 

125 # Calculate the RMSD between every frame in the trajectory and the frame 'i' 

126 distance_matrix[i] = mdt.rmsd(traj, traj, i, atom_indices=atom_indices) 

127 

128 # Get the maximum RMSD value in the whole matrix 

129 maximum_rmsd = np.max(distance_matrix) 

130 # Get the minimum RMSD value in the whole matrix 

131 # Discard 0s from frames against themselves 

132 minimum_rmsd = np.min(distance_matrix[distance_matrix != 0]) 

133 

134 # Set the difference between the minimum and maximum to determine the cutoffs step 

135 rmsd_difference = maximum_rmsd - minimum_rmsd 

136 rmsd_step = rmsd_difference / n_steps 

137 # Since we are then rounding to thousands we make sure a cutoff won't be repeated twice 

138 if rmsd_step < 0.001: rmsd_step = 0.001 

139 

140 # Set the initial RMSD cutoff 

141 cutoff = round_to_thousandths(minimum_rmsd + rmsd_difference / 2) 

142 # Keep a register of already tried cutoffs so we do not repeat 

143 already_tried_cutoffs = set() 

144 

145 # Adjust the RMSD cutoff until we get the desired amount of clusters 

146 # Note that final clusters will be ordered by the time they appear 

147 clusters = None 

148 n_clusters = 0 

149 while n_clusters != target_n_clusters: 

150 # Find clusters 

151 print(f' Trying with cutoff {cutoff}', end='') 

152 clusters = clustering(distance_matrix, cutoff) 

153 n_clusters = len(clusters) 

154 print(f' -> Found {n_clusters} clusters') 

155 # Update the cutoff 

156 already_tried_cutoffs.add(cutoff) 

157 if n_clusters > target_n_clusters: 

158 cutoff = round_to_thousandths(cutoff + rmsd_step) 

159 if n_clusters < target_n_clusters: 

160 cutoff = round_to_thousandths(cutoff - rmsd_step) 

161 # If we already tried the updated cutoff then we are close enough to the desired number of clusters 

162 if cutoff in already_tried_cutoffs: 

163 break 

164 # Erase previous log and write in the same line 

165 delete_previous_log() 

166 

167 # Count the number of frames per cluster 

168 cluster_lengths = [len(cluster) for cluster in clusters] 

169 

170 # Resort clusters in a "cluster per frame" structure 

171 frame_clusters = np.empty(traj.n_frames, dtype=int) 

172 for c, cluster in enumerate(clusters): 

173 for frame in cluster: 

174 frame_clusters[frame] = c 

175 

176 # Count the transitions between clusters 

177 transitions = [] 

178 

179 # Iterate over the different frames 

180 previous_cluster = frame_clusters[0] 

181 for cluster in frame_clusters[1:]: 

182 # If this is the same cluster then there is no transition here 

183 if previous_cluster == cluster: 

184 continue 

185 # Otherwise save the transition 

186 transition = previous_cluster, cluster 

187 transitions.append(transition) 

188 previous_cluster = cluster 

189 

190 print(f' Found {len(transitions)} transitions') 

191 

192 # Count every different transition 

193 transition_counts = {} 

194 for transition in transitions: 

195 current_count = transition_counts.get(transition, 0) 

196 transition_counts[transition] = current_count + 1 

197 

198 # Now for every cluster find the most representative frame (i.e. the one with less RMSD distance to its neighbours) 

199 # Then make a screenshot for this specific frame 

200 representative_frames = [] 

201 # Save the screenshot parameters so we can keep images coherent between clusters 

202 screenshot_parameters = None 

203 print(' Generating cluster screenshots') # This will be reprinted 

204 for c, cluster in enumerate(clusters): 

205 most_representative_frame = None 

206 min_distance = float('inf') # Positive infinity 

207 for frame, neighbour_frames in otherwise(cluster): 

208 # Calculate the sum of all rmsd distances 

209 total_distance = 0 

210 for neighbour_frame in neighbour_frames: 

211 total_distance += distance_matrix[frame][neighbour_frame] 

212 # If the distance is inferior the current minimum then set this frame as the most representative 

213 if total_distance < min_distance: 

214 most_representative_frame = frame 

215 min_distance = total_distance 

216 # Save the most representative frame in the list 

217 representative_frames.append(most_representative_frame) 

218 # Once we have the most representative frame we take a screenshot 

219 # This screenshots will be then uploaded to the database as well 

220 # Generate a pdb with coordinates from the most representative frame 

221 mdt_frame = traj[most_representative_frame] 

222 coordinates = mdt_frame.xyz[0] * 10 # We multiply by to restor Ångstroms 

223 # WARNING: a PDB generated by MDtraj may have problems thus leading to artifacts in the screenshot 

224 # WARNING: to avoid this we add the coordinates to the structure 

225 # coordinates.save(AUXILIAR_PDB_FILENAME) 

226 auxiliar_structure.set_new_coordinates(coordinates) 

227 # Set the screenshot filename from the input template 

228 screenshot_filepath = output_screenshot_filepath.replace('*', str(r).zfill(2)).replace('??', str(c).zfill(2)) 

229 screenshot_file = File(screenshot_filepath) 

230 # Generate the screenshot 

231 reprint(f' Generating cluster screenshot {c+1}/{n_clusters}') 

232 screenshot_parameters = get_screenshot(auxiliar_structure, screenshot_file, 

233 parameters=screenshot_parameters) 

234 

235 # Set the output clusters which include all frames in the cluster and the main or more representative frame 

236 output_clusters = [] 

237 for frames, most_representative_frame in zip(clusters, representative_frames): 

238 output_clusters.append({'frames': frames, 'main': most_representative_frame * step}) 

239 

240 # Set the output transitions in a hashable and json parseable way 

241 output_transitions = [] 

242 for transition, count in transition_counts.items(): 

243 # Set frames as regular ints to make them json serializable 

244 output_transitions.append({'from': int(transition[0]), 'to': int(transition[1]), 'count': count}) 

245 

246 # Set the output analysis 

247 output_analysis = { 

248 'name': name, 

249 'cutoff': cutoff, 

250 'clusters': output_clusters, 

251 'transitions': output_transitions, 

252 'step': step, 

253 'version': '0.1.0', 

254 } 

255 

256 # The output filename must be different for every run to avoid overwritting previous results 

257 # However the filename is not important regarding the database since this analysis is found by its 'run' 

258 save_json(output_analysis, numbered_output_analysis_filepath) 

259 

260 # Save the final summary 

261 save_json(output_summary, output_analysis_filepath) 

262 

263 

264# Set a function to cluster frames in a RMSD matrix given a RMSD cutoff 

265# https://github.com/boneta/RMSD-Clustering/blob/master/rmsd_clustering/clustering.py 

266def clustering(rmsd_matrix: np.ndarray, cutoff: float) -> list: 

267 clusters = [] 

268 for i in range(rmsd_matrix.shape[0]): 

269 for cluster in clusters: 

270 if all(rmsd_matrix[i, j] < cutoff for j in cluster): 

271 cluster.append(i) 

272 break 

273 else: 

274 clusters.append([i]) 

275 return clusters