diff --git a/README.md b/README.md index c27b42d..3718d23 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ web socket command execution protocol. It can be thought of as SSH without encryption. It's useful in cases where you want to provide a command exec interface into a remote environment. It's implemented -with WebSocket so it may be used directly by a browser frontend. Its symmetric design satisfies +with WebSocket so it may be used directly by a browser frontend. Its symmetric design satisfies `wsep.Execer` for local and remote execution. ## Examples @@ -54,8 +54,8 @@ go run ./dev/server Start a client: ```sh -go run ./dev/client tty bash -go run ./dev/client notty ls +go run ./dev/client tty --id 1 -- bash +go run ./dev/client notty -- ls -la ``` ### Local performance cost @@ -63,10 +63,10 @@ go run ./dev/client notty ls Local `sh` through a local `wsep` connection ```shell script -$ head -c 100000000 /dev/urandom > /tmp/random; cat /tmp/random | pv | time ./bin/client notty sh -c "cat > /dev/null" +$ head -c 100000000 /dev/urandom > /tmp/random; cat /tmp/random | pv | time ./bin/client notty -- sh -c "cat > /dev/null" 95.4MiB 0:00:00 [ 269MiB/s] [ <=> ] -./bin/client notty sh -c "cat > /dev/null" 0.32s user 0.31s system 31% cpu 2.019 total +./bin/client notty -- sh -c "cat > /dev/null" 0.32s user 0.31s system 31% cpu 2.019 total ``` Local `sh` directly diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index 890ff8b..5b5c0fb 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -3,6 +3,6 @@ FROM golang:1 ENV GOFLAGS="-mod=readonly" ENV CI=true -RUN go get golang.org/x/tools/cmd/goimports -RUN go get golang.org/x/lint/golint -RUN go get github.com/mattn/goveralls +RUN go install golang.org/x/tools/cmd/goimports@latest +RUN go install golang.org/x/lint/golint@latest +RUN go install github.com/mattn/goveralls@latest diff --git a/client.go b/client.go index e47c6bf..4a22270 100644 --- a/client.go +++ b/client.go @@ -27,6 +27,8 @@ func RemoteExecer(conn *websocket.Conn) Execer { // Command represents an external command to be run type Command struct { + // ID allows reconnecting commands that have a TTY. + ID string Command string Args []string TTY bool @@ -39,6 +41,7 @@ type Command struct { func (r remoteExec) Start(ctx context.Context, c Command) (Process, error) { header := proto.ClientStartHeader{ + ID: c.ID, Command: mapToProtoCmd(c), Type: proto.TypeStart, } diff --git a/client_test.go b/client_test.go index 4d03316..77276da 100644 --- a/client_test.go +++ b/client_test.go @@ -49,14 +49,14 @@ func TestRemoteStdin(t *testing.T) { } } -func mockConn(ctx context.Context, t *testing.T) (*websocket.Conn, *httptest.Server) { +func mockConn(ctx context.Context, t *testing.T, options *Options) (*websocket.Conn, *httptest.Server) { mockServerHandler := func(w http.ResponseWriter, r *http.Request) { ws, err := websocket.Accept(w, r, nil) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } - err = Serve(r.Context(), ws, LocalExecer{}) + err = Serve(r.Context(), ws, LocalExecer{}, options) if err != nil { t.Errorf("failed to serve execer: %v", err) ws.Close(websocket.StatusAbnormalClosure, "failed to serve execer") @@ -77,7 +77,7 @@ func TestRemoteExec(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t) + ws, server := mockConn(ctx, t, nil) defer server.Close() execer := RemoteExecer(ws) @@ -89,7 +89,7 @@ func TestRemoteExecFail(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - ws, server := mockConn(ctx, t) + ws, server := mockConn(ctx, t, nil) defer server.Close() execer := RemoteExecer(ws) @@ -123,7 +123,7 @@ func TestStderrVsStdout(t *testing.T) { stderr bytes.Buffer ) - ws, server := mockConn(ctx, t) + ws, server := mockConn(ctx, t, nil) defer server.Close() execer := RemoteExecer(ws) diff --git a/dev/client/main.go b/dev/client/main.go index 031b7c9..0c28fd4 100644 --- a/dev/client/main.go +++ b/dev/client/main.go @@ -23,35 +23,38 @@ type notty struct { } func (c *notty) Run(fl *pflag.FlagSet) { - do(fl, false) + do(fl, false, "") } func (c *notty) Spec() cli.CommandSpec { return cli.CommandSpec{ - Name: "notty", - Usage: "[flags]", - Desc: `Run a command without tty enabled.`, - RawArgs: true, + Name: "notty", + Usage: "[flags]", + Desc: `Run a command without tty enabled.`, } } type tty struct { + id string } func (c *tty) Run(fl *pflag.FlagSet) { - do(fl, true) + do(fl, true, c.id) } func (c *tty) Spec() cli.CommandSpec { return cli.CommandSpec{ - Name: "tty", - Usage: "[flags]", - Desc: `Run a command with tty enabled.`, - RawArgs: true, + Name: "tty", + Usage: "[id] [flags]", + Desc: `Run a command with tty enabled. Use the same ID to reconnect.`, } } -func do(fl *pflag.FlagSet, tty bool) { +func (c *tty) RegisterFlags(fl *pflag.FlagSet) { + fl.StringVar(&c.id, "id", "", "sets id for reconnection") +} + +func do(fl *pflag.FlagSet, tty bool, id string) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -71,6 +74,7 @@ func do(fl *pflag.FlagSet, tty bool) { args = fl.Args()[1:] } process, err := executor.Start(ctx, wsep.Command{ + ID: id, Command: fl.Arg(0), Args: args, TTY: tty, diff --git a/dev/server/main.go b/dev/server/main.go index 87dcdf2..66020d5 100644 --- a/dev/server/main.go +++ b/dev/server/main.go @@ -23,7 +23,7 @@ func serve(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } - err = wsep.Serve(r.Context(), ws, wsep.LocalExecer{}) + err = wsep.Serve(r.Context(), ws, wsep.LocalExecer{}, nil) if err != nil { flog.Error("failed to serve execer: %v", err) ws.Close(websocket.StatusAbnormalClosure, "failed to serve execer") diff --git a/go.mod b/go.mod index b699e88..67ac33f 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,10 @@ go 1.14 require ( cdr.dev/slog v1.3.0 + github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 github.com/creack/pty v1.1.11 github.com/google/go-cmp v0.4.0 + github.com/google/uuid v1.3.0 github.com/spf13/pflag v1.0.5 go.coder.com/cli v0.4.0 go.coder.com/flog v0.0.0-20190906214207-47dd47ea0512 diff --git a/go.sum b/go.sum index 55cf453..8703f19 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/alecthomas/kong v0.2.1-0.20190708041108-0548c6b1afae/go.mod h1:+inYUS github.com/alecthomas/kong-hcl v0.1.8-0.20190615233001-b21fea9723c8/go.mod h1:MRgZdU3vrFd05IQ89AxUZ0aYdF39BYoNFa324SodPCA= github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897 h1:p9Sln00KOTlrYkxI1zYWl1QLnEqAqEARBEYa8FQnQcY= github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= +github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 h1:7Ip0wMmLHLRJdrloDxZfhMm0xrLXZS8+COSu2bXmEQs= +github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/creack/pty v1.1.11 h1:07n33Z8lZxZ2qwegKbObQohDhXDQxiMMz1NOUGYlesw= @@ -92,6 +94,8 @@ github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXi github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gorilla/csrf v1.6.0/go.mod h1:7tSf8kmjNYr7IWDCYhd3U8Ck34iQ/Yw5CJu7bAkHEGI= diff --git a/internal/proto/clientmsg.go b/internal/proto/clientmsg.go index 610a7bf..ae05ad5 100644 --- a/internal/proto/clientmsg.go +++ b/internal/proto/clientmsg.go @@ -18,6 +18,7 @@ type ClientResizeHeader struct { // ClientStartHeader specifies a request to start command type ClientStartHeader struct { Type string `json:"type"` + ID string `json:"id"` Command Command `json:"command"` } diff --git a/server.go b/server.go index 224a6c6..c9f4045 100644 --- a/server.go +++ b/server.go @@ -7,6 +7,11 @@ import ( "errors" "io" "net" + "sync" + "time" + + "github.com/armon/circbuf" + "github.com/google/uuid" "go.coder.com/flog" "golang.org/x/sync/errgroup" @@ -16,24 +21,33 @@ import ( "cdr.dev/wsep/internal/proto" ) +var reconnectingProcesses sync.Map + +// Options allows configuring the server. +type Options struct { + ReconnectingProcessTimeout time.Duration +} + // Serve runs the server-side of wsep. // The execer may be another wsep connection for chaining. // Use LocalExecer for local command execution. -func Serve(ctx context.Context, c *websocket.Conn, execer Execer) error { +func Serve(ctx context.Context, c *websocket.Conn, execer Execer, options *Options) error { ctx, cancel := context.WithCancel(ctx) defer cancel() + if options == nil { + options = &Options{} + } + if options.ReconnectingProcessTimeout == 0 { + options.ReconnectingProcessTimeout = 5 * time.Minute + } + c.SetReadLimit(maxMessageSize) var ( header proto.Header process Process wsNetConn = websocket.NetConn(ctx, c, websocket.MessageBinary) ) - defer func() { - if process != nil { - process.Close() - } - }() for { if err := ctx.Err(); err != nil { return err @@ -61,36 +75,183 @@ func Serve(ctx context.Context, c *websocket.Conn, execer Execer) error { switch header.Type { case proto.TypeStart: + if process != nil { + return errors.New("command already started") + } + var header proto.ClientStartHeader err = json.Unmarshal(byt, &header) if err != nil { return xerrors.Errorf("unmarshal start header: %w", err) } - process, err = execer.Start(ctx, mapToClientCmd(header.Command)) - if err != nil { - return err - } - _ = sendPID(ctx, process.Pid(), wsNetConn) - - var outputgroup errgroup.Group - outputgroup.Go(func() error { - return copyWithHeader(process.Stdout(), wsNetConn, proto.Header{Type: proto.TypeStdout}) - }) - outputgroup.Go(func() error { - return copyWithHeader(process.Stderr(), wsNetConn, proto.Header{Type: proto.TypeStderr}) - }) - - go func() { - defer wsNetConn.Close() - _ = outputgroup.Wait() - err = process.Wait() - if exitErr, ok := err.(ExitError); ok { - _ = sendExitCode(ctx, exitErr.Code, wsNetConn) - return + command := mapToClientCmd(header.Command) + + // Only allow TTYs with IDs to be reconnected. + if command.TTY && header.ID != "" { + // Enforce a consistent format for IDs. + _, err := uuid.Parse(header.ID) + if err != nil { + flog.Error("%s is not a valid uuid: %w", header.ID, err) + } + + // Get an existing process or create a new one. + var rprocess *reconnectingProcess + rawRProcess, ok := reconnectingProcesses.Load(header.ID) + if ok { + rprocess, ok = rawRProcess.(*reconnectingProcess) + if !ok { + flog.Error("found invalid type in reconnecting process map for ID %s", header.ID) + } + process = rprocess.process + } else { + // The process will be kept alive as long as this context does not + // finish (and as long as the process does not exit on its own). This + // is a new context since the parent context finishes when the request + // ends which would kill the process prematurely. + ctx, cancel := context.WithCancel(context.Background()) + + // The process will be killed if the provided context ends. + process, err = execer.Start(ctx, command) + if err != nil { + cancel() + return err + } + + // Default to buffer 64KB. + ringBuffer, err := circbuf.NewBuffer(64 * 1024) + if err != nil { + cancel() + return xerrors.Errorf("unable to create ring buffer %w", err) + } + + rprocess = &reconnectingProcess{ + activeConns: make(map[string]net.Conn), + process: process, + // Timeouts created with AfterFunc can be reset. + timeout: time.AfterFunc(options.ReconnectingProcessTimeout, cancel), + ringBuffer: ringBuffer, + } + reconnectingProcesses.Store(header.ID, rprocess) + + // If the process exits send the exit code to all listening + // connections then close everything. + go func() { + err = process.Wait() + code := 0 + if exitErr, ok := err.(ExitError); ok { + code = exitErr.Code + } + rprocess.activeConnsMutex.Lock() + for _, conn := range rprocess.activeConns { + _ = sendExitCode(ctx, code, conn) + } + rprocess.activeConnsMutex.Unlock() + rprocess.Close() + reconnectingProcesses.Delete(header.ID) + }() + + // Write to the ring buffer and all connections as we receive stdout. + go func() { + buffer := make([]byte, 32*1024) + for { + read, err := rprocess.process.Stdout().Read(buffer) + if err != nil { + // When the process is closed this is triggered. + break + } + part := buffer[:read] + _, err = rprocess.ringBuffer.Write(part) + if err != nil { + flog.Error("reconnecting process %s write buffer: %v", header.ID, err) + cancel() + break + } + rprocess.activeConnsMutex.Lock() + for _, conn := range rprocess.activeConns { + _ = sendOutput(ctx, part, conn) + } + rprocess.activeConnsMutex.Unlock() + } + }() } - _ = sendExitCode(ctx, 0, wsNetConn) - }() + + err = sendPID(ctx, process.Pid(), wsNetConn) + if err != nil { + flog.Error("failed to send pid %d", process.Pid()) + } + + // Write out the initial contents in the ring buffer. + err = sendOutput(ctx, rprocess.ringBuffer.Bytes(), wsNetConn) + if err != nil { + return xerrors.Errorf("write reconnecting process %s buffer: %w", header.ID, err) + } + + // Store this connection on the reconnecting process. All connections + // stored on the process will receive the process's stdout. + connectionID := uuid.NewString() + rprocess.activeConnsMutex.Lock() + rprocess.activeConns[connectionID] = wsNetConn + rprocess.activeConnsMutex.Unlock() + + // Keep resetting the inactivity timer while this connection is alive. + rprocess.timeout.Reset(options.ReconnectingProcessTimeout) + heartbeat := time.NewTicker(options.ReconnectingProcessTimeout / 2) + defer heartbeat.Stop() + go func() { + for { + select { + // Stop looping once this request finishes. + case <-ctx.Done(): + return + case <-heartbeat.C: + } + rprocess.timeout.Reset(options.ReconnectingProcessTimeout) + } + }() + + // Remove this connection from the process's connection list once the + // connection ends so data is no longer sent to it. + defer func() { + wsNetConn.Close() // REVIEW@asher: Not sure if necessary. + rprocess.activeConnsMutex.Lock() + delete(rprocess.activeConns, connectionID) + rprocess.activeConnsMutex.Unlock() + }() + } else { + process, err = execer.Start(ctx, command) + if err != nil { + return err + } + + err = sendPID(ctx, process.Pid(), wsNetConn) + if err != nil { + flog.Error("failed to send pid %d", process.Pid()) + } + + var outputgroup errgroup.Group + outputgroup.Go(func() error { + return copyWithHeader(process.Stdout(), wsNetConn, proto.Header{Type: proto.TypeStdout}) + }) + outputgroup.Go(func() error { + return copyWithHeader(process.Stderr(), wsNetConn, proto.Header{Type: proto.TypeStderr}) + }) + + go func() { + defer wsNetConn.Close() + _ = outputgroup.Wait() + err = process.Wait() + if exitErr, ok := err.(ExitError); ok { + _ = sendExitCode(ctx, exitErr.Code, wsNetConn) + return + } + _ = sendExitCode(ctx, 0, wsNetConn) + }() + + defer func() { + process.Close() + }() + } case proto.TypeResize: if process == nil { return errors.New("resize sent before command started") @@ -143,6 +304,15 @@ func sendPID(_ context.Context, pid int, conn net.Conn) error { return err } +func sendOutput(_ context.Context, data []byte, conn net.Conn) error { + header, err := json.Marshal(proto.ServerPidHeader{Type: proto.TypeStdout}) + if err != nil { + return err + } + _, err = proto.WithHeader(conn, header).Write(data) + return err +} + func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error { headerByt, err := json.Marshal(header) if err != nil { @@ -155,3 +325,24 @@ func copyWithHeader(r io.Reader, w io.Writer, header proto.Header) error { } return nil } + +type reconnectingProcess struct { + activeConnsMutex sync.Mutex + activeConns map[string]net.Conn + + ringBuffer *circbuf.Buffer + timeout *time.Timer + process Process +} + +// Close ends all connections to the reconnecting process and clears the ring +// buffer. +func (r *reconnectingProcess) Close() { + r.activeConnsMutex.Lock() + defer r.activeConnsMutex.Unlock() + for _, conn := range r.activeConns { + _ = conn.Close() + } + _ = r.process.Close() + r.ringBuffer.Reset() +} diff --git a/tty_test.go b/tty_test.go index 913fb62..bfc20f3 100644 --- a/tty_test.go +++ b/tty_test.go @@ -1,6 +1,7 @@ package wsep import ( + "bufio" "context" "io/ioutil" "strings" @@ -9,6 +10,7 @@ import ( "time" "cdr.dev/slog/sloggers/slogtest/assert" + "github.com/google/uuid" "nhooyr.io/websocket" ) @@ -18,7 +20,7 @@ func TestTTY(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - ws, server := mockConn(ctx, t) + ws, server := mockConn(ctx, t, nil) defer ws.Close(websocket.StatusInternalError, "") defer server.Close() @@ -60,3 +62,117 @@ func testTTY(ctx context.Context, t *testing.T, e Execer) { process.Close() wg.Wait() } + +func TestReconnectTTY(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ws, server := mockConn(ctx, t, &Options{ + ReconnectingProcessTimeout: time.Second, + }) + defer server.Close() + + command := Command{ + ID: uuid.NewString(), + Command: "sh", + TTY: true, + Stdin: true, + } + execer := RemoteExecer(ws) + process, err := execer.Start(ctx, command) + assert.Success(t, "start sh", err) + + // Write some unique output. + echoCmd := "echo test:$((1+1))" + data := []byte(echoCmd + "\r\n") + _, err = process.Stdin().Write(data) + assert.Success(t, "write to stdin", err) + expected := []string{echoCmd, "test:2"} + + findEcho := func(expected []string) bool { + scanner := bufio.NewScanner(process.Stdout()) + outer: + for _, str := range expected { + for scanner.Scan() { + line := scanner.Text() + t.Logf("bash tty stdout = %s", line) + if strings.Contains(line, str) { + continue outer + } + } + return false // Reached the end of output without finding str. + } + return true + } + + assert.True(t, "find echo", findEcho(expected)) + + // Test disconnecting then reconnecting. + ws.Close(websocket.StatusNormalClosure, "disconnected") + server.Close() + + ws, server = mockConn(ctx, t, &Options{ + ReconnectingProcessTimeout: time.Second, + }) + defer server.Close() + + execer = RemoteExecer(ws) + process, err = execer.Start(ctx, command) + assert.Success(t, "attach sh", err) + + // The inactivity timeout should not have been triggered. + time.Sleep(time.Second) + + echoCmd = "echo test:$((2+2))" + data = []byte(echoCmd + "\r\n") + _, err = process.Stdin().Write(data) + assert.Success(t, "write to stdin", err) + expected = append(expected, echoCmd, "test:4") + + assert.True(t, "find echo", findEcho(expected)) + + // Test disconnecting while another connection is active. + ws2, server2 := mockConn(ctx, t, &Options{ + // Divide the time to test that the heartbeat keeps it open through multiple + // intervals. + ReconnectingProcessTimeout: time.Second / 4, + }) + defer server2.Close() + + execer = RemoteExecer(ws2) + process, err = execer.Start(ctx, command) + assert.Success(t, "attach sh", err) + + ws.Close(websocket.StatusNormalClosure, "disconnected") + server.Close() + time.Sleep(time.Second) + + // This connection should still be up. + echoCmd = "echo test:$((3+3))" + data = []byte(echoCmd + "\r\n") + _, err = process.Stdin().Write(data) + assert.Success(t, "write to stdin", err) + expected = append(expected, echoCmd, "test:6") + + assert.True(t, "find echo", findEcho(expected)) + + // Close the remaining connection and wait for inactivity. + ws2.Close(websocket.StatusNormalClosure, "disconnected") + server2.Close() + time.Sleep(time.Second) + + // The next connection should start a new process. + ws, server = mockConn(ctx, t, &Options{ + ReconnectingProcessTimeout: time.Second, + }) + defer server.Close() + + execer = RemoteExecer(ws) + process, err = execer.Start(ctx, command) + assert.Success(t, "attach sh", err) + + // This time no echo since it is a new process. + assert.True(t, "find echo", !findEcho(expected)) +}