|
#!/usr/bin/env python |
|
# -*- coding: utf-8 -*- |
|
""" |
|
Standalone Asynchronous Nanonets-OCR-s Inference Script using vLLM and PyMuPDF. |
|
|
|
This script processes PDF files from an input directory using the |
|
nanonets/Nanonets-OCR-s model served locally by vLLM via its OpenAI-compatible API. |
|
It renders each page, sends API requests concurrently for OCR, extracts the |
|
structured markdown/HTML text, and saves the combined text for each PDF into a |
|
corresponding .txt file in the specified output directory. |
|
|
|
This version uses asyncio and the AsyncOpenAI client to significantly speed up |
|
processing by sending multiple page OCR requests to the vLLM server concurrently. |
|
|
|
**IMPORTANT:** Requires a separate vLLM server running with the Nanonets-OCR-s model. |
|
Start the server BEFORE running this script, for example: |
|
|
|
vllm serve nanonets/Nanonets-OCR-s --max-num-seqs 256 --gpu-memory-utilization 0.9 |
|
|
|
Dependencies (vLLM - see vLLM docs for specific CUDA versions): |
|
pip install ninja vllm flash-attn |
|
|
|
Dependencies (Script): |
|
pip install "openai>=1.0" PyMuPDF Pillow fire tqdm pypdf "tqdm[asyncio]" joblib |
|
|
|
Example Usage: |
|
# 1. Start the vLLM server in a separate terminal: |
|
# vllm serve nanonets/Nanonets-OCR-s |
|
|
|
# 2. Run this script: |
|
python nanonets_pipeline.py \ |
|
--input_dir ./my_pdfs \ |
|
--output_dir ./output_text \ |
|
--model_id nanonets/Nanonets-OCR-s \ |
|
--max_pages 100 \ |
|
--overwrite \ |
|
--api_base_url http://localhost:8000/v1 \ |
|
--concurrency_limit 16 |
|
""" |
|
|
|
import asyncio |
|
import base64 |
|
import io |
|
import logging |
|
import os |
|
import re |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
import fire |
|
|
|
# REMOVED: mdformat is no longer needed as Nanonets produces structured output. |
|
# import mdformat |
|
from joblib import Parallel, delayed |
|
from PIL import Image |
|
from pypdf import PdfReader |
|
from pypdf.errors import PdfReadError |
|
from tqdm import tqdm |
|
from tqdm.asyncio import tqdm_asyncio |
|
|
|
try: |
|
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, RateLimitError |
|
except ImportError: |
|
print("=" * 80) |
|
print("ERROR: openai library >= 1.0 not found.") |
|
print("Please install it: pip install 'openai>=1.0'") |
|
print("=" * 80) |
|
exit(1) |
|
|
|
try: |
|
import fitz # PyMuPDF |
|
except ImportError: |
|
print("=" * 80) |
|
print("ERROR: PyMuPDF library not found.") |
|
print("Please install it: pip install PyMuPDF") |
|
print("=" * 80) |
|
exit(1) |
|
|
|
|
|
# --- Configuration --- |
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(levelname)s - [%(funcName)s] %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
# Reduce noise from underlying libraries |
|
logging.getLogger("httpx").setLevel(logging.WARNING) |
|
logging.getLogger("openai").setLevel(logging.WARNING) |
|
logging.getLogger("httpcore").setLevel(logging.WARNING) |
|
|
|
# --- CHANGED: Updated model ID, prompt, and default parameters for Nanonets-OCR-s --- |
|
DEFAULT_MODEL_ID: str = "nanonets/Nanonets-OCR-s" |
|
NANONETS_PROMPT: str = ( |
|
"Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and ☑ for check boxes." |
|
) |
|
DEFAULT_TARGET_IMAGE_DIM: int = 1024 |
|
DEFAULT_API_BASE_URL: str = "http://localhost:8000/v1" |
|
DEFAULT_API_KEY: str = "EMPTY" |
|
DEFAULT_CONCURRENCY_LIMIT: int = 16 |
|
DEFAULT_MAX_TOKENS_PER_PAGE: int = 10000 |
|
DEFAULT_TEMPERATURE: float = 0.0 |
|
DEFAULT_FREQ_PENALTY: float = 0.1 |
|
|
|
|
|
def render_pdf_page_to_pil_fitz( |
|
pdf_path: Path, |
|
page_num: int, |
|
target_longest_image_dim: int = DEFAULT_TARGET_IMAGE_DIM, |
|
) -> Optional[Image.Image]: |
|
""" |
|
Renders a single page of a PDF to a PIL Image using PyMuPDF (fitz). |
|
|
|
Resizes the image so its longest dimension matches target_longest_image_dim, |
|
but only downscales (does not upscale). |
|
|
|
Args: |
|
pdf_path: Path to the PDF file. |
|
page_num: The 1-based page number to render. |
|
target_longest_image_dim: Target size for the longest dimension. |
|
|
|
Returns: |
|
A PIL Image object of the rendered page, or None if rendering fails. |
|
""" |
|
doc: Optional[fitz.Document] = None |
|
try: |
|
doc = fitz.open(pdf_path) |
|
if not 0 < page_num <= doc.page_count: |
|
logger.error( |
|
f"Invalid page number {page_num} for {pdf_path.name} " |
|
f"({doc.page_count} pages)." |
|
) |
|
return None |
|
|
|
page: fitz.Page = doc.load_page(page_num - 1) # fitz uses 0-based index |
|
page_rect: fitz.Rect = page.rect |
|
width, height = page_rect.width, page_rect.height |
|
|
|
if max(width, height) <= 0: |
|
logger.error( |
|
f"Invalid page dimensions ({width}x{height}) for " |
|
f"{pdf_path.name} page {page_num}." |
|
) |
|
return None |
|
|
|
zoom_factor: float = 1.0 |
|
if max(width, height) > target_longest_image_dim: |
|
zoom_factor = target_longest_image_dim / max(width, height) |
|
|
|
matrix: fitz.Matrix = fitz.Matrix(zoom_factor, zoom_factor) |
|
pix: fitz.Pixmap = page.get_pixmap(matrix=matrix, alpha=False) |
|
|
|
if pix.width == 0 or pix.height == 0: |
|
logger.error( |
|
f"Rendered pixmap has zero dimension for {pdf_path.name} " |
|
f"page {page_num}." |
|
) |
|
return None |
|
|
|
img: Image.Image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
return img |
|
|
|
except fitz.fitz.FileNotFoundError: |
|
logger.error(f"PyMuPDF could not find file: {pdf_path}") |
|
return None |
|
except Exception as e: |
|
logger.error( |
|
f"PyMuPDF error rendering {pdf_path.name} page {page_num}: " |
|
f"{type(e).__name__} - {e}" |
|
) |
|
return None |
|
finally: |
|
if doc: |
|
try: |
|
doc.close() |
|
except Exception as e: |
|
logger.warning(f"Error closing PDF {pdf_path.name}: {e}") |
|
|
|
|
|
def get_pdf_page_count(pdf_path: Path) -> Optional[int]: |
|
""" |
|
Gets the number of pages in a PDF file using pypdf, with fitz fallback. |
|
|
|
Args: |
|
pdf_path: Path to the PDF file. |
|
|
|
Returns: |
|
The number of pages as an integer, or None if reading fails. |
|
""" |
|
try: |
|
reader = PdfReader(pdf_path, strict=False) |
|
count = len(reader.pages) |
|
|
|
if count == 0: |
|
try: |
|
with fitz.open(pdf_path) as doc: |
|
count = doc.page_count |
|
except Exception: |
|
logger.warning( |
|
f"pypdf reported 0 pages, fitz failed to open " |
|
f"{pdf_path.name}. Assuming 0 pages." |
|
) |
|
return 0 |
|
return count |
|
except PdfReadError as e: |
|
logger.error(f"pypdf failed to read {pdf_path.name}: {e}. Trying fitz.") |
|
try: |
|
with fitz.open(pdf_path) as doc: |
|
return doc.page_count |
|
except Exception as fitz_e: |
|
logger.error( |
|
f"Both pypdf and fitz failed page count for {pdf_path.name}: {fitz_e}" |
|
) |
|
return None |
|
except FileNotFoundError: |
|
logger.error(f"File not found for page count: {pdf_path}") |
|
return None |
|
except Exception as e: |
|
logger.error(f"Unexpected error getting page count for {pdf_path.name}: {e}") |
|
return None |
|
|
|
|
|
def encode_pil_to_base64(image: Image.Image, format: str = "PNG") -> str: |
|
""" |
|
Encodes a PIL image object to a base64 string. |
|
|
|
Args: |
|
image: The PIL Image object. |
|
format: The image format to use (e.g., "PNG", "JPEG"). |
|
|
|
Returns: |
|
The base64 encoded string representation of the image. |
|
""" |
|
buffered = io.BytesIO() |
|
image.save(buffered, format=format) |
|
img_byte = buffered.getvalue() |
|
img_base64 = base64.b64encode(img_byte) |
|
return img_base64.decode("utf-8") |
|
|
|
|
|
async def ocr_page_api( |
|
client: AsyncOpenAI, |
|
model_id: str, |
|
img_base64: str, |
|
page_num: int, |
|
pdf_name: str, |
|
semaphore: asyncio.Semaphore, |
|
temperature: float = DEFAULT_TEMPERATURE, |
|
max_tokens: int = DEFAULT_MAX_TOKENS_PER_PAGE, |
|
frequency_penalty: float = DEFAULT_FREQ_PENALTY, |
|
) -> str: |
|
""" |
|
Sends a single page image to the vLLM OpenAI API for OCR asynchronously. |
|
|
|
Uses an asyncio.Semaphore to limit the number of concurrent requests. |
|
|
|
Args: |
|
client: The initialized AsyncOpenAI client. |
|
model_id: The model identifier for the API call. |
|
img_base64: The base64 encoded string of the page image. |
|
page_num: The 1-based page number (for logging). |
|
pdf_name: The name of the PDF file (for logging). |
|
semaphore: The asyncio.Semaphore to control concurrency. |
|
temperature: Sampling temperature for the model. |
|
max_tokens: Maximum tokens to generate for the page. |
|
|
|
Returns: |
|
The extracted text content as a string, or an error marker string |
|
(e.g., "[API_CONNECTION_ERROR]") if an API error occurs. |
|
""" |
|
async with semaphore: # Acquire semaphore before making the API call |
|
try: |
|
response = await client.chat.completions.create( |
|
model=model_id, |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/png;base64,{img_base64}" |
|
}, |
|
}, |
|
# --- CHANGED: Use the detailed Nanonets prompt --- |
|
{"type": "text", "text": NANONETS_PROMPT}, |
|
], |
|
} |
|
], |
|
temperature=temperature, |
|
max_tokens=max_tokens, |
|
frequency_penalty=frequency_penalty, |
|
) |
|
content = response.choices[0].message.content |
|
return content.strip() if content else "[API_EMPTY_RESPONSE]" |
|
except APIConnectionError as e: |
|
logger.error( |
|
f"API Connect Error page {page_num} ({pdf_name}): {e}. " |
|
f"Is server at {client.base_url} running?" |
|
) |
|
return "[API_CONNECTION_ERROR]" |
|
except RateLimitError as e: |
|
logger.warning( |
|
f"API Rate Limit Error page {page_num} ({pdf_name}): {e}. " |
|
f"Server busy or concurrency too high? Retrying may be needed." |
|
) |
|
await asyncio.sleep(2) |
|
return "[API_RATE_LIMIT_ERROR]" |
|
except APIStatusError as e: |
|
logger.error( |
|
f"API Status Error page {page_num} ({pdf_name}): " |
|
f"Status={e.status_code}, Response={e.response}" |
|
) |
|
return f"[API_STATUS_ERROR_{e.status_code}]" |
|
except Exception as e: |
|
logger.exception(f"Unexpected API Error page {page_num} ({pdf_name}): {e}") |
|
return "[API_UNEXPECTED_ERROR]" |
|
|
|
|
|
def render_and_encode_single_page( |
|
pdf_file: Path, page_num: int, target_image_dim: int, pdf_name: str |
|
) -> tuple: |
|
""" |
|
Renders and encodes a single PDF page in one function for parallel processing. |
|
|
|
Args: |
|
pdf_file: Path to the PDF file |
|
page_num: Page number to render (1-based) |
|
target_image_dim: Target size for longest dimension |
|
pdf_name: Name of PDF file (for logging) |
|
|
|
Returns: |
|
tuple: (page_num, base64_string or error_message) |
|
""" |
|
pil_image = render_pdf_page_to_pil_fitz(pdf_file, page_num, target_image_dim) |
|
if not pil_image: |
|
logger.warning(f"Failed to render page {page_num} ({pdf_name})") |
|
return page_num, "[PAGE_RENDER_ERROR]" |
|
|
|
try: |
|
img_base64 = encode_pil_to_base64(pil_image) |
|
return page_num, img_base64 |
|
except Exception as e: |
|
logger.error(f"Failed to encode page {page_num} ({pdf_name}): {e}") |
|
return page_num, "[IMAGE_ENCODE_ERROR]" |
|
|
|
|
|
# --- Main Processing Logic --- |
|
|
|
|
|
async def process_directory( |
|
input_dir: str, |
|
output_dir: Optional[str] = None, |
|
model_id: str = DEFAULT_MODEL_ID, |
|
api_base_url: str = DEFAULT_API_BASE_URL, |
|
api_key: str = DEFAULT_API_KEY, |
|
target_image_dim: int = DEFAULT_TARGET_IMAGE_DIM, |
|
max_pages: Optional[int] = None, |
|
# --- CHANGED: Updated default temperature and max tokens for Nanonets --- |
|
temperature: float = DEFAULT_TEMPERATURE, |
|
max_tokens_per_page: int = DEFAULT_MAX_TOKENS_PER_PAGE, |
|
overwrite: bool = False, |
|
concurrency_limit: int = DEFAULT_CONCURRENCY_LIMIT, |
|
) -> None: |
|
""" |
|
Processes PDF files asynchronously using Nanonets-OCR-s via vLLM's OpenAI API. |
|
|
|
Renders pages, encodes them, sends concurrent API requests for OCR, |
|
combines results, and saves text files. |
|
|
|
Args: |
|
input_dir: Path to the directory containing input PDF files. |
|
output_dir: Path to the directory for output .txt files. If None, |
|
creates a directory next to input_dir. |
|
model_id: Model ID for the vLLM server API. |
|
api_base_url: Base URL of the vLLM OpenAI-compatible API endpoint. |
|
api_key: API key for the endpoint (usually 'EMPTY' for local vLLM). |
|
target_image_dim: Target size for the longest dimension of page images. |
|
max_pages: Max pages to process per PDF (None for all pages). |
|
temperature: Sampling temperature for the model (0.0 recommended for Nanonets). |
|
max_tokens_per_page: Max tokens the model can generate per page. |
|
overwrite: If True, overwrite existing output .txt files. |
|
concurrency_limit: Maximum number of concurrent API requests. |
|
""" |
|
input_path = Path(input_dir).resolve() |
|
assert ( |
|
input_path.is_dir() |
|
), f"Input directory not found or is not a directory: {input_path}" |
|
|
|
output_path = ( |
|
Path(output_dir).resolve() |
|
if output_dir is not None |
|
else input_path.parent / f"output-pdftotext-{input_path.name}" |
|
) |
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
logger.info(f"Input directory: {input_path}") |
|
logger.info(f"Output directory: {output_path}") |
|
logger.info(f"Model API: {model_id} at {api_base_url}") |
|
logger.info(f"Concurrency: {concurrency_limit}") |
|
logger.info(f"Target Image Dim: {target_image_dim}") |
|
logger.info(f"Overwrite: {overwrite}") |
|
|
|
client: Optional[AsyncOpenAI] = None |
|
try: |
|
client = AsyncOpenAI(api_key=api_key, base_url=api_base_url) |
|
logger.info(f"AsyncOpenAI client initialized for {api_base_url}") |
|
|
|
pdf_files: List[Path] = sorted(list(input_path.glob("*.pdf"))) |
|
if not pdf_files: |
|
logger.warning(f"No PDF files found in {input_path}") |
|
return |
|
logger.info(f"Found {len(pdf_files)} PDF files.") |
|
|
|
semaphore = asyncio.Semaphore(concurrency_limit) |
|
|
|
for pdf_file in tqdm( |
|
pdf_files, desc="Processing PDFs", unit="pdf", mininterval=1.0 |
|
): |
|
output_txt_path = output_path / (pdf_file.stem + ".md") |
|
|
|
if not overwrite and output_txt_path.exists(): |
|
logger.info(f"Skipping {pdf_file.name}, output exists.") |
|
continue |
|
|
|
logger.info(f"Starting processing for {pdf_file.name}") |
|
|
|
page_count = get_pdf_page_count(pdf_file) |
|
if page_count is None: |
|
logger.warning(f"Skipping {pdf_file.name}, failed to get page count.") |
|
output_txt_path.write_text("[ERROR_READING_PDF]", encoding="utf-8") |
|
continue |
|
if page_count == 0: |
|
logger.warning(f"Skipping {pdf_file.name}, contains 0 pages.") |
|
output_txt_path.write_text("", encoding="utf-8") # Empty file |
|
continue |
|
|
|
num_pages_to_process = page_count |
|
if max_pages is not None and 0 < max_pages < page_count: |
|
logger.info(f"Limiting to first {max_pages} pages of {pdf_file.name}") |
|
num_pages_to_process = max_pages |
|
|
|
# --- Preprocessing: Render and Encode Pages --- |
|
page_render_encode_data: Dict[int, str] = {} # page_num -> base64 or error |
|
logger.debug( |
|
f"Rendering/encoding {num_pages_to_process} pages for {pdf_file.name} in parallel" |
|
) |
|
n_jobs = min(64, os.cpu_count() or 1) |
|
logger.info(f"Using {n_jobs} cores for parallel page rendering") |
|
parallel_results = Parallel( |
|
n_jobs=n_jobs, verbose=0 |
|
)( # Set verbose=0 to avoid clutter |
|
delayed(render_and_encode_single_page)( |
|
pdf_file, page_num, target_image_dim, pdf_file.name |
|
) |
|
for page_num in range(1, num_pages_to_process + 1) |
|
) |
|
page_render_encode_data = {} |
|
valid_pages_for_api = 0 |
|
for page_num, result in parallel_results: |
|
page_render_encode_data[page_num] = result |
|
if not result.startswith("["): |
|
valid_pages_for_api += 1 |
|
|
|
if valid_pages_for_api == 0: |
|
logger.warning( |
|
f"No pages successfully rendered/encoded for {pdf_file.name}. " |
|
"Skipping API calls." |
|
) |
|
all_page_texts = { |
|
pn: data for pn, data in page_render_encode_data.items() |
|
} |
|
else: |
|
# --- Asynchronous API Calls --- |
|
tasks: List[Tuple[int, asyncio.Task[str]]] = [] |
|
logger.info( |
|
f"Submitting {valid_pages_for_api} pages to API for {pdf_file.name}" |
|
) |
|
for page_num in range(1, num_pages_to_process + 1): |
|
img_data = page_render_encode_data.get(page_num) |
|
if img_data and not img_data.startswith("["): |
|
task = asyncio.create_task( |
|
ocr_page_api( |
|
client=client, |
|
model_id=model_id, |
|
img_base64=img_data, |
|
page_num=page_num, |
|
pdf_name=pdf_file.name, |
|
semaphore=semaphore, |
|
temperature=temperature, |
|
max_tokens=max_tokens_per_page, |
|
), |
|
name=f"OCR_{pdf_file.stem}_p{page_num}", |
|
) |
|
tasks.append((page_num, task)) |
|
|
|
api_results: List[str] = await tqdm_asyncio.gather( |
|
*(task for _, task in tasks), |
|
desc=f" OCR Pages ({pdf_file.name[:20]})", |
|
unit="page", |
|
leave=False, |
|
mininterval=5.0, # Update every 5 seconds max |
|
) |
|
|
|
# --- Combine Results --- |
|
all_page_texts: Dict[int, str] = {} |
|
for pn, data in page_render_encode_data.items(): |
|
if data.startswith("["): |
|
all_page_texts[pn] = data |
|
for i, (page_num, _) in enumerate(tasks): |
|
all_page_texts[page_num] = api_results[i] |
|
|
|
if not all_page_texts: |
|
logger.warning(f"No text results generated for {pdf_file.name}.") |
|
output_txt_path.write_text("", encoding="utf-8") |
|
continue |
|
|
|
ERROR_PATTERN = re.compile(r"^\s*\[[A-Z0-9_]+\]\s*$") |
|
ordered_texts: List[str] = [ |
|
all_page_texts.get(pn, f"[PAGE_{pn}_MISSING_UNEXPECTEDLY]") |
|
for pn in range(1, num_pages_to_process + 1) |
|
] |
|
|
|
# The filtering of error messages is still useful. |
|
filtered_texts: List[str] = [ |
|
text |
|
for text in ordered_texts |
|
if text.strip() and not ERROR_PATTERN.match(text.strip()) |
|
] |
|
|
|
if not filtered_texts: |
|
logger.warning(f"All pages were filtered out for {pdf_file.name}.") |
|
output_txt_path.write_text("", encoding="utf-8") |
|
continue |
|
|
|
# Use form feed character (\f) as page separator. This is a good way |
|
# to delimit pages in the final text file. |
|
final_text: str = "\n\f\n".join(ordered_texts) |
|
|
|
try: |
|
output_txt_path.write_text(final_text, encoding="utf-8") |
|
logger.info(f"Successfully wrote output: {output_txt_path.name}") |
|
except Exception as e: |
|
logger.error(f"Failed to write output file {output_txt_path}: {e}") |
|
|
|
except Exception as e: |
|
logger.exception(f"An unexpected error occurred during processing: {e}") |
|
finally: |
|
if client: |
|
await client.close() |
|
logger.info("AsyncOpenAI client closed.") |
|
logger.info("Processing run finished.") |
|
|
|
|
|
def main(**kwargs: Any) -> None: |
|
""" |
|
Command-line entry point wrapper to run the async processing function. |
|
|
|
Uses fire library to handle command-line arguments. Any argument accepted |
|
by `process_directory` can be passed via the command line, e.g., |
|
`--input_dir ./pdfs --max_pages 5`. |
|
|
|
Args: |
|
**kwargs: Arguments passed from the command line via fire. |
|
""" |
|
try: |
|
asyncio.run(process_directory(**kwargs)) |
|
except KeyboardInterrupt: |
|
logger.info("Processing interrupted by user.") |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
I had claude format the install instructions
vLLM OCR Server Setup Guide
Prerequisites
This guide sets up a vLLM server with Nanonets OCR model on Ubuntu with CUDA support.
1. Install Build Tools
2. Install Python Development Headers
3. Install Go and pget
4. Install Image Processing Libraries
# Install required image processing libraries for Pillow sudo apt update sudo apt install libjpeg-turbo8-dev zlib1g-dev libpng-dev \ libfreetype6-dev liblcms2-dev libopenjp2-7-dev \ libtiff5-dev libwebp-dev
5. Install UV Package Manager
6. Setup Python Environment and Install Packages
7. Optimize Pillow with SIMD
8. Start vLLM Server
# Launch vLLM server with Nanonets OCR model vllm serve nanonets/Nanonets-OCR-s \ --tensor-parallel-size 1 \ --pipeline-parallel-size 1 \ --dtype auto \ --gpu-memory-utilization 0.95 \ --max-num-seqs 256 \ --max-num-batched-tokens 256000 \ --disable-log-requests \ --enable-prefix-caching \ --swap-space 0
Environment Variables to Persist
Add these to your
~/.bashrc
or~/.zshrc
:Notes
--gpu-memory-utilization 0.95
setting uses 95% of available GPU memory