Skip to content

Commit

Permalink
Add support for custom JWS algorithms (#1410)
Browse files Browse the repository at this point in the history
* Add support for custom JWS algorithms

This PR introduces a feature `custom_alg` to `identity_jose` (disabled
by default) that allows it to process JWS with custom `alg` values.

Switching on `custom_alg` makes quite a few changes to `JwsAlgorithm`:
- The type is no longer `Copy`
- `name()` takes only a reference and returns a `String` rather than
  `&'static str`
- The constant `ALL` is removed as it is no longer possible to enumerate
  all variants

* fmt

* Add comment

* Nightly fmt

* chore: add template for custom_alg file

* Split implementation of Display

---------

Co-authored-by: Yasir <[email protected]>
  • Loading branch information
frederikrothenberger and itsyaasir authored Sep 17, 2024
1 parent 13acb23 commit 26afa2c
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 13 deletions.
7 changes: 7 additions & 0 deletions identity_jose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,10 @@ test = true

[lints]
workspace = true

[features]
custom_alg = []

[[test]]
name = "custom_alg"
required-features = ["custom_alg"]
4 changes: 2 additions & 2 deletions identity_jose/src/jwk/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,9 @@ impl Jwk {
// ===========================================================================

/// Checks if the `alg` claim of the JWK is equal to `expected`.
pub fn check_alg(&self, expected: &str) -> Result<()> {
pub fn check_alg(&self, expected: impl AsRef<str>) -> Result<()> {
match self.alg() {
Some(value) if value == expected => Ok(()),
Some(value) if value == expected.as_ref() => Ok(()),
Some(_) => Err(Error::InvalidClaim("alg")),
None => Ok(()),
}
Expand Down
51 changes: 47 additions & 4 deletions identity_jose/src/jws/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ use core::fmt::Formatter;
use core::fmt::Result;
use std::str::FromStr;

use crate::error::Error;

/// Supported algorithms for the JSON Web Signatures `alg` claim.
///
/// [More Info](https://www.iana.org/assignments/jose/jose.xhtml#web-signature-encryption-algorithms)
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, serde::Deserialize, serde::Serialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, serde::Deserialize, serde::Serialize)]
#[cfg_attr(not(feature = "custom_alg"), derive(Copy))]
#[allow(non_camel_case_types)]
pub enum JwsAlgorithm {
/// HMAC using SHA-256
Expand Down Expand Up @@ -45,10 +44,19 @@ pub enum JwsAlgorithm {
NONE,
/// EdDSA signature algorithms
EdDSA,
/// Custom algorithm
#[cfg(feature = "custom_alg")]
#[serde(untagged)]
Custom(String),
}

impl JwsAlgorithm {
/// A slice of all supported [`JwsAlgorithm`]s.
///
/// Not available when feature `custom_alg` is enabled
/// as it is not possible to enumerate all variants when
/// supporting arbitrary `alg` values.
#[cfg(not(feature = "custom_alg"))]
pub const ALL: &'static [Self] = &[
Self::HS256,
Self::HS384,
Expand All @@ -68,6 +76,7 @@ impl JwsAlgorithm {
];

/// Returns the JWS algorithm as a `str` slice.
#[cfg(not(feature = "custom_alg"))]
pub const fn name(self) -> &'static str {
match self {
Self::HS256 => "HS256",
Expand All @@ -87,6 +96,29 @@ impl JwsAlgorithm {
Self::EdDSA => "EdDSA",
}
}

/// Returns the JWS algorithm as a `str` slice.
#[cfg(feature = "custom_alg")]
pub fn name(&self) -> String {
match self {
Self::HS256 => "HS256".to_string(),
Self::HS384 => "HS384".to_string(),
Self::HS512 => "HS512".to_string(),
Self::RS256 => "RS256".to_string(),
Self::RS384 => "RS384".to_string(),
Self::RS512 => "RS512".to_string(),
Self::PS256 => "PS256".to_string(),
Self::PS384 => "PS384".to_string(),
Self::PS512 => "PS512".to_string(),
Self::ES256 => "ES256".to_string(),
Self::ES384 => "ES384".to_string(),
Self::ES512 => "ES512".to_string(),
Self::ES256K => "ES256K".to_string(),
Self::NONE => "none".to_string(),
Self::EdDSA => "EdDSA".to_string(),
Self::Custom(name) => name.clone(),
}
}
}

impl FromStr for JwsAlgorithm {
Expand All @@ -109,13 +141,24 @@ impl FromStr for JwsAlgorithm {
"ES256K" => Ok(Self::ES256K),
"none" => Ok(Self::NONE),
"EdDSA" => Ok(Self::EdDSA),
_ => Err(Error::JwsAlgorithmParsingError),
#[cfg(feature = "custom_alg")]
value => Ok(Self::Custom(value.to_string())),
#[cfg(not(feature = "custom_alg"))]
_ => Err(crate::error::Error::JwsAlgorithmParsingError),
}
}
}

#[cfg(not(feature = "custom_alg"))]
impl Display for JwsAlgorithm {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
f.write_str(self.name())
}
}

#[cfg(feature = "custom_alg")]
impl Display for JwsAlgorithm {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
f.write_str(&(*self).name())
}
}
2 changes: 1 addition & 1 deletion identity_jose/src/jws/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl JwsHeader {

/// Returns the value for the algorithm claim (alg).
pub fn alg(&self) -> Option<JwsAlgorithm> {
self.alg.as_ref().copied()
self.alg.as_ref().cloned()
}

/// Sets a value for the algorithm claim (alg).
Expand Down
110 changes: 110 additions & 0 deletions identity_jose/tests/custom_alg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright 2020-2024 IOTA Stiftung
// SPDX-License-Identifier: Apache-2.0

use std::ops::Deref;
use std::time::SystemTime;

use crypto::signatures::ed25519::PublicKey;
use crypto::signatures::ed25519::SecretKey;
use crypto::signatures::ed25519::Signature;
use identity_jose::jwk::EdCurve;
use identity_jose::jwk::Jwk;
use identity_jose::jwk::JwkParamsOkp;
use identity_jose::jwk::JwkType;
use identity_jose::jws::CompactJwsEncoder;
use identity_jose::jws::Decoder;
use identity_jose::jws::JwsAlgorithm;
use identity_jose::jws::JwsHeader;
use identity_jose::jws::JwsVerifierFn;
use identity_jose::jws::SignatureVerificationError;
use identity_jose::jws::SignatureVerificationErrorKind;
use identity_jose::jws::VerificationInput;
use identity_jose::jwt::JwtClaims;
use identity_jose::jwu;
use jsonprooftoken::encoding::base64url_decode;

#[test]
fn custom_alg_roundtrip() {
let secret_key = SecretKey::generate().unwrap();
let public_key = secret_key.public_key();

let mut header: JwsHeader = JwsHeader::new();
header.set_alg(JwsAlgorithm::Custom("test".to_string()));
let kid = "did:iota:0x123#signing-key";
header.set_kid(kid);

let mut claims: JwtClaims<serde_json::Value> = JwtClaims::new();
claims.set_iss("issuer");
claims.set_iat(
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
);
claims.set_custom(serde_json::json!({"num": 42u64}));

let claims_bytes: Vec<u8> = serde_json::to_vec(&claims).unwrap();

let encoder: CompactJwsEncoder<'_> = CompactJwsEncoder::new(&claims_bytes, &header).unwrap();
let signing_input: &[u8] = encoder.signing_input();
let signature = secret_key.sign(signing_input).to_bytes();
let jws = encoder.into_jws(&signature);

let header = jws.split(".").next().unwrap();
let header_json = String::from_utf8(base64url_decode(header.as_bytes())).expect("failed to decode header");
assert_eq!(header_json, r#"{"kid":"did:iota:0x123#signing-key","alg":"test"}"#);

let verifier = JwsVerifierFn::from(|input: VerificationInput, key: &Jwk| {
if input.alg != JwsAlgorithm::Custom("test".to_string()) {
panic!("invalid algorithm");
}
verify(input, key)
});
let decoder = Decoder::new();
let mut public_key_jwk = Jwk::new(JwkType::Okp);
public_key_jwk.set_kid(kid);
public_key_jwk
.set_params(JwkParamsOkp {
crv: "Ed25519".into(),
x: jwu::encode_b64(public_key.as_slice()),
d: None,
})
.unwrap();

let token = decoder
.decode_compact_serialization(jws.as_bytes(), None)
.and_then(|decoded| decoded.verify(&verifier, &public_key_jwk))
.unwrap();

let recovered_claims: JwtClaims<serde_json::Value> = serde_json::from_slice(&token.claims).unwrap();

assert_eq!(token.protected.alg(), Some(JwsAlgorithm::Custom("test".to_string())));
assert_eq!(claims, recovered_claims);
}

fn verify(verification_input: VerificationInput, jwk: &Jwk) -> Result<(), SignatureVerificationError> {
let public_key = expand_public_jwk(jwk);

let signature_arr = <[u8; Signature::LENGTH]>::try_from(verification_input.decoded_signature.deref())
.map_err(|err| err.to_string())
.unwrap();

let signature = Signature::from_bytes(signature_arr);
if public_key.verify(&signature, &verification_input.signing_input) {
Ok(())
} else {
Err(SignatureVerificationErrorKind::InvalidSignature.into())
}
}

fn expand_public_jwk(jwk: &Jwk) -> PublicKey {
let params: &JwkParamsOkp = jwk.try_okp_params().unwrap();

if params.try_ed_curve().unwrap() != EdCurve::Ed25519 {
panic!("expected an ed25519 jwk");
}

let pk: [u8; PublicKey::LENGTH] = jwu::decode_b64(params.x.as_str()).unwrap().try_into().unwrap();

PublicKey::try_from(pk).unwrap()
}
6 changes: 3 additions & 3 deletions identity_storage/src/key_storage/memstore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl JwkStorage for JwkMemStore {
async fn generate(&self, key_type: KeyType, alg: JwsAlgorithm) -> KeyStorageResult<JwkGenOutput> {
let key_type: MemStoreKeyType = MemStoreKeyType::try_from(&key_type)?;

check_key_alg_compatibility(key_type, alg)?;
check_key_alg_compatibility(key_type, &alg)?;

let (private_key, public_key) = match key_type {
MemStoreKeyType::Ed25519 => {
Expand Down Expand Up @@ -102,7 +102,7 @@ impl JwkStorage for JwkMemStore {
Some(alg) => {
let alg: JwsAlgorithm = JwsAlgorithm::from_str(alg)
.map_err(|err| KeyStorageError::new(KeyStorageErrorKind::UnsupportedSignatureAlgorithm).with_source(err))?;
check_key_alg_compatibility(key_type, alg)?;
check_key_alg_compatibility(key_type, &alg)?;
}
None => {
return Err(
Expand Down Expand Up @@ -291,7 +291,7 @@ fn random_key_id() -> KeyId {
}

/// Check that the key type can be used with the algorithm.
fn check_key_alg_compatibility(key_type: MemStoreKeyType, alg: JwsAlgorithm) -> KeyStorageResult<()> {
fn check_key_alg_compatibility(key_type: MemStoreKeyType, alg: &JwsAlgorithm) -> KeyStorageResult<()> {
match (key_type, alg) {
(MemStoreKeyType::Ed25519, JwsAlgorithm::EdDSA) => Ok(()),
(key_type, alg) => Err(
Expand Down
4 changes: 2 additions & 2 deletions identity_stronghold/src/storage/stronghold_jwk_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl JwkStorage for StrongholdStorage {

let client = get_client(&stronghold)?;
let key_type = StrongholdKeyType::try_from(&key_type)?;
check_key_alg_compatibility(key_type, alg)?;
check_key_alg_compatibility(key_type, &alg)?;

let keytype: ProceduresKeyType = match key_type {
StrongholdKeyType::Ed25519 => ProceduresKeyType::Ed25519,
Expand Down Expand Up @@ -106,7 +106,7 @@ impl JwkStorage for StrongholdStorage {
Some(alg) => {
let alg: JwsAlgorithm = JwsAlgorithm::from_str(alg)
.map_err(|err| KeyStorageError::new(KeyStorageErrorKind::UnsupportedSignatureAlgorithm).with_source(err))?;
check_key_alg_compatibility(key_type, alg)?;
check_key_alg_compatibility(key_type, &alg)?;
}
None => {
return Err(
Expand Down
2 changes: 1 addition & 1 deletion identity_stronghold/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub fn random_key_id() -> KeyId {
}

/// Check that the key type can be used with the algorithm.
pub fn check_key_alg_compatibility(key_type: StrongholdKeyType, alg: JwsAlgorithm) -> KeyStorageResult<()> {
pub fn check_key_alg_compatibility(key_type: StrongholdKeyType, alg: &JwsAlgorithm) -> KeyStorageResult<()> {
match (key_type, alg) {
(StrongholdKeyType::Ed25519, JwsAlgorithm::EdDSA) => Ok(()),
(key_type, alg) => Err(
Expand Down

0 comments on commit 26afa2c

Please sign in to comment.