Start exchange
This commit is contained in:
69
src/main.rs
69
src/main.rs
@ -1,7 +1,15 @@
|
||||
use std::error::Error;
|
||||
use clap::Parser;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use rustls_pki_types::ServerName;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::{pin};
|
||||
use tokio::time::interval;
|
||||
|
||||
/// Simple program that proxify requests and save responses
|
||||
#[derive(Parser, Debug)]
|
||||
@ -12,8 +20,12 @@ struct Args {
|
||||
listen_address: String,
|
||||
|
||||
/// Upstream address this server will connect to
|
||||
#[arg(short, long, default_value = "httpbin.org")]
|
||||
upstream: String,
|
||||
#[arg(short, long, default_value = "communiquons.org")]
|
||||
upstream_dns: String,
|
||||
|
||||
/// Upstream address this server will connect to
|
||||
#[arg(short('I'), long, default_value = "10.0.1.10:443")]
|
||||
upstream_ip: String,
|
||||
|
||||
/// The path on the server this server will save requests and responses
|
||||
#[arg(short, long, default_value = "storage")]
|
||||
@ -31,26 +43,49 @@ async fn main() -> Result<(), Box<dyn Error>> { env_logger::init_from_env(env
|
||||
|
||||
loop {
|
||||
// Asynchronously wait for an inbound socket.
|
||||
let (mut socket, _) = listener.accept().await?;
|
||||
let (mut client_socket, _) = listener.accept().await?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0; 1024];
|
||||
let args = Args::parse();
|
||||
|
||||
let mut root_cert_store = RootCertStore::empty();
|
||||
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
let config = ClientConfig::builder()
|
||||
.with_root_certificates(root_cert_store)
|
||||
.with_no_client_auth();
|
||||
let connector = TlsConnector::from(Arc::new(config));
|
||||
let dnsname = ServerName::try_from(args.upstream_dns).unwrap();
|
||||
|
||||
let stream = TcpStream::connect(args.upstream_ip).await.expect("Failed to connect to upstream");
|
||||
let mut upstream = connector.connect(dnsname, stream).await.expect("Failed to establish TLS connection");
|
||||
|
||||
let (mut client_read, mut client_write) = client_socket.split();
|
||||
|
||||
|
||||
|
||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
let mut interval = interval(HEARTBEAT_INTERVAL);
|
||||
let mut buf_client = [0u8; 1024];
|
||||
let mut buf_server = [0u8; 1024];
|
||||
|
||||
// In a loop, read data from the socket and write the data back.
|
||||
loop {
|
||||
let n = socket
|
||||
.read(&mut buf)
|
||||
.await
|
||||
.expect("failed to read data from socket");
|
||||
let tick = interval.tick();
|
||||
// required for select()
|
||||
pin!(tick);
|
||||
|
||||
if n == 0 {
|
||||
return;
|
||||
tokio::select! {
|
||||
count = client_read.read(&mut buf_client) => {
|
||||
let count = count.expect("Failed to read from client socket");
|
||||
log::info!("Got a new client read {count}");
|
||||
upstream.write_all(&buf_client[..count]).await.expect("Failed to write to upstream");
|
||||
}
|
||||
|
||||
count = upstream.read(&mut buf_server) => {
|
||||
let count = count.expect("Failed to read from server socket");
|
||||
log::info!("Got a new upstream read {count}");
|
||||
client_write.write_all(&buf_server[..count]).await.expect("Failed to write to client");
|
||||
}
|
||||
}
|
||||
|
||||
socket
|
||||
.write_all(&buf[0..n])
|
||||
.await
|
||||
.expect("failed to write data to socket");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
Reference in New Issue
Block a user