Coverage for mddb_workflow / core / dataset.py: 11%

188 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 18:45 +0000

1from mddb_workflow.utils.auxiliar import load_yaml, is_glob, warn 

2from mddb_workflow.utils.type_hints import * 

3import pandas as pd 

4import subprocess 

5import jinja2 

6import time 

7import glob 

8import os 

9 

10 

11class Dataset: 

12 """Class to manage and process a dataset of MDDB projects.""" 

13 

14 def __init__(self, dataset_yaml_path: str): 

15 """Initialize the Dataset object. 

16 

17 Args: 

18 dataset_yaml_path (str): Path to the dataset YAML file. 

19 

20 """ 

21 self.root_path = os.path.dirname(os.path.abspath(dataset_yaml_path)) 

22 self.config = load_yaml(dataset_yaml_path) 

23 # Properties cache 

24 self.project_directories = self.get_project_directories() 

25 self.groups = self.get_groups() 

26 self._status = None 

27 

28 def _resolve_directory_patterns(self, dir_patterns: list[str]) -> list[str]: 

29 """Resolve directory patterns (glob or absolute/relative paths). 

30 Validates that resolved paths are under the dataset root. 

31 

32 Args: 

33 dir_patterns (list[str]): List of directory patterns to resolve. 

34 

35 Returns: 

36 list[str]: List of resolved directory paths. 

37 

38 Raises: 

39 ValueError: If a resolved path is outside the dataset root. 

40 

41 """ 

42 directories = [] 

43 for dir_pattern in dir_patterns: 

44 if is_glob(dir_pattern): 

45 matched_dirs = glob.glob(os.path.join(self.root_path, dir_pattern)) 

46 # keep only directories 

47 directories.extend([p for p in matched_dirs if os.path.isdir(p)]) 

48 else: 

49 p = dir_pattern 

50 if not os.path.isabs(p): 

51 p = os.path.abspath(os.path.join(self.root_path, p)) 

52 # ensure the resolved path is under the dataset root 

53 if os.path.commonprefix([p, self.root_path]) != self.root_path: 

54 raise ValueError(f"Project directory '{p}' is outside the dataset root '{self.root_path}'") 

55 directories.append(p) 

56 return directories 

57 

58 def get_project_directories(self) -> list[str]: 

59 """Retrieve the list of project directories from the dataset configuration.""" 

60 dirs = self._resolve_directory_patterns(self.config.get('projects', [])) 

61 ignore_dirs = self._resolve_directory_patterns(self.config.get('ignore', [])) 

62 dirs = [d for d in dirs if d not in ignore_dirs] 

63 if not dirs: 

64 raise ValueError("No project directories found in the dataset configuration.") 

65 return dirs 

66 

67 def get_groups(self) -> dict[str, list[str]]: 

68 """Retrieve the groups of project directories from the dataset configuration.""" 

69 groups = {} 

70 for group in self.config.get('groups', []): 

71 groups[group] = self._resolve_directory_patterns(self.config['groups'][group]) 

72 return groups 

73 

74 def generate_inputs_yaml(self, 

75 inputs_template_path: str, 

76 input_generator: Callable, 

77 overwrite: bool = False 

78 ): 

79 """Generate an inputs.yaml file in each project directory based on the dataset configuration. 

80 

81 Args: 

82 inputs_template_path (str): The file path to the Jinja2 template file that will be 

83 used to generate the inputs YAML files. 

84 input_generator (callable): A callable function intended for generating input values. 

85 Currently, it is called with the project directory name (DIR) as its argument 

86 overwrite (bool): Whether to overwrite existing inputs.yaml files. Default is False. 

87 

88 """ 

89 # Load the template 

90 with open(inputs_template_path, 'r') as f: 

91 template_str = f.read() 

92 

93 template = jinja2.Template(template_str) 

94 skipped = 0 

95 generated = 0 

96 for project_dir in self.project_directories: 

97 inputs_yaml_path = os.path.join(project_dir, 'inputs.yaml') 

98 if os.path.exists(inputs_yaml_path) and not overwrite: 

99 skipped += 1 

100 continue 

101 generated += 1 

102 # Get the directory name 

103 DIR = os.path.basename(os.path.normpath(project_dir)) 

104 # Render the template with project defaults 

105 rendered_yaml = template.render(DIR=DIR, title=input_generator(DIR)) 

106 # Write the rendered YAML to inputs.yaml 

107 with open(inputs_yaml_path, 'w') as f: 

108 f.write(rendered_yaml) 

109 print(f"Generated {generated} inputs.yaml files. Skipped {skipped} existing files.") 

110 

111 @property 

112 def status(self) -> pd.DataFrame: 

113 """Retrieve last line from logs from all project directories as a pandas DataFrame. 

114 

115 Returns: 

116 pd.DataFrame: Index is project directory; columns: state, message, log_file, error_log_file. 

117 

118 """ 

