Coverage for src/bob/pipelines/dataset/protocols/retrieve.py: 66%
139 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-06-16 14:21 +0200
« prev ^ index » next coverage.py v7.0.5, created at 2023-06-16 14:21 +0200
1"""Allows to find a protocol definition file locally, or download it if needed.
4Expected protocol structure:
6``base_dir / subdir / database_filename / protocol_name / group_name``
9By default, ``base_dir`` will be pointed by the ``bob_data_dir`` config.
10``subdir`` is provided as a way to use a directory inside ``base_dir`` when
11using its default.
13Here are some valid example paths (``bob_data_dir=/home/username/bob_data``):
15In a "raw" directory (not an archive):
17``/home/username/bob_data/protocols/my_db/my_protocol/my_group``
19In an archive:
21``/home/username/bob_data/protocols/my_db.tar.gz/my_protocol/my_group``
23In an archive with the database name as top-level (some legacy db used that):
25``/home/username/bob_data/protocols/my_db.tar.gz/my_db/my_protocol/my_group``
27"""
29import glob
31from logging import getLogger
32from os import PathLike
33from pathlib import Path
34from typing import Any, Callable, Optional, TextIO, Union
36import requests
38from clapper.rc import UserDefaults
40from bob.pipelines.dataset.protocols import archive, hashing
42logger = getLogger(__name__)
45def _get_local_data_directory() -> Path:
46 """Returns the local directory for data (``bob_data_dir`` config)."""
47 user_config = UserDefaults("bobrc.toml")
48 return Path(
49 user_config.get("bob_data_dir", default=Path.home() / "bob_data")
50 )
53def _infer_filename_from_urls(urls=Union[list[str], str]) -> str:
54 """Retrieves the remote filename from the URLs.
56 Parameters
57 ----------
58 urls
59 One or multiple URLs pointing to files with the same name.
61 Returns
62 -------
63 The remote file name.
65 Raises
66 ------
67 ValueError
68 When urls point to files with different names.
69 """
70 if isinstance(urls, str):
71 return urls.split("/")[-1]
73 # Check that all urls point to the same file name
74 names = [u.split("/")[-1] for u in urls]
75 if not all(n == names[0] for n in names):
76 raise ValueError(
77 f"Cannot infer file name when urls point to different files ({names=})."
78 )
79 return urls[0].split("/")[-1]
82def retrieve_protocols(
83 urls: list[str],
84 destination_filename: Optional[str] = None,
85 base_dir: Union[PathLike[str], str, None] = None,
86 subdir: Union[PathLike[str], str] = "protocol",
87 checksum: Union[str, None] = None,
88) -> Path:
89 """Automatically downloads the necessary protocol definition files."""
90 if base_dir is None:
91 base_dir = _get_local_data_directory()
93 remote_filename = _infer_filename_from_urls(urls)
94 if destination_filename is None:
95 destination_filename = remote_filename
96 elif Path(remote_filename).suffixes != Path(destination_filename).suffixes:
97 raise ValueError(
98 "Local dataset protocol definition files must have the same "
99 f"extension as the remote ones ({remote_filename=})"
100 )
102 return download_protocol_definition(
103 urls=urls,
104 destination_base_dir=base_dir,
105 destination_subdir=subdir,
106 destination_filename=destination_filename,
107 checksum=checksum,
108 force=False,
109 )
112def list_protocol_paths(
113 database_name: str,
114 base_dir: Union[PathLike[str], str, None] = None,
115 subdir: Union[PathLike[str], str] = "protocol",
116 database_filename: Union[str, None] = None,
117) -> list[Path]:
118 """Returns the paths of each protocol in a database definition file."""
119 if base_dir is None:
120 base_dir = _get_local_data_directory()
121 final_dir = Path(base_dir) / subdir
122 final_dir /= (
123 database_name if database_filename is None else database_filename
124 )
126 if archive.is_archive(final_dir):
127 protocols = archive.list_dirs(final_dir, show_files=False)
128 if len(protocols) == 1 and protocols[0].name == database_name:
129 protocols = archive.list_dirs(
130 final_dir, database_name, show_files=False
131 )
133 archive_path, inner_dir = archive.path_and_subdir(final_dir)
134 if inner_dir is None:
135 return [
136 Path(f"{archive_path.as_posix()}:{p.as_posix().lstrip('/')}")
137 for p in protocols
138 ]
140 return [
141 Path(f"{archive_path.as_posix()}:{inner_dir.as_posix()}/{p.name}")
142 for p in protocols
143 ]
145 # Not an archive
146 return final_dir.iterdir()
149def get_protocol_path(
150 database_name: str,
151 protocol: str,
152 base_dir: Union[PathLike[str], str, None] = None,
153 subdir: Union[PathLike[str], str] = "protocols",
154 database_filename: Optional[str] = None,
155) -> Union[Path, None]:
156 """Returns the path of a specific protocol.
158 Will look for ``protocol`` in ``base_dir / subdir / database_(file)name``.
160 Returns
161 -------
162 Path
163 The required protocol's path for the database.
164 """
165 protocol_paths = list_protocol_paths(
166 database_name=database_name,
167 base_dir=base_dir,
168 subdir=subdir,
169 database_filename=database_filename,
170 )
171 for protocol_path in protocol_paths:
172 if archive.is_archive(protocol_path):
173 _base, inner = archive.path_and_subdir(protocol_path)
174 if inner.name == protocol:
175 return protocol_path
176 elif protocol_path.name == protocol:
177 return protocol_path
178 logger.warning(f"Protocol {protocol} not found in {database_name}.")
179 return None
182def list_protocol_names(
183 database_name: str,
184 base_dir: Union[PathLike[str], str, None] = None,
185 subdir: Union[PathLike[str], str, None] = "protocols",
186 database_filename: Union[str, None] = None,
187) -> list[str]:
188 """Returns the paths of the protocol directories for a given database.
190 Archives are also accepted, either if the file name is the same as
191 ``database_name`` with a ``.tar.gz`` extension or by specifying the filename
192 in ``database_filename``.
194 This will look in ``base_dir/subdir`` for ``database_filename``, then
195 ``database_name``, then ``database_name+".tar.gz"``.
197 Parameters
198 ----------
199 database_name
200 The database name used to infer ``database_filename`` if not specified.
201 base_dir
202 The base path of data files (defaults to the ``bob_data_dir`` config, or
203 ``~/bob_data`` if not configured).
204 subdir
205 A sub directory for the protocols in ``base_dir``.
206 database_filename
207 If the file/directory name of the protocols is not the same as the
208 name of the database, this can be set to look in the correct file.
210 Returns
211 -------
212 A list of protocol names
213 The different protocols available for that database.
214 """
216 if base_dir is None:
217 base_dir = _get_local_data_directory()
219 if subdir is None:
220 subdir = "."
222 if database_filename is None:
223 database_filename = database_name
224 final_path: Path = Path(base_dir) / subdir / database_filename
225 if not final_path.is_dir():
226 database_filename = database_name + ".tar.gz"
228 final_path: Path = Path(base_dir) / subdir / database_filename
230 if archive.is_archive(final_path):
231 top_level_dirs = archive.list_dirs(final_path, show_files=False)
232 # Handle a database archive having database_name as top-level directory
233 if len(top_level_dirs) == 1 and top_level_dirs[0].name == database_name:
234 return [
235 p.name
236 for p in archive.list_dirs(
237 final_path, inner_dir=database_name, show_files=False
238 )
239 ]
240 return [p.name for p in top_level_dirs]
241 # Not an archive: list the dirs
242 return [p.name for p in final_path.iterdir() if p.is_dir()]
245def open_definition_file(
246 search_pattern: Union[PathLike[str], str],
247 database_name: str,
248 protocol: str,
249 base_dir: Union[PathLike[str], str, None] = None,
250 subdir: Union[PathLike[str], str, None] = "protocols",
251 database_filename: Optional[str] = None,
252) -> Union[TextIO, None]:
253 """Opens a protocol definition file inside a protocol directory.
255 Also handles protocols inside an archive.
256 """
257 search_path = get_protocol_path(
258 database_name, protocol, base_dir, subdir, database_filename
259 )
261 if archive.is_archive(search_path):
262 return archive.search_and_open(
263 search_pattern=search_pattern,
264 archive_path=search_path,
265 )
267 search_pattern = Path(search_pattern)
269 # we prepend './' to search_pattern because it might start with '/'
270 pattern = search_path / "**" / f"./{search_pattern.as_posix()}"
271 for path in glob.iglob(pattern.as_posix(), recursive=True):
272 if not Path(path).is_file():
273 continue
274 return open(path, mode="rt")
275 logger.info(f"Unable to locate and open a file that matches '{pattern}'.")
276 return None
279def list_group_paths(
280 database_name: str,
281 protocol: str,
282 base_dir: Union[PathLike[str], str, None] = None,
283 subdir: Union[PathLike[str], str] = "protocols",
284 database_filename: Optional[str] = None,
285) -> list[Path]:
286 """Returns the file paths of the groups in protocol"""
287 protocol_path = get_protocol_path(
288 database_name=database_name,
289 protocol=protocol,
290 base_dir=base_dir,
291 subdir=subdir,
292 database_filename=database_filename,
293 )
294 if archive.is_archive(protocol_path):
295 groups_inner = archive.list_dirs(protocol_path)
296 archive_path, inner_path = archive.path_and_subdir(protocol_path)
297 return [
298 Path(f"{archive_path.as_posix()}:{inner_path.as_posix()}/{g}")
299 for g in groups_inner
300 ]
301 return protocol_path.iterdir()
304def list_group_names(
305 database_name: str,
306 protocol: str,
307 base_dir: Union[PathLike[str], str, None] = None,
308 subdir: Union[PathLike[str], str] = "protocols",
309 database_filename: Optional[str] = None,
310) -> list[str]:
311 """Returns the group names of a protocol."""
312 paths = list_group_paths(
313 database_name=database_name,
314 protocol=protocol,
315 base_dir=base_dir,
316 subdir=subdir,
317 database_filename=database_filename,
318 )
319 # Supports groups as files or dirs
320 return [p.stem for p in paths] # ! This means group can't include a '.'
323def download_protocol_definition(
324 urls: Union[list[str], str],
325 destination_base_dir: Union[PathLike, None] = None,
326 destination_subdir: Union[str, None] = None,
327 destination_filename: Union[str, None] = None,
328 checksum: Union[str, None] = None,
329 checksum_fct: Callable[[Any, int], str] = hashing.sha256_hash,
330 force: bool = False,
331 makedirs: bool = True,
332) -> Path:
333 """Downloads a remote file locally.
335 Parameters
336 ----------
337 urls
338 The remote location of the server. If multiple addresses are given, we will try
339 to download from them in order until one succeeds.
340 destination_basedir
341 A path to a local directory where the file will be saved. If omitted, the file
342 will be saved in the folder pointed by the ``wdr.local_directory`` key in the
343 user configuration.
344 destination_subdir
345 An additional layer added to the destination directory (useful when using
346 ``destination_directory=None``).
347 destination_filename
348 The final name of the local file. If omitted, the file will keep the name of
349 the remote file.
350 checksum
351 When provided, will compute the file's checksum and compare to this.
352 checksum_fct
353 A callable that takes a ``reader`` and returns a hash.
354 force
355 Re-download and overwrite any existing file with the same name.
356 makedirs
357 Automatically make the parent directories of the new local file.
359 Returns
360 -------
361 The path to the new local file.
363 Raises
364 ------
365 RuntimeError
366 When the URLs provided are all invalid.
367 ValueError
368 When ``destination_filename`` is omitted and URLs point to files with different
369 names.
370 When the checksum of the file does not correspond to the provided ``checksum``.
371 """
373 if destination_filename is None:
374 destination_filename = _infer_filename_from_urls(urls=urls)
376 if destination_base_dir is None:
377 destination_base_dir = _get_local_data_directory()
379 destination_base_dir = Path(destination_base_dir)
381 if destination_subdir is not None:
382 destination_base_dir = destination_base_dir / destination_subdir
384 local_file = destination_base_dir / destination_filename
385 needs_download = True
387 if not force and local_file.is_file():
388 if checksum is None:
389 logger.info(
390 f"File {local_file} already exists, skipping download ({force=})."
391 )
392 needs_download = False
393 elif hashing.verify_file(local_file, checksum, checksum_fct):
394 logger.info(
395 f"File {local_file} already exists and checksum is valid."
396 )
397 needs_download = False
399 if needs_download:
400 if isinstance(urls, str):
401 urls = [urls]
403 for tries, url in enumerate(urls):
404 logger.debug(f"Retrieving file from '{url}'.")
405 try:
406 response = requests.get(url=url, timeout=10)
407 except requests.exceptions.ConnectionError as e:
408 if tries < len(urls) - 1:
409 logger.info(
410 f"Could not connect to {url}. Trying other URLs."
411 )
412 logger.debug(e)
413 continue
415 logger.debug(
416 f"http response: '{response.status_code}: {response.reason}'."
417 )
419 if response.ok:
420 logger.debug(f"Got file from {url}.")
421 break
422 if tries < len(urls) - 1:
423 logger.info(
424 f"Failed to get file from {url}, trying other URLs."
425 )
426 logger.debug(f"requests.response was\n{response}")
427 else:
428 raise RuntimeError(
429 f"Could not retrieve file from any of the provided URLs! ({urls=})"
430 )
432 if makedirs:
433 local_file.parent.mkdir(parents=True, exist_ok=True)
435 with local_file.open("wb") as f:
436 f.write(response.content)
438 if checksum is not None:
439 if not hashing.verify_file(local_file, checksum, hash_fct=checksum_fct):
440 if not needs_download:
441 raise ValueError(
442 f"The local file hash does not correspond to '{checksum}' "
443 f"and {force=} prevents overwriting."
444 )
445 raise ValueError(
446 "The downloaded file hash ('"
447 f"{hashing.compute_crc(local_file, hash_fct=checksum_fct)}') does "
448 f"not correspond to '{checksum}'."
449 )
451 return local_file