//! # OpenID service use crate::app_config::{AppConfig, OIDCProvider}; use crate::connections::redis_connection; use crate::constants::OPEN_ID_STATE_DURATION; use crate::utils::string_utils; use crate::utils::time_utils::time; use light_openid::primitives::OpenIDConfig; use std::cell::RefCell; use std::collections::HashMap; use std::net::IpAddr; thread_local! { static CONFIG_CACHES: RefCell> = RefCell::new(Default::default()); } #[derive(thiserror::Error, Debug)] enum OpenIDServiceError { #[error("Given provider not found!")] FindProvider, #[error("Failed to get provider configuration: {0}")] GetProviderConfiguration(String), #[error("Provided state does not exists!")] NonExistingState, #[error("The state has expired!")] ExpiredState, #[error("Invalid IP address")] InvalidIP, #[error("Failed to query token endpoint: {0}")] QueryTokenEndpoint(String), #[error("Failed to query user info endpoint: {0}")] QueryUserInfoEndpoint(String), } struct OpenIDClient<'a> { prov: OIDCProvider<'a>, conf: OpenIDConfig, } #[derive(serde::Serialize, serde::Deserialize)] struct OpenIDState { #[serde(rename = "i")] ip: IpAddr, #[serde(rename = "e")] expire: u64, #[serde(rename = "p")] prov_id: String, } impl OpenIDState { pub fn new(ip: IpAddr, client: &OpenIDClient) -> (String, Self) { ( string_utils::rand_str(30), Self { ip, expire: time() + OPEN_ID_STATE_DURATION.as_secs(), prov_id: client.prov.id.to_string(), }, ) } } fn redis_key(state: &str) -> String { format!("oidc-state-{state}") } async fn load_provider_info(prov_id: &str) -> anyhow::Result { let prov = AppConfig::get() .openid_providers() .into_iter() .find(|p| p.id.eq(prov_id)) .ok_or(OpenIDServiceError::FindProvider)?; if let Some(conf) = CONFIG_CACHES.with(|i| i.borrow().get(prov_id).cloned()) { return Ok(OpenIDClient { prov, conf }); } let conf = OpenIDConfig::load_from_url(prov.configuration_url) .await .map_err(|e| OpenIDServiceError::GetProviderConfiguration(e.to_string()))?; CONFIG_CACHES.with(|i| { i.borrow_mut() .insert(prov.configuration_url.to_string(), conf.clone()) }); Ok(OpenIDClient { prov, conf }) } /// Get the URL where a user should be redirected for login pub async fn start_login(prov_id: &str, ip: IpAddr) -> anyhow::Result { let prov = load_provider_info(prov_id).await?; let (state_key, state) = OpenIDState::new(ip, &prov); redis_connection::set_value(&redis_key(&state_key), &state, OPEN_ID_STATE_DURATION).await?; Ok(prov.conf.gen_authorization_url( prov.prov.client_id, &state_key, &AppConfig::get().oidc_redirect_url, )) } /// Finish OpenID login pub async fn finish_login( ip: IpAddr, code: &str, state_key: &str, ) -> anyhow::Result { // Consume state let state = redis_connection::get_value::(&redis_key(state_key)) .await? .ok_or(OpenIDServiceError::NonExistingState)?; redis_connection::remove_value(&redis_key(state_key)).await?; if state.expire < time() { return Err(OpenIDServiceError::ExpiredState.into()); } if state.ip != ip { log::error!( "Mismatching IP addresses (expected {} / got {}", state.ip, ip ); return Err(OpenIDServiceError::InvalidIP.into()); } // Query provider let prov = load_provider_info(&state.prov_id).await?; let (token, _) = prov .conf .request_token( prov.prov.client_id, prov.prov.client_secret, code, &AppConfig::get().oidc_redirect_url, ) .await .map_err(|e| OpenIDServiceError::QueryTokenEndpoint(e.to_string()))?; let (user_info, _) = prov .conf .request_user_info(&token) .await .map_err(|e| OpenIDServiceError::QueryUserInfoEndpoint(e.to_string()))?; Ok(user_info) }