Last active
March 20, 2025 11:18
-
-
Save luuil/d1ce5b7c4cc6ceb8c39095909a4b919c to your computer and use it in GitHub Desktop.
hfdown: Download file/repo from url of huggingface
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
# @ Author: Lu Liu | |
# @ Create Time: 2024-12-05 14:28:43 | |
# @ Modified by: Lu Liu | |
# @ Modified time: 2025-03-20 18:56:59 | |
# @ Description: 从huggingface下载 | |
# - 整个repo | |
# - 单个文件 | |
# huya海聪平台外网加速器 https://ai.huya.com/docs/QA/common.html#%E9%80%9A%E7%94%A8%E5%A4%96%E7%BD%91%E4%B8%8B%E8%BD%BD%E5%8A%A0%E9%80%9F%E5%99%A8 | |
# 使用时命令前加上 `hai run`, 如 hai run git clone xx | |
pip3 install hai --no-cache-dir -U -i https://pypi.huya.info/simple/ | |
pip install huggingface_hub[hf_transfer] | |
""" | |
import argparse | |
import os | |
import re | |
import unittest | |
from enum import IntEnum | |
from huggingface_hub import hf_hub_download, snapshot_download | |
# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" | |
# https://huggingface.co/docs/huggingface_hub/guides/download#faster-downloads | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
def parse_url(url): | |
# pattern = r'https?://huggingface\.co/(?P<repo_id>[^/]+/[^/]+)(?:/blob/[^/]+/(?P<filename>[^/]+))?' | |
# pattern = r"https://huggingface\.co/(?P<repo_id>[^/]+/[^/]+)(?:/blob/[^/]+/(?P<filename>.+))?" | |
pattern = r""" | |
^https://huggingface\.co/ | |
(?:(?P<repo_type>datasets|spaces)/)? # 匹配可选的repo_type | |
(?P<repo_id>[^/]+/[^/]+) # 匹配必须的repo_id(user/repo格式) | |
(?:/blob/(?P<revision>[^/]+)(?:/(?P<filename>.*))?)? # 匹配可选的blob路径 | |
/?$ # 允许结尾的斜杠 | |
""" | |
match = re.match(pattern, url, re.VERBOSE) | |
return { | |
"repo_id": match.group("repo_id") if match else None, | |
"repo_type": match.group("repo_type") if match and match.group("repo_type") else None, | |
"revision": match.group("revision") if match and match.group("revision") else None, | |
"filename": match.group("filename") if match and match.group("filename") else None, | |
} | |
def download_repo(local_dir, repo_id, repo_type, revision): | |
snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, revision=revision) | |
def download_file(local_dir, repo_id, repo_type, revision, filename): | |
hf_hub_download(repo_id=repo_id, repo_type=repo_type, filename=filename, local_dir=local_dir, revision=revision) | |
class URLType(IntEnum): | |
REPO = 0 | |
FILE = 1 | |
def main(): | |
model_dir = "." | |
parser = argparse.ArgumentParser(description="Download from Hugging Face") | |
parser.add_argument("url", type=str, help="Url of repo or file to be download.") | |
parser.add_argument("--local_dir", type=str, default=model_dir, help="Local directory to save files") | |
args = parser.parse_args() | |
repo_id, repo_type, revision, filename = parse_url(args.url).values() | |
assert repo_id is not None, "Only url from huggingface" | |
utype = URLType.REPO if filename is None else URLType.FILE | |
local_dir = os.path.join(args.local_dir, repo_id) if args.local_dir == model_dir else args.local_dir | |
local_path = local_dir if filename is None else os.path.join(local_dir, filename) | |
if utype == URLType.REPO: | |
print( | |
"downloading REPO:", | |
f'\t{"repo_id":<15}{repo_id}', | |
f'\t{"repo_type":<15}{repo_type}', | |
f'\t{"revision":<15}{revision}', | |
f'\t{"filename":<15}{filename}', | |
f'\t{"local_path":<15}{local_path}', | |
sep="\n", | |
) | |
download_repo(local_dir, repo_id, repo_type, revision) | |
elif utype == URLType.FILE: | |
print( | |
"downloading FILE:", | |
f'\t{"repo_id":<15}{repo_id}', | |
f'\t{"repo_type":<15}{repo_type}', | |
f'\t{"revision":<15}{revision}', | |
f'\t{"filename":<15}{filename}', | |
f'\t{"local_path":<15}{local_path}', | |
sep="\n", | |
) | |
download_file(local_dir, repo_id, repo_type, revision, filename) | |
class TestURLParser(unittest.TestCase): | |
test_cases = [ | |
( | |
"基础仓库", | |
"https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503", | |
{ | |
"repo_id": "mistralai/Mistral-Small-3.1-24B-Instruct-2503", | |
"repo_type": None, | |
"revision": None, | |
"filename": None, | |
}, | |
), | |
( | |
"基础仓库下文件", | |
"https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/figures/benchmark.jpg", | |
{ | |
"repo_id": "mistralai/Mistral-Small-3.1-24B-Instruct-2503", | |
"repo_type": None, | |
"revision": "main", | |
"filename": "figures/benchmark.jpg", | |
}, | |
), | |
( | |
"基础仓库下的更深路径文件", | |
"https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/blob/main/split_files/text_encoders/llava_llama3_fp8_scaled.safetensors", | |
{ | |
"repo_id": "Comfy-Org/HunyuanVideo_repackaged", | |
"repo_type": None, | |
"revision": "main", | |
"filename": "split_files/text_encoders/llava_llama3_fp8_scaled.safetensors", | |
}, | |
), | |
( | |
"数据集仓库", | |
"https://huggingface.co/datasets/nvidia/Llama-Nemotron-Post-Training-Dataset-v1", | |
{ | |
"repo_id": "nvidia/Llama-Nemotron-Post-Training-Dataset-v1", | |
"repo_type": "datasets", | |
"revision": None, | |
"filename": None, | |
}, | |
), | |
( | |
"数据集下的文件", | |
"https://huggingface.co/datasets/nvidia/Llama-Nemotron-Post-Training-Dataset-v1/blob/main/figures/benchmark.jpg", | |
{ | |
"repo_id": "nvidia/Llama-Nemotron-Post-Training-Dataset-v1", | |
"repo_type": "datasets", | |
"revision": "main", | |
"filename": "figures/benchmark.jpg", | |
}, | |
), | |
( | |
"Spaces仓库", | |
"https://huggingface.co/spaces/smolagents/smolagents-leaderboard", | |
{"repo_id": "smolagents/smolagents-leaderboard", "repo_type": "spaces", "revision": None, "filename": None}, | |
), | |
( | |
"Spaces下的文件", | |
"https://huggingface.co/spaces/smolagents/smolagents-leaderboard/blob/main/frontend/src/logo.svg", | |
{ | |
"repo_id": "smolagents/smolagents-leaderboard", | |
"repo_type": "spaces", | |
"revision": "main", | |
"filename": "frontend/src/logo.svg", | |
}, | |
), | |
] | |
def test_all_cases(self): | |
for case_name, url, expected in self.test_cases: | |
with self.subTest(case_name=case_name): | |
result = parse_url(url) | |
try: | |
self.assertEqual(result, expected) | |
print(f"\033[32m✓ PASS: {case_name}\033[0m") | |
except AssertionError: | |
print(f"\033[31m× FAIL: {case_name}\033[0m") | |
print(f"URL: {url}") | |
print(f"期望: {expected}") | |
print(f"实际: {result}\n") | |
raise | |
if __name__ == "__main__": | |
"""虎牙海聪上运行示例,加`hai run2`可以加速下载 | |
python3 hf_fast_download.py https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev/blob/main/flux1-canny-dev.safetensors | |
""" | |
# unittest.main() | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment