Can finish open id login
This commit is contained in:
		@@ -49,3 +49,10 @@ where
 | 
			
		||||
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Remove a value from Redis
 | 
			
		||||
pub async fn remove_value(key: &str) -> anyhow::Result<()> {
 | 
			
		||||
    execute_request(|conn| Ok(conn.del(key)?))?;
 | 
			
		||||
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -254,3 +254,53 @@ pub async fn start_openid_login(
 | 
			
		||||
 | 
			
		||||
    Ok(HttpResponse::Ok().json(StartOpenIDLoginResponse { url }))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(serde::Deserialize)]
 | 
			
		||||
pub struct FinishOpenIDLoginQuery {
 | 
			
		||||
    code: String,
 | 
			
		||||
    state: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Finish OpenID login
 | 
			
		||||
pub async fn finish_openid_login(
 | 
			
		||||
    remote_ip: RemoteIP,
 | 
			
		||||
    req: web::Json<FinishOpenIDLoginQuery>,
 | 
			
		||||
) -> HttpResult {
 | 
			
		||||
    let user_info = openid_service::finish_login(remote_ip.0, &req.code, &req.state).await?;
 | 
			
		||||
 | 
			
		||||
    if user_info.email_verified != Some(true) {
 | 
			
		||||
        log::error!("Email is not verified!");
 | 
			
		||||
        return Ok(
 | 
			
		||||
            HttpResponse::Unauthorized().json("Email non vérifié par le fournisseur d'identité !")
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let mail = match user_info.email {
 | 
			
		||||
        Some(m) => m,
 | 
			
		||||
        None => {
 | 
			
		||||
            return Ok(HttpResponse::Unauthorized()
 | 
			
		||||
                .json("Email non spécifié par le fournisseur d'identité !"));
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // Create the account, if required
 | 
			
		||||
    if !users_service::exists_email(&mail).await? {
 | 
			
		||||
        let name = match (user_info.name, user_info.given_name, user_info.family_name) {
 | 
			
		||||
            (Some(name), _, _) => name,
 | 
			
		||||
            (None, Some(g), Some(f)) => format!("{g} {f}"),
 | 
			
		||||
            (_, _, _) => {
 | 
			
		||||
                return Ok(HttpResponse::Unauthorized()
 | 
			
		||||
                    .json("Nom non spécifié par le fournisseur d'identité !"));
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        users_service::create_account(&name, &mail).await?;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let user = users_service::get_by_mail(&mail).await?;
 | 
			
		||||
 | 
			
		||||
    // OpenID auth is enough to validate accounts
 | 
			
		||||
    users_service::validate_account(&user).await?;
 | 
			
		||||
 | 
			
		||||
    finish_login(&user).await
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -47,6 +47,10 @@ async fn main() -> std::io::Result<()> {
 | 
			
		||||
                "/auth/start_openid_login",
 | 
			
		||||
                web::post().to(auth_controller::start_openid_login),
 | 
			
		||||
            )
 | 
			
		||||
            .route(
 | 
			
		||||
                "/auth/finish_openid_login",
 | 
			
		||||
                web::post().to(auth_controller::finish_openid_login),
 | 
			
		||||
            )
 | 
			
		||||
            // User controller
 | 
			
		||||
            .route("/user/info", web::get().to(user_controller::auth_info))
 | 
			
		||||
    })
 | 
			
		||||
 
 | 
			
		||||
@@ -8,12 +8,28 @@ use crate::utils::time_utils::time;
 | 
			
		||||
use light_openid::primitives::OpenIDConfig;
 | 
			
		||||
use std::cell::RefCell;
 | 
			
		||||
use std::collections::HashMap;
 | 
			
		||||
use std::io::ErrorKind;
 | 
			
		||||
use std::net::IpAddr;
 | 
			
		||||
 | 
			
		||||
thread_local! {
 | 
			
		||||
    static CONFIG_CACHES: RefCell<HashMap<String, OpenIDConfig>> = RefCell::new(Default::default());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(thiserror::Error, Debug)]
 | 
			
		||||
enum OpenIDServiceError {
 | 
			
		||||
    #[error("Given provider not found!")]
 | 
			
		||||
    FindProvider,
 | 
			
		||||
    #[error("Failed to get provider configuration: {0}")]
 | 
			
		||||
    GetProviderConfiguration(String),
 | 
			
		||||
    #[error("Provided state does not exists!")]
 | 
			
		||||
    NonExistingState,
 | 
			
		||||
    #[error("The state has expired!")]
 | 
			
		||||
    ExpiredState,
 | 
			
		||||
    #[error("Invalid IP address")]
 | 
			
		||||
    InvalidIP,
 | 
			
		||||
    #[error("Failed to query token endpoint: {0}")]
 | 
			
		||||
    QueryTokenEndpoint(String),
 | 
			
		||||
    #[error("Failed to query user info endpoint: {0}")]
 | 
			
		||||
    QueryUserInfoEndpoint(String),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct OpenIDClient<'a> {
 | 
			
		||||
@@ -53,7 +69,7 @@ async fn load_provider_info(prov_id: &str) -> anyhow::Result<OpenIDClient> {
 | 
			
		||||
        .openid_providers()
 | 
			
		||||
        .into_iter()
 | 
			
		||||
        .find(|p| p.id.eq(prov_id))
 | 
			
		||||
        .ok_or_else(|| std::io::Error::new(ErrorKind::Other, "Provider not found!"))?;
 | 
			
		||||
        .ok_or(OpenIDServiceError::FindProvider)?;
 | 
			
		||||
 | 
			
		||||
    if let Some(conf) = CONFIG_CACHES.with(|i| i.borrow().get(prov_id).cloned()) {
 | 
			
		||||
        return Ok(OpenIDClient { prov, conf });
 | 
			
		||||
@@ -61,7 +77,7 @@ async fn load_provider_info(prov_id: &str) -> anyhow::Result<OpenIDClient> {
 | 
			
		||||
 | 
			
		||||
    let conf = OpenIDConfig::load_from_url(prov.configuration_url)
 | 
			
		||||
        .await
 | 
			
		||||
        .map_err(|e| std::io::Error::new(ErrorKind::Other, e.to_string()))?;
 | 
			
		||||
        .map_err(|e| OpenIDServiceError::GetProviderConfiguration(e.to_string()))?;
 | 
			
		||||
 | 
			
		||||
    CONFIG_CACHES.with(|i| {
 | 
			
		||||
        i.borrow_mut()
 | 
			
		||||
@@ -83,3 +99,50 @@ pub async fn start_login(prov_id: &str, ip: IpAddr) -> anyhow::Result<String> {
 | 
			
		||||
        &AppConfig::get().oidc_redirect_url,
 | 
			
		||||
    ))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Finish OpenID login
 | 
			
		||||
pub async fn finish_login(
 | 
			
		||||
    ip: IpAddr,
 | 
			
		||||
    code: &str,
 | 
			
		||||
    state_key: &str,
 | 
			
		||||
) -> anyhow::Result<light_openid::primitives::OpenIDUserInfo> {
 | 
			
		||||
    // Consume state
 | 
			
		||||
    let state = redis_connection::get_value::<OpenIDState>(&redis_key(state_key))
 | 
			
		||||
        .await?
 | 
			
		||||
        .ok_or(OpenIDServiceError::NonExistingState)?;
 | 
			
		||||
    redis_connection::remove_value(&redis_key(state_key)).await?;
 | 
			
		||||
 | 
			
		||||
    if state.expire < time() {
 | 
			
		||||
        return Err(OpenIDServiceError::ExpiredState.into());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if state.ip != ip {
 | 
			
		||||
        log::error!(
 | 
			
		||||
            "Mismatching IP addresses (expected {} / got {}",
 | 
			
		||||
            state.ip,
 | 
			
		||||
            ip
 | 
			
		||||
        );
 | 
			
		||||
        return Err(OpenIDServiceError::InvalidIP.into());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Query provider
 | 
			
		||||
    let prov = load_provider_info(&state.prov_id).await?;
 | 
			
		||||
    let (token, _) = prov
 | 
			
		||||
        .conf
 | 
			
		||||
        .request_token(
 | 
			
		||||
            prov.prov.client_id,
 | 
			
		||||
            prov.prov.client_secret,
 | 
			
		||||
            code,
 | 
			
		||||
            &AppConfig::get().oidc_redirect_url,
 | 
			
		||||
        )
 | 
			
		||||
        .await
 | 
			
		||||
        .map_err(|e| OpenIDServiceError::QueryTokenEndpoint(e.to_string()))?;
 | 
			
		||||
 | 
			
		||||
    let (user_info, _) = prov
 | 
			
		||||
        .conf
 | 
			
		||||
        .request_user_info(&token)
 | 
			
		||||
        .await
 | 
			
		||||
        .map_err(|e| OpenIDServiceError::QueryUserInfoEndpoint(e.to_string()))?;
 | 
			
		||||
 | 
			
		||||
    Ok(user_info)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user