Skip to content

Commit

Permalink
use existing euclidean division impl
Browse files Browse the repository at this point in the history
  • Loading branch information
dark64 committed Nov 15, 2023
1 parent 940ccb8 commit 11bee6a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 82 deletions.
118 changes: 38 additions & 80 deletions zokrates_codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1278,16 +1278,10 @@ impl<'ast, T: Field> Flattener<'ast, T> {
fn euclidean_division(
&mut self,
statements_flattened: &mut FlatStatements<'ast, T>,
target_bitwidth: UBitwidth,
left: UExpression<'ast, T>,
right: UExpression<'ast, T>,
target_bitwidth: usize,
left_flattened: FlatExpression<T>,
right_flattened: FlatExpression<T>,
) -> (FlatExpression<T>, FlatExpression<T>) {
let left_flattened = self
.flatten_uint_expression(statements_flattened, left)
.get_field_unchecked();
let right_flattened = self
.flatten_uint_expression(statements_flattened, right)
.get_field_unchecked();
let n = if left_flattened.is_linear() {
left_flattened
} else {
Expand All @@ -1313,8 +1307,6 @@ impl<'ast, T: Field> Flattener<'ast, T> {
vec![n.clone(), d.clone()],
));

let target_bitwidth = target_bitwidth.to_usize();

// q in range
let _ = self.get_bits_unchecked(
&FlatUExpression::with_field(FlatExpression::from(q)),
Expand All @@ -1337,7 +1329,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let _ = self.get_bits_unchecked(
&FlatUExpression::with_field(FlatExpression::add(
FlatExpression::sub(r.into(), d.clone()),
FlatExpression::value(T::from(2_u128.pow(target_bitwidth as u32))),
FlatExpression::value(T::from(2).pow(target_bitwidth)),
)),
target_bitwidth,
target_bitwidth,
Expand Down Expand Up @@ -1558,21 +1550,35 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FlatUExpression::with_field(FlatExpression::identifier(res))
}
UExpressionInner::Div(e) => {
let left_flattened = self
.flatten_uint_expression(statements_flattened, *e.left)
.get_field_unchecked();
let right_flattened = self
.flatten_uint_expression(statements_flattened, *e.right)
.get_field_unchecked();

let (q, _) = self.euclidean_division(
statements_flattened,
target_bitwidth,
*e.left,
*e.right,
target_bitwidth.to_usize(),
left_flattened,
right_flattened,
);

FlatUExpression::with_field(q)
}
UExpressionInner::Rem(e) => {
let left_flattened = self
.flatten_uint_expression(statements_flattened, *e.left)
.get_field_unchecked();
let right_flattened = self
.flatten_uint_expression(statements_flattened, *e.right)
.get_field_unchecked();

let (_, r) = self.euclidean_division(
statements_flattened,
target_bitwidth,
*e.left,
*e.right,
target_bitwidth.to_usize(),
left_flattened,
right_flattened,
);

FlatUExpression::with_field(r)
Expand Down Expand Up @@ -2196,76 +2202,28 @@ impl<'ast, T: Field> Flattener<'ast, T> {
FieldElementExpression::IDiv(e) => {
let left_flattened = self.flatten_field_expression(statements_flattened, *e.left);
let right_flattened = self.flatten_field_expression(statements_flattened, *e.right);
let new_left: FlatExpression<T> = {
let id = self.use_sym();
statements_flattened.push_back(FlatStatement::definition(id, left_flattened));
id.into()
};
let new_right: FlatExpression<T> = {
let id = self.use_sym();
statements_flattened.push_back(FlatStatement::definition(id, right_flattened));
id.into()
};

// q is the quotient and r is the remainder
// we simply need to constrain `a = b * q + r` and `0 <= r < b` to hold
let q = self.use_sym();
let r = self.use_sym();

// # q, r = a \ b
statements_flattened.push_back(FlatStatement::directive(
vec![q, r],
Solver::EuclideanDiv,
vec![new_left.clone(), new_right.clone()],
));

// q * b == a - r
statements_flattened.push_back(FlatStatement::condition(
FlatExpression::sub(new_left, r.into()),
FlatExpression::mul(q.into(), new_right),
RuntimeError::Euclidean,
));

// todo: enforce `r < b`
let (q, _) = self.euclidean_division(
statements_flattened,
T::get_required_bits() - 2,
left_flattened,
right_flattened,
);

q.into()
q
}
FieldElementExpression::Rem(e) => {
let left_flattened = self.flatten_field_expression(statements_flattened, *e.left);
let right_flattened = self.flatten_field_expression(statements_flattened, *e.right);
let new_left: FlatExpression<T> = {
let id = self.use_sym();
statements_flattened.push_back(FlatStatement::definition(id, left_flattened));
id.into()
};
let new_right: FlatExpression<T> = {
let id = self.use_sym();
statements_flattened.push_back(FlatStatement::definition(id, right_flattened));
id.into()
};

// q is the quotient and r is the remainder
// we simply need to constrain `a = b * q + r` and `0 <= r < b` to hold
let q = self.use_sym();
let r = self.use_sym();

// # q, r = a \ b
statements_flattened.push_back(FlatStatement::directive(
vec![q, r],
Solver::EuclideanDiv,
vec![new_left.clone(), new_right.clone()],
));

// q * b == a - r
statements_flattened.push_back(FlatStatement::condition(
FlatExpression::sub(new_left, r.into()),
FlatExpression::mul(q.into(), new_right),
RuntimeError::Euclidean,
));

// todo: enforce `r < b`
let (_, r) = self.euclidean_division(
statements_flattened,
T::get_required_bits() - 2,
left_flattened,
right_flattened,
);

r.into()
r
}
FieldElementExpression::Pow(e) => {
match e.right.into_inner() {
Expand Down
2 changes: 1 addition & 1 deletion zokrates_core_test/tests/tests/assembly/idiv.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"curves": ["Bn128"],
"max_constraint_count": 261,
"max_constraint_count": 766,
"tests": [
{
"input": {
Expand Down
2 changes: 1 addition & 1 deletion zokrates_core_test/tests/tests/idiv.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"max_constraint_count": 3,
"max_constraint_count": 1131,
"curves": ["Bn128", "Bls12_381", "Bls12_377", "Bw6_761"],
"tests": [
{
Expand Down

0 comments on commit 11bee6a

Please sign in to comment.