diff --git a/src/lem/circuit.rs b/src/lem/circuit.rs index a6a2beff95..1d9dbb2e7b 100644 --- a/src/lem/circuit.rs +++ b/src/lem/circuit.rs @@ -35,7 +35,8 @@ use bellperson::{ use crate::circuit::gadgets::{ constraints::{ add, alloc_equal, alloc_equal_const, alloc_is_zero, allocate_is_negative, and, - boolean_to_num, div, enforce_selector_with_premise, implies_equal, mul, pick, sub, + boolean_to_num, div, enforce_pack, enforce_selector_with_premise, implies_equal, mul, pick, + sub, }, data::{allocate_constant, hash_poseidon}, pointer::AllocatedPtr, @@ -698,6 +699,35 @@ impl Func { let c = AllocatedPtr::from_parts(tag, lt.clone()); bound_allocations.insert(tgt.clone(), c); } + Op::BitAnd(tgt, a, b) => { + let a = bound_allocations.get(a)?; + let a_bits = a.hash().to_bits_le(&mut cs.namespace(|| "bitwise_and"))?; + let mut trunc_bits = Vec::with_capacity(64); + let mut b_rest = *b; + for a_bit in a_bits { + let b_bit = b_rest & 1; + if b_bit == 1 { + trunc_bits.push(a_bit.clone()); + } else { + trunc_bits.push(Boolean::Constant(false)) + } + b_rest /= 2; + } + let trunc = AllocatedNum::alloc(cs.namespace(|| ""), || { + let val = a + .hash() + .get_value() + .map(|a| F::from_u64(a.to_u64_unchecked() & b)) + .unwrap(); + Ok(val) + })?; + enforce_pack(&mut cs.namespace(|| "enforce_trunc"), &trunc_bits, &trunc)?; + let tag = g + .global_allocator + .get_or_alloc_const(cs, Tag::Expr(Num).to_field())?; + let c = AllocatedPtr::from_parts(tag, trunc); + bound_allocations.insert(tgt.clone(), c); + } Op::Emit(_) => (), Op::Hide(tgt, sec, pay) => { let sec = bound_allocations.get(sec)?; @@ -1062,6 +1092,11 @@ impl Func { globals.insert(FWrap(Tag::Expr(Num).to_field())); num_constraints += 2; } + Op::BitAnd(_, _, _) => { + globals.insert(FWrap(Tag::Expr(Num).to_field())); + // bit decomposition + enforce_pack + num_constraints += 257; + } Op::Emit(_) => (), Op::Hash2(_, tag, _) => { // tag for the image diff --git a/src/lem/eval.rs b/src/lem/eval.rs index 6e0eb8632c..a3c828d841 100644 --- a/src/lem/eval.rs +++ b/src/lem/eval.rs @@ -632,11 +632,9 @@ fn apply_cont() -> Func { Symbol("u64") => { match result.tag { Expr::Num => { - // TODO we also need to use `Mod` to truncate - // But 2^64 is out-of-range of u64, so we will - // maybe use u128 - // let limit = Num(18446744073709551616); - let cast = cast(result, Expr::U64); + // The limit is 2**64 - 1 + let trunc = bitwise_and(result, 18446744073709551615); + let cast = cast(trunc, Expr::U64); return(cast, env, continuation, makethunk) } Expr::U64 => { @@ -656,12 +654,15 @@ fn apply_cont() -> Func { } Symbol("char") => { match result.tag { - Expr::Num | Expr::Char => { - // TODO we also need to use `Mod` to truncate - // let limit = Num(4294967296); - let cast = cast(result, Expr::Num); + Expr::Num => { + // The limit is 2**32 - 1 + let trunc = bitwise_and(result, 4294967295); + let cast = cast(trunc, Expr::Char); return(cast, env, continuation, makethunk) } + Expr::Char => { + return(result, env, continuation, makethunk) + } }; return(result, env, err, errctrl) } @@ -802,8 +803,11 @@ fn apply_cont() -> Func { return (val, env, continuation, makethunk) } Num(2) => { - // TODO - return (result, env, err, errctrl) + let val = mul(evaled_arg, result); + // The limit is 2**64 - 1 + let trunc = bitwise_and(val, 18446744073709551615); + let cast = cast(trunc, Expr::U64); + return (cast, env, continuation, makethunk) } } } @@ -955,8 +959,8 @@ mod tests { use blstrs::Scalar as Fr; const NUM_INPUTS: usize = 1; - const NUM_AUX: usize = 8885; - const NUM_CONSTRAINTS: usize = 11139; + const NUM_AUX: usize = 9655; + const NUM_CONSTRAINTS: usize = 11912; const NUM_SLOTS: SlotsCounter = SlotsCounter { hash2: 16, hash3: 4, @@ -1009,6 +1013,14 @@ mod tests { } fn expr_in_expr_out_pairs(s: &mut Store) -> Vec<(Ptr, Ptr)> { + let u64_1 = s.read("(u64 100000000)").unwrap(); + let u64_1_res = s.read("100000000u64").unwrap(); + let u64_2 = s.read("(u64 1000000000000000000000000)").unwrap(); + let u64_2_res = s.read("2003764205206896640u64").unwrap(); + let mul_overflow = s.read("(* 1000000000000u64 100000000000000u64)").unwrap(); + let mul_overflow_res = s.read("15908979783594147840u64").unwrap(); + let char_conv = s.read("(char 97)").unwrap(); + let char_conv_res = s.read("'a'").unwrap(); let t = s.read("t").unwrap(); let nil = s.read("nil").unwrap(); let le1 = s.read("(<= 4 8)").unwrap(); @@ -1048,6 +1060,10 @@ mod tests { .unwrap(); let fold_res = s.read("55").unwrap(); vec![ + (u64_1, u64_1_res), + (u64_2, u64_2_res), + (mul_overflow, mul_overflow_res), + (char_conv, char_conv_res), (le1, t), (le2, t), (le3, nil), diff --git a/src/lem/interpreter.rs b/src/lem/interpreter.rs index fb5c2f8020..01b988d022 100644 --- a/src/lem/interpreter.rs +++ b/src/lem/interpreter.rs @@ -139,9 +139,7 @@ impl Block { let a = bindings.get(a)?; let b = bindings.get(b)?; let c = match (a, b) { - (Ptr::Leaf(Tag::Expr(Num), f), Ptr::Leaf(Tag::Expr(Num), g)) => { - Ptr::Leaf(Tag::Expr(Num), *f + *g) - } + (Ptr::Leaf(_, f), Ptr::Leaf(_, g)) => Ptr::Leaf(Tag::Expr(Num), *f + *g), _ => bail!("Addition only works on numbers"), }; bindings.insert(tgt.clone(), c); @@ -150,9 +148,7 @@ impl Block { let a = bindings.get(a)?; let b = bindings.get(b)?; let c = match (a, b) { - (Ptr::Leaf(Tag::Expr(Num), f), Ptr::Leaf(Tag::Expr(Num), g)) => { - Ptr::Leaf(Tag::Expr(Num), *f - *g) - } + (Ptr::Leaf(_, f), Ptr::Leaf(_, g)) => Ptr::Leaf(Tag::Expr(Num), *f - *g), _ => bail!("Addition only works on numbers"), }; bindings.insert(tgt.clone(), c); @@ -161,9 +157,7 @@ impl Block { let a = bindings.get(a)?; let b = bindings.get(b)?; let c = match (a, b) { - (Ptr::Leaf(Tag::Expr(Num), f), Ptr::Leaf(Tag::Expr(Num), g)) => { - Ptr::Leaf(Tag::Expr(Num), *f * *g) - } + (Ptr::Leaf(_, f), Ptr::Leaf(_, g)) => Ptr::Leaf(Tag::Expr(Num), *f * *g), _ => bail!("Addition only works on numbers"), }; bindings.insert(tgt.clone(), c); @@ -172,7 +166,7 @@ impl Block { let a = bindings.get(a)?; let b = bindings.get(b)?; let c = match (a, b) { - (Ptr::Leaf(Tag::Expr(Num), f), Ptr::Leaf(Tag::Expr(Num), g)) => { + (Ptr::Leaf(_, f), Ptr::Leaf(_, g)) => { Ptr::Leaf(Tag::Expr(Num), *f * g.invert().unwrap()) } _ => bail!("Division only works on numbers"), @@ -183,7 +177,7 @@ impl Block { let a = bindings.get(a)?; let b = bindings.get(b)?; let c = match (a, b) { - (Ptr::Leaf(Tag::Expr(Num), f), Ptr::Leaf(Tag::Expr(Num), g)) => { + (Ptr::Leaf(_, f), Ptr::Leaf(_, g)) => { preimages .is_diff_neg .push(Some(PreimageData::FPair(*f, *g))); @@ -196,6 +190,17 @@ impl Block { }; bindings.insert(tgt.clone(), c); } + Op::BitAnd(tgt, a, b) => { + let a = bindings.get(a)?; + + let c = match a { + Ptr::Leaf(_, f) => { + Ptr::Leaf(Tag::Expr(Num), F::from_u64(f.to_u64_unchecked() & b)) + } + _ => bail!("`&` only works on numbers"), + }; + bindings.insert(tgt.clone(), c); + } Op::Emit(a) => { let a = bindings.get(a)?; println!("{}", a.dgb_display(store)) diff --git a/src/lem/macros.rs b/src/lem/macros.rs index 977fdfe600..65031ba7c1 100644 --- a/src/lem/macros.rs +++ b/src/lem/macros.rs @@ -107,6 +107,13 @@ macro_rules! op { $crate::var!($b), ) }; + ( let $tgt:ident = bitwise_and($a:ident, $b:literal) ) => { + $crate::lem::Op::BitAnd( + $crate::var!($tgt), + $crate::var!($a), + $b, + ) + }; ( emit($v:ident) ) => { $crate::lem::Op::Emit($crate::var!($v)) }; @@ -334,6 +341,16 @@ macro_rules! block { $($tail)* ) }; + (@seq {$($limbs:expr)*}, let $tgt:ident = bitwise_and($a:ident, $b:literal) ; $($tail:tt)*) => { + $crate::block! ( + @seq + { + $($limbs)* + $crate::op!(let $tgt = bitwise_and($a, $b)) + }, + $($tail)* + ) + }; (@seq {$($limbs:expr)*}, emit($v:ident) ; $($tail:tt)*) => { $crate::block! ( @seq diff --git a/src/lem/mod.rs b/src/lem/mod.rs index 2f6e75746f..7a199b4156 100644 --- a/src/lem/mod.rs +++ b/src/lem/mod.rs @@ -251,6 +251,8 @@ pub enum Op { Div(Var, Var, Var), /// `Lt(y, a, b)` binds `y` to `1` if `a < b`, or to `0` otherwise Lt(Var, Var, Var), + /// `BitAnd(y, a, b)` binds `y` to `a & b` + BitAnd(Var, Var, u64), /// `Emit(v)` simply prints out the value of `v` when interpreting the code Emit(Var), /// `Hash2(x, t, ys)` binds `x` to a `Ptr` with tag `t` and 2 children `ys` @@ -357,6 +359,10 @@ impl Func { is_bound(b, map)?; is_unique(tgt, map); } + Op::BitAnd(tgt, a, _) => { + is_bound(a, map)?; + is_unique(tgt, map); + } Op::Emit(a) => { is_bound(a, map)?; } @@ -602,6 +608,11 @@ impl Block { let tgt = insert_one(map, uniq, &tgt); ops.push(Op::Lt(tgt, a, b)) } + Op::BitAnd(tgt, a, b) => { + let a = map.get_cloned(&a)?; + let tgt = insert_one(map, uniq, &tgt); + ops.push(Op::BitAnd(tgt, a, b)) + } Op::Emit(a) => { let a = map.get_cloned(&a)?; ops.push(Op::Emit(a))