From 9a3a82886315233c0714fe8b7f84b60ffe8e23be Mon Sep 17 00:00:00 2001 From: Steve Jones Date: Wed, 25 Sep 2024 08:02:48 -0700 Subject: [PATCH] Service specific endpoints compatible resolver --- auth/auth.go | 3 + utils/environment_endpoint_resolver.go | 74 ++++++++++++++++ utils/environment_endpoint_resolver_test.go | 94 +++++++++++++++++++++ 3 files changed, 171 insertions(+) create mode 100644 utils/environment_endpoint_resolver.go create mode 100644 utils/environment_endpoint_resolver_test.go diff --git a/auth/auth.go b/auth/auth.go index e070d3e..c0de413 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -19,6 +19,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/aws/secrets-store-csi-driver-provider-aws/utils" authv1 "k8s.io/api/authentication/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -83,6 +84,7 @@ func NewAuth( // Get an initial session to use for STS calls. sess, err := session.NewSession(aws.NewConfig(). + WithEndpointResolver(utils.EnvironmentEndpointResolver()). WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint). WithRegion(region), ) @@ -140,6 +142,7 @@ func (p Auth) GetAWSSession() (awsSession *session.Session, e error) { fetcher := &authTokenFetcher{p.nameSpace, p.svcAcc, p.k8sClient} ar := stscreds.NewWebIdentityRoleProviderWithToken(p.stsClient, *roleArn, ProviderName, fetcher) config := aws.NewConfig(). + WithEndpointResolver(utils.EnvironmentEndpointResolver()). WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint). // Use regional STS endpoint WithRegion(p.region). WithCredentials(credentials.NewCredentials(ar)) diff --git a/utils/environment_endpoint_resolver.go b/utils/environment_endpoint_resolver.go new file mode 100644 index 0000000..345adc8 --- /dev/null +++ b/utils/environment_endpoint_resolver.go @@ -0,0 +1,74 @@ +package utils + +import ( + "os" + "strings" + + "github.com/aws/aws-sdk-go/aws/endpoints" +) + +const ( + envVarDisable = "AWS_IGNORE_CONFIGURED_ENDPOINT_URLS" + envVarUrlDefault = "AWS_ENDPOINT_URL" + envVarUrlPrefix = "AWS_ENDPOINT_URL_" +) + +// non-standard endpoint service name to environment variable suffix mappings +var serviceToEnv = map[string]string{ + "secretsmanager": "SECRETS_MANAGER", +} + +var envResolver = endpoints.ResolverFunc(envResolve) + +// EnvironmentEndpointResolver uses environment variables to locate endpoints. +// +// Uses environment variables compatible with the service specific endpoints +// feature to locate service endpoints: +// +// - AWS_ENDPOINT_URL - default endpoint +// - AWS_ENDPOINT_URL_ - service specific endpoint +// - AWS_IGNORE_CONFIGURED_ENDPOINT_URLS - "true" to ignore configured +// +// When AWS_IGNORE_CONFIGURED_ENDPOINT_URLS is "true" all environment +// variables are ignored. +// +// When an endpoint is not configured via environment the default resolver +// is used. +func EnvironmentEndpointResolver() endpoints.Resolver { + return envResolver +} + +// envResolveEnabled should environment endpoints be used +func envResolveEnabled() bool { + return "true" != os.Getenv(envVarDisable) +} + +// serviceUrlEnvVar look up the custom mapping or use standard transform +func serviceUrlEnvVar(service string) string { + envVarSuffix, ok := serviceToEnv[service] + if !ok { + envVarSuffix = strings.ReplaceAll(strings.ToUpper(service), "-", "_") + } + return envVarUrlPrefix + envVarSuffix +} + +// urlFromEnvironment lookup url from service specific or default environment variable +func urlFromEnvironment(service string) string { + url := os.Getenv(serviceUrlEnvVar(service)) + if url == "" { + url = os.Getenv(envVarUrlDefault) + } + return url +} + +// envResolve lookup service endpoint via environment variables if enabled +func envResolve(service string, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { + if envResolveEnabled() { + if url := urlFromEnvironment(service); url != "" { + return endpoints.ResolvedEndpoint{ + URL: url, + }, nil + } + } + return endpoints.DefaultResolver().EndpointFor(service, region, opts...) +} diff --git a/utils/environment_endpoint_resolver_test.go b/utils/environment_endpoint_resolver_test.go new file mode 100644 index 0000000..b3533c7 --- /dev/null +++ b/utils/environment_endpoint_resolver_test.go @@ -0,0 +1,94 @@ +package utils + +import ( + "os" + "testing" + + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/stretchr/testify/assert" +) + +func TestEnvironmentEndpointResolver_EndpointFor_Disabled(t *testing.T) { + err := os.Setenv("AWS_IGNORE_CONFIGURED_ENDPOINT_URLS", "true") + assert.NoError(t, err) + + err = os.Setenv("AWS_ENDPOINT_URL", "https://127.0.0.1:443") // should be ignored + assert.NoError(t, err) + + endpoint, err := EnvironmentEndpointResolver(). + EndpointFor("sts", "us-west-1", endpoints.STSRegionalEndpointOption) + assert.NoError(t, err) + + assert.Equal(t, "aws", endpoint.PartitionID) + assert.Equal(t, "v4", endpoint.SigningMethod) + assert.Equal(t, "sts", endpoint.SigningName) + assert.Equal(t, true, endpoint.SigningNameDerived) + assert.Equal(t, "us-west-1", endpoint.SigningRegion) + assert.Equal(t, "https://sts.us-west-1.amazonaws.com", endpoint.URL) +} + +func TestEnvironmentEndpointResolver_EndpointFor_Default(t *testing.T) { + err := os.Unsetenv("AWS_IGNORE_CONFIGURED_ENDPOINT_URLS") + assert.NoError(t, err) + + err = os.Unsetenv("AWS_ENDPOINT_URL_STS") + assert.NoError(t, err) + + err = os.Setenv("AWS_ENDPOINT_URL", "https://127.0.0.1:443") + assert.NoError(t, err) + + endpoint, err := EnvironmentEndpointResolver(). + EndpointFor("sts", "us-west-1", endpoints.STSRegionalEndpointOption) + assert.NoError(t, err) + + assert.Equal(t, "", endpoint.PartitionID) + assert.Equal(t, "", endpoint.SigningMethod) + assert.Equal(t, "", endpoint.SigningName) + assert.Equal(t, false, endpoint.SigningNameDerived) + assert.Equal(t, "", endpoint.SigningRegion) + assert.Equal(t, "https://127.0.0.1:443", endpoint.URL) +} + +func TestEnvironmentEndpointResolver_EndpointFor_ServiceSpecific(t *testing.T) { + err := os.Setenv("AWS_IGNORE_CONFIGURED_ENDPOINT_URLS", "false") + assert.NoError(t, err) + + err = os.Setenv("AWS_ENDPOINT_URL", "https://127.0.0.1:443/default") + assert.NoError(t, err) + + err = os.Setenv("AWS_ENDPOINT_URL_STS", "https://127.0.0.1:443/service-specific") + assert.NoError(t, err) + + endpoint, err := EnvironmentEndpointResolver(). + EndpointFor("sts", "us-west-1", endpoints.STSRegionalEndpointOption) + assert.NoError(t, err) + + assert.Equal(t, "", endpoint.PartitionID) + assert.Equal(t, "", endpoint.SigningMethod) + assert.Equal(t, "", endpoint.SigningName) + assert.Equal(t, false, endpoint.SigningNameDerived) + assert.Equal(t, "", endpoint.SigningRegion) + assert.Equal(t, "https://127.0.0.1:443/service-specific", endpoint.URL) +} + +func TestEnvironmentEndpointResolver_EndpointFor_ServiceSpecificCustom(t *testing.T) { + err := os.Setenv("AWS_IGNORE_CONFIGURED_ENDPOINT_URLS", "false") + assert.NoError(t, err) + + err = os.Setenv("AWS_ENDPOINT_URL", "https://127.0.0.1:443/default") + assert.NoError(t, err) + + err = os.Setenv("AWS_ENDPOINT_URL_SECRETS_MANAGER", "https://127.0.0.1:443/service-specific") + assert.NoError(t, err) + + endpoint, err := EnvironmentEndpointResolver(). + EndpointFor("secretsmanager", "us-west-1", endpoints.STSRegionalEndpointOption) + assert.NoError(t, err) + + assert.Equal(t, "", endpoint.PartitionID) + assert.Equal(t, "", endpoint.SigningMethod) + assert.Equal(t, "", endpoint.SigningName) + assert.Equal(t, false, endpoint.SigningNameDerived) + assert.Equal(t, "", endpoint.SigningRegion) + assert.Equal(t, "https://127.0.0.1:443/service-specific", endpoint.URL) +}