Coverage for mddb_workflow/core/dataset.py: 13%

146 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 15:48 +0000

1from mddb_workflow.utils.auxiliar import load_yaml, is_glob 

2from mddb_workflow.mwf import workflow 

3import subprocess 

4import os 

5import glob 

6import yaml 

7import jinja2 

8import pandas as pd 

9 

10class Dataset: 

11 """ 

12 Class to manage and process a dataset of MDDB projects. 

13 """ 

14 def __init__(self, dataset_yaml_path: str): 

15 """ 

16 Initializes the Dataset object. 

17 

18 Args: 

19 dataset_yaml_path (str): Path to the dataset YAML file. 

20 """ 

21 self.root_path = os.path.dirname(os.path.abspath(dataset_yaml_path)) 

22 self.config = load_yaml(dataset_yaml_path) 

23 # Properties cache 

24 self._project_directories = None 

25 self._status = None 

26 

27 @property 

28 def project_directories(self) -> list[str]: 

29 """ 

30 Retrieves the list of project directories from the dataset configuration. 

31 

32 Returns: 

33 list: List of project directory patterns. 

34 """ 

35 if self._project_directories is not None: 

36 return self._project_directories 

37 

38 config_project_directories = self.config.get('global', {}).get('project_directories', []) 

39 self._project_directories = [] 

40 for i, dir_pattern in enumerate(config_project_directories): 

41 if is_glob(dir_pattern): 

42 matched_dirs = glob.glob(os.path.join(self.root_path, dir_pattern)) 

43 # keep only directories 

44 self._project_directories.extend([p for p in matched_dirs if os.path.isdir(p)]) 

45 else: 

46 p = dir_pattern 

47 if not os.path.isabs(p): 

48 p = os.path.abspath(os.path.join(self.root_path, p)) 

49 # ensure the resolved path is under the dataset root 

50 if os.path.commonprefix([p, self.root_path]) != self.root_path: 

51 raise ValueError(f"Project directory '{p}' is outside the dataset root '{self.root_path}'") 

52 self._project_directories.append(p) 

53 

54 return self._project_directories 

55 

56 def generate_inputs_yaml(self, inputs_template_path: str, input_generator: callable, overwrite: bool = False): 

57 """ 

58 Generates an inputs.yaml file in each project directory based on the dataset configuration. 

59  

60 Args: 

61 inputs_template_path (str): The file path to the Jinja2 template file that will be 

62 used to generate the inputs YAML files. 

63 input_generator (callable): A callable function intended for generating input values. 

64 Currently, it is called with the project directory name (DIR) as its argument 

65 overwrite (bool): Whether to overwrite existing inputs.yaml files. Default is False. 

66  

67 """ 

68 # Load the template 

69 with open(inputs_template_path, 'r') as f: 

70 template_str = f.read() 

71 

72 template = jinja2.Template(template_str) 

73 

74 for project_dir in self.project_directories: 

75 inputs_yaml_path = os.path.join(project_dir, 'inputs.yaml') 

76 if os.path.exists(inputs_yaml_path) and not overwrite: 

77 continue 

78 

79 # Get the directory name 

80 DIR = os.path.basename(os.path.normpath(project_dir)) 

81 

82 # Render the template with project defaults 

83 rendered_yaml = template.render(DIR=DIR, title=input_generator(DIR)) 

84 

85 # Write the rendered YAML to inputs.yaml 

86 with open(inputs_yaml_path, 'w') as f: 

87 f.write(rendered_yaml) 

88 break 

89 

90 @property 

91 def status(self) -> pd.DataFrame: 

92 """ 

93 Retrieves last line from logs from all project directories as a pandas DataFrame. 

94 

95 Returns: 

96 pd.DataFrame: Index is project directory; columns: state, message, log_file, error_log_file. 

97 """ 

98 if self._status is not None: 

99 return self._status 

100 

101 rows = [] 

102 for project_dir in self.project_directories: 

103 # RUBEN: por ahora usamos el mismo patron de log que en launch_workflow 

104 # en un futuro se recuperar el estado a partir de .register/.mwf_cache 

105 log_files = glob.glob(os.path.join(project_dir, 'logs', 'mwf*[0-9].out')) 

106 err_files = glob.glob(os.path.join(project_dir, 'logs', 'mwf*[0-9].err')) 

107 

108 if log_files: 

109 log_files.sort() 

110 log_files = [log_files[-1]] # Take the most recent one 

111 with open(log_files[0], 'r') as f: 

112 last_line = f.read().splitlines()[-1].strip() 

113 

114 if last_line == 'Done!': 

115 state, message, log_file = 'done', last_line, log_files[0] 

116 else: 

117 state, message, log_file = 'error', last_line, log_files[0] 

118 else: 

119 state, message, log_file = 'not_run', 'No output log available', None 

120 

121 # Handle error log files 

122 if err_files: 

123 err_files.sort() 

124 err_file = err_files[-1] # Take the most recent one 

125 else: 

126 err_file = None 

127 

128 rows.append({ 

129 'rel_path': os.path.relpath(project_dir, self.root_path), 

130 'state': state, 

131 'message': message, 

132 'log_file': os.path.relpath(log_file, project_dir) if log_file else '', 

133 'err_file': os.path.relpath(err_file, project_dir) if err_file else '' 

134 }) 

135 

136 df = pd.DataFrame(rows).set_index('rel_path').sort_index() 

137 # Assign an integer group id for identical messages, with messages sorted first 

138 unique_messages = sorted(df['message'].unique()) 

139 if 'Done!' in unique_messages: 

140 # Ensure 'Done!' is always group 0, then assign other messages 

141 unique_messages.remove('Done!') 

142 mapping = {'Done!': 0} 

143 mapping.update({msg: idx + 1 for idx, msg in enumerate(unique_messages)}) 

144 else: 

145 mapping = {msg: idx for idx, msg in enumerate(unique_messages)} 

146 df['group'] = df['message'].map(mapping).astype(int) 

147 self._status = df 

148 return self._status 

149 

150 def show_groups(self, cmd=False): 

151 """ 

152 Displays the groups of projects based on their status messages. 

153 """ 

154 if cmd: 

155 status = self.status 

156 grouped = status.groupby('group') 

157 for group_id, group_df in grouped: 

158 print(f"Group {group_id}:") 

159 print(f"Message: {group_df['message'].iloc[0]}") 

160 print("Projects:") 

161 for rel_path in group_df.index: 

162 print(f" - {rel_path}") 

163 print() 

164 else: 

165 grouped = self.status.groupby('group').agg({ 

166 'message': 'first', 

167 'state': 'count' 

168 }).rename(columns={'state': 'count'}) 

169 return grouped 

170 

171 def status_with_links(self) -> pd.DataFrame: 

172 """ 

173 Returns the status DataFrame with clickable log file links. 

174 """ 

175 df = self.status.copy() 

176 

177 # Create clickable links for log files 

178 def make_out_link(row): 

179 if row['log_file'] and row['state'] != 'not_run': 

180 project_dir = os.path.join(self.root_path, row.name) 

181 log_path = os.path.join(project_dir, row['log_file']) 

182 # Create a file:// URL for local files 

183 file_url = f"file://{log_path}" 

184 return f'<a href="{file_url}" target="_blank">{row["log_file"]}</a>' 

185 return row['log_file'] 

186 

187 # Create clickable links for error log files 

188 def make_error_link(row): 

189 if row['err_file'] and row['state'] != 'not_run': 

190 project_dir = os.path.join(self.root_path, row.name) 

191 error_log_path = os.path.join(project_dir, row['err_file']) 

192 # Create a file:// URL for local files 

193 file_url = f"file://{error_log_path}" 

194 return f'<a href="{file_url}" target="_blank">{row["err_file"]}</a>' 

195 return row['err_file'] 

196 

197 df['log_file_link'] = df.apply(make_out_link, axis=1) 

198 df['err_file_link'] = df.apply(make_error_link, axis=1) 

199 return df 

200 

201 def display_status_with_links(self): 

202 """ 

203 Display the status DataFrame with clickable links in Jupyter. 

204 """ 

205 from IPython.display import HTML, display 

206 

207 df = self.status_with_links() 

208 # Drop the original log_file columns and rename the link columns 

209 df_display = df.drop(['log_file', 'err_file'], axis=1) 

210 df_display = df_display.rename(columns={ 

211 'log_file_link': 'log_file', 

212 'err_file_link': 'err_file' 

213 }) 

214 

215 # Convert to HTML and display 

216 html = df_display.to_html(escape=False) 

217 display(HTML(html)) 

218 

219 def launch_workflow(self, 

220 include_groups: list[int]=[], 

221 exclude_groups: list[int]=[], 

222 slurm: bool=False, 

223 job_template: str=None): 

224 """ 

225 Launches the workflow for each project directory in the dataset. 

226 Args: 

227 include_groups (list[int]): 

228 List of group IDs to be run. 

229 exclude_groups (list[int]): 

230 List of group IDs to be excluded. 

231 slurm (bool): 

232 Whether to submit the workflow to SLURM. 

233 job_template (str): 

234 Path to the SLURM job template file. You can use Jinja2 

235 templating to customize the job script using the fields of 

236 the input YAML and the columns of project status dataframe. 

237 """ 

238 # Include/exclude groups should not intersect 

239 if include_groups and exclude_groups: 

240 intersection = set(include_groups).intersection(set(exclude_groups)) 

241 if intersection: 

242 raise ValueError(f"include_groups and exclude_groups intersect: {intersection}") 

243 

244 if slurm and not job_template: 

245 raise ValueError("job_template must be provided when slurm is True") 

246 for project_dir in self.project_directories: 

247 project_status = self.status.loc[os.path.relpath(project_dir, self.root_path)].to_dict() 

248 # Check group inclusion/exclusion 

249 group_id = project_status['group'] 

250 if group_id in exclude_groups or (include_groups and group_id not in include_groups): 

251 continue 

252 # Launch workflow 

253 if slurm: 

254 # SLURM execution 

255 inputs_yaml_path = os.path.join(project_dir, 'inputs.yaml') 

256 if not os.path.exists(inputs_yaml_path): 

257 print(f"Warning: {inputs_yaml_path} not found. Skipping {project_dir}") 

258 continue 

259 

260 inputs_config = load_yaml(inputs_yaml_path) 

261 

262 with open(job_template, 'r') as f: 

263 template_str = f.read() 

264 

265 template = jinja2.Template(template_str) 

266 rendered_script = template.render(**inputs_config, **project_status) 

267 

268 job_script_path = os.path.join(project_dir, 'mwf_slurm_job.sh') 

269 log_dir = os.path.join(project_dir, 'logs') 

270 os.makedirs(log_dir, exist_ok=True) 

271 with open(job_script_path, 'w') as f: 

272 f.write(rendered_script) 

273 

274 os.chmod(job_script_path, 0o755) 

275 

276 print(f"Submitting SLURM job for {project_dir}") 

277 subprocess.run(['sbatch', 'mwf_slurm_job.sh', 

278 '--output', os.path.join(log_dir, 'mwf-%j.out'), 

279 '--error', os.path.join(log_dir, 'mwf-%j.err')], 

280 cwd=project_dir) 

281 

282 else: 

283 # Normal Python execution 

284 raise NotImplementedError("Python execution is not implemented yet.") 

285 workflow(working_directory=project_dir)