Refactor project to make it easier to test
This commit is contained in:
		| @@ -5,11 +5,23 @@ use rustls::{Certificate, PrivateKey}; | ||||
| use rustls_pemfile::{read_one, Item}; | ||||
|  | ||||
| /// Parse PEM certificates bytes into a [`rustls::Certificate`] structure | ||||
| /// | ||||
| /// An error is returned if not any certificate could be found | ||||
| pub fn parse_pem_certificates(certs: &[u8]) -> Result<Vec<Certificate>, Box<dyn Error>> { | ||||
|     Ok(rustls_pemfile::certs(&mut Cursor::new(certs))? | ||||
|     let certs = rustls_pemfile::certs(&mut Cursor::new(certs))? | ||||
|         .into_iter() | ||||
|         .map(Certificate) | ||||
|         .collect()) | ||||
|         .collect::<Vec<_>>(); | ||||
|  | ||||
|     if certs.is_empty() { | ||||
|         Err(std::io::Error::new( | ||||
|             ErrorKind::InvalidData, | ||||
|             "Could not find any certificate!", | ||||
|         ))?; | ||||
|         unreachable!(); | ||||
|     } | ||||
|  | ||||
|     Ok(certs) | ||||
| } | ||||
|  | ||||
| /// Parse PEM private key bytes into a [`rustls::PrivateKey`] structure | ||||
| @@ -35,3 +47,41 @@ pub fn parse_pem_private_key(privkey: &[u8]) -> Result<PrivateKey, Box<dyn Error | ||||
|  | ||||
|     Ok(PrivateKey(key)) | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod test { | ||||
|     use crate::cert_utils::{parse_pem_certificates, parse_pem_private_key}; | ||||
|  | ||||
|     const SAMPLE_CERT: &[u8] = include_bytes!("../samples/TCPTunnelTest.crt"); | ||||
|     const SAMPLE_KEY: &[u8] = include_bytes!("../samples/TCPTunnelTest.key"); | ||||
|  | ||||
|     #[test] | ||||
|     fn parse_valid_cert() { | ||||
|         parse_pem_certificates(SAMPLE_CERT).unwrap(); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn parse_invalid_cert_1() { | ||||
|         parse_pem_certificates("Random content".as_bytes()).unwrap_err(); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn parse_invalid_cert_2() { | ||||
|         parse_pem_certificates(SAMPLE_KEY).unwrap_err(); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn parse_valid_key() { | ||||
|         parse_pem_private_key(SAMPLE_KEY).unwrap(); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn parse_invalid_key_1() { | ||||
|         parse_pem_private_key("Random content".as_bytes()).unwrap_err(); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn parse_invalid_key_2() { | ||||
|         parse_pem_private_key(SAMPLE_CERT).unwrap_err(); | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -1,2 +1,95 @@ | ||||
| extern crate core; | ||||
|  | ||||
| use std::error::Error; | ||||
| use std::sync::Arc; | ||||
|  | ||||
| use futures::future::join_all; | ||||
| use reqwest::{Certificate, Identity}; | ||||
|  | ||||
| use base::RemoteConfig; | ||||
|  | ||||
| use crate::client_config::ClientConfig; | ||||
| use crate::relay_client::relay_client; | ||||
|  | ||||
| pub mod client_config; | ||||
| pub mod relay_client; | ||||
| mod relay_client; | ||||
|  | ||||
| /// Get remote server config i.e. get the list of forwarded ports | ||||
| async fn get_server_config(conf: &ClientConfig) -> Result<RemoteConfig, Box<dyn Error>> { | ||||
|     let url = format!("{}/config", conf.relay_url); | ||||
|     log::info!("Retrieving configuration on {}", url); | ||||
|  | ||||
|     let mut client = reqwest::Client::builder(); | ||||
|  | ||||
|     // Specify root certificate, if any was specified in the command line | ||||
|     if let Some(cert) = conf.get_root_certificate() { | ||||
|         client = client.add_root_certificate(Certificate::from_pem(&cert)?); | ||||
|     } | ||||
|  | ||||
|     // Specify client certificate, if any | ||||
|     if let Some(kp) = conf.get_merged_client_keypair() { | ||||
|         let identity = Identity::from_pem(&kp).expect("Failed to load certificates for reqwest!"); | ||||
|         client = client.identity(identity).use_rustls_tls(); | ||||
|     } | ||||
|  | ||||
|     let client = client.build().expect("Failed to build reqwest client"); | ||||
|  | ||||
|     let req = client | ||||
|         .get(url) | ||||
|         .header("Authorization", format!("Bearer {}", conf.get_auth_token())) | ||||
|         .send() | ||||
|         .await?; | ||||
|     if req.status().as_u16() != 200 { | ||||
|         log::error!( | ||||
|             "Could not retrieve configuration! (got status {})", | ||||
|             req.status() | ||||
|         ); | ||||
|         std::process::exit(2); | ||||
|     } | ||||
|  | ||||
|     Ok(req.json::<RemoteConfig>().await?) | ||||
| } | ||||
|  | ||||
| /// Core logic of the application | ||||
| pub async fn run_app(mut args: ClientConfig) -> Result<(), Box<dyn Error>> { | ||||
|     args.load_certificates(); | ||||
|     let args = Arc::new(args); | ||||
|  | ||||
|     // Check arguments coherence | ||||
|     if args.tls_cert.is_some() != args.tls_key.is_some() { | ||||
|         log::error!( | ||||
|             "If you specify one of TLS certificate / key, you must then specify the other!" | ||||
|         ); | ||||
|         panic!(); | ||||
|     } | ||||
|  | ||||
|     if args.get_client_keypair().is_some() { | ||||
|         log::info!("Using client-side authentication"); | ||||
|     } | ||||
|  | ||||
|     // Get server relay configuration (fetch the list of port to forward) | ||||
|     let remote_conf = get_server_config(&args).await?; | ||||
|  | ||||
|     // Start to listen port | ||||
|     let mut handles = vec![]; | ||||
|     for port in remote_conf { | ||||
|         let listen_address = format!("{}:{}", args.listen_address, port.port); | ||||
|  | ||||
|         let h = tokio::spawn(relay_client( | ||||
|             format!( | ||||
|                 "{}/ws?id={}&token={}", | ||||
|                 args.relay_url, | ||||
|                 port.id, | ||||
|                 urlencoding::encode(args.get_auth_token()) | ||||
|             ) | ||||
|             .replace("http", "ws"), | ||||
|             listen_address, | ||||
|             args.clone(), | ||||
|         )); | ||||
|         handles.push(h); | ||||
|     } | ||||
|  | ||||
|     join_all(handles).await; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|   | ||||
| @@ -1,97 +1,11 @@ | ||||
| extern crate core; | ||||
|  | ||||
| use std::error::Error; | ||||
| use std::sync::Arc; | ||||
|  | ||||
| use clap::Parser; | ||||
| use futures::future::join_all; | ||||
| use reqwest::{Certificate, Identity}; | ||||
|  | ||||
| use base::RemoteConfig; | ||||
| use tcp_relay_client::client_config::ClientConfig; | ||||
| use tcp_relay_client::relay_client::relay_client; | ||||
|  | ||||
| async fn get_server_config(config: &ClientConfig) -> Result<RemoteConfig, Box<dyn Error>> { | ||||
|     let url = format!("{}/config", config.relay_url); | ||||
|     log::info!("Retrieving configuration on {}", url); | ||||
|  | ||||
|     let mut client = reqwest::Client::builder(); | ||||
|  | ||||
|     // Specify root certificate, if any was specified in the command line | ||||
|     if let Some(cert) = config.get_root_certificate() { | ||||
|         client = client.add_root_certificate(Certificate::from_pem(&cert)?); | ||||
|     } | ||||
|  | ||||
|     // Specify client certificate, if any | ||||
|     if let Some(kp) = config.get_merged_client_keypair() { | ||||
|         let identity = Identity::from_pem(&kp).expect("Failed to load certificates for reqwest!"); | ||||
|         client = client.identity(identity).use_rustls_tls(); | ||||
|     } | ||||
|  | ||||
|     let client = client.build().expect("Failed to build reqwest client"); | ||||
|  | ||||
|     let req = client | ||||
|         .get(url) | ||||
|         .header( | ||||
|             "Authorization", | ||||
|             format!("Bearer {}", config.get_auth_token()), | ||||
|         ) | ||||
|         .send() | ||||
|         .await?; | ||||
|     if req.status().as_u16() != 200 { | ||||
|         log::error!( | ||||
|             "Could not retrieve configuration! (got status {})", | ||||
|             req.status() | ||||
|         ); | ||||
|         std::process::exit(2); | ||||
|     } | ||||
|  | ||||
|     Ok(req.json::<RemoteConfig>().await?) | ||||
| } | ||||
|  | ||||
| #[tokio::main] | ||||
| async fn main() -> Result<(), Box<dyn Error>> { | ||||
|     env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); | ||||
|  | ||||
|     let mut args: ClientConfig = ClientConfig::parse(); | ||||
|     args.load_certificates(); | ||||
|     let args = Arc::new(args); | ||||
|  | ||||
|     // Check arguments coherence | ||||
|     if args.tls_cert.is_some() != args.tls_key.is_some() { | ||||
|         log::error!( | ||||
|             "If you specify one of TLS certificate / key, you must then specify the other!" | ||||
|         ); | ||||
|         panic!(); | ||||
|     } | ||||
|  | ||||
|     if args.get_client_keypair().is_some() { | ||||
|         log::info!("Using client-side authentication"); | ||||
|     } | ||||
|  | ||||
|     // Get server relay configuration (fetch the list of port to forward) | ||||
|     let conf = get_server_config(&args).await?; | ||||
|  | ||||
|     // Start to listen port | ||||
|     let mut handles = vec![]; | ||||
|     for port in conf { | ||||
|         let listen_address = format!("{}:{}", args.listen_address, port.port); | ||||
|  | ||||
|         let h = tokio::spawn(relay_client( | ||||
|             format!( | ||||
|                 "{}/ws?id={}&token={}", | ||||
|                 args.relay_url, | ||||
|                 port.id, | ||||
|                 urlencoding::encode(args.get_auth_token()) | ||||
|             ) | ||||
|             .replace("http", "ws"), | ||||
|             listen_address, | ||||
|             args.clone(), | ||||
|         )); | ||||
|         handles.push(h); | ||||
|     } | ||||
|  | ||||
|     join_all(handles).await; | ||||
|  | ||||
|     Ok(()) | ||||
|     tcp_relay_client::run_app(ClientConfig::parse()).await | ||||
| } | ||||
|   | ||||
| @@ -1,3 +1,125 @@ | ||||
| pub mod relay_ws; | ||||
| use std::sync::Arc; | ||||
|  | ||||
| use actix_web::web::Data; | ||||
| use actix_web::{middleware, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; | ||||
|  | ||||
| use base::{cert_utils, RelayedPort}; | ||||
|  | ||||
| use crate::relay_ws::relay_ws; | ||||
| use crate::server_config::ServerConfig; | ||||
| use crate::tls_cert_client_verifier::CustomCertClientVerifier; | ||||
|  | ||||
| mod relay_ws; | ||||
| pub mod server_config; | ||||
| pub mod tls_cert_client_verifier; | ||||
| mod tls_cert_client_verifier; | ||||
|  | ||||
| pub async fn hello_route() -> &'static str { | ||||
|     "Hello world!" | ||||
| } | ||||
|  | ||||
| pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> impl Responder { | ||||
|     if data.has_token_auth() { | ||||
|         let token = req | ||||
|             .headers() | ||||
|             .get("Authorization") | ||||
|             .map(|t| t.to_str().unwrap_or_default()) | ||||
|             .unwrap_or_default() | ||||
|             .strip_prefix("Bearer ") | ||||
|             .unwrap_or_default(); | ||||
|  | ||||
|         if !data.tokens.iter().any(|t| t.eq(token)) { | ||||
|             return HttpResponse::Unauthorized().json("Missing / invalid token"); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     HttpResponse::Ok().json( | ||||
|         data.ports | ||||
|             .iter() | ||||
|             .enumerate() | ||||
|             .map(|(id, port)| RelayedPort { | ||||
|                 id, | ||||
|                 port: port + data.increment_ports, | ||||
|             }) | ||||
|             .collect::<Vec<_>>(), | ||||
|     ) | ||||
| } | ||||
|  | ||||
| pub async fn run_app(mut config: ServerConfig) -> std::io::Result<()> { | ||||
|     // Check if no port are to be forwarded | ||||
|     if config.ports.is_empty() { | ||||
|         log::error!("No port to forward!"); | ||||
|         std::process::exit(2); | ||||
|     } | ||||
|  | ||||
|     // Read tokens from file, if any | ||||
|     if let Some(file) = &config.tokens_file { | ||||
|         std::fs::read_to_string(file) | ||||
|             .expect("Failed to read tokens file!") | ||||
|             .split('\n') | ||||
|             .filter(|l| !l.is_empty()) | ||||
|             .for_each(|t| config.tokens.push(t.to_string())); | ||||
|     } | ||||
|  | ||||
|     if !config.has_auth() { | ||||
|         log::error!("No authentication method specified!"); | ||||
|         std::process::exit(3); | ||||
|     } | ||||
|  | ||||
|     if config.has_tls_client_auth() && !config.has_tls_config() { | ||||
|         log::error!("Cannot provide client auth without TLS configuration!"); | ||||
|         panic!(); | ||||
|     } | ||||
|  | ||||
|     let args = Arc::new(config); | ||||
|  | ||||
|     // Load TLS configuration, if any | ||||
|     let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) { | ||||
|         // Load TLS certificate & private key | ||||
|         let cert_file = std::fs::read(cert).expect("Failed to read certificate file"); | ||||
|         let key_file = std::fs::read(key).expect("Failed to read server private key"); | ||||
|  | ||||
|         // Get certificates chain | ||||
|         let cert_chain = | ||||
|             cert_utils::parse_pem_certificates(&cert_file).expect("Failed to extract certificates"); | ||||
|  | ||||
|         // Get private key | ||||
|         let key = | ||||
|             cert_utils::parse_pem_private_key(&key_file).expect("Failed to extract private key!"); | ||||
|  | ||||
|         let config = rustls::ServerConfig::builder().with_safe_defaults(); | ||||
|  | ||||
|         let config = match args.has_tls_client_auth() { | ||||
|             true => config | ||||
|                 .with_client_cert_verifier(Arc::new(CustomCertClientVerifier::new(args.clone()))), | ||||
|             false => config.with_no_client_auth(), | ||||
|         }; | ||||
|  | ||||
|         let config = config | ||||
|             .with_single_cert(cert_chain, key) | ||||
|             .expect("Failed to load TLS certificate!"); | ||||
|  | ||||
|         Some(config) | ||||
|     } else { | ||||
|         None | ||||
|     }; | ||||
|  | ||||
|     log::info!("Starting relay on http://{}", args.listen_address); | ||||
|  | ||||
|     let args_clone = args.clone(); | ||||
|     let server = HttpServer::new(move || { | ||||
|         App::new() | ||||
|             .wrap(middleware::Logger::default()) | ||||
|             .app_data(Data::new(args_clone.clone())) | ||||
|             .route("/", web::get().to(hello_route)) | ||||
|             .route("/config", web::get().to(config_route)) | ||||
|             .route("/ws", web::get().to(relay_ws)) | ||||
|     }); | ||||
|  | ||||
|     if let Some(tls_conf) = tls_config { | ||||
|         server.bind_rustls(&args.listen_address, tls_conf)? | ||||
|     } else { | ||||
|         server.bind(&args.listen_address)? | ||||
|     } | ||||
|     .run() | ||||
|     .await | ||||
| } | ||||
|   | ||||
| @@ -1,126 +1,10 @@ | ||||
| use std::sync::Arc; | ||||
|  | ||||
| use actix_web::web::Data; | ||||
| use actix_web::{middleware, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; | ||||
| use clap::Parser; | ||||
|  | ||||
| use base::{cert_utils, RelayedPort}; | ||||
| use tcp_relay_server::relay_ws::relay_ws; | ||||
| use tcp_relay_server::server_config::ServerConfig; | ||||
| use tcp_relay_server::tls_cert_client_verifier::CustomCertClientVerifier; | ||||
|  | ||||
| pub async fn hello_route() -> &'static str { | ||||
|     "Hello world!" | ||||
| } | ||||
|  | ||||
| pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> impl Responder { | ||||
|     if data.has_token_auth() { | ||||
|         let token = req | ||||
|             .headers() | ||||
|             .get("Authorization") | ||||
|             .map(|t| t.to_str().unwrap_or_default()) | ||||
|             .unwrap_or_default() | ||||
|             .strip_prefix("Bearer ") | ||||
|             .unwrap_or_default(); | ||||
|  | ||||
|         if !data.tokens.iter().any(|t| t.eq(token)) { | ||||
|             return HttpResponse::Unauthorized().json("Missing / invalid token"); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     HttpResponse::Ok().json( | ||||
|         data.ports | ||||
|             .iter() | ||||
|             .enumerate() | ||||
|             .map(|(id, port)| RelayedPort { | ||||
|                 id, | ||||
|                 port: port + data.increment_ports, | ||||
|             }) | ||||
|             .collect::<Vec<_>>(), | ||||
|     ) | ||||
| } | ||||
|  | ||||
| #[actix_web::main] | ||||
| async fn main() -> std::io::Result<()> { | ||||
|     env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); | ||||
|  | ||||
|     let mut args: ServerConfig = ServerConfig::parse(); | ||||
|  | ||||
|     // Check if no port are to be forwarded | ||||
|     if args.ports.is_empty() { | ||||
|         log::error!("No port to forward!"); | ||||
|         std::process::exit(2); | ||||
|     } | ||||
|  | ||||
|     // Read tokens from file, if any | ||||
|     if let Some(file) = &args.tokens_file { | ||||
|         std::fs::read_to_string(file) | ||||
|             .expect("Failed to read tokens file!") | ||||
|             .split('\n') | ||||
|             .filter(|l| !l.is_empty()) | ||||
|             .for_each(|t| args.tokens.push(t.to_string())); | ||||
|     } | ||||
|  | ||||
|     if !args.has_auth() { | ||||
|         log::error!("No authentication method specified!"); | ||||
|         std::process::exit(3); | ||||
|     } | ||||
|  | ||||
|     if args.has_tls_client_auth() && !args.has_tls_config() { | ||||
|         log::error!("Cannot provide client auth without TLS configuration!"); | ||||
|         panic!(); | ||||
|     } | ||||
|  | ||||
|     let args = Arc::new(args); | ||||
|  | ||||
|     // Load TLS configuration, if any | ||||
|     let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) { | ||||
|         // Load TLS certificate & private key | ||||
|         let cert_file = std::fs::read(cert).expect("Failed to read certificate file"); | ||||
|         let key_file = std::fs::read(key).expect("Failed to read server private key"); | ||||
|  | ||||
|         // Get certificates chain | ||||
|         let cert_chain = | ||||
|             cert_utils::parse_pem_certificates(&cert_file).expect("Failed to extract certificates"); | ||||
|  | ||||
|         // Get private key | ||||
|         let key = | ||||
|             cert_utils::parse_pem_private_key(&key_file).expect("Failed to extract private key!"); | ||||
|  | ||||
|         let config = rustls::ServerConfig::builder().with_safe_defaults(); | ||||
|  | ||||
|         let config = match args.has_tls_client_auth() { | ||||
|             true => config | ||||
|                 .with_client_cert_verifier(Arc::new(CustomCertClientVerifier::new(args.clone()))), | ||||
|             false => config.with_no_client_auth(), | ||||
|         }; | ||||
|  | ||||
|         let config = config | ||||
|             .with_single_cert(cert_chain, key) | ||||
|             .expect("Failed to load TLS certificate!"); | ||||
|  | ||||
|         Some(config) | ||||
|     } else { | ||||
|         None | ||||
|     }; | ||||
|  | ||||
|     log::info!("Starting relay on http://{}", args.listen_address); | ||||
|  | ||||
|     let args_clone = args.clone(); | ||||
|     let server = HttpServer::new(move || { | ||||
|         App::new() | ||||
|             .wrap(middleware::Logger::default()) | ||||
|             .app_data(Data::new(args_clone.clone())) | ||||
|             .route("/", web::get().to(hello_route)) | ||||
|             .route("/config", web::get().to(config_route)) | ||||
|             .route("/ws", web::get().to(relay_ws)) | ||||
|     }); | ||||
|  | ||||
|     if let Some(tls_conf) = tls_config { | ||||
|         server.bind_rustls(&args.listen_address, tls_conf)? | ||||
|     } else { | ||||
|         server.bind(&args.listen_address)? | ||||
|     } | ||||
|     .run() | ||||
|     .await | ||||
|     tcp_relay_server::run_app(ServerConfig::parse()).await | ||||
| } | ||||
|   | ||||
| @@ -71,3 +71,14 @@ impl ServerConfig { | ||||
|         self.has_token_auth() || self.has_tls_client_auth() | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod test { | ||||
|     use crate::server_config::ServerConfig; | ||||
|  | ||||
|     #[test] | ||||
|     fn verify_cli() { | ||||
|         use clap::CommandFactory; | ||||
|         ServerConfig::command().debug_assert() | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user