Created
March 19, 2024 10:04
-
-
Save kshitijcode/af9f179db7298bcad1572a1ce068aea9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import aiohttp | |
import argparse | |
import asyncio | |
import json | |
import logging | |
import re | |
import tiktoken | |
import time | |
from status_tracker import StatusTracker | |
from api_request import APIRequest | |
import os | |
import csv | |
async def process_api_requests_from_file( | |
requests_filepath: str, | |
save_filepath: str, | |
request_url: str, | |
api_key: str, | |
token_encoding_name: str, | |
max_attempts: int, | |
logging_level: int, | |
status_tracker: StatusTracker | |
): | |
"""Processes API requests in parallel.""" | |
logging.basicConfig(level=logging_level) | |
logging.debug(f"Logging initialized at level {logging_level}") | |
# infer API endpoint and construct request header | |
api_endpoint = api_endpoint_from_url(request_url) | |
request_header = {"Authorization": f"Bearer {api_key}"} | |
queue_of_requests_to_retry = asyncio.Queue() | |
task_id_generator = task_id_generator_function() | |
next_request = None | |
file_not_finished = True | |
start_time = time.time() | |
with open(requests_filepath) as file: | |
requests = file.__iter__() | |
logging.debug("File opened. Entering main loop") | |
async with aiohttp.ClientSession() as session: | |
while True: | |
if next_request is None and not queue_of_requests_to_retry.empty(): | |
next_request = queue_of_requests_to_retry.get_nowait() | |
logging.debug(f"Retrying request {next_request.task_id}") | |
elif next_request is None and file_not_finished: | |
try: | |
request_json = json.loads(next(requests)) | |
next_request = APIRequest( | |
task_id=next(task_id_generator), | |
request_json=request_json, | |
token_consumption=0, # Token consumption is not tracked | |
attempts_left=max_attempts, | |
metadata=request_json.pop("metadata", None) | |
) | |
status_tracker.num_tasks_started += 1 | |
status_tracker.num_tasks_in_progress += 1 | |
except StopIteration: | |
file_not_finished = False | |
if next_request: | |
next_request.attempts_left -= 1 | |
asyncio.create_task( | |
next_request.call_api( | |
session=session, | |
request_url=request_url, | |
request_header=request_header, | |
retry_queue=queue_of_requests_to_retry, | |
save_filepath=save_filepath, | |
status_tracker=status_tracker, | |
) | |
) | |
next_request = None | |
if status_tracker.num_tasks_in_progress == 0 and not file_not_finished: | |
break | |
await asyncio.sleep(0.1) # Sleep to allow other tasks to process | |
logging.info(f"Parallel processing complete. Results saved to {save_filepath}") | |
def api_endpoint_from_url(request_url): | |
match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url) | |
if match is None: | |
match = re.search(r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url) | |
return match[1] | |
def task_id_generator_function(): | |
task_id = 0 | |
while True: | |
yield task_id | |
task_id += 1 | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--requests_filepath", default='data/real_data_modified.jsonl') | |
parser.add_argument("--save_filepath", default=None) | |
parser.add_argument("--request_url", default="https://openai-spike-wm.openai.azure.com/openai/deployments/gpt-4-32-k/chat/completions?api-version=2023-07-01-preview") | |
parser.add_argument("--api_key", default="") | |
parser.add_argument("--token_encoding_name", default="cl100k_base") | |
parser.add_argument("--max_attempts", type=int, default=30) | |
parser.add_argument("--logging_level", default=logging.INFO) | |
args = parser.parse_args() | |
if args.save_filepath is None: | |
args.save_filepath = args.requests_filepath.replace(".jsonl", "_results.jsonl") | |
status_tracker = StatusTracker() | |
async def run_script(): | |
await process_api_requests_from_file( | |
requests_filepath=args.requests_filepath, | |
save_filepath=args.save_filepath, | |
request_url=args.request_url, | |
api_key=args.api_key, | |
token_encoding_name=args.token_encoding_name, | |
max_attempts=int(args.max_attempts), | |
logging_level=int(args.logging_level), | |
status_tracker=status_tracker | |
) | |
asyncio.run(run_script()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment