Last active
June 10, 2024 06:03
-
-
Save ddh0/f215cd97d84741ae6d2aa6d96bac76de to your computer and use it in GitHub Desktop.
Python script to download Hugging Face repos with an optional download speed limit
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# hfget.py | |
# Python 3.11.2 | |
# Hugging Face Filesystem API: | |
# https://huggingface.co/docs/huggingface_hub/en/guides/hf_file_system | |
import os | |
import sys | |
import time | |
from typing import Union | |
try: | |
from huggingface_hub import HfFileSystem | |
except ImportError as exc: | |
exc.add_note("HINT: Try `pip install --upgrade huggingface_hub`") | |
raise exc | |
# change this to your actual Hugging Face token | |
HF_TOKEN = "hf_XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" | |
# change this if you don't want to limit download speed | |
# use a high number like 100_000 to effectively disable the limit | |
APPROX_MAX_DOWNLOAD_SPEED_MEGABYTES_PER_SECOND: int = 10 | |
# do not change these unless you know what you're doing | |
_DL_FETCH_PER_SECOND: int = 100 | |
_DL_COMPENSATION_FACTOR: Union[int, float] = 1.5 | |
_dl_bytes_per_second: int = int(APPROX_MAX_DOWNLOAD_SPEED_MEGABYTES_PER_SECOND * 10**6) | |
_dl_bytes_per_fetch: int = int(_DL_COMPENSATION_FACTOR * (_dl_bytes_per_second / _DL_FETCH_PER_SECOND)) | |
dest_folder = input('Please enter the destination folder path\n > ') | |
if not dest_folder.endswith(os.sep): | |
dest_folder = dest_folder + os.sep | |
if os.path.isfile(dest_folder): | |
raise IsADirectoryError( | |
"Destination path is a file, but it should be a directory" | |
) | |
if not os.path.exists(dest_folder): | |
os.mkdir(dest_folder) | |
if not os.path.exists(dest_folder): | |
raise OSError( | |
f"Failed to create directory at {dest_folder!r}" | |
) | |
print(f"Created directory at {dest_folder!r}") | |
else: | |
if (num_files_in_dst := len(os.listdir(dest_folder))) > 0: | |
print( | |
f"Warning: Destination directory is not empty ({num_files_in_dst} files)", | |
file=sys.stderr | |
) | |
repo_id = input('Please enter the Hugging Face repo_id or URL\n > ') | |
prefixes = [ | |
'https://huggingface.co/', | |
'huggingface.co/' | |
'https://hf.co/', | |
'hf.co/' | |
'/' | |
] | |
suffixes = [ | |
'/tree/main/', | |
'/tree/main', | |
'/' | |
] | |
for prefix in prefixes: | |
if repo_id.startswith(prefix): | |
repo_id = repo_id.removeprefix(prefix) | |
for suffix in suffixes: | |
if repo_id.endswith(suffix): | |
repo_id = repo_id.removesuffix(suffix) | |
exclude_str = input('Please enter a string. Paths containing this string will be skipped. Press ENTER to not skip any.\n > ') | |
if exclude_str in ['', ' ', '\n']: | |
exclude_str = None | |
print(f"Connecting to Hugging Face Filesystem API...") | |
HfFS = HfFileSystem(token=HF_TOKEN) | |
print(f"Getting directory contents...") | |
remote_files: list[str] = HfFS.ls( | |
path=repo_id, | |
detail=False | |
) | |
print("Start operation") | |
for file in remote_files: | |
if exclude_str is not None and exclude_str in file: | |
print(f"Skipping: {file!r}") | |
continue | |
else: | |
print(f"In progress: {file!r}") | |
# destination folder should have same structure as remote | |
_file = file.removeprefix(repo_id + HfFS.sep) | |
dest_path = str(dest_folder + _file).replace(HfFS.sep, os.sep) | |
with HfFS.open(file, 'rb') as src: | |
if os.path.exists(dest_path): | |
if os.stat(dest_path).st_size == src.size: | |
print('Already downloaded (exact size match)') | |
continue | |
else: | |
os.remove(dest_path) | |
print("Removed partial download") | |
with open(dest_path, 'wb') as dst: | |
print(f"Downloading...") | |
try: | |
while src.loc < src.size: | |
dst.write(src.read(_dl_bytes_per_fetch)) | |
time.sleep(1/_DL_FETCH_PER_SECOND) | |
except EOFError: | |
print("EOFError (safe to ignore)") | |
except KeyboardInterrupt: | |
print("\nOperation aborted due to KeyboardInterrupt", file=sys.stderr) | |
sys.exit(1) | |
if (dst_size := os.stat(dest_path).st_size) != src.size: | |
print( | |
f"Remote file: {file!r}\n" | |
f"Downloaded file: {dest_path!r}\n" | |
f"Remote size: {src.size} bytes\n" | |
f"Downloaded size: {dst_size} bytes\n" | |
"Expected downloaded file size to match remote file size after completion (corrupt download)\n" | |
"Operation aborted due to size mismatch", | |
file=sys.stderr | |
) | |
sys.exit(1) | |
else: | |
print("Success (exact size match)") | |
print("Operation completed successfully") | |
sys.exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment