Skip to content

Commit

Permalink
[feature]support read ipv4/ipv6 data type (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinlee19 authored May 7, 2024
1 parent 37c5175 commit a3cf8a6
Show file tree
Hide file tree
Showing 4 changed files with 489 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import com.google.common.base.Preconditions;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BaseIntVector;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
Expand All @@ -31,6 +32,7 @@
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeStampMicroVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
Expand All @@ -44,6 +46,7 @@
import org.apache.doris.sdk.thrift.TScanBatchResult;
import org.apache.doris.spark.exception.DorisException;
import org.apache.doris.spark.rest.models.Schema;
import org.apache.doris.spark.util.IPUtils;
import org.apache.spark.sql.types.Decimal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -69,6 +72,8 @@
import java.util.NoSuchElementException;
import java.util.Objects;

import static org.apache.doris.spark.util.IPUtils.convertLongToIPv4Address;

/**
* row batch data container.
*/
Expand Down Expand Up @@ -246,6 +251,20 @@ public void convertArrowToRowBatch() throws DorisException {
}
}
break;
case "IPV4":
Preconditions.checkArgument(mt.equals(Types.MinorType.UINT4) || mt.equals(Types.MinorType.INT),
typeMismatchMessage(currentType, mt));
BaseIntVector ipv4Vector;
if (mt.equals(Types.MinorType.INT)) {
ipv4Vector = (IntVector) curFieldVector;
} else {
ipv4Vector = (UInt4Vector) curFieldVector;
}
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = ipv4Vector.isNull(rowIndex) ? null : convertLongToIPv4Address(ipv4Vector.getValueAsLong(rowIndex));
addValueToRow(rowIndex, fieldValue);
}
break;
case "FLOAT":
Preconditions.checkArgument(mt.equals(Types.MinorType.FLOAT4),
typeMismatchMessage(currentType, mt));
Expand Down Expand Up @@ -314,7 +333,7 @@ public void convertArrowToRowBatch() throws DorisException {
case "DATE":
case "DATEV2":
Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR)
|| mt.equals(Types.MinorType.DATEDAY), typeMismatchMessage(currentType, mt));
|| mt.equals(Types.MinorType.DATEDAY), typeMismatchMessage(currentType, mt));
if (mt.equals(Types.MinorType.VARCHAR)) {
VarCharVector date = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Expand Down Expand Up @@ -392,6 +411,20 @@ public void convertArrowToRowBatch() throws DorisException {
addValueToRow(rowIndex, value);
}
break;
case "IPV6":
Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR),
typeMismatchMessage(currentType, mt));
VarCharVector ipv6VarcharVector = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (ipv6VarcharVector.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
break;
}
String ipv6Str = new String(ipv6VarcharVector.get(rowIndex));
String ipv6Address = IPUtils.fromBigInteger(new BigInteger(ipv6Str));
addValueToRow(rowIndex, ipv6Address);
}
break;
case "ARRAY":
Preconditions.checkArgument(mt.equals(Types.MinorType.LIST),
typeMismatchMessage(currentType, mt));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ private[spark] object SchemaUtils {
case "MAP" => MapType(DataTypes.StringType, DataTypes.StringType)
case "STRUCT" => DataTypes.StringType
case "VARIANT" => DataTypes.StringType
case "IPV4" => DataTypes.StringType
case "IPV6" => DataTypes.StringType
case "HLL" =>
throw new DorisException("Unsupported type " + dorisType)
case _ =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.spark.util;

import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;

/**
*
*/
public class IPUtils {
/**
* Create an IPv6 address from a (positive) {@link java.math.BigInteger}. The magnitude of the
* {@link java.math.BigInteger} represents the IPv6 address value. Or in other words, the {@link
* java.math.BigInteger} with value N defines the Nth possible IPv6 address.
*
* @param bigInteger {@link java.math.BigInteger} value
* @return IPv6 address
*/
public static String fromBigInteger(BigInteger bigInteger) {
byte[] bytes = bigInteger.toByteArray();
if (bytes[0] == 0) {
bytes = Arrays.copyOfRange(bytes, 1, bytes.length); // Skip leading zero byte
}
bytes = prefixWithZeroBytes(bytes);
long[] ipv6Bits = fromByteArray(bytes);
return toIPv6String(ipv6Bits[0], ipv6Bits[1]);
}

private static byte[] prefixWithZeroBytes(byte[] original) {
byte[] target = new byte[16];
System.arraycopy(original, 0, target, 16 - original.length, original.length);
return target;
}

/**
* Create an IPv6 address from a byte array.
*
* @param bytes byte array with 16 bytes (interpreted unsigned)
* @return IPv6 address
*/
public static long[] fromByteArray(byte[] bytes) {
if (bytes == null || bytes.length != 16) {
throw new IllegalArgumentException("Byte array must be exactly 16 bytes long");
}
ByteBuffer buf = ByteBuffer.wrap(bytes);
LongBuffer longBuffer = buf.asLongBuffer();
return new long[] {longBuffer.get(), longBuffer.get()};
}

private static String toShortHandNotationString(long highBits, long lowBits) {
String[] strings = toArrayOfShortStrings(highBits, lowBits);
StringBuilder result = new StringBuilder();
int[] shortHandNotationPositionAndLength =
startAndLengthOfLongestRunOfZeroes(highBits, lowBits);
int shortHandNotationPosition = shortHandNotationPositionAndLength[0];
int shortHandNotationLength = shortHandNotationPositionAndLength[1];
boolean useShortHandNotation = shortHandNotationLength > 1;

for (int i = 0; i < strings.length; ++i) {
if (useShortHandNotation && i == shortHandNotationPosition) {
if (i == 0) {
result.append("::");
} else {
result.append(":");
}
} else if (i <= shortHandNotationPosition
|| i >= shortHandNotationPosition + shortHandNotationLength) {
result.append(strings[i]);
if (i < 7) {
result.append(":");
}
}
}

return result.toString().toLowerCase();
}

private static String[] toArrayOfShortStrings(long highBits, long lowBits) {
short[] shorts = toShortArray(highBits, lowBits);
String[] strings = new String[shorts.length];

for (int i = 0; i < shorts.length; ++i) {
strings[i] = String.format("%x", shorts[i]);
}

return strings;
}

private static short[] toShortArray(long highBits, long lowBits) {
short[] shorts = new short[8];

for (int i = 0; i < 8; ++i) {
if (inHighRange(i)) {
shorts[i] = (short) ((int) (highBits << i * 16 >>> 48 & 0xFFFF));
} else {
shorts[i] = (short) ((int) (lowBits << i * 16 >>> 48 & 0xFFFF));
}
}

return shorts;
}

private static int[] startAndLengthOfLongestRunOfZeroes(long highBits, long lowBits) {
int longestConsecutiveZeroes = 0;
int longestConsecutiveZeroesPos = -1;
short[] shorts = toShortArray(highBits, lowBits);

for (int pos = 0; pos < shorts.length; ++pos) {
int consecutiveZeroesAtCurrentPos = countConsecutiveZeroes(shorts, pos);
if (consecutiveZeroesAtCurrentPos > longestConsecutiveZeroes) {
longestConsecutiveZeroes = consecutiveZeroesAtCurrentPos;
longestConsecutiveZeroesPos = pos;
}
}

return new int[] {longestConsecutiveZeroesPos, longestConsecutiveZeroes};
}

private static boolean inHighRange(int shortNumber) {
return shortNumber >= 0 && shortNumber < 4;
}

private static int countConsecutiveZeroes(short[] shorts, int offset) {
int count = 0;

for (int i = offset; i < shorts.length && shorts[i] == 0; ++i) {
++count;
}

return count;
}

public static String toIPv6String(long highBits, long lowBits) {

if (isIPv4Mapped(highBits, lowBits)) {
return toIPv4MappedAddressString(lowBits);
} else if (isIPv4Compatibility(highBits, lowBits)) {
return toIPv4CompatibilityAddressString(lowBits);
}

return toShortHandNotationString(highBits, lowBits);
}

public static String convertLongToIPv4Address(long lowBits) {
return String.format(
"%d.%d.%d.%d",
(lowBits >> 24) & 0xff,
(lowBits >> 16) & 0xff,
(lowBits >> 8) & 0xff,
lowBits & 0xff);
}

private static String toIPv4MappedAddressString(long lowBits) {
return "::ffff:" + convertLongToIPv4Address(lowBits);
}

private static String toIPv4CompatibilityAddressString(long lowBits) {
return "::" + convertLongToIPv4Address(lowBits);
}

/**
* Returns true if the address is an IPv4-mapped IPv6 address. In these addresses, the first 80
* bits are zero, the next 16 bits are one, and the remaining 32 bits are the IPv4 address.
*
* @return true if the address is an IPv4-mapped IPv6 addresses.
*/
private static boolean isIPv4Mapped(long highBits, long lowBits) {
return highBits == 0
&& (lowBits & 0xFFFF000000000000L) == 0
&& (lowBits & 0x0000FFFF00000000L) == 0x0000FFFF00000000L;
}

/**
* Checks if the given IPv6 address is in IPv4 compatibility format. IPv4 compatibility format
* is characterized by having the high 96 bits of the IPv6 address set to zero, while the low 32
* bits represent an IPv4 address. The criteria for determining IPv4 compatibility format are as
* follows: 1. The high 96 bits of the IPv6 address are all zeros. 2. The low 32 bits are within
* the range from 0 to 4294967295 (0xFFFFFFFF). 3. The first 16 bits of the low 32 bits are all
* ones (0xFFFF), indicating the special identifier for IPv4 compatibility format.
*
* @return True if the given IPv6 address is in IPv4 compatibility format; otherwise, false.
*/
private static boolean isIPv4Compatibility(long highBits, long lowBits) {
return highBits == 0L && lowBits <= 0xFFFFFFFFL && (lowBits & 65536L) == 65536L;
}
}
Loading

0 comments on commit a3cf8a6

Please sign in to comment.