Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZIR-194: Fix and test BigInt type inference #45

Merged
merged 40 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
81b17c5
Add TODO to bigint docs
tzerrell Sep 27, 2024
012d31e
Add BigInt lit tests about type inference
tzerrell Sep 27, 2024
f006d18
Organize and label bigint lit tests
tzerrell Sep 27, 2024
87bda1f
Fix typo of add for sub in test
tzerrell Sep 27, 2024
b9adeb6
Fix type inference for add/sub max_pos/max_neg
tzerrell Sep 27, 2024
245cb31
Fix type inference add min_bits
tzerrell Sep 27, 2024
ec6fd73
Fix type inference mul coeffs
tzerrell Sep 27, 2024
950fd16
Expand add tests
tzerrell Sep 27, 2024
ccc9481
Add tests for sub coeffs & min_bits
tzerrell Sep 27, 2024
bec9cfb
Fix type inference max terms feeding mul coeff
tzerrell Sep 27, 2024
a7d4fce
Add mul tests
tzerrell Sep 27, 2024
1c88883
Move comments near the tests
tzerrell Sep 30, 2024
14c3379
Add lit tests for nondet_quot
tzerrell Oct 1, 2024
ec3b1f1
Rename BigInt type inference lit test file
tzerrell Oct 1, 2024
dc6e844
Add lit tests for nondet_rem
tzerrell Oct 7, 2024
02b0b2c
Add lit tests for nondet_invmod
tzerrell Oct 7, 2024
bcf6e67
Clean up comments
tzerrell Oct 7, 2024
d36c565
Add inverse to nondet_invmod lit tests
tzerrell Oct 7, 2024
72c5a0e
Add reduce to the nondet_rem lit tests
tzerrell Oct 7, 2024
4b18728
Clean comments
tzerrell Oct 7, 2024
0b29308
Fix BigInt getMaxBits
tzerrell Oct 8, 2024
674fb7a
Fix type inference for nondet_rem/reduce on small lhs
tzerrell Oct 8, 2024
ed69d27
Add lit test for coefficient carrying in normalization
tzerrell Oct 8, 2024
76ed0d0
Fix getMaxBits for carries
tzerrell Oct 8, 2024
05199cc
Improve comments and fix off-by-one
tzerrell Oct 8, 2024
c3a621a
Rename getMaxBits -> getMaxPosBits
tzerrell Oct 8, 2024
1868e3d
Update r1cs mul bigint lit test for new type infer
tzerrell Oct 8, 2024
d5c92df
Update r1cs tests for new type inference
tzerrell Oct 8, 2024
56f5061
Add tests for negative nondets; clean comments
tzerrell Oct 10, 2024
97e4d06
Update BigInt inverse names
tzerrell Oct 10, 2024
76c06f7
Check for BigInt overflow and minBit negatives
tzerrell Oct 9, 2024
d3cda58
Remove unused code
tzerrell Oct 10, 2024
cd72649
Clean up comments
tzerrell Oct 10, 2024
091e1c1
Skip failing (overflowing) circom tests
tzerrell Oct 10, 2024
8ee1add
Update MulOp range to not overflow 64 bits
tzerrell Oct 10, 2024
3d4a0bd
Format
tzerrell Oct 10, 2024
49bd818
Format with different clang-format version
tzerrell Oct 10, 2024
830f15e
Shorter test names
tzerrell Oct 10, 2024
e727b7a
Improve BigInt type inference comments & tests
tzerrell Oct 11, 2024
9e23be2
Expand comments further
tzerrell Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions zirgen/Dialect/BigInt/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ cc_library(
"Dialect.cpp",
"Eval.cpp",
"Ops.cpp",
"Types.cpp",
],
hdrs = [
"BigInt.h",
Expand Down
6 changes: 3 additions & 3 deletions zirgen/Dialect/BigInt/IR/Eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ BytePoly nondetRem(const BytePoly& lhs, const BytePoly& rhs, size_t coeffs) {
return fromAPInt(rem, coeffs);
}

BytePoly nondetInvMod(const BytePoly& lhs, const BytePoly& rhs, size_t coeffs) {
BytePoly nondetInv(const BytePoly& lhs, const BytePoly& rhs, size_t coeffs) {
// Uses the formula n^(p-2) * n = 1 (mod p) to invert `lhs` (mod `rhs`)
// (via the square and multiply technique)
auto lhsInt = toAPInt(lhs);
Expand Down Expand Up @@ -226,9 +226,9 @@ EvalOutput eval(func::FuncOp inFunc, ArrayRef<APInt> witnessValues) {
polys[op.getOut()] = poly;
ret.privateWitness.push_back(poly);
})
.Case<NondetInvModOp>([&](auto op) {
.Case<NondetInvOp>([&](auto op) {
uint32_t coeffs = op.getOut().getType().getCoeffs();
auto poly = nondetInvMod(polys[op.getLhs()], polys[op.getRhs()], coeffs);
auto poly = nondetInv(polys[op.getLhs()], polys[op.getRhs()], coeffs);
polys[op.getOut()] = poly;
ret.privateWitness.push_back(poly);
})
Expand Down
95 changes: 65 additions & 30 deletions zirgen/Dialect/BigInt/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
using namespace mlir;
using risc0::ceilDiv;

// Additional comments on how type inference works for the BigInt dialect can be found in
// `test/type_infer.mlir`, including descriptions at the beginning of each op's suite of tests.

namespace zirgen::BigInt {

// Type inference
Expand Down Expand Up @@ -66,9 +69,10 @@ LogicalResult AddOp::inferReturnTypes(MLIRContext* ctx,
auto lhsType = adaptor.getLhs().getType().cast<BigIntType>();
auto rhsType = adaptor.getRhs().getType().cast<BigIntType>();
size_t maxCoeffs = std::max(lhsType.getCoeffs(), rhsType.getCoeffs());
size_t maxPos = std::max(lhsType.getMaxPos(), rhsType.getMaxPos());
size_t maxNeg = std::max(lhsType.getMaxNeg(), rhsType.getMaxNeg());
size_t minBits = std::max(lhsType.getMinBits(), rhsType.getMinBits());
size_t maxPos = lhsType.getMaxPos() + rhsType.getMaxPos();
size_t maxNeg = lhsType.getMaxNeg() + rhsType.getMaxNeg();
// TODO: We could be more clever on minBits, but probably doesn't matter
size_t minBits = maxNeg > 0 ? 0 : std::max(lhsType.getMinBits(), rhsType.getMinBits());
out.push_back(BigIntType::get(ctx, maxCoeffs, maxPos, maxNeg, minBits));
return success();
}
Expand All @@ -80,8 +84,8 @@ LogicalResult SubOp::inferReturnTypes(MLIRContext* ctx,
auto lhsType = adaptor.getLhs().getType().cast<BigIntType>();
auto rhsType = adaptor.getRhs().getType().cast<BigIntType>();
size_t maxCoeffs = std::max(lhsType.getCoeffs(), rhsType.getCoeffs());
size_t maxPos = std::max(lhsType.getMaxPos(), rhsType.getMaxNeg());
size_t maxNeg = std::max(lhsType.getMaxNeg(), rhsType.getMaxPos());
size_t maxPos = lhsType.getMaxPos() + rhsType.getMaxNeg();
size_t maxNeg = lhsType.getMaxNeg() + rhsType.getMaxPos();
// TODO: We could be more clever on minBits, but probably doesn't matter
out.push_back(BigIntType::get(ctx, maxCoeffs, maxPos, maxNeg, 0));
return success();
Expand All @@ -93,30 +97,56 @@ LogicalResult MulOp::inferReturnTypes(MLIRContext* ctx,
SmallVectorImpl<Type>& out) {
auto lhsType = adaptor.getLhs().getType().cast<BigIntType>();
auto rhsType = adaptor.getRhs().getType().cast<BigIntType>();
size_t maxCoeffs = std::max(lhsType.getCoeffs(), rhsType.getCoeffs());
size_t totCoeffs = lhsType.getCoeffs() + rhsType.getCoeffs();
size_t maxPos = std::max(lhsType.getMaxPos() * rhsType.getMaxPos(),
lhsType.getMaxNeg() * rhsType.getMaxNeg()) *
maxCoeffs;
size_t maxNeg = std::max(lhsType.getMaxPos() * rhsType.getMaxNeg(),
lhsType.getMaxNeg() * rhsType.getMaxPos()) *
maxCoeffs;
size_t coeffs = lhsType.getCoeffs() + rhsType.getCoeffs() - 1;
// The maximum number of coefficient pairs from the inputs used to calculate an output coefficient
size_t maxCoeffs = std::min(lhsType.getCoeffs(), rhsType.getCoeffs());
// This calculation could overflow if size_t is 32 bits, so cast to 64 bits
uint64_t maxPos = std::max((uint64_t)lhsType.getMaxPos() * rhsType.getMaxPos(),
(uint64_t)lhsType.getMaxNeg() * rhsType.getMaxNeg());
// The next step can potentially overflow even 64 bits; but if we're already above 32 bits we'll
// fail validation anyway. Therefore, skip this if we're above 32 bits
if (maxPos < (uint64_t)1 << 32) {
maxPos *= maxCoeffs;
}
// Clamp to size_t
if (maxPos > std::numeric_limits<size_t>::max()) {
maxPos = std::numeric_limits<size_t>::max();
}
// As with maxPos, this could overflow if size_t is 32 bits, so cast to 64 bits
uint64_t maxNeg = std::max((uint64_t)lhsType.getMaxPos() * rhsType.getMaxNeg(),
(uint64_t)lhsType.getMaxNeg() * rhsType.getMaxPos());
// The next step can potentially overflow even 64 bits; but if we're already above 32 bits we'll
// fail validation anyway. Therefore, skip this if we're above 32 bits
if (maxNeg < (uint64_t)1 << 32) {
maxNeg *= maxCoeffs;
}
// Clamp to size_t
if (maxNeg > std::numeric_limits<size_t>::max()) {
maxNeg = std::numeric_limits<size_t>::max();
}
size_t minBits;
if (lhsType.getMinBits() == 0 || rhsType.getMinBits() == 0) {
// Note that this catches _both_ cases where the input might be zero _and_ cases where the input
// might be negative, as type verification enforces that when minBits is zero, so is maxNeg.
minBits = 0;
} else {
minBits = lhsType.getMinBits() + rhsType.getMinBits() - 1;
}
out.push_back(BigIntType::get(ctx, totCoeffs, maxPos, maxNeg, minBits));
out.push_back(BigIntType::get(ctx, coeffs, maxPos, maxNeg, minBits));
return success();
}

LogicalResult NondetRemOp::inferReturnTypes(MLIRContext* ctx,
std::optional<Location> loc,
Adaptor adaptor,
SmallVectorImpl<Type>& out) {
auto lhsType = adaptor.getLhs().getType().cast<BigIntType>();
auto rhsType = adaptor.getRhs().getType().cast<BigIntType>();
size_t coeffsWidth = ceilDiv(rhsType.getMaxBits(), kBitsPerCoeff);
auto outBits = lhsType.getMaxPosBits();
if (rhsType.getMaxPosBits() < outBits) {
outBits = rhsType.getMaxPosBits();
}
size_t coeffsWidth = ceilDiv(outBits, kBitsPerCoeff);
out.push_back(BigIntType::get(ctx,
/*coeffs=*/coeffsWidth,
/*maxPos=*/(1 << kBitsPerCoeff) - 1,
Expand All @@ -131,26 +161,26 @@ LogicalResult NondetQuotOp::inferReturnTypes(MLIRContext* ctx,
SmallVectorImpl<Type>& out) {
auto lhsType = adaptor.getLhs().getType().cast<BigIntType>();
auto rhsType = adaptor.getRhs().getType().cast<BigIntType>();
size_t outBits = lhsType.getMaxBits();
size_t outBits = lhsType.getMaxPosBits();
if (rhsType.getMinBits() > 0) {
outBits -= rhsType.getMinBits() - 1;
}
size_t coeffsWidth = ceilDiv(outBits, kBitsPerCoeff);
// TODO: We could be more clever on minBits, but probably doesn't matter
out.push_back(BigIntType::get(ctx,
/*coeffs=*/coeffsWidth,
/*maxPos=*/(1 << kBitsPerCoeff) - 1,
/*maxNeg=*/0,
/*minBits=*/0 /*TODO: maybe better bound? */
));
/*minBits=*/0));
return success();
}

LogicalResult NondetInvModOp::inferReturnTypes(MLIRContext* ctx,
std::optional<Location> loc,
Adaptor adaptor,
SmallVectorImpl<Type>& out) {
LogicalResult NondetInvOp::inferReturnTypes(MLIRContext* ctx,
std::optional<Location> loc,
Adaptor adaptor,
SmallVectorImpl<Type>& out) {
auto rhsType = adaptor.getRhs().getType().cast<BigIntType>();
size_t coeffsWidth = ceilDiv(rhsType.getMaxBits(), kBitsPerCoeff);
size_t coeffsWidth = ceilDiv(rhsType.getMaxPosBits(), kBitsPerCoeff);
out.push_back(BigIntType::get(ctx,
/*coeffs=*/coeffsWidth,
/*maxPos=*/(1 << kBitsPerCoeff) - 1,
Expand All @@ -159,12 +189,12 @@ LogicalResult NondetInvModOp::inferReturnTypes(MLIRContext* ctx,
return success();
}

LogicalResult ModularInvOp::inferReturnTypes(MLIRContext* ctx,
std::optional<Location> loc,
Adaptor adaptor,
SmallVectorImpl<Type>& out) {
LogicalResult InvOp::inferReturnTypes(MLIRContext* ctx,
std::optional<Location> loc,
Adaptor adaptor,
SmallVectorImpl<Type>& out) {
auto rhsType = adaptor.getRhs().getType().cast<BigIntType>();
size_t coeffsWidth = ceilDiv(rhsType.getMaxBits(), kBitsPerCoeff);
size_t coeffsWidth = ceilDiv(rhsType.getMaxPosBits(), kBitsPerCoeff);
out.push_back(BigIntType::get(ctx,
/*coeffs=*/coeffsWidth,
/*maxPos=*/(1 << kBitsPerCoeff) - 1,
Expand All @@ -177,8 +207,13 @@ LogicalResult ReduceOp::inferReturnTypes(MLIRContext* ctx,
std::optional<Location> loc,
Adaptor adaptor,
SmallVectorImpl<Type>& out) {
auto lhsType = adaptor.getLhs().getType().cast<BigIntType>();
auto rhsType = adaptor.getRhs().getType().cast<BigIntType>();
size_t coeffsWidth = ceilDiv(rhsType.getMaxBits(), kBitsPerCoeff);
auto outBits = lhsType.getMaxPosBits();
if (rhsType.getMaxPosBits() < outBits) {
outBits = rhsType.getMaxPosBits();
}
size_t coeffsWidth = ceilDiv(outBits, kBitsPerCoeff);
out.push_back(BigIntType::get(ctx,
/*coeffs=*/coeffsWidth,
/*maxPos=*/(1 << kBitsPerCoeff) - 1,
Expand Down Expand Up @@ -254,7 +289,7 @@ void NondetQuotOp::emitExpr(codegen::CodegenEmitter& cg) {
{getLhs(), getRhs(), toConstantValue(cg, getContext(), getType().getCoeffs())});
}

void NondetInvModOp::emitExpr(codegen::CodegenEmitter& cg) {
void NondetInvOp::emitExpr(codegen::CodegenEmitter& cg) {
cg.emitInvokeMacro(
cg.getStringAttr("bigint_nondet_inv"),
/*contextArgs=*/{"ctx"},
Expand Down
4 changes: 2 additions & 2 deletions zirgen/Dialect/BigInt/IR/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def SubOp : BinaryOp<"sub", [Pure, DeclareOpInterfaceMethods<CodegenExprOpInterf
def MulOp : BinaryOp<"mul", [Pure, Commutative, DeclareOpInterfaceMethods<CodegenExprOpInterface>]> {}
def NondetRemOp : BinaryOp<"nondet_rem", [DeclareOpInterfaceMethods<CodegenExprOpInterface>]> {}
def NondetQuotOp : BinaryOp<"nondet_quot", [DeclareOpInterfaceMethods<CodegenExprOpInterface>]> {}
def NondetInvModOp : BinaryOp<"nondet_invmod", [DeclareOpInterfaceMethods<CodegenExprOpInterface>]> {}
def ModularInvOp : BinaryOp<"inv", []> {}
def NondetInvOp : BinaryOp<"nondet_inv", [DeclareOpInterfaceMethods<CodegenExprOpInterface>]> {}
def InvOp : BinaryOp<"inv", []> {}
def ReduceOp : BinaryOp<"reduce", []> {}

def EqualZeroOp : BigIntOp<"eqz", [DeclareOpInterfaceMethods<CodegenExprOpInterface>]> {
Expand Down
41 changes: 41 additions & 0 deletions zirgen/Dialect/BigInt/IR/Types.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2024 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "risc0/fp/fp.h"
#include "zirgen/Dialect/BigInt/IR/BigInt.h"
#include "zirgen/Dialect/BigInt/IR/Types.h.inc"

using namespace mlir;

namespace zirgen::BigInt {

LogicalResult BigIntType::verify(function_ref<InFlightDiagnostic()> emitError,
size_t coeffs,
size_t maxPos,
size_t maxNeg,
size_t minBits) {
if (maxNeg > 0 && minBits > 0) {
return emitError() << "BigInts with positive minBits must be positive: maxNeg: " << maxNeg
<< ", minBits: " << minBits;
}
// TODO: Think through whether maxPos / maxNeg can ever overflow their attribute type, which would
// cause problems here
if (maxPos + maxNeg >= risc0::Fp::P) {
return emitError() << "Cannot create BigInt with coefficients overflowing BabyBear: maxPos: "
<< maxPos << " + maxNeg: " << maxNeg;
}
return success();
}

} // namespace zirgen::BigInt
26 changes: 23 additions & 3 deletions zirgen/Dialect/BigInt/IR/Types.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,37 @@ def BigInt : BigIntType<"BigInt", "bigint", [
DeclareTypeInterfaceMethods<CodegenTypeInterface, ["getTypeName", "allowDuplicateTypeNames", "emitTypeDefinition"]>,
CodegenNeedsCloneType
]> {
let summary = "A big interger value represented as a polynomial";
let summary = "A big integer value represented as a polynomial";
let parameters = (ins
"size_t": $coeffs, // Number of polynomial coefficents
"size_t": $maxPos, // Maximum positive coefficient value
"size_t": $maxNeg, // Maximum negative coefficient value
"size_t": $minBits // If minBits == 0, no constraint, otherwise N >= 2^(minBits - 1)
);
let assemblyFormat = "`<` $coeffs `,` $maxPos `,` $maxNeg `,` $minBits `>`";
let genVerifyDecl = 1;
let extraClassDeclaration = [{
size_t getMaxBits() {
size_t extraBits = risc0::log2Ceil(getMaxPos() / (1 << kBitsPerCoeff));
size_t getMaxPosBits() {
// Because 2^k requires k+1 bits to represent, we add 1 to getMaxPos before log2Ceil
size_t extraBits = risc0::log2Ceil(getMaxPos() + 1);
if (extraBits <= kBitsPerCoeff) {
// When maxPos fits in a coeff, no extra bits are needed
extraBits = 0;
} else {
extraBits -= kBitsPerCoeff;
// Carries can sometimes lead to 1 extra bit so add 1
// Specifically, we know that the max value is
// getMaxPos() * sum over i in [0, getCoeffs()) of (2^kBitsPerCoeff)^i
// which is bounded above by
// getMaxPos() * 2^((kBitsPerCoeff * (getCoeffs() - 1)) + 1)
// which has bitwidth maxPosBitwidth + kBitsPerCoeff * (getCoeffs() - 1) + 1 which is
// kBitsPerCoeff * getCoeffs + extraBits
// where
// extraBits = maxPosBitwidth - kBitsPerCoeff + 1
if (getCoeffs() > 1) {
extraBits += 1;
}
}
return kBitsPerCoeff * getCoeffs() + extraBits;
}
size_t getCarryOffset() {
Expand Down
4 changes: 4 additions & 0 deletions zirgen/Dialect/BigInt/IR/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package(
default_visibility = ["//visibility:public"],
)

load("//bazel/rules/lit:defs.bzl", "glob_lit_tests")

glob_lit_tests()

cc_binary(
name = "test",
srcs = [
Expand Down
Loading
Loading