177 lines
5.9 KiB
Rust

use crate::broadcast_messages::{BroadcastMessage, SyncEvent};
use crate::constants::{WS_CLIENT_TIMEOUT, WS_HEARTBEAT_INTERVAL};
use crate::extractors::client_auth::APIClientAuth;
use crate::server::HttpResult;
use actix_web::dev::Payload;
use actix_web::{web, FromRequest, HttpRequest};
use actix_ws::Message;
use futures_util::StreamExt;
use std::time::Instant;
use tokio::select;
use tokio::sync::broadcast;
use tokio::sync::broadcast::Receiver;
use tokio::time::interval;
/// Messages send to the client
#[derive(Debug, serde::Deserialize, serde::Serialize)]
#[serde(tag = "type")]
pub enum WsMessage {
Sync(SyncEvent),
}
/// Main WS route
pub async fn ws(
req: HttpRequest,
stream: web::Payload,
tx: web::Data<broadcast::Sender<BroadcastMessage>>,
) -> HttpResult {
// Forcefully ignore request payload by manually extracting authentication information
let auth = APIClientAuth::from_request(&req, &mut Payload::None).await?;
let (res, session, msg_stream) = actix_ws::handle(&req, stream)?;
// Ask for sync client to be started
if let Err(e) = tx.send(BroadcastMessage::StartSyncTaskForUser(
auth.user.user_id.clone(),
)) {
log::error!("Failed to send StartSyncTaskForUser: {}", e);
}
let rx = tx.subscribe();
// spawn websocket handler (and don't await it) so that the response is returned immediately
actix_web::rt::spawn(ws_handler(session, msg_stream, auth, rx));
Ok(res)
}
pub async fn ws_handler(
mut session: actix_ws::Session,
mut msg_stream: actix_ws::MessageStream,
auth: APIClientAuth,
mut rx: Receiver<BroadcastMessage>,
) {
log::info!("WS connected");
let mut last_heartbeat = Instant::now();
let mut interval = interval(WS_HEARTBEAT_INTERVAL);
let reason = loop {
// waits for either `msg_stream` to receive a message from the client, the broadcast channel
// to send a message, or the heartbeat interval timer to tick, yielding the value of
// whichever one is ready first
select! {
ws_msg = rx.recv() => {
let msg = match ws_msg {
Ok(msg) => msg,
Err(broadcast::error::RecvError::Closed) => break None,
Err(broadcast::error::RecvError::Lagged(_)) => continue,
};
match msg {
BroadcastMessage::CloseClientSession(id) => {
if let Some(client) = &auth.client {
if client.id == id {
log::info!(
"closing client session {id:?} of user {:?} as requested", auth.user.user_id
);
break None;
}
}
},
BroadcastMessage::CloseAllUserSessions(userid) => {
if userid == auth.user.user_id {
log::info!(
"closing WS session of user {userid:?} as requested"
);
break None;
}
}
BroadcastMessage::SyncEvent(userid, event) => {
if userid != auth.user.user_id {
continue;
}
// Send the message to the websocket
if let Ok(msg) = serde_json::to_string(&WsMessage::Sync(event)) {
if let Err(e) = session.text(msg).await {
log::error!("Failed to send SyncEvent: {}", e);
}
}
}
_ => {}};
}
// heartbeat interval ticked
_tick = interval.tick() => {
// if no heartbeat ping/pong received recently, close the connection
if Instant::now().duration_since(last_heartbeat) > WS_CLIENT_TIMEOUT {
log::info!(
"client has not sent heartbeat in over {WS_CLIENT_TIMEOUT:?}; disconnecting"
);
break None;
}
// send heartbeat ping
let _ = session.ping(b"").await;
},
msg = msg_stream.next() => {
let msg = match msg {
// received message from WebSocket client
Some(Ok(msg)) => msg,
// client WebSocket stream error
Some(Err(err)) => {
log::error!("{err}");
break None;
}
// client WebSocket stream ended
None => break None
};
log::debug!("msg: {msg:?}");
match msg {
Message::Text(s) => {
log::info!("Text message: {s}");
}
Message::Binary(_) => {
// drop client's binary messages
}
Message::Close(reason) => {
break reason;
}
Message::Ping(bytes) => {
last_heartbeat = Instant::now();
let _ = session.pong(&bytes).await;
}
Message::Pong(_) => {
last_heartbeat = Instant::now();
}
Message::Continuation(_) => {
log::warn!("no support for continuation frames");
}
// no-op; ignore
Message::Nop => {}
};
}
}
};
// attempt to close connection gracefully
let _ = session.close(reason).await;
log::info!("WS disconnected");
}