From a866deb3e4a4d4169d3d1b9d00e753b245a71034 Mon Sep 17 00:00:00 2001 From: Pierre Hubert Date: Tue, 30 Aug 2022 14:47:16 +0200 Subject: [PATCH] Automatically close unresponsive websockets --- tcp_relay_server/src/relay_ws.rs | 39 ++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/tcp_relay_server/src/relay_ws.rs b/tcp_relay_server/src/relay_ws.rs index 9247380..913c2d2 100644 --- a/tcp_relay_server/src/relay_ws.rs +++ b/tcp_relay_server/src/relay_ws.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::time::{Duration, Instant}; use actix::{Actor, ActorContext, AsyncContext, Handler, Message, StreamHandler}; use actix_web::{Error, HttpRequest, HttpResponse, web}; @@ -9,6 +10,12 @@ use tokio::net::TcpStream; use crate::args::Args; +/// How often heartbeat pings are sent +const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); + +/// How long before lack of client response causes a timeout +const CLIENT_TIMEOUT: Duration = Duration::from_secs(60); + #[derive(Message)] #[rtype(result = "bool")] pub struct DataForWebSocket(Vec); @@ -19,15 +26,42 @@ struct RelayWS { tcp_read: Option, tcp_write: OwnedWriteHalf, - // TODO : add disconnect after ping timeout + // Client must respond to ping at a specific interval, otherwise we drop connection + hb: Instant, // TODO : handle socket close } +impl RelayWS { + /// helper method that sends ping to client every second. + /// + /// also this method checks heartbeats from client + fn hb(&self, ctx: &mut ws::WebsocketContext) { + ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { + // check client heartbeats + if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT { + // heartbeat timed out + log::warn!("WebSocket Client heartbeat failed, disconnecting!"); + + // stop actor + ctx.stop(); + + // don't try to send a ping + return; + } + + log::debug!("Send ping message..."); + ctx.ping(b""); + }); + } +} + impl Actor for RelayWS { type Context = ws::WebsocketContext; fn started(&mut self, ctx: &mut Self::Context) { + self.hb(ctx); + // Start to read on remote socket let mut read_half = self.tcp_read.take().unwrap(); let addr = ctx.address(); @@ -67,6 +101,7 @@ impl StreamHandler> for RelayWS { fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { match msg { Ok(ws::Message::Ping(msg)) => ctx.pong(&msg), + Ok(ws::Message::Pong(_)) => self.hb = Instant::now(), Ok(ws::Message::Text(text)) => ctx.text(text), Ok(ws::Message::Close(_reason)) => ctx.stop(), Ok(ws::Message::Binary(data)) => { @@ -125,7 +160,7 @@ pub async fn relay_ws(req: HttpRequest, stream: web::Payload, } }; - let relay = RelayWS { tcp_read: Some(tcp_read), tcp_write }; + let relay = RelayWS { tcp_read: Some(tcp_read), tcp_write, hb: Instant::now() }; let resp = ws::start(relay, &req, stream); log::info!("Opening new WS connection for {:?} to {}", req.peer_addr(), upstream_addr); resp