Created
January 29, 2025 19:26
-
-
Save deanm0000/fc5e80c411f90b197ce815ca5e56c754 to your computer and use it in GitHub Desktop.
polars to postgres via tokio::postgres copy binary in rust
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#[tokio::main] | |
async fn main() { | |
let mut client = make_postgres().await; // tokio_postgres::connect with spawn | |
let transaction = client.transaction().await.unwrap(); | |
// make an example df | |
let utc=PlSmallStr::from_str("UTC"); | |
let ava = AnyValue::Datetime(Utc | |
.with_ymd_and_hms(2010, 1, 1, 1, 0, 0) | |
.unwrap() | |
.timestamp_micros(), TimeUnit::Microseconds, Some(&utc)); | |
let avb = AnyValue::Datetime(Utc | |
.with_ymd_and_hms(2010, 1, 1, 2, 0, 0) | |
.unwrap() | |
.timestamp_micros(), TimeUnit::Microseconds, Some(&utc)); | |
let mut df = DataFrame::new(vec![ | |
Column::new("nodes".into(), ["apple", "banana"]), | |
Series::from_any_values("utcbegin".into(), &[ava, avb], true) | |
.unwrap() | |
.into(), | |
Column::new("pricedate".into(), [ | |
NaiveDate::from_ymd_opt(2010, 1, 1), | |
NaiveDate::from_ymd_opt(2010, 1, 1), | |
]), | |
Column::new("hour".into(), [2u8, 2u8]), | |
Column::new("measurement_name".into(), ["abc", "abc"]), | |
Column::new("v".into(), [1.1f64, 2.2]), | |
Column::new("inverter".into(), [1u8, 1]), | |
]) | |
.unwrap(); | |
eprintln!("{}", df); | |
// make temp table in postgres | |
let temp_name = "some_table"; | |
let create_temp = format!("CREATE TEMPORARY TABLE IF NOT EXISTS {} | |
( | |
node character varying NOT NULL, | |
utcbegin timestamp with time zone NOT NULL, | |
pricedate date NOT NULL, | |
hour integer NOT NULL, | |
measurement_name character varying NOT NULL, | |
v double precision NOT NULL, | |
inverter integer NOT NULL | |
)", temp_name); | |
transaction.execute(create_temp.as_str(), &[]).await.unwrap(); | |
eprintln!("created table"); | |
// create copy_in | |
let copy_str = format!("COPY {} FROM STDIN BINARY", temp_name); | |
let sink: CopyInSink<Bytes> = transaction.copy_in(copy_str.as_str()).await.unwrap(); | |
// derive column types from AnyValues, might be better to do with schema but I was already | |
// using AnyValue Enum to extract values so I kept the theme | |
let col_types:Vec<Type> = df.get_row(0).unwrap().0.into_iter().map(|val| match val { | |
AnyValue::Binary(_)=>Type::BYTEA, | |
AnyValue::BinaryOwned(_)=>Type::BYTEA, | |
AnyValue::Boolean(_)=>Type::BOOL, | |
AnyValue::Date(_)=>Type::DATE, | |
AnyValue::Datetime(_, _, tz)=>{ | |
match tz { | |
Some(_)=>Type::TIMESTAMPTZ, | |
None=>Type::TIMESTAMP | |
} | |
}, | |
AnyValue::DatetimeOwned(_, _, tz)=>{ | |
match tz { | |
Some(_)=>Type::TIMESTAMPTZ, | |
None=>Type::TIMESTAMP | |
} | |
}, | |
AnyValue::Duration(..)=>panic!("i dunno duration"), | |
AnyValue::Float32(_)=>Type::FLOAT4, | |
AnyValue::Float64(_)=>Type::FLOAT8, | |
AnyValue::Int128(_)=>panic!("i dunno int128"), | |
AnyValue::Int16(_)=>Type::INT2, | |
AnyValue::Int32(_)=>Type::INT4, | |
AnyValue::Int64(_)=>Type::INT8, | |
AnyValue::Int8(_)=>Type::CHAR, | |
AnyValue::UInt16(_)=>Type::INT2, | |
AnyValue::UInt32(_)=>Type::INT4, | |
AnyValue::UInt64(_)=>Type::INT8, | |
AnyValue::UInt8(_)=>Type::INT4, | |
AnyValue::List(..)=>panic!("guess can make recursive with postgres arrays but oof that's rough or maybe just jsonb"), | |
AnyValue::Null=>Type::VOID, | |
AnyValue::String(_)=>Type::VARCHAR, | |
AnyValue::StringOwned(_)=>Type::VARCHAR, | |
AnyValue::Time(_)=>Type::TIME, | |
AnyValue::Struct(_,_,_)=>Type::JSONB, //maybe map to custom types | |
AnyValue::StructOwned(_)=>Type::JSONB, //maybe map to custom types | |
_=>panic!("can't do it") | |
}).collect(); | |
let writer = BinaryCopyInWriter::new(sink, &col_types); | |
pin_mut!(writer); | |
// now convert the df rows into pg binary | |
for i in 0..df.height() { | |
let row_vals = df.get_row(i).unwrap().0; | |
let row: Vec<Box<dyn ToSql + Sync>> = row_vals | |
.into_iter() | |
.map(|val| match val { | |
AnyValue::String(s) => Box::new(s) as Box<dyn ToSql + Sync>, | |
AnyValue::Datetime(timestamp, time_unit, tz) => { | |
let (seconds, nanoseconds) = match time_unit { | |
TimeUnit::Milliseconds => { | |
(timestamp / 1000, (timestamp % 1000 * 1_000_000) as u32) | |
} | |
TimeUnit::Microseconds => ( | |
timestamp / 1_000_000, | |
(timestamp % 1_000_000 * 1_000) as u32, | |
), | |
TimeUnit::Nanoseconds => ( | |
timestamp / 1_000_000_000, | |
(timestamp % 1_000_000_000) as u32, | |
), | |
}; | |
let raw_utc = DateTime::from_timestamp(seconds, nanoseconds) | |
.unwrap(); | |
let naive = raw_utc.naive_utc(); | |
eprintln!("{} {} {} or {}", seconds, nanoseconds, raw_utc, naive); | |
match tz { | |
None => Box::new(raw_utc) as Box<dyn ToSql + Sync>, | |
Some(tz) => { | |
let timezone: chrono_tz::Tz = tz.parse().unwrap(); | |
let tz_aware = raw_utc.with_timezone(&timezone); | |
let tz_aware_offset=tz_aware.fixed_offset(); | |
eprintln!("tz_aware {} utc {}", tz_aware, tz_aware_offset); | |
Box::new(tz_aware_offset) as Box<dyn ToSql + Sync> | |
} | |
} | |
} | |
AnyValue::Boolean(b) => Box::new(b) as Box<dyn ToSql + Sync>, | |
AnyValue::Date(i) => Box::new(NaiveDate::from_ymd_opt(1970,1,1).unwrap().checked_add_days(Days::new(i as u64)).unwrap()) | |
as Box<dyn ToSql + Sync>, | |
AnyValue::UInt8(i) => Box::new(i as i32) as Box<dyn ToSql + Sync>, | |
AnyValue::Float64(f) => Box::new(f) as Box<dyn ToSql + Sync>, | |
_ => { | |
eprintln!("{}", val); | |
panic!("other type") | |
}, | |
}) | |
.collect(); | |
let row_refs: Vec<&(dyn ToSql + Sync)> = row.iter().map(|b| &**b).collect(); | |
writer.as_mut().write(&row_refs).await.unwrap(); | |
eprintln!("{:?}",row_refs); | |
} | |
writer.finish().await.unwrap(); | |
eprintln!("finished writing"); | |
eprintln!("finished sink"); | |
// do query of what was just copied to see if it's there | |
// in real usage you'd want to do something like | |
// INSERT INTO real_table select * from temp_table | |
let out = transaction.query("select * from some_table", &[]).await.unwrap(); | |
for row in out { | |
eprintln!("{:?}", row); | |
} | |
transaction.commit().await.unwrap(); | |
eprintln!("committed"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment