Last active
February 5, 2021 20:41
-
-
Save rodoufu/faf1bdd27d14a01c00b83f6a5c3f52da to your computer and use it in GitHub Desktop.
Rust UDP and TCP
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use chrono::{ | |
Timelike, | |
Utc, | |
}; | |
use std::{ | |
io::{ | |
self, | |
Read, | |
Error as IoError, | |
ErrorKind, | |
Write, | |
}, | |
net::{ | |
TcpListener, | |
TcpStream, | |
UdpSocket, | |
}, | |
time::Duration, | |
}; | |
#[derive(Debug, Clone, Hash, Ord, PartialOrd, Eq, PartialEq)] | |
pub enum ConnectionType { | |
TCP, | |
UDP, | |
} | |
pub trait Connection { | |
fn connection_type(&self) -> ConnectionType; | |
fn new(endpoint: &str, port: u16, is_server: bool) -> io::Result<Self> where Self: Sized; | |
fn send_timeout(&mut self, buf: &[u8], timeout: Option<Duration>) -> io::Result<usize>; | |
fn send(&mut self, buf: &[u8]) -> io::Result<usize>; | |
fn recv_timeout(&mut self, buf: &mut [u8], timeout: Option<Duration>) -> io::Result<usize>; | |
fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize>; | |
fn reconnect(&mut self) -> io::Result<()>; | |
} | |
struct UdpConnection { | |
socket: UdpSocket, | |
endpoint: String, | |
port: u16, | |
is_server: bool, | |
} | |
impl Connection for UdpConnection { | |
fn connection_type(&self) -> ConnectionType { ConnectionType::UDP } | |
fn new(endpoint: &str, port: u16, is_server: bool) -> io::Result<Self> where Self: Sized { | |
let socket = UdpSocket::bind( | |
format!("0.0.0.0:{}", if is_server { port } else { 0 }) | |
)?; | |
if !is_server { | |
socket.connect(format!("{}:{}", endpoint, port))?; | |
} | |
Ok(UdpConnection { | |
socket, | |
endpoint: endpoint.to_string(), | |
port, | |
is_server, | |
}) | |
} | |
fn send_timeout(&mut self, buf: &[u8], timeout: Option<Duration>) -> io::Result<usize> { | |
self.socket.set_write_timeout(timeout)?; | |
self.send(buf) | |
} | |
fn send(&mut self, buf: &[u8]) -> io::Result<usize> { | |
self.socket.send(buf) | |
} | |
fn recv_timeout(&mut self, buf: &mut [u8], timeout: Option<Duration>) -> io::Result<usize> { | |
self.socket.set_read_timeout(timeout)?; | |
self.recv(buf) | |
} | |
fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> { | |
// self.socket.recv_from(buf).map(|count_addr| count_addr.0) | |
self.socket.recv(buf) | |
} | |
fn reconnect(&mut self) -> io::Result<()> { | |
*self = Self::new(&self.endpoint, self.port, self.is_server)?; | |
Ok(()) | |
} | |
} | |
pub enum TcpConnection { | |
Listener { | |
listener: TcpListener, | |
stream: Option<TcpStream>, | |
endpoint: String, | |
port: u16, | |
}, | |
Sender { | |
stream: TcpStream, | |
endpoint: String, | |
port: u16, | |
}, | |
} | |
impl TcpConnection { | |
fn get_stream(&mut self) -> io::Result<&mut TcpStream> { | |
match self { | |
TcpConnection::Listener { listener, stream, .. } => { | |
if let None = stream { | |
*stream = Some(listener.incoming().next().unwrap()?); | |
} | |
if let Some(stream) = stream { | |
Ok(stream) | |
} else { | |
Err(IoError::new(ErrorKind::Other, "No stream to write")) | |
} | |
} | |
TcpConnection::Sender { stream, .. } => Ok(stream) | |
} | |
} | |
} | |
impl Connection for TcpConnection { | |
fn connection_type(&self) -> ConnectionType { ConnectionType::TCP } | |
fn new(endpoint: &str, port: u16, is_server: bool) -> io::Result<Self> where Self: Sized { | |
let connect_to = format!("{}:{}", endpoint, port); | |
if is_server { | |
Ok(Self::Listener { | |
listener: TcpListener::bind(connect_to)?, | |
stream: None, | |
endpoint: endpoint.to_string(), | |
port, | |
}) | |
} else { | |
Ok(Self::Sender { | |
stream: TcpStream::connect(connect_to)?, | |
endpoint: endpoint.to_string(), | |
port, | |
}) | |
} | |
} | |
fn send_timeout(&mut self, buf: &[u8], timeout: Option<Duration>) -> io::Result<usize> { | |
let stream = self.get_stream()?; | |
stream.set_write_timeout(timeout)?; | |
self.send(buf) | |
} | |
fn send(&mut self, buf: &[u8]) -> io::Result<usize> { | |
let stream = self.get_stream()?; | |
stream.write(buf) | |
} | |
fn recv_timeout(&mut self, buf: &mut [u8], timeout: Option<Duration>) -> io::Result<usize> { | |
let stream = self.get_stream()?; | |
stream.set_read_timeout(timeout)?; | |
self.recv(buf) | |
} | |
fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> { | |
let stream = self.get_stream()?; | |
stream.read(buf) | |
} | |
fn reconnect(&mut self) -> io::Result<()> { | |
*self = match self { | |
TcpConnection::Listener { endpoint, port, .. } => | |
Self::new(endpoint, *port, true)?, | |
TcpConnection::Sender { endpoint, port, .. } => | |
Self::new(endpoint, *port, false)?, | |
}; | |
Ok(()) | |
} | |
} | |
pub fn new_connection( | |
connection_type: ConnectionType, endpoint: &str, port: u16, bind: bool, | |
) -> io::Result<Box<dyn Connection + '_>> { | |
match connection_type { | |
ConnectionType::TCP => Ok(Box::new(TcpConnection::new(endpoint, port, bind)?)), | |
ConnectionType::UDP => Ok(Box::new(UdpConnection::new(endpoint, port, bind)?)) | |
} | |
} | |
pub fn precise_time_ns() -> i64 { | |
let now = Utc::now(); | |
let seconds: i64 = now.timestamp(); | |
let nanoseconds: i64 = now.nanosecond() as i64; | |
(seconds * 1000 * 1000 * 1000) + nanoseconds | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use clap::{ | |
Arg, | |
App, | |
AppSettings, | |
}; | |
use std::{ | |
io::ErrorKind, | |
time::Duration, | |
thread, | |
}; | |
use udp_tcp::{ | |
ConnectionType, | |
new_connection, | |
precise_time_ns, | |
}; | |
const VERSION: &'static str = env!("CARGO_PKG_VERSION"); | |
const CONNECTION_TYPE_ARG: &'static str = "connection-type"; | |
fn main() -> Result<(), String> { | |
let app = cli_arguments(); | |
let matches = app.get_matches(); | |
let connection_type = if let Some(value) = matches.value_of(CONNECTION_TYPE_ARG) { | |
match value { | |
"udp" | "UDP" => Ok(ConnectionType::UDP), | |
"tcp" | "TCP" => Ok(ConnectionType::TCP), | |
_ => Err(format!("Invalid connection type {}", value)), | |
} | |
} else { | |
Ok(ConnectionType::TCP) | |
}?; | |
let endpoint = matches.value_of("endpoint").unwrap_or("0.0.0.0"); | |
let port = matches.value_of("port").unwrap_or("3333") | |
.parse::<u16>().map_err(|e| format!("Error getting port number {:?}", e))?; | |
let repeat = matches.value_of("repeat").unwrap_or("1") | |
.parse::<u32>().map_err(|e| format!("Error getting repeat {:?}", e))?; | |
let timeout_ms = matches.value_of("timeout").unwrap_or("100") | |
.parse::<u64>().map_err(|e| format!("Error getting timeout {:?}", e))?; | |
let delay_ms = matches.value_of("delay").unwrap_or("10") | |
.parse::<u64>().map_err(|e| format!("Error getting delay {:?}", e))?; | |
let lost_connection_delay_ms = matches.value_of("lost-connection-delay") | |
.unwrap_or("1000") | |
.parse::<u64>().map_err(|e| format!("Error getting lost connection delay {:?}", e))?; | |
let delay_each_n_msg = matches.value_of("delay-each-n-msg").unwrap_or("1") | |
.parse::<u32>().map_err(|e| format!("Error getting number of messages to wait between delays {:?}", e))?; | |
let is_server = matches.is_present("server"); | |
let is_client = matches.is_present("client"); | |
let is_follow = matches.is_present("follow"); | |
let is_verbose = matches.is_present("verbose"); | |
let mut connection = new_connection( | |
connection_type, endpoint, port, is_server, | |
).map_err(|e| format!("Error creating connection {:?}", e))?; | |
println!("Repeat {}", repeat); | |
let delay_duration = Duration::from_millis(delay_ms); | |
let lost_conn_delay = Duration::from_millis(lost_connection_delay_ms); | |
let mut is_timeout = false; | |
let mut before = None; | |
let mut count_messages = 0; | |
let mut buf: [u8; 13] = [0; 13]; | |
if is_server { | |
println!("Listening at {}", port); | |
} else if is_client { | |
println!("Connecting to {}:{}", endpoint, port); | |
} | |
for i in 0..repeat { | |
if is_timeout { | |
break; | |
} | |
count_messages = i; | |
let timeout = if repeat > 1 && i > 0 && !is_follow { | |
Some(Duration::from_millis(timeout_ms)) | |
} else { | |
None | |
}; | |
if is_server { | |
let mut id: u32; | |
loop { | |
let count = match connection.recv_timeout(&mut buf, timeout) { | |
Ok(v) => v, | |
Err(err) => { | |
if ErrorKind::WouldBlock == err.kind() { | |
is_timeout = true; | |
break; | |
} else { | |
return Err(format!("Error receiving data {:?}", err)); | |
} | |
} | |
}; | |
if before.is_none() { | |
before = Some(precise_time_ns()); | |
} | |
id = 0; | |
for pos in 0..4 { | |
id *= 256; | |
id += buf[3 - pos] as u32; | |
} | |
if is_verbose { | |
println!("{}/{} Received {} bytes id: {} - {:?}", i, repeat, count, id, buf); | |
} | |
if !is_follow { | |
break; | |
} | |
} | |
} else if is_client { | |
if i % delay_each_n_msg == 0 { | |
thread::sleep(delay_duration); | |
} | |
if before.is_none() { | |
before = Some(precise_time_ns()); | |
} | |
let mut it = i; | |
for pos in 0..4 { | |
buf[pos] = (it % 256) as u8; | |
it /= 256; | |
} | |
let mut count = None; | |
let mut reconnect = false; | |
for _ in 0..5 { | |
if reconnect { | |
thread::sleep(lost_conn_delay); | |
connection.reconnect() | |
.map_err(|e| format!("Error reconnecting again {:?}", e))?; | |
} | |
reconnect = false; | |
let retry; | |
let resp_count = match connection.send_timeout(&buf, timeout) { | |
Ok(v) => { | |
retry = false; | |
Ok(v) | |
} | |
Err(err) => { | |
match err.kind() { | |
ErrorKind::BrokenPipe | ErrorKind::ConnectionRefused | | |
ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted | | |
ErrorKind::TimedOut | ErrorKind::Interrupted | ErrorKind::UnexpectedEof => { | |
// Try again | |
retry = true; | |
println!("Trying again after: {:?}", err); | |
thread::sleep(lost_conn_delay); | |
if let Err(err) = connection.reconnect() | |
.map_err(|e| format!("Error reconnecting {:?}", e)) { | |
eprintln!("{}", err); | |
reconnect = true; | |
} | |
Err(err) | |
} | |
_ => { | |
retry = false; | |
Err(err) | |
} | |
}.map_err(|e| format!("Error sending data {:?}", e)) | |
} | |
}; | |
if !retry { | |
count = Some(resp_count?); | |
break; | |
} | |
} | |
if is_verbose { | |
println!( | |
"{}/{} Sent {} bytes id: {} - {:?}", i, repeat, count.unwrap_or(0), i, | |
buf, | |
); | |
} | |
} | |
} | |
let after = precise_time_ns(); | |
let time_diff = after - before.unwrap(); | |
count_messages += 1; | |
println!( | |
"Process took {} milliseconds{} to process {}/{} ({}%) messages", | |
time_diff as f64 / 1e6 - if is_timeout { timeout_ms as f64 } else { 0f64 }, | |
if is_timeout { " after timeout" } else { "" }, | |
count_messages, repeat, (100 * count_messages) as f64 / repeat as f64, | |
); | |
Ok(()) | |
} | |
fn cli_arguments() -> App<'static, 'static> { | |
App::new("udp_tcp").version(VERSION) | |
.arg( | |
Arg::with_name(CONNECTION_TYPE_ARG).long(CONNECTION_TYPE_ARG) | |
.help("Connection type TCP/UDP").takes_value(true) | |
.possible_values(&["udp", "tcp", "UDP", "TCP"]) | |
) | |
.arg( | |
Arg::with_name("server").short("s").long("server") | |
.help("Start server") | |
) | |
.arg( | |
Arg::with_name("client").short("c").long("client") | |
.help("Start client") | |
) | |
.arg( | |
Arg::with_name("listener").long("listener") | |
.help("Start listener") | |
) | |
.arg( | |
Arg::with_name("sender").long("sender") | |
.help("Start sender") | |
) | |
.arg( | |
Arg::with_name("endpoint").short("e").long("endpoint").takes_value(true) | |
.help("Endpoint") | |
) | |
.arg( | |
Arg::with_name("port").short("p").long("port").takes_value(true) | |
.help("Port number") | |
) | |
.arg( | |
Arg::with_name("repeat").short("r").long("repeat").takes_value(true) | |
.help("Repeat N times") | |
) | |
.arg( | |
Arg::with_name("timeout").short("t").long("timeout").takes_value(true) | |
.help("Timeout in milliseconds") | |
) | |
.arg( | |
Arg::with_name("delay").short("d").long("delay").takes_value(true) | |
.help("Delay between messages in milliseconds") | |
) | |
.arg( | |
Arg::with_name("lost-connection-delay") | |
.long("lost-connection-delay").takes_value(true) | |
.help("Lost connection delay in milliseconds") | |
) | |
.arg( | |
Arg::with_name("delay-each-n-msg").long("delay-each-n-msg") | |
.takes_value(true) | |
.help("number of messages to wait between delays") | |
) | |
.arg( | |
Arg::with_name("follow").short("f") | |
.help("Keep receiving info") | |
) | |
.arg( | |
Arg::with_name("verbose").short("v").long("verbose") | |
.help("Verbose") | |
) | |
.setting(AppSettings::ArgRequiredElseHelp) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment