use std::error::Error; use std::io::ErrorKind; use aes_gcm::aead::{Aead, OsRng}; use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce}; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use base64::Engine as _; use bincode::{Decode, Encode}; use rand::Rng; /// The lenght of the nonce used to initialize encryption const NONCE_LEN: usize = 12; /// CryptoWrapper is a library that can be used to encrypt and decrypt some data marked /// that derives [Encode] and [Decode] traits using AES encryption pub struct CryptoWrapper { key: Key, } impl CryptoWrapper { /// Generate a new memory wrapper pub fn new_random() -> Self { Self { key: Aes256Gcm::generate_key(&mut OsRng), } } /// Encrypt some data, returning the result as a base64-encoded string pub fn encrypt(&self, data: &T) -> Result> { let aes_key = Aes256Gcm::new(&self.key); let nonce_bytes = rand::rng().random::<[u8; NONCE_LEN]>(); let serialized_data = bincode::encode_to_vec(data, bincode::config::standard())?; let mut enc = aes_key .encrypt(Nonce::from_slice(&nonce_bytes), serialized_data.as_slice()) .unwrap(); enc.extend_from_slice(&nonce_bytes); Ok(BASE64_STANDARD.encode(enc)) } /// Decrypt some data previously encrypted using the [`CryptoWrapper::encrypt`] method pub fn decrypt(&self, input: &str) -> Result> { let bytes = BASE64_STANDARD.decode(input)?; if bytes.len() < NONCE_LEN { return Err(Box::new(std::io::Error::new( ErrorKind::Other, "Input string is smaller than nonce!", ))); } let (enc, nonce) = bytes.split_at(bytes.len() - NONCE_LEN); assert_eq!(nonce.len(), NONCE_LEN); let aes_key = Aes256Gcm::new(&self.key); let dec = match aes_key.decrypt(Nonce::from_slice(nonce), enc) { Ok(d) => d, Err(e) => { log::error!("Failed to decrypt wrapped data! {:#?}", e); return Err(Box::new(std::io::Error::new( ErrorKind::Other, "Failed to decrypt wrapped data!", ))); } }; Ok(bincode::decode_from_slice(&dec, bincode::config::standard())?.0) } } #[cfg(test)] mod test { use crate::crypto_wrapper::CryptoWrapper; use bincode::{Decode, Encode}; #[derive(Encode, Decode, Eq, PartialEq, Debug)] struct Message(String); #[test] fn encrypt_and_decrypt() { let wrapper = CryptoWrapper::new_random(); let msg = Message("Pierre was here".to_string()); let enc = wrapper.encrypt(&msg).unwrap(); let dec: Message = wrapper.decrypt(&enc).unwrap(); assert_eq!(dec, msg) } #[test] fn encrypt_and_decrypt_invalid() { let wrapper_1 = CryptoWrapper::new_random(); let wrapper_2 = CryptoWrapper::new_random(); let msg = Message("Pierre was here".to_string()); let enc = wrapper_1.encrypt(&msg).unwrap(); wrapper_2.decrypt::(&enc).unwrap_err(); } }