149 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
			
		
		
	
	
			149 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
| //! # OpenID service
 | |
| 
 | |
| use crate::app_config::{AppConfig, OIDCProvider};
 | |
| use crate::connections::redis_connection;
 | |
| use crate::constants::OPEN_ID_STATE_DURATION;
 | |
| use crate::utils::string_utils;
 | |
| use crate::utils::time_utils::time;
 | |
| use light_openid::primitives::OpenIDConfig;
 | |
| use std::cell::RefCell;
 | |
| use std::collections::HashMap;
 | |
| use std::net::IpAddr;
 | |
| 
 | |
| thread_local! {
 | |
|     static CONFIG_CACHES: RefCell<HashMap<String, OpenIDConfig>> = RefCell::new(Default::default());
 | |
| }
 | |
| 
 | |
| #[derive(thiserror::Error, Debug)]
 | |
| enum OpenIDServiceError {
 | |
|     #[error("Given provider not found!")]
 | |
|     FindProvider,
 | |
|     #[error("Failed to get provider configuration: {0}")]
 | |
|     GetProviderConfiguration(String),
 | |
|     #[error("Provided state does not exists!")]
 | |
|     NonExistingState,
 | |
|     #[error("The state has expired!")]
 | |
|     ExpiredState,
 | |
|     #[error("Invalid IP address")]
 | |
|     InvalidIP,
 | |
|     #[error("Failed to query token endpoint: {0}")]
 | |
|     QueryTokenEndpoint(String),
 | |
|     #[error("Failed to query user info endpoint: {0}")]
 | |
|     QueryUserInfoEndpoint(String),
 | |
| }
 | |
| 
 | |
| struct OpenIDClient<'a> {
 | |
|     prov: OIDCProvider<'a>,
 | |
|     conf: OpenIDConfig,
 | |
| }
 | |
| 
 | |
| #[derive(serde::Serialize, serde::Deserialize)]
 | |
| struct OpenIDState {
 | |
|     #[serde(rename = "i")]
 | |
|     ip: IpAddr,
 | |
|     #[serde(rename = "e")]
 | |
|     expire: u64,
 | |
|     #[serde(rename = "p")]
 | |
|     prov_id: String,
 | |
| }
 | |
| 
 | |
| impl OpenIDState {
 | |
|     pub fn new(ip: IpAddr, client: &OpenIDClient) -> (String, Self) {
 | |
|         (
 | |
|             string_utils::rand_str(30),
 | |
|             Self {
 | |
|                 ip,
 | |
|                 expire: time() + OPEN_ID_STATE_DURATION.as_secs(),
 | |
|                 prov_id: client.prov.id.to_string(),
 | |
|             },
 | |
|         )
 | |
|     }
 | |
| }
 | |
| 
 | |
| fn redis_key(state: &str) -> String {
 | |
|     format!("oidc-state-{state}")
 | |
| }
 | |
| 
 | |
| async fn load_provider_info(prov_id: &str) -> anyhow::Result<OpenIDClient<'_>> {
 | |
|     let prov = AppConfig::get()
 | |
|         .openid_providers()
 | |
|         .into_iter()
 | |
|         .find(|p| p.id.eq(prov_id))
 | |
|         .ok_or(OpenIDServiceError::FindProvider)?;
 | |
| 
 | |
|     if let Some(conf) = CONFIG_CACHES.with(|i| i.borrow().get(prov_id).cloned()) {
 | |
|         return Ok(OpenIDClient { prov, conf });
 | |
|     }
 | |
| 
 | |
|     let conf = OpenIDConfig::load_from_url(prov.configuration_url)
 | |
|         .await
 | |
|         .map_err(|e| OpenIDServiceError::GetProviderConfiguration(e.to_string()))?;
 | |
| 
 | |
|     CONFIG_CACHES.with(|i| {
 | |
|         i.borrow_mut()
 | |
|             .insert(prov.configuration_url.to_string(), conf.clone())
 | |
|     });
 | |
| 
 | |
|     Ok(OpenIDClient { prov, conf })
 | |
| }
 | |
| 
 | |
| /// Get the URL where a user should be redirected for login
 | |
| pub async fn start_login(prov_id: &str, ip: IpAddr) -> anyhow::Result<String> {
 | |
|     let prov = load_provider_info(prov_id).await?;
 | |
|     let (state_key, state) = OpenIDState::new(ip, &prov);
 | |
|     redis_connection::set_value(&redis_key(&state_key), &state, OPEN_ID_STATE_DURATION).await?;
 | |
| 
 | |
|     Ok(prov.conf.gen_authorization_url(
 | |
|         prov.prov.client_id,
 | |
|         &state_key,
 | |
|         &AppConfig::get().oidc_redirect_url(),
 | |
|     ))
 | |
| }
 | |
| 
 | |
| /// Finish OpenID login
 | |
| pub async fn finish_login(
 | |
|     ip: IpAddr,
 | |
|     code: &str,
 | |
|     state_key: &str,
 | |
| ) -> anyhow::Result<light_openid::primitives::OpenIDUserInfo> {
 | |
|     // Consume state
 | |
|     let state = redis_connection::get_value::<OpenIDState>(&redis_key(state_key))
 | |
|         .await?
 | |
|         .ok_or(OpenIDServiceError::NonExistingState)?;
 | |
|     redis_connection::remove_value(&redis_key(state_key)).await?;
 | |
| 
 | |
|     if state.expire < time() {
 | |
|         return Err(OpenIDServiceError::ExpiredState.into());
 | |
|     }
 | |
| 
 | |
|     if state.ip != ip {
 | |
|         log::error!(
 | |
|             "Mismatching IP addresses (expected {} / got {}",
 | |
|             state.ip,
 | |
|             ip
 | |
|         );
 | |
|         return Err(OpenIDServiceError::InvalidIP.into());
 | |
|     }
 | |
| 
 | |
|     // Query provider
 | |
|     let prov = load_provider_info(&state.prov_id).await?;
 | |
|     let (token, _) = prov
 | |
|         .conf
 | |
|         .request_token(
 | |
|             prov.prov.client_id,
 | |
|             prov.prov.client_secret,
 | |
|             code,
 | |
|             &AppConfig::get().oidc_redirect_url(),
 | |
|         )
 | |
|         .await
 | |
|         .map_err(|e| OpenIDServiceError::QueryTokenEndpoint(e.to_string()))?;
 | |
| 
 | |
|     let (user_info, _) = prov
 | |
|         .conf
 | |
|         .request_user_info(&token)
 | |
|         .await
 | |
|         .map_err(|e| OpenIDServiceError::QueryUserInfoEndpoint(e.to_string()))?;
 | |
| 
 | |
|     Ok(user_info)
 | |
| }
 |