From 14e70f2891f45aed785ab78ba9ecb8197a5674d1 Mon Sep 17 00:00:00 2001 From: Piotr Dyraga Date: Fri, 23 Sep 2022 04:44:11 +0200 Subject: [PATCH] Improve DLN proof verification performance for large signing groups (#203) * Benchmark tests for DLN proof verification The DLN proof verification is one of the most expensive parts of the key generation protocol. This benchmark allows to check how expensive the Validate call is. * Control the concurrency level when verifying DLN proofs Control the concurrency level when verifying DLN proofs Verification of discrete logarithm proofs is the most expensive part of threshold ECDSA key generation for large groups. In round 2 of key generation, the local party needs to verify proofs received from all other parties. The cost of a call to `dlnProof.Verify` measured on Darwin/arm64 Apple M1 max is 341758528 ns/op - see `BenchmarkDLNProofVerification`. There are two proofs that need to be verified during the key generation so assuming there are two cores available exclusively for this work, the verification of 100 messages takes about 35 seconds. For a group size of 1000, the verification takes about 350 seconds. The verification is performed in separate goroutines with no control over the number of goroutines created. When executing the protocol locally, during the development, for a group size of 100, 100*99*2 = 19 800 goroutines for DLN proof verification are created more or less at the same time. Even such a powerful CPU as Apple M1 Max struggles with computing the proofs - it takes more than 16 minutes on all available cores and all other processes are starved. To optimize the code to allow it to be executed for larger groups, the number of goroutines created at the same time for DLN proof verification is throttled so that all other processes are not perpetually denied necessary CPU time to perform their work. This is achieved by introducing the `DlnProofVerifier` that limits the concurrency level, by default to the number of CPUs (cores) available. * Added benchmarks for DlnProofVerifier functions The benchmarks are promising and shows the the validator does not add any significant overhead over the DLN verification itself: BenchmarkDlnProof_Verify-10 3 342581417 ns/op 1766010 B/op 3790 allocs/op BenchmarkDlnVerifier_VerifyProof1-10 3 342741028 ns/op 1859093 B/op 4320 allocs/op BenchmarkDlnVerifier_VerifyProof2-10 3 341878361 ns/op 1851984 B/op 4311 allocs/op * Allow configuring key generation concurrency in params The concurrency defaults to `GOMAXPROCS` and can be updated with a call to `SetConcurrency`. The concurrency level is applied to the pre-params generator and DLN proof validator. Since there are two optional values now when constructing parameters, instead of passing safe prime gen timeout as the last value of `NewParameters`, all expected parameters should be configured with `Set*` functions. * Use DLN verifier for resharing protocol DLN proof verification is the most computationally expensive operation of the protocol when working in large groups. DLN verifier allows to throttle the number of goroutines verifying the proofs at the same time so that other processes do not get starved. DLN verifier is already applied to key generation protocol. Here, it is getting applied to resharing as well. * Always expect concurrency level to be passed to NewDlnProofVerifier The concurrency level is now available in all rounds constructing the verifier and the optional concurrency feature is never used. * Ensure that the concurrency level is non-zero `tss.NewParameters` is not validating provided values and I did not want to make a breaking change there. Instead, I added a comment to `SetConcurrency` and added a panic in the DLN proof verifier ensuring the protocol fails with a clear message instead of hanging. This is aligned with how `keygen.GeneratePreParamsWithContext` deals with an invalid value for `optionalConcurrency` param. * Log on DEBUG level concurrency level for DLN verification --- ecdsa/keygen/dln_verifier.go | 73 ++++++++ ecdsa/keygen/dln_verifier_test.go | 241 ++++++++++++++++++++++++++ ecdsa/keygen/round_1.go | 2 +- ecdsa/keygen/round_2.go | 29 +++- ecdsa/resharing/round_2_new_step_1.go | 2 +- ecdsa/resharing/round_4_new_step_2.go | 30 ++-- tss/params.go | 30 ++-- 7 files changed, 374 insertions(+), 33 deletions(-) create mode 100644 ecdsa/keygen/dln_verifier.go create mode 100644 ecdsa/keygen/dln_verifier_test.go diff --git a/ecdsa/keygen/dln_verifier.go b/ecdsa/keygen/dln_verifier.go new file mode 100644 index 00000000..cc6be8bf --- /dev/null +++ b/ecdsa/keygen/dln_verifier.go @@ -0,0 +1,73 @@ +// Copyright © 2019 Binance +// +// This file is part of Binance. The full Binance copyright notice, including +// terms governing use, modification, and redistribution, is contained in the +// file LICENSE at the root of the source code distribution tree. + +package keygen + +import ( + "errors" + "math/big" + + "github.com/bnb-chain/tss-lib/crypto/dlnproof" +) + +type DlnProofVerifier struct { + semaphore chan interface{} +} + +type message interface { + UnmarshalDLNProof1() (*dlnproof.Proof, error) + UnmarshalDLNProof2() (*dlnproof.Proof, error) +} + +func NewDlnProofVerifier(concurrency int) *DlnProofVerifier { + if concurrency == 0 { + panic(errors.New("NewDlnProofverifier: concurrency level must not be zero")) + } + + semaphore := make(chan interface{}, concurrency) + + return &DlnProofVerifier{ + semaphore: semaphore, + } +} + +func (dpv *DlnProofVerifier) VerifyDLNProof1( + m message, + h1, h2, n *big.Int, + onDone func(bool), +) { + dpv.semaphore <- struct{}{} + go func() { + defer func() { <-dpv.semaphore }() + + dlnProof, err := m.UnmarshalDLNProof1() + if err != nil { + onDone(false) + return + } + + onDone(dlnProof.Verify(h1, h2, n)) + }() +} + +func (dpv *DlnProofVerifier) VerifyDLNProof2( + m message, + h1, h2, n *big.Int, + onDone func(bool), +) { + dpv.semaphore <- struct{}{} + go func() { + defer func() { <-dpv.semaphore }() + + dlnProof, err := m.UnmarshalDLNProof2() + if err != nil { + onDone(false) + return + } + + onDone(dlnProof.Verify(h1, h2, n)) + }() +} diff --git a/ecdsa/keygen/dln_verifier_test.go b/ecdsa/keygen/dln_verifier_test.go new file mode 100644 index 00000000..738bbf6b --- /dev/null +++ b/ecdsa/keygen/dln_verifier_test.go @@ -0,0 +1,241 @@ +// Copyright © 2019 Binance +// +// This file is part of Binance. The full Binance copyright notice, including +// terms governing use, modification, and redistribution, is contained in the +// file LICENSE at the root of the source code distribution tree. + +package keygen + +import ( + "math/big" + "runtime" + "testing" + + "github.com/bnb-chain/tss-lib/crypto/dlnproof" +) + +func BenchmarkDlnProof_Verify(b *testing.B) { + localPartySaveData, _, err := LoadKeygenTestFixtures(1) + if err != nil { + b.Fatal(err) + } + + params := localPartySaveData[0].LocalPreParams + + proof := dlnproof.NewDLNProof( + params.H1i, + params.H2i, + params.Alpha, + params.P, + params.Q, + params.NTildei, + ) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + proof.Verify(params.H1i, params.H2i, params.NTildei) + } +} + +func BenchmarkDlnVerifier_VerifyProof1(b *testing.B) { + preParams, proof := prepareProofB(b) + message := &KGRound1Message{ + Dlnproof_1: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + resultChan := make(chan bool) + verifier.VerifyDLNProof1(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + <-resultChan + } +} + +func BenchmarkDlnVerifier_VerifyProof2(b *testing.B) { + preParams, proof := prepareProofB(b) + message := &KGRound1Message{ + Dlnproof_2: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + b.ResetTimer() + for n := 0; n < b.N; n++ { + resultChan := make(chan bool) + verifier.VerifyDLNProof2(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + <-resultChan + } +} + +func TestVerifyDLNProof1_Success(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_1: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + verifier.VerifyDLNProof1(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if !success { + t.Fatal("expected positive verification") + } +} + +func TestVerifyDLNProof1_MalformedMessage(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_1: proof[:len(proof)-1], // truncate + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + verifier.VerifyDLNProof1(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if success { + t.Fatal("expected negative verification") + } +} + +func TestVerifyDLNProof1_IncorrectProof(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_1: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + wrongH1i := preParams.H1i.Sub(preParams.H1i, big.NewInt(1)) + verifier.VerifyDLNProof1(message, wrongH1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if success { + t.Fatal("expected negative verification") + } +} + +func TestVerifyDLNProof2_Success(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_2: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + verifier.VerifyDLNProof2(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if !success { + t.Fatal("expected positive verification") + } +} + +func TestVerifyDLNProof2_MalformedMessage(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_2: proof[:len(proof)-1], // truncate + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + verifier.VerifyDLNProof2(message, preParams.H1i, preParams.H2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if success { + t.Fatal("expected negative verification") + } +} + +func TestVerifyDLNProof2_IncorrectProof(t *testing.T) { + preParams, proof := prepareProofT(t) + message := &KGRound1Message{ + Dlnproof_2: proof, + } + + verifier := NewDlnProofVerifier(runtime.GOMAXPROCS(0)) + + resultChan := make(chan bool) + + wrongH2i := preParams.H2i.Add(preParams.H2i, big.NewInt(1)) + verifier.VerifyDLNProof2(message, preParams.H1i, wrongH2i, preParams.NTildei, func(result bool) { + resultChan <- result + }) + + success := <-resultChan + if success { + t.Fatal("expected negative verification") + } +} + +func prepareProofT(t *testing.T) (*LocalPreParams, [][]byte) { + preParams, serialized, err := prepareProof() + if err != nil { + t.Fatal(err) + } + + return preParams, serialized +} + +func prepareProofB(b *testing.B) (*LocalPreParams, [][]byte) { + preParams, serialized, err := prepareProof() + if err != nil { + b.Fatal(err) + } + + return preParams, serialized +} + +func prepareProof() (*LocalPreParams, [][]byte, error) { + localPartySaveData, _, err := LoadKeygenTestFixtures(1) + if err != nil { + return nil, [][]byte{}, err + } + + preParams := localPartySaveData[0].LocalPreParams + + proof := dlnproof.NewDLNProof( + preParams.H1i, + preParams.H2i, + preParams.Alpha, + preParams.P, + preParams.Q, + preParams.NTildei, + ) + + serialized, err := proof.Serialize() + if err != nil { + if err != nil { + return nil, [][]byte{}, err + } + } + + return &preParams, serialized, nil +} diff --git a/ecdsa/keygen/round_1.go b/ecdsa/keygen/round_1.go index 8eea72d4..9cfda2df 100644 --- a/ecdsa/keygen/round_1.go +++ b/ecdsa/keygen/round_1.go @@ -74,7 +74,7 @@ func (round *round1) Start() *tss.Error { } else if round.save.LocalPreParams.ValidateWithProof() { preParams = &round.save.LocalPreParams } else { - preParams, err = GeneratePreParams(round.SafePrimeGenTimeout(), 3) + preParams, err = GeneratePreParams(round.SafePrimeGenTimeout(), round.Concurrency()) if err != nil { return round.WrapError(errors.New("pre-params generation failed"), Pi) } diff --git a/ecdsa/keygen/round_2.go b/ecdsa/keygen/round_2.go index b41ec78c..e72b6be2 100644 --- a/ecdsa/keygen/round_2.go +++ b/ecdsa/keygen/round_2.go @@ -9,9 +9,9 @@ package keygen import ( "encoding/hex" "errors" - "math/big" "sync" + "github.com/bnb-chain/tss-lib/common" "github.com/bnb-chain/tss-lib/tss" ) @@ -27,6 +27,13 @@ func (round *round2) Start() *tss.Error { round.started = true round.resetOK() + common.Logger.Debugf( + "%s Setting up DLN verification with concurrency level of %d", + round.PartyID(), + round.Concurrency(), + ) + dlnVerifier := NewDlnProofVerifier(round.Concurrency()) + i := round.PartyID().Index // 6. verify dln proofs, store r1 message pieces, ensure uniqueness of h1j, h2j @@ -58,19 +65,23 @@ func (round *round2) Start() *tss.Error { return round.WrapError(errors.New("this h2j was already used by another party"), msg.GetFrom()) } h1H2Map[h1JHex], h1H2Map[h2JHex] = struct{}{}, struct{}{} + wg.Add(2) - go func(j int, msg tss.ParsedMessage, r1msg *KGRound1Message, H1j, H2j, NTildej *big.Int) { - if dlnProof1, err := r1msg.UnmarshalDLNProof1(); err != nil || !dlnProof1.Verify(H1j, H2j, NTildej) { - dlnProof1FailCulprits[j] = msg.GetFrom() + _j := j + _msg := msg + + dlnVerifier.VerifyDLNProof1(r1msg, H1j, H2j, NTildej, func(isValid bool) { + if !isValid { + dlnProof1FailCulprits[_j] = _msg.GetFrom() } wg.Done() - }(j, msg, r1msg, H1j, H2j, NTildej) - go func(j int, msg tss.ParsedMessage, r1msg *KGRound1Message, H1j, H2j, NTildej *big.Int) { - if dlnProof2, err := r1msg.UnmarshalDLNProof2(); err != nil || !dlnProof2.Verify(H2j, H1j, NTildej) { - dlnProof2FailCulprits[j] = msg.GetFrom() + }) + dlnVerifier.VerifyDLNProof2(r1msg, H2j, H1j, NTildej, func(isValid bool) { + if !isValid { + dlnProof2FailCulprits[_j] = _msg.GetFrom() } wg.Done() - }(j, msg, r1msg, H1j, H2j, NTildej) + }) } wg.Wait() for _, culprit := range append(dlnProof1FailCulprits, dlnProof2FailCulprits...) { diff --git a/ecdsa/resharing/round_2_new_step_1.go b/ecdsa/resharing/round_2_new_step_1.go index 365e20a3..c93c2861 100644 --- a/ecdsa/resharing/round_2_new_step_1.go +++ b/ecdsa/resharing/round_2_new_step_1.go @@ -49,7 +49,7 @@ func (round *round2) Start() *tss.Error { preParams = &round.save.LocalPreParams } else { var err error - preParams, err = keygen.GeneratePreParams(round.SafePrimeGenTimeout()) + preParams, err = keygen.GeneratePreParams(round.SafePrimeGenTimeout(), round.Concurrency()) if err != nil { return round.WrapError(errors.New("pre-params generation failed"), Pi) } diff --git a/ecdsa/resharing/round_4_new_step_2.go b/ecdsa/resharing/round_4_new_step_2.go index 40ef9448..9bff552d 100644 --- a/ecdsa/resharing/round_4_new_step_2.go +++ b/ecdsa/resharing/round_4_new_step_2.go @@ -18,6 +18,7 @@ import ( "github.com/bnb-chain/tss-lib/crypto" "github.com/bnb-chain/tss-lib/crypto/commitments" "github.com/bnb-chain/tss-lib/crypto/vss" + "github.com/bnb-chain/tss-lib/ecdsa/keygen" "github.com/bnb-chain/tss-lib/tss" ) @@ -36,6 +37,13 @@ func (round *round4) Start() *tss.Error { return nil } + common.Logger.Debugf( + "%s Setting up DLN verification with concurrency level of %d", + round.PartyID(), + round.Concurrency(), + ) + dlnVerifier := keygen.NewDlnProofVerifier(round.Concurrency()) + Pi := round.PartyID() i := Pi.Index @@ -71,20 +79,22 @@ func (round *round4) Start() *tss.Error { } wg.Done() }(j, msg, r2msg1) - go func(j int, msg tss.ParsedMessage, r2msg1 *DGRound2Message1, H1j, H2j, NTildej *big.Int) { - if dlnProof1, err := r2msg1.UnmarshalDLNProof1(); err != nil || !dlnProof1.Verify(H1j, H2j, NTildej) { - dlnProof1FailCulprits[j] = msg.GetFrom() - common.Logger.Warningf("dln proof 1 verify failed for party %s", msg.GetFrom(), err) + _j := j + _msg := msg + dlnVerifier.VerifyDLNProof1(r2msg1, H1j, H2j, NTildej, func(isValid bool) { + if !isValid { + dlnProof1FailCulprits[_j] = _msg.GetFrom() + common.Logger.Warningf("dln proof 1 verify failed for party %s", _msg.GetFrom()) } wg.Done() - }(j, msg, r2msg1, H1j, H2j, NTildej) - go func(j int, msg tss.ParsedMessage, r2msg1 *DGRound2Message1, H1j, H2j, NTildej *big.Int) { - if dlnProof2, err := r2msg1.UnmarshalDLNProof2(); err != nil || !dlnProof2.Verify(H2j, H1j, NTildej) { - dlnProof2FailCulprits[j] = msg.GetFrom() - common.Logger.Warningf("dln proof 2 verify failed for party %s", msg.GetFrom(), err) + }) + dlnVerifier.VerifyDLNProof2(r2msg1, H2j, H1j, NTildej, func(isValid bool) { + if !isValid { + dlnProof2FailCulprits[_j] = _msg.GetFrom() + common.Logger.Warningf("dln proof 2 verify failed for party %s", _msg.GetFrom()) } wg.Done() - }(j, msg, r2msg1, H1j, H2j, NTildej) + }) } wg.Wait() for _, culprit := range append(append(paiProofCulprits, dlnProof1FailCulprits...), dlnProof2FailCulprits...) { diff --git a/tss/params.go b/tss/params.go index 8cb33368..8bf74148 100644 --- a/tss/params.go +++ b/tss/params.go @@ -8,7 +8,7 @@ package tss import ( "crypto/elliptic" - "errors" + "runtime" "time" ) @@ -19,6 +19,7 @@ type ( parties *PeerContext partyCount int threshold int + concurrency int safePrimeGenTimeout time.Duration } @@ -35,23 +36,15 @@ const ( ) // Exported, used in `tss` client -func NewParameters(ec elliptic.Curve, ctx *PeerContext, partyID *PartyID, partyCount, threshold int, optionalSafePrimeGenTimeout ...time.Duration) *Parameters { - var safePrimeGenTimeout time.Duration - if 0 < len(optionalSafePrimeGenTimeout) { - if 1 < len(optionalSafePrimeGenTimeout) { - panic(errors.New("GeneratePreParams: expected 0 or 1 item in `optionalSafePrimeGenTimeout`")) - } - safePrimeGenTimeout = optionalSafePrimeGenTimeout[0] - } else { - safePrimeGenTimeout = defaultSafePrimeGenTimeout - } +func NewParameters(ec elliptic.Curve, ctx *PeerContext, partyID *PartyID, partyCount, threshold int) *Parameters { return &Parameters{ ec: ec, parties: ctx, partyID: partyID, partyCount: partyCount, threshold: threshold, - safePrimeGenTimeout: safePrimeGenTimeout, + concurrency: runtime.GOMAXPROCS(0), + safePrimeGenTimeout: defaultSafePrimeGenTimeout, } } @@ -75,10 +68,23 @@ func (params *Parameters) Threshold() int { return params.threshold } +func (params *Parameters) Concurrency() int { + return params.concurrency +} + func (params *Parameters) SafePrimeGenTimeout() time.Duration { return params.safePrimeGenTimeout } +// The concurrency level must be >= 1. +func (params *Parameters) SetConcurrency(concurrency int) { + params.concurrency = concurrency +} + +func (params *Parameters) SetSafePrimeGenTimeout(timeout time.Duration) { + params.safePrimeGenTimeout = timeout +} + // ----- // // Exported, used in `tss` client