diff --git a/Cargo.lock b/Cargo.lock index dfdd0e0..9baec6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1596,6 +1596,7 @@ dependencies = [ "rustls-pemfile", "serde", "tokio", + "webpki", ] [[package]] diff --git a/tcp_relay_client/src/client_config.rs b/tcp_relay_client/src/client_config.rs index 123396a..e09a690 100644 --- a/tcp_relay_client/src/client_config.rs +++ b/tcp_relay_client/src/client_config.rs @@ -8,7 +8,7 @@ static mut ROOT_CERT: Option> = None; pub struct ClientConfig { /// Access token #[clap(short, long)] - pub token: String, + pub token: Option, /// Relay server #[clap(short, long, default_value = "http://127.0.0.1:8000")] @@ -24,6 +24,11 @@ pub struct ClientConfig { } impl ClientConfig { + /// Get client token, returning a dummy token if none was specified + pub fn get_auth_token(&self) -> &str { + self.token.as_deref().unwrap_or("none") + } + /// Load root certificate pub fn get_root_certificate(&self) -> Option> { self.root_certificate.as_ref()?; diff --git a/tcp_relay_client/src/main.rs b/tcp_relay_client/src/main.rs index 1030969..814243a 100644 --- a/tcp_relay_client/src/main.rs +++ b/tcp_relay_client/src/main.rs @@ -23,7 +23,7 @@ async fn get_server_config(config: &ClientConfig) -> Result Result<(), Box> { let h = tokio::spawn(relay_client( format!("{}/ws?id={}&token={}", - args.relay_url, port.id, urlencoding::encode(&args.token)) + args.relay_url, port.id, urlencoding::encode(args.get_auth_token())) .replace("http", "ws"), listen_address, args.clone(), diff --git a/tcp_relay_server/Cargo.toml b/tcp_relay_server/Cargo.toml index f22f418..32c9df7 100644 --- a/tcp_relay_server/Cargo.toml +++ b/tcp_relay_server/Cargo.toml @@ -16,4 +16,5 @@ serde = { version = "1.0.144", features = ["derive"] } tokio = { version = "1", features = ["full"] } futures = "0.3.24" rustls = "0.20.6" -rustls-pemfile = "1.0.1" \ No newline at end of file +rustls-pemfile = "1.0.1" +webpki = "0.22.0" \ No newline at end of file diff --git a/tcp_relay_server/src/lib.rs b/tcp_relay_server/src/lib.rs index 42b3602..44db330 100644 --- a/tcp_relay_server/src/lib.rs +++ b/tcp_relay_server/src/lib.rs @@ -1,2 +1,3 @@ pub mod server_config; -pub mod relay_ws; \ No newline at end of file +pub mod relay_ws; +pub mod tls_cert_client_verifier; \ No newline at end of file diff --git a/tcp_relay_server/src/main.rs b/tcp_relay_server/src/main.rs index 072935b..9622947 100644 --- a/tcp_relay_server/src/main.rs +++ b/tcp_relay_server/src/main.rs @@ -11,20 +11,23 @@ use rustls_pemfile::{certs, Item, read_one}; use base::RelayedPort; use tcp_relay_server::relay_ws::relay_ws; use tcp_relay_server::server_config::ServerConfig; +use tcp_relay_server::tls_cert_client_verifier::CustomCertClientVerifier; pub async fn hello_route() -> &'static str { "Hello world!" } pub async fn config_route(req: HttpRequest, data: Data>) -> impl Responder { - let token = req.headers().get("Authorization") - .map(|t| t.to_str().unwrap_or_default()) - .unwrap_or_default() - .strip_prefix("Bearer ") - .unwrap_or_default(); + if data.has_token_auth() { + let token = req.headers().get("Authorization") + .map(|t| t.to_str().unwrap_or_default()) + .unwrap_or_default() + .strip_prefix("Bearer ") + .unwrap_or_default(); - if !data.tokens.iter().any(|t| t.eq(token)) { - return HttpResponse::Unauthorized().json("Missing / invalid token"); + if !data.tokens.iter().any(|t| t.eq(token)) { + return HttpResponse::Unauthorized().json("Missing / invalid token"); + } } HttpResponse::Ok().json( @@ -41,11 +44,33 @@ async fn main() -> std::io::Result<()> { let mut args: ServerConfig = ServerConfig::parse(); + // Check if no port are to be forwarded if args.ports.is_empty() { log::error!("No port to forward!"); std::process::exit(2); } + // Read tokens from file, if any + if let Some(file) = &args.tokens_file { + std::fs::read_to_string(file) + .expect("Failed to read tokens file!") + .split('\n') + .filter(|l| !l.is_empty()) + .for_each(|t| args.tokens.push(t.to_string())); + } + + if !args.has_auth() { + log::error!("No authentication method specified!"); + std::process::exit(3); + } + + if args.has_tls_client_auth() && !args.has_tls_config() { + log::error!("Cannot provide client auth without TLS configuration!"); + panic!(); + } + + let args = Arc::new(args); + // Load TLS configuration, if any let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) { @@ -74,31 +99,22 @@ async fn main() -> std::io::Result<()> { }; let config = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(cert_chain, PrivateKey(key)) + .with_safe_defaults(); + + let config = match args.has_tls_client_auth() { + true => config.with_client_cert_verifier( + Arc::new(CustomCertClientVerifier::new(args.clone()))), + false => config.with_no_client_auth() + }; + + let config = config.with_single_cert(cert_chain, PrivateKey(key)) .expect("Failed to load TLS certificate!"); Some(config) } else { None }; - // Read tokens from file, if any - if let Some(file) = &args.tokens_file { - std::fs::read_to_string(file) - .expect("Failed to read tokens file!") - .split('\n') - .filter(|l| !l.is_empty()) - .for_each(|t| args.tokens.push(t.to_string())); - } - - if args.tokens.is_empty() { - log::error!("No tokens specified!"); - std::process::exit(3); - } - log::info!("Starting relay on http://{}", args.listen_address); - let args = Arc::new(args); let args_clone = args.clone(); let server = HttpServer::new(move || { App::new() diff --git a/tcp_relay_server/src/relay_ws.rs b/tcp_relay_server/src/relay_ws.rs index b605878..3d2c57c 100644 --- a/tcp_relay_server/src/relay_ws.rs +++ b/tcp_relay_server/src/relay_ws.rs @@ -94,7 +94,6 @@ impl Actor for RelayWS { } log::info!("Exited read loop"); - // TODO : notify context }; tokio::spawn(future); @@ -148,13 +147,14 @@ impl Handler for RelayWS { #[derive(serde::Deserialize)] pub struct WebSocketQuery { id: usize, - token: String, + token: Option, } pub async fn relay_ws(req: HttpRequest, stream: web::Payload, query: web::Query, conf: web::Data>) -> Result { - if !conf.tokens.contains(&query.token) { + if conf.has_token_auth() && + !conf.tokens.iter().any(|t| t == query.token.as_deref().unwrap_or_default()) { log::error!("Rejected WS request from {:?} due to invalid token!", req.peer_addr()); return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!")); } diff --git a/tcp_relay_server/src/server_config.rs b/tcp_relay_server/src/server_config.rs index ae084c3..f235920 100644 --- a/tcp_relay_server/src/server_config.rs +++ b/tcp_relay_server/src/server_config.rs @@ -3,7 +3,7 @@ use clap::Parser; /// TCP relay server #[derive(Parser, Debug, Clone)] #[clap(author, version, about, - long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")] +long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")] pub struct ServerConfig { /// Access tokens #[clap(short, long)] @@ -37,4 +37,29 @@ pub struct ServerConfig { /// TLS private key. Specify also certificate to use HTTPS/TLS instead of HTTP #[clap(long)] pub tls_key: Option, + + /// Restrict TLS client authentication to certificates signed directly or indirectly by the + /// provided root certificates + /// + /// This option automatically enable TLS client authentication + #[clap(long)] + pub tls_client_auth_root_cert: Option, +} + +impl ServerConfig { + pub fn has_token_auth(&self) -> bool { + !self.tokens.is_empty() + } + + pub fn has_tls_config(&self) -> bool { + self.tls_cert.is_some() && self.tls_key.is_some() + } + + pub fn has_tls_client_auth(&self) -> bool { + self.tls_client_auth_root_cert.is_some() + } + + pub fn has_auth(&self) -> bool { + self.has_token_auth() || self.has_tls_client_auth() + } } \ No newline at end of file diff --git a/tcp_relay_server/src/tls_cert_client_verifier.rs b/tcp_relay_server/src/tls_cert_client_verifier.rs new file mode 100644 index 0000000..a025775 --- /dev/null +++ b/tcp_relay_server/src/tls_cert_client_verifier.rs @@ -0,0 +1,61 @@ +use std::fs::File; +use std::io::BufReader; +use std::sync::Arc; +use std::time::SystemTime; + +use rustls::{Certificate, DistinguishedNames, Error, RootCertStore}; +use rustls::server::{AllowAnyAuthenticatedClient, ClientCertVerified, ClientCertVerifier}; +use rustls_pemfile::certs; + +use crate::server_config::ServerConfig; + +pub struct CustomCertClientVerifier { + upstream_cert_verifier: Box>, +} + +impl CustomCertClientVerifier { + pub fn new(conf: Arc) -> Self { + let cert_path = conf.tls_client_auth_root_cert.as_deref() + .expect("No root certificates for client authentication provided!"); + let cert_file = &mut BufReader::new(File::open(cert_path) + .expect("Failed to read root certificates for client authentication!")); + + let root_certs = certs(cert_file).unwrap() + .into_iter() + .map(Certificate) + .collect::>(); + + if root_certs.is_empty() { + log::error!("No certificates found for client authentication!"); + panic!(); + } + + let mut store = RootCertStore::empty(); + for cert in root_certs { + store.add(&cert).expect("Failed to add certificate to root store"); + } + + Self { + upstream_cert_verifier: Box::new(AllowAnyAuthenticatedClient::new(store)), + } + } +} + +impl ClientCertVerifier for CustomCertClientVerifier { + fn offer_client_auth(&self) -> bool { + true + } + + fn client_auth_mandatory(&self) -> Option { + Some(true) + } + + fn client_auth_root_subjects(&self) -> Option { + Some(vec![]) + } + + fn verify_client_cert(&self, end_entity: &Certificate, intermediates: &[Certificate], now: SystemTime) -> Result { + self.upstream_cert_verifier.verify_client_cert(end_entity, intermediates, now) + } +} +