diff --git a/Cargo.lock b/Cargo.lock index 219a896..733e3a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2072,7 +2072,6 @@ dependencies = [ "serde_json", "sha2 0.11.0-pre.4", "thiserror 2.0.11", - "time", "tokio", "urlencoding", "uuid", diff --git a/Cargo.toml b/Cargo.toml index a02c20d..ac04c9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,5 +32,4 @@ sha2 = "0.11.0-pre.4" base16ct = "0.2.0" ruma = { version = "0.12.0", features = ["client-api-c", "client-ext-client-api", "client-hyper-native-tls", "rand"] } actix-ws = "0.3.0" -tokio = { version = "1.43.0", features = ["rt", "time", "macros"] } -time = "0.3.37" \ No newline at end of file +tokio = { version = "1.43.0", features = ["rt", "time", "macros", "rt-multi-thread"] } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index cdb1d71..89183c5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,10 +4,11 @@ use actix_session::{storage::RedisSessionStore, SessionMiddleware}; use actix_web::cookie::Key; use actix_web::{web, App, HttpServer}; use matrix_gateway::app_config::AppConfig; +use matrix_gateway::server::api::ws::WsMessage; use matrix_gateway::server::{api, web_ui}; use matrix_gateway::user::UserConfig; -#[actix_web::main] +#[tokio::main] async fn main() -> std::io::Result<()> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); @@ -21,6 +22,8 @@ async fn main() -> std::io::Result<()> { .await .expect("Failed to connect to Redis!"); + let (ws_tx, _) = tokio::sync::broadcast::channel::(16); + log::info!( "Starting to listen on {} for {}", AppConfig::get().listen_address, @@ -38,6 +41,7 @@ async fn main() -> std::io::Result<()> { .app_data(web::Data::new(RemoteIPConfig { proxy: AppConfig::get().proxy_ip.clone(), })) + .app_data(web::Data::new(ws_tx.clone())) // Web configuration routes .route("/assets/{tail:.*}", web::get().to(web_ui::static_file)) .route("/", web::get().to(web_ui::home)) @@ -48,8 +52,9 @@ async fn main() -> std::io::Result<()> { .route("/api", web::get().to(api::api_home)) .route("/api", web::post().to(api::api_home)) .route("/api/account/whoami", web::get().to(api::account::who_am_i)) - .service(web::resource("/api/ws").route(web::get().to(api::ws))) + .service(web::resource("/api/ws").route(web::get().to(api::ws::ws))) }) + .workers(4) .bind(&AppConfig::get().listen_address)? .run() .await diff --git a/src/server/api/mod.rs b/src/server/api/mod.rs index f0389b2..f515e3e 100644 --- a/src/server/api/mod.rs +++ b/src/server/api/mod.rs @@ -1,112 +1,11 @@ -use crate::constants::{WS_CLIENT_TIMEOUT, WS_HEARTBEAT_INTERVAL}; use crate::extractors::client_auth::APIClientAuth; use crate::server::HttpResult; -use actix_web::dev::Payload; -use actix_web::{web, FromRequest, HttpRequest, HttpResponse}; -use actix_ws::Message; -use futures_util::future::Either; -use futures_util::{future, StreamExt}; -use std::time::Instant; -use tokio::{pin, time::interval}; +use actix_web::HttpResponse; pub mod account; +pub mod ws; /// API Home route pub async fn api_home(auth: APIClientAuth) -> HttpResult { Ok(HttpResponse::Ok().body(format!("Welcome user {}!", auth.user.user_id.0))) } - -/// Main WS route -pub async fn ws(req: HttpRequest, stream: web::Payload) -> HttpResult { - // Forcefully ignore request payload by manually extracting authentication information - let auth = APIClientAuth::from_request(&req, &mut Payload::None).await?; - - let (res, session, msg_stream) = actix_ws::handle(&req, stream)?; - - // spawn websocket handler (and don't await it) so that the response is returned immediately - actix_web::rt::spawn(ws_handler(session, msg_stream)); - - Ok(res) -} - -pub async fn ws_handler(mut session: actix_ws::Session, mut msg_stream: actix_ws::MessageStream) { - log::info!("WS connected"); - - let mut last_heartbeat = Instant::now(); - let mut interval = interval(WS_HEARTBEAT_INTERVAL); - - let reason = loop { - // create "next client timeout check" future - let tick = interval.tick(); - // required for select() - pin!(tick); - - // waits for either `msg_stream` to receive a message from the client or the heartbeat - // interval timer to tick, yielding the value of whichever one is ready first - match future::select(msg_stream.next(), tick).await { - // received message from WebSocket client - Either::Left((Some(Ok(msg)), _)) => { - log::debug!("msg: {msg:?}"); - - match msg { - Message::Text(text) => { - session.text(text).await.unwrap(); - } - - Message::Binary(bin) => { - session.binary(bin).await.unwrap(); - } - - Message::Close(reason) => { - break reason; - } - - Message::Ping(bytes) => { - last_heartbeat = Instant::now(); - let _ = session.pong(&bytes).await; - } - - Message::Pong(_) => { - last_heartbeat = Instant::now(); - } - - Message::Continuation(_) => { - log::warn!("no support for continuation frames"); - } - - // no-op; ignore - Message::Nop => {} - }; - } - - // client WebSocket stream error - Either::Left((Some(Err(err)), _)) => { - log::error!("{}", err); - break None; - } - - // client WebSocket stream ended - Either::Left((None, _)) => break None, - - // heartbeat interval ticked - Either::Right((_inst, _)) => { - // if no heartbeat ping/pong received recently, close the connection - if Instant::now().duration_since(last_heartbeat) > WS_CLIENT_TIMEOUT { - log::info!( - "client has not sent heartbeat in over {WS_CLIENT_TIMEOUT:?}; disconnecting" - ); - - break None; - } - - // send heartbeat ping - let _ = session.ping(b"").await; - } - } - }; - - // attempt to close connection gracefully - let _ = session.close(reason).await; - - log::info!("WS disconnected"); -} diff --git a/src/server/api/ws.rs b/src/server/api/ws.rs new file mode 100644 index 0000000..8b2b86b --- /dev/null +++ b/src/server/api/ws.rs @@ -0,0 +1,149 @@ +use crate::constants::{WS_CLIENT_TIMEOUT, WS_HEARTBEAT_INTERVAL}; +use crate::extractors::client_auth::APIClientAuth; +use crate::server::HttpResult; +use crate::user::{APIClientID, UserID}; +use actix_web::dev::Payload; +use actix_web::{web, FromRequest, HttpRequest}; +use actix_ws::Message; +use futures_util::StreamExt; +use std::time::Instant; +use tokio::select; +use tokio::sync::broadcast; +use tokio::sync::broadcast::Receiver; +use tokio::time::interval; + +/// WebSocket message +#[derive(Debug, Clone)] +pub enum WsMessage { + /// Request to close the session of a specific client + CloseClientSession(APIClientID), + /// Close all the sessions of a given user + CloseAllUserSessions(UserID), +} + +/// Main WS route +pub async fn ws( + req: HttpRequest, + stream: web::Payload, + tx: web::Data>, +) -> HttpResult { + // Forcefully ignore request payload by manually extracting authentication information + let auth = APIClientAuth::from_request(&req, &mut Payload::None).await?; + + let (res, session, msg_stream) = actix_ws::handle(&req, stream)?; + + let rx = tx.subscribe(); + + // spawn websocket handler (and don't await it) so that the response is returned immediately + actix_web::rt::spawn(ws_handler(session, msg_stream, auth, rx)); + + Ok(res) +} + +pub async fn ws_handler( + mut session: actix_ws::Session, + mut msg_stream: actix_ws::MessageStream, + auth: APIClientAuth, + mut rx: Receiver, +) { + log::info!("WS connected"); + + let mut last_heartbeat = Instant::now(); + let mut interval = interval(WS_HEARTBEAT_INTERVAL); + + let reason = loop { + // waits for either `msg_stream` to receive a message from the client, the broadcast channel + // to send a message, or the heartbeat interval timer to tick, yielding the value of + // whichever one is ready first + select! { + ws_msg = rx.recv() => { + let msg = match ws_msg { + Ok(msg) => msg, + Err(broadcast::error::RecvError::Closed) => break None, + Err(broadcast::error::RecvError::Lagged(_)) => continue, + }; + + match msg { + WsMessage::CloseClientSession(_) => todo!(), + WsMessage::CloseAllUserSessions(userid) => { + if userid == auth.user.user_id { + log::info!( + "closing WS session of user {userid:?} as requested" + ); + break None; + } + } + }; + + } + + // heartbeat interval ticked + _tick = interval.tick() => { + // if no heartbeat ping/pong received recently, close the connection + if Instant::now().duration_since(last_heartbeat) > WS_CLIENT_TIMEOUT { + log::info!( + "client has not sent heartbeat in over {WS_CLIENT_TIMEOUT:?}; disconnecting" + ); + + break None; + } + + // send heartbeat ping + let _ = session.ping(b"").await; + }, + + msg = msg_stream.next() => { + let msg = match msg { + // received message from WebSocket client + Some(Ok(msg)) => msg, + + // client WebSocket stream error + Some(Err(err)) => { + log::error!("{err}"); + break None; + } + + // client WebSocket stream ended + None => break None + }; + + log::debug!("msg: {msg:?}"); + + match msg { + Message::Text(s) => { + log::info!("Text message: {s}"); + } + + Message::Binary(_) => { + // drop client's binary messages + } + + Message::Close(reason) => { + break reason; + } + + Message::Ping(bytes) => { + last_heartbeat = Instant::now(); + let _ = session.pong(&bytes).await; + } + + Message::Pong(_) => { + last_heartbeat = Instant::now(); + } + + Message::Continuation(_) => { + log::warn!("no support for continuation frames"); + } + + // no-op; ignore + Message::Nop => {} + }; + } + } + }; + + // attempt to close connection gracefully + let _ = session.close(reason).await; + + log::info!("WS disconnected"); +} diff --git a/src/server/web_ui.rs b/src/server/web_ui.rs index 9da87dc..adc0462 100644 --- a/src/server/web_ui.rs +++ b/src/server/web_ui.rs @@ -1,5 +1,6 @@ use crate::app_config::AppConfig; use crate::constants::{STATE_KEY, USER_SESSION_KEY}; +use crate::server::api::ws::WsMessage; use crate::server::{HttpFailure, HttpResult}; use crate::user::{APIClient, APIClientID, User, UserConfig, UserID}; use crate::utils; @@ -9,6 +10,7 @@ use askama::Template; use ipnet::IpNet; use light_openid::primitives::OpenIDConfig; use std::str::FromStr; +use tokio::sync::broadcast; /// Static assets #[derive(rust_embed::Embed)] @@ -60,7 +62,11 @@ pub struct FormRequest { } /// Main route -pub async fn home(session: Session, form_req: Option>) -> HttpResult { +pub async fn home( + session: Session, + form_req: Option>, + tx: web::Data>, +) -> HttpResult { // Get user information, requesting authentication if information is missing let Some(user): Option = session.get(USER_SESSION_KEY)? else { // Generate auth state @@ -93,10 +99,14 @@ pub async fn home(session: Session, form_req: Option>) -> if t.len() < 3 { error_message = Some("Specified Matrix token is too short!".to_string()); } else { - // TODO : invalidate all existing connections config.matrix_token = t; config.save().await?; success_message = Some("Matrix token was successfully updated!".to_string()); + + // Invalidate all Ws connections + if let Err(e) = tx.send(WsMessage::CloseAllUserSessions(user.id.clone())) { + log::error!("Failed to send CloseAllUserSessions: {}", e); + } } } diff --git a/src/user.rs b/src/user.rs index f2c035e..25005db 100644 --- a/src/user.rs +++ b/src/user.rs @@ -19,7 +19,7 @@ pub enum UserError { MissingMatrixToken, } -#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] +#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)] pub struct UserID(pub String); impl UserID {