//! # Providers state actor //! //! This actor stores the content of the states //! during authentication with upstream providers use crate::constants::{ MAX_OIDC_PROVIDERS_STATES, OIDC_PROVIDERS_STATE_DURATION, OIDC_PROVIDERS_STATE_LEN, OIDC_STATES_CLEANUP_INTERVAL, }; use actix::{Actor, AsyncContext, Context, Handler, Message}; use std::collections::hash_map::Entry; use std::collections::HashMap; use std::net::IpAddr; use crate::data::login_redirect::LoginRedirect; use crate::data::provider::ProviderID; use crate::utils::string_utils::rand_str; use crate::utils::time::time; #[derive(Debug, Clone)] pub struct ProviderLoginState { pub provider_id: ProviderID, pub state_id: String, pub redirect: LoginRedirect, pub expire: u64, } impl ProviderLoginState { pub fn new(prov_id: &ProviderID, redirect: LoginRedirect) -> Self { Self { provider_id: prov_id.clone(), state_id: rand_str(OIDC_PROVIDERS_STATE_LEN), redirect, expire: time() + OIDC_PROVIDERS_STATE_DURATION, } } } #[derive(Message)] #[rtype(result = "()")] pub struct RecordState { pub ip: IpAddr, pub state: ProviderLoginState, } #[derive(Message)] #[rtype(result = "Option")] pub struct ConsumeState { pub ip: IpAddr, pub state_id: String, } #[derive(Debug, Default)] pub struct ProvidersStatesActor { states: HashMap>, } impl ProvidersStatesActor { /// Clean outdated states fn clean_old_states(&mut self) { #[allow(clippy::map_clone)] let keys = self.states.keys().map(|i| *i).collect::>(); for ip in keys { // Remove old states let states = self.states.get_mut(&ip).unwrap(); states.retain(|i| i.expire < time()); // Remove empty entry keys if states.is_empty() { self.states.remove(&ip); } } } /// Add a new provider login state pub fn insert_state(&mut self, ip: IpAddr, state: ProviderLoginState) { if let Entry::Vacant(e) = self.states.entry(ip) { e.insert(vec![state]); } else { let states = self.states.get_mut(&ip).unwrap(); // We limit the number of states per IP address if states.len() > MAX_OIDC_PROVIDERS_STATES { states.remove(0); } states.push(state); } } /// Get & consume a login state pub fn consume_state(&mut self, ip: IpAddr, state_id: &str) -> Option { let idx = self .states .get(&ip)? .iter() .position(|val| val.state_id.as_str() == state_id)?; Some(self.states.get_mut(&ip)?.remove(idx)) } } impl Actor for ProvidersStatesActor { type Context = Context; fn started(&mut self, ctx: &mut Self::Context) { // Clean up at a regular interval failed attempts ctx.run_interval(OIDC_STATES_CLEANUP_INTERVAL, |act, _ctx| { log::trace!("Cleaning up old states"); act.clean_old_states(); }); } } impl Handler for ProvidersStatesActor { type Result = (); fn handle(&mut self, req: RecordState, _ctx: &mut Self::Context) -> Self::Result { self.insert_state(req.ip, req.state); } } impl Handler for ProvidersStatesActor { type Result = Option; fn handle(&mut self, req: ConsumeState, _ctx: &mut Self::Context) -> Self::Result { self.consume_state(req.ip, &req.state_id) } }