Coverage for mddb_workflow / utils / auxiliar.py: 72%
343 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 18:45 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 18:45 +0000
1# Auxiliar generic functions and classes used along the workflow
3from mddb_workflow import __path__, __version__
4from mddb_workflow.utils.constants import RESIDUE_NAME_LETTERS, PROTEIN_RESIDUE_NAME_LETTERS
5from mddb_workflow.utils.constants import YELLOW_HEADER, COLOR_END
6from mddb_workflow.utils.constants import STANDARD_TOPOLOGY_FILENAME
7from mddb_workflow.utils.type_hints import *
9import os
10from os.path import isfile, exists
11import re
12import sys
13import json
14import yaml
15from glob import glob
16from struct import pack
17# NEVER FORGET: GraphQL has a problem with urllib.parse -> It will always return error 400 (Bad request)
18# We must use requests instead
19import requests
20import urllib.request
21from subprocess import run, PIPE
22from dataclasses import asdict, is_dataclass
25def is_imported(module_name: str) -> bool:
26 """Check if a module has been imported."""
27 return module_name in sys.modules
30class QuietException (Exception):
31 """Exception which is not to print traceback.
32 They are used when the problem is not in our code.
33 """
34 pass
37class InputError (QuietException):
38 """Quite exception for when user input is wrong."""
39 pass
42class TestFailure (QuietException):
43 """Quite exception for when MD data has not passed a quality check test."""
44 pass
47class EnvironmentError (QuietException):
48 """Quite exception for when the problem is not in the code but in the environment."""
49 pass
52class ToolError (QuietException):
53 """Quite exception for when the problem comes from a third party dependency."""
54 pass
57class RemoteServiceError (QuietException):
58 """Quite exception for when the problem comes from a remote service."""
59 pass
62class ForcedStop (QuietException):
63 """Quite exception for when we stop the workflow in purpose."""
64 pass
67class NoReferableException (Exception):
68 """No referable exception for PDB synthetic constructs or chimeric entities."""
69 def __str__(self): return f'No referable sequence {self.sequence}'
70 def __repr__(self): return self.__str__()
71 def get_sequence(self) -> str:
72 return self.args[0]
73 sequence = property(get_sequence, None, None, 'Aminoacids sequence')
76def custom_excepthook(exception_class, message, traceback):
77 """Handle a custom exception where our input error exception has a quiet behaviour."""
78 # Quite behaviour if it is our input error exception
79 if QuietException in exception_class.__bases__:
80 print('{0}: {1}'.format(exception_class.__name__, message)) # Only print Error Type and Message
81 return
82 # Default behaviour otherwise
83 sys.__excepthook__(exception_class, message, traceback)
86sys.excepthook = custom_excepthook
88# Set a special exceptions for when the topology is missing
89MISSING_TOPOLOGY = Exception('Missing topology')
90MISSING_CHARGES = Exception('Missing atom charges')
91MISSING_BONDS = Exception('Missing atom bonds')
92JSON_SERIALIZABLE_MISSING_BONDS = 'MB'
94# Keep all exceptions in a set
95STANDARD_EXCEPTIONS = {MISSING_TOPOLOGY, MISSING_CHARGES, MISSING_BONDS}
97# Set a function to get the next letter from an input letter in alphabetic order
98# Return None if we run out of letters
99letters = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
100 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'}
103def get_new_letter(current_letters: set) -> Optional[str]:
104 return next((letter for letter in letters if letter not in current_letters), None)
107def all_cases(names: list[str] | set[str]) -> set[str]:
108 """Given a list or set of names.
110 Return a set with all case-posibilites:
111 - All upper case
112 - All lower case
113 - First upper case and the rest lower case
115 """
116 all_names = []
117 for name in names:
118 all_upper = name.upper()
119 all_lower = name.lower()
120 one_upper = name[0].upper() + name[1:].lower()
121 all_names += [all_upper, all_lower, one_upper]
122 return set(all_names)
125def residue_name_to_letter(residue_name: str) -> str:
126 """Given a residue name, return its single letter."""
127 return RESIDUE_NAME_LETTERS.get(residue_name, 'X')
130def protein_residue_name_to_letter(residue_name: str) -> str:
131 """Given a protein residue name, return its single letter."""
132 return PROTEIN_RESIDUE_NAME_LETTERS.get(residue_name, 'X')
135OBJECT_TYPES = {dict, list, tuple}
138def recursive_transformer(target_object: dict | list | tuple, transformer: Optional[Callable] = None) -> dict | list | tuple:
139 """Recursive transformer for nested dicts and lists.
140 Note that a new object is created to prevent mutating the original.
141 If no transformer function is passed then it becomes a recursive cloner.
142 WARNING: the transformer function could mutate some types of the original object if not done properly.
143 """
144 object_type = type(target_object)
145 # Get a starting object and entries iterator depending on the object type
146 if object_type is dict:
147 clone = {}
148 entries = target_object.items()
149 elif object_type is list or object_type is tuple:
150 # Note that if it is a tuple we make a list anyway and we convert it to tuple at the end
151 # WARNING: Note that tuples do not support reassigning items
152 clone = [None for i in range(len(target_object))]
153 entries = enumerate(target_object)
154 else: ValueError(f'The recursive cloner should only be applied to object and lists, not to {object_type}')
155 # Iterate the different entries in the object
156 for index_or_key, value in entries:
157 # Get he value type
158 value_type = type(value)
159 # If it is a dict or list then call the transformer recursively
160 if value_type in OBJECT_TYPES: clone[index_or_key] = recursive_transformer(value, transformer)
161 # If it is not an object type then apply the transformer to it
162 else: clone[index_or_key] = transformer(value) if transformer else value
163 # If it was a tuple then make the conversion now
164 if object_type == tuple: clone = tuple(clone)
165 return clone
168# Set some headers for the serializer
169EXCEPTION_HEADER = 'Exception: '
172# LORE: This was originally intended to support exceptions in the cache
173def json_serializer(object: dict | list | tuple) -> dict | list | tuple:
174 """Serialize a standard JSON with support for additional types."""
175 def serializer(value):
176 # If we have exceptions then convert them to text with an appropiate header
177 if type(value) is Exception:
178 return f'{EXCEPTION_HEADER}{value}'
179 # This must be done before the set check because asdict() will create dicts with sets inside
180 if is_dataclass(value) and not isinstance(value, type):
181 dict_value = asdict(value)
182 return recursive_transformer(dict_value, serializer)
183 if isinstance(value, set):
184 return list(value)
185 # If the type is not among the ones we check then assume it is already serializable
186 return value
187 object_clone = recursive_transformer(object, serializer)
188 return object_clone
191def json_deserializer(object: dict | list | tuple) -> dict | list | tuple:
192 """Deserialize a standard JSON with support for additional types."""
193 def deserializer(value):
194 # Check if there is any value which was adapted to become JSON serialized and restore it
195 if type(value) is str and value[0:11] == EXCEPTION_HEADER:
196 # WARNING: Do not declare new exceptions here but use the constant ones
197 # WARNING: Otherwise further equality comparisions will fail
198 exception_message = value[11:]
199 standard_exception = next((exception for exception in STANDARD_EXCEPTIONS if str(exception) == exception_message), None)
200 if standard_exception is None:
201 raise ValueError(f'Exception "{exception_message}" is not among standard exceptions')
202 return standard_exception
203 # If the type is not among the ones we check then assume it is already deserialized
204 return value
205 object_clone = recursive_transformer(object, deserializer)
206 return object_clone
209def load_json(filepath: str, replaces: Optional[list[tuple]] = []) -> dict:
210 """Load a JSON with additional logic to better handle problems."""
211 try:
212 with open(filepath, 'r') as file:
213 content = file.read()
214 # Make pure text replacements
215 for replace in replaces:
216 target, replacement = replace
217 content = content.replace(target, replacement)
218 # Parse the content to JSON
219 parsed_content = json.loads(content)
220 # Deserialize some types like Exceptions, which are stored in JSONs in this context
221 deserialized_content = json_deserializer(parsed_content)
222 return deserialized_content
223 except Exception as error:
224 raise Exception(f'Something went wrong when loading JSON file {filepath}: {str(error)}')
227def save_json(content, filepath: str, indent: Optional[int] = None):
228 """Save a JSON with additional logic to better handle problems."""
229 try:
230 with open(filepath, 'w') as file:
231 serialized_content = json_serializer(content)
232 json.dump(serialized_content, file, indent=indent)
233 except Exception as error:
234 # Rename the JSON file since it will be half written thus giving problems when loaded
235 os.rename(filepath, filepath + '.wrong')
236 raise Exception(f'Something went wrong when saving JSON file {filepath}: {str(error)}')
239# DANI: Por algún motivo yaml.load también funciona con archivos en formato JSON
240def load_yaml(filepath: str, replaces: Optional[list[tuple]] = []) -> dict:
241 """Load a YAML with additional logic to better handle problems.
242 The argument replaces allows to replace file content before beeing processed.
243 Every replace is a tuple whith two values: the target and the replacement.
244 """
245 try:
246 with open(filepath, 'r') as file:
247 content = file.read()
248 # Pre-process fields that commonly contain colons (DOIs, URLs, etc.)
249 # Match field: value where value isn't already quoted and may span multiple lines
250 content = re.sub(
251 r'^(citation:\s*)(?!")([^\n]+(?:\n(?!\w+:)[^\n]+)*)$',
252 r'\1"\2"',
253 content,
254 flags=re.MULTILINE
255 )
256 for replace in replaces:
257 target, replacement = replace
258 content = content.replace(target, replacement)
259 parsed_content = yaml.load(content, Loader=yaml.CLoader)
260 return parsed_content
261 except Exception as error:
262 warn(str(error).replace('\n', ' '))
263 raise InputError('Something went wrong when loading YAML file ' + filepath)
266def save_yaml(content, filepath: str):
267 """Save a YAML with additional logic to better handle problems."""
268 with open(filepath, 'w') as file:
269 yaml.dump(content, file)
272# Set a few constants to erase previou logs in the terminal
273CURSOR_UP_ONE = '\x1b[1A'
274ERASE_LINE = '\x1b[2K'
275ERASE_PREVIOUS_LINES = CURSOR_UP_ONE + ERASE_LINE + CURSOR_UP_ONE
278def delete_previous_log():
279 """Remove previous line in the terminal."""
280 print(ERASE_PREVIOUS_LINES)
283def reprint(text: str):
284 """Reprint text in the same line in the terminal."""
285 delete_previous_log()
286 print(text)
289def warn(message: str):
290 """Print a message with a colored warning header."""
291 print(YELLOW_HEADER + '⚠ WARNING: ' + COLOR_END + message)
294def mean(values: list[float]) -> float:
295 """Get the mean/average of a list of values."""
296 return sum(values) / len(values)
299def round_to_hundredths(number: float) -> float:
300 """Round a number to hundredths."""
301 return round(number * 100) / 100
304def round_to_thousandths(number: float) -> float:
305 """Round a number to thousandths."""
306 return round(number * 1000) / 1000
309def ranger(numbers: list[int]) -> str:
310 """Given a list with numbers, create a string where number in a row are represented rangedly.
312 Example:
313 [1, 3, 5, 6, 7, 8] => "1, 3, 5-8"
315 """
316 # Remove duplicates and sort numbers
317 sorted_numbers = sorted(list(set(numbers)))
318 # Get the number of numbers in the list
319 number_count = len(sorted_numbers)
320 # If there is only one number then finish here
321 if number_count == 1:
322 return str(sorted_numbers[0])
323 # Start the parsing otherwise
324 ranged = ''
325 last_number = -1
326 # Iterate numbers
327 for i, number in enumerate(sorted_numbers):
328 # Skip this number if it was already included in a previous serie
329 if i <= last_number: continue
330 # Add current number to the ranged string
331 ranged += ',' + str(number)
332 # Now iterate numbers after the current number
333 next_index = i+1
334 for j, next_number in enumerate(sorted_numbers[next_index:], next_index):
335 # Set the length of the serie
336 length = j - i
337 # End of the serie
338 if next_number - number != length:
339 # The length here is the length which broke the serie
340 # i.e. if the length here is 2 the actual serie length is 1
341 serie_length = length - 1
342 if serie_length > 1:
343 last_serie_number = j - 1
344 previous_number = sorted_numbers[last_serie_number]
345 ranged += '-' + str(previous_number)
346 last_number = last_serie_number
347 break
348 # End of the selection
349 if j == number_count - 1:
350 if length > 1:
351 ranged += '-' + str(next_number)
352 last_number = j
353 # Remove the first coma before returning the ranged string
354 return ranged[1:]
357def otherwise(values: list) -> Generator[tuple, None, None]:
358 """Set a special iteration system.
359 Return one value of the array and a new array with all other values for each value.
360 """
361 for v, value in enumerate(values):
362 others = values[0:v] + values[v+1:]
363 yield value, others
366# List files in a directory
367def list_files(directory: str) -> list[str]:
368 """List files in a directory."""
369 return [f for f in os.listdir(directory) if isfile(f'{directory}/{f}')]
372# Check if a directory is empty
373def is_directory_empty(directory: str) -> bool:
374 """Check if a directory is empty."""
375 return len(os.listdir(directory)) == 0
378GLOB_CHARACTERS = ['*', '?', '[']
381def is_glob(path: str) -> bool:
382 """Check if a string has patterns to be parsed by a glob function.
384 Note that this is not trivial, but this function should be good enough for our case.
385 https://stackoverflow.com/questions/42283009/check-if-string-is-a-glob-pattern
386 """
387 # Find unescaped glob characters
388 for c, character in enumerate(path):
389 if character not in GLOB_CHARACTERS:
390 continue
391 if c == 0:
392 return True
393 previous_characters = path[c-1]
394 if previous_characters != '\\':
395 return True
396 return False
399def parse_glob(path: str) -> list[str]:
400 """Parse a glob path into one or several results.
402 If the path has no glob characters then return it as it is.
403 Otherwise parse the glob pattern.
404 """
405 # If there is no glob pattern then just return the string as is
406 if not is_glob(path):
407 return [path]
408 # If there is glob pattern then parse it
409 parsed_filepaths = glob(path)
410 return parsed_filepaths
413def is_url(path: str) -> bool:
414 """Return whether the passed string is a URL or not."""
415 return path[0:4] == 'http'
418def url_to_source_filename(url: str) -> str:
419 """Set the filename of an input file downloaded from an input URL.
421 In this scenario we are free to set our own paths or filenames.
422 Note that the original name will usually be the very same output filename.
423 In order to avoid both filenames being the same we will add a header here.
424 """
425 original_filename = url.split('/')[-1]
426 return 'source_' + original_filename
429def download_file(request_url: str, output_file: 'File'):
430 """Download files from a specific URL."""
431 print(f'Downloading file "{output_file.path}" from {request_url}\n')
432 try:
433 urllib.request.urlretrieve(request_url, output_file.path)
434 except urllib.error.HTTPError as error:
435 if error.code == 404:
436 raise Exception(f'Missing remote file "{output_file.filename}"')
437 # If we don't know the error then simply say something went wrong
438 raise Exception(f'Something went wrong when downloading file "{output_file.filename}" from {request_url}')
441def is_standard_topology(file: 'File') -> bool:
442 """Check if a file is a standard topology.
443 Note that the filename may include the source header.
444 """
445 return file.filename.endswith(STANDARD_TOPOLOGY_FILENAME)
448# Supported byte sizes
449SUPPORTED_BYTE_SIZES = {
450 2: 'e',
451 4: 'f',
452 8: 'd'
453}
456def store_binary_data(data: list[float], byte_size: int, filepath: str):
457 """Store binary data to a file.
459 Args:
460 data: A list of numeric values
461 byte_size: The number of bytes for each value in data to be occupied
462 filepath: The output file path
464 """
465 # Check bit size to make sense
466 letter = SUPPORTED_BYTE_SIZES.get(byte_size, None)
467 if not letter:
468 raise ValueError(f'Not supported byte size {byte_size}, please select one of these: {", ".join(SUPPORTED_BYTE_SIZES.keys())}')
469 # Set the binary format
470 # '<' stands for little endian
471 byte_flag = f'<{letter}'
472 # Start writting the output file
473 with open(filepath, 'wb') as file:
474 # Iterate over data list values
475 for value in data:
476 value = float(value)
477 file.write(pack(byte_flag, value))
480class CaptureOutput (object):
481 """Capture all stdout or stderr within a code region even if it comes from another non-python threaded process.
483 https://stackoverflow.com/questions/24277488/in-python-how-to-capture-the-stdout-from-a-c-shared-library-to-a-variable
484 """
485 escape_char = "\b"
486 def __init__(self, stream: str = 'stdout'):
487 # Get sys stdout or stderr
488 if not hasattr(sys, stream):
489 raise ValueError(f'Unknown stream "{stream}". Expected stream value is "stdout" or "stderr"')
490 self.original_stream = getattr(sys, stream)
491 self.original_streamfd = self.original_stream.fileno()
492 self.captured_text = ""
493 # Create a pipe so the stream can be captured:
494 self.pipe_out, self.pipe_in = os.pipe()
495 def __enter__(self):
496 self.captured_text = ""
497 # Save a copy of the stream:
498 self.streamfd = os.dup(self.original_streamfd)
499 # Replace the original stream with our write pipe:
500 os.dup2(self.pipe_in, self.original_streamfd)
501 return self
502 def __exit__(self, type, value, traceback):
503 # Print the escape character to make the readOutput method stop:
504 self.original_stream.write(self.escape_char)
505 # Flush the stream to make sure all our data goes in before
506 # the escape character:
507 self.original_stream.flush()
508 self.readOutput()
509 # Close the pipe:
510 os.close(self.pipe_in)
511 os.close(self.pipe_out)
512 # Restore the original stream:
513 os.dup2(self.streamfd, self.original_streamfd)
514 # Close the duplicate stream:
515 os.close(self.streamfd)
516 def readOutput(self):
517 while True:
518 char = os.read(self.pipe_out, 1).decode(self.original_stream.encoding)
519 if not char or self.escape_char in char:
520 break
521 self.captured_text += char
524def request_pdb_data(pdb_id: str, query: str) -> dict:
525 """Request data to the PDB GraphQL API.
527 Note that this function may be used for either PDB ids or PDB molecule ids, depending on the query.
528 The query parameter may be constructed using the following page:
529 https://data.rcsb.org/graphql/index.html
530 """
531 # Make sure the PDB id is valid as we set the correct key to mine the response data
532 if len(pdb_id) == 4: data_key = 'entry'
533 elif len(pdb_id) < 4 or len(pdb_id) > 4: data_key = 'chem_comp'
534 else: raise ValueError(f'Wrong PDB id "{pdb_id}". It must be 4 (entries) or less (ligands) characters long')
535 # Set the request URL
536 request_url = 'https://data.rcsb.org/graphql'
537 # Set the POST data
538 post_data = {
539 "query": query,
540 "variables": {"id": pdb_id}
541 }
542 # Send the request
543 try:
544 response = requests.post(request_url, json=post_data)
545 except requests.exceptions.ConnectionError as error:
546 raise ConnectionError('No internet connection :(') from None
547 # Get the response
548 parsed_response = json.loads(response.text)['data'][data_key]
549 if parsed_response is None:
550 new_pdb_id = request_replaced_pdb(pdb_id)
551 if new_pdb_id:
552 parsed_response = request_pdb_data(new_pdb_id, query)
553 else:
554 print(f'PDB id {pdb_id} not found')
555 return parsed_response
558def request_replaced_pdb(pdb_id):
559 """Use the RCSB REST API to get the replaced PDB id.
561 This is useful when the PDB is obsolete and has been replaced.
562 """
563 query_url = 'https://data.rcsb.org/rest/v1/holdings/removed/' + pdb_id
564 response = requests.get(query_url, headers={'Content-Type': 'application/json'})
565 # Check if the response is OK
566 if response.status_code == 200:
567 try:
568 return response.json()['rcsb_repository_holdings_removed']['id_codes_replaced_by'][0]
569 except:
570 print(f'Error when mining replaced PDB id for {pdb_id}')
571 return None
572 else:
573 return None
576def numerate_filename(filename: str, number: int, zeros: int = 2, separator: str = '_') -> str:
577 """Given a filename, set a suffix number on it, right before the extension.
579 Args:
580 filename: The original filename
581 number: The number to add as suffix
582 zeros: The number of zeros to fill the name
583 separator: The separator between filename and number
585 """
586 splits = filename.split('.')
587 sufix = separator + str(number).zfill(zeros)
588 return '.'.join(splits[0:-1]) + sufix + '.' + splits[-1]
591def glob_filename(filename: str, separator: str = '_') -> str:
592 """Given a filename, set a suffix including '*', right before the extension.
593 This should match all filenames obtained through the numerate_filename function when used in bash.
594 """
595 splits = filename.split('.')
596 sufix = separator + '*'
597 return '.'.join(splits[0:-1]) + sufix + '.' + splits[-1]
600def purge_glob(filename: str):
601 """Delete all files matched by the glob_filename function."""
602 glob_pattern = glob_filename(filename)
603 existing_outputs = glob(glob_pattern)
604 for existing_output in existing_outputs:
605 if exists(existing_output): os.remove(existing_output)
608def get_analysis_name(filename: str) -> str:
609 """Given a filename with the pattern 'mda.xxxx.json', get the 'xxxx' out of it."""
610 name_search = re.search(r'/mda.([A-Za-z0-9_-]*).json$', filename)
611 if not name_search:
612 raise ValueError(f'Wrong expected format in filename {filename}')
613 # To make it coherent with the rest of analyses, the analysis name become parsed when loaded in the database
614 # Every '_' is replaced by '-' so we must keep the analysis name coherent or the web client will not find it
615 return name_search[1].replace('_', '-')
618# DANI: Do not use getattr with a default argument or hasattr
619# DANI: If you do, you will loose any further AtributeError(s)
620# DANI: Thus you will have very silent errors every time you have a silly typo
621# DANI: This is a python itself unresolved error https://bugs.python.org/issue39865
622def safe_hasattr(instance, attribute_name: str) -> bool:
623 """Use a safe alternative to hasattr."""
624 return attribute_name in set(dir(instance))
627def safe_getattr(instance, attribute_name: str, default):
628 """Use a safe alternative to getattr."""
629 if not safe_hasattr(instance, attribute_name): return default
630 return getattr(instance, attribute_name)
633def read_ndict(nested_dict: dict, nested_key: str, placeholder=KeyError('Missing nested key')):
634 """Read a value in a nested dictionary using a single combined key.
635 Return the placeholder if any key in the path does not exist.
636 """
637 keys = nested_key.split('.')
638 value = nested_dict
639 for key in keys:
640 # support list indices
641 if key.isdigit():
642 if type(value) is not list: return placeholder
643 index = int(key)
644 value = value[index]
645 # support dict keys
646 else:
647 if type(value) is not dict: return placeholder
648 value = value.get(key, placeholder)
649 if value == placeholder: return placeholder
650 return value
653def write_ndict(nested_dict: dict, nested_key: str, value):
654 """Write a value in a nested dictionary using a single combined key.
655 Raise an error if any key in the path is missing.
656 """
657 keys = nested_key.split('.')
658 nested_keys = keys[0:-1]
659 next_target = nested_dict
660 for k, key in enumerate(nested_keys):
661 # support list indices
662 if key.isdigit():
663 if type(next_target) is not list:
664 raise ValueError(f'{".".join(nested_keys[0:k])} should be a list, but it is {next_target}')
665 index = int(key)
666 next_target = next_target[index]
667 # support dict keys
668 else:
669 if type(next_target) is not dict:
670 raise ValueError(f'{".".join(nested_keys[0:k])} should be a dict, but it is {next_target}')
671 missing_key_error = KeyError(f'Missing nested key {key}')
672 next_target = next_target.get(key, missing_key_error)
673 if next_target == missing_key_error: raise missing_key_error
674 field = keys[-1]
675 next_target[field] = value
678def get_git_version() -> str:
679 """Get the current git version."""
680 git_command = f"git -C {__path__[0]} describe"
681 process = run(git_command, shell=True, stdout=PIPE)
682 return process.stdout.decode().replace('\n', '') or __version__