1
0
mirror of https://github.com/BitskiCo/jwk-rs synced 2025-09-20 05:08:47 +00:00

Initial commit

This commit is contained in:
Nick Hynes
2020-07-12 18:57:57 +00:00
committed by Nick Hynes
commit 188772365a
11 changed files with 798 additions and 0 deletions

97
src/byte_array.rs Normal file
View File

@@ -0,0 +1,97 @@
use std::{array::FixedSizeArray, fmt};
use derive_more::{AsRef, Deref, From};
use serde::{
de::{self, Deserialize, Deserializer},
ser::{Serialize, Serializer},
};
use zeroize::{Zeroize, Zeroizing};
use crate::utils::{deserialize_base64, serialize_base64};
#[derive(Zeroize, Deref, AsRef, From)]
#[zeroize(drop)]
pub struct ByteArray<const N: usize>(pub [u8; N]);
impl<const N: usize> fmt::Debug for ByteArray<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if cfg!(debug_assertions) {
write!(f, "{}", base64::encode(self.0.as_slice()))
} else {
write!(f, "ByteArray<{}>", N)
}
}
}
impl<const N: usize> PartialEq for ByteArray<N> {
fn eq(&self, other: &Self) -> bool {
self.0.as_slice() == other.0.as_slice()
}
}
impl<const N: usize> Eq for ByteArray<N> {}
impl<const N: usize> ByteArray<N> {
pub fn try_from_slice(bytes: impl AsRef<[u8]>) -> Result<Self, String> {
let mut arr = Self([0u8; N]);
let bytes = bytes.as_ref();
if bytes.len() != N {
Err(format!("expected {} bytes but got {}", N, bytes.len()))
} else {
arr.0.copy_from_slice(bytes);
Ok(arr)
}
}
}
impl<const N: usize> Serialize for ByteArray<N> {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
serialize_base64(self.0.as_slice(), s)
}
}
impl<'de, const N: usize> Deserialize<'de> for ByteArray<N> {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let bytes = Zeroizing::new(deserialize_base64(d)?);
Self::try_from_slice(&*bytes).map_err(|_| {
de::Error::invalid_length(bytes.len(), &format!("{} base64-encoded bytes", N).as_str())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
static BYTES: &[u8] = &[1, 2, 3, 4, 5, 6, 7];
static BASE64_JSON: &str = "\"AQIDBAUGBw==\"";
fn get_de() -> serde_json::Deserializer<serde_json::de::StrRead<'static>> {
serde_json::Deserializer::from_str(&BASE64_JSON)
}
#[test]
fn test_serde_byte_array_good() {
let arr = ByteArray::<7>::try_from_slice(BYTES).unwrap();
let b64 = serde_json::to_string(&arr).unwrap();
assert_eq!(b64, BASE64_JSON);
let bytes: ByteArray<7> = serde_json::from_str(&b64).unwrap();
assert_eq!(bytes.as_ref(), BYTES);
}
#[test]
fn test_serde_deserialize_byte_array_invalid() {
let mut de = serde_json::Deserializer::from_str("\"Z\"");
ByteArray::<0>::deserialize(&mut de).unwrap_err();
}
#[test]
fn test_serde_base64_deserialize_array_long() {
ByteArray::<6>::deserialize(&mut get_de()).unwrap_err();
}
#[test]
fn test_serde_base64_deserialize_array_short() {
ByteArray::<8>::deserialize(&mut get_de()).unwrap_err();
}
}

52
src/byte_vec.rs Normal file
View File

@@ -0,0 +1,52 @@
use std::fmt;
use derive_more::{AsRef, Deref, From};
use serde::{
de::{Deserialize, Deserializer},
ser::{Serialize, Serializer},
};
use zeroize::Zeroize;
use crate::utils::{deserialize_base64, serialize_base64};
#[derive(PartialEq, Eq, Zeroize, Deref, AsRef, From)]
#[zeroize(drop)]
pub struct ByteVec(pub Vec<u8>);
impl fmt::Debug for ByteVec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if cfg!(debug_assertions) {
write!(f, "{:?}", self.0)
} else {
write!(f, "ByteVec")
}
}
}
impl Serialize for ByteVec {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
serialize_base64(&self.0, s)
}
}
impl<'de> Deserialize<'de> for ByteVec {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
Ok(Self(deserialize_base64(d)?))
}
}
#[cfg(test)]
mod tests {
use super::*;
static BYTES: &[u8] = &[1, 2, 3, 4, 5, 6, 7];
static BASE64_JSON: &str = "\"AQIDBAUGBw==\"";
#[test]
fn test_serde_byte_vec() {
let b64 = serde_json::to_string(&ByteVec(BYTES.to_vec())).unwrap();
assert_eq!(b64, BASE64_JSON);
let bytes: ByteVec = serde_json::from_str(&b64).unwrap();
assert_eq!(bytes.as_slice(), BYTES);
}
}

58
src/key_ops.rs Normal file
View File

@@ -0,0 +1,58 @@
use serde::{
de::{self, Deserialize, Deserializer},
ser::{Serialize, SerializeSeq, Serializer},
};
macro_rules! impl_key_ops {
($(($key_op:ident, $i:literal)),+,) => {
paste::item! {
bitflags::bitflags! {
#[derive(Default)]
pub struct KeyOps: u16 {
$(const [<$key_op:upper>] = $i;)*
}
}
}
impl Serialize for KeyOps {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
let mut seq = s.serialize_seq(Some(self.bits().count_ones() as usize))?;
$(
if self.contains(paste::expr! { KeyOps::[<$key_op:upper>] }) {
seq.serialize_element(stringify!($key_op))?;
}
)+
seq.end()
}
}
impl<'de> Deserialize<'de> for KeyOps {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<KeyOps, D::Error> {
let op_strs: Vec<String> = Deserialize::deserialize(d)?;
let mut ops = KeyOps::default();
for op_str in op_strs {
$(
if op_str == stringify!($key_op) {
ops |= paste::expr! { KeyOps::[<$key_op:upper>] };
continue;
}
)+
return Err(de::Error::custom(&format!("invalid key op: `{}`", op_str)));
}
Ok(ops)
}
}
};
}
#[rustfmt::skip]
impl_key_ops!(
(sign, 0b00000001),
(verify, 0b00000010),
(encrypt, 0b00000100),
(decrypt, 0b00001000),
(wrapKey, 0b00010000),
(unwrapKey, 0b00100000),
(deriveKey, 0b01000000),
(deriveBits, 0b10000000),
);

193
src/lib.rs Normal file
View File

@@ -0,0 +1,193 @@
#![allow(incomplete_features)]
#![feature(box_syntax, const_generics, fixed_size_array)]
mod byte_array;
mod byte_vec;
mod key_ops;
#[cfg(test)]
mod tests;
mod utils;
use serde::{Deserialize, Serialize};
use zeroize::Zeroize;
pub use byte_array::ByteArray;
pub use byte_vec::ByteVec;
pub use key_ops::KeyOps;
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct JsonWebKey {
#[serde(flatten)]
pub key_type: Box<KeyType>,
#[serde(default, rename = "use", skip_serializing_if = "Option::is_none")]
pub key_use: Option<KeyUse>,
#[serde(default, skip_serializing_if = "KeyOps::is_empty")]
pub key_ops: KeyOps,
#[serde(default, rename = "kid", skip_serializing_if = "Option::is_none")]
pub key_id: Option<String>,
#[serde(default, rename = "alg", skip_serializing_if = "Option::is_none")]
pub algorithm: Option<JsonWebAlgorithm>,
}
impl JsonWebKey {
pub fn from_slice(bytes: impl AsRef<[u8]>) -> Result<Self, Error> {
Ok(serde_json::from_slice(bytes.as_ref())?)
}
}
impl std::str::FromStr for JsonWebKey {
type Err = Error;
fn from_str(json: &str) -> Result<Self, Self::Err> {
let jwk = Self::from_slice(json.as_bytes())?;
// Validate alg.
use JsonWebAlgorithm::*;
use KeyType::*;
let alg = match &jwk.algorithm {
Some(alg) => alg,
None => return Ok(jwk),
};
match (alg, &*jwk.key_type) {
(
ES256,
EC {
params: Curve::P256 { .. },
},
)
| (RS256, RSA { .. })
| (HS256, Symmetric { .. }) => Ok(jwk),
_ => Err(Error::MismatchedAlgorithm),
}
}
}
impl std::fmt::Display for JsonWebKey {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if f.alternate() {
write!(f, "{}", serde_json::to_string_pretty(self).unwrap())
} else {
write!(f, "{}", serde_json::to_string(self).unwrap())
}
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kty")]
pub enum KeyType {
EC {
#[serde(flatten)]
params: Curve,
},
RSA {
#[serde(flatten)]
public: RsaPublic,
#[serde(flatten, default, skip_serializing_if = "Option::is_none")]
private: Option<RsaPrivate>,
},
#[serde(rename = "oct")]
Symmetric {
#[serde(rename = "k")]
key: ByteVec,
},
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "crv")]
pub enum Curve {
/// prime256v1
#[serde(rename = "P-256")]
P256 {
/// Private point.
#[serde(skip_serializing_if = "Option::is_none")]
d: Option<ByteArray<32>>,
x: ByteArray<32>,
y: ByteArray<32>,
},
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct RsaPublic {
/// Public exponent. Must be 65537.
pub e: PublicExponent,
/// Modulus, p*q.
pub n: ByteVec,
}
const PUBLIC_EXPONENT: u32 = 65537;
const PUBLIC_EXPONENT_B64: &str = "AQAB"; // little-endian, strip zeros
const PUBLIC_EXPONENT_B64_PADDED: &str = "AQABAA==";
#[derive(Debug, PartialEq, Eq)]
pub struct PublicExponent;
impl Serialize for PublicExponent {
fn serialize<S: serde::ser::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
PUBLIC_EXPONENT_B64.serialize(s)
}
}
impl<'de> Deserialize<'de> for PublicExponent {
fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let e = String::deserialize(d)?;
if e == PUBLIC_EXPONENT_B64 || e == PUBLIC_EXPONENT_B64_PADDED {
Ok(Self)
} else {
Err(serde::de::Error::custom(&format!(
"public exponent must be {}",
PUBLIC_EXPONENT
)))
}
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct RsaPrivate {
/// Private exponent.
pub d: ByteVec,
/// First prime factor.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub p: Option<ByteVec>,
/// Second prime factor.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub q: Option<ByteVec>,
/// First factor Chinese Remainder Theorem (CRT) exponent.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp: Option<ByteVec>,
/// Second factor Chinese Remainder Theorem (CRT) exponent.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dq: Option<ByteVec>,
/// First CRT coefficient.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub qi: Option<ByteVec>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum KeyUse {
#[serde(rename = "sig")]
Signing,
#[serde(rename = "enc")]
Encryption,
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Zeroize)]
pub enum JsonWebAlgorithm {
HS256,
RS256,
ES256,
}
#[derive(thiserror::Error)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub enum Error {
#[error(transparent)]
Serde(#[from] serde_json::Error),
#[error(transparent)]
Base64Decode(#[from] base64::DecodeError),
#[error("mismatched algorithm for key type")]
MismatchedAlgorithm,
}

263
src/tests.rs Normal file
View File

@@ -0,0 +1,263 @@
use super::*;
use std::str::FromStr;
#[test]
fn deserialize_es256() {
// Generated using https://mkjwk.org
let jwk_str = r#"{
"kty": "EC",
"d": "ZoKQ9j4dhIBlMRVrv-QG8P_T9sutv3_95eio9MtpgKg",
"use": "enc",
"crv": "P-256",
"kid": "a key",
"x": "QOMHmv96tVlJv-uNqprnDSKIj5AiLTXKRomXYnav0N0",
"y": "TjYZoHnctatEE6NCrKmXQdJJPnNzZEX8nBmZde3AY4k",
"alg": "ES256"
}"#;
let jwk = JsonWebKey::from_str(jwk_str).unwrap();
assert_eq!(
jwk,
JsonWebKey {
key_type: box KeyType::EC {
// The parameters were decoded using a 10-liner Rust script.
params: Curve::P256 {
d: Some(
[
102, 130, 144, 246, 62, 29, 132, 128, 101, 49, 21, 107, 191, 228, 6,
240, 255, 211, 246, 203, 173, 191, 127, 253, 229, 232, 168, 244, 203,
105, 128, 168
]
.into()
),
x: [
64, 227, 7, 154, 255, 122, 181, 89, 73, 191, 235, 141, 170, 154, 231, 13,
34, 136, 143, 144, 34, 45, 53, 202, 70, 137, 151, 98, 118, 175, 208, 221
]
.into(),
y: [
78, 54, 25, 160, 121, 220, 181, 171, 68, 19, 163, 66, 172, 169, 151, 65,
210, 73, 62, 115, 115, 100, 69, 252, 156, 25, 153, 117, 237, 192, 99, 137
]
.into(),
},
},
algorithm: Some(JsonWebAlgorithm::ES256),
key_id: Some("a key".into()),
key_ops: KeyOps::empty(),
key_use: Some(KeyUse::Encryption),
}
);
}
#[test]
fn serialize_es256() {
let jwk = JsonWebKey {
key_type: box KeyType::EC {
params: Curve::P256 {
d: None,
x: [1u8; 32].into(),
y: [2u8; 32].into(),
},
},
key_id: None,
algorithm: None,
key_ops: KeyOps::empty(),
key_use: None,
};
assert_eq!(
jwk.to_string(),
r#"{"kty":"EC","crv":"P-256","x":"AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQE=","y":"AgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgI="}"#
);
}
#[test]
fn deserialize_hs256() {
let jwk_str = r#"{
"kty": "oct",
"k": "tAON6Q",
"alg": "HS256",
"key_ops": ["verify", "sign"]
}"#;
let jwk = JsonWebKey::from_str(jwk_str).unwrap();
assert_eq!(
jwk,
JsonWebKey {
key_type: box KeyType::Symmetric {
// The parameters were decoded using a 10-liner Rust script.
key: vec![180, 3, 141, 233].into(),
},
algorithm: Some(JsonWebAlgorithm::HS256),
key_id: None,
key_ops: KeyOps::SIGN | KeyOps::VERIFY,
key_use: None,
}
);
}
#[test]
fn serialize_hs256() {
let jwk = JsonWebKey {
key_type: box KeyType::Symmetric {
key: vec![42; 16].into(),
},
key_id: None,
algorithm: None,
key_ops: KeyOps::empty(),
key_use: None,
};
assert_eq!(
jwk.to_string(),
r#"{"kty":"oct","k":"KioqKioqKioqKioqKioqKg=="}"#
);
}
#[test]
fn deserialize_rs256() {
let jwk_str = r#"{
"p": "_LSip5o4eaGf25uvwyUq9ubRtKemrCaoCxumoj63Au0",
"kty": "RSA",
"q": "l20iLpicEW3uja0Zg2xP6DjZa86bD4IQ3wFXCcKCf1c",
"d": "Xo0VAHtfV38HwJbAI6X-Fu7vuyoQjnuiSlQhcSjxn0BZfLP_DKxdJ2ANgTGVE0x243YHqhWRHLobbmDcnUuMOQ",
"e": "AQAB",
"qi": "2mzAaSr7I1D3vDtOhbWKS9-9ELRHKbAHz4dhn4DSCBo",
"dp": "-kyswxeVEpyM6wdU2xRobu-HDMn145PSZFY6AX_e460",
"alg": "RS256",
"dq": "OqMWE3khJlatg8s-D_hHUSOCfg65WN4C7ng0XiEmK20",
"n": "lXpGmBoIxj56TpptApaac6V19_7WWbq0a14a5UHBBlkc54NwIUa2X4p9OeK2sy6rLQ_1g1AcSwfsVUy8MP-Riw"
}"#;
let jwk = JsonWebKey::from_str(jwk_str).unwrap();
assert_eq!(
jwk,
JsonWebKey {
key_type: box KeyType::RSA {
public: RsaPublic {
e: PublicExponent,
n: vec![
149, 122, 70, 152, 26, 8, 198, 62, 122, 78, 154, 109, 2, 150, 154, 115,
165, 117, 247, 254, 214, 89, 186, 180, 107, 94, 26, 229, 65, 193, 6, 89,
28, 231, 131, 112, 33, 70, 182, 95, 138, 125, 57, 226, 182, 179, 46, 171,
45, 15, 245, 131, 80, 28, 75, 7, 236, 85, 76, 188, 48, 255, 145, 139
]
.into()
},
private: Some(RsaPrivate {
d: vec![
94, 141, 21, 0, 123, 95, 87, 127, 7, 192, 150, 192, 35, 165, 254, 22, 238,
239, 187, 42, 16, 142, 123, 162, 74, 84, 33, 113, 40, 241, 159, 64, 89,
124, 179, 255, 12, 172, 93, 39, 96, 13, 129, 49, 149, 19, 76, 118, 227,
118, 7, 170, 21, 145, 28, 186, 27, 110, 96, 220, 157, 75, 140, 57
]
.into(),
p: Some(
vec![
252, 180, 162, 167, 154, 56, 121, 161, 159, 219, 155, 175, 195, 37, 42,
246, 230, 209, 180, 167, 166, 172, 38, 168, 11, 27, 166, 162, 62, 183,
2, 237
]
.into()
),
q: Some(
vec![
151, 109, 34, 46, 152, 156, 17, 109, 238, 141, 173, 25, 131, 108, 79,
232, 56, 217, 107, 206, 155, 15, 130, 16, 223, 1, 87, 9, 194, 130, 127,
87
]
.into()
),
dp: Some(
vec![
250, 76, 172, 195, 23, 149, 18, 156, 140, 235, 7, 84, 219, 20, 104,
110, 239, 135, 12, 201, 245, 227, 147, 210, 100, 86, 58, 1, 127, 222,
227, 173
]
.into()
),
dq: Some(
vec![
58, 163, 22, 19, 121, 33, 38, 86, 173, 131, 203, 62, 15, 248, 71, 81,
35, 130, 126, 14, 185, 88, 222, 2, 238, 120, 52, 94, 33, 38, 43, 109
]
.into()
),
qi: Some(
vec![
218, 108, 192, 105, 42, 251, 35, 80, 247, 188, 59, 78, 133, 181, 138,
75, 223, 189, 16, 180, 71, 41, 176, 7, 207, 135, 97, 159, 128, 210, 8,
26
]
.into()
)
})
},
algorithm: Some(JsonWebAlgorithm::RS256),
key_id: None,
key_ops: KeyOps::empty(),
key_use: None,
}
);
}
#[test]
fn serialize_rs256() {
let jwk = JsonWebKey {
key_type: box KeyType::RSA {
public: RsaPublic {
e: PublicExponent,
n: vec![105, 183, 62].into(),
},
private: Some(RsaPrivate {
d: vec![105, 183, 63].into(),
p: None,
q: None,
dp: None,
dq: None,
qi: None,
}),
},
key_id: None,
algorithm: None,
key_ops: KeyOps::empty(),
key_use: None,
};
assert_eq!(
jwk.to_string(),
r#"{"kty":"RSA","e":"AQAB","n":"abc-","d":"abc_"}"#
);
}
#[test]
fn mismatched_algorithm() {
macro_rules! assert_mismatched_alg {
($jwk_str:literal) => {
match JsonWebKey::from_str($jwk_str) {
Err(Error::MismatchedAlgorithm) => {}
v => panic!("expected MismatchedAlgorithm, got {:?}", v),
}
};
}
assert_mismatched_alg!(r#"{ "kty": "oct", "k": "tAON6Q", "alg": "ES256" }"#);
assert_mismatched_alg!(r#"{ "kty": "oct", "k": "tAON6Q", "alg": "RS256" }"#);
assert_mismatched_alg!(
r#"{
"kty": "EC",
"d": "ZoKQ9j4dhIBlMRVrv-QG8P_T9sutv3_95eio9MtpgKg",
"crv": "P-256",
"x": "QOMHmv96tVlJv-uNqprnDSKIj5AiLTXKRomXYnav0N0",
"y": "TjYZoHnctatEE6NCrKmXQdJJPnNzZEX8nBmZde3AY4k",
"alg": "RS256"
}"#
);
assert_mismatched_alg!(
r#"{
"kty": "EC",
"d": "ZoKQ9j4dhIBlMRVrv-QG8P_T9sutv3_95eio9MtpgKg",
"crv": "P-256",
"x": "QOMHmv96tVlJv-uNqprnDSKIj5AiLTXKRomXYnav0N0",
"y": "TjYZoHnctatEE6NCrKmXQdJJPnNzZEX8nBmZde3AY4k",
"alg": "HS256"
}"#
);
}

32
src/utils.rs Normal file
View File

@@ -0,0 +1,32 @@
use serde::{
de::{self, Deserialize, Deserializer},
ser::{Serialize, Serializer},
};
use zeroize::Zeroizing;
fn base64_config() -> base64::Config {
base64::Config::new(base64::CharacterSet::UrlSafe, true /* pad */)
}
fn base64_encode(bytes: impl AsRef<[u8]>) -> String {
base64::encode_config(bytes, base64_config())
}
fn base64_decode(b64: impl AsRef<[u8]>) -> Result<Vec<u8>, base64::DecodeError> {
base64::decode_config(b64, base64_config())
}
pub fn serialize_base64<S: Serializer>(bytes: impl AsRef<[u8]>, s: S) -> Result<S::Ok, S::Error> {
base64_encode(bytes).serialize(s)
}
pub fn deserialize_base64<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
let base64_str = Zeroizing::new(String::deserialize(d)?);
base64_decode(&*base64_str).map_err(|e| {
#[cfg(debug_assertions)]
let err_msg = e.to_string().to_lowercase();
#[cfg(not(debug_assertions))]
let err_msg = "invalid base64";
de::Error::custom(err_msg.strip_suffix(".").unwrap_or(&err_msg))
})
}