From b252e1038d3b464d16e349eb8c61b75d214556d4 Mon Sep 17 00:00:00 2001 From: Nick Rosbrook Date: Fri, 23 Aug 2024 11:41:18 -0400 Subject: [PATCH] message: implement "packet" directly in Message type This really does not need to be its own type: it's really just a message header. Absorb this functionality into the Message type directly to consolidate. --- vici/client_conn.go | 44 ++++---- vici/events.go | 32 +++--- vici/message.go | 114 +++++++++++++++++++-- vici/message_test.go | 125 +++++++++++++++++++++++ vici/message_unmarshal_test.go | 68 ++++++------- vici/packet.go | 177 --------------------------------- vici/packet_test.go | 137 ------------------------- vici/session.go | 48 ++++++--- 8 files changed, 345 insertions(+), 400 deletions(-) delete mode 100644 vici/packet.go delete mode 100644 vici/packet_test.go diff --git a/vici/client_conn.go b/vici/client_conn.go index 361e7d3..bd6f962 100644 --- a/vici/client_conn.go +++ b/vici/client_conn.go @@ -25,6 +25,7 @@ import ( "context" "encoding/binary" "errors" + "fmt" "io" "net" "time" @@ -46,7 +47,7 @@ type clientConn struct { conn net.Conn } -func (cc *clientConn) packetWrite(ctx context.Context, p *packet) error { +func (cc *clientConn) packetWrite(ctx context.Context, m *Message) error { if err := cc.conn.SetWriteDeadline(time.Time{}); err != nil { return err } @@ -55,7 +56,7 @@ func (cc *clientConn) packetWrite(ctx context.Context, p *packet) error { case <-ctx.Done(): err := cc.conn.SetWriteDeadline(time.Now()) return errors.Join(err, ctx.Err()) - case err := <-cc.awaitPacketWrite(p): + case err := <-cc.awaitPacketWrite(m): if err != nil { return err } @@ -63,7 +64,7 @@ func (cc *clientConn) packetWrite(ctx context.Context, p *packet) error { } } -func (cc *clientConn) packetRead(ctx context.Context) (*packet, error) { +func (cc *clientConn) packetRead(ctx context.Context) (*Message, error) { if err := cc.conn.SetReadDeadline(time.Time{}); err != nil { return nil, err } @@ -72,21 +73,26 @@ func (cc *clientConn) packetRead(ctx context.Context) (*packet, error) { case <-ctx.Done(): err := cc.conn.SetReadDeadline(time.Now()) return nil, errors.Join(err, ctx.Err()) - case p := <-cc.awaitPacketRead(): - if p.err != nil { - return nil, p.err + case v := <-cc.awaitPacketRead(): + switch v.(type) { + case error: + return nil, v.(error) + case *Message: + return v.(*Message), nil + default: + // This is a programmer error. + return nil, fmt.Errorf("%v: invalid packet read", errEncoding) } - return p, nil } } -func (cc *clientConn) awaitPacketWrite(p *packet) <-chan error { +func (cc *clientConn) awaitPacketWrite(m *Message) <-chan error { r := make(chan error, 1) buf := bytes.NewBuffer([]byte{}) go func() { defer close(r) - b, err := p.bytes() + b, err := m.encode() if err != nil { r <- err return @@ -111,18 +117,17 @@ func (cc *clientConn) awaitPacketWrite(p *packet) <-chan error { return r } -func (cc *clientConn) awaitPacketRead() <-chan *packet { - r := make(chan *packet, 1) +func (cc *clientConn) awaitPacketRead() <-chan any { + r := make(chan any, 1) go func() { defer close(r) - p := &packet{} + m := NewMessage() buf := make([]byte, headerLength) _, err := io.ReadFull(cc.conn, buf) if err != nil { - p.err = err - r <- p + r <- err return } pl := binary.BigEndian.Uint32(buf) @@ -130,13 +135,16 @@ func (cc *clientConn) awaitPacketRead() <-chan *packet { buf = make([]byte, int(pl)) _, err = io.ReadFull(cc.conn, buf) if err != nil { - p.err = err - r <- p + r <- err + return + } + + if err := m.decode(buf); err != nil { + r <- err return } - p.parse(buf) - r <- p + r <- m }() return r diff --git a/vici/events.go b/vici/events.go index 9811c76..1881cbb 100644 --- a/vici/events.go +++ b/vici/events.go @@ -37,7 +37,7 @@ type eventListener struct { // Packet channel used to communicate event registration // results. - pc chan *packet + pc chan *Message muChans sync.Mutex chans map[chan<- Event]struct{} @@ -65,7 +65,7 @@ type Event struct { func newEventListener(cc *clientConn) *eventListener { el := &eventListener{ cc: cc, - pc: make(chan *packet, 4), + pc: make(chan *Message, 4), chans: make(map[chan<- Event]struct{}), } @@ -93,18 +93,18 @@ func (el *eventListener) listen() { defer el.closeAllChans() for { - p, err := el.cc.packetRead(context.Background()) + m, err := el.cc.packetRead(context.Background()) if err != nil { return } ts := time.Now() - switch p.ptype { + switch m.header.ptype { case pktEvent: e := Event{ - Name: p.name, - Message: p.msg, + Name: m.header.name, + Message: m, Timestamp: ts, } @@ -114,7 +114,7 @@ func (el *eventListener) listen() { // requests from the event listener. Forward them over // the packet channel. case pktEventConfirm, pktEventUnknown: - el.pc <- p + el.pc <- m } } } @@ -219,25 +219,33 @@ func (el *eventListener) unregisterEvents(events []string, all bool) error { } func (el *eventListener) eventRequest(ptype uint8, event string) error { - p := newPacket(ptype, event, nil) + m := &Message{ + header: &struct { + ptype uint8 + name string + }{ + ptype: ptype, + name: event, + }, + } - if err := el.cc.packetWrite(context.Background(), p); err != nil { + if err := el.cc.packetWrite(context.Background(), m); err != nil { return err } // The response packet is read by listen(), and written over pc. - p, ok := <-el.pc + m, ok := <-el.pc if !ok { return io.ErrClosedPipe } - switch p.ptype { + switch m.header.ptype { case pktEventConfirm: return nil case pktEventUnknown: return fmt.Errorf("%v: %v", errEventUnknown, event) default: - return fmt.Errorf("%v: %v", errUnexpectedResponse, p.ptype) + return fmt.Errorf("%v: %v", errUnexpectedResponse, m.header.ptype) } } diff --git a/vici/message.go b/vici/message.go index dd1df55..1800a96 100644 --- a/vici/message.go +++ b/vici/message.go @@ -52,6 +52,32 @@ const ( msgListEnd ) +const ( + // A name request message + pktCmdRequest uint8 = iota + + // An unnamed response message for a request + pktCmdResponse + + // An unnamed response if requested command is unknown + pktCmdUnkown + + // A named event registration request + pktEventRegister + + // A name event deregistration request + pktEventUnregister + + // An unnamed response for successful event (de-)registration + pktEventConfirm + + // An unnamed response if event (de-)registration failed + pktEventUnknown + + // A named event message + pktEvent +) + var ( // Generic encoding/decoding and marshaling/unmarshaling errors errEncoding = errors.New("vici: error encoding message") @@ -69,6 +95,7 @@ var ( errMalformedMessage = errors.New("vici: malformed message") // Malformed message errors + errBadName = fmt.Errorf("%v: expected name length does not match actual length", errDecoding) errBadKey = fmt.Errorf("%v: expected key length does not match actual length", errMalformedMessage) errBadValue = fmt.Errorf("%v: expected value length does not match actual length", errMalformedMessage) errEndOfBuffer = fmt.Errorf("%v: unexpected end of buffer", errMalformedMessage) @@ -97,16 +124,21 @@ var ( // for convenience, and may have rules on how they are converted to an appropriate internal message // element type. See Message.Set and MarshalMessage for details. type Message struct { + // Packet header. Set only for reading and writing message packets. + header *struct { + ptype uint8 + name string + } keys []string - data map[string]any } // NewMessage returns an empty Message. func NewMessage() *Message { return &Message{ - keys: make([]string, 0), - data: make(map[string]any), + header: nil, + keys: make([]string, 0), + data: make(map[string]any), } } @@ -225,6 +257,33 @@ func (m *Message) Err() error { return nil } +// packetIsNamed returns a bool indicating the packet is a named type +func (m *Message) packetIsNamed() bool { + if m.header == nil { + return false + } + + switch m.header.ptype { + case /* Named packet types */ + pktCmdRequest, + pktEventRegister, + pktEventUnregister, + pktEvent: + + return true + + case /* Un-named packet types */ + pktCmdResponse, + pktCmdUnkown, + pktEventConfirm, + pktEventUnknown: + + return false + } + + return false +} + func (m *Message) addItem(key string, value any) error { // Check if the key is already set in the message _, exists := m.data[key] @@ -317,6 +376,24 @@ func safePutUint32(buf *bytes.Buffer, val int) error { func (m *Message) encode() ([]byte, error) { buf := bytes.NewBuffer([]byte{}) + if m.header != nil { + if err := buf.WriteByte(m.header.ptype); err != nil { + return nil, fmt.Errorf("%v: %v", errEncoding, err) + } + + if m.packetIsNamed() { + err := safePutUint8(buf, len(m.header.name)) + if err != nil { + return nil, fmt.Errorf("%v: %v", errEncoding, err) + } + + _, err = buf.WriteString(m.header.name) + if err != nil { + return nil, fmt.Errorf("%v: %v", errEncoding, err) + } + } + } + for k, v := range m.elements() { rv := reflect.ValueOf(v) @@ -369,14 +446,38 @@ func (m *Message) encode() ([]byte, error) { } func (m *Message) decode(data []byte) error { + m.header = &struct { + ptype uint8 + name string + }{} buf := bytes.NewBuffer(data) + // Parse the message header first. b, err := buf.ReadByte() - if err != nil && err != io.EOF { + if err != nil { return fmt.Errorf("%v: %v", errDecoding, err) } + m.header.ptype = b + + if m.packetIsNamed() { + l, err := buf.ReadByte() + if err != nil { + return fmt.Errorf("%v: %v", errDecoding, err) + } + + if name := buf.Next(int(l)); len(name) != int(l) { + return errBadName + } else { + m.header.name = string(name) + } + } for buf.Len() > 0 { + b, err = buf.ReadByte() + if err != nil && err != io.EOF { + return fmt.Errorf("%v: %v", errDecoding, err) + } + // Determine the next message element switch b { case msgKeyValue: @@ -400,11 +501,6 @@ func (m *Message) decode(data []byte) error { } buf.Next(n) } - - b, err = buf.ReadByte() - if err != nil && err != io.EOF { - return fmt.Errorf("%v: %v", errDecoding, err) - } } return nil diff --git a/vici/message_test.go b/vici/message_test.go index 9d80ea5..46ab89c 100644 --- a/vici/message_test.go +++ b/vici/message_test.go @@ -29,8 +29,66 @@ import ( ) var ( + goldNamedPacket = &Message{ + header: &struct { + ptype uint8 + name string + }{ + ptype: pktCmdRequest, + name: "install", + }, + keys: []string{"child", "ike"}, + data: map[string]any{ + "child": "test-CHILD_SA", + "ike": "test-IKE_SA", + }, + } + + goldNamedPacketBytes = []byte{ + // Packet type + 0, + // Length of "install" + 7, + // "install" in bytes + 105, 110, 115, 116, 97, 108, 108, + // Encoded message bytes + 3, 5, 99, 104, 105, 108, 100, 0, 13, 116, 101, 115, 116, + 45, 67, 72, 73, 76, 68, 95, 83, 65, 3, 3, 105, 107, 101, + 0, 11, 116, 101, 115, 116, 45, 73, 75, 69, 95, 83, 65, + } + + goldUnnamedPacket = &Message{ + header: &struct { + ptype uint8 + name string + }{ + ptype: pktCmdResponse, + }, + keys: []string{"success", "errmsg"}, + data: map[string]any{ + "success": "no", + "errmsg": "failed to install CHILD_SA", + }, + } + + goldUnnamedPacketBytes = []byte{ + // Packet type + 1, + // Encoded message bytes + 3, 7, 115, 117, 99, 99, 101, 115, 115, 0, 2, 110, 111, 3, 6, + 101, 114, 114, 109, 115, 103, 0, 26, 102, 97, 105, 108, 101, + 100, 32, 116, 111, 32, 105, 110, 115, 116, 97, 108, 108, 32, + 67, 72, 73, 76, 68, 95, 83, 65, + } + // Gold message goldMessage = &Message{ + header: &struct { + ptype uint8 + name string + }{ + ptype: pktCmdResponse, + }, keys: []string{"key1", "section1"}, data: map[string]any{ "key1": "value1", @@ -53,6 +111,8 @@ var ( // Expected byte stream from encoding testMessage goldMessageBytes = []byte{ + // pktCmdResponse + 1, // key1 = value1 3, 4, 'k', 'e', 'y', '1', 0, 6, 'v', 'a', 'l', 'u', 'e', '1', // section1 @@ -123,6 +183,71 @@ var ( } ) +func TestPacketParse(t *testing.T) { + m := NewMessage() + + if err := m.decode(goldNamedPacketBytes); err != nil { + t.Fatalf("Error parsing packet: %v", err) + } + + if !reflect.DeepEqual(m, goldNamedPacket) { + t.Fatalf("Parsed named packet does not equal gold packet.\nExpected: %v\nReceived: %v", goldNamedPacket, m) + } + + m = NewMessage() + + if err := m.decode(goldUnnamedPacketBytes); err != nil { + t.Fatalf("Error parsing packet: %v", err) + } + + if !reflect.DeepEqual(m, goldUnnamedPacket) { + t.Fatalf("Parsed unnamed packet does not equal gold packet.\nExpected: %v\nReceived: %v", goldUnnamedPacket, m) + } +} + +func TestPacketBytes(t *testing.T) { + b, err := goldNamedPacket.encode() + if err != nil { + t.Fatalf("Unexpected error getting packet bytes: %v", err) + } + + if !bytes.Equal(b, goldNamedPacketBytes) { + t.Fatalf("Encoded packet does not equal gold bytes.\nExpected: %v\nReceived: %v", goldNamedPacketBytes, b) + } + + b, err = goldUnnamedPacket.encode() + if err != nil { + t.Fatalf("Unexpected error getting packet bytes: %v", err) + } + + if !bytes.Equal(b, goldUnnamedPacketBytes) { + t.Fatalf("Encoded packet does not equal gold bytes.\nExpected: %v\nReceived: %v", goldUnnamedPacketBytes, b) + } +} + +func TestPacketTooLong(t *testing.T) { + tooLong := make([]byte, 256) + + for i := range tooLong { + tooLong[i] = 'a' + } + + m := &Message{ + header: &struct { + ptype uint8 + name string + }{ + ptype: pktCmdRequest, + name: string(tooLong), + }, + } + + _, err := m.encode() + if err == nil { + t.Fatalf("Expected packet-too-long error due to %s", m.header.name) + } +} + type testMessage struct { Key string `vici:"key"` Empty string `vici:"empty"` diff --git a/vici/message_unmarshal_test.go b/vici/message_unmarshal_test.go index 719cc1f..cabb105 100644 --- a/vici/message_unmarshal_test.go +++ b/vici/message_unmarshal_test.go @@ -32,8 +32,8 @@ func TestUnmarshalBoolTrue(t *testing.T) { } m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "yes", }, } @@ -56,8 +56,8 @@ func TestUnmarshalBoolFalse(t *testing.T) { } m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "no", }, } @@ -80,8 +80,8 @@ func TestUnmarshalBoolInvalid(t *testing.T) { } m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "invalid-not-a-bool", }, } @@ -100,8 +100,8 @@ func TestUnmarshalBoolTruePtr(t *testing.T) { } m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "yes", }, } @@ -128,8 +128,8 @@ func TestUnmarshalBoolFalsePtr(t *testing.T) { } m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "no", }, } @@ -156,8 +156,8 @@ func TestUnmarshalInt(t *testing.T) { } m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "23", }, } @@ -180,8 +180,8 @@ func TestUnmarshalInt2(t *testing.T) { } m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "-23", }, } @@ -204,8 +204,8 @@ func TestUnmarshalInt8(t *testing.T) { } m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "23", }, } @@ -228,8 +228,8 @@ func TestUnmarshalInt8Overflow(t *testing.T) { } m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "1001", }, } @@ -252,8 +252,8 @@ func TestUnmarshalUint(t *testing.T) { } m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "23", }, } @@ -276,8 +276,8 @@ func TestUnmarshalUintInvalid(t *testing.T) { } m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "-1", }, } @@ -297,8 +297,8 @@ func TestUnmarshalEnumType(t *testing.T) { }{} m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "test-value", }, } @@ -325,11 +325,11 @@ func TestUnmarshalEmbeddedStruct(t *testing.T) { }{} m := &Message{ - []string{"embedded"}, - map[string]any{ + keys: []string{"embedded"}, + data: map[string]any{ "embedded": &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": testValue, }, }, @@ -358,8 +358,8 @@ func TestUnmarshalInline(t *testing.T) { }{} m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": testValue, }, } @@ -380,8 +380,8 @@ func TestUnmarshalInlineInvalidType(t *testing.T) { }{} m := &Message{ - []string{"field"}, - map[string]any{ + keys: []string{"field"}, + data: map[string]any{ "field": "test-value", }, } @@ -406,8 +406,8 @@ func TestUnmarshalInlineComposite(t *testing.T) { }{} m := &Message{ - []string{"field", "other"}, - map[string]any{ + keys: []string{"field", "other"}, + data: map[string]any{ "field": testValue, "other": otherValue, }, diff --git a/vici/packet.go b/vici/packet.go deleted file mode 100644 index c6649f1..0000000 --- a/vici/packet.go +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright (C) 2019 Nick Rosbrook -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -package vici - -import ( - "bytes" - "errors" - "fmt" -) - -const ( - // A name request message - pktCmdRequest uint8 = iota - - // An unnamed response message for a request - pktCmdResponse - - // An unnamed response if requested command is unknown - pktCmdUnkown - - // A named event registration request - pktEventRegister - - // A name event deregistration request - pktEventUnregister - - // An unnamed response for successful event (de-)registration - pktEventConfirm - - // An unnamed response if event (de-)registration failed - pktEventUnknown - - // A named event message - pktEvent -) - -var ( - // Generic packet writing error - errPacketWrite = errors.New("vici: error writing packet") - - // Generic packet parsing error - errPacketParse = errors.New("vici: error parsing packet") - - errBadName = fmt.Errorf("%v: expected name length does not match actual length", errPacketParse) -) - -// A packet has a required type (an 8-bit identifier), a name (only required for named types), -// and and an optional message field. -type packet struct { - ptype uint8 - name string - - msg *Message - err error -} - -func newPacket(ptype uint8, name string, msg *Message) *packet { - return &packet{ - ptype: ptype, - name: name, - msg: msg, - } -} - -// isNamed returns a bool indicating the packet is a named type -func (p *packet) isNamed() bool { - switch p.ptype { - case /* Named packet types */ - pktCmdRequest, - pktEventRegister, - pktEventUnregister, - pktEvent: - - return true - - case /* Un-named packet types */ - pktCmdResponse, - pktCmdUnkown, - pktEventConfirm, - pktEventUnknown: - - return false - } - - return false -} - -// bytes formats the packet and returns it as a byte slice -func (p *packet) bytes() ([]byte, error) { - // Create a new buffer with the first byte indicating the packet type - buf := bytes.NewBuffer([]byte{p.ptype}) - - // Write the name, preceded by its length - if p.isNamed() { - err := safePutUint8(buf, len(p.name)) - if err != nil { - return nil, fmt.Errorf("%v: %v", errPacketWrite, err) - } - - _, err = buf.WriteString(p.name) - if err != nil { - return nil, fmt.Errorf("%v: %v", errPacketWrite, err) - } - } - - if p.msg != nil { - b, err := p.msg.encode() - if err != nil { - return nil, err - } - - _, err = buf.Write(b) - if err != nil { - return nil, fmt.Errorf("%v: %v", errPacketWrite, err) - } - } - - return buf.Bytes(), nil -} - -// parse will parse the given bytes and populate its fields with that data. If -// there is an error, then the error filed will be set. -func (p *packet) parse(data []byte) { - buf := bytes.NewBuffer(data) - - // Read the packet type - b, err := buf.ReadByte() - if err != nil { - p.err = fmt.Errorf("%v: %v", errPacketParse, err) - return - } - p.ptype = b - - if p.isNamed() { - // Get the length of the name - l, err := buf.ReadByte() - if err != nil { - p.err = fmt.Errorf("%v: %v", errPacketParse, err) - return - } - - // Read the name - name := buf.Next(int(l)) - if len(name) != int(l) { - p.err = errBadName - return - } - p.name = string(name) - } - - // Decode the message field - m := NewMessage() - err = m.decode(buf.Bytes()) - if err != nil { - p.err = err - return - } - p.msg = m -} diff --git a/vici/packet_test.go b/vici/packet_test.go deleted file mode 100644 index c16301d..0000000 --- a/vici/packet_test.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright (C) 2019 Nick Rosbrook -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -package vici - -import ( - "bytes" - "reflect" - "testing" -) - -var ( - goldNamedPacket = &packet{ - ptype: pktCmdRequest, - name: "install", - msg: &Message{ - keys: []string{"child", "ike"}, - data: map[string]any{ - "child": "test-CHILD_SA", - "ike": "test-IKE_SA", - }, - }, - } - - goldNamedPacketBytes = []byte{ - // Packet type - 0, - // Length of "install" - 7, - // "install" in bytes - 105, 110, 115, 116, 97, 108, 108, - // Encoded message bytes - 3, 5, 99, 104, 105, 108, 100, 0, 13, 116, 101, 115, 116, - 45, 67, 72, 73, 76, 68, 95, 83, 65, 3, 3, 105, 107, 101, - 0, 11, 116, 101, 115, 116, 45, 73, 75, 69, 95, 83, 65, - } - - goldUnnamedPacket = &packet{ - ptype: pktCmdResponse, - msg: &Message{ - keys: []string{"success", "errmsg"}, - data: map[string]any{ - "success": "no", - "errmsg": "failed to install CHILD_SA", - }, - }, - } - - goldUnnamedPacketBytes = []byte{ - // Packet type - 1, - // Encoded message bytes - 3, 7, 115, 117, 99, 99, 101, 115, 115, 0, 2, 110, 111, 3, 6, - 101, 114, 114, 109, 115, 103, 0, 26, 102, 97, 105, 108, 101, - 100, 32, 116, 111, 32, 105, 110, 115, 116, 97, 108, 108, 32, - 67, 72, 73, 76, 68, 95, 83, 65, - } -) - -func TestPacketParse(t *testing.T) { - p := &packet{} - p.parse(goldNamedPacketBytes) - - if p.err != nil { - t.Fatalf("Error parsing packet: %v", p.err) - } - - if !reflect.DeepEqual(p, goldNamedPacket) { - t.Fatalf("Parsed named packet does not equal gold packet.\nExpected: %v\nReceived: %v", goldNamedPacket, p) - } - - p = &packet{} - p.parse(goldUnnamedPacketBytes) - - if p.err != nil { - t.Fatalf("Error parsing packet: %v", p.err) - } - - if !reflect.DeepEqual(p, goldUnnamedPacket) { - t.Fatalf("Parsed unnamed packet does not equal gold packet.\nExpected: %v\nReceived: %v", goldUnnamedPacket, p) - } -} - -func TestPacketBytes(t *testing.T) { - b, err := goldNamedPacket.bytes() - if err != nil { - t.Fatalf("Unexpected error getting packet bytes: %v", err) - } - - if !bytes.Equal(b, goldNamedPacketBytes) { - t.Fatalf("Encoded packet does not equal gold bytes.\nExpected: %v\nReceived: %v", goldNamedPacketBytes, b) - } - - b, err = goldUnnamedPacket.bytes() - if err != nil { - t.Fatalf("Unexpected error getting packet bytes: %v", err) - } - - if !bytes.Equal(b, goldUnnamedPacketBytes) { - t.Fatalf("Encoded packet does not equal gold bytes.\nExpected: %v\nReceived: %v", goldUnnamedPacketBytes, b) - } -} - -func TestPacketTooLong(t *testing.T) { - tooLong := make([]byte, 256) - - for i := range tooLong { - tooLong[i] = 'a' - } - - p := &packet{ - ptype: pktCmdRequest, - name: string(tooLong), - } - - _, err := p.bytes() - if err == nil { - t.Fatalf("Expected packet-too-long error due to %s", p.name) - } -} diff --git a/vici/session.go b/vici/session.go index 3364bad..8636dae 100644 --- a/vici/session.go +++ b/vici/session.go @@ -285,7 +285,15 @@ func (s *Session) CallStreaming(ctx context.Context, cmd string, event string, i } }() - if err := s.cc.packetWrite(ctx, newPacket(pktCmdRequest, cmd, in)); err != nil { + in.header = &struct { + ptype uint8 + name string + }{ + ptype: pktCmdRequest, + name: cmd, + } + + if err := s.cc.packetWrite(ctx, in); err != nil { return nil, err } @@ -301,16 +309,16 @@ func (s *Session) CallStreaming(ctx context.Context, cmd string, event string, i return } - switch p.ptype { + switch p.header.ptype { case pktEvent: - if !yield(p.msg, p.msg.Err()) { + if !yield(p, p.Err()) { return } case pktCmdResponse: - yield(p.msg, p.msg.Err()) + yield(p, p.Err()) return // End of event stream default: - yield(nil, fmt.Errorf("%v: %v", errUnexpectedResponse, p.ptype)) + yield(nil, fmt.Errorf("%v: %v", errUnexpectedResponse, p.header.ptype)) return } } @@ -361,9 +369,15 @@ func (s *Session) StopEvents(c chan<- Event) { } func (s *Session) request(ctx context.Context, cmd string, in *Message) (*Message, error) { - p := newPacket(pktCmdRequest, cmd, in) + in.header = &struct { + ptype uint8 + name string + }{ + ptype: pktCmdRequest, + name: cmd, + } - if err := s.cc.packetWrite(ctx, p); err != nil { + if err := s.cc.packetWrite(ctx, in); err != nil { return nil, err } @@ -372,15 +386,23 @@ func (s *Session) request(ctx context.Context, cmd string, in *Message) (*Messag return nil, err } - if p.ptype != pktCmdResponse { - return nil, fmt.Errorf("%v: %v", errUnexpectedResponse, p.ptype) + if p.header.ptype != pktCmdResponse { + return nil, fmt.Errorf("%v: %v", errUnexpectedResponse, p.header.ptype) } - return p.msg, nil + return p, nil } func (s *Session) eventRequest(ctx context.Context, ptype uint8, event string) error { - p := newPacket(ptype, event, nil) + p := &Message{ + header: &struct { + ptype uint8 + name string + }{ + ptype: ptype, + name: event, + }, + } if err := s.cc.packetWrite(ctx, p); err != nil { return err @@ -391,13 +413,13 @@ func (s *Session) eventRequest(ctx context.Context, ptype uint8, event string) e return err } - switch p.ptype { + switch p.header.ptype { case pktEventConfirm: return nil case pktEventUnknown: return fmt.Errorf("%v: %v", errEventUnknown, event) default: - return fmt.Errorf("%v: %v", errUnexpectedResponse, p.ptype) + return fmt.Errorf("%v: %v", errUnexpectedResponse, p.header.ptype) } }