Skip to content

Commit

Permalink
convert internal row value manually
Browse files Browse the repository at this point in the history
  • Loading branch information
gnehil committed Sep 8, 2023
1 parent c53bd9a commit a4e99e5
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.util.EntityUtils;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -177,7 +175,7 @@ public String toString() {
}
}

public int load(Iterator<InternalRow> rows, StructType schema, Deserializer<Row> deserializer)
public int load(Iterator<InternalRow> rows, StructType schema)
throws StreamLoadException, JsonProcessingException {

String label = generateLoadLabel();
Expand All @@ -191,7 +189,7 @@ public int load(Iterator<InternalRow> rows, StructType schema, Deserializer<Row>
.format(fileType)
.sep(FIELD_DELIMITER)
.delim(LINE_DELIMITER)
.schema(schema).build(), deserializer, streamingPassthrough);
.schema(schema).build(), streamingPassthrough);
httpPut.setEntity(new InputStreamEntity(recodeBatchInputStream));
HttpResponse httpResponse = httpClient.execute(httpPut);
loadResponse = new LoadResponse(httpResponse);
Expand All @@ -218,12 +216,12 @@ public int load(Iterator<InternalRow> rows, StructType schema, Deserializer<Row>

}

public Integer loadStream(Iterator<InternalRow> rows, StructType schema, Deserializer<Row> deserializer)
public Integer loadStream(Iterator<InternalRow> rows, StructType schema)
throws StreamLoadException, JsonProcessingException {
if (this.streamingPassthrough) {
handleStreamPassThrough();
}
return load(rows, schema, deserializer);
return load(rows, schema);
}

public void commit(int txnId) throws StreamLoadException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
import org.apache.doris.spark.util.DataUtil;

import com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -52,14 +50,8 @@ public class RecordBatchInputStream extends InputStream {
*/
private final boolean passThrough;

/**
* deserializer for converting InternalRow to Row
*/
private final ExpressionEncoder.Deserializer<Row> deserializer;

public RecordBatchInputStream(RecordBatch recordBatch, ExpressionEncoder.Deserializer<Row> deserializer, boolean passThrough) {
public RecordBatchInputStream(RecordBatch recordBatch, boolean passThrough) {
this.recordBatch = recordBatch;
this.deserializer = deserializer;
this.passThrough = passThrough;
}

Expand Down Expand Up @@ -176,28 +168,26 @@ private int calculateNewCapacity(int capacity, int minCapacity) {
/**
* Convert Spark row data to byte array
*
* @param internalRow row data
* @param row row data
* @return byte array
* @throws DorisException
*/
private byte[] rowToByte(InternalRow internalRow) throws DorisException {
private byte[] rowToByte(InternalRow row) throws DorisException {

byte[] bytes;

Row row = deserializer.apply(internalRow.copy());

if (passThrough) {
bytes = row.getString(0).getBytes(StandardCharsets.UTF_8);
return bytes;
}

switch (recordBatch.getFormat().toLowerCase()) {
case "csv":
bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSep());
bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSchema(), recordBatch.getSep());
break;
case "json":
try {
bytes = DataUtil.rowToJsonBytes(row, recordBatch.getSchema().fieldNames());
bytes = DataUtil.rowToJsonBytes(row, recordBatch.getSchema());
} catch (JsonProcessingException e) {
throw new DorisException("parse row to json bytes failed", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.doris.spark.util;

import org.apache.doris.spark.sql.SchemaUtils;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.spark.sql.Row;
import scala.collection.mutable.WrappedArray;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.nio.charset.StandardCharsets;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.HashMap;
import java.util.Map;

Expand All @@ -34,44 +35,28 @@ public class DataUtil {

public static final String NULL_VALUE = "\\N";

public static Object handleColumnValue(Object value) {

if (value == null) {
return NULL_VALUE;
}

if (value instanceof Date || value instanceof Timestamp) {
return value.toString();
}

if (value instanceof WrappedArray) {
return String.format("[%s]", ((WrappedArray<?>) value).mkString(","));
}

return value;

}

public static byte[] rowToCsvBytes(Row row, String sep) {
public static byte[] rowToCsvBytes(InternalRow row, StructType schema, String sep) {
StringBuilder builder = new StringBuilder();
int n = row.size();
StructField[] fields = schema.fields();
int n = row.numFields();
if (n > 0) {
builder.append(handleColumnValue(row.get(0)));
builder.append(SchemaUtils.rowColumnValue(row, 0, fields[0].dataType()));
int i = 1;
while (i < n) {
builder.append(sep);
builder.append(handleColumnValue(row.get(i)));
builder.append(SchemaUtils.rowColumnValue(row, i, fields[i].dataType()));
i++;
}
}
return builder.toString().getBytes(StandardCharsets.UTF_8);
}

public static byte[] rowToJsonBytes(Row row, String[] columns)
public static byte[] rowToJsonBytes(InternalRow row, StructType schema)
throws JsonProcessingException {
Map<String, Object> rowMap = new HashMap<>(row.size());
for (int i = 0; i < columns.length; i++) {
rowMap.put(columns[i], handleColumnValue(row.get(i)));
StructField[] fields = schema.fields();
Map<String, Object> rowMap = new HashMap<>(row.numFields());
for (int i = 0; i < fields.length; i++) {
rowMap.put(fields[i].name(), SchemaUtils.rowColumnValue(row, i, fields[i].dataType()));
}
return MAPPER.writeValueAsBytes(rowMap);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,23 @@
package org.apache.doris.spark.sql

import org.apache.doris.sdk.thrift.TScanColumnDesc

import scala.collection.JavaConversions._
import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, DORIS_READ_FIELD}
import org.apache.doris.spark.cfg.Settings
import org.apache.doris.spark.exception.DorisException
import org.apache.doris.spark.rest.RestService
import org.apache.doris.spark.rest.models.{Field, Schema}
import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, DORIS_READ_FIELD}
import org.apache.doris.spark.util.DataUtil
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.slf4j.LoggerFactory

import java.sql.Timestamp
import java.time.{LocalDateTime, ZoneOffset}
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable

private[spark] object SchemaUtils {
private val logger = LoggerFactory.getLogger(SchemaUtils.getClass.getSimpleName.stripSuffix("$"))

Expand Down Expand Up @@ -137,4 +144,49 @@ private[spark] object SchemaUtils {
tscanColumnDescs.foreach(desc => schema.put(new Field(desc.getName, desc.getType.name, "", 0, 0, "")))
schema
}

def rowColumnValue(row: SpecializedGetters, ordinal: Int, dataType: DataType): Any = {

dataType match {
case NullType => DataUtil.NULL_VALUE
case BooleanType => row.getBoolean(ordinal)
case ByteType => row.getByte(ordinal)
case ShortType => row.getShort(ordinal)
case IntegerType => row.getInt(ordinal)
case LongType => row.getLong(ordinal)
case FloatType => row.getFloat(ordinal)
case DoubleType => row.getDouble(ordinal)
case StringType => row.getUTF8String(ordinal).toString
case TimestampType =>
LocalDateTime.ofEpochSecond(row.getLong(ordinal) / 100000, (row.getLong(ordinal) % 1000).toInt, ZoneOffset.UTC)
new Timestamp(row.getLong(ordinal) / 1000).toString
case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString
case BinaryType => row.getBinary(ordinal)
case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale)
case at: ArrayType =>
val arrayData = row.getArray(ordinal)
var i = 0
val buffer = mutable.Buffer[Any]()
while (i < arrayData.numElements()) {
if (arrayData.isNullAt(i)) buffer += null else buffer += rowColumnValue(arrayData, i, at.elementType)
i += 1
}
s"[${buffer.mkString(",")}]"
case mt: MapType =>
val mapData = row.getMap(ordinal)
val keys = mapData.keyArray()
val values = mapData.valueArray()
var i = 0
val map = mutable.Map[Any, Any]()
while (i < keys.numElements()) {
map += rowColumnValue(keys, i, mt.keyType) -> rowColumnValue(values, i, mt.valueType)
i += 1
}
map.toMap.asJava
case st: StructType => row.getStruct(ordinal, st.length)
case _ => throw new DorisException(s"Unsupported spark type: ${dataType.typeName}")
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
import org.apache.doris.spark.listener.DorisTransactionListener
import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
import org.apache.doris.spark.sql.Utils
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.CollectionAccumulator
import org.slf4j.{Logger, LoggerFactory}
Expand Down Expand Up @@ -62,7 +60,9 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
doWrite(dataFrame, dorisStreamLoader.loadStream)
}

private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType, Deserializer[Row]) => Int): Unit = {
private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType) => Int): Unit = {



val sc = dataFrame.sqlContext.sparkContext
val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc")
Expand All @@ -72,15 +72,14 @@ class DorisWriter(settings: SparkSettings) extends Serializable {

var resultRdd = dataFrame.queryExecution.toRdd
val schema = dataFrame.schema
val deserializer = RowEncoder(schema).resolveAndBind().createDeserializer()
if (Objects.nonNull(sinkTaskPartitionSize)) {
resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize)
}
resultRdd.foreachPartition(iterator => {
while (iterator.hasNext) {
// do load batch with retries
Utils.retry[Int, Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
loadFunc(iterator.asJava, schema, deserializer)
loadFunc(iterator.asJava, schema)
} match {
case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc)
case Failure(e) =>
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package org.apache.doris.spark.sql

import org.apache.spark.sql.SparkSession
import org.junit.{Assert, Ignore, Test}

import java.sql.{Date, Timestamp}
import scala.collection.JavaConverters._

@Ignore
class SchemaUtilsTest {

@Test
def rowColumnValueTest(): Unit = {

val spark = SparkSession.builder().master("local").getOrCreate()

val df = spark.createDataFrame(Seq(
(1, Date.valueOf("2023-09-08"), Timestamp.valueOf("2023-09-08 17:00:00"), Array(1, 2, 3), Map[String, String]("a" -> "1"))
)).toDF("c1", "c2", "c3", "c4", "c5")

val schema = df.schema

df.queryExecution.toRdd.foreach(row => {

val fields = schema.fields
Assert.assertEquals(1, SchemaUtils.rowColumnValue(row, 0, fields(0).dataType))
Assert.assertEquals("2023-09-08", SchemaUtils.rowColumnValue(row, 1, fields(1).dataType))
Assert.assertEquals("2023-09-08 17:00:00.0", SchemaUtils.rowColumnValue(row, 2, fields(2).dataType))
Assert.assertEquals("[1,2,3]", SchemaUtils.rowColumnValue(row, 3, fields(3).dataType))
println(SchemaUtils.rowColumnValue(row, 4, fields(4).dataType))
Assert.assertEquals(Map("a" -> "1").asJava, SchemaUtils.rowColumnValue(row, 4, fields(4).dataType))

})

}

}

0 comments on commit a4e99e5

Please sign in to comment.