diff --git a/matrixgw_backend/Cargo.lock b/matrixgw_backend/Cargo.lock index 33a12ff..46d938c 100644 --- a/matrixgw_backend/Cargo.lock +++ b/matrixgw_backend/Cargo.lock @@ -241,6 +241,20 @@ dependencies = [ "syn", ] +[[package]] +name = "actix-ws" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3a1fb4f9f2794b0aadaf2ba5f14a6f034c7e86957b458c506a8cb75953f2d99" +dependencies = [ + "actix-codec", + "actix-http", + "actix-web", + "bytestring", + "futures-core", + "tokio", +] + [[package]] name = "adler2" version = "2.0.1" @@ -3026,6 +3040,7 @@ dependencies = [ "actix-remote-ip", "actix-session", "actix-web", + "actix-ws", "anyhow", "base16ct 0.3.0", "bytes", diff --git a/matrixgw_backend/Cargo.toml b/matrixgw_backend/Cargo.toml index 9474195..f3b32d1 100644 --- a/matrixgw_backend/Cargo.toml +++ b/matrixgw_backend/Cargo.toml @@ -32,4 +32,5 @@ matrix-sdk = "0.14.0" url = "2.5.7" ractor = "0.15.9" serde_json = "1.0.145" -lazy-regex = "3.4.2" \ No newline at end of file +lazy-regex = "3.4.2" +actix-ws = "0.3.0" \ No newline at end of file diff --git a/matrixgw_backend/src/constants.rs b/matrixgw_backend/src/constants.rs index 2b8b86a..93ee5b7 100644 --- a/matrixgw_backend/src/constants.rs +++ b/matrixgw_backend/src/constants.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + /// Auth header pub const API_AUTH_HEADER: &str = "x-client-auth"; @@ -16,3 +18,11 @@ pub mod sessions { /// Authenticated ID pub const USER_ID: &str = "uid"; } + +/// How often heartbeat pings are sent. +/// +/// Should be half (or less) of the acceptable client timeout. +pub const WS_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); + +/// How long before lack of client response causes a timeout. +pub const WS_CLIENT_TIMEOUT: Duration = Duration::from_secs(10); diff --git a/matrixgw_backend/src/controllers/mod.rs b/matrixgw_backend/src/controllers/mod.rs index 01d8545..f457c29 100644 --- a/matrixgw_backend/src/controllers/mod.rs +++ b/matrixgw_backend/src/controllers/mod.rs @@ -7,6 +7,7 @@ pub mod matrix_link_controller; pub mod matrix_sync_thread_controller; pub mod server_controller; pub mod tokens_controller; +pub mod ws_controller; #[derive(thiserror::Error, Debug)] pub enum HttpFailure { @@ -18,6 +19,8 @@ pub enum HttpFailure { OpenID(Box), #[error("an unspecified internal error occurred: {0}")] InternalError(#[from] anyhow::Error), + #[error("Actix web error: {0}")] + ActixError(#[from] actix_web::Error), } impl ResponseError for HttpFailure { diff --git a/matrixgw_backend/src/controllers/ws_controller.rs b/matrixgw_backend/src/controllers/ws_controller.rs new file mode 100644 index 0000000..9ac0092 --- /dev/null +++ b/matrixgw_backend/src/controllers/ws_controller.rs @@ -0,0 +1,198 @@ +use crate::broadcast_messages::BroadcastMessage; +use crate::constants; +use crate::controllers::HttpResult; +use crate::extractors::auth_extractor::{AuthExtractor, AuthenticatedMethod}; +use crate::extractors::matrix_client_extractor::MatrixClientExtractor; +use crate::matrix_connection::matrix_client::MatrixClient; +use crate::matrix_connection::matrix_manager::MatrixManagerMsg; +use actix_web::dev::Payload; +use actix_web::{FromRequest, HttpRequest, web}; +use actix_ws::Message; +use futures_util::StreamExt; +use matrix_sdk::ruma::OwnedRoomId; +use matrix_sdk::ruma::events::room::message::RoomMessageEventContent; +use ractor::ActorRef; +use std::time::Instant; +use tokio::sync::broadcast; +use tokio::sync::broadcast::Receiver; +use tokio::time::interval; + +/// Messages sent to the client +#[derive(Debug, serde::Serialize)] +#[serde(tag = "type")] +pub enum WsMessage { + /// Room message event + RoomMessageEvent { + event: RoomMessageEventContent, + room_id: OwnedRoomId, + }, +} + +/// Main WS route +pub async fn ws( + req: HttpRequest, + stream: web::Payload, + tx: web::Data>, + manager: web::Data>, +) -> HttpResult { + // Forcefully ignore request payload by manually extracting authentication information + let client = MatrixClientExtractor::from_request(&req, &mut Payload::None).await?; + + // Ensure sync thread is started + ractor::cast!( + manager, + MatrixManagerMsg::StartSyncThread(client.auth.user.email.clone()) + ) + .expect("Failed to start sync thread prior to running WebSocket!"); + + let rx = tx.subscribe(); + + let (res, session, msg_stream) = actix_ws::handle(&req, stream)?; + + // spawn websocket handler (and don't await it) so that the response is returned immediately + actix_web::rt::spawn(ws_handler( + session, + msg_stream, + client.auth, + client.client, + rx, + )); + + Ok(res) +} + +pub async fn ws_handler( + mut session: actix_ws::Session, + mut msg_stream: actix_ws::MessageStream, + auth: AuthExtractor, + client: MatrixClient, + mut rx: Receiver, +) { + log::info!( + "WS connected for user {:?} / auth method={}", + client.email, + auth.method.light_str() + ); + + let mut last_heartbeat = Instant::now(); + let mut interval = interval(constants::WS_HEARTBEAT_INTERVAL); + + let reason = loop { + // waits for either `msg_stream` to receive a message from the client, the broadcast channel + // to send a message, or the heartbeat interval timer to tick, yielding the value of + // whichever one is ready first + tokio::select! { + ws_msg = rx.recv() => { + let msg = match ws_msg { + Ok(msg) => msg, + Err(broadcast::error::RecvError::Closed) => break None, + Err(broadcast::error::RecvError::Lagged(_)) => continue, + }; + + match msg { + BroadcastMessage::APITokenDeleted(t) => { + match &auth.method{ + AuthenticatedMethod::Token(tok) if tok.id == t.id => { + log::info!( + "closing WS session of user {:?} as associated token was deleted {:?}", + client.email, + t.base.name + ); + break None; + } + _=>{} + } + + }, + BroadcastMessage::UserDisconnectedFromMatrix(mail) if mail == auth.user.email => { + log::info!( + "closing WS session of user {mail:?} as user was disconnected from Matrix" + ); + break None; + } + + BroadcastMessage::RoomMessageEvent{user, event, room} if user == auth.user.email => { + // Send the message to the websocket + if let Ok(msg) = serde_json::to_string(&WsMessage::RoomMessageEvent { + event:event.content, + room_id: room.room_id().to_owned(), + }) && let Err(e) = session.text(msg).await { + log::error!("Failed to send SyncEvent: {e}"); + } + } + _ => {} + }; + + } + + // heartbeat interval ticked + _tick = interval.tick() => { + // if no heartbeat ping/pong received recently, close the connection + if Instant::now().duration_since(last_heartbeat) > constants::WS_CLIENT_TIMEOUT { + log::info!( + "client has not sent heartbeat in over {:?}; disconnecting",constants::WS_CLIENT_TIMEOUT + ); + + break None; + } + + // send heartbeat ping + let _ = session.ping(b"").await; + }, + + // Websocket messages + msg = msg_stream.next() => { + let msg = match msg { + // received message from WebSocket client + Some(Ok(msg)) => msg, + + // client WebSocket stream error + Some(Err(err)) => { + log::error!("{err}"); + break None; + } + + // client WebSocket stream ended + None => break None + }; + + log::debug!("msg: {msg:?}"); + + match msg { + Message::Text(s) => { + log::info!("Text message from WS: {s}"); + } + + Message::Binary(_) => { + // drop client's binary messages + } + + Message::Close(reason) => { + break reason; + } + + Message::Ping(bytes) => { + last_heartbeat = Instant::now(); + let _ = session.pong(&bytes).await; + } + + Message::Pong(_) => { + last_heartbeat = Instant::now(); + } + + Message::Continuation(_) => { + log::warn!("no support for continuation frames"); + } + + // no-op; ignore + Message::Nop => {} + }; + } + } + }; + + // attempt to close connection gracefully + let _ = session.close(reason).await; + + log::info!("WS disconnected for user {:?}", client.email); +} diff --git a/matrixgw_backend/src/extractors/auth_extractor.rs b/matrixgw_backend/src/extractors/auth_extractor.rs index b3a390f..339ca5d 100644 --- a/matrixgw_backend/src/extractors/auth_extractor.rs +++ b/matrixgw_backend/src/extractors/auth_extractor.rs @@ -28,6 +28,16 @@ pub enum AuthenticatedMethod { Token(APIToken), } +impl AuthenticatedMethod { + pub fn light_str(&self) -> String { + match self { + AuthenticatedMethod::Cookie => "Cookie".to_string(), + AuthenticatedMethod::Dev => "DevAuthentication".to_string(), + AuthenticatedMethod::Token(t) => format!("Token({:?} - {})", t.id, t.base.name), + } + } +} + pub struct AuthExtractor { pub user: User, pub method: AuthenticatedMethod, diff --git a/matrixgw_backend/src/main.rs b/matrixgw_backend/src/main.rs index f873bf8..75a52ef 100644 --- a/matrixgw_backend/src/main.rs +++ b/matrixgw_backend/src/main.rs @@ -11,7 +11,7 @@ use matrixgw_backend::broadcast_messages::BroadcastMessage; use matrixgw_backend::constants; use matrixgw_backend::controllers::{ auth_controller, matrix_link_controller, matrix_sync_thread_controller, server_controller, - tokens_controller, + tokens_controller, ws_controller, }; use matrixgw_backend::matrix_connection::matrix_manager::MatrixManagerActor; use matrixgw_backend::users::User; @@ -133,6 +133,7 @@ async fn main() -> std::io::Result<()> { "/api/matrix_sync/status", web::get().to(matrix_sync_thread_controller::status), ) + .service(web::resource("/api/ws").route(web::get().to(ws_controller::ws))) }) .workers(4) .bind(&AppConfig::get().listen_address)?