diff --git a/Cargo.lock b/Cargo.lock index a4d754d..facd298 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1603,6 +1603,7 @@ dependencies = [ "hyper-rustls", "log", "pem", + "rand", "reqwest", "rustls", "rustls-pemfile", diff --git a/Cargo.toml b/Cargo.toml index 4a1925e..a789475 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,4 +24,7 @@ urlencoding = "2.1.0" hyper-rustls = { version = "0.23.0", features = ["rustls-native-certs"] } bytes = "1.2.1" rustls-pemfile = "1.0.1" -rustls = "0.20.6" \ No newline at end of file +rustls = "0.20.6" + +[dev-dependencies] +rand = "0.8.5" \ No newline at end of file diff --git a/src/base/cert_utils.rs b/src/base/cert_utils.rs index 620de2d..9dac6c9 100644 --- a/src/base/cert_utils.rs +++ b/src/base/cert_utils.rs @@ -50,10 +50,10 @@ pub fn parse_pem_private_key(privkey: &[u8]) -> Result, - #[clap(skip)] - _root_certificate_cache: Option>, - /// TLS certificate for TLS authentication. #[clap(long)] pub tls_cert: Option, - #[clap(skip)] - _tls_cert_cache: Option>, - /// TLS key for TLS authentication. #[clap(long)] pub tls_key: Option, + #[clap(skip)] + pub _keys_cache: KeysCache, +} + +#[derive(Parser, Debug, Clone, Default)] +pub struct KeysCache { + #[clap(skip)] + _root_certificate_cache: Option>, + #[clap(skip)] + _tls_cert_cache: Option>, #[clap(skip)] _tls_key_cache: Option>, } @@ -42,20 +46,20 @@ pub struct ClientConfig { impl ClientConfig { /// Load certificates and put them in cache pub fn load_certificates(&mut self) { - self._root_certificate_cache = self - .root_certificate - .as_ref() - .map(|c| std::fs::read(c).expect("Failed to read root certificate!")); - - self._tls_cert_cache = self - .tls_cert - .as_ref() - .map(|c| std::fs::read(c).expect("Failed to read client certificate!")); - - self._tls_key_cache = self - .tls_key - .as_ref() - .map(|c| std::fs::read(c).expect("Failed to read client key!")); + self._keys_cache = KeysCache { + _root_certificate_cache: self + .root_certificate + .as_ref() + .map(|c| std::fs::read(c).expect("Failed to read root certificate!")), + _tls_cert_cache: self + .tls_cert + .as_ref() + .map(|c| std::fs::read(c).expect("Failed to read client certificate!")), + _tls_key_cache: self + .tls_key + .as_ref() + .map(|c| std::fs::read(c).expect("Failed to read client key!")), + }; } /// Get client token, returning a dummy token if none was specified @@ -65,12 +69,15 @@ impl ClientConfig { /// Get root certificate content pub fn get_root_certificate(&self) -> Option> { - self._root_certificate_cache.clone() + self._keys_cache._root_certificate_cache.clone() } /// Get client certificate & key pair, if available pub fn get_client_keypair(&self) -> Option<(&Vec, &Vec)> { - if let (Some(cert), Some(key)) = (&self._tls_cert_cache, &self._tls_key_cache) { + if let (Some(cert), Some(key)) = ( + &self._keys_cache._tls_cert_cache, + &self._keys_cache._tls_key_cache, + ) { Some((cert, key)) } else { None @@ -90,7 +97,7 @@ impl ClientConfig { #[cfg(test)] mod test { - use crate::client_config::ClientConfig; + use crate::tcp_relay_client::client_config::ClientConfig; #[test] fn verify_cli() { diff --git a/src/tcp_relay_server/server_config.rs b/src/tcp_relay_server/server_config.rs index 838ca40..9e3e94d 100644 --- a/src/tcp_relay_server/server_config.rs +++ b/src/tcp_relay_server/server_config.rs @@ -74,7 +74,7 @@ impl ServerConfig { #[cfg(test)] mod test { - use crate::server_config::ServerConfig; + use crate::tcp_relay_server::server_config::ServerConfig; #[test] fn verify_cli() { diff --git a/src/test/dummy_tcp_sockets.rs b/src/test/dummy_tcp_sockets.rs new file mode 100644 index 0000000..348a5cd --- /dev/null +++ b/src/test/dummy_tcp_sockets.rs @@ -0,0 +1,250 @@ +use std::time::Duration; + +use rand::Rng; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::time; + +use crate::test::LOCALHOST; + +pub struct DummyTCPServer(TcpListener); + +impl DummyTCPServer { + pub async fn start(port: u16) -> Self { + let addr = format!("{}:{}", LOCALHOST, port); + println!("[DUMMY TCP] Listen on {}", addr); + let listener = TcpListener::bind(addr) + .await + .expect("Failed to bind dummy TCP listener!"); + Self(listener) + } + + /// Receive chunk of data from following connection + pub async fn read_next_connection(&self) -> Vec { + let (mut conn, _addr) = self + .0 + .accept() + .await + .expect("Could not open next connection!"); + + let mut buff = Vec::with_capacity(100); + conn.read_to_end(&mut buff).await.unwrap(); + + buff + } + + /// Receive chunk of data from following connection + pub async fn read_then_write_next_connection(&self, to_send: &[u8]) -> Vec { + let (mut conn, _addr) = self + .0 + .accept() + .await + .expect("Could not open next connection!"); + + let mut buff: [u8; 100] = [0; 100]; + let size = conn.read(&mut buff).await.unwrap(); + + conn.write_all(to_send).await.unwrap(); + + buff[0..size].to_vec() + } + + /// Receive chunk of data from following connection + pub async fn write_next_connection(&self, to_send: &[u8]) { + let (mut conn, _addr) = self + .0 + .accept() + .await + .expect("Could not open next connection!"); + + conn.write_all(to_send).await.unwrap() + } + + /// Perform complex exchange: receive numbers from client and respond with their square + pub async fn next_conn_square_operations(&self) { + let (mut conn, _addr) = self + .0 + .accept() + .await + .expect("Could not open next connection!"); + + let mut buff: [u8; 100] = [0; 100]; + loop { + let size = conn.read(&mut buff).await.unwrap(); + if size == 0 { + break; + } + + let content = String::from_utf8_lossy(&buff[0..size]) + .to_string() + .parse::() + .unwrap(); + + conn.write_all((content * content).to_string().as_bytes()) + .await + .unwrap(); + } + } +} + +pub async fn dummy_tcp_client_read_conn(port: u16) -> Vec { + let mut socket = TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .expect("Failed to connect to dummy TCP server!"); + + let mut buff = Vec::with_capacity(100); + socket.read_to_end(&mut buff).await.unwrap(); + + buff +} + +pub async fn dummy_tcp_client_write_then_read_conn(port: u16, data: &[u8]) -> Vec { + let mut socket = TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .expect("Failed to connect to dummy TCP server!"); + + socket.write_all(data).await.unwrap(); + + let mut buff: [u8; 100] = [0; 100]; + let size = socket.read(&mut buff).await.unwrap(); + + buff[0..size].to_vec() +} + +pub async fn dummy_tcp_client_write_conn(port: u16, data: &[u8]) { + let mut socket = TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .expect("Failed to connect to dummy TCP server!"); + + socket.write_all(data).await.unwrap() +} + +pub async fn dummy_tcp_client_square_root_requests(port: u16, num_exchanges: usize) { + let mut socket = TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .expect("Failed to connect to dummy TCP server!"); + + let mut rng = rand::thread_rng(); + let mut buff: [u8; 100] = [0; 100]; + + for _ in 0..num_exchanges { + let num = rng.gen::() % 100; + socket.write_all(num.to_string().as_bytes()).await.unwrap(); + + let size = socket.read(&mut buff).await.unwrap(); + + if size == 0 { + panic!("Got empty response!"); + } + + let got = String::from_utf8_lossy(&buff[0..size]) + .to_string() + .parse::() + .unwrap(); + println!("{} * {} = {} (based on server response)", num, num, got); + assert_eq!((num * num) as i64, got); + } +} + +/// Check whether a given port is open or not +pub async fn is_port_open(port: u16) -> bool { + match TcpStream::connect(("127.0.0.1", port)).await { + Ok(_) => true, + Err(_) => false, + } +} + +/// Wait for a port to become available +pub async fn wait_for_port(port: u16) { + for _ in 0..50 { + if is_port_open(port).await { + return; + } + time::sleep(Duration::from_millis(10)).await; + } + + eprintln!("Port {} did not open in time!", port); + std::process::exit(2); +} + +mod test { + use crate::test::dummy_tcp_sockets::{ + dummy_tcp_client_read_conn, dummy_tcp_client_square_root_requests, + dummy_tcp_client_write_conn, dummy_tcp_client_write_then_read_conn, DummyTCPServer, + }; + use crate::test::{get_port_number, PortsAllocation}; + + fn port(index: u16) -> u16 { + get_port_number(PortsAllocation::DummyTCPServer, index) + } + + #[tokio::test] + async fn socket_read_from_server() { + const DATA: &[u8] = "Hello world!!!".as_bytes(); + + let listener = DummyTCPServer::start(port(0)).await; + let handle = tokio::spawn(async move { + listener.write_next_connection(DATA).await; + }); + let data = dummy_tcp_client_read_conn(port(0)).await; + assert_eq!(data, DATA); + + handle.await.unwrap(); + } + + #[tokio::test] + async fn socket_write_to_server() { + const DATA: &[u8] = "Hello world 2".as_bytes(); + + let listener = DummyTCPServer::start(port(1)).await; + tokio::spawn(async move { + dummy_tcp_client_write_conn(port(1), DATA).await; + }); + let data = listener.read_next_connection().await; + assert_eq!(data, DATA); + } + + #[tokio::test] + async fn socket_read_and_write_to_server() { + const DATA_1: &[u8] = "Hello world 3a".as_bytes(); + const DATA_2: &[u8] = "Hello world 3b".as_bytes(); + + let listener = DummyTCPServer::start(port(2)).await; + let handle = tokio::spawn(async move { + println!("client handle"); + let data = dummy_tcp_client_write_then_read_conn(port(2), DATA_1).await; + assert_eq!(data, DATA_2); + }); + let h2 = tokio::spawn(async move { + println!("server handle"); + let data = listener.read_then_write_next_connection(DATA_2).await; + assert_eq!(data, DATA_1); + }); + + handle.await.unwrap(); + h2.await.unwrap(); + } + + #[tokio::test] + async fn socket_dummy_root_calculator() { + let listener = DummyTCPServer::start(port(3)).await; + let handle = tokio::spawn(async move { + listener.next_conn_square_operations().await; + }); + let data = dummy_tcp_client_write_then_read_conn(port(3), "5".as_bytes()).await; + assert_eq!(data, "25".as_bytes()); + + handle.await.unwrap(); + } + + #[tokio::test] + async fn socket_dummy_root_calculator_multiple() { + let listener = DummyTCPServer::start(port(4)).await; + let handle = tokio::spawn(async move { + listener.next_conn_square_operations().await; + }); + dummy_tcp_client_square_root_requests(port(4), 10).await; + + handle.await.unwrap(); + } +} diff --git a/src/test/mod.rs b/src/test/mod.rs new file mode 100644 index 0000000..cc0f23b --- /dev/null +++ b/src/test/mod.rs @@ -0,0 +1,15 @@ +#[non_exhaustive] +enum PortsAllocation { + DummyTCPServer, + ValidWithTokenAuth, +} + +fn get_port_number(alloc: PortsAllocation, index: u16) -> u16 { + 2100 + 20 * (alloc as u16) + index +} + +const LOCALHOST: &str = "127.0.0.1"; + +mod dummy_tcp_sockets; + +mod valid_with_token_auth; diff --git a/src/test/valid_with_token_auth.rs b/src/test/valid_with_token_auth.rs new file mode 100644 index 0000000..da3aba6 --- /dev/null +++ b/src/test/valid_with_token_auth.rs @@ -0,0 +1,61 @@ +use tokio::task; + +use crate::tcp_relay_client::client_config::ClientConfig; +use crate::tcp_relay_server::server_config::ServerConfig; +use crate::test::dummy_tcp_sockets::{ + dummy_tcp_client_square_root_requests, dummy_tcp_client_write_then_read_conn, wait_for_port, + DummyTCPServer, +}; +use crate::test::{get_port_number, PortsAllocation, LOCALHOST}; + +const VALID_TOKEN: &str = "AvalidTOKEN"; + +const DATA_1: &[u8] = "DATA1".as_bytes(); +const DATA_2: &[u8] = "DATA2".as_bytes(); + +fn port(index: u16) -> u16 { + get_port_number(PortsAllocation::ValidWithTokenAuth, index) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 5)] +async fn valid_with_token_auth() { + let _ = env_logger::builder().is_test(true).try_init(); + + tokio::spawn(async move { + // Start internal service + let local_server = DummyTCPServer::start(port(1)).await; + local_server.next_conn_square_operations().await; + }); + + let local_set = task::LocalSet::new(); + local_set + .run_until(async move { + wait_for_port(port(1)).await; + + // Start server relay + task::spawn_local(crate::tcp_relay_server::run_app(ServerConfig { + tokens: vec![VALID_TOKEN.to_string()], + tokens_file: None, + ports: vec![port(1)], + upstream_server: "127.0.0.1".to_string(), + listen_address: format!("127.0.0.1:{}", port(0)), + increment_ports: 1, + tls_cert: None, + tls_key: None, + tls_client_auth_root_cert: None, + tls_revocation_list: None, + })); + wait_for_port(port(0)).await; + + // Start client relay + task::spawn(crate::tcp_relay_client::run_app(ClientConfig { + token: Some(VALID_TOKEN.to_string()), + relay_url: format!("http://{}:{}", LOCALHOST, port(0)), + listen_address: LOCALHOST.to_string(), + root_certificate: None, + ..Default::default() + })); + wait_for_port(port(2)).await; + }) + .await; +}