Skip to content

Commit 3ccc73a

Browse files
fineguyThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Append dataset name to download dir.
PiperOrigin-RevId: 681832795
1 parent 2638dc2 commit 3ccc73a

File tree

2 files changed

+79
-34
lines changed

2 files changed

+79
-34
lines changed

tensorflow_datasets/core/download/download_manager.py

Lines changed: 74 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,24 @@ def downloaded_size(self) -> int:
316316
return sum(url_info.size for url_info in self._recorded_url_infos.values())
317317

318318
def _get_dl_path(
319-
self, resource: resource_lib.Resource, checksum: str | None = None
319+
self,
320+
resource: resource_lib.Resource,
321+
checksum: str | None = None,
322+
legacy_mode: bool = False,
320323
) -> epath.Path:
321-
return (
322-
self._download_dir
323-
/ resource.relative_download_dir
324-
/ resource_lib.get_dl_fname(resource.url, checksum)
325-
)
324+
"""Returns the path where the resource should be downloaded.
325+
326+
Args:
327+
resource: The resource to download.
328+
checksum: The checksum of the resource.
329+
legacy_mode: If True, returns path in the legacy format without dataset
330+
name in the path.
331+
"""
332+
download_dir = self._download_dir
333+
if not legacy_mode:
334+
download_dir /= self._dataset_name
335+
download_dir /= resource.relative_download_dir
336+
return download_dir / resource_lib.get_dl_fname(resource.url, checksum)
326337

327338
@property
328339
def register_checksums(self):
@@ -353,6 +364,50 @@ def _get_manually_downloaded_path(
353364

354365
return manual_path
355366

367+
def _get_checksum_dl_result(
368+
self, resource: resource_lib.Resource, legacy_mode: bool = False
369+
) -> downloader.DownloadResult | None:
370+
"""Checks if the download has been cached and checksum is known."""
371+
expected_url_info = self._url_infos.get(resource.url)
372+
373+
if not expected_url_info:
374+
return None
375+
376+
checksum_path = self._get_dl_path(
377+
resource, expected_url_info.checksum, legacy_mode=legacy_mode
378+
)
379+
if not resource_lib.is_locally_cached(checksum_path):
380+
return None
381+
382+
return downloader.DownloadResult(
383+
path=checksum_path, url_info=expected_url_info
384+
)
385+
386+
def _get_url_dl_result(
387+
self, resource: resource_lib.Resource, legacy_mode: bool = False
388+
) -> downloader.DownloadResult | None:
389+
"""Checks if the download has been cached and checksum is unknown."""
390+
url_path = self._get_dl_path(resource, legacy_mode=legacy_mode)
391+
if not resource_lib.is_locally_cached(url_path):
392+
return None
393+
394+
expected_url_info = self._url_infos.get(resource.url)
395+
url_info = downloader.read_url_info(url_path)
396+
397+
if expected_url_info and expected_url_info != url_info:
398+
# If checksums are registered but do not match, trigger a new
399+
# download (e.g. previous file corrupted, checksums updated)
400+
return None
401+
elif self._is_checksum_registered(url=resource.url):
402+
# Checksums were registered: Rename -> checksum_path
403+
path = self._get_dl_path(resource, url_info.checksum)
404+
resource_lib.replace_info_file(url_path, path)
405+
url_path.replace(path)
406+
return downloader.DownloadResult(path=path, url_info=url_info)
407+
else:
408+
# Checksums not registered: -> do nothing
409+
return downloader.DownloadResult(path=url_path, url_info=url_info)
410+
356411
# Synchronize and memoize decorators ensure same resource will only be
357412
# processed once, even if passed twice to download_manager.
358413
@utils.build_synchronize_decorator()
@@ -399,32 +454,20 @@ def _download_or_get_cache(
399454
dl_result = None
400455

401456
# Download has been cached (checksum known)
402-
elif expected_url_info and resource_lib.is_locally_cached(
403-
checksum_path := self._get_dl_path(resource, expected_url_info.checksum)
404-
):
405-
dl_result = downloader.DownloadResult(
406-
path=checksum_path, url_info=expected_url_info
407-
)
457+
elif dl_result := self._get_checksum_dl_result(resource):
458+
pass
459+
460+
# Download has been cached (checksum known, legacy mode)
461+
elif dl_result := self._get_checksum_dl_result(resource, legacy_mode=True):
462+
pass
408463

409464
# Download has been cached (checksum unknown)
410-
elif resource_lib.is_locally_cached(
411-
url_path := self._get_dl_path(resource)
412-
):
413-
url_info = downloader.read_url_info(url_path)
414-
415-
if expected_url_info and expected_url_info != url_info:
416-
# If checksums are registered but do not match, trigger a new
417-
# download (e.g. previous file corrupted, checksums updated)
418-
dl_result = None
419-
elif self._is_checksum_registered(url=url):
420-
# Checksums were registered: Rename -> checksum_path
421-
path = self._get_dl_path(resource, url_info.checksum)
422-
resource_lib.replace_info_file(url_path, path)
423-
url_path.replace(path)
424-
dl_result = downloader.DownloadResult(path=path, url_info=url_info)
425-
else:
426-
# Checksums not registered: -> do nothing
427-
dl_result = downloader.DownloadResult(path=url_path, url_info=url_info)
465+
elif dl_result := self._get_url_dl_result(resource):
466+
pass
467+
468+
# Download has been cached (checksum unknown, legacy mode)
469+
elif dl_result := self._get_url_dl_result(resource, legacy_mode=True):
470+
pass
428471

429472
# Cache not found
430473
else:
@@ -504,7 +547,7 @@ def _download(
504547
download_tmp_dir = (
505548
url_path.parent / f'{url_path.name}.tmp.{uuid.uuid4().hex}'
506549
)
507-
download_tmp_dir.mkdir()
550+
download_tmp_dir.mkdir(parents=True, exist_ok=True)
508551
logging.info(f'Downloading {url} into {download_tmp_dir}...')
509552
future = self._downloader.download(
510553
url, download_tmp_dir, verify=self._verify_ssl

tensorflow_datasets/core/download/download_manager_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
TAR = resource_lib.ExtractMethod.TAR
3434
NO_EXTRACT = resource_lib.ExtractMethod.NO_EXTRACT
3535

36+
_DATASET_NAME = 'mnist'
37+
3638
_CHECKSUMS_DIR = epath.Path('/checksums')
3739
_CHECKSUMS_PATH = _CHECKSUMS_DIR / 'checksums.tsv'
3840

@@ -60,10 +62,10 @@ def __init__(
6062
)
6163

6264
self.file_name = resource_lib.get_dl_fname(self.url, self.url_info.checksum)
63-
self.file_path = _DOWNLOAD_DIR / self.file_name
65+
self.file_path = _DOWNLOAD_DIR / _DATASET_NAME / self.file_name
6466

6567
self.url_name = resource_lib.get_dl_fname(self.url)
66-
self.url_path = _DOWNLOAD_DIR / self.url_name
68+
self.url_path = _DOWNLOAD_DIR / _DATASET_NAME / self.url_name
6769

6870
self.manual_path = _MANUAL_DIR / name
6971
extract_method = resource_lib.guess_extract_method(name)
@@ -177,7 +179,7 @@ def _get_manager(
177179
**kwargs,
178180
):
179181
manager = dm.DownloadManager(
180-
dataset_name='mnist',
182+
dataset_name=_DATASET_NAME,
181183
download_dir=dl_dir,
182184
extract_dir=extract_dir,
183185
manual_dir=manual_dir,

0 commit comments

Comments
 (0)