Optimize root certificate management on client side

This commit is contained in:
Pierre HUBERT 2022-08-31 11:21:23 +02:00
parent 723ed5e390
commit 1b95b10553
8 changed files with 64 additions and 44 deletions

View File

@ -0,0 +1,40 @@
use clap::Parser;
static mut ROOT_CERT: Option<Vec<u8>> = None;
/// TCP relay client
#[derive(Parser, Debug, Clone)]
#[clap(author, version, about, long_about = None)]
pub struct ClientConfig {
/// Access token
#[clap(short, long)]
pub token: String,
/// Relay server
#[clap(short, long, default_value = "http://127.0.0.1:8000")]
pub relay_url: String,
/// Listen address
#[clap(short, long, default_value = "127.0.0.1")]
pub listen_address: String,
/// Optional root certificate to use for server authentication
#[clap(short = 'c', long)]
pub root_certificate: Option<String>,
}
impl ClientConfig {
/// Load root certificate
pub fn get_root_certificate(&self) -> Option<Vec<u8>> {
self.root_certificate.as_ref()?;
if unsafe { ROOT_CERT.is_none() } {
log::info!("Loading root certificate from disk");
let cert = self.root_certificate.as_ref().map(|c| std::fs::read(c)
.expect("Failed to read root certificate!"));
unsafe { ROOT_CERT = cert }
}
unsafe { ROOT_CERT.clone() }
}
}

View File

@ -1 +1,2 @@
pub mod client_config;
pub mod relay_client; pub mod relay_client;

View File

@ -6,38 +6,18 @@ use futures::future::join_all;
use reqwest::Certificate; use reqwest::Certificate;
use base::RemoteConfig; use base::RemoteConfig;
use tcp_relay_client::client_config::ClientConfig;
use tcp_relay_client::relay_client::relay_client; use tcp_relay_client::relay_client::relay_client;
/// TCP relay client async fn get_server_config(config: &ClientConfig) -> Result<RemoteConfig, Box<dyn Error>> {
#[derive(Parser, Debug, Clone)]
#[clap(author, version, about, long_about = None)]
pub struct Args {
/// Access token
#[clap(short, long)]
pub token: String,
/// Relay server
#[clap(short, long, default_value = "http://127.0.0.1:8000")]
pub relay_url: String,
/// Listen address
#[clap(short, long, default_value = "127.0.0.1")]
pub listen_address: String,
/// Optional root certificate to use for server authentication
#[clap(short = 'c', long)]
pub root_certificate: Option<String>,
}
async fn get_server_config(config: &Args, root_cert: &Option<Vec<u8>>) -> Result<RemoteConfig, Box<dyn Error>> {
let url = format!("{}/config", config.relay_url); let url = format!("{}/config", config.relay_url);
log::info!("Retrieving configuration on {}", url); log::info!("Retrieving configuration on {}", url);
let mut client = reqwest::Client::builder(); let mut client = reqwest::Client::builder();
// Specify root certificate, if any was specified in the command line // Specify root certificate, if any was specified in the command line
if let Some(cert) = root_cert { if let Some(cert) = config.get_root_certificate() {
client = client.add_root_certificate(Certificate::from_pem(cert)?); client = client.add_root_certificate(Certificate::from_pem(&cert)?);
} }
let client = client.build().expect("Failed to build reqwest client"); let client = client.build().expect("Failed to build reqwest client");
@ -55,17 +35,14 @@ async fn get_server_config(config: &Args, root_cert: &Option<Vec<u8>>) -> Result
} }
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn Error>> {
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));
let args: Args = Args::parse(); let args: ClientConfig = ClientConfig::parse();
let args = Arc::new(args); let args = Arc::new(args);
let root_cert = args.root_certificate.as_ref().map(|c| std::fs::read(c)
.expect("Failed to read root certificate!"));
// Get server relay configuration (fetch the list of port to forward) // Get server relay configuration (fetch the list of port to forward)
let conf = get_server_config(&args, &root_cert).await?; let conf = get_server_config(&args).await?;
// Start to listen port // Start to listen port
let mut handles = vec![]; let mut handles = vec![];
@ -77,7 +54,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
args.relay_url, port.id, urlencoding::encode(&args.token)) args.relay_url, port.id, urlencoding::encode(&args.token))
.replace("http", "ws"), .replace("http", "ws"),
listen_address, listen_address,
root_cert.clone(), args.clone(),
)); ));
handles.push(h); handles.push(h);
} }

View File

@ -8,7 +8,9 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Message;
pub async fn relay_client(ws_url: String, listen_address: String, root_cert: Option<Vec<u8>>) { use crate::client_config::ClientConfig;
pub async fn relay_client(ws_url: String, listen_address: String, config: Arc<ClientConfig>) {
log::info!("Start to listen on {}", listen_address); log::info!("Start to listen on {}", listen_address);
let listener = match TcpListener::bind(&listen_address).await { let listener = match TcpListener::bind(&listen_address).await {
Ok(l) => l, Ok(l) => l,
@ -22,7 +24,7 @@ pub async fn relay_client(ws_url: String, listen_address: String, root_cert: Opt
let (socket, _) = listener.accept().await let (socket, _) = listener.accept().await
.expect("Failed to accept new connection!"); .expect("Failed to accept new connection!");
tokio::spawn(relay_connection(ws_url.clone(), socket, root_cert.clone())); tokio::spawn(relay_connection(ws_url.clone(), socket, config.clone()));
} }
} }
@ -30,14 +32,14 @@ pub async fn relay_client(ws_url: String, listen_address: String, root_cert: Opt
/// ///
/// WS read => TCP write /// WS read => TCP write
/// TCP read => WS write /// TCP read => WS write
async fn relay_connection(ws_url: String, socket: TcpStream, root_cert: Option<Vec<u8>>) { async fn relay_connection(ws_url: String, socket: TcpStream, conf: Arc<ClientConfig>) {
log::debug!("Connecting to {}...", ws_url); log::debug!("Connecting to {}...", ws_url);
let ws_stream = if ws_url.starts_with("wss") { let ws_stream = if ws_url.starts_with("wss") {
let config = rustls::ClientConfig::builder() let config = rustls::ClientConfig::builder()
.with_safe_defaults(); .with_safe_defaults();
let config = match root_cert { let config = match conf.get_root_certificate() {
None => config.with_native_roots(), None => config.with_native_roots(),
Some(cert) => { Some(cert) => {
log::debug!("Using custom root certificates"); log::debug!("Using custom root certificates");

View File

@ -1,2 +1,2 @@
pub mod args; pub mod server_config;
pub mod relay_ws; pub mod relay_ws;

View File

@ -5,18 +5,18 @@ use std::sync::Arc;
use actix_web::{App, HttpRequest, HttpResponse, HttpServer, middleware, Responder, web}; use actix_web::{App, HttpRequest, HttpResponse, HttpServer, middleware, Responder, web};
use actix_web::web::Data; use actix_web::web::Data;
use clap::Parser; use clap::Parser;
use rustls::{Certificate, PrivateKey, ServerConfig}; use rustls::{Certificate, PrivateKey};
use rustls_pemfile::{certs, Item, read_one}; use rustls_pemfile::{certs, Item, read_one};
use base::RelayedPort; use base::RelayedPort;
use tcp_relay_server::args::ProgramArgs;
use tcp_relay_server::relay_ws::relay_ws; use tcp_relay_server::relay_ws::relay_ws;
use tcp_relay_server::server_config::ServerConfig;
pub async fn hello_route() -> &'static str { pub async fn hello_route() -> &'static str {
"Hello world!" "Hello world!"
} }
pub async fn config_route(req: HttpRequest, data: Data<Arc<ProgramArgs>>) -> impl Responder { pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> impl Responder {
let token = req.headers().get("Authorization") let token = req.headers().get("Authorization")
.map(|t| t.to_str().unwrap_or_default()) .map(|t| t.to_str().unwrap_or_default())
.unwrap_or_default() .unwrap_or_default()
@ -39,7 +39,7 @@ pub async fn config_route(req: HttpRequest, data: Data<Arc<ProgramArgs>>) -> imp
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));
let mut args: ProgramArgs = ProgramArgs::parse(); let mut args: ServerConfig = ServerConfig::parse();
if args.ports.is_empty() { if args.ports.is_empty() {
log::error!("No port to forward!"); log::error!("No port to forward!");
@ -73,7 +73,7 @@ async fn main() -> std::io::Result<()> {
} }
}; };
let config = ServerConfig::builder() let config = rustls::ServerConfig::builder()
.with_safe_defaults() .with_safe_defaults()
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(cert_chain, PrivateKey(key)) .with_single_cert(cert_chain, PrivateKey(key))

View File

@ -9,7 +9,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use crate::args::ProgramArgs; use crate::server_config::ServerConfig;
/// How often heartbeat pings are sent /// How often heartbeat pings are sent
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
@ -153,7 +153,7 @@ pub struct WebSocketQuery {
pub async fn relay_ws(req: HttpRequest, stream: web::Payload, pub async fn relay_ws(req: HttpRequest, stream: web::Payload,
query: web::Query<WebSocketQuery>, query: web::Query<WebSocketQuery>,
conf: web::Data<Arc<ProgramArgs>>) -> Result<HttpResponse, Error> { conf: web::Data<Arc<ServerConfig>>) -> Result<HttpResponse, Error> {
if !conf.tokens.contains(&query.token) { if !conf.tokens.contains(&query.token) {
log::error!("Rejected WS request from {:?} due to invalid token!", req.peer_addr()); log::error!("Rejected WS request from {:?} due to invalid token!", req.peer_addr());
return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!")); return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!"));

View File

@ -4,7 +4,7 @@ use clap::Parser;
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
#[clap(author, version, about, #[clap(author, version, about,
long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")] long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")]
pub struct ProgramArgs { pub struct ServerConfig {
/// Access tokens /// Access tokens
#[clap(short, long)] #[clap(short, long)]
pub tokens: Vec<String>, pub tokens: Vec<String>,