Coverage for model_workflow/analyses/pca_contacts.py: 0%

62 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-23 10:54 +0000

1from itertools import product 

2 

3import mdtraj as mdt 

4import pytraj as pt 

5import numpy as np 

6# from scipy.special import expit 

7from sklearn.decomposition import PCA 

8 

9from model_workflow.utils.auxiliar import save_json 

10 

11 

12def pytraj_residue_pairs(resid_1, resid_2): 

13 # get list of residue pairs 

14 # residue_pairs = list(product(*residue_lists)) 

15 residue_pairs = [f":{i} :{j}" for i, j in product(resid_1, resid_2)] 

16 return residue_pairs 

17 

18 

19def mdtraj_residue_pairs(resid_1, resid_2): 

20 residue_pairs = list(product(resid_1, resid_2)) 

21 return residue_pairs 

22 

23 

24def pytraj_distances(traj, residue_pairs): 

25 # get pairwise minimal distances 

26 dists = pt.distance(traj, residue_pairs) 

27 dists = dists.reshape(-1, dists.shape[0]) 

28 return dists 

29 

30 

31def mdtraj_distances(traj, residue_pairs): 

32 dists, _ = mdt.compute_contacts(traj, contacts=residue_pairs) 

33 return dists 

34 

35 

36def get_most_frequent_pairs( 

37 traj, 

38 dists, 

39 residue_pairs, 

40 frequency_threshold, 

41 distance_threshold): 

42 # how frequent does the contact need to be: frequency_threshold 

43 threshold = frequency_threshold * len(traj) 

44 # set contact definitions: distance_threshold 

45 close_contacts = np.sum( 

46 dists < distance_threshold, axis=0) 

47 # finding all pairs where contact frequency exceeds threshold 

48 contact_is_frequent = close_contacts > threshold 

49 

50 # selecting indices for which the above is true 

51 selection = np.where(contact_is_frequent)[0] 

52 

53 # select the residue indices pairs corresponding to the selection 

54 frequent_pairs = np.array(residue_pairs)[selection] 

55 

56 return frequent_pairs 

57 

58 

59def sigmoid(x): 

60 return 1/(1+np.exp(-x)) 

61 

62 

63def pca_contacts( 

64 trajectory: str, 

65 topology: str, 

66 interactions: list, 

67 output_analysis_filename: str, 

68 use_pytraj=True, 

69 n_components=2, 

70 distance_threshold=15.0, 

71 frequency_threshold=0.05, 

72 smooth=5.0): 

73 

74 print('-> Running PCA contacts analysis') 

75 

76 # Return before doing anything if there are no interactions 

77 if len(interactions) == 0: 

78 return 

79 

80 output_analysis = [] 

81 for interaction in interactions: 

82 # DANI: Estos campos ya no están en interactions 

83 # DANI: Se pueden recuperar tal y como se hace en distance_per_residue 

84 # DANI: No lo hice en su día porque este análisis nunca se ha llegado a usar 

85 residue_lists = (interaction["pt_residues_1"], 

86 interaction["pt_residues_2"]) 

87 # get list of residue pairs and pairwise minimal distances 

88 if use_pytraj: 

89 traj = pt.load(trajectory, top=topology) 

90 residue_pairs = pytraj_residue_pairs(*residue_lists) 

91 dists = pytraj_distances(traj, residue_pairs) 

92 else: 

93 traj = mdt.load(trajectory, top=topology) 

94 residue_pairs = mdtraj_residue_pairs(*residue_lists) 

95 dists = mdtraj_distances(traj, residue_pairs) 

96 

97 frequent_pairs = get_most_frequent_pairs( 

98 traj, 

99 dists, 

100 residue_pairs, 

101 frequency_threshold, 

102 distance_threshold) 

103 

104 # compute new only with frequent pairs distances 

105 if use_pytraj: 

106 dists = pytraj_distances(traj, frequent_pairs) 

107 else: 

108 dists = mdtraj_distances(traj, frequent_pairs) 

109 

110 # smooth distances 

111 smooth_distances = 1 - sigmoid(smooth * (dists - distance_threshold)) 

112 

113 # compute PCA 

114 pca = PCA(n_components=n_components) 

115 transformed = pca.fit_transform(smooth_distances) 

116 

117 # most important pairs indices ordered in descending order 

118 pca_components_argsort = pca.components_.argsort() 

119 most_important_contacts = [frequent_pairs[pca_components_argsort[i][::-1]] 

120 for i in range(n_components)] 

121 most_important_contacts = [tuple(map(int, s.replace(":", "").split())) 

122 for arr in most_important_contacts 

123 for s in arr] 

124 

125 # writing data 

126 # the transformed distances can be plotted individually 

127 # as a function of the trajectory time (frames) 

128 transformed_dists = { 

129 f"transformed_dist_{i+1}": list(transformed.T[i]) 

130 for i in range(n_components)} 

131 

132 # each of the components (axes in the feature space) 

133 # sorted in descending order 

134 components_values = { 

135 f"component_{i+1}": list(pca.components_[i][pca_components_argsort[i]]) 

136 for i in range(n_components)} 

137 

138 # residue pairs sorted according to the corresponding components 

139 ordered_residues = { 

140 f"ordered_residues_{i+1}": list(most_important_contacts[i]) 

141 for i in range(n_components)} 

142 

143 # residues that are part of both interaction groups used for the analysis 

144 interaction_residues = { 

145 f"interaction_residues": residue_lists} 

146 

147 output_analysis.append(transformed_dists) 

148 output_analysis.append(components_values) 

149 output_analysis.append(ordered_residues) 

150 output_analysis.append(interaction_residues) 

151 

152 if output_analysis: 

153 # Export the analysis in json format 

154 save_json(output_analysis, output_analysis_filename)