Skip to content

Instantly share code, notes, and snippets.

@LK
Created January 5, 2025 16:26
Show Gist options
  • Save LK/c55b827f7b82aa7f9fd40dd9754f19d2 to your computer and use it in GitHub Desktop.
Save LK/c55b827f7b82aa7f9fd40dd9754f19d2 to your computer and use it in GitHub Desktop.
from openai import OpenAI
import csv
import json
from typing import List, Optional, Literal, Union
from pydantic import BaseModel, Field, constr
import tempfile
import shutil
from pathlib import Path
import os
from io import StringIO
import argparse
# Fill in API keys here
PERPLEXITY_API_KEY = ""
OPENAI_API_KEY = ""
# Define valid subspecialties and practice environments
Subspecialty = Literal[
"Oculoplastics",
"Cornea",
"Glaucoma",
"Uveitis",
"Pediatrics",
"Neuro",
"Med Retina",
"Surg Retina",
"Oncology",
"Comprehensive",
]
PracticeEnvironment = Literal["private", "academic"]
class OphthalmologistInfo(BaseModel):
state: List[str] = Field(
default_factory=list, description="List of 2-letter state codes"
)
state_citations: List[int] = Field(
default_factory=list, description="Citation indices for state"
)
gender: Optional[Literal["M", "F"]] = Field(
None, description="Gender of the ophthalmologist"
)
gender_citations: List[int] = Field(
default_factory=list, description="Citation indices for gender"
)
subspecialties: List[Subspecialty] = Field(
default_factory=list, description="List of subspecialties"
)
subspecialties_citations: List[int] = Field(
default_factory=list, description="Citation indices for subspecialties"
)
practice_environment: List[PracticeEnvironment] = Field(
default_factory=list, description="List of practice environments"
)
practice_environment_citations: List[int] = Field(
default_factory=list, description="Citation indices for practice environment"
)
perplexity_client = OpenAI(
api_key=PERPLEXITY_API_KEY, base_url="https://api.perplexity.ai"
)
openai_client = OpenAI(api_key=OPENAI_API_KEY)
class DualLogger:
"""Logger that writes to both console and a file."""
def __init__(self, log_dir: str, index: int, npi: str, name: str):
"""Initialize logger with log directory and doctor info.
Args:
log_dir: Directory to store log files
index: Index of the doctor being processed (1-based)
npi: NPI number of the doctor
name: Name of the doctor (will be sanitized for filename)
"""
self.console_buffer = StringIO()
self.index = index
# Create logs directory if it doesn't exist
os.makedirs(log_dir, exist_ok=True)
# Sanitize name for filename (remove special characters and convert to lowercase)
first_name = "".join(c.lower() for c in name.split("_")[0] if c.isalnum())
last_name = "".join(c.lower() for c in name.split("_")[1] if c.isalnum())
safe_name = f"{first_name}_{last_name}"
# Create filename with format: 0001_<npi>_name.txt
filename = f"{index:04d}_{npi}_{safe_name}.txt"
self.log_path = os.path.join(log_dir, filename)
# Open file for writing
self.file = open(self.log_path, "w", encoding="utf-8")
def print(self, *args, **kwargs):
"""Print to both console and file."""
# Print to console
print(*args, **kwargs)
# Print to file
print(*args, file=self.file, **kwargs)
self.file.flush() # Ensure it's written immediately
def close(self):
"""Close the log file."""
self.file.close()
def clean_ophth_data(raw_text: str) -> OphthalmologistInfo:
response = openai_client.beta.chat.completions.parse(
model="gpt-4o",
temperature=0,
messages=[
{
"role": "system",
"content": """Extract structured information about ophthalmologists from raw text into a JSON object with these fields:
- state: list of 2-letter state codes where they practice. If they primarily practice in only one state, prefer to just put that one. Only list multiple when they do not have one obvious primary state.
- state_citations: list of citation indices for state information
- gender: "M" or "F" (null if not found)
- gender_citations: list of citation indices for gender information
- subspecialties: list of specialties from [Oculoplastics, Cornea, Glaucoma, Uveitis, Pediatrics, Neuro, Med Retina, Surg Retina, Oncology, Comprehensive]. Use Comprehensive if no specialization is mentioned. If they have a more specific subspecialty than Comprehensive, DO NOT include Comprehensive as a subspecialty.
- subspecialties_citations: list of citation indices for subspecialties information
- practice_environment: list containing "private" and/or "academic"
- practice_environment_citations: list of citation indices for practice environment information
For each field, include the indices of citations that support that information. Citation indices start at 1 and correspond to the citations in the text marked with [X].
""",
},
{"role": "user", "content": raw_text},
],
response_format=OphthalmologistInfo,
)
try:
raw_data = json.loads(response.choices[0].message.content)
return OphthalmologistInfo(**raw_data)
except (json.JSONDecodeError, IndexError, AttributeError):
return OphthalmologistInfo()
def get_specialization_info(first_name, last_name, npi, general_info=None):
messages = [
{
"role": "system",
"content": (
"You are a focused researcher looking ONLY for an ophthalmologist's specialization(s). "
"You must find explicit evidence of their specialization(s) from one of these categories: "
"[Oculoplastics, Cornea, Glaucoma, Uveitis, Pediatrics, Neuro, Medical Retina, Surgical Retina, Oncology]. "
"Focus on finding their fellowship training or board certifications in these areas. "
"If you cannot find explicit evidence of one of these specializations, do not guess - "
"they should be considered Comprehensive. "
"Look especially for terms like 'fellowship trained in X' or 'specializes in X' where X is one of our categories. "
"DO NOT infer specialization from conditions they treat or general descriptions. "
"If you cannot find explicit specialization evidence, return the string COMPREHENSIVE."
"Note that it's *possible* to have more than one specialization."
),
},
{
"role": "user",
"content": (
f"Find ONLY specialization information for ophthalmologist {first_name} {last_name} (NPI {npi}). "
f"Focus on fellowship training and board certifications.\n\n"
f"Additional biographical information from previous research:\n{general_info if general_info else 'None provided'}"
),
},
]
response = perplexity_client.chat.completions.create(
model="llama-3.1-sonar-large-128k-online",
messages=messages,
temperature=0,
)
return {
"content": response.choices[0].message.content,
"citations": response.citations,
}
def get_general_info(first_name, last_name, npi):
messages = [
{
"role": "system",
"content": (
"You are a happy and incredibly successful researcher who is finding information on ophthalmologists. "
"You DO NOT GIVE UP unless the answer is impossible to find. "
"You are going to be given a name and NPI number for an ophthalmologist. "
"You must generate a profile including their state of practice, gender, and practice type. "
"For practice type:\n"
"- Only mark as 'academic' if they have a CURRENT faculty position or academic appointment (e.g., Assistant Professor, Associate Professor, Professor), OR if they practice in an academic setting.\n"
"- Only mark as 'private' if you can find information about their specific private practice that is not affiliated with an academic institution.\n"
"- Past academic positions, teaching affiliations, or having trained at academic institutions do NOT qualify as 'academic'\n"
"- If they only work in private practice or there's no clear evidence of a current academic appointment, mark as 'private'\n"
"- If they have both private practice and practice in an academic setting, they may be marked as both 'private' and 'academic', although this is relatively rare.\n"
"- It's ok to mark this as UNSURE if you cannot determine.\n"
"DO NOT include specialization information as this will be handled separately. "
"If you cannot find the ophthalmologist OR you cannot be certain that it's the right one, you must return the string UNCERTAIN. "
"The NPI number *must* match the provider you are describing, else you must return the string UNCERTAIN."
),
},
{
"role": "user",
"content": (
f"Please find the relevant information (excluding specialization) of ophthalmologist {first_name} {last_name} (NPI {npi})."
),
},
]
response = perplexity_client.chat.completions.create(
model="llama-3.1-sonar-large-128k-online",
messages=messages,
temperature=0,
)
return {
"content": response.choices[0].message.content,
"citations": response.citations,
}
def format_list_for_csv(items: list, is_sources: bool = False) -> str:
"""Format a list for CSV storage in a clean, readable way.
Args:
items: List of items to format
is_sources: If True, format as "[1]url [2]url" instead of "url; url"
"""
if not items:
return ""
if is_sources:
return " ".join(f"[{i+1}]{url}" for i, url in enumerate(items))
return "; ".join(str(item) for item in items)
def format_value_with_citations(
value: Union[str, List[str]], citations: List[int]
) -> tuple[str, str]:
"""Format a value and its citations separately.
Args:
value: Either a string or list of strings to format
citations: List of citation indices
Returns:
Tuple of (formatted_value, formatted_citations)
"""
formatted_value = (
value if isinstance(value, str) else "; ".join(value) if value else ""
)
formatted_citations = format_list_for_csv(sorted(citations)) if citations else ""
return formatted_value, formatted_citations
def is_row_processed(row):
"""Check if a row has already been processed by checking if any computed fields have values."""
computed_fields = [
"Computed_State",
"Computed_Gender",
"Computed_Subspecialties",
"Computed_Practice_Environment",
"Computed_Sources",
]
return any(row.get(field, "").strip() for field in computed_fields)
def update_csv_with_new_data(
input_file: str,
start_row: int = 0,
num_rows: int = 10,
continue_processing: bool = False,
):
# First create a backup of the original file
backup_file = Path(input_file).with_suffix(".bak")
shutil.copy2(input_file, backup_file)
# Create logs directory next to the input file
log_dir = os.path.join(os.path.dirname(os.path.abspath(input_file)), "reasoning")
# Keep the backup file around in case we need it
print(f"\nBackup of original file saved as: {backup_file}")
# Get total number of rows
total_rows = sum(1 for line in open(backup_file)) - 1 # subtract header row
# Define which original columns to keep
original_columns = [
"Physician_NPI",
"Physician_Last_Name",
"Physician_First_Name",
"Physician State",
"Physician_Gender",
"Subspecialty",
"Practice Environment",
]
# Define new columns to add
new_columns = [
"Computed_State",
"Computed_State_Citations",
"Computed_Gender",
"Computed_Gender_Citations",
"Computed_Subspecialties",
"Computed_Subspecialties_Citations",
"Computed_Practice_Environment",
"Computed_Practice_Environment_Citations",
"Computed_Sources",
]
# Create new fieldnames list
new_fieldnames = original_columns + new_columns
if not continue_processing:
# First, write the entire file with empty computed fields only for unprocessed rows
with open(backup_file, "r") as file_in, open(
input_file, "w", newline=""
) as file_out:
reader = csv.DictReader(file_in)
writer = csv.DictWriter(file_out, fieldnames=new_fieldnames)
writer.writeheader()
for row in reader:
row = {k.strip(): v.strip() for k, v in row.items()}
new_row = {col: row.get(col, "") for col in original_columns}
# Preserve existing computed fields if they exist, otherwise initialize as empty
for col in new_columns:
new_row[col] = row.get(col, "").strip()
writer.writerow(new_row)
# Read all rows
with open(input_file, "r") as file:
reader = csv.DictReader(file)
rows = list(reader)
# Determine which rows need processing
rows_to_process = []
for i in range(start_row, min(start_row + num_rows, total_rows)):
if not continue_processing or not is_row_processed(rows[i]):
rows_to_process.append(i)
print(f"Found {len(rows_to_process)} rows to process")
# Now process and update rows one by one
for i in rows_to_process:
# Create logger for this doctor - use i+1 for the actual CSV row number (1-based)
logger = DualLogger(
log_dir,
i + 1, # Use actual CSV row number (1-based)
rows[i]["Physician_NPI"],
f"{rows[i]['Physician_First_Name']}_{rows[i]['Physician_Last_Name']}",
)
try:
logger.print("=" * 80)
logger.print(f"Processing Doctor at row {i+1} of {total_rows}")
logger.print(
f"Name: {rows[i]['Physician_First_Name']} {rows[i]['Physician_Last_Name']}"
)
logger.print(f"NPI: {rows[i]['Physician_NPI']}")
logger.print("=" * 80)
# Get general information from Perplexity
info = get_general_info(
rows[i]["Physician_First_Name"],
rows[i]["Physician_Last_Name"],
rows[i]["Physician_NPI"],
)
raw_info = info["content"]
citations = info["citations"]
# Get specialization information from Perplexity, passing along general info
spec_info = get_specialization_info(
rows[i]["Physician_First_Name"],
rows[i]["Physician_Last_Name"],
rows[i]["Physician_NPI"],
raw_info, # Pass the general info here
)
spec_raw_info = spec_info["content"]
spec_citations = spec_info["citations"]
# Create citation mapping for deduplication
citation_map = {citation: idx + 1 for idx, citation in enumerate(citations)}
next_index = len(citations) + 1
spec_citation_map = {}
adjusted_spec_citations = []
for spec_citation in spec_citations:
if spec_citation in citation_map:
# Reuse the existing index for this citation
spec_citation_map[spec_citation] = citation_map[spec_citation]
adjusted_spec_citations.append(citation_map[spec_citation])
else:
# Add new citation with next available index
citation_map[spec_citation] = next_index
spec_citation_map[spec_citation] = next_index
adjusted_spec_citations.append(next_index)
next_index += 1
# Create final deduplicated citations list
all_citations = list(dict.fromkeys(citations + spec_citations))
# Adjust citation numbers in both texts
general_info_text = raw_info
adjusted_spec_text = spec_raw_info
for citation_idx in range(len(spec_citations), 0, -1):
original_citation = spec_citations[citation_idx - 1]
new_index = spec_citation_map[original_citation]
# Replace [i] with the deduplicated index
adjusted_spec_text = adjusted_spec_text.replace(
f"[{citation_idx}]", f"[{new_index}]"
)
# Print responses with corrected citation indices
logger.print("\nGeneral Information:")
logger.print("-" * 40)
logger.print(general_info_text)
logger.print("\nSpecialization Analysis:")
logger.print("-" * 40)
logger.print(adjusted_spec_text)
# Print consolidated sources list
if all_citations:
logger.print("\nSources:")
logger.print("-" * 40)
for idx, citation in enumerate(all_citations, 1):
logger.print(f"{idx}. {citation}")
if "UNCERTAIN" in raw_info:
logger.print("\nStatus: Could not find reliable information")
rows[i]["Computed_State"] = ""
rows[i]["Computed_State_Citations"] = ""
rows[i]["Computed_Gender"] = ""
rows[i]["Computed_Gender_Citations"] = ""
rows[i]["Computed_Subspecialties"] = ""
rows[i]["Computed_Subspecialties_Citations"] = ""
rows[i]["Computed_Practice_Environment"] = ""
rows[i]["Computed_Practice_Environment_Citations"] = ""
rows[i]["Computed_Sources"] = ""
else:
# Create combined text for OpenAI processing
combined_text = f"General Information:\n{general_info_text}\n\nSpecialization Information:\n{adjusted_spec_text}"
structured_info = clean_ophth_data(combined_text)
# Print structured information
logger.print("\nStructured Information:")
logger.print("-" * 40)
logger.print(f"State: {structured_info.state or 'Not found'}")
if structured_info.state_citations:
logger.print(
f" Citations: {format_list_for_csv(map(str, structured_info.state_citations))}"
)
logger.print(f"Gender: {structured_info.gender or 'Not found'}")
if structured_info.gender_citations:
logger.print(
f" Citations: {format_list_for_csv(map(str, structured_info.gender_citations))}"
)
# Format subspecialties for display
subspecialties = sorted(structured_info.subspecialties)
if subspecialties:
logger.print("Subspecialties:")
for idx, specialty in enumerate(subspecialties, 1):
logger.print(f" {idx}. {specialty}")
if structured_info.subspecialties_citations:
logger.print(
f" Citations: {format_list_for_csv(map(str, structured_info.subspecialties_citations))}"
)
else:
logger.print("Subspecialties: None")
# Format practice environment for display
practice_env = sorted(structured_info.practice_environment)
if practice_env:
logger.print("Practice Environment:")
for idx, env in enumerate(practice_env, 1):
logger.print(f" {idx}. {env}")
if structured_info.practice_environment_citations:
logger.print(
f" Citations: {format_list_for_csv(map(str, structured_info.practice_environment_citations))}"
)
else:
logger.print("Practice Environment: None")
# Update row with new data
state_value, state_citations = format_value_with_citations(
structured_info.state or "", structured_info.state_citations
)
rows[i]["Computed_State"] = state_value
rows[i]["Computed_State_Citations"] = state_citations
gender_value, gender_citations = format_value_with_citations(
structured_info.gender or "", structured_info.gender_citations
)
rows[i]["Computed_Gender"] = gender_value
rows[i]["Computed_Gender_Citations"] = gender_citations
subspecialties_value, subspecialties_citations = (
format_value_with_citations(
structured_info.subspecialties,
structured_info.subspecialties_citations,
)
)
rows[i]["Computed_Subspecialties"] = subspecialties_value
rows[i]["Computed_Subspecialties_Citations"] = subspecialties_citations
practice_env_value, practice_env_citations = (
format_value_with_citations(
structured_info.practice_environment,
structured_info.practice_environment_citations,
)
)
rows[i]["Computed_Practice_Environment"] = practice_env_value
rows[i][
"Computed_Practice_Environment_Citations"
] = practice_env_citations
rows[i]["Computed_Sources"] = format_list_for_csv(
all_citations, is_sources=True
)
except Exception as e:
logger.print(f"\nError processing doctor: {str(e)}")
# On error, ensure computed fields are empty
rows[i]["Computed_State"] = ""
rows[i]["Computed_State_Citations"] = ""
rows[i]["Computed_Gender"] = ""
rows[i]["Computed_Gender_Citations"] = ""
rows[i]["Computed_Subspecialties"] = ""
rows[i]["Computed_Subspecialties_Citations"] = ""
rows[i]["Computed_Practice_Environment"] = ""
rows[i]["Computed_Practice_Environment_Citations"] = ""
rows[i]["Computed_Sources"] = ""
finally:
# Always close the logger
logger.close()
# Write all rows back to the file after each row is processed
with open(input_file, "w", newline="") as file_out:
writer = csv.DictWriter(file_out, fieldnames=new_fieldnames)
writer.writeheader()
writer.writerows(rows)
def main():
parser = argparse.ArgumentParser(description="Process ophthalmologist data")
parser.add_argument(
"--continue",
dest="continue_processing",
action="store_true",
help="Continue processing from where we left off, skipping already processed rows",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Start processing from this row number (0-based)",
)
parser.add_argument(
"--limit",
type=int,
default=10,
help="Number of rows to process (ignored if --continue is used)",
)
args = parser.parse_args()
# If --continue is used, set num_rows to a very large number to process all remaining rows
num_rows = 1_000_000 if args.continue_processing else args.limit
update_csv_with_new_data(
"ophth.csv", args.start, num_rows, args.continue_processing
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment