Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

did-simple: improve varint decode, expand testcases #97

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/did-simple/src/key_algos.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::varint::encode_varint;

/// A key algorithm.
pub trait KeyAlgo {
fn pub_key_size(&self) -> usize;
Expand All @@ -8,6 +10,8 @@ pub trait KeyAlgo {
pub trait StaticKeyAlgo: KeyAlgo {
const PUB_KEY_SIZE: usize;
const MULTICODEC_VALUE: u16;
const MULTICODEC_VALUE_ENCODED: &'static [u8] =
encode_varint(Self::MULTICODEC_VALUE).as_slice();
}

impl<T: StaticKeyAlgo> KeyAlgo for T {
Expand Down
28 changes: 21 additions & 7 deletions crates/did-simple/src/methods/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use std::fmt::Display;

use crate::{
key_algos::{DynKeyAlgo, Ed25519, StaticKeyAlgo},
key_algos::{DynKeyAlgo, Ed25519, KeyAlgo, StaticKeyAlgo},
uri::{DidMethod, DidUri},
utf8bytes::Utf8Bytes,
varint::decode_varint,
Expand All @@ -20,6 +20,8 @@ pub struct DidKey<A = DynKeyAlgo> {
/// The decoded multibase portion of the DID.
mb_value: Vec<u8>,
key_algo: A,
/// The index into [`Self::mb_value`] that is the public key.
pubkey_bytes: std::ops::RangeFrom<usize>,
}

pub const PREFIX: &str = "did:key:";
Expand Down Expand Up @@ -93,20 +95,30 @@ impl TryFrom<DidUri> for DidKey {
);

let s = value.as_utf8_bytes().clone();
let mut decoded = Vec::new();
decode_multibase(&s, &mut decoded)?;
let mut decoded_multibase = Vec::new();
decode_multibase(&s, &mut decoded_multibase)?;

// TODO: Instead of comparing decoded versions which requires running the decode
// function at runtime, compare the encoded versions. We can do the encode at
// compile time.
let multicodec_key_type = decode_varint(&decoded[0..2])?;
let key_algo = match multicodec_key_type {
let (multicodec_key_algo, pubkey_bytes) = decode_varint(&decoded_multibase)?;
let key_algo = match multicodec_key_algo {
Ed25519::MULTICODEC_VALUE => DynKeyAlgo::Ed25519,
_ => return Err(FromUriError::UnknownKeyAlgo(multicodec_key_type)),
_ => return Err(FromUriError::UnknownKeyAlgo(multicodec_key_algo)),
};

let pubkey_len = pubkey_bytes.len();
if pubkey_len != key_algo.pub_key_size() {
return Err(FromUriError::MismatchedPubkeyLen(key_algo, pubkey_len));
}

let pubkey_bytes = decoded_multibase.len() - pubkey_len..;

Ok(Self {
s,
mb_value: decoded,
mb_value: decoded_multibase,
key_algo,
pubkey_bytes,
})
}
}
Expand All @@ -121,6 +133,8 @@ pub enum FromUriError {
UnknownKeyAlgo(u16),
#[error(transparent)]
Varint(#[from] crate::varint::DecodeError),
#[error("{0:?} requires pubkeys of length {} but got {1} bytes", .0.pub_key_size())]
MismatchedPubkeyLen(DynKeyAlgo, usize),
}

impl Display for DidKey {
Expand Down
104 changes: 78 additions & 26 deletions crates/did-simple/src/varint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,30 @@ const fn msb_is_1(val: u8) -> bool {

#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub(crate) struct VarintEncoding {
buf: [u8; 3],
buf: [u8; Self::MAX_LEN],
len: u8,
}

impl VarintEncoding {
#[allow(dead_code)]
pub const MAX_LEN: usize = 3;

pub const fn as_slice(&self) -> &[u8] {
self.buf.split_at(self.len as usize).0
}
}

/// Encodes a value as a varint.
/// Returns an array as well as the length of the array to slice., along well as an array.
#[allow(dead_code)]
pub(crate) const fn encode_varint(value: u16) -> VarintEncoding {
let mut out_buf = [0; 3];
let mut out_buf = [0; VarintEncoding::MAX_LEN];
// ilog2 can be used to get the (0-indexed from LSB position) of the MSB.
// We then add one to it, since we are trying to indicate length, not position.
let in_bit_length: u16 = if let Some(x) = value.checked_ilog2() {
(x + 1) as u16
} else {
return VarintEncoding {
buf: out_buf,
len: 0,
len: 1,
};
};

Expand Down Expand Up @@ -76,33 +76,34 @@ pub(crate) const fn encode_varint(value: u16) -> VarintEncoding {
}
}

pub(crate) const fn decode_varint(encoded: &[u8]) -> Result<u16, DecodeError> {
// TODO: Technically, some three byte encodings could fit into a u16, we
// should support those in the future.
// Luckily none of them are used for did:key afaik.
if encoded.len() > 2 {
return Err(DecodeError::WouldOverflow);
}
/// Returns tuple of decoded varint and rest of buffer
pub(crate) const fn decode_varint(encoded: &[u8]) -> Result<(u16, &[u8]), DecodeError> {
if encoded.is_empty() {
return Err(DecodeError::MissingBytes);
}

let a = encoded[0];
let mut result: u16 = (a & LSB_7) as u16;
if msb_is_1(a) {
// There is another 7 bits to decode.
if encoded.len() < 2 {
let mut decoded: u16 = 0;
let mut current_encoded_idx = 0;
let mut current_decoded_bit = 0;
while current_decoded_bit < u16::BITS {
if current_encoded_idx >= encoded.len() {
return Err(DecodeError::MissingBytes);
}
let b = encoded[1];

result |= ((b & LSB_7) as u16) << 7;
if msb_is_1(b) {
// We were provided a varint that ought to have had at least another byte.
return Err(DecodeError::MissingBytes);
let current_byte = encoded[current_encoded_idx];
let Some(shifted) =
((current_byte & LSB_7) as u16).checked_shl(current_decoded_bit)
else {
return Err(DecodeError::WouldOverflow);
};
decoded |= shifted;
current_decoded_bit += 7;
current_encoded_idx += 1;
if !msb_is_1(current_byte) {
break;
}
}
Ok(result)
let bytes_in_varint = current_encoded_idx;

Ok((decoded, encoded.split_at(bytes_in_varint).1))
}

#[derive(thiserror::Error, Debug, Eq, PartialEq)]
Expand All @@ -119,6 +120,20 @@ pub enum DecodeError {
mod test {
use super::*;

fn test_roundtrip(decoded: u16) {
let extra_bytes = [1, 2, 3, 4].as_slice();
let empty = [].as_slice();
let encoded = encode_varint(decoded);
println!("encoded {:?}", encoded.as_slice());

let round_tripped = decode_varint(encoded.as_slice());
assert_eq!(Ok((decoded, empty)), round_tripped);
let mut encoded_with_extra = encoded.as_slice().to_vec();
encoded_with_extra.extend_from_slice(extra_bytes);
let round_tripped_with_extra = decode_varint(&encoded_with_extra);
assert_eq!(Ok((decoded, extra_bytes)), round_tripped_with_extra)
}

#[test]
fn test_known_examples() {
// See https://github.com/multiformats/unsigned-varint/blob/16bf9f7d3ff78c10c1ab26d397c03c91205cd4ee/README.md
Expand All @@ -138,9 +153,46 @@ mod test {
(0xe7, &[0xe7, 0x01]), // Secp256k1
];

let empty: &[u8] = &[];
let extra_bytes = [1, 2, 3, 4];
for (decoded, encoded) in examples1.into_iter().chain(examples2) {
assert_eq!(Ok(decoded), decode_varint(encoded));
assert_eq!(encoded, encode_varint(decoded).as_slice());
assert_eq!(Ok((decoded, empty)), decode_varint(encoded));
// Do another check with some additional extra bytes in the encoding
let mut extended_encoded = encoded.to_vec();
extended_encoded.extend_from_slice(&extra_bytes);
assert_eq!(
Ok((decoded, extra_bytes.as_slice())),
decode_varint(&extended_encoded)
);
test_roundtrip(decoded);
}
}

#[test]
fn test_all_u16_roundtrip() {
for i in 0..u16::MAX {
test_roundtrip(i);
}
}

#[test]
fn test_decode_boundary_conditions() {
let examples = [
([].as_slice(), Err(DecodeError::MissingBytes)),
(&[0], Ok((0, [].as_slice()))),
(&[0, 0], Ok((0, &[0]))),
(&[0, 1], Ok((0, &[1]))),
(&[0, u8::MAX], Ok((0, &[u8::MAX]))),
(&[u8::MAX], Err(DecodeError::MissingBytes)),
(&[1 << 7], Err(DecodeError::MissingBytes)),
(&[LSB_7], Ok((LSB_7 as u16, &[]))),
(&[LSB_7, 0], Ok((LSB_7 as u16, &[0]))),
(&[LSB_7, u8::MAX], Ok((LSB_7 as u16, &[u8::MAX]))),
];

for (encoded, result) in examples {
assert_eq!(decode_varint(encoded), result, "decoded from {encoded:?}");
}
}
}
Loading