diff --git a/virtweb_backend/src/api_tokens.rs b/virtweb_backend/src/api_tokens.rs index 91aef97..f216974 100644 --- a/virtweb_backend/src/api_tokens.rs +++ b/virtweb_backend/src/api_tokens.rs @@ -5,11 +5,19 @@ use crate::constants; use crate::utils::jwt_utils; use crate::utils::jwt_utils::{TokenPrivKey, TokenPubKey}; use crate::utils::time_utils::time; +use actix_http::Method; use std::path::Path; #[derive(serde::Serialize, serde::Deserialize, Clone, Copy, Debug)] pub struct TokenID(pub uuid::Uuid); +impl TokenID { + /// Parse a string as a token id + pub fn parse(t: &str) -> anyhow::Result { + Ok(Self(uuid::Uuid::parse_str(t)?)) + } +} + #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] pub struct TokenRight { verb: TokenVerb, @@ -29,9 +37,9 @@ pub struct Token { #[serde(skip_serializing_if = "TokenPubKey::is_invalid")] pub pub_key: TokenPubKey, pub rights: TokenRights, - pub last_used: Option, + pub last_used: u64, pub ip_restriction: Option, - pub delete_after_inactivity: Option, + pub max_inactivity: Option, } impl Token { @@ -45,9 +53,25 @@ impl Token { } /// Load token information from a file - pub fn load_from_file(path: &Path) -> anyhow::Result { + fn load_from_file(path: &Path) -> anyhow::Result { Ok(serde_json::from_str(&std::fs::read_to_string(path)?)?) } + + /// Check whether a token is expired or not + pub fn is_expired(&self) -> bool { + if let Some(max_inactivity) = self.max_inactivity { + if max_inactivity + self.last_used < time() { + return true; + } + } + + false + } + + /// Check whether last_used shall be updated or not + pub fn should_update_last_activity(&self) -> bool { + self.last_used + 3600 < time() + } } #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, Copy, Eq, PartialEq)] @@ -59,6 +83,18 @@ pub enum TokenVerb { DELETE, } +impl TokenVerb { + pub fn as_method(&self) -> Method { + match self { + TokenVerb::GET => Method::GET, + TokenVerb::POST => Method::POST, + TokenVerb::PUT => Method::PUT, + TokenVerb::PATCH => Method::PATCH, + TokenVerb::DELETE => Method::DELETE, + } + } +} + /// Structure used to create a token #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] pub struct NewToken { @@ -125,9 +161,9 @@ pub async fn create(t: &NewToken) -> anyhow::Result<(Token, TokenPrivKey)> { updated: time(), pub_key, rights: t.rights.clone(), - last_used: Some(time()), + last_used: time(), ip_restriction: t.ip_restriction, - delete_after_inactivity: t.delete_after_inactivity, + max_inactivity: t.delete_after_inactivity, }; token.save()?; diff --git a/virtweb_backend/src/extractors/api_auth_extractor.rs b/virtweb_backend/src/extractors/api_auth_extractor.rs new file mode 100644 index 0000000..bed5da3 --- /dev/null +++ b/virtweb_backend/src/extractors/api_auth_extractor.rs @@ -0,0 +1,123 @@ +use crate::api_tokens::{Token, TokenID, TokenVerb}; + +use crate::api_tokens; +use crate::utils::jwt_utils; +use crate::utils::time_utils::time; +use actix_web::dev::Payload; +use actix_web::error::ErrorBadRequest; +use actix_web::{Error, FromRequest, HttpRequest}; +use std::future::Future; +use std::pin::Pin; + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +pub struct TokenClaims { + pub sub: String, + pub iat: usize, + pub exp: usize, + pub verb: TokenVerb, + pub path: String, + pub nonce: String, +} + +pub struct ApiAuthExtractor { + pub token: Token, + pub claims: TokenClaims, +} + +impl FromRequest for ApiAuthExtractor { + type Error = Error; + type Future = Pin>>>; + + fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future { + let req = req.clone(); + + Box::pin(async move { + let (token_id, token_jwt) = match ( + req.headers().get("x-token-id"), + req.headers().get("x-token-content"), + ) { + (Some(id), Some(jwt)) => ( + id.to_str().unwrap_or("").to_string(), + jwt.to_str().unwrap_or("").to_string(), + ), + (_, _) => { + return Err(ErrorBadRequest("API auth headers were not all specified!")); + } + }; + + let token_id = match TokenID::parse(&token_id) { + Ok(t) => t, + Err(e) => { + log::error!("Failed to parse token id! {e}"); + return Err(ErrorBadRequest("Unable to validate token ID!")); + } + }; + + let token = match api_tokens::get_single(token_id).await { + Ok(t) => t, + Err(e) => { + log::error!("Failed to retrieve token: {e}"); + return Err(ErrorBadRequest("Unable to validate token!")); + } + }; + + if token.is_expired() { + log::error!("Token has expired (not been used for too long)!"); + return Err(ErrorBadRequest("Unable to validate token!")); + } + + let claims = match jwt_utils::validate_jwt::(&token.pub_key, &token_jwt) { + Ok(c) => c, + Err(e) => { + log::error!("Failed to validate JWT: {e}"); + return Err(ErrorBadRequest("Unable to validate token!")); + } + }; + + if claims.sub != token.id.0.to_string() { + log::error!("JWT sub mismatch (should equal to token id)!"); + return Err(ErrorBadRequest( + "JWT sub mismatch (should equal to token id)!", + )); + } + + if time() + 60 * 15 < claims.iat as u64 { + log::error!("iat is in the future!"); + return Err(ErrorBadRequest("iat is in the future!")); + } + + if claims.exp < claims.iat { + log::error!("exp shall not be smaller than iat!"); + return Err(ErrorBadRequest("exp shall not be smaller than iat!")); + } + + if claims.exp - claims.iat > 1800 { + log::error!("JWT shall not be valid more than 30 minutes!"); + return Err(ErrorBadRequest( + "JWT shall not be valid more than 30 minutes!", + )); + } + + if claims.path != req.path() { + log::error!("JWT path mismatch!"); + return Err(ErrorBadRequest("JWT path mismatch!")); + } + + if claims.verb.as_method() != req.method() { + log::error!("JWT method mismatch!"); + return Err(ErrorBadRequest("JWT method mismatch!")); + } + + // TODO : check if route is authorized with token + // TODO : check for ip restriction + + // TODO : manually validate all checks + + if token.should_update_last_activity() { + // TODO : update last activity + } + + Ok(ApiAuthExtractor { token, claims }) + }) + } +} diff --git a/virtweb_backend/src/extractors/mod.rs b/virtweb_backend/src/extractors/mod.rs index d7e49e1..d284d19 100644 --- a/virtweb_backend/src/extractors/mod.rs +++ b/virtweb_backend/src/extractors/mod.rs @@ -1,2 +1,3 @@ +pub mod api_auth_extractor; pub mod auth_extractor; pub mod local_auth_extractor; diff --git a/virtweb_backend/src/middlewares/auth_middleware.rs b/virtweb_backend/src/middlewares/auth_middleware.rs index f279216..7479b4d 100644 --- a/virtweb_backend/src/middlewares/auth_middleware.rs +++ b/virtweb_backend/src/middlewares/auth_middleware.rs @@ -3,6 +3,7 @@ use std::rc::Rc; use crate::app_config::AppConfig; use crate::constants; +use crate::extractors::api_auth_extractor::ApiAuthExtractor; use crate::extractors::auth_extractor::AuthExtractor; use actix_web::body::EitherBody; use actix_web::dev::Payload; @@ -68,8 +69,28 @@ where let auth_disabled = AppConfig::get().unsecure_disable_auth; - // Check authentication, if required - if !auth_disabled + // Check API authentication + if req.headers().get("x-token-id").is_some() { + let auth = + match ApiAuthExtractor::from_request(req.request(), &mut Payload::None).await { + Ok(auth) => auth, + Err(e) => { + log::error!( + "Failed to extract API authentication information from request! {e}" + ); + return Ok(req + .into_response(HttpResponse::PreconditionFailed().finish()) + .map_into_right_body()); + } + }; + + log::info!( + "Using API token '{}' to perform the request", + auth.token.name + ); + } + // Check user authentication, if required + else if !auth_disabled && !constants::ROUTES_WITHOUT_AUTH.contains(&req.path()) && req.path().starts_with("/api/") {