131 lines
3.6 KiB
Rust
131 lines
3.6 KiB
Rust
|
//! # Providers state actor
|
||
|
//!
|
||
|
//! This actor stores the content of the states
|
||
|
//! during authentication with upstream providers
|
||
|
|
||
|
use crate::constants::{
|
||
|
MAX_OIDC_PROVIDERS_STATES, OIDC_PROVIDERS_STATE_DURATION, OIDC_PROVIDERS_STATE_LEN,
|
||
|
OIDC_STATES_CLEANUP_INTERVAL,
|
||
|
};
|
||
|
use actix::{Actor, AsyncContext, Context, Handler, Message};
|
||
|
use std::collections::hash_map::Entry;
|
||
|
use std::collections::HashMap;
|
||
|
use std::net::IpAddr;
|
||
|
|
||
|
use crate::data::login_redirect::LoginRedirect;
|
||
|
use crate::data::provider::ProviderID;
|
||
|
use crate::utils::string_utils::rand_str;
|
||
|
use crate::utils::time::time;
|
||
|
|
||
|
#[derive(Debug, Clone)]
|
||
|
pub struct ProviderLoginState {
|
||
|
pub provider_id: ProviderID,
|
||
|
pub state_id: String,
|
||
|
pub redirect: LoginRedirect,
|
||
|
pub expire: u64,
|
||
|
}
|
||
|
|
||
|
impl ProviderLoginState {
|
||
|
pub fn new(prov_id: &ProviderID, redirect: LoginRedirect) -> Self {
|
||
|
Self {
|
||
|
provider_id: prov_id.clone(),
|
||
|
state_id: rand_str(OIDC_PROVIDERS_STATE_LEN),
|
||
|
redirect,
|
||
|
expire: time() + OIDC_PROVIDERS_STATE_DURATION,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Message)]
|
||
|
#[rtype(result = "()")]
|
||
|
pub struct RecordState {
|
||
|
pub ip: IpAddr,
|
||
|
pub state: ProviderLoginState,
|
||
|
}
|
||
|
|
||
|
#[derive(Message)]
|
||
|
#[rtype(result = "Option<ProviderLoginState>")]
|
||
|
pub struct ConsumeState {
|
||
|
pub ip: IpAddr,
|
||
|
pub state_id: String,
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, Default)]
|
||
|
pub struct ProvidersStatesActor {
|
||
|
states: HashMap<IpAddr, Vec<ProviderLoginState>>,
|
||
|
}
|
||
|
|
||
|
impl ProvidersStatesActor {
|
||
|
/// Clean outdated states
|
||
|
fn clean_old_states(&mut self) {
|
||
|
#[allow(clippy::map_clone)]
|
||
|
let keys = self.states.keys().map(|i| *i).collect::<Vec<_>>();
|
||
|
|
||
|
for ip in keys {
|
||
|
// Remove old states
|
||
|
let states = self.states.get_mut(&ip).unwrap();
|
||
|
states.retain(|i| i.expire < time());
|
||
|
|
||
|
// Remove empty entry keys
|
||
|
if states.is_empty() {
|
||
|
self.states.remove(&ip);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/// Add a new provider login state
|
||
|
pub fn insert_state(&mut self, ip: IpAddr, state: ProviderLoginState) {
|
||
|
if let Entry::Vacant(e) = self.states.entry(ip) {
|
||
|
e.insert(vec![state]);
|
||
|
} else {
|
||
|
let states = self.states.get_mut(&ip).unwrap();
|
||
|
|
||
|
// We limit the number of states per IP address
|
||
|
if states.len() > MAX_OIDC_PROVIDERS_STATES {
|
||
|
states.remove(0);
|
||
|
}
|
||
|
|
||
|
states.push(state);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/// Get & consume a login state
|
||
|
pub fn consume_state(&mut self, ip: IpAddr, state_id: &str) -> Option<ProviderLoginState> {
|
||
|
let idx = self
|
||
|
.states
|
||
|
.get(&ip)?
|
||
|
.iter()
|
||
|
.position(|val| val.state_id.as_str() == state_id)?;
|
||
|
|
||
|
Some(self.states.get_mut(&ip)?.remove(idx))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Actor for ProvidersStatesActor {
|
||
|
type Context = Context<Self>;
|
||
|
|
||
|
fn started(&mut self, ctx: &mut Self::Context) {
|
||
|
// Clean up at a regular interval failed attempts
|
||
|
ctx.run_interval(OIDC_STATES_CLEANUP_INTERVAL, |act, _ctx| {
|
||
|
log::trace!("Cleaning up old states");
|
||
|
act.clean_old_states();
|
||
|
});
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Handler<RecordState> for ProvidersStatesActor {
|
||
|
type Result = ();
|
||
|
|
||
|
fn handle(&mut self, req: RecordState, _ctx: &mut Self::Context) -> Self::Result {
|
||
|
self.insert_state(req.ip, req.state);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Handler<ConsumeState> for ProvidersStatesActor {
|
||
|
type Result = Option<ProviderLoginState>;
|
||
|
|
||
|
fn handle(&mut self, req: ConsumeState, _ctx: &mut Self::Context) -> Self::Result {
|
||
|
self.consume_state(req.ip, &req.state_id)
|
||
|
}
|
||
|
}
|