Implement server #3

Merged
pierre merged 5 commits from feat-server into master 2023-04-28 07:15:10 +00:00
5 changed files with 239 additions and 5 deletions
Showing only changes of commit 9a3c44e840 - Show all commits

View File

@ -1,5 +1,6 @@
use actix_web::{get, App, HttpResponse, HttpServer};
use actix_web::{get, web, App, HttpResponse, HttpServer};
use askama::Template;
use actix_web::middleware::Logger;
use oidc_test_client::app_config::AppConfig;
use oidc_test_client::openid_primitives::OpenIDConfig;
@ -26,6 +27,27 @@ struct HomeTemplate {
redirect_url: String,
}
#[derive(Template)]
#[template(path = "result.html")]
struct ResultTemplate {
token: String,
user_info: String,
}
#[derive(Template)]
#[template(path = "error.html")]
struct ErrorTemplate<'a> {
message: &'a str,
}
impl<'a> ErrorTemplate<'a> {
pub fn build(message: &'a str) -> HttpResponse {
HttpResponse::Unauthorized()
.content_type("text/html")
.body(Self { message }.render().unwrap())
}
}
#[get("/")]
async fn home() -> HttpResponse {
HttpResponse::Ok().content_type("text/html").body(
@ -39,13 +61,25 @@ async fn home() -> HttpResponse {
#[get("/start")]
async fn start(remote_ip: RemoteIP) -> HttpResponse {
let config = OpenIDConfig::load_from(&AppConfig::get().configuration_url)
.await
.expect("Failed to load provider configuration!");
let config = match OpenIDConfig::load_from(&AppConfig::get().configuration_url).await {
Ok(c) => c,
Err(e) => {
log::error!("Failed to load OpenID configuration! {e}");
return ErrorTemplate::build("Failed to load OpenID configuration!");
}
};
let state = match StateManager::gen_state(&remote_ip) {
Ok(s) => s,
Err(e) => {
log::error!("Failed to generate state! {:?}", e);
return ErrorTemplate::build("Failed to generate state!");
}
};
let authorization_url = config.authorization_url(
&AppConfig::get().client_id,
&StateManager::gen_state(&remote_ip).expect("Failed to generate state!"),
&state,
&AppConfig::get().redirect_url(),
);
@ -54,6 +88,65 @@ async fn start(remote_ip: RemoteIP) -> HttpResponse {
.finish()
}
#[derive(serde::Deserialize)]
struct RedirectQuery {
state: String,
code: String,
}
#[get("/redirect")]
async fn redirect(remote_ip: RemoteIP, query: web::Query<RedirectQuery>) -> HttpResponse {
// First, validate state
if let Err(e) = StateManager::validate_state(&remote_ip, &query.state) {
log::error!("Failed to validate state {}: {:?}", query.state, e);
return ErrorTemplate::build("State could not be validated!");
}
// Then, load OpenID configuration
let config = match OpenIDConfig::load_from(&AppConfig::get().configuration_url).await {
Ok(c) => c,
Err(e) => {
log::error!("Failed to load OpenID configuration! {e}");
return ErrorTemplate::build("Failed to load OpenID configuration!");
}
};
// Query token endpoint
let token = match config
.request_token(
&AppConfig::get().client_id,
&AppConfig::get().client_secret,
&query.code,
&AppConfig::get().redirect_url(),
)
.await
{
Ok(t) => t,
Err(e) => {
log::error!("Failed to retrieve token! {}", e);
return ErrorTemplate::build("Failed to retrieve access token!");
}
};
// Query userinfo endpoint
let user_info = match config.request_user_info(&token).await {
Ok(t) => t,
Err(e) => {
log::error!("Failed to retrieve user info! {}", e);
return ErrorTemplate::build("Failed to retrieve user info!");
}
};
HttpResponse::Ok().content_type("text/html").body(
ResultTemplate {
token: format!("{:#?}", token),
user_info: format!("{:#?}", user_info),
}
.render()
.unwrap(),
)
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));
@ -65,10 +158,12 @@ async fn main() -> std::io::Result<()> {
HttpServer::new(|| {
App::new()
.wrap(Logger::default())
.service(bootstrap)
.service(cover)
.service(home)
.service(start)
.service(redirect)
})
.bind(&AppConfig::get().listen_addr)
.expect("Failed to bind server!")

View File

@ -1,3 +1,7 @@
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine;
use std::collections::HashMap;
use crate::Res;
#[derive(Debug, Clone, serde::Deserialize)]
@ -7,6 +11,35 @@ pub struct OpenIDConfig {
pub userinfo_endpoint: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_in: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id_token: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct UserInfo {
pub sub: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub given_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub family_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub preferred_username: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub email_verified: Option<bool>,
}
impl OpenIDConfig {
/// Load OpenID configuration from a given URL
pub async fn load_from(url: &str) -> Res<Self> {
@ -21,4 +54,40 @@ impl OpenIDConfig {
format!("{}?response_type=code&scope=openid%20profile%20email&client_id={client_id}&state={state}&redirect_uri={redirect_uri}", self.authorization_endpoint)
}
/// Query the token endpoint
pub async fn request_token(
&self,
client_id: &str,
client_secret: &str,
code: &str,
redirect_uri: &str,
) -> Res<TokenResponse> {
let authorization = BASE64_STANDARD.encode(format!("{}:{}", client_id, client_secret));
let mut params = HashMap::new();
params.insert("grant_type", "authorization_code");
params.insert("code", code);
params.insert("redirect_uri", redirect_uri);
Ok(reqwest::Client::new()
.post(&self.token_endpoint)
.header("Authorization", format!("Basic {authorization}"))
.form(&params)
.send()
.await?
.json()
.await?)
}
/// Query the UserInfo endpoint
pub async fn request_user_info(&self, token: &TokenResponse) -> Res<UserInfo> {
Ok(reqwest::Client::new()
.get(&self.userinfo_endpoint)
.header("Authorization", format!("Bearer {}", token.access_token))
.send()
.await?
.json()
.await?)
}
}

View File

@ -1,3 +1,6 @@
use std::error::Error;
use std::fmt;
use crate::crypto_wrapper::CryptoWrapper;
use crate::remote_ip::RemoteIP;
use crate::time_utils::time;
@ -24,6 +27,20 @@ impl State {
}
}
#[derive(Debug, Copy, Clone)]
enum StateError {
InvalidIp,
Expired,
}
impl Error for StateError {}
impl fmt::Display for StateError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "StateManager error {:?}", self)
}
}
impl StateManager {
pub fn init() {
unsafe {
@ -37,4 +54,19 @@ impl StateManager {
unsafe { WRAPPER.as_ref().unwrap() }.encrypt(&state)
}
/// Validate generated state
pub fn validate_state(ip: &RemoteIP, state: &str) -> Res {
let state: State = unsafe { WRAPPER.as_ref().unwrap() }.decrypt(state)?;
if state.ip != ip.0 {
return Err(Box::new(StateError::InvalidIp));
}
if state.expire < time() {
return Err(Box::new(StateError::Expired));
}
Ok(())
}
}

9
templates/error.html Normal file
View File

@ -0,0 +1,9 @@
{% extends "base_page.html" %}
{% block content %}
<div class="alert alert-danger" role="alert">
{{ message }}
</div>
<a class="btn btn-primary" href="/start" role="button">Start again</a>
{% endblock content %}

29
templates/result.html Normal file
View File

@ -0,0 +1,29 @@
{% extends "base_page.html" %}
{% block content %}
<style>
.card {
text-align: left;
}
</style>
<div class="alert alert-success" role="alert">
Login successful
</div>
<div class="card">
<div class="card-body">
<h5 class="card-title">Token response</h5>
<pre class="card-text">{{ token }}</pre>
</div>
</div>
<div class="card">
<div class="card-body">
<h5 class="card-title">User info</h5>
<pre class="card-text">{{ user_info }}</pre>
</div>
</div>
<a class="btn btn-primary" href="/start" role="button">Start again</a>
{% endblock content %}