Coverage for mddb_workflow/utils/auxiliar.py: 71%
315 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 15:48 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-29 15:48 +0000
1# Auxiliar generic functions and classes used along the workflow
3from mddb_workflow import __path__
4from mddb_workflow.utils.constants import RESIDUE_NAME_LETTERS, PROTEIN_RESIDUE_NAME_LETTERS
5from mddb_workflow.utils.constants import YELLOW_HEADER, COLOR_END
7import os
8from os.path import isfile, exists
9import re
10import sys
11import json
12import yaml
13from glob import glob
14from typing import Optional, Generator, Callable
15from struct import pack
16# NEVER FORGET: GraphQL has a problem with urllib.parse -> It will always return error 400 (Bad request)
17# We must use requests instead
18import requests
19from subprocess import run, PIPE
22# Check if a module has been imported
23def is_imported (module_name : str) -> bool:
24 return module_name in sys.modules
26# Set custom exception which is not to print traceback
27# They are used when the problem is not in our code
28class QuietException (Exception):
29 pass
31# Set a custom quite exception for when user input is wrong
32class InputError (QuietException):
33 pass
35# Set a custom quite exception for when MD data has not passed a quality check test
36class TestFailure (QuietException):
37 pass
39# Set a custom quite exception for when the problem comes from a third party dependency
40class ToolError (QuietException):
41 pass
43# Set a custom quite exception for when the problem comes from a remote service
44class RemoteServiceError (QuietException):
45 pass
47# Set a no referable exception for PDB synthetic constructs or chimeric entities
48class NoReferableException (Exception):
49 def __str__ (self): return f'No referable sequence {self.sequence}'
50 def __repr__ (self): return self.__str__()
51 def get_sequence (self) -> str:
52 return self.args[0]
53 sequence = property(get_sequence, None, None, 'Aminoacids sequence')
55# Set a custom exception handler where our input error exception has a quiet behaviour
56def custom_excepthook (exception_class, message, traceback):
57 # Quite behaviour if it is our input error exception
58 if QuietException in exception_class.__bases__:
59 print('{0}: {1}'.format(exception_class.__name__, message)) # Only print Error Type and Message
60 return
61 # Default behaviour otherwise
62 sys.__excepthook__(exception_class, message, traceback)
63sys.excepthook = custom_excepthook
65# Set a special exceptions for when the topology is missing
66MISSING_TOPOLOGY = Exception('Missing topology')
67MISSING_CHARGES = Exception('Missing atom charges')
68MISSING_BONDS = Exception('Missing atom bonds')
69JSON_SERIALIZABLE_MISSING_BONDS = 'MB'
71# Keep all exceptions in a set
72STANDARD_EXCEPTIONS = { MISSING_TOPOLOGY, MISSING_CHARGES, MISSING_BONDS }
74# Set a function to get the next letter from an input letter in alphabetic order
75# Return None if we run out of letters
76letters = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
77 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' }
78def get_new_letter(current_letters : set) -> Optional[str]:
79 return next((letter for letter in letters if letter not in current_letters), None)
81# Given a list or set of names, return a set with all case-posibilites:
82# All upper case
83# All lower case
84# First upper case and the rest lower case
85def all_cases (names : list[str] | set[str]) -> set[str]:
86 all_names = []
87 for name in names:
88 all_upper = name.upper()
89 all_lower = name.lower()
90 one_upper = name[0].upper() + name[1:].lower()
91 all_names += [ all_upper, all_lower, one_upper ]
92 return set(all_names)
94# Given a residue name, return its single letter
95def residue_name_to_letter (residue_name : str) -> str:
96 return RESIDUE_NAME_LETTERS.get(residue_name, 'X')
98# Given a protein residue name, return its single letter
99def protein_residue_name_to_letter (residue_name : str) -> str:
100 return PROTEIN_RESIDUE_NAME_LETTERS.get(residue_name, 'X')
102# Set a recursive transformer for nested dicts and lists
103# Note that a new object is created to prevent mutating the original
104# If no transformer function is passed then it becomes a recursive cloner
105# WARNING: the transformer function could mutate some types of the original object if not done properly
106OBJECT_TYPES = { dict, list, tuple }
107def recursive_transformer (target_object : dict | list | tuple, transformer : Optional[Callable] = None) -> dict | list | tuple:
108 object_type = type(target_object)
109 # Get a starting object and entries iterator depending on the object type
110 if object_type == dict:
111 clone = {}
112 entries = target_object.items()
113 elif object_type == list or object_type == tuple:
114 # Note that if it is a tuple we make a list anyway and we convert it to tuple at the end
115 # WARNING: Note that tuples do not support reassigning items
116 clone = [ None for i in range(len(target_object)) ]
117 entries = enumerate(target_object)
118 else: ValueError(f'The recuersive cloner should only be applied to object and lists, not to {object_type}')
119 # Iterate the different entries in the object
120 for index_or_key, value in entries:
121 # Get he value type
122 value_type = type(value)
123 # If it is a dict or list then call the transformer recursively
124 if value_type in OBJECT_TYPES: clone[index_or_key] = recursive_transformer(value)
125 # If it is not an object type then apply the transformer to it
126 else: clone[index_or_key] = transformer(value) if transformer else value
127 # If it was a tuple then make the conversion now
128 if object_type == tuple: clone = tuple(clone)
129 return clone
131# Set some headers for the serializer
132EXCEPTION_HEADER = 'Exception: '
133# Set a standard JSON serializer/unserializer to support additonal types
134# LORE: This was originally intended to support exceptions in the cache
135def json_serializer (object : dict | list | tuple) -> dict | list | tuple:
136 def serializer (value):
137 # If we have exceptions then convert them to text with an appropiate header
138 if type(value) == Exception:
139 return f'{EXCEPTION_HEADER}{value}'
140 # If the type is not among the ones we check then assume it is already serializable
141 return value
142 object_clone = recursive_transformer(object, serializer)
143 return object_clone
144def json_deserializer (object : dict | list | tuple) -> dict | list | tuple:
145 def deserializer (value):
146 # Check if there is any value which was adapted to become JSON serialized and restore it
147 if type(value) == str and value[0:11] == EXCEPTION_HEADER:
148 # WARNING: Do not declare new exceptions here but use the constant ones
149 # WARNING: Otherwise further equality comparisions will fail
150 exception_message = value[11:]
151 standard_exception = next((exception for exception in STANDARD_EXCEPTIONS if str(exception) == exception_message), None)
152 if standard_exception == None:
153 raise ValueError(f'Exception "{exception_message}" is not among standard exceptions')
154 return standard_exception
155 # If the type is not among the ones we check then assume it is already deserialized
156 return value
157 object_clone = recursive_transformer(object, deserializer)
158 return object_clone
160# Set a JSON loader with additional logic to better handle problems
161def load_json (filepath : str, replaces : Optional[list[tuple]] = []) -> dict:
162 try:
163 with open(filepath, 'r') as file:
164 content = file.read()
165 # Make pure text replacements
166 for replace in replaces:
167 target, replacement = replace
168 content = content.replace(target, replacement)
169 # Parse the content to JSON
170 parsed_content = json.loads(content)
171 # Deserialize some types like Exceptions, which are stored in JSONs in this context
172 deserialized_content = json_deserializer(parsed_content)
173 return deserialized_content
174 except Exception as error:
175 raise Exception(f'Something went wrong when loading JSON file {filepath}: {str(error)}')
177# Set a JSON saver with additional logic to better handle problems
178def save_json (content, filepath : str, indent : Optional[int] = None):
179 try:
180 with open(filepath, 'w') as file:
181 serialized_content = json_serializer(content)
182 json.dump(serialized_content, file, indent=indent)
183 except Exception as error:
184 # Rename the JSON file since it will be half written thus giving problems when loaded
185 os.rename(filepath, filepath + '.wrong')
186 raise Exception(f'Something went wrong when saving JSON file {filepath}: {str(error)}')
188# Set a YAML loader with additional logic to better handle problems
189# The argument replaces allows to replace file content before beeing processed
190# Every replace is a tuple whith two values: the target and the replacement
191# DANI: Por algún motivo yaml.load también funciona con archivos en formato JSON
192def load_yaml (filepath : str, replaces : Optional[list[tuple]] = []) -> dict:
193 try:
194 with open(filepath, 'r') as file:
195 content = file.read()
196 for replace in replaces:
197 target, replacement = replace
198 content = content.replace(target, replacement)
199 parsed_content = yaml.load(content, Loader=yaml.CLoader)
200 return parsed_content
201 except Exception as error:
202 warn(str(error).replace('\n', ' '))
203 raise InputError('Something went wrong when loading YAML file ' + filepath)
205# Set a YAML saver with additional logic to better handle problems
206def save_yaml (content, filepath : str):
207 with open(filepath, 'w') as file:
208 yaml.dump(content, file)
210# Set a few constants to erase previou logs in the terminal
211CURSOR_UP_ONE = '\x1b[1A'
212ERASE_LINE = '\x1b[2K'
213ERASE_PREVIOUS_LINES = CURSOR_UP_ONE + ERASE_LINE + CURSOR_UP_ONE
215# Set a function to remove previous line
216def delete_previous_log ():
217 print(ERASE_PREVIOUS_LINES)
219# Set a function to reprint in the same line
220def reprint (text : str):
221 delete_previous_log()
222 print(text)
224# Set a function to print a messahe with a colored warning header
225def warn (message : str):
226 print(YELLOW_HEADER + '⚠ WARNING: ' + COLOR_END + message)
228# Get the mean/average of a list of values
229def mean(values : list[float]) -> float:
230 return sum(values) / len(values)
232# Round a number to hundredths
233def round_to_hundredths (number : float) -> float:
234 return round(number * 100) / 100
236# Round a number to hundredths
237def round_to_thousandths (number : float) -> float:
238 return round(number * 1000) / 1000
240# Given a list with numbers, create a string where number in a row are represented rangedly
241# e.g. [1, 3, 5, 6, 7, 8] => "1, 3, 5-8"
242def ranger (numbers : list[int]) -> str:
243 # Remove duplicates and sort numbers
244 sorted_numbers = sorted(list(set(numbers)))
245 # Get the number of numbers in the list
246 number_count = len(sorted_numbers)
247 # If there is only one number then finish here
248 if number_count == 1:
249 return str(sorted_numbers[0])
250 # Start the parsing otherwise
251 ranged = ''
252 last_number = -1
253 # Iterate numbers
254 for i, number in enumerate(sorted_numbers):
255 # Skip this number if it was already included in a previous serie
256 if i <= last_number: continue
257 # Add current number to the ranged string
258 ranged += ',' + str(number)
259 # Now iterate numbers after the current number
260 next_index = i+1
261 for j, next_number in enumerate(sorted_numbers[next_index:], next_index):
262 # Set the length of the serie
263 length = j - i
264 # End of the serie
265 if next_number - number != length:
266 # The length here is the length which broke the serie
267 # i.e. if the length here is 2 the actual serie length is 1
268 serie_length = length - 1
269 if serie_length > 1:
270 last_serie_number = j - 1
271 previous_number = sorted_numbers[last_serie_number]
272 ranged += '-' + str(previous_number)
273 last_number = last_serie_number
274 break
275 # End of the selection
276 if j == number_count - 1:
277 if length > 1:
278 ranged += '-' + str(next_number)
279 last_number = j
280 # Remove the first coma before returning the ranged string
281 return ranged[1:]
283# Set a special iteration system
284# Return one value of the array and a new array with all other values for each value
285def otherwise (values : list) -> Generator[tuple, None, None]:
286 for v, value in enumerate(values):
287 others = values[0:v] + values[v+1:]
288 yield value, others
290# List files in a directory
291def list_files (directory : str) -> list[str]:
292 return [f for f in os.listdir(directory) if isfile(f'{directory}/{f}')]
294# Check if a directory is empty
295def is_directory_empty (directory : str) -> bool:
296 return len(os.listdir(directory)) == 0
298# Set a function to check if a string has patterns to be parsed by a glob function
299# Note that this is not trivial, but this function should be good enough for our case
300# https://stackoverflow.com/questions/42283009/check-if-string-is-a-glob-pattern
301GLOB_CHARACTERS = ['*', '?', '[']
302def is_glob (path : str) -> bool:
303 # Find unescaped glob characters
304 for c, character in enumerate(path):
305 if character not in GLOB_CHARACTERS:
306 continue
307 if c == 0:
308 return True
309 previous_characters = path[c-1]
310 if previous_characters != '\\':
311 return True
312 return False
314# Parse a glob path into one or several results
315# If the path has no glob characters then return it as it is
316# Otherwise make sure
317def parse_glob (path : str) -> list[str]:
318 # If there is no glob pattern then just return the string as is
319 if not is_glob(path):
320 return [ path ]
321 # If there is glob pattern then parse it
322 parsed_filepaths = glob(path)
323 return parsed_filepaths
325# Supported byte sizes
326SUPPORTED_BYTE_SIZES = {
327 2: 'e',
328 4: 'f',
329 8: 'd'
330}
332# Data is a list of numeric values
333# Bit size is the number of bits for each value in data to be occupied
334def store_binary_data (data : list[float], byte_size : int, filepath : str):
335 # Check bit size to make sense
336 letter = SUPPORTED_BYTE_SIZES.get(byte_size, None)
337 if not letter:
338 raise ValueError(f'Not supported byte size {byte_size}, please select one of these: {", ".join(SUPPORTED_BYTE_SIZES.keys())}')
339 # Set the binary format
340 # '<' stands for little endian
341 byte_flag = f'<{letter}'
342 # Start writting the output file
343 with open(filepath, 'wb') as file:
344 # Iterate over data list values
345 for value in data:
346 value = float(value)
347 file.write(pack(byte_flag, value))
349# Capture all stdout or stderr within a code region even if it comes from another non-python threaded process
350# https://stackoverflow.com/questions/24277488/in-python-how-to-capture-the-stdout-from-a-c-shared-library-to-a-variable
351class CaptureOutput (object):
352 escape_char = "\b"
353 def __init__(self, stream : str = 'stdout'):
354 # Get sys stdout or stderr
355 if not hasattr(sys, stream):
356 raise ValueError(f'Unknown stream "{stream}". Expected stream value is "stdout" or "stderr"')
357 self.original_stream = getattr(sys, stream)
358 self.original_streamfd = self.original_stream.fileno()
359 self.captured_text = ""
360 # Create a pipe so the stream can be captured:
361 self.pipe_out, self.pipe_in = os.pipe()
362 def __enter__(self):
363 self.captured_text = ""
364 # Save a copy of the stream:
365 self.streamfd = os.dup(self.original_streamfd)
366 # Replace the original stream with our write pipe:
367 os.dup2(self.pipe_in, self.original_streamfd)
368 return self
369 def __exit__(self, type, value, traceback):
370 # Print the escape character to make the readOutput method stop:
371 self.original_stream.write(self.escape_char)
372 # Flush the stream to make sure all our data goes in before
373 # the escape character:
374 self.original_stream.flush()
375 self.readOutput()
376 # Close the pipe:
377 os.close(self.pipe_in)
378 os.close(self.pipe_out)
379 # Restore the original stream:
380 os.dup2(self.streamfd, self.original_streamfd)
381 # Close the duplicate stream:
382 os.close(self.streamfd)
383 def readOutput(self):
384 while True:
385 char = os.read(self.pipe_out,1).decode(self.original_stream.encoding)
386 if not char or self.escape_char in char:
387 break
388 self.captured_text += char
390# Set a function to request data to the PDB GraphQL API
391# Note that this function may be used for either PDB ids or PDB molecule ids, depending on the query
392# The query parameter may be constructed using the following page:
393# https://data.rcsb.org/graphql/index.html
394def request_pdb_data (pdb_id : str, query : str) -> dict:
395 # Make sure the PDB id is valid as we set the correct key to mine the response data
396 if len(pdb_id) == 4: data_key = 'entry'
397 elif len(pdb_id) < 4 or len(pdb_id) > 4: data_key = 'chem_comp'
398 else: raise ValueError(f'Wrong PDB id "{pdb_id}". It must be 4 (entries) or less (ligands) characters long')
399 # Set the request URL
400 request_url = 'https://data.rcsb.org/graphql'
401 # Set the POST data
402 post_data = {
403 "query": query,
404 "variables": { "id": pdb_id }
405 }
406 # Send the request
407 try:
408 response = requests.post(request_url, json=post_data)
409 except requests.exceptions.ConnectionError as error:
410 raise ConnectionError('No internet connection :(') from None
411 # Get the response
412 parsed_response = json.loads(response.text)['data'][data_key]
413 if parsed_response == None:
414 new_pdb_id = request_replaced_pdb(pdb_id)
415 if new_pdb_id:
416 parsed_response = request_pdb_data(new_pdb_id, query)
417 else:
418 print(f'PDB id {pdb_id} not found')
419 return parsed_response
421# Use the RCSB REST API to get the replaced PDB id
422# This is useful when the PDB is obsolete and has been replaced
423def request_replaced_pdb(pdb_id):
424 query_url = 'https://data.rcsb.org/rest/v1/holdings/removed/'+pdb_id
425 response = requests.get(query_url, headers={'Content-Type': 'application/json'})
426 # Check if the response is OK
427 if response.status_code == 200:
428 try:
429 return response.json()['rcsb_repository_holdings_removed']['id_codes_replaced_by'][0]
430 except:
431 print(f'Error when mining replaced PDB id for {pdb_id}')
432 return None
433 else:
434 return None
436# Given a filename, set a sufix number on it, right before the extension
437# Set also the number of zeros to fill the name
438def numerate_filename (filename : str, number : int, zeros : int = 2, separator : str = '_') -> str:
439 splits = filename.split('.')
440 sufix = separator + str(number).zfill(zeros)
441 return '.'.join(splits[0:-1]) + sufix + '.' + splits[-1]
443# Given a filename, set a sufix including '*', right before the extension
444# This should match all filenames obtanied through the numerate_filename function when used in bash
445def glob_filename (filename : str, separator : str = '_') -> str:
446 splits = filename.split('.')
447 sufix = separator + '*'
448 return '.'.join(splits[0:-1]) + sufix + '.' + splits[-1]
450# Delete all files matched by the glob_filename function
451def purge_glob (filename : str):
452 glob_pattern = glob_filename(filename)
453 existing_outputs = glob(glob_pattern)
454 for existing_output in existing_outputs:
455 if exists(existing_output): os.remove(existing_output)
457# Given a filename with the the pattern 'mda.xxxx.json', get the 'xxxx' out of it
458def get_analysis_name (filename : str) -> str:
459 name_search = re.search(r'/mda.([A-Za-z0-9_-]*).json$', filename)
460 if not name_search:
461 raise ValueError(f'Wrong expected format in filename {filename}')
462 # To make it coherent with the rest of analyses, the analysis name become parsed when loaded in the database
463 # Every '_' is replaced by '-' so we must keep the analysis name coherent or the web client will not find it
464 return name_search[1].replace('_', '-')
466# Use a safe alternative to hasattr/getattr
467# DANI: Do not use getattr with a default argument or hasattr
468# DANI: If you do, you will loose any further AtributeError(s)
469# DANI: Thus you will have very silent errors every time you have a silly typo
470# DANI: This is a python itself unresolved error https://bugs.python.org/issue39865
471def safe_hasattr (instance, attribute_name : str) -> bool:
472 return attribute_name in set(dir(instance))
473def safe_getattr (instance, attribute_name : str, default):
474 if not safe_hasattr(instance, attribute_name): return default
475 return getattr(instance, attribute_name)
477# Function to read and write a dict nested value with using a single combined key
479# Read a value in a nested dictionary and return the placeholder if any key in the path does not exist
480def read_ndict (nested_dict : dict, nested_key : str, placeholder = KeyError('Missing nested key')):
481 keys = nested_key.split('.')
482 value = nested_dict
483 for key in keys:
484 # support list indices
485 if key.isdigit():
486 if type(value) != list: return placeholder
487 index = int(key)
488 value = value[index]
489 # support dict keys
490 else:
491 if type(value) != dict: return placeholder
492 value = value.get(key, placeholder)
493 if value == placeholder: return placeholder
494 return value
496# Write a value in a nested dictionary and raise an error if any key in the path s missing
497def write_ndict (nested_dict : dict, nested_key : str, value):
498 keys = nested_key.split('.')
499 nested_keys = keys[0:-1]
500 next_target = nested_dict
501 for k, key in enumerate(nested_keys):
502 # support list indices
503 if key.isdigit():
504 if type(next_target) != list:
505 raise ValueError(f'{".".join(nested_keys[0:k])} should be a list, but it is {next_target}')
506 index = int(key)
507 next_target = next_target[index]
508 # support dict keys
509 else:
510 if type(next_target) != dict:
511 raise ValueError(f'{".".join(nested_keys[0:k])} should be a dict, but it is {next_target}')
512 missing_key_error = KeyError(f'Missing nested key {key}')
513 next_target = next_target.get(key, missing_key_error)
514 if next_target == missing_key_error: raise missing_key_error
515 field = keys[-1]
516 next_target[field] = value
518# Get the current git version
519def get_git_version () -> str:
520 git_command = f"git -C {__path__[0]} describe"
521 process = run(git_command, shell=True, stdout=PIPE)
522 return process.stdout.decode().replace('\n','')