-
-
Save sefatanam/6e19894b0e61de461f82f1d34835e2bc to your computer and use it in GitHub Desktop.
Ollama command line interface with Markdown rendering
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/env python3 | |
import argparse | |
import io | |
import json | |
import os | |
import sys | |
import urllib.parse | |
import urllib.request | |
from typing import Optional, NamedTuple | |
import rich.console | |
import rich.live | |
import rich.markdown | |
import rich.panel | |
import rich.spinner | |
import rich.text | |
import rich.table | |
import rich.pretty | |
query_url = urllib.parse.urljoin( | |
os.environ.get("OLLAMA_HOST", "http://localhost:11434"), "/api/generate" | |
) | |
class ExecutionParams(NamedTuple): | |
raw: bool | |
stats: bool | |
seed: Optional[int] = None | |
temperature: Optional[float] = None | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--models", | |
"--model", | |
default="llama3", | |
help="The name of the models to run. Multiple models can be specified separated by commas", | |
) | |
parser.add_argument( | |
"--raw", | |
action="store_true", | |
help="Do not format the results with markdown. Output the results to stdout", | |
) | |
parser.add_argument( | |
"--stats", | |
action="store_true", | |
help="Output some stats about the generation at the end", | |
) | |
parser.add_argument( | |
"--seed", type=int, help="The seed to use in the model evaluation. " | |
"for deterministic output you must also set temperature=0." | |
) | |
parser.add_argument( | |
"--temperature", | |
type=float, | |
help="The temperature to use in the model evaluation. A value of 0 means the model" | |
" will always pick the highest probability token. Values above 0 add some" | |
" randomness to token selection and generally make the model more creative." | |
" For deterministic output, set this to 0 and set a seed to a non-zero constant." | |
" Default in ollama is 0.8.", | |
) | |
parser.add_argument("prompt", nargs="*") | |
args = vars(parser.parse_args()) | |
raw = args["raw"] | |
if args["prompt"]: | |
prompt = " ".join(args["prompt"]) | |
else: | |
prompt = sys.stdin.read() | |
console = rich.console.Console(stderr=raw) | |
if not raw: | |
console.print(rich.panel.Panel.fit("[bold]Prompt")) | |
console.print(rich.text.Text(prompt)) | |
models = args["models"].split(",") | |
for model in models: | |
run_ollama( | |
model, | |
prompt, | |
console, | |
ExecutionParams( | |
raw=raw, | |
stats=args["stats"], | |
seed=args["seed"], | |
temperature=args["temperature"], | |
), | |
) | |
def run_ollama(model, prompt, console: rich.console.Console, params: ExecutionParams): | |
raw = params.raw | |
if not raw: | |
console.print() | |
console.print(rich.panel.Panel.fit("[bold]" + model)) | |
req_data = { | |
"model": model, | |
"prompt": prompt, | |
"options": {}, | |
} | |
if params.temperature is not None: | |
req_data["options"]["temperature"] = params.temperature | |
if params.seed is not None: | |
req_data["options"]["seed"] = params.seed | |
request = urllib.request.Request( | |
query_url, | |
data=json.dumps( | |
req_data | |
).encode("utf-8"), | |
headers={ | |
"Content-Type": "application/json; charset=utf-8", | |
}, | |
) | |
output = io.StringIO() | |
spinner = rich.spinner.Spinner( | |
"dots", text="Loading model...", style="status.spinner", speed=1.0 | |
) | |
stats_data: Optional[dict] = None | |
with rich.live.Live( | |
spinner, | |
console=console, | |
vertical_overflow="ellipsis", | |
refresh_per_second=12.5, | |
transient=raw, | |
) as live: | |
response = urllib.request.urlopen(request) | |
with response: | |
response_buf = io.TextIOWrapper(response, encoding="utf-8") | |
while True: | |
buf = response_buf.readline() | |
if not buf: | |
break | |
data = json.loads(buf) | |
if data["done"]: | |
stats_data = data | |
if raw: | |
sys.stdout.write("\n") | |
break | |
if raw: | |
live.stop() | |
sys.stdout.write(data["response"]) | |
else: | |
output.write(data["response"]) | |
live.update(rich.markdown.Markdown(output.getvalue())) | |
if params.stats and stats_data: | |
console.print() | |
table = rich.table.Table(f"Generation Stats", box=None) | |
table.add_column() | |
table.add_column() | |
def format_int(val): | |
return format(val, "n") | |
def format_float(val): | |
return format(val, ".2f") | |
table.add_row("[bold]Model", stats_data.get("model")) | |
table.add_row("[bold]Created At", stats_data.get("created_at")) | |
table.add_row( | |
"[bold]Total Duration (s)", | |
format_float(int(stats_data.get("total_duration")) / 10**9), | |
) | |
table.add_row( | |
"[bold]Load Duration (s)", | |
format_float(int(stats_data.get("load_duration")) / 10**9), | |
) | |
if "prompt_eval_count" in stats_data: | |
table.add_row( | |
"[bold]Input Tokens", format_int(int(stats_data.get("prompt_eval_count"))) | |
) | |
table.add_row( | |
"[bold]Prompt Evaluation Duration (s)", | |
format_float(int(stats_data.get("prompt_eval_duration")) / 10**9), | |
) | |
table.add_row( | |
"[bold]Response Tokens", format_int(int(stats_data.get("eval_count"))) | |
) | |
table.add_row( | |
"[bold]Response Evaluation Duration (s)", | |
format_float(int(stats_data.get("eval_duration")) / 10**9), | |
) | |
table.add_row( | |
"[bold]Tokens per second", | |
format_float( | |
int(stats_data.get("eval_count")) | |
/ int(stats_data.get("eval_duration")) | |
* 10**9 | |
), | |
) | |
console.print(table) | |
if __name__ == "__main__": | |
main() |
Author
sefatanam
commented
Jan 28, 2025

Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment