use elliptic_curve::pkcs8::EncodePublicKey; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Validation}; use p384::ecdsa::{SigningKey, VerifyingKey}; use p384::pkcs8::{EncodePrivateKey, LineEnding}; use rand::rngs::OsRng; use serde::de::DeserializeOwned; use serde::Serialize; #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(tag = "alg")] pub enum TokenPubKey { /// This variant DOES make crash the program. It MUST NOT used to validate JWT. /// /// It is a hack to hide public key when getting the list of tokens None, /// ECDSA with SHA2-384 variant ES384 { r#pub: String }, } impl TokenPubKey { pub fn is_invalid(&self) -> bool { self == &TokenPubKey::None } } #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] #[serde(tag = "alg")] pub enum TokenPrivKey { ES384 { r#priv: String }, } /// Generate a new token keypair pub fn generate_key_pair() -> anyhow::Result<(TokenPubKey, TokenPrivKey)> { let signing_key = SigningKey::random(&mut OsRng); let priv_pem = signing_key .to_pkcs8_der()? .to_pem("PRIVATE KEY", LineEnding::LF)? .to_string(); let pub_key = VerifyingKey::from(signing_key); let pub_pem = pub_key.to_public_key_pem(LineEnding::LF)?; Ok(( TokenPubKey::ES384 { r#pub: pub_pem }, TokenPrivKey::ES384 { r#priv: priv_pem }, )) } /// Sign JWT with a private key pub fn sign_jwt(key: &TokenPrivKey, claims: &C) -> anyhow::Result { match key { TokenPrivKey::ES384 { r#priv } => { let encoding_key = EncodingKey::from_ec_pem(r#priv.as_bytes())?; Ok(jsonwebtoken::encode( &jsonwebtoken::Header::new(Algorithm::ES384), &claims, &encoding_key, )?) } } } /// Validate a given JWT pub fn validate_jwt(key: &TokenPubKey, token: &str) -> anyhow::Result { match key { TokenPubKey::ES384 { r#pub } => { let decoding_key = DecodingKey::from_ec_pem(r#pub.as_bytes())?; let validation = Validation::new(Algorithm::ES384); Ok(jsonwebtoken::decode::(token, &decoding_key, &validation)?.claims) } TokenPubKey::None => { panic!("A public key is required!") } } } #[cfg(test)] mod test { use crate::utils::jwt_utils::{generate_key_pair, sign_jwt, validate_jwt}; use crate::utils::time_utils::time; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize, Eq, PartialEq)] pub struct Claims { sub: String, exp: u64, } impl Default for Claims { fn default() -> Self { Self { sub: "my-sub".to_string(), exp: time() + 100, } } } #[test] fn jwt_encode_sign_verify_valid() { let (pub_key, priv_key) = generate_key_pair().unwrap(); let claims = Claims::default(); let jwt = sign_jwt(&priv_key, &claims).expect("Failed to sign JWT!"); let claims_out = validate_jwt(&pub_key, &jwt).expect("Failed to validate JWT!"); assert_eq!(claims, claims_out) } #[test] fn jwt_encode_sign_verify_invalid_key() { let (_pub_key, priv_key) = generate_key_pair().unwrap(); let (pub_key_2, _priv_key_2) = generate_key_pair().unwrap(); let claims = Claims::default(); let jwt = sign_jwt(&priv_key, &claims).expect("Failed to sign JWT!"); validate_jwt::(&pub_key_2, &jwt).expect_err("JWT should not have validated!"); } #[test] fn jwt_verify_random_string() { let (pub_key, _priv_key) = generate_key_pair().unwrap(); validate_jwt::(&pub_key, "random_string") .expect_err("JWT should not have validated!"); } #[test] fn jwt_expired() { let (pub_key, priv_key) = generate_key_pair().unwrap(); let claims = Claims { exp: time() - 100, ..Default::default() }; let jwt = sign_jwt(&priv_key, &claims).expect("Failed to sign JWT!"); validate_jwt::(&pub_key, &jwt).expect_err("JWT should not have validated!"); } #[test] fn jwt_invalid_signature() { let (pub_key, priv_key) = generate_key_pair().unwrap(); let claims = Claims::default(); let jwt = sign_jwt(&priv_key, &claims).expect("Failed to sign JWT!"); validate_jwt::(&pub_key, &format!("{jwt}bad")) .expect_err("JWT should not have validated!"); } }