use std::fs::File; use std::io::{ErrorKind, Write}; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::mpsc; use std::task::{Context, Poll}; use actix::dev::Stream; use actix_web::HttpServer; use actix_web::dev::ServiceRequest; use actix_web::middleware::Logger; use actix_web::web::Bytes; use actix_web::{App, HttpResponse, web}; use actix_web_httpauth::extractors; use actix_web_httpauth::extractors::AuthenticationError; use actix_web_httpauth::extractors::basic::BasicAuth; use actix_web_httpauth::middleware::HttpAuthentication; use askama::Template; use clap::Parser; /// Simple heavy file server #[derive(Parser, Debug)] struct Args { /// The URL this service will listen to #[arg(short, long, env, default_value = "0.0.0.0:5000")] listen_url: String, /// Directory that contains served files #[arg(short, long, env)] target_dir: String, /// Access token used to secure access to this service #[arg(short, long, env)] access_token: Option, } lazy_static::lazy_static! { static ref ARGS: Args = Args::parse(); } async fn validator( req: ServiceRequest, creds: BasicAuth, ) -> Result { if creds.password().eq(&ARGS.access_token.as_deref()) { Ok(req) } else { let config = extractors::basic::Config::default(); Err((AuthenticationError::from(config).into(), req)) } } fn recurse_scan>(dir: B) -> Vec { let dir = dir.as_ref(); if dir.is_file() { return vec![dir.to_path_buf()]; } let mut list = vec![]; for file in dir.read_dir().unwrap() { let file = file.unwrap(); list.append(&mut recurse_scan(file.path())); } list } /// Get the list of files to download fn files_list() -> Vec { recurse_scan(&ARGS.target_dir) } async fn bootstrap_css() -> HttpResponse { HttpResponse::Ok() .insert_header(("content-type", "text/css")) .body(include_str!("../assets/bootstrap.min.css")) } #[derive(Template)] #[template(path = "../templates/index.html")] struct IndexTemplate { files: Vec, app_title: &'static str, } async fn index() -> HttpResponse { HttpResponse::Ok() .insert_header(("content-type", "text/html")) .body( IndexTemplate { files: files_list(), app_title: "Heavy file server", } .render() .unwrap(), ) } struct SendWrapper(mpsc::SyncSender>); impl Write for SendWrapper { fn write(&mut self, buf: &[u8]) -> std::io::Result { if let Err(e) = self.0.send(buf.to_vec()) { log::error!("Failed to send a chunk of data! {}", e); return Err(std::io::Error::new( ErrorKind::Other, "Failed to send a chunk of data!", )); } Ok(buf.len()) } fn flush(&mut self) -> std::io::Result<()> { Ok(()) } } struct FileStreamer { receive: mpsc::Receiver>, } impl FileStreamer { pub fn start() -> Self { let (send, receive) = mpsc::sync_channel(1); std::thread::spawn(move || { let mut tar = tar::Builder::new(SendWrapper(send)); for file in files_list() { let file_path = &file.to_str().unwrap().replace(&ARGS.target_dir, "")[1..]; log::debug!("Add {} to archive", file_path); tar.append_file( file_path, &mut File::open(&file).expect("Failed to open file"), ) .unwrap(); } tar.finish().unwrap(); }); Self { receive } } } impl Stream for FileStreamer { type Item = Result; fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { match self.receive.recv() { Ok(d) => Poll::Ready(Some(Ok(Bytes::copy_from_slice(&d)))), Err(e) => { log::error!("Recv error: {}", e); Poll::Ready(None) } } } } async fn download() -> HttpResponse { HttpResponse::Ok() .insert_header(("Content-Disposition", " attachment; filename=\"files.tar\"")) .streaming(FileStreamer::start()) } #[actix_web::main] async fn main() -> std::io::Result<()> { env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); log::info!("Start to listen on {}", ARGS.listen_url); log::info!("File are served from {}", ARGS.target_dir); HttpServer::new(|| { App::new() .wrap(HttpAuthentication::basic(validator)) .wrap(Logger::default()) .route("/assets/bootstrap.min.css", web::get().to(bootstrap_css)) .route("/", web::get().to(index)) .route("/download", web::get().to(download)) }) .bind(ARGS.listen_url.to_string())? .run() .await } #[cfg(test)] mod test { use crate::Args; #[test] fn verify_cli() { use clap::CommandFactory; Args::command().debug_assert() } }