Can close all user sessions when Matrix token is changed
This commit is contained in:
		
							
								
								
									
										1
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										1
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							| @@ -2072,7 +2072,6 @@ dependencies = [ | |||||||
|  "serde_json", |  "serde_json", | ||||||
|  "sha2 0.11.0-pre.4", |  "sha2 0.11.0-pre.4", | ||||||
|  "thiserror 2.0.11", |  "thiserror 2.0.11", | ||||||
|  "time", |  | ||||||
|  "tokio", |  "tokio", | ||||||
|  "urlencoding", |  "urlencoding", | ||||||
|  "uuid", |  "uuid", | ||||||
|   | |||||||
| @@ -32,5 +32,4 @@ sha2 = "0.11.0-pre.4" | |||||||
| base16ct = "0.2.0" | base16ct = "0.2.0" | ||||||
| ruma = { version = "0.12.0", features = ["client-api-c", "client-ext-client-api", "client-hyper-native-tls", "rand"] } | ruma = { version = "0.12.0", features = ["client-api-c", "client-ext-client-api", "client-hyper-native-tls", "rand"] } | ||||||
| actix-ws = "0.3.0" | actix-ws = "0.3.0" | ||||||
| tokio = { version = "1.43.0", features = ["rt", "time", "macros"] } | tokio = { version = "1.43.0", features = ["rt", "time", "macros", "rt-multi-thread"] } | ||||||
| time = "0.3.37" |  | ||||||
| @@ -4,10 +4,11 @@ use actix_session::{storage::RedisSessionStore, SessionMiddleware}; | |||||||
| use actix_web::cookie::Key; | use actix_web::cookie::Key; | ||||||
| use actix_web::{web, App, HttpServer}; | use actix_web::{web, App, HttpServer}; | ||||||
| use matrix_gateway::app_config::AppConfig; | use matrix_gateway::app_config::AppConfig; | ||||||
|  | use matrix_gateway::server::api::ws::WsMessage; | ||||||
| use matrix_gateway::server::{api, web_ui}; | use matrix_gateway::server::{api, web_ui}; | ||||||
| use matrix_gateway::user::UserConfig; | use matrix_gateway::user::UserConfig; | ||||||
|  |  | ||||||
| #[actix_web::main] | #[tokio::main] | ||||||
| async fn main() -> std::io::Result<()> { | async fn main() -> std::io::Result<()> { | ||||||
|     env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); |     env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); | ||||||
|  |  | ||||||
| @@ -21,6 +22,8 @@ async fn main() -> std::io::Result<()> { | |||||||
|         .await |         .await | ||||||
|         .expect("Failed to connect to Redis!"); |         .expect("Failed to connect to Redis!"); | ||||||
|  |  | ||||||
|  |     let (ws_tx, _) = tokio::sync::broadcast::channel::<WsMessage>(16); | ||||||
|  |  | ||||||
|     log::info!( |     log::info!( | ||||||
|         "Starting to listen on {} for {}", |         "Starting to listen on {} for {}", | ||||||
|         AppConfig::get().listen_address, |         AppConfig::get().listen_address, | ||||||
| @@ -38,6 +41,7 @@ async fn main() -> std::io::Result<()> { | |||||||
|             .app_data(web::Data::new(RemoteIPConfig { |             .app_data(web::Data::new(RemoteIPConfig { | ||||||
|                 proxy: AppConfig::get().proxy_ip.clone(), |                 proxy: AppConfig::get().proxy_ip.clone(), | ||||||
|             })) |             })) | ||||||
|  |             .app_data(web::Data::new(ws_tx.clone())) | ||||||
|             // Web configuration routes |             // Web configuration routes | ||||||
|             .route("/assets/{tail:.*}", web::get().to(web_ui::static_file)) |             .route("/assets/{tail:.*}", web::get().to(web_ui::static_file)) | ||||||
|             .route("/", web::get().to(web_ui::home)) |             .route("/", web::get().to(web_ui::home)) | ||||||
| @@ -48,8 +52,9 @@ async fn main() -> std::io::Result<()> { | |||||||
|             .route("/api", web::get().to(api::api_home)) |             .route("/api", web::get().to(api::api_home)) | ||||||
|             .route("/api", web::post().to(api::api_home)) |             .route("/api", web::post().to(api::api_home)) | ||||||
|             .route("/api/account/whoami", web::get().to(api::account::who_am_i)) |             .route("/api/account/whoami", web::get().to(api::account::who_am_i)) | ||||||
|             .service(web::resource("/api/ws").route(web::get().to(api::ws))) |             .service(web::resource("/api/ws").route(web::get().to(api::ws::ws))) | ||||||
|     }) |     }) | ||||||
|  |     .workers(4) | ||||||
|     .bind(&AppConfig::get().listen_address)? |     .bind(&AppConfig::get().listen_address)? | ||||||
|     .run() |     .run() | ||||||
|     .await |     .await | ||||||
|   | |||||||
| @@ -1,112 +1,11 @@ | |||||||
| use crate::constants::{WS_CLIENT_TIMEOUT, WS_HEARTBEAT_INTERVAL}; |  | ||||||
| use crate::extractors::client_auth::APIClientAuth; | use crate::extractors::client_auth::APIClientAuth; | ||||||
| use crate::server::HttpResult; | use crate::server::HttpResult; | ||||||
| use actix_web::dev::Payload; | use actix_web::HttpResponse; | ||||||
| use actix_web::{web, FromRequest, HttpRequest, HttpResponse}; |  | ||||||
| use actix_ws::Message; |  | ||||||
| use futures_util::future::Either; |  | ||||||
| use futures_util::{future, StreamExt}; |  | ||||||
| use std::time::Instant; |  | ||||||
| use tokio::{pin, time::interval}; |  | ||||||
|  |  | ||||||
| pub mod account; | pub mod account; | ||||||
|  | pub mod ws; | ||||||
|  |  | ||||||
| /// API Home route | /// API Home route | ||||||
| pub async fn api_home(auth: APIClientAuth) -> HttpResult { | pub async fn api_home(auth: APIClientAuth) -> HttpResult { | ||||||
|     Ok(HttpResponse::Ok().body(format!("Welcome user {}!", auth.user.user_id.0))) |     Ok(HttpResponse::Ok().body(format!("Welcome user {}!", auth.user.user_id.0))) | ||||||
| } | } | ||||||
|  |  | ||||||
| /// Main WS route |  | ||||||
| pub async fn ws(req: HttpRequest, stream: web::Payload) -> HttpResult { |  | ||||||
|     // Forcefully ignore request payload by manually extracting authentication information |  | ||||||
|     let auth = APIClientAuth::from_request(&req, &mut Payload::None).await?; |  | ||||||
|  |  | ||||||
|     let (res, session, msg_stream) = actix_ws::handle(&req, stream)?; |  | ||||||
|  |  | ||||||
|     // spawn websocket handler (and don't await it) so that the response is returned immediately |  | ||||||
|     actix_web::rt::spawn(ws_handler(session, msg_stream)); |  | ||||||
|  |  | ||||||
|     Ok(res) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| pub async fn ws_handler(mut session: actix_ws::Session, mut msg_stream: actix_ws::MessageStream) { |  | ||||||
|     log::info!("WS connected"); |  | ||||||
|  |  | ||||||
|     let mut last_heartbeat = Instant::now(); |  | ||||||
|     let mut interval = interval(WS_HEARTBEAT_INTERVAL); |  | ||||||
|  |  | ||||||
|     let reason = loop { |  | ||||||
|         // create "next client timeout check" future |  | ||||||
|         let tick = interval.tick(); |  | ||||||
|         // required for select() |  | ||||||
|         pin!(tick); |  | ||||||
|  |  | ||||||
|         // waits for either `msg_stream` to receive a message from the client or the heartbeat |  | ||||||
|         // interval timer to tick, yielding the value of whichever one is ready first |  | ||||||
|         match future::select(msg_stream.next(), tick).await { |  | ||||||
|             // received message from WebSocket client |  | ||||||
|             Either::Left((Some(Ok(msg)), _)) => { |  | ||||||
|                 log::debug!("msg: {msg:?}"); |  | ||||||
|  |  | ||||||
|                 match msg { |  | ||||||
|                     Message::Text(text) => { |  | ||||||
|                         session.text(text).await.unwrap(); |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     Message::Binary(bin) => { |  | ||||||
|                         session.binary(bin).await.unwrap(); |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     Message::Close(reason) => { |  | ||||||
|                         break reason; |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     Message::Ping(bytes) => { |  | ||||||
|                         last_heartbeat = Instant::now(); |  | ||||||
|                         let _ = session.pong(&bytes).await; |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     Message::Pong(_) => { |  | ||||||
|                         last_heartbeat = Instant::now(); |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     Message::Continuation(_) => { |  | ||||||
|                         log::warn!("no support for continuation frames"); |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     // no-op; ignore |  | ||||||
|                     Message::Nop => {} |  | ||||||
|                 }; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             // client WebSocket stream error |  | ||||||
|             Either::Left((Some(Err(err)), _)) => { |  | ||||||
|                 log::error!("{}", err); |  | ||||||
|                 break None; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             // client WebSocket stream ended |  | ||||||
|             Either::Left((None, _)) => break None, |  | ||||||
|  |  | ||||||
|             // heartbeat interval ticked |  | ||||||
|             Either::Right((_inst, _)) => { |  | ||||||
|                 // if no heartbeat ping/pong received recently, close the connection |  | ||||||
|                 if Instant::now().duration_since(last_heartbeat) > WS_CLIENT_TIMEOUT { |  | ||||||
|                     log::info!( |  | ||||||
|                         "client has not sent heartbeat in over {WS_CLIENT_TIMEOUT:?}; disconnecting" |  | ||||||
|                     ); |  | ||||||
|  |  | ||||||
|                     break None; |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 // send heartbeat ping |  | ||||||
|                 let _ = session.ping(b"").await; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     // attempt to close connection gracefully |  | ||||||
|     let _ = session.close(reason).await; |  | ||||||
|  |  | ||||||
|     log::info!("WS disconnected"); |  | ||||||
| } |  | ||||||
|   | |||||||
							
								
								
									
										149
									
								
								src/server/api/ws.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										149
									
								
								src/server/api/ws.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,149 @@ | |||||||
|  | use crate::constants::{WS_CLIENT_TIMEOUT, WS_HEARTBEAT_INTERVAL}; | ||||||
|  | use crate::extractors::client_auth::APIClientAuth; | ||||||
|  | use crate::server::HttpResult; | ||||||
|  | use crate::user::{APIClientID, UserID}; | ||||||
|  | use actix_web::dev::Payload; | ||||||
|  | use actix_web::{web, FromRequest, HttpRequest}; | ||||||
|  | use actix_ws::Message; | ||||||
|  | use futures_util::StreamExt; | ||||||
|  | use std::time::Instant; | ||||||
|  | use tokio::select; | ||||||
|  | use tokio::sync::broadcast; | ||||||
|  | use tokio::sync::broadcast::Receiver; | ||||||
|  | use tokio::time::interval; | ||||||
|  |  | ||||||
|  | /// WebSocket message | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | pub enum WsMessage { | ||||||
|  |     /// Request to close the session of a specific client | ||||||
|  |     CloseClientSession(APIClientID), | ||||||
|  |     /// Close all the sessions of a given user | ||||||
|  |     CloseAllUserSessions(UserID), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /// Main WS route | ||||||
|  | pub async fn ws( | ||||||
|  |     req: HttpRequest, | ||||||
|  |     stream: web::Payload, | ||||||
|  |     tx: web::Data<broadcast::Sender<WsMessage>>, | ||||||
|  | ) -> HttpResult { | ||||||
|  |     // Forcefully ignore request payload by manually extracting authentication information | ||||||
|  |     let auth = APIClientAuth::from_request(&req, &mut Payload::None).await?; | ||||||
|  |  | ||||||
|  |     let (res, session, msg_stream) = actix_ws::handle(&req, stream)?; | ||||||
|  |  | ||||||
|  |     let rx = tx.subscribe(); | ||||||
|  |  | ||||||
|  |     // spawn websocket handler (and don't await it) so that the response is returned immediately | ||||||
|  |     actix_web::rt::spawn(ws_handler(session, msg_stream, auth, rx)); | ||||||
|  |  | ||||||
|  |     Ok(res) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub async fn ws_handler( | ||||||
|  |     mut session: actix_ws::Session, | ||||||
|  |     mut msg_stream: actix_ws::MessageStream, | ||||||
|  |     auth: APIClientAuth, | ||||||
|  |     mut rx: Receiver<WsMessage>, | ||||||
|  | ) { | ||||||
|  |     log::info!("WS connected"); | ||||||
|  |  | ||||||
|  |     let mut last_heartbeat = Instant::now(); | ||||||
|  |     let mut interval = interval(WS_HEARTBEAT_INTERVAL); | ||||||
|  |  | ||||||
|  |     let reason = loop { | ||||||
|  |         // waits for either `msg_stream` to receive a message from the client, the broadcast channel | ||||||
|  |         // to send a message, or the heartbeat interval timer to tick, yielding the value of | ||||||
|  |         // whichever one is ready first | ||||||
|  |         select! { | ||||||
|  |             ws_msg = rx.recv() => { | ||||||
|  |                 let msg = match ws_msg { | ||||||
|  |                     Ok(msg) => msg, | ||||||
|  |                     Err(broadcast::error::RecvError::Closed) => break None, | ||||||
|  |                     Err(broadcast::error::RecvError::Lagged(_)) => continue, | ||||||
|  |                 }; | ||||||
|  |  | ||||||
|  |                 match msg { | ||||||
|  |                     WsMessage::CloseClientSession(_) => todo!(), | ||||||
|  |                     WsMessage::CloseAllUserSessions(userid) => { | ||||||
|  |                         if userid == auth.user.user_id { | ||||||
|  |                             log::info!( | ||||||
|  |                                 "closing WS session of user {userid:?} as requested" | ||||||
|  |                             ); | ||||||
|  |                             break None; | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 }; | ||||||
|  |  | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             // heartbeat interval ticked | ||||||
|  |             _tick = interval.tick() => { | ||||||
|  |                 // if no heartbeat ping/pong received recently, close the connection | ||||||
|  |                 if Instant::now().duration_since(last_heartbeat) > WS_CLIENT_TIMEOUT { | ||||||
|  |                     log::info!( | ||||||
|  |                         "client has not sent heartbeat in over {WS_CLIENT_TIMEOUT:?}; disconnecting" | ||||||
|  |                     ); | ||||||
|  |  | ||||||
|  |                     break None; | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 // send heartbeat ping | ||||||
|  |                 let _ = session.ping(b"").await; | ||||||
|  |             }, | ||||||
|  |  | ||||||
|  |             msg = msg_stream.next() => { | ||||||
|  |                 let msg = match msg { | ||||||
|  |                     // received message from WebSocket client | ||||||
|  |                     Some(Ok(msg)) => msg, | ||||||
|  |  | ||||||
|  |                     // client WebSocket stream error | ||||||
|  |                     Some(Err(err)) => { | ||||||
|  |                         log::error!("{err}"); | ||||||
|  |                         break None; | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     // client WebSocket stream ended | ||||||
|  |                     None => break None | ||||||
|  |                 }; | ||||||
|  |  | ||||||
|  |                 log::debug!("msg: {msg:?}"); | ||||||
|  |  | ||||||
|  |                 match msg { | ||||||
|  |                     Message::Text(s) => { | ||||||
|  |                         log::info!("Text message: {s}"); | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     Message::Binary(_) => { | ||||||
|  |                         // drop client's binary messages | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     Message::Close(reason) => { | ||||||
|  |                         break reason; | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     Message::Ping(bytes) => { | ||||||
|  |                         last_heartbeat = Instant::now(); | ||||||
|  |                         let _ = session.pong(&bytes).await; | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     Message::Pong(_) => { | ||||||
|  |                         last_heartbeat = Instant::now(); | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     Message::Continuation(_) => { | ||||||
|  |                         log::warn!("no support for continuation frames"); | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     // no-op; ignore | ||||||
|  |                     Message::Nop => {} | ||||||
|  |                 }; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     // attempt to close connection gracefully | ||||||
|  |     let _ = session.close(reason).await; | ||||||
|  |  | ||||||
|  |     log::info!("WS disconnected"); | ||||||
|  | } | ||||||
| @@ -1,5 +1,6 @@ | |||||||
| use crate::app_config::AppConfig; | use crate::app_config::AppConfig; | ||||||
| use crate::constants::{STATE_KEY, USER_SESSION_KEY}; | use crate::constants::{STATE_KEY, USER_SESSION_KEY}; | ||||||
|  | use crate::server::api::ws::WsMessage; | ||||||
| use crate::server::{HttpFailure, HttpResult}; | use crate::server::{HttpFailure, HttpResult}; | ||||||
| use crate::user::{APIClient, APIClientID, User, UserConfig, UserID}; | use crate::user::{APIClient, APIClientID, User, UserConfig, UserID}; | ||||||
| use crate::utils; | use crate::utils; | ||||||
| @@ -9,6 +10,7 @@ use askama::Template; | |||||||
| use ipnet::IpNet; | use ipnet::IpNet; | ||||||
| use light_openid::primitives::OpenIDConfig; | use light_openid::primitives::OpenIDConfig; | ||||||
| use std::str::FromStr; | use std::str::FromStr; | ||||||
|  | use tokio::sync::broadcast; | ||||||
|  |  | ||||||
| /// Static assets | /// Static assets | ||||||
| #[derive(rust_embed::Embed)] | #[derive(rust_embed::Embed)] | ||||||
| @@ -60,7 +62,11 @@ pub struct FormRequest { | |||||||
| } | } | ||||||
|  |  | ||||||
| /// Main route | /// Main route | ||||||
| pub async fn home(session: Session, form_req: Option<web::Form<FormRequest>>) -> HttpResult { | pub async fn home( | ||||||
|  |     session: Session, | ||||||
|  |     form_req: Option<web::Form<FormRequest>>, | ||||||
|  |     tx: web::Data<broadcast::Sender<WsMessage>>, | ||||||
|  | ) -> HttpResult { | ||||||
|     // Get user information, requesting authentication if information is missing |     // Get user information, requesting authentication if information is missing | ||||||
|     let Some(user): Option<User> = session.get(USER_SESSION_KEY)? else { |     let Some(user): Option<User> = session.get(USER_SESSION_KEY)? else { | ||||||
|         // Generate auth state |         // Generate auth state | ||||||
| @@ -93,10 +99,14 @@ pub async fn home(session: Session, form_req: Option<web::Form<FormRequest>>) -> | |||||||
|             if t.len() < 3 { |             if t.len() < 3 { | ||||||
|                 error_message = Some("Specified Matrix token is too short!".to_string()); |                 error_message = Some("Specified Matrix token is too short!".to_string()); | ||||||
|             } else { |             } else { | ||||||
|                 // TODO : invalidate all existing connections |  | ||||||
|                 config.matrix_token = t; |                 config.matrix_token = t; | ||||||
|                 config.save().await?; |                 config.save().await?; | ||||||
|                 success_message = Some("Matrix token was successfully updated!".to_string()); |                 success_message = Some("Matrix token was successfully updated!".to_string()); | ||||||
|  |  | ||||||
|  |                 // Invalidate all Ws connections | ||||||
|  |                 if let Err(e) = tx.send(WsMessage::CloseAllUserSessions(user.id.clone())) { | ||||||
|  |                     log::error!("Failed to send CloseAllUserSessions: {}", e); | ||||||
|  |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -19,7 +19,7 @@ pub enum UserError { | |||||||
|     MissingMatrixToken, |     MissingMatrixToken, | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] | #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)] | ||||||
| pub struct UserID(pub String); | pub struct UserID(pub String); | ||||||
|  |  | ||||||
| impl UserID { | impl UserID { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user