119 if self._status is not None: 

120 return self._status 

121 

122 rows = [] 

123 for project_dir in self.project_directories: 

124 # RUBEN: por ahora usamos el mismo patron de log que en launch_workflow 

125 # en un futuro se recuperar el estado a partir de .register/.mwf_cache 

126 log_files = glob.glob(os.path.join(project_dir, 'logs', 'mwf*[0-9].out')) 

127 err_files = glob.glob(os.path.join(project_dir, 'logs', 'mwf*[0-9].err')) 

128 

129 if log_files: 

130 log_files.sort(key=lambda x: os.path.getmtime(x)) 

131 log_files = [log_files[-1]] # Take the most recent one 

132 with open(log_files[0], 'r') as f: 

133 lines = f.read().splitlines() 

134 if lines: 

135 last_line = lines[-1].strip() 

136 if len(last_line) > 80: 

137 last_line = last_line[:80] + '...' 

138 else: 

139 last_line = '' 

140 

141 if last_line == 'Done!': 

142 state, message, log_file = 'done', last_line, log_files[0] 

143 else: 

144 state, message, log_file = 'error', last_line if last_line else 'Empty log file', log_files[0] 

145 else: 

146 state, message, log_file = 'not_run', 'No output log available', None 

147 

148 # Handle error log files 

149 if err_files: 

150 err_files.sort() 

151 err_file = err_files[-1] # Take the most recent one 

152 else: 

153 err_file = None 

154 

155 # Check if files were modified recently (within last 5 minutes) to detect running state 

156 current_time = time.time() 

157 recently_modified_threshold = 300 # 5 minutes in seconds 

158 

159 last_modified = '' 

160 if (log_file and os.path.exists(log_file) and (current_time - os.path.getmtime(log_file)) < recently_modified_threshold) or \ 

161 (err_file and os.path.exists(err_file) and (current_time - os.path.getmtime(err_file)) < recently_modified_threshold): 

162 if state != 'done': 

163 state = 'running' 

164 else: 

165 # Save last modification time if not running 

166 if log_file and os.path.exists(log_file): 

167 last_modified = time.strftime('%H:%M:%S %d/%m/%y', time.localtime(os.path.getmtime(log_file))) 

168 

169 rows.append({ 

170 'rel_path': os.path.relpath(project_dir, self.root_path), 

171 'state': state, 

172 'message': message, 

173 'log_file': os.path.relpath(log_file, project_dir) if log_file else '', 

174 'err_file': os.path.relpath(err_file, project_dir) if err_file else '', 

175 'last_modified': last_modified 

176 }) 

177 

178 df = pd.DataFrame(rows).set_index('rel_path').sort_index() 

179 # Assign an integer group id for identical messages, with messages sorted first 

180 unique_messages = sorted(df['message'].unique()) 

181 if 'Done!' in unique_messages: 

182 # Ensure 'Done!' is always group 0, then assign other messages 

183 unique_messages.remove('Done!') 

184 mapping = {'Done!': 0} 

185 mapping.update({msg: idx + 1 for idx, msg in enumerate(unique_messages)}) 

186 else: 

187 mapping = {msg: idx for idx, msg in enumerate(unique_messages)} 

188 df['group'] = df['message'].map(mapping).astype(int) 

189 self._status = df 

190 return self._status 

191 

192 def show_groups(self, cmd=False): 

193 """Display the groups of projects based on their status messages.""" 

194 if cmd: 

195 print("Project groups based on status messages:\n") 

196 status = self.status 

197 grouped = status.groupby('group') 

198 for group_id, group_df in grouped: 

199 print(f"Group {group_id}:") 

200 print(f"Message: {group_df['message'].iloc[0]}") 

201 print("Projects:") 

202 for rel_path in group_df.index: 

203 print(f" - {rel_path}") 

204 print() 

205 else: 

206 grouped = self.status.groupby('group').agg({ 

207 'message': 'first', 

208 'state': 'count' 

209 }).rename(columns={'state': 'count'}) 

210 return grouped 

211 

212 def status_with_links(self) -> pd.DataFrame: 

213 """Return the status DataFrame with clickable log file links.""" 

214 self._status = None # Force reload 

215 df = self.status.copy() 

216 

217 # Create clickable links for log files 

218 def make_out_link(row): 

219 if row['log_file'] and row['state'] != 'not_run': 

220 project_dir = os.path.join(self.root_path, row.name) 

221 log_path = os.path.join(project_dir, row['log_file']) 

222 # Create a file:// URL for local files 

223 file_url = f"file://{log_path}" 

224 return f'<a href="{file_url}" target="_blank">{row["log_file"].split("/")[-1]}</a>' 

225 return row['log_file'] 

226 

227 # Create clickable links for error log files 

228 def make_error_link(row): 

229 if row['err_file'] and row['state'] != 'not_run': 

230 project_dir = os.path.join(self.root_path, row.name) 

231 error_log_path = os.path.join(project_dir, row['err_file']) 

232 # Create a file:// URL for local files 

233 file_url = f"file://{error_log_path}" 

234 return f'<a href="{file_url}" target="_blank">{row["err_file"].split("/")[-1]}</a>' 

235 return row['err_file'] 

236 

237 df['log_file_link'] = df.apply(make_out_link, axis=1) 

238 df['err_file_link'] = df.apply(make_error_link, axis=1) 

239 return df 

240 

241 def display_status_with_links(self): 

242 """Display the status DataFrame with clickable links in Jupyter.""" 

243 from IPython.display import HTML, display 

244 

245 df = self.status_with_links() 

246 # Drop the original log_file columns and rename the link columns 

247 df_display = df.drop(['log_file', 'err_file'], axis=1) 

248 df_display = df_display.rename(columns={ 

249 'log_file_link': 'log_file', 

250 'err_file_link': 'err_file' 

251 }) 

252 

253 # Convert to HTML and display 

254 html = df_display.to_html(escape=False) 

255 display(HTML(html)) 

256 

257 def launch_workflow(self, 

258 include_groups: list[int] = [], 

259 exclude_groups: list[int] = [], 

260 n_jobs: int = 0, 

261 slurm: bool = False, 

262 job_template: str = None, 

263 debug: bool = False): 

264 """Launch the workflow for each project directory in the dataset. 

265 

266 Args: 

267 include_groups (list[int]): 

268 List of group IDs to be run. 

269 exclude_groups (list[int]): 

270 List of group IDs to be excluded. 

271 n_jobs (int): 

272 Number of jobs to launch. If 0, all jobs are launched. 

273 slurm (bool): 

274 Whether to submit the workflow to SLURM. 

275 job_template (str): 

276 Path to the bash script or SLURM job template file. You can use Jinja2 

277 templating to customize the job script using the fields of 

278 the input YAML and the columns of project status dataframe. 

279 debug (bool): 

280 Only print the commands without executing them. 

281 

282 """ 

283 # Include/exclude groups should not intersect 

284 if include_groups and exclude_groups: 

285 intersection = set(include_groups).intersection(set(exclude_groups)) 

286 if intersection: 

287 raise ValueError(f"include_groups and exclude_groups intersect: {intersection}") 

288 

289 if slurm and not job_template: 

290 raise ValueError("job_template must be provided when slurm is True") 

291 n = 0 

292 for project_dir in self.project_directories: 

293 rel_path = os.path.relpath(project_dir, self.root_path) 

294 project_status = self.status.loc[rel_path].to_dict() 

295 project_status['rel_path'] = rel_path 

296 # Check group inclusion/exclusion 

297 if project_status['group'] in exclude_groups or \ 

298 (include_groups and project_status['group'] not in include_groups): 

299 continue 

300 n += 1 

301 if n_jobs > 0 and n > n_jobs: 

302 break 

303 inputs_yaml_path = os.path.join(project_dir, 'inputs.yaml') 

304 if not os.path.exists(inputs_yaml_path): 

305 warn(f"{inputs_yaml_path} not found. Skipping {project_dir}") 

306 continue 

307 

308 inputs_config = load_yaml(inputs_yaml_path) 

309 

310 with open(job_template, 'r') as f: 

311 template_str = f.read() 

312 

313 template = jinja2.Template(template_str) 

314 rendered_script = template.render(**inputs_config, **project_status, 

315 DIR=os.path.basename(os.path.normpath(project_dir))) 

316 

317 job_script_path = os.path.join(project_dir, 'mwf_slurm_job.sh') 

318 log_dir = os.path.join(project_dir, 'logs') 

319 os.makedirs(log_dir, exist_ok=True) 

320 with open(job_script_path, 'w') as f: 

321 f.write(rendered_script) 

322 os.chmod(job_script_path, 0o755) 

323 # Launch workflow 

324 if slurm: 

325 # SLURM execution 

326 if debug: 

327 print(f"cd {project_dir}") 

328 print("sbatch --output=logs/mwf_%j.out --error=logs/mwf_%j.err mwf_slurm_job.sh ") 

329 else: 

330 print(f"Submitting SLURM job for {project_dir}") 

331 subprocess.run(['sbatch', 

332 '--output=logs/mwf_%j.out', 

333 '--error=logs/mwf_%j.err', 

334 job_script_path], 

335 cwd=project_dir) 

336 

337 else: 

338 # Normal Python execution 

339 if debug: 

340 print(f"cd {project_dir}") 

341 print(f"bash {project_dir}/{os.path.basename(job_script_path)}") 

342 continue 

343 print(f"Running job for {project_dir}") 

344 log_file = os.path.join(log_dir, f'mwf_{int(time.time())}.out') 

345 subprocess.run( 

346 f"{job_script_path} 2>&1 | tee {log_file}", 

347 cwd=project_dir, 

348 shell=True 

349 )