diff --git a/tcp_relay_client/src/client_config.rs b/tcp_relay_client/src/client_config.rs new file mode 100644 index 0000000..123396a --- /dev/null +++ b/tcp_relay_client/src/client_config.rs @@ -0,0 +1,40 @@ +use clap::Parser; + +static mut ROOT_CERT: Option> = None; + +/// TCP relay client +#[derive(Parser, Debug, Clone)] +#[clap(author, version, about, long_about = None)] +pub struct ClientConfig { + /// Access token + #[clap(short, long)] + pub token: String, + + /// Relay server + #[clap(short, long, default_value = "http://127.0.0.1:8000")] + pub relay_url: String, + + /// Listen address + #[clap(short, long, default_value = "127.0.0.1")] + pub listen_address: String, + + /// Optional root certificate to use for server authentication + #[clap(short = 'c', long)] + pub root_certificate: Option, +} + +impl ClientConfig { + /// Load root certificate + pub fn get_root_certificate(&self) -> Option> { + self.root_certificate.as_ref()?; + + if unsafe { ROOT_CERT.is_none() } { + log::info!("Loading root certificate from disk"); + let cert = self.root_certificate.as_ref().map(|c| std::fs::read(c) + .expect("Failed to read root certificate!")); + unsafe { ROOT_CERT = cert } + } + + unsafe { ROOT_CERT.clone() } + } +} \ No newline at end of file diff --git a/tcp_relay_client/src/lib.rs b/tcp_relay_client/src/lib.rs index dfc2cc9..2efbe10 100644 --- a/tcp_relay_client/src/lib.rs +++ b/tcp_relay_client/src/lib.rs @@ -1 +1,2 @@ +pub mod client_config; pub mod relay_client; \ No newline at end of file diff --git a/tcp_relay_client/src/main.rs b/tcp_relay_client/src/main.rs index 492108a..1030969 100644 --- a/tcp_relay_client/src/main.rs +++ b/tcp_relay_client/src/main.rs @@ -6,38 +6,18 @@ use futures::future::join_all; use reqwest::Certificate; use base::RemoteConfig; +use tcp_relay_client::client_config::ClientConfig; use tcp_relay_client::relay_client::relay_client; -/// TCP relay client -#[derive(Parser, Debug, Clone)] -#[clap(author, version, about, long_about = None)] -pub struct Args { - /// Access token - #[clap(short, long)] - pub token: String, - - /// Relay server - #[clap(short, long, default_value = "http://127.0.0.1:8000")] - pub relay_url: String, - - /// Listen address - #[clap(short, long, default_value = "127.0.0.1")] - pub listen_address: String, - - /// Optional root certificate to use for server authentication - #[clap(short = 'c', long)] - pub root_certificate: Option, -} - -async fn get_server_config(config: &Args, root_cert: &Option>) -> Result> { +async fn get_server_config(config: &ClientConfig) -> Result> { let url = format!("{}/config", config.relay_url); log::info!("Retrieving configuration on {}", url); let mut client = reqwest::Client::builder(); // Specify root certificate, if any was specified in the command line - if let Some(cert) = root_cert { - client = client.add_root_certificate(Certificate::from_pem(cert)?); + if let Some(cert) = config.get_root_certificate() { + client = client.add_root_certificate(Certificate::from_pem(&cert)?); } let client = client.build().expect("Failed to build reqwest client"); @@ -55,17 +35,14 @@ async fn get_server_config(config: &Args, root_cert: &Option>) -> Result } #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> Result<(), Box> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); - let args: Args = Args::parse(); + let args: ClientConfig = ClientConfig::parse(); let args = Arc::new(args); - let root_cert = args.root_certificate.as_ref().map(|c| std::fs::read(c) - .expect("Failed to read root certificate!")); - // Get server relay configuration (fetch the list of port to forward) - let conf = get_server_config(&args, &root_cert).await?; + let conf = get_server_config(&args).await?; // Start to listen port let mut handles = vec![]; @@ -77,7 +54,7 @@ async fn main() -> Result<(), Box> { args.relay_url, port.id, urlencoding::encode(&args.token)) .replace("http", "ws"), listen_address, - root_cert.clone(), + args.clone(), )); handles.push(h); } diff --git a/tcp_relay_client/src/relay_client.rs b/tcp_relay_client/src/relay_client.rs index 2a1391b..6875bd2 100644 --- a/tcp_relay_client/src/relay_client.rs +++ b/tcp_relay_client/src/relay_client.rs @@ -8,7 +8,9 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio_tungstenite::tungstenite::Message; -pub async fn relay_client(ws_url: String, listen_address: String, root_cert: Option>) { +use crate::client_config::ClientConfig; + +pub async fn relay_client(ws_url: String, listen_address: String, config: Arc) { log::info!("Start to listen on {}", listen_address); let listener = match TcpListener::bind(&listen_address).await { Ok(l) => l, @@ -22,7 +24,7 @@ pub async fn relay_client(ws_url: String, listen_address: String, root_cert: Opt let (socket, _) = listener.accept().await .expect("Failed to accept new connection!"); - tokio::spawn(relay_connection(ws_url.clone(), socket, root_cert.clone())); + tokio::spawn(relay_connection(ws_url.clone(), socket, config.clone())); } } @@ -30,14 +32,14 @@ pub async fn relay_client(ws_url: String, listen_address: String, root_cert: Opt /// /// WS read => TCP write /// TCP read => WS write -async fn relay_connection(ws_url: String, socket: TcpStream, root_cert: Option>) { +async fn relay_connection(ws_url: String, socket: TcpStream, conf: Arc) { log::debug!("Connecting to {}...", ws_url); let ws_stream = if ws_url.starts_with("wss") { let config = rustls::ClientConfig::builder() .with_safe_defaults(); - let config = match root_cert { + let config = match conf.get_root_certificate() { None => config.with_native_roots(), Some(cert) => { log::debug!("Using custom root certificates"); diff --git a/tcp_relay_server/src/lib.rs b/tcp_relay_server/src/lib.rs index 4e996c7..42b3602 100644 --- a/tcp_relay_server/src/lib.rs +++ b/tcp_relay_server/src/lib.rs @@ -1,2 +1,2 @@ -pub mod args; +pub mod server_config; pub mod relay_ws; \ No newline at end of file diff --git a/tcp_relay_server/src/main.rs b/tcp_relay_server/src/main.rs index ec0cf9b..072935b 100644 --- a/tcp_relay_server/src/main.rs +++ b/tcp_relay_server/src/main.rs @@ -5,18 +5,18 @@ use std::sync::Arc; use actix_web::{App, HttpRequest, HttpResponse, HttpServer, middleware, Responder, web}; use actix_web::web::Data; use clap::Parser; -use rustls::{Certificate, PrivateKey, ServerConfig}; +use rustls::{Certificate, PrivateKey}; use rustls_pemfile::{certs, Item, read_one}; use base::RelayedPort; -use tcp_relay_server::args::ProgramArgs; use tcp_relay_server::relay_ws::relay_ws; +use tcp_relay_server::server_config::ServerConfig; pub async fn hello_route() -> &'static str { "Hello world!" } -pub async fn config_route(req: HttpRequest, data: Data>) -> impl Responder { +pub async fn config_route(req: HttpRequest, data: Data>) -> impl Responder { let token = req.headers().get("Authorization") .map(|t| t.to_str().unwrap_or_default()) .unwrap_or_default() @@ -39,7 +39,7 @@ pub async fn config_route(req: HttpRequest, data: Data>) -> imp async fn main() -> std::io::Result<()> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); - let mut args: ProgramArgs = ProgramArgs::parse(); + let mut args: ServerConfig = ServerConfig::parse(); if args.ports.is_empty() { log::error!("No port to forward!"); @@ -73,7 +73,7 @@ async fn main() -> std::io::Result<()> { } }; - let config = ServerConfig::builder() + let config = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(cert_chain, PrivateKey(key)) diff --git a/tcp_relay_server/src/relay_ws.rs b/tcp_relay_server/src/relay_ws.rs index 45bfdbd..b605878 100644 --- a/tcp_relay_server/src/relay_ws.rs +++ b/tcp_relay_server/src/relay_ws.rs @@ -9,7 +9,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; -use crate::args::ProgramArgs; +use crate::server_config::ServerConfig; /// How often heartbeat pings are sent const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); @@ -153,7 +153,7 @@ pub struct WebSocketQuery { pub async fn relay_ws(req: HttpRequest, stream: web::Payload, query: web::Query, - conf: web::Data>) -> Result { + conf: web::Data>) -> Result { if !conf.tokens.contains(&query.token) { log::error!("Rejected WS request from {:?} due to invalid token!", req.peer_addr()); return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!")); diff --git a/tcp_relay_server/src/args.rs b/tcp_relay_server/src/server_config.rs similarity index 97% rename from tcp_relay_server/src/args.rs rename to tcp_relay_server/src/server_config.rs index b1096d3..ae084c3 100644 --- a/tcp_relay_server/src/args.rs +++ b/tcp_relay_server/src/server_config.rs @@ -4,7 +4,7 @@ use clap::Parser; #[derive(Parser, Debug, Clone)] #[clap(author, version, about, long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")] -pub struct ProgramArgs { +pub struct ServerConfig { /// Access tokens #[clap(short, long)] pub tokens: Vec,