Skip to content

Commit

Permalink
clean up eval a bit: check operand bounds & pull eqz out into its own…
Browse files Browse the repository at this point in the history
… function
  • Loading branch information
mars-risc0 committed Oct 23, 2024
1 parent faf48a0 commit 9fb7a2e
Showing 1 changed file with 72 additions and 49 deletions.
121 changes: 72 additions & 49 deletions zirgen/Dialect/BigInt/Bytecode/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ namespace zirgen::BigInt::Bytecode {
// We must unfortunately redefine these constants whose original definitions
// live in zirgen/Dialect/BigInt/IR/BigInt.h
constexpr size_t kBitsPerCoeff = 8;
constexpr size_t kCoeffsPerPoly = 16;

BytePoly fromBQInt(BQInt value, size_t coeffs) {
BytePoly out(coeffs);
Expand Down Expand Up @@ -140,6 +139,68 @@ size_t getCarryBytes(const Type& type) {
return 4;
}

std::vector<BytePoly> eqz(const BytePoly &poly, const Type &type) {
if (toBQInt(poly) != 0) {
throw std::runtime_error("NONZERO");
}
uint32_t coeffs = type.coeffs;
int32_t carryOffset = getCarryOffset(type);
size_t carryBytes = getCarryBytes(type);
std::vector<BytePoly> carryPolys;
for (size_t i = 0; i < carryBytes; i++) {
carryPolys.emplace_back(coeffs);
};
int32_t carry = 0;
for (size_t i = 0; i < coeffs; i++) {
carry = (poly[i] + carry) / 256;
uint32_t carryU = carry + carryOffset;
carryPolys[0][i] = carryU & 0xff;
if (carryBytes > 1) {
carryPolys[1][i] = ((carryU >> 8) & 0xff);
}
if (carryBytes > 2) {
carryPolys[2][i] = ((carryU >> 16) & 0xff);
carryPolys[3][i] = ((carryU >> 16) & 0xff) * 4;
}
}
// Verify carry computation
BytePoly bigCarry(coeffs);
for (size_t i = 0; i < coeffs; i++) {
bigCarry[i] = carryPolys[0][i];
if (carryBytes > 1) {
bigCarry[i] += 256 * carryPolys[1][i];
}
if (carryBytes > 2) {
bigCarry[i] += 65536 * carryPolys[2][i];
}
bigCarry[i] -= carryOffset;
}
for (size_t i = 0; i < coeffs; i++) {
int32_t shouldBeZero = poly[i];
shouldBeZero -= 256 * bigCarry[i];
if (i != 0) {
shouldBeZero += bigCarry[i - 1];
}
if (shouldBeZero != 0) {
throw std::runtime_error("INVALID CARRY");
}
}
return carryPolys;
}

void checkOperandA(const Op& op, size_t opIndex) {
if (op.operandA >= opIndex) {
throw std::runtime_error("Forward reference to undefined value");
}
}

void checkOperands(const Op& op, size_t opIndex) {
// Have both operandA and operandB been defined?
if (op.operandA >= opIndex || op.operandB >= opIndex) {
throw std::runtime_error("Forward reference to undefined value");
}
}

} // namespace

EvalOutput eval(const Program& inFunc, std::vector<BQInt>& witnessValues) {
Expand Down Expand Up @@ -177,23 +238,27 @@ EvalOutput eval(const Program& inFunc, std::vector<BQInt>& witnessValues) {
ret.constantWitness.push_back(poly);
} break;
case Op::Add: {
checkOperands(op, opIndex);
auto lhs = polys[op.operandA];
auto rhs = polys[op.operandB];
polys[opIndex] = add(lhs, rhs);
} break;
case Op::Sub: {
checkOperands(op, opIndex);
auto lhs = polys[op.operandA];
auto rhs = polys[op.operandB];
polys[opIndex] = sub(lhs, rhs);
} break;
case Op::Mul: {
checkOperands(op, opIndex);
auto lhs = polys[op.operandA];
auto rhs = polys[op.operandB];
polys[opIndex] = mul(lhs, rhs);
} break;
case Op::Rem: {
const Type& type = inFunc.types[op.type];
uint32_t coeffs = type.coeffs;
checkOperands(op, opIndex);
auto lhs = polys[op.operandA];
auto rhs = polys[op.operandB];
auto poly = nondetRem(lhs, rhs, coeffs);
Expand All @@ -203,6 +268,7 @@ EvalOutput eval(const Program& inFunc, std::vector<BQInt>& witnessValues) {
case Op::Quo: {
const Type& type = inFunc.types[op.type];
uint32_t coeffs = type.coeffs;
checkOperands(op, opIndex);
auto lhs = polys[op.operandA];
auto rhs = polys[op.operandB];
auto poly = nondetQuot(lhs, rhs, coeffs);
Expand All @@ -212,63 +278,20 @@ EvalOutput eval(const Program& inFunc, std::vector<BQInt>& witnessValues) {
case Op::Inv: {
const Type& type = inFunc.types[op.type];
uint32_t coeffs = type.coeffs;
checkOperands(op, opIndex);
auto lhs = polys[op.operandA];
auto rhs = polys[op.operandB];
auto poly = nondetInvMod(lhs, rhs, coeffs);
polys[opIndex] = poly;
ret.privateWitness.push_back(poly);
} break;
case Op::Eqz: {
checkOperandA(op, opIndex);
auto poly = polys[op.operandA];
if (toBQInt(poly) != 0) {
throw std::runtime_error("NONZERO");
}
const Type& type = inFunc.types[op.type];
uint32_t coeffs = type.coeffs;
int32_t carryOffset = getCarryOffset(type);
size_t carryBytes = getCarryBytes(type);
std::vector<BytePoly> carryPolys;
for (size_t i = 0; i < carryBytes; i++) {
carryPolys.emplace_back(coeffs);
};
int32_t carry = 0;
for (size_t i = 0; i < coeffs; i++) {
carry = (poly[i] + carry) / 256;
uint32_t carryU = carry + carryOffset;
carryPolys[0][i] = carryU & 0xff;
if (carryBytes > 1) {
carryPolys[1][i] = ((carryU >> 8) & 0xff);
}
if (carryBytes > 2) {
carryPolys[2][i] = ((carryU >> 16) & 0xff);
carryPolys[3][i] = ((carryU >> 16) & 0xff) * 4;
}
}
// Verify carry computation
BytePoly bigCarry(coeffs);
for (size_t i = 0; i < coeffs; i++) {
bigCarry[i] = carryPolys[0][i];
if (carryBytes > 1) {
bigCarry[i] += 256 * carryPolys[1][i];
}
if (carryBytes > 2) {
bigCarry[i] += 65536 * carryPolys[2][i];
}
bigCarry[i] -= carryOffset;
}
for (size_t i = 0; i < coeffs; i++) {
int32_t shouldBeZero = poly[i];
shouldBeZero -= 256 * bigCarry[i];
if (i != 0) {
shouldBeZero += bigCarry[i - 1];
}
if (shouldBeZero != 0) {
throw std::runtime_error("INVALID CARRY");
}
}
// Store the results
for (size_t i = 0; i < carryPolys.size(); i++) {
ret.privateWitness.push_back(carryPolys[i]);
auto carryPolys = eqz(poly, type);
for (auto &p: carryPolys) {
ret.privateWitness.push_back(p);
}
} break;
default: {
Expand Down

0 comments on commit 9fb7a2e

Please sign in to comment.