Skip to content

Commit

Permalink
More progress on eval
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-barrett committed Aug 16, 2023
1 parent d574786 commit 2e71c28
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 25 deletions.
37 changes: 36 additions & 1 deletion src/lem/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ use bellperson::{
use crate::circuit::gadgets::{
constraints::{
add, alloc_equal, alloc_equal_const, alloc_is_zero, allocate_is_negative, and,
boolean_to_num, div, enforce_selector_with_premise, implies_equal, mul, pick, sub,
boolean_to_num, div, enforce_pack, enforce_selector_with_premise, implies_equal, mul, pick,
sub,
},
data::{allocate_constant, hash_poseidon},
pointer::AllocatedPtr,
Expand Down Expand Up @@ -698,6 +699,35 @@ impl Func {
let c = AllocatedPtr::from_parts(tag, lt.clone());
bound_allocations.insert(tgt.clone(), c);
}
Op::BitAnd(tgt, a, b) => {
let a = bound_allocations.get(a)?;
let a_bits = a.hash().to_bits_le(&mut cs.namespace(|| "bitwise_and"))?;
let mut trunc_bits = Vec::with_capacity(64);
let mut b_rest = *b;
for a_bit in a_bits {
let b_bit = b_rest & 1;
if b_bit == 1 {
trunc_bits.push(a_bit.clone());
} else {
trunc_bits.push(Boolean::Constant(false))
}
b_rest /= 2;
}
let trunc = AllocatedNum::alloc(cs.namespace(|| ""), || {
let val = a
.hash()
.get_value()
.map(|a| F::from_u64(a.to_u64_unchecked() & b))
.unwrap();
Ok(val)
})?;
enforce_pack(&mut cs.namespace(|| "enforce_trunc"), &trunc_bits, &trunc)?;
let tag = g
.global_allocator
.get_or_alloc_const(cs, Tag::Expr(Num).to_field())?;
let c = AllocatedPtr::from_parts(tag, trunc);
bound_allocations.insert(tgt.clone(), c);
}
Op::Emit(_) => (),
Op::Hide(tgt, sec, pay) => {
let sec = bound_allocations.get(sec)?;
Expand Down Expand Up @@ -1062,6 +1092,11 @@ impl Func {
globals.insert(FWrap(Tag::Expr(Num).to_field()));
num_constraints += 2;
}
Op::BitAnd(_, _, _) => {
globals.insert(FWrap(Tag::Expr(Num).to_field()));
// bit decomposition + enforce_pack
num_constraints += 257;
}
Op::Emit(_) => (),
Op::Hash2(_, tag, _) => {
// tag for the image
Expand Down
42 changes: 29 additions & 13 deletions src/lem/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,11 +632,9 @@ fn apply_cont() -> Func {
Symbol("u64") => {
match result.tag {
Expr::Num => {
// TODO we also need to use `Mod` to truncate
// But 2^64 is out-of-range of u64, so we will
// maybe use u128
// let limit = Num(18446744073709551616);
let cast = cast(result, Expr::U64);
// The limit is 2**64 - 1
let trunc = bitwise_and(result, 18446744073709551615);
let cast = cast(trunc, Expr::U64);
return(cast, env, continuation, makethunk)
}
Expr::U64 => {
Expand All @@ -656,12 +654,15 @@ fn apply_cont() -> Func {
}
Symbol("char") => {
match result.tag {
Expr::Num | Expr::Char => {
// TODO we also need to use `Mod` to truncate
// let limit = Num(4294967296);
let cast = cast(result, Expr::Num);
Expr::Num => {
// The limit is 2**32 - 1
let trunc = bitwise_and(result, 4294967295);
let cast = cast(trunc, Expr::Char);
return(cast, env, continuation, makethunk)
}
Expr::Char => {
return(result, env, continuation, makethunk)
}
};
return(result, env, err, errctrl)
}
Expand Down Expand Up @@ -802,8 +803,11 @@ fn apply_cont() -> Func {
return (val, env, continuation, makethunk)
}
Num(2) => {
// TODO
return (result, env, err, errctrl)
let val = mul(evaled_arg, result);
// The limit is 2**64 - 1
let trunc = bitwise_and(val, 18446744073709551615);
let cast = cast(trunc, Expr::U64);
return (cast, env, continuation, makethunk)
}
}
}
Expand Down Expand Up @@ -955,8 +959,8 @@ mod tests {
use blstrs::Scalar as Fr;

const NUM_INPUTS: usize = 1;
const NUM_AUX: usize = 8885;
const NUM_CONSTRAINTS: usize = 11139;
const NUM_AUX: usize = 9655;
const NUM_CONSTRAINTS: usize = 11912;
const NUM_SLOTS: SlotsCounter = SlotsCounter {
hash2: 16,
hash3: 4,
Expand Down Expand Up @@ -1009,6 +1013,14 @@ mod tests {
}

fn expr_in_expr_out_pairs(s: &mut Store<Fr>) -> Vec<(Ptr<Fr>, Ptr<Fr>)> {
let u64_1 = s.read("(u64 100000000)").unwrap();
let u64_1_res = s.read("100000000u64").unwrap();
let u64_2 = s.read("(u64 1000000000000000000000000)").unwrap();
let u64_2_res = s.read("2003764205206896640u64").unwrap();
let mul_overflow = s.read("(* 1000000000000u64 100000000000000u64)").unwrap();
let mul_overflow_res = s.read("15908979783594147840u64").unwrap();
let char_conv = s.read("(char 97)").unwrap();
let char_conv_res = s.read("'a'").unwrap();
let t = s.read("t").unwrap();
let nil = s.read("nil").unwrap();
let le1 = s.read("(<= 4 8)").unwrap();
Expand Down Expand Up @@ -1048,6 +1060,10 @@ mod tests {
.unwrap();
let fold_res = s.read("55").unwrap();
vec![
(u64_1, u64_1_res),
(u64_2, u64_2_res),
(mul_overflow, mul_overflow_res),
(char_conv, char_conv_res),
(le1, t),
(le2, t),
(le3, nil),
Expand Down
27 changes: 16 additions & 11 deletions src/lem/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,7 @@ impl Block {
let a = bindings.get(a)?;
let b = bindings.get(b)?;
let c = match (a, b) {
(Ptr::Leaf(Tag::Expr(Num), f), Ptr::Leaf(Tag::Expr(Num), g)) => {
Ptr::Leaf(Tag::Expr(Num), *f + *g)
}
(Ptr::Leaf(_, f), Ptr::Leaf(_, g)) => Ptr::Leaf(Tag::Expr(Num), *f + *g),
_ => bail!("Addition only works on numbers"),
};
bindings.insert(tgt.clone(), c);
Expand All @@ -150,9 +148,7 @@ impl Block {
let a = bindings.get(a)?;
let b = bindings.get(b)?;
let c = match (a, b) {
(Ptr::Leaf(Tag::Expr(Num), f), Ptr::Leaf(Tag::Expr(Num), g)) => {
Ptr::Leaf(Tag::Expr(Num), *f - *g)
}
(Ptr::Leaf(_, f), Ptr::Leaf(_, g)) => Ptr::Leaf(Tag::Expr(Num), *f - *g),
_ => bail!("Addition only works on numbers"),
};
bindings.insert(tgt.clone(), c);
Expand All @@ -161,9 +157,7 @@ impl Block {
let a = bindings.get(a)?;
let b = bindings.get(b)?;
let c = match (a, b) {
(Ptr::Leaf(Tag::Expr(Num), f), Ptr::Leaf(Tag::Expr(Num), g)) => {
Ptr::Leaf(Tag::Expr(Num), *f * *g)
}
(Ptr::Leaf(_, f), Ptr::Leaf(_, g)) => Ptr::Leaf(Tag::Expr(Num), *f * *g),
_ => bail!("Addition only works on numbers"),
};
bindings.insert(tgt.clone(), c);
Expand All @@ -172,7 +166,7 @@ impl Block {
let a = bindings.get(a)?;
let b = bindings.get(b)?;
let c = match (a, b) {
(Ptr::Leaf(Tag::Expr(Num), f), Ptr::Leaf(Tag::Expr(Num), g)) => {
(Ptr::Leaf(_, f), Ptr::Leaf(_, g)) => {
Ptr::Leaf(Tag::Expr(Num), *f * g.invert().unwrap())
}
_ => bail!("Division only works on numbers"),
Expand All @@ -183,7 +177,7 @@ impl Block {
let a = bindings.get(a)?;
let b = bindings.get(b)?;
let c = match (a, b) {
(Ptr::Leaf(Tag::Expr(Num), f), Ptr::Leaf(Tag::Expr(Num), g)) => {
(Ptr::Leaf(_, f), Ptr::Leaf(_, g)) => {
preimages
.is_diff_neg
.push(Some(PreimageData::FPair(*f, *g)));
Expand All @@ -196,6 +190,17 @@ impl Block {
};
bindings.insert(tgt.clone(), c);
}
Op::BitAnd(tgt, a, b) => {
let a = bindings.get(a)?;

let c = match a {
Ptr::Leaf(_, f) => {
Ptr::Leaf(Tag::Expr(Num), F::from_u64(f.to_u64_unchecked() & b))
}
_ => bail!("`&` only works on numbers"),
};
bindings.insert(tgt.clone(), c);
}
Op::Emit(a) => {
let a = bindings.get(a)?;
println!("{}", a.dgb_display(store))
Expand Down
17 changes: 17 additions & 0 deletions src/lem/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ macro_rules! op {
$crate::var!($b),
)
};
( let $tgt:ident = bitwise_and($a:ident, $b:literal) ) => {
$crate::lem::Op::BitAnd(
$crate::var!($tgt),
$crate::var!($a),
$b,
)
};
( emit($v:ident) ) => {
$crate::lem::Op::Emit($crate::var!($v))
};
Expand Down Expand Up @@ -334,6 +341,16 @@ macro_rules! block {
$($tail)*
)
};
(@seq {$($limbs:expr)*}, let $tgt:ident = bitwise_and($a:ident, $b:literal) ; $($tail:tt)*) => {
$crate::block! (
@seq
{
$($limbs)*
$crate::op!(let $tgt = bitwise_and($a, $b))
},
$($tail)*
)
};
(@seq {$($limbs:expr)*}, emit($v:ident) ; $($tail:tt)*) => {
$crate::block! (
@seq
Expand Down
11 changes: 11 additions & 0 deletions src/lem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ pub enum Op {
Div(Var, Var, Var),
/// `Lt(y, a, b)` binds `y` to `1` if `a < b`, or to `0` otherwise
Lt(Var, Var, Var),
/// `BitAnd(y, a, b)` binds `y` to `a & b`
BitAnd(Var, Var, u64),
/// `Emit(v)` simply prints out the value of `v` when interpreting the code
Emit(Var),
/// `Hash2(x, t, ys)` binds `x` to a `Ptr` with tag `t` and 2 children `ys`
Expand Down Expand Up @@ -357,6 +359,10 @@ impl Func {
is_bound(b, map)?;
is_unique(tgt, map);
}
Op::BitAnd(tgt, a, _) => {
is_bound(a, map)?;
is_unique(tgt, map);
}
Op::Emit(a) => {
is_bound(a, map)?;
}
Expand Down Expand Up @@ -602,6 +608,11 @@ impl Block {
let tgt = insert_one(map, uniq, &tgt);
ops.push(Op::Lt(tgt, a, b))
}
Op::BitAnd(tgt, a, b) => {
let a = map.get_cloned(&a)?;
let tgt = insert_one(map, uniq, &tgt);
ops.push(Op::BitAnd(tgt, a, b))
}
Op::Emit(a) => {
let a = map.get_cloned(&a)?;
ops.push(Op::Emit(a))
Expand Down

0 comments on commit 2e71c28

Please sign in to comment.