Coverage for model_workflow/analyses/clusters.py: 96%

126 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-23 10:54 +0000

1from os.path import exists 

2 

3import numpy as np 

4 

5import mdtraj as mdt 

6 

7from model_workflow.utils.auxiliar import round_to_thousandths, save_json, otherwise 

8from model_workflow.utils.auxiliar import numerate_filename, get_analysis_name 

9from model_workflow.utils.auxiliar import reprint, delete_previous_log 

10from model_workflow.utils.constants import OUTPUT_CLUSTERS_FILENAME, OUTPUT_CLUSTER_SCREENSHOT_FILENAMES 

11from model_workflow.tools.get_screenshot import get_screenshot 

12from model_workflow.tools.get_reduced_trajectory import get_reduced_trajectory 

13from model_workflow.utils.type_hints import * 

14 

15def clusters_analysis ( 

16 structure_file : 'File', 

17 trajectory_file : 'File', 

18 interactions : list, 

19 structure : 'Structure', 

20 snapshots : int, 

21 pbc_selection : 'Selection', 

22 output_directory : str, 

23 # Set the maximum number of frames 

24 frames_limit : int = 1000, 

25 # Set the number of steps between the maximum and minimum RMSD so set how many cutoff are tried and how far they are 

26 n_steps : int = 100, 

27 # Set the final amount of desired clusters 

28 desired_n_clusters : int = 20, 

29 # Set the atom selection for the overall clustering 

30 overall_selection : str = "name CA or name C5'", 

31): 

32 """Run the cluster analysis.""" 

33 # The cluster analysis is run for the overall structure and then once more for every interaction 

34 # We must set the atom selection of every run in atom indices, for MDtraj 

35 runs = [] 

36 

37 # Start with the overall selection 

38 parsed_overall_selection = structure.select(overall_selection) 

39 # If the default selection is empty then use all heavy atoms 

40 if not parsed_overall_selection: 

41 parsed_overall_selection = structure.select_heavy_atoms() 

42 # Substract PBC atoms 

43 # Note that this is essential since sudden jumps across boundaries would eclipse the actual clusters 

44 parsed_overall_selection -= pbc_selection 

45 if parsed_overall_selection: 

46 runs.append({ 

47 'name': 'Overall', 

48 'selection': parsed_overall_selection 

49 }) 

50 

51 # Now setup the interaction runs 

52 for interaction in interactions: 

53 # Get the interface selection 

54 interface_residue_indices = interaction['interface_residue_indices_1'] \ 

55 + interaction['interface_residue_indices_2'] 

56 interface_selection = structure.select_residue_indices(interface_residue_indices) 

57 heavy_atoms_selection = structure.select_heavy_atoms() 

58 # Keep only heavy atoms for the distance calculation 

59 final_selection = interface_selection & heavy_atoms_selection 

60 # Substract PBC atoms 

61 final_selection -= pbc_selection 

62 if final_selection: 

63 runs.append({ 

64 'name': interaction['name'], 

65 'selection': final_selection 

66 }) 

67 

68 # If there are no runs at this point then stop here 

69 # This may happen when the whole system is in PCB 

70 if len(runs) == 0: 

71 print(' No clusters to analyze') 

72 return 

73 

74 # If trajectory frames number is bigger than the limit we create a reduced trajectory 

75 reduced_trajectory_filepath, step, frames = get_reduced_trajectory( 

76 structure_file, 

77 trajectory_file, 

78 snapshots, 

79 frames_limit, 

80 ) 

81 

82 # Set output filepaths 

83 output_analysis_filepath = f'{output_directory}/{OUTPUT_CLUSTERS_FILENAME}' 

84 output_screenshot_filepath = f'{output_directory}/{OUTPUT_CLUSTER_SCREENSHOT_FILENAMES}' 

85 

86 # Load the whole trajectory 

87 traj = mdt.load(reduced_trajectory_filepath, top=structure_file.path) 

88 

89 # Set the target number of clusters 

90 # This should be the desired number of clusters unless there are less frames than that 

91 target_n_clusters = min([ desired_n_clusters, snapshots ]) 

92 

93 # Copy the structure to further mutate its coordinates without affecting the original 

94 auxiliar_structure = structure.copy() 

95 

96 # Set the final analysis which is actually a summary to find every run 

97 output_summary = [] 

98 

99 # Now iterate over the different runs 

100 for r, run in enumerate(runs): 

101 # Set the output analysis filename from the input template 

102 # e.g. replica_1/mda.clusters_*.json -> replica_1/mda.clusters_01.json 

103 numbered_output_analysis_filepath = numerate_filename(output_analysis_filepath, r) 

104 # Get the run name 

105 name = run['name'] 

106 # Add the root of the output analysis filename to the run data 

107 analysis_name = get_analysis_name(numbered_output_analysis_filepath) 

108 # Add this run to the final summary 

109 output_summary.append({ 

110 'name': name, 

111 'analysis': analysis_name 

112 }) 

113 # If the output file already exists then skip this iteration 

114 if exists(numbered_output_analysis_filepath): 

115 continue 

116 

117 print(f'Calculating distances for {name} -> {analysis_name}') 

118 # Get the run selection atom indices 

119 atom_indices = run['selection'].atom_indices 

120 

121 # Calculate the RMSD matrix 

122 distance_matrix = np.empty((traj.n_frames, traj.n_frames)) 

123 for i in range(traj.n_frames): 

124 print(f' Frame {i+1} out of {traj.n_frames}', end='\r') 

125 # Calculate the RMSD between every frame in the trajectory and the frame 'i' 

126 distance_matrix[i] = mdt.rmsd(traj, traj, i, atom_indices=atom_indices) 

127 

128 # Get the maximum RMSD value in the whole matrix 

129 maximum_rmsd = np.max(distance_matrix) 

130 # Get the minimum RMSD value in the whole matrix 

131 # Discard 0s from frames against themselves 

132 minimum_rmsd = np.min(distance_matrix[distance_matrix != 0]) 

133 

134 # Set the difference between the minimum and maximum to determine the cutoffs step 

135 rmsd_difference = maximum_rmsd - minimum_rmsd 

136 rmsd_step = rmsd_difference / n_steps 

137 

138 # Set the initial RMSD cutoff 

139 cutoff = round_to_thousandths(minimum_rmsd + rmsd_difference / 2) 

140 # Keep a register of already tried cutoffs so we do not repeat 

141 already_tried_cutoffs = set() 

142 

143 # Adjust the RMSD cutoff until we get the desired amount of clusters 

144 # Note that final clusters will be ordered by the time they appear 

145 clusters = None 

146 n_clusters = 0 

147 while n_clusters != target_n_clusters: 

148 # Find clusters 

149 print(f' Trying with cutoff {cutoff}', end='') 

150 clusters = clustering(distance_matrix, cutoff) 

151 n_clusters = len(clusters) 

152 print(f' -> Found {n_clusters} clusters') 

153 # Update the cutoff 

154 already_tried_cutoffs.add(cutoff) 

155 if n_clusters > target_n_clusters: 

156 cutoff = round_to_thousandths(cutoff + rmsd_step) 

157 if n_clusters < target_n_clusters: 

158 cutoff = round_to_thousandths(cutoff - rmsd_step) 

159 # If we already tried the updated cutoff then we are close enough to the desired number of clusters 

160 if cutoff in already_tried_cutoffs: 

161 break 

162 # Erase previous log and write in the same line 

163 delete_previous_log() 

164 

165 # Count the number of frames per cluster 

166 cluster_lengths = [ len(cluster) for cluster in clusters ] 

167 

168 # Resort clusters in a "cluster per frame" structure 

169 frame_clusters = np.empty(traj.n_frames, dtype=int) 

170 for c, cluster in enumerate(clusters): 

171 for frame in cluster: 

172 frame_clusters[frame] = c 

173 

174 # Count the transitions between clusters 

175 transitions = [] 

176 

177 # Iterate over the different frames 

178 previous_cluster = frame_clusters[0] 

179 for cluster in frame_clusters[1:]: 

180 # If this is the same cluster then there is no transition here 

181 if previous_cluster == cluster: 

182 continue 

183 # Otherwise save the transition 

184 transition = previous_cluster, cluster 

185 transitions.append(transition) 

