Skip to content

Commit

Permalink
feat: witness caching now operates on a bounded buffer (#1215)
Browse files Browse the repository at this point in the history
* refactor: Refactor prove_recursively to accept IntoIterator

- Update `prove_recursively` functions across several files (`prove.rs`, `nova.rs`, `mod.rs`, `supernova.rs`) to accept an `IntoIterator` instead of a `Vec`, allowing for more flexible input types.
- Add `ExactSizeIterator` bound to the `IntoIterator` across the mentioned modules, enhancing function's safety and aiding in potential optimization.
- Refactor iteration over steps by converting Vector to an Iterator, reducing the ownership requirements and accommodating iterator input.
- Adjust the checking of the first element's arity in `nova.rs` to avoid consuming the iterator.

* Refactor Store and lifetimes handling with Arc

- Updated various places across the project which were using `Store` object directly, replacing it with the `Arc` type for safer shared ownership. This is further reflected in updated object initialization (e.g., `Store::default()` is now `Arc::new(Store::default())`).
- Removed explicit lifetime specifier `'a` from multiple structures and functions, making the implementation more readable and efficient. This change has affected objects and files such as `NovaProver`, `SuperNovaProver`, `LurkProofWrapper` etc.
- Modified various existing function signatures to eliminate direct references to the `Store` object and use `Arc<Store>` instead. Similarly, certain method arguments are updated to use references to the `Arc` instead.
- Changes in the `MultiFrame` struct and `InterpretationData` to use Arc, resulting in improved lifetime management. This update required adjustments in function implementation as well.
- Overall, these alterations do not affect logic or functionality but lead to better thread-safety and memory management in the system. The lifetime specification removal enhances code simplicity and readability.

* refactor: Refactor proof module to employ a fixed-sized buffer of cached witnesses

- Revise `prove_recursively` function to replace Mutex-protected steps iteration with a synchronous channel, streamlining parallelism.
- Add a constant for maximum buffered frames in the proving process in `mod.rs`.

* chore: adapt trie bench
  • Loading branch information
huitseeker authored Mar 19, 2024
1 parent 0d2ae13 commit 090931a
Show file tree
Hide file tree
Showing 22 changed files with 951 additions and 782 deletions.
15 changes: 7 additions & 8 deletions benches/end2end.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ fn end2end_benchmark(c: &mut Criterion) {
let lang = Lang::<Bn>::new();
let lang_rc = Arc::new(lang.clone());

let store = Store::default();
let prover: NovaProver<'_, Bn, Coproc<Bn>> = NovaProver::new(reduction_count, lang_rc.clone());
let store = Arc::new(Store::default());
let prover: NovaProver<Bn, Coproc<Bn>> = NovaProver::new(reduction_count, lang_rc.clone());

// use cached public params
let instance = Instance::new(reduction_count, lang_rc, true, Kind::NovaPublicParams);
Expand Down Expand Up @@ -195,7 +195,7 @@ fn prove_benchmark(c: &mut Criterion) {
let limit = 1_000_000_000;
let reduction_count = DEFAULT_REDUCTION_COUNT;

let store = Store::default();
let store = Arc::new(Store::default());

let size = (10, 0);
let benchmark_id = BenchmarkId::new("prove_go_base_nova", format!("_{}_{}", size.0, size.1));
Expand All @@ -216,8 +216,7 @@ fn prove_benchmark(c: &mut Criterion) {

group.bench_with_input(benchmark_id, &size, |b, &s| {
let ptr = go_base::<Bn>(&store, state.clone(), s.0, s.1);
let prover: NovaProver<'_, Bn, Coproc<Bn>> =
NovaProver::new(reduction_count, lang_rc.clone());
let prover: NovaProver<Bn, Coproc<Bn>> = NovaProver::new(reduction_count, lang_rc.clone());
let frames =
evaluate::<Bn, Coproc<Bn>>(None, ptr, &store, limit, &dummy_terminal()).unwrap();

Expand All @@ -242,7 +241,7 @@ fn prove_compressed_benchmark(c: &mut Criterion) {

set_bench_config();
let limit = 1_000_000_000;
let store = Store::default();
let store = Arc::new(Store::default());
let reduction_count = DEFAULT_REDUCTION_COUNT;

let size = (10, 0);
Expand Down Expand Up @@ -293,7 +292,7 @@ fn verify_benchmark(c: &mut Criterion) {

set_bench_config();
let limit = 1_000_000_000;
let store = Store::default();
let store = Arc::new(Store::default());
let reduction_count = DEFAULT_REDUCTION_COUNT;

let state = State::init_lurk_state().rccell();
Expand Down Expand Up @@ -348,7 +347,7 @@ fn verify_compressed_benchmark(c: &mut Criterion) {

set_bench_config();
let limit = 1_000_000_000;
let store = Store::default();
let store = Arc::new(Store::default());
let reduction_count = DEFAULT_REDUCTION_COUNT;

let state = State::init_lurk_state().rccell();
Expand Down
2 changes: 1 addition & 1 deletion benches/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ fn fibonacci_prove<M: measurement::Measurement>(
true,
Kind::NovaPublicParams,
);
let store = Store::default();
let store = Arc::new(Store::default());
let pp = public_params(&instance).unwrap();

// Track the number of `Lurk frames / sec`
Expand Down
24 changes: 12 additions & 12 deletions benches/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ fn sha256_ivc_prove<M: measurement::Measurement>(

let limit = 10000;

let store = &Store::<Bn>::default();
let store = Arc::new(Store::<Bn>::default());
let cproc_sym = user_sym(&format!("sha256_ivc_{arity}"));

let mut lang = Lang::<Bn, Sha256Coproc<Bn>>::new();
Expand All @@ -125,7 +125,7 @@ fn sha256_ivc_prove<M: measurement::Measurement>(
&prove_params,
|b, prove_params| {
let ptr = sha256_ivc(
store,
&*store,
state.clone(),
black_box(prove_params.arity),
black_box(prove_params.n),
Expand All @@ -137,7 +137,7 @@ fn sha256_ivc_prove<M: measurement::Measurement>(
let frames = &evaluate(
Some((&lurk_step, &[], &lang)),
ptr,
store,
&*store,
limit,
&dummy_terminal(),
)
Expand All @@ -146,7 +146,7 @@ fn sha256_ivc_prove<M: measurement::Measurement>(
b.iter_batched(
|| frames,
|frames| {
let result = prover.prove_from_frames(&pp, frames, store, None);
let result = prover.prove_from_frames(&pp, frames, &store, None);
let _ = black_box(result);
},
BatchSize::LargeInput,
Expand Down Expand Up @@ -190,7 +190,7 @@ fn sha256_ivc_prove_compressed<M: measurement::Measurement>(

let limit = 10000;

let store = &Store::<Bn>::default();
let store = Arc::new(Store::<Bn>::default());
let cproc_sym = user_sym(&format!("sha256_ivc_{arity}"));

let mut lang = Lang::<Bn, Sha256Coproc<Bn>>::new();
Expand All @@ -213,7 +213,7 @@ fn sha256_ivc_prove_compressed<M: measurement::Measurement>(
&prove_params,
|b, prove_params| {
let ptr = sha256_ivc(
store,
&*store,
state.clone(),
black_box(prove_params.arity),
black_box(prove_params.n),
Expand All @@ -225,7 +225,7 @@ fn sha256_ivc_prove_compressed<M: measurement::Measurement>(
let frames = &evaluate(
Some((&lurk_step, &[], &lang)),
ptr,
store,
&*store,
limit,
&dummy_terminal(),
)
Expand All @@ -235,7 +235,7 @@ fn sha256_ivc_prove_compressed<M: measurement::Measurement>(
|| frames,
|frames| {
let (proof, _, _, _) =
prover.prove_from_frames(&pp, frames, store, None).unwrap();
prover.prove_from_frames(&pp, frames, &store, None).unwrap();
let compressed_result = proof.compress(&pp).unwrap();

let _ = black_box(compressed_result);
Expand Down Expand Up @@ -281,7 +281,7 @@ fn sha256_nivc_prove<M: measurement::Measurement>(

let limit = 10000;

let store = &Store::<Bn>::default();
let store = Arc::new(Store::<Bn>::default());
let cproc_sym = user_sym(&format!("sha256_ivc_{arity}"));

let mut lang = Lang::<Bn, Sha256Coproc<Bn>>::new();
Expand All @@ -305,7 +305,7 @@ fn sha256_nivc_prove<M: measurement::Measurement>(
&prove_params,
|b, prove_params| {
let ptr = sha256_ivc(
store,
&*store,
state.clone(),
black_box(prove_params.arity),
black_box(prove_params.n),
Expand All @@ -317,7 +317,7 @@ fn sha256_nivc_prove<M: measurement::Measurement>(
let frames = &evaluate(
Some((&lurk_step, &cprocs, &lang)),
ptr,
store,
&*store,
limit,
&dummy_terminal(),
)
Expand All @@ -326,7 +326,7 @@ fn sha256_nivc_prove<M: measurement::Measurement>(
b.iter_batched(
|| frames,
|frames| {
let result = prover.prove_from_frames(&pp, frames, store, None);
let result = prover.prove_from_frames(&pp, frames, &store, None);
let _ = black_box(result);
},
BatchSize::LargeInput,
Expand Down
2 changes: 1 addition & 1 deletion benches/synthesis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fn synthesize<M: measurement::Measurement>(
BenchmarkId::new(name.to_string(), reduction_count),
&reduction_count,
|b, reduction_count| {
let store = Store::default();
let store = Arc::new(Store::default());
let fib_n = (reduction_count / 3) as u64; // Heuristic, since one fib is 35 iterations.
let ptr = fib::<Bn>(&store, state.clone(), black_box(fib_n));
let frames =
Expand Down
4 changes: 2 additions & 2 deletions benches/trie_nivc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn prove<M: measurement::Measurement>(
name: &str,
reduction_count: usize,
lang: &Arc<Lang<Fr, TrieCoproc<Fr>>>,
store: &Store<Fr>,
store: &Arc<Store<Fr>>,
frames: &[Frame],
c: &mut BenchmarkGroup<'_, M>,
) {
Expand Down Expand Up @@ -67,7 +67,7 @@ fn trie_nivc(c: &mut Criterion) {
install(&state, &mut lang);
let lang = Arc::new(lang);

let store = Store::<Fr>::default();
let store = Arc::new(Store::<Fr>::default());
let expr = store.read(state, CODE).unwrap();

let lurk_step = make_eval_step_from_config(&EvalConfig::new_nivc(&lang));
Expand Down
21 changes: 10 additions & 11 deletions chain-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ use chain_prover::{

use chain_server::{ChainRequestData, ChainResponseData};

struct ChainProverService<'a, F: CurveCycleEquipped, C: Coprocessor<F>> {
struct ChainProverService<F: CurveCycleEquipped, C: Coprocessor<F>> {
callable: Arc<Mutex<Ptr>>,
store: Store<F>, // TODO: add the store to the state to allow memory cleansing
store: Arc<Store<F>>, // TODO: add the store to the state to allow memory cleansing
limit: usize,
lurk_step: Func,
cprocs: Vec<Func>,
prover: SuperNovaProver<'a, F, C>,
prover: SuperNovaProver<F, C>,
public_params: OnceCell<PublicParams<F>>,
session: Option<Utf8PathBuf>,
}

impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> ChainProverService<'a, F, C> {
impl<F: CurveCycleEquipped, C: Coprocessor<F>> ChainProverService<F, C> {
fn new(
callable: Ptr,
store: Store<F>,
Expand All @@ -75,7 +75,7 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> ChainProverService<'a, F, C>
let prover = SuperNovaProver::<_, C>::new(rc, Arc::new(lang));
Self {
callable: Arc::new(Mutex::new(callable)),
store,
store: Arc::new(store),
limit,
lurk_step,
cprocs,
Expand All @@ -90,7 +90,7 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> ChainProverService<'a, F, C>
impl<
F: CurveCycleEquipped + DeserializeOwned + Serialize,
C: Coprocessor<F> + Serialize + DeserializeOwned + 'static,
> ChainProver for ChainProverService<'static, F, C>
> ChainProver for ChainProverService<F, C>
where
<F as ff::PrimeField>::Repr: Abomonation,
<Dual<F> as ff::PrimeField>::Repr: Abomonation,
Expand Down Expand Up @@ -220,8 +220,8 @@ struct SessionData<F: LurkField, S> {
rc: usize,
}

impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> SessionData<F, C> {
fn pack(data: &ChainProverService<'a, F, C>, callable: &Ptr) -> Self {
impl<F: CurveCycleEquipped, C: Coprocessor<F>> SessionData<F, C> {
fn pack(data: &ChainProverService<F, C>, callable: &Ptr) -> Self {
let ChainProverService {
store,
limit,
Expand All @@ -241,7 +241,7 @@ impl<'a, F: CurveCycleEquipped, C: Coprocessor<F>> SessionData<F, C> {
}
}

fn unpack(self, session: Utf8PathBuf) -> Result<ChainProverService<'a, F, C>> {
fn unpack(self, session: Utf8PathBuf) -> Result<ChainProverService<F, C>> {
let Self {
callable,
z_store,
Expand Down Expand Up @@ -344,10 +344,9 @@ struct ResumeArgs {
}

fn get_service_and_address<
'a,
F: CurveCycleEquipped + DeserializeOwned,
C: Coprocessor<F> + DeserializeOwned,
>() -> Result<(ChainProverService<'a, F, C>, SocketAddr), Box<dyn std::error::Error>> {
>() -> Result<(ChainProverService<F, C>, SocketAddr), Box<dyn std::error::Error>> {
let Cli { command } = Cli::parse();
let local_ip = |port| SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port));
match command {
Expand Down
8 changes: 4 additions & 4 deletions examples/keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,12 @@ fn main() {
let args = std::env::args().collect::<Vec<_>>();

// Initialize store, responsible for handling variables in the lurk context
let store: &Store<pallas::Scalar> = &Store::default();
let store: Arc<Store<pallas::Scalar>> = Arc::new(Store::default());

// Define the symbol that will call upon our Coprocessor
let str_to_le_bits_sym = user_sym("str_to_le_bits");
let keccak_sym = user_sym("keccak_hash");
let program = lurk_program(store, &args[1]);
let program = lurk_program(&*store, &args[1]);

// Create the Lang. ie the list of corprocessor that will be accessible in our program
let mut lang = Lang::<pallas::Scalar, KeccakExampleCoproc<pallas::Scalar>>::new();
Expand All @@ -360,7 +360,7 @@ fn main() {
let frames = evaluate(
Some((&lurk_step, &cprocs, &lang)),
program,
store,
&*store,
1000,
&dummy_terminal(),
)
Expand Down Expand Up @@ -395,7 +395,7 @@ fn main() {
let proof_start = Instant::now();

let (proof, z0, zi, _) = supernova_prover
.prove_from_frames(&pp, &frames, store, None)
.prove_from_frames(&pp, &frames, &store, None)
.unwrap();

let proof_end = proof_start.elapsed();
Expand Down
6 changes: 3 additions & 3 deletions examples/sha256_ivc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ fn main() {
let args = std::env::args().collect::<Vec<_>>();
let n = args.get(1).unwrap_or(&"1".into()).parse().unwrap();

let store = &Store::default();
let store = Arc::new(Store::default());
let cproc_sym = user_sym(&format!("sha256_ivc_{n}"));

let call = sha256_ivc(store, n, &(0..n).collect::<Vec<_>>());
let call = sha256_ivc(&*store, n, &(0..n).collect::<Vec<_>>());

let mut lang = Lang::<Bn, Sha256Coproc<Bn>>::new();
lang.add_coprocessor(cproc_sym, Sha256Coprocessor::new(n));
Expand All @@ -92,7 +92,7 @@ fn main() {
&pp,
call,
store.intern_empty_env(),
store,
&store,
10000,
&dummy_terminal(),
)
Expand Down
8 changes: 4 additions & 4 deletions examples/sha256_nivc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ fn main() {
let args = std::env::args().collect::<Vec<_>>();
let n = args.get(1).unwrap_or(&"1".into()).parse().unwrap();

let store = &Store::default();
let store = Arc::new(Store::default());
let cproc_sym = user_sym(&format!("sha256_nivc_{n}"));

let call = sha256_nivc(store, n, &(0..n).collect::<Vec<_>>());
let call = sha256_nivc(&*store, n, &(0..n).collect::<Vec<_>>());

let mut lang = Lang::<Bn, Sha256Coproc<Bn>>::new();
lang.add_coprocessor(cproc_sym, Sha256Coprocessor::new(n));
Expand All @@ -81,7 +81,7 @@ fn main() {
let frames = evaluate(
Some((&lurk_step, &cprocs, &lang)),
call,
store,
&*store,
1000,
&dummy_terminal(),
)
Expand All @@ -104,7 +104,7 @@ fn main() {
let (proof, z0, zi, _num_steps) = tracing_texray::examine(tracing::info_span!("bang!"))
.in_scope(|| {
supernova_prover
.prove_from_frames(&pp, &frames, store, None)
.prove_from_frames(&pp, &frames, &store, None)
.unwrap()
});
let proof_end = proof_start.elapsed();
Expand Down
4 changes: 2 additions & 2 deletions examples/tp_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ fn main() {

let limit = n_iters(max_n_folds, *max_rc);

let store = Store::default();
let store = Arc::new(Store::default());
let program = store.read_with_default_state(PROGRAM).unwrap();

let frames =
Expand All @@ -164,7 +164,7 @@ fn main() {
let mut data = Vec::with_capacity(rc_vec.len());

for rc in rc_vec.clone() {
let prover: NovaProver<'_, _, _> = NovaProver::new(rc, lang_arc.clone());
let prover: NovaProver<_, _> = NovaProver::new(rc, lang_arc.clone());
println!("Getting public params for rc={rc}");
// TODO: use cache once it's fixed
let pp: PublicParams<_> = public_params(rc, lang_arc.clone());
Expand Down
Loading

1 comment on commit 090931a

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Fibonacci GPU benchmark.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
32 vCPUs
125 GB RAM
Workflow run: https://github.com/lurk-lab/lurk-rs/actions/runs/8347157740

Benchmark Results

LEM Fibonacci Prove - rc = 100

ref=0d2ae13ceb55963aa4be94b65fb3d0194f92b411 ref=090931a24dbf871464ccb41eace5dc3b1f70f993
num-100 1.46 s (✅ 1.00x) 1.51 s (✅ 1.03x slower)
num-200 2.78 s (✅ 1.00x) 2.86 s (✅ 1.03x slower)

LEM Fibonacci Prove - rc = 600

ref=0d2ae13ceb55963aa4be94b65fb3d0194f92b411 ref=090931a24dbf871464ccb41eace5dc3b1f70f993
num-100 1.86 s (✅ 1.00x) 1.93 s (✅ 1.04x slower)
num-200 3.04 s (✅ 1.00x) 3.15 s (✅ 1.04x slower)

Made with criterion-table

Please sign in to comment.