Last active
July 12, 2024 13:47
-
-
Save datasciencemonkey/a8df9d4002f27f3df27fd1a6787463c7 to your computer and use it in GitHub Desktop.
dspy module for an LLM Pipeline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from utils import get_country_data_and_corr | |
class WriteEssayAfterParsingUserQuery(dspy.Module): | |
def __init__(self): | |
super().__init__() | |
self.country = dspy.Predict(ExtractCountry) | |
self.essay_writer = dspy.ChainOfThought(EssayOnePass) | |
def forward(self, question, constraint, persona): | |
with dspy.settings.context(lm=qwen2): # Call QWEN2 for extracting country name | |
country_name = self.country(question=question).result | |
print(f"extracted country ==> {country_name}") | |
if country_name == 'unknown': | |
return "country unknown. please try again." | |
# next call dbsql to run the query based on the country, also get a correlation matrix | |
country_data, corr_df = get_country_data_and_corr(country_name=country_name) | |
stringified_tbl = json.dumps(country_data.to_dict(orient="records")) | |
stringified_corr_tbl = json.dumps(corr_df.to_dict()) | |
with dspy.settings.context(lm=dspy.Databricks(model='openai-chat-endpoint-sg-4o', | |
model_type='chat', | |
api_key = API_KEY, | |
api_base = API_BASE, | |
max_tokens=2000, | |
temperature=round(0.7+(random.randint(1,100)/10000),4) | |
)): | |
essay = self.essay_writer(trends_context=stringified_tbl, | |
correlation_context=stringified_corr_tbl, | |
question=question | |
) | |
with dspy.settings.context(lm=dspy.Databricks(model='openai-chat-endpoint-sg-4o', | |
model_type='chat', | |
api_key = API_KEY, | |
api_base = API_BASE, | |
max_tokens=2000, | |
temperature=round(0.7+(random.randint(1,100)/10000),4) | |
)): | |
#Impose Suggestions - offering opportunity for self refinement - Persona Check! | |
dspy.Suggest( | |
result_check(essay.result, | |
statement = f'Essay has only simple sentences easily read by {persona}'), | |
constraint, | |
target_module=EssayOnePass) | |
# Calculate the consensus text grade required | |
text_standard = textstat.text_standard(essay.result, float_output=True) | |
print(f"Consensus text standard: {text_standard}") | |
return (essay.result, essay.rationale) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment