From 27b50d2333ea44494e5d82c4a0626af593e0fa1e Mon Sep 17 00:00:00 2001 From: Pierre Hubert Date: Thu, 1 Sep 2022 09:08:45 +0200 Subject: [PATCH] Refactor project to make it easier to test --- base/src/cert_utils.rs | 54 ++++++++++- tcp_relay_client/src/lib.rs | 95 ++++++++++++++++++- tcp_relay_client/src/main.rs | 88 +----------------- tcp_relay_server/src/lib.rs | 126 +++++++++++++++++++++++++- tcp_relay_server/src/main.rs | 118 +----------------------- tcp_relay_server/src/server_config.rs | 11 +++ 6 files changed, 283 insertions(+), 209 deletions(-) diff --git a/base/src/cert_utils.rs b/base/src/cert_utils.rs index acf6efc..620de2d 100644 --- a/base/src/cert_utils.rs +++ b/base/src/cert_utils.rs @@ -5,11 +5,23 @@ use rustls::{Certificate, PrivateKey}; use rustls_pemfile::{read_one, Item}; /// Parse PEM certificates bytes into a [`rustls::Certificate`] structure +/// +/// An error is returned if not any certificate could be found pub fn parse_pem_certificates(certs: &[u8]) -> Result, Box> { - Ok(rustls_pemfile::certs(&mut Cursor::new(certs))? + let certs = rustls_pemfile::certs(&mut Cursor::new(certs))? .into_iter() .map(Certificate) - .collect()) + .collect::>(); + + if certs.is_empty() { + Err(std::io::Error::new( + ErrorKind::InvalidData, + "Could not find any certificate!", + ))?; + unreachable!(); + } + + Ok(certs) } /// Parse PEM private key bytes into a [`rustls::PrivateKey`] structure @@ -35,3 +47,41 @@ pub fn parse_pem_private_key(privkey: &[u8]) -> Result Result> { + let url = format!("{}/config", conf.relay_url); + log::info!("Retrieving configuration on {}", url); + + let mut client = reqwest::Client::builder(); + + // Specify root certificate, if any was specified in the command line + if let Some(cert) = conf.get_root_certificate() { + client = client.add_root_certificate(Certificate::from_pem(&cert)?); + } + + // Specify client certificate, if any + if let Some(kp) = conf.get_merged_client_keypair() { + let identity = Identity::from_pem(&kp).expect("Failed to load certificates for reqwest!"); + client = client.identity(identity).use_rustls_tls(); + } + + let client = client.build().expect("Failed to build reqwest client"); + + let req = client + .get(url) + .header("Authorization", format!("Bearer {}", conf.get_auth_token())) + .send() + .await?; + if req.status().as_u16() != 200 { + log::error!( + "Could not retrieve configuration! (got status {})", + req.status() + ); + std::process::exit(2); + } + + Ok(req.json::().await?) +} + +/// Core logic of the application +pub async fn run_app(mut args: ClientConfig) -> Result<(), Box> { + args.load_certificates(); + let args = Arc::new(args); + + // Check arguments coherence + if args.tls_cert.is_some() != args.tls_key.is_some() { + log::error!( + "If you specify one of TLS certificate / key, you must then specify the other!" + ); + panic!(); + } + + if args.get_client_keypair().is_some() { + log::info!("Using client-side authentication"); + } + + // Get server relay configuration (fetch the list of port to forward) + let remote_conf = get_server_config(&args).await?; + + // Start to listen port + let mut handles = vec![]; + for port in remote_conf { + let listen_address = format!("{}:{}", args.listen_address, port.port); + + let h = tokio::spawn(relay_client( + format!( + "{}/ws?id={}&token={}", + args.relay_url, + port.id, + urlencoding::encode(args.get_auth_token()) + ) + .replace("http", "ws"), + listen_address, + args.clone(), + )); + handles.push(h); + } + + join_all(handles).await; + + Ok(()) +} diff --git a/tcp_relay_client/src/main.rs b/tcp_relay_client/src/main.rs index 2df316e..9234c5e 100644 --- a/tcp_relay_client/src/main.rs +++ b/tcp_relay_client/src/main.rs @@ -1,97 +1,11 @@ -extern crate core; - use std::error::Error; -use std::sync::Arc; use clap::Parser; -use futures::future::join_all; -use reqwest::{Certificate, Identity}; -use base::RemoteConfig; use tcp_relay_client::client_config::ClientConfig; -use tcp_relay_client::relay_client::relay_client; - -async fn get_server_config(config: &ClientConfig) -> Result> { - let url = format!("{}/config", config.relay_url); - log::info!("Retrieving configuration on {}", url); - - let mut client = reqwest::Client::builder(); - - // Specify root certificate, if any was specified in the command line - if let Some(cert) = config.get_root_certificate() { - client = client.add_root_certificate(Certificate::from_pem(&cert)?); - } - - // Specify client certificate, if any - if let Some(kp) = config.get_merged_client_keypair() { - let identity = Identity::from_pem(&kp).expect("Failed to load certificates for reqwest!"); - client = client.identity(identity).use_rustls_tls(); - } - - let client = client.build().expect("Failed to build reqwest client"); - - let req = client - .get(url) - .header( - "Authorization", - format!("Bearer {}", config.get_auth_token()), - ) - .send() - .await?; - if req.status().as_u16() != 200 { - log::error!( - "Could not retrieve configuration! (got status {})", - req.status() - ); - std::process::exit(2); - } - - Ok(req.json::().await?) -} #[tokio::main] async fn main() -> Result<(), Box> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); - - let mut args: ClientConfig = ClientConfig::parse(); - args.load_certificates(); - let args = Arc::new(args); - - // Check arguments coherence - if args.tls_cert.is_some() != args.tls_key.is_some() { - log::error!( - "If you specify one of TLS certificate / key, you must then specify the other!" - ); - panic!(); - } - - if args.get_client_keypair().is_some() { - log::info!("Using client-side authentication"); - } - - // Get server relay configuration (fetch the list of port to forward) - let conf = get_server_config(&args).await?; - - // Start to listen port - let mut handles = vec![]; - for port in conf { - let listen_address = format!("{}:{}", args.listen_address, port.port); - - let h = tokio::spawn(relay_client( - format!( - "{}/ws?id={}&token={}", - args.relay_url, - port.id, - urlencoding::encode(args.get_auth_token()) - ) - .replace("http", "ws"), - listen_address, - args.clone(), - )); - handles.push(h); - } - - join_all(handles).await; - - Ok(()) + tcp_relay_client::run_app(ClientConfig::parse()).await } diff --git a/tcp_relay_server/src/lib.rs b/tcp_relay_server/src/lib.rs index a042584..7fa1859 100644 --- a/tcp_relay_server/src/lib.rs +++ b/tcp_relay_server/src/lib.rs @@ -1,3 +1,125 @@ -pub mod relay_ws; +use std::sync::Arc; + +use actix_web::web::Data; +use actix_web::{middleware, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; + +use base::{cert_utils, RelayedPort}; + +use crate::relay_ws::relay_ws; +use crate::server_config::ServerConfig; +use crate::tls_cert_client_verifier::CustomCertClientVerifier; + +mod relay_ws; pub mod server_config; -pub mod tls_cert_client_verifier; +mod tls_cert_client_verifier; + +pub async fn hello_route() -> &'static str { + "Hello world!" +} + +pub async fn config_route(req: HttpRequest, data: Data>) -> impl Responder { + if data.has_token_auth() { + let token = req + .headers() + .get("Authorization") + .map(|t| t.to_str().unwrap_or_default()) + .unwrap_or_default() + .strip_prefix("Bearer ") + .unwrap_or_default(); + + if !data.tokens.iter().any(|t| t.eq(token)) { + return HttpResponse::Unauthorized().json("Missing / invalid token"); + } + } + + HttpResponse::Ok().json( + data.ports + .iter() + .enumerate() + .map(|(id, port)| RelayedPort { + id, + port: port + data.increment_ports, + }) + .collect::>(), + ) +} + +pub async fn run_app(mut config: ServerConfig) -> std::io::Result<()> { + // Check if no port are to be forwarded + if config.ports.is_empty() { + log::error!("No port to forward!"); + std::process::exit(2); + } + + // Read tokens from file, if any + if let Some(file) = &config.tokens_file { + std::fs::read_to_string(file) + .expect("Failed to read tokens file!") + .split('\n') + .filter(|l| !l.is_empty()) + .for_each(|t| config.tokens.push(t.to_string())); + } + + if !config.has_auth() { + log::error!("No authentication method specified!"); + std::process::exit(3); + } + + if config.has_tls_client_auth() && !config.has_tls_config() { + log::error!("Cannot provide client auth without TLS configuration!"); + panic!(); + } + + let args = Arc::new(config); + + // Load TLS configuration, if any + let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) { + // Load TLS certificate & private key + let cert_file = std::fs::read(cert).expect("Failed to read certificate file"); + let key_file = std::fs::read(key).expect("Failed to read server private key"); + + // Get certificates chain + let cert_chain = + cert_utils::parse_pem_certificates(&cert_file).expect("Failed to extract certificates"); + + // Get private key + let key = + cert_utils::parse_pem_private_key(&key_file).expect("Failed to extract private key!"); + + let config = rustls::ServerConfig::builder().with_safe_defaults(); + + let config = match args.has_tls_client_auth() { + true => config + .with_client_cert_verifier(Arc::new(CustomCertClientVerifier::new(args.clone()))), + false => config.with_no_client_auth(), + }; + + let config = config + .with_single_cert(cert_chain, key) + .expect("Failed to load TLS certificate!"); + + Some(config) + } else { + None + }; + + log::info!("Starting relay on http://{}", args.listen_address); + + let args_clone = args.clone(); + let server = HttpServer::new(move || { + App::new() + .wrap(middleware::Logger::default()) + .app_data(Data::new(args_clone.clone())) + .route("/", web::get().to(hello_route)) + .route("/config", web::get().to(config_route)) + .route("/ws", web::get().to(relay_ws)) + }); + + if let Some(tls_conf) = tls_config { + server.bind_rustls(&args.listen_address, tls_conf)? + } else { + server.bind(&args.listen_address)? + } + .run() + .await +} diff --git a/tcp_relay_server/src/main.rs b/tcp_relay_server/src/main.rs index 4b4d63a..3636c5e 100644 --- a/tcp_relay_server/src/main.rs +++ b/tcp_relay_server/src/main.rs @@ -1,126 +1,10 @@ -use std::sync::Arc; - -use actix_web::web::Data; -use actix_web::{middleware, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; use clap::Parser; -use base::{cert_utils, RelayedPort}; -use tcp_relay_server::relay_ws::relay_ws; use tcp_relay_server::server_config::ServerConfig; -use tcp_relay_server::tls_cert_client_verifier::CustomCertClientVerifier; - -pub async fn hello_route() -> &'static str { - "Hello world!" -} - -pub async fn config_route(req: HttpRequest, data: Data>) -> impl Responder { - if data.has_token_auth() { - let token = req - .headers() - .get("Authorization") - .map(|t| t.to_str().unwrap_or_default()) - .unwrap_or_default() - .strip_prefix("Bearer ") - .unwrap_or_default(); - - if !data.tokens.iter().any(|t| t.eq(token)) { - return HttpResponse::Unauthorized().json("Missing / invalid token"); - } - } - - HttpResponse::Ok().json( - data.ports - .iter() - .enumerate() - .map(|(id, port)| RelayedPort { - id, - port: port + data.increment_ports, - }) - .collect::>(), - ) -} #[actix_web::main] async fn main() -> std::io::Result<()> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); - let mut args: ServerConfig = ServerConfig::parse(); - - // Check if no port are to be forwarded - if args.ports.is_empty() { - log::error!("No port to forward!"); - std::process::exit(2); - } - - // Read tokens from file, if any - if let Some(file) = &args.tokens_file { - std::fs::read_to_string(file) - .expect("Failed to read tokens file!") - .split('\n') - .filter(|l| !l.is_empty()) - .for_each(|t| args.tokens.push(t.to_string())); - } - - if !args.has_auth() { - log::error!("No authentication method specified!"); - std::process::exit(3); - } - - if args.has_tls_client_auth() && !args.has_tls_config() { - log::error!("Cannot provide client auth without TLS configuration!"); - panic!(); - } - - let args = Arc::new(args); - - // Load TLS configuration, if any - let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) { - // Load TLS certificate & private key - let cert_file = std::fs::read(cert).expect("Failed to read certificate file"); - let key_file = std::fs::read(key).expect("Failed to read server private key"); - - // Get certificates chain - let cert_chain = - cert_utils::parse_pem_certificates(&cert_file).expect("Failed to extract certificates"); - - // Get private key - let key = - cert_utils::parse_pem_private_key(&key_file).expect("Failed to extract private key!"); - - let config = rustls::ServerConfig::builder().with_safe_defaults(); - - let config = match args.has_tls_client_auth() { - true => config - .with_client_cert_verifier(Arc::new(CustomCertClientVerifier::new(args.clone()))), - false => config.with_no_client_auth(), - }; - - let config = config - .with_single_cert(cert_chain, key) - .expect("Failed to load TLS certificate!"); - - Some(config) - } else { - None - }; - - log::info!("Starting relay on http://{}", args.listen_address); - - let args_clone = args.clone(); - let server = HttpServer::new(move || { - App::new() - .wrap(middleware::Logger::default()) - .app_data(Data::new(args_clone.clone())) - .route("/", web::get().to(hello_route)) - .route("/config", web::get().to(config_route)) - .route("/ws", web::get().to(relay_ws)) - }); - - if let Some(tls_conf) = tls_config { - server.bind_rustls(&args.listen_address, tls_conf)? - } else { - server.bind(&args.listen_address)? - } - .run() - .await + tcp_relay_server::run_app(ServerConfig::parse()).await } diff --git a/tcp_relay_server/src/server_config.rs b/tcp_relay_server/src/server_config.rs index 70e655c..82b5201 100644 --- a/tcp_relay_server/src/server_config.rs +++ b/tcp_relay_server/src/server_config.rs @@ -71,3 +71,14 @@ impl ServerConfig { self.has_token_auth() || self.has_tls_client_auth() } } + +#[cfg(test)] +mod test { + use crate::server_config::ServerConfig; + + #[test] + fn verify_cli() { + use clap::CommandFactory; + ServerConfig::command().debug_assert() + } +}