use std::sync::Arc; use std::time::{Duration, Instant}; use actix::{Actor, ActorContext, AsyncContext, Handler, Message, StreamHandler}; use actix_web::{web, Error, HttpRequest, HttpResponse}; use actix_web_actors::ws; use actix_web_actors::ws::{CloseCode, CloseReason}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; use crate::tcp_relay_server::server_config::ServerConfig; /// How often heartbeat pings are sent const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); /// How long before lack of client response causes a timeout const CLIENT_TIMEOUT: Duration = Duration::from_secs(60); #[derive(Message)] #[rtype(result = "bool")] pub struct DataForWebSocket(Vec<u8>); #[derive(Message)] #[rtype(result = "()")] pub struct TCPReadEndClosed; /// Define HTTP actor struct RelayWS { tcp_read: Option<OwnedReadHalf>, tcp_write: OwnedWriteHalf, // Client must respond to ping at a specific interval, otherwise we drop connection hb: Instant, // TODO : handle socket close } impl RelayWS { /// helper method that sends ping to client every second. /// /// also this method checks heartbeats from client fn hb(&self, ctx: &mut ws::WebsocketContext<Self>) { ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { // check client heartbeats if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT { // heartbeat timed out log::warn!("WebSocket Client heartbeat failed, disconnecting!"); // stop actor ctx.stop(); // don't try to send a ping return; } log::debug!("Send ping message..."); ctx.ping(b""); }); } } impl Actor for RelayWS { type Context = ws::WebsocketContext<Self>; fn started(&mut self, ctx: &mut Self::Context) { self.hb(ctx); // Start to read on remote socket let mut read_half = self.tcp_read.take().unwrap(); let addr = ctx.address(); let future = async move { let mut buff: [u8; 5000] = [0; 5000]; loop { match read_half.read(&mut buff).await { Ok(l) => { if l == 0 { log::info!("Got empty read. Closing read end..."); addr.do_send(TCPReadEndClosed); return; } let to_send = DataForWebSocket(Vec::from(&buff[0..l])); if let Err(e) = addr.send(to_send).await { log::error!("Failed to send to websocket. Stopping now... {:?}", e); return; } } Err(e) => { log::error!("Failed to read from remote socket. Stopping now... {:?}", e); break; } }; } log::info!("Exited read loop"); }; tokio::spawn(future); } } /// Handler for ws::Message message impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for RelayWS { fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) { match msg { Ok(ws::Message::Ping(msg)) => ctx.pong(&msg), Ok(ws::Message::Pong(_)) => self.hb = Instant::now(), Ok(ws::Message::Text(text)) => ctx.text(text), Ok(ws::Message::Close(_reason)) => ctx.stop(), Ok(ws::Message::Binary(data)) => { if let Err(e) = futures::executor::block_on(self.tcp_write.write_all(&data)) { log::error!("Failed to forward some data, closing connection! {:?}", e); ctx.stop(); } if data.is_empty() { log::info!("Got empty binary message. Closing websocket..."); ctx.stop(); } } _ => (), } } } impl Handler<DataForWebSocket> for RelayWS { type Result = bool; fn handle(&mut self, msg: DataForWebSocket, ctx: &mut Self::Context) -> Self::Result { ctx.binary(msg.0); true } } impl Handler<TCPReadEndClosed> for RelayWS { type Result = (); fn handle(&mut self, _msg: TCPReadEndClosed, ctx: &mut Self::Context) -> Self::Result { ctx.close(Some(CloseReason { code: CloseCode::Away, description: Some("TCP read end closed.".to_string()), })); } } #[derive(serde::Deserialize)] pub struct WebSocketQuery { id: usize, token: Option<String>, } pub async fn relay_ws( req: HttpRequest, stream: web::Payload, query: web::Query<WebSocketQuery>, conf: web::Data<Arc<ServerConfig>>, ) -> Result<HttpResponse, Error> { if conf.has_token_auth() && !conf .tokens .iter() .any(|t| t == query.token.as_deref().unwrap_or_default()) { log::error!( "Rejected WS request from {:?} due to invalid token!", req.peer_addr() ); return Ok(HttpResponse::Unauthorized().json("Invalid / missing token!")); } if conf.ports.len() <= query.id { log::error!( "Rejected WS request from {:?} due to invalid port number!", req.peer_addr() ); return Ok(HttpResponse::BadRequest().json("Invalid port number!")); } let upstream_addr = format!("{}:{}", conf.upstream_server, conf.ports[query.id]); let (tcp_read, tcp_write) = match TcpStream::connect(&upstream_addr).await { Ok(s) => s.into_split(), Err(e) => { log::error!( "Failed to establish connection with upstream server! {:?}", e ); return Ok(HttpResponse::InternalServerError().json("Failed to establish connection!")); } }; let relay = RelayWS { tcp_read: Some(tcp_read), tcp_write, hb: Instant::now(), }; let resp = ws::start(relay, &req, stream); log::info!( "Opening new WS connection:\ * for {:?}\ * to {}\ * token {:?}", req.peer_addr(), upstream_addr, query.token ); resp }