Add embedded TLS server
This commit is contained in:
@ -9,8 +9,11 @@ clap = { version = "3.2.18", features = ["derive", "env"] }
|
||||
log = "0.4.17"
|
||||
env_logger = "0.9.0"
|
||||
actix = "0.13.0"
|
||||
actix-web = "4"
|
||||
actix-web = { version = "4", features = ["rustls"] }
|
||||
actix-web-actors = "4.1.0"
|
||||
actix-tls = "3.0.3"
|
||||
serde = { version = "1.0.144", features = ["derive"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
futures = "0.3.24"
|
||||
futures = "0.3.24"
|
||||
rustls = "0.20.6"
|
||||
rustls-pemfile = "1.0.1"
|
@ -2,8 +2,9 @@ use clap::Parser;
|
||||
|
||||
/// TCP relay server
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
#[clap(author, version, about,
|
||||
long_about = "TCP-over-HTTP server. This program might be configured behind a reverse-proxy.")]
|
||||
pub struct ProgramArgs {
|
||||
/// Access tokens
|
||||
#[clap(short, long)]
|
||||
pub tokens: Vec<String>,
|
||||
@ -28,4 +29,12 @@ pub struct Args {
|
||||
/// on the same machine
|
||||
#[clap(short, long, default_value_t = 0)]
|
||||
pub increment_ports: u16,
|
||||
|
||||
/// TLS certificate. Specify also private key to use HTTPS/TLS instead of HTTP
|
||||
#[clap(long)]
|
||||
pub tls_cert: Option<String>,
|
||||
|
||||
/// TLS private key. Specify also certificate to use HTTPS/TLS instead of HTTP
|
||||
#[clap(long)]
|
||||
pub tls_key: Option<String>,
|
||||
}
|
@ -1,18 +1,22 @@
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::sync::Arc;
|
||||
|
||||
use actix_web::{App, HttpRequest, HttpResponse, HttpServer, middleware, Responder, web};
|
||||
use actix_web::web::Data;
|
||||
use clap::Parser;
|
||||
use rustls::{Certificate, PrivateKey, ServerConfig};
|
||||
use rustls_pemfile::{certs, Item, read_one};
|
||||
|
||||
use base::RelayedPort;
|
||||
use tcp_relay_server::args::Args;
|
||||
use tcp_relay_server::args::ProgramArgs;
|
||||
use tcp_relay_server::relay_ws::relay_ws;
|
||||
|
||||
pub async fn hello_route() -> &'static str {
|
||||
"Hello world!"
|
||||
}
|
||||
|
||||
pub async fn config_route(req: HttpRequest, data: Data<Arc<Args>>) -> impl Responder {
|
||||
pub async fn config_route(req: HttpRequest, data: Data<Arc<ProgramArgs>>) -> impl Responder {
|
||||
let token = req.headers().get("Authorization")
|
||||
.map(|t| t.to_str().unwrap_or_default())
|
||||
.unwrap_or_default()
|
||||
@ -35,13 +39,49 @@ pub async fn config_route(req: HttpRequest, data: Data<Arc<Args>>) -> impl Respo
|
||||
async fn main() -> std::io::Result<()> {
|
||||
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));
|
||||
|
||||
let mut args: Args = Args::parse();
|
||||
let mut args: ProgramArgs = ProgramArgs::parse();
|
||||
|
||||
if args.ports.is_empty() {
|
||||
log::error!("No port to forward!");
|
||||
std::process::exit(2);
|
||||
}
|
||||
|
||||
// Load TLS configuration, if any
|
||||
let tls_config = if let (Some(cert), Some(key)) = (&args.tls_cert, &args.tls_key) {
|
||||
|
||||
// Load TLS certificate & private key
|
||||
let cert_file = &mut BufReader::new(File::open(cert).unwrap());
|
||||
let key_file = &mut BufReader::new(File::open(key).unwrap());
|
||||
|
||||
// Get certificates chain
|
||||
let cert_chain = certs(cert_file).unwrap()
|
||||
.into_iter()
|
||||
.map(Certificate)
|
||||
.collect();
|
||||
|
||||
// Get private key
|
||||
let key = match read_one(key_file).expect("Failed to read private key!") {
|
||||
None => {
|
||||
log::error!("Failed to extract private key!");
|
||||
panic!();
|
||||
}
|
||||
Some(Item::PKCS8Key(key)) => key,
|
||||
Some(Item::RSAKey(key)) => key,
|
||||
_ => {
|
||||
log::error!("Unsupported private key type!");
|
||||
panic!();
|
||||
}
|
||||
};
|
||||
|
||||
let config = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, PrivateKey(key))
|
||||
.expect("Failed to load TLS certificate!");
|
||||
|
||||
Some(config)
|
||||
} else { None };
|
||||
|
||||
// Read tokens from file, if any
|
||||
if let Some(file) = &args.tokens_file {
|
||||
std::fs::read_to_string(file)
|
||||
@ -60,15 +100,19 @@ async fn main() -> std::io::Result<()> {
|
||||
|
||||
let args = Arc::new(args);
|
||||
let args_clone = args.clone();
|
||||
HttpServer::new(move || {
|
||||
let server = HttpServer::new(move || {
|
||||
App::new()
|
||||
.wrap(middleware::Logger::default())
|
||||
.app_data(Data::new(args_clone.clone()))
|
||||
.route("/", web::get().to(hello_route))
|
||||
.route("/config", web::get().to(config_route))
|
||||
.route("/ws", web::get().to(relay_ws))
|
||||
})
|
||||
.bind(&args.listen_address)?
|
||||
.run()
|
||||
});
|
||||
|
||||
if let Some(tls_conf) = tls_config {
|
||||
server.bind_rustls(&args.listen_address, tls_conf)?
|
||||
} else {
|
||||
server.bind(&args.listen_address)?
|
||||
}.run()
|
||||
.await
|
||||
}
|
@ -9,7 +9,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::args::Args;
|
||||
use crate::args::ProgramArgs;
|
||||
|
||||
/// How often heartbeat pings are sent
|
||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
@ -153,7 +153,7 @@ pub struct WebSocketQuery {
|
||||
|
||||
pub async fn relay_ws(req: HttpRequest, stream: web::Payload,
|
||||
query: web::Query<WebSocketQuery>,
|
||||
conf: web::Data<Arc<Args>>) -> Result<HttpResponse, Error> {
|
||||
conf: web::Data<Arc<ProgramArgs>>) -> Result<HttpResponse, Error> {
|
||||
if !conf.tokens.contains(&query.token) {
|
||||
log::error!("Rejected WS request from {:?} due to invalid token!", req.peer_addr());
|
||||
return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!"));
|
||||
|
Reference in New Issue
Block a user