cargo fmt
This commit is contained in:
@ -1,3 +1,3 @@
|
||||
pub mod server_config;
|
||||
pub mod relay_ws;
|
||||
pub mod tls_cert_client_verifier;
|
||||
pub mod server_config;
|
||||
pub mod tls_cert_client_verifier;
|
||||
|
@ -2,11 +2,11 @@ use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::sync::Arc;
|
||||
|
||||
use actix_web::{App, HttpRequest, HttpResponse, HttpServer, middleware, Responder, web};
|
||||
use actix_web::web::Data;
|
||||
use actix_web::{middleware, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
|
||||
use clap::Parser;
|
||||
use rustls::{Certificate, PrivateKey};
|
||||
use rustls_pemfile::{certs, Item, read_one};
|
||||
use rustls_pemfile::{certs, read_one, Item};
|
||||
|
||||
use base::RelayedPort;
|
||||
use tcp_relay_server::relay_ws::relay_ws;
|
||||
@ -19,7 +19,9 @@ pub async fn hello_route() -> &'static str {
|
||||
|
||||
pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> impl Responder {
|
||||
if data.has_token_auth() {
|
||||
let token = req.headers().get("Authorization")
|
||||
let token = req
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.map(|t| t.to_str().unwrap_or_default())
|
||||
.unwrap_or_default()
|
||||
.strip_prefix("Bearer ")
|
||||
@ -31,10 +33,14 @@ pub async fn config_route(req: HttpRequest, data: Data<Arc<ServerConfig>>) -> im
|
||||
}
|
||||
|
||||
HttpResponse::Ok().json(
|
||||
data.ports.iter()
|
||||
data.ports
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(id, port)| RelayedPort { id, port: port + data.increment_ports })
|
||||
.collect::<Vec<_>>()
|
||||
.map(|(id, port)| RelayedPort {
|
||||
id,
|
||||
port: port + data.increment_ports,
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
@ -73,13 +79,13 @@ async fn main() -> std::io::Result<()> {
|
||||
|
||||
// Load TLS configuration, if any
|
||||
let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) {
|
||||
|
||||
// Load TLS certificate & private key
|
||||
let cert_file = &mut BufReader::new(File::open(cert).unwrap());
|
||||
let key_file = &mut BufReader::new(File::open(key).unwrap());
|
||||
|
||||
// Get certificates chain
|
||||
let cert_chain = certs(cert_file).unwrap()
|
||||
let cert_chain = certs(cert_file)
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(Certificate)
|
||||
.collect();
|
||||
@ -98,20 +104,22 @@ async fn main() -> std::io::Result<()> {
|
||||
}
|
||||
};
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults();
|
||||
let config = rustls::ServerConfig::builder().with_safe_defaults();
|
||||
|
||||
let config = match args.has_tls_client_auth() {
|
||||
true => config.with_client_cert_verifier(
|
||||
Arc::new(CustomCertClientVerifier::new(args.clone()))),
|
||||
false => config.with_no_client_auth()
|
||||
true => config
|
||||
.with_client_cert_verifier(Arc::new(CustomCertClientVerifier::new(args.clone()))),
|
||||
false => config.with_no_client_auth(),
|
||||
};
|
||||
|
||||
let config = config.with_single_cert(cert_chain, PrivateKey(key))
|
||||
let config = config
|
||||
.with_single_cert(cert_chain, PrivateKey(key))
|
||||
.expect("Failed to load TLS certificate!");
|
||||
|
||||
Some(config)
|
||||
} else { None };
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
log::info!("Starting relay on http://{}", args.listen_address);
|
||||
|
||||
@ -129,6 +137,7 @@ async fn main() -> std::io::Result<()> {
|
||||
server.bind_rustls(&args.listen_address, tls_conf)?
|
||||
} else {
|
||||
server.bind(&args.listen_address)?
|
||||
}.run()
|
||||
.await
|
||||
}
|
||||
}
|
||||
.run()
|
||||
.await
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use actix::{Actor, ActorContext, AsyncContext, Handler, Message, StreamHandler};
|
||||
use actix_web::{Error, HttpRequest, HttpResponse, web};
|
||||
use actix_web::{web, Error, HttpRequest, HttpResponse};
|
||||
use actix_web_actors::ws;
|
||||
use actix_web_actors::ws::{CloseCode, CloseReason};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
@ -32,7 +32,6 @@ struct RelayWS {
|
||||
|
||||
// Client must respond to ping at a specific interval, otherwise we drop connection
|
||||
hb: Instant,
|
||||
|
||||
// TODO : handle socket close
|
||||
}
|
||||
|
||||
@ -109,7 +108,9 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for RelayWS {
|
||||
Ok(ws::Message::Text(text)) => ctx.text(text),
|
||||
Ok(ws::Message::Close(_reason)) => ctx.stop(),
|
||||
Ok(ws::Message::Binary(data)) => {
|
||||
if let Err(e) = futures::executor::block_on(self.tcp_write.write_all(&data.to_vec())) {
|
||||
if let Err(e) =
|
||||
futures::executor::block_on(self.tcp_write.write_all(&data.to_vec()))
|
||||
{
|
||||
log::error!("Failed to forward some data, closing connection! {:?}", e);
|
||||
ctx.stop();
|
||||
}
|
||||
@ -150,17 +151,30 @@ pub struct WebSocketQuery {
|
||||
token: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn relay_ws(req: HttpRequest, stream: web::Payload,
|
||||
query: web::Query<WebSocketQuery>,
|
||||
conf: web::Data<Arc<ServerConfig>>) -> Result<HttpResponse, Error> {
|
||||
if conf.has_token_auth() &&
|
||||
!conf.tokens.iter().any(|t| t == query.token.as_deref().unwrap_or_default()) {
|
||||
log::error!("Rejected WS request from {:?} due to invalid token!", req.peer_addr());
|
||||
pub async fn relay_ws(
|
||||
req: HttpRequest,
|
||||
stream: web::Payload,
|
||||
query: web::Query<WebSocketQuery>,
|
||||
conf: web::Data<Arc<ServerConfig>>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
if conf.has_token_auth()
|
||||
&& !conf
|
||||
.tokens
|
||||
.iter()
|
||||
.any(|t| t == query.token.as_deref().unwrap_or_default())
|
||||
{
|
||||
log::error!(
|
||||
"Rejected WS request from {:?} due to invalid token!",
|
||||
req.peer_addr()
|
||||
);
|
||||
return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!"));
|
||||
}
|
||||
|
||||
if conf.ports.len() <= query.id {
|
||||
log::error!("Rejected WS request from {:?} due to invalid port number!", req.peer_addr());
|
||||
log::error!(
|
||||
"Rejected WS request from {:?} due to invalid port number!",
|
||||
req.peer_addr()
|
||||
);
|
||||
return Ok(HttpResponse::BadRequest().json("Invalid port number!"));
|
||||
}
|
||||
|
||||
@ -169,14 +183,24 @@ pub async fn relay_ws(req: HttpRequest, stream: web::Payload,
|
||||
let (tcp_read, tcp_write) = match TcpStream::connect(&upstream_addr).await {
|
||||
Ok(s) => s.into_split(),
|
||||
Err(e) => {
|
||||
log::error!("Failed to establish connection with upstream server! {:?}", e);
|
||||
return Ok(HttpResponse::InternalServerError()
|
||||
.json("Failed to establish connection!"));
|
||||
log::error!(
|
||||
"Failed to establish connection with upstream server! {:?}",
|
||||
e
|
||||
);
|
||||
return Ok(HttpResponse::InternalServerError().json("Failed to establish connection!"));
|
||||
}
|
||||
};
|
||||
|
||||
let relay = RelayWS { tcp_read: Some(tcp_read), tcp_write, hb: Instant::now() };
|
||||
let relay = RelayWS {
|
||||
tcp_read: Some(tcp_read),
|
||||
tcp_write,
|
||||
hb: Instant::now(),
|
||||
};
|
||||
let resp = ws::start(relay, &req, stream);
|
||||
log::info!("Opening new WS connection for {:?} to {}", req.peer_addr(), upstream_addr);
|
||||
log::info!(
|
||||
"Opening new WS connection for {:?} to {}",
|
||||
req.peer_addr(),
|
||||
upstream_addr
|
||||
);
|
||||
resp
|
||||
}
|
||||
}
|
||||
|
@ -2,8 +2,12 @@ use clap::Parser;
|
||||
|
||||
/// TCP relay server
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[clap(author, version, about,
|
||||
long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")]
|
||||
#[clap(
|
||||
author,
|
||||
version,
|
||||
about,
|
||||
long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy."
|
||||
)]
|
||||
pub struct ServerConfig {
|
||||
/// Access tokens
|
||||
#[clap(short, long)]
|
||||
@ -62,4 +66,4 @@ impl ServerConfig {
|
||||
pub fn has_auth(&self) -> bool {
|
||||
self.has_token_auth() || self.has_tls_client_auth()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3,8 +3,8 @@ use std::io::BufReader;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
|
||||
use rustls::{Certificate, DistinguishedNames, Error, RootCertStore};
|
||||
use rustls::server::{AllowAnyAuthenticatedClient, ClientCertVerified, ClientCertVerifier};
|
||||
use rustls::{Certificate, DistinguishedNames, Error, RootCertStore};
|
||||
use rustls_pemfile::certs;
|
||||
|
||||
use crate::server_config::ServerConfig;
|
||||
@ -15,12 +15,17 @@ pub struct CustomCertClientVerifier {
|
||||
|
||||
impl CustomCertClientVerifier {
|
||||
pub fn new(conf: Arc<ServerConfig>) -> Self {
|
||||
let cert_path = conf.tls_client_auth_root_cert.as_deref()
|
||||
let cert_path = conf
|
||||
.tls_client_auth_root_cert
|
||||
.as_deref()
|
||||
.expect("No root certificates for client authentication provided!");
|
||||
let cert_file = &mut BufReader::new(File::open(cert_path)
|
||||
.expect("Failed to read root certificates for client authentication!"));
|
||||
let cert_file = &mut BufReader::new(
|
||||
File::open(cert_path)
|
||||
.expect("Failed to read root certificates for client authentication!"),
|
||||
);
|
||||
|
||||
let root_certs = certs(cert_file).unwrap()
|
||||
let root_certs = certs(cert_file)
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(Certificate)
|
||||
.collect::<Vec<_>>();
|
||||
@ -32,7 +37,9 @@ impl CustomCertClientVerifier {
|
||||
|
||||
let mut store = RootCertStore::empty();
|
||||
for cert in root_certs {
|
||||
store.add(&cert).expect("Failed to add certificate to root store");
|
||||
store
|
||||
.add(&cert)
|
||||
.expect("Failed to add certificate to root store");
|
||||
}
|
||||
|
||||
Self {
|
||||
@ -54,8 +61,13 @@ impl ClientCertVerifier for CustomCertClientVerifier {
|
||||
Some(vec![])
|
||||
}
|
||||
|
||||
fn verify_client_cert(&self, end_entity: &Certificate, intermediates: &[Certificate], now: SystemTime) -> Result<ClientCertVerified, Error> {
|
||||
self.upstream_cert_verifier.verify_client_cert(end_entity, intermediates, now)
|
||||
fn verify_client_cert(
|
||||
&self,
|
||||
end_entity: &Certificate,
|
||||
intermediates: &[Certificate],
|
||||
now: SystemTime,
|
||||
) -> Result<ClientCertVerified, Error> {
|
||||
self.upstream_cert_verifier
|
||||
.verify_client_cert(end_entity, intermediates, now)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user