diff --git a/virtweb_backend/src/constants.rs b/virtweb_backend/src/constants.rs index 87507d7..f3c3549 100644 --- a/virtweb_backend/src/constants.rs +++ b/virtweb_backend/src/constants.rs @@ -6,3 +6,6 @@ pub const MAX_INACTIVITY_DURATION: u64 = 60 * 30; /// Maximum session duration (6 hours) pub const MAX_SESSION_DURATION: u64 = 3600 * 6; + +/// The routes that can be accessed without authentication +pub const ROUTES_WITHOUT_AUTH: [&str; 3] = ["/", "/api/server/static_config", "/api/auth/local"]; diff --git a/virtweb_backend/src/controllers/auth_controller.rs b/virtweb_backend/src/controllers/auth_controller.rs index 5191119..5dce683 100644 --- a/virtweb_backend/src/controllers/auth_controller.rs +++ b/virtweb_backend/src/controllers/auth_controller.rs @@ -1,5 +1,5 @@ use crate::app_config::AppConfig; -use crate::extractors::auth_extractor::AuthChecker; +use crate::extractors::auth_extractor::AuthExtractor; use crate::extractors::local_auth_extractor::LocalAuthEnabled; use actix_web::{web, HttpResponse, Responder}; @@ -13,7 +13,7 @@ pub struct LocalAuthReq { pub async fn local_auth( local_auth_enabled: LocalAuthEnabled, req: web::Json, - auth: AuthChecker, + auth: AuthExtractor, ) -> impl Responder { if !*local_auth_enabled { log::error!("Local auth attempt while this authentication method is disabled!"); @@ -29,3 +29,15 @@ pub async fn local_auth( HttpResponse::Accepted().json("Welcome") } + +#[derive(serde::Serialize)] +struct CurrentUser { + id: String, +} + +/// Get current authenticated user +pub async fn current_user(auth: AuthExtractor) -> impl Responder { + HttpResponse::Ok().json(CurrentUser { + id: auth.id().unwrap(), + }) +} diff --git a/virtweb_backend/src/extractors/auth_extractor.rs b/virtweb_backend/src/extractors/auth_extractor.rs index eea58b2..26c2570 100644 --- a/virtweb_backend/src/extractors/auth_extractor.rs +++ b/virtweb_backend/src/extractors/auth_extractor.rs @@ -4,24 +4,24 @@ use actix_web::{Error, FromRequest, HttpMessage, HttpRequest}; use futures_util::future::{ready, Ready}; use std::fmt::Display; -pub struct AuthChecker { +pub struct AuthExtractor { identity: Option, request: HttpRequest, } -impl AuthChecker { +impl AuthExtractor { /// Check whether the user is authenticated or not pub fn is_authenticated(&self) -> bool { self.identity.is_some() } /// Authenticate the user - pub fn authenticate(&self, username: impl Display) { - Identity::login(&self.request.extensions(), username.to_string()) + pub fn authenticate(&self, id: impl Display) { + Identity::login(&self.request.extensions(), id.to_string()) .expect("Unable to set authentication!"); } - pub fn user_name(&self) -> Option { + pub fn id(&self) -> Option { self.identity.as_ref().map(|i| i.id().unwrap()) } @@ -32,7 +32,7 @@ impl AuthChecker { } } -impl FromRequest for AuthChecker { +impl FromRequest for AuthExtractor { type Error = Error; type Future = Ready>; diff --git a/virtweb_backend/src/lib.rs b/virtweb_backend/src/lib.rs index fd23042..d706a9c 100644 --- a/virtweb_backend/src/lib.rs +++ b/virtweb_backend/src/lib.rs @@ -2,3 +2,4 @@ pub mod app_config; pub mod constants; pub mod controllers; pub mod extractors; +pub mod middlewares; diff --git a/virtweb_backend/src/main.rs b/virtweb_backend/src/main.rs index f9e4aa0..d623027 100644 --- a/virtweb_backend/src/main.rs +++ b/virtweb_backend/src/main.rs @@ -12,6 +12,7 @@ use virtweb_backend::constants::{ MAX_INACTIVITY_DURATION, MAX_SESSION_DURATION, SESSION_COOKIE_NAME, }; use virtweb_backend::controllers::{auth_controller, server_controller}; +use virtweb_backend::middlewares::auth_middleware::AuthChecker; #[actix_web::main] async fn main() -> std::io::Result<()> { @@ -37,6 +38,7 @@ async fn main() -> std::io::Result<()> { App::new() .wrap(Logger::default()) + .wrap(AuthChecker) .wrap(identity_middleware) .wrap(session_mw) .app_data(web::Data::new(RemoteIPConfig { @@ -53,6 +55,10 @@ async fn main() -> std::io::Result<()> { "/api/auth/local", web::post().to(auth_controller::local_auth), ) + .route( + "/api/auth/user", + web::get().to(auth_controller::current_user), + ) }) .bind(&AppConfig::get().listen_address)? .run() diff --git a/virtweb_backend/src/middlewares.rs b/virtweb_backend/src/middlewares.rs new file mode 100644 index 0000000..2a709c0 --- /dev/null +++ b/virtweb_backend/src/middlewares.rs @@ -0,0 +1 @@ +pub mod auth_middleware; diff --git a/virtweb_backend/src/middlewares/auth_middleware.rs b/virtweb_backend/src/middlewares/auth_middleware.rs new file mode 100644 index 0000000..669abcf --- /dev/null +++ b/virtweb_backend/src/middlewares/auth_middleware.rs @@ -0,0 +1,95 @@ +use std::future::{ready, Ready}; +use std::rc::Rc; + +use crate::constants; +use crate::extractors::auth_extractor::AuthExtractor; +use actix_web::body::EitherBody; +use actix_web::dev::Payload; +use actix_web::{ + dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, + Error, FromRequest, HttpResponse, +}; +use futures_util::future::LocalBoxFuture; + +// There are two steps in middleware processing. +// 1. Middleware initialization, middleware factory gets called with +// next service in chain as parameter. +// 2. Middleware's call method gets called with normal request. +#[derive(Default)] +pub struct AuthChecker; + +// Middleware factory is `Transform` trait +// `S` - type of the next service +// `B` - type of response's body +impl Transform for AuthChecker +where + S: Service, Error = Error> + 'static, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse>; + type Error = Error; + type InitError = (); + type Transform = AuthMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(AuthMiddleware { + service: Rc::new(service), + })) + } +} + +pub struct AuthMiddleware { + service: Rc, +} + +impl Service for AuthMiddleware +where + S: Service, Error = Error> + 'static, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse>; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + forward_ready!(service); + + fn call(&self, req: ServiceRequest) -> Self::Future { + let service = Rc::clone(&self.service); + + Box::pin(async move { + // Check authentication, if required + if !constants::ROUTES_WITHOUT_AUTH.contains(&req.path()) { + let auth = match AuthExtractor::from_request(req.request(), &mut Payload::None) + .into_inner() + { + Ok(auth) => auth, + Err(e) => { + log::error!( + "Failed to extract authentication information from request! {e}" + ); + return Ok(req + .into_response(HttpResponse::InternalServerError().finish()) + .map_into_right_body()); + } + }; + + if !auth.is_authenticated() { + log::error!( + "User attempted to access privileged route without authentication!" + ); + return Ok(req + .into_response(HttpResponse::Unauthorized().json("Please authenticate!")) + .map_into_right_body()); + } + } + + service + .call(req) + .await + .map(ServiceResponse::map_into_left_body) + }) + } +}