186 previous_cluster = cluster 

187 

188 print(f' Found {len(transitions)} transitions') 

189 

190 # Count every different transition 

191 transition_counts = {} 

192 for transition in transitions: 

193 current_count = transition_counts.get(transition, 0) 

194 transition_counts[transition] = current_count + 1 

195 

196 # Now for every cluster find the most representative frame (i.e. the one with less RMSD distance to its neighbours) 

197 # Then make a screenshot for this specific frame 

198 representative_frames = [] 

199 # Save the screenshot parameters so we can keep images coherent between clusters 

200 screenshot_parameters = None 

201 print(' Generating cluster screenshots') # This will be reprinted 

202 for c, cluster in enumerate(clusters): 

203 most_representative_frame = None 

204 min_distance = float('inf') # Positive infinity 

205 for frame, neighbour_frames in otherwise(cluster): 

206 # Calculate the sum of all rmsd distances 

207 total_distance = 0 

208 for neighbour_frame in neighbour_frames: 

209 total_distance += distance_matrix[frame][neighbour_frame] 

210 # If the distance is inferior the current minimum then set this frame as the most representative 

211 if total_distance < min_distance: 

212 most_representative_frame = frame 

213 min_distance = total_distance 

214 # Save the most representative frame in the list 

215 representative_frames.append(most_representative_frame) 

216 # Once we have the most representative frame we take a screenshot 

217 # This screenshots will be then uploaded to the database as well 

218 # Generate a pdb with coordinates from the most representative frame 

219 mdt_frame = traj[most_representative_frame] 

220 coordinates = mdt_frame.xyz[0] * 10 # We multiply by to restor Ångstroms 

221 # WARNING: a PDB generated by MDtraj may have problems thus leading to artifacts in the screenshot 

222 # WARNING: to avoid this we add the coordinates to the structure 

223 # coordinates.save(AUXILIAR_PDB_FILENAME) 

224 auxiliar_structure.set_new_coordinates(coordinates) 

225 # Set the screenshot filename from the input template 

226 screenshot_filename = output_screenshot_filepath.replace('*', str(r).zfill(2)).replace('??', str(c).zfill(2)) 

227 # Generate the screenshot 

228 reprint(f' Generating cluster screenshot {c+1}/{n_clusters}') 

229 screenshot_parameters = get_screenshot(auxiliar_structure, screenshot_filename, 

230 parameters=screenshot_parameters) 

231 

232 # Set the output clusters which include all frames in the cluster and the main or more representative frame 

233 output_clusters = [] 

234 for frames, most_representative_frame in zip(clusters, representative_frames): 

235 output_clusters.append({ 'frames': frames, 'main': most_representative_frame * step }) 

236 

237 # Set the output transitions in a hashable and json parseable way 

238 output_transitions = [] 

239 for transition, count in transition_counts.items(): 

240 # Set frames as regular ints to make them json serializable 

241 output_transitions.append({ 'from': int(transition[0]), 'to': int(transition[1]), 'count': count }) 

242 

243 # Set the output analysis 

244 output_analysis = { 

245 'name': name, 

246 'cutoff': cutoff, 

247 'clusters': output_clusters, 

248 'transitions': output_transitions, 

249 'step': step, 

250 'version': '0.1.0', 

251 } 

252 

253 # The output filename must be different for every run to avoid overwritting previous results 

254 # However the filename is not important regarding the database since this analysis is found by its 'run' 

255 save_json(output_analysis, numbered_output_analysis_filepath) 

256 

257 # Save the final summary 

258 save_json(output_summary, output_analysis_filepath) 

259 

260# Set a function to cluster frames in a RMSD matrix given a RMSD cutoff 

261# https://github.com/boneta/RMSD-Clustering/blob/master/rmsd_clustering/clustering.py 

262def clustering (rmsd_matrix : np.ndarray, cutoff : float) -> list: 

263 clusters = [] 

264 for i in range(rmsd_matrix.shape[0]): 

265 for cluster in clusters: 

266 if all(rmsd_matrix[i,j] < cutoff for j in cluster): 

267 cluster.append(i) 

268 break 

269 else: 

270 clusters.append([i]) 

271 return clusters