Last active
March 7, 2025 08:43
-
-
Save brownan/29108eb73f547dc6f6f55ebe4f2c5e94 to your computer and use it in GitHub Desktop.
Ollama command line interface with Markdown rendering
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/env python3 | |
import argparse | |
import io | |
import json | |
import os | |
import sys | |
import urllib.parse | |
import urllib.request | |
from typing import Optional, NamedTuple | |
import rich.console | |
import rich.live | |
import rich.markdown | |
import rich.panel | |
import rich.spinner | |
import rich.text | |
import rich.table | |
import rich.pretty | |
query_url = urllib.parse.urljoin( | |
os.environ.get("OLLAMA_HOST", "http://localhost:11434"), "/api/generate" | |
) | |
class ExecutionParams(NamedTuple): | |
raw: bool | |
stats: bool | |
seed: Optional[int] = None | |
temperature: Optional[float] = None | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--models", | |
"--model", | |
default="llama3", | |
help="The name of the models to run. Multiple models can be specified separated by commas", | |
) | |
parser.add_argument( | |
"--raw", | |
action="store_true", | |
help="Do not format the results with markdown. Output the results to stdout", | |
) | |
parser.add_argument( | |
"--stats", | |
action="store_true", | |
help="Output some stats about the generation at the end", | |
) | |
parser.add_argument( | |
"--seed", type=int, help="The seed to use in the model evaluation. " | |
"for deterministic output you must also set temperature=0." | |
) | |
parser.add_argument( | |
"--temperature", | |
type=float, | |
help="The temperature to use in the model evaluation. A value of 0 means the model" | |
" will always pick the highest probability token. Values above 0 add some" | |
" randomness to token selection and generally make the model more creative." | |
" For deterministic output, set this to 0 and set a seed to a non-zero constant." | |
" Default in ollama is 0.8.", | |
) | |
parser.add_argument("prompt", nargs="*") | |
args = vars(parser.parse_args()) | |
raw = args["raw"] | |
if args["prompt"]: | |
prompt = " ".join(args["prompt"]) | |
else: | |
prompt = sys.stdin.read() | |
console = rich.console.Console(stderr=raw) | |
if not raw: | |
console.print(rich.panel.Panel.fit("[bold]Prompt")) | |
console.print(rich.text.Text(prompt)) | |
models = args["models"].split(",") | |
for model in models: | |
run_ollama( | |
model, | |
prompt, | |
console, | |
ExecutionParams( | |
raw=raw, | |
stats=args["stats"], | |
seed=args["seed"], | |
temperature=args["temperature"], | |
), | |
) | |
def run_ollama(model, prompt, console: rich.console.Console, params: ExecutionParams): | |
raw = params.raw | |
if not raw: | |
console.print() | |
console.print(rich.panel.Panel.fit("[bold]" + model)) | |
req_data = { | |
"model": model, | |
"prompt": prompt, | |
"options": {}, | |
} | |
if params.temperature is not None: | |
req_data["options"]["temperature"] = params.temperature | |
if params.seed is not None: | |
req_data["options"]["seed"] = params.seed | |
request = urllib.request.Request( | |
query_url, | |
data=json.dumps( | |
req_data | |
).encode("utf-8"), | |
headers={ | |
"Content-Type": "application/json; charset=utf-8", | |
}, | |
) | |
output = io.StringIO() | |
spinner = rich.spinner.Spinner( | |
"dots", text="Loading model...", style="status.spinner", speed=1.0 | |
) | |
stats_data: Optional[dict] = None | |
with rich.live.Live( | |
spinner, | |
console=console, | |
vertical_overflow="ellipsis", | |
refresh_per_second=12.5, | |
transient=raw, | |
) as live: | |
response = urllib.request.urlopen(request) | |
with response: | |
response_buf = io.TextIOWrapper(response, encoding="utf-8") | |
while True: | |
buf = response_buf.readline() | |
if not buf: | |
break | |
data = json.loads(buf) | |
if data["done"]: | |
stats_data = data | |
if raw: | |
sys.stdout.write("\n") | |
break | |
if raw: | |
live.stop() | |
sys.stdout.write(data["response"]) | |
else: | |
output.write(data["response"]) | |
live.update(rich.markdown.Markdown(output.getvalue())) | |
if params.stats and stats_data: | |
console.print() | |
table = rich.table.Table(f"Generation Stats", box=None) | |
table.add_column() | |
table.add_column() | |
def format_int(val): | |
return format(val, "n") | |
def format_float(val): | |
return format(val, ".2f") | |
table.add_row("[bold]Model", stats_data.get("model")) | |
table.add_row("[bold]Created At", stats_data.get("created_at")) | |
table.add_row( | |
"[bold]Total Duration (s)", | |
format_float(int(stats_data.get("total_duration")) / 10**9), | |
) | |
table.add_row( | |
"[bold]Load Duration (s)", | |
format_float(int(stats_data.get("load_duration")) / 10**9), | |
) | |
if "prompt_eval_count" in stats_data: | |
table.add_row( | |
"[bold]Input Tokens", format_int(int(stats_data.get("prompt_eval_count"))) | |
) | |
table.add_row( | |
"[bold]Prompt Evaluation Duration (s)", | |
format_float(int(stats_data.get("prompt_eval_duration")) / 10**9), | |
) | |
table.add_row( | |
"[bold]Response Tokens", format_int(int(stats_data.get("eval_count"))) | |
) | |
table.add_row( | |
"[bold]Response Evaluation Duration (s)", | |
format_float(int(stats_data.get("eval_duration")) / 10**9), | |
) | |
table.add_row( | |
"[bold]Tokens per second", | |
format_float( | |
int(stats_data.get("eval_count")) | |
/ int(stats_data.get("eval_duration")) | |
* 10**9 | |
), | |
) | |
console.print(table) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment