diff --git a/examples/api_curl.rs b/examples/api_curl.rs new file mode 100644 index 0000000..3f078b0 --- /dev/null +++ b/examples/api_curl.rs @@ -0,0 +1,81 @@ +use clap::Parser; +use jwt_simple::algorithms::HS256Key; +use jwt_simple::prelude::{Clock, Duration, JWTClaims, MACLike}; +use matrix_gateway::extractors::client_auth::TokenClaims; +use matrix_gateway::utils::rand_str; +use std::ops::Add; +use std::os::unix::prelude::CommandExt; +use std::process::Command; + +/// cURL wrapper to query MatrixGW +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Args { + /// URL of Matrix GW + #[arg(short('U'), long, env, default_value = "http://localhost:8000")] + matrix_gw_url: String, + + /// Token ID + #[arg(short('i'), long, env)] + token_id: String, + + /// User ID + #[arg(short('u'), long, env)] + user_id: String, + + /// Token secret + #[arg(short('t'), long, env)] + token_secret: String, + + /// Request verb + #[arg(short('X'), long, default_value = "GET")] + method: String, + + /// Request URI + uri: String, + + /// Command line arguments to pass to cURL + #[clap(trailing_var_arg = true, allow_hyphen_values = true)] + run: Vec, +} + +fn main() { + let args: Args = Args::parse(); + + let full_url = format!("{}{}", args.matrix_gw_url, args.uri); + log::debug!("Full URL: {full_url}"); + + let key = HS256Key::from_bytes(args.token_secret.as_bytes()); + + let claims = JWTClaims:: { + issued_at: Some(Clock::now_since_epoch()), + expires_at: Some(Clock::now_since_epoch().add(Duration::from_mins(15))), + invalid_before: None, + issuer: None, + subject: None, + audiences: None, + jwt_id: None, + nonce: Some(rand_str(10)), + custom: TokenClaims { + method: args.method.to_string(), + uri: args.uri, + }, + }; + + let jwt = key + .with_key_id(&format!( + "{}#{}", + urlencoding::encode(&args.user_id), + urlencoding::encode(&args.token_id) + )) + .authenticate(claims) + .expect("Failed to sign JWT!"); + + let _ = Command::new("curl") + .args(["-X", &args.method]) + .args(["-H", &format!("x-client-auth: {jwt}")]) + .args(args.run) + .arg(full_url) + .exec(); + panic!("Failed to run curl!") +} diff --git a/src/extractors/client_auth.rs b/src/extractors/client_auth.rs index ab974d8..e6f673c 100644 --- a/src/extractors/client_auth.rs +++ b/src/extractors/client_auth.rs @@ -2,7 +2,7 @@ use crate::user::{APIClient, APIClientID, UserConfig, UserID}; use actix_web::dev::Payload; use actix_web::{FromRequest, HttpRequest}; use jwt_simple::common::VerificationOptions; -use jwt_simple::prelude::{HS256Key, MACLike}; +use jwt_simple::prelude::{Duration, HS256Key, MACLike}; use std::str::FromStr; pub struct APIClientAuth { @@ -12,7 +12,10 @@ pub struct APIClientAuth { } #[derive(Debug, serde::Serialize, serde::Deserialize)] -struct JWTClaims {} +pub struct TokenClaims { + pub method: String, + pub uri: String, +} impl APIClientAuth { async fn extract_auth(req: &HttpRequest) -> Result { @@ -78,17 +81,31 @@ impl APIClientAuth { // Decode JWT let key = HS256Key::from_bytes(client.secret.as_bytes()); - let claims = - match key.verify_token::(jwt_token, Some(VerificationOptions::default())) { - Ok(t) => t, - Err(e) => { - log::error!("JWT validation failed! {e}"); - return Err(actix_web::error::ErrorForbidden("JWT validation failed!")); - } - }; + let mut verif = VerificationOptions::default(); + verif.max_validity = Some(Duration::from_mins(15)); + let claims = match key.verify_token::(jwt_token, Some(verif)) { + Ok(t) => t, + Err(e) => { + log::error!("JWT validation failed! {e}"); + return Err(actix_web::error::ErrorForbidden("JWT validation failed!")); + } + }; + + // Check for nonce + if claims.nonce.is_none() { + return Err(actix_web::error::ErrorBadRequest( + "A nonce is required in JWT!", + )); + } + + // Check URI & verb + if claims.custom.uri != req.uri().to_string() { + return Err(actix_web::error::ErrorBadRequest("URI mismatch!")); + } + if claims.custom.method != req.method().to_string() { + return Err(actix_web::error::ErrorBadRequest("Method mismatch!")); + } - // TODO : check timing - // TODO : check URI & verb // TODO : handle payload // TODO : check read only access // TODO : update last use (if required) diff --git a/src/main.rs b/src/main.rs index 294f4fb..a364918 100644 --- a/src/main.rs +++ b/src/main.rs @@ -42,7 +42,8 @@ async fn main() -> std::io::Result<()> { .route("/oidc_cb", web::get().to(web_ui::oidc_cb)) .route("/sign_out", web::get().to(web_ui::sign_out)) // API routes - .route("/api/", web::get().to(api::api_home)) + .route("/api", web::get().to(api::api_home)) + .route("/api", web::post().to(api::api_home)) }) .bind(&AppConfig::get().listen_address)? .run()