All checks were successful
continuous-integration/drone/push Build is passing
244 lines
8.6 KiB
Rust
244 lines
8.6 KiB
Rust
use clap::Parser;
|
|
use rand::distr::{Alphanumeric, SampleString};
|
|
use std::error::Error;
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::TcpListener;
|
|
use tokio::net::TcpStream;
|
|
|
|
/// Simple program that redirect requests to a given server in request
|
|
#[derive(Parser, Debug)]
|
|
#[command(version, about, long_about = None)]
|
|
struct Args {
|
|
/// The address the server will listen to
|
|
#[arg(short, long, env, default_value = "0.0.0.0:8000")]
|
|
listen_address: String,
|
|
|
|
/// The name of the header that contain target host and port
|
|
#[arg(short, long, env, default_value = "x-target-host")]
|
|
target_host_port_header: String,
|
|
|
|
/// Name of optional header that contains path to add to the request.
|
|
///
|
|
/// If this value is defined, all clients packets are inspected in research for path to
|
|
/// manipulate
|
|
#[arg(short, long, env)]
|
|
path_prefix_header: Option<String>,
|
|
}
|
|
|
|
lazy_static::lazy_static! {
|
|
static ref ARGS: Args = {
|
|
Args::parse()
|
|
};
|
|
}
|
|
|
|
const ERR_NOT_PROXIFIALBE: &[u8; 44] = b"HTTP/1.1 400 Forbidden\r\n\r\nNot proxifiable.\r\n";
|
|
|
|
pub fn rand_str(len: usize) -> String {
|
|
Alphanumeric.sample_string(&mut rand::rng(), len)
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn Error>> {
|
|
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));
|
|
|
|
log::info!("Will start to listen on {}", ARGS.listen_address);
|
|
let listener = TcpListener::bind(&ARGS.listen_address).await?;
|
|
|
|
loop {
|
|
// Asynchronously wait for an inbound socket.
|
|
let (mut client_socket, _) = listener.accept().await?;
|
|
|
|
tokio::spawn(async move {
|
|
let conn_id = rand_str(5);
|
|
|
|
log::info!(
|
|
"[{conn_id}] Handle new connection from {}",
|
|
client_socket.peer_addr().unwrap()
|
|
);
|
|
|
|
let (mut client_read, mut client_write) = client_socket.split();
|
|
let mut buf_client = [0u8; 10000];
|
|
let mut buf_server = [0u8; 1024];
|
|
|
|
// Perform first read operation manually to manipulate path and determine target server
|
|
let mut total = 0;
|
|
let headers_processed = loop {
|
|
let count = match client_read.read(&mut buf_client[total..]).await {
|
|
Ok(count) => count,
|
|
Err(e) => {
|
|
log::error!(
|
|
"[{conn_id}] Failed to read initial data from client, closing connection! {e}"
|
|
);
|
|
return;
|
|
}
|
|
};
|
|
|
|
if count == 0 {
|
|
log::error!("[{conn_id}] read from client count is null, cannot continue!");
|
|
let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await;
|
|
return;
|
|
}
|
|
|
|
total += count;
|
|
|
|
match process_headers(&buf_client[..total]) {
|
|
Ok(None) => {
|
|
log::debug!("[{conn_id}] Insufficient amount of data, need to continue");
|
|
continue;
|
|
}
|
|
Ok(Some(res)) => break res,
|
|
Err(e) => {
|
|
log::error!("[{conn_id}] failed to parse initial request headers! {e}");
|
|
let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await;
|
|
return;
|
|
}
|
|
}
|
|
};
|
|
|
|
// Connect to upstream
|
|
let mut upstream = match TcpStream::connect(headers_processed.target_host).await {
|
|
Ok(upstream) => upstream,
|
|
Err(e) => {
|
|
log::error!("[{conn_id}] Could not connect to upstream! {e}");
|
|
let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
// Transfer modified headers to upstream
|
|
if let Err(e) = upstream.write_all(&headers_processed.buff).await {
|
|
log::error!("[{conn_id}] Could not forward initial bytes to upstream! {e}");
|
|
let _ = client_write.write_all(ERR_NOT_PROXIFIALBE).await;
|
|
return;
|
|
}
|
|
|
|
// Enter in loop to forward remaining data
|
|
loop {
|
|
tokio::select! {
|
|
count = client_read.read(&mut buf_client) => {
|
|
let count = match count{ Ok(count) => count, Err(e) => {
|
|
log::error!("[{conn_id}] Failed to read data from client, closing connection! {e}");
|
|
break;
|
|
}};
|
|
log::debug!("[{conn_id}] Got a new client read {count}");
|
|
|
|
if count == 0 {
|
|
log::warn!("[{conn_id}] infinite loop (client), closing connection");
|
|
drop(upstream);
|
|
break;
|
|
}
|
|
|
|
// In case of connection reuse, we need to reanalyze data
|
|
if ARGS.path_prefix_header.is_some() &&
|
|
let Ok(Some(res))= process_headers(&buf_client[..count])
|
|
&& let Err(e) = upstream.write_all(&res.buff).await {
|
|
log::error!("[{conn_id}] Failed to write to upstream! {e}");
|
|
break;
|
|
}
|
|
|
|
if let Err(e)=upstream.write_all(&buf_client[..count]).await {
|
|
log::error!("[{conn_id}] Failed to write to upstream! {e}");
|
|
break;
|
|
}
|
|
}
|
|
|
|
count = upstream.read(&mut buf_server) => {
|
|
let count = match count {
|
|
Ok(count) => count,
|
|
Err(e) => {
|
|
log::error!("[{conn_id}] Failed to read from upstream! {e}");
|
|
break;
|
|
}
|
|
};
|
|
|
|
if count == 0 {
|
|
log::warn!("[{conn_id}] infinite loop (upstream), closing connection");
|
|
drop(upstream);
|
|
break;
|
|
}
|
|
|
|
log::debug!("[{conn_id}] Got a new upstream read {count}");
|
|
if let Err(e) = client_write.write_all(&buf_server[..count]).await {
|
|
log::error!("Failed to write to upstream! {e}");
|
|
break;
|
|
};
|
|
}
|
|
}
|
|
}
|
|
|
|
log::info!("[{conn_id}] Connection finished.");
|
|
});
|
|
}
|
|
}
|
|
|
|
struct ProcessHeadersResult {
|
|
buff: Vec<u8>,
|
|
target_host: String,
|
|
}
|
|
|
|
fn process_headers(buff: &[u8]) -> anyhow::Result<Option<ProcessHeadersResult>> {
|
|
let mut headers = [httparse::EMPTY_HEADER; 64];
|
|
let mut req = httparse::Request::new(&mut headers);
|
|
let parsing_res = req.parse(buff)?;
|
|
|
|
let target_host = headers
|
|
.iter()
|
|
.find(|h| h.name.eq_ignore_ascii_case(&ARGS.target_host_port_header))
|
|
.map(|h| String::from_utf8_lossy(h.value));
|
|
|
|
log::debug!("Request headers: {:?}", headers);
|
|
|
|
let Some(target_host) = target_host else {
|
|
if parsing_res.is_partial() {
|
|
return Ok(None);
|
|
} else {
|
|
anyhow::bail!(
|
|
"Request is complete without header {}",
|
|
ARGS.target_host_port_header
|
|
)
|
|
}
|
|
};
|
|
|
|
// Check if path needs to be prefixed
|
|
let prefix_path = if let Some(hname) = &ARGS.path_prefix_header {
|
|
headers
|
|
.iter()
|
|
.find(|h| h.name.eq_ignore_ascii_case(hname))
|
|
.map(|h| String::from_utf8_lossy(h.value).replace(['\n', '\r', '\t', ' '], ""))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// Perform prefix injection
|
|
let mut buff = buff.to_vec();
|
|
if let Some(prefix) = prefix_path {
|
|
let pos = buff.iter().position(|c| c == &b' ');
|
|
log::debug!("Add path prefix to request {prefix}");
|
|
if let Some(pos) = pos {
|
|
for (num, c) in prefix.as_bytes().iter().enumerate() {
|
|
buff.insert(pos + 1 + num, *c);
|
|
}
|
|
} else {
|
|
log::warn!("Unable to inject prefix!");
|
|
}
|
|
}
|
|
|
|
log::trace!("Final request: {}", String::from_utf8_lossy(&buff));
|
|
|
|
Ok(Some(ProcessHeadersResult {
|
|
target_host: target_host.to_string(),
|
|
buff,
|
|
}))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use crate::Args;
|
|
|
|
#[test]
|
|
fn verify_cli() {
|
|
use clap::CommandFactory;
|
|
Args::command().debug_assert()
|
|
}
|
|
}
|