Add basic ping-pong websocket
This commit is contained in:
		
							
								
								
									
										29
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										29
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							@@ -214,6 +214,20 @@ dependencies = [
 | 
				
			|||||||
 "syn",
 | 
					 "syn",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[[package]]
 | 
				
			||||||
 | 
					name = "actix-ws"
 | 
				
			||||||
 | 
					version = "0.3.0"
 | 
				
			||||||
 | 
					source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
				
			||||||
 | 
					checksum = "a3a1fb4f9f2794b0aadaf2ba5f14a6f034c7e86957b458c506a8cb75953f2d99"
 | 
				
			||||||
 | 
					dependencies = [
 | 
				
			||||||
 | 
					 "actix-codec",
 | 
				
			||||||
 | 
					 "actix-http",
 | 
				
			||||||
 | 
					 "actix-web",
 | 
				
			||||||
 | 
					 "bytestring",
 | 
				
			||||||
 | 
					 "futures-core",
 | 
				
			||||||
 | 
					 "tokio",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[[package]]
 | 
					[[package]]
 | 
				
			||||||
name = "addr2line"
 | 
					name = "addr2line"
 | 
				
			||||||
version = "0.24.2"
 | 
					version = "0.24.2"
 | 
				
			||||||
@@ -2035,6 +2049,7 @@ dependencies = [
 | 
				
			|||||||
 "actix-remote-ip",
 | 
					 "actix-remote-ip",
 | 
				
			||||||
 "actix-session",
 | 
					 "actix-session",
 | 
				
			||||||
 "actix-web",
 | 
					 "actix-web",
 | 
				
			||||||
 | 
					 "actix-ws",
 | 
				
			||||||
 "anyhow",
 | 
					 "anyhow",
 | 
				
			||||||
 "askama",
 | 
					 "askama",
 | 
				
			||||||
 "base16ct",
 | 
					 "base16ct",
 | 
				
			||||||
@@ -2057,6 +2072,8 @@ dependencies = [
 | 
				
			|||||||
 "serde_json",
 | 
					 "serde_json",
 | 
				
			||||||
 "sha2 0.11.0-pre.4",
 | 
					 "sha2 0.11.0-pre.4",
 | 
				
			||||||
 "thiserror 2.0.11",
 | 
					 "thiserror 2.0.11",
 | 
				
			||||||
 | 
					 "time",
 | 
				
			||||||
 | 
					 "tokio",
 | 
				
			||||||
 "urlencoding",
 | 
					 "urlencoding",
 | 
				
			||||||
 "uuid",
 | 
					 "uuid",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
@@ -3490,9 +3507,21 @@ dependencies = [
 | 
				
			|||||||
 "pin-project-lite",
 | 
					 "pin-project-lite",
 | 
				
			||||||
 "signal-hook-registry",
 | 
					 "signal-hook-registry",
 | 
				
			||||||
 "socket2",
 | 
					 "socket2",
 | 
				
			||||||
 | 
					 "tokio-macros",
 | 
				
			||||||
 "windows-sys 0.52.0",
 | 
					 "windows-sys 0.52.0",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[[package]]
 | 
				
			||||||
 | 
					name = "tokio-macros"
 | 
				
			||||||
 | 
					version = "2.5.0"
 | 
				
			||||||
 | 
					source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
				
			||||||
 | 
					checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8"
 | 
				
			||||||
 | 
					dependencies = [
 | 
				
			||||||
 | 
					 "proc-macro2",
 | 
				
			||||||
 | 
					 "quote",
 | 
				
			||||||
 | 
					 "syn",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[[package]]
 | 
					[[package]]
 | 
				
			||||||
name = "tokio-native-tls"
 | 
					name = "tokio-native-tls"
 | 
				
			||||||
version = "0.3.1"
 | 
					version = "0.3.1"
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -24,10 +24,13 @@ urlencoding = "2.1.3"
 | 
				
			|||||||
uuid = { version = "1.12.1", features = ["v4", "serde"] }
 | 
					uuid = { version = "1.12.1", features = ["v4", "serde"] }
 | 
				
			||||||
ipnet = { version = "2.11.0", features = ["serde"] }
 | 
					ipnet = { version = "2.11.0", features = ["serde"] }
 | 
				
			||||||
chrono = "0.4.39"
 | 
					chrono = "0.4.39"
 | 
				
			||||||
futures-util = "0.3.31"
 | 
					futures-util = { version = "0.3.31", features = ["sink"] }
 | 
				
			||||||
jwt-simple = { version = "0.12.11", default-features=false, features=["pure-rust"] }
 | 
					jwt-simple = { version = "0.12.11", default-features = false, features = ["pure-rust"] }
 | 
				
			||||||
actix-remote-ip = "0.1.0"
 | 
					actix-remote-ip = "0.1.0"
 | 
				
			||||||
bytes = "1.9.0"
 | 
					bytes = "1.9.0"
 | 
				
			||||||
sha2 = "0.11.0-pre.4"
 | 
					sha2 = "0.11.0-pre.4"
 | 
				
			||||||
base16ct = "0.2.0"
 | 
					base16ct = "0.2.0"
 | 
				
			||||||
ruma = { version = "0.12.0", features = ["client-api-c", "client-ext-client-api", "client-hyper-native-tls", "rand"] }
 | 
					ruma = { version = "0.12.0", features = ["client-api-c", "client-ext-client-api", "client-hyper-native-tls", "rand"] }
 | 
				
			||||||
 | 
					actix-ws = "0.3.0"
 | 
				
			||||||
 | 
					tokio = { version = "1.43.0", features = ["rt", "time", "macros"] }
 | 
				
			||||||
 | 
					time = "0.3.37"
 | 
				
			||||||
@@ -1,3 +1,5 @@
 | 
				
			|||||||
 | 
					use std::time::Duration;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/// Session key for OpenID login state
 | 
					/// Session key for OpenID login state
 | 
				
			||||||
pub const STATE_KEY: &str = "oidc-state";
 | 
					pub const STATE_KEY: &str = "oidc-state";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -6,3 +8,11 @@ pub const USER_SESSION_KEY: &str = "user";
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
/// Token length
 | 
					/// Token length
 | 
				
			||||||
pub const TOKEN_LEN: usize = 20;
 | 
					pub const TOKEN_LEN: usize = 20;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// How often heartbeat pings are sent.
 | 
				
			||||||
 | 
					///
 | 
				
			||||||
 | 
					/// Should be half (or less) of the acceptable client timeout.
 | 
				
			||||||
 | 
					pub const WS_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// How long before lack of client response causes a timeout.
 | 
				
			||||||
 | 
					pub const WS_CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -48,6 +48,7 @@ async fn main() -> std::io::Result<()> {
 | 
				
			|||||||
            .route("/api", web::get().to(api::api_home))
 | 
					            .route("/api", web::get().to(api::api_home))
 | 
				
			||||||
            .route("/api", web::post().to(api::api_home))
 | 
					            .route("/api", web::post().to(api::api_home))
 | 
				
			||||||
            .route("/api/account/whoami", web::get().to(api::account::who_am_i))
 | 
					            .route("/api/account/whoami", web::get().to(api::account::who_am_i))
 | 
				
			||||||
 | 
					            .service(web::resource("/api/ws").route(web::get().to(api::ws)))
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
    .bind(&AppConfig::get().listen_address)?
 | 
					    .bind(&AppConfig::get().listen_address)?
 | 
				
			||||||
    .run()
 | 
					    .run()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,13 @@
 | 
				
			|||||||
 | 
					use crate::constants::{WS_CLIENT_TIMEOUT, WS_HEARTBEAT_INTERVAL};
 | 
				
			||||||
use crate::extractors::client_auth::APIClientAuth;
 | 
					use crate::extractors::client_auth::APIClientAuth;
 | 
				
			||||||
use crate::server::HttpResult;
 | 
					use crate::server::HttpResult;
 | 
				
			||||||
use actix_web::HttpResponse;
 | 
					use actix_web::dev::Payload;
 | 
				
			||||||
 | 
					use actix_web::{web, FromRequest, HttpRequest, HttpResponse};
 | 
				
			||||||
 | 
					use actix_ws::Message;
 | 
				
			||||||
 | 
					use futures_util::future::Either;
 | 
				
			||||||
 | 
					use futures_util::{future, StreamExt};
 | 
				
			||||||
 | 
					use std::time::Instant;
 | 
				
			||||||
 | 
					use tokio::{pin, time::interval};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub mod account;
 | 
					pub mod account;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -8,3 +15,98 @@ pub mod account;
 | 
				
			|||||||
pub async fn api_home(auth: APIClientAuth) -> HttpResult {
 | 
					pub async fn api_home(auth: APIClientAuth) -> HttpResult {
 | 
				
			||||||
    Ok(HttpResponse::Ok().body(format!("Welcome user {}!", auth.user.user_id.0)))
 | 
					    Ok(HttpResponse::Ok().body(format!("Welcome user {}!", auth.user.user_id.0)))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Main WS route
 | 
				
			||||||
 | 
					pub async fn ws(req: HttpRequest, stream: web::Payload) -> HttpResult {
 | 
				
			||||||
 | 
					    // Forcefully ignore request payload by manually extracting authentication information
 | 
				
			||||||
 | 
					    let auth = APIClientAuth::from_request(&req, &mut Payload::None).await?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let (res, session, msg_stream) = actix_ws::handle(&req, stream)?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // spawn websocket handler (and don't await it) so that the response is returned immediately
 | 
				
			||||||
 | 
					    actix_web::rt::spawn(ws_handler(session, msg_stream));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Ok(res)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pub async fn ws_handler(mut session: actix_ws::Session, mut msg_stream: actix_ws::MessageStream) {
 | 
				
			||||||
 | 
					    log::info!("WS connected");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let mut last_heartbeat = Instant::now();
 | 
				
			||||||
 | 
					    let mut interval = interval(WS_HEARTBEAT_INTERVAL);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let reason = loop {
 | 
				
			||||||
 | 
					        // create "next client timeout check" future
 | 
				
			||||||
 | 
					        let tick = interval.tick();
 | 
				
			||||||
 | 
					        // required for select()
 | 
				
			||||||
 | 
					        pin!(tick);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // waits for either `msg_stream` to receive a message from the client or the heartbeat
 | 
				
			||||||
 | 
					        // interval timer to tick, yielding the value of whichever one is ready first
 | 
				
			||||||
 | 
					        match future::select(msg_stream.next(), tick).await {
 | 
				
			||||||
 | 
					            // received message from WebSocket client
 | 
				
			||||||
 | 
					            Either::Left((Some(Ok(msg)), _)) => {
 | 
				
			||||||
 | 
					                log::debug!("msg: {msg:?}");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                match msg {
 | 
				
			||||||
 | 
					                    Message::Text(text) => {
 | 
				
			||||||
 | 
					                        session.text(text).await.unwrap();
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    Message::Binary(bin) => {
 | 
				
			||||||
 | 
					                        session.binary(bin).await.unwrap();
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    Message::Close(reason) => {
 | 
				
			||||||
 | 
					                        break reason;
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    Message::Ping(bytes) => {
 | 
				
			||||||
 | 
					                        last_heartbeat = Instant::now();
 | 
				
			||||||
 | 
					                        let _ = session.pong(&bytes).await;
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    Message::Pong(_) => {
 | 
				
			||||||
 | 
					                        last_heartbeat = Instant::now();
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    Message::Continuation(_) => {
 | 
				
			||||||
 | 
					                        log::warn!("no support for continuation frames");
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    // no-op; ignore
 | 
				
			||||||
 | 
					                    Message::Nop => {}
 | 
				
			||||||
 | 
					                };
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // client WebSocket stream error
 | 
				
			||||||
 | 
					            Either::Left((Some(Err(err)), _)) => {
 | 
				
			||||||
 | 
					                log::error!("{}", err);
 | 
				
			||||||
 | 
					                break None;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // client WebSocket stream ended
 | 
				
			||||||
 | 
					            Either::Left((None, _)) => break None,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            // heartbeat interval ticked
 | 
				
			||||||
 | 
					            Either::Right((_inst, _)) => {
 | 
				
			||||||
 | 
					                // if no heartbeat ping/pong received recently, close the connection
 | 
				
			||||||
 | 
					                if Instant::now().duration_since(last_heartbeat) > WS_CLIENT_TIMEOUT {
 | 
				
			||||||
 | 
					                    log::info!(
 | 
				
			||||||
 | 
					                        "client has not sent heartbeat in over {WS_CLIENT_TIMEOUT:?}; disconnecting"
 | 
				
			||||||
 | 
					                    );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    break None;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // send heartbeat ping
 | 
				
			||||||
 | 
					                let _ = session.ping(b"").await;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // attempt to close connection gracefully
 | 
				
			||||||
 | 
					    let _ = session.close(reason).await;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    log::info!("WS disconnected");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,6 +12,8 @@ pub enum HttpFailure {
 | 
				
			|||||||
    Forbidden,
 | 
					    Forbidden,
 | 
				
			||||||
    #[error("this resource was not found")]
 | 
					    #[error("this resource was not found")]
 | 
				
			||||||
    NotFound,
 | 
					    NotFound,
 | 
				
			||||||
 | 
					    #[error("Actix web error")]
 | 
				
			||||||
 | 
					    ActixError(#[from] actix_web::Error),
 | 
				
			||||||
    #[error("an unhandled session insert error occurred")]
 | 
					    #[error("an unhandled session insert error occurred")]
 | 
				
			||||||
    SessionInsertError(#[from] actix_session::SessionInsertError),
 | 
					    SessionInsertError(#[from] actix_session::SessionInsertError),
 | 
				
			||||||
    #[error("an unhandled session error occurred")]
 | 
					    #[error("an unhandled session error occurred")]
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user