Merged all workspace projects into a single binary project

This commit is contained in:
2022-09-01 10:11:24 +02:00
parent 3be4c5a68e
commit b24e8ba68b
19 changed files with 96 additions and 123 deletions

87
src/base/cert_utils.rs Normal file
View File

@ -0,0 +1,87 @@
use std::error::Error;
use std::io::{Cursor, ErrorKind};
use rustls::{Certificate, PrivateKey};
use rustls_pemfile::{read_one, Item};
/// Parse PEM certificates bytes into a [`rustls::Certificate`] structure
///
/// An error is returned if not any certificate could be found
pub fn parse_pem_certificates(certs: &[u8]) -> Result<Vec<Certificate>, Box<dyn Error>> {
let certs = rustls_pemfile::certs(&mut Cursor::new(certs))?
.into_iter()
.map(Certificate)
.collect::<Vec<_>>();
if certs.is_empty() {
Err(std::io::Error::new(
ErrorKind::InvalidData,
"Could not find any certificate!",
))?;
unreachable!();
}
Ok(certs)
}
/// Parse PEM private key bytes into a [`rustls::PrivateKey`] structure
pub fn parse_pem_private_key(privkey: &[u8]) -> Result<PrivateKey, Box<dyn Error>> {
let key = match read_one(&mut Cursor::new(privkey))? {
None => {
Err(std::io::Error::new(
ErrorKind::Other,
"Failed to extract private key!",
))?;
unreachable!()
}
Some(Item::PKCS8Key(key)) => key,
Some(Item::RSAKey(key)) => key,
_ => {
Err(std::io::Error::new(
ErrorKind::Other,
"Unsupported private key type!",
))?;
unreachable!();
}
};
Ok(PrivateKey(key))
}
#[cfg(test)]
mod test {
use crate::cert_utils::{parse_pem_certificates, parse_pem_private_key};
const SAMPLE_CERT: &[u8] = include_bytes!("../samples/TCPTunnelTest.crt");
const SAMPLE_KEY: &[u8] = include_bytes!("../samples/TCPTunnelTest.key");
#[test]
fn parse_valid_cert() {
parse_pem_certificates(SAMPLE_CERT).unwrap();
}
#[test]
fn parse_invalid_cert_1() {
parse_pem_certificates("Random content".as_bytes()).unwrap_err();
}
#[test]
fn parse_invalid_cert_2() {
parse_pem_certificates(SAMPLE_KEY).unwrap_err();
}
#[test]
fn parse_valid_key() {
parse_pem_private_key(SAMPLE_KEY).unwrap();
}
#[test]
fn parse_invalid_key_1() {
parse_pem_private_key("Random content".as_bytes()).unwrap_err();
}
#[test]
fn parse_invalid_key_2() {
parse_pem_private_key(SAMPLE_CERT).unwrap_err();
}
}

4
src/base/mod.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod cert_utils;
mod structs;
pub use structs::{RelayedPort, RemoteConfig};

7
src/base/structs.rs Normal file
View File

@ -0,0 +1,7 @@
#[derive(serde::Serialize, serde::Deserialize, Copy, Clone, Debug)]
pub struct RelayedPort {
pub id: usize,
pub port: u16,
}
pub type RemoteConfig = Vec<RelayedPort>;

3
src/lib.rs Normal file
View File

@ -0,0 +1,3 @@
mod base;
pub mod tcp_relay_client;
pub mod tcp_relay_server;

38
src/main.rs Normal file
View File

@ -0,0 +1,38 @@
use clap::{Parser, Subcommand};
use tcp_over_http::tcp_relay_client::client_config::ClientConfig;
use tcp_over_http::tcp_relay_server::server_config::ServerConfig;
#[derive(Parser, Debug)]
#[clap(
author,
version,
about,
long_about = "Encapsulate TCP sockets inside HTTP WebSockets"
)]
struct CliArgs {
#[clap(subcommand)]
command: SubCommands,
}
#[derive(Subcommand, Debug)]
enum SubCommands {
/// Run as server
Server(ServerConfig),
/// Run as client
Client(ClientConfig),
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));
let args: CliArgs = CliArgs::parse();
// Dispatch the request to the appropriate part of the program
match args.command {
SubCommands::Server(c) => tcp_over_http::tcp_relay_server::run_app(c).await,
SubCommands::Client(c) => tcp_over_http::tcp_relay_client::run_app(c).await,
}
}

View File

@ -0,0 +1,100 @@
use bytes::BufMut;
use clap::Parser;
/// TCP relay client
#[derive(Parser, Debug, Clone)]
#[clap(author, version, about, long_about = None)]
pub struct ClientConfig {
/// Access token
#[clap(short, long)]
pub token: Option<String>,
/// Relay server
#[clap(short, long, default_value = "http://127.0.0.1:8000")]
pub relay_url: String,
/// Listen address
#[clap(short, long, default_value = "127.0.0.1")]
pub listen_address: String,
/// Alternative root certificate to use for server authentication
#[clap(short = 'c', long)]
pub root_certificate: Option<String>,
#[clap(skip)]
_root_certificate_cache: Option<Vec<u8>>,
/// TLS certificate for TLS authentication.
#[clap(long)]
pub tls_cert: Option<String>,
#[clap(skip)]
_tls_cert_cache: Option<Vec<u8>>,
/// TLS key for TLS authentication.
#[clap(long)]
pub tls_key: Option<String>,
#[clap(skip)]
_tls_key_cache: Option<Vec<u8>>,
}
impl ClientConfig {
/// Load certificates and put them in cache
pub fn load_certificates(&mut self) {
self._root_certificate_cache = self
.root_certificate
.as_ref()
.map(|c| std::fs::read(c).expect("Failed to read root certificate!"));
self._tls_cert_cache = self
.tls_cert
.as_ref()
.map(|c| std::fs::read(c).expect("Failed to read client certificate!"));
self._tls_key_cache = self
.tls_key
.as_ref()
.map(|c| std::fs::read(c).expect("Failed to read client key!"));
}
/// Get client token, returning a dummy token if none was specified
pub fn get_auth_token(&self) -> &str {
self.token.as_deref().unwrap_or("none")
}
/// Get root certificate content
pub fn get_root_certificate(&self) -> Option<Vec<u8>> {
self._root_certificate_cache.clone()
}
/// Get client certificate & key pair, if available
pub fn get_client_keypair(&self) -> Option<(&Vec<u8>, &Vec<u8>)> {
if let (Some(cert), Some(key)) = (&self._tls_cert_cache, &self._tls_key_cache) {
Some((cert, key))
} else {
None
}
}
/// Get client certificate & key pair, in a single memory buffer
pub fn get_merged_client_keypair(&self) -> Option<Vec<u8>> {
self.get_client_keypair().map(|(c, k)| {
let mut out = k.to_vec();
out.put_slice("\n".as_bytes());
out.put_slice(c);
out
})
}
}
#[cfg(test)]
mod test {
use crate::client_config::ClientConfig;
#[test]
fn verify_cli() {
use clap::CommandFactory;
ClientConfig::command().debug_assert()
}
}

100
src/tcp_relay_client/mod.rs Normal file
View File

@ -0,0 +1,100 @@
extern crate core;
use std::error::Error;
use std::sync::Arc;
use futures::future::join_all;
use reqwest::{Certificate, Identity};
use crate::base::RemoteConfig;
use crate::tcp_relay_client::client_config::ClientConfig;
use crate::tcp_relay_client::relay_client::relay_client;
pub mod client_config;
mod relay_client;
/// Get remote server config i.e. get the list of forwarded ports
async fn get_server_config(conf: &ClientConfig) -> Result<RemoteConfig, Box<dyn Error>> {
let url = format!("{}/config", conf.relay_url);
log::info!("Retrieving configuration on {}", url);
let mut client = reqwest::Client::builder();
// Specify root certificate, if any was specified in the command line
if let Some(cert) = conf.get_root_certificate() {
client = client.add_root_certificate(Certificate::from_pem(&cert)?);
}
// Specify client certificate, if any
if let Some(kp) = conf.get_merged_client_keypair() {
let identity = Identity::from_pem(&kp).expect("Failed to load certificates for reqwest!");
client = client.identity(identity).use_rustls_tls();
}
let client = client.build().expect("Failed to build reqwest client");
let req = client
.get(url)
.header("Authorization", format!("Bearer {}", conf.get_auth_token()))
.send()
.await?;
if req.status().as_u16() != 200 {
log::error!(
"Could not retrieve configuration! (got status {})",
req.status()
);
std::process::exit(2);
}
Ok(req.json::<RemoteConfig>().await?)
}
/// Core logic of the application
pub async fn run_app(mut args: ClientConfig) -> std::io::Result<()> {
args.load_certificates();
let args = Arc::new(args);
// Check arguments coherence
if args.tls_cert.is_some() != args.tls_key.is_some() {
log::error!(
"If you specify one of TLS certificate / key, you must then specify the other!"
);
panic!();
}
if args.get_client_keypair().is_some() {
log::info!("Using client-side authentication");
}
// Get server relay configuration (fetch the list of port to forward)
let remote_conf = match get_server_config(&args).await {
Ok(c) => c,
Err(e) => {
log::error!("Failed to fetch relay configuration from server! {}", e);
panic!();
}
};
// Start to listen port
let mut handles = vec![];
for port in remote_conf {
let listen_address = format!("{}:{}", args.listen_address, port.port);
let h = tokio::spawn(relay_client(
format!(
"{}/ws?id={}&token={}",
args.relay_url,
port.id,
urlencoding::encode(args.get_auth_token())
)
.replace("http", "ws"),
listen_address,
args.clone(),
));
handles.push(h);
}
join_all(handles).await;
Ok(())
}

View File

@ -0,0 +1,146 @@
use std::sync::Arc;
use futures::{SinkExt, StreamExt};
use hyper_rustls::ConfigBuilderExt;
use rustls::RootCertStore;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::tungstenite::Message;
use crate::base::cert_utils;
use crate::tcp_relay_client::client_config::ClientConfig;
pub async fn relay_client(ws_url: String, listen_address: String, config: Arc<ClientConfig>) {
log::info!("Start to listen on {}", listen_address);
let listener = match TcpListener::bind(&listen_address).await {
Ok(l) => l,
Err(e) => {
log::error!("Failed to start to listen on {}! {}", listen_address, e);
std::process::exit(3);
}
};
loop {
let (socket, _) = listener
.accept()
.await
.expect("Failed to accept new connection!");
tokio::spawn(relay_connection(ws_url.clone(), socket, config.clone()));
}
}
/// Relay connection
///
/// WS read => TCP write
/// TCP read => WS write
async fn relay_connection(ws_url: String, socket: TcpStream, conf: Arc<ClientConfig>) {
log::debug!("Connecting to {}...", ws_url);
let ws_stream = if ws_url.starts_with("wss") {
let config = rustls::ClientConfig::builder().with_safe_defaults();
let config = match conf.get_root_certificate() {
None => config.with_native_roots(),
Some(cert) => {
log::debug!("Using custom root certificates");
let mut store = RootCertStore::empty();
cert_utils::parse_pem_certificates(&cert)
.unwrap()
.iter()
.for_each(|c| store.add(c).expect("Failed to add certificate to chain!"));
config.with_root_certificates(store)
}
};
let config = match conf.get_client_keypair() {
None => config.with_no_client_auth(),
Some((certs, key)) => {
let certs = cert_utils::parse_pem_certificates(certs)
.expect("Failed to parse client certificate!");
let key = cert_utils::parse_pem_private_key(key)
.expect("Failed to parse client auth private key!");
config
.with_single_cert(certs, key)
.expect("Failed to set client certificate!")
}
};
let connector = tokio_tungstenite::Connector::Rustls(Arc::new(config));
let (ws_stream, _) =
tokio_tungstenite::connect_async_tls_with_config(ws_url, None, Some(connector))
.await
.expect("Failed to connect to server relay!");
ws_stream
} else {
let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url)
.await
.expect("Failed to connect to server relay!");
ws_stream
};
let (mut tcp_read, mut tcp_write) = socket.into_split();
let (mut ws_write, mut ws_read) = ws_stream.split();
// TCP read -> WS write
let future = async move {
let mut buff: [u8; 5000] = [0; 5000];
loop {
match tcp_read.read(&mut buff).await {
Ok(s) => {
if let Err(e) = ws_write.send(Message::Binary(Vec::from(&buff[0..s]))).await {
log::error!(
"Failed to write to WS connection! {:?} Exiting TCP read -> WS write loop...",e);
break;
}
if s == 0 {
log::info!("Got empty read TCP buffer. Stopping...");
break;
}
}
Err(e) => {
log::error!(
"Failed to read from TCP connection! {:?} Exitin TCP read -> WS write loop...",
e
);
break;
}
}
}
};
tokio::spawn(future);
// WS read -> TCP write
while let Some(m) = ws_read.next().await {
match m {
Err(e) => {
log::error!(
"Failed to read from WebSocket. Breaking read loop... {:?}",
e
);
break;
}
Ok(Message::Binary(b)) => {
if let Err(e) = tcp_write.write_all(&b).await {
log::error!(
"Failed to forward message to websocket. Closing reading end... {:?}",
e
);
break;
};
}
Ok(Message::Close(_r)) => {
log::info!("Server asked to close this WebSocket connection");
break;
}
Ok(m) => log::info!("{:?}", m),
}
}
}

125
src/tcp_relay_server/mod.rs Normal file
View File

@ -0,0 +1,125 @@
use std::sync::Arc;
use actix_web::web::Data;
use actix_web::{middleware, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
use crate::base::{cert_utils, RelayedPort};
use crate::tcp_relay_server::relay_ws::relay_ws;
use crate::tcp_relay_server::server_config::ServerConfig;
use crate::tcp_relay_server::tls_cert_client_verifier::CustomCertClientVerifier;
mod relay_ws;
pub mod server_config;
mod tls_cert_client_verifier;
pub async fn hello_route() -> &'static str {
"Hello world!"
}
pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> impl Responder {
if data.has_token_auth() {
let token = req
.headers()
.get("Authorization")
.map(|t| t.to_str().unwrap_or_default())
.unwrap_or_default()
.strip_prefix("Bearer ")
.unwrap_or_default();
if !data.tokens.iter().any(|t| t.eq(token)) {
return HttpResponse::Unauthorized().json("Missing / invalid token");
}
}
HttpResponse::Ok().json(
data.ports
.iter()
.enumerate()
.map(|(id, port)| RelayedPort {
id,
port: port + data.increment_ports,
})
.collect::<Vec<_>>(),
)
}
pub async fn run_app(mut config: ServerConfig) -> std::io::Result<()> {
// Check if no port are to be forwarded
if config.ports.is_empty() {
log::error!("No port to forward!");
std::process::exit(2);
}
// Read tokens from file, if any
if let Some(file) = &config.tokens_file {
std::fs::read_to_string(file)
.expect("Failed to read tokens file!")
.split('\n')
.filter(|l| !l.is_empty())
.for_each(|t| config.tokens.push(t.to_string()));
}
if !config.has_auth() {
log::error!("No authentication method specified!");
std::process::exit(3);
}
if config.has_tls_client_auth() && !config.has_tls_config() {
log::error!("Cannot provide client auth without TLS configuration!");
panic!();
}
let args = Arc::new(config);
// Load TLS configuration, if any
let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) {
// Load TLS certificate & private key
let cert_file = std::fs::read(cert).expect("Failed to read certificate file");
let key_file = std::fs::read(key).expect("Failed to read server private key");
// Get certificates chain
let cert_chain =
cert_utils::parse_pem_certificates(&cert_file).expect("Failed to extract certificates");
// Get private key
let key =
cert_utils::parse_pem_private_key(&key_file).expect("Failed to extract private key!");
let config = rustls::ServerConfig::builder().with_safe_defaults();
let config = match args.has_tls_client_auth() {
true => config
.with_client_cert_verifier(Arc::new(CustomCertClientVerifier::new(args.clone()))),
false => config.with_no_client_auth(),
};
let config = config
.with_single_cert(cert_chain, key)
.expect("Failed to load TLS certificate!");
Some(config)
} else {
None
};
log::info!("Starting relay on http://{}", args.listen_address);
let args_clone = args.clone();
let server = HttpServer::new(move || {
App::new()
.wrap(middleware::Logger::default())
.app_data(Data::new(args_clone.clone()))
.route("/", web::get().to(hello_route))
.route("/config", web::get().to(config_route))
.route("/ws", web::get().to(relay_ws))
});
if let Some(tls_conf) = tls_config {
server.bind_rustls(&args.listen_address, tls_conf)?
} else {
server.bind(&args.listen_address)?
}
.run()
.await
}

View File

@ -0,0 +1,206 @@
use std::sync::Arc;
use std::time::{Duration, Instant};
use actix::{Actor, ActorContext, AsyncContext, Handler, Message, StreamHandler};
use actix_web::{web, Error, HttpRequest, HttpResponse};
use actix_web_actors::ws;
use actix_web_actors::ws::{CloseCode, CloseReason};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use crate::tcp_relay_server::server_config::ServerConfig;
/// How often heartbeat pings are sent
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
/// How long before lack of client response causes a timeout
const CLIENT_TIMEOUT: Duration = Duration::from_secs(60);
#[derive(Message)]
#[rtype(result = "bool")]
pub struct DataForWebSocket(Vec<u8>);
#[derive(Message)]
#[rtype(result = "()")]
pub struct TCPReadEndClosed;
/// Define HTTP actor
struct RelayWS {
tcp_read: Option<OwnedReadHalf>,
tcp_write: OwnedWriteHalf,
// Client must respond to ping at a specific interval, otherwise we drop connection
hb: Instant,
// TODO : handle socket close
}
impl RelayWS {
/// helper method that sends ping to client every second.
///
/// also this method checks heartbeats from client
fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) {
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
// check client heartbeats
if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
// heartbeat timed out
log::warn!("WebSocket Client heartbeat failed, disconnecting!");
// stop actor
ctx.stop();
// don't try to send a ping
return;
}
log::debug!("Send ping message...");
ctx.ping(b"");
});
}
}
impl Actor for RelayWS {
type Context = ws::WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
self.hb(ctx);
// Start to read on remote socket
let mut read_half = self.tcp_read.take().unwrap();
let addr = ctx.address();
let future = async move {
let mut buff: [u8; 5000] = [0; 5000];
loop {
match read_half.read(&mut buff).await {
Ok(l) => {
if l == 0 {
log::info!("Got empty read. Closing read end...");
addr.do_send(TCPReadEndClosed);
return;
}
let to_send = DataForWebSocket(Vec::from(&buff[0..l]));
if let Err(e) = addr.send(to_send).await {
log::error!("Failed to send to websocket. Stopping now... {:?}", e);
return;
}
}
Err(e) => {
log::error!("Failed to read from remote socket. Stopping now... {:?}", e);
break;
}
};
}
log::info!("Exited read loop");
};
tokio::spawn(future);
}
}
/// Handler for ws::Message message
impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for RelayWS {
fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
match msg {
Ok(ws::Message::Ping(msg)) => ctx.pong(&msg),
Ok(ws::Message::Pong(_)) => self.hb = Instant::now(),
Ok(ws::Message::Text(text)) => ctx.text(text),
Ok(ws::Message::Close(_reason)) => ctx.stop(),
Ok(ws::Message::Binary(data)) => {
if let Err(e) =
futures::executor::block_on(self.tcp_write.write_all(&data.to_vec()))
{
log::error!("Failed to forward some data, closing connection! {:?}", e);
ctx.stop();
}
if data.is_empty() {
log::info!("Got empty binary message. Closing websocket...");
ctx.stop();
}
}
_ => (),
}
}
}
impl Handler<DataForWebSocket> for RelayWS {
type Result = bool;
fn handle(&mut self, msg: DataForWebSocket, ctx: &mut Self::Context) -> Self::Result {
ctx.binary(msg.0);
true
}
}
impl Handler<TCPReadEndClosed> for RelayWS {
type Result = ();
fn handle(&mut self, _msg: TCPReadEndClosed, ctx: &mut Self::Context) -> Self::Result {
ctx.close(Some(CloseReason {
code: CloseCode::Away,
description: Some("TCP read end closed.".to_string()),
}));
}
}
#[derive(serde::Deserialize)]
pub struct WebSocketQuery {
id: usize,
token: Option<String>,
}
pub async fn relay_ws(
req: HttpRequest,
stream: web::Payload,
query: web::Query<WebSocketQuery>,
conf: web::Data<Arc<ServerConfig>>,
) -> Result<HttpResponse, Error> {
if conf.has_token_auth()
&& !conf
.tokens
.iter()
.any(|t| t == query.token.as_deref().unwrap_or_default())
{
log::error!(
"Rejected WS request from {:?} due to invalid token!",
req.peer_addr()
);
return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!"));
}
if conf.ports.len() <= query.id {
log::error!(
"Rejected WS request from {:?} due to invalid port number!",
req.peer_addr()
);
return Ok(HttpResponse::BadRequest().json("Invalid port number!"));
}
let upstream_addr = format!("{}:{}", conf.upstream_server, conf.ports[query.id]);
let (tcp_read, tcp_write) = match TcpStream::connect(&upstream_addr).await {
Ok(s) => s.into_split(),
Err(e) => {
log::error!(
"Failed to establish connection with upstream server! {:?}",
e
);
return Ok(HttpResponse::InternalServerError().json("Failed to establish connection!"));
}
};
let relay = RelayWS {
tcp_read: Some(tcp_read),
tcp_write,
hb: Instant::now(),
};
let resp = ws::start(relay, &req, stream);
log::info!(
"Opening new WS connection for {:?} to {}",
req.peer_addr(),
upstream_addr
);
resp
}

