From a4e99e57ca7528361888c6940ecde2c81abcd39e Mon Sep 17 00:00:00 2001 From: gnehil Date: Fri, 8 Sep 2023 18:18:40 +0800 Subject: [PATCH] convert internal row value manually --- .../doris/spark/load/DorisStreamLoad.java | 10 ++-- .../spark/load/RecordBatchInputStream.java | 20 ++----- .../org/apache/doris/spark/util/DataUtil.java | 45 +++++--------- .../apache/doris/spark/sql/SchemaUtils.scala | 58 ++++++++++++++++++- .../doris/spark/writer/DorisWriter.scala | 11 ++-- .../apache/doris/spark/util/DataUtilTest.java | 32 ---------- .../doris/spark/sql/SchemaUtilsTest.scala | 37 ++++++++++++ 7 files changed, 121 insertions(+), 92 deletions(-) delete mode 100644 spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java create mode 100644 spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java index 23aab298..9ecfa405 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java @@ -43,9 +43,7 @@ import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.util.EntityUtils; -import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer; import org.apache.spark.sql.types.StructType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -177,7 +175,7 @@ public String toString() { } } - public int load(Iterator rows, StructType schema, Deserializer deserializer) + public int load(Iterator rows, StructType schema) throws StreamLoadException, JsonProcessingException { String label = generateLoadLabel(); @@ -191,7 +189,7 @@ public int load(Iterator rows, StructType schema, Deserializer .format(fileType) .sep(FIELD_DELIMITER) .delim(LINE_DELIMITER) - .schema(schema).build(), deserializer, streamingPassthrough); + .schema(schema).build(), streamingPassthrough); httpPut.setEntity(new InputStreamEntity(recodeBatchInputStream)); HttpResponse httpResponse = httpClient.execute(httpPut); loadResponse = new LoadResponse(httpResponse); @@ -218,12 +216,12 @@ public int load(Iterator rows, StructType schema, Deserializer } - public Integer loadStream(Iterator rows, StructType schema, Deserializer deserializer) + public Integer loadStream(Iterator rows, StructType schema) throws StreamLoadException, JsonProcessingException { if (this.streamingPassthrough) { handleStreamPassThrough(); } - return load(rows, schema, deserializer); + return load(rows, schema); } public void commit(int txnId) throws StreamLoadException { diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java index 830f5d91..6d7686fd 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java @@ -6,9 +6,7 @@ import org.apache.doris.spark.util.DataUtil; import com.fasterxml.jackson.core.JsonProcessingException; -import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,14 +50,8 @@ public class RecordBatchInputStream extends InputStream { */ private final boolean passThrough; - /** - * deserializer for converting InternalRow to Row - */ - private final ExpressionEncoder.Deserializer deserializer; - - public RecordBatchInputStream(RecordBatch recordBatch, ExpressionEncoder.Deserializer deserializer, boolean passThrough) { + public RecordBatchInputStream(RecordBatch recordBatch, boolean passThrough) { this.recordBatch = recordBatch; - this.deserializer = deserializer; this.passThrough = passThrough; } @@ -176,16 +168,14 @@ private int calculateNewCapacity(int capacity, int minCapacity) { /** * Convert Spark row data to byte array * - * @param internalRow row data + * @param row row data * @return byte array * @throws DorisException */ - private byte[] rowToByte(InternalRow internalRow) throws DorisException { + private byte[] rowToByte(InternalRow row) throws DorisException { byte[] bytes; - Row row = deserializer.apply(internalRow.copy()); - if (passThrough) { bytes = row.getString(0).getBytes(StandardCharsets.UTF_8); return bytes; @@ -193,11 +183,11 @@ private byte[] rowToByte(InternalRow internalRow) throws DorisException { switch (recordBatch.getFormat().toLowerCase()) { case "csv": - bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSep()); + bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSchema(), recordBatch.getSep()); break; case "json": try { - bytes = DataUtil.rowToJsonBytes(row, recordBatch.getSchema().fieldNames()); + bytes = DataUtil.rowToJsonBytes(row, recordBatch.getSchema()); } catch (JsonProcessingException e) { throw new DorisException("parse row to json bytes failed", e); } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java index 270266bd..aea6ddee 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java @@ -17,14 +17,15 @@ package org.apache.doris.spark.util; +import org.apache.doris.spark.sql.SchemaUtils; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.spark.sql.Row; -import scala.collection.mutable.WrappedArray; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import java.nio.charset.StandardCharsets; -import java.sql.Date; -import java.sql.Timestamp; import java.util.HashMap; import java.util.Map; @@ -34,44 +35,28 @@ public class DataUtil { public static final String NULL_VALUE = "\\N"; - public static Object handleColumnValue(Object value) { - - if (value == null) { - return NULL_VALUE; - } - - if (value instanceof Date || value instanceof Timestamp) { - return value.toString(); - } - - if (value instanceof WrappedArray) { - return String.format("[%s]", ((WrappedArray) value).mkString(",")); - } - - return value; - - } - - public static byte[] rowToCsvBytes(Row row, String sep) { + public static byte[] rowToCsvBytes(InternalRow row, StructType schema, String sep) { StringBuilder builder = new StringBuilder(); - int n = row.size(); + StructField[] fields = schema.fields(); + int n = row.numFields(); if (n > 0) { - builder.append(handleColumnValue(row.get(0))); + builder.append(SchemaUtils.rowColumnValue(row, 0, fields[0].dataType())); int i = 1; while (i < n) { builder.append(sep); - builder.append(handleColumnValue(row.get(i))); + builder.append(SchemaUtils.rowColumnValue(row, i, fields[i].dataType())); i++; } } return builder.toString().getBytes(StandardCharsets.UTF_8); } - public static byte[] rowToJsonBytes(Row row, String[] columns) + public static byte[] rowToJsonBytes(InternalRow row, StructType schema) throws JsonProcessingException { - Map rowMap = new HashMap<>(row.size()); - for (int i = 0; i < columns.length; i++) { - rowMap.put(columns[i], handleColumnValue(row.get(i))); + StructField[] fields = schema.fields(); + Map rowMap = new HashMap<>(row.numFields()); + for (int i = 0; i < fields.length; i++) { + rowMap.put(fields[i].name(), SchemaUtils.rowColumnValue(row, i, fields[i].dataType())); } return MAPPER.writeValueAsBytes(rowMap); } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala index c8aa0349..f5a6a159 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala @@ -18,16 +18,23 @@ package org.apache.doris.spark.sql import org.apache.doris.sdk.thrift.TScanColumnDesc - -import scala.collection.JavaConversions._ +import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, DORIS_READ_FIELD} import org.apache.doris.spark.cfg.Settings import org.apache.doris.spark.exception.DorisException import org.apache.doris.spark.rest.RestService import org.apache.doris.spark.rest.models.{Field, Schema} -import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, DORIS_READ_FIELD} +import org.apache.doris.spark.util.DataUtil +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.slf4j.LoggerFactory +import java.sql.Timestamp +import java.time.{LocalDateTime, ZoneOffset} +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable + private[spark] object SchemaUtils { private val logger = LoggerFactory.getLogger(SchemaUtils.getClass.getSimpleName.stripSuffix("$")) @@ -137,4 +144,49 @@ private[spark] object SchemaUtils { tscanColumnDescs.foreach(desc => schema.put(new Field(desc.getName, desc.getType.name, "", 0, 0, ""))) schema } + + def rowColumnValue(row: SpecializedGetters, ordinal: Int, dataType: DataType): Any = { + + dataType match { + case NullType => DataUtil.NULL_VALUE + case BooleanType => row.getBoolean(ordinal) + case ByteType => row.getByte(ordinal) + case ShortType => row.getShort(ordinal) + case IntegerType => row.getInt(ordinal) + case LongType => row.getLong(ordinal) + case FloatType => row.getFloat(ordinal) + case DoubleType => row.getDouble(ordinal) + case StringType => row.getUTF8String(ordinal).toString + case TimestampType => + LocalDateTime.ofEpochSecond(row.getLong(ordinal) / 100000, (row.getLong(ordinal) % 1000).toInt, ZoneOffset.UTC) + new Timestamp(row.getLong(ordinal) / 1000).toString + case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString + case BinaryType => row.getBinary(ordinal) + case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale) + case at: ArrayType => + val arrayData = row.getArray(ordinal) + var i = 0 + val buffer = mutable.Buffer[Any]() + while (i < arrayData.numElements()) { + if (arrayData.isNullAt(i)) buffer += null else buffer += rowColumnValue(arrayData, i, at.elementType) + i += 1 + } + s"[${buffer.mkString(",")}]" + case mt: MapType => + val mapData = row.getMap(ordinal) + val keys = mapData.keyArray() + val values = mapData.valueArray() + var i = 0 + val map = mutable.Map[Any, Any]() + while (i < keys.numElements()) { + map += rowColumnValue(keys, i, mt.keyType) -> rowColumnValue(values, i, mt.valueType) + i += 1 + } + map.toMap.asJava + case st: StructType => row.getStruct(ordinal, st.length) + case _ => throw new DorisException(s"Unsupported spark type: ${dataType.typeName}") + } + + } + } diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala index 9f9f99b3..b278a385 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala @@ -21,10 +21,8 @@ import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings} import org.apache.doris.spark.listener.DorisTransactionListener import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad} import org.apache.doris.spark.sql.Utils -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer -import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.types.StructType import org.apache.spark.util.CollectionAccumulator import org.slf4j.{Logger, LoggerFactory} @@ -62,7 +60,9 @@ class DorisWriter(settings: SparkSettings) extends Serializable { doWrite(dataFrame, dorisStreamLoader.loadStream) } - private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType, Deserializer[Row]) => Int): Unit = { + private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType) => Int): Unit = { + + val sc = dataFrame.sqlContext.sparkContext val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc") @@ -72,7 +72,6 @@ class DorisWriter(settings: SparkSettings) extends Serializable { var resultRdd = dataFrame.queryExecution.toRdd val schema = dataFrame.schema - val deserializer = RowEncoder(schema).resolveAndBind().createDeserializer() if (Objects.nonNull(sinkTaskPartitionSize)) { resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize) } @@ -80,7 +79,7 @@ class DorisWriter(settings: SparkSettings) extends Serializable { while (iterator.hasNext) { // do load batch with retries Utils.retry[Int, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) { - loadFunc(iterator.asJava, schema, deserializer) + loadFunc(iterator.asJava, schema) } match { case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc) case Failure(e) => diff --git a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java b/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java deleted file mode 100644 index 0f6fb36b..00000000 --- a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java +++ /dev/null @@ -1,32 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.spark.util; - -import junit.framework.TestCase; -import org.junit.Assert; -import scala.collection.mutable.WrappedArray; - -import java.sql.Timestamp; - -public class DataUtilTest extends TestCase { - - public void testHandleColumnValue() { - Assert.assertEquals("2023-08-14 18:00:00.0", DataUtil.handleColumnValue(Timestamp.valueOf("2023-08-14 18:00:00"))); - Assert.assertEquals("[1,2,3]", DataUtil.handleColumnValue(WrappedArray.make(new Integer[]{1,2,3}))); - } -} \ No newline at end of file diff --git a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala new file mode 100644 index 00000000..fb729cef --- /dev/null +++ b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala @@ -0,0 +1,37 @@ +package org.apache.doris.spark.sql + +import org.apache.spark.sql.SparkSession +import org.junit.{Assert, Ignore, Test} + +import java.sql.{Date, Timestamp} +import scala.collection.JavaConverters._ + +@Ignore +class SchemaUtilsTest { + + @Test + def rowColumnValueTest(): Unit = { + + val spark = SparkSession.builder().master("local").getOrCreate() + + val df = spark.createDataFrame(Seq( + (1, Date.valueOf("2023-09-08"), Timestamp.valueOf("2023-09-08 17:00:00"), Array(1, 2, 3), Map[String, String]("a" -> "1")) + )).toDF("c1", "c2", "c3", "c4", "c5") + + val schema = df.schema + + df.queryExecution.toRdd.foreach(row => { + + val fields = schema.fields + Assert.assertEquals(1, SchemaUtils.rowColumnValue(row, 0, fields(0).dataType)) + Assert.assertEquals("2023-09-08", SchemaUtils.rowColumnValue(row, 1, fields(1).dataType)) + Assert.assertEquals("2023-09-08 17:00:00.0", SchemaUtils.rowColumnValue(row, 2, fields(2).dataType)) + Assert.assertEquals("[1,2,3]", SchemaUtils.rowColumnValue(row, 3, fields(3).dataType)) + println(SchemaUtils.rowColumnValue(row, 4, fields(4).dataType)) + Assert.assertEquals(Map("a" -> "1").asJava, SchemaUtils.rowColumnValue(row, 4, fields(4).dataType)) + + }) + + } + +}