use clap::Parser; use rand::distr::{Alphanumeric, SampleString}; use rustls_pki_types::ServerName; use std::error::Error; use std::fs::OpenOptions; use std::io::Write; use std::path::Path; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::TlsConnector; use tokio_rustls::rustls::{ClientConfig, RootCertStore}; /// Simple program that proxify requests and save responses #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { /// The address the server will listen to #[arg(short, long, default_value = "0.0.0.0:8000")] listen_address: String, /// Upstream address this server will connect to #[arg(short, long, default_value = "communiquons.org")] upstream_dns: String, /// Upstream address this server will connect to #[arg(short('I'), long, default_value = "10.0.1.10:443")] upstream_ip: String, /// The path on the server this server will save requests and responses #[arg(short, long, default_value = "storage")] storage_path: String, } /// Get the current time since epoch pub fn time() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() } pub fn rand_str(len: usize) -> String { Alphanumeric.sample_string(&mut rand::rng(), len) } #[tokio::main] async fn main() -> Result<(), Box> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); let args = Args::parse(); log::info!("Will start to listen on {}", args.listen_address); let listener = TcpListener::bind(&args.listen_address).await?; std::fs::create_dir_all(Path::new(args.storage_path.as_str())).unwrap(); loop { // Asynchronously wait for an inbound socket. let (mut client_socket, _) = listener.accept().await?; tokio::spawn(async move { let args = Args::parse(); let base_file_name = format!( "{}-{}-{}", client_socket.peer_addr().unwrap().ip(), time(), rand_str(10) ); let mut req_file = OpenOptions::new() .create(true) .write(true) .open(Path::new(&args.storage_path).join(format!("{base_file_name}-req"))) .expect("Failed to create req file"); let mut res_file = OpenOptions::new() .create(true) .write(true) .open(Path::new(&args.storage_path).join(format!("{base_file_name}-res"))) .expect("Failed to create req file"); let mut root_cert_store = RootCertStore::empty(); root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let config = ClientConfig::builder() .with_root_certificates(root_cert_store) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(config)); let dnsname = ServerName::try_from(args.upstream_dns.to_string()).unwrap(); let stream = TcpStream::connect(args.upstream_ip) .await .expect("Failed to connect to upstream"); let mut upstream = connector .connect(dnsname, stream) .await .expect("Failed to establish TLS connection"); let (mut client_read, mut client_write) = client_socket.split(); let mut buf_client = [0u8; 1024]; let mut buf_server = [0u8; 1024]; let mut modified_headers = false; loop { tokio::select! { count = client_read.read(&mut buf_client) => { let count = match count{ Ok(count) => count, Err(e) => { log::error!("Failed to read data from client, closing connection! {e}"); return; }}; log::info!("Got a new client read {count}"); if count == 0 { log::warn!("infinite loop"); return; } // We need to modify some headers (if not done already) to adapt the request to the server let buff = if !modified_headers { modified_headers = true; manipulate_headers(&buf_client[..count], &args.upstream_dns) } else { buf_client[..count].to_vec() }; upstream.write_all(&buff).await.expect("Failed to write to upstream"); req_file.write_all(&buff).expect("Failed to write to req"); } count = upstream.read(&mut buf_server) => { let count = match count{ Ok(count) => count,Err(e) => { log::error!("Failed to read from upstream! {e}"); return; }}; log::info!("Got a new upstream read {count}"); client_write.write_all(&buf_server[..count]).await.expect("Failed to write to client"); res_file.write_all(&buf_server[..count]).expect("Failed to write to res"); } } } }); } } fn manipulate_headers(buff: &[u8], host: &str) -> Vec { // return buff.to_vec(); let mut out = Vec::with_capacity(buff.len()); let mut i = 0; while i < buff.len() { if buff[i] != b'\n' || i + 1 == buff.len() || !buff[i + 1..].starts_with(b"Host:") { out.push(buff[i]); i += 1; continue; } i += 1; out.push(b'\n'); out.append(&mut format!("Host: {host}").as_bytes().to_vec()); while buff[i] != b'\r' && buff[i] != b'\n' { i += 1; } } out } #[cfg(test)] mod test { use crate::Args; #[test] fn verify_cli() { use clap::CommandFactory; Args::command().debug_assert() } }