details) {
super(message, cause);
this.status = status;
this.sqlState = sqlState;
@@ -94,14 +93,14 @@ public int getVendorCode() {
}
/**
- * Get extra driver-specific binary error details.
+ * Get extra driver-specific error details.
*
* This allows drivers to return custom, structured error information (for example, JSON or
* Protocol Buffers) that can be optionally parsed by clients, beyond the standard AdbcError
* fields, without having to encode it in the error message. The encoding of the data is
* driver-defined.
*/
- public Collection getDetails() {
+ public Collection getDetails() {
return details;
}
@@ -115,7 +114,7 @@ public AdbcException withCause(Throwable cause) {
/**
* Copy this exception with different details (a convenience for use with the static factories).
*/
- public AdbcException withDetails(Collection details) {
+ public AdbcException withDetails(Collection details) {
return new AdbcException(getMessage(), getCause(), status, sqlState, vendorCode, details);
}
diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java
index efd8eab77b..5a8e78b08f 100644
--- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java
+++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptions.java
@@ -27,7 +27,7 @@ public interface AdbcOptions {
* @return The option value.
* @param The option value type.
*/
- default T getOption(AdbcOptionKey key) throws AdbcException {
+ default T getOption(TypedKey key) throws AdbcException {
throw AdbcException.notImplemented("Unsupported option " + key);
}
@@ -39,7 +39,7 @@ default T getOption(AdbcOptionKey key) throws AdbcException {
* @param value The option value.
* @param The option value type.
*/
- default void setOption(AdbcOptionKey key, T value) throws AdbcException {
+ default void setOption(TypedKey key, T value) throws AdbcException {
throw AdbcException.notImplemented("Unsupported option " + key);
}
}
diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java
index 27708e1bbf..07c7eab126 100644
--- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java
+++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcStatement.java
@@ -58,7 +58,7 @@ default void cancel() throws AdbcException {
/**
* Set a generic query option.
*
- * @deprecated Prefer {@link #setOption(AdbcOptionKey, Object)}.
+ * @deprecated Prefer {@link #setOption(TypedKey, Object)}.
*/
default void setOption(String key, Object value) throws AdbcException {
throw AdbcException.notImplemented("Unsupported option " + key);
diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java b/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java
new file mode 100644
index 0000000000..13521fb82e
--- /dev/null
+++ b/java/core/src/main/java/org/apache/arrow/adbc/core/ErrorDetail.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.adbc.core;
+
+import java.util.Objects;
+
+/** Additional details (not necessarily human-readable) contained in an {@link AdbcException}. */
+public class ErrorDetail {
+ private final String key;
+ private final Object value;
+
+ public ErrorDetail(String key, Object value) {
+ this.key = Objects.requireNonNull(key);
+ this.value = Objects.requireNonNull(value);
+ }
+
+ public String getKey() {
+ return key;
+ }
+
+ public Object getValue() {
+ return value;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ErrorDetail that = (ErrorDetail) o;
+ return Objects.equals(getKey(), that.getKey()) && Objects.equals(getValue(), that.getValue());
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(getKey(), getValue());
+ }
+
+ @Override
+ public String toString() {
+ return "ErrorDetail{" + "key='" + key + '\'' + ", value=" + value + '}';
+ }
+}
diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java
index a14c04c700..c1e5594b87 100644
--- a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java
+++ b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardSchemas.java
@@ -19,6 +19,8 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.UnionMode;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
@@ -30,10 +32,13 @@ private StandardSchemas() {
throw new AssertionError("Do not instantiate this class");
}
- private static final ArrowType INT16 = new ArrowType.Int(16, true);
- private static final ArrowType INT32 = new ArrowType.Int(32, true);
- private static final ArrowType INT64 = new ArrowType.Int(64, true);
+ private static final ArrowType INT16 = Types.MinorType.SMALLINT.getType();
+ private static final ArrowType INT32 = Types.MinorType.INT.getType();
+ private static final ArrowType INT64 = Types.MinorType.BIGINT.getType();
private static final ArrowType UINT32 = new ArrowType.Int(32, false);
+ private static final ArrowType UINT64 = new ArrowType.Int(64, false);
+ private static final ArrowType FLOAT64 =
+ new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
/** The schema of the result set of {@link AdbcConnection#getInfo(int[])}}. */
public static final Schema GET_INFO_SCHEMA =
@@ -83,11 +88,11 @@ private StandardSchemas() {
Field.notNullable("constraint_type", ArrowType.Utf8.INSTANCE),
new Field(
"constraint_column_names",
- FieldType.notNullable(ArrowType.List.INSTANCE),
+ FieldType.nullable(ArrowType.List.INSTANCE),
Collections.singletonList(Field.nullable("item", new ArrowType.Utf8()))),
new Field(
"constraint_column_usage",
- FieldType.notNullable(ArrowType.List.INSTANCE),
+ FieldType.nullable(ArrowType.List.INSTANCE),
Collections.singletonList(
new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), USAGE_SCHEMA))));
@@ -119,12 +124,12 @@ private StandardSchemas() {
new Field("table_type", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null),
new Field(
"table_columns",
- FieldType.notNullable(ArrowType.List.INSTANCE),
+ FieldType.nullable(ArrowType.List.INSTANCE),
Collections.singletonList(
new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), COLUMN_SCHEMA))),
new Field(
"table_constraints",
- FieldType.notNullable(ArrowType.List.INSTANCE),
+ FieldType.nullable(ArrowType.List.INSTANCE),
Collections.singletonList(
new Field(
"item", FieldType.nullable(ArrowType.Struct.INSTANCE), CONSTRAINT_SCHEMA))));
@@ -134,20 +139,76 @@ private StandardSchemas() {
new Field("db_schema_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null),
new Field(
"db_schema_tables",
- FieldType.notNullable(ArrowType.List.INSTANCE),
+ FieldType.nullable(ArrowType.List.INSTANCE),
Collections.singletonList(
new Field("item", FieldType.nullable(ArrowType.Struct.INSTANCE), TABLE_SCHEMA))));
+ /**
+ * The schema of the result of {@link AdbcConnection#getObjects(AdbcConnection.GetObjectsDepth,
+ * String, String, String, String[], String)}.
+ */
public static final Schema GET_OBJECTS_SCHEMA =
new Schema(
Arrays.asList(
new Field("catalog_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null),
new Field(
"catalog_db_schemas",
- FieldType.notNullable(ArrowType.List.INSTANCE),
+ FieldType.nullable(ArrowType.List.INSTANCE),
Collections.singletonList(
new Field(
"item",
FieldType.nullable(ArrowType.Struct.INSTANCE),
DB_SCHEMA_SCHEMA)))));
+
+ public static final List STATISTICS_VALUE_SCHEMA =
+ Arrays.asList(
+ Field.nullable("int64", INT64),
+ Field.nullable("uint64", UINT64),
+ Field.nullable("float64", FLOAT64),
+ Field.nullable("binary", ArrowType.Binary.INSTANCE));
+
+ public static final List STATISTICS_SCHEMA =
+ Arrays.asList(
+ Field.notNullable("table_name", ArrowType.Utf8.INSTANCE),
+ Field.nullable("column_name", ArrowType.Utf8.INSTANCE),
+ Field.notNullable("statistic_key", INT16),
+ new Field(
+ "statistic_value",
+ FieldType.notNullable(new ArrowType.Union(UnionMode.Dense, new int[] {0, 1, 2, 3})),
+ STATISTICS_VALUE_SCHEMA),
+ Field.notNullable("statistic_is_approximate", ArrowType.Bool.INSTANCE));
+
+ public static final List STATISTICS_DB_SCHEMA_SCHEMA =
+ Arrays.asList(
+ new Field("db_schema_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null),
+ new Field(
+ "db_schema_statistics",
+ FieldType.nullable(ArrowType.List.INSTANCE),
+ Collections.singletonList(
+ new Field(
+ "item", FieldType.nullable(ArrowType.Struct.INSTANCE), STATISTICS_SCHEMA))));
+
+ /**
+ * The schema of the result of {@link AdbcConnection#getStatistics(String, String, String,
+ * boolean)}.
+ */
+ public static final Schema GET_STATISTICS_SCHEMA =
+ new Schema(
+ Arrays.asList(
+ new Field("catalog_name", FieldType.notNullable(ArrowType.Utf8.INSTANCE), null),
+ new Field(
+ "catalog_db_schemas",
+ FieldType.nullable(ArrowType.List.INSTANCE),
+ Collections.singletonList(
+ new Field(
+ "item",
+ FieldType.nullable(ArrowType.Struct.INSTANCE),
+ STATISTICS_DB_SCHEMA_SCHEMA)))));
+
+ /** The schema of the result of {@link AdbcConnection#getStatisticNames()}. */
+ public static final Schema GET_STATISTIC_NAMES_SCHEMA =
+ new Schema(
+ Arrays.asList(
+ Field.notNullable("statistic_name", ArrowType.Utf8.INSTANCE),
+ Field.notNullable("statistic_name", INT16)));
}
diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java
index 5412c645c3..f5097f4413 100644
--- a/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java
+++ b/java/core/src/main/java/org/apache/arrow/adbc/core/StandardStatistics.java
@@ -32,39 +32,39 @@ public enum StandardStatistics {
*
* For example, this is roughly the average length of a string for a string column.
*/
- AVERAGE_BYTE_WIDTH("adbc.statistic.byte_width", 0),
+ AVERAGE_BYTE_WIDTH("adbc.statistic.byte_width", (short) 0),
/**
* The distinct value count (NDV) statistic. The number of distinct values in the column. Value
* type is int64 (when not approximate) or float64 (when approximate).
*/
- DISTINCT_COUNT("adbc.statistic.distinct_count", 1),
+ DISTINCT_COUNT("adbc.statistic.distinct_count", (short) 1),
/**
* The max byte width statistic. The maximum size in bytes of a row in the column. Value type is
* int64 (when not approximate) or float64 (when approximate).
*
*
For example, this is the maximum length of a string for a string column.
*/
- MAX_BYTE_WIDTH("adbc.statistic.byte_width", 2),
+ MAX_BYTE_WIDTH("adbc.statistic.byte_width", (short) 2),
/** The max value statistic. Value type is column-dependent. */
- MAX_VALUE_NAME("adbc.statistic.byte_width", 3),
+ MAX_VALUE("adbc.statistic.byte_width", (short) 3),
/** The min value statistic. Value type is column-dependent. */
- MIN_VALUE_NAME("adbc.statistic.byte_width", 4),
+ MIN_VALUE("adbc.statistic.byte_width", (short) 4),
/**
* The null count statistic. The number of values that are null in the column. Value type is int64
* (when not approximate) or float64 (when approximate).
*/
- NULL_COUNT_NAME("adbc.statistic.null_count", 5),
+ NULL_COUNT("adbc.statistic.null_count", (short) 5),
/**
* The row count statistic. The number of rows in the column or table. Value type is int64 (when
* not approximate) or float64 (when approximate).
*/
- ROW_COUNT_NAME("adbc.statistic.row_count", 6),
+ ROW_COUNT("adbc.statistic.row_count", (short) 6),
;
private final String name;
- private final int key;
+ private final short key;
- StandardStatistics(String name, int key) {
+ StandardStatistics(String name, short key) {
this.name = Objects.requireNonNull(name);
this.key = key;
}
@@ -75,7 +75,7 @@ public String getName() {
}
/** Get the dictionary-encoded name. */
- public int getKey() {
+ public short getKey() {
return key;
}
}
diff --git a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptionKey.java b/java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java
similarity index 78%
rename from java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptionKey.java
rename to java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java
index d594703688..21523bb429 100644
--- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcOptionKey.java
+++ b/java/core/src/main/java/org/apache/arrow/adbc/core/TypedKey.java
@@ -26,15 +26,33 @@
* @since ADBC API revision 1.1.0
* @param The option value type.
*/
-public final class AdbcOptionKey {
+public final class TypedKey {
private final String key;
private final Class type;
- public AdbcOptionKey(String key, Class type) {
+ public TypedKey(String key, Class type) {
this.key = Objects.requireNonNull(key);
this.type = Objects.requireNonNull(type);
}
+ /** Get the option key. */
+ public String getKey() {
+ return key;
+ }
+
+ /**
+ * Get the option value (if it was set) and check the type.
+ *
+ * @throws ClassCastException if the value is of the wrong type.
+ */
+ public T get(Map options) {
+ Object value = options.get(key);
+ if (value == null) {
+ return null;
+ }
+ return type.cast(value);
+ }
+
/**
* Set this option in an options map (like for {@link AdbcDriver#open(Map)}.
*
@@ -53,7 +71,7 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) {
return false;
}
- AdbcOptionKey> that = (AdbcOptionKey>) o;
+ TypedKey> that = (TypedKey>) o;
return Objects.equals(key, that.key) && Objects.equals(type, that.type);
}
diff --git a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java
index 43a6df99c9..d3f79889ec 100644
--- a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java
+++ b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java
@@ -47,7 +47,7 @@ public AdbcDatabase initDatabase(BufferAllocator allocator) throws AdbcException
String url = getFlightLocation();
final Map parameters = new HashMap<>();
- parameters.put(AdbcDriver.PARAM_URL, url);
+ AdbcDriver.PARAM_URI.set(parameters, url);
return new FlightSqlDriver(allocator).open(parameters);
}
diff --git a/java/driver/flight-sql/pom.xml b/java/driver/flight-sql/pom.xml
index 432967963b..9b78b4da24 100644
--- a/java/driver/flight-sql/pom.xml
+++ b/java/driver/flight-sql/pom.xml
@@ -66,5 +66,17 @@
org.apache.arrow.adbc
adbc-sql
+
+
+
+ org.assertj
+ assertj-core
+ test
+
+
+ org.junit.jupiter
+ junit-jupiter
+ test
+
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java
index 30fc460b8e..5015ecfcfe 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java
@@ -43,17 +43,22 @@ public class FlightSqlDriver implements AdbcDriver {
@Override
public AdbcDatabase open(Map parameters) throws AdbcException {
- Object target = parameters.get("adbc.url");
- if (!(target instanceof String)) {
- throw AdbcException.invalidArgument(
- "[Flight SQL] Must provide String " + PARAM_URL + " parameter");
+ String uri = PARAM_URI.get(parameters);
+ if (uri == null) {
+ Object target = parameters.get("adbc.url");
+ if (!(target instanceof String)) {
+ throw AdbcException.invalidArgument(
+ "[Flight SQL] Must provide String " + PARAM_URI + " parameter");
+ }
+ uri = (String) target;
}
+
Location location;
try {
- location = new Location((String) target);
+ location = new Location(uri);
} catch (URISyntaxException e) {
throw AdbcException.invalidArgument(
- String.format("[Flight SQL] Location %s is invalid: %s", target, e))
+ String.format("[Flight SQL] Location %s is invalid: %s", uri, e))
.withCause(e);
}
Object quirks = parameters.get(PARAM_SQL_QUIRKS);
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java
index cb6b3038f8..45b42df2ee 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriverUtil.java
@@ -17,8 +17,11 @@
package org.apache.arrow.adbc.driver.flightsql;
import java.sql.SQLException;
+import java.util.ArrayList;
+import java.util.List;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatusCode;
+import org.apache.arrow.adbc.core.ErrorDetail;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
@@ -72,7 +75,24 @@ static AdbcStatusCode fromFlightStatusCode(FlightStatusCode code) {
}
static AdbcException fromFlightException(FlightRuntimeException e) {
+ List errorDetails = new ArrayList<>();
+ for (String key : e.status().metadata().keys()) {
+ if (key.endsWith("-bin")) {
+ for (byte[] value : e.status().metadata().getAllByte(key)) {
+ errorDetails.add(new ErrorDetail(key, value));
+ }
+ } else {
+ for (String value : e.status().metadata().getAll(key)) {
+ errorDetails.add(new ErrorDetail(key, value));
+ }
+ }
+ }
return new AdbcException(
- e.getMessage(), e.getCause(), fromFlightStatusCode(e.status().code()), null, 0);
+ e.getMessage(),
+ e.getCause(),
+ fromFlightStatusCode(e.status().code()),
+ null,
+ 0,
+ errorDetails);
}
}
diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java
new file mode 100644
index 0000000000..c617f664fa
--- /dev/null
+++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/DetailsTest.java
@@ -0,0 +1,381 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.adbc.driver.flightsql;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import io.grpc.Metadata;
+import io.grpc.Status;
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import org.apache.arrow.adbc.core.AdbcConnection;
+import org.apache.arrow.adbc.core.AdbcDatabase;
+import org.apache.arrow.adbc.core.AdbcDriver;
+import org.apache.arrow.adbc.core.AdbcException;
+import org.apache.arrow.adbc.core.AdbcStatement;
+import org.apache.arrow.adbc.core.ErrorDetail;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.Criteria;
+import org.apache.arrow.flight.ErrorFlightMetadata;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.PutResult;
+import org.apache.arrow.flight.Result;
+import org.apache.arrow.flight.SchemaResult;
+import org.apache.arrow.flight.sql.FlightSqlProducer;
+import org.apache.arrow.flight.sql.impl.FlightSql;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+/** Test that gRPC error details make it through. */
+class DetailsTest {
+ static BufferAllocator allocator;
+ static Producer producer;
+ static FlightServer server;
+ static AdbcDriver driver;
+ static AdbcDatabase database;
+ AdbcConnection connection;
+ AdbcStatement statement;
+
+ @BeforeAll
+ static void beforeAll() throws Exception {
+ allocator = new RootAllocator();
+ producer = new Producer();
+ server =
+ FlightServer.builder()
+ .allocator(allocator)
+ .producer(producer)
+ .location(Location.forGrpcInsecure("localhost", 0))
+ .build();
+ server.start();
+ driver = new FlightSqlDriver(allocator);
+ Map parameters = new HashMap<>();
+ AdbcDriver.PARAM_URI.set(
+ parameters, Location.forGrpcInsecure("localhost", server.getPort()).getUri().toString());
+ database = driver.open(parameters);
+ }
+
+ @BeforeEach
+ void beforeEach() throws Exception {
+ connection = database.connect();
+ statement = connection.createStatement();
+ }
+
+ @AfterEach
+ void afterEach() throws Exception {
+ AutoCloseables.close(statement, connection);
+ }
+
+ @AfterAll
+ static void afterAll() throws Exception {
+ AutoCloseables.close(database, server, allocator);
+ }
+
+ @Test
+ void flightDetails() throws Exception {
+ statement.setSqlQuery("flight");
+
+ AdbcException exception =
+ assertThrows(
+ AdbcException.class,
+ () -> {
+ try (AdbcStatement.QueryResult result = statement.executeQuery()) {}
+ });
+
+ assertThat(exception.getDetails()).contains(new ErrorDetail("x-foo", "text"));
+ Optional binaryKey =
+ exception.getDetails().stream().filter(x -> x.getKey().equals("x-foo-bin")).findAny();
+ assertThat(binaryKey)
+ .get()
+ .extracting(ErrorDetail::getValue)
+ .isEqualTo("text".getBytes(StandardCharsets.UTF_8));
+ }
+
+ @Test
+ void grpcDetails() throws Exception {
+ statement.setSqlQuery("grpc");
+
+ AdbcException exception =
+ assertThrows(
+ AdbcException.class,
+ () -> {
+ try (AdbcStatement.QueryResult result = statement.executeQuery()) {}
+ });
+
+ assertThat(exception.getDetails()).contains(new ErrorDetail("x-foo", "text"));
+ Optional binaryKey =
+ exception.getDetails().stream().filter(x -> x.getKey().equals("x-foo-bin")).findAny();
+ assertThat(binaryKey)
+ .get()
+ .extracting(ErrorDetail::getValue)
+ .isEqualTo("text".getBytes(StandardCharsets.UTF_8));
+ }
+
+ static class Producer implements FlightSqlProducer {
+ Metadata.Key BINARY_KEY = Metadata.Key.of("x-foo-bin", Metadata.BINARY_BYTE_MARSHALLER);
+ Metadata.Key TEXT_KEY = Metadata.Key.of("x-foo", Metadata.ASCII_STRING_MARSHALLER);
+
+ @Override
+ public FlightInfo getFlightInfoStatement(
+ FlightSql.CommandStatementQuery commandStatementQuery,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ if (commandStatementQuery.getQuery().equals("flight")) {
+ // Using Flight path
+ ErrorFlightMetadata metadata = new ErrorFlightMetadata();
+ metadata.insert("x-foo", "text");
+ metadata.insert("x-foo-bin", "text".getBytes(StandardCharsets.UTF_8));
+ throw CallStatus.UNKNOWN
+ .withDescription("Expected")
+ .withMetadata(metadata)
+ .toRuntimeException();
+ } else if (commandStatementQuery.getQuery().equals("grpc")) {
+ // Using gRPC path
+ Metadata trailers = new Metadata();
+ trailers.put(TEXT_KEY, "text");
+ trailers.put(BINARY_KEY, "text".getBytes(StandardCharsets.UTF_8));
+ throw Status.UNKNOWN.asRuntimeException(trailers);
+ }
+
+ throw CallStatus.UNIMPLEMENTED.toRuntimeException();
+ }
+
+ // No-op implementations
+
+ @Override
+ public void createPreparedStatement(
+ FlightSql.ActionCreatePreparedStatementRequest actionCreatePreparedStatementRequest,
+ CallContext callContext,
+ StreamListener streamListener) {}
+
+ @Override
+ public void closePreparedStatement(
+ FlightSql.ActionClosePreparedStatementRequest actionClosePreparedStatementRequest,
+ CallContext callContext,
+ StreamListener streamListener) {}
+
+ @Override
+ public FlightInfo getFlightInfoPreparedStatement(
+ FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public SchemaResult getSchemaStatement(
+ FlightSql.CommandStatementQuery commandStatementQuery,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public void getStreamStatement(
+ FlightSql.TicketStatementQuery ticketStatementQuery,
+ CallContext callContext,
+ ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public void getStreamPreparedStatement(
+ FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery,
+ CallContext callContext,
+ ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public Runnable acceptPutStatement(
+ FlightSql.CommandStatementUpdate commandStatementUpdate,
+ CallContext callContext,
+ FlightStream flightStream,
+ StreamListener streamListener) {
+ return null;
+ }
+
+ @Override
+ public Runnable acceptPutPreparedStatementUpdate(
+ FlightSql.CommandPreparedStatementUpdate commandPreparedStatementUpdate,
+ CallContext callContext,
+ FlightStream flightStream,
+ StreamListener streamListener) {
+ return null;
+ }
+
+ @Override
+ public Runnable acceptPutPreparedStatementQuery(
+ FlightSql.CommandPreparedStatementQuery commandPreparedStatementQuery,
+ CallContext callContext,
+ FlightStream flightStream,
+ StreamListener streamListener) {
+ return null;
+ }
+
+ @Override
+ public FlightInfo getFlightInfoSqlInfo(
+ FlightSql.CommandGetSqlInfo commandGetSqlInfo,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public void getStreamSqlInfo(
+ FlightSql.CommandGetSqlInfo commandGetSqlInfo,
+ CallContext callContext,
+ ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public FlightInfo getFlightInfoTypeInfo(
+ FlightSql.CommandGetXdbcTypeInfo commandGetXdbcTypeInfo,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public void getStreamTypeInfo(
+ FlightSql.CommandGetXdbcTypeInfo commandGetXdbcTypeInfo,
+ CallContext callContext,
+ ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public FlightInfo getFlightInfoCatalogs(
+ FlightSql.CommandGetCatalogs commandGetCatalogs,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public void getStreamCatalogs(
+ CallContext callContext, ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public FlightInfo getFlightInfoSchemas(
+ FlightSql.CommandGetDbSchemas commandGetDbSchemas,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public void getStreamSchemas(
+ FlightSql.CommandGetDbSchemas commandGetDbSchemas,
+ CallContext callContext,
+ ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public FlightInfo getFlightInfoTables(
+ FlightSql.CommandGetTables commandGetTables,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public void getStreamTables(
+ FlightSql.CommandGetTables commandGetTables,
+ CallContext callContext,
+ ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public FlightInfo getFlightInfoTableTypes(
+ FlightSql.CommandGetTableTypes commandGetTableTypes,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public void getStreamTableTypes(
+ CallContext callContext, ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public FlightInfo getFlightInfoPrimaryKeys(
+ FlightSql.CommandGetPrimaryKeys commandGetPrimaryKeys,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public void getStreamPrimaryKeys(
+ FlightSql.CommandGetPrimaryKeys commandGetPrimaryKeys,
+ CallContext callContext,
+ ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public FlightInfo getFlightInfoExportedKeys(
+ FlightSql.CommandGetExportedKeys commandGetExportedKeys,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public FlightInfo getFlightInfoImportedKeys(
+ FlightSql.CommandGetImportedKeys commandGetImportedKeys,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public FlightInfo getFlightInfoCrossReference(
+ FlightSql.CommandGetCrossReference commandGetCrossReference,
+ CallContext callContext,
+ FlightDescriptor flightDescriptor) {
+ return null;
+ }
+
+ @Override
+ public void getStreamExportedKeys(
+ FlightSql.CommandGetExportedKeys commandGetExportedKeys,
+ CallContext callContext,
+ ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public void getStreamImportedKeys(
+ FlightSql.CommandGetImportedKeys commandGetImportedKeys,
+ CallContext callContext,
+ ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public void getStreamCrossReference(
+ FlightSql.CommandGetCrossReference commandGetCrossReference,
+ CallContext callContext,
+ ServerStreamListener serverStreamListener) {}
+
+ @Override
+ public void close() throws Exception {}
+
+ @Override
+ public void listFlights(
+ CallContext callContext, Criteria criteria, StreamListener streamListener) {}
+ }
+}
diff --git a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java
index ccce7db70d..fce9ff134d 100644
--- a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java
+++ b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/PostgresqlQuirks.java
@@ -37,6 +37,9 @@ public class PostgresqlQuirks extends SqlValidationQuirks {
static final String POSTGRESQL_URL_ENV_VAR = "ADBC_JDBC_POSTGRESQL_URL";
static final String POSTGRESQL_USER_ENV_VAR = "ADBC_JDBC_POSTGRESQL_USER";
static final String POSTGRESQL_PASSWORD_ENV_VAR = "ADBC_JDBC_POSTGRESQL_PASSWORD";
+ static final String POSTGRESQL_DATABASE_ENV_VAR = "ADBC_JDBC_POSTGRESQL_DATABASE";
+
+ String catalog = "postgres";
static String makeJdbcUrl() {
final String postgresUrl = System.getenv(POSTGRESQL_URL_ENV_VAR);
@@ -49,12 +52,21 @@ static String makeJdbcUrl() {
return String.format("jdbc:postgresql://%s?user=%s&password=%s", postgresUrl, user, password);
}
+ public Connection getJdbcConnection() throws SQLException {
+ return DriverManager.getConnection(makeJdbcUrl());
+ }
+
@Override
public AdbcDatabase initDatabase(BufferAllocator allocator) throws AdbcException {
String url = makeJdbcUrl();
+ final String catalog = System.getenv(POSTGRESQL_DATABASE_ENV_VAR);
+ Assumptions.assumeFalse(
+ catalog == null, "PostgreSQL catalog not found, set " + POSTGRESQL_DATABASE_ENV_VAR);
+ this.catalog = catalog;
+
final Map parameters = new HashMap<>();
- parameters.put(AdbcDriver.PARAM_URL, url);
+ AdbcDriver.PARAM_URI.set(parameters, url);
parameters.put(JdbcDriver.PARAM_JDBC_QUIRKS, StandardJdbcQuirks.POSTGRESQL);
return new JdbcDriver(allocator).open(parameters);
}
@@ -71,8 +83,12 @@ public void cleanupTable(String name) throws Exception {
@Override
public String defaultCatalog() {
- // XXX: this should really come from configuration
- return "postgres";
+ return catalog;
+ }
+
+ @Override
+ public String defaultDbSchema() {
+ return "public";
}
@Override
@@ -94,4 +110,9 @@ public TimeUnit defaultTimeUnit() {
public TimeUnit defaultTimestampUnit() {
return TimeUnit.MICROSECOND;
}
+
+ @Override
+ public boolean supportsCurrentCatalog() {
+ return true;
+ }
}
diff --git a/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java
new file mode 100644
index 0000000000..13ca0ee191
--- /dev/null
+++ b/java/driver/jdbc-validation-postgresql/src/test/java/org/apache/arrow/adbc/driver/jdbc/postgresql/StatisticsTest.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.adbc.driver.jdbc.postgresql;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.entry;
+
+import java.sql.Connection;
+import java.sql.DatabaseMetaData;
+import java.sql.ResultSet;
+import java.sql.ResultSetMetaData;
+import java.sql.Statement;
+import java.util.Map;
+import org.apache.arrow.adbc.core.AdbcConnection;
+import org.apache.arrow.adbc.core.AdbcDatabase;
+import org.apache.arrow.adbc.core.StandardStatistics;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.ipc.ArrowReader;
+import org.apache.arrow.vector.util.Text;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+class StatisticsTest {
+ static PostgresqlQuirks quirks;
+
+ @BeforeAll
+ static void beforeAll() {
+ quirks = new PostgresqlQuirks();
+ }
+
+ @Test
+ void adbc() throws Exception {
+ try (Connection connection = quirks.getJdbcConnection();
+ Statement statement = connection.createStatement()) {
+ statement.executeUpdate("DROP TABLE IF EXISTS adbcpkeytest");
+ statement.executeUpdate("CREATE TABLE adbcpkeytest (key SERIAL PRIMARY KEY, value INT)");
+ statement.executeUpdate("INSERT INTO adbcpkeytest (value) VALUES (0), (1), (2)");
+ statement.executeUpdate("ANALYZE adbcpkeytest");
+ }
+
+ try (BufferAllocator allocator = new RootAllocator();
+ AdbcDatabase database = quirks.initDatabase(allocator);
+ AdbcConnection connection = database.connect();
+ ArrowReader reader = connection.getStatistics(null, null, "adbcpkeytest", true)) {
+ assertThat(reader.loadNextBatch()).isTrue();
+ VectorSchemaRoot vsr = reader.getVectorSchemaRoot();
+ assertThat(vsr.getRowCount()).isEqualTo(1);
+
+ ListVector catalogDbSchemas = (ListVector) vsr.getVector(1);
+ assertThat(catalogDbSchemas.getValueCount()).isEqualTo(1);
+
+ StructVector catalogDbSchema = (StructVector) catalogDbSchemas.getDataVector();
+ ListVector dbSchemaStatistics = (ListVector) catalogDbSchema.getVectorById(1);
+ assertThat(dbSchemaStatistics.getValueCount()).isEqualTo(1);
+
+ @SuppressWarnings("unchecked")
+ Map statistic = (Map) dbSchemaStatistics.getObject(0).get(0);
+ assertThat(statistic)
+ .contains(
+ entry("table_name", new Text("adbcpkeytest")),
+ entry("statistic_key", StandardStatistics.DISTINCT_COUNT.getKey()),
+ entry("statistic_value", 3L));
+
+ assertThat(reader.loadNextBatch()).isFalse();
+ }
+ }
+
+ /** Validate what PostgreSQL does. */
+ @Test
+ void jdbc() throws Exception {
+ try (Connection connection = quirks.getJdbcConnection();
+ Statement statement = connection.createStatement()) {
+ statement.executeUpdate("DROP TABLE IF EXISTS adbcpkeytest");
+ statement.executeUpdate("CREATE TABLE adbcpkeytest (key SERIAL PRIMARY KEY, value INT)");
+ statement.executeUpdate("INSERT INTO adbcpkeytest (value) VALUES (0), (1), (2)");
+ statement.executeUpdate("ANALYZE adbcpkeytest");
+
+ int count = 0;
+ try (ResultSet rs =
+ connection.getMetaData().getIndexInfo(null, null, "adbcpkeytest", false, true)) {
+ ResultSetMetaData rsmd = rs.getMetaData();
+ while (rs.next()) {
+ // For debugging
+ for (int i = 1; i <= rsmd.getColumnCount(); i++) {
+ System.out.println(rsmd.getColumnName(i) + " => " + rs.getObject(i));
+ }
+ System.out.println("===");
+
+ // TABLE_NAME
+ assertThat(rs.getString(3)).isEqualTo("adbcpkeytest");
+ // TYPE
+ assertThat(rs.getShort(7)).isEqualTo(DatabaseMetaData.tableIndexOther);
+ // CARDINALITY
+ assertThat(rs.getLong(11)).isEqualTo(3);
+
+ count++;
+ }
+ }
+
+ assertThat(count).isEqualTo(1);
+ }
+ }
+}
diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java
index 02c2ccac22..ae5ec226fc 100644
--- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java
+++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/InfoMetadataBuilder.java
@@ -25,10 +25,12 @@
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
+import org.apache.arrow.adbc.core.AdbcDriver;
import org.apache.arrow.adbc.core.AdbcInfoCode;
import org.apache.arrow.adbc.core.StandardSchemas;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
@@ -37,6 +39,7 @@
/** Helper class to track state needed to build up the info structure. */
final class InfoMetadataBuilder implements AutoCloseable {
private static final byte STRING_VALUE_TYPE_ID = (byte) 0;
+ private static final byte BIGINT_VALUE_TYPE_ID = (byte) 2;
private static final Map SUPPORTED_CODES = new HashMap<>();
private final Collection requestedCodes;
private final DatabaseMetaData dbmd;
@@ -45,6 +48,7 @@ final class InfoMetadataBuilder implements AutoCloseable {
final UInt4Vector infoCodes;
final DenseUnionVector infoValues;
final VarCharVector stringValues;
+ final BigIntVector bigIntValues;
@FunctionalInterface
interface AddInfo {
@@ -74,6 +78,11 @@ interface AddInfo {
final String driverVersion = b.dbmd.getDriverVersion() + " (ADBC Driver Version 0.0.1)";
b.setStringValue(idx, driverVersion);
});
+ SUPPORTED_CODES.put(
+ AdbcInfoCode.DRIVER_ADBC_VERSION.getValue(),
+ (b, idx) -> {
+ b.setBigIntValue(idx, AdbcDriver.ADBC_VERSION_1_1_0);
+ });
}
InfoMetadataBuilder(BufferAllocator allocator, Connection connection, int[] infoCodes)
@@ -86,7 +95,18 @@ interface AddInfo {
this.dbmd = connection.getMetaData();
this.infoCodes = (UInt4Vector) root.getVector(0);
this.infoValues = (DenseUnionVector) root.getVector(1);
- this.stringValues = this.infoValues.getVarCharVector((byte) 0);
+ this.stringValues = this.infoValues.getVarCharVector(STRING_VALUE_TYPE_ID);
+ this.bigIntValues = this.infoValues.getBigIntVector(BIGINT_VALUE_TYPE_ID);
+ }
+
+ void setBigIntValue(int index, long value) {
+ infoValues.setValueCount(index + 1);
+ infoValues.setTypeId(index, BIGINT_VALUE_TYPE_ID);
+ bigIntValues.setSafe(index, value);
+ infoValues
+ .getOffsetBuffer()
+ .setInt((long) index * DenseUnionVector.OFFSET_WIDTH, bigIntValues.getValueCount());
+ bigIntValues.setValueCount(bigIntValues.getValueCount() + 1);
}
void setStringValue(int index, final String value) {
diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java
index 1ddbf1c88a..aba972a9a2 100644
--- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java
+++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcArrowReader.java
@@ -42,12 +42,7 @@ public class JdbcArrowReader extends ArrowReader {
JdbcArrowReader(BufferAllocator allocator, ResultSet resultSet, Schema overrideSchema)
throws AdbcException {
super(allocator);
- final JdbcToArrowConfig config =
- new JdbcToArrowConfigBuilder()
- .setAllocator(allocator)
- .setCalendar(JdbcToArrowUtils.getUtcCalendar())
- .setTargetBatchSize(1024)
- .build();
+ final JdbcToArrowConfig config = makeJdbcConfig(allocator);
try {
this.delegate = JdbcToArrow.sqlToArrowVectorIterator(resultSet, config);
} catch (SQLException e) {
@@ -75,6 +70,14 @@ public class JdbcArrowReader extends ArrowReader {
}
}
+ static JdbcToArrowConfig makeJdbcConfig(BufferAllocator allocator) {
+ return new JdbcToArrowConfigBuilder()
+ .setAllocator(allocator)
+ .setCalendar(JdbcToArrowUtils.getUtcCalendar())
+ .setTargetBatchSize(1024)
+ .build();
+ }
+
@Override
public boolean loadNextBatch() {
if (!delegate.hasNext()) return false;
diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java
index 398ef6d42e..8f66c154fa 100644
--- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java
+++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcConnection.java
@@ -21,7 +21,9 @@
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import org.apache.arrow.adbc.core.AdbcConnection;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatement;
@@ -29,15 +31,24 @@
import org.apache.arrow.adbc.core.BulkIngestMode;
import org.apache.arrow.adbc.core.IsolationLevel;
import org.apache.arrow.adbc.core.StandardSchemas;
+import org.apache.arrow.adbc.core.StandardStatistics;
import org.apache.arrow.adbc.driver.jdbc.adapter.JdbcFieldInfoExtra;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.SmallIntVector;
+import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.complex.DenseUnionVector;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.holders.NullableBigIntHolder;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.Text;
public class JdbcConnection implements AdbcConnection {
private final BufferAllocator allocator;
@@ -117,6 +128,165 @@ public ArrowReader getObjects(
}
}
+ static final class Statistic {
+ String table;
+ String column;
+ short key;
+ long value;
+ boolean multiColumn = false;
+ }
+
+ @Override
+ public ArrowReader getStatistics(
+ String catalogPattern, String dbSchemaPattern, String tableNamePattern, boolean approximate)
+ throws AdbcException {
+ if (tableNamePattern == null) {
+ throw AdbcException.notImplemented(
+ JdbcDriverUtil.prefixExceptionMessage("getStatistics: must supply table name"));
+ }
+
+ try (final VectorSchemaRoot root =
+ VectorSchemaRoot.create(StandardSchemas.GET_STATISTICS_SCHEMA, allocator);
+ ResultSet rs =
+ connection
+ .getMetaData()
+ .getIndexInfo(
+ catalogPattern,
+ dbSchemaPattern,
+ tableNamePattern, /*unique*/
+ false,
+ approximate)) {
+ // Build up the statistics in-memory and then return a constant reader.
+ // We have to read and sort the data first because the ordering is not by the catalog/etc.
+
+ // {catalog: {schema: {index_name: statistic}}}
+ Map>> allStatistics = new HashMap<>();
+
+ while (rs.next()) {
+ String catalog = rs.getString(1);
+ String schema = rs.getString(2);
+ String table = rs.getString(3);
+ String index = rs.getString(6);
+ short statisticType = rs.getShort(7);
+ String column = rs.getString(9);
+ long cardinality = rs.getLong(11);
+
+ if (!allStatistics.containsKey(catalog)) {
+ allStatistics.put(catalog, new HashMap<>());
+ }
+
+ Map> catalogStats = allStatistics.get(catalog);
+ if (!catalogStats.containsKey(schema)) {
+ catalogStats.put(schema, new HashMap<>());
+ }
+
+ Map schemaStats = catalogStats.get(schema);
+ Statistic statistic = schemaStats.getOrDefault(index, new Statistic());
+ if (schemaStats.containsKey(index)) {
+ // Multi-column index, ignore it
+ statistic.multiColumn = true;
+ continue;
+ }
+
+ statistic.column = column;
+ statistic.table = table;
+ statistic.key =
+ statisticType == DatabaseMetaData.tableIndexStatistic
+ ? StandardStatistics.ROW_COUNT.getKey()
+ : StandardStatistics.DISTINCT_COUNT.getKey();
+ statistic.value = cardinality;
+ schemaStats.put(index, statistic);
+ }
+
+ VarCharVector catalogNames = (VarCharVector) root.getVector(0);
+ ListVector catalogDbSchemas = (ListVector) root.getVector(1);
+ StructVector dbSchemas = (StructVector) catalogDbSchemas.getDataVector();
+ VarCharVector dbSchemaNames = (VarCharVector) dbSchemas.getVectorById(0);
+ ListVector dbSchemaStatistics = (ListVector) dbSchemas.getVectorById(1);
+ StructVector statistics = (StructVector) dbSchemaStatistics.getDataVector();
+ VarCharVector tableNames = (VarCharVector) statistics.getVectorById(0);
+ VarCharVector columnNames = (VarCharVector) statistics.getVectorById(1);
+ SmallIntVector statisticKeys = (SmallIntVector) statistics.getVectorById(2);
+ DenseUnionVector statisticValues = (DenseUnionVector) statistics.getVectorById(3);
+ BitVector statisticIsApproximate = (BitVector) statistics.getVectorById(4);
+
+ // Build up the Arrow result
+ Text text = new Text();
+ NullableBigIntHolder holder = new NullableBigIntHolder();
+ int catalogIndex = 0;
+ int schemaIndex = 0;
+ int statisticIndex = 0;
+ for (String catalog : allStatistics.keySet()) {
+ Map> schemas = allStatistics.get(catalog);
+
+ if (catalog == null) {
+ catalogNames.setNull(catalogIndex);
+ } else {
+ text.set(catalog);
+ catalogNames.setSafe(catalogIndex, text);
+ }
+ catalogDbSchemas.startNewValue(catalogIndex);
+
+ int schemaCount = 0;
+ for (String schema : schemas.keySet()) {
+ if (schema == null) {
+ dbSchemaNames.setNull(schemaIndex);
+ } else {
+ text.set(schema);
+ dbSchemaNames.setSafe(schemaIndex, text);
+ }
+
+ dbSchemaStatistics.startNewValue(schemaIndex);
+
+ Map indices = schemas.get(schema);
+ int statisticCount = 0;
+ for (Statistic statistic : indices.values()) {
+ if (statistic.multiColumn) {
+ continue;
+ }
+
+ text.set(statistic.table);
+ tableNames.setSafe(statisticIndex, text);
+ if (statistic.column == null) {
+ columnNames.setNull(statisticIndex);
+ } else {
+ text.set(statistic.column);
+ columnNames.setSafe(statisticIndex, text);
+ }
+ statisticKeys.setSafe(statisticIndex, statistic.key);
+ statisticValues.setTypeId(statisticIndex, (byte) 0);
+ holder.isSet = 1;
+ holder.value = statistic.value;
+ statisticValues.setSafe(statisticIndex, holder);
+ statisticIsApproximate.setSafe(statisticIndex, approximate ? 1 : 0);
+
+ statistics.setIndexDefined(statisticIndex++);
+ statisticCount++;
+ }
+
+ dbSchemaStatistics.endValue(schemaIndex, statisticCount);
+
+ dbSchemas.setIndexDefined(schemaIndex++);
+ schemaCount++;
+ }
+
+ catalogDbSchemas.endValue(catalogIndex, schemaCount);
+ catalogIndex++;
+ }
+ root.setRowCount(catalogIndex);
+
+ return RootArrowReader.fromRoot(allocator, root);
+ } catch (SQLException e) {
+ throw JdbcDriverUtil.fromSqlException(e);
+ }
+ }
+
+ @Override
+ public ArrowReader getStatisticNames() throws AdbcException {
+ // TODO:
+ return AdbcConnection.super.getStatisticNames();
+ }
+
@Override
public Schema getTableSchema(String catalog, String dbSchema, String tableName)
throws AdbcException {
@@ -211,6 +381,42 @@ public void setAutoCommit(boolean enableAutoCommit) throws AdbcException {
}
}
+ @Override
+ public String getCurrentCatalog() throws AdbcException {
+ try {
+ return connection.getCatalog();
+ } catch (SQLException e) {
+ throw JdbcDriverUtil.fromSqlException(e);
+ }
+ }
+
+ @Override
+ public void setCurrentCatalog(String catalog) throws AdbcException {
+ try {
+ connection.setCatalog(catalog);
+ } catch (SQLException e) {
+ throw JdbcDriverUtil.fromSqlException(e);
+ }
+ }
+
+ @Override
+ public String getCurrentDbSchema() throws AdbcException {
+ try {
+ return connection.getSchema();
+ } catch (SQLException e) {
+ throw JdbcDriverUtil.fromSqlException(e);
+ }
+ }
+
+ @Override
+ public void setCurrentDbSchema(String dbSchema) throws AdbcException {
+ try {
+ connection.setSchema(dbSchema);
+ } catch (SQLException e) {
+ throw JdbcDriverUtil.fromSqlException(e);
+ }
+ }
+
@Override
public boolean getReadOnly() throws AdbcException {
try {
diff --git a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java
index 95b3775f68..fd39e6d08b 100644
--- a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java
+++ b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java
@@ -30,6 +30,7 @@
import java.util.stream.LongStream;
import org.apache.arrow.adapter.jdbc.JdbcFieldInfo;
import org.apache.arrow.adapter.jdbc.JdbcParameterBinder;
+import org.apache.arrow.adapter.jdbc.JdbcToArrowConfig;
import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatement;
@@ -263,6 +264,41 @@ public QueryResult executeQuery() throws AdbcException {
return new QueryResult(/*affectedRows=*/ -1, reader);
}
+ @Override
+ public Schema executeSchema() throws AdbcException {
+ if (bulkOperation != null) {
+ throw AdbcException.invalidState("[JDBC] Call executeUpdate() for bulk operations");
+ } else if (sqlQuery == null) {
+ throw AdbcException.invalidState("[JDBC] Must setSqlQuery() first");
+ }
+ try {
+ invalidatePriorQuery();
+ final PreparedStatement preparedStatement;
+ final PreparedStatement ownedStatement;
+ if (statement instanceof PreparedStatement) {
+ preparedStatement = (PreparedStatement) statement;
+ if (bindRoot != null) {
+ JdbcParameterBinder.builder(preparedStatement, bindRoot).bindAll().build().next();
+ }
+ ownedStatement = null;
+ } else {
+ // new statement
+ preparedStatement = connection.prepareStatement(sqlQuery);
+ ownedStatement = preparedStatement;
+ }
+
+ final JdbcToArrowConfig config = JdbcArrowReader.makeJdbcConfig(allocator);
+ final Schema schema =
+ JdbcToArrowUtils.jdbcToArrowSchema(preparedStatement.getMetaData(), config);
+ if (ownedStatement != null) {
+ ownedStatement.close();
+ }
+ return schema;
+ } catch (SQLException e) {
+ throw JdbcDriverUtil.fromSqlException(e);
+ }
+ }
+
@Override
public Schema getParameterSchema() throws AdbcException {
if (statement instanceof PreparedStatement) {
diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java
index 9915636d69..54e6059046 100644
--- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java
+++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionTest.java
@@ -18,6 +18,7 @@
package org.apache.arrow.adbc.driver.testsuite;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assumptions.assumeThat;
import org.apache.arrow.adbc.core.AdbcConnection;
import org.apache.arrow.adbc.core.AdbcDatabase;
@@ -48,6 +49,19 @@ public void afterEach() throws Exception {
AutoCloseables.close(connection, database, allocator);
}
+ @Test
+ void currentCatalog() throws Exception {
+ assumeThat(quirks.supportsCurrentCatalog()).isTrue();
+
+ assertThat(connection.getCurrentCatalog()).isEqualTo(quirks.defaultCatalog());
+ connection.setCurrentCatalog(quirks.defaultCatalog());
+ assertThat(connection.getCurrentCatalog()).isEqualTo(quirks.defaultCatalog());
+
+ assertThat(connection.getCurrentDbSchema()).isEqualTo(quirks.defaultDbSchema());
+ connection.setCurrentDbSchema(quirks.defaultDbSchema());
+ assertThat(connection.getCurrentDbSchema()).isEqualTo(quirks.defaultDbSchema());
+ }
+
@Test
void multipleConnections() throws Exception {
try (final AdbcConnection ignored = database.connect()) {}
diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java
index e7a1a5743a..4d9184a4bb 100644
--- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java
+++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java
@@ -19,6 +19,7 @@
import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertField;
import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertRoot;
+import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertSchema;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
@@ -239,6 +240,62 @@ public void bulkIngestCreateConflict() throws Exception {
}
}
+ @Test
+ public void executeSchema() throws Exception {
+ util.ingestTableIntsStrs(allocator, connection, tableName);
+ final String name = quirks.caseFoldColumnName("STRS");
+ try (final AdbcStatement stmt = connection.createStatement()) {
+ stmt.setSqlQuery("SELECT " + name + " FROM " + tableName);
+ final Schema actualSchema = stmt.executeSchema();
+ assertSchema(actualSchema)
+ .isEqualTo(
+ new Schema(
+ Collections.singletonList(
+ Field.nullable(name, Types.MinorType.VARCHAR.getType()))));
+ }
+ }
+
+ @Test
+ public void executeSchemaPrepared() throws Exception {
+ util.ingestTableIntsStrs(allocator, connection, tableName);
+ final String name = quirks.caseFoldColumnName("STRS");
+ try (final AdbcStatement stmt = connection.createStatement()) {
+ stmt.setSqlQuery("SELECT " + name + " FROM " + tableName);
+ stmt.prepare();
+ final Schema actualSchema = stmt.executeSchema();
+ assertSchema(actualSchema)
+ .isEqualTo(
+ new Schema(
+ Collections.singletonList(
+ Field.nullable(name, Types.MinorType.VARCHAR.getType()))));
+ }
+ }
+
+ @Test
+ public void executeSchemaParams() throws Exception {
+ try (final AdbcStatement stmt = connection.createStatement()) {
+ stmt.setSqlQuery("SELECT ? AS FOO");
+ stmt.prepare();
+ Schema actualSchema = stmt.executeSchema();
+ // Actual type unknown
+ assertThat(actualSchema.getFields().size()).isEqualTo(1);
+
+ final Schema schema =
+ new Schema(
+ Collections.singletonList(
+ Field.nullable(
+ quirks.caseFoldColumnName("foo"), Types.MinorType.VARCHAR.getType())));
+ try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ ((VarCharVector) root.getVector(0)).setSafe(0, "foo".getBytes(StandardCharsets.UTF_8));
+ root.setRowCount(1);
+ stmt.bind(root);
+
+ actualSchema = stmt.executeSchema();
+ assertSchema(actualSchema).isEqualTo(schema);
+ }
+ }
+ }
+
@Test
public void prepareQuery() throws Exception {
final Schema expectedSchema = util.ingestTableIntsStrs(allocator, connection, tableName);
diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java
index 120ecab255..a5da97f658 100644
--- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java
+++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java
@@ -35,6 +35,11 @@ public void cleanupTable(String name) throws Exception {}
/** Get the name of the default catalog. */
public abstract String defaultCatalog();
+ /** Get the name of the default schema. */
+ public String defaultDbSchema() {
+ return "";
+ }
+
/** Normalize a table name. */
public String caseFoldTableName(String name) {
return name;
@@ -110,4 +115,8 @@ public ArrowType defaultTimeType() {
public TimeUnit defaultTimestampUnit() {
return TimeUnit.MILLISECOND;
}
+
+ public boolean supportsCurrentCatalog() {
+ return false;
+ }
}