Add rate limiting

This commit is contained in:
Pierre HUBERT 2023-05-26 17:55:19 +02:00
parent 4ba4d10fce
commit c84c2ef3c5
14 changed files with 242 additions and 31 deletions

View File

@ -440,6 +440,16 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
[[package]]
name = "combine"
version = "4.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4"
dependencies = [
"bytes",
"memchr",
]
[[package]] [[package]]
name = "convert_case" name = "convert_case"
version = "0.4.0" version = "0.4.0"
@ -666,7 +676,9 @@ dependencies = [
"lazy_static", "lazy_static",
"log", "log",
"mailchecker", "mailchecker",
"redis",
"serde", "serde",
"serde_json",
] ]
[[package]] [[package]]
@ -1083,6 +1095,20 @@ dependencies = [
"getrandom", "getrandom",
] ]
[[package]]
name = "redis"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ea8c51b5dc1d8e5fd3350ec8167f464ec0995e79f2e90a075b63371500d557f"
dependencies = [
"combine",
"itoa",
"percent-encoding",
"ryu",
"sha1_smol",
"url",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.2.16" version = "0.2.16"
@ -1204,6 +1230,12 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "sha1_smol"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.1" version = "1.4.1"

View File

@ -14,5 +14,7 @@ anyhow = "1.0.71"
actix-web = "4.3.1" actix-web = "4.3.1"
diesel = { version = "2.0.4", features = ["postgres"] } diesel = { version = "2.0.4", features = ["postgres"] }
serde = { version = "1.0.163", features = ["derive"] } serde = { version = "1.0.163", features = ["derive"] }
serde_json = "1.0.96"
actix-remote-ip = "0.1.0" actix-remote-ip = "0.1.0"
mailchecker = "5.0.9" mailchecker = "5.0.9"
redis = "0.23.0"

View File

@ -35,6 +35,26 @@ pub struct AppConfig {
/// PostgreSQL database name /// PostgreSQL database name
#[clap(long, env, default_value = "geneit")] #[clap(long, env, default_value = "geneit")]
db_name: String, db_name: String,
/// Redis connection hostname
#[clap(long, env, default_value = "localhost")]
redis_hostname: String,
/// Redis connection port
#[clap(long, env, default_value_t = 6379)]
redis_port: u16,
/// Redis database number
#[clap(long, env, default_value_t = 0)]
redis_db_number: i64,
/// Redis username
#[clap(long, env)]
redis_username: Option<String>,
/// Redis password
#[clap(long, env, default_value = "secretredis")]
redis_password: String,
} }
lazy_static::lazy_static! { lazy_static::lazy_static! {
@ -56,4 +76,16 @@ impl AppConfig {
self.db_username, self.db_password, self.db_host, self.db_port, self.db_name self.db_username, self.db_password, self.db_host, self.db_port, self.db_name
) )
} }
/// Get Redis connection configuration
pub fn redis_connection_config(&self) -> redis::ConnectionInfo {
redis::ConnectionInfo {
addr: redis::ConnectionAddr::Tcp(self.redis_hostname.clone(), self.redis_port),
redis: redis::RedisConnectionInfo {
db: self.redis_db_number,
username: self.redis_username.clone(),
password: Some(self.redis_password.clone()),
},
}
}
} }

View File

@ -0,0 +1,4 @@
//! # External services connections
pub mod db_connection;
pub mod redis_connection;

View File

@ -0,0 +1,51 @@
//! # Redis connection management
use crate::app_config::AppConfig;
use redis::Commands;
use serde::de::DeserializeOwned;
use std::cell::RefCell;
use std::time::Duration;
thread_local! {
static REDIS_CONNECTION: RefCell<Option<redis::Client>> = RefCell::new(None);
}
/// Execute a request on Redis
fn execute_request<E, I>(cb: E) -> anyhow::Result<I>
where
E: FnOnce(&mut redis::Client) -> anyhow::Result<I>,
{
// Establish connection if required
if REDIS_CONNECTION.with(|i| i.borrow().is_none()) {
let conn = redis::Client::open(AppConfig::get().redis_connection_config())?;
REDIS_CONNECTION.with(|i| *i.borrow_mut() = Some(conn))
}
REDIS_CONNECTION.with(|i| cb(i.borrow_mut().as_mut().unwrap()))
}
/// Get a value stored on Redis
pub async fn get_value<E>(key: &str) -> anyhow::Result<Option<E>>
where
E: DeserializeOwned,
{
let value: Option<String> = execute_request(|conn| Ok(conn.get(key)?))?;
Ok(match value {
None => None,
Some(v) => serde_json::from_str(&v)?,
})
}
/// Set a new value on Redis
pub async fn set_value<E>(key: &str, value: &E, lifetime: Duration) -> anyhow::Result<()>
where
E: serde::Serialize,
{
let value_str = serde_json::to_string(value)?;
execute_request(|conn| Ok(conn.set_ex(key, value_str, lifetime.as_secs() as usize)?))?;
Ok(())
}

View File

@ -1,4 +0,0 @@
//! # API controller
pub mod auth_controller;
pub mod config_controller;

View File

@ -1,7 +1,8 @@
use crate::constants::StaticConstraints; use crate::constants::StaticConstraints;
use crate::services::users_service; use crate::controllers::HttpResult;
use crate::services::rate_limiter_service::RatedAction;
use crate::services::{rate_limiter_service, users_service};
use actix_remote_ip::RemoteIP; use actix_remote_ip::RemoteIP;
use actix_web::error::ErrorInternalServerError;
use actix_web::{web, HttpResponse}; use actix_web::{web, HttpResponse};
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
@ -11,11 +12,12 @@ pub struct CreateAccountBody {
} }
/// Create a new account /// Create a new account
pub async fn create_account( pub async fn create_account(remote_ip: RemoteIP, req: web::Json<CreateAccountBody>) -> HttpResult {
_remote_ip: RemoteIP, // Rate limiting
req: web::Json<CreateAccountBody>, if rate_limiter_service::should_block_action(remote_ip.0, RatedAction::CreateAccount).await? {
) -> actix_web::Result<HttpResponse> { return Ok(HttpResponse::TooManyRequests().finish());
// TODO : rate limiting }
rate_limiter_service::record_action(remote_ip.0, RatedAction::CreateAccount).await?;
// Check if email is valid // Check if email is valid
if !mailchecker::is_valid(&req.email) { if !mailchecker::is_valid(&req.email) {
@ -30,25 +32,14 @@ pub async fn create_account(
} }
// Check if email is already attached to an account // Check if email is already attached to an account
match users_service::exists_email(&req.email).await { if users_service::exists_email(&req.email).await? {
Ok(false) => {} return Ok(
Ok(true) => { HttpResponse::Conflict().json("An account with the same email address already exists!")
return Ok(HttpResponse::Conflict() );
.json("An account with the same email address already exists!"));
}
Err(e) => {
log::error!("Failed to check email existence! {}", e);
return Err(ErrorInternalServerError(e));
}
} }
// Create the account // Create the account
let user_id = users_service::create_account(&req.name, &req.email) let user_id = users_service::create_account(&req.name, &req.email).await?;
.await
.map_err(|e| {
log::error!("Failed to create user! {e}");
ErrorInternalServerError(e)
})?;
// TODO : trigger reset password (send mail) // TODO : trigger reset password (send mail)

View File

@ -0,0 +1,35 @@
//! # API controller
use actix_web::body::BoxBody;
use actix_web::HttpResponse;
use std::fmt::{Debug, Display, Formatter};
pub mod auth_controller;
pub mod config_controller;
/// Custom error to ease controller writing
#[derive(Debug)]
pub struct HttpErr {
err: anyhow::Error,
}
impl Display for HttpErr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.err, f)
}
}
impl actix_web::error::ResponseError for HttpErr {
fn error_response(&self) -> HttpResponse<BoxBody> {
log::error!("Error while processing request! {}", self);
HttpResponse::InternalServerError().body("Failed to execute request!")
}
}
impl From<anyhow::Error> for HttpErr {
fn from(err: anyhow::Error) -> HttpErr {
HttpErr { err }
}
}
pub type HttpResult = Result<HttpResponse, HttpErr>;

View File

@ -1,10 +1,9 @@
pub mod app_config; pub mod app_config;
pub mod connections;
pub mod constants; pub mod constants;
pub mod controllers; pub mod controllers;
pub mod services; pub mod services;
pub mod utils; pub mod utils;
// Diesel specific
pub mod db_connection;
pub mod models; pub mod models;
pub mod schema; pub mod schema;

View File

@ -1,3 +1,4 @@
//! # Backend services //! # Backend services
pub mod rate_limiter_service;
pub mod users_service; pub mod users_service;

View File

@ -0,0 +1,68 @@
use crate::connections::redis_connection;
use crate::utils::time_utils::time;
use std::net::IpAddr;
use std::time::Duration;
#[derive(Debug, Copy, Clone)]
pub enum RatedAction {
CreateAccount,
}
impl RatedAction {
fn id(&self) -> &'static str {
match self {
RatedAction::CreateAccount => "create-account",
}
}
fn limit(&self) -> usize {
match self {
RatedAction::CreateAccount => 5,
}
}
fn keep_seconds(&self) -> u64 {
match self {
RatedAction::CreateAccount => 3600,
}
}
fn key(&self, ip: IpAddr) -> String {
format!("rate-{}-{}", self.id(), ip)
}
}
/// Keep track of the time the action was executed by the user
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
struct ActionRecord(Vec<u64>);
impl ActionRecord {
pub fn clean(&mut self, action: RatedAction) {
self.0.retain(|e| e + action.keep_seconds() > time());
}
}
/// Record a new action of the user
pub async fn record_action(ip: IpAddr, action: RatedAction) -> anyhow::Result<()> {
let key = action.key(ip);
let mut record = redis_connection::get_value::<ActionRecord>(&key)
.await?
.unwrap_or_default();
record.clean(action);
record.0.push(time());
redis_connection::set_value(&key, &record, Duration::from_secs(action.keep_seconds())).await?;
Ok(())
}
/// Check whether an action should be blocked, due to too much attempts from the user
pub async fn should_block_action(ip: IpAddr, action: RatedAction) -> anyhow::Result<bool> {
let mut record = redis_connection::get_value::<ActionRecord>(&action.key(ip))
.await?
.unwrap_or_default();
record.clean(action);
Ok(record.0.len() >= action.limit())
}

View File

@ -1,6 +1,6 @@
//! # Users service //! # Users service
use crate::db_connection; use crate::connections::db_connection;
use crate::models::{NewUser, User}; use crate::models::{NewUser, User};
use crate::schema::users; use crate::schema::users;
use crate::utils::time_utils::time; use crate::utils::time_utils::time;