Skip to content

Instantly share code, notes, and snippets.

@rodoufu
Last active February 5, 2021 20:41
Show Gist options
  • Save rodoufu/faf1bdd27d14a01c00b83f6a5c3f52da to your computer and use it in GitHub Desktop.
Save rodoufu/faf1bdd27d14a01c00b83f6a5c3f52da to your computer and use it in GitHub Desktop.
Rust UDP and TCP
use chrono::{
Timelike,
Utc,
};
use std::{
io::{
self,
Read,
Error as IoError,
ErrorKind,
Write,
},
net::{
TcpListener,
TcpStream,
UdpSocket,
},
time::Duration,
};
#[derive(Debug, Clone, Hash, Ord, PartialOrd, Eq, PartialEq)]
pub enum ConnectionType {
TCP,
UDP,
}
pub trait Connection {
fn connection_type(&self) -> ConnectionType;
fn new(endpoint: &str, port: u16, is_server: bool) -> io::Result<Self> where Self: Sized;
fn send_timeout(&mut self, buf: &[u8], timeout: Option<Duration>) -> io::Result<usize>;
fn send(&mut self, buf: &[u8]) -> io::Result<usize>;
fn recv_timeout(&mut self, buf: &mut [u8], timeout: Option<Duration>) -> io::Result<usize>;
fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize>;
fn reconnect(&mut self) -> io::Result<()>;
}
struct UdpConnection {
socket: UdpSocket,
endpoint: String,
port: u16,
is_server: bool,
}
impl Connection for UdpConnection {
fn connection_type(&self) -> ConnectionType { ConnectionType::UDP }
fn new(endpoint: &str, port: u16, is_server: bool) -> io::Result<Self> where Self: Sized {
let socket = UdpSocket::bind(
format!("0.0.0.0:{}", if is_server { port } else { 0 })
)?;
if !is_server {
socket.connect(format!("{}:{}", endpoint, port))?;
}
Ok(UdpConnection {
socket,
endpoint: endpoint.to_string(),
port,
is_server,
})
}
fn send_timeout(&mut self, buf: &[u8], timeout: Option<Duration>) -> io::Result<usize> {
self.socket.set_write_timeout(timeout)?;
self.send(buf)
}
fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
self.socket.send(buf)
}
fn recv_timeout(&mut self, buf: &mut [u8], timeout: Option<Duration>) -> io::Result<usize> {
self.socket.set_read_timeout(timeout)?;
self.recv(buf)
}
fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
// self.socket.recv_from(buf).map(|count_addr| count_addr.0)
self.socket.recv(buf)
}
fn reconnect(&mut self) -> io::Result<()> {
*self = Self::new(&self.endpoint, self.port, self.is_server)?;
Ok(())
}
}
pub enum TcpConnection {
Listener {
listener: TcpListener,
stream: Option<TcpStream>,
endpoint: String,
port: u16,
},
Sender {
stream: TcpStream,
endpoint: String,
port: u16,
},
}
impl TcpConnection {
fn get_stream(&mut self) -> io::Result<&mut TcpStream> {
match self {
TcpConnection::Listener { listener, stream, .. } => {
if let None = stream {
*stream = Some(listener.incoming().next().unwrap()?);
}
if let Some(stream) = stream {
Ok(stream)
} else {
Err(IoError::new(ErrorKind::Other, "No stream to write"))
}
}
TcpConnection::Sender { stream, .. } => Ok(stream)
}
}
}
impl Connection for TcpConnection {
fn connection_type(&self) -> ConnectionType { ConnectionType::TCP }
fn new(endpoint: &str, port: u16, is_server: bool) -> io::Result<Self> where Self: Sized {
let connect_to = format!("{}:{}", endpoint, port);
if is_server {
Ok(Self::Listener {
listener: TcpListener::bind(connect_to)?,
stream: None,
endpoint: endpoint.to_string(),
port,
})
} else {
Ok(Self::Sender {
stream: TcpStream::connect(connect_to)?,
endpoint: endpoint.to_string(),
port,
})
}
}
fn send_timeout(&mut self, buf: &[u8], timeout: Option<Duration>) -> io::Result<usize> {
let stream = self.get_stream()?;
stream.set_write_timeout(timeout)?;
self.send(buf)
}
fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
let stream = self.get_stream()?;
stream.write(buf)
}
fn recv_timeout(&mut self, buf: &mut [u8], timeout: Option<Duration>) -> io::Result<usize> {
let stream = self.get_stream()?;
stream.set_read_timeout(timeout)?;
self.recv(buf)
}
fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let stream = self.get_stream()?;
stream.read(buf)
}
fn reconnect(&mut self) -> io::Result<()> {
*self = match self {
TcpConnection::Listener { endpoint, port, .. } =>
Self::new(endpoint, *port, true)?,
TcpConnection::Sender { endpoint, port, .. } =>
Self::new(endpoint, *port, false)?,
};
Ok(())
}
}
pub fn new_connection(
connection_type: ConnectionType, endpoint: &str, port: u16, bind: bool,
) -> io::Result<Box<dyn Connection + '_>> {
match connection_type {
ConnectionType::TCP => Ok(Box::new(TcpConnection::new(endpoint, port, bind)?)),
ConnectionType::UDP => Ok(Box::new(UdpConnection::new(endpoint, port, bind)?))
}
}
pub fn precise_time_ns() -> i64 {
let now = Utc::now();
let seconds: i64 = now.timestamp();
let nanoseconds: i64 = now.nanosecond() as i64;
(seconds * 1000 * 1000 * 1000) + nanoseconds
}
use clap::{
Arg,
App,
AppSettings,
};
use std::{
io::ErrorKind,
time::Duration,
thread,
};
use udp_tcp::{
ConnectionType,
new_connection,
precise_time_ns,
};
const VERSION: &'static str = env!("CARGO_PKG_VERSION");
const CONNECTION_TYPE_ARG: &'static str = "connection-type";
fn main() -> Result<(), String> {
let app = cli_arguments();
let matches = app.get_matches();
let connection_type = if let Some(value) = matches.value_of(CONNECTION_TYPE_ARG) {
match value {
"udp" | "UDP" => Ok(ConnectionType::UDP),
"tcp" | "TCP" => Ok(ConnectionType::TCP),
_ => Err(format!("Invalid connection type {}", value)),
}
} else {
Ok(ConnectionType::TCP)
}?;
let endpoint = matches.value_of("endpoint").unwrap_or("0.0.0.0");
let port = matches.value_of("port").unwrap_or("3333")
.parse::<u16>().map_err(|e| format!("Error getting port number {:?}", e))?;
let repeat = matches.value_of("repeat").unwrap_or("1")
.parse::<u32>().map_err(|e| format!("Error getting repeat {:?}", e))?;
let timeout_ms = matches.value_of("timeout").unwrap_or("100")
.parse::<u64>().map_err(|e| format!("Error getting timeout {:?}", e))?;
let delay_ms = matches.value_of("delay").unwrap_or("10")
.parse::<u64>().map_err(|e| format!("Error getting delay {:?}", e))?;
let lost_connection_delay_ms = matches.value_of("lost-connection-delay")
.unwrap_or("1000")
.parse::<u64>().map_err(|e| format!("Error getting lost connection delay {:?}", e))?;
let delay_each_n_msg = matches.value_of("delay-each-n-msg").unwrap_or("1")
.parse::<u32>().map_err(|e| format!("Error getting number of messages to wait between delays {:?}", e))?;
let is_server = matches.is_present("server");
let is_client = matches.is_present("client");
let is_follow = matches.is_present("follow");
let is_verbose = matches.is_present("verbose");
let mut connection = new_connection(
connection_type, endpoint, port, is_server,
).map_err(|e| format!("Error creating connection {:?}", e))?;
println!("Repeat {}", repeat);
let delay_duration = Duration::from_millis(delay_ms);
let lost_conn_delay = Duration::from_millis(lost_connection_delay_ms);
let mut is_timeout = false;
let mut before = None;
let mut count_messages = 0;
let mut buf: [u8; 13] = [0; 13];
if is_server {
println!("Listening at {}", port);
} else if is_client {
println!("Connecting to {}:{}", endpoint, port);
}
for i in 0..repeat {
if is_timeout {
break;
}
count_messages = i;
let timeout = if repeat > 1 && i > 0 && !is_follow {
Some(Duration::from_millis(timeout_ms))
} else {
None
};
if is_server {
let mut id: u32;
loop {
let count = match connection.recv_timeout(&mut buf, timeout) {
Ok(v) => v,
Err(err) => {
if ErrorKind::WouldBlock == err.kind() {
is_timeout = true;
break;
} else {
return Err(format!("Error receiving data {:?}", err));
}
}
};
if before.is_none() {
before = Some(precise_time_ns());
}
id = 0;
for pos in 0..4 {
id *= 256;
id += buf[3 - pos] as u32;
}
if is_verbose {
println!("{}/{} Received {} bytes id: {} - {:?}", i, repeat, count, id, buf);
}
if !is_follow {
break;
}
}
} else if is_client {
if i % delay_each_n_msg == 0 {
thread::sleep(delay_duration);
}
if before.is_none() {
before = Some(precise_time_ns());
}
let mut it = i;
for pos in 0..4 {
buf[pos] = (it % 256) as u8;
it /= 256;
}
let mut count = None;
let mut reconnect = false;
for _ in 0..5 {
if reconnect {
thread::sleep(lost_conn_delay);
connection.reconnect()
.map_err(|e| format!("Error reconnecting again {:?}", e))?;
}
reconnect = false;
let retry;
let resp_count = match connection.send_timeout(&buf, timeout) {
Ok(v) => {
retry = false;
Ok(v)
}
Err(err) => {
match err.kind() {
ErrorKind::BrokenPipe | ErrorKind::ConnectionRefused |
ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted |
ErrorKind::TimedOut | ErrorKind::Interrupted | ErrorKind::UnexpectedEof => {
// Try again
retry = true;
println!("Trying again after: {:?}", err);
thread::sleep(lost_conn_delay);
if let Err(err) = connection.reconnect()
.map_err(|e| format!("Error reconnecting {:?}", e)) {
eprintln!("{}", err);
reconnect = true;
}
Err(err)
}
_ => {
retry = false;
Err(err)
}
}.map_err(|e| format!("Error sending data {:?}", e))
}
};
if !retry {
count = Some(resp_count?);
break;
}
}
if is_verbose {
println!(
"{}/{} Sent {} bytes id: {} - {:?}", i, repeat, count.unwrap_or(0), i,
buf,
);
}
}
}
let after = precise_time_ns();
let time_diff = after - before.unwrap();
count_messages += 1;
println!(
"Process took {} milliseconds{} to process {}/{} ({}%) messages",
time_diff as f64 / 1e6 - if is_timeout { timeout_ms as f64 } else { 0f64 },
if is_timeout { " after timeout" } else { "" },
count_messages, repeat, (100 * count_messages) as f64 / repeat as f64,
);
Ok(())
}
fn cli_arguments() -> App<'static, 'static> {
App::new("udp_tcp").version(VERSION)
.arg(
Arg::with_name(CONNECTION_TYPE_ARG).long(CONNECTION_TYPE_ARG)
.help("Connection type TCP/UDP").takes_value(true)
.possible_values(&["udp", "tcp", "UDP", "TCP"])
)
.arg(
Arg::with_name("server").short("s").long("server")
.help("Start server")
)
.arg(
Arg::with_name("client").short("c").long("client")
.help("Start client")
)
.arg(
Arg::with_name("listener").long("listener")
.help("Start listener")
)
.arg(
Arg::with_name("sender").long("sender")
.help("Start sender")
)
.arg(
Arg::with_name("endpoint").short("e").long("endpoint").takes_value(true)
.help("Endpoint")
)
.arg(
Arg::with_name("port").short("p").long("port").takes_value(true)
.help("Port number")
)
.arg(
Arg::with_name("repeat").short("r").long("repeat").takes_value(true)
.help("Repeat N times")
)
.arg(
Arg::with_name("timeout").short("t").long("timeout").takes_value(true)
.help("Timeout in milliseconds")
)
.arg(
Arg::with_name("delay").short("d").long("delay").takes_value(true)
.help("Delay between messages in milliseconds")
)
.arg(
Arg::with_name("lost-connection-delay")
.long("lost-connection-delay").takes_value(true)
.help("Lost connection delay in milliseconds")
)
.arg(
Arg::with_name("delay-each-n-msg").long("delay-each-n-msg")
.takes_value(true)
.help("number of messages to wait between delays")
)
.arg(
Arg::with_name("follow").short("f")
.help("Keep receiving info")
)
.arg(
Arg::with_name("verbose").short("v").long("verbose")
.help("Verbose")
)
.setting(AppSettings::ArgRequiredElseHelp)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment