cargo fmt

This commit is contained in:
Pierre HUBERT 2022-08-31 14:36:07 +02:00
parent cd0f6fea94
commit 3cbbd72a14
10 changed files with 170 additions and 99 deletions

View File

@ -4,4 +4,4 @@ pub struct RelayedPort {
pub port: u16, pub port: u16,
} }
pub type RemoteConfig = Vec<RelayedPort>; pub type RemoteConfig = Vec<RelayedPort>;

View File

@ -42,17 +42,20 @@ pub struct ClientConfig {
impl ClientConfig { impl ClientConfig {
/// Load certificates and put them in cache /// Load certificates and put them in cache
pub fn load_certificates(&mut self) { pub fn load_certificates(&mut self) {
self._root_certificate_cache = self.root_certificate.as_ref() self._root_certificate_cache = self
.map(|c| std::fs::read(c) .root_certificate
.expect("Failed to read root certificate!")); .as_ref()
.map(|c| std::fs::read(c).expect("Failed to read root certificate!"));
self._tls_cert_cache = self.tls_cert.as_ref() self._tls_cert_cache = self
.map(|c| std::fs::read(c) .tls_cert
.expect("Failed to read client certificate!")); .as_ref()
.map(|c| std::fs::read(c).expect("Failed to read client certificate!"));
self._tls_key_cache = self.tls_key.as_ref() self._tls_key_cache = self
.map(|c| std::fs::read(c) .tls_key
.expect("Failed to read client key!")); .as_ref()
.map(|c| std::fs::read(c).expect("Failed to read client key!"));
} }
/// Get client token, returning a dummy token if none was specified /// Get client token, returning a dummy token if none was specified
@ -69,18 +72,19 @@ impl ClientConfig {
pub fn get_client_keypair(&self) -> Option<(&Vec<u8>, &Vec<u8>)> { pub fn get_client_keypair(&self) -> Option<(&Vec<u8>, &Vec<u8>)> {
if let (Some(cert), Some(key)) = (&self._tls_cert_cache, &self._tls_key_cache) { if let (Some(cert), Some(key)) = (&self._tls_cert_cache, &self._tls_key_cache) {
Some((cert, key)) Some((cert, key))
} else { None } } else {
None
}
} }
/// Get client certificate & key pair, in a single memory buffer /// Get client certificate & key pair, in a single memory buffer
pub fn get_merged_client_keypair(&self) -> Option<Vec<u8>> { pub fn get_merged_client_keypair(&self) -> Option<Vec<u8>> {
self.get_client_keypair() self.get_client_keypair().map(|(c, k)| {
.map(|(c, k)| { let mut out = k.to_vec();
let mut out = k.to_vec(); out.put_slice("\n".as_bytes());
out.put_slice("\n".as_bytes()); out.put_slice(c);
out.put_slice(c); out
out })
})
} }
} }
@ -93,4 +97,4 @@ mod test {
use clap::CommandFactory; use clap::CommandFactory;
ClientConfig::command().debug_assert() ClientConfig::command().debug_assert()
} }
} }

View File

@ -1,2 +1,2 @@
pub mod client_config; pub mod client_config;
pub mod relay_client; pub mod relay_client;

View File

@ -24,20 +24,25 @@ async fn get_server_config(config: &ClientConfig) -> Result<RemoteConfig, Box<dy
// Specify client certificate, if any // Specify client certificate, if any
if let Some(kp) = config.get_merged_client_keypair() { if let Some(kp) = config.get_merged_client_keypair() {
let identity = Identity::from_pem(&kp) let identity = Identity::from_pem(&kp).expect("Failed to load certificates for reqwest!");
.expect("Failed to load certificates for reqwest!"); client = client.identity(identity).use_rustls_tls();
client = client.identity(identity)
.use_rustls_tls();
} }
let client = client.build().expect("Failed to build reqwest client"); let client = client.build().expect("Failed to build reqwest client");
let req = client.get(url) let req = client
.header("Authorization", format!("Bearer {}", config.get_auth_token())) .get(url)
.header(
"Authorization",
format!("Bearer {}", config.get_auth_token()),
)
.send() .send()
.await?; .await?;
if req.status().as_u16() != 200 { if req.status().as_u16() != 200 {
log::error!("Could not retrieve configuration! (got status {})", req.status()); log::error!(
"Could not retrieve configuration! (got status {})",
req.status()
);
std::process::exit(2); std::process::exit(2);
} }
@ -54,7 +59,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Check arguments coherence // Check arguments coherence
if args.tls_cert.is_some() != args.tls_key.is_some() { if args.tls_cert.is_some() != args.tls_key.is_some() {
log::error!("If you specify one of TLS certificate / key, you must then specify the other!"); log::error!(
"If you specify one of TLS certificate / key, you must then specify the other!"
);
panic!(); panic!();
} }
@ -71,9 +78,13 @@ async fn main() -> Result<(), Box<dyn Error>> {
let listen_address = format!("{}:{}", args.listen_address, port.port); let listen_address = format!("{}:{}", args.listen_address, port.port);
let h = tokio::spawn(relay_client( let h = tokio::spawn(relay_client(
format!("{}/ws?id={}&token={}", format!(
args.relay_url, port.id, urlencoding::encode(args.get_auth_token())) "{}/ws?id={}&token={}",
.replace("http", "ws"), args.relay_url,
port.id,
urlencoding::encode(args.get_auth_token())
)
.replace("http", "ws"),
listen_address, listen_address,
args.clone(), args.clone(),
)); ));
@ -83,4 +94,4 @@ async fn main() -> Result<(), Box<dyn Error>> {
join_all(handles).await; join_all(handles).await;
Ok(()) Ok(())
} }

View File

