Skip to content

Commit

Permalink
all: re-work client connection to be context-aware
Browse files Browse the repository at this point in the history
The current API does not accept context.Context in functions that may be
blocking on network I/O, and the underlying transport is not written in
a way that this could easily be added.

Re-write the underlying packet connection logic using a new clientConn
type. This new type has similar functionality to the previous transport
type, but is built with context support.

Adapt the session and event listener code to use this new type, but just
pass context.Background() for now. The existing API can be changed to
accept context.Context later, or "Context" variants can be added.

Signed-off-by: Nick Rosbrook <nr@enr0n.net>
  • Loading branch information
enr0n committed Aug 1, 2023
1 parent 51587ca commit c1c16ad
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 308 deletions.
146 changes: 146 additions & 0 deletions vici/client_conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright (C) 2023 Nick Rosbrook
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

package vici

import (
"bytes"
"context"
"encoding/binary"
"errors"
"io"
"net"
"time"
)

const (
headerLength = 4
)

var (
// Received unexpected response from server
errUnexpectedResponse = errors.New("vici: unexpected response type")

// Received EVENT_UNKNOWN from server
errEventUnknown = errors.New("vici: unknown event type")
)

type clientConn struct {
conn net.Conn
}

func (cc *clientConn) packetWrite(ctx context.Context, p *packet) error {
if err := cc.conn.SetWriteDeadline(time.Time{}); err != nil {
return err
}

select {
case <-ctx.Done():
err := cc.conn.SetWriteDeadline(time.Now())
return errors.Join(err, ctx.Err())
case err := <-cc.awaitPacketWrite(p):
if err != nil {
return err
}
return nil
}
}

func (cc *clientConn) packetRead(ctx context.Context) (*packet, error) {
if err := cc.conn.SetReadDeadline(time.Time{}); err != nil {
return nil, err
}

select {
case <-ctx.Done():
err := cc.conn.SetReadDeadline(time.Now())
return nil, errors.Join(err, ctx.Err())
case p := <-cc.awaitPacketRead():
if p.err != nil {
return nil, p.err
}
return p, nil
}
}

func (cc *clientConn) awaitPacketWrite(p *packet) <-chan error {
r := make(chan error, 1)
buf := bytes.NewBuffer([]byte{})

go func() {
defer close(r)
b, err := p.bytes()
if err != nil {
r <- err
return
}

// Write the packet length
pl := make([]byte, headerLength)
binary.BigEndian.PutUint32(pl, uint32(len(b)))
_, err = buf.Write(pl)
if err != nil {
r <- err
return
}

// Write the payload
_, err = buf.Write(b)
if err != nil {
r <- err
return
}
_, err = cc.conn.Write(buf.Bytes())
r <- err
}()

return r
}

func (cc *clientConn) awaitPacketRead() <-chan *packet {
r := make(chan *packet, 1)

go func() {
defer close(r)
p := &packet{}

buf := make([]byte, headerLength)
_, err := io.ReadFull(cc.conn, buf)
if err != nil {
p.err = err
r <- p
return
}
pl := binary.BigEndian.Uint32(buf)

buf = make([]byte, int(pl))
_, err = io.ReadFull(cc.conn, buf)
if err != nil {
p.err = err
r <- p
return
}

p.parse(buf)
r <- p
}()

return r
}
13 changes: 7 additions & 6 deletions vici/transport_test.go → vici/client_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@ package vici

import (
"bytes"
"context"
"encoding/binary"
"net"
"reflect"
"testing"
)

func TestTransportSend(t *testing.T) {
func TestPacketWrite(t *testing.T) {
client, srvr := net.Pipe()
defer client.Close()
defer srvr.Close()

tr := &transport{
cc := &clientConn{
conn: client,
}

Expand Down Expand Up @@ -70,20 +71,20 @@ func TestTransportSend(t *testing.T) {
}
}()

err := tr.send(goldNamedPacket)
err := cc.packetWrite(context.Background(), goldNamedPacket)
if err != nil {
t.Fatalf("Unexpected error sending packet: %v", err)
}

<-done
}

func TestTransportRecv(t *testing.T) {
func TestPacketRead(t *testing.T) {
client, srvr := net.Pipe()
defer client.Close()
defer srvr.Close()

tr := &transport{
cc := &clientConn{
conn: client,
}

Expand All @@ -94,7 +95,7 @@ func TestTransportRecv(t *testing.T) {
go func() {
defer close(done)

p, err := tr.recv()
p, err := cc.packetRead(context.Background())
if err != nil {
t.Errorf("Unexpected error receiving packet: %v", err)
}
Expand Down
66 changes: 31 additions & 35 deletions vici/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
package vici

import (
"context"
"fmt"
"io"
"sync"
"time"
)

type eventListener struct {
*transport
cc *clientConn

// Lock events when registering and unregistering.
mu sync.Mutex
Expand Down Expand Up @@ -61,11 +62,11 @@ type Event struct {
Timestamp time.Time
}

func newEventListener(t *transport) *eventListener {
func newEventListener(cc *clientConn) *eventListener {
el := &eventListener{
transport: t,
pc: make(chan *packet, 4),
chans: make(map[chan<- Event]struct{}),
cc: cc,
pc: make(chan *packet, 4),
chans: make(map[chan<- Event]struct{}),
}

return el
Expand All @@ -80,7 +81,7 @@ func (el *eventListener) Close() error {
return err
}

el.conn.Close()
el.cc.conn.Close()

return nil
}
Expand All @@ -92,7 +93,7 @@ func (el *eventListener) listen() {
defer el.closeAllChans()

for {
p, err := el.recv()
p, err := el.cc.packetRead(context.Background())
if err != nil {
return
}
Expand Down Expand Up @@ -177,7 +178,7 @@ func (el *eventListener) registerEvents(events []string) error {
continue
}

if err := el.eventRegisterUnregister(event, true); err != nil {
if err := el.register(event); err != nil {
return err
}

Expand All @@ -196,13 +197,13 @@ func (el *eventListener) unregisterEvents(events []string, all bool) error {
events = el.events
}

for _, e := range events {
if err := el.eventRegisterUnregister(e, false); err != nil {
for _, event := range events {
if err := el.unregister(event); err != nil {
return err
}

for i, registered := range el.events {
if e != registered {
if event != registered {
continue
}

Expand All @@ -217,38 +218,33 @@ func (el *eventListener) unregisterEvents(events []string, all bool) error {
return nil
}

func (el *eventListener) eventRegisterUnregister(event string, register bool) error {
ptype := pktEventRegister
if !register {
ptype = pktEventUnregister
}
func (el *eventListener) eventRequest(ptype uint8, event string) error {
p := newPacket(ptype, event, nil)

p, err := el.eventTransportCommunicate(newPacket(ptype, event, nil))
if err != nil {
if err := el.cc.packetWrite(context.Background(), p); err != nil {
return err
}

if p.ptype == pktEventUnknown {
return fmt.Errorf("%v: %v", errEventUnknown, event)
// The response packet is read by listen(), and written over pc.
p, ok := <-el.pc
if !ok {
return io.ErrClosedPipe
}

if p.ptype != pktEventConfirm {
return fmt.Errorf("%v:%v", errUnexpectedResponse, p.ptype)
switch p.ptype {
case pktEventConfirm:
return nil
case pktEventUnknown:
return fmt.Errorf("%v: %v", errEventUnknown, event)
default:
return fmt.Errorf("%v: %v", errUnexpectedResponse, p.ptype)
}

return nil
}

func (el *eventListener) eventTransportCommunicate(pkt *packet) (*packet, error) {
err := el.send(pkt)
if err != nil {
return nil, err
}

p, ok := <-el.pc
if !ok {
return nil, io.ErrClosedPipe
}
func (el *eventListener) register(event string) error {
return el.eventRequest(pktEventRegister, event)
}

return p, nil
func (el *eventListener) unregister(event string) error {
return el.eventRequest(pktEventUnregister, event)
}
20 changes: 12 additions & 8 deletions vici/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type packet struct {
name string

msg *Message
err error
}

func newPacket(ptype uint8, name string, msg *Message) *packet {
Expand Down Expand Up @@ -135,28 +136,32 @@ func (p *packet) bytes() ([]byte, error) {
return buf.Bytes(), nil
}

// parse will parse the given bytes and populate its fields with that data
func (p *packet) parse(data []byte) error {
// parse will parse the given bytes and populate its fields with that data. If
// there is an error, then the error filed will be set.
func (p *packet) parse(data []byte) {
buf := bytes.NewBuffer(data)

// Read the packet type
b, err := buf.ReadByte()
if err != nil {
return fmt.Errorf("%v: %v", errPacketParse, err)
p.err = fmt.Errorf("%v: %v", errPacketParse, err)
return
}
p.ptype = b

if p.isNamed() {
// Get the length of the name
l, err := buf.ReadByte()
if err != nil {
return fmt.Errorf("%v: %v", errPacketParse, err)
p.err = fmt.Errorf("%v: %v", errPacketParse, err)
return
}

// Read the name
name := buf.Next(int(l))
if len(name) != int(l) {
return errBadName
p.err = errBadName
return
}
p.name = string(name)
}
Expand All @@ -165,9 +170,8 @@ func (p *packet) parse(data []byte) error {
m := NewMessage()
err = m.decode(buf.Bytes())
if err != nil {
return err
p.err = err
return
}
p.msg = m

return nil
}
Loading

0 comments on commit c1c16ad

Please sign in to comment.