Coverage for mddb_workflow/core/dataset.py: 13%
146 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 15:48 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 15:48 +0000
1from mddb_workflow.utils.auxiliar import load_yaml, is_glob
2from mddb_workflow.mwf import workflow
3import subprocess
4import os
5import glob
6import yaml
7import jinja2
8import pandas as pd
10class Dataset:
11 """
12 Class to manage and process a dataset of MDDB projects.
13 """
14 def __init__(self, dataset_yaml_path: str):
15 """
16 Initializes the Dataset object.
18 Args:
19 dataset_yaml_path (str): Path to the dataset YAML file.
20 """
21 self.root_path = os.path.dirname(os.path.abspath(dataset_yaml_path))
22 self.config = load_yaml(dataset_yaml_path)
23 # Properties cache
24 self._project_directories = None
25 self._status = None
27 @property
28 def project_directories(self) -> list[str]:
29 """
30 Retrieves the list of project directories from the dataset configuration.
32 Returns:
33 list: List of project directory patterns.
34 """
35 if self._project_directories is not None:
36 return self._project_directories
38 config_project_directories = self.config.get('global', {}).get('project_directories', [])
39 self._project_directories = []
40 for i, dir_pattern in enumerate(config_project_directories):
41 if is_glob(dir_pattern):
42 matched_dirs = glob.glob(os.path.join(self.root_path, dir_pattern))
43 # keep only directories
44 self._project_directories.extend([p for p in matched_dirs if os.path.isdir(p)])
45 else:
46 p = dir_pattern
47 if not os.path.isabs(p):
48 p = os.path.abspath(os.path.join(self.root_path, p))
49 # ensure the resolved path is under the dataset root
50 if os.path.commonprefix([p, self.root_path]) != self.root_path:
51 raise ValueError(f"Project directory '{p}' is outside the dataset root '{self.root_path}'")
52 self._project_directories.append(p)
54 return self._project_directories
56 def generate_inputs_yaml(self, inputs_template_path: str, input_generator: callable, overwrite: bool = False):
57 """
58 Generates an inputs.yaml file in each project directory based on the dataset configuration.
60 Args:
61 inputs_template_path (str): The file path to the Jinja2 template file that will be
62 used to generate the inputs YAML files.
63 input_generator (callable): A callable function intended for generating input values.
64 Currently, it is called with the project directory name (DIR) as its argument
65 overwrite (bool): Whether to overwrite existing inputs.yaml files. Default is False.
67 """
68 # Load the template
69 with open(inputs_template_path, 'r') as f:
70 template_str = f.read()
72 template = jinja2.Template(template_str)
74 for project_dir in self.project_directories:
75 inputs_yaml_path = os.path.join(project_dir, 'inputs.yaml')
76 if os.path.exists(inputs_yaml_path) and not overwrite:
77 continue
79 # Get the directory name
80 DIR = os.path.basename(os.path.normpath(project_dir))
82 # Render the template with project defaults
83 rendered_yaml = template.render(DIR=DIR, title=input_generator(DIR))
85 # Write the rendered YAML to inputs.yaml
86 with open(inputs_yaml_path, 'w') as f:
87 f.write(rendered_yaml)
88 break
90 @property
91 def status(self) -> pd.DataFrame:
92 """
93 Retrieves last line from logs from all project directories as a pandas DataFrame.
95 Returns:
96 pd.DataFrame: Index is project directory; columns: state, message, log_file, error_log_file.
97 """
98 if self._status is not None:
99 return self._status
101 rows = []
102 for project_dir in self.project_directories:
103 # RUBEN: por ahora usamos el mismo patron de log que en launch_workflow
104 # en un futuro se recuperar el estado a partir de .register/.mwf_cache
105 log_files = glob.glob(os.path.join(project_dir, 'logs', 'mwf*[0-9].out'))
106 err_files = glob.glob(os.path.join(project_dir, 'logs', 'mwf*[0-9].err'))
108 if log_files:
109 log_files.sort()
110 log_files = [log_files[-1]] # Take the most recent one
111 with open(log_files[0], 'r') as f:
112 last_line = f.read().splitlines()[-1].strip()
114 if last_line == 'Done!':
115 state, message, log_file = 'done', last_line, log_files[0]
116 else:
117 state, message, log_file = 'error', last_line, log_files[0]
118 else:
119 state, message, log_file = 'not_run', 'No output log available', None
121 # Handle error log files
122 if err_files:
123 err_files.sort()
124 err_file = err_files[-1] # Take the most recent one
125 else:
126 err_file = None
128 rows.append({
129 'rel_path': os.path.relpath(project_dir, self.root_path),
130 'state': state,
131 'message': message,
132 'log_file': os.path.relpath(log_file, project_dir) if log_file else '',
133 'err_file': os.path.relpath(err_file, project_dir) if err_file else ''
134 })
136 df = pd.DataFrame(rows).set_index('rel_path').sort_index()
137 # Assign an integer group id for identical messages, with messages sorted first
138 unique_messages = sorted(df['message'].unique())
139 if 'Done!' in unique_messages:
140 # Ensure 'Done!' is always group 0, then assign other messages
141 unique_messages.remove('Done!')
142 mapping = {'Done!': 0}
143 mapping.update({msg: idx + 1 for idx, msg in enumerate(unique_messages)})
144 else:
145 mapping = {msg: idx for idx, msg in enumerate(unique_messages)}
146 df['group'] = df['message'].map(mapping).astype(int)
147 self._status = df
148 return self._status
150 def show_groups(self, cmd=False):
151 """
152 Displays the groups of projects based on their status messages.
153 """
154 if cmd:
155 status = self.status
156 grouped = status.groupby('group')
157 for group_id, group_df in grouped:
158 print(f"Group {group_id}:")
159 print(f"Message: {group_df['message'].iloc[0]}")
160 print("Projects:")
161 for rel_path in group_df.index:
162 print(f" - {rel_path}")
163 print()
164 else:
165 grouped = self.status.groupby('group').agg({
166 'message': 'first',
167 'state': 'count'
168 }).rename(columns={'state': 'count'})
169 return grouped
171 def status_with_links(self) -> pd.DataFrame:
172 """
173 Returns the status DataFrame with clickable log file links.
174 """
175 df = self.status.copy()
177 # Create clickable links for log files
178 def make_out_link(row):
179 if row['log_file'] and row['state'] != 'not_run':
180 project_dir = os.path.join(self.root_path, row.name)
181 log_path = os.path.join(project_dir, row['log_file'])
182 # Create a file:// URL for local files
183 file_url = f"file://{log_path}"
184 return f'<a href="{file_url}" target="_blank">{row["log_file"]}</a>'
185 return row['log_file']
187 # Create clickable links for error log files
188 def make_error_link(row):
189 if row['err_file'] and row['state'] != 'not_run':
190 project_dir = os.path.join(self.root_path, row.name)
191 error_log_path = os.path.join(project_dir, row['err_file'])
192 # Create a file:// URL for local files
193 file_url = f"file://{error_log_path}"
194 return f'<a href="{file_url}" target="_blank">{row["err_file"]}</a>'
195 return row['err_file']
197 df['log_file_link'] = df.apply(make_out_link, axis=1)
198 df['err_file_link'] = df.apply(make_error_link, axis=1)
199 return df
201 def display_status_with_links(self):
202 """
203 Display the status DataFrame with clickable links in Jupyter.
204 """
205 from IPython.display import HTML, display
207 df = self.status_with_links()
208 # Drop the original log_file columns and rename the link columns
209 df_display = df.drop(['log_file', 'err_file'], axis=1)
210 df_display = df_display.rename(columns={
211 'log_file_link': 'log_file',
212 'err_file_link': 'err_file'
213 })
215 # Convert to HTML and display
216 html = df_display.to_html(escape=False)
217 display(HTML(html))
219 def launch_workflow(self,
220 include_groups: list[int]=[],
221 exclude_groups: list[int]=[],
222 slurm: bool=False,
223 job_template: str=None):
224 """
225 Launches the workflow for each project directory in the dataset.
226 Args:
227 include_groups (list[int]):
228 List of group IDs to be run.
229 exclude_groups (list[int]):
230 List of group IDs to be excluded.
231 slurm (bool):
232 Whether to submit the workflow to SLURM.
233 job_template (str):
234 Path to the SLURM job template file. You can use Jinja2
235 templating to customize the job script using the fields of
236 the input YAML and the columns of project status dataframe.
237 """
238 # Include/exclude groups should not intersect
239 if include_groups and exclude_groups:
240 intersection = set(include_groups).intersection(set(exclude_groups))
241 if intersection:
242 raise ValueError(f"include_groups and exclude_groups intersect: {intersection}")
244 if slurm and not job_template:
245 raise ValueError("job_template must be provided when slurm is True")
246 for project_dir in self.project_directories:
247 project_status = self.status.loc[os.path.relpath(project_dir, self.root_path)].to_dict()
248 # Check group inclusion/exclusion
249 group_id = project_status['group']
250 if group_id in exclude_groups or (include_groups and group_id not in include_groups):
251 continue
252 # Launch workflow
253 if slurm:
254 # SLURM execution
255 inputs_yaml_path = os.path.join(project_dir, 'inputs.yaml')
256 if not os.path.exists(inputs_yaml_path):
257 print(f"Warning: {inputs_yaml_path} not found. Skipping {project_dir}")
258 continue
260 inputs_config = load_yaml(inputs_yaml_path)
262 with open(job_template, 'r') as f:
263 template_str = f.read()
265 template = jinja2.Template(template_str)
266 rendered_script = template.render(**inputs_config, **project_status)
268 job_script_path = os.path.join(project_dir, 'mwf_slurm_job.sh')
269 log_dir = os.path.join(project_dir, 'logs')
270 os.makedirs(log_dir, exist_ok=True)
271 with open(job_script_path, 'w') as f:
272 f.write(rendered_script)
274 os.chmod(job_script_path, 0o755)
276 print(f"Submitting SLURM job for {project_dir}")
277 subprocess.run(['sbatch', 'mwf_slurm_job.sh',
278 '--output', os.path.join(log_dir, 'mwf-%j.out'),
279 '--error', os.path.join(log_dir, 'mwf-%j.err')],
280 cwd=project_dir)
282 else:
283 # Normal Python execution
284 raise NotImplementedError("Python execution is not implemented yet.")
285 workflow(working_directory=project_dir)