Coverage for src/bob/pipelines/dataset/protocols/retrieve.py: 66%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2024-07-12 21:32 +0200

1"""Allows to find a protocol definition file locally, or download it if needed. 

2 

3 

4Expected protocol structure: 

5 

6``base_dir / subdir / database_filename / protocol_name / group_name`` 

7 

8 

9By default, ``base_dir`` will be pointed by the ``bob_data_dir`` config. 

10``subdir`` is provided as a way to use a directory inside ``base_dir`` when 

11using its default. 

12 

13Here are some valid example paths (``bob_data_dir=/home/username/bob_data``): 

14 

15In a "raw" directory (not an archive): 

16 

17``/home/username/bob_data/protocols/my_db/my_protocol/my_group`` 

18 

19In an archive: 

20 

21``/home/username/bob_data/protocols/my_db.tar.gz/my_protocol/my_group`` 

22 

23In an archive with the database name as top-level (some legacy db used that): 

24 

25``/home/username/bob_data/protocols/my_db.tar.gz/my_db/my_protocol/my_group`` 

26 

27""" 

28 

29import glob 

30 

31from logging import getLogger 

32from os import PathLike 

33from pathlib import Path 

34from typing import Any, Callable, Optional, TextIO, Union 

35 

36import requests 

37 

38from clapper.rc import UserDefaults 

39 

40from bob.pipelines.dataset.protocols import archive, hashing 

41 

42logger = getLogger(__name__) 

43 

44 

45def _get_local_data_directory() -> Path: 

46 """Returns the local directory for data (``bob_data_dir`` config).""" 

47 user_config = UserDefaults("bobrc.toml") 

48 return Path( 

49 user_config.get("bob_data_dir", default=Path.home() / "bob_data") 

50 ) 

51 

52 

53def _infer_filename_from_urls(urls=Union[list[str], str]) -> str: 

54 """Retrieves the remote filename from the URLs. 

55 

56 Parameters 

57 ---------- 

58 urls 

59 One or multiple URLs pointing to files with the same name. 

60 

61 Returns 

62 ------- 

63 The remote file name. 

64 

65 Raises 

66 ------ 

67 ValueError 

68 When urls point to files with different names. 

69 """ 

70 if isinstance(urls, str): 

71 return urls.split("/")[-1] 

72 

73 # Check that all urls point to the same file name 

74 names = [u.split("/")[-1] for u in urls] 

75 if not all(n == names[0] for n in names): 

76 raise ValueError( 

77 f"Cannot infer file name when urls point to different files ({names=})." 

78 ) 

79 return urls[0].split("/")[-1] 

80 

81 

82def retrieve_protocols( 

83 urls: list[str], 

84 destination_filename: Optional[str] = None, 

85 base_dir: Union[PathLike[str], str, None] = None, 

86 subdir: Union[PathLike[str], str] = "protocol", 

87 checksum: Union[str, None] = None, 

88) -> Path: 

89 """Automatically downloads the necessary protocol definition files.""" 

90 if base_dir is None: 

91 base_dir = _get_local_data_directory() 

92 

93 remote_filename = _infer_filename_from_urls(urls) 

94 if destination_filename is None: 

95 destination_filename = remote_filename 

96 elif Path(remote_filename).suffixes != Path(destination_filename).suffixes: 

97 raise ValueError( 

98 "Local dataset protocol definition files must have the same " 

99 f"extension as the remote ones ({remote_filename=})" 

100 ) 

101 

102 return download_protocol_definition( 

103 urls=urls, 

104 destination_base_dir=base_dir, 

105 destination_subdir=subdir, 

106 destination_filename=destination_filename, 

107 checksum=checksum, 

108 force=False, 

109 ) 

110 

111 

112def list_protocol_paths( 

113 database_name: str, 

114 base_dir: Union[PathLike[str], str, None] = None, 

115 subdir: Union[PathLike[str], str] = "protocol", 

116 database_filename: Union[str, None] = None, 

117) -> list[Path]: 

118 """Returns the paths of each protocol in a database definition file.""" 

119 if base_dir is None: 

120 base_dir = _get_local_data_directory() 

121 final_dir = Path(base_dir) / subdir 

122 final_dir /= ( 

123 database_name if database_filename is None else database_filename 

124 ) 

125 

126 if archive.is_archive(final_dir): 

127 protocols = archive.list_dirs(final_dir, show_files=False) 

128 if len(protocols) == 1 and protocols[0].name == database_name: 

129 protocols = archive.list_dirs( 

130 final_dir, database_name, show_files=False 

131 ) 

132 

133 archive_path, inner_dir = archive.path_and_subdir(final_dir) 

134 if inner_dir is None: 

135 return [ 

136 Path(f"{archive_path.as_posix()}:{p.as_posix().lstrip('/')}") 

137 for p in protocols 

138 ] 

139 

140 return [ 

141 Path(f"{archive_path.as_posix()}:{inner_dir.as_posix()}/{p.name}") 

142 for p in protocols 

143 ] 

144 

145 # Not an archive 

146 return final_dir.iterdir() 

147 

148 

149def get_protocol_path( 

150 database_name: str, 

151 protocol: str, 

152 base_dir: Union[PathLike[str], str, None] = None, 

153 subdir: Union[PathLike[str], str] = "protocols", 

154 database_filename: Optional[str] = None, 

155) -> Union[Path, None]: 

156 """Returns the path of a specific protocol. 

157 

158 Will look for ``protocol`` in ``base_dir / subdir / database_(file)name``. 

159 

160 Returns 

161 ------- 

162 Path 

163 The required protocol's path for the database. 

164 """ 

165 protocol_paths = list_protocol_paths( 

166 database_name=database_name, 

167 base_dir=base_dir, 

168 subdir=subdir, 

169 database_filename=database_filename, 

170 ) 

171 for protocol_path in protocol_paths: 

172 if archive.is_archive(protocol_path): 

173 _base, inner = archive.path_and_subdir(protocol_path) 

174 if inner.name == protocol: 

175 return protocol_path 

176 elif protocol_path.name == protocol: 

177 return protocol_path 

178 logger.warning(f"Protocol {protocol} not found in {database_name}.") 

179 return None 

180 

181 

182def list_protocol_names( 

183 database_name: str, 

184 base_dir: Union[PathLike[str], str, None] = None, 

185 subdir: Union[PathLike[str], str, None] = "protocols", 

186 database_filename: Union[str, None] = None, 

187) -> list[str]: 

188 """Returns the paths of the protocol directories for a given database. 

189 

190 Archives are also accepted, either if the file name is the same as 

191 ``database_name`` with a ``.tar.gz`` extension or by specifying the filename 

192 in ``database_filename``. 

193 

194 This will look in ``base_dir/subdir`` for ``database_filename``, then 

195 ``database_name``, then ``database_name+".tar.gz"``. 

196 

197 Parameters 

198 ---------- 

199 database_name 

200 The database name used to infer ``database_filename`` if not specified. 

201 base_dir 

202 The base path of data files (defaults to the ``bob_data_dir`` config, or 

203 ``~/bob_data`` if not configured). 

204 subdir 

205 A sub directory for the protocols in ``base_dir``. 

206 database_filename 

207 If the file/directory name of the protocols is not the same as the 

208 name of the database, this can be set to look in the correct file. 

209 

210 Returns 

211 ------- 

212 A list of protocol names 

213 The different protocols available for that database. 

214 """ 

215 

216 if base_dir is None: 

217 base_dir = _get_local_data_directory() 

218 

219 if subdir is None: 

220 subdir = "." 

221 

222 if database_filename is None: 

223 database_filename = database_name 

224 final_path: Path = Path(base_dir) / subdir / database_filename 

225 if not final_path.is_dir(): 

226 database_filename = database_name + ".tar.gz" 

227 

228 final_path: Path = Path(base_dir) / subdir / database_filename 

229 

230 if archive.is_archive(final_path): 

231 top_level_dirs = archive.list_dirs(final_path, show_files=False) 

232 # Handle a database archive having database_name as top-level directory 

233 if len(top_level_dirs) == 1 and top_level_dirs[0].name == database_name: 

234 return [ 

235 p.name 

236 for p in archive.list_dirs( 

237 final_path, inner_dir=database_name, show_files=False 

238 ) 

239 ] 

240 return [p.name for p in top_level_dirs] 

241 # Not an archive: list the dirs 

242 return [p.name for p in final_path.iterdir() if p.is_dir()] 

243 

244 

245def open_definition_file( 

246 search_pattern: Union[PathLike[str], str], 

247 database_name: str, 

248 protocol: str, 

249 base_dir: Union[PathLike[str], str, None] = None, 

250 subdir: Union[PathLike[str], str, None] = "protocols", 

251 database_filename: Optional[str] = None, 

252) -> Union[TextIO, None]: 

253 """Opens a protocol definition file inside a protocol directory. 

254 

255 Also handles protocols inside an archive. 

256 """ 

257 search_path = get_protocol_path( 

258 database_name, protocol, base_dir, subdir, database_filename 

259 ) 

260 

261 if archive.is_archive(search_path): 

262 return archive.search_and_open( 

263 search_pattern=search_pattern, 

264 archive_path=search_path, 

265 ) 

266 

267 search_pattern = Path(search_pattern) 

268 

269 # we prepend './' to search_pattern because it might start with '/' 

270 pattern = search_path / "**" / f"./{search_pattern.as_posix()}" 

271 for path in glob.iglob(pattern.as_posix(), recursive=True): 

272 if not Path(path).is_file(): 

273 continue 

274 return open(path, mode="rt") 

275 logger.info(f"Unable to locate and open a file that matches '{pattern}'.") 

276 return None 

277 

278 

279def list_group_paths( 

280 database_name: str, 

281 protocol: str, 

282 base_dir: Union[PathLike[str], str, None] = None, 

283 subdir: Union[PathLike[str], str] = "protocols", 

284 database_filename: Optional[str] = None, 

285) -> list[Path]: 

286 """Returns the file paths of the groups in protocol""" 

287 protocol_path = get_protocol_path( 

288 database_name=database_name, 

289 protocol=protocol, 

290 base_dir=base_dir, 

291 subdir=subdir, 

292 database_filename=database_filename, 

293 ) 

294 if archive.is_archive(protocol_path): 

295 groups_inner = archive.list_dirs(protocol_path) 

296 archive_path, inner_path = archive.path_and_subdir(protocol_path) 

297 return [ 

298 Path(f"{archive_path.as_posix()}:{inner_path.as_posix()}/{g}") 

299 for g in groups_inner 

300 ] 

301 return protocol_path.iterdir() 

302 

303 

304def list_group_names( 

305 database_name: str, 

306 protocol: str, 

307 base_dir: Union[PathLike[str], str, None] = None, 

308 subdir: Union[PathLike[str], str] = "protocols", 

309 database_filename: Optional[str] = None, 

310) -> list[str]: 

311 """Returns the group names of a protocol.""" 

312 paths = list_group_paths( 

313 database_name=database_name, 

314 protocol=protocol, 

315 base_dir=base_dir, 

316 subdir=subdir, 

317 database_filename=database_filename, 

318 ) 

319 # Supports groups as files or dirs 

320 return [p.stem for p in paths] # ! This means group can't include a '.' 

321 

322 

323def download_protocol_definition( 

324 urls: Union[list[str], str], 

325 destination_base_dir: Union[PathLike, None] = None, 

326 destination_subdir: Union[str, None] = None, 

327 destination_filename: Union[str, None] = None, 

328 checksum: Union[str, None] = None, 

329 checksum_fct: Callable[[Any, int], str] = hashing.sha256_hash, 

330 force: bool = False, 

331 makedirs: bool = True, 

332) -> Path: 

