Coverage for model_workflow/utils/auxiliar.py: 69%

234 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-23 10:54 +0000

1# Auxiliar generic functions and classes used along the workflow 

2 

3from model_workflow.utils.constants import RESIDUE_NAME_LETTERS, PROTEIN_RESIDUE_NAME_LETTERS 

4from model_workflow.utils.constants import YELLOW_HEADER, COLOR_END 

5 

6import os 

7from os import rename, listdir, remove 

8from os.path import isfile, exists 

9import re 

10import sys 

11import json 

12import yaml 

13from glob import glob 

14from typing import Optional, List, Set, Generator, Union 

15from struct import pack 

16# NEVER FORGET: GraphQL has a problem with urllib.parse -> It will always return error 400 (Bad request) 

17# We must use requests instead 

18import requests 

19 

20 

21# Check if a module has been imported 

22def is_imported (module_name : str) -> bool: 

23 return module_name in sys.modules 

24 

25# Set custom exception which is not to print traceback 

26# They are used when the problem is not in our code 

27class QuietException (Exception): 

28 pass 

29 

30# Set a custom quite exception for when user input is wrong 

31class InputError (QuietException): 

32 pass 

33 

34# Set a custom quite exception for when MD data has not passed a quality check test 

35class TestFailure (QuietException): 

36 pass 

37 

38# Set a custom quite exception for when the problem comes from a third party dependency 

39class ToolError (QuietException): 

40 pass 

41 

42# Set a custom quite exception for when the problem comes from a remote service 

43class RemoteServiceError (QuietException): 

44 pass 

45 

46# Set a no referable exception for PDB synthetic constructs or chimeric entities 

47class NoReferableException (Exception): 

48 def __str__ (self): return f'No referable sequence {self.sequence}' 

49 def __repr__ (self): return self.__str__() 

50 def get_sequence (self) -> str: 

51 return self.args[0] 

52 sequence = property(get_sequence, None, None, 'Aminoacids sequence') 

53 

54# Set a custom exception handler where our input error exception has a quiet behaviour 

55def custom_excepthook (exception_class, message, traceback): 

56 # Quite behaviour if it is our input error exception 

57 if QuietException in exception_class.__bases__: 

58 print('{0}: {1}'.format(exception_class.__name__, message)) # Only print Error Type and Message 

59 return 

60 # Default behaviour otherwise 

61 sys.__excepthook__(exception_class, message, traceback) 

62sys.excepthook = custom_excepthook 

63 

64# Set a special exceptions for when the topology is missing 

65MISSING_TOPOLOGY = Exception('Missing topology') 

66MISSING_CHARGES = Exception('Missing atom charges') 

67MISSING_BONDS = Exception('Missing atom bonds') 

68JSON_SERIALIZABLE_MISSING_BONDS = 'MB' 

69 

70# Set a function to get the next letter from an input letter in alphabetic order 

71# Return None if we run out of letters 

72letters = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 

73 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' } 

74def get_new_letter(current_letters : set) -> Optional[str]: 

75 return next((letter for letter in letters if letter not in current_letters), None) 

76 

77# Given a list or set of names, return a set with all case-posibilites: 

78# All upper case 

79# All lower case 

80# First upper case and the rest lower case 

81def all_cases (names : Union[List[str], Set[str]]) -> Set[str]: 

82 all_names = [] 

83 for name in names: 

84 all_upper = name.upper() 

85 all_lower = name.lower() 

86 one_upper = name[0].upper() + name[1:].lower() 

87 all_names += [ all_upper, all_lower, one_upper ] 

88 return set(all_names) 

89 

90# Given a residue name, return its single letter 

91def residue_name_to_letter (residue_name : str) -> str: 

92 return RESIDUE_NAME_LETTERS.get(residue_name, 'X') 

93 

94# Given a protein residue name, return its single letter 

95def protein_residue_name_to_letter (residue_name : str) -> str: 

96 return PROTEIN_RESIDUE_NAME_LETTERS.get(residue_name, 'X') 

97 

98# Set a JSON loader with additional logic to better handle problems 

99def load_json (filepath : str) -> dict: 

100 try: 

101 with open(filepath, 'r') as file: 

102 content = json.load(file) 

103 return content 

104 except Exception as error: 

105 raise Exception(f'Something went wrong when loading JSON file {filepath}: {str(error)}') 

106 

107# Set a JSON saver with additional logic to better handle problems 

108def save_json (content, filepath : str, indent : Optional[int] = None): 

109 try: 

110 with open(filepath, 'w') as file: 

111 json.dump(content, file, indent=indent) 

112 except Exception as error: 

113 # Rename the JSON file since it will be half written thus giving problems when loaded 

114 rename(filepath, filepath + '.wrong') 

115 raise Exception(f'Something went wrong when saving JSON file {filepath}: {str(error)}') 

116 

117# Set a YAML loader with additional logic to better handle problems 

118# DANI: Por algún motivo yaml.load también funciona con archivos en formato JSON 

119def load_yaml (filepath : str): 

120 try: 

121 with open(filepath, 'r') as file: 

122 content = yaml.load(file, Loader=yaml.CLoader) 

123 return content 

124 except Exception as error: 

125 warn(str(error).replace('\n', ' ')) 

126 raise InputError('Something went wrong when loading YAML file ' + filepath) 

127 

128# Set a YAML saver with additional logic to better handle problems 

129def save_yaml (content, filepath : str): 

130 with open(filepath, 'w') as file: 

131 yaml.dump(content, file) 

132 

133# Set a few constants to erase previou logs in the terminal 

134CURSOR_UP_ONE = '\x1b[1A' 

135ERASE_LINE = '\x1b[2K' 

136ERASE_PREVIOUS_LINES = CURSOR_UP_ONE + ERASE_LINE + CURSOR_UP_ONE 

137 

138# Set a function to remove previous line 

139def delete_previous_log (): 

140 print(ERASE_PREVIOUS_LINES) 

141 

142# Set a function to reprint in the same line 

143def reprint (text : str): 

144 delete_previous_log() 

145 print(text) 

146 

147# Set a function to print a messahe with a colored warning header 

148def warn (message : str): 

149 print(YELLOW_HEADER + '⚠ WARNING: ' + COLOR_END + message) 

150 

151# Get the mean/average of a list of values 

152def mean(values : List[float]) -> float: 

153 return sum(values) / len(values) 

154 

155# Round a number to hundredths 

156def round_to_hundredths (number : float) -> float: 

157 return round(number * 100) / 100 

158 

159# Round a number to hundredths 

160def round_to_thousandths (number : float) -> float: 

161 return round(number * 1000) / 1000 

162 

163# Given a list with numbers, create a string where number in a row are represented rangedly 

164# e.g. [1, 3, 5, 6, 7, 8] => "1, 3, 5-8" 

165def ranger (numbers : List[int]) -> str: 

166 # Remove duplicates and sort numbers 

167 sorted_numbers = sorted(list(set(numbers))) 

168 # Get the number of numbers in the list 

169 number_count = len(sorted_numbers) 

170 # If there is only one number then finish here 

171 if number_count == 1: 

172 return str(sorted_numbers[0]) 

173 # Start the parsing otherwise 

174 ranged = '' 

175 last_number = -1 

176 # Iterate numbers 

177 for i, number in enumerate(sorted_numbers): 

178 # Skip this number if it was already included in a previous serie 

179 if i <= last_number: continue 

180 # Add current number to the ranged string 

181 ranged += ',' + str(number) 

182 # Now iterate numbers after the current number 

183 next_index = i+1 

184 for j, next_number in enumerate(sorted_numbers[next_index:], next_index): 

185 # Set the length of the serie 

186 length = j - i 

187 # End of the serie 

188 if next_number - number != length: 

189 # The length here is the length which broke the serie 

190 # i.e. if the length here is 2 the actual serie length is 1 

191 serie_length = length - 1 

192 if serie_length > 1: 

193 last_serie_number = j - 1 

194 previous_number = sorted_numbers[last_serie_number] 

195 ranged += '-' + str(previous_number) 

196 last_number = last_serie_number 

197 break 

198 # End of the selection 

199 if j == number_count - 1: 

200 if length > 1: 

201 ranged += '-' + str(next_number) 

202 last_number = j 

203 # Remove the first coma before returning the ranged string 

204 return ranged[1:] 

205 

206# Set a special iteration system 

207# Return one value of the array and a new array with all other values for each value 

208def otherwise (values : list) -> Generator[tuple, None, None]: 

209 for v, value in enumerate(values): 

210 others = values[0:v] + values[v+1:] 

211 yield value, others 

212 

213# List files in a directory 

214def list_files (directory : str) -> List[str]: 

215 return [f for f in listdir(directory) if isfile(f'{directory}/{f}')] 

216 

217# Check if a directory is empty 

218def is_directory_empty (directory : str) -> bool: 

219 return len(listdir(directory)) == 0 

220 

221# Set a function to check if a string has patterns to be parsed by a glob function 

222# Note that this is not trivial, but this function should be good enough for our case 

223# https://stackoverflow.com/questions/42283009/check-if-string-is-a-glob-pattern 

224GLOB_CHARACTERS = ['*', '?', '['] 

225def is_glob (path : str) -> bool: 

226 # Find unescaped glob characters 

227 for c, character in enumerate(path): 

228 if character not in GLOB_CHARACTERS: 

229 continue 

230 if c == 0: 

231 return True 

232 previous_characters = path[c-1] 

233 if previous_characters != '\\': 

234 return True 

235 return False 

236 

237# Parse a glob path into one or several results 

238# If the path has no glob characters then return it as it is 

239# Otherwise make sure 

240def parse_glob (path : str) -> List[str]: 

241 # If there is no glob pattern then just return the string as is 

242 if not is_glob(path): 

243 return [ path ] 

244 # If there is glob pattern then parse it 

245 parsed_filepaths = glob(path) 

246 return parsed_filepaths 

247 

248# Supported byte sizes 

249SUPPORTED_BYTE_SIZES = { 

250 2: 'e', 

251 4: 'f', 

252 8: 'd' 

253} 

254 

255# Data is a list of numeric values 

256# Bit size is the number of bits for each value in data to be occupied 

257def store_binary_data (data : List[float], byte_size : int, filepath : str): 

258 # Check bit size to make sense 

259 letter = SUPPORTED_BYTE_SIZES.get(byte_size, None) 

260 if not letter: 

261 raise ValueError(f'Not supported byte size {byte_size}, please select one of these: {", ".join(SUPPORTED_BYTE_SIZES.keys())}') 

262 # Set the binary format 

263 # '<' stands for little endian 

264 byte_flag = f'<{letter}' 

265 # Start writting the output file 

266 with open(filepath, 'wb') as file: 

267 # Iterate over data list values 

268 for value in data: 

269 value = float(value) 

270 file.write(pack(byte_flag, value)) 

271 

272# Capture all stdout or stderr within a code region even if it comes from another non-python threaded process 

273# https://stackoverflow.com/questions/24277488/in-python-how-to-capture-the-stdout-from-a-c-shared-library-to-a-variable 

274class CaptureOutput (object): 

275 escape_char = "\b" 

276 def __init__(self, stream : str = 'stdout'): 

277 # Get sys stdout or stderr 

278 if not hasattr(sys, stream): 

279 raise ValueError(f'Unknown stream "{stream}". Expected stream value is "stdout" or "stderr"') 

280 self.original_stream = getattr(sys, stream) 

281 self.original_streamfd = self.original_stream.fileno() 

282 self.captured_text = "" 

283 # Create a pipe so the stream can be captured: 

284 self.pipe_out, self.pipe_in = os.pipe() 

285 def __enter__(self): 

286 self.captured_text = "" 

287 # Save a copy of the stream: 

288 self.streamfd = os.dup(self.original_streamfd) 

289 # Replace the original stream with our write pipe: 

290 os.dup2(self.pipe_in, self.original_streamfd) 

291 return self 

292 def __exit__(self, type, value, traceback): 

293 # Print the escape character to make the readOutput method stop: 

294 self.original_stream.write(self.escape_char) 

295 # Flush the stream to make sure all our data goes in before 

296 # the escape character: 

297 self.original_stream.flush() 

298 self.readOutput() 

299 # Close the pipe: 

300 os.close(self.pipe_in) 

301 os.close(self.pipe_out) 

302 # Restore the original stream: 

303 os.dup2(self.streamfd, self.original_streamfd) 

304 # Close the duplicate stream: 

305 os.close(self.streamfd) 

306 def readOutput(self): 

307 while True: 

308 char = os.read(self.pipe_out,1).decode(self.original_stream.encoding) 

309 if not char or self.escape_char in char: 

310 break 

311 self.captured_text += char 

312 

313# Set a function to request data to the PDB GraphQL API 

314# Note that this function may be used for either PDB ids or PDB molecule ids, depending on the query 

