use crate::broadcast_messages::{BroadcastMessage, SyncEvent}; use crate::constants::{WS_CLIENT_TIMEOUT, WS_HEARTBEAT_INTERVAL}; use crate::extractors::client_auth::APIClientAuth; use crate::server::HttpResult; use actix_web::dev::Payload; use actix_web::{web, FromRequest, HttpRequest}; use actix_ws::Message; use futures_util::StreamExt; use std::time::Instant; use tokio::select; use tokio::sync::broadcast; use tokio::sync::broadcast::Receiver; use tokio::time::interval; /// Messages send to the client #[derive(Debug, serde::Deserialize, serde::Serialize)] #[serde(tag = "type")] pub enum WsMessage { Sync(SyncEvent), } /// Main WS route pub async fn ws( req: HttpRequest, stream: web::Payload, tx: web::Data>, ) -> HttpResult { // Forcefully ignore request payload by manually extracting authentication information let auth = APIClientAuth::from_request(&req, &mut Payload::None).await?; let (res, session, msg_stream) = actix_ws::handle(&req, stream)?; // Ask for sync client to be started if let Err(e) = tx.send(BroadcastMessage::StartSyncTaskForUser( auth.user.user_id.clone(), )) { log::error!("Failed to send StartSyncTaskForUser: {}", e); } let rx = tx.subscribe(); // spawn websocket handler (and don't await it) so that the response is returned immediately actix_web::rt::spawn(ws_handler(session, msg_stream, auth, rx)); Ok(res) } pub async fn ws_handler( mut session: actix_ws::Session, mut msg_stream: actix_ws::MessageStream, auth: APIClientAuth, mut rx: Receiver, ) { log::info!("WS connected"); let mut last_heartbeat = Instant::now(); let mut interval = interval(WS_HEARTBEAT_INTERVAL); let reason = loop { // waits for either `msg_stream` to receive a message from the client, the broadcast channel // to send a message, or the heartbeat interval timer to tick, yielding the value of // whichever one is ready first select! { ws_msg = rx.recv() => { let msg = match ws_msg { Ok(msg) => msg, Err(broadcast::error::RecvError::Closed) => break None, Err(broadcast::error::RecvError::Lagged(_)) => continue, }; match msg { BroadcastMessage::CloseClientSession(id) => { if let Some(client) = &auth.client { if client.id == id { log::info!( "closing client session {id:?} of user {:?} as requested", auth.user.user_id ); break None; } } }, BroadcastMessage::CloseAllUserSessions(userid) => { if userid == auth.user.user_id { log::info!( "closing WS session of user {userid:?} as requested" ); break None; } } BroadcastMessage::SyncEvent(userid, event) => { if userid != auth.user.user_id { continue; } // Send the message to the websocket if let Ok(msg) = serde_json::to_string(&WsMessage::Sync(event)) { if let Err(e) = session.text(msg).await { log::error!("Failed to send SyncEvent: {}", e); } } } _ => {}}; } // heartbeat interval ticked _tick = interval.tick() => { // if no heartbeat ping/pong received recently, close the connection if Instant::now().duration_since(last_heartbeat) > WS_CLIENT_TIMEOUT { log::info!( "client has not sent heartbeat in over {WS_CLIENT_TIMEOUT:?}; disconnecting" ); break None; } // send heartbeat ping let _ = session.ping(b"").await; }, msg = msg_stream.next() => { let msg = match msg { // received message from WebSocket client Some(Ok(msg)) => msg, // client WebSocket stream error Some(Err(err)) => { log::error!("{err}"); break None; } // client WebSocket stream ended None => break None }; log::debug!("msg: {msg:?}"); match msg { Message::Text(s) => { log::info!("Text message: {s}"); } Message::Binary(_) => { // drop client's binary messages } Message::Close(reason) => { break reason; } Message::Ping(bytes) => { last_heartbeat = Instant::now(); let _ = session.pong(&bytes).await; } Message::Pong(_) => { last_heartbeat = Instant::now(); } Message::Continuation(_) => { log::warn!("no support for continuation frames"); } // no-op; ignore Message::Nop => {} }; } } }; // attempt to close connection gracefully let _ = session.close(reason).await; log::info!("WS disconnected"); }