Last active
March 21, 2024 21:43
-
-
Save a-agmon/65fe8e6f065404f039937befbbfa401e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use arrow::array::{array, Scalar, StringArray}; | |
use parquet::arrow::arrow_reader::{ArrowPredicateFn, ParquetRecordBatchReaderBuilder, RowFilter}; | |
use parquet::arrow::async_reader::ParquetRecordBatchStreamBuilder; | |
use parquet::arrow::ProjectionMask; | |
use parquet::schema::types::SchemaDescriptor; | |
use tokio::fs::File; | |
use tokio_stream::StreamExt; | |
#[tokio::main] | |
async fn main() -> anyhow::Result<()> { | |
let parq_file_path = "/Users/alonagmon/Downloads/part-00000-ef48f0ca-94f3-4b26-97b8-995c162f73af-c000.gz.parquet"; | |
let prq_file = tokio::fs::File::open(parq_file_path).await?; | |
count_state_occurrences(prq_file.try_clone().await?, "ltv_country", "KR", true).await?; | |
count_state_occurrences(prq_file.try_clone().await?, "ltv_country", "KR", false).await?; | |
count_state_occurrences(prq_file.try_clone().await?, "ltv_country", "US", true).await?; | |
count_state_occurrences(prq_file.try_clone().await?, "ltv_country", "US", false).await?; | |
Ok(()) | |
} | |
async fn count_state_occurrences( | |
prq_file: File, | |
field_name: &str, | |
state: &str, | |
with_filter: bool, | |
) -> anyhow::Result<usize> { | |
let mut state_count = 0; | |
let ts = std::time::Instant::now(); | |
let stream_builder = ParquetRecordBatchStreamBuilder::new(prq_file).await?; | |
let file_metadata = stream_builder.metadata().file_metadata().clone(); | |
let schema = stream_builder.schema(); | |
let column_index = schema.index_of(field_name)?; | |
let schema = file_metadata.schema_descr(); | |
let row_filter = get_row_filter(schema, column_index, state)?; | |
let mut stream = match with_filter { | |
true => stream_builder.with_row_filter(row_filter).build()?, | |
false => stream_builder.build()?, | |
}; | |
while let Some(batch) = stream.next().await { | |
let batch = batch?; | |
let column = batch.column_by_name(field_name).unwrap(); | |
let column = column | |
.as_any() | |
.downcast_ref::<array::StringArray>() | |
.unwrap(); | |
for v in column.iter() { | |
if v.unwrap_or_default() == state { | |
state_count += 1; | |
} | |
} | |
} | |
println!( | |
"State: {} row count: {} with_filter: {} => time taken: {:?}", | |
state, | |
state_count, | |
with_filter, | |
ts.elapsed() | |
); | |
Ok(state_count) | |
} | |
fn get_row_filter( | |
schema: &SchemaDescriptor, | |
col_index: usize, | |
val: &str, | |
) -> anyhow::Result<RowFilter> { | |
let predicate = StringArray::from(vec![val]); | |
let predicate = Scalar::new(predicate); | |
let filter = ArrowPredicateFn::new(ProjectionMask::roots(schema, [col_index]), move |batch| { | |
arrow_ord::cmp::eq(batch.column(0), &predicate) | |
}); | |
Ok(RowFilter::new(vec![Box::new(filter)])) | |
} | |
/* | |
-- RESULTS -- | |
State: KR row count: 12660 with_filter: true => time taken: 656.518875ms | |
State: KR row count: 12660 with_filter: false => time taken: 844.822917ms | |
State: US row count: 158015 with_filter: true => time taken: 1.085824833s | |
State: US row count: 158015 with_filter: false => time taken: 862.845125ms | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment