Skip to content

Commit

Permalink
Merge pull request #3 from thomasstevens89/master
Browse files Browse the repository at this point in the history
Fix shutdown method and make it public. Implement configurable timeouts
  • Loading branch information
checksum0 authored Mar 27, 2022
2 parents 2dcd823 + c518c8c commit 8416d48
Showing 1 changed file with 48 additions and 15 deletions.
63 changes: 48 additions & 15 deletions electrum/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ const (
// ProtocolVersion identifies the support protocol version to the remote server
ProtocolVersion = "1.4"

// TCP connection and server request timeout duration.
connTimeout = 30 * time.Second
nl = byte('\n')
nl = byte('\n')
)

var (
Expand Down Expand Up @@ -50,6 +48,7 @@ type Transport interface {
SendMessage([]byte) error
Responses() <-chan []byte
Errors() <-chan error
Close() error
}

// TCPTransport store informations about the TCP transport.
Expand All @@ -60,8 +59,8 @@ type TCPTransport struct {
}

// NewTCPTransport opens a new TCP connection to the remote server.
func NewTCPTransport(addr string) (*TCPTransport, error) {
conn, err := net.DialTimeout("tcp", addr, connTimeout)
func NewTCPTransport(addr string, timeout time.Duration) (*TCPTransport, error) {
conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil {
return nil, err
}
Expand All @@ -78,9 +77,9 @@ func NewTCPTransport(addr string) (*TCPTransport, error) {
}

// NewSSLTransport opens a new SSL connection to the remote server.
func NewSSLTransport(addr string, config *tls.Config) (*TCPTransport, error) {
func NewSSLTransport(addr string, config *tls.Config, timeout time.Duration) (*TCPTransport, error) {
dialer := net.Dialer{
Timeout: connTimeout,
Timeout: timeout,
}
conn, err := tls.DialWithDialer(&dialer, "tcp", addr, config)
if err != nil {
Expand Down Expand Up @@ -136,14 +135,24 @@ func (t *TCPTransport) Errors() <-chan error {
return t.errors
}

func (t *TCPTransport) Close() error {
return t.conn.Close()
}

type container struct {
content []byte
err error
}

type ServerOptions struct {
ConnTimeout time.Duration
ReqTimeout time.Duration
}

// Server stores information about the remote server.
type Server struct {
transport Transport
opts *ServerOptions

handlers map[uint64]chan *container
handlersLock sync.RWMutex
Expand All @@ -158,13 +167,15 @@ type Server struct {
}

// NewServer initialize a new remote server.
func NewServer() *Server {
func NewServer(opts *ServerOptions) *Server {
s := &Server{
handlers: make(map[uint64]chan *container),
pushHandlers: make(map[string][]chan *container),

Error: make(chan error),
quit: make(chan struct{}),

opts: opts,
}

return s
Expand All @@ -176,7 +187,7 @@ func (s *Server) ConnectTCP(addr string) error {
return ErrServerConnected
}

transport, err := NewTCPTransport(addr)
transport, err := NewTCPTransport(addr, s.opts.ConnTimeout)
if err != nil {
return err
}
Expand All @@ -193,7 +204,7 @@ func (s *Server) ConnectSSL(addr string, config *tls.Config) error {
return ErrServerConnected
}

transport, err := NewSSLTransport(addr, config)
transport, err := NewSSLTransport(addr, config, s.opts.ConnTimeout)
if err != nil {
return err
}
Expand Down Expand Up @@ -221,10 +232,18 @@ type response struct {

func (s *Server) listen() {
for {
if s.IsShutdown() {
break
}
if s.transport == nil {
break
}
select {
case <-s.quit:
break
case err := <-s.transport.Errors():
s.Error <- err
s.shutdown()
s.Shutdown()
case bytes := <-s.transport.Responses():
result := &container{
content: bytes,
Expand Down Expand Up @@ -302,6 +321,7 @@ func (s *Server) request(method string, params []interface{}, v interface{}) err

err = s.transport.SendMessage(bytes)
if err != nil {
s.Shutdown()
return err
}

Expand All @@ -314,7 +334,7 @@ func (s *Server) request(method string, params []interface{}, v interface{}) err
var resp *container
select {
case resp = <-c:
case <-time.After(connTimeout):
case <-time.After(s.opts.ReqTimeout):
return ErrTimeout
}

Expand All @@ -336,10 +356,23 @@ func (s *Server) request(method string, params []interface{}, v interface{}) err
return nil
}

func (s *Server) shutdown() {
close(s.quit)

func (s *Server) Shutdown() {
if !s.IsShutdown() {
close(s.quit)
}
if s.transport != nil {
_ = s.transport.Close()
}
s.transport = nil
s.handlers = nil
s.pushHandlers = nil
}

func (s *Server) IsShutdown() bool {
select {
case <-s.quit:
return true
default:
}
return false
}

0 comments on commit 8416d48

Please sign in to comment.