Coverage for model_workflow/analyses/pca_contacts.py: 0%
62 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
1from itertools import product
3import mdtraj as mdt
4import pytraj as pt
5import numpy as np
6# from scipy.special import expit
7from sklearn.decomposition import PCA
9from model_workflow.utils.auxiliar import save_json
12def pytraj_residue_pairs(resid_1, resid_2):
13 # get list of residue pairs
14 # residue_pairs = list(product(*residue_lists))
15 residue_pairs = [f":{i} :{j}" for i, j in product(resid_1, resid_2)]
16 return residue_pairs
19def mdtraj_residue_pairs(resid_1, resid_2):
20 residue_pairs = list(product(resid_1, resid_2))
21 return residue_pairs
24def pytraj_distances(traj, residue_pairs):
25 # get pairwise minimal distances
26 dists = pt.distance(traj, residue_pairs)
27 dists = dists.reshape(-1, dists.shape[0])
28 return dists
31def mdtraj_distances(traj, residue_pairs):
32 dists, _ = mdt.compute_contacts(traj, contacts=residue_pairs)
33 return dists
36def get_most_frequent_pairs(
37 traj,
38 dists,
39 residue_pairs,
40 frequency_threshold,
41 distance_threshold):
42 # how frequent does the contact need to be: frequency_threshold
43 threshold = frequency_threshold * len(traj)
44 # set contact definitions: distance_threshold
45 close_contacts = np.sum(
46 dists < distance_threshold, axis=0)
47 # finding all pairs where contact frequency exceeds threshold
48 contact_is_frequent = close_contacts > threshold
50 # selecting indices for which the above is true
51 selection = np.where(contact_is_frequent)[0]
53 # select the residue indices pairs corresponding to the selection
54 frequent_pairs = np.array(residue_pairs)[selection]
56 return frequent_pairs
59def sigmoid(x):
60 return 1/(1+np.exp(-x))
63def pca_contacts(
64 trajectory: str,
65 topology: str,
66 interactions: list,
67 output_analysis_filename: str,
68 use_pytraj=True,
69 n_components=2,
70 distance_threshold=15.0,
71 frequency_threshold=0.05,
72 smooth=5.0):
74 print('-> Running PCA contacts analysis')
76 # Return before doing anything if there are no interactions
77 if len(interactions) == 0:
78 return
80 output_analysis = []
81 for interaction in interactions:
82 # DANI: Estos campos ya no están en interactions
83 # DANI: Se pueden recuperar tal y como se hace en distance_per_residue
84 # DANI: No lo hice en su día porque este análisis nunca se ha llegado a usar
85 residue_lists = (interaction["pt_residues_1"],
86 interaction["pt_residues_2"])
87 # get list of residue pairs and pairwise minimal distances
88 if use_pytraj:
89 traj = pt.load(trajectory, top=topology)
90 residue_pairs = pytraj_residue_pairs(*residue_lists)
91 dists = pytraj_distances(traj, residue_pairs)
92 else:
93 traj = mdt.load(trajectory, top=topology)
94 residue_pairs = mdtraj_residue_pairs(*residue_lists)
95 dists = mdtraj_distances(traj, residue_pairs)
97 frequent_pairs = get_most_frequent_pairs(
98 traj,
99 dists,
100 residue_pairs,
101 frequency_threshold,
102 distance_threshold)
104 # compute new only with frequent pairs distances
105 if use_pytraj:
106 dists = pytraj_distances(traj, frequent_pairs)
107 else:
108 dists = mdtraj_distances(traj, frequent_pairs)
110 # smooth distances
111 smooth_distances = 1 - sigmoid(smooth * (dists - distance_threshold))
113 # compute PCA
114 pca = PCA(n_components=n_components)
115 transformed = pca.fit_transform(smooth_distances)
117 # most important pairs indices ordered in descending order
118 pca_components_argsort = pca.components_.argsort()
119 most_important_contacts = [frequent_pairs[pca_components_argsort[i][::-1]]
120 for i in range(n_components)]
121 most_important_contacts = [tuple(map(int, s.replace(":", "").split()))
122 for arr in most_important_contacts
123 for s in arr]
125 # writing data
126 # the transformed distances can be plotted individually
127 # as a function of the trajectory time (frames)
128 transformed_dists = {
129 f"transformed_dist_{i+1}": list(transformed.T[i])
130 for i in range(n_components)}
132 # each of the components (axes in the feature space)
133 # sorted in descending order
134 components_values = {
135 f"component_{i+1}": list(pca.components_[i][pca_components_argsort[i]])
136 for i in range(n_components)}
138 # residue pairs sorted according to the corresponding components
139 ordered_residues = {
140 f"ordered_residues_{i+1}": list(most_important_contacts[i])
141 for i in range(n_components)}
143 # residues that are part of both interaction groups used for the analysis
144 interaction_residues = {
145 f"interaction_residues": residue_lists}
147 output_analysis.append(transformed_dists)
148 output_analysis.append(components_values)
149 output_analysis.append(ordered_residues)
150 output_analysis.append(interaction_residues)
152 if output_analysis:
153 # Export the analysis in json format
154 save_json(output_analysis, output_analysis_filename)