Add client TLS auth on server side
This commit is contained in:
parent
1b95b10553
commit
27b52dfcb7
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -1596,6 +1596,7 @@ dependencies = [
|
|||||||
"rustls-pemfile",
|
"rustls-pemfile",
|
||||||
"serde",
|
"serde",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"webpki",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -8,7 +8,7 @@ static mut ROOT_CERT: Option<Vec<u8>> = None;
|
|||||||
pub struct ClientConfig {
|
pub struct ClientConfig {
|
||||||
/// Access token
|
/// Access token
|
||||||
#[clap(short, long)]
|
#[clap(short, long)]
|
||||||
pub token: String,
|
pub token: Option<String>,
|
||||||
|
|
||||||
/// Relay server
|
/// Relay server
|
||||||
#[clap(short, long, default_value = "http://127.0.0.1:8000")]
|
#[clap(short, long, default_value = "http://127.0.0.1:8000")]
|
||||||
@ -24,6 +24,11 @@ pub struct ClientConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ClientConfig {
|
impl ClientConfig {
|
||||||
|
/// Get client token, returning a dummy token if none was specified
|
||||||
|
pub fn get_auth_token(&self) -> &str {
|
||||||
|
self.token.as_deref().unwrap_or("none")
|
||||||
|
}
|
||||||
|
|
||||||
/// Load root certificate
|
/// Load root certificate
|
||||||
pub fn get_root_certificate(&self) -> Option<Vec<u8>> {
|
pub fn get_root_certificate(&self) -> Option<Vec<u8>> {
|
||||||
self.root_certificate.as_ref()?;
|
self.root_certificate.as_ref()?;
|
||||||
|
@ -23,7 +23,7 @@ async fn get_server_config(config: &ClientConfig) -> Result<RemoteConfig, Box<dy
|
|||||||
let client = client.build().expect("Failed to build reqwest client");
|
let client = client.build().expect("Failed to build reqwest client");
|
||||||
|
|
||||||
let req = client.get(url)
|
let req = client.get(url)
|
||||||
.header("Authorization", format!("Bearer {}", config.token))
|
.header("Authorization", format!("Bearer {}", config.get_auth_token()))
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
if req.status().as_u16() != 200 {
|
if req.status().as_u16() != 200 {
|
||||||
@ -51,7 +51,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
|
|
||||||
let h = tokio::spawn(relay_client(
|
let h = tokio::spawn(relay_client(
|
||||||
format!("{}/ws?id={}&token={}",
|
format!("{}/ws?id={}&token={}",
|
||||||
args.relay_url, port.id, urlencoding::encode(&args.token))
|
args.relay_url, port.id, urlencoding::encode(args.get_auth_token()))
|
||||||
.replace("http", "ws"),
|
.replace("http", "ws"),
|
||||||
listen_address,
|
listen_address,
|
||||||
args.clone(),
|
args.clone(),
|
||||||
|
@ -17,3 +17,4 @@ tokio = { version = "1", features = ["full"] }
|
|||||||
futures = "0.3.24"
|
futures = "0.3.24"
|
||||||
rustls = "0.20.6"
|
rustls = "0.20.6"
|
||||||
rustls-pemfile = "1.0.1"
|
rustls-pemfile = "1.0.1"
|
||||||
|
webpki = "0.22.0"
|
@ -1,2 +1,3 @@
|
|||||||
pub mod server_config;
|
pub mod server_config;
|
||||||
pub mod relay_ws;
|
pub mod relay_ws;
|
||||||
|
pub mod tls_cert_client_verifier;
|
@ -11,20 +11,23 @@ use rustls_pemfile::{certs, Item, read_one};
|
|||||||
use base::RelayedPort;
|
use base::RelayedPort;
|
||||||
use tcp_relay_server::relay_ws::relay_ws;
|
use tcp_relay_server::relay_ws::relay_ws;
|
||||||
use tcp_relay_server::server_config::ServerConfig;
|
use tcp_relay_server::server_config::ServerConfig;
|
||||||
|
use tcp_relay_server::tls_cert_client_verifier::CustomCertClientVerifier;
|
||||||
|
|
||||||
pub async fn hello_route() -> &'static str {
|
pub async fn hello_route() -> &'static str {
|
||||||
"Hello world!"
|
"Hello world!"
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> impl Responder {
|
pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> impl Responder {
|
||||||
let token = req.headers().get("Authorization")
|
if data.has_token_auth() {
|
||||||
.map(|t| t.to_str().unwrap_or_default())
|
let token = req.headers().get("Authorization")
|
||||||
.unwrap_or_default()
|
.map(|t| t.to_str().unwrap_or_default())
|
||||||
.strip_prefix("Bearer ")
|
.unwrap_or_default()
|
||||||
.unwrap_or_default();
|
.strip_prefix("Bearer ")
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
if !data.tokens.iter().any(|t| t.eq(token)) {
|
if !data.tokens.iter().any(|t| t.eq(token)) {
|
||||||
return HttpResponse::Unauthorized().json("Missing / invalid token");
|
return HttpResponse::Unauthorized().json("Missing / invalid token");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HttpResponse::Ok().json(
|
HttpResponse::Ok().json(
|
||||||
@ -41,11 +44,33 @@ async fn main() -> std::io::Result<()> {
|
|||||||
|
|
||||||
let mut args: ServerConfig = ServerConfig::parse();
|
let mut args: ServerConfig = ServerConfig::parse();
|
||||||
|
|
||||||
|
// Check if no port are to be forwarded
|
||||||
if args.ports.is_empty() {
|
if args.ports.is_empty() {
|
||||||
log::error!("No port to forward!");
|
log::error!("No port to forward!");
|
||||||
std::process::exit(2);
|
std::process::exit(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read tokens from file, if any
|
||||||
|
if let Some(file) = &args.tokens_file {
|
||||||
|
std::fs::read_to_string(file)
|
||||||
|
.expect("Failed to read tokens file!")
|
||||||
|
.split('\n')
|
||||||
|
.filter(|l| !l.is_empty())
|
||||||
|
.for_each(|t| args.tokens.push(t.to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !args.has_auth() {
|
||||||
|
log::error!("No authentication method specified!");
|
||||||
|
std::process::exit(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.has_tls_client_auth() && !args.has_tls_config() {
|
||||||
|
log::error!("Cannot provide client auth without TLS configuration!");
|
||||||
|
panic!();
|
||||||
|
}
|
||||||
|
|
||||||
|
let args = Arc::new(args);
|
||||||
|
|
||||||
// Load TLS configuration, if any
|
// Load TLS configuration, if any
|
||||||
let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) {
|
let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) {
|
||||||
|
|
||||||
@ -74,31 +99,22 @@ async fn main() -> std::io::Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let config = rustls::ServerConfig::builder()
|
let config = rustls::ServerConfig::builder()
|
||||||
.with_safe_defaults()
|
.with_safe_defaults();
|
||||||
.with_no_client_auth()
|
|
||||||
.with_single_cert(cert_chain, PrivateKey(key))
|
let config = match args.has_tls_client_auth() {
|
||||||
|
true => config.with_client_cert_verifier(
|
||||||
|
Arc::new(CustomCertClientVerifier::new(args.clone()))),
|
||||||
|
false => config.with_no_client_auth()
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = config.with_single_cert(cert_chain, PrivateKey(key))
|
||||||
.expect("Failed to load TLS certificate!");
|
.expect("Failed to load TLS certificate!");
|
||||||
|
|
||||||
Some(config)
|
Some(config)
|
||||||
} else { None };
|
} else { None };
|
||||||
|
|
||||||
// Read tokens from file, if any
|
|
||||||
if let Some(file) = &args.tokens_file {
|
|
||||||
std::fs::read_to_string(file)
|
|
||||||
.expect("Failed to read tokens file!")
|
|
||||||
.split('\n')
|
|
||||||
.filter(|l| !l.is_empty())
|
|
||||||
.for_each(|t| args.tokens.push(t.to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.tokens.is_empty() {
|
|
||||||
log::error!("No tokens specified!");
|
|
||||||
std::process::exit(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
log::info!("Starting relay on http://{}", args.listen_address);
|
log::info!("Starting relay on http://{}", args.listen_address);
|
||||||
|
|
||||||
let args = Arc::new(args);
|
|
||||||
let args_clone = args.clone();
|
let args_clone = args.clone();
|
||||||
let server = HttpServer::new(move || {
|
let server = HttpServer::new(move || {
|
||||||
App::new()
|
App::new()
|
||||||
|
@ -94,7 +94,6 @@ impl Actor for RelayWS {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log::info!("Exited read loop");
|
log::info!("Exited read loop");
|
||||||
// TODO : notify context
|
|
||||||
};
|
};
|
||||||
|
|
||||||
tokio::spawn(future);
|
tokio::spawn(future);
|
||||||
@ -148,13 +147,14 @@ impl Handler<TCPReadEndClosed> for RelayWS {
|
|||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
pub struct WebSocketQuery {
|
pub struct WebSocketQuery {
|
||||||
id: usize,
|
id: usize,
|
||||||
token: String,
|
token: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn relay_ws(req: HttpRequest, stream: web::Payload,
|
pub async fn relay_ws(req: HttpRequest, stream: web::Payload,
|
||||||
query: web::Query<WebSocketQuery>,
|
query: web::Query<WebSocketQuery>,
|
||||||
conf: web::Data<Arc<ServerConfig>>) -> Result<HttpResponse, Error> {
|
conf: web::Data<Arc<ServerConfig>>) -> Result<HttpResponse, Error> {
|
||||||
if !conf.tokens.contains(&query.token) {
|
if conf.has_token_auth() &&
|
||||||
|
!conf.tokens.iter().any(|t| t == query.token.as_deref().unwrap_or_default()) {
|
||||||
log::error!("Rejected WS request from {:?} due to invalid token!", req.peer_addr());
|
log::error!("Rejected WS request from {:?} due to invalid token!", req.peer_addr());
|
||||||
return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!"));
|
return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!"));
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@ use clap::Parser;
|
|||||||
/// TCP relay server
|
/// TCP relay server
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
#[clap(author, version, about,
|
#[clap(author, version, about,
|
||||||
long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")]
|
long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")]
|
||||||
pub struct ServerConfig {
|
pub struct ServerConfig {
|
||||||
/// Access tokens
|
/// Access tokens
|
||||||
#[clap(short, long)]
|
#[clap(short, long)]
|
||||||
@ -37,4 +37,29 @@ pub struct ServerConfig {
|
|||||||
/// TLS private key. Specify also certificate to use HTTPS/TLS instead of HTTP
|
/// TLS private key. Specify also certificate to use HTTPS/TLS instead of HTTP
|
||||||
#[clap(long)]
|
#[clap(long)]
|
||||||
pub tls_key: Option<String>,
|
pub tls_key: Option<String>,
|
||||||
|
|
||||||
|
/// Restrict TLS client authentication to certificates signed directly or indirectly by the
|
||||||
|
/// provided root certificates
|
||||||
|
///
|
||||||
|
/// This option automatically enable TLS client authentication
|
||||||
|
#[clap(long)]
|
||||||
|
pub tls_client_auth_root_cert: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerConfig {
|
||||||
|
pub fn has_token_auth(&self) -> bool {
|
||||||
|
!self.tokens.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_tls_config(&self) -> bool {
|
||||||
|
self.tls_cert.is_some() && self.tls_key.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_tls_client_auth(&self) -> bool {
|
||||||
|
self.tls_client_auth_root_cert.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_auth(&self) -> bool {
|
||||||
|
self.has_token_auth() || self.has_tls_client_auth()
|
||||||
|
}
|
||||||
}
|
}
|
61
tcp_relay_server/src/tls_cert_client_verifier.rs
Normal file
61
tcp_relay_server/src/tls_cert_client_verifier.rs
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
use std::fs::File;
|
||||||
|
use std::io::BufReader;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::SystemTime;
|
||||||
|
|
||||||
|
use rustls::{Certificate, DistinguishedNames, Error, RootCertStore};
|
||||||
|
use rustls::server::{AllowAnyAuthenticatedClient, ClientCertVerified, ClientCertVerifier};
|
||||||
|
use rustls_pemfile::certs;
|
||||||
|
|
||||||
|
use crate::server_config::ServerConfig;
|
||||||
|
|
||||||
|
pub struct CustomCertClientVerifier {
|
||||||
|
upstream_cert_verifier: Box<Arc<dyn ClientCertVerifier>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CustomCertClientVerifier {
|
||||||
|
pub fn new(conf: Arc<ServerConfig>) -> Self {
|
||||||
|
let cert_path = conf.tls_client_auth_root_cert.as_deref()
|
||||||
|
.expect("No root certificates for client authentication provided!");
|
||||||
|
let cert_file = &mut BufReader::new(File::open(cert_path)
|
||||||
|
.expect("Failed to read root certificates for client authentication!"));
|
||||||
|
|
||||||
|
let root_certs = certs(cert_file).unwrap()
|
||||||
|
.into_iter()
|
||||||
|
.map(Certificate)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
if root_certs.is_empty() {
|
||||||
|
log::error!("No certificates found for client authentication!");
|
||||||
|
panic!();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut store = RootCertStore::empty();
|
||||||
|
for cert in root_certs {
|
||||||
|
store.add(&cert).expect("Failed to add certificate to root store");
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
upstream_cert_verifier: Box::new(AllowAnyAuthenticatedClient::new(store)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientCertVerifier for CustomCertClientVerifier {
|
||||||
|
fn offer_client_auth(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client_auth_mandatory(&self) -> Option<bool> {
|
||||||
|
Some(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client_auth_root_subjects(&self) -> Option<DistinguishedNames> {
|
||||||
|
Some(vec![])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verify_client_cert(&self, end_entity: &Certificate, intermediates: &[Certificate], now: SystemTime) -> Result<ClientCertVerified, Error> {
|
||||||
|
self.upstream_cert_verifier.verify_client_cert(end_entity, intermediates, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user