Skip to content

Instantly share code, notes, and snippets.

@abetlen
Created September 29, 2024 21:16
Show Gist options
  • Save abetlen/e3ff8f5a7538c1e1f50368baef9a8117 to your computer and use it in GitHub Desktop.
Save abetlen/e3ff8f5a7538c1e1f50368baef9a8117 to your computer and use it in GitHub Desktop.
llama-cpp-python image embeddings
from __future__ import annotations
import os
import ctypes
import contextlib
import numpy as np
import llama_cpp
import llama_cpp.llava_cpp as llava_cpp
class LlavaEmbedding:
def __init__(self, embedding: ctypes._Pointer[llava_cpp.llava_image_embed], hidden_size: int):
self._embedding = embedding
self._exit_stack = contextlib.ExitStack()
self.hidden_size = hidden_size
def llava_image_embed_free():
llava_cpp.llava_image_embed_free(self._embedding)
self._exit_stack.callback(llava_image_embed_free)
@property
def n_image_pos(self) -> int:
return self._embedding.contents.n_image_pos
def embed(
self, llama_ctx: llama_cpp.llama_context_p, n_tokens: int, n_batch: int
) -> int:
n_past = ctypes.c_int(n_tokens)
n_past_p = ctypes.pointer(n_past)
llava_cpp.llava_eval_image_embed(
llama_ctx,
self._embedding,
n_batch,
n_past_p,
)
return n_past.value
def as_numpy(self, n_image_pos: int, embedding_dim: int):
return np.ctypeslib.as_array(
self._embedding.contents.embed,
shape=(n_image_pos, embedding_dim),
)
class LlavaModel:
def __init__(self, path: str, n_threads: int = 1):
self._path = path
self._n_threads = n_threads
self._exit_stack = contextlib.ExitStack()
if not os.path.exists(self._path):
raise ValueError(f"Clip model path does not exist: {self._path}")
clip_ctx = llava_cpp.clip_model_load(self._path.encode(), 0)
if clip_ctx is None:
raise ValueError(f"Failed to load clip model: {self._path}")
self._clip_ctx = clip_ctx
def clip_free():
llava_cpp.clip_free(self._clip_ctx)
print("Clip model freed")
self._exit_stack.callback(clip_free)
def embed_bytes(self, image_bytes: bytes):
embed = llava_cpp.llava_image_embed_make_with_bytes(
self._clip_ctx,
self._n_threads,
(ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)),
len(image_bytes),
)
return LlavaEmbedding(embed, hidden_size=self.hidden_size)
@property
def hidden_size(self):
return llava_cpp.clip_hidden_size(self._clip_ctx)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str)
parser.add_argument("--image_path", type=str)
parser.add_argument("--embedding_dim", type=int)
args = parser.parse_args()
model = LlavaModel(args.model_path)
with open(args.image_path, "rb") as f:
image_bytes = f.read()
embedding = model.embed_bytes(image_bytes)
embedding_numpy = embedding.as_numpy(embedding.n_image_pos, args.embedding_dim)
print(embedding_numpy)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment