Coverage for model_workflow/utils/heatmaps_nassa.py: 57%

337 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-23 10:54 +0000

1import pathlib 

2import itertools 

3 

4import numpy as np 

5import pandas as pd 

6import matplotlib as mpl 

7import matplotlib.pyplot as plt 

8from matplotlib.tri import Triangulation 

9from matplotlib.colors import ListedColormap 

10from model_workflow.utils.nucleicacid import NucleicAcid 

11 

12def get_axes(subunit_len, base): 

13 '''Create the correct labels according to the lenght of the subunit''' 

14 

15 yaxis = [ 

16 f"{base}{base}", 

17 f"{base}C", 

18 f"C{base}", 

19 "CC", 

20 f"G{base}", 

21 "GC", 

22 f"A{base}", 

23 "AC", 

24 f"{base}G", 

25 f"{base}A", 

26 "CG", 

27 "CA", 

28 "GG", 

29 "GA", 

30 "AG", 

31 "AA"] 

32 

33 if subunit_len == 3: 

34 xaxis = f"G A C {base}".split() 

35 nucleotide_order = [b.join(f) for f in yaxis for b in xaxis] 

36 

37 elif subunit_len == 4: 

38 xaxis = yaxis[::-1] 

39 nucleotide_order = [b.join(f) for f in yaxis for b in xaxis] 

40 

41 elif subunit_len == 5: 

42 #xaxis = f"G A C {base}".split() 

43 xaxis = [] 

44 yaxis = [] 

45 nucl = ['A', 'C', 'G', base] 

46 nucleotide_order = [a + b + c + d + e for a in nucl for b in nucl for c in nucl for d in nucl for e in nucl] 

47 for i in nucleotide_order: 

48 xaxis.append(i[1:-1]) 

49 yaxis.append(i[0] + "..." + i[-1:]) 

50 #yaxis.append(i[:2] + "_" + i[-2:]) 

51 yaxis = list(set(xaxis)) 

52 

53 elif subunit_len == 6: 

54 nucleotide_order= [f"{n1}{n2}{n3}" for n1 in yaxis for n2 in yaxis for n3 in yaxis] 

55 yaxis = [] 

56 xaxis = [] 

57 for i in nucleotide_order: 

58 yaxis.append(i[:2] + "_" + i[-2:]) 

59 xaxis.append(i[2:4]) 

60 yaxis = list(set(yaxis)) 

61 xaxis = list(set(xaxis)) 

62 

63 return xaxis, yaxis, nucleotide_order 

64 

65 

66def reorder_labels_rotated_plot(df, subunit_name, tetramer_order): 

67 '''For the rotated plot: reorder the nucleotides labels matching them to their according values of the dataframe. ''' 

68 

69 if subunit_name == 'tetramer': 

70 sorted_index = dict(zip(tetramer_order, range(len(tetramer_order)))) 

71 df["subunit_rank"] = df[subunit_name].map(sorted_index) 

72 df = df.sort_values(by="subunit_rank") 

73 df = df.drop("subunit_rank", axis=1) 

74 df = df.reset_index(drop=True) 

75 

76 tetramer_order.sort() 

77 all_tetramers = pd.DataFrame({subunit_name: ["".join(tetramer) for tetramer in tetramer_order]}) 

78 merged_df = pd.merge(all_tetramers, df, on=subunit_name, how='left') 

79 merged_df['subunit_rank'] = merged_df[subunit_name].map(lambda x: tetramer_order.index(x) if not pd.isna(x) else np.nan) 

80 merged_df = merged_df.sort_values(by='subunit_rank').reset_index(drop=True) 

81 

82 if subunit_name == 'hexamer': 

83 centre= len(merged_df[subunit_name][1])//2 

84 merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[:centre-2]+"...."+x[centre+2:]) 

85 merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[centre-2:centre+2]) 

86 merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True) 

87 yaxiss = [] 

88 xaxiss = [] 

89 for i in range(len(merged_df)): 

90 xaxiss.append(merged_df[subunit_name][i][centre-2:centre+2]) 