@ -4,7 +4,7 @@ use std::sync::Arc;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use hyper_rustls::ConfigBuilderExt; use hyper_rustls::ConfigBuilderExt;
use rustls::{Certificate, PrivateKey, RootCertStore}; use rustls::{Certificate, PrivateKey, RootCertStore};
use rustls_pemfile::{Item, read_one}; use rustls_pemfile::{read_one, Item};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Message;
@ -22,7 +22,9 @@ pub async fn relay_client(ws_url: String, listen_address: String, config: Arc<Cl
}; };
loop { loop {
let (socket, _) = listener.accept().await let (socket, _) = listener
.accept()
.await
.expect("Failed to accept new connection!"); .expect("Failed to accept new connection!");
tokio::spawn(relay_connection(ws_url.clone(), socket, config.clone())); tokio::spawn(relay_connection(ws_url.clone(), socket, config.clone()));
@ -37,8 +39,7 @@ async fn relay_connection(ws_url: String, socket: TcpStream, conf: Arc<ClientCon
log::debug!("Connecting to {}...", ws_url); log::debug!("Connecting to {}...", ws_url);
let ws_stream = if ws_url.starts_with("wss") { let ws_stream = if ws_url.starts_with("wss") {
let config = rustls::ClientConfig::builder() let config = rustls::ClientConfig::builder().with_safe_defaults();
.with_safe_defaults();
let config = match conf.get_root_certificate() { let config = match conf.get_root_certificate() {
None => config.with_native_roots(), None => config.with_native_roots(),
@ -65,7 +66,8 @@ async fn relay_connection(ws_url: String, socket: TcpStream, conf: Arc<ClientCon
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let key = match read_one(&mut Cursor::new(key)) let key = match read_one(&mut Cursor::new(key))
.expect("Failed to read client private key!") { .expect("Failed to read client private key!")
{
None => { None => {
log::error!("Failed to extract private key!"); log::error!("Failed to extract private key!");
panic!(); panic!();
@ -78,30 +80,29 @@ async fn relay_connection(ws_url: String, socket: TcpStream, conf: Arc<ClientCon
} }
}; };
config.with_single_cert(certs, PrivateKey(key)) config
.with_single_cert(certs, PrivateKey(key))
.expect("Failed to set client certificate!") .expect("Failed to set client certificate!")
} }
}; };
let connector = tokio_tungstenite::Connector::Rustls(Arc::new(config)); let connector = tokio_tungstenite::Connector::Rustls(Arc::new(config));
let (ws_stream, _) = tokio_tungstenite::connect_async_tls_with_config( let (ws_stream, _) =
ws_url, tokio_tungstenite::connect_async_tls_with_config(ws_url, None, Some(connector))
None, .await
Some(connector)) .expect("Failed to connect to server relay!");
.await.expect("Failed to connect to server relay!");
ws_stream ws_stream
} else { } else {
let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url) let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url)
.await.expect("Failed to connect to server relay!"); .await
.expect("Failed to connect to server relay!");
ws_stream ws_stream
}; };
let (mut tcp_read, mut tcp_write) = socket.into_split(); let (mut tcp_read, mut tcp_write) = socket.into_split();
let (mut ws_write, mut ws_read) = let (mut ws_write, mut ws_read) = ws_stream.split();
ws_stream.split();
// TCP read -> WS write // TCP read -> WS write
let future = async move { let future = async move {
@ -136,12 +137,18 @@ async fn relay_connection(ws_url: String, socket: TcpStream, conf: Arc<ClientCon
while let Some(m) = ws_read.next().await { while let Some(m) = ws_read.next().await {
match m { match m {
Err(e) => { Err(e) => {
log::error!("Failed to read from WebSocket. Breaking read loop... {:?}", e); log::error!(
"Failed to read from WebSocket. Breaking read loop... {:?}",
e
);
break; break;
} }
Ok(Message::Binary(b)) => { Ok(Message::Binary(b)) => {
if let Err(e) = tcp_write.write_all(&b).await { if let Err(e) = tcp_write.write_all(&b).await {
log::error!("Failed to forward message to websocket. Closing reading end... {:?}", e); log::error!(
"Failed to forward message to websocket. Closing reading end... {:?}",
e
);
break; break;
}; };
} }
@ -149,7 +156,7 @@ async fn relay_connection(ws_url: String, socket: TcpStream, conf: Arc<ClientCon
log::info!("Server asked to close this WebSocket connection"); log::info!("Server asked to close this WebSocket connection");
break; break;
} }
Ok(m) => log::info!("{:?}", m) Ok(m) => log::info!("{:?}", m),
} }
} }
} }

View File

@ -1,3 +1,3 @@
pub mod server_config;
pub mod relay_ws; pub mod relay_ws;
pub mod tls_cert_client_verifier; pub mod server_config;
pub mod tls_cert_client_verifier;

View File

@ -2,11 +2,11 @@ use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use std::sync::Arc; use std::sync::Arc;
use actix_web::{App, HttpRequest, HttpResponse, HttpServer, middleware, Responder, web};
use actix_web::web::Data; use actix_web::web::Data;
use actix_web::{middleware, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
use clap::Parser; use clap::Parser;
use rustls::{Certificate, PrivateKey}; use rustls::{Certificate, PrivateKey};
use rustls_pemfile::{certs, Item, read_one}; use rustls_pemfile::{certs, read_one, Item};
use base::RelayedPort; use base::RelayedPort;
use tcp_relay_server::relay_ws::relay_ws; use tcp_relay_server::relay_ws::relay_ws;
@ -19,7 +19,9 @@ pub async fn hello_route() -> &'static str {
pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> impl Responder { pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> impl Responder {
if data.has_token_auth() { if data.has_token_auth() {
let token = req.headers().get("Authorization") let token = req
.headers()
.get("Authorization")
.map(|t| t.to_str().unwrap_or_default()) .map(|t| t.to_str().unwrap_or_default())
.unwrap_or_default() .unwrap_or_default()
.strip_prefix("Bearer ") .strip_prefix("Bearer ")
@ -31,10 +33,14 @@ pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> im
} }
HttpResponse::Ok().json( HttpResponse::Ok().json(
data.ports.iter() data.ports
.iter()
.enumerate() .enumerate()
.map(|(id, port)| RelayedPort { id, port: port + data.increment_ports }) .map(|(id, port)| RelayedPort {
.collect::<Vec<_>>() id,
port: port + data.increment_ports,
})
.collect::<Vec<_>>(),
) )
} }
@ -73,13 +79,13 @@ async fn main() -> std::io::Result<()> {
// Load TLS configuration, if any // Load TLS configuration, if any
let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) { let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) {
// Load TLS certificate & private key // Load TLS certificate & private key
let cert_file = &mut BufReader::new(File::open(cert).unwrap()); let cert_file = &mut BufReader::new(File::open(cert).unwrap());
let key_file = &mut BufReader::new(File::open(key).unwrap()); let key_file = &mut BufReader::new(File::open(key).unwrap());
// Get certificates chain // Get certificates chain
let cert_chain = certs(cert_file).unwrap() let cert_chain = certs(cert_file)
.unwrap()
.into_iter() .into_iter()
.map(Certificate) .map(Certificate)
.collect(); .collect();
@ -98,20 +104,22 @@ async fn main() -> std::io::Result<()> {
} }
}; };
let config = rustls::ServerConfig::builder() let config = rustls::ServerConfig::builder().with_safe_defaults();
.with_safe_defaults();
let config = match args.has_tls_client_auth() { let config = match args.has_tls_client_auth() {
true => config.with_client_cert_verifier( true => config
Arc::new(CustomCertClientVerifier::new(args.clone()))), .with_client_cert_verifier(Arc::new(CustomCertClientVerifier::new(args.clone()))),
false => config.with_no_client_auth() false => config.with_no_client_auth(),
}; };
let config = config.with_single_cert(cert_chain, PrivateKey(key)) let config = config
.with_single_cert(cert_chain, PrivateKey(key))
.expect("Failed to load TLS certificate!"); .expect("Failed to load TLS certificate!");
Some(config) Some(config)
} else { None }; } else {
None
};
log::info!("Starting relay on http://{}", args.listen_address); log::info!("Starting relay on http://{}", args.listen_address);
@ -129,6 +137,7 @@ async fn main() -> std::io::Result<()> {
server.bind_rustls(&args.listen_address, tls_conf)? server.bind_rustls(&args.listen_address, tls_conf)?
} else { } else {
server.bind(&args.listen_address)? server.bind(&args.listen_address)?
}.run() }
.await .run()
} .await
}

View File

@ -2,7 +2,7 @@ use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use actix::{Actor, ActorContext, AsyncContext, Handler, Message, StreamHandler}; use actix::{Actor, ActorContext, AsyncContext, Handler, Message, StreamHandler};
use actix_web::{Error, HttpRequest, HttpResponse, web}; use actix_web::{web, Error, HttpRequest, HttpResponse};
use actix_web_actors::ws; use actix_web_actors::ws;
use actix_web_actors::ws::{CloseCode, CloseReason}; use actix_web_actors::ws::{CloseCode, CloseReason};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
@ -32,7 +32,6 @@ struct RelayWS {
// Client must respond to ping at a specific interval, otherwise we drop connection // Client must respond to ping at a specific interval, otherwise we drop connection
hb: Instant, hb: Instant,
// TODO : handle socket close // TODO : handle socket close
} }
@ -109,7 +108,9 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for RelayWS {
Ok(ws::Message::Text(text)) => ctx.text(text), Ok(ws::Message::Text(text)) => ctx.text(text),
Ok(ws::Message::Close(_reason)) => ctx.stop(), Ok(ws::Message::Close(_reason)) => ctx.stop(),
Ok(ws::Message::Binary(data)) => { Ok(ws::Message::Binary(data)) => {
if let Err(e) = futures::executor::block_on(self.tcp_write.write_all(&data.to_vec())) { if let Err(e) =
futures::executor::block_on(self.tcp_write.write_all(&data.to_vec()))
{
log::error!("Failed to forward some data, closing connection! {:?}", e); log::error!("Failed to forward some data, closing connection! {:?}", e);
ctx.stop(); ctx.stop();
} }
@ -150,17 +151,30 @@ pub struct WebSocketQuery {
token: Option<String>, token: Option<String>,
} }
pub async fn relay_ws(req: HttpRequest, stream: web::Payload, pub async fn relay_ws(
query: web::Query<WebSocketQuery>, req: HttpRequest,
conf: web::Data<Arc<ServerConfig>>) -> Result<HttpResponse, Error> { stream: web::Payload,
if conf.has_token_auth() && query: web::Query<WebSocketQuery>,
!conf.tokens.iter().any(|t| t == query.token.as_deref().unwrap_or_default()) { conf: web::Data<Arc<ServerConfig>>,
log::error!("Rejected WS request from {:?} due to invalid token!", req.peer_addr()); ) -> Result<HttpResponse, Error> {
if conf.has_token_auth()
&& !conf
.tokens
.iter()
.any(|t| t == query.token.as_deref().unwrap_or_default())
{
log::error!(
"Rejected WS request from {:?} due to invalid token!",
req.peer_addr()
);
return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!")); return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!"));
} }
if conf.ports.len() <= query.id { if conf.ports.len() <= query.id {
log::error!("Rejected WS request from {:?} due to invalid port number!", req.peer_addr()); log::error!(
"Rejected WS request from {:?} due to invalid port number!",
req.peer_addr()
);
return Ok(HttpResponse::BadRequest().json("Invalid port number!")); return Ok(HttpResponse::BadRequest().json("Invalid port number!"));
} }
@ -169,14 +183,24 @@ pub async fn relay_ws(req: HttpRequest, stream: web::Payload,
let (tcp_read, tcp_write) = match TcpStream::connect(&upstream_addr).await { let (tcp_read, tcp_write) = match TcpStream::connect(&upstream_addr).await {
Ok(s) => s.into_split(), Ok(s) => s.into_split(),
Err(e) => { Err(e) => {
log::error!("Failed to establish connection with upstream server! {:?}", e); log::error!(
return Ok(HttpResponse::InternalServerError() "Failed to establish connection with upstream server! {:?}",
.json("Failed to establish connection!")); e
);
return Ok(HttpResponse::InternalServerError().json("Failed to establish connection!"));
} }
}; };
let relay = RelayWS { tcp_read: Some(tcp_read), tcp_write, hb: Instant::now() }; let relay = RelayWS {
tcp_read: Some(tcp_read),
tcp_write,
hb: Instant::now(),
};
let resp = ws::start(relay, &req, stream); let resp = ws::start(relay, &req, stream);
log::info!("Opening new WS connection for {:?} to {}", req.peer_addr(), upstream_addr); log::info!(
"Opening new WS connection for {:?} to {}",
req.peer_addr(),
upstream_addr
);
resp resp
} }

View File

@ -2,8 +2,12 @@ use clap::Parser;
/// TCP relay server /// TCP relay server
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
#[clap(author, version, about, #[clap(
long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")] author,
version,
about,
long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy."
)]
pub struct ServerConfig { pub struct ServerConfig {
/// Access tokens /// Access tokens
#[clap(short, long)] #[clap(short, long)]
@ -62,4 +66,4 @@ impl ServerConfig {
pub fn has_auth(&self) -> bool { pub fn has_auth(&self) -> bool {
self.has_token_auth() || self.has_tls_client_auth() self.has_token_auth() || self.has_tls_client_auth()
} }
} }

View File

@ -3,8 +3,8 @@ use std::io::BufReader;
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime; use std::time::SystemTime;
use rustls::{Certificate, DistinguishedNames, Error, RootCertStore};
use rustls::server::{AllowAnyAuthenticatedClient, ClientCertVerified, ClientCertVerifier}; use rustls::server::{AllowAnyAuthenticatedClient, ClientCertVerified, ClientCertVerifier};
use rustls::{Certificate, DistinguishedNames, Error, RootCertStore};
use rustls_pemfile::certs; use rustls_pemfile::certs;
use crate::server_config::ServerConfig; use crate::server_config::ServerConfig;
@ -15,12 +15,17 @@ pub struct CustomCertClientVerifier {
impl CustomCertClientVerifier { impl CustomCertClientVerifier {
pub fn new(conf: Arc<ServerConfig>) -> Self { pub fn new(conf: Arc<ServerConfig>) -> Self {
let cert_path = conf.tls_client_auth_root_cert.as_deref() let cert_path = conf
.tls_client_auth_root_cert
.as_deref()
.expect("No root certificates for client authentication provided!"); .expect("No root certificates for client authentication provided!");
let cert_file = &mut BufReader::new(File::open(cert_path) let cert_file = &mut BufReader::new(
.expect("Failed to read root certificates for client authentication!")); File::open(cert_path)
.expect("Failed to read root certificates for client authentication!"),
);
let root_certs = certs(cert_file).unwrap() let root_certs = certs(cert_file)
.unwrap()
.into_iter() .into_iter()
.map(Certificate) .map(Certificate)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -32,7 +37,9 @@ impl CustomCertClientVerifier {
let mut store = RootCertStore::empty(); let mut store = RootCertStore::empty();
for cert in root_certs { for cert in root_certs {
store.add(&cert).expect("Failed to add certificate to root store"); store
.add(&cert)
.expect("Failed to add certificate to root store");
} }
Self { Self {
@ -54,8 +61,13 @@ impl ClientCertVerifier for CustomCertClientVerifier {
Some(vec![]) Some(vec![])
} }
fn verify_client_cert(&self, end_entity: &Certificate, intermediates: &[Certificate], now: SystemTime) -> Result<ClientCertVerified, Error> { fn verify_client_cert(
self.upstream_cert_verifier.verify_client_cert(end_entity, intermediates, now) &self,
end_entity: &Certificate,
intermediates: &[Certificate],
now: SystemTime,
) -> Result<ClientCertVerified, Error> {
self.upstream_cert_verifier
.verify_client_cert(end_entity, intermediates, now)
} }
} }