Can close all user sessions when Matrix token is changed

This commit is contained in:
Pierre HUBERT 2025-02-06 22:39:53 +01:00
parent 4ff72e073e
commit 558d5cda3f
7 changed files with 172 additions and 111 deletions

1
Cargo.lock generated
View File

@ -2072,7 +2072,6 @@ dependencies = [
"serde_json", "serde_json",
"sha2 0.11.0-pre.4", "sha2 0.11.0-pre.4",
"thiserror 2.0.11", "thiserror 2.0.11",
"time",
"tokio", "tokio",
"urlencoding", "urlencoding",
"uuid", "uuid",

View File

@ -32,5 +32,4 @@ sha2 = "0.11.0-pre.4"
base16ct = "0.2.0" base16ct = "0.2.0"
ruma = { version = "0.12.0", features = ["client-api-c", "client-ext-client-api", "client-hyper-native-tls", "rand"] } ruma = { version = "0.12.0", features = ["client-api-c", "client-ext-client-api", "client-hyper-native-tls", "rand"] }
actix-ws = "0.3.0" actix-ws = "0.3.0"
tokio = { version = "1.43.0", features = ["rt", "time", "macros"] } tokio = { version = "1.43.0", features = ["rt", "time", "macros", "rt-multi-thread"] }
time = "0.3.37"

View File

@ -4,10 +4,11 @@ use actix_session::{storage::RedisSessionStore, SessionMiddleware};
use actix_web::cookie::Key; use actix_web::cookie::Key;
use actix_web::{web, App, HttpServer}; use actix_web::{web, App, HttpServer};
use matrix_gateway::app_config::AppConfig; use matrix_gateway::app_config::AppConfig;
use matrix_gateway::server::api::ws::WsMessage;
use matrix_gateway::server::{api, web_ui}; use matrix_gateway::server::{api, web_ui};
use matrix_gateway::user::UserConfig; use matrix_gateway::user::UserConfig;
#[actix_web::main] #[tokio::main]
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));
@ -21,6 +22,8 @@ async fn main() -> std::io::Result<()> {
.await .await
.expect("Failed to connect to Redis!"); .expect("Failed to connect to Redis!");
let (ws_tx, _) = tokio::sync::broadcast::channel::<WsMessage>(16);
log::info!( log::info!(
"Starting to listen on {} for {}", "Starting to listen on {} for {}",
AppConfig::get().listen_address, AppConfig::get().listen_address,
@ -38,6 +41,7 @@ async fn main() -> std::io::Result<()> {
.app_data(web::Data::new(RemoteIPConfig { .app_data(web::Data::new(RemoteIPConfig {
proxy: AppConfig::get().proxy_ip.clone(), proxy: AppConfig::get().proxy_ip.clone(),
})) }))
.app_data(web::Data::new(ws_tx.clone()))
// Web configuration routes // Web configuration routes
.route("/assets/{tail:.*}", web::get().to(web_ui::static_file)) .route("/assets/{tail:.*}", web::get().to(web_ui::static_file))
.route("/", web::get().to(web_ui::home)) .route("/", web::get().to(web_ui::home))
@ -48,8 +52,9 @@ async fn main() -> std::io::Result<()> {
.route("/api", web::get().to(api::api_home)) .route("/api", web::get().to(api::api_home))
.route("/api", web::post().to(api::api_home)) .route("/api", web::post().to(api::api_home))
.route("/api/account/whoami", web::get().to(api::account::who_am_i)) .route("/api/account/whoami", web::get().to(api::account::who_am_i))
.service(web::resource("/api/ws").route(web::get().to(api::ws))) .service(web::resource("/api/ws").route(web::get().to(api::ws::ws)))
}) })
.workers(4)
.bind(&AppConfig::get().listen_address)? .bind(&AppConfig::get().listen_address)?
.run() .run()
.await .await

View File

@ -1,112 +1,11 @@
use crate::constants::{WS_CLIENT_TIMEOUT, WS_HEARTBEAT_INTERVAL};
use crate::extractors::client_auth::APIClientAuth; use crate::extractors::client_auth::APIClientAuth;
use crate::server::HttpResult; use crate::server::HttpResult;
use actix_web::dev::Payload; use actix_web::HttpResponse;
use actix_web::{web, FromRequest, HttpRequest, HttpResponse};
use actix_ws::Message;
use futures_util::future::Either;
use futures_util::{future, StreamExt};
use std::time::Instant;
use tokio::{pin, time::interval};
pub mod account; pub mod account;
pub mod ws;
/// API Home route /// API Home route
pub async fn api_home(auth: APIClientAuth) -> HttpResult { pub async fn api_home(auth: APIClientAuth) -> HttpResult {
Ok(HttpResponse::Ok().body(format!("Welcome user {}!", auth.user.user_id.0))) Ok(HttpResponse::Ok().body(format!("Welcome user {}!", auth.user.user_id.0)))
} }
/// Main WS route
pub async fn ws(req: HttpRequest, stream: web::Payload) -> HttpResult {
// Forcefully ignore request payload by manually extracting authentication information
let auth = APIClientAuth::from_request(&req, &mut Payload::None).await?;
let (res, session, msg_stream) = actix_ws::handle(&req, stream)?;
// spawn websocket handler (and don't await it) so that the response is returned immediately
actix_web::rt::spawn(ws_handler(session, msg_stream));
Ok(res)
}
pub async fn ws_handler(mut session: actix_ws::Session, mut msg_stream: actix_ws::MessageStream) {
log::info!("WS connected");
let mut last_heartbeat = Instant::now();
let mut interval = interval(WS_HEARTBEAT_INTERVAL);
let reason = loop {
// create "next client timeout check" future
let tick = interval.tick();
// required for select()
pin!(tick);
// waits for either `msg_stream` to receive a message from the client or the heartbeat
// interval timer to tick, yielding the value of whichever one is ready first
match future::select(msg_stream.next(), tick).await {
// received message from WebSocket client
Either::Left((Some(Ok(msg)), _)) => {
log::debug!("msg: {msg:?}");
match msg {
Message::Text(text) => {
session.text(text).await.unwrap();
}
Message::Binary(bin) => {
session.binary(bin).await.unwrap();
}
Message::Close(reason) => {
break reason;
}
Message::Ping(bytes) => {
last_heartbeat = Instant::now();
let _ = session.pong(&bytes).await;
}
Message::Pong(_) => {
last_heartbeat = Instant::now();
}
Message::Continuation(_) => {
log::warn!("no support for continuation frames");
}
// no-op; ignore
Message::Nop => {}
};
}
// client WebSocket stream error
Either::Left((Some(Err(err)), _)) => {
log::error!("{}", err);
break None;
}
// client WebSocket stream ended
Either::Left((None, _)) => break None,
// heartbeat interval ticked
Either::Right((_inst, _)) => {
// if no heartbeat ping/pong received recently, close the connection
if Instant::now().duration_since(last_heartbeat) > WS_CLIENT_TIMEOUT {
log::info!(
"client has not sent heartbeat in over {WS_CLIENT_TIMEOUT:?}; disconnecting"
);
break None;
}
// send heartbeat ping
let _ = session.ping(b"").await;
}
}
};
// attempt to close connection gracefully
let _ = session.close(reason).await;
log::info!("WS disconnected");
}

149
src/server/api/ws.rs Normal file
View File

@ -0,0 +1,149 @@
use crate::constants::{WS_CLIENT_TIMEOUT, WS_HEARTBEAT_INTERVAL};
use crate::extractors::client_auth::APIClientAuth;
use crate::server::HttpResult;
use crate::user::{APIClientID, UserID};
use actix_web::dev::Payload;
use actix_web::{web, FromRequest, HttpRequest};
use actix_ws::Message;
use futures_util::StreamExt;
use std::time::Instant;
use tokio::select;
use tokio::sync::broadcast;
use tokio::sync::broadcast::Receiver;
use tokio::time::interval;
/// WebSocket message
#[derive(Debug, Clone)]
pub enum WsMessage {
/// Request to close the session of a specific client
CloseClientSession(APIClientID),
/// Close all the sessions of a given user
CloseAllUserSessions(UserID),
}
/// Main WS route
pub async fn ws(
req: HttpRequest,
stream: web::Payload,
tx: web::Data<broadcast::Sender<WsMessage>>,
) -> HttpResult {
// Forcefully ignore request payload by manually extracting authentication information
let auth = APIClientAuth::from_request(&req, &mut Payload::None).await?;
let (res, session, msg_stream) = actix_ws::handle(&req, stream)?;
let rx = tx.subscribe();
// spawn websocket handler (and don't await it) so that the response is returned immediately
actix_web::rt::spawn(ws_handler(session, msg_stream, auth, rx));
Ok(res)
}
pub async fn ws_handler(
mut session: actix_ws::Session,
mut msg_stream: actix_ws::MessageStream,
auth: APIClientAuth,
mut rx: Receiver<WsMessage>,
) {
log::info!("WS connected");
let mut last_heartbeat = Instant::now();
let mut interval = interval(WS_HEARTBEAT_INTERVAL);
let reason = loop {
// waits for either `msg_stream` to receive a message from the client, the broadcast channel
// to send a message, or the heartbeat interval timer to tick, yielding the value of
// whichever one is ready first
select! {
ws_msg = rx.recv() => {
let msg = match ws_msg {
Ok(msg) => msg,
Err(broadcast::error::RecvError::Closed) => break None,
Err(broadcast::error::RecvError::Lagged(_)) => continue,
};
match msg {
WsMessage::CloseClientSession(_) => todo!(),
WsMessage::CloseAllUserSessions(userid) => {
if userid == auth.user.user_id {
log::info!(
"closing WS session of user {userid:?} as requested"
);
break None;
}
}
};
}
// heartbeat interval ticked
_tick = interval.tick() => {
// if no heartbeat ping/pong received recently, close the connection
if Instant::now().duration_since(last_heartbeat) > WS_CLIENT_TIMEOUT {
log::info!(
"client has not sent heartbeat in over {WS_CLIENT_TIMEOUT:?}; disconnecting"
);
break None;
}
// send heartbeat ping
let _ = session.ping(b"").await;
},
msg = msg_stream.next() => {
let msg = match msg {
// received message from WebSocket client
Some(Ok(msg)) => msg,
// client WebSocket stream error
Some(Err(err)) => {
log::error!("{err}");
break None;
}
// client WebSocket stream ended
None => break None
};
log::debug!("msg: {msg:?}");
match msg {
Message::Text(s) => {
log::info!("Text message: {s}");
}
Message::Binary(_) => {
// drop client's binary messages
}
Message::Close(reason) => {
break reason;
}
Message::Ping(bytes) => {
last_heartbeat = Instant::now();
let _ = session.pong(&bytes).await;
}
Message::Pong(_) => {
last_heartbeat = Instant::now();
}
Message::Continuation(_) => {
log::warn!("no support for continuation frames");
}
// no-op; ignore
Message::Nop => {}
};
}
}
};
// attempt to close connection gracefully
let _ = session.close(reason).await;
log::info!("WS disconnected");
}

View File

@ -1,5 +1,6 @@
use crate::app_config::AppConfig; use crate::app_config::AppConfig;
use crate::constants::{STATE_KEY, USER_SESSION_KEY}; use crate::constants::{STATE_KEY, USER_SESSION_KEY};
use crate::server::api::ws::WsMessage;
use crate::server::{HttpFailure, HttpResult}; use crate::server::{HttpFailure, HttpResult};
use crate::user::{APIClient, APIClientID, User, UserConfig, UserID}; use crate::user::{APIClient, APIClientID, User, UserConfig, UserID};
use crate::utils; use crate::utils;
@ -9,6 +10,7 @@ use askama::Template;
use ipnet::IpNet; use ipnet::IpNet;
use light_openid::primitives::OpenIDConfig; use light_openid::primitives::OpenIDConfig;
use std::str::FromStr; use std::str::FromStr;
use tokio::sync::broadcast;
/// Static assets /// Static assets
#[derive(rust_embed::Embed)] #[derive(rust_embed::Embed)]
@ -60,7 +62,11 @@ pub struct FormRequest {
} }
/// Main route /// Main route
pub async fn home(session: Session, form_req: Option<web::Form<FormRequest>>) -> HttpResult { pub async fn home(
session: Session,
form_req: Option<web::Form<FormRequest>>,
tx: web::Data<broadcast::Sender<WsMessage>>,
) -> HttpResult {
// Get user information, requesting authentication if information is missing // Get user information, requesting authentication if information is missing
let Some(user): Option<User> = session.get(USER_SESSION_KEY)? else { let Some(user): Option<User> = session.get(USER_SESSION_KEY)? else {
// Generate auth state // Generate auth state
@ -93,10 +99,14 @@ pub async fn home(session: Session, form_req: Option<web::Form<FormRequest>>) ->
if t.len() < 3 { if t.len() < 3 {
error_message = Some("Specified Matrix token is too short!".to_string()); error_message = Some("Specified Matrix token is too short!".to_string());
} else { } else {
// TODO : invalidate all existing connections
config.matrix_token = t; config.matrix_token = t;
config.save().await?; config.save().await?;
success_message = Some("Matrix token was successfully updated!".to_string()); success_message = Some("Matrix token was successfully updated!".to_string());
// Invalidate all Ws connections
if let Err(e) = tx.send(WsMessage::CloseAllUserSessions(user.id.clone())) {
log::error!("Failed to send CloseAllUserSessions: {}", e);
}
} }
} }

View File

@ -19,7 +19,7 @@ pub enum UserError {
MissingMatrixToken, MissingMatrixToken,
} }
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
pub struct UserID(pub String); pub struct UserID(pub String);
impl UserID { impl UserID {