Skip to content

Instantly share code, notes, and snippets.

@a-agmon
Last active March 21, 2024 21:43
Show Gist options
  • Save a-agmon/65fe8e6f065404f039937befbbfa401e to your computer and use it in GitHub Desktop.
Save a-agmon/65fe8e6f065404f039937befbbfa401e to your computer and use it in GitHub Desktop.
use arrow::array::{array, Scalar, StringArray};
use parquet::arrow::arrow_reader::{ArrowPredicateFn, ParquetRecordBatchReaderBuilder, RowFilter};
use parquet::arrow::async_reader::ParquetRecordBatchStreamBuilder;
use parquet::arrow::ProjectionMask;
use parquet::schema::types::SchemaDescriptor;
use tokio::fs::File;
use tokio_stream::StreamExt;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let parq_file_path = "/Users/alonagmon/Downloads/part-00000-ef48f0ca-94f3-4b26-97b8-995c162f73af-c000.gz.parquet";
let prq_file = tokio::fs::File::open(parq_file_path).await?;
count_state_occurrences(prq_file.try_clone().await?, "ltv_country", "KR", true).await?;
count_state_occurrences(prq_file.try_clone().await?, "ltv_country", "KR", false).await?;
count_state_occurrences(prq_file.try_clone().await?, "ltv_country", "US", true).await?;
count_state_occurrences(prq_file.try_clone().await?, "ltv_country", "US", false).await?;
Ok(())
}
async fn count_state_occurrences(
prq_file: File,
field_name: &str,
state: &str,
with_filter: bool,
) -> anyhow::Result<usize> {
let mut state_count = 0;
let ts = std::time::Instant::now();
let stream_builder = ParquetRecordBatchStreamBuilder::new(prq_file).await?;
let file_metadata = stream_builder.metadata().file_metadata().clone();
let schema = stream_builder.schema();
let column_index = schema.index_of(field_name)?;
let schema = file_metadata.schema_descr();
let row_filter = get_row_filter(schema, column_index, state)?;
let mut stream = match with_filter {
true => stream_builder.with_row_filter(row_filter).build()?,
false => stream_builder.build()?,
};
while let Some(batch) = stream.next().await {
let batch = batch?;
let column = batch.column_by_name(field_name).unwrap();
let column = column
.as_any()
.downcast_ref::<array::StringArray>()
.unwrap();
for v in column.iter() {
if v.unwrap_or_default() == state {
state_count += 1;
}
}
}
println!(
"State: {} row count: {} with_filter: {} => time taken: {:?}",
state,
state_count,
with_filter,
ts.elapsed()
);
Ok(state_count)
}
fn get_row_filter(
schema: &SchemaDescriptor,
col_index: usize,
val: &str,
) -> anyhow::Result<RowFilter> {
let predicate = StringArray::from(vec![val]);
let predicate = Scalar::new(predicate);
let filter = ArrowPredicateFn::new(ProjectionMask::roots(schema, [col_index]), move |batch| {
arrow_ord::cmp::eq(batch.column(0), &predicate)
});
Ok(RowFilter::new(vec![Box::new(filter)]))
}
/*
-- RESULTS --
State: KR row count: 12660 with_filter: true => time taken: 656.518875ms
State: KR row count: 12660 with_filter: false => time taken: 844.822917ms
State: US row count: 158015 with_filter: true => time taken: 1.085824833s
State: US row count: 158015 with_filter: false => time taken: 862.845125ms
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment