Skip to content

Commit

Permalink
Validate improper TLS RSA certificates early-on
Browse files Browse the repository at this point in the history
This commit makes validation of TLS certificates with either too big RSA keys, or the wrong exponent, fail as soon as the remote node presents its TLS certificate, in contrast to after the TLS handshake.

Signed-off-by: Yacov Manevich <yacov.manevich@avalabs.org>
  • Loading branch information
yacovm committed Oct 24, 2024
1 parent dcb2f70 commit 86cde37
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 12 deletions.
23 changes: 23 additions & 0 deletions network/peer/tls_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
package peer

import (
"crypto/rsa"
"crypto/tls"
"errors"
"io"

"github.com/ava-labs/avalanchego/staking"
)

// TLSConfig returns the TLS config that will allow secure connections to other
Expand All @@ -26,5 +30,24 @@ func TLSConfig(cert tls.Certificate, keyLogWriter io.Writer) *tls.Config {
InsecureSkipVerify: true, //#nosec G402
MinVersion: tls.VersionTLS13,
KeyLogWriter: keyLogWriter,
VerifyConnection: validateRSACertificate,
}
}

func validateRSACertificate(cs tls.ConnectionState) error {
if len(cs.PeerCertificates) == 0 {
return errors.New("no certificates sent by peer")
}

pk := cs.PeerCertificates[0].PublicKey
if pk == nil {
return errors.New("no public key sent by peer")
}

switch rsaKey := pk.(type) {
case *rsa.PublicKey:
return staking.ValidateRSAPublicKeyIsWellFormed(rsaKey)
default:
return nil
}
}
202 changes: 202 additions & 0 deletions network/peer/upgrader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package peer_test

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"math/big"
"net"
"sync"
"testing"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"

"github.com/ava-labs/avalanchego/network/peer"
"github.com/ava-labs/avalanchego/staking"
)

func TestBlockClientsWithIncorrectRSAKeys(t *testing.T) {
privKey2048, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

privKey4096, err := rsa.GenerateKey(rand.Reader, 4096)
require.NoError(t, err)

privKey8192, err := rsa.GenerateKey(rand.Reader, 8192)
require.NoError(t, err)

clientCert2048 := makeTLSCert(t, privKey2048)
clientCert4096 := makeTLSCert(t, privKey4096)
clientCert8192 := makeTLSCert(t, privKey8192)
clientCertBad := makeTLSCert(t, nonStandardRSAKey(t))

for _, testCase := range []struct {
description string
clientTLSCert tls.Certificate
shouldSucceed bool
expectedErr error
}{
{
description: "Proper key size and private key - 2048",
clientTLSCert: clientCert2048,
shouldSucceed: true,
},
{
description: "Proper key size and private key - 4096",
clientTLSCert: clientCert4096,
shouldSucceed: true,
},
{
description: "Too big key",
clientTLSCert: clientCert8192,
expectedErr: staking.ErrUnsupportedRSAModulusBitLen,
},
{
description: "Improper public exponent",
clientTLSCert: clientCertBad,
expectedErr: staking.ErrUnsupportedRSAPublicExponent,
},
} {
t.Run(testCase.description, func(t *testing.T) {
serverKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

serverCert := makeTLSCert(t, serverKey)

config := peer.TLSConfig(serverCert, nil)

c := prometheus.NewCounter(prometheus.CounterOpts{})

// Initialize upgrader with a mock that fails when it's incremented.
failOnIncrementCounter := &mockPrometheusCounter{
Counter: c,
t: t,
onIncrement: func() {
require.FailNow(t, "should not have invoked")
},
}
upgrader := peer.NewTLSServerUpgrader(config, failOnIncrementCounter)

clientConfig := tls.Config{
ClientAuth: tls.RequireAnyClientCert,
InsecureSkipVerify: true, //#nosec G402
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{testCase.clientTLSCert},
}

listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()

var wg sync.WaitGroup
wg.Add(1)

go func() {
defer wg.Done()
conn, err := listener.Accept()
require.NoError(t, err)

_, _, _, err = upgrader.Upgrade(conn)

if testCase.shouldSucceed {
require.NoError(t, err)
} else {
require.ErrorIs(t, err, testCase.expectedErr)
}
}()

conn, err := tls.Dial("tcp", listener.Addr().String(), &clientConfig)
require.NoError(t, err)

err = conn.Handshake()
require.NoError(t, err)

wg.Wait()
})
}
}

func nonStandardRSAKey(t *testing.T) *rsa.PrivateKey {
sk, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

sk.Precomputed = rsa.PrecomputedValues{}

// We want a non-standard E, so let's use E = 257 and derive D again.
e := 257
sk.PublicKey.E = e
sk.E = e

p := sk.Primes[0]
q := sk.Primes[1]

pminus1 := new(big.Int).Sub(p, big.NewInt(1))
qminus1 := new(big.Int).Sub(q, big.NewInt(1))

phiN := big.NewInt(0).Mul(pminus1, qminus1)

sk.D = big.NewInt(0).ModInverse(big.NewInt(int64(e)), phiN)

return sk
}

func makeTLSCert(t *testing.T, privKey *rsa.PrivateKey) tls.Certificate {
x509Cert := makeRSACertAndKey(t, privKey)

rawX509PEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: x509Cert.cert.Raw})
privateKeyInDER, err := x509.MarshalPKCS8PrivateKey(x509Cert.key)
require.NoError(t, err)

privateKeyInPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyInDER})

tlsCertServer, err := tls.X509KeyPair(rawX509PEM, privateKeyInPEM)
require.NoError(t, err)

return tlsCertServer
}

type certAndKey struct {
cert x509.Certificate
key *rsa.PrivateKey
}

func makeRSACertAndKey(t *testing.T, privKey *rsa.PrivateKey) certAndKey {
// Create a self-signed cert
basicCert := basicCert()
certBytes, err := x509.CreateCertificate(rand.Reader, basicCert, basicCert, &privKey.PublicKey, privKey)
require.NoError(t, err)

cert, err := x509.ParseCertificate(certBytes)
require.NoError(t, err)

return certAndKey{
cert: *cert,
key: privKey,
}
}

func basicCert() *x509.Certificate {
return &x509.Certificate{
SerialNumber: big.NewInt(0).SetInt64(100),
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour).UTC(),
BasicConstraintsValid: true,
}
}

type mockPrometheusCounter struct {
t *testing.T
prometheus.Counter
onIncrement func()
}

func (m *mockPrometheusCounter) Inc() {
m.onIncrement()
}
30 changes: 18 additions & 12 deletions staking/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,8 @@ func parsePublicKey(oid asn1.ObjectIdentifier, publicKey asn1.BitString) (crypto
return nil, ErrInvalidRSAPublicExponent
}

if pub.N.Sign() <= 0 {
return nil, ErrRSAModulusNotPositive
}

if bitLen := pub.N.BitLen(); bitLen != allowedRSALargeModulusLen && bitLen != allowedRSASmallModulusLen {
return nil, fmt.Errorf("%w: %d", ErrUnsupportedRSAModulusBitLen, bitLen)
}
if pub.N.Bit(0) == 0 {
return nil, ErrRSAModulusIsEven
}
if pub.E != allowedRSAPublicExponentValue {
return nil, fmt.Errorf("%w: %d", ErrUnsupportedRSAPublicExponent, pub.E)
if err := ValidateRSAPublicKeyIsWellFormed(pub); err != nil {
return nil, err
}
return pub, nil
case oid.Equal(oidPublicKeyECDSA):
Expand All @@ -165,3 +155,19 @@ func parsePublicKey(oid asn1.ObjectIdentifier, publicKey asn1.BitString) (crypto
return nil, ErrUnknownPublicKeyAlgorithm
}
}

func ValidateRSAPublicKeyIsWellFormed(pub *rsa.PublicKey) error {
if pub.N.Sign() <= 0 {
return ErrRSAModulusNotPositive
}
if bitLen := pub.N.BitLen(); bitLen != allowedRSALargeModulusLen && bitLen != allowedRSASmallModulusLen {
return fmt.Errorf("%w: %d", ErrUnsupportedRSAModulusBitLen, bitLen)
}
if pub.N.Bit(0) == 0 {
return ErrRSAModulusIsEven
}
if pub.E != allowedRSAPublicExponentValue {
return fmt.Errorf("%w: %d", ErrUnsupportedRSAPublicExponent, pub.E)
}
return nil
}

0 comments on commit 86cde37

Please sign in to comment.