use aes_gcm::aead::{Aead, OsRng}; use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce}; use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; use base64::Engine as _; use rand::Rng; use rkyv::api::high::{HighSerializer, HighValidator}; use rkyv::bytecheck::CheckBytes; use rkyv::de::Pool; use rkyv::rancor::Strategy; use rkyv::ser::allocator::ArenaHandle; use rkyv::util::AlignedVec; use rkyv::{Archive, Deserialize, Serialize}; use std::error::Error; /// The length of the nonce used to initialize encryption const NONCE_LEN: usize = 12; /// CryptoWrapper is a library that can be used to encrypt and decrypt some data marked /// that derives [SchemaWrite] and [SchemaRead] traits using AES encryption pub struct CryptoWrapper { key: Key, } impl CryptoWrapper { /// Generate a new memory wrapper pub fn new_random() -> Self { Self { key: Aes256Gcm::generate_key(&mut OsRng), } } /// Encrypt some data, returning the result as a base64-encoded string pub fn encrypt( &self, data: &impl for<'a> Serialize, rkyv::rancor::Error>>, ) -> Result> { let aes_key = Aes256Gcm::new(&self.key); let nonce_bytes = rand::rng().random::<[u8; NONCE_LEN]>(); let serialized_data = rkyv::to_bytes(data)?; let mut enc = aes_key .encrypt(Nonce::from_slice(&nonce_bytes), serialized_data.as_slice()) .unwrap(); enc.extend_from_slice(&nonce_bytes); Ok(BASE64_STANDARD.encode(enc)) } /// Decrypt some data previously encrypted using the [`CryptoWrapper::encrypt`] method pub fn decrypt(&self, input: &str) -> Result> where T: Archive, T::Archived: for<'a> CheckBytes> + Deserialize>, { let bytes = BASE64_STANDARD.decode(input)?; if bytes.len() < NONCE_LEN { return Err(Box::new(std::io::Error::other( "Input string is smaller than nonce!", ))); } let (enc, nonce) = bytes.split_at(bytes.len() - NONCE_LEN); assert_eq!(nonce.len(), NONCE_LEN); let aes_key = Aes256Gcm::new(&self.key); let dec = match aes_key.decrypt(Nonce::from_slice(nonce), enc) { Ok(d) => d, Err(e) => { log::error!("Failed to decrypt wrapped data! {e:#?}"); return Err(Box::new(std::io::Error::other( "Failed to decrypt wrapped data!", ))); } }; Ok(rkyv::from_bytes(&dec)?) } } #[cfg(test)] mod test { use crate::crypto_wrapper::CryptoWrapper; #[derive(Eq, PartialEq, Debug, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)] struct Message(String); #[test] fn encrypt_and_decrypt() { let wrapper = CryptoWrapper::new_random(); let msg = Message("Pierre was here".to_string()); let enc = wrapper.encrypt(&msg).unwrap(); let dec: Message = wrapper.decrypt(&enc).unwrap(); assert_eq!(dec, msg) } #[test] fn encrypt_and_decrypt_invalid() { let wrapper_1 = CryptoWrapper::new_random(); let wrapper_2 = CryptoWrapper::new_random(); let msg = Message("Pierre was here".to_string()); let enc = wrapper_1.encrypt(&msg).unwrap(); wrapper_2.decrypt::(&enc).unwrap_err(); } }