diff --git a/src/main.rs b/src/main.rs index 9e9b784..bbd869f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,9 @@ use clap::Parser; use rand::distr::{Alphanumeric, SampleString}; -use rustls_pki_types::ServerName; use std::error::Error; -use std::fs::OpenOptions; -use std::io::Write; -use std::path::Path; -use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::net::TcpStream; -use tokio_rustls::TlsConnector; -use tokio_rustls::rustls::{ClientConfig, RootCertStore}; /// Simple program that redirect requests to a given server in request #[derive(Parser, Debug)] @@ -27,7 +19,7 @@ struct Args { /// Name of optional header that contains path to add to the request #[arg(short, long, default_value = "x-path-prefix")] - path_prefix_heder: String, + path_prefix_header: String, } lazy_static::lazy_static! { @@ -56,61 +48,62 @@ async fn main() -> Result<(), Box> { tokio::spawn(async move { let conn_id = rand_str(5); - log::debug!( - "Handle new connection from {}", + log::info!( + "[{conn_id}] Handle new connection from {}", client_socket.peer_addr().unwrap() ); - - let stream = TcpStream::connect(TODO) - .await - .expect("Failed to connect to upstream"); - let mut upstream = connector - .connect(dnsname, stream) - .await - .expect("Failed to establish TLS connection"); - let (mut client_read, mut client_write) = client_socket.split(); - let mut buf_client = [0u8; 1024]; let mut buf_server = [0u8; 1024]; - let mut modified_headers = false; + // Perform first read operation manually to manipulate path and determine destination server + let count = match client_read.read(&mut buf_client).await { + Ok(count) => count, + Err(e) => { + log::error!( + "[{conn_id}] Failed to read initial data from client, closing connection! {e}" + ); + return; + } + }; + + if count < 10 { + log::error!("[{conn_id}] Initial read too small (count={count}), cannot continue!"); + return; + } + + let headers_processed = process_headers(&buf_client[..count]); + + // Transfer modified headers to upstream + let mut upstream = match TcpStream::connect(headers_processed.remote_host).await { + Ok(upstream) => upstream, + Err(e) => { + log::error!("Could not connect to upstream! {e}"); + let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await; + return; + } + }; loop { tokio::select! { count = client_read.read(&mut buf_client) => { let count = match count{ Ok(count) => count, Err(e) => { log::error!("[{conn_id}] Failed to read data from client, closing connection! {e}"); - return; + break; }}; - log::info!("[{conn_id}] Got a new client read {count} - {base_file_name}"); + log::debug!("[{conn_id}] Got a new client read {count}"); if count == 0 { log::warn!("[{conn_id}] infinite loop (client), closing connection"); drop(upstream); - return; + break; } - // We need to modify some headers (if not done already) to adapt the request to the server - let buff = if !modified_headers { - - // Check for URL prefix - if let Some(prefix) = &args.prefix - && !String::from_utf8_lossy(&buf_client[..count]).split_once('\n').map(|l|l.0).unwrap_or("").contains(&format!(" {prefix}")) { - client_write.write_all(ERR_NOT_PROXIFIABLE).await.expect("Failed to respond to client"); - client_write.flush().await.expect("Failed to flush response to client!"); - return; - } - - modified_headers = true; - manipulate_headers(&buf_client[..count], &args.upstream_dns) + if let Err(e)=upstream.write_all(&buf_client[..count]).await { + log::error!("[{conn_id}] Failed to write to upstream! {e}"); + break; } - else { - buf_client[..count].to_vec() - }; - - upstream.write_all(&buff).await.unwrap_or_else(|_| panic!("[{conn_id}] Failed to write to upstream")); } count = upstream.read(&mut buf_server) => { @@ -118,26 +111,36 @@ async fn main() -> Result<(), Box> { Ok(count) => count, Err(e) => { log::error!("[{conn_id}] Failed to read from upstream! {e}"); - return; + break; } }; if count == 0 { log::warn!("[{conn_id}] infinite loop (upstream), closing connection"); drop(upstream); - return; + break; } - log::info!("[{conn_id}] Got a new upstream read {count} - {base_file_name}"); - client_write.write_all(&buf_server[..count]).await.expect("Failed to write to client"); + log::debug!("[{conn_id}] Got a new upstream read {count}"); + if let Err(e) = client_write.write_all(&buf_server[..count]).await { + log::error!("Failed to write to upstream! {e}"); + break; + }; } } } + + log::info!("[{conn_id}] Connection finished."); }); } } -fn manipulate_headers(buff: &[u8], host: &str) -> Vec { +struct ProcessHeadersResult { + buff: Vec, + remote_host: String, +} + +fn process_headers(buff: &[u8]) -> ProcessHeadersResult { let mut out = Vec::with_capacity(buff.len()); let mut i = 0; @@ -158,7 +161,7 @@ fn manipulate_headers(buff: &[u8], host: &str) -> Vec { } } - out + todo!() } #[cfg(test)]