diff --git a/src/main.rs b/src/main.rs index 1663a2a..f9849c0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ -use actix_web::{get, App, HttpResponse, HttpServer}; +use actix_web::{get, web, App, HttpResponse, HttpServer}; use askama::Template; +use actix_web::middleware::Logger; use oidc_test_client::app_config::AppConfig; use oidc_test_client::openid_primitives::OpenIDConfig; @@ -26,6 +27,27 @@ struct HomeTemplate { redirect_url: String, } +#[derive(Template)] +#[template(path = "result.html")] +struct ResultTemplate { + token: String, + user_info: String, +} + +#[derive(Template)] +#[template(path = "error.html")] +struct ErrorTemplate<'a> { + message: &'a str, +} + +impl<'a> ErrorTemplate<'a> { + pub fn build(message: &'a str) -> HttpResponse { + HttpResponse::Unauthorized() + .content_type("text/html") + .body(Self { message }.render().unwrap()) + } +} + #[get("/")] async fn home() -> HttpResponse { HttpResponse::Ok().content_type("text/html").body( @@ -39,13 +61,25 @@ async fn home() -> HttpResponse { #[get("/start")] async fn start(remote_ip: RemoteIP) -> HttpResponse { - let config = OpenIDConfig::load_from(&AppConfig::get().configuration_url) - .await - .expect("Failed to load provider configuration!"); + let config = match OpenIDConfig::load_from(&AppConfig::get().configuration_url).await { + Ok(c) => c, + Err(e) => { + log::error!("Failed to load OpenID configuration! {e}"); + return ErrorTemplate::build("Failed to load OpenID configuration!"); + } + }; + + let state = match StateManager::gen_state(&remote_ip) { + Ok(s) => s, + Err(e) => { + log::error!("Failed to generate state! {:?}", e); + return ErrorTemplate::build("Failed to generate state!"); + } + }; let authorization_url = config.authorization_url( &AppConfig::get().client_id, - &StateManager::gen_state(&remote_ip).expect("Failed to generate state!"), + &state, &AppConfig::get().redirect_url(), ); @@ -54,6 +88,65 @@ async fn start(remote_ip: RemoteIP) -> HttpResponse { .finish() } +#[derive(serde::Deserialize)] +struct RedirectQuery { + state: String, + code: String, +} + +#[get("/redirect")] +async fn redirect(remote_ip: RemoteIP, query: web::Query) -> HttpResponse { + // First, validate state + if let Err(e) = StateManager::validate_state(&remote_ip, &query.state) { + log::error!("Failed to validate state {}: {:?}", query.state, e); + return ErrorTemplate::build("State could not be validated!"); + } + + // Then, load OpenID configuration + let config = match OpenIDConfig::load_from(&AppConfig::get().configuration_url).await { + Ok(c) => c, + Err(e) => { + log::error!("Failed to load OpenID configuration! {e}"); + return ErrorTemplate::build("Failed to load OpenID configuration!"); + } + }; + + // Query token endpoint + let token = match config + .request_token( + &AppConfig::get().client_id, + &AppConfig::get().client_secret, + &query.code, + &AppConfig::get().redirect_url(), + ) + .await + { + Ok(t) => t, + Err(e) => { + log::error!("Failed to retrieve token! {}", e); + return ErrorTemplate::build("Failed to retrieve access token!"); + } + }; + + // Query userinfo endpoint + let user_info = match config.request_user_info(&token).await { + Ok(t) => t, + Err(e) => { + log::error!("Failed to retrieve user info! {}", e); + return ErrorTemplate::build("Failed to retrieve user info!"); + } + }; + + HttpResponse::Ok().content_type("text/html").body( + ResultTemplate { + token: format!("{:#?}", token), + user_info: format!("{:#?}", user_info), + } + .render() + .unwrap(), + ) +} + #[actix_web::main] async fn main() -> std::io::Result<()> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); @@ -65,10 +158,12 @@ async fn main() -> std::io::Result<()> { HttpServer::new(|| { App::new() + .wrap(Logger::default()) .service(bootstrap) .service(cover) .service(home) .service(start) + .service(redirect) }) .bind(&AppConfig::get().listen_addr) .expect("Failed to bind server!") diff --git a/src/openid_primitives.rs b/src/openid_primitives.rs index bfc4b40..80e178a 100644 --- a/src/openid_primitives.rs +++ b/src/openid_primitives.rs @@ -1,3 +1,7 @@ +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use base64::Engine; +use std::collections::HashMap; + use crate::Res; #[derive(Debug, Clone, serde::Deserialize)] @@ -7,6 +11,35 @@ pub struct OpenIDConfig { pub userinfo_endpoint: String, } +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct TokenResponse { + pub access_token: String, + pub token_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub refresh_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_in: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub id_token: Option, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct UserInfo { + pub sub: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub given_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub family_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub preferred_username: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub email: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub email_verified: Option, +} + impl OpenIDConfig { /// Load OpenID configuration from a given URL pub async fn load_from(url: &str) -> Res { @@ -21,4 +54,40 @@ impl OpenIDConfig { format!("{}?response_type=code&scope=openid%20profile%20email&client_id={client_id}&state={state}&redirect_uri={redirect_uri}", self.authorization_endpoint) } + + /// Query the token endpoint + pub async fn request_token( + &self, + client_id: &str, + client_secret: &str, + code: &str, + redirect_uri: &str, + ) -> Res { + let authorization = BASE64_STANDARD.encode(format!("{}:{}", client_id, client_secret)); + + let mut params = HashMap::new(); + params.insert("grant_type", "authorization_code"); + params.insert("code", code); + params.insert("redirect_uri", redirect_uri); + + Ok(reqwest::Client::new() + .post(&self.token_endpoint) + .header("Authorization", format!("Basic {authorization}")) + .form(¶ms) + .send() + .await? + .json() + .await?) + } + + /// Query the UserInfo endpoint + pub async fn request_user_info(&self, token: &TokenResponse) -> Res { + Ok(reqwest::Client::new() + .get(&self.userinfo_endpoint) + .header("Authorization", format!("Bearer {}", token.access_token)) + .send() + .await? + .json() + .await?) + } } diff --git a/src/state_manager.rs b/src/state_manager.rs index d15521d..b92f208 100644 --- a/src/state_manager.rs +++ b/src/state_manager.rs @@ -1,3 +1,6 @@ +use std::error::Error; +use std::fmt; + use crate::crypto_wrapper::CryptoWrapper; use crate::remote_ip::RemoteIP; use crate::time_utils::time; @@ -24,6 +27,20 @@ impl State { } } +#[derive(Debug, Copy, Clone)] +enum StateError { + InvalidIp, + Expired, +} + +impl Error for StateError {} + +impl fmt::Display for StateError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "StateManager error {:?}", self) + } +} + impl StateManager { pub fn init() { unsafe { @@ -37,4 +54,19 @@ impl StateManager { unsafe { WRAPPER.as_ref().unwrap() }.encrypt(&state) } + + /// Validate generated state + pub fn validate_state(ip: &RemoteIP, state: &str) -> Res { + let state: State = unsafe { WRAPPER.as_ref().unwrap() }.decrypt(state)?; + + if state.ip != ip.0 { + return Err(Box::new(StateError::InvalidIp)); + } + + if state.expire < time() { + return Err(Box::new(StateError::Expired)); + } + + Ok(()) + } } diff --git a/templates/error.html b/templates/error.html new file mode 100644 index 0000000..7752fc4 --- /dev/null +++ b/templates/error.html @@ -0,0 +1,9 @@ +{% extends "base_page.html" %} +{% block content %} + + + +Start again +{% endblock content %} \ No newline at end of file diff --git a/templates/result.html b/templates/result.html new file mode 100644 index 0000000..755eb75 --- /dev/null +++ b/templates/result.html @@ -0,0 +1,29 @@ +{% extends "base_page.html" %} +{% block content %} + + + + + +
+
+
Token response
+
{{ token }}
+
+
+ +
+
+
User info
+
{{ user_info }}
+
+
+ +Start again +{% endblock content %} \ No newline at end of file