Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix](reader) fix arrow timezone bug #232

Merged
merged 5 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@

package org.apache.doris.spark.serialization;

import org.apache.doris.sdk.thrift.TScanBatchResult;
import org.apache.doris.spark.exception.DorisException;
import org.apache.doris.spark.rest.models.Schema;
import org.apache.doris.spark.util.IPUtils;

import com.google.common.base.Preconditions;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BaseIntVector;
Expand All @@ -48,8 +43,11 @@
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.doris.sdk.thrift.TScanBatchResult;
import org.apache.doris.spark.exception.DorisException;
import org.apache.doris.spark.rest.models.Schema;
import org.apache.doris.spark.util.IPUtils;
import org.apache.spark.sql.types.Decimal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -385,13 +383,6 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "DATETIME":
case "DATETIMEV2":

Preconditions.checkArgument(
mt.equals(Types.MinorType.TIMESTAMPMICRO) || mt.equals(MinorType.VARCHAR) ||
mt.equals(MinorType.TIMESTAMPMILLI) || mt.equals(MinorType.TIMESTAMPSEC),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(currentType, mt);

if (mt.equals(Types.MinorType.VARCHAR)) {
VarCharVector varCharVector = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Expand All @@ -404,18 +395,18 @@ public void convertArrowToRowBatch() throws DorisException {
}
} else if (curFieldVector instanceof TimeStampVector) {
TimeStampVector timeStampVector = (TimeStampVector) curFieldVector;

for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (timeStampVector.isNull(rowIndex)) {

addValueToRow(rowIndex, null);
continue;
}
LocalDateTime dateTime = getDateTime(rowIndex, timeStampVector);
String formatted = DATE_TIME_FORMATTER.format(dateTime);
addValueToRow(rowIndex, formatted);
}

} else {
String errMsg = String.format("Unsupported type for DATETIMEV2, minorType %s, class is %s", mt.name(), curFieldVector.getClass());
throw new java.lang.IllegalArgumentException(errMsg);
}
break;
case "CHAR":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,28 @@

package org.apache.doris.spark.serialization;

import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.TimeStampMicroVector;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.doris.sdk.thrift.TScanBatchResult;
import org.apache.doris.sdk.thrift.TStatus;
import org.apache.doris.sdk.thrift.TStatusCode;
import org.apache.doris.spark.exception.DorisException;
import org.apache.doris.spark.rest.RestService;
import org.apache.doris.spark.rest.models.Schema;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.FixedSizeBinaryVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeStampMicroTZVector;
import org.apache.arrow.vector.TimeStampMicroVector;
import org.apache.arrow.vector.TimeStampMilliTZVector;
import org.apache.arrow.vector.TimeStampMilliVector;
import org.apache.arrow.vector.TimeStampSecTZVector;
import org.apache.arrow.vector.TimeStampSecVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
Expand All @@ -52,11 +48,19 @@
import org.apache.arrow.vector.complex.impl.UnionMapWriter;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.doris.sdk.thrift.TScanBatchResult;
import org.apache.doris.sdk.thrift.TStatus;
import org.apache.doris.sdk.thrift.TStatusCode;
import org.apache.doris.spark.exception.DorisException;
import org.apache.doris.spark.rest.RestService;
import org.apache.doris.spark.rest.models.Schema;
import org.apache.spark.sql.types.Decimal;
import org.junit.Assert;
import org.junit.Rule;
Expand All @@ -74,6 +78,7 @@
import java.sql.Date;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;
Expand Down Expand Up @@ -1163,4 +1168,185 @@ public void testIPv6() throws DorisException, IOException {
thrown.expectMessage(startsWith("Get row offset:"));
rowBatch.next();
}

@Test
public void timestampVector() throws IOException, DorisException {
List<Field> childrenBuilder = new ArrayList<>();
childrenBuilder.add(
new Field(
"k0",
FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)),
null));
childrenBuilder.add(
new Field(
"k1",
FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)),
null));
childrenBuilder.add(
new Field(
"k2",
FieldType.nullable(new ArrowType.Timestamp(TimeUnit.SECOND, null)),
null));
childrenBuilder.add(
new Field(
"k3",
FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC+8")),
null));
childrenBuilder.add(
new Field(
"k4",
FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC+8")),
null));
childrenBuilder.add(
new Field(
"k5",
FieldType.nullable(new ArrowType.Timestamp(TimeUnit.SECOND, "UTC+8")),
null));

VectorSchemaRoot root =
VectorSchemaRoot.create(
new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder, null),
new RootAllocator(Integer.MAX_VALUE));
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
ArrowStreamWriter arrowStreamWriter =
new ArrowStreamWriter(
root, new DictionaryProvider.MapDictionaryProvider(), outputStream);

arrowStreamWriter.start();
root.setRowCount(1);

FieldVector vector = root.getVector("k0");
TimeStampMicroVector mircoVec = (TimeStampMicroVector) vector;
mircoVec.allocateNew(1);
mircoVec.setIndexDefined(0);
mircoVec.setSafe(0, 1721892143586123L);
vector.setValueCount(1);

vector = root.getVector("k1");
TimeStampMilliVector milliVector = (TimeStampMilliVector) vector;
milliVector.allocateNew(1);
milliVector.setIndexDefined(0);
milliVector.setSafe(0, 1721892143586L);
vector.setValueCount(1);

vector = root.getVector("k2");
TimeStampSecVector secVector = (TimeStampSecVector) vector;
secVector.allocateNew(1);
secVector.setIndexDefined(0);
secVector.setSafe(0, 1721892143L);
vector.setValueCount(1);

vector = root.getVector("k3");
TimeStampMicroTZVector mircoTZVec = (TimeStampMicroTZVector) vector;
mircoTZVec.allocateNew(1);
mircoTZVec.setIndexDefined(0);
mircoTZVec.setSafe(0, 1721892143586123L);
vector.setValueCount(1);

vector = root.getVector("k4");
TimeStampMilliTZVector milliTZVector = (TimeStampMilliTZVector) vector;
milliTZVector.allocateNew(1);
milliTZVector.setIndexDefined(0);
milliTZVector.setSafe(0, 1721892143586L);
vector.setValueCount(1);

vector = root.getVector("k5");
TimeStampSecTZVector secTZVector = (TimeStampSecTZVector) vector;
secTZVector.allocateNew(1);
secTZVector.setIndexDefined(0);
secTZVector.setSafe(0, 1721892143L);
vector.setValueCount(1);

arrowStreamWriter.writeBatch();

arrowStreamWriter.end();
arrowStreamWriter.close();

TStatus status = new TStatus();
status.setStatusCode(TStatusCode.OK);
TScanBatchResult scanBatchResult = new TScanBatchResult();
scanBatchResult.setStatus(status);
scanBatchResult.setEos(false);
scanBatchResult.setRows(outputStream.toByteArray());

String schemaStr =
"{\"properties\":[{\"type\":\"DATETIME\",\"name\":\"k0\",\"comment\":\"\"}, {\"type\":\"DATETIME\",\"name\":\"k1\",\"comment\":\"\"}, {\"type\":\"DATETIME\",\"name\":\"k2\",\"comment\":\"\"}, {\"type\":\"DATETIME\",\"name\":\"k3\",\"comment\":\"\"}, {\"type\":\"DATETIME\",\"name\":\"k4\",\"comment\":\"\"}, {\"type\":\"DATETIME\",\"name\":\"k5\",\"comment\":\"\"}],"
+ "\"status\":200}";

Schema schema = RestService.parseSchema(schemaStr, logger);
RowBatch rowBatch = new RowBatch(scanBatchResult, schema);
List<Object> next = rowBatch.next();
Assert.assertEquals(next.size(), 6);
Assert.assertEquals(
next.get(0),
"2024-07-25 15:22:23.586123");
Assert.assertEquals(
next.get(1),
"2024-07-25 15:22:23.586");
Assert.assertEquals(
next.get(2),
"2024-07-25 15:22:23");
Assert.assertEquals(
next.get(3),
"2024-07-25 15:22:23.586123");
Assert.assertEquals(
next.get(4),
"2024-07-25 15:22:23.586");
Assert.assertEquals(
next.get(5),
"2024-07-25 15:22:23");
}

@Test
public void timestampTypeNotMatch() throws IOException, DorisException {
List<Field> childrenBuilder = new ArrayList<>();
childrenBuilder.add(
new Field(
"k0",
FieldType.nullable(new ArrowType.Int(32, false)),
null));

VectorSchemaRoot root =
VectorSchemaRoot.create(
new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder, null),
new RootAllocator(Integer.MAX_VALUE));
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
ArrowStreamWriter arrowStreamWriter =
new ArrowStreamWriter(
root, new DictionaryProvider.MapDictionaryProvider(), outputStream);

arrowStreamWriter.start();
root.setRowCount(1);

FieldVector vector = root.getVector("k0");
UInt4Vector uInt4Vector = (UInt4Vector) vector;
uInt4Vector.setInitialCapacity(1);
uInt4Vector.allocateNew();
uInt4Vector.setIndexDefined(0);
uInt4Vector.setSafe(0, 0);

vector.setValueCount(1);
arrowStreamWriter.writeBatch();

arrowStreamWriter.end();
arrowStreamWriter.close();

TStatus status = new TStatus();
status.setStatusCode(TStatusCode.OK);
TScanBatchResult scanBatchResult = new TScanBatchResult();
scanBatchResult.setStatus(status);
scanBatchResult.setEos(false);
scanBatchResult.setRows(outputStream.toByteArray());

String schemaStr =
"{\"properties\":["
+ "{\"type\":\"DATETIMEV2\",\"name\":\"k0\",\"comment\":\"\"}"
+ "], \"status\":200}";

Schema schema = RestService.parseSchema(schemaStr, logger);
thrown.expect(DorisException.class);
thrown.expectMessage(startsWith("Unsupported type for DATETIMEV2"));
new RowBatch(scanBatchResult, schema);
}

}
Loading