Created
May 14, 2025 03:04
-
-
Save usametov/422fc4740e653e01ed0404a12dce87a7 to your computer and use it in GitHub Desktop.
run storm against local qdrant instance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This STORM Wiki pipeline powered by Groq (Llama3-70B) and local Qdrant vector store with Ollama embeddings. | |
You need to set up the following environment variables to run this script: | |
- GROQ_API_KEY: Groq Cloud API key | |
- QDRANT_API_KEY: Qdrant API key (optional, only needed if using authenticated Qdrant) | |
Requirements: | |
1. Local Qdrant instance running (default: http://localhost:6333) | |
2. Ollama serving the embedding model (default: all-minilm:l6-v2) | |
The script uses an existing Qdrant collection named 'storm_collection' by default. | |
The collection should contain documents with these fields: | |
- content: Main text content | |
- title: Document title | |
- url: Unique document identifier | |
- description: Optional description | |
Output structure: | |
args.output_dir/ | |
topic_name/ # Underscore-formatted topic name | |
conversation_log.json # Information-seeking conversation log | |
raw_search_results.json # Raw retrieval results | |
direct_gen_outline.txt # Initial LLM-generated outline | |
storm_gen_outline.txt # Refined outline with collected info | |
url_to_info.json # Sources used in final article | |
storm_gen_article.txt # Final generated article | |
storm_gen_article_polished.txt # Polished article (if args.do_polish_article) | |
""" | |
import os | |
from argparse import ArgumentParser | |
from knowledge_storm import ( | |
STORMWikiRunnerArguments, | |
STORMWikiRunner, | |
STORMWikiLMConfigs, | |
) | |
from knowledge_storm.rm import VectorRM | |
from knowledge_storm.lm import GroqModel | |
from knowledge_storm.utils import load_api_key #, QdrantVectorStoreManager | |
def main(args): | |
# Load API key from the specified toml file path | |
load_api_key(toml_file_path=os.path.expanduser("~/.storm_secrets/secrets.toml")) | |
# Initialize the language model configurations | |
engine_lm_configs = STORMWikiLMConfigs() | |
groq_kwargs = { | |
"api_key": os.getenv("GROQ_API_KEY"), # Make sure to set this | |
"model": "llama3-70b-8192", # "mixtral-8x7b-32768", # or "llama3-70b-8192" depending on your preference | |
"temperature": 1.0, | |
"top_p": 0.9, | |
} | |
# STORM is a LM system so different components can be powered by different models. | |
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm | |
# which is used to split queries, synthesize answers in the conversation. We recommend using stronger models | |
# for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm | |
# which is responsible for generating sections with citations. | |
conv_simulator_lm = GroqModel(max_tokens=500, **groq_kwargs) | |
question_asker_lm = GroqModel(max_tokens=500, **groq_kwargs) | |
outline_gen_lm = GroqModel(max_tokens=400, **groq_kwargs) | |
article_gen_lm = GroqModel(max_tokens=700, **groq_kwargs) | |
article_polish_lm = GroqModel(max_tokens=4000, **groq_kwargs) | |
engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm) | |
engine_lm_configs.set_question_asker_lm(question_asker_lm) | |
engine_lm_configs.set_outline_gen_lm(outline_gen_lm) | |
engine_lm_configs.set_article_gen_lm(article_gen_lm) | |
engine_lm_configs.set_article_polish_lm(article_polish_lm) | |
# Initialize the engine arguments | |
engine_args = STORMWikiRunnerArguments( | |
output_dir=args.output_dir, | |
max_conv_turn=args.max_conv_turn, | |
max_perspective=args.max_perspective, | |
search_top_k=args.search_top_k, | |
max_thread_num=args.max_thread_num, | |
) | |
# Setup VectorRM to retrieve information from your own data | |
rm = VectorRM( | |
collection_name=args.collection_name, | |
embedding_model=args.embedding_model, | |
device=args.device, | |
k=engine_args.search_top_k, | |
) | |
# initialize the vector store, either online (store the db on Qdrant server) or offline (store the db locally): | |
if args.vector_db_mode == "offline": | |
rm.init_offline_vector_db(vector_store_path=args.offline_vector_db_dir) | |
elif args.vector_db_mode == "online": | |
rm.init_online_vector_db( | |
url=args.online_vector_db_url, api_key=os.getenv("QDRANT_API_KEY") | |
) | |
# Initialize the STORM Wiki Runner | |
runner = STORMWikiRunner(engine_args, engine_lm_configs, rm) | |
# run the pipeline | |
topic = input("Topic: ") | |
runner.run( | |
topic=topic, | |
do_research=args.do_research, | |
do_generate_outline=args.do_generate_outline, | |
do_generate_article=args.do_generate_article, | |
do_polish_article=args.do_polish_article, | |
) | |
runner.post_run() | |
runner.summary() | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
# global arguments | |
parser.add_argument( | |
"--output-dir", | |
type=str, | |
default="./results/gpt_retrieval", | |
help="Directory to store the outputs.", | |
) | |
parser.add_argument( | |
"--max-thread-num", | |
type=int, | |
default=3, | |
help="Maximum number of threads to use. The information seeking part and the article generation" | |
"part can speed up by using multiple threads. Consider reducing it if keep getting " | |
'"Exceed rate limit" error when calling LM API.', | |
) | |
# provide local corpus and set up vector db | |
parser.add_argument( | |
"--collection-name", | |
type=str, | |
default="storm_collection", | |
help="The collection name for vector store.", | |
) | |
# stage of the pipeline | |
parser.add_argument( | |
"--do-research", | |
action="store_true", | |
help="If True, simulate conversation to research the topic; otherwise, load the results.", | |
) | |
# parser.add_argument( | |
# "--do-generate-outline", | |
# action="store_true", | |
# help="If True, generate an outline for the topic; otherwise, load the results.", | |
# ) | |
parser.add_argument( | |
"--do-generate-article", | |
action="store_true", | |
help="If True, generate an article for the topic; otherwise, load the results.", | |
) | |
parser.add_argument( | |
"--do-polish-article", | |
action="store_true", | |
help="If True, polish the article by adding a summarization section and (optionally) removing " | |
"duplicate content.", | |
) | |
# hyperparameters for the pre-writing stage | |
parser.add_argument( | |
"--max-conv-turn", | |
type=int, | |
default=3, | |
help="Maximum number of questions in conversational question asking.", | |
) | |
parser.add_argument( | |
"--max-perspective", | |
type=int, | |
default=3, | |
help="Maximum number of perspectives to consider in perspective-guided question asking.", | |
) | |
parser.add_argument( | |
"--search-top-k", | |
type=int, | |
default=5, | |
help="Top k search results to consider for each search query.", | |
) | |
# hyperparameters for the writing stage | |
parser.add_argument( | |
"--retrieve-top-k", | |
type=int, | |
default=5, | |
help="Top k collected references for each section title.", | |
) | |
parser.add_argument( | |
"--remove-duplicate", | |
action="store_true", | |
help="If True, remove duplicate content from the article.", | |
) | |
parser.add_argument( | |
"--vector-db-mode", | |
type=str, | |
choices=["online", "offline"], | |
default="online", | |
help="Whether to use online (Qdrant server) or offline (local) vector database" | |
) | |
parser.add_argument( | |
"--online-vector-db-url", | |
type=str, | |
default="http://localhost:6333", | |
help="URL of the online Qdrant vector database" | |
) | |
parser.add_argument( | |
"--offline-vector-db-dir", | |
type=str, | |
default="./vector_db", | |
help="Directory for offline vector database storage" | |
) | |
parser.add_argument( | |
"--embedding-model", | |
type=str, | |
default="sentence-transformers/all-MiniLM-L6-v2", # Standard model | |
help="Embedding model to use for vector retrieval" | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
default="cpu", | |
choices=["cpu", "cuda"], | |
help="Device to run the embedding model on" | |
) | |
parser.add_argument( | |
"--do-generate-outline", | |
action="store_true", | |
default=True, # Changed from default False | |
help="If True, generate an outline for the topic" | |
) | |
parser.add_argument( | |
"--include-citations", | |
action="store_true", | |
default=True, # Set False to disable | |
help="Include [1][2] style citations in generated text" | |
) | |
parser.add_argument( | |
"--show-urls", | |
type=str, | |
choices=["inline", "footnotes", "bibliography", "none"], | |
default="inline", | |
help="Control how URLs appear in generated text" | |
) | |
args = parser.parse_args() | |
main(args) | |
# example run: | |
# python run_storm_wiki_gpt_with_VectorRM.py --do-research --do-generate-outline --do-generate-article --output-dir ./results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment