Coverage for mddb_workflow / core / dataset.py: 11%
188 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 18:45 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 18:45 +0000
1from mddb_workflow.utils.auxiliar import load_yaml, is_glob, warn
2from mddb_workflow.utils.type_hints import *
3import pandas as pd
4import subprocess
5import jinja2
6import time
7import glob
8import os
11class Dataset:
12 """Class to manage and process a dataset of MDDB projects."""
14 def __init__(self, dataset_yaml_path: str):
15 """Initialize the Dataset object.
17 Args:
18 dataset_yaml_path (str): Path to the dataset YAML file.
20 """
21 self.root_path = os.path.dirname(os.path.abspath(dataset_yaml_path))
22 self.config = load_yaml(dataset_yaml_path)
23 # Properties cache
24 self.project_directories = self.get_project_directories()
25 self.groups = self.get_groups()
26 self._status = None
28 def _resolve_directory_patterns(self, dir_patterns: list[str]) -> list[str]:
29 """Resolve directory patterns (glob or absolute/relative paths).
30 Validates that resolved paths are under the dataset root.
32 Args:
33 dir_patterns (list[str]): List of directory patterns to resolve.
35 Returns:
36 list[str]: List of resolved directory paths.
38 Raises:
39 ValueError: If a resolved path is outside the dataset root.
41 """
42 directories = []
43 for dir_pattern in dir_patterns:
44 if is_glob(dir_pattern):
45 matched_dirs = glob.glob(os.path.join(self.root_path, dir_pattern))
46 # keep only directories
47 directories.extend([p for p in matched_dirs if os.path.isdir(p)])
48 else:
49 p = dir_pattern
50 if not os.path.isabs(p):
51 p = os.path.abspath(os.path.join(self.root_path, p))
52 # ensure the resolved path is under the dataset root
53 if os.path.commonprefix([p, self.root_path]) != self.root_path:
54 raise ValueError(f"Project directory '{p}' is outside the dataset root '{self.root_path}'")
55 directories.append(p)
56 return directories
58 def get_project_directories(self) -> list[str]:
59 """Retrieve the list of project directories from the dataset configuration."""
60 dirs = self._resolve_directory_patterns(self.config.get('projects', []))
61 ignore_dirs = self._resolve_directory_patterns(self.config.get('ignore', []))
62 dirs = [d for d in dirs if d not in ignore_dirs]
63 if not dirs:
64 raise ValueError("No project directories found in the dataset configuration.")
65 return dirs
67 def get_groups(self) -> dict[str, list[str]]:
68 """Retrieve the groups of project directories from the dataset configuration."""
69 groups = {}
70 for group in self.config.get('groups', []):
71 groups[group] = self._resolve_directory_patterns(self.config['groups'][group])
72 return groups
74 def generate_inputs_yaml(self,
75 inputs_template_path: str,
76 input_generator: Callable,
77 overwrite: bool = False
78 ):
79 """Generate an inputs.yaml file in each project directory based on the dataset configuration.
81 Args:
82 inputs_template_path (str): The file path to the Jinja2 template file that will be
83 used to generate the inputs YAML files.
84 input_generator (callable): A callable function intended for generating input values.
85 Currently, it is called with the project directory name (DIR) as its argument
86 overwrite (bool): Whether to overwrite existing inputs.yaml files. Default is False.
88 """
89 # Load the template
90 with open(inputs_template_path, 'r') as f:
91 template_str = f.read()
93 template = jinja2.Template(template_str)
94 skipped = 0
95 generated = 0
96 for project_dir in self.project_directories:
97 inputs_yaml_path = os.path.join(project_dir, 'inputs.yaml')
98 if os.path.exists(inputs_yaml_path) and not overwrite:
99 skipped += 1
100 continue
101 generated += 1
102 # Get the directory name
103 DIR = os.path.basename(os.path.normpath(project_dir))
104 # Render the template with project defaults
105 rendered_yaml = template.render(DIR=DIR, title=input_generator(DIR))
106 # Write the rendered YAML to inputs.yaml
107 with open(inputs_yaml_path, 'w') as f:
108 f.write(rendered_yaml)
109 print(f"Generated {generated} inputs.yaml files. Skipped {skipped} existing files.")
111 @property
112 def status(self) -> pd.DataFrame:
113 """Retrieve last line from logs from all project directories as a pandas DataFrame.
115 Returns:
116 pd.DataFrame: Index is project directory; columns: state, message, log_file, error_log_file.
118 """
119 if self._status is not None:
120 return self._status
122 rows = []
123 for project_dir in self.project_directories:
124 # RUBEN: por ahora usamos el mismo patron de log que en launch_workflow
125 # en un futuro se recuperar el estado a partir de .register/.mwf_cache
126 log_files = glob.glob(os.path.join(project_dir, 'logs', 'mwf*[0-9].out'))
127 err_files = glob.glob(os.path.join(project_dir, 'logs', 'mwf*[0-9].err'))
129 if log_files:
130 log_files.sort(key=lambda x: os.path.getmtime(x))
131 log_files = [log_files[-1]] # Take the most recent one
132 with open(log_files[0], 'r') as f:
133 lines = f.read().splitlines()
134 if lines:
135 last_line = lines[-1].strip()
136 if len(last_line) > 80:
137 last_line = last_line[:80] + '...'
138 else:
139 last_line = ''
141 if last_line == 'Done!':
142 state, message, log_file = 'done', last_line, log_files[0]
143 else:
144 state, message, log_file = 'error', last_line if last_line else 'Empty log file', log_files[0]
145 else:
146 state, message, log_file = 'not_run', 'No output log available', None
148 # Handle error log files
149 if err_files:
150 err_files.sort()
151 err_file = err_files[-1] # Take the most recent one
152 else:
153 err_file = None
155 # Check if files were modified recently (within last 5 minutes) to detect running state
156 current_time = time.time()
157 recently_modified_threshold = 300 # 5 minutes in seconds
159 last_modified = ''
160 if (log_file and os.path.exists(log_file) and (current_time - os.path.getmtime(log_file)) < recently_modified_threshold) or \
161 (err_file and os.path.exists(err_file) and (current_time - os.path.getmtime(err_file)) < recently_modified_threshold):
162 if state != 'done':
163 state = 'running'
164 else:
165 # Save last modification time if not running
166 if log_file and os.path.exists(log_file):
167 last_modified = time.strftime('%H:%M:%S %d/%m/%y', time.localtime(os.path.getmtime(log_file)))
169 rows.append({
170 'rel_path': os.path.relpath(project_dir, self.root_path),
171 'state': state,
172 'message': message,
173 'log_file': os.path.relpath(log_file, project_dir) if log_file else '',
174 'err_file': os.path.relpath(err_file, project_dir) if err_file else '',
175 'last_modified': last_modified
176 })
178 df = pd.DataFrame(rows).set_index('rel_path').sort_index()
179 # Assign an integer group id for identical messages, with messages sorted first
180 unique_messages = sorted(df['message'].unique())
181 if 'Done!' in unique_messages:
182 # Ensure 'Done!' is always group 0, then assign other messages
183 unique_messages.remove('Done!')
184 mapping = {'Done!': 0}
185 mapping.update({msg: idx + 1 for idx, msg in enumerate(unique_messages)})
186 else:
187 mapping = {msg: idx for idx, msg in enumerate(unique_messages)}
188 df['group'] = df['message'].map(mapping).astype(int)
189 self._status = df
190 return self._status
192 def show_groups(self, cmd=False):
193 """Display the groups of projects based on their status messages."""
194 if cmd:
195 print("Project groups based on status messages:\n")
196 status = self.status
197 grouped = status.groupby('group')
198 for group_id, group_df in grouped:
199 print(f"Group {group_id}:")
200 print(f"Message: {group_df['message'].iloc[0]}")
201 print("Projects:")
202 for rel_path in group_df.index:
203 print(f" - {rel_path}")
204 print()
205 else:
206 grouped = self.status.groupby('group').agg({
207 'message': 'first',
208 'state': 'count'
209 }).rename(columns={'state': 'count'})
210 return grouped
212 def status_with_links(self) -> pd.DataFrame:
213 """Return the status DataFrame with clickable log file links."""
214 self._status = None # Force reload
215 df = self.status.copy()
217 # Create clickable links for log files
218 def make_out_link(row):
219 if row['log_file'] and row['state'] != 'not_run':
220 project_dir = os.path.join(self.root_path, row.name)
221 log_path = os.path.join(project_dir, row['log_file'])
222 # Create a file:// URL for local files
223 file_url = f"file://{log_path}"
224 return f'<a href="{file_url}" target="_blank">{row["log_file"].split("/")[-1]}</a>'
225 return row['log_file']
227 # Create clickable links for error log files
228 def make_error_link(row):
229 if row['err_file'] and row['state'] != 'not_run':
230 project_dir = os.path.join(self.root_path, row.name)
231 error_log_path = os.path.join(project_dir, row['err_file'])
232 # Create a file:// URL for local files
233 file_url = f"file://{error_log_path}"
234 return f'<a href="{file_url}" target="_blank">{row["err_file"].split("/")[-1]}</a>'
235 return row['err_file']
237 df['log_file_link'] = df.apply(make_out_link, axis=1)
238 df['err_file_link'] = df.apply(make_error_link, axis=1)
239 return df
241 def display_status_with_links(self):
242 """Display the status DataFrame with clickable links in Jupyter."""
243 from IPython.display import HTML, display
245 df = self.status_with_links()
246 # Drop the original log_file columns and rename the link columns
247 df_display = df.drop(['log_file', 'err_file'], axis=1)
248 df_display = df_display.rename(columns={
249 'log_file_link': 'log_file',
250 'err_file_link': 'err_file'
251 })
253 # Convert to HTML and display
254 html = df_display.to_html(escape=False)
255 display(HTML(html))
257 def launch_workflow(self,
258 include_groups: list[int] = [],
259 exclude_groups: list[int] = [],
260 n_jobs: int = 0,
261 slurm: bool = False,
262 job_template: str = None,
263 debug: bool = False):
264 """Launch the workflow for each project directory in the dataset.
266 Args:
267 include_groups (list[int]):
268 List of group IDs to be run.
269 exclude_groups (list[int]):
270 List of group IDs to be excluded.
271 n_jobs (int):
272 Number of jobs to launch. If 0, all jobs are launched.
273 slurm (bool):
274 Whether to submit the workflow to SLURM.
275 job_template (str):
276 Path to the bash script or SLURM job template file. You can use Jinja2
277 templating to customize the job script using the fields of
278 the input YAML and the columns of project status dataframe.
279 debug (bool):
280 Only print the commands without executing them.
282 """
283 # Include/exclude groups should not intersect
284 if include_groups and exclude_groups:
285 intersection = set(include_groups).intersection(set(exclude_groups))
286 if intersection:
287 raise ValueError(f"include_groups and exclude_groups intersect: {intersection}")
289 if slurm and not job_template:
290 raise ValueError("job_template must be provided when slurm is True")
291 n = 0
292 for project_dir in self.project_directories:
293 rel_path = os.path.relpath(project_dir, self.root_path)
294 project_status = self.status.loc[rel_path].to_dict()
295 project_status['rel_path'] = rel_path
296 # Check group inclusion/exclusion
297 if project_status['group'] in exclude_groups or \
298 (include_groups and project_status['group'] not in include_groups):
299 continue
300 n += 1
301 if n_jobs > 0 and n > n_jobs:
302 break
303 inputs_yaml_path = os.path.join(project_dir, 'inputs.yaml')
304 if not os.path.exists(inputs_yaml_path):
305 warn(f"{inputs_yaml_path} not found. Skipping {project_dir}")
306 continue
308 inputs_config = load_yaml(inputs_yaml_path)
310 with open(job_template, 'r') as f:
311 template_str = f.read()
313 template = jinja2.Template(template_str)
314 rendered_script = template.render(**inputs_config, **project_status,
315 DIR=os.path.basename(os.path.normpath(project_dir)))
317 job_script_path = os.path.join(project_dir, 'mwf_slurm_job.sh')
318 log_dir = os.path.join(project_dir, 'logs')
319 os.makedirs(log_dir, exist_ok=True)
320 with open(job_script_path, 'w') as f:
321 f.write(rendered_script)
322 os.chmod(job_script_path, 0o755)
323 # Launch workflow
324 if slurm:
325 # SLURM execution
326 if debug:
327 print(f"cd {project_dir}")
328 print("sbatch --output=logs/mwf_%j.out --error=logs/mwf_%j.err mwf_slurm_job.sh ")
329 else:
330 print(f"Submitting SLURM job for {project_dir}")
331 subprocess.run(['sbatch',
332 '--output=logs/mwf_%j.out',
333 '--error=logs/mwf_%j.err',
334 job_script_path],
335 cwd=project_dir)
337 else:
338 # Normal Python execution
339 if debug:
340 print(f"cd {project_dir}")
341 print(f"bash {project_dir}/{os.path.basename(job_script_path)}")
342 continue
343 print(f"Running job for {project_dir}")
344 log_file = os.path.join(log_dir, f'mwf_{int(time.time())}.out')
345 subprocess.run(
346 f"{job_script_path} 2>&1 | tee {log_file}",
347 cwd=project_dir,
348 shell=True
349 )