Last active
March 4, 2022 14:52
-
-
Save IhostVlad/a505bef43efb38a29d4e906fd8da13f4 to your computer and use it in GitHub Desktop.
Micro eventsourcing Rust + AWS serverless
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "lambda_runtime" | |
version = "0.4.1" | |
edition = "2021" | |
[dependencies] | |
tokio = { version = "1.0", features = ["macros", "io-util", "sync", "rt-multi-thread"] } | |
hyper = { version = "0.14", features = ["http1", "client", "server", "stream", "runtime"] } | |
serde = { version = "1", features = ["derive"] } | |
serde_json = "^1" | |
aws-sdk-sqs = "0.6.0" | |
aws-sdk-lambda = "0.6.0" | |
aws-config = "0.6.0" | |
bytes = "1.0" | |
http = "0.2" | |
async-stream = "0.3" | |
lambda_runtime = "0.4.1" | |
tracing = { version = "0.1", features = ["log"] } | |
tower = { version = "0.4", features = ["util"] } | |
custom_error = "1.9.2" | |
tokio-stream = "0.1.2" | |
tower-service = "0.3" | |
tracing-subscriber = "0.3" | |
tokio-postgres = "0.7.5" | |
rand = "0.8.5" | |
lazy_static = "1.4.0" | |
regex = "1.5.4" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/zsh | |
export CC_x86_64_unknown_linux_gnu=x86_64-unknown-linux-gnu-gcc | |
export CXX_x86_64_unknown_linux_gnu=x86_64-unknown-linux-gnu-g++ | |
export AR_x86_64_unknown_linux_gnu=x86_64-unknown-linux-gnu-ar | |
export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_LINKER=x86_64-unknown-linux-gnu-gcc | |
export CC_x86_64_unknown_linux_musl=x86_64-unknown-linux-musl-gcc | |
export CXX_x86_64_unknown_linux_musl=x86_64-unknown-linux-musl-g++ | |
export AR_x86_64_unknown_linux_musl=x86_64-unknown-linux-musl-ar | |
export CARGO_TARGET_X86_64_UNKNOWN_LINUX_MUSL_LINKER=x86_64-unknown-linux-musl-gcc | |
cargo build --release --target x86_64-unknown-linux-musl | |
cp ./target/x86_64-unknown-linux-musl/release/lambda_runtime ./bootstrap | |
chmod +x bootstrap | |
zip lambda.zip ./bootstrap | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use lambda_runtime; | |
use serde_json; | |
use tracing_subscriber; | |
use tokio_postgres; | |
use aws_sdk_lambda; | |
use aws_sdk_sqs; | |
use aws_config; | |
#[tokio::main] | |
async fn main() -> Result<(), lambda_runtime::Error> { | |
tracing_subscriber::fmt() | |
.with_max_level(tracing::Level::INFO) | |
.with_ansi(false) | |
.without_time() | |
.init(); | |
let func = lambda_runtime::handler_fn(my_handler); | |
lambda_runtime::run(func).await?; | |
Ok(()) | |
} | |
#[derive(serde::Deserialize)] | |
struct Request { | |
command: String, | |
part: Option<u8>, | |
} | |
#[derive(serde::Serialize)] | |
struct Response { | |
msg: String, | |
} | |
custom_error::custom_error!{LambdaError | |
Error{message:std::string::String} = "{message}" | |
} | |
fn is_passthrough_error (error: &tokio_postgres::Error) -> bool { | |
lazy_static::lazy_static! { | |
static ref RE: regex::Regex = regex::Regex::new(r#"(?:(?i)(?:(?:terminating connection because backend initialization completed past serverless scale point)|(?:Connection rate is too high, please reduce connection rate)|(?:terminating connection due to serverless scale event timeout)|(?:terminating connection due to administrator command)(?:canceling statement due to statement timeout)|(?:Remaining connection slots are reserved)|(?:Connection terminated unexpectedly)|(?:Too many clients already)|(?:in a read-only transaction)|(?:Query read timeout)|(?:Connection terminated)|(?:timeout expired)|(?:connection closed)|(?:error communicating with the server)|(?:Too many connection errors)|(?:database is not available)|(?:ECONNRESET)|(?:ETIMEDOUT)|(?:getaddrinfo)))"# ).unwrap(); | |
} | |
RE.is_match(format!("{}", error).as_str()) | |
} | |
fn is_lock_busy_error (error: &tokio_postgres::Error) -> bool { | |
lazy_static::lazy_static! { | |
static ref RE: regex::Regex = regex::Regex::new(r#"(?:(?i)(?:(?:could not obtain lock on row in relation)))"#).unwrap(); | |
} | |
RE.is_match(format!("{}", error).as_str()) | |
} | |
fn is_zero_runtime_error (error: &tokio_postgres::Error) -> bool { | |
lazy_static::lazy_static! { | |
static ref RE: regex::Regex = regex::Regex::new(r#"(?:(?i)(?:(?:division by zero)))"#).unwrap(); | |
} | |
RE.is_match(format!("{}", error).as_str()) | |
} | |
type MaybePostgresClient = std::option::Option::<(tokio_postgres::Client, tokio::sync::oneshot::Receiver<tokio_postgres::Error>)>; | |
async fn maybe_close_connection(maybe_postgres_client: &mut MaybePostgresClient) { | |
if let Some((client, _)) = maybe_postgres_client.take() { | |
drop(client); | |
}; | |
} | |
async fn execute_query_auto_retry<'e>(maybe_postgres_client: &'e mut MaybePostgresClient, connection_string: &std::string::String, query_consumed: std::string::String) -> Result<(), tokio_postgres::Error> { | |
let query = &query_consumed; | |
loop { | |
if maybe_postgres_client.is_none() { | |
let (sender, receiver) = tokio::sync::oneshot::channel(); | |
match tokio_postgres::connect(connection_string.as_str(), tokio_postgres::NoTls).await { | |
Ok((client, connection)) => { | |
*maybe_postgres_client = Some((client, receiver )); | |
tokio::spawn(async move { | |
if let Err(err) = connection.await { | |
match sender.send(err) { _ => {} }; | |
} | |
}); | |
}, | |
Err(err) => { | |
if is_passthrough_error(&err) { | |
maybe_close_connection(maybe_postgres_client).await; | |
continue; | |
} else { | |
return Err(err); | |
} | |
} | |
} | |
} | |
if let Some((ref mut client, ref mut connection_errors_receiver)) = maybe_postgres_client { | |
let empty_receiver_fallthrough = tokio::time::sleep(tokio::time::Duration::from_micros(1)); | |
tokio::pin!(empty_receiver_fallthrough); | |
tokio::select! { | |
msg = connection_errors_receiver => { | |
match msg { | |
Ok(err) => { | |
if is_passthrough_error(&err) { | |
maybe_close_connection(maybe_postgres_client).await; | |
continue; | |
} else { | |
return Err(err) | |
} | |
}, | |
_ => {}, | |
} | |
}, | |
_ = &mut empty_receiver_fallthrough => {} | |
} | |
match client.batch_execute(query).await { | |
Ok(_) => { return Ok(()) }, | |
Err(err) => { | |
if is_passthrough_error(&err) { | |
maybe_close_connection(maybe_postgres_client).await; | |
continue; | |
} else { | |
return Err(err); | |
} | |
} | |
} | |
} | |
} | |
} | |
async fn my_handler(input_event: serde_json::Value, context: lambda_runtime::Context) -> Result<serde_json::Value, lambda_runtime::Error> { | |
let region_provider = aws_config::meta::region::RegionProviderChain::default_provider(); | |
let shared_config = aws_config::from_env().region(region_provider).load().await; | |
let mut error_messages = String::from(""); | |
let concurrent_senders_count = std::env::var("CONCURRENT_SENDERS_COUNT").unwrap_or("".into()).parse::<u32>().unwrap_or(u32::MAX); | |
if concurrent_senders_count == u32::MAX { | |
error_messages.insert_str(error_messages.len(), format!("Mission CONCURRENT_SENDERS_COUNT env var or not U32 integer").as_str()); | |
} | |
let incoming_messages_per_sender_count = std::env::var("INCOMING_MESSAGES_PER_SENDER").unwrap_or("".into()).parse::<u32>().unwrap_or(u32::MAX); | |
if incoming_messages_per_sender_count == u32::MAX { | |
error_messages.insert_str(error_messages.len(), format!("Mission INCOMING_MESSAGES_PER_SENDER env var or not U32 integer").as_str()); | |
} | |
let concurrent_parts_count = std::env::var("CONCURRENT_PARTS_COUNT").unwrap_or("".into()).parse::<u32>().unwrap_or(u32::MAX); | |
if concurrent_parts_count == u32::MAX { | |
error_messages.insert_str(error_messages.len(), format!("Mission CONCURRENT_PARTS_COUNT env var or not U32 integer").as_str()); | |
} | |
let parts_per_message = std::env::var("PARTS_PER_MESSAGE").unwrap_or("".into()).parse::<u32>().unwrap_or(u32::MAX); | |
if parts_per_message == u32::MAX || parts_per_message > concurrent_parts_count { | |
error_messages.insert_str(error_messages.len(), format!("Mission PARTS_PER_MESSAGE env var or not U32 integer or large than concurrent parts").as_str()); | |
} | |
let sqs_queue_url = std::env::var("SQS_QUEUE_URL").unwrap_or("".into()); | |
if sqs_queue_url == "" { | |
error_messages.insert_str(error_messages.len(), format!("Missing SQS_QUEUE_URL env var").as_str()); | |
} | |
let postgres_connection_string = std::env::var("POSTGRES_CONNECTION_STRING").unwrap_or("".into()); | |
if postgres_connection_string == "" { | |
error_messages.insert_str(error_messages.len(), format!("Missing POSTGRES_CONNECTION_STRING env var").as_str()); | |
} | |
let postgres_schema_name = std::env::var("POSTGRES_SCHEMA_NAME").unwrap_or("".into()); | |
if postgres_schema_name == "" || postgres_schema_name.contains("\"") || postgres_schema_name.contains("'") || postgres_schema_name.contains("\\") { | |
error_messages.insert_str(error_messages.len(), format!("Missing POSTGRES_SCHEMA_NAME env var or containing bad symbols").as_str()); | |
} | |
if error_messages != "" { | |
return Err(Box::new(LambdaError::Error { | |
message: error_messages | |
})); | |
} | |
let event = match serde_json::from_value::<Request>(input_event) { | |
Ok(value) => value, | |
Err(err) => { | |
return Err(Box::new(LambdaError::Error { | |
message: format!("Input event parse failure {}\n", err) | |
})); | |
} | |
}; | |
let lambda_client = aws_sdk_lambda::Client::new(&shared_config); | |
let sqs_client = aws_sdk_sqs::Client::new(&shared_config); | |
let command = event.command; | |
if command == "invoke-postgres" || command == "invoke-sqs" { | |
let payload = match command.as_str() { | |
"invoke-postgres" => aws_sdk_lambda::Blob::new(format!("{{ \"command\": \"send-postgres\" }}")), | |
"invoke-sqs" => aws_sdk_lambda::Blob::new(format!("{{ \"command\": \"send-sqs\" }}")), | |
_ => unreachable!() | |
}; | |
let mut maybe_postgres_client: MaybePostgresClient = MaybePostgresClient::None; | |
match execute_query_auto_retry(&mut maybe_postgres_client, &postgres_connection_string, format!(r###" | |
DROP TABLE IF EXISTS "{0}"."incoming"; | |
DROP TABLE IF EXISTS "{0}"."processing"; | |
CREATE TABLE "{0}"."incoming"("timeguid" bigint primary key, "parts" integer[] not null); | |
CREATE TABLE "{0}"."processing"( | |
"part" integer primary key, | |
"border_timeguid" bigint not null, | |
"radix_timeguids" bigint[] not null, | |
"pid" integer null | |
); | |
INSERT INTO "{0}"."processing"("part", "border_timeguid", "radix_timeguids") | |
SELECT "i" as "part", 0 as "border_timeguid", array[]::bigint[] as "radix_timeguids" | |
FROM generate_series(0, {1}) "i"; | |
"###, postgres_schema_name, concurrent_parts_count - 1)).await { | |
Ok(_) => {}, | |
Err(err) => { | |
maybe_close_connection(&mut maybe_postgres_client).await; | |
return Err(Box::new(LambdaError::Error { | |
message: format!("Postgres resource management failure {}", err) | |
})); | |
} | |
}; | |
let invoke_self = || lambda_client.invoke() | |
.function_name(&context.invoked_function_arn) | |
.invocation_type(aws_sdk_lambda::model::InvocationType::Event) | |
.payload(payload.clone()) | |
.send(); | |
let mut promises = vec!(invoke_self()); | |
for _ in 0..concurrent_senders_count - 1 { | |
promises.push(invoke_self()); | |
} | |
let mut success_invokes = 0; | |
let mut failed_invokes = 0; | |
for promise in promises { | |
match promise.await { | |
Ok(_) => { success_invokes = success_invokes + 1; }, | |
Err(_) => { failed_invokes = failed_invokes + 1; } | |
}; | |
} | |
return Ok(serde_json::to_value(Response { | |
msg: format!("Successfully invoked {} and failed {} lambdas", success_invokes, failed_invokes) | |
}).unwrap()) | |
} else if command == "send-postgres" || command == "send-sqs" { | |
let mut maybe_postgres_client: MaybePostgresClient = MaybePostgresClient::None; | |
for _ in 0..incoming_messages_per_sender_count { | |
let mut parts = std::vec::Vec::<u32>::with_capacity(parts_per_message as usize); | |
for _ in 0..parts_per_message { | |
parts.push(rand::random::<u32>() % concurrent_parts_count); | |
} | |
match execute_query_auto_retry(&mut maybe_postgres_client, &postgres_connection_string, format!(r###" | |
INSERT INTO "{0}"."incoming"("parts", "timeguid") VALUES(array{1:?}, | |
CAST(CAST(extract(epoch from clock_timestamp()) * 1000 AS BIGINT) * 1000000 AS BIGINT) + CAST(random() * 1000000 AS BIGINT) | |
) | |
"###, postgres_schema_name, parts)).await { | |
Ok(_) => {}, | |
Err(err) => { | |
maybe_close_connection(&mut maybe_postgres_client).await; | |
return Err(Box::new(LambdaError::Error { | |
message: format!("Send incoming entries failed {}", err) | |
})); | |
} | |
}; | |
for i in 0..parts.len() { | |
match command.as_str() { | |
"send-postgres" => { | |
match execute_query_auto_retry(&mut maybe_postgres_client, &postgres_connection_string, format!(r###" | |
WITH "lock_count" AS ( | |
SELECT CASE WHEN Count(*) > 0 THEN 0 ELSE 1 END AS "ActiveLocksCount" FROM pg_locks | |
WHERE database = (SELECT oid FROM pg_database WHERE datname = current_database()) | |
AND relation = (SELECT oid FROM pg_class WHERE relname = 'processing' | |
AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = '{0}')) | |
AND pid = (SELECT "E"."pid" FROM "{0}"."processing" "E" WHERE "E"."part" = {1}) | |
) | |
SELECT 1/(SELECT "lock_count"."ActiveLocksCount" FROM "lock_count") AS "cnt" | |
"###, postgres_schema_name, parts[i] | |
)).await { | |
Ok(_) => { | |
match lambda_client.invoke() | |
.function_name(&context.invoked_function_arn) | |
.invocation_type(aws_sdk_lambda::model::InvocationType::Event) | |
.payload(aws_sdk_lambda::Blob::new(format!("{{ \"command\": \"work-postgres\", \"part\": {} }}", parts[i]))) | |
.send() | |
.await | |
{ _ => {} }; | |
}, | |
Err(err) => { | |
if !is_zero_runtime_error(&err) { | |
maybe_close_connection(&mut maybe_postgres_client).await; | |
return Err(Box::new(LambdaError::Error { | |
message: format!("Send incoming entries failed {}", err) | |
})); | |
} | |
} | |
}; | |
}, | |
"send-sqs" => { | |
match sqs_client.send_message() | |
.queue_url(sqs_queue_url.as_str()) | |
.message_body(format!("{{ \"command\": \"work-sqs\", \"part\": {} }}", parts[i])) | |
.message_group_id(format!("{}", parts[i])) | |
.send() | |
.await | |
{ _ => {} }; | |
}, | |
_ => unreachable!() | |
}; | |
} | |
} | |
maybe_close_connection(&mut maybe_postgres_client).await; | |
return Ok(serde_json::to_value(Response { | |
msg: format!("Successfully invoked insert incoming messages") | |
}).unwrap()) | |
} else if command == "work-postgres" { | |
let current_part = match event.part { | |
Some(part) => part, | |
None => { | |
return Err(Box::new(LambdaError::Error { | |
message: format!("Partition number is not provided for worker process") | |
})); | |
} | |
}; | |
let mut maybe_postgres_client: MaybePostgresClient = MaybePostgresClient::None; | |
'worker_loop: loop { | |
match execute_query_auto_retry(&mut maybe_postgres_client, &postgres_connection_string, "SELECT 0".into()).await { | |
Ok(_) => {}, | |
Err(err) => { | |
return Err(Box::new(LambdaError::Error { | |
message: format!("Worker database connection error {}", err) | |
})); | |
} | |
}; | |
if let Some((ref mut client, _)) = maybe_postgres_client { | |
match client.batch_execute(format!(r###" | |
WITH "maybe_lock_part" AS ( | |
SELECT * FROM "{0}"."processing" WHERE "part" = {1} FOR UPDATE NOWAIT | |
) | |
UPDATE "{0}"."processing" SET "pid" = pg_backend_pid() | |
WHERE (SELECT Count("maybe_lock_part".*) FROM "maybe_lock_part") <> 0 | |
AND "part" = {1} | |
"###, postgres_schema_name, current_part).as_str()).await { | |
Ok(_) => {}, | |
Err(err) => { | |
if !is_passthrough_error(&err) { | |
maybe_close_connection(&mut maybe_postgres_client).await; | |
if is_lock_busy_error(&err) { | |
return Ok(serde_json::to_value(Response { | |
msg: format!("Successfully skipped worker") | |
}).unwrap()) | |
} else { | |
return Err(Box::new(LambdaError::Error { | |
message: format!("Worker process failed {}", err) | |
})); | |
} | |
} else { | |
continue 'worker_loop; | |
} | |
} | |
} | |
match client.batch_execute(format!(r###" | |
BEGIN TRANSACTION; | |
SELECT * FROM "{0}"."processing" WHERE "part" = {1} | |
AND 1 / ( | |
SELECT Count("E".*) FROM "{0}"."processing" "E" | |
WHERE "E"."part" = {1} AND "E"."pid" = pg_backend_pid() | |
) <> 2 | |
FOR UPDATE NOWAIT; | |
"###, postgres_schema_name, current_part).as_str()).await { | |
Ok(_) => {}, | |
Err(err) => { | |
if !is_passthrough_error(&err) { | |
maybe_close_connection(&mut maybe_postgres_client).await; | |
if is_lock_busy_error(&err) || is_zero_runtime_error(&err) { | |
return Ok(serde_json::to_value(Response { | |
msg: format!("Successfully skipped worker part {}", current_part) | |
}).unwrap()) | |
} else { | |
return Err(Box::new(LambdaError::Error { | |
message: format!("Worker process failed {}", err) | |
})); | |
} | |
} else { | |
continue 'worker_loop; | |
} | |
} | |
} | |
match client.batch_execute(format!(r###" | |
WITH "current_part" AS ( | |
SELECT * FROM "{0}"."processing" WHERE "part" = {1} LIMIT 1 | |
), "current_radix_column" AS ( | |
SELECT unnest("current_part"."radix_timeguids") AS "timeguid" FROM "current_part" | |
), "available_messages" AS ( | |
SELECT * FROM "{0}"."incoming" WHERE "timeguid" + 500000000 > ( | |
SELECT "border_timeguid" FROM "current_part" LIMIT 1 | |
) AND "timeguid" NOT IN ( | |
SELECT "current_radix_column"."timeguid" FROM "current_radix_column" | |
) AND array_position("parts", {1}) IS NOT NULL | |
ORDER BY "timeguid" | |
), "next_radix_column" AS ( | |
SELECT "current_radix_column"."timeguid" FROM "current_radix_column" | |
WHERE "current_radix_column"."timeguid" + 500000000 > ( | |
SELECT max("available_messages"."timeguid") FROM "available_messages" | |
) UNION ALL | |
SELECT "available_messages"."timeguid" FROM "available_messages" | |
WHERE "available_messages"."timeguid" + 500000000 > ( | |
SELECT max("available_messages"."timeguid") FROM "available_messages" | |
) | |
) | |
UPDATE "{0}"."processing" SET "border_timeguid" = ( | |
SELECT max("available_messages"."timeguid") FROM "available_messages" | |
), "radix_timeguids" = ( | |
SELECT array_agg("next_radix_column"."timeguid") FROM "next_radix_column" | |
) | |
WHERE "part" = {1} AND ( | |
SELECT count("available_messages".*) FROM "available_messages" | |
) > 0 | |
"###, postgres_schema_name, current_part).as_str()).await { | |
Ok(_) => {}, | |
Err(err) => { | |
if !is_passthrough_error(&err) { | |
maybe_close_connection(&mut maybe_postgres_client).await; | |
return Err(Box::new(LambdaError::Error { | |
message: format!("Worker process failed {}", err) | |
})); | |
} else { | |
continue 'worker_loop; | |
} | |
} | |
}; | |
match client.batch_execute(format!(r###" | |
COMMIT | |
"###).as_str()).await { | |
Ok(_) => {}, | |
Err(err) => { | |
if !is_passthrough_error(&err) { | |
maybe_close_connection(&mut maybe_postgres_client).await; | |
return Err(Box::new(LambdaError::Error { | |
message: format!("Worker process failed {}", err) | |
})); | |
} else { | |
continue 'worker_loop; | |
} | |
} | |
} | |
} else { | |
unreachable!(); | |
} | |
break 'worker_loop; | |
}; | |
maybe_close_connection(&mut maybe_postgres_client).await; | |
return Ok(serde_json::to_value(Response { | |
msg: format!("Successfully done worker part {}", current_part) | |
}).unwrap()); | |
} else { | |
return Err(Box::new(LambdaError::Error { | |
message: format!("Unknown command {}", command) | |
})); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment