GeneIT/geneit_backend/src/services/openid_service.rs

149 lines
4.2 KiB
Rust
Raw Normal View History

2023-06-02 09:49:18 +00:00
//! # OpenID service
use crate::app_config::{AppConfig, OIDCProvider};
use crate::connections::redis_connection;
use crate::constants::OPEN_ID_STATE_DURATION;
use crate::utils::string_utils;
use crate::utils::time_utils::time;
use light_openid::primitives::OpenIDConfig;
use std::cell::RefCell;
use std::collections::HashMap;
use std::net::IpAddr;
thread_local! {
static CONFIG_CACHES: RefCell<HashMap<String, OpenIDConfig>> = RefCell::new(Default::default());
2023-06-02 13:04:49 +00:00
}
2023-06-02 09:49:18 +00:00
2023-06-02 13:04:49 +00:00
#[derive(thiserror::Error, Debug)]
enum OpenIDServiceError {
#[error("Given provider not found!")]
FindProvider,
#[error("Failed to get provider configuration: {0}")]
GetProviderConfiguration(String),
#[error("Provided state does not exists!")]
NonExistingState,
#[error("The state has expired!")]
ExpiredState,
#[error("Invalid IP address")]
InvalidIP,
#[error("Failed to query token endpoint: {0}")]
QueryTokenEndpoint(String),
#[error("Failed to query user info endpoint: {0}")]
QueryUserInfoEndpoint(String),
2023-06-02 09:49:18 +00:00
}
struct OpenIDClient<'a> {
prov: OIDCProvider<'a>,
conf: OpenIDConfig,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct OpenIDState {
#[serde(rename = "i")]
ip: IpAddr,
#[serde(rename = "e")]
expire: u64,
#[serde(rename = "p")]
prov_id: String,
}
impl OpenIDState {
pub fn new(ip: IpAddr, client: &OpenIDClient) -> (String, Self) {
(
string_utils::rand_str(30),
Self {
ip,
expire: time() + OPEN_ID_STATE_DURATION.as_secs(),
prov_id: client.prov.id.to_string(),
},
)
}
}
fn redis_key(state: &str) -> String {
format!("oidc-state-{state}")
}
async fn load_provider_info(prov_id: &str) -> anyhow::Result<OpenIDClient> {
let prov = AppConfig::get()
.openid_providers()
.into_iter()
.find(|p| p.id.eq(prov_id))
2023-06-02 13:04:49 +00:00
.ok_or(OpenIDServiceError::FindProvider)?;
2023-06-02 09:49:18 +00:00
if let Some(conf) = CONFIG_CACHES.with(|i| i.borrow().get(prov_id).cloned()) {
return Ok(OpenIDClient { prov, conf });
}
let conf = OpenIDConfig::load_from_url(prov.configuration_url)
.await
2023-06-02 13:04:49 +00:00
.map_err(|e| OpenIDServiceError::GetProviderConfiguration(e.to_string()))?;
2023-06-02 09:49:18 +00:00
CONFIG_CACHES.with(|i| {
i.borrow_mut()
.insert(prov.configuration_url.to_string(), conf.clone())
});
Ok(OpenIDClient { prov, conf })
}
/// Get the URL where a user should be redirected for login
pub async fn start_login(prov_id: &str, ip: IpAddr) -> anyhow::Result<String> {
let prov = load_provider_info(prov_id).await?;
let (state_key, state) = OpenIDState::new(ip, &prov);
redis_connection::set_value(&redis_key(&state_key), &state, OPEN_ID_STATE_DURATION).await?;
Ok(prov.conf.gen_authorization_url(
prov.prov.client_id,
&state_key,
2023-06-06 13:47:30 +00:00
&AppConfig::get().oidc_redirect_url(),
2023-06-02 09:49:18 +00:00
))
}
2023-06-02 13:04:49 +00:00
/// Finish OpenID login
pub async fn finish_login(
ip: IpAddr,
code: &str,
state_key: &str,
) -> anyhow::Result<light_openid::primitives::OpenIDUserInfo> {
// Consume state
let state = redis_connection::get_value::<OpenIDState>(&redis_key(state_key))
.await?
.ok_or(OpenIDServiceError::NonExistingState)?;
redis_connection::remove_value(&redis_key(state_key)).await?;
if state.expire < time() {
return Err(OpenIDServiceError::ExpiredState.into());
}
if state.ip != ip {
log::error!(
"Mismatching IP addresses (expected {} / got {}",
state.ip,
ip
);
return Err(OpenIDServiceError::InvalidIP.into());
}
// Query provider
let prov = load_provider_info(&state.prov_id).await?;
let (token, _) = prov
.conf
.request_token(
prov.prov.client_id,
prov.prov.client_secret,
code,
2023-06-06 13:47:30 +00:00
&AppConfig::get().oidc_redirect_url(),
2023-06-02 13:04:49 +00:00
)
.await
.map_err(|e| OpenIDServiceError::QueryTokenEndpoint(e.to_string()))?;
let (user_info, _) = prov
.conf
.request_user_info(&token)
.await
.map_err(|e| OpenIDServiceError::QueryUserInfoEndpoint(e.to_string()))?;
Ok(user_info)
}