View File

@ -0,0 +1,84 @@
use clap::Parser;
/// TCP relay server mode
#[derive(Parser, Debug, Clone)]
#[clap(
author,
version,
about,
long_about = "TCP-over-HTTP server. This program can be configured behind a reverse-proxy (without TLS authentication)."
)]
pub struct ServerConfig {
/// Access tokens
#[clap(short, long)]
pub tokens: Vec<String>,
/// Access tokens stored in a file, one token per line
#[clap(long)]
pub tokens_file: Option<String>,
/// Forwarded ports
#[clap(short, long)]
pub ports: Vec<u16>,
/// Upstream server
#[clap(short, long, default_value = "127.0.0.1")]
pub upstream_server: String,
/// HTTP server listen address
#[clap(short, long, default_value = "0.0.0.0:8000")]
pub listen_address: String,
/// Increment ports on client. Useful for debugging and running both client and server
/// on the same machine
#[clap(short, long, default_value_t = 0)]
pub increment_ports: u16,
/// TLS certificate. Specify also private key to use HTTPS/TLS instead of HTTP
#[clap(long)]
pub tls_cert: Option<String>,
/// TLS private key. Specify also certificate to use HTTPS/TLS instead of HTTP
#[clap(long)]
pub tls_key: Option<String>,
/// Restrict TLS client authentication to certificates signed directly or indirectly by the
/// provided root certificates
///
/// This option automatically enable TLS client authentication
#[clap(long)]
pub tls_client_auth_root_cert: Option<String>,
/// TLS client authentication revocation list (CRL file)
#[clap(long)]
pub tls_revocation_list: Option<String>,
}
impl ServerConfig {
pub fn has_token_auth(&self) -> bool {
!self.tokens.is_empty()
}
pub fn has_tls_config(&self) -> bool {
self.tls_cert.is_some() && self.tls_key.is_some()
}
pub fn has_tls_client_auth(&self) -> bool {
self.tls_client_auth_root_cert.is_some()
}
pub fn has_auth(&self) -> bool {
self.has_token_auth() || self.has_tls_client_auth()
}
}
#[cfg(test)]
mod test {
use crate::server_config::ServerConfig;
#[test]
fn verify_cli() {
use clap::CommandFactory;
ServerConfig::command().debug_assert()
}
}

View File

@ -0,0 +1,103 @@
use std::sync::Arc;
use std::time::SystemTime;
use rustls::internal::msgs::enums::AlertDescription;
use rustls::server::{AllowAnyAuthenticatedClient, ClientCertVerified, ClientCertVerifier};
use rustls::{Certificate, DistinguishedNames, Error, RootCertStore};
use x509_parser::prelude::{CertificateRevocationList, FromDer, X509Certificate};
use crate::base::cert_utils::parse_pem_certificates;
use crate::tcp_relay_server::server_config::ServerConfig;
pub struct CustomCertClientVerifier {
upstream_cert_verifier: Box<Arc<dyn ClientCertVerifier>>,
crl: Option<Vec<u8>>,
}
impl CustomCertClientVerifier {
pub fn new(conf: Arc<ServerConfig>) -> Self {
// Build root certifications list
let cert_path = conf
.tls_client_auth_root_cert
.as_deref()
.expect("No root certificates for client authentication provided!");
let cert_file = std::fs::read(cert_path)
.expect("Failed to read root certificates for client authentication!");
let root_certs = parse_pem_certificates(&cert_file)
.expect("Failed to read root certificates for server authentication!");
if root_certs.is_empty() {
log::error!("No certificates found for client authentication!");
panic!();
}
let mut store = RootCertStore::empty();
for cert in root_certs {
store
.add(&cert)
.expect("Failed to add certificate to root store");
}
// Parse CRL file (if any)
let crl = if let Some(crl_file) = &conf.tls_revocation_list {
let crl_file = std::fs::read(crl_file).expect("Failed to open / read CRL file!");
let parsed_crl = pem::parse(crl_file).expect("Failed to decode CRL file!");
Some(parsed_crl.contents)
} else {
None
};
Self {
upstream_cert_verifier: Box::new(AllowAnyAuthenticatedClient::new(store)),
crl,
}
}
}
impl ClientCertVerifier for CustomCertClientVerifier {
fn offer_client_auth(&self) -> bool {
true
}
fn client_auth_mandatory(&self) -> Option<bool> {
Some(true)
}
fn client_auth_root_subjects(&self) -> Option<DistinguishedNames> {
Some(vec![])
}
fn verify_client_cert(
&self,
end_entity: &Certificate,
intermediates: &[Certificate],
now: SystemTime,
) -> Result<ClientCertVerified, Error> {
// Check the certificates sent by the client has been revoked
if let Some(crl) = &self.crl {
let (_rem, crl) =
CertificateRevocationList::from_der(crl).expect("Failed to read CRL!");
let (_rem, cert) =
X509Certificate::from_der(&end_entity.0).expect("Failed to read certificate!");
for revoked in crl.iter_revoked_certificates() {
if revoked.user_certificate == cert.serial {
log::error!(
"Client attempted to use a revoked certificate! Serial={:?} Subject={}",
cert.serial,
cert.subject
);
return Err(Error::AlertReceived(AlertDescription::CertificateRevoked));
}
}
}
self.upstream_cert_verifier
.verify_client_cert(end_entity, intermediates, now)
}
}