Coverage for model_workflow/utils/heatmaps_nassa.py: 57%
337 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
1import pathlib
2import itertools
4import numpy as np
5import pandas as pd
6import matplotlib as mpl
7import matplotlib.pyplot as plt
8from matplotlib.tri import Triangulation
9from matplotlib.colors import ListedColormap
10from model_workflow.utils.nucleicacid import NucleicAcid
12def get_axes(subunit_len, base):
13 '''Create the correct labels according to the lenght of the subunit'''
15 yaxis = [
16 f"{base}{base}",
17 f"{base}C",
18 f"C{base}",
19 "CC",
20 f"G{base}",
21 "GC",
22 f"A{base}",
23 "AC",
24 f"{base}G",
25 f"{base}A",
26 "CG",
27 "CA",
28 "GG",
29 "GA",
30 "AG",
31 "AA"]
33 if subunit_len == 3:
34 xaxis = f"G A C {base}".split()
35 nucleotide_order = [b.join(f) for f in yaxis for b in xaxis]
37 elif subunit_len == 4:
38 xaxis = yaxis[::-1]
39 nucleotide_order = [b.join(f) for f in yaxis for b in xaxis]
41 elif subunit_len == 5:
42 #xaxis = f"G A C {base}".split()
43 xaxis = []
44 yaxis = []
45 nucl = ['A', 'C', 'G', base]
46 nucleotide_order = [a + b + c + d + e for a in nucl for b in nucl for c in nucl for d in nucl for e in nucl]
47 for i in nucleotide_order:
48 xaxis.append(i[1:-1])
49 yaxis.append(i[0] + "..." + i[-1:])
50 #yaxis.append(i[:2] + "_" + i[-2:])
51 yaxis = list(set(xaxis))
53 elif subunit_len == 6:
54 nucleotide_order= [f"{n1}{n2}{n3}" for n1 in yaxis for n2 in yaxis for n3 in yaxis]
55 yaxis = []
56 xaxis = []
57 for i in nucleotide_order:
58 yaxis.append(i[:2] + "_" + i[-2:])
59 xaxis.append(i[2:4])
60 yaxis = list(set(yaxis))
61 xaxis = list(set(xaxis))
63 return xaxis, yaxis, nucleotide_order
66def reorder_labels_rotated_plot(df, subunit_name, tetramer_order):
67 '''For the rotated plot: reorder the nucleotides labels matching them to their according values of the dataframe. '''
69 if subunit_name == 'tetramer':
70 sorted_index = dict(zip(tetramer_order, range(len(tetramer_order))))
71 df["subunit_rank"] = df[subunit_name].map(sorted_index)
72 df = df.sort_values(by="subunit_rank")
73 df = df.drop("subunit_rank", axis=1)
74 df = df.reset_index(drop=True)
76 tetramer_order.sort()
77 all_tetramers = pd.DataFrame({subunit_name: ["".join(tetramer) for tetramer in tetramer_order]})
78 merged_df = pd.merge(all_tetramers, df, on=subunit_name, how='left')
79 merged_df['subunit_rank'] = merged_df[subunit_name].map(lambda x: tetramer_order.index(x) if not pd.isna(x) else np.nan)
80 merged_df = merged_df.sort_values(by='subunit_rank').reset_index(drop=True)
82 if subunit_name == 'hexamer':
83 centre= len(merged_df[subunit_name][1])//2
84 merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[:centre-2]+"...."+x[centre+2:])
85 merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[centre-2:centre+2])
86 merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True)
87 yaxiss = []
88 xaxiss = []
89 for i in range(len(merged_df)):
90 xaxiss.append(merged_df[subunit_name][i][centre-2:centre+2])
91 yaxiss.append(merged_df[subunit_name][i][:centre-2] + "...." + merged_df[subunit_name][i][centre+2:])
93 xaxis1=[]
94 for nucl in xaxiss:
95 if nucl not in xaxis1:
96 xaxis1.append(nucl)
98 yaxis1 = []
99 for nucl in yaxiss:
100 if nucl not in yaxis1:
101 yaxis1.append(nucl)
102 df1 = merged_df
104 if subunit_name == 'pentamer':
105 centre= len(merged_df[subunit_name][1])%2
106 merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[1:-1])
107 merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[0]+"..."+x[-1:])
108 merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True)
109 yaxiss = []
110 xaxiss = []
111 for i in range(len(merged_df)):
112 xaxiss.append(merged_df[subunit_name][i][1:-1])
113 yaxiss.append(merged_df[subunit_name][i][0] + "..." + merged_df[subunit_name][i][-1:])
115 xaxis1=[]
116 for nucl in xaxiss:
117 if nucl not in xaxis1:
118 xaxis1.append(nucl)
120 yaxis1 = []
121 for nucl in yaxiss:
122 if nucl not in yaxis1:
123 yaxis1.append(nucl)
124 df1 = merged_df
126 # if subunit_name == 'pentamer':
127 # centre= len(merged_df[subunit_name][1])%2
128 # merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[:centre+1]+"_"+x[centre+2:])
129 # merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[centre+1])
130 # merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True)
131 # yaxiss = []
132 # xaxiss = []
133 # for i in range(len(merged_df)):
134 # yaxiss.append(merged_df[subunit_name][i][centre+1])
135 # xaxiss.append(merged_df[subunit_name][i][:centre+1] + "_" + merged_df[subunit_name][i][centre+2:])
137 # xaxis1=[]
138 # for nucl in xaxiss:
139 # if nucl not in xaxis1:
140 # xaxis1.append(nucl)
142 # yaxis1 = []
143 # for nucl in yaxiss:
144 # if nucl not in yaxis1:
145 # yaxis1.append(nucl)
146 # df1 = merged_df
148 return df1, xaxis1, yaxis1
150def reorder_labels_straight_plot(df, subunit_name, tetramer_order):
151 '''For the straight plot: reorder the nucleotides labels matching them to their according values of the dataframe. '''
153 if subunit_name == 'tetramer':
154 sorted_index = dict(zip(tetramer_order, range(len(tetramer_order))))
155 df["subunit_rank"] = df[subunit_name].map(sorted_index)
156 df = df.sort_values(by="subunit_rank")
157 df = df.drop("subunit_rank", axis=1)
158 df = df.reset_index(drop=True)
160 tetramer_order.sort()
161 all_tetramers = pd.DataFrame({subunit_name: ["".join(tetramer) for tetramer in tetramer_order]})
162 merged_df = pd.merge(all_tetramers, df, on=subunit_name, how='left')
163 merged_df['subunit_rank'] = merged_df[subunit_name].map(lambda x: tetramer_order.index(x) if not pd.isna(x) else np.nan)
164 merged_df = merged_df.sort_values(by='subunit_rank').reset_index(drop=True)
166 if subunit_name == 'hexamer':
167 centre= len(merged_df[subunit_name][1])//2
168 merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[:centre-1] + "_" + x[centre+1:])
169 merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[centre-1:centre+1])
170 merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True)
171 yaxiss = []
172 xaxiss = []
173 for i in range(len(merged_df)):
174 xaxiss.append(merged_df[subunit_name][i][centre-1:centre+1])
175 yaxiss.append(merged_df[subunit_name][i][:centre-1] + "_" + merged_df[subunit_name][i][centre+1:])
177 xaxis=[]
178 for nucl in xaxiss:
179 if nucl not in xaxis:
180 xaxis.append(nucl)
182 yaxis = []
183 for nucl in yaxiss:
184 if nucl not in yaxis:
185 yaxis.append(nucl)
186 df = merged_df
188 if subunit_name == 'pentamer':
189 centre= len(merged_df[subunit_name][1])%2
190 merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[1:-1])
191 merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[0]+"..."+x[-1:])
193 merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True)
194 yaxiss = []
195 xaxiss = []
196 for i in range(len(merged_df)):
197 yaxiss.append(merged_df[subunit_name][i][1:-1])
198 xaxiss.append(merged_df[subunit_name][i][0] + "..." + merged_df[subunit_name][i][-1:])
200 xaxis=[]
201 for nucl in xaxiss:
202 if nucl not in xaxis:
203 xaxis.append(nucl)
205 yaxis = []
206 for nucl in yaxiss:
207 if nucl not in yaxis:
208 yaxis.append(nucl)
209 df = merged_df
211 # if subunit_name == 'pentamer':
212 # centre= len(merged_df[subunit_name][1])%2
213 # merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[:centre+1] + "_" + x[centre+2:])
214 # merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[centre+1])
215 # merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True)
216 # yaxiss = []
217 # xaxiss = []
218 # for i in range(len(merged_df)):
219 # xaxiss.append(merged_df[subunit_name][i][centre+1])
220 # yaxiss.append(merged_df[subunit_name][i][:centre+1] + "_" + merged_df[subunit_name][i][centre+2:])
222 # xaxis=[]
223 # for nucl in xaxiss:
224 # if nucl not in xaxis:
225 # xaxis.append(nucl)
227 # yaxis = []
228 # for nucl in yaxiss:
229 # if nucl not in yaxis:
230 # yaxis.append(nucl)
231 # df = merged_df
233 return df, xaxis, yaxis
235def arlequin_plot(
236 df,
237 global_mean,
238 global_std,
239 helpar,
240 save_path,
241 unit_name,
242 unit_len,
243 base,
244 label_offset=0.5):
246 xaxis, yaxis, tetramer_order = get_axes(unit_len, base)
247 df, xaxis, yaxis = reorder_labels_straight_plot(df, unit_name, tetramer_order)
248 df1, xaxis1, yaxis1 = reorder_labels_rotated_plot(df, unit_name, tetramer_order)
250 sz1 = df["col1"].ravel()
251 sz2 = df["col2"].ravel()
253 if unit_name == 'hexamer':
254 M = 4**2
255 N = 4**(unit_len - 2)
256 M_1 = 4 ** (unit_len - 2)
257 N_1 = 4 ** 2
259 if unit_name == 'pentamer':
260 M = 4 * 4
261 N = 4 ** 3
262 M_1 = 4 ** 3
263 N_1 = 4 * 4
265 # STRAIGHT PLOT
267 x = np.arange(M + 1)
268 y = np.arange(N + 1)
269 xs, ys = np.meshgrid(x, y)
271 upper_triangle = [(i + j*(M+1), i+1 + j*(M+1), i+1 + (j+1)*(M+1))
272 for j in range(N) for i in range(M)]
273 lower_triangle = [(i + j*(M+1), i+1 + (j+1)*(M+1), i + (j+1)*(M+1))
274 for j in range(N) for i in range(M)]
275 triang1 = Triangulation(xs.ravel(), ys.ravel(), upper_triangle)
276 triang2 = Triangulation(xs.ravel(), ys.ravel(), lower_triangle)
278 fig, axs = plt.subplots(
279 1,
280 1,
281 figsize=(8,18),
282 dpi=300,
283 tight_layout=True)
285 colormap = plt.get_cmap("bwr", 3).reversed()
286 colormap.set_bad(color="grey")
287 img1 = axs.tripcolor(triang1, sz1, cmap=colormap, vmin=-1, vmax=1)
288 _ = axs.tripcolor(triang2, sz2, cmap=colormap, vmin=-1, vmax=1)
290 axs.grid()
291 xlocs = np.arange(len(xaxis))
292 ylocs = np.arange(len(yaxis))
293 _ = axs.set_xticks(xlocs)
294 _ = axs.set_xticklabels("")
295 _ = axs.set_yticks(ylocs)
296 _ = axs.set_yticklabels("")
297 _ = axs.set_xticks(xlocs+label_offset, minor=True)
298 _ = axs.set_xticklabels(xaxis, minor=True,fontsize=8)
299 _ = axs.set_yticks(ylocs+label_offset, minor=True)
300 _ = axs.set_yticklabels(yaxis, minor=True,fontsize=6)
302 _ = axs.set_xlim(0, M)
303 _ = axs.set_ylim(0, N)
304 axs.set_title(helpar.upper())
305 cbar = fig.colorbar(img1, ax=axs, ticks=[-1, 0, 1], shrink=0.4)
306 cbar.ax.set_yticklabels([
307 f"< {global_mean:.2f}-{global_std:.2f}",
308 f"{global_mean:.2f}$\pm${global_std:.2f}",
309 f"> {global_mean:.2f}+{global_std:.2f}"])
311 file_path = pathlib.Path(save_path) / f"{helpar}.pdf"
312 fig.savefig(fname=file_path, format="pdf")
314 # ROTATED PLOT
316 sz1_1 = df1["col1"].ravel()
317 sz2_1 = df1["col2"].ravel()
319 x_1 = np.arange(M_1 + 1)
320 y_1 = np.arange(N_1 + 1)
321 xs_1, ys_1 = np.meshgrid(x_1, y_1)
323 upper_triangle_1 = [(i + j*(M_1+1), i+1 + j*(M_1+1), i+1 + (j+1)*(M_1+1))
324 for j in range(N_1) for i in range(M_1)]
325 lower_triangle_1 = [(i + j*(M_1+1), i+1 + (j+1)*(M_1+1), i + (j+1)*(M_1+1))
326 for j in range(N_1) for i in range(M_1)]
327 triang1_1 = Triangulation(xs_1.ravel(), ys_1.ravel(), upper_triangle_1)
328 triang2_1 = Triangulation(xs_1.ravel(), ys_1.ravel(), lower_triangle_1)
330 fig_1, axs_1 = plt.subplots(
331 1,
332 1,
333 figsize=(22, 6),
334 dpi=300,
335 tight_layout=True)
337 colormap_1 = plt.get_cmap("bwr", 3).reversed()
338 colormap_1.set_bad(color="grey")
339 img1_1 = axs_1.tripcolor(triang1_1, sz1_1, cmap=colormap_1, vmin=-1, vmax=1)
340 _ = axs_1.tripcolor(triang2_1, sz2_1, cmap=colormap_1, vmin=-1, vmax=1)
342 axs_1.grid()
343 xlocs_1 = np.arange(len(xaxis1))
344 ylocs_1 = np.arange(len(yaxis1))
345 _ = axs_1.set_xticks(xlocs_1)
346 _ = axs_1.set_xticklabels("")
347 _ = axs_1.set_yticks(ylocs_1)
348 _ = axs_1.set_yticklabels("")
349 _ = axs_1.set_xticks(xlocs_1+label_offset, minor=True)
350 _ = axs_1.set_xticklabels(xaxis1, minor=True,fontsize=4, rotation=90)
351 _ = axs_1.set_yticks(ylocs_1+label_offset, minor=True)
352 _ = axs_1.set_yticklabels(yaxis1, minor=True,fontsize=6)
353 _ = axs_1.set_xlim(0, M_1)
354 _ = axs_1.set_ylim(0, N_1)
355 axs_1.set_title(helpar.upper())
356 cbar_1 = fig_1.colorbar(img1_1, ax=axs_1, ticks=[-1, 0, 1], shrink=0.4)
357 cbar_1.ax.set_yticklabels([
358 f"< {global_mean:.2f}-{global_std:.2f}",
359 f"{global_mean:.2f}$\pm${global_std:.2f}",
360 f"> {global_mean:.2f}+{global_std:.2f}"])
362 file_path1 = pathlib.Path(save_path) / f"{helpar}_rotated.pdf"
363 fig_1.savefig(fname=file_path1, format="pdf")
364 return fig, axs, fig_1, axs_1
367def bconf_heatmap(df, fname, save_path, subunit_len, base="T", label_offset=0.05):
368 print('subunit len: ', subunit_len)
369 if subunit_len == 3:
370 yaxis = [
371 f"{base}{base}",
372 f"{base}C",
373 f"C{base}",
374 "CC",
375 f"G{base}",
376 "GC",
377 f"A{base}",
378 "AC",
379 f"{base}G",
380 f"{base}A",
381 "CG",
382 "CA",
383 "GG",
384 "GA",
385 "AG",
386 "AA"]
387 xaxis = f"G A C {base}".split()
388 nucleotide_order = pd.DataFrame(
389 [b.join(f) for f in yaxis for b in xaxis],
390 columns=["trimer"])
391 df = df.merge(nucleotide_order, how="right", on="trimer")
392 fig, ax = plt.subplots()
394 elif subunit_len == 4:
395 xaxis = [
396 "GG",
397 "GA",
398 "AG",
399 "AA",
400 "GC",
401 f"G{base}",
402 f"A{base}",
403 "AC",
404 "CA",
405 f"{base}A",
406 f"{base}G",
407 "CG",
408 "CC",
409 f"C{base}",
410 f"{base}C",
411 f"{base}{base}"]
412 yaxis = xaxis.copy()
413 tetramer_order = pd.DataFrame(
414 [b.join(f) for f in yaxis for b in xaxis],
415 columns=["tetramer"])
416 df = df.merge(tetramer_order, how="right", on="tetramer")
417 fig, ax = plt.subplots()
419 elif subunit_len == 5:
420 raise ValueError('The length of the subunit 5 is not a valid option for bconf analyses. Try with 4 or 6.')
422 elif subunit_len == 6:
423 xaxis = [
424 f"{base}{base}",
425 f"{base}C",
426 f"C{base}",
427 "CC",
428 f"G{base}",
429 "GC",
430 f"A{base}",
431 "AC",
432 f"{base}G",
433 f"{base}A",
434 "CG",
435 "CA",
436 "GG",
437 "GA",
438 "AG",
439 "AA"]
440 nucleotide_order= [f"{n1}{n2}{n3}" for n1 in xaxis for n2 in xaxis for n3 in xaxis]
441 yaxis = []
442 xaxis = []
443 for i in nucleotide_order:
444 xaxis.append(i[:2] + "_" + i[-2:])
445 yaxis.append(i[2:4])
446 yaxis = list(set(yaxis))
447 xaxis = list(set(xaxis))
448 # tetramer_order = pd.DataFrame(
449 # [b.join(f) for f in xaxis for b in yaxis],
450 # columns=["hexamer"])
451 tetramer_order = pd.DataFrame(
452 sorted(nucleotide_order),
453 columns=["hexamer"])
454 df = df.merge(tetramer_order, how="right", on="hexamer")
455 fig, ax = plt.subplots(
456 1,
457 1,
458 figsize=(22, 12),
459 dpi=300,
460 tight_layout=True)
462 colormap = ListedColormap([
463 "darkblue",
464 "blue",
465 "lightblue",
466 "lightgreen",
467 "lime",
468 "orange",
469 "red",
470 "crimson"])
471 colormap.set_bad(color="grey")
473 # plot
474 if subunit_len == 3:
475 im = ax.imshow(df["pct"].to_numpy().reshape((16, 4)), cmap=colormap)
476 if subunit_len == 4:
477 im = ax.imshow(df["pct"].to_numpy().reshape((16, 16)), cmap=colormap)
478 if subunit_len == 6:
479 im = ax.imshow(df["pct"].to_numpy().reshape((16, 256)), cmap=colormap)
480 plt.colorbar(im)
481 # axes
482 xlocs = np.arange(len(xaxis))
483 ylocs = np.arange(len(yaxis))
484 _ = ax.set_xticks(xlocs)
485 _ = ax.set_xticklabels(xaxis, minor=True, fontsize=4)
486 _ = ax.set_yticks(ylocs)
487 _ = ax.set_yticklabels(yaxis, minor=True,fontsize=4)
488 ax.set_title((fname + " conformations").upper())
489 # save as pdf
490 file_path = pathlib.Path(save_path) / f"{fname}_percentages.pdf"
491 fig.savefig(fname=file_path, format="pdf")
494def correlation_plot(data, fname, save_path, base="T", label_offset=0.05):
495 # define colormap
496 cmap = mpl.colors.ListedColormap([
497 "blue",
498 "cornflowerblue",
499 "lightskyblue",
500 "white",
501 "mistyrose",
502 "tomato",
503 "red"])
504 bounds = [-1.0, -.73, -.53, -.3, .3, .53, .73, 1.0]
505 norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
506 cmap.set_bad(color="gainsboro")
508 # reorder dataset
509 coordinates = list(set(data.index.get_level_values(0)))
510 data = data.loc[coordinates][coordinates].sort_index(
511 level=1, axis=0).sort_index(level=1, axis=1)
513 for crd1, crd2 in itertools.combinations_with_replacement(
514 coordinates,
515 r=2):
516 crd_data = data.loc[crd1][crd2]
518 # plot
519 fig, ax = plt.subplots(
520 1,
521 1,
522 dpi=300,
523 tight_layout=True)
524 im = ax.imshow(crd_data, cmap=cmap, norm=norm, aspect='auto')
525 plt.colorbar(im)
527 # axes
528 units = set(crd_data.index)
529 xlocs = np.arange(len(units))
530 _ = ax.set_xticks(xlocs)
531 _ = ax.set_xticklabels(units, rotation=90)
533 ylocs = np.arange(len(units))
534 _ = ax.set_yticks(ylocs)
535 _ = ax.set_yticklabels(units)
536 ax.set_title(f"rows: {crd1} | columns: {crd2}")
537 plt.tight_layout()
539 # save as pdf
540 file_path = pathlib.Path(save_path) / f"{crd1}_{crd2}.pdf"
541 fig.savefig(fname=file_path, format="pdf")
543 plt.close()
545 # plot
546 fig, ax = plt.subplots(
547 1,
548 1,
549 dpi=300,
550 tight_layout=True)
551 im = ax.imshow(data, cmap=cmap, norm=norm, aspect='auto')
552 plt.colorbar(im)
554 # axes
555 start = len(data) // (2 * len(coordinates))
556 step = 2 * start
557 locs = np.arange(start, len(data)-1, step)
558 _ = ax.set_xticks(locs)
559 _ = ax.set_yticks(locs)
560 _ = ax.set_xticklabels(coordinates, rotation=90)
561 _ = ax.set_yticklabels(coordinates)
563 plt.tight_layout()
565 # save as pdf
566 file_path = pathlib.Path(save_path) / f"{fname}.pdf"
567 fig.savefig(fname=file_path, format="pdf")
570def basepair_plot(
571 data,
572 fname,
573 save_path,
574 base="T",
575 label_offset=0.05):
576 # define colormap
577 cmap = mpl.colors.ListedColormap([
578 "blue",
579 "cornflowerblue",
580 "lightskyblue",
581 "white",
582 "mistyrose",
583 "tomato",
584 "red"])
585 bounds = [-.6, -.5, -.4, -.3, .3, .4, .5, .6]
586 norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
587 cmap.set_bad(color="gainsboro")
589 category = data.index.to_series().apply(lambda s: s[0:4])
590 data["category"] = category
592 for cat in category.unique():
593 cat_df = data[data["category"] == cat]
594 cat_df = cat_df.drop("category", axis=1)
595 # plot
596 fig, ax = plt.subplots(
597 1,
598 1,
599 dpi=300,
600 figsize=(8, 7),
601 tight_layout=True)
602 im = ax.imshow(cat_df, cmap=cmap, norm=norm, aspect='auto')
603 plt.colorbar(im)
605 # axes
606 xlocs = np.arange(len(cat_df.columns))
607 _ = ax.set_xticks(xlocs)
608 _ = ax.set_xticklabels(cat_df.columns.to_list(), rotation=90)
610 ylocs = np.arange(len(cat_df.index))
611 _ = ax.set_yticks(ylocs)
612 _ = ax.set_yticklabels(cat_df.index.to_list(),fontsize=4)
614 ax.set_title(
615 f"Correlation for basepair group {cat}")
616 plt.tight_layout()
618 # save as pdf
619 file_path = pathlib.Path(save_path) / f"{cat}.pdf"
620 fig.savefig(fname=file_path, format="pdf")
622 plt.close()
624 data = data.sort_values(by="category")
625 # cat_count = category.value_counts()
626 # category = category.unique()
627 # category.sort()
628 data = data.drop("category", axis=1)
630 # plot
631 fig, ax = plt.subplots(
632 1,
633 1,
634 dpi=300,
635 figsize=(7.5, 5),
636 tight_layout=True)
637 im = ax.imshow(data, cmap=cmap, norm=norm, aspect='auto')
638 plt.colorbar(im)
640 # axes
641 xlocs = np.arange(len(data.columns))
642 _ = ax.set_xticks(xlocs)
643 _ = ax.set_xticklabels(data.columns.to_list(), rotation=90)
645 # if yaxis:
646 # y_positions = [cat_count[category[i]] for i in range(len(category))]
647 # ylocs = np.cumsum(y_positions)
648 # _ = ax.set_yticks(ylocs)
649 # _ = ax.set_yticklabels(category)
650 # else:
651 # _ = ax.set_yticklabels([])
653 ax.set_title("Correlation for all basepairs")
654 plt.tight_layout()
656 # save as pdf
657 file_path = pathlib.Path(save_path) / f"{fname}.pdf"
658 fig.savefig(fname=file_path, format="pdf")