diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java index 61184f9d..2ab200d8 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java @@ -97,4 +97,7 @@ public interface ConfigurationOptions { */ String DORIS_IGNORE_TYPE = "doris.ignore-type"; + String DORIS_SINK_ENABLE_2PC = "doris.sink.enable-2pc"; + boolean DORIS_SINK_ENABLE_2PC_DEFAULT = false; + } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java index d2e845a0..798ec8cf 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java @@ -62,6 +62,17 @@ public Integer getIntegerProperty(String name, Integer defaultValue) { return defaultValue; } + public Boolean getBooleanProperty(String name) { + return getBooleanProperty(name, null); + } + + public Boolean getBooleanProperty(String name, Boolean defaultValue) { + if (getProperty(name) != null) { + return Boolean.valueOf(getProperty(name)); + } + return defaultValue; + } + public Settings merge(Properties properties) { if (properties == null || properties.isEmpty()) { return this; diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index 5341f67e..c40420d8 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -23,8 +23,10 @@ import org.apache.doris.spark.rest.models.BackendV2; import org.apache.doris.spark.rest.models.RespContent; import org.apache.doris.spark.util.ListUtils; +import org.apache.doris.spark.util.ResponseUtil; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; @@ -34,6 +36,7 @@ import org.apache.http.HttpHeaders; import org.apache.http.HttpResponse; import org.apache.http.HttpStatus; +import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpPut; import org.apache.http.entity.BufferedHttpEntity; import org.apache.http.entity.StringEntity; @@ -74,6 +77,9 @@ public class DorisStreamLoad implements Serializable { private final static List DORIS_SUCCESS_STATUS = new ArrayList<>(Arrays.asList("Success", "Publish Timeout")); private static String loadUrlPattern = "http://%s/api/%s/%s/_stream_load?"; + + private static String abortUrlPattern = "http://%s/api/%s/%s/_stream_load_2pc?"; + private String user; private String passwd; private String loadUrlStr; @@ -99,9 +105,7 @@ public DorisStreamLoad(SparkSettings settings) { this.columns = settings.getProperty(ConfigurationOptions.DORIS_WRITE_FIELDS); this.maxFilterRatio = settings.getProperty(ConfigurationOptions.DORIS_MAX_FILTER_RATIO); this.streamLoadProp = getStreamLoadProp(settings); - cache = CacheBuilder.newBuilder() - .expireAfterWrite(cacheExpireTimeout, TimeUnit.MINUTES) - .build(new BackendCacheLoader(settings)); + cache = CacheBuilder.newBuilder().expireAfterWrite(cacheExpireTimeout, TimeUnit.MINUTES).build(new BackendCacheLoader(settings)); fileType = streamLoadProp.getOrDefault("format", "csv"); if ("csv".equals(fileType)) { FIELD_DELIMITER = escapeString(streamLoadProp.getOrDefault("column_separator", "\t")); @@ -121,13 +125,13 @@ public String getLoadUrlStr() { } return loadUrlStr; } + private CloseableHttpClient getHttpClient() { - HttpClientBuilder httpClientBuilder = HttpClientBuilder.create() - .disableRedirectHandling(); + HttpClientBuilder httpClientBuilder = HttpClientBuilder.create().disableRedirectHandling(); return httpClientBuilder.build(); } - private HttpPut getHttpPut(String label, String loadUrlStr) { + private HttpPut getHttpPut(String label, String loadUrlStr, Boolean enable2PC) { HttpPut httpPut = new HttpPut(loadUrlStr); httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded); httpPut.setHeader(HttpHeaders.EXPECT, "100-continue"); @@ -139,6 +143,9 @@ private HttpPut getHttpPut(String label, String loadUrlStr) { if (StringUtils.isNotBlank(maxFilterRatio)) { httpPut.setHeader("max_filter_ratio", maxFilterRatio); } + if (enable2PC) { + httpPut.setHeader("two_phase_commit", "true"); + } if (MapUtils.isNotEmpty(streamLoadProp)) { streamLoadProp.forEach(httpPut::setHeader); } @@ -158,100 +165,162 @@ public LoadResponse(int status, String respMsg, String respContent) { @Override public String toString() { - return "status: " + status + - ", resp msg: " + respMsg + - ", resp content: " + respContent; + return "status: " + status + ", resp msg: " + respMsg + ", resp content: " + respContent; } } - public String listToString(List> rows) { - return rows.stream().map(row -> - row.stream().map(field -> field == null ? NULL_VALUE : field.toString()) - .collect(Collectors.joining(FIELD_DELIMITER)) - ).collect(Collectors.joining(LINE_DELIMITER)); - } + public List loadV2(List> rows, String[] dfColumns, Boolean enable2PC) throws StreamLoadException, JsonProcessingException { + List loadData = parseLoadData(rows, dfColumns); + List txnIds = new ArrayList<>(loadData.size()); - public void loadV2(List> rows, String[] dfColumns) throws StreamLoadException, JsonProcessingException { - if (fileType.equals("csv")) { - load(listToString(rows)); - } else if(fileType.equals("json")) { - List> dataList = new ArrayList<>(); - try { - for (List row : rows) { - Map dataMap = new HashMap<>(); - if (dfColumns.length == row.size()) { - for (int i = 0; i < dfColumns.length; i++) { - Object col = row.get(i); - if (col instanceof Timestamp) { - dataMap.put(dfColumns[i], col.toString()); - continue; - } - dataMap.put(dfColumns[i], col); - } - } - dataList.add(dataMap); - } - } catch (Exception e) { - throw new StreamLoadException("The number of configured columns does not match the number of data columns."); + try { + for (String data : loadData) { + txnIds.add(load(data, enable2PC)); } - // splits large collections to normal collection to avoid the "Requested array size exceeds VM limit" exception - List serializedList = ListUtils.getSerializedList(dataList, readJsonByLine ? LINE_DELIMITER : null); - for (String serializedRows : serializedList) { - load(serializedRows); + } catch (StreamLoadException e) { + if (enable2PC && !txnIds.isEmpty()) { + LOG.error("load batch failed, abort previously pre-committed transactions"); + for (Integer txnId : txnIds) { + abort(txnId); + } } - } else { - throw new StreamLoadException(String.format("Unsupported file format in stream load: %s.", fileType)); + throw e; } + + return txnIds; + } - public void load(String value) throws StreamLoadException { - LoadResponse loadResponse = loadBatch(value); + public int load(String value, Boolean enable2PC) throws StreamLoadException { + + String label = generateLoadLabel(); + + LoadResponse loadResponse; + int responseHttpStatus = -1; + try (CloseableHttpClient httpClient = getHttpClient()) { + String loadUrlStr = String.format(loadUrlPattern, getBackend(), db, tbl); + LOG.debug("Stream load Request:{} ,Body:{}", loadUrlStr, value); + // only to record the BE node in case of an exception + this.loadUrlStr = loadUrlStr; + + HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC); + httpPut.setEntity(new StringEntity(value, StandardCharsets.UTF_8)); + HttpResponse httpResponse = httpClient.execute(httpPut); + responseHttpStatus = httpResponse.getStatusLine().getStatusCode(); + String respMsg = httpResponse.getStatusLine().getReasonPhrase(); + String response = EntityUtils.toString(new BufferedHttpEntity(httpResponse.getEntity()), StandardCharsets.UTF_8); + loadResponse = new LoadResponse(responseHttpStatus, respMsg, response); + } catch (IOException e) { + e.printStackTrace(); + String err = "http request exception,load url : " + loadUrlStr + ",failed to execute spark stream load with label: " + label; + LOG.warn(err, e); + loadResponse = new LoadResponse(responseHttpStatus, e.getMessage(), err); + } + if (loadResponse.status != HttpStatus.SC_OK) { - LOG.info("Streamload Response HTTP Status Error:{}", loadResponse); - throw new StreamLoadException("stream load error: " + loadResponse.respContent); + LOG.info("Stream load Response HTTP Status Error:{}", loadResponse); + // throw new StreamLoadException("stream load error: " + loadResponse.respContent); + throw new StreamLoadException("stream load error"); } else { ObjectMapper obj = new ObjectMapper(); try { RespContent respContent = obj.readValue(loadResponse.respContent, RespContent.class); if (!DORIS_SUCCESS_STATUS.contains(respContent.getStatus())) { - LOG.error("Streamload Response RES STATUS Error:{}", loadResponse); - throw new StreamLoadException("stream load error: " + loadResponse); + LOG.error("Stream load Response RES STATUS Error:{}", loadResponse); + throw new StreamLoadException("stream load error"); } - LOG.info("Streamload Response:{}", loadResponse); + LOG.info("Stream load Response:{}", loadResponse); + return respContent.getTxnId(); } catch (IOException e) { throw new StreamLoadException(e); } } + } - private LoadResponse loadBatch(String value) { - Calendar calendar = Calendar.getInstance(); - String label = String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s", - calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH), - calendar.get(Calendar.HOUR_OF_DAY), calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND), - UUID.randomUUID().toString().replaceAll("-", "")); + public void commit(int txnId) throws StreamLoadException { - int responseHttpStatus = -1; - try (CloseableHttpClient httpClient = getHttpClient()) { - String loadUrlStr = String.format(loadUrlPattern, getBackend(), db, tbl); - LOG.debug("Streamload Request:{} ,Body:{}", loadUrlStr, value); - //only to record the BE node in case of an exception - this.loadUrlStr = loadUrlStr; + try (CloseableHttpClient client = getHttpClient()) { + + String backend = getBackend(); + String abortUrl = String.format(abortUrlPattern, backend, db, tbl); + HttpPut httpPut = new HttpPut(abortUrl); + httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded); + httpPut.setHeader(HttpHeaders.EXPECT, "100-continue"); + httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8"); + httpPut.setHeader("txn_operation", "commit"); + httpPut.setHeader("txn_id", String.valueOf(txnId)); + + CloseableHttpResponse response = client.execute(httpPut); + int statusCode = response.getStatusLine().getStatusCode(); + if (statusCode != 200 || response.getEntity() == null) { + LOG.warn("commit transaction response: " + response.getStatusLine().toString()); + throw new StreamLoadException("Fail to commit transaction " + txnId + " with url " + abortUrl); + } + + statusCode = response.getStatusLine().getStatusCode(); + String reasonPhrase = response.getStatusLine().getReasonPhrase(); + if (statusCode != 200) { + LOG.warn("commit failed with {}, reason {}", backend, reasonPhrase); + throw new StreamLoadException("stream load error: " + reasonPhrase); + } + + ObjectMapper mapper = new ObjectMapper(); + if (response.getEntity() != null) { + String loadResult = EntityUtils.toString(response.getEntity()); + Map res = mapper.readValue(loadResult, new TypeReference>() { + }); + if (res.get("status").equals("Fail") && !ResponseUtil.isCommitted(res.get("msg"))) { + throw new StreamLoadException("Commit failed " + loadResult); + } else { + LOG.info("load result {}", loadResult); + } + } - HttpPut httpPut = getHttpPut(label, loadUrlStr); - httpPut.setEntity(new StringEntity(value, StandardCharsets.UTF_8)); - HttpResponse httpResponse = httpClient.execute(httpPut); - responseHttpStatus = httpResponse.getStatusLine().getStatusCode(); - String respMsg = httpResponse.getStatusLine().getReasonPhrase(); - String response = EntityUtils.toString(new BufferedHttpEntity(httpResponse.getEntity()), StandardCharsets.UTF_8); - return new LoadResponse(responseHttpStatus, respMsg, response); } catch (IOException e) { - e.printStackTrace(); - String err = "http request exception,load url : " + loadUrlStr + ",failed to execute spark streamload with label: " + label; - LOG.warn(err, e); - return new LoadResponse(responseHttpStatus, e.getMessage(), err); + throw new StreamLoadException(e); } + + } + + public void abort(int txnId) throws StreamLoadException { + + LOG.info("start abort transaction {}.", txnId); + + try (CloseableHttpClient client = getHttpClient()) { + String abortUrl = String.format(abortUrlPattern, getBackend(), db, tbl); + HttpPut httpPut = new HttpPut(abortUrl); + httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded); + httpPut.setHeader(HttpHeaders.EXPECT, "100-continue"); + httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8"); + httpPut.setHeader("txn_operation", "abort"); + httpPut.setHeader("txn_id", String.valueOf(txnId)); + + CloseableHttpResponse response = client.execute(httpPut); + int statusCode = response.getStatusLine().getStatusCode(); + if (statusCode != 200 || response.getEntity() == null) { + LOG.warn("abort transaction response: " + response.getStatusLine().toString()); + throw new StreamLoadException("Fail to abort transaction " + txnId + " with url " + abortUrl); + } + + ObjectMapper mapper = new ObjectMapper(); + String loadResult = EntityUtils.toString(response.getEntity()); + Map res = mapper.readValue(loadResult, new TypeReference>() { + }); + if (!"Success".equals(res.get("status"))) { + if (ResponseUtil.isCommitted(res.get("msg"))) { + throw new StreamLoadException("try abort committed transaction, " + "do you recover from old savepoint?"); + } + LOG.warn("Fail to abort transaction. txnId: {}, error: {}", txnId, res.get("msg")); + } + + } catch (IOException e) { + throw new StreamLoadException(e); + } + + LOG.info("abort transaction {} succeed.", txnId); + } public Map getStreamLoadProp(SparkSettings sparkSettings) { @@ -268,7 +337,7 @@ public Map getStreamLoadProp(SparkSettings sparkSettings) { private String getBackend() { try { - //get backends from cache + // get backends from cache List backends = cache.get("backends"); Collections.shuffle(backends); BackendV2.BackendRowV2 backend = backends.get(0); @@ -301,6 +370,54 @@ public List load(String key) throws Exception { } + private List parseLoadData(List> rows, String[] dfColumns) throws StreamLoadException, JsonProcessingException { + + List loadDataList; + + switch (fileType.toUpperCase()) { + + case "CSV": + loadDataList = Collections.singletonList(rows.stream().map(row -> row.stream().map(field -> field == null ? NULL_VALUE : field.toString()).collect(Collectors.joining(FIELD_DELIMITER))).collect(Collectors.joining(LINE_DELIMITER))); + break; + case "JSON": + List> dataList = new ArrayList<>(); + try { + for (List row : rows) { + Map dataMap = new HashMap<>(); + if (dfColumns.length == row.size()) { + for (int i = 0; i < dfColumns.length; i++) { + Object col = row.get(i); + if (col instanceof Timestamp) { + dataMap.put(dfColumns[i], col.toString()); + continue; + } + dataMap.put(dfColumns[i], col); + } + } + dataList.add(dataMap); + } + } catch (Exception e) { + throw new StreamLoadException("The number of configured columns does not match the number of data columns."); + } + // splits large collections to normal collection to avoid the "Requested array size exceeds VM limit" exception + loadDataList = ListUtils.getSerializedList(dataList, readJsonByLine ? LINE_DELIMITER : null); + break; + default: + throw new StreamLoadException(String.format("Unsupported file format in stream load: %s.", fileType)); + + } + + return loadDataList; + + } + + private String generateLoadLabel() { + + Calendar calendar = Calendar.getInstance(); + return String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s", calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, calendar.get(Calendar.DAY_OF_MONTH), calendar.get(Calendar.HOUR_OF_DAY), calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND), UUID.randomUUID().toString().replaceAll("-", "")); + + } + private String escapeString(String hexData) { if (hexData.startsWith("\\x") || hexData.startsWith("\\X")) { try { @@ -314,7 +431,7 @@ private String escapeString(String hexData) { } return stringBuilder.toString(); } catch (Exception e) { - throw new RuntimeException("escape column_separator or line_delimiter error.{}" , e); + throw new RuntimeException("escape column_separator or line_delimiter error.{}", e); } } return hexData; diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java index f7fa6ff4..7829cc23 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java @@ -75,6 +75,10 @@ public class RespContent { @JsonProperty(value = "ErrorURL") private String ErrorURL; + public int getTxnId() { + return TxnId; + } + public String getStatus() { return Status; } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ResponseUtil.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ResponseUtil.java new file mode 100644 index 00000000..1b6a66b1 --- /dev/null +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ResponseUtil.java @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.util; + +import java.util.regex.Pattern; + +public class ResponseUtil { + public static final Pattern LABEL_EXIST_PATTERN = + Pattern.compile("errCode = 2, detailMessage = Label \\[(.*)\\] " + + "has already been used, relate to txn \\[(\\d+)\\]"); + public static final Pattern COMMITTED_PATTERN = + Pattern.compile("errCode = 2, detailMessage = transaction \\[(\\d+)\\] " + + "is already \\b(COMMITTED|committed|VISIBLE|visible)\\b, not pre-committed."); + + public static boolean isCommitted(String msg) { + return COMMITTED_PATTERN.matcher(msg).matches(); + } +} \ No newline at end of file diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala new file mode 100644 index 00000000..a36e634a --- /dev/null +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.listener + +import org.apache.doris.spark.load.DorisStreamLoad +import org.apache.doris.spark.sql.Utils +import org.apache.spark.scheduler._ +import org.apache.spark.util.CollectionAccumulator +import org.slf4j.{Logger, LoggerFactory} + +import java.time.Duration +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.{Failure, Success} + +class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Int], dorisStreamLoad: DorisStreamLoad) + extends SparkListener { + + val logger: Logger = LoggerFactory.getLogger(classOf[DorisTransactionListener]) + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + val txnIds: mutable.Buffer[Int] = preCommittedTxnAcc.value.asScala + val failedTxnIds = mutable.Buffer[Int]() + jobEnd.jobResult match { + // if job succeed, commit all transactions + case JobSucceeded => + if (txnIds.isEmpty) { + logger.warn("job run succeed, but there is no pre-committed txn ids") + return + } + logger.info("job run succeed, start committing transactions") + txnIds.foreach(txnId => + Utils.retry(3, Duration.ofSeconds(1), logger) { + dorisStreamLoad.commit(txnId) + } match { + case Success(_) => + case Failure(_) => failedTxnIds += txnId + } + ) + + if (failedTxnIds.nonEmpty) { + logger.error("uncommitted txn ids: {}", failedTxnIds.mkString(",")) + } else { + logger.info("commit transaction success") + } + // if job failed, abort all pre committed transactions + case _ => + if (txnIds.isEmpty) { + logger.warn("job run failed, but there is no pre-committed txn ids") + return + } + logger.info("job run failed, start aborting transactions") + txnIds.foreach(txnId => + Utils.retry(3, Duration.ofSeconds(1), logger) { + dorisStreamLoad.abort(txnId) + } match { + case Success(_) => + case Failure(_) => failedTxnIds += txnId + }) + if (failedTxnIds.nonEmpty) { + logger.error("not aborted txn ids: {}", failedTxnIds.mkString(",")) + } else { + logger.info("abort transaction success") + } + } + } + +} diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala index 2f3a5bb0..2b9c3c18 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala @@ -174,7 +174,7 @@ private[spark] object Utils { Success(result) case Failure(exception: T) if retryTimes > 0 => logger.warn(s"Execution failed caused by: ", exception) - logger.warn(s"$retryTimes times retry remaining, the next will be in ${interval.toMillis}ms") + logger.warn(s"$retryTimes times retry remaining, the next attempt will be in ${interval.toMillis} ms") LockSupport.parkNanos(interval.toNanos) retry(retryTimes - 1, interval, logger)(f) case Failure(exception) => Failure(exception) diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index 3839ff71..2b918e88 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -18,6 +18,7 @@ package org.apache.doris.spark.writer import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings} +import org.apache.doris.spark.listener.DorisTransactionListener import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad} import org.apache.doris.spark.sql.Utils import org.apache.spark.sql.DataFrame @@ -28,6 +29,7 @@ import java.time.Duration import java.util import java.util.Objects import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.{Failure, Success} class DorisWriter(settings: SparkSettings) extends Serializable { @@ -44,9 +46,19 @@ class DorisWriter(settings: SparkSettings) extends Serializable { private val batchInterValMs: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS, ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT) + private val enable2PC: Boolean = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC, + ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT); + private val dorisStreamLoader: DorisStreamLoad = CachedDorisStreamLoadClient.getOrCreate(settings) def write(dataFrame: DataFrame): Unit = { + + val sc = dataFrame.sqlContext.sparkContext + val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") + if (enable2PC) { + sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader)) + } + var resultRdd = dataFrame.rdd val dfColumns = dataFrame.columns if (Objects.nonNull(sinkTaskPartitionSize)) { @@ -65,11 +77,27 @@ class DorisWriter(settings: SparkSettings) extends Serializable { * */ def flush(batch: Iterable[util.List[Object]], dfColumns: Array[String]): Unit = { - Utils.retry[Unit, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { - dorisStreamLoader.loadV2(batch.toList.asJava, dfColumns) + Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { + dorisStreamLoader.loadV2(batch.toList.asJava, dfColumns, enable2PC) } match { - case Success(_) => + case Success(txnIds) => if (enable2PC) txnIds.asScala.foreach(txnId => preCommittedTxnAcc.add(txnId)) case Failure(e) => + if (enable2PC) { + // if task run failed, acc value will not be returned to driver, + // should abort all pre committed transactions inside the task + logger.info("load task failed, start aborting previously pre-committed transactions") + val abortFailedTxnIds = mutable.Buffer[Int]() + preCommittedTxnAcc.value.asScala.foreach(txnId => { + Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) { + dorisStreamLoader.abort(txnId) + } match { + case Success(_) => + case Failure(_) => abortFailedTxnIds += txnId + } + }) + if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(",")) + preCommittedTxnAcc.reset() + } throw new IOException( s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e) }