Skip to content

Commit

Permalink
rewrite decode_varint and greatly expand testcases
Browse files Browse the repository at this point in the history
  • Loading branch information
TheButlah committed May 19, 2024
1 parent 07afa0f commit 0c19d66
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 33 deletions.
4 changes: 4 additions & 0 deletions crates/did-simple/src/key_algos.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::varint::encode_varint;

/// A key algorithm.
pub trait KeyAlgo {
fn pub_key_size(&self) -> usize;
Expand All @@ -8,6 +10,8 @@ pub trait KeyAlgo {
pub trait StaticKeyAlgo: KeyAlgo {
const PUB_KEY_SIZE: usize;
const MULTICODEC_VALUE: u16;
const MULTICODEC_VALUE_ENCODED: &'static [u8] =
encode_varint(Self::MULTICODEC_VALUE).as_slice();
}

impl<T: StaticKeyAlgo> KeyAlgo for T {
Expand Down
28 changes: 21 additions & 7 deletions crates/did-simple/src/methods/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use std::fmt::Display;

use crate::{
key_algos::{DynKeyAlgo, Ed25519, StaticKeyAlgo},
key_algos::{DynKeyAlgo, Ed25519, KeyAlgo, StaticKeyAlgo},
uri::{DidMethod, DidUri},
utf8bytes::Utf8Bytes,
varint::decode_varint,
Expand All @@ -20,6 +20,8 @@ pub struct DidKey<A = DynKeyAlgo> {
/// The decoded multibase portion of the DID.
mb_value: Vec<u8>,
key_algo: A,
/// The index into [`Self::mb_value`] that is the public key.
pubkey_bytes: std::ops::RangeFrom<usize>,
}

pub const PREFIX: &str = "did:key:";
Expand Down Expand Up @@ -93,20 +95,30 @@ impl TryFrom<DidUri> for DidKey {
);

let s = value.as_utf8_bytes().clone();
let mut decoded = Vec::new();
decode_multibase(&s, &mut decoded)?;
let mut decoded_multibase = Vec::new();
decode_multibase(&s, &mut decoded_multibase)?;

// TODO: Instead of comparing decoded versions which requires running the decode
// function at runtime, compare the encoded versions. We can do the encode at
// compile time.
let multicodec_key_type = decode_varint(&decoded[0..2])?;
let key_algo = match multicodec_key_type {
let (multicodec_key_algo, pubkey_bytes) = decode_varint(&decoded_multibase)?;
let key_algo = match multicodec_key_algo {
Ed25519::MULTICODEC_VALUE => DynKeyAlgo::Ed25519,
_ => return Err(FromUriError::UnknownKeyAlgo(multicodec_key_type)),
_ => return Err(FromUriError::UnknownKeyAlgo(multicodec_key_algo)),
};

let pubkey_len = pubkey_bytes.len();
if pubkey_len != key_algo.pub_key_size() {
return Err(FromUriError::MismatchedPubkeyLen(key_algo, pubkey_len));
}

let pubkey_bytes = decoded_multibase.len() - pubkey_len..;

Ok(Self {
s,
mb_value: decoded,
mb_value: decoded_multibase,
key_algo,
pubkey_bytes,
})
}
}
Expand All @@ -121,6 +133,8 @@ pub enum FromUriError {
UnknownKeyAlgo(u16),
#[error(transparent)]
Varint(#[from] crate::varint::DecodeError),
#[error("{0:?} requires pubkeys of length {} but got {1} bytes", .0.pub_key_size())]
MismatchedPubkeyLen(DynKeyAlgo, usize),
}

impl Display for DidKey {
Expand Down
104 changes: 78 additions & 26 deletions crates/did-simple/src/varint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,30 @@ const fn msb_is_1(val: u8) -> bool {

#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub(crate) struct VarintEncoding {
buf: [u8; 3],
buf: [u8; Self::MAX_LEN],
len: u8,
}

impl VarintEncoding {
#[allow(dead_code)]
pub const MAX_LEN: usize = 3;

pub const fn as_slice(&self) -> &[u8] {
self.buf.split_at(self.len as usize).0
}
}

/// Encodes a value as a varint.
/// Returns an array as well as the length of the array to slice., along well as an array.
#[allow(dead_code)]
pub(crate) const fn encode_varint(value: u16) -> VarintEncoding {
let mut out_buf = [0; 3];
let mut out_buf = [0; VarintEncoding::MAX_LEN];
// ilog2 can be used to get the (0-indexed from LSB position) of the MSB.
// We then add one to it, since we are trying to indicate length, not position.
let in_bit_length: u16 = if let Some(x) = value.checked_ilog2() {
(x + 1) as u16
} else {
return VarintEncoding {
buf: out_buf,
len: 0,
len: 1,
};
};

Expand Down Expand Up @@ -76,33 +76,34 @@ pub(crate) const fn encode_varint(value: u16) -> VarintEncoding {
}
}

pub(crate) const fn decode_varint(encoded: &[u8]) -> Result<u16, DecodeError> {
// TODO: Technically, some three byte encodings could fit into a u16, we
// should support those in the future.
// Luckily none of them are used for did:key afaik.
if encoded.len() > 2 {
return Err(DecodeError::WouldOverflow);
}
/// Returns tuple of decoded varint and rest of buffer
pub(crate) const fn decode_varint(encoded: &[u8]) -> Result<(u16, &[u8]), DecodeError> {
if encoded.is_empty() {
return Err(DecodeError::MissingBytes);
}

let a = encoded[0];
let mut result: u16 = (a & LSB_7) as u16;
if msb_is_1(a) {
// There is another 7 bits to decode.
if encoded.len() < 2 {
let mut decoded: u16 = 0;
let mut current_encoded_idx = 0;
let mut current_decoded_bit = 0;
while current_decoded_bit < u16::BITS {
if current_encoded_idx >= encoded.len() {
return Err(DecodeError::MissingBytes);
}
let b = encoded[1];

result |= ((b & LSB_7) as u16) << 7;
if msb_is_1(b) {
// We were provided a varint that ought to have had at least another byte.
return Err(DecodeError::MissingBytes);
let current_byte = encoded[current_encoded_idx];
let Some(shifted) =
((current_byte & LSB_7) as u16).checked_shl(current_decoded_bit)
else {
return Err(DecodeError::WouldOverflow);
};
decoded |= shifted;
current_decoded_bit += 7;
current_encoded_idx += 1;
if !msb_is_1(current_byte) {
break;
}
}
Ok(result)
let bytes_in_varint = current_encoded_idx;

Ok((decoded, encoded.split_at(bytes_in_varint).1))
}

#[derive(thiserror::Error, Debug, Eq, PartialEq)]
Expand All @@ -119,6 +120,20 @@ pub enum DecodeError {
mod test {
use super::*;

fn test_roundtrip(decoded: u16) {
let extra_bytes = [1, 2, 3, 4].as_slice();
let empty = [].as_slice();
let encoded = encode_varint(decoded);
println!("encoded {:?}", encoded.as_slice());

let round_tripped = decode_varint(encoded.as_slice());
assert_eq!(Ok((decoded, empty)), round_tripped);
let mut encoded_with_extra = encoded.as_slice().to_vec();
encoded_with_extra.extend_from_slice(extra_bytes);
let round_tripped_with_extra = decode_varint(&encoded_with_extra);
assert_eq!(Ok((decoded, extra_bytes)), round_tripped_with_extra)
}

#[test]
fn test_known_examples() {
// See https://github.com/multiformats/unsigned-varint/blob/16bf9f7d3ff78c10c1ab26d397c03c91205cd4ee/README.md
Expand All @@ -138,9 +153,46 @@ mod test {
(0xe7, &[0xe7, 0x01]), // Secp256k1
];

let empty: &[u8] = &[];
let extra_bytes = [1, 2, 3, 4];
for (decoded, encoded) in examples1.into_iter().chain(examples2) {
assert_eq!(Ok(decoded), decode_varint(encoded));
assert_eq!(encoded, encode_varint(decoded).as_slice());
assert_eq!(Ok((decoded, empty)), decode_varint(encoded));
// Do another check with some additional extra bytes in the encoding
let mut extended_encoded = encoded.to_vec();
extended_encoded.extend_from_slice(&extra_bytes);
assert_eq!(
Ok((decoded, extra_bytes.as_slice())),
decode_varint(&extended_encoded)
);
test_roundtrip(decoded);
}
}

#[test]
fn test_all_u16_roundtrip() {
for i in 0..u16::MAX {
test_roundtrip(i);
}
}

#[test]
fn test_decode_boundary_conditions() {
let examples = [
([].as_slice(), Err(DecodeError::MissingBytes)),
(&[0], Ok((0, [].as_slice()))),
(&[0, 0], Ok((0, &[0]))),
(&[0, 1], Ok((0, &[1]))),
(&[0, u8::MAX], Ok((0, &[u8::MAX]))),
(&[u8::MAX], Err(DecodeError::MissingBytes)),
(&[1 << 7], Err(DecodeError::MissingBytes)),
(&[LSB_7], Ok((LSB_7 as u16, &[]))),
(&[LSB_7, 0], Ok((LSB_7 as u16, &[0]))),
(&[LSB_7, u8::MAX], Ok((LSB_7 as u16, &[u8::MAX]))),
];

for (encoded, result) in examples {
assert_eq!(decode_varint(encoded), result, "decoded from {encoded:?}");
}
}
}

0 comments on commit 0c19d66

Please sign in to comment.