diff --git a/Cargo.lock b/Cargo.lock index 53e72a7..219a896 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -214,6 +214,20 @@ dependencies = [ "syn", ] +[[package]] +name = "actix-ws" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3a1fb4f9f2794b0aadaf2ba5f14a6f034c7e86957b458c506a8cb75953f2d99" +dependencies = [ + "actix-codec", + "actix-http", + "actix-web", + "bytestring", + "futures-core", + "tokio", +] + [[package]] name = "addr2line" version = "0.24.2" @@ -2035,6 +2049,7 @@ dependencies = [ "actix-remote-ip", "actix-session", "actix-web", + "actix-ws", "anyhow", "askama", "base16ct", @@ -2057,6 +2072,8 @@ dependencies = [ "serde_json", "sha2 0.11.0-pre.4", "thiserror 2.0.11", + "time", + "tokio", "urlencoding", "uuid", ] @@ -3490,9 +3507,21 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "socket2", + "tokio-macros", "windows-sys 0.52.0", ] +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tokio-native-tls" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index cb42f5e..a02c20d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,10 +24,13 @@ urlencoding = "2.1.3" uuid = { version = "1.12.1", features = ["v4", "serde"] } ipnet = { version = "2.11.0", features = ["serde"] } chrono = "0.4.39" -futures-util = "0.3.31" -jwt-simple = { version = "0.12.11", default-features=false, features=["pure-rust"] } +futures-util = { version = "0.3.31", features = ["sink"] } +jwt-simple = { version = "0.12.11", default-features = false, features = ["pure-rust"] } actix-remote-ip = "0.1.0" bytes = "1.9.0" sha2 = "0.11.0-pre.4" base16ct = "0.2.0" -ruma = { version = "0.12.0", features = ["client-api-c", "client-ext-client-api", "client-hyper-native-tls", "rand"] } \ No newline at end of file +ruma = { version = "0.12.0", features = ["client-api-c", "client-ext-client-api", "client-hyper-native-tls", "rand"] } +actix-ws = "0.3.0" +tokio = { version = "1.43.0", features = ["rt", "time", "macros"] } +time = "0.3.37" \ No newline at end of file diff --git a/src/constants.rs b/src/constants.rs index 2e1fcce..539f496 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + /// Session key for OpenID login state pub const STATE_KEY: &str = "oidc-state"; @@ -6,3 +8,11 @@ pub const USER_SESSION_KEY: &str = "user"; /// Token length pub const TOKEN_LEN: usize = 20; + +/// How often heartbeat pings are sent. +/// +/// Should be half (or less) of the acceptable client timeout. +pub const WS_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); + +/// How long before lack of client response causes a timeout. +pub const WS_CLIENT_TIMEOUT: Duration = Duration::from_secs(10); diff --git a/src/main.rs b/src/main.rs index 0239672..cdb1d71 100644 --- a/src/main.rs +++ b/src/main.rs @@ -48,6 +48,7 @@ async fn main() -> std::io::Result<()> { .route("/api", web::get().to(api::api_home)) .route("/api", web::post().to(api::api_home)) .route("/api/account/whoami", web::get().to(api::account::who_am_i)) + .service(web::resource("/api/ws").route(web::get().to(api::ws))) }) .bind(&AppConfig::get().listen_address)? .run() diff --git a/src/server/api/mod.rs b/src/server/api/mod.rs index 499de08..f0389b2 100644 --- a/src/server/api/mod.rs +++ b/src/server/api/mod.rs @@ -1,6 +1,13 @@ +use crate::constants::{WS_CLIENT_TIMEOUT, WS_HEARTBEAT_INTERVAL}; use crate::extractors::client_auth::APIClientAuth; use crate::server::HttpResult; -use actix_web::HttpResponse; +use actix_web::dev::Payload; +use actix_web::{web, FromRequest, HttpRequest, HttpResponse}; +use actix_ws::Message; +use futures_util::future::Either; +use futures_util::{future, StreamExt}; +use std::time::Instant; +use tokio::{pin, time::interval}; pub mod account; @@ -8,3 +15,98 @@ pub mod account; pub async fn api_home(auth: APIClientAuth) -> HttpResult { Ok(HttpResponse::Ok().body(format!("Welcome user {}!", auth.user.user_id.0))) } + +/// Main WS route +pub async fn ws(req: HttpRequest, stream: web::Payload) -> HttpResult { + // Forcefully ignore request payload by manually extracting authentication information + let auth = APIClientAuth::from_request(&req, &mut Payload::None).await?; + + let (res, session, msg_stream) = actix_ws::handle(&req, stream)?; + + // spawn websocket handler (and don't await it) so that the response is returned immediately + actix_web::rt::spawn(ws_handler(session, msg_stream)); + + Ok(res) +} + +pub async fn ws_handler(mut session: actix_ws::Session, mut msg_stream: actix_ws::MessageStream) { + log::info!("WS connected"); + + let mut last_heartbeat = Instant::now(); + let mut interval = interval(WS_HEARTBEAT_INTERVAL); + + let reason = loop { + // create "next client timeout check" future + let tick = interval.tick(); + // required for select() + pin!(tick); + + // waits for either `msg_stream` to receive a message from the client or the heartbeat + // interval timer to tick, yielding the value of whichever one is ready first + match future::select(msg_stream.next(), tick).await { + // received message from WebSocket client + Either::Left((Some(Ok(msg)), _)) => { + log::debug!("msg: {msg:?}"); + + match msg { + Message::Text(text) => { + session.text(text).await.unwrap(); + } + + Message::Binary(bin) => { + session.binary(bin).await.unwrap(); + } + + Message::Close(reason) => { + break reason; + } + + Message::Ping(bytes) => { + last_heartbeat = Instant::now(); + let _ = session.pong(&bytes).await; + } + + Message::Pong(_) => { + last_heartbeat = Instant::now(); + } + + Message::Continuation(_) => { + log::warn!("no support for continuation frames"); + } + + // no-op; ignore + Message::Nop => {} + }; + } + + // client WebSocket stream error + Either::Left((Some(Err(err)), _)) => { + log::error!("{}", err); + break None; + } + + // client WebSocket stream ended + Either::Left((None, _)) => break None, + + // heartbeat interval ticked + Either::Right((_inst, _)) => { + // if no heartbeat ping/pong received recently, close the connection + if Instant::now().duration_since(last_heartbeat) > WS_CLIENT_TIMEOUT { + log::info!( + "client has not sent heartbeat in over {WS_CLIENT_TIMEOUT:?}; disconnecting" + ); + + break None; + } + + // send heartbeat ping + let _ = session.ping(b"").await; + } + } + }; + + // attempt to close connection gracefully + let _ = session.close(reason).await; + + log::info!("WS disconnected"); +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 9833631..a1c79c7 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -12,6 +12,8 @@ pub enum HttpFailure { Forbidden, #[error("this resource was not found")] NotFound, + #[error("Actix web error")] + ActixError(#[from] actix_web::Error), #[error("an unhandled session insert error occurred")] SessionInsertError(#[from] actix_session::SessionInsertError), #[error("an unhandled session error occurred")]