diff --git a/geneit_backend/Cargo.lock b/geneit_backend/Cargo.lock index e68b7d6..bd1efac 100644 --- a/geneit_backend/Cargo.lock +++ b/geneit_backend/Cargo.lock @@ -796,6 +796,7 @@ dependencies = [ "redis", "serde", "serde_json", + "thiserror", ] [[package]] @@ -1740,6 +1741,26 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "thiserror" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.16", +] + [[package]] name = "time" version = "0.3.21" diff --git a/geneit_backend/Cargo.toml b/geneit_backend/Cargo.toml index 951ded6..8554c2a 100644 --- a/geneit_backend/Cargo.toml +++ b/geneit_backend/Cargo.toml @@ -22,4 +22,5 @@ redis = "0.23.0" lettre = "0.10.4" rand = "0.8.5" bcrypt = "0.14.0" -light-openid = "1.0.1" \ No newline at end of file +light-openid = "1.0.1" +thiserror = "1.0.40" \ No newline at end of file diff --git a/geneit_backend/src/connections/redis_connection.rs b/geneit_backend/src/connections/redis_connection.rs index 0d58aad..d3f8692 100644 --- a/geneit_backend/src/connections/redis_connection.rs +++ b/geneit_backend/src/connections/redis_connection.rs @@ -49,3 +49,10 @@ where Ok(()) } + +/// Remove a value from Redis +pub async fn remove_value(key: &str) -> anyhow::Result<()> { + execute_request(|conn| Ok(conn.del(key)?))?; + + Ok(()) +} diff --git a/geneit_backend/src/controllers/auth_controller.rs b/geneit_backend/src/controllers/auth_controller.rs index 7b536cc..9c2e999 100644 --- a/geneit_backend/src/controllers/auth_controller.rs +++ b/geneit_backend/src/controllers/auth_controller.rs @@ -254,3 +254,53 @@ pub async fn start_openid_login( Ok(HttpResponse::Ok().json(StartOpenIDLoginResponse { url })) } + +#[derive(serde::Deserialize)] +pub struct FinishOpenIDLoginQuery { + code: String, + state: String, +} + +/// Finish OpenID login +pub async fn finish_openid_login( + remote_ip: RemoteIP, + req: web::Json, +) -> HttpResult { + let user_info = openid_service::finish_login(remote_ip.0, &req.code, &req.state).await?; + + if user_info.email_verified != Some(true) { + log::error!("Email is not verified!"); + return Ok( + HttpResponse::Unauthorized().json("Email non vérifié par le fournisseur d'identité !") + ); + } + + let mail = match user_info.email { + Some(m) => m, + None => { + return Ok(HttpResponse::Unauthorized() + .json("Email non spécifié par le fournisseur d'identité !")); + } + }; + + // Create the account, if required + if !users_service::exists_email(&mail).await? { + let name = match (user_info.name, user_info.given_name, user_info.family_name) { + (Some(name), _, _) => name, + (None, Some(g), Some(f)) => format!("{g} {f}"), + (_, _, _) => { + return Ok(HttpResponse::Unauthorized() + .json("Nom non spécifié par le fournisseur d'identité !")); + } + }; + + users_service::create_account(&name, &mail).await?; + } + + let user = users_service::get_by_mail(&mail).await?; + + // OpenID auth is enough to validate accounts + users_service::validate_account(&user).await?; + + finish_login(&user).await +} diff --git a/geneit_backend/src/main.rs b/geneit_backend/src/main.rs index 134b8a4..b8e2912 100644 --- a/geneit_backend/src/main.rs +++ b/geneit_backend/src/main.rs @@ -47,6 +47,10 @@ async fn main() -> std::io::Result<()> { "/auth/start_openid_login", web::post().to(auth_controller::start_openid_login), ) + .route( + "/auth/finish_openid_login", + web::post().to(auth_controller::finish_openid_login), + ) // User controller .route("/user/info", web::get().to(user_controller::auth_info)) }) diff --git a/geneit_backend/src/services/openid_service.rs b/geneit_backend/src/services/openid_service.rs index df92921..43607a5 100644 --- a/geneit_backend/src/services/openid_service.rs +++ b/geneit_backend/src/services/openid_service.rs @@ -8,12 +8,28 @@ use crate::utils::time_utils::time; use light_openid::primitives::OpenIDConfig; use std::cell::RefCell; use std::collections::HashMap; -use std::io::ErrorKind; use std::net::IpAddr; thread_local! { static CONFIG_CACHES: RefCell> = RefCell::new(Default::default()); +} +#[derive(thiserror::Error, Debug)] +enum OpenIDServiceError { + #[error("Given provider not found!")] + FindProvider, + #[error("Failed to get provider configuration: {0}")] + GetProviderConfiguration(String), + #[error("Provided state does not exists!")] + NonExistingState, + #[error("The state has expired!")] + ExpiredState, + #[error("Invalid IP address")] + InvalidIP, + #[error("Failed to query token endpoint: {0}")] + QueryTokenEndpoint(String), + #[error("Failed to query user info endpoint: {0}")] + QueryUserInfoEndpoint(String), } struct OpenIDClient<'a> { @@ -53,7 +69,7 @@ async fn load_provider_info(prov_id: &str) -> anyhow::Result { .openid_providers() .into_iter() .find(|p| p.id.eq(prov_id)) - .ok_or_else(|| std::io::Error::new(ErrorKind::Other, "Provider not found!"))?; + .ok_or(OpenIDServiceError::FindProvider)?; if let Some(conf) = CONFIG_CACHES.with(|i| i.borrow().get(prov_id).cloned()) { return Ok(OpenIDClient { prov, conf }); @@ -61,7 +77,7 @@ async fn load_provider_info(prov_id: &str) -> anyhow::Result { let conf = OpenIDConfig::load_from_url(prov.configuration_url) .await - .map_err(|e| std::io::Error::new(ErrorKind::Other, e.to_string()))?; + .map_err(|e| OpenIDServiceError::GetProviderConfiguration(e.to_string()))?; CONFIG_CACHES.with(|i| { i.borrow_mut() @@ -83,3 +99,50 @@ pub async fn start_login(prov_id: &str, ip: IpAddr) -> anyhow::Result { &AppConfig::get().oidc_redirect_url, )) } + +/// Finish OpenID login +pub async fn finish_login( + ip: IpAddr, + code: &str, + state_key: &str, +) -> anyhow::Result { + // Consume state + let state = redis_connection::get_value::(&redis_key(state_key)) + .await? + .ok_or(OpenIDServiceError::NonExistingState)?; + redis_connection::remove_value(&redis_key(state_key)).await?; + + if state.expire < time() { + return Err(OpenIDServiceError::ExpiredState.into()); + } + + if state.ip != ip { + log::error!( + "Mismatching IP addresses (expected {} / got {}", + state.ip, + ip + ); + return Err(OpenIDServiceError::InvalidIP.into()); + } + + // Query provider + let prov = load_provider_info(&state.prov_id).await?; + let (token, _) = prov + .conf + .request_token( + prov.prov.client_id, + prov.prov.client_secret, + code, + &AppConfig::get().oidc_redirect_url, + ) + .await + .map_err(|e| OpenIDServiceError::QueryTokenEndpoint(e.to_string()))?; + + let (user_info, _) = prov + .conf + .request_user_info(&token) + .await + .map_err(|e| OpenIDServiceError::QueryUserInfoEndpoint(e.to_string()))?; + + Ok(user_info) +}