diff --git a/Cargo.lock b/Cargo.lock index afcf225..1236462 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,6 +61,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "anyhow" +version = "1.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" + [[package]] name = "bitflags" version = "2.10.0" @@ -164,8 +170,10 @@ dependencies = [ name = "header_proxy" version = "0.1.0" dependencies = [ + "anyhow", "clap", "env_logger", + "httparse", "lazy_static", "log", "rand", @@ -178,6 +186,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + [[package]] name = "is_terminal_polyfill" version = "1.70.2" diff --git a/Cargo.toml b/Cargo.toml index 98431c5..63d7d5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,3 +10,5 @@ clap = { version = "4.5.53", features = ["env", "derive"] } tokio = { version = "1.48.0", features = ["full"] } rand = "0.9.2" lazy_static = "1.5.0" +httparse = "1.10.1" +anyhow = "1.0.100" \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index bbd869f..3764630 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,16 +10,16 @@ use tokio::net::TcpStream; #[command(version, about, long_about = None)] struct Args { /// The address the server will listen to - #[arg(short, long, default_value = "0.0.0.0:8000")] + #[arg(short, long, env, default_value = "0.0.0.0:8000")] listen_address: String, /// The name of the header that contain target host and port - #[arg(short, long, default_value = "x-target-host")] + #[arg(short, long, env, default_value = "x-target-host")] target_host_port_header: String, /// Name of optional header that contains path to add to the request - #[arg(short, long, default_value = "x-path-prefix")] - path_prefix_header: String, + #[arg(short, long, env)] + path_prefix_header: Option, } lazy_static::lazy_static! { @@ -54,37 +54,62 @@ async fn main() -> Result<(), Box> { ); let (mut client_read, mut client_write) = client_socket.split(); - let mut buf_client = [0u8; 1024]; + let mut buf_client = [0u8; 10000]; let mut buf_server = [0u8; 1024]; - // Perform first read operation manually to manipulate path and determine destination server - let count = match client_read.read(&mut buf_client).await { - Ok(count) => count, - Err(e) => { - log::error!( - "[{conn_id}] Failed to read initial data from client, closing connection! {e}" - ); + // Perform first read operation manually to manipulate path and determine target server + let mut total = 0; + let headers_processed = loop { + let count = match client_read.read(&mut buf_client[total..]).await { + Ok(count) => count, + Err(e) => { + log::error!( + "[{conn_id}] Failed to read initial data from client, closing connection! {e}" + ); + return; + } + }; + + if count == 0 { + log::error!("[{conn_id}] read from client count is null, cannot continue!"); + let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await; return; } + + total += count; + + match process_headers(&buf_client[..total]) { + Ok(None) => { + log::debug!("[{conn_id}] Insufficient amount of data, need to continue"); + continue; + } + Ok(Some(res)) => break res, + Err(e) => { + log::error!("[{conn_id}] failed to parse initial request headers! {e}"); + let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await; + return; + } + } }; - if count < 10 { - log::error!("[{conn_id}] Initial read too small (count={count}), cannot continue!"); - return; - } - - let headers_processed = process_headers(&buf_client[..count]); - - // Transfer modified headers to upstream - let mut upstream = match TcpStream::connect(headers_processed.remote_host).await { + // Connect to upstream + let mut upstream = match TcpStream::connect(headers_processed.target_host).await { Ok(upstream) => upstream, Err(e) => { - log::error!("Could not connect to upstream! {e}"); + log::error!("[{conn_id}] Could not connect to upstream! {e}"); let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await; return; } }; + // Transfer modified headers to upstream + if let Err(e) = upstream.write_all(&headers_processed.buff).await { + log::error!("[{conn_id}] Could not forward initial bytes to upstream! {e}"); + let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await; + return; + } + + // Enter in loop to forward remaining data loop { tokio::select! { count = client_read.read(&mut buf_client) => { @@ -137,31 +162,55 @@ async fn main() -> Result<(), Box> { struct ProcessHeadersResult { buff: Vec, - remote_host: String, + target_host: String, } -fn process_headers(buff: &[u8]) -> ProcessHeadersResult { - let mut out = Vec::with_capacity(buff.len()); +fn process_headers(buff: &[u8]) -> anyhow::Result> { + let mut headers = [httparse::EMPTY_HEADER; 64]; + let mut req = httparse::Request::new(&mut headers); + let parsing_res = req.parse(buff)?; - let mut i = 0; - while i < buff.len() { - if buff[i] != b'\n' || i + 1 == buff.len() || !buff[i + 1..].starts_with(b"Host:") { - out.push(buff[i]); - i += 1; - continue; + let target_host = headers + .iter() + .find(|h| h.name.eq_ignore_ascii_case(&ARGS.target_host_port_header)) + .map(|h| String::from_utf8_lossy(h.value)); + + let Some(target_host) = target_host else { + if parsing_res.is_partial() { + return Ok(None); + } else { + anyhow::bail!( + "Request is complete without header {}", + ARGS.target_host_port_header + ) } + }; - i += 1; - out.push(b'\n'); + // Check if path needs to be prefixed + let prefix_path = if let Some(hname) = &ARGS.path_prefix_header { + headers + .iter() + .find(|h| h.name.eq_ignore_ascii_case(hname)) + .map(|h| String::from_utf8_lossy(h.value).replace(['\n', '\r', '\t', ' '], "")) + } else { + None + }; - out.append(&mut format!("Host: {host}").as_bytes().to_vec()); - - while buff[i] != b'\r' && buff[i] != b'\n' { - i += 1; + // Perform prefix injection + let mut buff = buff.to_vec(); + if let Some(prefix) = prefix_path { + let pos = buff.iter().position(|c| c == &b' '); + if let Some(pos) = pos { + for (num, c) in prefix.as_bytes().iter().enumerate() { + buff.insert(pos + 1 + num, *c); + } } } - todo!() + Ok(Some(ProcessHeadersResult { + target_host: target_host.to_string(), + buff, + })) } #[cfg(test)]