Skip to content

Instantly share code, notes, and snippets.

@luuil
Last active March 20, 2025 11:18
Show Gist options
  • Save luuil/d1ce5b7c4cc6ceb8c39095909a4b919c to your computer and use it in GitHub Desktop.
Save luuil/d1ce5b7c4cc6ceb8c39095909a4b919c to your computer and use it in GitHub Desktop.
hfdown: Download file/repo from url of huggingface
"""
# @ Author: Lu Liu
# @ Create Time: 2024-12-05 14:28:43
# @ Modified by: Lu Liu
# @ Modified time: 2025-03-20 18:56:59
# @ Description: 从huggingface下载
# - 整个repo
# - 单个文件
# huya海聪平台外网加速器 https://ai.huya.com/docs/QA/common.html#%E9%80%9A%E7%94%A8%E5%A4%96%E7%BD%91%E4%B8%8B%E8%BD%BD%E5%8A%A0%E9%80%9F%E5%99%A8
# 使用时命令前加上 `hai run`, 如 hai run git clone xx
pip3 install hai --no-cache-dir -U -i https://pypi.huya.info/simple/
pip install huggingface_hub[hf_transfer]
"""
import argparse
import os
import re
import unittest
from enum import IntEnum
from huggingface_hub import hf_hub_download, snapshot_download
# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# https://huggingface.co/docs/huggingface_hub/guides/download#faster-downloads
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def parse_url(url):
# pattern = r'https?://huggingface\.co/(?P<repo_id>[^/]+/[^/]+)(?:/blob/[^/]+/(?P<filename>[^/]+))?'
# pattern = r"https://huggingface\.co/(?P<repo_id>[^/]+/[^/]+)(?:/blob/[^/]+/(?P<filename>.+))?"
pattern = r"""
^https://huggingface\.co/
(?:(?P<repo_type>datasets|spaces)/)? # 匹配可选的repo_type
(?P<repo_id>[^/]+/[^/]+) # 匹配必须的repo_id(user/repo格式)
(?:/blob/(?P<revision>[^/]+)(?:/(?P<filename>.*))?)? # 匹配可选的blob路径
/?$ # 允许结尾的斜杠
"""
match = re.match(pattern, url, re.VERBOSE)
return {
"repo_id": match.group("repo_id") if match else None,
"repo_type": match.group("repo_type") if match and match.group("repo_type") else None,
"revision": match.group("revision") if match and match.group("revision") else None,
"filename": match.group("filename") if match and match.group("filename") else None,
}
def download_repo(local_dir, repo_id, repo_type, revision):
snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, revision=revision)
def download_file(local_dir, repo_id, repo_type, revision, filename):
hf_hub_download(repo_id=repo_id, repo_type=repo_type, filename=filename, local_dir=local_dir, revision=revision)
class URLType(IntEnum):
REPO = 0
FILE = 1
def main():
model_dir = "."
parser = argparse.ArgumentParser(description="Download from Hugging Face")
parser.add_argument("url", type=str, help="Url of repo or file to be download.")
parser.add_argument("--local_dir", type=str, default=model_dir, help="Local directory to save files")
args = parser.parse_args()
repo_id, repo_type, revision, filename = parse_url(args.url).values()
assert repo_id is not None, "Only url from huggingface"
utype = URLType.REPO if filename is None else URLType.FILE
local_dir = os.path.join(args.local_dir, repo_id) if args.local_dir == model_dir else args.local_dir
local_path = local_dir if filename is None else os.path.join(local_dir, filename)
if utype == URLType.REPO:
print(
"downloading REPO:",
f'\t{"repo_id":<15}{repo_id}',
f'\t{"repo_type":<15}{repo_type}',
f'\t{"revision":<15}{revision}',
f'\t{"filename":<15}{filename}',
f'\t{"local_path":<15}{local_path}',
sep="\n",
)
download_repo(local_dir, repo_id, repo_type, revision)
elif utype == URLType.FILE:
print(
"downloading FILE:",
f'\t{"repo_id":<15}{repo_id}',
f'\t{"repo_type":<15}{repo_type}',
f'\t{"revision":<15}{revision}',
f'\t{"filename":<15}{filename}',
f'\t{"local_path":<15}{local_path}',
sep="\n",
)
download_file(local_dir, repo_id, repo_type, revision, filename)
class TestURLParser(unittest.TestCase):
test_cases = [
(
"基础仓库",
"https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503",
{
"repo_id": "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
"repo_type": None,
"revision": None,
"filename": None,
},
),
(
"基础仓库下文件",
"https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/figures/benchmark.jpg",
{
"repo_id": "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
"repo_type": None,
"revision": "main",
"filename": "figures/benchmark.jpg",
},
),
(
"基础仓库下的更深路径文件",
"https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/blob/main/split_files/text_encoders/llava_llama3_fp8_scaled.safetensors",
{
"repo_id": "Comfy-Org/HunyuanVideo_repackaged",
"repo_type": None,
"revision": "main",
"filename": "split_files/text_encoders/llava_llama3_fp8_scaled.safetensors",
},
),
(
"数据集仓库",
"https://huggingface.co/datasets/nvidia/Llama-Nemotron-Post-Training-Dataset-v1",
{
"repo_id": "nvidia/Llama-Nemotron-Post-Training-Dataset-v1",
"repo_type": "datasets",
"revision": None,
"filename": None,
},
),
(
"数据集下的文件",
"https://huggingface.co/datasets/nvidia/Llama-Nemotron-Post-Training-Dataset-v1/blob/main/figures/benchmark.jpg",
{
"repo_id": "nvidia/Llama-Nemotron-Post-Training-Dataset-v1",
"repo_type": "datasets",
"revision": "main",
"filename": "figures/benchmark.jpg",
},
),
(
"Spaces仓库",
"https://huggingface.co/spaces/smolagents/smolagents-leaderboard",
{"repo_id": "smolagents/smolagents-leaderboard", "repo_type": "spaces", "revision": None, "filename": None},
),
(
"Spaces下的文件",
"https://huggingface.co/spaces/smolagents/smolagents-leaderboard/blob/main/frontend/src/logo.svg",
{
"repo_id": "smolagents/smolagents-leaderboard",
"repo_type": "spaces",
"revision": "main",
"filename": "frontend/src/logo.svg",
},
),
]
def test_all_cases(self):
for case_name, url, expected in self.test_cases:
with self.subTest(case_name=case_name):
result = parse_url(url)
try:
self.assertEqual(result, expected)
print(f"\033[32m✓ PASS: {case_name}\033[0m")
except AssertionError:
print(f"\033[31m× FAIL: {case_name}\033[0m")
print(f"URL: {url}")
print(f"期望: {expected}")
print(f"实际: {result}\n")
raise
if __name__ == "__main__":
"""虎牙海聪上运行示例,加`hai run2`可以加速下载
python3 hf_fast_download.py https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev/blob/main/flux1-canny-dev.safetensors
"""
# unittest.main()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment