Last active
April 2, 2024 23:31
-
-
Save palewire/db35df6c637372b8d484f20549ac3782 to your computer and use it in GitHub Desktop.
Python helpers for running Amazon Athena queries
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Utilities for working Amazon Athena. | |
Example: | |
To run a query and get the results as a pandas DataFrame: | |
>>> query = "SELECT * FROM my_table" | |
>>> df = get_df_from_athena(query) | |
>>> df.head() | |
Required Environment Variables: | |
AWS_ACCESS_KEY_ID (str): AWS access key ID. | |
AWS_SECRET_ACCESS_KEY (str): AWS secret access key. | |
AWS_REGION_NAME (str): AWS region name. | |
AWS_S3_BUCKET_NAME (str): AWS S3 bucket name. | |
""" | |
from __future__ import annotations | |
import io | |
import os | |
import time | |
import boto3 | |
import pandas as pd | |
from rich import print | |
def get_df_from_athena( | |
query: str, | |
verbose: bool = False, | |
athena_directory: str = "athena-workspace", | |
**kwargs, | |
) -> pd.DataFrame: | |
"""Get pandas DataFrame from Amazon Athena query. | |
Args: | |
query : str | |
formatted string containing athena sql query | |
verbose : bool | |
whether to print verbose output | |
athena_directory : str | |
directory in S3 bucket used to store query results | |
kwargs | |
additional arguments to pass to the read_csv method | |
Returns: | |
pd.DataFrame : pandas DataFrame containing query results | |
""" | |
# Run the query | |
job_id = run_athena_query(query, verbose=verbose) | |
# Get the results | |
if kwargs is None: | |
kwargs = {} | |
df = read_csv_from_s3(f"{athena_directory}/{job_id}.csv", verbose=verbose, **kwargs) | |
if verbose: | |
print(f"Retrieved dataframe with shape: {df.shape}") | |
# Return the DataFrame | |
return df | |
def run_athena_query( | |
query: str, | |
wait: int = 10, | |
verbose: bool = False, | |
athena_directory: str = "athena-workspace", | |
) -> str: | |
"""Execute query on Amazon Athena. | |
Args: | |
query : str | |
formatted string containing athena sql query | |
wait : int | |
number of seconds to wait between checking query status | |
verbose : bool | |
whether to print verbose output | |
athena_directory : str | |
directory in S3 bucket to store query results | |
Returns: | |
str : query execution id | |
""" | |
# Create the Athena client | |
client = boto3.client("athena", region_name=os.getenv("AWS_REGION_NAME")) | |
# Set the destination as our temporary S3 workspace folder | |
s3_destination = f"s3://{os.getenv('AWS_S3_BUCKET_NAME')}/{athena_directory}/" | |
# Execute the query | |
if verbose: | |
print(f"Running query: {query}") | |
start_response = client.start_query_execution( | |
QueryString=query, | |
ResultConfiguration={ | |
"OutputLocation": s3_destination, | |
}, | |
) | |
# Get the query execution id | |
query_id = start_response["QueryExecutionId"] | |
if verbose: | |
print(f"Query ID: {query_id}") | |
# Wait for the query to finish | |
retry_count = 0 | |
while True: | |
# Get the query execution state | |
state = client.get_query_execution(QueryExecutionId=query_id)["QueryExecution"][ | |
"Status" | |
]["State"] | |
# If it's still running, wait a little longer | |
if state in ["RUNNING", "QUEUED"]: | |
if verbose: | |
print(f"Query state: {state}. Waiting {wait} seconds...") | |
time.sleep(wait) | |
retry_count += 1 | |
# If it failed, raise an exception | |
else: | |
break | |
# Make sure it finished successfully | |
if verbose: | |
print(f"Query finished with state: {state}") | |
assert state == "SUCCEEDED", f"query state is {state}" | |
# Return the query id | |
return query_id | |
def read_csv_from_s3( | |
key_name: str, | |
verbose: bool = False, | |
**kwargs, | |
) -> pd.DataFrame: | |
"""Read a CSV file from S3. | |
Args: | |
key_name (str): The key name to read. | |
verbose (bool, optional): Whether to print verbose output. Defaults to False. | |
kwargs: Additional kwargs to pass to pd.read_csv. | |
Returns: | |
pd.DataFrame: The dataframe. | |
""" | |
bucket = os.getenv("AWS_S3_BUCKET_NAME") | |
assert bucket | |
if verbose: | |
print(f"Reading {key_name} from S3 bucket {bucket}") | |
# Connect to S3 | |
s3 = get_s3_client() | |
# Download the file | |
response = s3.get_object( | |
Bucket=bucket, | |
Key=key_name, | |
) | |
# Read the file | |
df = pd.read_csv(io.BytesIO(response["Body"].read()), **kwargs) | |
# Return the dataframe | |
return df | |
def get_s3_client() -> boto3.client: | |
"""Get a boto3 client for S3. | |
Returns: | |
boto3.client: The boto3 client. | |
""" | |
return boto3.client( | |
"s3", | |
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), | |
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment