Created
January 13, 2024 02:18
-
-
Save killfill/52c8e1bb5e5f9baa7a36e68be97d8817 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/autogen/hooks_utils.py b/autogen/hooks_utils.py | |
new file mode 100644 | |
index 00000000..7de0e876 | |
--- /dev/null | |
+++ b/autogen/hooks_utils.py | |
@@ -0,0 +1,24 @@ | |
+from typing import Callable | |
+from collections import defaultdict | |
+ | |
+class Hooks: | |
+ _hooks = defaultdict(list) | |
+ | |
+ def register(self, name: str, func: Callable) -> None: | |
+ self._hooks[name].append(func) | |
+ | |
+ def trigger(self, name: str, *args, **kwargs) -> None: | |
+ for func in self._hooks[name]: | |
+ func(*args, **kwargs) | |
+ | |
+ @classmethod | |
+ def instance(cls): | |
+ if not hasattr(cls, "_instance"): | |
+ cls._instance = cls() | |
+ return cls._instance | |
+ | |
+if __name__ == "__main__": | |
+ hooks = Hooks.instance() | |
+ hooks.register("test", lambda x: print('1', x)) | |
+ hooks.register("test", lambda x: print('2', x)) | |
+ hooks.trigger('test', "p1") | |
diff --git a/autogen/oai/client.py b/autogen/oai/client.py | |
index 1bdfd835..e191b174 100644 | |
--- a/autogen/oai/client.py | |
+++ b/autogen/oai/client.py | |
@@ -15,6 +15,8 @@ from autogen.oai.openai_utils import get_key, OAI_PRICE1K | |
from autogen.token_count_utils import count_token | |
from autogen._pydantic import model_dump | |
+from autogen.hooks_utils import Hooks | |
+ | |
TOOL_ENABLED = False | |
try: | |
import openai | |
@@ -287,6 +289,8 @@ class OpenAIWrapper: | |
else: | |
# add cost calculation before caching no matter filter is passed or not | |
response.cost = self.cost(response) | |
+ Hooks.instance().trigger("after_llm", input=full_config, output=response) | |
+ | |
self._update_usage_summary(response, use_cache=False) | |
if cache_seed is not None: | |
# Cache the response | |
diff --git a/test/test_hooks.py b/test/test_hooks.py | |
new file mode 100644 | |
index 00000000..01e9c577 | |
--- /dev/null | |
+++ b/test/test_hooks.py | |
@@ -0,0 +1,16 @@ | |
+from autogen import config_list_from_json, OpenAIWrapper | |
+from autogen.hooks_utils import Hooks | |
+import json | |
+ | |
+config_list = config_list_from_json(env_or_file="OAI_CONFIG_LIST") | |
+ | |
+client = OpenAIWrapper() | |
+ | |
+def show_me(input, output): | |
+ print('input', input) | |
+ print('output', output) | |
+ print('Cost', output.cost) | |
+ | |
+hooks = Hooks.instance().register('after_llm', show_me) | |
+ | |
+client.create(messages=[{"role": "system", "content": "you are a cat"}, {"role": "user", "content": "2+1="}], model='gpt-3.5-turbo', cache_seed=None) | |
\ No newline at end of file |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment