Skip to content

Commit

Permalink
add oauth
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-alhuang committed Jul 24, 2023
1 parent 52f0a8a commit c49a670
Show file tree
Hide file tree
Showing 18 changed files with 728 additions and 44 deletions.
Binary file modified .github/scripts/profile.json.gpg
Binary file not shown.
Binary file modified .github/scripts/profile_azure.json.gpg
Binary file not shown.
Binary file modified .github/scripts/profile_gcs.json.gpg
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,11 @@ public Config validate(Map<String, String> connectorConfigs) {
}

// If private key or private key passphrase is provided through file, skip validation
if (connectorConfigs.getOrDefault(Utils.SF_PRIVATE_KEY, "").contains("${file:")
|| connectorConfigs.getOrDefault(Utils.PRIVATE_KEY_PASSPHRASE, "").contains("${file:"))
if (connectorConfigs
.getOrDefault(Utils.SF_AUTHENTICATOR, Utils.SNOWFLAKE_JWT)
.equals(Utils.SNOWFLAKE_JWT)
&& (connectorConfigs.getOrDefault(Utils.SF_PRIVATE_KEY, "").contains("${file:")
|| connectorConfigs.getOrDefault(Utils.PRIVATE_KEY_PASSPHRASE, "").contains("${file:")))
return result;

// We don't validate name, since it is not included in the return value
Expand Down Expand Up @@ -244,6 +247,28 @@ public Config validate(Map<String, String> connectorConfigs) {
case "0013":
Utils.updateConfigErrorMessage(result, Utils.SF_PRIVATE_KEY, " must be non-empty");
break;
case "0026":
Utils.updateConfigErrorMessage(
result,
Utils.SF_OAUTH_CLIENT_ID,
" must be non-empty when using oauth authenticator");
break;
case "0027":
Utils.updateConfigErrorMessage(
result,
Utils.SF_OAUTH_CLIENT_SECRET,
" must be non-empty when using oauth authenticator");
break;
case "0028":
Utils.updateConfigErrorMessage(
result,
Utils.SF_OAUTH_REFRESH_TOKEN,
" must be non-empty when using oauth authenticator");
break;
case "0029":
Utils.updateConfigErrorMessage(
result, Utils.SF_AUTHENTICATOR, " is not a valid authenticator");
break;
case "0002":
Utils.updateConfigErrorMessage(
result, Utils.SF_PRIVATE_KEY, " must be a valid PEM RSA private key");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ public class SnowflakeSinkConnectorConfig {
static final String SNOWFLAKE_DATABASE = Utils.SF_DATABASE;
static final String SNOWFLAKE_SCHEMA = Utils.SF_SCHEMA;
static final String SNOWFLAKE_PRIVATE_KEY_PASSPHRASE = Utils.PRIVATE_KEY_PASSPHRASE;
static final String AUTHENTICATOR = Utils.SF_AUTHENTICATOR;
static final String OAUTH_CLIENT_ID = Utils.SF_OAUTH_CLIENT_ID;
static final String OAUTH_CLIENT_SECRET = Utils.SF_OAUTH_CLIENT_SECRET;
static final String OAUTH_REFRESH_TOKEN = Utils.SF_OAUTH_REFRESH_TOKEN;

// For Snowpipe Streaming client
public static final String SNOWFLAKE_ROLE = Utils.SF_ROLE;
Expand Down
207 changes: 202 additions & 5 deletions src/main/java/com/snowflake/kafka/connector/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,42 @@
import com.snowflake.kafka.connector.internal.BufferThreshold;
import com.snowflake.kafka.connector.internal.KCLogger;
import com.snowflake.kafka.connector.internal.SnowflakeErrors;
import com.snowflake.kafka.connector.internal.SnowflakeURL;
import com.snowflake.kafka.connector.internal.streaming.IngestionMethodConfig;
import com.snowflake.kafka.connector.internal.streaming.StreamingUtils;
import java.io.BufferedReader;
import java.io.File;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.net.Authenticator;
import java.net.PasswordAuthentication;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
import java.net.URLEncoder;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import net.snowflake.client.jdbc.internal.apache.http.HttpHeaders;
import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse;
import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpPost;
import net.snowflake.client.jdbc.internal.apache.http.client.utils.URIBuilder;
import net.snowflake.client.jdbc.internal.apache.http.entity.ContentType;
import net.snowflake.client.jdbc.internal.apache.http.entity.StringEntity;
import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient;
import net.snowflake.client.jdbc.internal.apache.http.impl.client.HttpClientBuilder;
import net.snowflake.client.jdbc.internal.apache.http.util.EntityUtils;
import net.snowflake.client.jdbc.internal.google.api.client.http.HttpStatusCodes;
import net.snowflake.client.jdbc.internal.google.gson.JsonObject;
import net.snowflake.client.jdbc.internal.google.gson.JsonParser;
import org.apache.kafka.common.config.Config;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.config.ConfigValue;
Expand All @@ -62,6 +81,10 @@ public class Utils {
public static final String SF_SSL = "sfssl"; // for test only
public static final String SF_WAREHOUSE = "sfwarehouse"; // for test only
public static final String PRIVATE_KEY_PASSPHRASE = "snowflake.private.key" + ".passphrase";
public static final String SF_AUTHENTICATOR = "snowflake.authenticator";
public static final String SF_OAUTH_CLIENT_ID = "snowflake.oauth.client.id";
public static final String SF_OAUTH_CLIENT_SECRET = "snowflake.oauth.client.secret";
public static final String SF_OAUTH_REFRESH_TOKEN = "snowflake.refresh.token";

/**
* This value should be present if ingestion method is {@link
Expand Down Expand Up @@ -105,6 +128,19 @@ public class Utils {
public static final String GET_EXCEPTION_MISSING_MESSAGE = "missing exception message";
public static final String GET_EXCEPTION_MISSING_CAUSE = "missing exception cause";

// OAuth
public static final String TOKEN_REQUEST_ENDPOINT = "/oauth/token-request";
public static final String OAUTH_CONTENT_TYPE_HEADER = "application/x-www-form-urlencoded";
public static final String BASIC_AUTH_HEADER_PREFIX = "Basic ";
public static final String GRANT_TYPE_PARAM = "grant_type";
public static final String REFRESH_TOKEN = "refresh_token";
public static final String ACCESS_TOKEN = "access_token";
public static final String SNOWFLAKE_JWT = "snowflake_jwt";
public static final String OAUTH = "oauth";
public static final String REDIRECT_URI = "redirect_uri";
public static final String DEFAULT_REDIRECT_URI = "https://localhost.com/oauth";
public static final int OAUTH_MAX_RETRY = 5;

private static final KCLogger LOGGER = new KCLogger(Utils.class.getName());

/**
Expand Down Expand Up @@ -440,11 +476,51 @@ && parseTopicToTableMap(config.get(SnowflakeSinkConnectorConfig.TOPICS_TABLES_MA
Utils.formatString("{} cannot be empty.", SnowflakeSinkConnectorConfig.SNOWFLAKE_SCHEMA));
}

if (!config.containsKey(SnowflakeSinkConnectorConfig.SNOWFLAKE_PRIVATE_KEY)) {
invalidConfigParams.put(
SnowflakeSinkConnectorConfig.SNOWFLAKE_PRIVATE_KEY,
Utils.formatString(
"{} cannot be empty.", SnowflakeSinkConnectorConfig.SNOWFLAKE_PRIVATE_KEY));
switch (config.getOrDefault(SnowflakeSinkConnectorConfig.AUTHENTICATOR, SNOWFLAKE_JWT)) {
case SNOWFLAKE_JWT:
if (!config.containsKey(SnowflakeSinkConnectorConfig.SNOWFLAKE_PRIVATE_KEY)) {
invalidConfigParams.put(
SnowflakeSinkConnectorConfig.SNOWFLAKE_PRIVATE_KEY,
Utils.formatString(
"{} cannot be empty when using {} authenticator.",
SnowflakeSinkConnectorConfig.SNOWFLAKE_PRIVATE_KEY,
SNOWFLAKE_JWT));
}
break;
case OAUTH:
if (!config.containsKey(SnowflakeSinkConnectorConfig.OAUTH_CLIENT_ID)) {
invalidConfigParams.put(
SnowflakeSinkConnectorConfig.OAUTH_CLIENT_ID,
Utils.formatString(
"{} cannot be empty when using {} authenticator.",
SnowflakeSinkConnectorConfig.OAUTH_CLIENT_ID,
OAUTH));
}
if (!config.containsKey(SnowflakeSinkConnectorConfig.OAUTH_CLIENT_SECRET)) {
invalidConfigParams.put(
SnowflakeSinkConnectorConfig.OAUTH_CLIENT_SECRET,
Utils.formatString(
"{} cannot be empty when using {} authenticator.",
SnowflakeSinkConnectorConfig.OAUTH_CLIENT_SECRET,
OAUTH));
}
if (!config.containsKey(SnowflakeSinkConnectorConfig.OAUTH_REFRESH_TOKEN)) {
invalidConfigParams.put(
SnowflakeSinkConnectorConfig.OAUTH_REFRESH_TOKEN,
Utils.formatString(
"{} cannot be empty when using {} authenticator.",
SnowflakeSinkConnectorConfig.OAUTH_REFRESH_TOKEN,
OAUTH));
}
break;
default:
invalidConfigParams.put(
SnowflakeSinkConnectorConfig.AUTHENTICATOR,
Utils.formatString(
"{} should be one of {} or {}.",
SnowflakeSinkConnectorConfig.AUTHENTICATOR,
SNOWFLAKE_JWT,
OAUTH));
}

if (!config.containsKey(SnowflakeSinkConnectorConfig.SNOWFLAKE_USER)) {
Expand Down Expand Up @@ -704,6 +780,127 @@ public static String formatString(String format, Object... vars) {
return format;
}

/**
* Get OAuth access token given refresh token
*
* @param url OAuth server url
* @param clientId OAuth clientId
* @param clientSecret OAuth clientSecret
* @param refreshToken OAuth refresh token
* @return OAuth access token
*/
public static String getSnowflakeOAuthAccessToken(
SnowflakeURL url, String clientId, String clientSecret, String refreshToken) {
return getSnowflakeOAuthToken(
url, clientId, clientSecret, refreshToken, REFRESH_TOKEN, REFRESH_TOKEN, ACCESS_TOKEN);
}

/**
* Get OAuth token given integration info <a
* href="https://docs.snowflake.com/en/user-guide/oauth-snowflake-overview">Snowflake OAuth
* Overview</a>
*
* @param url OAuth server url
* @param clientId OAuth clientId
* @param clientSecret OAuth clientSecret
* @param credential OAuth credential, either az code or refresh token
* @param grantType OAuth grant type, either authorization_code or refresh_token
* @param credentialType OAuth credential key, either code or refresh_token
* @param tokenType type of OAuth token to get, either access_token or refresh_token
* @return OAuth token
*/
public static String getSnowflakeOAuthToken(
SnowflakeURL url,
String clientId,
String clientSecret,
String credential,
String grantType,
String credentialType,
String tokenType) {
Map<String, String> headers = new HashMap<>();
headers.put(HttpHeaders.CONTENT_TYPE, OAUTH_CONTENT_TYPE_HEADER);
headers.put(
HttpHeaders.AUTHORIZATION,
BASIC_AUTH_HEADER_PREFIX
+ Base64.getEncoder().encodeToString((clientId + ":" + clientSecret).getBytes()));

Map<String, String> payload = new HashMap<>();
payload.put(GRANT_TYPE_PARAM, grantType);
payload.put(credentialType, credential);
payload.put(REDIRECT_URI, DEFAULT_REDIRECT_URI);

// Encode and convert payload into string entity
String payloadString =
payload.entrySet().stream()
.map(
e -> {
try {
return e.getKey() + "=" + URLEncoder.encode(e.getValue(), "UTF-8");
} catch (UnsupportedEncodingException ex) {
throw new RuntimeException(ex);
}
})
.collect(Collectors.joining("&"));
final StringEntity entity =
new StringEntity(payloadString, ContentType.APPLICATION_FORM_URLENCODED);

HttpPost post = makeOAuthHttpPost(url, TOKEN_REQUEST_ENDPOINT, headers, entity);

// Request access token
CloseableHttpClient client = HttpClientBuilder.create().build();
for (int retries = 0; retries < OAUTH_MAX_RETRY; retries++) {
try (CloseableHttpResponse httpResponse = client.execute(post)) {
String respBodyString = EntityUtils.toString(httpResponse.getEntity());

if (httpResponse.getStatusLine().getStatusCode() == HttpStatusCodes.STATUS_CODE_OK) {
JsonObject respBody = JsonParser.parseString(respBodyString).getAsJsonObject();
if (respBody.has(tokenType)) {
// Trim surrounding quotation marks
return respBody.get(tokenType).toString().replaceAll("^\"|\"$", "");
}
}
} catch (Exception e) {
// Exponential backoff retires
try {
Thread.sleep((1L << retries) * 1000L);
} catch (InterruptedException ex) {
throw SnowflakeErrors.ERROR_1004.getException(ex);
}
}
}
throw SnowflakeErrors.ERROR_1004.getException("Failed to get access token");
}

/**
* Build OAuth http post request base on headers and payload
*
* @param url target url
* @param headers headers key value pairs
* @param entity payload entity
* @return HttpPost request for OAuth
*/
public static HttpPost makeOAuthHttpPost(
SnowflakeURL url, String path, Map<String, String> headers, StringEntity entity) {
// Build post request
URI uri;
try {
uri =
new URIBuilder().setHost(url.toString()).setScheme(url.getScheme()).setPath(path).build();
} catch (URISyntaxException e) {
throw SnowflakeErrors.ERROR_1004.getException(e);
}

// Add headers
HttpPost post = new HttpPost(uri);
for (Map.Entry<String, String> e : headers.entrySet()) {
post.addHeader(e.getKey(), e.getValue());
}

post.setEntity(entity);

return post;
}

/**
* Get the message and cause of a missing exception, handling the null or empty cases of each
*
Expand Down
Loading

0 comments on commit c49a670

Please sign in to comment.