Skip to content

Instantly share code, notes, and snippets.

@IhostVlad
Last active March 4, 2022 14:52
Show Gist options
  • Save IhostVlad/a505bef43efb38a29d4e906fd8da13f4 to your computer and use it in GitHub Desktop.
Save IhostVlad/a505bef43efb38a29d4e906fd8da13f4 to your computer and use it in GitHub Desktop.
Micro eventsourcing Rust + AWS serverless
[package]
name = "lambda_runtime"
version = "0.4.1"
edition = "2021"
[dependencies]
tokio = { version = "1.0", features = ["macros", "io-util", "sync", "rt-multi-thread"] }
hyper = { version = "0.14", features = ["http1", "client", "server", "stream", "runtime"] }
serde = { version = "1", features = ["derive"] }
serde_json = "^1"
aws-sdk-sqs = "0.6.0"
aws-sdk-lambda = "0.6.0"
aws-config = "0.6.0"
bytes = "1.0"
http = "0.2"
async-stream = "0.3"
lambda_runtime = "0.4.1"
tracing = { version = "0.1", features = ["log"] }
tower = { version = "0.4", features = ["util"] }
custom_error = "1.9.2"
tokio-stream = "0.1.2"
tower-service = "0.3"
tracing-subscriber = "0.3"
tokio-postgres = "0.7.5"
rand = "0.8.5"
lazy_static = "1.4.0"
regex = "1.5.4"
#!/bin/zsh
export CC_x86_64_unknown_linux_gnu=x86_64-unknown-linux-gnu-gcc
export CXX_x86_64_unknown_linux_gnu=x86_64-unknown-linux-gnu-g++
export AR_x86_64_unknown_linux_gnu=x86_64-unknown-linux-gnu-ar
export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_LINKER=x86_64-unknown-linux-gnu-gcc
export CC_x86_64_unknown_linux_musl=x86_64-unknown-linux-musl-gcc
export CXX_x86_64_unknown_linux_musl=x86_64-unknown-linux-musl-g++
export AR_x86_64_unknown_linux_musl=x86_64-unknown-linux-musl-ar
export CARGO_TARGET_X86_64_UNKNOWN_LINUX_MUSL_LINKER=x86_64-unknown-linux-musl-gcc
cargo build --release --target x86_64-unknown-linux-musl
cp ./target/x86_64-unknown-linux-musl/release/lambda_runtime ./bootstrap
chmod +x bootstrap
zip lambda.zip ./bootstrap
use lambda_runtime;
use serde_json;
use tracing_subscriber;
use tokio_postgres;
use aws_sdk_lambda;
use aws_sdk_sqs;
use aws_config;
#[tokio::main]
async fn main() -> Result<(), lambda_runtime::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_ansi(false)
.without_time()
.init();
let func = lambda_runtime::handler_fn(my_handler);
lambda_runtime::run(func).await?;
Ok(())
}
#[derive(serde::Deserialize)]
struct Request {
command: String,
part: Option<u8>,
}
#[derive(serde::Serialize)]
struct Response {
msg: String,
}
custom_error::custom_error!{LambdaError
Error{message:std::string::String} = "{message}"
}
fn is_passthrough_error (error: &tokio_postgres::Error) -> bool {
lazy_static::lazy_static! {
static ref RE: regex::Regex = regex::Regex::new(r#"(?:(?i)(?:(?:terminating connection because backend initialization completed past serverless scale point)|(?:Connection rate is too high, please reduce connection rate)|(?:terminating connection due to serverless scale event timeout)|(?:terminating connection due to administrator command)(?:canceling statement due to statement timeout)|(?:Remaining connection slots are reserved)|(?:Connection terminated unexpectedly)|(?:Too many clients already)|(?:in a read-only transaction)|(?:Query read timeout)|(?:Connection terminated)|(?:timeout expired)|(?:connection closed)|(?:error communicating with the server)|(?:Too many connection errors)|(?:database is not available)|(?:ECONNRESET)|(?:ETIMEDOUT)|(?:getaddrinfo)))"# ).unwrap();
}
RE.is_match(format!("{}", error).as_str())
}
fn is_lock_busy_error (error: &tokio_postgres::Error) -> bool {
lazy_static::lazy_static! {
static ref RE: regex::Regex = regex::Regex::new(r#"(?:(?i)(?:(?:could not obtain lock on row in relation)))"#).unwrap();
}
RE.is_match(format!("{}", error).as_str())
}
fn is_zero_runtime_error (error: &tokio_postgres::Error) -> bool {
lazy_static::lazy_static! {
static ref RE: regex::Regex = regex::Regex::new(r#"(?:(?i)(?:(?:division by zero)))"#).unwrap();
}
RE.is_match(format!("{}", error).as_str())
}
type MaybePostgresClient = std::option::Option::<(tokio_postgres::Client, tokio::sync::oneshot::Receiver<tokio_postgres::Error>)>;
async fn maybe_close_connection(maybe_postgres_client: &mut MaybePostgresClient) {
if let Some((client, _)) = maybe_postgres_client.take() {
drop(client);
};
}
async fn execute_query_auto_retry<'e>(maybe_postgres_client: &'e mut MaybePostgresClient, connection_string: &std::string::String, query_consumed: std::string::String) -> Result<(), tokio_postgres::Error> {
let query = &query_consumed;
loop {
if maybe_postgres_client.is_none() {
let (sender, receiver) = tokio::sync::oneshot::channel();
match tokio_postgres::connect(connection_string.as_str(), tokio_postgres::NoTls).await {
Ok((client, connection)) => {
*maybe_postgres_client = Some((client, receiver ));
tokio::spawn(async move {
if let Err(err) = connection.await {
match sender.send(err) { _ => {} };
}
});
},
Err(err) => {
if is_passthrough_error(&err) {
maybe_close_connection(maybe_postgres_client).await;
continue;
} else {
return Err(err);
}
}
}
}
if let Some((ref mut client, ref mut connection_errors_receiver)) = maybe_postgres_client {
let empty_receiver_fallthrough = tokio::time::sleep(tokio::time::Duration::from_micros(1));
tokio::pin!(empty_receiver_fallthrough);
tokio::select! {
msg = connection_errors_receiver => {
match msg {
Ok(err) => {
if is_passthrough_error(&err) {
maybe_close_connection(maybe_postgres_client).await;
continue;
} else {
return Err(err)
}
},
_ => {},
}
},
_ = &mut empty_receiver_fallthrough => {}
}
match client.batch_execute(query).await {
Ok(_) => { return Ok(()) },
Err(err) => {
if is_passthrough_error(&err) {
maybe_close_connection(maybe_postgres_client).await;
continue;
} else {
return Err(err);
}
}
}
}
}
}
async fn my_handler(input_event: serde_json::Value, context: lambda_runtime::Context) -> Result<serde_json::Value, lambda_runtime::Error> {
let region_provider = aws_config::meta::region::RegionProviderChain::default_provider();
let shared_config = aws_config::from_env().region(region_provider).load().await;
let mut error_messages = String::from("");
let concurrent_senders_count = std::env::var("CONCURRENT_SENDERS_COUNT").unwrap_or("".into()).parse::<u32>().unwrap_or(u32::MAX);
if concurrent_senders_count == u32::MAX {
error_messages.insert_str(error_messages.len(), format!("Mission CONCURRENT_SENDERS_COUNT env var or not U32 integer").as_str());
}
let incoming_messages_per_sender_count = std::env::var("INCOMING_MESSAGES_PER_SENDER").unwrap_or("".into()).parse::<u32>().unwrap_or(u32::MAX);
if incoming_messages_per_sender_count == u32::MAX {
error_messages.insert_str(error_messages.len(), format!("Mission INCOMING_MESSAGES_PER_SENDER env var or not U32 integer").as_str());
}
let concurrent_parts_count = std::env::var("CONCURRENT_PARTS_COUNT").unwrap_or("".into()).parse::<u32>().unwrap_or(u32::MAX);
if concurrent_parts_count == u32::MAX {
error_messages.insert_str(error_messages.len(), format!("Mission CONCURRENT_PARTS_COUNT env var or not U32 integer").as_str());
}
let parts_per_message = std::env::var("PARTS_PER_MESSAGE").unwrap_or("".into()).parse::<u32>().unwrap_or(u32::MAX);
if parts_per_message == u32::MAX || parts_per_message > concurrent_parts_count {
error_messages.insert_str(error_messages.len(), format!("Mission PARTS_PER_MESSAGE env var or not U32 integer or large than concurrent parts").as_str());
}
let sqs_queue_url = std::env::var("SQS_QUEUE_URL").unwrap_or("".into());
if sqs_queue_url == "" {
error_messages.insert_str(error_messages.len(), format!("Missing SQS_QUEUE_URL env var").as_str());
}
let postgres_connection_string = std::env::var("POSTGRES_CONNECTION_STRING").unwrap_or("".into());
if postgres_connection_string == "" {
error_messages.insert_str(error_messages.len(), format!("Missing POSTGRES_CONNECTION_STRING env var").as_str());
}
let postgres_schema_name = std::env::var("POSTGRES_SCHEMA_NAME").unwrap_or("".into());
if postgres_schema_name == "" || postgres_schema_name.contains("\"") || postgres_schema_name.contains("'") || postgres_schema_name.contains("\\") {
error_messages.insert_str(error_messages.len(), format!("Missing POSTGRES_SCHEMA_NAME env var or containing bad symbols").as_str());
}
if error_messages != "" {
return Err(Box::new(LambdaError::Error {
message: error_messages
}));
}
let event = match serde_json::from_value::<Request>(input_event) {
Ok(value) => value,
Err(err) => {
return Err(Box::new(LambdaError::Error {
message: format!("Input event parse failure {}\n", err)
}));
}
};
let lambda_client = aws_sdk_lambda::Client::new(&shared_config);
let sqs_client = aws_sdk_sqs::Client::new(&shared_config);
let command = event.command;
if command == "invoke-postgres" || command == "invoke-sqs" {
let payload = match command.as_str() {
"invoke-postgres" => aws_sdk_lambda::Blob::new(format!("{{ \"command\": \"send-postgres\" }}")),
"invoke-sqs" => aws_sdk_lambda::Blob::new(format!("{{ \"command\": \"send-sqs\" }}")),
_ => unreachable!()
};
let mut maybe_postgres_client: MaybePostgresClient = MaybePostgresClient::None;
match execute_query_auto_retry(&mut maybe_postgres_client, &postgres_connection_string, format!(r###"
DROP TABLE IF EXISTS "{0}"."incoming";
DROP TABLE IF EXISTS "{0}"."processing";
CREATE TABLE "{0}"."incoming"("timeguid" bigint primary key, "parts" integer[] not null);
CREATE TABLE "{0}"."processing"(
"part" integer primary key,
"border_timeguid" bigint not null,
"radix_timeguids" bigint[] not null,
"pid" integer null
);
INSERT INTO "{0}"."processing"("part", "border_timeguid", "radix_timeguids")
SELECT "i" as "part", 0 as "border_timeguid", array[]::bigint[] as "radix_timeguids"
FROM generate_series(0, {1}) "i";
"###, postgres_schema_name, concurrent_parts_count - 1)).await {
Ok(_) => {},
Err(err) => {
maybe_close_connection(&mut maybe_postgres_client).await;
return Err(Box::new(LambdaError::Error {
message: format!("Postgres resource management failure {}", err)
}));
}
};
let invoke_self = || lambda_client.invoke()
.function_name(&context.invoked_function_arn)
.invocation_type(aws_sdk_lambda::model::InvocationType::Event)
.payload(payload.clone())
.send();
let mut promises = vec!(invoke_self());
for _ in 0..concurrent_senders_count - 1 {
promises.push(invoke_self());
}
let mut success_invokes = 0;
let mut failed_invokes = 0;
for promise in promises {
match promise.await {
Ok(_) => { success_invokes = success_invokes + 1; },
Err(_) => { failed_invokes = failed_invokes + 1; }
};
}
return Ok(serde_json::to_value(Response {
msg: format!("Successfully invoked {} and failed {} lambdas", success_invokes, failed_invokes)
}).unwrap())
} else if command == "send-postgres" || command == "send-sqs" {
let mut maybe_postgres_client: MaybePostgresClient = MaybePostgresClient::None;
for _ in 0..incoming_messages_per_sender_count {
let mut parts = std::vec::Vec::<u32>::with_capacity(parts_per_message as usize);
for _ in 0..parts_per_message {
parts.push(rand::random::<u32>() % concurrent_parts_count);
}
match execute_query_auto_retry(&mut maybe_postgres_client, &postgres_connection_string, format!(r###"
INSERT INTO "{0}"."incoming"("parts", "timeguid") VALUES(array{1:?},
CAST(CAST(extract(epoch from clock_timestamp()) * 1000 AS BIGINT) * 1000000 AS BIGINT) + CAST(random() * 1000000 AS BIGINT)
)
"###, postgres_schema_name, parts)).await {
Ok(_) => {},
Err(err) => {
maybe_close_connection(&mut maybe_postgres_client).await;
return Err(Box::new(LambdaError::Error {
message: format!("Send incoming entries failed {}", err)
}));
}
};
for i in 0..parts.len() {
match command.as_str() {
"send-postgres" => {
match execute_query_auto_retry(&mut maybe_postgres_client, &postgres_connection_string, format!(r###"
WITH "lock_count" AS (
SELECT CASE WHEN Count(*) > 0 THEN 0 ELSE 1 END AS "ActiveLocksCount" FROM pg_locks
WHERE database = (SELECT oid FROM pg_database WHERE datname = current_database())
AND relation = (SELECT oid FROM pg_class WHERE relname = 'processing'
AND relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = '{0}'))
AND pid = (SELECT "E"."pid" FROM "{0}"."processing" "E" WHERE "E"."part" = {1})
)
SELECT 1/(SELECT "lock_count"."ActiveLocksCount" FROM "lock_count") AS "cnt"
"###, postgres_schema_name, parts[i]
)).await {
Ok(_) => {
match lambda_client.invoke()
.function_name(&context.invoked_function_arn)
.invocation_type(aws_sdk_lambda::model::InvocationType::Event)
.payload(aws_sdk_lambda::Blob::new(format!("{{ \"command\": \"work-postgres\", \"part\": {} }}", parts[i])))
.send()
.await
{ _ => {} };
},
Err(err) => {
if !is_zero_runtime_error(&err) {
maybe_close_connection(&mut maybe_postgres_client).await;
return Err(Box::new(LambdaError::Error {
message: format!("Send incoming entries failed {}", err)
}));
}
}
};
},
"send-sqs" => {
match sqs_client.send_message()
.queue_url(sqs_queue_url.as_str())
.message_body(format!("{{ \"command\": \"work-sqs\", \"part\": {} }}", parts[i]))
.message_group_id(format!("{}", parts[i]))
.send()
.await
{ _ => {} };
},
_ => unreachable!()
};
}
}
maybe_close_connection(&mut maybe_postgres_client).await;
return Ok(serde_json::to_value(Response {
msg: format!("Successfully invoked insert incoming messages")
}).unwrap())
} else if command == "work-postgres" {
let current_part = match event.part {
Some(part) => part,
None => {
return Err(Box::new(LambdaError::Error {
message: format!("Partition number is not provided for worker process")
}));
}
};
let mut maybe_postgres_client: MaybePostgresClient = MaybePostgresClient::None;
'worker_loop: loop {
match execute_query_auto_retry(&mut maybe_postgres_client, &postgres_connection_string, "SELECT 0".into()).await {
Ok(_) => {},
Err(err) => {
return Err(Box::new(LambdaError::Error {
message: format!("Worker database connection error {}", err)
}));
}
};
if let Some((ref mut client, _)) = maybe_postgres_client {
match client.batch_execute(format!(r###"
WITH "maybe_lock_part" AS (
SELECT * FROM "{0}"."processing" WHERE "part" = {1} FOR UPDATE NOWAIT
)
UPDATE "{0}"."processing" SET "pid" = pg_backend_pid()
WHERE (SELECT Count("maybe_lock_part".*) FROM "maybe_lock_part") <> 0
AND "part" = {1}
"###, postgres_schema_name, current_part).as_str()).await {
Ok(_) => {},
Err(err) => {
if !is_passthrough_error(&err) {
maybe_close_connection(&mut maybe_postgres_client).await;
if is_lock_busy_error(&err) {
return Ok(serde_json::to_value(Response {
msg: format!("Successfully skipped worker")
}).unwrap())
} else {
return Err(Box::new(LambdaError::Error {
message: format!("Worker process failed {}", err)
}));
}
} else {
continue 'worker_loop;
}
}
}
match client.batch_execute(format!(r###"
BEGIN TRANSACTION;
SELECT * FROM "{0}"."processing" WHERE "part" = {1}
AND 1 / (
SELECT Count("E".*) FROM "{0}"."processing" "E"
WHERE "E"."part" = {1} AND "E"."pid" = pg_backend_pid()
) <> 2
FOR UPDATE NOWAIT;
"###, postgres_schema_name, current_part).as_str()).await {
Ok(_) => {},
Err(err) => {
if !is_passthrough_error(&err) {
maybe_close_connection(&mut maybe_postgres_client).await;
if is_lock_busy_error(&err) || is_zero_runtime_error(&err) {
return Ok(serde_json::to_value(Response {
msg: format!("Successfully skipped worker part {}", current_part)
}).unwrap())
} else {
return Err(Box::new(LambdaError::Error {
message: format!("Worker process failed {}", err)
}));
}
} else {
continue 'worker_loop;
}
}
}
match client.batch_execute(format!(r###"
WITH "current_part" AS (
SELECT * FROM "{0}"."processing" WHERE "part" = {1} LIMIT 1
), "current_radix_column" AS (
SELECT unnest("current_part"."radix_timeguids") AS "timeguid" FROM "current_part"
), "available_messages" AS (
SELECT * FROM "{0}"."incoming" WHERE "timeguid" + 500000000 > (
SELECT "border_timeguid" FROM "current_part" LIMIT 1
) AND "timeguid" NOT IN (
SELECT "current_radix_column"."timeguid" FROM "current_radix_column"
) AND array_position("parts", {1}) IS NOT NULL
ORDER BY "timeguid"
), "next_radix_column" AS (
SELECT "current_radix_column"."timeguid" FROM "current_radix_column"
WHERE "current_radix_column"."timeguid" + 500000000 > (
SELECT max("available_messages"."timeguid") FROM "available_messages"
) UNION ALL
SELECT "available_messages"."timeguid" FROM "available_messages"
WHERE "available_messages"."timeguid" + 500000000 > (
SELECT max("available_messages"."timeguid") FROM "available_messages"
)
)
UPDATE "{0}"."processing" SET "border_timeguid" = (
SELECT max("available_messages"."timeguid") FROM "available_messages"
), "radix_timeguids" = (
SELECT array_agg("next_radix_column"."timeguid") FROM "next_radix_column"
)
WHERE "part" = {1} AND (
SELECT count("available_messages".*) FROM "available_messages"
) > 0
"###, postgres_schema_name, current_part).as_str()).await {
Ok(_) => {},
Err(err) => {
if !is_passthrough_error(&err) {
maybe_close_connection(&mut maybe_postgres_client).await;
return Err(Box::new(LambdaError::Error {
message: format!("Worker process failed {}", err)
}));
} else {
continue 'worker_loop;
}
}
};
match client.batch_execute(format!(r###"
COMMIT
"###).as_str()).await {
Ok(_) => {},
Err(err) => {
if !is_passthrough_error(&err) {
maybe_close_connection(&mut maybe_postgres_client).await;
return Err(Box::new(LambdaError::Error {
message: format!("Worker process failed {}", err)
}));
} else {
continue 'worker_loop;
}
}
}
} else {
unreachable!();
}
break 'worker_loop;
};
maybe_close_connection(&mut maybe_postgres_client).await;
return Ok(serde_json::to_value(Response {
msg: format!("Successfully done worker part {}", current_part)
}).unwrap());
} else {
return Err(Box::new(LambdaError::Error {
message: format!("Unknown command {}", command)
}));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment