From 715aed3a3cbd4018637b4e2f6c599b1a55d74e88 Mon Sep 17 00:00:00 2001 From: Pierre HUBERT Date: Tue, 25 Feb 2025 11:13:05 +0100 Subject: [PATCH] Log requests and responses --- src/main.rs | 118 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 97 insertions(+), 21 deletions(-) diff --git a/src/main.rs b/src/main.rs index 804c2c9..8504f5b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,19 +1,19 @@ +use clap::Parser; +use rand::distr::{Alphanumeric, SampleString}; +use rustls_pki_types::ServerName; use std::error::Error; -use std::fs::{File, OpenOptions}; +use std::fs::OpenOptions; use std::io::Write; use std::path::Path; -use clap::Parser; -use tokio::net::TcpListener; -use rustls_pki_types::ServerName; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use rand::distr::{Alphanumeric, SampleString}; -use tokio::net::TcpStream; -use tokio_rustls::rustls::{ClientConfig, RootCertStore}; -use tokio_rustls::TlsConnector; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::{pin}; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::pin; use tokio::time::interval; +use tokio_rustls::TlsConnector; +use tokio_rustls::rustls::{ClientConfig, RootCertStore}; /// Simple program that proxify requests and save responses #[derive(Parser, Debug)] @@ -51,7 +51,8 @@ pub fn rand_str(len: usize) -> String { } #[tokio::main] -async fn main() -> Result<(), Box> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); +async fn main() -> Result<(), Box> { + env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); let args = Args::parse(); @@ -60,7 +61,6 @@ async fn main() -> Result<(), Box> { env_logger::init_from_env(env std::fs::create_dir_all(Path::new(args.storage_path.as_str())).unwrap(); - loop { // Asynchronously wait for an inbound socket. let (mut client_socket, _) = listener.accept().await?; @@ -68,11 +68,22 @@ async fn main() -> Result<(), Box> { env_logger::init_from_env(env tokio::spawn(async move { let args = Args::parse(); - let base_file_name = format!("{}-{}-{}", client_socket.peer_addr().unwrap().ip(), time(), rand_str(10)); + let base_file_name = format!( + "{}-{}-{}", + client_socket.peer_addr().unwrap().ip(), + time(), + rand_str(10) + ); - let mut req_file = OpenOptions::new().create(true).write(true).open(Path::new(&args.storage_path).join(format!("req-{base_file_name}"))) + let mut req_file = OpenOptions::new() + .create(true) + .write(true) + .open(Path::new(&args.storage_path).join(format!("req-{base_file_name}"))) .expect("Failed to create req file"); - let mut res_file = OpenOptions::new().create(true).write(true).open(Path::new(&args.storage_path).join(format!("res-{base_file_name}"))) + let mut res_file = OpenOptions::new() + .create(true) + .write(true) + .open(Path::new(&args.storage_path).join(format!("res-{base_file_name}"))) .expect("Failed to create req file"); let mut root_cert_store = RootCertStore::empty(); @@ -81,10 +92,15 @@ async fn main() -> Result<(), Box> { env_logger::init_from_env(env .with_root_certificates(root_cert_store) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(config)); - let dnsname = ServerName::try_from(args.upstream_dns).unwrap(); + let dnsname = ServerName::try_from(args.upstream_dns.to_string()).unwrap(); - let stream = TcpStream::connect(args.upstream_ip).await.expect("Failed to connect to upstream"); - let mut upstream = connector.connect(dnsname, stream).await.expect("Failed to establish TLS connection"); + let stream = TcpStream::connect(args.upstream_ip) + .await + .expect("Failed to connect to upstream"); + let mut upstream = connector + .connect(dnsname, stream) + .await + .expect("Failed to establish TLS connection"); let (mut client_read, mut client_write) = client_socket.split(); @@ -92,6 +108,8 @@ async fn main() -> Result<(), Box> { env_logger::init_from_env(env let mut buf_client = [0u8; 1024]; let mut buf_server = [0u8; 1024]; + let mut modified_headers = false; + loop { let tick = interval.tick(); // required for select() @@ -99,15 +117,35 @@ async fn main() -> Result<(), Box> { env_logger::init_from_env(env tokio::select! { count = client_read.read(&mut buf_client) => { - let count = count.expect("Failed to read from client socket"); + let count = match count{ Ok(count) => count, Err(e) => { + log::error!("Failed to read data from client, closing connection! {e}"); + return; + }}; log::info!("Got a new client read {count}"); - upstream.write_all(&buf_client[..count]).await.expect("Failed to write to upstream"); - req_file.write_all(&buf_client[..count]).expect("Failed to write to req"); + if count == 0 { + log::warn!("infinite loop"); + return; + } + + // We need to modify some headers (if not done already) to adapt the request to the server + let buff = if !modified_headers { + modified_headers = true; + manipulate_headers(&buf_client[..count], &args.upstream_dns) + } + else { + buf_client[..count].to_vec() + }; + + upstream.write_all(&buff).await.expect("Failed to write to upstream"); + req_file.write_all(&buff).expect("Failed to write to req"); } count = upstream.read(&mut buf_server) => { - let count = count.expect("Failed to read from server socket"); + let count = match count{ Ok(count) => count,Err(e) => { + log::error!("Failed to read from upstream! {e}"); + return; + }}; log::info!("Got a new upstream read {count}"); client_write.write_all(&buf_server[..count]).await.expect("Failed to write to client"); res_file.write_all(&buf_server[..count]).expect("Failed to write to res"); @@ -117,3 +155,41 @@ async fn main() -> Result<(), Box> { env_logger::init_from_env(env }); } } + +fn manipulate_headers(buff: &[u8], host: &str) -> Vec { + // return buff.to_vec(); + + let mut out = Vec::with_capacity(buff.len()); + + let mut i = 0; + while i < buff.len() { + if buff[i] != b'\n' || i + 1 == buff.len() || !buff[i + 1..].starts_with(b"Host:") { + out.push(buff[i]); + i += 1; + continue; + } + + i += 1; + out.push(b'\n'); + + out.append(&mut format!("Host: {host}").as_bytes().to_vec()); + + while buff[i] != b'\r' && buff[i] != b'\n' { + i += 1; + } + } + + out + + /*// Work line per line + buff.to_vec() + .split(&b'\n') + .map(|l| { + if l.starts_with(b"Host:") { + format!("Host: {host}\r") + } else { + l.to_owned() + } + }) + .collect::>()*/ +}