From bc11d1ca4db672071c43bb671e4329d5fedbd99d Mon Sep 17 00:00:00 2001 From: Alex Richards Date: Mon, 22 Jul 2024 14:31:00 +1200 Subject: [PATCH] feat: support detached payloads in COSESign and COSESign1 --- bench_test.go | 180 +++++++++++++++++----------------- errors.go | 1 + sign.go | 58 +++++++++-- sign1.go | 110 ++++++++++++++++++--- sign1_test.go | 2 +- sign_test.go | 260 ++++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 498 insertions(+), 113 deletions(-) diff --git a/bench_test.go b/bench_test.go index 9a95f96..a60d1aa 100644 --- a/bench_test.go +++ b/bench_test.go @@ -1,90 +1,90 @@ -package cose_test - -import ( - "io" - "testing" - - "github.com/veraison/go-cose" -) - -func newSign1Message() *cose.Sign1Message { - return &cose.Sign1Message{ - Headers: cose.Headers{ - Protected: cose.ProtectedHeader{ - cose.HeaderLabelAlgorithm: cose.AlgorithmES256, - }, - Unprotected: cose.UnprotectedHeader{ - cose.HeaderLabelKeyID: []byte{0x01}, - }, - }, - Payload: make([]byte, 100), - Signature: make([]byte, 32), - } -} - -type noSigner struct{} - -func (noSigner) Algorithm() cose.Algorithm { - return cose.AlgorithmES256 -} - -func (noSigner) Sign(_ io.Reader, digest []byte) ([]byte, error) { - return digest, nil -} - -func (noSigner) Verify(_, _ []byte) error { - return nil -} - -func BenchmarkSign1Message_MarshalCBOR(b *testing.B) { - msg := newSign1Message() - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := msg.MarshalCBOR() - if err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkSign1Message_UnmarshalCBOR(b *testing.B) { - data, err := newSign1Message().MarshalCBOR() - if err != nil { - b.Fatal(err) - } - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - var m cose.Sign1Message - err = m.UnmarshalCBOR(data) - if err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkSign1Message_Sign(b *testing.B) { - msg := newSign1Message() - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - msg.Signature = nil - err := msg.Sign(zeroSource{}, nil, noSigner{}) - if err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkSign1Message_Verify(b *testing.B) { - msg := newSign1Message() - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - err := msg.Verify(nil, noSigner{}) - if err != nil { - b.Fatal(err) - } - } -} +package cose_test + +import ( + "io" + "testing" + + "github.com/veraison/go-cose" +) + +func newSign1Message() *cose.Sign1Message { + return &cose.Sign1Message{ + Headers: cose.Headers{ + Protected: cose.ProtectedHeader{ + cose.HeaderLabelAlgorithm: cose.AlgorithmES256, + }, + Unprotected: cose.UnprotectedHeader{ + cose.HeaderLabelKeyID: []byte{0x01}, + }, + }, + Payload: make([]byte, 100), + Signature: make([]byte, 32), + } +} + +type noSigner struct{} + +func (noSigner) Algorithm() cose.Algorithm { + return cose.AlgorithmES256 +} + +func (noSigner) Sign(_ io.Reader, digest []byte) ([]byte, error) { + return digest, nil +} + +func (noSigner) Verify(_, _ []byte) error { + return nil +} + +func BenchmarkSign1Message_MarshalCBOR(b *testing.B) { + msg := newSign1Message() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := msg.MarshalCBOR() + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign1Message_UnmarshalCBOR(b *testing.B) { + data, err := newSign1Message().MarshalCBOR() + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var m cose.Sign1Message + err = m.UnmarshalCBOR(data) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign1Message_Sign(b *testing.B) { + msg := newSign1Message() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + msg.Signature = nil + err := msg.Sign(zeroSource{}, nil, noSigner{}) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSign1Message_Verify(b *testing.B) { + msg := newSign1Message() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := msg.Verify(nil, noSigner{}) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/errors.go b/errors.go index 770dc9e..88bd0b7 100644 --- a/errors.go +++ b/errors.go @@ -10,6 +10,7 @@ var ( ErrEmptySignature = errors.New("empty signature") ErrInvalidAlgorithm = errors.New("invalid algorithm") ErrMissingPayload = errors.New("missing payload") + ErrMultiplePayloads = errors.New("multiple payloads") ErrNoSignatures = errors.New("no signatures attached") ErrUnavailableHashFunc = errors.New("hash function is not available") ErrVerification = errors.New("verification error") diff --git a/sign.go b/sign.go index a2bb6c0..8700ae2 100644 --- a/sign.go +++ b/sign.go @@ -398,12 +398,17 @@ func (m *SignMessage) UnmarshalCBOR(data []byte) error { // Notice: The COSE Sign API is EXPERIMENTAL and may be changed or removed in a // later release. func (m *SignMessage) Sign(rand io.Reader, external []byte, signers ...Signer) error { + return m.sign(rand, external, nil, signers) +} + +func (m *SignMessage) SignDetached(rand io.Reader, external, detachedPayload []byte, signers ...Signer) error { + return m.sign(rand, external, detachedPayload, signers) +} + +func (m *SignMessage) sign(rand io.Reader, external, detachedPayload []byte, signers []Signer) error { if m == nil { return errors.New("signing nil SignMessage") } - if m.Payload == nil { - return ErrMissingPayload - } switch len(m.Signatures) { case 0: return ErrNoSignatures @@ -413,16 +418,21 @@ func (m *SignMessage) Sign(rand io.Reader, external []byte, signers ...Signer) e return fmt.Errorf("%d signers for %d signatures", len(signers), len(m.Signatures)) } + payload, err := resolvePayload(m.Payload, detachedPayload) + if err != nil { + return err + } + // populate common parameters var protected cbor.RawMessage - protected, err := m.Headers.MarshalProtected() + protected, err = m.Headers.MarshalProtected() if err != nil { return err } // sign message accordingly for i, signature := range m.Signatures { - if err := signature.Sign(rand, signers[i], protected, m.Payload, external); err != nil { + if err := signature.Sign(rand, signers[i], protected, payload, external); err != nil { return err } } @@ -443,12 +453,17 @@ func (m *SignMessage) Sign(rand io.Reader, external []byte, signers ...Signer) e // Notice: The COSE Sign API is EXPERIMENTAL and may be changed or removed in a // later release. func (m *SignMessage) Verify(external []byte, verifiers ...Verifier) error { + return m.verify(external, nil, verifiers...) +} + +func (m *SignMessage) VerifyDetached(external, detachedPayload []byte, verifiers ...Verifier) error { + return m.verify(external, detachedPayload, verifiers...) +} + +func (m *SignMessage) verify(external, detachedPayload []byte, verifiers ...Verifier) error { if m == nil { return errors.New("verifying nil SignMessage") } - if m.Payload == nil { - return ErrMissingPayload - } switch len(m.Signatures) { case 0: return ErrNoSignatures @@ -458,18 +473,41 @@ func (m *SignMessage) Verify(external []byte, verifiers ...Verifier) error { return fmt.Errorf("%d verifiers for %d signatures", len(verifiers), len(m.Signatures)) } + payload, err := resolvePayload(m.Payload, detachedPayload) + if err != nil { + return err + } + // populate common parameters var protected cbor.RawMessage - protected, err := m.Headers.MarshalProtected() + protected, err = m.Headers.MarshalProtected() if err != nil { return err } // verify message accordingly for i, signature := range m.Signatures { - if err := signature.Verify(verifiers[i], protected, m.Payload, external); err != nil { + if err := signature.Verify(verifiers[i], protected, payload, external); err != nil { return err } } return nil } + +func resolvePayload(payloads ...[]byte) ([]byte, error) { + var payload []byte + for _, candidatePayload := range payloads { + if candidatePayload != nil { + if payload == nil { + payload = candidatePayload + } else { + return nil, ErrMultiplePayloads + } + } + } + if payload == nil { + return nil, ErrMissingPayload + } else { + return payload, nil + } +} diff --git a/sign1.go b/sign1.go index e1bd4d0..7e02e46 100644 --- a/sign1.go +++ b/sign1.go @@ -87,26 +87,45 @@ func (m *Sign1Message) UnmarshalCBOR(data []byte) error { // // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 func (m *Sign1Message) Sign(rand io.Reader, external []byte, signer Signer) error { + return m.sign(rand, external, nil, signer) +} + +// Sign signs a Sign1Message using the provided Signer. +// The signature is stored in m.Signature. +// +// Note that m.Signature is only valid as long as m.Headers.Protected +// remains unchanged after calling this method. +// It is possible to modify m.Headers.Unprotected after signing, +// i.e., add counter signatures or timestamps. +// +// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 +func (m *Sign1Message) SignDetached(rand io.Reader, external, detachedPayload []byte, signer Signer) error { + return m.sign(rand, external, detachedPayload, signer) +} + +func (m *Sign1Message) sign(rand io.Reader, external, detachedPayload []byte, signer Signer) error { if m == nil { return errors.New("signing nil Sign1Message") } - if m.Payload == nil { - return ErrMissingPayload - } if len(m.Signature) > 0 { return errors.New("Sign1Message signature already has signature bytes") } + payload, err := resolvePayload(m.Payload, detachedPayload) + if err != nil { + return err + } + // check algorithm if present. // `alg` header MUST be present if there is no externally supplied data. alg := signer.Algorithm() - err := m.Headers.ensureSigningAlgorithm(alg, external) + err = m.Headers.ensureSigningAlgorithm(alg, external) if err != nil { return err } // sign the message - toBeSigned, err := m.toBeSigned(external) + toBeSigned, err := m.toBeSigned(external, payload) if err != nil { return err } @@ -124,26 +143,40 @@ func (m *Sign1Message) Sign(rand io.Reader, external []byte, signer Signer) erro // // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 func (m *Sign1Message) Verify(external []byte, verifier Verifier) error { + return m.verify(external, nil, verifier) +} + +// Verify verifies the signature on the Sign1Message returning nil on success or +// a suitable error if verification fails. +// +// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 +func (m *Sign1Message) VerifyDetached(external, detachedPayload []byte, verifier Verifier) error { + return m.verify(external, detachedPayload, verifier) +} + +func (m *Sign1Message) verify(external, detachedPayload []byte, verifier Verifier) error { if m == nil { return errors.New("verifying nil Sign1Message") } - if m.Payload == nil { - return ErrMissingPayload - } if len(m.Signature) == 0 { return ErrEmptySignature } + payload, err := resolvePayload(m.Payload, detachedPayload) + if err != nil { + return err + } + // check algorithm if present. // `alg` header MUST present if there is no externally supplied data. alg := verifier.Algorithm() - err := m.Headers.ensureVerificationAlgorithm(alg, external) + err = m.Headers.ensureVerificationAlgorithm(alg, external) if err != nil { return err } // verify the message - toBeSigned, err := m.toBeSigned(external) + toBeSigned, err := m.toBeSigned(external, payload) if err != nil { return err } @@ -153,7 +186,7 @@ func (m *Sign1Message) Verify(external []byte, verifier Verifier) error { // toBeSigned constructs Sig_structure, computes and returns ToBeSigned. // // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 -func (m *Sign1Message) toBeSigned(external []byte) ([]byte, error) { +func (m *Sign1Message) toBeSigned(external []byte, payload []byte) ([]byte, error) { // create a Sig_structure and populate it with the appropriate fields. // // Sig_structure = [ @@ -178,7 +211,7 @@ func (m *Sign1Message) toBeSigned(external []byte) ([]byte, error) { "Signature1", // context protected, // body_protected external, // external_aad - m.Payload, // payload + payload, // payload } // create the value ToBeSigned by encoding the Sig_structure to a byte @@ -250,6 +283,22 @@ func Sign1(rand io.Reader, signer Signer, headers Headers, payload []byte, exter return msg.MarshalCBOR() } +// Sign1 signs a Sign1Message using the provided Signer. +// +// This method is a wrapper of `Sign1Message.SignDetached()`. +// +// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 +func Sign1Detached(rand io.Reader, signer Signer, headers Headers, detachedPayload []byte, external []byte) ([]byte, error) { + msg := Sign1Message{ + Headers: headers, + } + err := msg.SignDetached(rand, external, detachedPayload, signer) + if err != nil { + return nil, err + } + return msg.MarshalCBOR() +} + type UntaggedSign1Message Sign1Message // MarshalCBOR encodes UntaggedSign1Message into a COSE_Sign1 object. @@ -293,6 +342,19 @@ func (m *UntaggedSign1Message) Sign(rand io.Reader, external []byte, signer Sign return (*Sign1Message)(m).Sign(rand, external, signer) } +// Sign signs an UnttaggedSign1Message using the provided Signer. +// The signature is stored in m.Signature. +// +// Note that m.Signature is only valid as long as m.Headers.Protected +// remains unchanged after calling this method. +// It is possible to modify m.Headers.Unprotected after signing, +// i.e., add counter signatures or timestamps. +// +// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 +func (m *UntaggedSign1Message) SignDetached(rand io.Reader, external, detachedPayload []byte, signer Signer) error { + return (*Sign1Message)(m).SignDetached(rand, external, detachedPayload, signer) +} + // Verify verifies the signature on the UntaggedSign1Message returning nil on success or // a suitable error if verification fails. // @@ -301,6 +363,14 @@ func (m *UntaggedSign1Message) Verify(external []byte, verifier Verifier) error return (*Sign1Message)(m).Verify(external, verifier) } +// Verify verifies the signature on the UntaggedSign1Message returning nil on success or +// a suitable error if verification fails. +// +// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 +func (m *UntaggedSign1Message) VerifyDetached(external, detachedPayload []byte, verifier Verifier) error { + return (*Sign1Message)(m).VerifyDetached(external, detachedPayload, verifier) +} + // Sign1Untagged signs an UntaggedSign1Message using the provided Signer. // // This method is a wrapper of `UntaggedSign1Message.Sign()`. @@ -317,3 +387,19 @@ func Sign1Untagged(rand io.Reader, signer Signer, headers Headers, payload []byt } return msg.MarshalCBOR() } + +// Sign1Untagged signs an UntaggedSign1Message using the provided Signer. +// +// This method is a wrapper of `UntaggedSign1Message.SignDetached()`. +// +// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-4.4 +func Sign1UntaggedDetached(rand io.Reader, signer Signer, headers Headers, detachedPayload []byte, external []byte) ([]byte, error) { + msg := UntaggedSign1Message{ + Headers: headers, + } + err := msg.SignDetached(rand, external, detachedPayload, signer) + if err != nil { + return nil, err + } + return msg.MarshalCBOR() +} diff --git a/sign1_test.go b/sign1_test.go index 8bf2054..9260cfa 100644 --- a/sign1_test.go +++ b/sign1_test.go @@ -967,7 +967,7 @@ func TestSign1Message_toBeSigned(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.m.toBeSigned(tt.external) + got, err := tt.m.toBeSigned(tt.external, tt.m.Payload) if (err != nil) != tt.wantErr { t.Errorf("Sign1Message.toBeSigned() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/sign_test.go b/sign_test.go index 4b1f701..fd20d09 100644 --- a/sign_test.go +++ b/sign_test.go @@ -1968,6 +1968,133 @@ func TestSignMessage_Sign(t *testing.T) { }) } + // detached payloads + detachedTests := []struct { + name string + msg *SignMessage + detachedPayload []byte + wantErr string + }{ + { + name: "valid message", + msg: &SignMessage{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelContentType: "text/plain", + }, + Unprotected: UnprotectedHeader{ + "extra": "test", + }, + }, + Signatures: []*Signature{ + { + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES256, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelKeyID: []byte("42"), + }, + }, + }, + { + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES512, + }, + }, + }, + }, + }, + detachedPayload: []byte("lorem ipsum"), + }, + { + name: "multiple payloads", + msg: &SignMessage{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelContentType: "text/plain", + }, + Unprotected: UnprotectedHeader{ + "extra": "test", + }, + }, + Payload: []byte("lorem ipsum"), + Signatures: []*Signature{ + { + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES256, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelKeyID: []byte("42"), + }, + }, + }, + { + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES512, + }, + }, + }, + }, + }, + detachedPayload: []byte("lorem ipsum"), + wantErr: "multiple payloads", + }, + { + name: "missing payload", + msg: &SignMessage{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelContentType: "text/plain", + }, + Unprotected: UnprotectedHeader{ + "extra": "test", + }, + }, + Signatures: []*Signature{ + { + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES256, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelKeyID: []byte("42"), + }, + }, + }, + { + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES512, + }, + }, + }, + }, + }, + wantErr: "missing payload", + }, + } + for _, tt := range detachedTests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.SignDetached(rand.Reader, nil, tt.detachedPayload, signers...) + if err != nil { + if err.Error() != tt.wantErr { + t.Errorf("SignMessage.Sign() error = %v, wantErr %v", err, tt.wantErr) + } + return + } else if tt.wantErr != "" { + t.Errorf("SignMessage.Sign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err := tt.msg.VerifyDetached(nil, tt.detachedPayload, verifiers...); err != nil { + t.Errorf("SignMessage.Verify() error = %v", err) + } + }) + } + // special cases t.Run("no signer", func(t *testing.T) { msg := &SignMessage{ @@ -2167,6 +2294,77 @@ func TestSignMessage_Verify(t *testing.T) { }) } + // detached payloads + detachedTests := []struct { + name string + detachedPayloadOnSign []byte + detachedPayloadOnVerify []byte + wantErr string + }{ + { + name: "round trip on valid detached message", + detachedPayloadOnSign: []byte("lorem ipsum"), + detachedPayloadOnVerify: []byte("lorem ipsum"), + }, + { + name: "missing payload", + detachedPayloadOnSign: []byte("lorem ipsum"), + wantErr: "missing payload", + }, + { + name: "changes payload", + detachedPayloadOnSign: []byte("lorem ipsum"), + detachedPayloadOnVerify: []byte("lorem ipsum dolor sit amet"), + wantErr: "verification error", + }, + } + for _, tt := range detachedTests { + t.Run(tt.name, func(t *testing.T) { + // generate message and sign + msg := &SignMessage{ + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelContentType: "text/plain", + }, + Unprotected: UnprotectedHeader{ + "extra": "test", + }, + }, + Signatures: []*Signature{ + { + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES256, + }, + Unprotected: UnprotectedHeader{ + HeaderLabelKeyID: []byte("42"), + }, + }, + }, + { + Headers: Headers{ + Protected: ProtectedHeader{ + HeaderLabelAlgorithm: AlgorithmES512, + }, + }, + }, + }, + } + if err := msg.SignDetached(rand.Reader, nil, tt.detachedPayloadOnSign, signers...); err != nil { + t.Errorf("SignMessage.SignDetached() error = %v", err) + return + } + + // verify message + err := msg.VerifyDetached(nil, tt.detachedPayloadOnVerify, verifiers...) + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("SignMessage.VerifyDetached() error = %v, wantErr %v", err, tt.wantErr) + } else if err == nil && (tt.wantErr != "") { + t.Errorf("SignMessage.VerifyDetached() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } + // special cases t.Run("nil payload", func(t *testing.T) { // payload is detached msg := &SignMessage{ @@ -2310,3 +2508,65 @@ func TestSignature_toBeSigned(t *testing.T) { }) } } + +func TestSign_resolvePayload(t *testing.T) { + tests := []struct { + name string + payloads [][]byte + want []byte + wantErr error + }{ + { + name: "nil payloads", + wantErr: ErrMissingPayload, + }, + { + name: "empty payloads", + payloads: [][]byte{}, + wantErr: ErrMissingPayload, + }, + { + name: "single nil payload", + payloads: [][]byte{ + nil, + }, + wantErr: ErrMissingPayload, + }, + { + name: "single payload", + payloads: [][]byte{ + {1}, + }, + want: []byte{1}, + }, + { + name: "single payload with nil payloads", + payloads: [][]byte{ + nil, + {1}, + nil, + }, + want: []byte{1}, + }, + { + name: "multiple payloads", + payloads: [][]byte{ + {1}, + {2}, + }, + wantErr: ErrMultiplePayloads, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := resolvePayload(tt.payloads...) + if err != tt.wantErr { + t.Fatalf("resolvePayload: err = %v, wantErr = %v", err, tt.wantErr) + } + if !reflect.DeepEqual(tt.want, got) { + t.Fatalf("resolvePayload: got = %v, want = %v", got, tt.want) + } + }) + } +}