diff --git a/assert_test.go b/assert_test.go index 871c43c21..b49213691 100644 --- a/assert_test.go +++ b/assert_test.go @@ -46,6 +46,10 @@ func assertStringContainsE(t *testing.T, actual string, expectedToContain string errorOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...)) } +func assertStringContainsF(t *testing.T, actual string, expectedToContain string, descriptions ...string) { + fatalOnNonEmpty(t, validateStringContains(actual, expectedToContain, descriptions...)) +} + func assertHasPrefixE(t *testing.T, actual string, expectedPrefix string, descriptions ...string) { errorOnNonEmpty(t, validateHasPrefix(actual, expectedPrefix, descriptions...)) } diff --git a/auth.go b/auth.go index 9493459f9..c844b2e66 100644 --- a/auth.go +++ b/auth.go @@ -501,6 +501,10 @@ func authenticateWithConfig(sc *snowflakeConn) error { if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue { fillCachedIDToken(sc) } + // Disable console login by default + if sc.cfg.DisableConsoleLogin == configBoolNotSet { + sc.cfg.DisableConsoleLogin = ConfigBoolTrue + } } if sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA { @@ -524,7 +528,8 @@ func authenticateWithConfig(sc *snowflakeConn) error { sc.cfg.Account, sc.cfg.User, sc.cfg.Password, - sc.cfg.ExternalBrowserTimeout) + sc.cfg.ExternalBrowserTimeout, + sc.cfg.DisableConsoleLogin) if err != nil { sc.cleanup() return err diff --git a/auth_test.go b/auth_test.go index 4a6fd0e9f..25e2d0d40 100644 --- a/auth_test.go +++ b/auth_test.go @@ -686,6 +686,22 @@ func TestUnitAuthenticateWithConfigOkta(t *testing.T) { assertEqualE(t, err.Error(), "failed to get SAML response") } +func TestUnitAuthenticateWithConfigExternalBrowser(t *testing.T) { + var err error + sr := &snowflakeRestful{ + FuncPostAuthSAML: postAuthSAMLError, + TokenAccessor: getSimpleTokenAccessor(), + } + sc := getDefaultSnowflakeConn() + sc.cfg.Authenticator = AuthTypeExternalBrowser + sc.cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout + sc.rest = sr + sc.ctx = context.Background() + err = authenticateWithConfig(sc) + assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.") + assertEqualE(t, err.Error(), "failed to get SAML response") +} + func TestUnitAuthenticateExternalBrowser(t *testing.T) { var err error sr := &snowflakeRestful{ diff --git a/authexternalbrowser.go b/authexternalbrowser.go index 373173f5b..ac53c3707 100644 --- a/authexternalbrowser.go +++ b/authexternalbrowser.go @@ -5,6 +5,7 @@ package gosnowflake import ( "bytes" "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -70,11 +71,11 @@ func createLocalTCPListener() (*net.TCPListener, error) { return tcpListener, nil } -// Opens a browser window (or new tab) with the configured IDP Url. +// Opens a browser window (or new tab) with the configured login Url. // This can / will fail if running inside a shell with no display, ie // ssh'ing into a box attempting to authenticate via external browser. -func openBrowser(idpURL string) error { - err := browser.OpenURL(idpURL) +func openBrowser(loginURL string) error { + err := browser.OpenURL(loginURL) if err != nil { logger.Infof("failed to open a browser. err: %v", err) return err @@ -91,6 +92,7 @@ func getIdpURLProofKey( authenticator string, application string, account string, + user string, callbackPort int) (string, string, error) { headers := make(map[string]string) @@ -108,6 +110,7 @@ func getIdpURLProofKey( ClientAppID: clientType, ClientAppVersion: SnowflakeGoDriverVersion, AccountName: account, + LoginName: user, ClientEnvironment: clientEnvironment, Authenticator: authenticator, BrowserModeRedirectPort: strconv.Itoa(callbackPort), @@ -144,6 +147,24 @@ func getIdpURLProofKey( return respd.Data.SSOURL, respd.Data.ProofKey, nil } +// Gets the login URL for multiple SAML +func getLoginURL(sr *snowflakeRestful, user string, callbackPort int) (string, string, error) { + proofKey := generateProofKey() + + params := &url.Values{} + params.Add("login_name", user) + params.Add("browser_mode_redirect_port", strconv.Itoa(callbackPort)) + params.Add("proof_key", proofKey) + url := sr.getFullURL(consoleLoginRequestPath, params) + + return url.String(), proofKey, nil +} + +func generateProofKey() string { + randomness := getSecureRandom(32) + return base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString(randomness) +} + // The response returned from Snowflake looks like so: // GET /?token=encodedSamlToken // Host: localhost:54001 @@ -187,10 +208,11 @@ func authenticateByExternalBrowser( user string, password string, externalBrowserTimeout time.Duration, + disableConsoleLogin ConfigBool, ) ([]byte, []byte, error) { resultChan := make(chan authenticateByExternalBrowserResult, 1) go func() { - resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, password) + resultChan <- doAuthenticateByExternalBrowser(ctx, sr, authenticator, application, account, user, password, disableConsoleLogin) }() select { case <-time.After(externalBrowserTimeout): @@ -204,7 +226,7 @@ func authenticateByExternalBrowser( // - the golang snowflake driver communicates to Snowflake that the user wishes to // authenticate via external browser // - snowflake sends back the IDP Url configured at the Snowflake side for the -// provided account +// provided account, or use the multiple SAML way via console login // - the default browser is opened to that URL // - user authenticates at the IDP, and is redirected to Snowflake // - Snowflake directs the user back to the driver @@ -217,6 +239,7 @@ func doAuthenticateByExternalBrowser( account string, user string, password string, + disableConsoleLogin ConfigBool, ) authenticateByExternalBrowserResult { l, err := createLocalTCPListener() if err != nil { @@ -225,13 +248,22 @@ func doAuthenticateByExternalBrowser( defer l.Close() callbackPort := l.Addr().(*net.TCPAddr).Port - idpURL, proofKey, err := getIdpURLProofKey( - ctx, sr, authenticator, application, account, callbackPort) + + var loginURL string + var proofKey string + if disableConsoleLogin == ConfigBoolTrue { + // Gets the IDP URL and Proof Key from Snowflake + loginURL, proofKey, err = getIdpURLProofKey(ctx, sr, authenticator, application, account, user, callbackPort) + } else { + // Multiple SAML way to do authentication via console login + loginURL, proofKey, err = getLoginURL(sr, user, callbackPort) + } + if err != nil { return authenticateByExternalBrowserResult{nil, nil, err} } - if err = openBrowser(idpURL); err != nil { + if err = openBrowser(loginURL); err != nil { return authenticateByExternalBrowserResult{nil, nil, err} } diff --git a/authexternalbrowser_test.go b/authexternalbrowser_test.go index ea1a19ac9..a839d389d 100644 --- a/authexternalbrowser_test.go +++ b/authexternalbrowser_test.go @@ -5,6 +5,7 @@ package gosnowflake import ( "context" "errors" + "net/url" "strings" "testing" "time" @@ -91,17 +92,17 @@ func TestUnitAuthenticateByExternalBrowser(t *testing.T) { FuncPostAuthSAML: postAuthExternalBrowserError, TokenAccessor: getSimpleTokenAccessor(), } - _, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout) + _, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue) if err == nil { t.Fatal("should have failed.") } sr.FuncPostAuthSAML = postAuthExternalBrowserFail - _, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout) + _, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue) if err == nil { t.Fatal("should have failed.") } sr.FuncPostAuthSAML = postAuthExternalBrowserFailWithCode - _, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout) + _, _, err = authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue) if err == nil { t.Fatal("should have failed.") } @@ -128,7 +129,7 @@ func TestAuthenticationTimeout(t *testing.T) { FuncPostAuthSAML: postAuthExternalBrowserError, TokenAccessor: getSimpleTokenAccessor(), } - _, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout) + _, _, err := authenticateByExternalBrowser(context.Background(), sr, authenticator, application, account, user, password, timeout, ConfigBoolTrue) if err.Error() != "authentication timed out" { t.Fatal("should have timed out") } @@ -146,3 +147,29 @@ func Test_createLocalTCPListener(t *testing.T) { // Close the listener after the test. defer listener.Close() } + +func TestUnitGetLoginURL(t *testing.T) { + expectedScheme := "https" + expectedHost := "abc.com:443" + user := "u" + callbackPort := 123 + sr := &snowflakeRestful{ + Protocol: "https", + Host: "abc.com", + Port: 443, + TokenAccessor: getSimpleTokenAccessor(), + } + + loginURL, proofKey, err := getLoginURL(sr, user, callbackPort) + assertNilF(t, err, "failed to get login URL") + assertNotNilF(t, len(proofKey), "proofKey should be non-empty string") + + urlPtr, err := url.Parse(loginURL) + assertNilF(t, err, "failed to parse the login URL") + assertEqualF(t, urlPtr.Scheme, expectedScheme) + assertEqualF(t, urlPtr.Host, expectedHost) + assertEqualF(t, urlPtr.Path, consoleLoginRequestPath) + assertStringContainsF(t, urlPtr.RawQuery, "login_name") + assertStringContainsF(t, urlPtr.RawQuery, "browser_mode_redirect_port") + assertStringContainsF(t, urlPtr.RawQuery, "proof_key") +} diff --git a/dsn.go b/dsn.go index 341e8e2b4..e2c541a47 100644 --- a/dsn.go +++ b/dsn.go @@ -105,6 +105,8 @@ type Config struct { IncludeRetryReason ConfigBool // Should retried request contain retry reason ClientConfigFile string // File path to the client configuration json file + + DisableConsoleLogin ConfigBool // Indicates whether console login should be disabled } // Validate enables testing if config is correct. @@ -262,6 +264,9 @@ func DSN(cfg *Config) (dsn string, err error) { if cfg.ClientConfigFile != "" { params.Add("clientConfigFile", cfg.ClientConfigFile) } + if cfg.DisableConsoleLogin != configBoolNotSet { + params.Add("disableConsoleLogin", strconv.FormatBool(cfg.DisableConsoleLogin != ConfigBoolFalse)) + } dsn = fmt.Sprintf("%v:%v@%v:%v", url.QueryEscape(cfg.User), url.QueryEscape(cfg.Password), cfg.Host, cfg.Port) if params.Encode() != "" { @@ -754,6 +759,17 @@ func parseDSNParams(cfg *Config, params string) (err error) { } case "clientConfigFile": cfg.ClientConfigFile = value + case "disableConsoleLogin": + var vv bool + vv, err = strconv.ParseBool(value) + if err != nil { + return + } + if vv { + cfg.DisableConsoleLogin = ConfigBoolTrue + } else { + cfg.DisableConsoleLogin = ConfigBoolFalse + } default: if cfg.Params == nil { cfg.Params = make(map[string]*string) diff --git a/dsn_test.go b/dsn_test.go index e133d8e8b..dfa2a960f 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -714,6 +714,40 @@ func TestParseDSN(t *testing.T) { dsn: "u:p@a.snowflakecomputing.com:443?authenticator=http%3A%2F%2Fsc.okta.com&ocspFailOpen=true&validateDefaultParameters=true", err: errFailedToParseAuthenticator(), }, + { + dsn: "u:p@a.snowflake.local:9876?account=a&protocol=http&authenticator=EXTERNALBROWSER&disableConsoleLogin=true", + config: &Config{ + Account: "a", User: "u", Password: "p", + Authenticator: AuthTypeExternalBrowser, + Protocol: "http", Host: "a.snowflake.local", Port: 9876, + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, + DisableConsoleLogin: ConfigBoolTrue, + }, + ocspMode: ocspModeFailOpen, + err: nil, + }, + { + dsn: "u:p@a.snowflake.local:9876?account=a&protocol=http&authenticator=EXTERNALBROWSER&disableConsoleLogin=false", + config: &Config{ + Account: "a", User: "u", Password: "p", + Authenticator: AuthTypeExternalBrowser, + Protocol: "http", Host: "a.snowflake.local", Port: 9876, + OCSPFailOpen: OCSPFailOpenTrue, + ValidateDefaultParameters: ConfigBoolTrue, + ClientTimeout: defaultClientTimeout, + JWTClientTimeout: defaultJWTClientTimeout, + ExternalBrowserTimeout: defaultExternalBrowserTimeout, + IncludeRetryReason: ConfigBoolTrue, + DisableConsoleLogin: ConfigBoolFalse, + }, + ocspMode: ocspModeFailOpen, + err: nil, + }, } for _, at := range []AuthType{AuthTypeExternalBrowser, AuthTypeOAuth} { @@ -873,6 +907,9 @@ func TestParseDSN(t *testing.T) { if test.config.IncludeRetryReason != cfg.IncludeRetryReason { t.Fatalf("%v: Failed to match IncludeRetryReason. expected: %v, got: %v", i, test.config.IncludeRetryReason, cfg.IncludeRetryReason) } + if test.config.DisableConsoleLogin != cfg.DisableConsoleLogin { + t.Fatalf("%v: Failed to match DisableConsoleLogin. expected: %v, got: %v", i, test.config.DisableConsoleLogin, cfg.DisableConsoleLogin) + } assertEqualF(t, cfg.ClientConfigFile, test.config.ClientConfigFile, "client config file") case test.err != nil: driverErrE, okE := test.err.(*SnowflakeError) @@ -1322,6 +1359,26 @@ func TestDSN(t *testing.T) { }, dsn: "u:p@a.b.c.snowflakecomputing.com:443?clientConfigFile=c%3A%5CUsers%5Cuser%5Cconfig.json&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Authenticator: AuthTypeExternalBrowser, + DisableConsoleLogin: ConfigBoolTrue, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=externalbrowser&disableConsoleLogin=true&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, + { + cfg: &Config{ + User: "u", + Password: "p", + Account: "a.b.c", + Authenticator: AuthTypeExternalBrowser, + DisableConsoleLogin: ConfigBoolFalse, + }, + dsn: "u:p@a.b.c.snowflakecomputing.com:443?authenticator=externalbrowser&disableConsoleLogin=false&ocspFailOpen=true®ion=b.c&validateDefaultParameters=true", + }, } for _, test := range testcases { t.Run(test.dsn, func(t *testing.T) { diff --git a/restful.go b/restful.go index c92d9c762..9b69c6700 100644 --- a/restful.go +++ b/restful.go @@ -40,6 +40,7 @@ const ( monitoringQueriesPath = "/monitoring/queries" sessionRequestPath = "/session" heartBeatPath = "/session/heartbeat" + consoleLoginRequestPath = "/console/login" ) type (