Coverage for model_workflow/utils/auxiliar.py: 69%
234 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-23 10:54 +0000
1# Auxiliar generic functions and classes used along the workflow
3from model_workflow.utils.constants import RESIDUE_NAME_LETTERS, PROTEIN_RESIDUE_NAME_LETTERS
4from model_workflow.utils.constants import YELLOW_HEADER, COLOR_END
6import os
7from os import rename, listdir, remove
8from os.path import isfile, exists
9import re
10import sys
11import json
12import yaml
13from glob import glob
14from typing import Optional, List, Set, Generator, Union
15from struct import pack
16# NEVER FORGET: GraphQL has a problem with urllib.parse -> It will always return error 400 (Bad request)
17# We must use requests instead
18import requests
21# Check if a module has been imported
22def is_imported (module_name : str) -> bool:
23 return module_name in sys.modules
25# Set custom exception which is not to print traceback
26# They are used when the problem is not in our code
27class QuietException (Exception):
28 pass
30# Set a custom quite exception for when user input is wrong
31class InputError (QuietException):
32 pass
34# Set a custom quite exception for when MD data has not passed a quality check test
35class TestFailure (QuietException):
36 pass
38# Set a custom quite exception for when the problem comes from a third party dependency
39class ToolError (QuietException):
40 pass
42# Set a custom quite exception for when the problem comes from a remote service
43class RemoteServiceError (QuietException):
44 pass
46# Set a no referable exception for PDB synthetic constructs or chimeric entities
47class NoReferableException (Exception):
48 def __str__ (self): return f'No referable sequence {self.sequence}'
49 def __repr__ (self): return self.__str__()
50 def get_sequence (self) -> str:
51 return self.args[0]
52 sequence = property(get_sequence, None, None, 'Aminoacids sequence')
54# Set a custom exception handler where our input error exception has a quiet behaviour
55def custom_excepthook (exception_class, message, traceback):
56 # Quite behaviour if it is our input error exception
57 if QuietException in exception_class.__bases__:
58 print('{0}: {1}'.format(exception_class.__name__, message)) # Only print Error Type and Message
59 return
60 # Default behaviour otherwise
61 sys.__excepthook__(exception_class, message, traceback)
62sys.excepthook = custom_excepthook
64# Set a special exceptions for when the topology is missing
65MISSING_TOPOLOGY = Exception('Missing topology')
66MISSING_CHARGES = Exception('Missing atom charges')
67MISSING_BONDS = Exception('Missing atom bonds')
68JSON_SERIALIZABLE_MISSING_BONDS = 'MB'
70# Set a function to get the next letter from an input letter in alphabetic order
71# Return None if we run out of letters
72letters = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
73 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' }
74def get_new_letter(current_letters : set) -> Optional[str]:
75 return next((letter for letter in letters if letter not in current_letters), None)
77# Given a list or set of names, return a set with all case-posibilites:
78# All upper case
79# All lower case
80# First upper case and the rest lower case
81def all_cases (names : Union[List[str], Set[str]]) -> Set[str]:
82 all_names = []
83 for name in names:
84 all_upper = name.upper()
85 all_lower = name.lower()
86 one_upper = name[0].upper() + name[1:].lower()
87 all_names += [ all_upper, all_lower, one_upper ]
88 return set(all_names)
90# Given a residue name, return its single letter
91def residue_name_to_letter (residue_name : str) -> str:
92 return RESIDUE_NAME_LETTERS.get(residue_name, 'X')
94# Given a protein residue name, return its single letter
95def protein_residue_name_to_letter (residue_name : str) -> str:
96 return PROTEIN_RESIDUE_NAME_LETTERS.get(residue_name, 'X')
98# Set a JSON loader with additional logic to better handle problems
99def load_json (filepath : str) -> dict:
100 try:
101 with open(filepath, 'r') as file:
102 content = json.load(file)
103 return content
104 except Exception as error:
105 raise Exception(f'Something went wrong when loading JSON file {filepath}: {str(error)}')
107# Set a JSON saver with additional logic to better handle problems
108def save_json (content, filepath : str, indent : Optional[int] = None):
109 try:
110 with open(filepath, 'w') as file:
111 json.dump(content, file, indent=indent)
112 except Exception as error:
113 # Rename the JSON file since it will be half written thus giving problems when loaded
114 rename(filepath, filepath + '.wrong')
115 raise Exception(f'Something went wrong when saving JSON file {filepath}: {str(error)}')
117# Set a YAML loader with additional logic to better handle problems
118# DANI: Por algún motivo yaml.load también funciona con archivos en formato JSON
119def load_yaml (filepath : str):
120 try:
121 with open(filepath, 'r') as file:
122 content = yaml.load(file, Loader=yaml.CLoader)
123 return content
124 except Exception as error:
125 warn(str(error).replace('\n', ' '))
126 raise InputError('Something went wrong when loading YAML file ' + filepath)
128# Set a YAML saver with additional logic to better handle problems
129def save_yaml (content, filepath : str):
130 with open(filepath, 'w') as file:
131 yaml.dump(content, file)
133# Set a few constants to erase previou logs in the terminal
134CURSOR_UP_ONE = '\x1b[1A'
135ERASE_LINE = '\x1b[2K'
136ERASE_PREVIOUS_LINES = CURSOR_UP_ONE + ERASE_LINE + CURSOR_UP_ONE
138# Set a function to remove previous line
139def delete_previous_log ():
140 print(ERASE_PREVIOUS_LINES)
142# Set a function to reprint in the same line
143def reprint (text : str):
144 delete_previous_log()
145 print(text)
147# Set a function to print a messahe with a colored warning header
148def warn (message : str):
149 print(YELLOW_HEADER + '⚠ WARNING: ' + COLOR_END + message)
151# Get the mean/average of a list of values
152def mean(values : List[float]) -> float:
153 return sum(values) / len(values)
155# Round a number to hundredths
156def round_to_hundredths (number : float) -> float:
157 return round(number * 100) / 100
159# Round a number to hundredths
160def round_to_thousandths (number : float) -> float:
161 return round(number * 1000) / 1000
163# Given a list with numbers, create a string where number in a row are represented rangedly
164# e.g. [1, 3, 5, 6, 7, 8] => "1, 3, 5-8"
165def ranger (numbers : List[int]) -> str:
166 # Remove duplicates and sort numbers
167 sorted_numbers = sorted(list(set(numbers)))
168 # Get the number of numbers in the list
169 number_count = len(sorted_numbers)
170 # If there is only one number then finish here
171 if number_count == 1:
172 return str(sorted_numbers[0])
173 # Start the parsing otherwise
174 ranged = ''
175 last_number = -1
176 # Iterate numbers
177 for i, number in enumerate(sorted_numbers):
178 # Skip this number if it was already included in a previous serie
179 if i <= last_number: continue
180 # Add current number to the ranged string
181 ranged += ',' + str(number)
182 # Now iterate numbers after the current number
183 next_index = i+1
184 for j, next_number in enumerate(sorted_numbers[next_index:], next_index):
185 # Set the length of the serie
186 length = j - i
187 # End of the serie
188 if next_number - number != length:
189 # The length here is the length which broke the serie
190 # i.e. if the length here is 2 the actual serie length is 1
191 serie_length = length - 1
192 if serie_length > 1:
193 last_serie_number = j - 1
194 previous_number = sorted_numbers[last_serie_number]
195 ranged += '-' + str(previous_number)
196 last_number = last_serie_number
197 break
198 # End of the selection
199 if j == number_count - 1:
200 if length > 1:
201 ranged += '-' + str(next_number)
202 last_number = j
203 # Remove the first coma before returning the ranged string
204 return ranged[1:]
206# Set a special iteration system
207# Return one value of the array and a new array with all other values for each value
208def otherwise (values : list) -> Generator[tuple, None, None]:
209 for v, value in enumerate(values):
210 others = values[0:v] + values[v+1:]
211 yield value, others
213# List files in a directory
214def list_files (directory : str) -> List[str]:
215 return [f for f in listdir(directory) if isfile(f'{directory}/{f}')]
217# Check if a directory is empty
218def is_directory_empty (directory : str) -> bool:
219 return len(listdir(directory)) == 0
221# Set a function to check if a string has patterns to be parsed by a glob function
222# Note that this is not trivial, but this function should be good enough for our case
223# https://stackoverflow.com/questions/42283009/check-if-string-is-a-glob-pattern
224GLOB_CHARACTERS = ['*', '?', '[']
225def is_glob (path : str) -> bool:
226 # Find unescaped glob characters
227 for c, character in enumerate(path):
228 if character not in GLOB_CHARACTERS:
229 continue
230 if c == 0:
231 return True
232 previous_characters = path[c-1]
233 if previous_characters != '\\':
234 return True
235 return False
237# Parse a glob path into one or several results
238# If the path has no glob characters then return it as it is
239# Otherwise make sure
240def parse_glob (path : str) -> List[str]:
241 # If there is no glob pattern then just return the string as is
242 if not is_glob(path):
243 return [ path ]
244 # If there is glob pattern then parse it
245 parsed_filepaths = glob(path)
246 return parsed_filepaths
248# Supported byte sizes
249SUPPORTED_BYTE_SIZES = {
250 2: 'e',
251 4: 'f',
252 8: 'd'
253}
255# Data is a list of numeric values
256# Bit size is the number of bits for each value in data to be occupied
257def store_binary_data (data : List[float], byte_size : int, filepath : str):
258 # Check bit size to make sense
259 letter = SUPPORTED_BYTE_SIZES.get(byte_size, None)
260 if not letter:
261 raise ValueError(f'Not supported byte size {byte_size}, please select one of these: {", ".join(SUPPORTED_BYTE_SIZES.keys())}')
262 # Set the binary format
263 # '<' stands for little endian
264 byte_flag = f'<{letter}'
265 # Start writting the output file
266 with open(filepath, 'wb') as file:
267 # Iterate over data list values
268 for value in data:
269 value = float(value)
270 file.write(pack(byte_flag, value))
272# Capture all stdout or stderr within a code region even if it comes from another non-python threaded process
273# https://stackoverflow.com/questions/24277488/in-python-how-to-capture-the-stdout-from-a-c-shared-library-to-a-variable
274class CaptureOutput (object):
275 escape_char = "\b"
276 def __init__(self, stream : str = 'stdout'):
277 # Get sys stdout or stderr
278 if not hasattr(sys, stream):
279 raise ValueError(f'Unknown stream "{stream}". Expected stream value is "stdout" or "stderr"')
280 self.original_stream = getattr(sys, stream)
281 self.original_streamfd = self.original_stream.fileno()
282 self.captured_text = ""
283 # Create a pipe so the stream can be captured:
284 self.pipe_out, self.pipe_in = os.pipe()
285 def __enter__(self):
286 self.captured_text = ""
287 # Save a copy of the stream:
288 self.streamfd = os.dup(self.original_streamfd)
289 # Replace the original stream with our write pipe:
290 os.dup2(self.pipe_in, self.original_streamfd)
291 return self
292 def __exit__(self, type, value, traceback):
293 # Print the escape character to make the readOutput method stop:
294 self.original_stream.write(self.escape_char)
295 # Flush the stream to make sure all our data goes in before
296 # the escape character:
297 self.original_stream.flush()
298 self.readOutput()
299 # Close the pipe:
300 os.close(self.pipe_in)
301 os.close(self.pipe_out)
302 # Restore the original stream:
303 os.dup2(self.streamfd, self.original_streamfd)
304 # Close the duplicate stream:
305 os.close(self.streamfd)
306 def readOutput(self):
307 while True:
308 char = os.read(self.pipe_out,1).decode(self.original_stream.encoding)
309 if not char or self.escape_char in char:
310 break
311 self.captured_text += char
313# Set a function to request data to the PDB GraphQL API
314# Note that this function may be used for either PDB ids or PDB molecule ids, depending on the query
315# The query parameter may be constructed using the following page:
316# https://data.rcsb.org/graphql/index.html
317def request_pdb_data (pdb_id : str, query : str) -> dict:
318 # Make sure the PDB id is valid as we set the correct key to mine the response data
319 if len(pdb_id) == 4: data_key = 'entry'
320 elif len(pdb_id) < 4 or len(pdb_id) > 4: data_key = 'chem_comp'
321 else: raise ValueError(f'Wrong PDB id "{pdb_id}". It must be 4 (entries) or less (ligands) characters long')
322 # Set the request URL
323 request_url = 'https://data.rcsb.org/graphql'
324 # Set the POST data
325 post_data = {
326 "query": query,
327 "variables": { "id": pdb_id }
328 }
329 # Send the request
330 try:
331 response = requests.post(request_url, json=post_data)
332 except requests.exceptions.ConnectionError as error:
333 raise ConnectionError('No internet connection :(') from None
334 # Get the response
335 parsed_response = json.loads(response.text)['data'][data_key]
336 if parsed_response == None:
337 new_pdb_id = request_replaced_pdb(pdb_id)
338 if new_pdb_id:
339 parsed_response = request_pdb_data(new_pdb_id, query)
340 else:
341 print(f'PDB id {pdb_id} not found')
342 return parsed_response
344# Use the RCSB REST API to get the replaced PDB id
345# This is useful when the PDB is obsolete and has been replaced
346def request_replaced_pdb(pdb_id):
347 query_url = 'https://data.rcsb.org/rest/v1/holdings/removed/'+pdb_id
348 response = requests.get(query_url, headers={'Content-Type': 'application/json'})
349 # Check if the response is OK
350 if response.status_code == 200:
351 try:
352 return response.json()['rcsb_repository_holdings_removed']['id_codes_replaced_by'][0]
353 except:
354 print(f'Error when mining replaced PDB id for {pdb_id}')
355 return None
356 else:
357 return None
359# Given a filename, set a sufix number on it, right before the extension
360# Set also the number of zeros to fill the name
361def numerate_filename (filename : str, number : int, zeros : int = 2, separator : str = '_') -> str:
362 splits = filename.split('.')
363 sufix = separator + str(number).zfill(zeros)
364 return '.'.join(splits[0:-1]) + sufix + '.' + splits[-1]
366# Given a filename, set a sufix including '*', right before the extension
367# This should match all filenames obtanied through the numerate_filename function when used in bash
368def glob_filename (filename : str, separator : str = '_') -> str:
369 splits = filename.split('.')
370 sufix = separator + '*'
371 return '.'.join(splits[0:-1]) + sufix + '.' + splits[-1]
373# Delete all files matched by the glob_filename function
374def purge_glob (filename : str):
375 glob_pattern = glob_filename(filename)
376 existing_outputs = glob(glob_pattern)
377 for existing_output in existing_outputs:
378 if exists(existing_output): remove(existing_output)
380# Given a filename with the the pattern 'mda.xxxx.json', get the 'xxxx' out of it
381def get_analysis_name (filename : str) -> str:
382 name_search = re.search(r'/mda.([A-Za-z0-9_-]*).json$', filename)
383 if not name_search:
384 raise ValueError(f'Wrong expected format in filename {filename}')
385 # To make it coherent with the rest of analyses, the analysis name become parsed when loaded in the database
386 # Every '_' is replaced by '-' so we must keep the analysis name coherent or the web client will not find it
387 return name_search[1].replace('_', '-')
389# Use a safe alternative to hasattr/getattr
390# DANI: Do not use getattr with a default argument or hasattr
391# DANI: If you do, you will loose any further AtributeError(s)
392# DANI: Thus you will have very silent errors every time you have a silly typo
393# DANI: This is a python itself unresolved error https://bugs.python.org/issue39865
394def safe_hasattr (instance, attribute_name : str) -> bool:
395 return attribute_name in set(dir(instance))
396def safe_getattr (instance, attribute_name : str, defualt):
397 if not safe_hasattr(instance, attribute_name): return defualt
398 return getattr(instance, attribute_name)