diff --git a/Cargo.lock b/Cargo.lock index 6ca9533..b30e3f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1463,8 +1463,10 @@ dependencies = [ "base", "clap", "env_logger", + "futures", "log", "serde", + "tokio", ] [[package]] diff --git a/tcp_relay_server/Cargo.toml b/tcp_relay_server/Cargo.toml index b7a4d3f..4b5ebc8 100644 --- a/tcp_relay_server/Cargo.toml +++ b/tcp_relay_server/Cargo.toml @@ -11,4 +11,6 @@ env_logger = "0.9.0" actix = "0.13.0" actix-web = "4" actix-web-actors = "4.1.0" -serde = { version = "1.0.144", features = ["derive"] } \ No newline at end of file +serde = { version = "1.0.144", features = ["derive"] } +tokio = { version = "1", features = ["full"] } +futures = "0.3.24" \ No newline at end of file diff --git a/tcp_relay_server/src/relay_ws.rs b/tcp_relay_server/src/relay_ws.rs index 1bc2907..0e67fa7 100644 --- a/tcp_relay_server/src/relay_ws.rs +++ b/tcp_relay_server/src/relay_ws.rs @@ -1,16 +1,62 @@ use std::sync::Arc; -use actix::{Actor, StreamHandler}; +use actix::{Actor, ActorContext, Addr, ArbiterHandle, AsyncContext, Context, Handler, Message, Running, StreamHandler}; use actix_web::{Error, HttpRequest, HttpResponse, web}; +use actix_web::web::Data; use actix_web_actors::ws; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::net::TcpStream; +use tokio::sync::mpsc; use crate::args::Args; +#[derive(Message)] +#[rtype(result = "bool")] +pub struct DataForWebSocket(Vec); + + /// Define HTTP actor -struct RelayWS; +struct RelayWS { + tcp_read: Option, + tcp_write: OwnedWriteHalf, + + // TODO : add disconnect after ping timeout + + // TODO : handle socket close +} impl Actor for RelayWS { type Context = ws::WebsocketContext; + + fn started(&mut self, ctx: &mut Self::Context) { + // Start to read on remote socket + let mut read_half = self.tcp_read.take().unwrap(); + let addr = ctx.address(); + let future = async move { + let mut buff: [u8; 5000] = [0; 5000]; + loop { + match read_half.read(&mut buff).await { + Ok(l) => { + let to_send = DataForWebSocket(Vec::from(&buff[0..l])); + if let Err(e) = addr.send(to_send).await { + log::error!("Failed to send to websocket. Stopping now... {:?}", e); + return; + } + } + Err(e) => { + log::error!("Failed to read from remote socket. Stopping now... {:?}", e); + break; + } + }; + } + + log::info!("Exited read loop"); + // TODO : notify context + }; + + tokio::spawn(future); + } } /// Handler for ws::Message message @@ -19,12 +65,28 @@ impl StreamHandler> for RelayWS { match msg { Ok(ws::Message::Ping(msg)) => ctx.pong(&msg), Ok(ws::Message::Text(text)) => ctx.text(text), - Ok(ws::Message::Binary(bin)) => ctx.binary(bin), + Ok(ws::Message::Close(_reason)) => ctx.stop(), + Ok(ws::Message::Binary(data)) => { + if let Err(e) = futures::executor::block_on(self.tcp_write.write_all(&data.to_vec())) { + log::error!("Failed to forward some data, closing connection!"); + ctx.stop(); + } + } _ => (), } } } +impl Handler for RelayWS { + type Result = bool; + + fn handle(&mut self, msg: DataForWebSocket, ctx: &mut Self::Context) -> Self::Result { + ctx.binary(msg.0); + true + } +} + + #[derive(serde::Deserialize)] pub struct WebSocketQuery { id: usize, @@ -46,7 +108,17 @@ pub async fn relay_ws(req: HttpRequest, stream: web::Payload, let upstream_addr = format!("{}:{}", conf.upstream_server, conf.ports[query.id]); - let resp = ws::start(RelayWS {}, &req, stream); + let (tcp_read, tcp_write) = match TcpStream::connect(&upstream_addr).await { + Ok(s) => s.into_split(), + Err(e) => { + log::error!("Failed to establish connection with upstream server! {:?}", e); + return Ok(HttpResponse::InternalServerError() + .json("Failed to establish connection!")); + } + }; + + let relay = RelayWS { tcp_read: Some(tcp_read), tcp_write }; + let resp = ws::start(relay, &req, stream); log::info!("Opening new WS connection for {:?} to {}", req.peer_addr(), upstream_addr); resp } \ No newline at end of file