315# The query parameter may be constructed using the following page: 

316# https://data.rcsb.org/graphql/index.html 

317def request_pdb_data (pdb_id : str, query : str) -> dict: 

318 # Make sure the PDB id is valid as we set the correct key to mine the response data 

319 if len(pdb_id) == 4: data_key = 'entry' 

320 elif len(pdb_id) < 4 or len(pdb_id) > 4: data_key = 'chem_comp' 

321 else: raise ValueError(f'Wrong PDB id "{pdb_id}". It must be 4 (entries) or less (ligands) characters long') 

322 # Set the request URL 

323 request_url = 'https://data.rcsb.org/graphql' 

324 # Set the POST data 

325 post_data = { 

326 "query": query, 

327 "variables": { "id": pdb_id } 

328 } 

329 # Send the request 

330 try: 

331 response = requests.post(request_url, json=post_data) 

332 except requests.exceptions.ConnectionError as error: 

333 raise ConnectionError('No internet connection :(') from None 

334 # Get the response 

335 parsed_response = json.loads(response.text)['data'][data_key] 

336 if parsed_response == None: 

337 new_pdb_id = request_replaced_pdb(pdb_id) 

338 if new_pdb_id: 

339 parsed_response = request_pdb_data(new_pdb_id, query) 

340 else: 

341 print(f'PDB id {pdb_id} not found') 

342 return parsed_response 

343 

344# Use the RCSB REST API to get the replaced PDB id 

345# This is useful when the PDB is obsolete and has been replaced 

346def request_replaced_pdb(pdb_id): 

347 query_url = 'https://data.rcsb.org/rest/v1/holdings/removed/'+pdb_id 

348 response = requests.get(query_url, headers={'Content-Type': 'application/json'}) 

349 # Check if the response is OK 

350 if response.status_code == 200: 

351 try: 

352 return response.json()['rcsb_repository_holdings_removed']['id_codes_replaced_by'][0] 

353 except: 

354 print(f'Error when mining replaced PDB id for {pdb_id}') 

355 return None 

356 else: 

357 return None 

358 

359# Given a filename, set a sufix number on it, right before the extension 

360# Set also the number of zeros to fill the name 

361def numerate_filename (filename : str, number : int, zeros : int = 2, separator : str = '_') -> str: 

362 splits = filename.split('.') 

363 sufix = separator + str(number).zfill(zeros) 

364 return '.'.join(splits[0:-1]) + sufix + '.' + splits[-1] 

365 

366# Given a filename, set a sufix including '*', right before the extension 

367# This should match all filenames obtanied through the numerate_filename function when used in bash 

368def glob_filename (filename : str, separator : str = '_') -> str: 

369 splits = filename.split('.') 

370 sufix = separator + '*' 

371 return '.'.join(splits[0:-1]) + sufix + '.' + splits[-1] 

372 

373# Delete all files matched by the glob_filename function 

374def purge_glob (filename : str): 

375 glob_pattern = glob_filename(filename) 

376 existing_outputs = glob(glob_pattern) 

377 for existing_output in existing_outputs: 

378 if exists(existing_output): remove(existing_output) 

379 

380# Given a filename with the the pattern 'mda.xxxx.json', get the 'xxxx' out of it 

381def get_analysis_name (filename : str) -> str: 

382 name_search = re.search(r'/mda.([A-Za-z0-9_-]*).json$', filename) 

383 if not name_search: 

384 raise ValueError(f'Wrong expected format in filename {filename}') 

385 # To make it coherent with the rest of analyses, the analysis name become parsed when loaded in the database 

386 # Every '_' is replaced by '-' so we must keep the analysis name coherent or the web client will not find it 

387 return name_search[1].replace('_', '-') 

388 

389# Use a safe alternative to hasattr/getattr 

390# DANI: Do not use getattr with a default argument or hasattr 

391# DANI: If you do, you will loose any further AtributeError(s) 

392# DANI: Thus you will have very silent errors every time you have a silly typo 

393# DANI: This is a python itself unresolved error https://bugs.python.org/issue39865 

394def safe_hasattr (instance, attribute_name : str) -> bool: 

395 return attribute_name in set(dir(instance)) 

396def safe_getattr (instance, attribute_name : str, defualt): 

397 if not safe_hasattr(instance, attribute_name): return defualt 

398 return getattr(instance, attribute_name) 

399