Last active
January 23, 2024 07:15
-
-
Save pszemraj/5bc8dcc59d99f6a8f5cd8c3f784d5b08 to your computer and use it in GitHub Desktop.
huggingface hub - download a full snapshot of a repository without using git
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
hf_hub_download.py | |
This script allows you to download a snapshot repository from the Hugging Face Hub to a local directory without needing Git or loading the model. | |
Usage: | |
python hf_hub_download.py <repo_id> [options] | |
Arguments: | |
<repo_id> Repository ID in the format "organization/repository". | |
Options: | |
--revision <str> Revision of the repository (commit/tag/branch). Default: None. | |
--cache_dir <str> Directory to store the downloaded files. Default: "~/.cache/huggingface/transformers". | |
--library_name <str> Name of the library associated with the download. Default: None. | |
--library_version <str> Version of the library associated with the download. Default: None. | |
--user_agent <str> User agent string. Default: None. | |
--ignore_files <str> List of file patterns to ignore. Default: None. | |
--use_auth_token <str> Authentication token for private repositories. Default: None. | |
""" | |
import logging | |
from pathlib import Path | |
from typing import Dict, List, Optional, Union | |
from fnmatch import fnmatch | |
from packaging import version | |
from tqdm.auto import tqdm | |
import fire | |
import huggingface_hub | |
from huggingface_hub import HfApi, HfFolder, cached_download, hf_hub_url | |
HUGGINGFACE_HUB_CACHE = Path("~/.cache/huggingface/transformers").expanduser() | |
DEFAULT_CACHE = Path.cwd() / "downloaded-models" | |
def setup_logging(): | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
return logger | |
def snapshot_download( | |
repo_id: str, | |
revision: Optional[str] = None, | |
cache_dir: Optional[Union[str, Path]] = None, | |
library_name: Optional[str] = None, | |
library_version: Optional[str] = None, | |
user_agent: Union[Dict, str, None] = None, | |
ignore_files: Optional[List[str]] = None, | |
use_auth_token: Union[bool, str, None] = None, | |
) -> str: | |
cache_dir = Path(cache_dir) if cache_dir else HUGGINGFACE_HUB_CACHE | |
_api = HfApi() | |
token = HfFolder.get_token() if use_auth_token else None | |
model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token) | |
storage_folder = cache_dir / repo_id.replace("/", "_") | |
all_files = model_info.siblings | |
modules_json_file = next( | |
(file for file in all_files if file.rfilename == "modules.json"), None | |
) | |
if modules_json_file is not None: | |
all_files.remove(modules_json_file) | |
all_files.append(modules_json_file) | |
logger = setup_logging() | |
pbar = tqdm(all_files, desc="Downloading files", unit="file") | |
for model_file in pbar: | |
if ignore_files is not None and any( | |
fnmatch(model_file.rfilename, pattern) for pattern in ignore_files | |
): | |
continue | |
url = hf_hub_url( | |
repo_id, filename=model_file.rfilename, revision=model_info.sha | |
) | |
relative_filepath = Path(model_file.rfilename) | |
nested_dirname = storage_folder / relative_filepath.parent | |
nested_dirname.mkdir(parents=True, exist_ok=True) | |
path = cached_download( | |
url=url, | |
cache_dir=storage_folder, | |
force_filename=str(relative_filepath), | |
library_name=library_name, | |
library_version=library_version, | |
user_agent=user_agent, | |
use_auth_token=use_auth_token, | |
legacy_cache_layout=version.parse(huggingface_hub.__version__) | |
>= version.parse("0.8.1"), | |
) | |
if Path(f"{path}.lock").exists(): | |
Path(f"{path}.lock").unlink() | |
pbar.close() | |
logger.info("Download completed.") | |
return str(storage_folder) | |
def main( | |
repo_id: str, | |
revision: Optional[str] = None, | |
cache_dir: Optional[Union[str, Path]] = None, | |
library_name: Optional[str] = None, | |
library_version: Optional[str] = None, | |
user_agent: Union[Dict, str, None] = None, | |
ignore_files: Optional[List[str]] = None, | |
use_auth_token: Union[bool, str, None] = None, | |
): | |
""" | |
Main function to download the snapshot repository. | |
snapshot_download - downloads a repo to a local directory without needing git or loading the model in AutoModelForBlah | |
**Credit to sentence-transformers** | |
Args: | |
repo_id (str): Repository ID in the format "organization/repository". | |
revision (str, optional): Revision of the repository (commit/tag/branch). Defaults to None. | |
cache_dir (Union[str, Path, None], optional): Directory to store the downloaded files. Defaults to None. | |
library_name (str, optional): Name of the library associated with the download. Defaults to None. | |
library_version (str, optional): Version of the library associated with the download. Defaults to None. | |
user_agent (Union[Dict, str, None], optional): User agent string. Defaults to None. | |
ignore_files (List[str], optional): List of file patterns to ignore. Defaults to None. | |
use_auth_token (Union[bool, str, None], optional): Authentication token for private repositories. Defaults to None. | |
Returns: | |
str: Storage folder path where the repository is downloaded. | |
""" | |
cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE | |
storage_folder = snapshot_download( | |
repo_id=repo_id, | |
revision=revision, | |
cache_dir=cache_dir, | |
library_name=library_name, | |
library_version=library_version, | |
user_agent=user_agent, | |
ignore_files=ignore_files, | |
use_auth_token=use_auth_token, | |
) | |
print(f"Snapshot repository downloaded to: {storage_folder}") | |
if __name__ == "__main__": | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment