Skip to content

Instantly share code, notes, and snippets.

@ActiveTK
Last active October 12, 2025 13:37
Show Gist options
  • Select an option

  • Save ActiveTK/16aa6fee916734678a337c1bbd2373a0 to your computer and use it in GitHub Desktop.

Select an option

Save ActiveTK/16aa6fee916734678a337c1bbd2373a0 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import math
import os
import hashlib
import shutil
import subprocess
import sys
import time
from typing import Dict, List, Tuple
MIN_EXCLUSIVE_VCPUS = 16
MIN_NET_MBPS = 200.0
ORDER = "+dph_total,+id"
PER_QUERY_LIMIT = 1000
CMD_TIMEOUT_SEC = 30
SLEEP_BETWEEN_QUERIES = 0.0
TOP_N = 300
CUM_MILESTONES = [10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, 250, 300]
INITIAL_PRICE_MIN = 0.00
INITIAL_PRICE_MAX = 5.00
MIN_PRICE_BAND_WIDTH = 0.025
CACHE_TTL_SEC = 30 * 60
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
CACHE_DIR = os.path.join(SCRIPT_DIR, "cache")
os.makedirs(CACHE_DIR, exist_ok=True)
RESULTS_JSON = os.path.join(SCRIPT_DIR, "results.json")
EXCLUDED_ISO2 = {"CN", "HK"}
EXCLUDED_ISO3 = {"CHN", "HKG"}
EXCLUDED_NAMES = {"china", "hong kong", "hongkong"}
def log(msg: str):
print(msg, file=sys.stderr, flush=True)
def to_gb(val) -> float:
if val is None:
return 0.0
v = float(val)
if v > 4096:
return v / 1024.0
return v
def metric(o: dict) -> float:
dph = float(o.get("dph_total") or 0.0)
vcpu = float(o.get("cpu_cores_effective") or 0.0)
if dph > 0 and vcpu > MIN_EXCLUSIVE_VCPUS:
return vcpu / dph
return -math.inf
def pass_network(o: dict) -> bool:
up = float(o.get("inet_up") or 0.0)
dn = float(o.get("inet_down") or 0.0)
return (up >= MIN_NET_MBPS) and (dn >= MIN_NET_MBPS)
def pass_cpu(o: dict) -> bool:
vcpu = float(o.get("cpu_cores_effective") or 0.0)
return vcpu > MIN_EXCLUSIVE_VCPUS
def pass_geo(o: dict) -> bool:
vals = []
for k in ("geolocation", "country_code", "country", "geolocation_country", "geolocation_country_code"):
v = str(o.get(k) or "").strip()
if v:
vals.append(v)
if not vals:
return False
s_up = " ".join(vals).upper()
if any(code in s_up.split() for code in (EXCLUDED_ISO2 | EXCLUDED_ISO3)):
return False
s_lo = " ".join(vals).lower()
if any(name in s_lo for name in EXCLUDED_NAMES):
return False
return True
def fmt_num(x, digits=1):
return f"{x:.{digits}f}"
def row(o: dict) -> str:
inst_id = str(o.get("id", ""))
dph = float(o.get("dph_total") or 0.0)
vcpu = float(o.get("cpu_cores_effective") or 0.0)
ram_gb = to_gb(o.get("cpu_ram"))
ngpu = int(o.get("num_gpus") or 0)
model = str(o.get("gpu_name") or "-")
score = metric(o)
return (
f"{inst_id:>9} "
f"[{fmt_num(score, 1):>8}] "
f"{fmt_num(vcpu, 0):>6} "
f"{fmt_num(ram_gb, 0):>7}GB "
f"{dph:>8.3f} "
f"{ngpu:>2} "
f"{model[:18]:<18}"
)
def _cache_key() -> str:
key_payload = json.dumps({
"min_vcpu_exclusive": MIN_EXCLUSIVE_VCPUS,
"min_net": MIN_NET_MBPS,
"order": ORDER,
"per_query_limit": PER_QUERY_LIMIT,
"timeout": CMD_TIMEOUT_SEC,
"sleep": SLEEP_BETWEEN_QUERIES,
"price_min": INITIAL_PRICE_MIN,
"price_max": INITIAL_PRICE_MAX,
"min_band": MIN_PRICE_BAND_WIDTH,
"top_n": TOP_N,
"milestones": CUM_MILESTONES,
"geo_excluded": sorted(list(EXCLUDED_ISO2 | EXCLUDED_ISO3 | EXCLUDED_NAMES)),
}, sort_keys=True)
return hashlib.sha256(key_payload.encode("utf-8")).hexdigest()
def _cache_path() -> str:
return os.path.join(CACHE_DIR, f"offers_{_cache_key()}.json")
def load_cache() -> List[dict] | None:
path = _cache_path()
if not os.path.exists(path):
log(f"[CACHE] MISS path={path}")
return None
try:
with open(path, "r", encoding="utf-8") as f:
obj = json.load(f)
ts = float(obj.get("ts", 0))
age = time.time() - ts
if age <= CACHE_TTL_SEC:
log(f"[CACHE] HIT age={int(age)}s path={path}")
offers = obj.get("offers", [])
return offers if isinstance(offers, list) else None
log(f"[CACHE] STALE age={int(age)}s path={path}")
return None
except Exception as e:
log(f"[CACHE] ERROR {e}")
return None
def save_cache(offers: List[dict]) -> None:
path = _cache_path()
tmp_path = path + ".tmp"
try:
with open(tmp_path, "w", encoding="utf-8") as f:
json.dump({"ts": time.time(), "offers": offers}, f, ensure_ascii=False)
os.replace(tmp_path, path)
log(f"[CACHE] SAVED path={path} offers={len(offers)}")
except Exception as e:
log(f"[CACHE] SAVE ERROR: {e}")
def run_vast_query(query: str) -> List[dict]:
if not shutil.which("vastai"):
log("ERROR: 'vastai' not found in PATH.")
sys.exit(1)
cmd = [
"vastai", "search", "offers",
"-n", query,
"--raw",
"--limit", str(PER_QUERY_LIMIT),
"-o", ORDER,
]
log("----------------------------------------------------------------")
log(f"[QUERY] {query}")
log(f"[CMD ] {' '.join(cmd)}")
t0 = time.monotonic()
try:
res = subprocess.run(cmd, check=True, capture_output=True, text=True, timeout=CMD_TIMEOUT_SEC)
dt = time.monotonic() - t0
except subprocess.TimeoutExpired:
dt = time.monotonic() - t0
log(f"[TIMEOUT] >{CMD_TIMEOUT_SEC}s ({dt:.2f}s)")
return []
except subprocess.CalledProcessError as e:
dt = time.monotonic() - t0
log(f"[ERROR] vastai failed ({dt:.2f}s)\nSTDERR:\n{e.stderr.strip()}")
return []
try:
data = json.loads(res.stdout)
offers = data["offers"] if isinstance(data, dict) and "offers" in data else (
data if isinstance(data, list) else []
)
log(f"[OK ] {len(offers)} rows ({dt:.2f}s)")
return offers
except json.JSONDecodeError:
log(f"[WARN ] JSON parse failed ({dt:.2f}s)\nSTDERR:\n{res.stderr.strip()}")
return []
def build_base_tail() -> str:
geo_not = " geolocation!=HK geolocation!=CN"
return (
f"verified=any reliability=any external=any rented=any rentable=True "
f"inet_up>={int(MIN_NET_MBPS)} inet_down>={int(MIN_NET_MBPS)} "
f"cpu_cores_effective>{MIN_EXCLUSIVE_VCPUS}" + geo_not
)
def fetch_price_band(lo: float, hi: float) -> Tuple[List[dict], bool]:
base_tail = build_base_tail()
query = f"{base_tail} dph_total>={lo:.3f} dph_total<={hi:.3f}"
rows = run_vast_query(query)
saturated = (len(rows) >= PER_QUERY_LIMIT)
return rows, saturated
def adaptive_collect(lo: float, hi: float, offers_by_id: Dict[int, dict]):
if hi - lo < MIN_PRICE_BAND_WIDTH:
rows, _ = fetch_price_band(lo, hi)
_keep_rows(rows, offers_by_id)
return
rows, saturated = fetch_price_band(lo, hi)
if saturated:
mid = (lo + hi) / 2.0
log(f"[SPLIT] {lo:.3f}-{hi:.3f} -> {lo:.3f}-{mid:.3f} / {mid:.3f}-{hi:.3f}")
if SLEEP_BETWEEN_QUERIES > 0:
time.sleep(SLEEP_BETWEEN_QUERIES)
adaptive_collect(lo, mid, offers_by_id)
if SLEEP_BETWEEN_QUERIES > 0:
time.sleep(SLEEP_BETWEEN_QUERIES)
adaptive_collect(mid, hi, offers_by_id)
else:
_keep_rows(rows, offers_by_id)
def _keep_rows(rows: List[dict], offers_by_id: Dict[int, dict]):
kept_this = 0
for o in rows:
if not pass_network(o): continue
if not pass_cpu(o): continue
if (o.get("dph_total") or 0) <= 0: continue
if not pass_geo(o): continue
oid = int(o.get("id"))
if oid not in offers_by_id or metric(o) > metric(offers_by_id[oid]):
offers_by_id[oid] = o
kept_this += 1
log(f"[KEEP ] {kept_this} accepted, unique {len(offers_by_id)}")
def main():
cached = load_cache()
if cached is not None:
offers_list = cached
log(f"[CACHE] USING {len(offers_list)} rows")
else:
log("======== searching ========")
log(f"[BAND0] dph_total {INITIAL_PRICE_MIN:.2f} - {INITIAL_PRICE_MAX:.2f}")
offers_by_id: Dict[int, dict] = {}
adaptive_collect(INITIAL_PRICE_MIN, INITIAL_PRICE_MAX, offers_by_id)
log("=================================")
offers_list = list(offers_by_id.values())
save_cache(offers_list)
ranked = sorted(offers_list, key=metric, reverse=True)
top = ranked[:TOP_N]
header = (
f"{'ID':>9} "
f"{'[SCORE]':>10} "
f"{'vCPUs':>6} "
f"{'RAM':>7} "
f"{'$/hr':>8} "
f"{'N':>2} "
f"{'MODEL':<18}"
)
print(header)
print("-" * len(header))
sum_cost = 0.0
sum_vcpu = 0.0
sum_ram_gb = 0.0
milestones_set = set(CUM_MILESTONES)
compact_rows = []
for idx, o in enumerate(top, 1):
print(row(o))
dph = float(o.get("dph_total") or 0.0)
vcpu = float(o.get("cpu_cores_effective") or 0.0)
ramg = to_gb(o.get("cpu_ram"))
sum_cost += dph
sum_vcpu += vcpu
sum_ram_gb += ramg
compact_rows.append({
"id": int(o.get("id")),
"score": metric(o),
"vcpus": vcpu,
"ram_gb": ramg,
"dph_total": dph,
"num_gpus": int(o.get("num_gpus") or 0),
"gpu_name": o.get("gpu_name") or "-",
"inet_up": float(o.get("inet_up") or 0.0),
"inet_down": float(o.get("inet_down") or 0.0),
})
if idx in milestones_set:
print("-" * len(header))
print(
f"{('TOTAL('+str(idx)+')'):>9} "
f"{'':>10} "
f"{fmt_num(sum_vcpu, 0):>6} "
f"{fmt_num(sum_ram_gb, 0):>7}GB "
f"{sum_cost:>8.3f} "
f"{'':>2} "
f"{''}"
)
print("-" * len(header))
try:
with open(RESULTS_JSON, "w", encoding="utf-8") as f:
json.dump(compact_rows, f, ensure_ascii=False, indent=2)
log(f"[SAVE ] results.json -> {RESULTS_JSON} ({len(compact_rows)} rows)")
except Exception as e:
log(f"[ERROR] results.json save failed: {e}")
if not top:
log("[INFO] No candidates. Adjust price range or bandwidth thresholds.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment