2023-06-02 09:49:18 +00:00
|
|
|
//! # OpenID service
|
|
|
|
|
|
|
|
use crate::app_config::{AppConfig, OIDCProvider};
|
|
|
|
use crate::connections::redis_connection;
|
|
|
|
use crate::constants::OPEN_ID_STATE_DURATION;
|
|
|
|
use crate::utils::string_utils;
|
|
|
|
use crate::utils::time_utils::time;
|
|
|
|
use light_openid::primitives::OpenIDConfig;
|
|
|
|
use std::cell::RefCell;
|
|
|
|
use std::collections::HashMap;
|
|
|
|
use std::net::IpAddr;
|
|
|
|
|
|
|
|
thread_local! {
|
|
|
|
static CONFIG_CACHES: RefCell<HashMap<String, OpenIDConfig>> = RefCell::new(Default::default());
|
2023-06-02 13:04:49 +00:00
|
|
|
}
|
2023-06-02 09:49:18 +00:00
|
|
|
|
2023-06-02 13:04:49 +00:00
|
|
|
#[derive(thiserror::Error, Debug)]
|
|
|
|
enum OpenIDServiceError {
|
|
|
|
#[error("Given provider not found!")]
|
|
|
|
FindProvider,
|
|
|
|
#[error("Failed to get provider configuration: {0}")]
|
|
|
|
GetProviderConfiguration(String),
|
|
|
|
#[error("Provided state does not exists!")]
|
|
|
|
NonExistingState,
|
|
|
|
#[error("The state has expired!")]
|
|
|
|
ExpiredState,
|
|
|
|
#[error("Invalid IP address")]
|
|
|
|
InvalidIP,
|
|
|
|
#[error("Failed to query token endpoint: {0}")]
|
|
|
|
QueryTokenEndpoint(String),
|
|
|
|
#[error("Failed to query user info endpoint: {0}")]
|
|
|
|
QueryUserInfoEndpoint(String),
|
2023-06-02 09:49:18 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
struct OpenIDClient<'a> {
|
|
|
|
prov: OIDCProvider<'a>,
|
|
|
|
conf: OpenIDConfig,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(serde::Serialize, serde::Deserialize)]
|
|
|
|
struct OpenIDState {
|
|
|
|
#[serde(rename = "i")]
|
|
|
|
ip: IpAddr,
|
|
|
|
#[serde(rename = "e")]
|
|
|
|
expire: u64,
|
|
|
|
#[serde(rename = "p")]
|
|
|
|
prov_id: String,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl OpenIDState {
|
|
|
|
pub fn new(ip: IpAddr, client: &OpenIDClient) -> (String, Self) {
|
|
|
|
(
|
|
|
|
string_utils::rand_str(30),
|
|
|
|
Self {
|
|
|
|
ip,
|
|
|
|
expire: time() + OPEN_ID_STATE_DURATION.as_secs(),
|
|
|
|
prov_id: client.prov.id.to_string(),
|
|
|
|
},
|
|
|
|
)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn redis_key(state: &str) -> String {
|
|
|
|
format!("oidc-state-{state}")
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn load_provider_info(prov_id: &str) -> anyhow::Result<OpenIDClient> {
|
|
|
|
let prov = AppConfig::get()
|
|
|
|
.openid_providers()
|
|
|
|
.into_iter()
|
|
|
|
.find(|p| p.id.eq(prov_id))
|
2023-06-02 13:04:49 +00:00
|
|
|
.ok_or(OpenIDServiceError::FindProvider)?;
|
2023-06-02 09:49:18 +00:00
|
|
|
|
|
|
|
if let Some(conf) = CONFIG_CACHES.with(|i| i.borrow().get(prov_id).cloned()) {
|
|
|
|
return Ok(OpenIDClient { prov, conf });
|
|
|
|
}
|
|
|
|
|
|
|
|
let conf = OpenIDConfig::load_from_url(prov.configuration_url)
|
|
|
|
.await
|
2023-06-02 13:04:49 +00:00
|
|
|
.map_err(|e| OpenIDServiceError::GetProviderConfiguration(e.to_string()))?;
|
2023-06-02 09:49:18 +00:00
|
|
|
|
|
|
|
CONFIG_CACHES.with(|i| {
|
|
|
|
i.borrow_mut()
|
|
|
|
.insert(prov.configuration_url.to_string(), conf.clone())
|
|
|
|
});
|
|
|
|
|
|
|
|
Ok(OpenIDClient { prov, conf })
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Get the URL where a user should be redirected for login
|
|
|
|
pub async fn start_login(prov_id: &str, ip: IpAddr) -> anyhow::Result<String> {
|
|
|
|
let prov = load_provider_info(prov_id).await?;
|
|
|
|
let (state_key, state) = OpenIDState::new(ip, &prov);
|
|
|
|
redis_connection::set_value(&redis_key(&state_key), &state, OPEN_ID_STATE_DURATION).await?;
|
|
|
|
|
|
|
|
Ok(prov.conf.gen_authorization_url(
|
|
|
|
prov.prov.client_id,
|
|
|
|
&state_key,
|
2023-06-06 13:47:30 +00:00
|
|
|
&AppConfig::get().oidc_redirect_url(),
|
2023-06-02 09:49:18 +00:00
|
|
|
))
|
|
|
|
}
|
2023-06-02 13:04:49 +00:00
|
|
|
|
|
|
|
/// Finish OpenID login
|
|
|
|
pub async fn finish_login(
|
|
|
|
ip: IpAddr,
|
|
|
|
code: &str,
|
|
|
|
state_key: &str,
|
|
|
|
) -> anyhow::Result<light_openid::primitives::OpenIDUserInfo> {
|
|
|
|
// Consume state
|
|
|
|
let state = redis_connection::get_value::<OpenIDState>(&redis_key(state_key))
|
|
|
|
.await?
|
|
|
|
.ok_or(OpenIDServiceError::NonExistingState)?;
|
|
|
|
redis_connection::remove_value(&redis_key(state_key)).await?;
|
|
|
|
|
|
|
|
if state.expire < time() {
|
|
|
|
return Err(OpenIDServiceError::ExpiredState.into());
|
|
|
|
}
|
|
|
|
|
|
|
|
if state.ip != ip {
|
|
|
|
log::error!(
|
|
|
|
"Mismatching IP addresses (expected {} / got {}",
|
|
|
|
state.ip,
|
|
|
|
ip
|
|
|
|
);
|
|
|
|
return Err(OpenIDServiceError::InvalidIP.into());
|
|
|
|
}
|
|
|
|
|
|
|
|
// Query provider
|
|
|
|
let prov = load_provider_info(&state.prov_id).await?;
|
|
|
|
let (token, _) = prov
|
|
|
|
.conf
|
|
|
|
.request_token(
|
|
|
|
prov.prov.client_id,
|
|
|
|
prov.prov.client_secret,
|
|
|
|
code,
|
2023-06-06 13:47:30 +00:00
|
|
|
&AppConfig::get().oidc_redirect_url(),
|
2023-06-02 13:04:49 +00:00
|
|
|
)
|
|
|
|
.await
|
|
|
|
.map_err(|e| OpenIDServiceError::QueryTokenEndpoint(e.to_string()))?;
|
|
|
|
|
|
|
|
let (user_info, _) = prov
|
|
|
|
.conf
|
|
|
|
.request_user_info(&token)
|
|
|
|
.await
|
|
|
|
.map_err(|e| OpenIDServiceError::QueryUserInfoEndpoint(e.to_string()))?;
|
|
|
|
|
|
|
|
Ok(user_info)
|
|
|
|
}
|