Compare commits
2 Commits
c573d2f74a
...
558d5cda3f
| Author | SHA1 | Date | |
|---|---|---|---|
| 558d5cda3f | |||
| 4ff72e073e |
28
Cargo.lock
generated
28
Cargo.lock
generated
@@ -214,6 +214,20 @@ dependencies = [
|
|||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "actix-ws"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a3a1fb4f9f2794b0aadaf2ba5f14a6f034c7e86957b458c506a8cb75953f2d99"
|
||||||
|
dependencies = [
|
||||||
|
"actix-codec",
|
||||||
|
"actix-http",
|
||||||
|
"actix-web",
|
||||||
|
"bytestring",
|
||||||
|
"futures-core",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "addr2line"
|
name = "addr2line"
|
||||||
version = "0.24.2"
|
version = "0.24.2"
|
||||||
@@ -2035,6 +2049,7 @@ dependencies = [
|
|||||||
"actix-remote-ip",
|
"actix-remote-ip",
|
||||||
"actix-session",
|
"actix-session",
|
||||||
"actix-web",
|
"actix-web",
|
||||||
|
"actix-ws",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"askama",
|
"askama",
|
||||||
"base16ct",
|
"base16ct",
|
||||||
@@ -2057,6 +2072,7 @@ dependencies = [
|
|||||||
"serde_json",
|
"serde_json",
|
||||||
"sha2 0.11.0-pre.4",
|
"sha2 0.11.0-pre.4",
|
||||||
"thiserror 2.0.11",
|
"thiserror 2.0.11",
|
||||||
|
"tokio",
|
||||||
"urlencoding",
|
"urlencoding",
|
||||||
"uuid",
|
"uuid",
|
||||||
]
|
]
|
||||||
@@ -3490,9 +3506,21 @@ dependencies = [
|
|||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"signal-hook-registry",
|
"signal-hook-registry",
|
||||||
"socket2",
|
"socket2",
|
||||||
|
"tokio-macros",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-macros"
|
||||||
|
version = "2.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-native-tls"
|
name = "tokio-native-tls"
|
||||||
version = "0.3.1"
|
version = "0.3.1"
|
||||||
|
|||||||
@@ -24,10 +24,12 @@ urlencoding = "2.1.3"
|
|||||||
uuid = { version = "1.12.1", features = ["v4", "serde"] }
|
uuid = { version = "1.12.1", features = ["v4", "serde"] }
|
||||||
ipnet = { version = "2.11.0", features = ["serde"] }
|
ipnet = { version = "2.11.0", features = ["serde"] }
|
||||||
chrono = "0.4.39"
|
chrono = "0.4.39"
|
||||||
futures-util = "0.3.31"
|
futures-util = { version = "0.3.31", features = ["sink"] }
|
||||||
jwt-simple = { version = "0.12.11", default-features=false, features=["pure-rust"] }
|
jwt-simple = { version = "0.12.11", default-features = false, features = ["pure-rust"] }
|
||||||
actix-remote-ip = "0.1.0"
|
actix-remote-ip = "0.1.0"
|
||||||
bytes = "1.9.0"
|
bytes = "1.9.0"
|
||||||
sha2 = "0.11.0-pre.4"
|
sha2 = "0.11.0-pre.4"
|
||||||
base16ct = "0.2.0"
|
base16ct = "0.2.0"
|
||||||
ruma = { version = "0.12.0", features = ["client-api-c", "client-ext-client-api", "client-hyper-native-tls", "rand"] }
|
ruma = { version = "0.12.0", features = ["client-api-c", "client-ext-client-api", "client-hyper-native-tls", "rand"] }
|
||||||
|
actix-ws = "0.3.0"
|
||||||
|
tokio = { version = "1.43.0", features = ["rt", "time", "macros", "rt-multi-thread"] }
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
/// Session key for OpenID login state
|
/// Session key for OpenID login state
|
||||||
pub const STATE_KEY: &str = "oidc-state";
|
pub const STATE_KEY: &str = "oidc-state";
|
||||||
|
|
||||||
@@ -6,3 +8,11 @@ pub const USER_SESSION_KEY: &str = "user";
|
|||||||
|
|
||||||
/// Token length
|
/// Token length
|
||||||
pub const TOKEN_LEN: usize = 20;
|
pub const TOKEN_LEN: usize = 20;
|
||||||
|
|
||||||
|
/// How often heartbeat pings are sent.
|
||||||
|
///
|
||||||
|
/// Should be half (or less) of the acceptable client timeout.
|
||||||
|
pub const WS_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
|
||||||
|
|
||||||
|
/// How long before lack of client response causes a timeout.
|
||||||
|
pub const WS_CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||||
|
|||||||
@@ -4,10 +4,11 @@ use actix_session::{storage::RedisSessionStore, SessionMiddleware};
|
|||||||
use actix_web::cookie::Key;
|
use actix_web::cookie::Key;
|
||||||
use actix_web::{web, App, HttpServer};
|
use actix_web::{web, App, HttpServer};
|
||||||
use matrix_gateway::app_config::AppConfig;
|
use matrix_gateway::app_config::AppConfig;
|
||||||
|
use matrix_gateway::server::api::ws::WsMessage;
|
||||||
use matrix_gateway::server::{api, web_ui};
|
use matrix_gateway::server::{api, web_ui};
|
||||||
use matrix_gateway::user::UserConfig;
|
use matrix_gateway::user::UserConfig;
|
||||||
|
|
||||||
#[actix_web::main]
|
#[tokio::main]
|
||||||
async fn main() -> std::io::Result<()> {
|
async fn main() -> std::io::Result<()> {
|
||||||
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));
|
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));
|
||||||
|
|
||||||
@@ -21,6 +22,8 @@ async fn main() -> std::io::Result<()> {
|
|||||||
.await
|
.await
|
||||||
.expect("Failed to connect to Redis!");
|
.expect("Failed to connect to Redis!");
|
||||||
|
|
||||||
|
let (ws_tx, _) = tokio::sync::broadcast::channel::<WsMessage>(16);
|
||||||
|
|
||||||
log::info!(
|
log::info!(
|
||||||
"Starting to listen on {} for {}",
|
"Starting to listen on {} for {}",
|
||||||
AppConfig::get().listen_address,
|
AppConfig::get().listen_address,
|
||||||
@@ -38,6 +41,7 @@ async fn main() -> std::io::Result<()> {
|
|||||||
.app_data(web::Data::new(RemoteIPConfig {
|
.app_data(web::Data::new(RemoteIPConfig {
|
||||||
proxy: AppConfig::get().proxy_ip.clone(),
|
proxy: AppConfig::get().proxy_ip.clone(),
|
||||||
}))
|
}))
|
||||||
|
.app_data(web::Data::new(ws_tx.clone()))
|
||||||
// Web configuration routes
|
// Web configuration routes
|
||||||
.route("/assets/{tail:.*}", web::get().to(web_ui::static_file))
|
.route("/assets/{tail:.*}", web::get().to(web_ui::static_file))
|
||||||
.route("/", web::get().to(web_ui::home))
|
.route("/", web::get().to(web_ui::home))
|
||||||
@@ -48,7 +52,9 @@ async fn main() -> std::io::Result<()> {
|
|||||||
.route("/api", web::get().to(api::api_home))
|
.route("/api", web::get().to(api::api_home))
|
||||||
.route("/api", web::post().to(api::api_home))
|
.route("/api", web::post().to(api::api_home))
|
||||||
.route("/api/account/whoami", web::get().to(api::account::who_am_i))
|
.route("/api/account/whoami", web::get().to(api::account::who_am_i))
|
||||||
|
.service(web::resource("/api/ws").route(web::get().to(api::ws::ws)))
|
||||||
})
|
})
|
||||||
|
.workers(4)
|
||||||
.bind(&AppConfig::get().listen_address)?
|
.bind(&AppConfig::get().listen_address)?
|
||||||
.run()
|
.run()
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use crate::server::HttpResult;
|
|||||||
use actix_web::HttpResponse;
|
use actix_web::HttpResponse;
|
||||||
|
|
||||||
pub mod account;
|
pub mod account;
|
||||||
|
pub mod ws;
|
||||||
|
|
||||||
/// API Home route
|
/// API Home route
|
||||||
pub async fn api_home(auth: APIClientAuth) -> HttpResult {
|
pub async fn api_home(auth: APIClientAuth) -> HttpResult {
|
||||||
|
|||||||
149
src/server/api/ws.rs
Normal file
149
src/server/api/ws.rs
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
use crate::constants::{WS_CLIENT_TIMEOUT, WS_HEARTBEAT_INTERVAL};
|
||||||
|
use crate::extractors::client_auth::APIClientAuth;
|
||||||
|
use crate::server::HttpResult;
|
||||||
|
use crate::user::{APIClientID, UserID};
|
||||||
|
use actix_web::dev::Payload;
|
||||||
|
use actix_web::{web, FromRequest, HttpRequest};
|
||||||
|
use actix_ws::Message;
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use std::time::Instant;
|
||||||
|
use tokio::select;
|
||||||
|
use tokio::sync::broadcast;
|
||||||
|
use tokio::sync::broadcast::Receiver;
|
||||||
|
use tokio::time::interval;
|
||||||
|
|
||||||
|
/// WebSocket message
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum WsMessage {
|
||||||
|
/// Request to close the session of a specific client
|
||||||
|
CloseClientSession(APIClientID),
|
||||||
|
/// Close all the sessions of a given user
|
||||||
|
CloseAllUserSessions(UserID),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Main WS route
|
||||||
|
pub async fn ws(
|
||||||
|
req: HttpRequest,
|
||||||
|
stream: web::Payload,
|
||||||
|
tx: web::Data<broadcast::Sender<WsMessage>>,
|
||||||
|
) -> HttpResult {
|
||||||
|
// Forcefully ignore request payload by manually extracting authentication information
|
||||||
|
let auth = APIClientAuth::from_request(&req, &mut Payload::None).await?;
|
||||||
|
|
||||||
|
let (res, session, msg_stream) = actix_ws::handle(&req, stream)?;
|
||||||
|
|
||||||
|
let rx = tx.subscribe();
|
||||||
|
|
||||||
|
// spawn websocket handler (and don't await it) so that the response is returned immediately
|
||||||
|
actix_web::rt::spawn(ws_handler(session, msg_stream, auth, rx));
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ws_handler(
|
||||||
|
mut session: actix_ws::Session,
|
||||||
|
mut msg_stream: actix_ws::MessageStream,
|
||||||
|
auth: APIClientAuth,
|
||||||
|
mut rx: Receiver<WsMessage>,
|
||||||
|
) {
|
||||||
|
log::info!("WS connected");
|
||||||
|
|
||||||
|
let mut last_heartbeat = Instant::now();
|
||||||
|
let mut interval = interval(WS_HEARTBEAT_INTERVAL);
|
||||||
|
|
||||||
|
let reason = loop {
|
||||||
|
// waits for either `msg_stream` to receive a message from the client, the broadcast channel
|
||||||
|
// to send a message, or the heartbeat interval timer to tick, yielding the value of
|
||||||
|
// whichever one is ready first
|
||||||
|
select! {
|
||||||
|
ws_msg = rx.recv() => {
|
||||||
|
let msg = match ws_msg {
|
||||||
|
Ok(msg) => msg,
|
||||||
|
Err(broadcast::error::RecvError::Closed) => break None,
|
||||||
|
Err(broadcast::error::RecvError::Lagged(_)) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
match msg {
|
||||||
|
WsMessage::CloseClientSession(_) => todo!(),
|
||||||
|
WsMessage::CloseAllUserSessions(userid) => {
|
||||||
|
if userid == auth.user.user_id {
|
||||||
|
log::info!(
|
||||||
|
"closing WS session of user {userid:?} as requested"
|
||||||
|
);
|
||||||
|
break None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// heartbeat interval ticked
|
||||||
|
_tick = interval.tick() => {
|
||||||
|
// if no heartbeat ping/pong received recently, close the connection
|
||||||
|
if Instant::now().duration_since(last_heartbeat) > WS_CLIENT_TIMEOUT {
|
||||||
|
log::info!(
|
||||||
|
"client has not sent heartbeat in over {WS_CLIENT_TIMEOUT:?}; disconnecting"
|
||||||
|
);
|
||||||
|
|
||||||
|
break None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// send heartbeat ping
|
||||||
|
let _ = session.ping(b"").await;
|
||||||
|
},
|
||||||
|
|
||||||
|
msg = msg_stream.next() => {
|
||||||
|
let msg = match msg {
|
||||||
|
// received message from WebSocket client
|
||||||
|
Some(Ok(msg)) => msg,
|
||||||
|
|
||||||
|
// client WebSocket stream error
|
||||||
|
Some(Err(err)) => {
|
||||||
|
log::error!("{err}");
|
||||||
|
break None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// client WebSocket stream ended
|
||||||
|
None => break None
|
||||||
|
};
|
||||||
|
|
||||||
|
log::debug!("msg: {msg:?}");
|
||||||
|
|
||||||
|
match msg {
|
||||||
|
Message::Text(s) => {
|
||||||
|
log::info!("Text message: {s}");
|
||||||
|
}
|
||||||
|
|
||||||
|
Message::Binary(_) => {
|
||||||
|
// drop client's binary messages
|
||||||
|
}
|
||||||
|
|
||||||
|
Message::Close(reason) => {
|
||||||
|
break reason;
|
||||||
|
}
|
||||||
|
|
||||||
|
Message::Ping(bytes) => {
|
||||||
|
last_heartbeat = Instant::now();
|
||||||
|
let _ = session.pong(&bytes).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Message::Pong(_) => {
|
||||||
|
last_heartbeat = Instant::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
Message::Continuation(_) => {
|
||||||
|
log::warn!("no support for continuation frames");
|
||||||
|
}
|
||||||
|
|
||||||
|
// no-op; ignore
|
||||||
|
Message::Nop => {}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// attempt to close connection gracefully
|
||||||
|
let _ = session.close(reason).await;
|
||||||
|
|
||||||
|
log::info!("WS disconnected");
|
||||||
|
}
|
||||||
@@ -12,6 +12,8 @@ pub enum HttpFailure {
|
|||||||
Forbidden,
|
Forbidden,
|
||||||
#[error("this resource was not found")]
|
#[error("this resource was not found")]
|
||||||
NotFound,
|
NotFound,
|
||||||
|
#[error("Actix web error")]
|
||||||
|
ActixError(#[from] actix_web::Error),
|
||||||
#[error("an unhandled session insert error occurred")]
|
#[error("an unhandled session insert error occurred")]
|
||||||
SessionInsertError(#[from] actix_session::SessionInsertError),
|
SessionInsertError(#[from] actix_session::SessionInsertError),
|
||||||
#[error("an unhandled session error occurred")]
|
#[error("an unhandled session error occurred")]
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use crate::app_config::AppConfig;
|
use crate::app_config::AppConfig;
|
||||||
use crate::constants::{STATE_KEY, USER_SESSION_KEY};
|
use crate::constants::{STATE_KEY, USER_SESSION_KEY};
|
||||||
|
use crate::server::api::ws::WsMessage;
|
||||||
use crate::server::{HttpFailure, HttpResult};
|
use crate::server::{HttpFailure, HttpResult};
|
||||||
use crate::user::{APIClient, APIClientID, User, UserConfig, UserID};
|
use crate::user::{APIClient, APIClientID, User, UserConfig, UserID};
|
||||||
use crate::utils;
|
use crate::utils;
|
||||||
@@ -9,6 +10,7 @@ use askama::Template;
|
|||||||
use ipnet::IpNet;
|
use ipnet::IpNet;
|
||||||
use light_openid::primitives::OpenIDConfig;
|
use light_openid::primitives::OpenIDConfig;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
use tokio::sync::broadcast;
|
||||||
|
|
||||||
/// Static assets
|
/// Static assets
|
||||||
#[derive(rust_embed::Embed)]
|
#[derive(rust_embed::Embed)]
|
||||||
@@ -60,7 +62,11 @@ pub struct FormRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Main route
|
/// Main route
|
||||||
pub async fn home(session: Session, form_req: Option<web::Form<FormRequest>>) -> HttpResult {
|
pub async fn home(
|
||||||
|
session: Session,
|
||||||
|
form_req: Option<web::Form<FormRequest>>,
|
||||||
|
tx: web::Data<broadcast::Sender<WsMessage>>,
|
||||||
|
) -> HttpResult {
|
||||||
// Get user information, requesting authentication if information is missing
|
// Get user information, requesting authentication if information is missing
|
||||||
let Some(user): Option<User> = session.get(USER_SESSION_KEY)? else {
|
let Some(user): Option<User> = session.get(USER_SESSION_KEY)? else {
|
||||||
// Generate auth state
|
// Generate auth state
|
||||||
@@ -93,10 +99,14 @@ pub async fn home(session: Session, form_req: Option<web::Form<FormRequest>>) ->
|
|||||||
if t.len() < 3 {
|
if t.len() < 3 {
|
||||||
error_message = Some("Specified Matrix token is too short!".to_string());
|
error_message = Some("Specified Matrix token is too short!".to_string());
|
||||||
} else {
|
} else {
|
||||||
// TODO : invalidate all existing connections
|
|
||||||
config.matrix_token = t;
|
config.matrix_token = t;
|
||||||
config.save().await?;
|
config.save().await?;
|
||||||
success_message = Some("Matrix token was successfully updated!".to_string());
|
success_message = Some("Matrix token was successfully updated!".to_string());
|
||||||
|
|
||||||
|
// Invalidate all Ws connections
|
||||||
|
if let Err(e) = tx.send(WsMessage::CloseAllUserSessions(user.id.clone())) {
|
||||||
|
log::error!("Failed to send CloseAllUserSessions: {}", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ pub enum UserError {
|
|||||||
MissingMatrixToken,
|
MissingMatrixToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
|
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
|
||||||
pub struct UserID(pub String);
|
pub struct UserID(pub String);
|
||||||
|
|
||||||
impl UserID {
|
impl UserID {
|
||||||
|
|||||||
Reference in New Issue
Block a user