Add client TLS auth on server side

This commit is contained in:
Pierre HUBERT 2022-08-31 12:24:54 +02:00
parent 1b95b10553
commit 27b52dfcb7
9 changed files with 144 additions and 34 deletions

1
Cargo.lock generated
View File

@ -1596,6 +1596,7 @@ dependencies = [
"rustls-pemfile", "rustls-pemfile",
"serde", "serde",
"tokio", "tokio",
"webpki",
] ]
[[package]] [[package]]

View File

@ -8,7 +8,7 @@ static mut ROOT_CERT: Option<Vec<u8>> = None;
pub struct ClientConfig { pub struct ClientConfig {
/// Access token /// Access token
#[clap(short, long)] #[clap(short, long)]
pub token: String, pub token: Option<String>,
/// Relay server /// Relay server
#[clap(short, long, default_value = "http://127.0.0.1:8000")] #[clap(short, long, default_value = "http://127.0.0.1:8000")]
@ -24,6 +24,11 @@ pub struct ClientConfig {
} }
impl ClientConfig { impl ClientConfig {
/// Get client token, returning a dummy token if none was specified
pub fn get_auth_token(&self) -> &str {
self.token.as_deref().unwrap_or("none")
}
/// Load root certificate /// Load root certificate
pub fn get_root_certificate(&self) -> Option<Vec<u8>> { pub fn get_root_certificate(&self) -> Option<Vec<u8>> {
self.root_certificate.as_ref()?; self.root_certificate.as_ref()?;

View File

@ -23,7 +23,7 @@ async fn get_server_config(config: &ClientConfig) -> Result<RemoteConfig, Box<dy
let client = client.build().expect("Failed to build reqwest client"); let client = client.build().expect("Failed to build reqwest client");
let req = client.get(url) let req = client.get(url)
.header("Authorization", format!("Bearer {}", config.token)) .header("Authorization", format!("Bearer {}", config.get_auth_token()))
.send() .send()
.await?; .await?;
if req.status().as_u16() != 200 { if req.status().as_u16() != 200 {
@ -51,7 +51,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
let h = tokio::spawn(relay_client( let h = tokio::spawn(relay_client(
format!("{}/ws?id={}&token={}", format!("{}/ws?id={}&token={}",
args.relay_url, port.id, urlencoding::encode(&args.token)) args.relay_url, port.id, urlencoding::encode(args.get_auth_token()))
.replace("http", "ws"), .replace("http", "ws"),
listen_address, listen_address,
args.clone(), args.clone(),

View File

@ -17,3 +17,4 @@ tokio = { version = "1", features = ["full"] }
futures = "0.3.24" futures = "0.3.24"
rustls = "0.20.6" rustls = "0.20.6"
rustls-pemfile = "1.0.1" rustls-pemfile = "1.0.1"
webpki = "0.22.0"

View File

@ -1,2 +1,3 @@
pub mod server_config; pub mod server_config;
pub mod relay_ws; pub mod relay_ws;
pub mod tls_cert_client_verifier;

View File

@ -11,12 +11,14 @@ use rustls_pemfile::{certs, Item, read_one};
use base::RelayedPort; use base::RelayedPort;
use tcp_relay_server::relay_ws::relay_ws; use tcp_relay_server::relay_ws::relay_ws;
use tcp_relay_server::server_config::ServerConfig; use tcp_relay_server::server_config::ServerConfig;
use tcp_relay_server::tls_cert_client_verifier::CustomCertClientVerifier;
pub async fn hello_route() -> &'static str { pub async fn hello_route() -> &'static str {
"Hello world!" "Hello world!"
} }
pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> impl Responder { pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> impl Responder {
if data.has_token_auth() {
let token = req.headers().get("Authorization") let token = req.headers().get("Authorization")
.map(|t| t.to_str().unwrap_or_default()) .map(|t| t.to_str().unwrap_or_default())
.unwrap_or_default() .unwrap_or_default()
@ -26,6 +28,7 @@ pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> im
if !data.tokens.iter().any(|t| t.eq(token)) { if !data.tokens.iter().any(|t| t.eq(token)) {
return HttpResponse::Unauthorized().json("Missing / invalid token"); return HttpResponse::Unauthorized().json("Missing / invalid token");
} }
}
HttpResponse::Ok().json( HttpResponse::Ok().json(
data.ports.iter() data.ports.iter()
@ -41,11 +44,33 @@ async fn main() -> std::io::Result<()> {
let mut args: ServerConfig = ServerConfig::parse(); let mut args: ServerConfig = ServerConfig::parse();
// Check if no port are to be forwarded
if args.ports.is_empty() { if args.ports.is_empty() {
log::error!("No port to forward!"); log::error!("No port to forward!");
std::process::exit(2); std::process::exit(2);
} }
// Read tokens from file, if any
if let Some(file) = &args.tokens_file {
std::fs::read_to_string(file)
.expect("Failed to read tokens file!")
.split('\n')
.filter(|l| !l.is_empty())
.for_each(|t| args.tokens.push(t.to_string()));
}
if !args.has_auth() {
log::error!("No authentication method specified!");
std::process::exit(3);
}
if args.has_tls_client_auth() && !args.has_tls_config() {
log::error!("Cannot provide client auth without TLS configuration!");
panic!();
}
let args = Arc::new(args);
// Load TLS configuration, if any // Load TLS configuration, if any
let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) { let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) {
@ -74,31 +99,22 @@ async fn main() -> std::io::Result<()> {
}; };
let config = rustls::ServerConfig::builder() let config = rustls::ServerConfig::builder()
.with_safe_defaults() .with_safe_defaults();
.with_no_client_auth()
.with_single_cert(cert_chain, PrivateKey(key)) let config = match args.has_tls_client_auth() {
true => config.with_client_cert_verifier(
Arc::new(CustomCertClientVerifier::new(args.clone()))),
false => config.with_no_client_auth()
};
let config = config.with_single_cert(cert_chain, PrivateKey(key))
.expect("Failed to load TLS certificate!"); .expect("Failed to load TLS certificate!");
Some(config) Some(config)
} else { None }; } else { None };
// Read tokens from file, if any
if let Some(file) = &args.tokens_file {
std::fs::read_to_string(file)
.expect("Failed to read tokens file!")
.split('\n')
.filter(|l| !l.is_empty())
.for_each(|t| args.tokens.push(t.to_string()));
}
if args.tokens.is_empty() {
log::error!("No tokens specified!");
std::process::exit(3);
}
log::info!("Starting relay on http://{}", args.listen_address); log::info!("Starting relay on http://{}", args.listen_address);
let args = Arc::new(args);
let args_clone = args.clone(); let args_clone = args.clone();
let server = HttpServer::new(move || { let server = HttpServer::new(move || {
App::new() App::new()

View File

@ -94,7 +94,6 @@ impl Actor for RelayWS {
} }
log::info!("Exited read loop"); log::info!("Exited read loop");
// TODO : notify context
}; };
tokio::spawn(future); tokio::spawn(future);
@ -148,13 +147,14 @@ impl Handler<TCPReadEndClosed> for RelayWS {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
pub struct WebSocketQuery { pub struct WebSocketQuery {
id: usize, id: usize,
token: String, token: Option<String>,
} }
pub async fn relay_ws(req: HttpRequest, stream: web::Payload, pub async fn relay_ws(req: HttpRequest, stream: web::Payload,
query: web::Query<WebSocketQuery>, query: web::Query<WebSocketQuery>,
conf: web::Data<Arc<ServerConfig>>) -> Result<HttpResponse, Error> { conf: web::Data<Arc<ServerConfig>>) -> Result<HttpResponse, Error> {
if !conf.tokens.contains(&query.token) { if conf.has_token_auth() &&
!conf.tokens.iter().any(|t| t == query.token.as_deref().unwrap_or_default()) {
log::error!("Rejected WS request from {:?} due to invalid token!", req.peer_addr()); log::error!("Rejected WS request from {:?} due to invalid token!", req.peer_addr());
return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!")); return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!"));
} }

View File

@ -3,7 +3,7 @@ use clap::Parser;
/// TCP relay server /// TCP relay server
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
#[clap(author, version, about, #[clap(author, version, about,
long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")] long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")]
pub struct ServerConfig { pub struct ServerConfig {
/// Access tokens /// Access tokens
#[clap(short, long)] #[clap(short, long)]
@ -37,4 +37,29 @@ pub struct ServerConfig {
/// TLS private key. Specify also certificate to use HTTPS/TLS instead of HTTP /// TLS private key. Specify also certificate to use HTTPS/TLS instead of HTTP
#[clap(long)] #[clap(long)]
pub tls_key: Option<String>, pub tls_key: Option<String>,
/// Restrict TLS client authentication to certificates signed directly or indirectly by the
/// provided root certificates
///
/// This option automatically enable TLS client authentication
#[clap(long)]
pub tls_client_auth_root_cert: Option<String>,
}
impl ServerConfig {
pub fn has_token_auth(&self) -> bool {
!self.tokens.is_empty()
}
pub fn has_tls_config(&self) -> bool {
self.tls_cert.is_some() && self.tls_key.is_some()
}
pub fn has_tls_client_auth(&self) -> bool {
self.tls_client_auth_root_cert.is_some()
}
pub fn has_auth(&self) -> bool {
self.has_token_auth() || self.has_tls_client_auth()
}
} }

View File

@ -0,0 +1,61 @@
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use std::time::SystemTime;
use rustls::{Certificate, DistinguishedNames, Error, RootCertStore};
use rustls::server::{AllowAnyAuthenticatedClient, ClientCertVerified, ClientCertVerifier};
use rustls_pemfile::certs;
use crate::server_config::ServerConfig;
pub struct CustomCertClientVerifier {
upstream_cert_verifier: Box<Arc<dyn ClientCertVerifier>>,
}
impl CustomCertClientVerifier {
pub fn new(conf: Arc<ServerConfig>) -> Self {
let cert_path = conf.tls_client_auth_root_cert.as_deref()
.expect("No root certificates for client authentication provided!");
let cert_file = &mut BufReader::new(File::open(cert_path)
.expect("Failed to read root certificates for client authentication!"));
let root_certs = certs(cert_file).unwrap()
.into_iter()
.map(Certificate)
.collect::<Vec<_>>();
if root_certs.is_empty() {
log::error!("No certificates found for client authentication!");
panic!();
}
let mut store = RootCertStore::empty();
for cert in root_certs {
store.add(&cert).expect("Failed to add certificate to root store");
}
Self {
upstream_cert_verifier: Box::new(AllowAnyAuthenticatedClient::new(store)),
}
}
}
impl ClientCertVerifier for CustomCertClientVerifier {
fn offer_client_auth(&self) -> bool {
true
}
fn client_auth_mandatory(&self) -> Option<bool> {
Some(true)
}
fn client_auth_root_subjects(&self) -> Option<DistinguishedNames> {
Some(vec![])
}
fn verify_client_cert(&self, end_entity: &Certificate, intermediates: &[Certificate], now: SystemTime) -> Result<ClientCertVerified, Error> {
self.upstream_cert_verifier.verify_client_cert(end_entity, intermediates, now)
}
}