Skip to content

Commit

Permalink
Added equality operator
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-barrett committed Aug 16, 2023
1 parent 6a01555 commit f1f9bda
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 15 deletions.
37 changes: 30 additions & 7 deletions src/lem/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,10 +585,35 @@ impl Func {
let allocated_ptr = AllocatedPtr::from_parts(tag, src.hash().clone());
bound_allocations.insert(tgt.clone(), allocated_ptr);
}
Op::EqTag(tgt, a, b) => {
let a = bound_allocations.get(a)?;
let b = bound_allocations.get(b)?;
let a_num = a.tag();
let b_num = b.tag();
let eq = alloc_equal(&mut cs.namespace(|| "equal_tag"), a_num, b_num)?;
let c_num = boolean_to_num(&mut cs.namespace(|| "equal_tag.to_num"), &eq)?;
let tag = g
.global_allocator
.get_or_alloc_const(cs, Tag::Expr(Num).to_field())?;
let c = AllocatedPtr::from_parts(tag, c_num);
bound_allocations.insert(tgt.clone(), c);
}
Op::EqVal(tgt, a, b) => {
let a = bound_allocations.get(a)?;
let b = bound_allocations.get(b)?;
let a_num = a.hash();
let b_num = b.hash();
let eq = alloc_equal(&mut cs.namespace(|| "equal_val"), a_num, b_num)?;
let c_num = boolean_to_num(&mut cs.namespace(|| "equal_val.to_num"), &eq)?;
let tag = g
.global_allocator
.get_or_alloc_const(cs, Tag::Expr(Num).to_field())?;
let c = AllocatedPtr::from_parts(tag, c_num);
bound_allocations.insert(tgt.clone(), c);
}
Op::Add(tgt, a, b) => {
let a = bound_allocations.get(a)?;
let b = bound_allocations.get(b)?;
// TODO check that the tags are correct
let a_num = a.hash();
let b_num = b.hash();
let c_num = add(&mut cs.namespace(|| "add"), a_num, b_num)?;
Expand All @@ -601,7 +626,6 @@ impl Func {
Op::Sub(tgt, a, b) => {
let a = bound_allocations.get(a)?;
let b = bound_allocations.get(b)?;
// TODO check that the tags are correct
let a_num = a.hash();
let b_num = b.hash();
let c_num = sub(&mut cs.namespace(|| "sub"), a_num, b_num)?;
Expand All @@ -614,7 +638,6 @@ impl Func {
Op::Mul(tgt, a, b) => {
let a = bound_allocations.get(a)?;
let b = bound_allocations.get(b)?;
// TODO check that the tags are correct
let a_num = a.hash();
let b_num = b.hash();
let c_num = mul(&mut cs.namespace(|| "mul"), a_num, b_num)?;
Expand All @@ -627,7 +650,6 @@ impl Func {
Op::Div(tgt, a, b) => {
let a = bound_allocations.get(a)?;
let b = bound_allocations.get(b)?;
// TODO check that the tags are correct
let a_num = a.hash();
let b_num = b.hash();

Expand All @@ -652,7 +674,6 @@ impl Func {
Op::Lt(tgt, a, b) => {
let a = bound_allocations.get(a)?;
let b = bound_allocations.get(b)?;
// TODO check that the tags are correct
let tag = g
.global_allocator
.get_or_alloc_const(cs, Tag::Expr(Num).to_field())?;
Expand Down Expand Up @@ -782,10 +803,8 @@ impl Func {
Ctrl::IfEq(x, y, eq_block, else_block) => {
let x = bound_allocations.get(x)?.hash();
let y = bound_allocations.get(y)?.hash();
// TODO should we check whether the tags are equal too?
let eq = alloc_equal(&mut cs.namespace(|| "if_eq.alloc_equal"), x, y)?;
let not_eq = eq.not();
// TODO is this the most efficient way of doing if statements?
let not_dummy_and_eq = and(&mut cs.namespace(|| "if_eq.and"), not_dummy, &eq)?;
let not_dummy_and_not_eq =
and(&mut cs.namespace(|| "if_eq.and.2"), not_dummy, &not_eq)?;
Expand Down Expand Up @@ -1027,6 +1046,10 @@ impl Func {
Op::Cast(_tgt, tag, _src) => {
globals.insert(FWrap(tag.to_field()));
}
Op::EqTag(_, _, _) | Op::EqVal(_, _, _) => {
globals.insert(FWrap(Tag::Expr(Num).to_field()));
num_constraints += 5;
}
Op::Add(_, _, _) | Op::Sub(_, _, _) | Op::Mul(_, _, _) => {
globals.insert(FWrap(Tag::Expr(Num).to_field()));
num_constraints += 1;
Expand Down
18 changes: 12 additions & 6 deletions src/lem/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,11 +725,17 @@ fn apply_cont() -> Func {
return(hidden, env, continuation, makethunk)
}
Symbol("eq") => {
// TODO should we check whether the tags are also equal?
if evaled_arg == result {
return (t, env, continuation, makethunk)
let eq_tag = eq_tag(evaled_arg, result);
let eq_val = eq_val(evaled_arg, result);
let eq = mul(eq_tag, eq_val);
match eq.val {
Num(0) => {
return (nil, env, continuation, makethunk)
}
Num(1) => {
return (t, env, continuation, makethunk)
}
}
return (nil, env, continuation, makethunk)
}
Symbol("+") => {
match args_num_type.val {
Expand Down Expand Up @@ -924,8 +930,8 @@ mod tests {
use blstrs::Scalar as Fr;

const NUM_INPUTS: usize = 1;
const NUM_AUX: usize = 8781;
const NUM_CONSTRAINTS: usize = 10875;
const NUM_AUX: usize = 8868;
const NUM_CONSTRAINTS: usize = 11096;
const NUM_SLOTS: SlotsCounter = SlotsCounter {
hash2: 16,
hash3: 4,
Expand Down
24 changes: 24 additions & 0 deletions src/lem/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,30 @@ impl Block {
let tgt_ptr = src_ptr.cast(*tag);
bindings.insert(tgt.clone(), tgt_ptr);
}
Op::EqTag(tgt, a, b) => {
let a = bindings.get(a)?;
let b = bindings.get(b)?;
let c = if a.tag() == b.tag() {
Ptr::Leaf(Tag::Expr(Num), F::ONE)
} else {
Ptr::Leaf(Tag::Expr(Num), F::ZERO)
};
bindings.insert(tgt.clone(), c);
}
Op::EqVal(tgt, a, b) => {
let a = bindings.get(a)?;
let b = bindings.get(b)?;
// In order to compare Ptrs, we *must* resolve the hashes. Otherwise, we risk failing to recognize equality of
// compound data with opaque data in either element's transitive closure.
let a_hash = store.hash_ptr(a)?.hash;
let b_hash = store.hash_ptr(b)?.hash;
let c = if a_hash == b_hash {
Ptr::Leaf(Tag::Expr(Num), F::ONE)
} else {
Ptr::Leaf(Tag::Expr(Num), F::ZERO)
};
bindings.insert(tgt.clone(), c);
}
Op::Add(tgt, a, b) => {
let a = bindings.get(a)?;
let b = bindings.get(b)?;
Expand Down
34 changes: 34 additions & 0 deletions src/lem/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ macro_rules! op {
$crate::var!($src),
)
};
( let $tgt:ident = eq_tag($a:ident, $b:ident) ) => {
$crate::lem::Op::EqTag(
$crate::var!($tgt),
$crate::var!($a),
$crate::var!($b),
)
};
( let $tgt:ident = eq_val($a:ident, $b:ident) ) => {
$crate::lem::Op::EqVal(
$crate::var!($tgt),
$crate::var!($a),
$crate::var!($b),
)
};
( let $tgt:ident = add($a:ident, $b:ident) ) => {
$crate::lem::Op::Add(
$crate::var!($tgt),
Expand Down Expand Up @@ -250,6 +264,26 @@ macro_rules! block {
$($tail)*
)
};
(@seq {$($limbs:expr)*}, let $tgt:ident = eq_tag($a:ident, $b:ident) ; $($tail:tt)*) => {
$crate::block! (
@seq
{
$($limbs)*
$crate::op!(let $tgt = eq_tag($a, $b))
},
$($tail)*
)
};
(@seq {$($limbs:expr)*}, let $tgt:ident = eq_val($a:ident, $b:ident) ; $($tail:tt)*) => {
$crate::block! (
@seq
{
$($limbs)*
$crate::op!(let $tgt = eq_val($a, $b))
},
$($tail)*
)
};
(@seq {$($limbs:expr)*}, let $tgt:ident = add($a:ident, $b:ident) ; $($tail:tt)*) => {
$crate::block! (
@seq
Expand Down
22 changes: 20 additions & 2 deletions src/lem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,18 @@ pub enum Op {
/// `Cast(y, t, x)` binds `y` to a pointer with tag `t` and the hash of `x`
Cast(Var, Tag, Var),
/// `Add(y, a, b)` binds `y` to the sum of `a` and `b`
EqTag(Var, Var, Var),
/// `EqVal(y, a, b)` binds `y` to `1` if `a.val != b.val`, or to `0` otherwise
EqVal(Var, Var, Var),
/// `Lt(y, a, b)` binds `y` to `1` if `a < b`, or to `0` otherwise
Add(Var, Var, Var),
/// `Sub(y, a, b)` binds `y` to the sum of `a` and `b`
Sub(Var, Var, Var),
/// `Mul(y, a, b)` binds `y` to the sum of `a` and `b`
Mul(Var, Var, Var),
/// `Div(y, a, b)` binds `y` to the sum of `a` and `b`
Div(Var, Var, Var),
/// `Lt(y, a, b)` binds `y` to `t` if `a < b`, or to `nil` otherwise
/// `Lt(y, a, b)` binds `y` to `1` if `a < b`, or to `0` otherwise
Lt(Var, Var, Var),
/// `Emit(v)` simply prints out the value of `v` when interpreting the code
Emit(Var),
Expand Down Expand Up @@ -342,7 +346,9 @@ impl Func {
is_bound(src, map)?;
is_unique(tgt, map);
}
Op::Add(tgt, a, b)
Op::EqTag(tgt, a, b)
| Op::EqVal(tgt, a, b)
| Op::Add(tgt, a, b)
| Op::Sub(tgt, a, b)
| Op::Mul(tgt, a, b)
| Op::Div(tgt, a, b)
Expand Down Expand Up @@ -554,6 +560,18 @@ impl Block {
let tgt = insert_one(map, uniq, &tgt);
ops.push(Op::Cast(tgt, tag, src))
}
Op::EqTag(tgt, a, b) => {
let a = map.get_cloned(&a)?;
let b = map.get_cloned(&b)?;
let tgt = insert_one(map, uniq, &tgt);
ops.push(Op::EqTag(tgt, a, b))
}
Op::EqVal(tgt, a, b) => {
let a = map.get_cloned(&a)?;
let b = map.get_cloned(&b)?;
let tgt = insert_one(map, uniq, &tgt);
ops.push(Op::EqVal(tgt, a, b))
}
Op::Add(tgt, a, b) => {
let a = map.get_cloned(&a)?;
let b = map.get_cloned(&b)?;
Expand Down

0 comments on commit f1f9bda

Please sign in to comment.