diff --git a/src/base/err_utils.rs b/src/base/err_utils.rs index d9c10a7..1bb1915 100644 --- a/src/base/err_utils.rs +++ b/src/base/err_utils.rs @@ -1,7 +1,12 @@ -use std::error::Error; +use std::fmt::Display; use std::io::ErrorKind; /// Encapsulate errors in [`std::io::Error`] with a message -pub fn encpasulate_error(e: E, msg: &str) -> std::io::Error { +pub fn encpasulate_error(e: E, msg: &str) -> std::io::Error { std::io::Error::new(ErrorKind::Other, format!("{}: {}", msg, e)) } + +/// Create a new [`std::io::Error`] +pub fn new_err(msg: &str) -> std::io::Error { + std::io::Error::new(ErrorKind::Other, msg.to_string()) +} diff --git a/src/tcp_relay_server/mod.rs b/src/tcp_relay_server/mod.rs index 8905c72..39234dc 100644 --- a/src/tcp_relay_server/mod.rs +++ b/src/tcp_relay_server/mod.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use actix_web::web::Data; use actix_web::{middleware, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; -use crate::base::err_utils::encpasulate_error; +use crate::base::err_utils::{encpasulate_error, new_err}; use crate::base::{cert_utils, RelayedPort}; use crate::tcp_relay_server::relay_ws::relay_ws; use crate::tcp_relay_server::server_config::ServerConfig; @@ -47,8 +47,7 @@ pub async fn config_route(req: HttpRequest, data: Data>) -> im pub async fn run_app(mut config: ServerConfig) -> std::io::Result<()> { // Check if no port are to be forwarded if config.ports.is_empty() { - log::error!("No port to forward!"); - std::process::exit(2); + return Err(new_err("No port to forward!")); } // Read tokens from file, if any @@ -61,13 +60,17 @@ pub async fn run_app(mut config: ServerConfig) -> std::io::Result<()> { } if !config.has_auth() { - log::error!("No authentication method specified!"); - std::process::exit(3); + return Err(new_err("No authentication method specified!")); + } + + if config.tls_cert.is_some() != config.tls_key.is_some() { + return Err(new_err("Incomplete server TLS configuration!")); } if config.has_tls_client_auth() && !config.has_tls_config() { - log::error!("Cannot provide client auth without TLS configuration!"); - panic!(); + return Err(new_err( + "Cannot provide client auth without TLS configuration!", + )); } let args = Arc::new(config); @@ -75,16 +78,18 @@ pub async fn run_app(mut config: ServerConfig) -> std::io::Result<()> { // Load TLS configuration, if any let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) { // Load TLS certificate & private key - let cert_file = std::fs::read(cert).expect("Failed to read certificate file"); - let key_file = std::fs::read(key).expect("Failed to read server private key"); + let cert_file = std::fs::read(cert) + .map_err(|e| encpasulate_error(e, "Failed to read certificate file"))?; + let key_file = std::fs::read(key) + .map_err(|e| encpasulate_error(e, "Failed to read server private key"))?; // Get certificates chain - let cert_chain = - cert_utils::parse_pem_certificates(&cert_file).expect("Failed to extract certificates"); + let cert_chain = cert_utils::parse_pem_certificates(&cert_file) + .map_err(|e| encpasulate_error(e, "Failed to extract certificates"))?; // Get private key - let key = - cert_utils::parse_pem_private_key(&key_file).expect("Failed to extract private key!"); + let key = cert_utils::parse_pem_private_key(&key_file) + .map_err(|e| encpasulate_error(e, "Failed to extract private key!"))?; let config = rustls::ServerConfig::builder().with_safe_defaults(); diff --git a/src/test/dummy_tcp_sockets.rs b/src/test/dummy_tcp_sockets.rs index 69cfd36..93b807d 100644 --- a/src/test/dummy_tcp_sockets.rs +++ b/src/test/dummy_tcp_sockets.rs @@ -130,10 +130,10 @@ pub async fn dummy_tcp_client_read_conn(port: u16) -> Vec { .await .expect("Failed to connect to dummy TCP server!"); - let mut buff = Vec::with_capacity(100); - socket.read_to_end(&mut buff).await.unwrap(); + let mut buff: [u8; 100] = [0; 100]; + let size = socket.read(&mut buff).await.unwrap(); - buff + buff[0..size].to_vec() } pub async fn dummy_tcp_client_write_then_read_conn(port: u16, data: &[u8]) -> Vec { diff --git a/src/test/invalid_token_file.rs b/src/test/invalid_token_file.rs index b6f3564..de5846f 100644 --- a/src/test/invalid_token_file.rs +++ b/src/test/invalid_token_file.rs @@ -4,7 +4,7 @@ use crate::test::{get_port_number, PortsAllocation}; const INVALID_TOKEN: &str = "/tmp/a/token/file/that/does/not/exists"; fn port(index: u16) -> u16 { - get_port_number(PortsAllocation::InvalidTokenFile, index) + get_port_number(PortsAllocation::TestsWithoutPortOpened, index) } #[tokio::test(flavor = "multi_thread", worker_threads = 5)] diff --git a/src/test/mod.rs b/src/test/mod.rs index c0f132d..37d8331 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -1,11 +1,11 @@ #[non_exhaustive] enum PortsAllocation { + TestsWithoutPortOpened, DummyTCPServer, ValidWithTokenAuth, InvalidWithTokenAuth, ValidWithMultipleTokenAuth, ValidWithTokenFile, - InvalidTokenFile, ClientTryTLSWhileThereIsNoTLS, ValidTokenWithCustomIncrement, ValidWithTokenAuthMultiplePorts, @@ -15,7 +15,7 @@ enum PortsAllocation { } fn get_port_number(alloc: PortsAllocation, index: u16) -> u16 { - 2100 + 20 * (alloc as u16) + index + 30000 + 20 * (alloc as u16) + index } const LOCALHOST_IP: &str = "127.0.0.1"; @@ -27,6 +27,10 @@ mod test_files_utils; mod client_try_tls_while_there_is_no_tls; mod invalid_token_file; mod invalid_with_token_auth; +mod server_invalid_tls_config_invalid_cert; +mod server_invalid_tls_config_invalid_key; +mod server_invalid_tls_config_invalid_paths; +mod server_invalid_tls_config_missing_key; mod valid_token_with_custom_increment; mod valid_with_multiple_token_auth; mod valid_with_token_auth; diff --git a/src/test/server_invalid_tls_config_invalid_cert.rs b/src/test/server_invalid_tls_config_invalid_cert.rs new file mode 100644 index 0000000..f33cc8c --- /dev/null +++ b/src/test/server_invalid_tls_config_invalid_cert.rs @@ -0,0 +1,31 @@ +use crate::tcp_relay_server::server_config::ServerConfig; +use crate::test::pki::Pki; +use crate::test::{get_port_number, PortsAllocation}; + +const TOKEN: &str = "mytok"; + +fn port(index: u16) -> u16 { + get_port_number(PortsAllocation::TestsWithoutPortOpened, index) +} + +#[tokio::test] +async fn test() { + let _ = env_logger::builder().is_test(true).try_init(); + + let pki = Pki::load(); + + crate::tcp_relay_server::run_app(ServerConfig { + tokens: vec![TOKEN.to_string()], + tokens_file: None, + ports: vec![port(1)], + upstream_server: "127.0.0.1".to_string(), + listen_address: format!("127.0.0.1:{}", port(0)), + increment_ports: 1, + tls_cert: Some(pki.root_ca_crl.file_path()), + tls_key: Some(pki.localhost_key.file_path()), + tls_client_auth_root_cert: None, + tls_revocation_list: None, + }) + .await + .unwrap_err(); +} diff --git a/src/test/server_invalid_tls_config_invalid_key.rs b/src/test/server_invalid_tls_config_invalid_key.rs new file mode 100644 index 0000000..6bd0788 --- /dev/null +++ b/src/test/server_invalid_tls_config_invalid_key.rs @@ -0,0 +1,31 @@ +use crate::tcp_relay_server::server_config::ServerConfig; +use crate::test::pki::Pki; +use crate::test::{get_port_number, PortsAllocation}; + +const TOKEN: &str = "mytok"; + +fn port(index: u16) -> u16 { + get_port_number(PortsAllocation::TestsWithoutPortOpened, index) +} + +#[tokio::test] +async fn test() { + let _ = env_logger::builder().is_test(true).try_init(); + + let pki = Pki::load(); + + crate::tcp_relay_server::run_app(ServerConfig { + tokens: vec![TOKEN.to_string()], + tokens_file: None, + ports: vec![port(1)], + upstream_server: "127.0.0.1".to_string(), + listen_address: format!("127.0.0.1:{}", port(0)), + increment_ports: 1, + tls_cert: Some(pki.root_ca_crt.file_path()), + tls_key: Some(pki.root_ca_crt.file_path()), + tls_client_auth_root_cert: None, + tls_revocation_list: None, + }) + .await + .unwrap_err(); +} diff --git a/src/test/server_invalid_tls_config_invalid_paths.rs b/src/test/server_invalid_tls_config_invalid_paths.rs new file mode 100644 index 0000000..fe9922b --- /dev/null +++ b/src/test/server_invalid_tls_config_invalid_paths.rs @@ -0,0 +1,53 @@ +use crate::tcp_relay_server::server_config::ServerConfig; +use crate::test::pki::Pki; +use crate::test::{get_port_number, PortsAllocation}; + +const TOKEN: &str = "mytok"; + +fn port(index: u16) -> u16 { + get_port_number(PortsAllocation::TestsWithoutPortOpened, index) +} + +#[tokio::test] +async fn invalid_key_path() { + let _ = env_logger::builder().is_test(true).try_init(); + + let pki = Pki::load(); + + crate::tcp_relay_server::run_app(ServerConfig { + tokens: vec![TOKEN.to_string()], + tokens_file: None, + ports: vec![port(1)], + upstream_server: "127.0.0.1".to_string(), + listen_address: format!("127.0.0.1:{}", port(0)), + increment_ports: 1, + tls_cert: Some(pki.localhost_crt.file_path()), + tls_key: Some("/bad/path/to/key/file".to_string()), + tls_client_auth_root_cert: None, + tls_revocation_list: None, + }) + .await + .unwrap_err(); +} + +#[tokio::test] +async fn invalid_cert_path() { + let _ = env_logger::builder().is_test(true).try_init(); + + let pki = Pki::load(); + + crate::tcp_relay_server::run_app(ServerConfig { + tokens: vec![TOKEN.to_string()], + tokens_file: None, + ports: vec![port(1)], + upstream_server: "127.0.0.1".to_string(), + listen_address: format!("127.0.0.1:{}", port(0)), + increment_ports: 1, + tls_cert: Some("/bad/path/to/key/file".to_string()), + tls_key: Some(pki.localhost_key.file_path()), + tls_client_auth_root_cert: None, + tls_revocation_list: None, + }) + .await + .unwrap_err(); +} diff --git a/src/test/server_invalid_tls_config_missing_key.rs b/src/test/server_invalid_tls_config_missing_key.rs new file mode 100644 index 0000000..709a3d0 --- /dev/null +++ b/src/test/server_invalid_tls_config_missing_key.rs @@ -0,0 +1,31 @@ +use crate::tcp_relay_server::server_config::ServerConfig; +use crate::test::pki::Pki; +use crate::test::{get_port_number, PortsAllocation}; + +const TOKEN: &str = "mytok"; + +fn port(index: u16) -> u16 { + get_port_number(PortsAllocation::TestsWithoutPortOpened, index) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test() { + let _ = env_logger::builder().is_test(true).try_init(); + + let pki = Pki::load(); + + crate::tcp_relay_server::run_app(ServerConfig { + tokens: vec![TOKEN.to_string()], + tokens_file: None, + ports: vec![port(1)], + upstream_server: "127.0.0.1".to_string(), + listen_address: format!("127.0.0.1:{}", port(0)), + increment_ports: 1, + tls_cert: Some(pki.root_ca_crt.file_path()), + tls_key: None, + tls_client_auth_root_cert: None, + tls_revocation_list: None, + }) + .await + .unwrap_err(); +}