use crate::connections::redis_connection; use crate::utils::time_utils::time; use std::net::IpAddr; use std::time::Duration; #[derive(Debug, Copy, Clone)] pub enum RatedAction { CreateAccount, CheckResetPasswordTokenFailed, RequestNewPasswordResetLink, } impl RatedAction { fn id(&self) -> &'static str { match self { RatedAction::CreateAccount => "create-account", RatedAction::CheckResetPasswordTokenFailed => "check-reset-password-token", RatedAction::RequestNewPasswordResetLink => "req-pwd-reset-lnk", } } fn limit(&self) -> usize { match self { RatedAction::CreateAccount => 5, RatedAction::CheckResetPasswordTokenFailed => 100, RatedAction::RequestNewPasswordResetLink => 5, } } fn keep_seconds(&self) -> u64 { 3600 } fn key(&self, ip: IpAddr) -> String { format!("rate-{}-{}", self.id(), ip) } } /// Keep track of the time the action was executed by the user #[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] struct ActionRecord(Vec); impl ActionRecord { pub fn clean(&mut self, action: RatedAction) { self.0.retain(|e| e + action.keep_seconds() > time()); } } /// Record a new action of the user pub async fn record_action(ip: IpAddr, action: RatedAction) -> anyhow::Result<()> { let key = action.key(ip); let mut record = redis_connection::get_value::(&key) .await? .unwrap_or_default(); record.clean(action); record.0.push(time()); redis_connection::set_value(&key, &record, Duration::from_secs(action.keep_seconds())).await?; Ok(()) } /// Check whether an action should be blocked, due to too much attempts from the user pub async fn should_block_action(ip: IpAddr, action: RatedAction) -> anyhow::Result { let mut record = redis_connection::get_value::(&action.key(ip)) .await? .unwrap_or_default(); record.clean(action); Ok(record.0.len() >= action.limit()) }