Coverage for mddb_workflow / analyses / density.py: 58%

57 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 18:45 +0000

1from mddb_workflow.utils.pyt_spells import get_reduced_pytraj_trajectory 

2from mddb_workflow.utils.auxiliar import save_json, load_json 

3from mddb_workflow.utils.constants import OUTPUT_DENSITY_FILENAME 

4from mddb_workflow.utils.type_hints import * 

5import pytraj as pt 

6 

7 

8def density( 

9 structure_file: 'File', 

10 trajectory_file: 'File', 

11 output_directory: str, 

12 membrane_map: dict, 

13 structure: 'Structure', 

14 snapshots: int, 

15 density_types=['number', 'mass', 'charge', 'electron'], 

16 frames_limit=1000 

17): 

18 """Membrane density analysis.""" 

19 if membrane_map is None or membrane_map['n_mems'] == 0: 

20 print('-> Skipping density analysis') 

21 return 

22 

23 # Set the main output filepath 

24 output_analysis_filepath = f'{output_directory}/{OUTPUT_DENSITY_FILENAME}' 

25 

26 # Load 

27 tj, frame_step, frames_count = get_reduced_pytraj_trajectory( 

28 structure_file.path, trajectory_file.path, snapshots, frames_limit) 

29 

30 # Set every selections to be analyzed separately 

31 components = [] 

32 for chain in structure.chains: 

33 components.append({ 

34 'name': chain.name, 

35 'selection': chain.get_selection(), 

36 'number': {}, 

37 'mass': {}, 

38 'charge': {}, # charge will be all 0 because we cannot add charges to pytraj topology 

39 'electron': {} 

40 }) 

41 # Parse selections to pytraj masks 

42 pytraj_masks = [component['selection'].to_pytraj() for component in components] 

43 # Add polar atoms selection 

44 polar_atoms = [] 

45 for n in range(membrane_map['n_mems']): 

46 polar_atoms.extend(membrane_map['mems'][str(n)]['polar_atoms']['top']) 

47 polar_atoms.extend(membrane_map['mems'][str(n)]['polar_atoms']['bot']) 

48 components.append({ 

49 'name': 'polar', 

50 'selection': polar_atoms, 

51 'number': {}, 'mass': {}, 'charge': {}, 'electron': {} 

52 }) 

53 pytraj_masks.append('@' + ', '.join(map(str, polar_atoms))) 

54 

55 # Run pytraj 

56 for density_type in density_types: 

57 out = pt.density(tj, pytraj_masks, density_type) 

58 # Iterate pytraj results 

59 results = iter(out.values()) 

60 for component in components: 

61 # Mine pytraj results 

62 component[density_type]['dens'] = list(next(results)) 

63 component[density_type]['stdv'] = list(next(results)) 

64 

65 # Parse the selection to atom indices 

66 # Selections could be removed to make the file smaller 

67 for component in components: 

68 if component['name'] == 'polar': continue 

69 component['selection'] = component['selection'].atom_indices 

70 # Export results 

71 data = {'data': {'comps': components, 'z': list(out['z'])}} 

72 save_json(data, output_analysis_filepath) 

73 

74 

75def plot_density(output_analysis_filepath): 

76 """Plot density analysis grouped by density type.""" 

77 import matplotlib.pyplot as plt 

78 # Load the density analysis results 

79 data = load_json(output_analysis_filepath) 

80 

81 components = data['data']['comps'] 

82 z = data['data']['z'] 

83 

84 # Group plots by density type in a 2x2 grid 

85 density_types = ['number', 'mass', 'charge', 'electron'] 

86 fig, axes = plt.subplots(2, 2, figsize=(12, 10)) 

87 axes = axes.flatten() 

88 

89 for i, density_type in enumerate(density_types): 

90 ax = axes[i] 

91 for component in components: 

92 name = component['name'] 

93 if density_type in component: 

94 dens = component[density_type].get('dens', []) 

95 stdv = component[density_type].get('stdv', []) 

96 if dens: 

97 ax.plot(z, dens, label=f"{name}") 

98 ax.fill_between(z, [d - s for d, s in zip(dens, stdv)], 

99 [d + s for d, s in zip(dens, stdv)], alpha=0.2) 

100 

101 ax.set_title(f"{density_type.capitalize()} Density Analysis") 

102 ax.set_xlabel("Z-axis") 

103 ax.set_ylabel(f"{density_type.capitalize()} Density") 

104 ax.legend() 

105 ax.grid() 

106 

107 plt.tight_layout() 

108 plt.show()