#![allow(incomplete_features)] #![feature(box_syntax, const_generics, fixed_size_array)] mod byte_array; mod byte_vec; mod key_ops; #[cfg(test)] mod tests; mod utils; use std::array::FixedSizeArray; use serde::{Deserialize, Serialize}; use zeroize::Zeroize; pub use byte_array::ByteArray; pub use byte_vec::ByteVec; pub use key_ops::KeyOps; #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct JsonWebKey { #[serde(flatten)] pub key: Box, #[serde(default, rename = "use", skip_serializing_if = "Option::is_none")] pub key_use: Option, #[serde(default, skip_serializing_if = "KeyOps::is_empty")] pub key_ops: KeyOps, #[serde(default, rename = "kid", skip_serializing_if = "Option::is_none")] pub key_id: Option, #[serde(default, rename = "alg", skip_serializing_if = "Option::is_none")] pub algorithm: Option, } impl JsonWebKey { pub fn new(key: Key) -> Self { Self { key: box key, key_use: None, key_ops: KeyOps::empty(), key_id: None, algorithm: None, } } pub fn set_algorithm(&mut self, alg: JsonWebAlgorithm) -> Result<(), Error> { Self::validate_algorithm(alg, &*self.key)?; self.algorithm = Some(alg); Ok(()) } pub fn from_slice(bytes: impl AsRef<[u8]>) -> Result { Ok(serde_json::from_slice(bytes.as_ref())?) } fn validate_algorithm(alg: JsonWebAlgorithm, key: &Key) -> Result<(), Error> { use JsonWebAlgorithm::*; use Key::*; match (alg, key) { ( ES256, EC { curve: Curve::P256 { .. }, }, ) | (RS256, RSA { .. }) | (HS256, Symmetric { .. }) => Ok(()), _ => Err(Error::MismatchedAlgorithm), } } } impl std::str::FromStr for JsonWebKey { type Err = Error; fn from_str(json: &str) -> Result { let jwk = Self::from_slice(json.as_bytes())?; let alg = match jwk.algorithm { Some(alg) => alg, None => return Ok(jwk), }; Self::validate_algorithm(alg, &*jwk.key).map(|_| jwk) } } impl std::fmt::Display for JsonWebKey { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { if f.alternate() { write!(f, "{}", serde_json::to_string_pretty(self).unwrap()) } else { write!(f, "{}", serde_json::to_string(self).unwrap()) } } } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "kty")] pub enum Key { /// An elliptic curve, as per [RFC 7518 §6.2](https://tools.ietf.org/html/rfc7518#section-6.2). EC { #[serde(flatten)] curve: Curve, }, /// An elliptic curve, as per [RFC 7518 §6.3](https://tools.ietf.org/html/rfc7518#section-6.3). /// See also: [RFC 3447](https://tools.ietf.org/html/rfc3447). RSA { #[serde(flatten)] public: RsaPublic, #[serde(flatten, default, skip_serializing_if = "Option::is_none")] private: Option, }, /// A symmetric key, as per [RFC 7518 §6.4](https://tools.ietf.org/html/rfc7518#section-6.4). #[serde(rename = "oct")] Symmetric { #[serde(rename = "k")] key: ByteVec, }, } impl Key { /// Returns true iff this key only contains private components (i.e. a private asymmetric /// key or a symmetric key). fn is_private(&self) -> bool { match self { Self::Symmetric { .. } | Self::EC { curve: Curve::P256 { d: Some(_), .. }, .. } | Self::RSA { private: Some(_), .. } => true, _ => false, } } /// Returns true iff this key only contains non-private components. pub fn is_public(&self) -> bool { !self.is_private() } /// Returns the public part of this key, if it's symmetric. pub fn to_public(&self) -> Option { if self.is_public() { return Some(self.clone()); } Some(match self { Self::Symmetric { .. } => return None, Self::EC { curve: Curve::P256 { x, y, .. }, } => Self::EC { curve: Curve::P256 { x: x.clone(), y: y.clone(), d: None, }, }, Self::RSA { public, .. } => Self::RSA { public: public.clone(), private: None, }, }) } /// If this key is asymmetric, encodes it as PKCS#8. #[cfg(feature = "convert")] pub fn to_der(&self) -> Result, PkcsConvertError> { use num_bigint::BigUint; use yasna::{models::ObjectIdentifier, DERWriter, DERWriterSeq, Tag}; use crate::utils::pkcs8; if let Self::Symmetric { .. } = self { return Err(PkcsConvertError::NotAsymmetric); } Ok(match self { Self::EC { curve: Curve::P256 { d, x, y }, } => { let ec_public_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 2, 1]); let prime256v1_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 3, 1, 7]); let oids = &[Some(&ec_public_oid), Some(&prime256v1_oid)]; let write_public = |writer: DERWriter| { let public_bytes: Vec = [0x04 /* uncompressed */] .iter() .chain(x.iter()) .chain(y.iter()) .copied() .collect(); writer.write_bitvec_bytes(&public_bytes, 8 * (32 * 2 + 1)); }; match d { Some(private_point) => { pkcs8::write_private(oids, |writer: &mut DERWriterSeq| { writer.next().write_i8(1); // version writer.next().write_bytes(private_point.as_slice()); writer.next().write_tagged(Tag::context(0), |writer| { writer.write_oid(&prime256v1_oid) }); writer.next().write_tagged(Tag::context(1), write_public); }) } None => pkcs8::write_public(oids, write_public), } } Self::RSA { public, private } => { let rsa_encryption_oid = ObjectIdentifier::from_slice(&[ 1, 2, 840, 113549, 1, 1, 1, // rsaEncryption ]); let oids = &[Some(&rsa_encryption_oid), None]; let write_bytevec = |writer: DERWriter, vec: &ByteVec| { let bigint = BigUint::from_bytes_be(vec.as_slice()); writer.write_biguint(&bigint); }; let write_public = |writer: &mut DERWriterSeq| { write_bytevec(writer.next(), &public.n); writer.next().write_u32(PUBLIC_EXPONENT); }; let write_private = |writer: &mut DERWriterSeq, private: &RsaPrivate| { // https://tools.ietf.org/html/rfc3447#appendix-A.1.2 writer.next().write_i8(0); // version (two-prime) write_public(writer); write_bytevec(writer.next(), &private.d); macro_rules! write_opt_bytevecs { ($($param:ident),+) => {{ $(write_bytevec(writer.next(), private.$param.as_ref().unwrap());)+ }}; } write_opt_bytevecs!(p, q, dp, dq, qi); }; match private { Some( private @ RsaPrivate { d: _, p: Some(_), q: Some(_), dp: Some(_), dq: Some(_), qi: Some(_), }, ) => pkcs8::write_private(oids, |writer| write_private(writer, private)), Some(_) => return Err(PkcsConvertError::MissingRsaParams), None => pkcs8::write_public(oids, |writer| { let body = yasna::construct_der(|writer| writer.write_sequence(write_public)); writer.write_bitvec_bytes(&body, body.len() * 8); }), } } Self::Symmetric { .. } => unreachable!("checked above"), }) } /// If this key is asymmetric, encodes it as PKCS#8 with PEM armoring. #[cfg(feature = "convert")] pub fn to_pem(&self) -> Result { use std::fmt::Write; let der_b64 = base64::encode(self.to_der()?); let key_ty = if self.is_private() { "PRIVATE" } else { "PUBLIC" }; let mut pem = String::new(); writeln!(&mut pem, "-----BEGIN {} KEY-----", key_ty).unwrap(); const MAX_LINE_LEN: usize = 64; for i in (0..der_b64.len()).step_by(MAX_LINE_LEN) { writeln!( &mut pem, "{}", &der_b64[i..std::cmp::min(i + MAX_LINE_LEN, der_b64.len())] ) .unwrap(); } writeln!(&mut pem, "-----END {} KEY-----", key_ty).unwrap(); Ok(pem) } } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "crv")] pub enum Curve { /// prime256v1 #[serde(rename = "P-256")] P256 { /// Private point. #[serde(skip_serializing_if = "Option::is_none")] d: Option>, x: ByteArray<32>, y: ByteArray<32>, }, } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct RsaPublic { /// Public exponent. Must be 65537. pub e: PublicExponent, /// Modulus, p*q. pub n: ByteVec, } const PUBLIC_EXPONENT: u32 = 65537; const PUBLIC_EXPONENT_B64: &str = "AQAB"; // little-endian, strip zeros const PUBLIC_EXPONENT_B64_PADDED: &str = "AQABAA=="; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct PublicExponent; impl Serialize for PublicExponent { fn serialize(&self, s: S) -> Result { PUBLIC_EXPONENT_B64.serialize(s) } } impl<'de> Deserialize<'de> for PublicExponent { fn deserialize>(d: D) -> Result { let e = String::deserialize(d)?; if e == PUBLIC_EXPONENT_B64 || e == PUBLIC_EXPONENT_B64_PADDED { Ok(Self) } else { Err(serde::de::Error::custom(&format!( "public exponent must be {}", PUBLIC_EXPONENT ))) } } } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct RsaPrivate { /// Private exponent. pub d: ByteVec, /// First prime factor. #[serde(default, skip_serializing_if = "Option::is_none")] pub p: Option, /// Second prime factor. #[serde(default, skip_serializing_if = "Option::is_none")] pub q: Option, /// First factor Chinese Remainder Theorem (CRT) exponent. #[serde(default, skip_serializing_if = "Option::is_none")] pub dp: Option, /// Second factor Chinese Remainder Theorem (CRT) exponent. #[serde(default, skip_serializing_if = "Option::is_none")] pub dq: Option, /// First CRT coefficient. #[serde(default, skip_serializing_if = "Option::is_none")] pub qi: Option, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum KeyUse { #[serde(rename = "sig")] Signing, #[serde(rename = "enc")] Encryption, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Zeroize)] pub enum JsonWebAlgorithm { HS256, RS256, ES256, } #[cfg(any(test, feature = "jsonwebtoken"))] impl Into for JsonWebAlgorithm { fn into(self) -> jsonwebtoken::Algorithm { match self { Self::HS256 => jsonwebtoken::Algorithm::HS256, Self::ES256 => jsonwebtoken::Algorithm::ES256, Self::RS256 => jsonwebtoken::Algorithm::RS256, } } } #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] Serde(#[from] serde_json::Error), #[error(transparent)] Base64Decode(#[from] base64::DecodeError), #[error("mismatched algorithm for key type")] MismatchedAlgorithm, } #[derive(Debug, thiserror::Error)] pub enum PkcsConvertError { #[error("encoding RSA JWK as PKCS#8 requires specifing all of p, q, dp, dq, qi")] MissingRsaParams, #[error("a symmetric key can not be encoded using PKCS#8")] NotAsymmetric, }