333 """Downloads a remote file locally. 

334 

335 Parameters 

336 ---------- 

337 urls 

338 The remote location of the server. If multiple addresses are given, we will try 

339 to download from them in order until one succeeds. 

340 destination_basedir 

341 A path to a local directory where the file will be saved. If omitted, the file 

342 will be saved in the folder pointed by the ``wdr.local_directory`` key in the 

343 user configuration. 

344 destination_subdir 

345 An additional layer added to the destination directory (useful when using 

346 ``destination_directory=None``). 

347 destination_filename 

348 The final name of the local file. If omitted, the file will keep the name of 

349 the remote file. 

350 checksum 

351 When provided, will compute the file's checksum and compare to this. 

352 checksum_fct 

353 A callable that takes a ``reader`` and returns a hash. 

354 force 

355 Re-download and overwrite any existing file with the same name. 

356 makedirs 

357 Automatically make the parent directories of the new local file. 

358 

359 Returns 

360 ------- 

361 The path to the new local file. 

362 

363 Raises 

364 ------ 

365 RuntimeError 

366 When the URLs provided are all invalid. 

367 ValueError 

368 When ``destination_filename`` is omitted and URLs point to files with different 

369 names. 

370 When the checksum of the file does not correspond to the provided ``checksum``. 

371 """ 

372 

373 if destination_filename is None: 

374 destination_filename = _infer_filename_from_urls(urls=urls) 

375 

376 if destination_base_dir is None: 

377 destination_base_dir = _get_local_data_directory() 

378 

379 destination_base_dir = Path(destination_base_dir) 

380 

381 if destination_subdir is not None: 

382 destination_base_dir = destination_base_dir / destination_subdir 

383 

384 local_file = destination_base_dir / destination_filename 

385 needs_download = True 

386 

387 if not force and local_file.is_file(): 

388 if checksum is None: 

389 logger.info( 

390 f"File {local_file} already exists, skipping download ({force=})." 

391 ) 

392 needs_download = False 

393 elif hashing.verify_file(local_file, checksum, checksum_fct): 

394 logger.info( 

395 f"File {local_file} already exists and checksum is valid." 

396 ) 

397 needs_download = False 

398 

399 if needs_download: 

400 if isinstance(urls, str): 

401 urls = [urls] 

402 

403 for tries, url in enumerate(urls): 

404 logger.debug(f"Retrieving file from '{url}'.") 

405 try: 

406 response = requests.get(url=url, timeout=10) 

407 except requests.exceptions.ConnectionError as e: 

408 if tries < len(urls) - 1: 

409 logger.info( 

410 f"Could not connect to {url}. Trying other URLs." 

411 ) 

412 logger.debug(e) 

413 continue 

414 

415 logger.debug( 

416 f"http response: '{response.status_code}: {response.reason}'." 

417 ) 

418 

419 if response.ok: 

420 logger.debug(f"Got file from {url}.") 

421 break 

422 if tries < len(urls) - 1: 

423 logger.info( 

424 f"Failed to get file from {url}, trying other URLs." 

425 ) 

426 logger.debug(f"requests.response was\n{response}") 

427 else: 

428 raise RuntimeError( 

429 f"Could not retrieve file from any of the provided URLs! ({urls=})" 

430 ) 

431 

432 if makedirs: 

433 local_file.parent.mkdir(parents=True, exist_ok=True) 

434 

435 with local_file.open("wb") as f: 

436 f.write(response.content) 

437 

438 if checksum is not None: 

439 if not hashing.verify_file(local_file, checksum, hash_fct=checksum_fct): 

440 if not needs_download: 

441 raise ValueError( 

442 f"The local file hash does not correspond to '{checksum}' " 

443 f"and {force=} prevents overwriting." 

444 ) 

445 raise ValueError( 

446 "The downloaded file hash ('" 

447 f"{hashing.compute_crc(local_file, hash_fct=checksum_fct)}') does " 

448 f"not correspond to '{checksum}'." 

449 ) 

450 

451 return local_file