diff --git a/.gitignore b/.gitignore index 82eaf26..0cab0fc 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ target .idea *.crt *.key +pki diff --git a/Cargo.lock b/Cargo.lock index 995b79c..dfdd0e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -837,6 +837,21 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d87c48c02e0dc5e3b849a2041db3029fd066650f8f717c07bf8ed78ccb895cac" +dependencies = [ + "http", + "hyper", + "log", + "rustls", + "rustls-native-certs", + "tokio", + "tokio-rustls", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -1279,6 +1294,7 @@ dependencies = [ "http", "http-body", "hyper", + "hyper-rustls", "hyper-tls", "ipnet", "js-sys", @@ -1288,16 +1304,20 @@ dependencies = [ "native-tls", "percent-encoding", "pin-project-lite", + "rustls", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", "tokio", "tokio-native-tls", + "tokio-rustls", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", + "webpki-roots", "winreg", ] @@ -1337,6 +1357,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "rustls-native-certs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.1" @@ -1537,8 +1569,11 @@ dependencies = [ "clap", "env_logger", "futures", + "hyper-rustls", "log", "reqwest", + "rustls", + "rustls-pemfile", "tokio", "tokio-tungstenite", "urlencoding", @@ -1706,8 +1741,12 @@ checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181" dependencies = [ "futures-util", "log", + "rustls", + "rustls-native-certs", "tokio", + "tokio-rustls", "tungstenite", + "webpki", ] [[package]] @@ -1770,10 +1809,12 @@ dependencies = [ "httparse", "log", "rand", + "rustls", "sha-1", "thiserror", "url", "utf-8", + "webpki", ] [[package]] diff --git a/tcp_relay_client/Cargo.toml b/tcp_relay_client/Cargo.toml index a35978d..cc87933 100644 --- a/tcp_relay_client/Cargo.toml +++ b/tcp_relay_client/Cargo.toml @@ -8,8 +8,11 @@ base = { path = "../base" } clap = { version = "3.2.18", features = ["derive", "env"] } log = "0.4.17" env_logger = "0.9.0" -reqwest = { version = "0.11", features = ["json"] } +reqwest = { version = "0.11", features = ["json", "rustls-tls"] } tokio = { version = "1", features = ["full"] } futures = "0.3.24" -tokio-tungstenite = "0.17.2" -urlencoding = "2.1.0" \ No newline at end of file +tokio-tungstenite = { version = "0.17.2", features = ["__rustls-tls", "rustls-tls-native-roots"] } +urlencoding = "2.1.0" +rustls = { version = "0.20.6" } +hyper-rustls = { version = "0.23.0", features = ["rustls-native-certs"] } +rustls-pemfile = { version = "1.0.1" } \ No newline at end of file diff --git a/tcp_relay_client/src/main.rs b/tcp_relay_client/src/main.rs index 430c92b..492108a 100644 --- a/tcp_relay_client/src/main.rs +++ b/tcp_relay_client/src/main.rs @@ -1,7 +1,9 @@ +use std::error::Error; use std::sync::Arc; use clap::Parser; use futures::future::join_all; +use reqwest::Certificate; use base::RemoteConfig; use tcp_relay_client::relay_client::relay_client; @@ -21,6 +23,35 @@ pub struct Args { /// Listen address #[clap(short, long, default_value = "127.0.0.1")] pub listen_address: String, + + /// Optional root certificate to use for server authentication + #[clap(short = 'c', long)] + pub root_certificate: Option, +} + +async fn get_server_config(config: &Args, root_cert: &Option>) -> Result> { + let url = format!("{}/config", config.relay_url); + log::info!("Retrieving configuration on {}", url); + + let mut client = reqwest::Client::builder(); + + // Specify root certificate, if any was specified in the command line + if let Some(cert) = root_cert { + client = client.add_root_certificate(Certificate::from_pem(cert)?); + } + + let client = client.build().expect("Failed to build reqwest client"); + + let req = client.get(url) + .header("Authorization", format!("Bearer {}", config.token)) + .send() + .await?; + if req.status().as_u16() != 200 { + log::error!("Could not retrieve configuration! (got status {})", req.status()); + std::process::exit(2); + } + + Ok(req.json::().await?) } #[tokio::main] @@ -30,19 +61,11 @@ async fn main() -> Result<(), Box> { let args: Args = Args::parse(); let args = Arc::new(args); + let root_cert = args.root_certificate.as_ref().map(|c| std::fs::read(c) + .expect("Failed to read root certificate!")); + // Get server relay configuration (fetch the list of port to forward) - let url = format!("{}/config", args.relay_url); - log::info!("Retrieving configuration on {}", url); - let req = reqwest::Client::new().get(url) - .header("Authorization", format!("Bearer {}", args.token)) - .send() - .await?; - if req.status().as_u16() != 200 { - log::error!("Could not retrieve configuration! (got status {})", req.status()); - std::process::exit(2); - } - let conf = req.json::() - .await?; + let conf = get_server_config(&args, &root_cert).await?; // Start to listen port let mut handles = vec![]; @@ -54,6 +77,7 @@ async fn main() -> Result<(), Box> { args.relay_url, port.id, urlencoding::encode(&args.token)) .replace("http", "ws"), listen_address, + root_cert.clone(), )); handles.push(h); } diff --git a/tcp_relay_client/src/relay_client.rs b/tcp_relay_client/src/relay_client.rs index 2bd27e6..2a1391b 100644 --- a/tcp_relay_client/src/relay_client.rs +++ b/tcp_relay_client/src/relay_client.rs @@ -1,9 +1,14 @@ +use std::io::Cursor; +use std::sync::Arc; + use futures::{SinkExt, StreamExt}; +use hyper_rustls::ConfigBuilderExt; +use rustls::{Certificate, RootCertStore}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio_tungstenite::tungstenite::Message; -pub async fn relay_client(ws_url: String, listen_address: String) { +pub async fn relay_client(ws_url: String, listen_address: String, root_cert: Option>) { log::info!("Start to listen on {}", listen_address); let listener = match TcpListener::bind(&listen_address).await { Ok(l) => l, @@ -17,7 +22,7 @@ pub async fn relay_client(ws_url: String, listen_address: String) { let (socket, _) = listener.accept().await .expect("Failed to accept new connection!"); - tokio::spawn(relay_connection(ws_url.clone(), socket)); + tokio::spawn(relay_connection(ws_url.clone(), socket, root_cert.clone())); } } @@ -25,10 +30,44 @@ pub async fn relay_client(ws_url: String, listen_address: String) { /// /// WS read => TCP write /// TCP read => WS write -async fn relay_connection(ws_url: String, socket: TcpStream) { +async fn relay_connection(ws_url: String, socket: TcpStream, root_cert: Option>) { log::debug!("Connecting to {}...", ws_url); - let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url) - .await.expect("Failed to connect to server relay!"); + + let ws_stream = if ws_url.starts_with("wss") { + let config = rustls::ClientConfig::builder() + .with_safe_defaults(); + + let config = match root_cert { + None => config.with_native_roots(), + Some(cert) => { + log::debug!("Using custom root certificates"); + let mut store = RootCertStore::empty(); + rustls_pemfile::certs(&mut Cursor::new(cert)) + .expect("Failed to parse root certificates!") + .into_iter() + .map(Certificate) + .for_each(|c| store.add(&c).expect("Failed to add certificate to chain!")); + + config.with_root_certificates(store) + } + }; + + let config = config.with_no_client_auth(); + let connector = tokio_tungstenite::Connector::Rustls(Arc::new(config)); + + let (ws_stream, _) = tokio_tungstenite::connect_async_tls_with_config( + ws_url, + None, + Some(connector)) + .await.expect("Failed to connect to server relay!"); + + ws_stream + } else { + let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url) + .await.expect("Failed to connect to server relay!"); + + ws_stream + }; let (mut tcp_read, mut tcp_write) = socket.into_split();