From 2e8d01ca5c71cf56c0a3103f769d71032030c611 Mon Sep 17 00:00:00 2001 From: dogukanoksuz Date: Mon, 18 Dec 2023 19:03:53 +0000 Subject: [PATCH] feature: Internal key usage on ssh tunnels --- app/handlers/tunnel.go | 17 +++++++++-- internal/bridge/ssh_tunnel.go | 57 ++++++++++++++++++++++++++++++----- 2 files changed, 65 insertions(+), 9 deletions(-) diff --git a/app/handlers/tunnel.go b/app/handlers/tunnel.go index 6864ea6a..b0a944ae 100644 --- a/app/handlers/tunnel.go +++ b/app/handlers/tunnel.go @@ -8,7 +8,7 @@ import ( // OpenTunnel opens ssh tunnel on unix sockets or ports func OpenTunnel(c *fiber.Ctx) error { - params := []string{"remote_host", "remote_port", "username", "password"} + params := []string{"remote_host", "remote_port"} for _, param := range params { if len(c.FormValue(param)) < 1 { @@ -21,13 +21,26 @@ func OpenTunnel(c *fiber.Ctx) error { sshPort = "22" } - port := bridge.CreateTunnel( + if len(c.FormValue("username")) < 1 { + port, err := bridge.CreateTunnelInternalKey(c) + if err != nil { + return logger.FiberError(fiber.StatusInternalServerError, err.Error()) + } + + return c.JSON(port) + } + + port, err := bridge.CreateTunnel( c.FormValue("remote_host"), c.FormValue("remote_port"), c.FormValue("username"), c.FormValue("password"), sshPort, + "ssh", ) + if err != nil { + return logger.FiberError(fiber.StatusInternalServerError, err.Error()) + } return c.JSON(port) } diff --git a/internal/bridge/ssh_tunnel.go b/internal/bridge/ssh_tunnel.go index 71cdc627..74b58082 100644 --- a/internal/bridge/ssh_tunnel.go +++ b/internal/bridge/ssh_tunnel.go @@ -2,6 +2,7 @@ package bridge import ( "context" + "errors" "fmt" "io" "net" @@ -11,6 +12,9 @@ import ( "time" "github.com/avast/retry-go" + "github.com/gofiber/fiber/v2" + "github.com/limanmys/render-engine/app/models" + "github.com/limanmys/render-engine/internal/liman" "github.com/limanmys/render-engine/pkg/logger" "github.com/phayes/freeport" "golang.org/x/crypto/ssh" @@ -301,8 +305,30 @@ func (t *Tunnel) String() string { return fmt.Sprintf("%s@%s | %s %s %s", t.user, t.hostAddr, left, mode, right) } +func CreateTunnelInternalKey(c *fiber.Ctx) (int, error) { + credentials, err := liman.GetCredentials(&models.User{ + ID: c.Locals("user_id").(string), + }, &models.Server{ + ID: c.FormValue("server_id"), + }) + if err != nil { + return 0, err + } + + port, err := CreateTunnel( + c.FormValue("remote_host"), + c.FormValue("remote_port"), + credentials.Username, + credentials.Key, + credentials.Port, + credentials.Type, + ) + + return port, err +} + // CreateTunnel starts a new tunnel instance and sets it into TunnelPool -func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) int { +func CreateTunnel(remoteHost, remotePort, username, password, sshPort, connType string) (int, error) { // Creating a tunnel cannot exceed 25 seconds ch := make(chan int) time.AfterFunc(25*time.Second, func() { @@ -315,7 +341,7 @@ func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) in // Check if existing tunnel started, if not wait until starts (max: 25sec) if err == nil { if t.password != password { - return 0 + return 0, errors.New("password mismatch") } startedLoop: @@ -336,14 +362,14 @@ func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) in } t.LastConnection = time.Now() - return t.Port + return t.Port, nil } // This part from now creates a new tunnel port, err := freeport.GetFreePort() if err != nil { logger.Sugar().Errorw(err.Error()) - return 0 + return 0, errors.New("couldnt find a free port") } dial := net.JoinHostPort("127.0.0.1", remotePort) @@ -355,8 +381,8 @@ func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) in } ctx, cancel := context.WithCancel(context.Background()) + sshTunnel := &Tunnel{ - auth: []ssh.AuthMethod{ssh.RetryableAuthMethod(ssh.Password(password), 3)}, hostKeys: ssh.InsecureIgnoreHostKey(), user: username, mode: '>', @@ -376,6 +402,23 @@ func CreateTunnel(remoteHost, remotePort, username, password, sshPort string) in cancel: cancel, } + if connType == "ssh" { + sshTunnel.auth = []ssh.AuthMethod{ + ssh.RetryableAuthMethod(ssh.Password(password), 3), + } + } + + if connType == "ssh_certificate" { + key, err := ssh.ParsePrivateKey([]byte(password)) + if err != nil { + return 0, errors.New("an error occured while parsing ssh key") + } + + sshTunnel.auth = []ssh.AuthMethod{ + ssh.RetryableAuthMethod(ssh.PublicKeys(key), 3), + } + } + Tunnels.Set(remoteHost, remotePort, username, sshTunnel) go sshTunnel.Start() @@ -398,8 +441,8 @@ loop: if !sshTunnel.Started { cancel() - return 0 + return 0, errors.New("cannot start tunnel") } - return sshTunnel.Port + return sshTunnel.Port, nil }