Coverage for mddb_workflow/utils/auxiliar.py: 71%

315 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-29 15:48 +0000

1# Auxiliar generic functions and classes used along the workflow 

2 

3from mddb_workflow import __path__ 

4from mddb_workflow.utils.constants import RESIDUE_NAME_LETTERS, PROTEIN_RESIDUE_NAME_LETTERS 

5from mddb_workflow.utils.constants import YELLOW_HEADER, COLOR_END 

6 

7import os 

8from os.path import isfile, exists 

9import re 

10import sys 

11import json 

12import yaml 

13from glob import glob 

14from typing import Optional, Generator, Callable 

15from struct import pack 

16# NEVER FORGET: GraphQL has a problem with urllib.parse -> It will always return error 400 (Bad request) 

17# We must use requests instead 

18import requests 

19from subprocess import run, PIPE 

20 

21 

22# Check if a module has been imported 

23def is_imported (module_name : str) -> bool: 

24 return module_name in sys.modules 

25 

26# Set custom exception which is not to print traceback 

27# They are used when the problem is not in our code 

28class QuietException (Exception): 

29 pass 

30 

31# Set a custom quite exception for when user input is wrong 

32class InputError (QuietException): 

33 pass 

34 

35# Set a custom quite exception for when MD data has not passed a quality check test 

36class TestFailure (QuietException): 

37 pass 

38 

39# Set a custom quite exception for when the problem comes from a third party dependency 

40class ToolError (QuietException): 

41 pass 

42 

43# Set a custom quite exception for when the problem comes from a remote service 

44class RemoteServiceError (QuietException): 

45 pass 

46 

47# Set a no referable exception for PDB synthetic constructs or chimeric entities 

48class NoReferableException (Exception): 

49 def __str__ (self): return f'No referable sequence {self.sequence}' 

50 def __repr__ (self): return self.__str__() 

51 def get_sequence (self) -> str: 

52 return self.args[0] 

53 sequence = property(get_sequence, None, None, 'Aminoacids sequence') 

54 

55# Set a custom exception handler where our input error exception has a quiet behaviour 

56def custom_excepthook (exception_class, message, traceback): 

57 # Quite behaviour if it is our input error exception 

58 if QuietException in exception_class.__bases__: 

59 print('{0}: {1}'.format(exception_class.__name__, message)) # Only print Error Type and Message 

60 return 

61 # Default behaviour otherwise 

62 sys.__excepthook__(exception_class, message, traceback) 

63sys.excepthook = custom_excepthook 

64 

65# Set a special exceptions for when the topology is missing 

66MISSING_TOPOLOGY = Exception('Missing topology') 

67MISSING_CHARGES = Exception('Missing atom charges') 

68MISSING_BONDS = Exception('Missing atom bonds') 

69JSON_SERIALIZABLE_MISSING_BONDS = 'MB' 

70 

71# Keep all exceptions in a set 

72STANDARD_EXCEPTIONS = { MISSING_TOPOLOGY, MISSING_CHARGES, MISSING_BONDS } 

73 

74# Set a function to get the next letter from an input letter in alphabetic order 

75# Return None if we run out of letters 

76letters = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 

77 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' } 

78def get_new_letter(current_letters : set) -> Optional[str]: 

79 return next((letter for letter in letters if letter not in current_letters), None) 

80 

81# Given a list or set of names, return a set with all case-posibilites: 

82# All upper case 

83# All lower case 

84# First upper case and the rest lower case 

85def all_cases (names : list[str] | set[str]) -> set[str]: 

86 all_names = [] 

87 for name in names: 

88 all_upper = name.upper() 

89 all_lower = name.lower() 

90 one_upper = name[0].upper() + name[1:].lower() 

91 all_names += [ all_upper, all_lower, one_upper ] 

92 return set(all_names) 

93 

94# Given a residue name, return its single letter 

95def residue_name_to_letter (residue_name : str) -> str: 

96 return RESIDUE_NAME_LETTERS.get(residue_name, 'X') 

97 

98# Given a protein residue name, return its single letter 

99def protein_residue_name_to_letter (residue_name : str) -> str: 

100 return PROTEIN_RESIDUE_NAME_LETTERS.get(residue_name, 'X') 

101 

102# Set a recursive transformer for nested dicts and lists 

103# Note that a new object is created to prevent mutating the original 

104# If no transformer function is passed then it becomes a recursive cloner 

105# WARNING: the transformer function could mutate some types of the original object if not done properly 

106OBJECT_TYPES = { dict, list, tuple } 

107def recursive_transformer (target_object : dict | list | tuple, transformer : Optional[Callable] = None) -> dict | list | tuple: 

108 object_type = type(target_object) 

109 # Get a starting object and entries iterator depending on the object type 

110 if object_type == dict: 

111 clone = {} 

112 entries = target_object.items() 

113 elif object_type == list or object_type == tuple: 

114 # Note that if it is a tuple we make a list anyway and we convert it to tuple at the end 

115 # WARNING: Note that tuples do not support reassigning items 

116 clone = [ None for i in range(len(target_object)) ] 

117 entries = enumerate(target_object) 

118 else: ValueError(f'The recuersive cloner should only be applied to object and lists, not to {object_type}') 

119 # Iterate the different entries in the object 

120 for index_or_key, value in entries: 

121 # Get he value type 

122 value_type = type(value) 

123 # If it is a dict or list then call the transformer recursively 

124 if value_type in OBJECT_TYPES: clone[index_or_key] = recursive_transformer(value) 

125 # If it is not an object type then apply the transformer to it 

126 else: clone[index_or_key] = transformer(value) if transformer else value 

127 # If it was a tuple then make the conversion now 

128 if object_type == tuple: clone = tuple(clone) 

129 return clone 

130 

131# Set some headers for the serializer 

132EXCEPTION_HEADER = 'Exception: ' 

133# Set a standard JSON serializer/unserializer to support additonal types 

134# LORE: This was originally intended to support exceptions in the cache 

135def json_serializer (object : dict | list | tuple) -> dict | list | tuple: 

136 def serializer (value): 

137 # If we have exceptions then convert them to text with an appropiate header 

138 if type(value) == Exception: 

139 return f'{EXCEPTION_HEADER}{value}' 

140 # If the type is not among the ones we check then assume it is already serializable 

141 return value 

142 object_clone = recursive_transformer(object, serializer) 

143 return object_clone 

144def json_deserializer (object : dict | list | tuple) -> dict | list | tuple: 

145 def deserializer (value): 

146 # Check if there is any value which was adapted to become JSON serialized and restore it 

147 if type(value) == str and value[0:11] == EXCEPTION_HEADER: 

148 # WARNING: Do not declare new exceptions here but use the constant ones 

149 # WARNING: Otherwise further equality comparisions will fail 

150 exception_message = value[11:] 

151 standard_exception = next((exception for exception in STANDARD_EXCEPTIONS if str(exception) == exception_message), None) 

152 if standard_exception == None: 

153 raise ValueError(f'Exception "{exception_message}" is not among standard exceptions') 

154 return standard_exception 

155 # If the type is not among the ones we check then assume it is already deserialized 

156 return value 

157 object_clone = recursive_transformer(object, deserializer) 

158 return object_clone 

159 

160# Set a JSON loader with additional logic to better handle problems 

161def load_json (filepath : str, replaces : Optional[list[tuple]] = []) -> dict: 

162 try: 

163 with open(filepath, 'r') as file: 

164 content = file.read() 

165 # Make pure text replacements 

166 for replace in replaces: 

167 target, replacement = replace 

168 content = content.replace(target, replacement) 

169 # Parse the content to JSON 

170 parsed_content = json.loads(content) 

171 # Deserialize some types like Exceptions, which are stored in JSONs in this context 

172 deserialized_content = json_deserializer(parsed_content) 

173 return deserialized_content 

174 except Exception as error: 

175 raise Exception(f'Something went wrong when loading JSON file {filepath}: {str(error)}') 

176 

177# Set a JSON saver with additional logic to better handle problems 

178def save_json (content, filepath : str, indent : Optional[int] = None): 

179 try: 

180 with open(filepath, 'w') as file: 

181 serialized_content = json_serializer(content) 

182 json.dump(serialized_content, file, indent=indent) 

183 except Exception as error: 

184 # Rename the JSON file since it will be half written thus giving problems when loaded 

185 os.rename(filepath, filepath + '.wrong') 

186 raise Exception(f'Something went wrong when saving JSON file {filepath}: {str(error)}') 

187 

188# Set a YAML loader with additional logic to better handle problems 

189# The argument replaces allows to replace file content before beeing processed 

190# Every replace is a tuple whith two values: the target and the replacement 

191# DANI: Por algún motivo yaml.load también funciona con archivos en formato JSON 

192def load_yaml (filepath : str, replaces : Optional[list[tuple]] = []) -> dict: 

193 try: 

194 with open(filepath, 'r') as file: 

195 content = file.read() 

196 for replace in replaces: 

197 target, replacement = replace 

198 content = content.replace(target, replacement) 

199 parsed_content = yaml.load(content, Loader=yaml.CLoader) 

200 return parsed_content 

201 except Exception as error: 

202 warn(str(error).replace('\n', ' ')) 

203 raise InputError('Something went wrong when loading YAML file ' + filepath) 

204 

205# Set a YAML saver with additional logic to better handle problems 

206def save_yaml (content, filepath : str): 

207 with open(filepath, 'w') as file: 

208 yaml.dump(content, file) 

209 

210# Set a few constants to erase previou logs in the terminal 

211CURSOR_UP_ONE = '\x1b[1A' 

212ERASE_LINE = '\x1b[2K' 

213ERASE_PREVIOUS_LINES = CURSOR_UP_ONE + ERASE_LINE + CURSOR_UP_ONE 

214 

215# Set a function to remove previous line 

216def delete_previous_log (): 

217 print(ERASE_PREVIOUS_LINES) 

218 

219# Set a function to reprint in the same line 

220def reprint (text : str): 

221 delete_previous_log() 

222 print(text) 

223 

224# Set a function to print a messahe with a colored warning header 

225def warn (message : str): 

226 print(YELLOW_HEADER + '⚠ WARNING: ' + COLOR_END + message) 

227 

228# Get the mean/average of a list of values 

229def mean(values : list[float]) -> float: 

230 return sum(values) / len(values) 

231 

232# Round a number to hundredths 

233def round_to_hundredths (number : float) -> float: 

234 return round(number * 100) / 100 

235 

236# Round a number to hundredths 

237def round_to_thousandths (number : float) -> float: 

238 return round(number * 1000) / 1000 

239 

240# Given a list with numbers, create a string where number in a row are represented rangedly 

241# e.g. [1, 3, 5, 6, 7, 8] => "1, 3, 5-8" 

242def ranger (numbers : list[int]) -> str: 

243 # Remove duplicates and sort numbers 

244 sorted_numbers = sorted(list(set(numbers))) 

245 # Get the number of numbers in the list 

246 number_count = len(sorted_numbers) 

247 # If there is only one number then finish here 

248 if number_count == 1: 

249 return str(sorted_numbers[0]) 

250 # Start the parsing otherwise 

251 ranged = '' 

252 last_number = -1 

253 # Iterate numbers 

254 for i, number in enumerate(sorted_numbers): 

255 # Skip this number if it was already included in a previous serie 

256 if i <= last_number: continue 

257 # Add current number to the ranged string 

258 ranged += ',' + str(number) 

259 # Now iterate numbers after the current number 

260 next_index = i+1 

261 for j, next_number in enumerate(sorted_numbers[next_index:], next_index): 

262 # Set the length of the serie 

263 length = j - i 

264 # End of the serie 

265 if next_number - number != length: 

266 # The length here is the length which broke the serie 

267 # i.e. if the length here is 2 the actual serie length is 1 

268 serie_length = length - 1 

269 if serie_length > 1: 

270 last_serie_number = j - 1 

271 previous_number = sorted_numbers[last_serie_number] 

272 ranged += '-' + str(previous_number) 

273 last_number = last_serie_number 

274 break 

275 # End of the selection 

276 if j == number_count - 1: 

277 if length > 1: 

278 ranged += '-' + str(next_number) 

279 last_number = j 

280 # Remove the first coma before returning the ranged string 

281 return ranged[1:] 

282 

283# Set a special iteration system 

284# Return one value of the array and a new array with all other values for each value 

285def otherwise (values : list) -> Generator[tuple, None, None]: 

286 for v, value in enumerate(values): 

287 others = values[0:v] + values[v+1:] 

288 yield value, others 

289 

290# List files in a directory 

291def list_files (directory : str) -> list[str]: 

292 return [f for f in os.listdir(directory) if isfile(f'{directory}/{f}')] 

293 

294# Check if a directory is empty 

295def is_directory_empty (directory : str) -> bool: 

296 return len(os.listdir(directory)) == 0 

297 

298# Set a function to check if a string has patterns to be parsed by a glob function 

299# Note that this is not trivial, but this function should be good enough for our case 

300# https://stackoverflow.com/questions/42283009/check-if-string-is-a-glob-pattern 

301GLOB_CHARACTERS = ['*', '?', '['] 

302def is_glob (path : str) -> bool: 

303 # Find unescaped glob characters 

304 for c, character in enumerate(path): 

305 if character not in GLOB_CHARACTERS: 

306 continue 

307 if c == 0: 

308 return True 

309 previous_characters = path[c-1] 

310 if previous_characters != '\\': 

311 return True 

312 return False 

313 

314# Parse a glob path into one or several results 

315# If the path has no glob characters then return it as it is 

316# Otherwise make sure 

317def parse_glob (path : str) -> list[str]: 

318 # If there is no glob pattern then just return the string as is 

319 if not is_glob(path): 

320 return [ path ] 

321 # If there is glob pattern then parse it 

322 parsed_filepaths = glob(path) 

323 return parsed_filepaths 

324 

325# Supported byte sizes 

326SUPPORTED_BYTE_SIZES = { 

327 2: 'e', 

328 4: 'f', 

329 8: 'd' 

330} 

331 

332# Data is a list of numeric values 

333# Bit size is the number of bits for each value in data to be occupied 

334def store_binary_data (data : list[float], byte_size : int, filepath : str): 

335 # Check bit size to make sense 

336 letter = SUPPORTED_BYTE_SIZES.get(byte_size, None) 

337 if not letter: 

338 raise ValueError(f'Not supported byte size {byte_size}, please select one of these: {", ".join(SUPPORTED_BYTE_SIZES.keys())}') 

339 # Set the binary format 

340 # '<' stands for little endian 

341 byte_flag = f'<{letter}' 

342 # Start writting the output file 

343 with open(filepath, 'wb') as file: 

344 # Iterate over data list values 

345 for value in data: 

346 value = float(value) 

347 file.write(pack(byte_flag, value)) 

348 

349# Capture all stdout or stderr within a code region even if it comes from another non-python threaded process 

350# https://stackoverflow.com/questions/24277488/in-python-how-to-capture-the-stdout-from-a-c-shared-library-to-a-variable 

351class CaptureOutput (object): 

352 escape_char = "\b" 

353 def __init__(self, stream : str = 'stdout'): 

354 # Get sys stdout or stderr 

355 if not hasattr(sys, stream): 

356 raise ValueError(f'Unknown stream "{stream}". Expected stream value is "stdout" or "stderr"') 

357 self.original_stream = getattr(sys, stream) 

358 self.original_streamfd = self.original_stream.fileno() 

359 self.captured_text = "" 

360 # Create a pipe so the stream can be captured: 

361 self.pipe_out, self.pipe_in = os.pipe() 

362 def __enter__(self): 

363 self.captured_text = "" 

364 # Save a copy of the stream: 

365 self.streamfd = os.dup(self.original_streamfd) 

366 # Replace the original stream with our write pipe: 

367 os.dup2(self.pipe_in, self.original_streamfd) 

368 return self 

369 def __exit__(self, type, value, traceback): 

370 # Print the escape character to make the readOutput method stop: 

371 self.original_stream.write(self.escape_char) 

372 # Flush the stream to make sure all our data goes in before 

373 # the escape character: 

374 self.original_stream.flush() 

375 self.readOutput() 

376 # Close the pipe: 

377 os.close(self.pipe_in) 

378 os.close(self.pipe_out) 

379 # Restore the original stream: 

380 os.dup2(self.streamfd, self.original_streamfd) 

381 # Close the duplicate stream: 

382 os.close(self.streamfd) 

383 def readOutput(self): 

384 while True: 

385 char = os.read(self.pipe_out,1).decode(self.original_stream.encoding) 

386 if not char or self.escape_char in char: 

387 break 

388 self.captured_text += char 

389 

390# Set a function to request data to the PDB GraphQL API 

391# Note that this function may be used for either PDB ids or PDB molecule ids, depending on the query 

392# The query parameter may be constructed using the following page: 

393# https://data.rcsb.org/graphql/index.html 

394def request_pdb_data (pdb_id : str, query : str) -> dict: 

395 # Make sure the PDB id is valid as we set the correct key to mine the response data 

396 if len(pdb_id) == 4: data_key = 'entry' 

397 elif len(pdb_id) < 4 or len(pdb_id) > 4: data_key = 'chem_comp' 

398 else: raise ValueError(f'Wrong PDB id "{pdb_id}". It must be 4 (entries) or less (ligands) characters long') 

399 # Set the request URL 

400 request_url = 'https://data.rcsb.org/graphql' 

401 # Set the POST data 

402 post_data = { 

403 "query": query, 

404 "variables": { "id": pdb_id } 

405 } 

406 # Send the request 

407 try: 

408 response = requests.post(request_url, json=post_data) 

409 except requests.exceptions.ConnectionError as error: 

410 raise ConnectionError('No internet connection :(') from None 

411 # Get the response 

412 parsed_response = json.loads(response.text)['data'][data_key] 

413 if parsed_response == None: 

414 new_pdb_id = request_replaced_pdb(pdb_id) 

415 if new_pdb_id: 

416 parsed_response = request_pdb_data(new_pdb_id, query) 

417 else: 

418 print(f'PDB id {pdb_id} not found') 

419 return parsed_response 

420 

421# Use the RCSB REST API to get the replaced PDB id 

422# This is useful when the PDB is obsolete and has been replaced 

423def request_replaced_pdb(pdb_id): 

424 query_url = 'https://data.rcsb.org/rest/v1/holdings/removed/'+pdb_id 

425 response = requests.get(query_url, headers={'Content-Type': 'application/json'}) 

426 # Check if the response is OK 

427 if response.status_code == 200: 

428 try: 

429 return response.json()['rcsb_repository_holdings_removed']['id_codes_replaced_by'][0] 

430 except: 

431 print(f'Error when mining replaced PDB id for {pdb_id}') 

432 return None 

433 else: 

434 return None 

435 

436# Given a filename, set a sufix number on it, right before the extension 

437# Set also the number of zeros to fill the name 

438def numerate_filename (filename : str, number : int, zeros : int = 2, separator : str = '_') -> str: 

439 splits = filename.split('.') 

440 sufix = separator + str(number).zfill(zeros) 

441 return '.'.join(splits[0:-1]) + sufix + '.' + splits[-1] 

442 

443# Given a filename, set a sufix including '*', right before the extension 

444# This should match all filenames obtanied through the numerate_filename function when used in bash 

445def glob_filename (filename : str, separator : str = '_') -> str: 

446 splits = filename.split('.') 

447 sufix = separator + '*' 

448 return '.'.join(splits[0:-1]) + sufix + '.' + splits[-1] 

449 

450# Delete all files matched by the glob_filename function 

451def purge_glob (filename : str): 

452 glob_pattern = glob_filename(filename) 

453 existing_outputs = glob(glob_pattern) 

454 for existing_output in existing_outputs: 

455 if exists(existing_output): os.remove(existing_output) 

456 

457# Given a filename with the the pattern 'mda.xxxx.json', get the 'xxxx' out of it 

458def get_analysis_name (filename : str) -> str: 

459 name_search = re.search(r'/mda.([A-Za-z0-9_-]*).json$', filename) 

460 if not name_search: 

461 raise ValueError(f'Wrong expected format in filename {filename}') 

462 # To make it coherent with the rest of analyses, the analysis name become parsed when loaded in the database 

463 # Every '_' is replaced by '-' so we must keep the analysis name coherent or the web client will not find it 

464 return name_search[1].replace('_', '-') 

465 

466# Use a safe alternative to hasattr/getattr 

467# DANI: Do not use getattr with a default argument or hasattr 

468# DANI: If you do, you will loose any further AtributeError(s) 

469# DANI: Thus you will have very silent errors every time you have a silly typo 

470# DANI: This is a python itself unresolved error https://bugs.python.org/issue39865 

471def safe_hasattr (instance, attribute_name : str) -> bool: 

472 return attribute_name in set(dir(instance)) 

473def safe_getattr (instance, attribute_name : str, default): 

474 if not safe_hasattr(instance, attribute_name): return default 

475 return getattr(instance, attribute_name) 

476 

477# Function to read and write a dict nested value with using a single combined key 

478 

479# Read a value in a nested dictionary and return the placeholder if any key in the path does not exist 

480def read_ndict (nested_dict : dict, nested_key : str, placeholder = KeyError('Missing nested key')): 

481 keys = nested_key.split('.') 

482 value = nested_dict 

483 for key in keys: 

484 # support list indices 

485 if key.isdigit(): 

486 if type(value) != list: return placeholder 

487 index = int(key) 

488 value = value[index] 

489 # support dict keys 

490 else: 

491 if type(value) != dict: return placeholder 

492 value = value.get(key, placeholder) 

493 if value == placeholder: return placeholder 

494 return value 

495 

496# Write a value in a nested dictionary and raise an error if any key in the path s missing 

497def write_ndict (nested_dict : dict, nested_key : str, value): 

498 keys = nested_key.split('.') 

499 nested_keys = keys[0:-1] 

500 next_target = nested_dict 

501 for k, key in enumerate(nested_keys): 

502 # support list indices 

503 if key.isdigit(): 

504 if type(next_target) != list: 

505 raise ValueError(f'{".".join(nested_keys[0:k])} should be a list, but it is {next_target}') 

506 index = int(key) 

507 next_target = next_target[index] 

508 # support dict keys 

509 else: 

510 if type(next_target) != dict: 

511 raise ValueError(f'{".".join(nested_keys[0:k])} should be a dict, but it is {next_target}') 

512 missing_key_error = KeyError(f'Missing nested key {key}') 

513 next_target = next_target.get(key, missing_key_error) 

514 if next_target == missing_key_error: raise missing_key_error 

515 field = keys[-1] 

516 next_target[field] = value 

517 

518# Get the current git version 

519def get_git_version () -> str: 

520 git_command = f"git -C {__path__[0]} describe" 

521 process = run(git_command, shell=True, stdout=PIPE) 

522 return process.stdout.decode().replace('\n','')