91 yaxiss.append(merged_df[subunit_name][i][:centre-2] + "...." + merged_df[subunit_name][i][centre+2:]) 

92 

93 xaxis1=[] 

94 for nucl in xaxiss: 

95 if nucl not in xaxis1: 

96 xaxis1.append(nucl) 

97 

98 yaxis1 = [] 

99 for nucl in yaxiss: 

100 if nucl not in yaxis1: 

101 yaxis1.append(nucl) 

102 df1 = merged_df 

103 

104 if subunit_name == 'pentamer': 

105 centre= len(merged_df[subunit_name][1])%2 

106 merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[1:-1]) 

107 merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[0]+"..."+x[-1:]) 

108 merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True) 

109 yaxiss = [] 

110 xaxiss = [] 

111 for i in range(len(merged_df)): 

112 xaxiss.append(merged_df[subunit_name][i][1:-1]) 

113 yaxiss.append(merged_df[subunit_name][i][0] + "..." + merged_df[subunit_name][i][-1:]) 

114 

115 xaxis1=[] 

116 for nucl in xaxiss: 

117 if nucl not in xaxis1: 

118 xaxis1.append(nucl) 

119 

120 yaxis1 = [] 

121 for nucl in yaxiss: 

122 if nucl not in yaxis1: 

123 yaxis1.append(nucl) 

124 df1 = merged_df 

125 

126 # if subunit_name == 'pentamer': 

127 # centre= len(merged_df[subunit_name][1])%2 

128 # merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[:centre+1]+"_"+x[centre+2:]) 

129 # merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[centre+1]) 

130 # merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True) 

131 # yaxiss = [] 

132 # xaxiss = [] 

133 # for i in range(len(merged_df)): 

134 # yaxiss.append(merged_df[subunit_name][i][centre+1]) 

135 # xaxiss.append(merged_df[subunit_name][i][:centre+1] + "_" + merged_df[subunit_name][i][centre+2:]) 

136 

137 # xaxis1=[] 

138 # for nucl in xaxiss: 

139 # if nucl not in xaxis1: 

140 # xaxis1.append(nucl) 

141 

142 # yaxis1 = [] 

143 # for nucl in yaxiss: 

144 # if nucl not in yaxis1: 

145 # yaxis1.append(nucl) 

146 # df1 = merged_df  

147 

148 return df1, xaxis1, yaxis1 

149 

150def reorder_labels_straight_plot(df, subunit_name, tetramer_order): 

151 '''For the straight plot: reorder the nucleotides labels matching them to their according values of the dataframe. ''' 

152 

153 if subunit_name == 'tetramer': 

154 sorted_index = dict(zip(tetramer_order, range(len(tetramer_order)))) 

155 df["subunit_rank"] = df[subunit_name].map(sorted_index) 

156 df = df.sort_values(by="subunit_rank") 

157 df = df.drop("subunit_rank", axis=1) 

158 df = df.reset_index(drop=True) 

159 

160 tetramer_order.sort() 

161 all_tetramers = pd.DataFrame({subunit_name: ["".join(tetramer) for tetramer in tetramer_order]}) 

162 merged_df = pd.merge(all_tetramers, df, on=subunit_name, how='left') 

163 merged_df['subunit_rank'] = merged_df[subunit_name].map(lambda x: tetramer_order.index(x) if not pd.isna(x) else np.nan) 

164 merged_df = merged_df.sort_values(by='subunit_rank').reset_index(drop=True) 

165 

166 if subunit_name == 'hexamer': 

167 centre= len(merged_df[subunit_name][1])//2 

168 merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[:centre-1] + "_" + x[centre+1:]) 

169 merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[centre-1:centre+1]) 

170 merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True) 

171 yaxiss = [] 

172 xaxiss = [] 

173 for i in range(len(merged_df)): 

174 xaxiss.append(merged_df[subunit_name][i][centre-1:centre+1]) 

175 yaxiss.append(merged_df[subunit_name][i][:centre-1] + "_" + merged_df[subunit_name][i][centre+1:]) 

176 

177 xaxis=[] 

178 for nucl in xaxiss: 

179 if nucl not in xaxis: 

180 xaxis.append(nucl) 

181 

182 yaxis = [] 

183 for nucl in yaxiss: 

184 if nucl not in yaxis: 

185 yaxis.append(nucl) 

186 df = merged_df 

187 

188 if subunit_name == 'pentamer': 

189 centre= len(merged_df[subunit_name][1])%2 

190 merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[1:-1]) 

191 merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[0]+"..."+x[-1:]) 

192 

193 merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True) 

194 yaxiss = [] 

195 xaxiss = [] 

196 for i in range(len(merged_df)): 

197 yaxiss.append(merged_df[subunit_name][i][1:-1]) 

198 xaxiss.append(merged_df[subunit_name][i][0] + "..." + merged_df[subunit_name][i][-1:]) 

199 

200 xaxis=[] 

201 for nucl in xaxiss: 

202 if nucl not in xaxis: 

203 xaxis.append(nucl) 

204 

205 yaxis = [] 

206 for nucl in yaxiss: 

207 if nucl not in yaxis: 

208 yaxis.append(nucl) 

209 df = merged_df 

210 

211 # if subunit_name == 'pentamer': 

212 # centre= len(merged_df[subunit_name][1])%2 

213 # merged_df['yaxis'] = merged_df[subunit_name].apply(lambda x: x[:centre+1] + "_" + x[centre+2:]) 

214 # merged_df['xaxis'] = merged_df[subunit_name].apply(lambda x: x[centre+1]) 

215 # merged_df = merged_df.sort_values(by=['yaxis','xaxis'], ascending=True) 

216 # yaxiss = [] 

217 # xaxiss = [] 

218 # for i in range(len(merged_df)): 

219 # xaxiss.append(merged_df[subunit_name][i][centre+1]) 

220 # yaxiss.append(merged_df[subunit_name][i][:centre+1] + "_" + merged_df[subunit_name][i][centre+2:]) 

221 

222 # xaxis=[] 

223 # for nucl in xaxiss: 

224 # if nucl not in xaxis: 

225 # xaxis.append(nucl) 

226 

227 # yaxis = [] 

228 # for nucl in yaxiss: 

229 # if nucl not in yaxis: 

230 # yaxis.append(nucl) 

231 # df = merged_df 

232 

233 return df, xaxis, yaxis 

234 

235def arlequin_plot( 

236 df, 

237 global_mean, 

238 global_std, 

239 helpar, 

240 save_path, 

241 unit_name, 

242 unit_len, 

243 base, 

244 label_offset=0.5): 

245 

246 xaxis, yaxis, tetramer_order = get_axes(unit_len, base) 

247 df, xaxis, yaxis = reorder_labels_straight_plot(df, unit_name, tetramer_order) 

248 df1, xaxis1, yaxis1 = reorder_labels_rotated_plot(df, unit_name, tetramer_order) 

249 

250 sz1 = df["col1"].ravel() 

251 sz2 = df["col2"].ravel() 

252 

253 if unit_name == 'hexamer': 

254 M = 4**2 

255 N = 4**(unit_len - 2) 

256 M_1 = 4 ** (unit_len - 2) 

257 N_1 = 4 ** 2 

258 

259 if unit_name == 'pentamer': 

260 M = 4 * 4 

261 N = 4 ** 3 

262 M_1 = 4 ** 3 

263 N_1 = 4 * 4 

264 

265 # STRAIGHT PLOT 

266 

267 x = np.arange(M + 1) 

268 y = np.arange(N + 1) 

269 xs, ys = np.meshgrid(x, y) 

270 

271 upper_triangle = [(i + j*(M+1), i+1 + j*(M+1), i+1 + (j+1)*(M+1)) 

272 for j in range(N) for i in range(M)] 

273 lower_triangle = [(i + j*(M+1), i+1 + (j+1)*(M+1), i + (j+1)*(M+1)) 

274 for j in range(N) for i in range(M)] 

275 triang1 = Triangulation(xs.ravel(), ys.ravel(), upper_triangle) 

276 triang2 = Triangulation(xs.ravel(), ys.ravel(), lower_triangle) 

277 

278 fig, axs = plt.subplots( 

279 1, 

280 1, 

281 figsize=(8,18), 

282 dpi=300, 

283 tight_layout=True) 

284 

285 colormap = plt.get_cmap("bwr", 3).reversed() 

286 colormap.set_bad(color="grey") 

287 img1 = axs.tripcolor(triang1, sz1, cmap=colormap, vmin=-1, vmax=1) 

288 _ = axs.tripcolor(triang2, sz2, cmap=colormap, vmin=-1, vmax=1) 

289 

290 axs.grid() 

291 xlocs = np.arange(len(xaxis)) 

292 ylocs = np.arange(len(yaxis)) 

293 _ = axs.set_xticks(xlocs) 

294 _ = axs.set_xticklabels("") 

295 _ = axs.set_yticks(ylocs) 

296 _ = axs.set_yticklabels("") 

297 _ = axs.set_xticks(xlocs+label_offset, minor=True) 

298 _ = axs.set_xticklabels(xaxis, minor=True,fontsize=8) 

299 _ = axs.set_yticks(ylocs+label_offset, minor=True) 

300 _ = axs.set_yticklabels(yaxis, minor=True,fontsize=6) 

301 

302 _ = axs.set_xlim(0, M) 

303 _ = axs.set_ylim(0, N) 

304 axs.set_title(helpar.upper()) 

305 cbar = fig.colorbar(img1, ax=axs, ticks=[-1, 0, 1], shrink=0.4) 

306 cbar.ax.set_yticklabels([ 

307 f"< {global_mean:.2f}-{global_std:.2f}", 

308 f"{global_mean:.2f}$\pm${global_std:.2f}", 

309 f"> {global_mean:.2f}+{global_std:.2f}"]) 

310 

311 file_path = pathlib.Path(save_path) / f"{helpar}.pdf" 

312 fig.savefig(fname=file_path, format="pdf") 

313 

314 # ROTATED PLOT 

315 

316 sz1_1 = df1["col1"].ravel() 

317 sz2_1 = df1["col2"].ravel() 

318 

319 x_1 = np.arange(M_1 + 1) 

320 y_1 = np.arange(N_1 + 1) 

321 xs_1, ys_1 = np.meshgrid(x_1, y_1) 

322 

323 upper_triangle_1 = [(i + j*(M_1+1), i+1 + j*(M_1+1), i+1 + (j+1)*(M_1+1)) 

324 for j in range(N_1) for i in range(M_1)] 

325 lower_triangle_1 = [(i + j*(M_1+1), i+1 + (j+1)*(M_1+1), i + (j+1)*(M_1+1)) 

326 for j in range(N_1) for i in range(M_1)] 

327 triang1_1 = Triangulation(xs_1.ravel(), ys_1.ravel(), upper_triangle_1) 

328 triang2_1 = Triangulation(xs_1.ravel(), ys_1.ravel(), lower_triangle_1) 

329 

330 fig_1, axs_1 = plt.subplots( 

331 1, 

332 1, 

333 figsize=(22, 6), 

334 dpi=300, 

335 tight_layout=True) 

336 

337 colormap_1 = plt.get_cmap("bwr", 3).reversed() 

338 colormap_1.set_bad(color="grey") 

339 img1_1 = axs_1.tripcolor(triang1_1, sz1_1, cmap=colormap_1, vmin=-1, vmax=1) 

340 _ = axs_1.tripcolor(triang2_1, sz2_1, cmap=colormap_1, vmin=-1, vmax=1) 

341 

342 axs_1.grid() 

343 xlocs_1 = np.arange(len(xaxis1)) 

344 ylocs_1 = np.arange(len(yaxis1)) 

345 _ = axs_1.set_xticks(xlocs_1) 

346 _ = axs_1.set_xticklabels("") 

347 _ = axs_1.set_yticks(ylocs_1) 

348 _ = axs_1.set_yticklabels("") 

349 _ = axs_1.set_xticks(xlocs_1+label_offset, minor=True) 

350 _ = axs_1.set_xticklabels(xaxis1, minor=True,fontsize=4, rotation=90) 

351 _ = axs_1.set_yticks(ylocs_1+label_offset, minor=True) 

352 _ = axs_1.set_yticklabels(yaxis1, minor=True,fontsize=6) 

353 _ = axs_1.set_xlim(0, M_1) 

354 _ = axs_1.set_ylim(0, N_1) 

355 axs_1.set_title(helpar.upper()) 

356 cbar_1 = fig_1.colorbar(img1_1, ax=axs_1, ticks=[-1, 0, 1], shrink=0.4) 

357 cbar_1.ax.set_yticklabels([ 

358 f"< {global_mean:.2f}-{global_std:.2f}", 

359 f"{global_mean:.2f}$\pm${global_std:.2f}", 

360 f"> {global_mean:.2f}+{global_std:.2f}"]) 

361 

362 file_path1 = pathlib.Path(save_path) / f"{helpar}_rotated.pdf" 

363 fig_1.savefig(fname=file_path1, format="pdf") 

364 return fig, axs, fig_1, axs_1 

365 

366 

367def bconf_heatmap(df, fname, save_path, subunit_len, base="T", label_offset=0.05): 

368 print('subunit len: ', subunit_len) 

369 if subunit_len == 3: 

370 yaxis = [ 

371 f"{base}{base}", 

372 f"{base}C", 

373 f"C{base}", 

374 "CC", 

375 f"G{base}", 

376 "GC", 

377 f"A{base}", 

378 "AC", 

379 f"{base}G", 

380 f"{base}A", 

381 "CG", 

382 "CA", 

383 "GG", 

384 "GA", 

385 "AG", 

386 "AA"] 

387 xaxis = f"G A C {base}".split() 

388 nucleotide_order = pd.DataFrame( 

389 [b.join(f) for f in yaxis for b in xaxis], 

390 columns=["trimer"]) 

391 df = df.merge(nucleotide_order, how="right", on="trimer") 

392 fig, ax = plt.subplots() 

393 

394 elif subunit_len == 4: 

395 xaxis = [ 

396 "GG", 

397 "GA", 

398 "AG", 

399 "AA", 

400 "GC", 

401 f"G{base}", 

402 f"A{base}", 

403 "AC", 

404 "CA", 

405 f"{base}A", 

406 f"{base}G", 

407 "CG", 

408 "CC", 

409 f"C{base}", 

410 f"{base}C", 

411 f"{base}{base}"] 

412 yaxis = xaxis.copy() 

413 tetramer_order = pd.DataFrame( 

414 [b.join(f) for f in yaxis for b in xaxis], 

415 columns=["tetramer"]) 

416 df = df.merge(tetramer_order, how="right", on="tetramer") 

417 fig, ax = plt.subplots() 

418 

419 elif subunit_len == 5: 

420 raise ValueError('The length of the subunit 5 is not a valid option for bconf analyses. Try with 4 or 6.') 

421 

422 elif subunit_len == 6: 

423 xaxis = [ 

424 f"{base}{base}", 

425 f"{base}C", 

426 f"C{base}", 

427 "CC", 

428 f"G{base}", 

429 "GC", 

430 f"A{base}", 

431 "AC", 

432 f"{base}G", 

433 f"{base}A", 

434 "CG", 

435 "CA", 

436 "GG", 

437 "GA", 

438 "AG", 

439 "AA"] 

440 nucleotide_order= [f"{n1}{n2}{n3}" for n1 in xaxis for n2 in xaxis for n3 in xaxis] 

441 yaxis = [] 

442 xaxis = [] 

443 for i in nucleotide_order: 

444 xaxis.append(i[:2] + "_" + i[-2:]) 

445 yaxis.append(i[2:4]) 

446 yaxis = list(set(yaxis)) 

447 xaxis = list(set(xaxis)) 

448 # tetramer_order = pd.DataFrame( 

449 # [b.join(f) for f in xaxis for b in yaxis], 

450 # columns=["hexamer"]) 

451 tetramer_order = pd.DataFrame( 

452 sorted(nucleotide_order), 

453 columns=["hexamer"]) 

454 df = df.merge(tetramer_order, how="right", on="hexamer") 

455 fig, ax = plt.subplots( 

456 1, 

457 1, 

458 figsize=(22, 12), 

459 dpi=300, 

460 tight_layout=True) 

461 

462 colormap = ListedColormap([ 

463 "darkblue", 

464 "blue", 

465 "lightblue", 

466 "lightgreen", 

467 "lime", 

468 "orange", 

469 "red", 

470 "crimson"]) 

471 colormap.set_bad(color="grey") 

472 

473 # plot 

474 if subunit_len == 3: 

475 im = ax.imshow(df["pct"].to_numpy().reshape((16, 4)), cmap=colormap) 

476 if subunit_len == 4: 

477 im = ax.imshow(df["pct"].to_numpy().reshape((16, 16)), cmap=colormap) 

478 if subunit_len == 6: 

479 im = ax.imshow(df["pct"].to_numpy().reshape((16, 256)), cmap=colormap) 

480 plt.colorbar(im) 

481 # axes 

482 xlocs = np.arange(len(xaxis)) 

483 ylocs = np.arange(len(yaxis)) 

484 _ = ax.set_xticks(xlocs) 

485 _ = ax.set_xticklabels(xaxis, minor=True, fontsize=4) 

486 _ = ax.set_yticks(ylocs) 

487 _ = ax.set_yticklabels(yaxis, minor=True,fontsize=4) 

488 ax.set_title((fname + " conformations").upper()) 

489 # save as pdf 

490 file_path = pathlib.Path(save_path) / f"{fname}_percentages.pdf" 

491 fig.savefig(fname=file_path, format="pdf") 

492 

493 

494def correlation_plot(data, fname, save_path, base="T", label_offset=0.05): 

495 # define colormap 

496 cmap = mpl.colors.ListedColormap([ 

497 "blue", 

498 "cornflowerblue", 

499 "lightskyblue", 

500 "white", 

501 "mistyrose", 

502 "tomato", 

503 "red"]) 

504 bounds = [-1.0, -.73, -.53, -.3, .3, .53, .73, 1.0] 

505 norm = mpl.colors.BoundaryNorm(bounds, cmap.N) 

506 cmap.set_bad(color="gainsboro") 

507 

508 # reorder dataset 

509 coordinates = list(set(data.index.get_level_values(0))) 

510 data = data.loc[coordinates][coordinates].sort_index( 

511 level=1, axis=0).sort_index(level=1, axis=1) 

512 

513 for crd1, crd2 in itertools.combinations_with_replacement( 

514 coordinates, 

515 r=2): 

516 crd_data = data.loc[crd1][crd2] 

517 

518 # plot 

519 fig, ax = plt.subplots( 

520 1, 

521 1, 

522 dpi=300, 

523 tight_layout=True) 

524 im = ax.imshow(crd_data, cmap=cmap, norm=norm, aspect='auto') 

525 plt.colorbar(im) 

526 

527 # axes 

528 units = set(crd_data.index) 

529 xlocs = np.arange(len(units)) 

530 _ = ax.set_xticks(xlocs) 

531 _ = ax.set_xticklabels(units, rotation=90) 

532 

533 ylocs = np.arange(len(units)) 

534 _ = ax.set_yticks(ylocs) 

535 _ = ax.set_yticklabels(units) 

536 ax.set_title(f"rows: {crd1} | columns: {crd2}") 

537 plt.tight_layout() 

538 

539 # save as pdf 

540 file_path = pathlib.Path(save_path) / f"{crd1}_{crd2}.pdf" 

541 fig.savefig(fname=file_path, format="pdf") 

542 

543 plt.close() 

544 

545 # plot 

546 fig, ax = plt.subplots( 

547 1, 

548 1, 

549 dpi=300, 

550 tight_layout=True) 

551 im = ax.imshow(data, cmap=cmap, norm=norm, aspect='auto') 

552 plt.colorbar(im) 

553 

554 # axes 

555 start = len(data) // (2 * len(coordinates)) 

556 step = 2 * start 

557 locs = np.arange(start, len(data)-1, step) 

558 _ = ax.set_xticks(locs) 

559 _ = ax.set_yticks(locs) 

560 _ = ax.set_xticklabels(coordinates, rotation=90) 

561 _ = ax.set_yticklabels(coordinates) 

562 

563 plt.tight_layout() 

564 

565 # save as pdf 

566 file_path = pathlib.Path(save_path) / f"{fname}.pdf" 

567 fig.savefig(fname=file_path, format="pdf") 

568 

569 

570def basepair_plot( 

571 data, 

572 fname, 

573 save_path, 

574 base="T", 

575 label_offset=0.05): 

576 # define colormap 

577 cmap = mpl.colors.ListedColormap([ 

578 "blue", 

579 "cornflowerblue", 

580 "lightskyblue", 

581 "white", 

582 "mistyrose", 

583 "tomato", 

584 "red"]) 

585 bounds = [-.6, -.5, -.4, -.3, .3, .4, .5, .6] 

586 norm = mpl.colors.BoundaryNorm(bounds, cmap.N) 

587 cmap.set_bad(color="gainsboro") 

588 

589 category = data.index.to_series().apply(lambda s: s[0:4]) 

590 data["category"] = category 

591 

592 for cat in category.unique(): 

593 cat_df = data[data["category"] == cat] 

594 cat_df = cat_df.drop("category", axis=1) 

595 # plot 

596 fig, ax = plt.subplots( 

597 1, 

598 1, 

599 dpi=300, 

600 figsize=(8, 7), 

601 tight_layout=True) 

602 im = ax.imshow(cat_df, cmap=cmap, norm=norm, aspect='auto') 

603 plt.colorbar(im) 

604 

605 # axes 

606 xlocs = np.arange(len(cat_df.columns)) 

607 _ = ax.set_xticks(xlocs) 

608 _ = ax.set_xticklabels(cat_df.columns.to_list(), rotation=90) 

609 

610 ylocs = np.arange(len(cat_df.index)) 

611 _ = ax.set_yticks(ylocs) 

612 _ = ax.set_yticklabels(cat_df.index.to_list(),fontsize=4) 

613 

614 ax.set_title( 

615 f"Correlation for basepair group {cat}") 

616 plt.tight_layout() 

617 

618 # save as pdf 

619 file_path = pathlib.Path(save_path) / f"{cat}.pdf" 

620 fig.savefig(fname=file_path, format="pdf") 

621 

622 plt.close() 

623 

624 data = data.sort_values(by="category") 

625 # cat_count = category.value_counts() 

626 # category = category.unique() 

627 # category.sort() 

628 data = data.drop("category", axis=1) 

629 

630 # plot 

631 fig, ax = plt.subplots( 

632 1, 

633 1, 

634 dpi=300, 

635 figsize=(7.5, 5), 

636 tight_layout=True) 

637 im = ax.imshow(data, cmap=cmap, norm=norm, aspect='auto') 

638 plt.colorbar(im) 

639 

640 # axes 

641 xlocs = np.arange(len(data.columns)) 

642 _ = ax.set_xticks(xlocs) 

643 _ = ax.set_xticklabels(data.columns.to_list(), rotation=90) 

644 

645 # if yaxis: 

646 # y_positions = [cat_count[category[i]] for i in range(len(category))] 

647 # ylocs = np.cumsum(y_positions) 

648 # _ = ax.set_yticks(ylocs) 

649 # _ = ax.set_yticklabels(category) 

650 # else: 

651 # _ = ax.set_yticklabels([]) 

652 

653 ax.set_title("Correlation for all basepairs") 

654 plt.tight_layout() 

655 

656 # save as pdf 

657 file_path = pathlib.Path(save_path) / f"{fname}.pdf" 

658 fig.savefig(fname=file_path, format="pdf")