Created
August 3, 2021 20:09
-
-
Save film42/c56a0e9b6c7b4da4936249813034e05d to your computer and use it in GitHub Desktop.
Forked tokio diesel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use async_trait::async_trait; | |
use diesel::{ | |
connection::SimpleConnection, | |
dsl::Limit, | |
query_dsl::{ | |
methods::{ExecuteDsl, LimitDsl, LoadQuery}, | |
RunQueryDsl, | |
}, | |
r2d2::{ConnectionManager, Pool, R2D2Connection}, | |
result::QueryResult, | |
}; | |
use std::{error::Error as StdError, fmt}; | |
use tokio::task; | |
pub type AsyncResult<R> = Result<R, AsyncError>; | |
#[derive(Debug)] | |
pub enum AsyncError { | |
// Failed to checkout a connection | |
Checkout(r2d2::Error), | |
// The query failed in some way | |
Error(diesel::result::Error), | |
} | |
pub trait OptionalExtension<T> { | |
fn optional(self) -> Result<Option<T>, AsyncError>; | |
} | |
impl<T> OptionalExtension<T> for AsyncResult<T> { | |
fn optional(self) -> Result<Option<T>, AsyncError> { | |
match self { | |
Ok(value) => Ok(Some(value)), | |
Err(AsyncError::Error(diesel::result::Error::NotFound)) => Ok(None), | |
Err(e) => Err(e), | |
} | |
} | |
} | |
impl fmt::Display for AsyncError { | |
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
match *self { | |
AsyncError::Checkout(ref err) => err.fmt(f), | |
AsyncError::Error(ref err) => err.fmt(f), | |
} | |
} | |
} | |
impl StdError for AsyncError { | |
fn source(&self) -> Option<&(dyn StdError + 'static)> { | |
match *self { | |
AsyncError::Checkout(ref err) => Some(err), | |
AsyncError::Error(ref err) => Some(err), | |
} | |
} | |
} | |
#[async_trait] | |
pub trait AsyncSimpleConnection<Conn> | |
where | |
Conn: 'static + SimpleConnection, | |
{ | |
async fn batch_execute_async(&self, query: &str) -> AsyncResult<()>; | |
} | |
#[async_trait] | |
impl<Conn> AsyncSimpleConnection<Conn> for Pool<ConnectionManager<Conn>> | |
where | |
Conn: 'static + R2D2Connection, | |
{ | |
#[inline] | |
async fn batch_execute_async(&self, query: &str) -> AsyncResult<()> { | |
let self_ = self.clone(); | |
let query = query.to_string(); | |
task::spawn_blocking(move || { | |
let mut conn = self_.get().map_err(AsyncError::Checkout)?; | |
conn.batch_execute(&query).map_err(AsyncError::Error) | |
}) | |
.await | |
.expect("task has panicked") | |
} | |
} | |
#[async_trait] | |
pub trait AsyncConnection<Conn>: AsyncSimpleConnection<Conn> | |
where | |
Conn: 'static + R2D2Connection, | |
{ | |
async fn run<R, Func>(&self, f: Func) -> AsyncResult<R> | |
where | |
R: 'static + Send, | |
Func: 'static + FnOnce(&mut Conn) -> QueryResult<R> + Send; | |
async fn transaction<R, Func>(&self, f: Func) -> AsyncResult<R> | |
where | |
R: 'static + Send, | |
Func: 'static + FnOnce(&mut Conn) -> QueryResult<R> + Send; | |
} | |
#[async_trait] | |
impl<Conn> AsyncConnection<Conn> for Pool<ConnectionManager<Conn>> | |
where | |
Conn: 'static + R2D2Connection, | |
{ | |
#[inline] | |
async fn run<R, Func>(&self, f: Func) -> AsyncResult<R> | |
where | |
R: 'static + Send, | |
Func: 'static + FnOnce(&mut Conn) -> QueryResult<R> + Send, | |
{ | |
let self_ = self.clone(); | |
task::spawn_blocking(move || { | |
let mut conn = self_.get().map_err(AsyncError::Checkout)?; | |
f(&mut *conn).map_err(AsyncError::Error) | |
}) | |
.await | |
.expect("task has panicked") | |
} | |
#[inline] | |
async fn transaction<R, Func>(&self, f: Func) -> AsyncResult<R> | |
where | |
R: 'static + Send, | |
Func: 'static + FnOnce(&mut Conn) -> QueryResult<R> + Send, | |
{ | |
let self_ = self.clone(); | |
task::spawn_blocking(move || { | |
let mut conn = self_.get().map_err(AsyncError::Checkout)?; | |
conn.transaction(|txn| f(&mut *txn)) | |
.map_err(AsyncError::Error) | |
}) | |
.await | |
.expect("task has panicked") | |
} | |
} | |
#[async_trait] | |
pub trait AsyncRunQueryDsl<Conn, AsyncConn> | |
where | |
Conn: 'static + R2D2Connection, | |
{ | |
async fn execute_async(self, asc: &AsyncConn) -> AsyncResult<usize> | |
where | |
Self: ExecuteDsl<Conn>; | |
async fn load_async<U>(self, asc: &AsyncConn) -> AsyncResult<Vec<U>> | |
where | |
U: 'static + Send, | |
Self: LoadQuery<Conn, U>; | |
async fn get_result_async<U>(self, asc: &AsyncConn) -> AsyncResult<U> | |
where | |
U: 'static + Send, | |
Self: LoadQuery<Conn, U>; | |
async fn get_results_async<U>(self, asc: &AsyncConn) -> AsyncResult<Vec<U>> | |
where | |
U: 'static + Send, | |
Self: LoadQuery<Conn, U>; | |
async fn first_async<U>(self, asc: &AsyncConn) -> AsyncResult<U> | |
where | |
U: 'static + Send, | |
Self: LimitDsl, | |
Limit<Self>: LoadQuery<Conn, U>; | |
} | |
#[async_trait] | |
impl<T, Conn> AsyncRunQueryDsl<Conn, Pool<ConnectionManager<Conn>>> for T | |
where | |
T: 'static + Send + RunQueryDsl<Conn>, | |
Conn: 'static + R2D2Connection, | |
{ | |
async fn execute_async(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<usize> | |
where | |
Self: ExecuteDsl<Conn>, | |
{ | |
asc.run(|conn| self.execute(&mut *conn)).await | |
} | |
async fn load_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<Vec<U>> | |
where | |
U: 'static + Send, | |
Self: LoadQuery<Conn, U>, | |
{ | |
asc.run(|conn| self.load(&mut *conn)).await | |
} | |
async fn get_result_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<U> | |
where | |
U: 'static + Send, | |
Self: LoadQuery<Conn, U>, | |
{ | |
asc.run(|conn| self.get_result(&mut *conn)).await | |
} | |
async fn get_results_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<Vec<U>> | |
where | |
U: 'static + Send, | |
Self: LoadQuery<Conn, U>, | |
{ | |
asc.run(|conn| self.get_results(&mut *conn)).await | |
} | |
async fn first_async<U>(self, asc: &Pool<ConnectionManager<Conn>>) -> AsyncResult<U> | |
where | |
U: 'static + Send, | |
Self: LimitDsl, | |
Limit<Self>: LoadQuery<Conn, U>, | |
{ | |
asc.run(|conn| self.first(&mut *conn)).await | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment