use clap::Parser; use rand::distr::{Alphanumeric, SampleString}; use std::error::Error; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::net::TcpStream; /// Simple program that redirect requests to a given server in request #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { /// The address the server will listen to #[arg(short, long, env, default_value = "0.0.0.0:8000")] listen_address: String, /// The name of the header that contain target host and port #[arg(short, long, env, default_value = "x-target-host")] target_host_port_header: String, /// Name of optional header that contains path to add to the request. /// /// If this value is defined, all clients packets are inspected in research for path to /// manipulate #[arg(short, long, env)] path_prefix_header: Option, } lazy_static::lazy_static! { static ref ARGS: Args = { Args::parse() }; } const ERR_NOT_PROXIFIALBE: &[u8; 44] = b"HTTP/1.1 400 Forbidden\r\n\r\nNot proxifiable.\r\n"; pub fn rand_str(len: usize) -> String { Alphanumeric.sample_string(&mut rand::rng(), len) } #[tokio::main] async fn main() -> Result<(), Box> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); log::info!("Will start to listen on {}", ARGS.listen_address); let listener = TcpListener::bind(&ARGS.listen_address).await?; loop { // Asynchronously wait for an inbound socket. let (mut client_socket, _) = listener.accept().await?; tokio::spawn(async move { let conn_id = rand_str(5); log::info!( "[{conn_id}] Handle new connection from {}", client_socket.peer_addr().unwrap() ); let (mut client_read, mut client_write) = client_socket.split(); let mut buf_client = [0u8; 10000]; let mut buf_server = [0u8; 1024]; // Perform first read operation manually to manipulate path and determine target server let mut total = 0; let headers_processed = loop { let count = match client_read.read(&mut buf_client[total..]).await { Ok(count) => count, Err(e) => { log::error!( "[{conn_id}] Failed to read initial data from client, closing connection! {e}" ); return; } }; if count == 0 { log::error!("[{conn_id}] read from client count is null, cannot continue!"); let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await; return; } total += count; match process_headers(&buf_client[..total]) { Ok(None) => { log::debug!("[{conn_id}] Insufficient amount of data, need to continue"); continue; } Ok(Some(res)) => break res, Err(e) => { log::error!("[{conn_id}] failed to parse initial request headers! {e}"); let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await; return; } } }; // Connect to upstream let mut upstream = match TcpStream::connect(headers_processed.target_host).await { Ok(upstream) => upstream, Err(e) => { log::error!("[{conn_id}] Could not connect to upstream! {e}"); let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await; return; } }; // Transfer modified headers to upstream if let Err(e) = upstream.write_all(&headers_processed.buff).await { log::error!("[{conn_id}] Could not forward initial bytes to upstream! {e}"); let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await; return; } // Enter in loop to forward remaining data loop { tokio::select! { count = client_read.read(&mut buf_client) => { let count = match count{ Ok(count) => count, Err(e) => { log::error!("[{conn_id}] Failed to read data from client, closing connection! {e}"); break; }}; log::debug!("[{conn_id}] Got a new client read {count}"); if count == 0 { log::warn!("[{conn_id}] infinite loop (client), closing connection"); drop(upstream); break; } // In case of connection reuse, we need to reanalyze data if ARGS.path_prefix_header.is_some() && let Ok(Some(res))= process_headers(&buf_client[..count]) && let Err(e) = upstream.write_all(&res.buff).await { log::error!("[{conn_id}] Failed to write to upstream! {e}"); break; } if let Err(e)=upstream.write_all(&buf_client[..count]).await { log::error!("[{conn_id}] Failed to write to upstream! {e}"); break; } } count = upstream.read(&mut buf_server) => { let count = match count { Ok(count) => count, Err(e) => { log::error!("[{conn_id}] Failed to read from upstream! {e}"); break; } }; if count == 0 { log::warn!("[{conn_id}] infinite loop (upstream), closing connection"); drop(upstream); break; } log::debug!("[{conn_id}] Got a new upstream read {count}"); if let Err(e) = client_write.write_all(&buf_server[..count]).await { log::error!("Failed to write to upstream! {e}"); break; }; } } } log::info!("[{conn_id}] Connection finished."); }); } } struct ProcessHeadersResult { buff: Vec, target_host: String, } fn process_headers(buff: &[u8]) -> anyhow::Result> { let mut headers = [httparse::EMPTY_HEADER; 64]; let mut req = httparse::Request::new(&mut headers); let parsing_res = req.parse(buff)?; let target_host = headers .iter() .find(|h| h.name.eq_ignore_ascii_case(&ARGS.target_host_port_header)) .map(|h| String::from_utf8_lossy(h.value)); log::debug!("Request headers: {:?}", headers); let Some(target_host) = target_host else { if parsing_res.is_partial() { return Ok(None); } else { anyhow::bail!( "Request is complete without header {}", ARGS.target_host_port_header ) } }; // Check if path needs to be prefixed let prefix_path = if let Some(hname) = &ARGS.path_prefix_header { headers .iter() .find(|h| h.name.eq_ignore_ascii_case(hname)) .map(|h| String::from_utf8_lossy(h.value).replace(['\n', '\r', '\t', ' '], "")) } else { None }; // Perform prefix injection let mut buff = buff.to_vec(); if let Some(prefix) = prefix_path { let pos = buff.iter().position(|c| c == &b' '); log::debug!("Add path prefix to request {prefix}"); if let Some(pos) = pos { for (num, c) in prefix.as_bytes().iter().enumerate() { buff.insert(pos + 1 + num, *c); } } else { log::warn!("Unable to inject prefix!"); } } log::trace!("Final request: {}", String::from_utf8_lossy(&buff)); Ok(Some(ProcessHeadersResult { target_host: target_host.to_string(), buff, })) } #[cfg(test)] mod test { use crate::Args; #[test] fn verify_cli() { use clap::CommandFactory; Args::command().debug_assert() } }