Last active
February 28, 2024 08:38
-
-
Save ddomen/8eaa49879d42a4a42a243437b5ddfa83 to your computer and use it in GitHub Desktop.
langchain_textgen_tempfix
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# llm version | |
import json | |
import logging | |
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional | |
import requests | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain.llms.base import LLM | |
from langchain.pydantic_v1 import Field | |
from langchain.schema.output import GenerationChunk | |
logger = logging.getLogger(__name__) | |
class TextGen(LLM): | |
"""text-generation-webui models. | |
To use, you should have the text-generation-webui installed, a model loaded, | |
and --api added as a command-line option. | |
Suggested installation, use one-click installer for your OS: | |
https://github.com/oobabooga/text-generation-webui#one-click-installers | |
Parameters below taken from text-generation-webui api example: | |
https://github.com/oobabooga/text-generation-webui/blob/main/api-examples/api-example.py | |
Example: | |
.. code-block:: python | |
from langchain.llms import TextGen | |
llm = TextGen(model_url="http://localhost:8500") | |
""" | |
model_url: str | |
"""The full URL to the textgen webui including http[s]://host:port """ | |
preset: Optional[str] = None | |
"""The preset to use in the textgen webui """ | |
max_tokens: Optional[int] = 250 | |
"""The maximum number of tokens to generate.""" | |
do_sample: bool = Field(True, alias="do_sample") | |
"""Do sample""" | |
temperature: Optional[float] = 1.3 | |
"""Primary factor to control randomness of outputs. 0 = deterministic | |
(only the most likely token is used). Higher value = more randomness.""" | |
top_p: Optional[float] = 0.1 | |
"""If not set to 1, select tokens with probabilities adding up to less than this | |
number. Higher value = higher range of possible random results.""" | |
typical_p: Optional[float] = 1 | |
"""If not set to 1, select only tokens that are at least this much more likely to | |
appear than random tokens, given the prior text.""" | |
epsilon_cutoff: Optional[float] = 0 # In units of 1e-4 | |
"""Epsilon cutoff""" | |
eta_cutoff: Optional[float] = 0 # In units of 1e-4 | |
"""ETA cutoff""" | |
repetition_penalty: Optional[float] = 1.18 | |
"""Exponential penalty factor for repeating prior tokens. 1 means no penalty, | |
higher value = less repetition, lower value = more repetition.""" | |
top_k: Optional[float] = 40 | |
"""Similar to top_p, but select instead only the top_k most likely tokens. | |
Higher value = higher range of possible random results.""" | |
min_length: Optional[int] = 0 | |
"""Minimum generation length in tokens.""" | |
no_repeat_ngram_size: Optional[int] = 0 | |
"""If not set to 0, specifies the length of token sets that are completely blocked | |
from repeating at all. Higher values = blocks larger phrases, | |
lower values = blocks words or letters from repeating. | |
Only 0 or high values are a good idea in most cases.""" | |
num_beams: Optional[int] = 1 | |
"""Number of beams""" | |
penalty_alpha: Optional[float] = 0 | |
"""Penalty Alpha""" | |
length_penalty: Optional[float] = 1 | |
"""Length Penalty""" | |
early_stopping: bool = Field(False, alias="early_stopping") | |
"""Early stopping""" | |
seed: int = Field(-1, alias="seed") | |
"""Seed (-1 for random)""" | |
add_bos_token: bool = Field(True, alias="add_bos_token") | |
"""Add the bos_token to the beginning of prompts. | |
Disabling this can make the replies more creative.""" | |
truncation_length: Optional[int] = 2048 | |
"""Truncate the prompt up to this length. The leftmost tokens are removed if | |
the prompt exceeds this length. Most models require this to be at most 2048.""" | |
ban_eos_token: bool = Field(False, alias="ban_eos_token") | |
"""Ban the eos_token. Forces the model to never end the generation prematurely.""" | |
skip_special_tokens: bool = Field(True, alias="skip_special_tokens") | |
"""Skip special tokens. Some specific models need this unset.""" | |
stopping_strings: Optional[List[str]] = [] | |
"""A list of strings to stop generation when encountered.""" | |
streaming: bool = False | |
"""Whether to stream the results, token by token.""" | |
legacy_api: bool = False | |
"""Wheter to use the legacy rest api""" | |
@property | |
def _default_params(self) -> Dict[str, Any]: | |
"""Get the default parameters for calling textgen.""" | |
return { | |
"max_tokens": self.max_tokens, | |
"do_sample": self.do_sample, | |
"temperature": self.temperature, | |
"top_p": self.top_p, | |
"typical_p": self.typical_p, | |
"epsilon_cutoff": self.epsilon_cutoff, | |
"eta_cutoff": self.eta_cutoff, | |
"repetition_penalty": self.repetition_penalty, | |
"top_k": self.top_k, | |
"min_length": self.min_length, | |
"no_repeat_ngram_size": self.no_repeat_ngram_size, | |
"num_beams": self.num_beams, | |
"penalty_alpha": self.penalty_alpha, | |
"length_penalty": self.length_penalty, | |
"early_stopping": self.early_stopping, | |
"seed": self.seed, | |
"add_bos_token": self.add_bos_token, | |
"truncation_length": self.truncation_length, | |
"ban_eos_token": self.ban_eos_token, | |
"skip_special_tokens": self.skip_special_tokens, | |
"stopping_strings": self.stopping_strings, | |
"legacy_api": self.legacy_api | |
} | |
@property | |
def _identifying_params(self) -> Dict[str, Any]: | |
"""Get the identifying parameters.""" | |
return {**{"model_url": self.model_url}, **self._default_params} | |
@property | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return "textgen" | |
def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]: | |
""" | |
Performs sanity check, preparing parameters in format needed by textgen. | |
Args: | |
stop (Optional[List[str]]): List of stop sequences for textgen. | |
Returns: | |
Dictionary containing the combined parameters. | |
""" | |
# Raise error if stop sequences are in both input and default params | |
# if self.stop and stop is not None: | |
if self.stopping_strings and stop is not None: | |
raise ValueError("`stop` found in both the input and default params.") | |
if self.preset is None: | |
params = self._default_params | |
else: | |
params = {"preset": self.preset} | |
# then sets it as configured, or default to an empty list: | |
params["stopping_strings"] = self.stopping_strings or stop or [] | |
return params | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
"""Call the textgen web API and return the output. | |
Args: | |
prompt: The prompt to use for generation. | |
stop: A list of strings to stop generation when encountered. | |
Returns: | |
The generated text. | |
Example: | |
.. code-block:: python | |
from langchain.llms import TextGen | |
llm = TextGen(model_url="http://localhost:5000") | |
llm("Write a story about llamas.") | |
""" | |
if self.streaming: | |
combined_text_output = "" | |
for chunk in self._stream( | |
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs | |
): | |
combined_text_output += chunk.text | |
result = combined_text_output | |
else: | |
params = self._get_parameters(stop) | |
if params.get("legacy_api"): | |
url = f"{self.model_url}/api/v1/generate" | |
resname = "results" | |
else: | |
url = f"{self.model_url}/v1/completions" | |
resname = "choices" | |
request = params.copy() | |
request["prompt"] = prompt | |
response = requests.post(url, json=request) | |
if response.status_code == 200: | |
result = response.json()[resname][0]["text"] | |
else: | |
print(f"ERROR: response: {response}") | |
result = "" | |
return result | |
async def _acall( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
"""Call the textgen web API and return the output. | |
Args: | |
prompt: The prompt to use for generation. | |
stop: A list of strings to stop generation when encountered. | |
Returns: | |
The generated text. | |
Example: | |
.. code-block:: python | |
from langchain.llms import TextGen | |
llm = TextGen(model_url="http://localhost:5000") | |
llm("Write a story about llamas.") | |
""" | |
if self.streaming: | |
combined_text_output = "" | |
async for chunk in self._stream( | |
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs | |
): | |
combined_text_output += chunk.text | |
result = combined_text_output | |
else: | |
params = self._get_parameters(stop) | |
if params.get('legacy_api'): | |
url = f'{self.model_url}/api/v1/generate' | |
resname = 'results' | |
else: | |
url = f'{self.model_url}/v1/completions' | |
resname = 'choices' | |
request = params.copy() | |
request["prompt"] = prompt | |
response = requests.post(url, json=request) | |
if response.status_code == 200: | |
result = response.json()[resname][0]["text"] | |
else: | |
print(f"ERROR: response: {response}") | |
result = "" | |
return result | |
def _stream( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[GenerationChunk]: | |
"""Yields results objects as they are generated in real time. | |
It also calls the callback manager's on_llm_new_token event with | |
similar parameters to the OpenAI LLM class method of the same name. | |
Args: | |
prompt: The prompts to pass into the model. | |
stop: Optional list of stop words to use when generating. | |
Returns: | |
A generator representing the stream of tokens being generated. | |
Yields: | |
A dictionary like objects containing a string token and metadata. | |
See text-generation-webui docs and below for more. | |
Example: | |
.. code-block:: python | |
from langchain.llms import TextGen | |
llm = TextGen( | |
model_url = "ws://localhost:5005" | |
streaming=True | |
) | |
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'", | |
stop=["'","\n"]): | |
print(chunk, end='', flush=True) | |
""" | |
params = {**self._get_parameters(stop), **kwargs} | |
request = params.copy() | |
request["prompt"] = prompt | |
if params.get('legacy_api'): | |
try: | |
import websocket | |
except ImportError: | |
raise ImportError("The `websocket-client` package is required for streaming (pip install websocket).") | |
url = f"{self.model_url}/api/v1/stream" | |
websocket_client = websocket.WebSocket() | |
websocket_client.connect(url) | |
websocket_client.send(json.dumps(request)) | |
while True: | |
result = websocket_client.recv() | |
result = json.loads(result) | |
if result["event"] == "text_stream": | |
chunk = GenerationChunk( | |
text=result["text"], | |
generation_info=None, | |
) | |
yield chunk | |
elif result["event"] == "stream_end": | |
websocket_client.close() | |
return | |
if run_manager: | |
run_manager.on_llm_new_token(token=chunk.text) | |
else: | |
url = f"{self.model_url}/v1/completions" | |
request["stream"] = True | |
req_session = requests.Session() | |
with req_session.post(url, json=request, stream=True) as response: | |
if response.status_code != 200: | |
print(f"ERROR: response: {response}") | |
return | |
for chunk in response.iter_lines(): | |
chunk = chunk.decode("utf8").strip() | |
if not chunk: continue | |
if chunk.startswith('data: '): chunk = chunk[6:] | |
chunk = json.loads(chunk) | |
chunk_text = chunk["choices"][0]["text"] | |
yield GenerationChunk(text=chunk_text, generation_info=None) | |
if run_manager: | |
run_manager.on_llm_new_token(token=chunk_text) | |
async def _astream( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> AsyncIterator[GenerationChunk]: | |
"""Yields results objects as they are generated in real time. | |
It also calls the callback manager's on_llm_new_token event with | |
similar parameters to the OpenAI LLM class method of the same name. | |
Args: | |
prompt: The prompts to pass into the model. | |
stop: Optional list of stop words to use when generating. | |
Returns: | |
A generator representing the stream of tokens being generated. | |
Yields: | |
A dictionary like objects containing a string token and metadata. | |
See text-generation-webui docs and below for more. | |
Example: | |
.. code-block:: python | |
from langchain.llms import TextGen | |
llm = TextGen( | |
model_url = "ws://localhost:5005" | |
streaming=True | |
) | |
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'", | |
stop=["'","\n"]): | |
print(chunk, end='', flush=True) | |
""" | |
params = {**self._get_parameters(stop), **kwargs} | |
request = params.copy() | |
request["prompt"] = prompt | |
if params.get('legacy_api'): | |
try: | |
import websocket | |
except ImportError: | |
raise ImportError("The `websocket-client` package is required for streaming (pip install websocket).") | |
url = f"{self.model_url}/api/v1/stream" | |
websocket_client = websocket.WebSocket() | |
websocket_client.connect(url) | |
websocket_client.send(json.dumps(request)) | |
while True: | |
result = websocket_client.recv() | |
result = json.loads(result) | |
if result["event"] == "text_stream": | |
chunk = GenerationChunk( | |
text=result["text"], | |
generation_info=None, | |
) | |
yield chunk | |
elif result["event"] == "stream_end": | |
websocket_client.close() | |
return | |
if run_manager: | |
await run_manager.on_llm_new_token(token=chunk.text) | |
else: | |
url = f"{self.model_url}/v1/completions" | |
req_session = requests.Session() | |
with req_session.post(url, json=request, stream=True) as response: | |
if response.status_code != 200: | |
print(f"ERROR: response: {response}") | |
return | |
for chunk in response.iter_content(chunk_size=5): | |
yield GenerationChunk( | |
text=chunk, | |
generation_info=None | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import logging | |
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional | |
import requests | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain.chat_models.base import SimpleChatModel | |
from langchain.pydantic_v1 import Field | |
from langchain.schema.output import ChatGenerationChunk | |
from langchain.schema.messages import BaseMessage, HumanMessageChunk, AIMessageChunk, SystemMessageChunk | |
logger = logging.getLogger(__name__) | |
TYPE_MAP=dict( | |
human='user', | |
ai='assistant', | |
chat='user', | |
function='user', | |
tool='user', | |
system='system' | |
) | |
INVTYPE_MAP = dict( | |
user=HumanMessageChunk, | |
assistant=AIMessageChunk, | |
system=SystemMessageChunk | |
) | |
class ChatTextGen(SimpleChatModel): | |
"""text-generation-webui models. | |
To use, you should have the text-generation-webui installed, a model loaded, | |
and --api added as a command-line option. | |
Suggested installation, use one-click installer for your OS: | |
https://github.com/oobabooga/text-generation-webui#one-click-installers | |
Parameters below taken from text-generation-webui api example: | |
https://github.com/oobabooga/text-generation-webui/blob/main/api-examples/api-example.py | |
Example: | |
.. code-block:: python | |
from langchain.llms import TextGen | |
llm = TextGen(model_url="http://localhost:8500") | |
""" | |
model_url: str | |
"""The full URL to the textgen webui including http[s]://host:port """ | |
preset: Optional[str] = None | |
"""The preset to use in the textgen webui """ | |
max_tokens: Optional[int] = 250 | |
"""The maximum number of tokens to generate.""" | |
do_sample: bool = Field(True, alias="do_sample") | |
"""Do sample""" | |
temperature: Optional[float] = 1.3 | |
"""Primary factor to control randomness of outputs. 0 = deterministic | |
(only the most likely token is used). Higher value = more randomness.""" | |
top_p: Optional[float] = 0.1 | |
"""If not set to 1, select tokens with probabilities adding up to less than this | |
number. Higher value = higher range of possible random results.""" | |
typical_p: Optional[float] = 1 | |
"""If not set to 1, select only tokens that are at least this much more likely to | |
appear than random tokens, given the prior text.""" | |
epsilon_cutoff: Optional[float] = 0 # In units of 1e-4 | |
"""Epsilon cutoff""" | |
eta_cutoff: Optional[float] = 0 # In units of 1e-4 | |
"""ETA cutoff""" | |
repetition_penalty: Optional[float] = 1.18 | |
"""Exponential penalty factor for repeating prior tokens. 1 means no penalty, | |
higher value = less repetition, lower value = more repetition.""" | |
top_k: Optional[float] = 40 | |
"""Similar to top_p, but select instead only the top_k most likely tokens. | |
Higher value = higher range of possible random results.""" | |
min_length: Optional[int] = 0 | |
"""Minimum generation length in tokens.""" | |
no_repeat_ngram_size: Optional[int] = 0 | |
"""If not set to 0, specifies the length of token sets that are completely blocked | |
from repeating at all. Higher values = blocks larger phrases, | |
lower values = blocks words or letters from repeating. | |
Only 0 or high values are a good idea in most cases.""" | |
num_beams: Optional[int] = 1 | |
"""Number of beams""" | |
penalty_alpha: Optional[float] = 0 | |
"""Penalty Alpha""" | |
length_penalty: Optional[float] = 1 | |
"""Length Penalty""" | |
early_stopping: bool = Field(False, alias="early_stopping") | |
"""Early stopping""" | |
seed: int = Field(-1, alias="seed") | |
"""Seed (-1 for random)""" | |
add_bos_token: bool = Field(True, alias="add_bos_token") | |
"""Add the bos_token to the beginning of prompts. | |
Disabling this can make the replies more creative.""" | |
truncation_length: Optional[int] = 2048 | |
"""Truncate the prompt up to this length. The leftmost tokens are removed if | |
the prompt exceeds this length. Most models require this to be at most 2048.""" | |
ban_eos_token: bool = Field(False, alias="ban_eos_token") | |
"""Ban the eos_token. Forces the model to never end the generation prematurely.""" | |
skip_special_tokens: bool = Field(True, alias="skip_special_tokens") | |
"""Skip special tokens. Some specific models need this unset.""" | |
stopping_strings: Optional[List[str]] = [] | |
"""A list of strings to stop generation when encountered.""" | |
streaming: bool = False | |
"""Whether to stream the results, token by token.""" | |
legacy_api: bool = False | |
"""Wheter to use the legacy rest api""" | |
@property | |
def _default_params(self) -> Dict[str, Any]: | |
"""Get the default parameters for calling textgen.""" | |
return { | |
"max_tokens": self.max_tokens, | |
"do_sample": self.do_sample, | |
"temperature": self.temperature, | |
"top_p": self.top_p, | |
"typical_p": self.typical_p, | |
"epsilon_cutoff": self.epsilon_cutoff, | |
"eta_cutoff": self.eta_cutoff, | |
"repetition_penalty": self.repetition_penalty, | |
"top_k": self.top_k, | |
"min_length": self.min_length, | |
"no_repeat_ngram_size": self.no_repeat_ngram_size, | |
"num_beams": self.num_beams, | |
"penalty_alpha": self.penalty_alpha, | |
"length_penalty": self.length_penalty, | |
"early_stopping": self.early_stopping, | |
"seed": self.seed, | |
"add_bos_token": self.add_bos_token, | |
"truncation_length": self.truncation_length, | |
"ban_eos_token": self.ban_eos_token, | |
"skip_special_tokens": self.skip_special_tokens, | |
"stopping_strings": self.stopping_strings, | |
"legacy_api": self.legacy_api | |
} | |
@property | |
def _identifying_params(self) -> Dict[str, Any]: | |
"""Get the identifying parameters.""" | |
return {**{"model_url": self.model_url}, **self._default_params} | |
@property | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return "textgen" | |
def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]: | |
""" | |
Performs sanity check, preparing parameters in format needed by textgen. | |
Args: | |
stop (Optional[List[str]]): List of stop sequences for textgen. | |
Returns: | |
Dictionary containing the combined parameters. | |
""" | |
# Raise error if stop sequences are in both input and default params | |
# if self.stop and stop is not None: | |
if self.stopping_strings and stop is not None: | |
raise ValueError("`stop` found in both the input and default params.") | |
if self.preset is None: | |
params = self._default_params | |
else: | |
params = {"preset": self.preset} | |
# then sets it as configured, or default to an empty list: | |
params["stopping_strings"] = self.stopping_strings or stop or [] | |
return params | |
def _call( | |
self, | |
prompt: List[BaseMessage], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
"""Call the textgen web API and return the output. | |
Args: | |
prompt: The prompt to use for generation. | |
stop: A list of strings to stop generation when encountered. | |
Returns: | |
The generated text. | |
Example: | |
.. code-block:: python | |
from langchain.llms import TextGen | |
llm = TextGen(model_url="http://localhost:5000") | |
llm("Write a story about llamas.") | |
""" | |
if self.streaming: | |
combined_text_output = "" | |
for chunk in self._stream( | |
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs | |
): | |
combined_text_output += chunk.text | |
result = combined_text_output | |
else: | |
params = self._get_parameters(stop) | |
request = params.copy() | |
if params.get("legacy_api"): | |
url = f"{self.model_url}/api/v1/generate" | |
resname = "results" | |
request["prompt"] = str.join("\n", [ | |
f"{TYPE_MAP.get(message.type, 'user')}: {message.content}" | |
for message in prompt | |
]) | |
else: | |
url = f"{self.model_url}/v1/chat/completions" | |
resname = "choices" | |
request["messages"] = [ | |
{ | |
'content': message.content, | |
'role': TYPE_MAP.get(message.type, 'user') | |
} | |
for message in prompt | |
] | |
response = requests.post(url, json=request) | |
if response.status_code == 200: | |
result = response.json()[resname][0]["message"]["content"] | |
else: | |
print(f"ERROR: response: {response}") | |
result = "" | |
return result | |
async def _acall( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
"""Call the textgen web API and return the output. | |
Args: | |
prompt: The prompt to use for generation. | |
stop: A list of strings to stop generation when encountered. | |
Returns: | |
The generated text. | |
Example: | |
.. code-block:: python | |
from langchain.llms import TextGen | |
llm = TextGen(model_url="http://localhost:5000") | |
llm("Write a story about llamas.") | |
""" | |
if self.streaming: | |
combined_text_output = "" | |
async for chunk in self._stream( | |
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs | |
): | |
combined_text_output += chunk.text | |
result = combined_text_output | |
else: | |
params = self._get_parameters(stop) | |
request = params.copy() | |
if params.get("legacy_api"): | |
url = f"{self.model_url}/api/v1/generate" | |
resname = "results" | |
request["prompt"] = str.join("\n", [ | |
f"{TYPE_MAP.get(message.type, 'user')}: {message.content}" | |
for message in prompt | |
]) | |
else: | |
url = f"{self.model_url}/v1/chat/completions" | |
resname = "choices" | |
request["messages"] = [ | |
{ 'content': message.content, 'role': TYPE_MAP.get(message.type, 'user') } | |
for message in prompt | |
] | |
response = requests.post(url, json=request) | |
if response.status_code == 200: | |
result = response.json()[resname][0]["message"]["content"] | |
else: | |
print(f"ERROR: response: {response}") | |
result = "" | |
return result | |
def _stream( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Iterator[ChatGenerationChunk]: | |
"""Yields results objects as they are generated in real time. | |
It also calls the callback manager's on_llm_new_token event with | |
similar parameters to the OpenAI LLM class method of the same name. | |
Args: | |
prompt: The prompts to pass into the model. | |
stop: Optional list of stop words to use when generating. | |
Returns: | |
A generator representing the stream of tokens being generated. | |
Yields: | |
A dictionary like objects containing a string token and metadata. | |
See text-generation-webui docs and below for more. | |
Example: | |
.. code-block:: python | |
from langchain.llms import TextGen | |
llm = TextGen( | |
model_url = "ws://localhost:5005" | |
streaming=True | |
) | |
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'", | |
stop=["'","\n"]): | |
print(chunk, end='', flush=True) | |
""" | |
params = {**self._get_parameters(stop), **kwargs} | |
request = params.copy() | |
if params.get('legacy_api'): | |
try: | |
import websocket | |
except ImportError: | |
raise ImportError("The `websocket-client` package is required for streaming (pip install websocket).") | |
url = f"{self.model_url}/api/v1/stream" | |
request["prompt"] = str.join("\n", [ | |
f"{TYPE_MAP.get(message.type, 'user')}: {message.content}" | |
for message in prompt | |
]) | |
websocket_client = websocket.WebSocket() | |
websocket_client.connect(url) | |
websocket_client.send(json.dumps(request)) | |
while True: | |
result = websocket_client.recv() | |
result = json.loads(result) | |
if result["event"] == "text_stream": | |
chunk = ChatGenerationChunk( | |
text=result["text"], | |
generation_info=None, | |
) | |
yield chunk | |
elif result["event"] == "stream_end": | |
websocket_client.close() | |
return | |
if run_manager: | |
run_manager.on_llm_new_token(token=chunk.text) | |
else: | |
url = f"{self.model_url}/v1/chat/completions" | |
request["stream"] = True | |
request["messages"] = [ | |
{ 'content': message.content, 'role': TYPE_MAP.get(message.type, 'user') } | |
for message in prompt | |
] | |
req_session = requests.Session() | |
with req_session.post(url, json=request, stream=True) as response: | |
if response.status_code != 200: | |
print(f"ERROR: response: {response}") | |
return | |
for chunk in response.iter_lines(): | |
chunk = chunk.decode("utf8").strip() | |
if not chunk: continue | |
if chunk.startswith('data: '): chunk = chunk[6:] | |
chunk = json.loads(chunk) | |
chunk_msg = chunk["choices"][0]["message"] | |
yield ChatGenerationChunk( | |
text=chunk_msg["content"], | |
message=INVTYPE_MAP.get(chunk_msg["role"], HumanMessageChunk)(content=chunk_msg["content"]), | |
generation_info=None | |
) | |
if run_manager: | |
run_manager.on_llm_new_token(token=chunk_msg["content"]) | |
async def _astream( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> AsyncIterator[ChatGenerationChunk]: | |
"""Yields results objects as they are generated in real time. | |
It also calls the callback manager's on_llm_new_token event with | |
similar parameters to the OpenAI LLM class method of the same name. | |
Args: | |
prompt: The prompts to pass into the model. | |
stop: Optional list of stop words to use when generating. | |
Returns: | |
A generator representing the stream of tokens being generated. | |
Yields: | |
A dictionary like objects containing a string token and metadata. | |
See text-generation-webui docs and below for more. | |
Example: | |
.. code-block:: python | |
from langchain.llms import TextGen | |
llm = TextGen( | |
model_url = "ws://localhost:5005" | |
streaming=True | |
) | |
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'", | |
stop=["'","\n"]): | |
print(chunk, end='', flush=True) | |
""" | |
params = {**self._get_parameters(stop), **kwargs} | |
request = params.copy() | |
if params.get('legacy_api'): | |
try: | |
import websocket | |
except ImportError: | |
raise ImportError("The `websocket-client` package is required for streaming (pip install websocket).") | |
url = f"{self.model_url}/api/v1/stream" | |
request["prompt"] = str.join("\n", [ | |
f"{TYPE_MAP.get(message.type, 'user')}: {message.content}" | |
for message in prompt | |
]) | |
websocket_client = websocket.WebSocket() | |
websocket_client.connect(url) | |
websocket_client.send(json.dumps(request)) | |
while True: | |
result = websocket_client.recv() | |
result = json.loads(result) | |
if result["event"] == "text_stream": | |
chunk = ChatGenerationChunk( | |
text=result["text"], | |
generation_info=None, | |
) | |
yield chunk | |
elif result["event"] == "stream_end": | |
websocket_client.close() | |
return | |
if run_manager: | |
run_manager.on_llm_new_token(token=chunk.text) | |
else: | |
url = f"{self.model_url}/v1/chat/completions" | |
request["stream"] = True | |
request["messages"] = [ | |
{ 'content': message.content, 'role': TYPE_MAP.get(message.type, 'user') } | |
for message in prompt | |
] | |
req_session = requests.Session() | |
with req_session.post(url, json=request, stream=True) as response: | |
if response.status_code != 200: | |
print(f"ERROR: response: {response}") | |
return | |
for chunk in response.iter_lines(): | |
chunk = chunk.decode("utf8").strip() | |
if not chunk: continue | |
if chunk.startswith('data: '): chunk = chunk[6:] | |
chunk = json.loads(chunk) | |
chunk_msg = chunk["choices"][0]["message"] | |
yield ChatGenerationChunk( | |
text=chunk_msg["content"], | |
message=INVTYPE_MAP.get(chunk_msg["role"], HumanMessageChunk)(content=chunk_msg["content"]), | |
generation_info=None | |
) | |
if run_manager: | |
await run_manager.on_llm_new_token(token=chunk_msg["content"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment