Created
January 5, 2025 16:26
-
-
Save LK/c55b827f7b82aa7f9fd40dd9754f19d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from openai import OpenAI | |
import csv | |
import json | |
from typing import List, Optional, Literal, Union | |
from pydantic import BaseModel, Field, constr | |
import tempfile | |
import shutil | |
from pathlib import Path | |
import os | |
from io import StringIO | |
import argparse | |
# Fill in API keys here | |
PERPLEXITY_API_KEY = "" | |
OPENAI_API_KEY = "" | |
# Define valid subspecialties and practice environments | |
Subspecialty = Literal[ | |
"Oculoplastics", | |
"Cornea", | |
"Glaucoma", | |
"Uveitis", | |
"Pediatrics", | |
"Neuro", | |
"Med Retina", | |
"Surg Retina", | |
"Oncology", | |
"Comprehensive", | |
] | |
PracticeEnvironment = Literal["private", "academic"] | |
class OphthalmologistInfo(BaseModel): | |
state: List[str] = Field( | |
default_factory=list, description="List of 2-letter state codes" | |
) | |
state_citations: List[int] = Field( | |
default_factory=list, description="Citation indices for state" | |
) | |
gender: Optional[Literal["M", "F"]] = Field( | |
None, description="Gender of the ophthalmologist" | |
) | |
gender_citations: List[int] = Field( | |
default_factory=list, description="Citation indices for gender" | |
) | |
subspecialties: List[Subspecialty] = Field( | |
default_factory=list, description="List of subspecialties" | |
) | |
subspecialties_citations: List[int] = Field( | |
default_factory=list, description="Citation indices for subspecialties" | |
) | |
practice_environment: List[PracticeEnvironment] = Field( | |
default_factory=list, description="List of practice environments" | |
) | |
practice_environment_citations: List[int] = Field( | |
default_factory=list, description="Citation indices for practice environment" | |
) | |
perplexity_client = OpenAI( | |
api_key=PERPLEXITY_API_KEY, base_url="https://api.perplexity.ai" | |
) | |
openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
class DualLogger: | |
"""Logger that writes to both console and a file.""" | |
def __init__(self, log_dir: str, index: int, npi: str, name: str): | |
"""Initialize logger with log directory and doctor info. | |
Args: | |
log_dir: Directory to store log files | |
index: Index of the doctor being processed (1-based) | |
npi: NPI number of the doctor | |
name: Name of the doctor (will be sanitized for filename) | |
""" | |
self.console_buffer = StringIO() | |
self.index = index | |
# Create logs directory if it doesn't exist | |
os.makedirs(log_dir, exist_ok=True) | |
# Sanitize name for filename (remove special characters and convert to lowercase) | |
first_name = "".join(c.lower() for c in name.split("_")[0] if c.isalnum()) | |
last_name = "".join(c.lower() for c in name.split("_")[1] if c.isalnum()) | |
safe_name = f"{first_name}_{last_name}" | |
# Create filename with format: 0001_<npi>_name.txt | |
filename = f"{index:04d}_{npi}_{safe_name}.txt" | |
self.log_path = os.path.join(log_dir, filename) | |
# Open file for writing | |
self.file = open(self.log_path, "w", encoding="utf-8") | |
def print(self, *args, **kwargs): | |
"""Print to both console and file.""" | |
# Print to console | |
print(*args, **kwargs) | |
# Print to file | |
print(*args, file=self.file, **kwargs) | |
self.file.flush() # Ensure it's written immediately | |
def close(self): | |
"""Close the log file.""" | |
self.file.close() | |
def clean_ophth_data(raw_text: str) -> OphthalmologistInfo: | |
response = openai_client.beta.chat.completions.parse( | |
model="gpt-4o", | |
temperature=0, | |
messages=[ | |
{ | |
"role": "system", | |
"content": """Extract structured information about ophthalmologists from raw text into a JSON object with these fields: | |
- state: list of 2-letter state codes where they practice. If they primarily practice in only one state, prefer to just put that one. Only list multiple when they do not have one obvious primary state. | |
- state_citations: list of citation indices for state information | |
- gender: "M" or "F" (null if not found) | |
- gender_citations: list of citation indices for gender information | |
- subspecialties: list of specialties from [Oculoplastics, Cornea, Glaucoma, Uveitis, Pediatrics, Neuro, Med Retina, Surg Retina, Oncology, Comprehensive]. Use Comprehensive if no specialization is mentioned. If they have a more specific subspecialty than Comprehensive, DO NOT include Comprehensive as a subspecialty. | |
- subspecialties_citations: list of citation indices for subspecialties information | |
- practice_environment: list containing "private" and/or "academic" | |
- practice_environment_citations: list of citation indices for practice environment information | |
For each field, include the indices of citations that support that information. Citation indices start at 1 and correspond to the citations in the text marked with [X]. | |
""", | |
}, | |
{"role": "user", "content": raw_text}, | |
], | |
response_format=OphthalmologistInfo, | |
) | |
try: | |
raw_data = json.loads(response.choices[0].message.content) | |
return OphthalmologistInfo(**raw_data) | |
except (json.JSONDecodeError, IndexError, AttributeError): | |
return OphthalmologistInfo() | |
def get_specialization_info(first_name, last_name, npi, general_info=None): | |
messages = [ | |
{ | |
"role": "system", | |
"content": ( | |
"You are a focused researcher looking ONLY for an ophthalmologist's specialization(s). " | |
"You must find explicit evidence of their specialization(s) from one of these categories: " | |
"[Oculoplastics, Cornea, Glaucoma, Uveitis, Pediatrics, Neuro, Medical Retina, Surgical Retina, Oncology]. " | |
"Focus on finding their fellowship training or board certifications in these areas. " | |
"If you cannot find explicit evidence of one of these specializations, do not guess - " | |
"they should be considered Comprehensive. " | |
"Look especially for terms like 'fellowship trained in X' or 'specializes in X' where X is one of our categories. " | |
"DO NOT infer specialization from conditions they treat or general descriptions. " | |
"If you cannot find explicit specialization evidence, return the string COMPREHENSIVE." | |
"Note that it's *possible* to have more than one specialization." | |
), | |
}, | |
{ | |
"role": "user", | |
"content": ( | |
f"Find ONLY specialization information for ophthalmologist {first_name} {last_name} (NPI {npi}). " | |
f"Focus on fellowship training and board certifications.\n\n" | |
f"Additional biographical information from previous research:\n{general_info if general_info else 'None provided'}" | |
), | |
}, | |
] | |
response = perplexity_client.chat.completions.create( | |
model="llama-3.1-sonar-large-128k-online", | |
messages=messages, | |
temperature=0, | |
) | |
return { | |
"content": response.choices[0].message.content, | |
"citations": response.citations, | |
} | |
def get_general_info(first_name, last_name, npi): | |
messages = [ | |
{ | |
"role": "system", | |
"content": ( | |
"You are a happy and incredibly successful researcher who is finding information on ophthalmologists. " | |
"You DO NOT GIVE UP unless the answer is impossible to find. " | |
"You are going to be given a name and NPI number for an ophthalmologist. " | |
"You must generate a profile including their state of practice, gender, and practice type. " | |
"For practice type:\n" | |
"- Only mark as 'academic' if they have a CURRENT faculty position or academic appointment (e.g., Assistant Professor, Associate Professor, Professor), OR if they practice in an academic setting.\n" | |
"- Only mark as 'private' if you can find information about their specific private practice that is not affiliated with an academic institution.\n" | |
"- Past academic positions, teaching affiliations, or having trained at academic institutions do NOT qualify as 'academic'\n" | |
"- If they only work in private practice or there's no clear evidence of a current academic appointment, mark as 'private'\n" | |
"- If they have both private practice and practice in an academic setting, they may be marked as both 'private' and 'academic', although this is relatively rare.\n" | |
"- It's ok to mark this as UNSURE if you cannot determine.\n" | |
"DO NOT include specialization information as this will be handled separately. " | |
"If you cannot find the ophthalmologist OR you cannot be certain that it's the right one, you must return the string UNCERTAIN. " | |
"The NPI number *must* match the provider you are describing, else you must return the string UNCERTAIN." | |
), | |
}, | |
{ | |
"role": "user", | |
"content": ( | |
f"Please find the relevant information (excluding specialization) of ophthalmologist {first_name} {last_name} (NPI {npi})." | |
), | |
}, | |
] | |
response = perplexity_client.chat.completions.create( | |
model="llama-3.1-sonar-large-128k-online", | |
messages=messages, | |
temperature=0, | |
) | |
return { | |
"content": response.choices[0].message.content, | |
"citations": response.citations, | |
} | |
def format_list_for_csv(items: list, is_sources: bool = False) -> str: | |
"""Format a list for CSV storage in a clean, readable way. | |
Args: | |
items: List of items to format | |
is_sources: If True, format as "[1]url [2]url" instead of "url; url" | |
""" | |
if not items: | |
return "" | |
if is_sources: | |
return " ".join(f"[{i+1}]{url}" for i, url in enumerate(items)) | |
return "; ".join(str(item) for item in items) | |
def format_value_with_citations( | |
value: Union[str, List[str]], citations: List[int] | |
) -> tuple[str, str]: | |
"""Format a value and its citations separately. | |
Args: | |
value: Either a string or list of strings to format | |
citations: List of citation indices | |
Returns: | |
Tuple of (formatted_value, formatted_citations) | |
""" | |
formatted_value = ( | |
value if isinstance(value, str) else "; ".join(value) if value else "" | |
) | |
formatted_citations = format_list_for_csv(sorted(citations)) if citations else "" | |
return formatted_value, formatted_citations | |
def is_row_processed(row): | |
"""Check if a row has already been processed by checking if any computed fields have values.""" | |
computed_fields = [ | |
"Computed_State", | |
"Computed_Gender", | |
"Computed_Subspecialties", | |
"Computed_Practice_Environment", | |
"Computed_Sources", | |
] | |
return any(row.get(field, "").strip() for field in computed_fields) | |
def update_csv_with_new_data( | |
input_file: str, | |
start_row: int = 0, | |
num_rows: int = 10, | |
continue_processing: bool = False, | |
): | |
# First create a backup of the original file | |
backup_file = Path(input_file).with_suffix(".bak") | |
shutil.copy2(input_file, backup_file) | |
# Create logs directory next to the input file | |
log_dir = os.path.join(os.path.dirname(os.path.abspath(input_file)), "reasoning") | |
# Keep the backup file around in case we need it | |
print(f"\nBackup of original file saved as: {backup_file}") | |
# Get total number of rows | |
total_rows = sum(1 for line in open(backup_file)) - 1 # subtract header row | |
# Define which original columns to keep | |
original_columns = [ | |
"Physician_NPI", | |
"Physician_Last_Name", | |
"Physician_First_Name", | |
"Physician State", | |
"Physician_Gender", | |
"Subspecialty", | |
"Practice Environment", | |
] | |
# Define new columns to add | |
new_columns = [ | |
"Computed_State", | |
"Computed_State_Citations", | |
"Computed_Gender", | |
"Computed_Gender_Citations", | |
"Computed_Subspecialties", | |
"Computed_Subspecialties_Citations", | |
"Computed_Practice_Environment", | |
"Computed_Practice_Environment_Citations", | |
"Computed_Sources", | |
] | |
# Create new fieldnames list | |
new_fieldnames = original_columns + new_columns | |
if not continue_processing: | |
# First, write the entire file with empty computed fields only for unprocessed rows | |
with open(backup_file, "r") as file_in, open( | |
input_file, "w", newline="" | |
) as file_out: | |
reader = csv.DictReader(file_in) | |
writer = csv.DictWriter(file_out, fieldnames=new_fieldnames) | |
writer.writeheader() | |
for row in reader: | |
row = {k.strip(): v.strip() for k, v in row.items()} | |
new_row = {col: row.get(col, "") for col in original_columns} | |
# Preserve existing computed fields if they exist, otherwise initialize as empty | |
for col in new_columns: | |
new_row[col] = row.get(col, "").strip() | |
writer.writerow(new_row) | |
# Read all rows | |
with open(input_file, "r") as file: | |
reader = csv.DictReader(file) | |
rows = list(reader) | |
# Determine which rows need processing | |
rows_to_process = [] | |
for i in range(start_row, min(start_row + num_rows, total_rows)): | |
if not continue_processing or not is_row_processed(rows[i]): | |
rows_to_process.append(i) | |
print(f"Found {len(rows_to_process)} rows to process") | |
# Now process and update rows one by one | |
for i in rows_to_process: | |
# Create logger for this doctor - use i+1 for the actual CSV row number (1-based) | |
logger = DualLogger( | |
log_dir, | |
i + 1, # Use actual CSV row number (1-based) | |
rows[i]["Physician_NPI"], | |
f"{rows[i]['Physician_First_Name']}_{rows[i]['Physician_Last_Name']}", | |
) | |
try: | |
logger.print("=" * 80) | |
logger.print(f"Processing Doctor at row {i+1} of {total_rows}") | |
logger.print( | |
f"Name: {rows[i]['Physician_First_Name']} {rows[i]['Physician_Last_Name']}" | |
) | |
logger.print(f"NPI: {rows[i]['Physician_NPI']}") | |
logger.print("=" * 80) | |
# Get general information from Perplexity | |
info = get_general_info( | |
rows[i]["Physician_First_Name"], | |
rows[i]["Physician_Last_Name"], | |
rows[i]["Physician_NPI"], | |
) | |
raw_info = info["content"] | |
citations = info["citations"] | |
# Get specialization information from Perplexity, passing along general info | |
spec_info = get_specialization_info( | |
rows[i]["Physician_First_Name"], | |
rows[i]["Physician_Last_Name"], | |
rows[i]["Physician_NPI"], | |
raw_info, # Pass the general info here | |
) | |
spec_raw_info = spec_info["content"] | |
spec_citations = spec_info["citations"] | |
# Create citation mapping for deduplication | |
citation_map = {citation: idx + 1 for idx, citation in enumerate(citations)} | |
next_index = len(citations) + 1 | |
spec_citation_map = {} | |
adjusted_spec_citations = [] | |
for spec_citation in spec_citations: | |
if spec_citation in citation_map: | |
# Reuse the existing index for this citation | |
spec_citation_map[spec_citation] = citation_map[spec_citation] | |
adjusted_spec_citations.append(citation_map[spec_citation]) | |
else: | |
# Add new citation with next available index | |
citation_map[spec_citation] = next_index | |
spec_citation_map[spec_citation] = next_index | |
adjusted_spec_citations.append(next_index) | |
next_index += 1 | |
# Create final deduplicated citations list | |
all_citations = list(dict.fromkeys(citations + spec_citations)) | |
# Adjust citation numbers in both texts | |
general_info_text = raw_info | |
adjusted_spec_text = spec_raw_info | |
for citation_idx in range(len(spec_citations), 0, -1): | |
original_citation = spec_citations[citation_idx - 1] | |
new_index = spec_citation_map[original_citation] | |
# Replace [i] with the deduplicated index | |
adjusted_spec_text = adjusted_spec_text.replace( | |
f"[{citation_idx}]", f"[{new_index}]" | |
) | |
# Print responses with corrected citation indices | |
logger.print("\nGeneral Information:") | |
logger.print("-" * 40) | |
logger.print(general_info_text) | |
logger.print("\nSpecialization Analysis:") | |
logger.print("-" * 40) | |
logger.print(adjusted_spec_text) | |
# Print consolidated sources list | |
if all_citations: | |
logger.print("\nSources:") | |
logger.print("-" * 40) | |
for idx, citation in enumerate(all_citations, 1): | |
logger.print(f"{idx}. {citation}") | |
if "UNCERTAIN" in raw_info: | |
logger.print("\nStatus: Could not find reliable information") | |
rows[i]["Computed_State"] = "" | |
rows[i]["Computed_State_Citations"] = "" | |
rows[i]["Computed_Gender"] = "" | |
rows[i]["Computed_Gender_Citations"] = "" | |
rows[i]["Computed_Subspecialties"] = "" | |
rows[i]["Computed_Subspecialties_Citations"] = "" | |
rows[i]["Computed_Practice_Environment"] = "" | |
rows[i]["Computed_Practice_Environment_Citations"] = "" | |
rows[i]["Computed_Sources"] = "" | |
else: | |
# Create combined text for OpenAI processing | |
combined_text = f"General Information:\n{general_info_text}\n\nSpecialization Information:\n{adjusted_spec_text}" | |
structured_info = clean_ophth_data(combined_text) | |
# Print structured information | |
logger.print("\nStructured Information:") | |
logger.print("-" * 40) | |
logger.print(f"State: {structured_info.state or 'Not found'}") | |
if structured_info.state_citations: | |
logger.print( | |
f" Citations: {format_list_for_csv(map(str, structured_info.state_citations))}" | |
) | |
logger.print(f"Gender: {structured_info.gender or 'Not found'}") | |
if structured_info.gender_citations: | |
logger.print( | |
f" Citations: {format_list_for_csv(map(str, structured_info.gender_citations))}" | |
) | |
# Format subspecialties for display | |
subspecialties = sorted(structured_info.subspecialties) | |
if subspecialties: | |
logger.print("Subspecialties:") | |
for idx, specialty in enumerate(subspecialties, 1): | |
logger.print(f" {idx}. {specialty}") | |
if structured_info.subspecialties_citations: | |
logger.print( | |
f" Citations: {format_list_for_csv(map(str, structured_info.subspecialties_citations))}" | |
) | |
else: | |
logger.print("Subspecialties: None") | |
# Format practice environment for display | |
practice_env = sorted(structured_info.practice_environment) | |
if practice_env: | |
logger.print("Practice Environment:") | |
for idx, env in enumerate(practice_env, 1): | |
logger.print(f" {idx}. {env}") | |
if structured_info.practice_environment_citations: | |
logger.print( | |
f" Citations: {format_list_for_csv(map(str, structured_info.practice_environment_citations))}" | |
) | |
else: | |
logger.print("Practice Environment: None") | |
# Update row with new data | |
state_value, state_citations = format_value_with_citations( | |
structured_info.state or "", structured_info.state_citations | |
) | |
rows[i]["Computed_State"] = state_value | |
rows[i]["Computed_State_Citations"] = state_citations | |
gender_value, gender_citations = format_value_with_citations( | |
structured_info.gender or "", structured_info.gender_citations | |
) | |
rows[i]["Computed_Gender"] = gender_value | |
rows[i]["Computed_Gender_Citations"] = gender_citations | |
subspecialties_value, subspecialties_citations = ( | |
format_value_with_citations( | |
structured_info.subspecialties, | |
structured_info.subspecialties_citations, | |
) | |
) | |
rows[i]["Computed_Subspecialties"] = subspecialties_value | |
rows[i]["Computed_Subspecialties_Citations"] = subspecialties_citations | |
practice_env_value, practice_env_citations = ( | |
format_value_with_citations( | |
structured_info.practice_environment, | |
structured_info.practice_environment_citations, | |
) | |
) | |
rows[i]["Computed_Practice_Environment"] = practice_env_value | |
rows[i][ | |
"Computed_Practice_Environment_Citations" | |
] = practice_env_citations | |
rows[i]["Computed_Sources"] = format_list_for_csv( | |
all_citations, is_sources=True | |
) | |
except Exception as e: | |
logger.print(f"\nError processing doctor: {str(e)}") | |
# On error, ensure computed fields are empty | |
rows[i]["Computed_State"] = "" | |
rows[i]["Computed_State_Citations"] = "" | |
rows[i]["Computed_Gender"] = "" | |
rows[i]["Computed_Gender_Citations"] = "" | |
rows[i]["Computed_Subspecialties"] = "" | |
rows[i]["Computed_Subspecialties_Citations"] = "" | |
rows[i]["Computed_Practice_Environment"] = "" | |
rows[i]["Computed_Practice_Environment_Citations"] = "" | |
rows[i]["Computed_Sources"] = "" | |
finally: | |
# Always close the logger | |
logger.close() | |
# Write all rows back to the file after each row is processed | |
with open(input_file, "w", newline="") as file_out: | |
writer = csv.DictWriter(file_out, fieldnames=new_fieldnames) | |
writer.writeheader() | |
writer.writerows(rows) | |
def main(): | |
parser = argparse.ArgumentParser(description="Process ophthalmologist data") | |
parser.add_argument( | |
"--continue", | |
dest="continue_processing", | |
action="store_true", | |
help="Continue processing from where we left off, skipping already processed rows", | |
) | |
parser.add_argument( | |
"--start", | |
type=int, | |
default=0, | |
help="Start processing from this row number (0-based)", | |
) | |
parser.add_argument( | |
"--limit", | |
type=int, | |
default=10, | |
help="Number of rows to process (ignored if --continue is used)", | |
) | |
args = parser.parse_args() | |
# If --continue is used, set num_rows to a very large number to process all remaining rows | |
num_rows = 1_000_000 if args.continue_processing else args.limit | |
update_csv_with_new_data( | |
"ophth.csv", args.start, num_rows, args.continue_processing | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment