Skip to content

Commit 26afa2c

Browse files
Frederik Rothenbergeritsyaasir
Frederik Rothenberger
andauthored
Add support for custom JWS algorithms (#1410)
* Add support for custom JWS algorithms This PR introduces a feature `custom_alg` to `identity_jose` (disabled by default) that allows it to process JWS with custom `alg` values. Switching on `custom_alg` makes quite a few changes to `JwsAlgorithm`: - The type is no longer `Copy` - `name()` takes only a reference and returns a `String` rather than `&'static str` - The constant `ALL` is removed as it is no longer possible to enumerate all variants * fmt * Add comment * Nightly fmt * chore: add template for custom_alg file * Split implementation of Display --------- Co-authored-by: Yasir <yasir@shariff.dev>
1 parent 13acb23 commit 26afa2c

File tree

8 files changed

+173
-13
lines changed

8 files changed

+173
-13
lines changed

identity_jose/Cargo.toml

+7
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,10 @@ test = true
3434

3535
[lints]
3636
workspace = true
37+
38+
[features]
39+
custom_alg = []
40+
41+
[[test]]
42+
name = "custom_alg"
43+
required-features = ["custom_alg"]

identity_jose/src/jwk/key.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,9 @@ impl Jwk {
395395
// ===========================================================================
396396

397397
/// Checks if the `alg` claim of the JWK is equal to `expected`.
398-
pub fn check_alg(&self, expected: &str) -> Result<()> {
398+
pub fn check_alg(&self, expected: impl AsRef<str>) -> Result<()> {
399399
match self.alg() {
400-
Some(value) if value == expected => Ok(()),
400+
Some(value) if value == expected.as_ref() => Ok(()),
401401
Some(_) => Err(Error::InvalidClaim("alg")),
402402
None => Ok(()),
403403
}

identity_jose/src/jws/algorithm.rs

+47-4
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@ use core::fmt::Formatter;
66
use core::fmt::Result;
77
use std::str::FromStr;
88

9-
use crate::error::Error;
10-
119
/// Supported algorithms for the JSON Web Signatures `alg` claim.
1210
///
1311
/// [More Info](https://www.iana.org/assignments/jose/jose.xhtml#web-signature-encryption-algorithms)
14-
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, serde::Deserialize, serde::Serialize)]
12+
#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, serde::Deserialize, serde::Serialize)]
13+
#[cfg_attr(not(feature = "custom_alg"), derive(Copy))]
1514
#[allow(non_camel_case_types)]
1615
pub enum JwsAlgorithm {
1716
/// HMAC using SHA-256
@@ -45,10 +44,19 @@ pub enum JwsAlgorithm {
4544
NONE,
4645
/// EdDSA signature algorithms
4746
EdDSA,
47+
/// Custom algorithm
48+
#[cfg(feature = "custom_alg")]
49+
#[serde(untagged)]
50+
Custom(String),
4851
}
4952

5053
impl JwsAlgorithm {
5154
/// A slice of all supported [`JwsAlgorithm`]s.
55+
///
56+
/// Not available when feature `custom_alg` is enabled
57+
/// as it is not possible to enumerate all variants when
58+
/// supporting arbitrary `alg` values.
59+
#[cfg(not(feature = "custom_alg"))]
5260
pub const ALL: &'static [Self] = &[
5361
Self::HS256,
5462
Self::HS384,
@@ -68,6 +76,7 @@ impl JwsAlgorithm {
6876
];
6977

7078
/// Returns the JWS algorithm as a `str` slice.
79+
#[cfg(not(feature = "custom_alg"))]
7180
pub const fn name(self) -> &'static str {
7281
match self {
7382
Self::HS256 => "HS256",
@@ -87,6 +96,29 @@ impl JwsAlgorithm {
8796
Self::EdDSA => "EdDSA",
8897
}
8998
}
99+
100+
/// Returns the JWS algorithm as a `str` slice.
101+
#[cfg(feature = "custom_alg")]
102+
pub fn name(&self) -> String {
103+
match self {
104+
Self::HS256 => "HS256".to_string(),
105+
Self::HS384 => "HS384".to_string(),
106+
Self::HS512 => "HS512".to_string(),
107+
Self::RS256 => "RS256".to_string(),
108+
Self::RS384 => "RS384".to_string(),
109+
Self::RS512 => "RS512".to_string(),
110+
Self::PS256 => "PS256".to_string(),
111+
Self::PS384 => "PS384".to_string(),
112+
Self::PS512 => "PS512".to_string(),
113+
Self::ES256 => "ES256".to_string(),
114+
Self::ES384 => "ES384".to_string(),
115+
Self::ES512 => "ES512".to_string(),
116+
Self::ES256K => "ES256K".to_string(),
117+
Self::NONE => "none".to_string(),
118+
Self::EdDSA => "EdDSA".to_string(),
119+
Self::Custom(name) => name.clone(),
120+
}
121+
}
90122
}
91123

92124
impl FromStr for JwsAlgorithm {
@@ -109,13 +141,24 @@ impl FromStr for JwsAlgorithm {
109141
"ES256K" => Ok(Self::ES256K),
110142
"none" => Ok(Self::NONE),
111143
"EdDSA" => Ok(Self::EdDSA),
112-
_ => Err(Error::JwsAlgorithmParsingError),
144+
#[cfg(feature = "custom_alg")]
145+
value => Ok(Self::Custom(value.to_string())),
146+
#[cfg(not(feature = "custom_alg"))]
147+
_ => Err(crate::error::Error::JwsAlgorithmParsingError),
113148
}
114149
}
115150
}
116151

152+
#[cfg(not(feature = "custom_alg"))]
117153
impl Display for JwsAlgorithm {
118154
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
119155
f.write_str(self.name())
120156
}
121157
}
158+
159+
#[cfg(feature = "custom_alg")]
160+
impl Display for JwsAlgorithm {
161+
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
162+
f.write_str(&(*self).name())
163+
}
164+
}

identity_jose/src/jws/header.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ impl JwsHeader {
6767

6868
/// Returns the value for the algorithm claim (alg).
6969
pub fn alg(&self) -> Option<JwsAlgorithm> {
70-
self.alg.as_ref().copied()
70+
self.alg.as_ref().cloned()
7171
}
7272

7373
/// Sets a value for the algorithm claim (alg).

identity_jose/tests/custom_alg.rs

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Copyright 2020-2024 IOTA Stiftung
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use std::ops::Deref;
5+
use std::time::SystemTime;
6+
7+
use crypto::signatures::ed25519::PublicKey;
8+
use crypto::signatures::ed25519::SecretKey;
9+
use crypto::signatures::ed25519::Signature;
10+
use identity_jose::jwk::EdCurve;
11+
use identity_jose::jwk::Jwk;
12+
use identity_jose::jwk::JwkParamsOkp;
13+
use identity_jose::jwk::JwkType;
14+
use identity_jose::jws::CompactJwsEncoder;
15+
use identity_jose::jws::Decoder;
16+
use identity_jose::jws::JwsAlgorithm;
17+
use identity_jose::jws::JwsHeader;
18+
use identity_jose::jws::JwsVerifierFn;
19+
use identity_jose::jws::SignatureVerificationError;
20+
use identity_jose::jws::SignatureVerificationErrorKind;
21+
use identity_jose::jws::VerificationInput;
22+
use identity_jose::jwt::JwtClaims;
23+
use identity_jose::jwu;
24+
use jsonprooftoken::encoding::base64url_decode;
25+
26+
#[test]
27+
fn custom_alg_roundtrip() {
28+
let secret_key = SecretKey::generate().unwrap();
29+
let public_key = secret_key.public_key();
30+
31+
let mut header: JwsHeader = JwsHeader::new();
32+
header.set_alg(JwsAlgorithm::Custom("test".to_string()));
33+
let kid = "did:iota:0x123#signing-key";
34+
header.set_kid(kid);
35+
36+
let mut claims: JwtClaims<serde_json::Value> = JwtClaims::new();
37+
claims.set_iss("issuer");
38+
claims.set_iat(
39+
SystemTime::now()
40+
.duration_since(SystemTime::UNIX_EPOCH)
41+
.unwrap()
42+
.as_secs() as i64,
43+
);
44+
claims.set_custom(serde_json::json!({"num": 42u64}));
45+
46+
let claims_bytes: Vec<u8> = serde_json::to_vec(&claims).unwrap();
47+
48+
let encoder: CompactJwsEncoder<'_> = CompactJwsEncoder::new(&claims_bytes, &header).unwrap();
49+
let signing_input: &[u8] = encoder.signing_input();
50+
let signature = secret_key.sign(signing_input).to_bytes();
51+
let jws = encoder.into_jws(&signature);
52+
53+
let header = jws.split(".").next().unwrap();
54+
let header_json = String::from_utf8(base64url_decode(header.as_bytes())).expect("failed to decode header");
55+
assert_eq!(header_json, r#"{"kid":"did:iota:0x123#signing-key","alg":"test"}"#);
56+
57+
let verifier = JwsVerifierFn::from(|input: VerificationInput, key: &Jwk| {
58+
if input.alg != JwsAlgorithm::Custom("test".to_string()) {
59+
panic!("invalid algorithm");
60+
}
61+
verify(input, key)
62+
});
63+
let decoder = Decoder::new();
64+
let mut public_key_jwk = Jwk::new(JwkType::Okp);
65+
public_key_jwk.set_kid(kid);
66+
public_key_jwk
67+
.set_params(JwkParamsOkp {
68+
crv: "Ed25519".into(),
69+
x: jwu::encode_b64(public_key.as_slice()),
70+
d: None,
71+
})
72+
.unwrap();
73+
74+
let token = decoder
75+
.decode_compact_serialization(jws.as_bytes(), None)
76+
.and_then(|decoded| decoded.verify(&verifier, &public_key_jwk))
77+
.unwrap();
78+
79+
let recovered_claims: JwtClaims<serde_json::Value> = serde_json::from_slice(&token.claims).unwrap();
80+
81+
assert_eq!(token.protected.alg(), Some(JwsAlgorithm::Custom("test".to_string())));
82+
assert_eq!(claims, recovered_claims);
83+
}
84+
85+
fn verify(verification_input: VerificationInput, jwk: &Jwk) -> Result<(), SignatureVerificationError> {
86+
let public_key = expand_public_jwk(jwk);
87+
88+
let signature_arr = <[u8; Signature::LENGTH]>::try_from(verification_input.decoded_signature.deref())
89+
.map_err(|err| err.to_string())
90+
.unwrap();
91+
92+
let signature = Signature::from_bytes(signature_arr);
93+
if public_key.verify(&signature, &verification_input.signing_input) {
94+
Ok(())
95+
} else {
96+
Err(SignatureVerificationErrorKind::InvalidSignature.into())
97+
}
98+
}
99+
100+
fn expand_public_jwk(jwk: &Jwk) -> PublicKey {
101+
let params: &JwkParamsOkp = jwk.try_okp_params().unwrap();
102+
103+
if params.try_ed_curve().unwrap() != EdCurve::Ed25519 {
104+
panic!("expected an ed25519 jwk");
105+
}
106+
107+
let pk: [u8; PublicKey::LENGTH] = jwu::decode_b64(params.x.as_str()).unwrap().try_into().unwrap();
108+
109+
PublicKey::try_from(pk).unwrap()
110+
}

identity_storage/src/key_storage/memstore.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ impl JwkStorage for JwkMemStore {
5858
async fn generate(&self, key_type: KeyType, alg: JwsAlgorithm) -> KeyStorageResult<JwkGenOutput> {
5959
let key_type: MemStoreKeyType = MemStoreKeyType::try_from(&key_type)?;
6060

61-
check_key_alg_compatibility(key_type, alg)?;
61+
check_key_alg_compatibility(key_type, &alg)?;
6262

6363
let (private_key, public_key) = match key_type {
6464
MemStoreKeyType::Ed25519 => {
@@ -102,7 +102,7 @@ impl JwkStorage for JwkMemStore {
102102
Some(alg) => {
103103
let alg: JwsAlgorithm = JwsAlgorithm::from_str(alg)
104104
.map_err(|err| KeyStorageError::new(KeyStorageErrorKind::UnsupportedSignatureAlgorithm).with_source(err))?;
105-
check_key_alg_compatibility(key_type, alg)?;
105+
check_key_alg_compatibility(key_type, &alg)?;
106106
}
107107
None => {
108108
return Err(
@@ -291,7 +291,7 @@ fn random_key_id() -> KeyId {
291291
}
292292

293293
/// Check that the key type can be used with the algorithm.
294-
fn check_key_alg_compatibility(key_type: MemStoreKeyType, alg: JwsAlgorithm) -> KeyStorageResult<()> {
294+
fn check_key_alg_compatibility(key_type: MemStoreKeyType, alg: &JwsAlgorithm) -> KeyStorageResult<()> {
295295
match (key_type, alg) {
296296
(MemStoreKeyType::Ed25519, JwsAlgorithm::EdDSA) => Ok(()),
297297
(key_type, alg) => Err(

identity_stronghold/src/storage/stronghold_jwk_storage.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl JwkStorage for StrongholdStorage {
3636

3737
let client = get_client(&stronghold)?;
3838
let key_type = StrongholdKeyType::try_from(&key_type)?;
39-
check_key_alg_compatibility(key_type, alg)?;
39+
check_key_alg_compatibility(key_type, &alg)?;
4040

4141
let keytype: ProceduresKeyType = match key_type {
4242
StrongholdKeyType::Ed25519 => ProceduresKeyType::Ed25519,
@@ -106,7 +106,7 @@ impl JwkStorage for StrongholdStorage {
106106
Some(alg) => {
107107
let alg: JwsAlgorithm = JwsAlgorithm::from_str(alg)
108108
.map_err(|err| KeyStorageError::new(KeyStorageErrorKind::UnsupportedSignatureAlgorithm).with_source(err))?;
109-
check_key_alg_compatibility(key_type, alg)?;
109+
check_key_alg_compatibility(key_type, &alg)?;
110110
}
111111
None => {
112112
return Err(

identity_stronghold/src/utils.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub fn random_key_id() -> KeyId {
2424
}
2525

2626
/// Check that the key type can be used with the algorithm.
27-
pub fn check_key_alg_compatibility(key_type: StrongholdKeyType, alg: JwsAlgorithm) -> KeyStorageResult<()> {
27+
pub fn check_key_alg_compatibility(key_type: StrongholdKeyType, alg: &JwsAlgorithm) -> KeyStorageResult<()> {
2828
match (key_type, alg) {
2929
(StrongholdKeyType::Ed25519, JwsAlgorithm::EdDSA) => Ok(()),
3030
(key_type, alg) => Err(

0 commit comments

Comments
 (0)