diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index fe14bac6ef..b60cc43b23 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -626,15 +626,19 @@ int TupleReader::GetNext(struct ArrowArray* out) { // Check the server-side response result_ = PQgetResult(conn_); - const int pq_status = PQresultStatus(result_); + const ExecStatusType pq_status = PQresultStatus(result_); if (pq_status != PGRES_COMMAND_OK) { - StringBuilderAppend(&error_builder_, "[libpq] Query failed [%d]: %s", pq_status, - PQresultErrorMessage(result_)); + const char* sqlstate = PQresultErrorField(result_, PG_DIAG_SQLSTATE); + StringBuilderAppend(&error_builder_, "[libpq] Query failed [%s]: %s", + PQresStatus(pq_status), PQresultErrorMessage(result_)); if (tmp.release != nullptr) { tmp.release(&tmp); } + if (sqlstate != nullptr && std::strcmp(sqlstate, "57014") == 0) { + return ECANCELED; + } return EIO; } @@ -1038,7 +1042,7 @@ AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { result = std::to_string(reader_.batch_size_hint_bytes_); } else { - SetError(error, "[libq] Unknown statement option '%s'", key); + SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_FOUND; } @@ -1052,13 +1056,13 @@ AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t AdbcStatusCode PostgresStatement::GetOptionBytes(const char* key, uint8_t* value, size_t* length, struct AdbcError* error) { - SetError(error, "[libq] Unknown statement option '%s'", key); + SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_FOUND; } AdbcStatusCode PostgresStatement::GetOptionDouble(const char* key, double* value, struct AdbcError* error) { - SetError(error, "[libq] Unknown statement option '%s'", key); + SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_FOUND; } @@ -1069,7 +1073,7 @@ AdbcStatusCode PostgresStatement::GetOptionInt(const char* key, int64_t* value, *value = reader_.batch_size_hint_bytes_; return ADBC_STATUS_OK; } - SetError(error, "[libq] Unknown statement option '%s'", key); + SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_FOUND; } @@ -1133,7 +1137,7 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, this->reader_.batch_size_hint_bytes_ = int_value; } else { - SetError(error, "[libq] Unknown statement option '%s'", key); + SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_IMPLEMENTED; } return ADBC_STATUS_OK; @@ -1141,13 +1145,13 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, AdbcStatusCode PostgresStatement::SetOptionBytes(const char* key, const uint8_t* value, size_t length, struct AdbcError* error) { - SetError(error, "%s%s", "[libpq] Unknown option ", key); + SetError(error, "%s%s", "[libpq] Unknown statement option ", key); return ADBC_STATUS_NOT_IMPLEMENTED; } AdbcStatusCode PostgresStatement::SetOptionDouble(const char* key, double value, struct AdbcError* error) { - SetError(error, "%s%s", "[libpq] Unknown option ", key); + SetError(error, "%s%s", "[libpq] Unknown statement option ", key); return ADBC_STATUS_NOT_IMPLEMENTED; } @@ -1162,7 +1166,7 @@ AdbcStatusCode PostgresStatement::SetOptionInt(const char* key, int64_t value, this->reader_.batch_size_hint_bytes_ = value; return ADBC_STATUS_OK; } - SetError(error, "%s%s", "[libpq] Unknown option ", key); + SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_IMPLEMENTED; } diff --git a/docs/source/python/api/adbc_driver_manager.rst b/docs/source/python/api/adbc_driver_manager.rst index c0d22b62ec..7023af6ace 100644 --- a/docs/source/python/api/adbc_driver_manager.rst +++ b/docs/source/python/api/adbc_driver_manager.rst @@ -31,9 +31,11 @@ Constants & Enums .. autoclass:: adbc_driver_manager.AdbcStatusCode :members: + :undoc-members: .. autoclass:: adbc_driver_manager.GetObjectsDepth :members: + :undoc-members: .. autoclass:: adbc_driver_manager.ConnectionOptions :members: diff --git a/python/adbc_driver_manager/adbc_driver_manager/__init__.py b/python/adbc_driver_manager/adbc_driver_manager/__init__.py index e2eaee5701..25b821eb80 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/__init__.py +++ b/python/adbc_driver_manager/adbc_driver_manager/__init__.py @@ -90,6 +90,8 @@ class DatabaseOptions(enum.Enum): #: Set the password to use for username-password authentication. PASSWORD = "password" + #: The URI to connect to. + URI = "uri" #: Set the username to use for username-password authentication. USERNAME = "username" @@ -100,6 +102,10 @@ class ConnectionOptions(enum.Enum): Not all drivers support all options. """ + #: Get/set the current catalog. + CURRENT_CATALOG = "adbc.connection.catalog" + #: Get/set the current schema. + CURRENT_DB_SCHEMA = "adbc.connection.db_schema" #: Set the transaction isolation level. ISOLATION_LEVEL = "adbc.connection.transaction.isolation_level" @@ -110,7 +116,11 @@ class StatementOptions(enum.Enum): Not all drivers support all options. """ + #: Enable incremental execution on ExecutePartitions. + INCREMENTAL = "adbc.statement.exec.incremental" #: For bulk ingestion, whether to create or append to the table. INGEST_MODE = INGEST_OPTION_MODE #: For bulk ingestion, the table to ingest into. INGEST_TARGET_TABLE = INGEST_OPTION_TARGET_TABLE + #: Get progress of a query. + PROGRESS = "adbc.statement.exec.progress" diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi index 8f107369ea..7723df1772 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi @@ -26,15 +26,22 @@ import typing INGEST_OPTION_MODE: str INGEST_OPTION_MODE_APPEND: str INGEST_OPTION_MODE_CREATE: str +INGEST_OPTION_MODE_CREATE_APPEND: str +INGEST_OPTION_MODE_REPLACE: str INGEST_OPTION_TARGET_TABLE: str class AdbcConnection(_AdbcHandle): def __init__(self, database: "AdbcDatabase", **kwargs: str) -> None: ... + def cancel(self) -> None: ... def close(self) -> None: ... def commit(self) -> None: ... def get_info( self, info_codes: Optional[List[Union[int, "AdbcInfoCode"]]] = None ) -> "ArrowArrayStreamHandle": ... + def get_option(self, key: str) -> str: ... + def get_option_bytes(self, key: str) -> bytes: ... + def get_option_float(self, key: str) -> float: ... + def get_option_int(self, key: str) -> int: ... def get_objects( self, depth: "GetObjectsDepth", @@ -54,12 +61,16 @@ class AdbcConnection(_AdbcHandle): def read_partition(self, partition: bytes) -> "ArrowArrayStreamHandle": ... def rollback(self) -> None: ... def set_autocommit(self, enabled: bool) -> None: ... - def set_options(self, **kwargs: str) -> None: ... + def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... class AdbcDatabase(_AdbcHandle): def __init__(self, **kwargs: str) -> None: ... def close(self) -> None: ... - def set_options(self, **kwargs: str) -> None: ... + def get_option(self, key: str) -> str: ... + def get_option_bytes(self, key: str) -> bytes: ... + def get_option_float(self, key: str) -> float: ... + def get_option_int(self, key: str) -> int: ... + def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... class AdbcInfoCode(enum.IntEnum): DRIVER_ARROW_VERSION = ... @@ -73,13 +84,19 @@ class AdbcStatement(_AdbcHandle): def __init__(self, *args, **kwargs) -> None: ... def bind(self, *args, **kwargs) -> Any: ... def bind_stream(self, *args, **kwargs) -> Any: ... + def cancel(self) -> None: ... def close(self) -> None: ... def execute_partitions(self, *args, **kwargs) -> Any: ... def execute_query(self, *args, **kwargs) -> Any: ... + def execute_schema(self) -> "ArrowSchemaHandle": ... def execute_update(self, *args, **kwargs) -> Any: ... + def get_option(self, key: str) -> str: ... + def get_option_bytes(self, key: str) -> bytes: ... + def get_option_float(self, key: str) -> float: ... + def get_option_int(self, key: str) -> int: ... def get_parameter_schema(self, *args, **kwargs) -> Any: ... def prepare(self, *args, **kwargs) -> Any: ... - def set_options(self, *args, **kwargs) -> Any: ... + def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... def set_sql_query(self, *args, **kwargs) -> Any: ... def set_substrait_plan(self, *args, **kwargs) -> Any: ... def __reduce__(self) -> Any: ... diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index 406d577887..a5ccc23be6 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -69,6 +69,8 @@ cdef extern from "adbc.h" nogil: cdef const char* ADBC_INGEST_OPTION_MODE cdef const char* ADBC_INGEST_OPTION_MODE_APPEND cdef const char* ADBC_INGEST_OPTION_MODE_CREATE + cdef const char* ADBC_INGEST_OPTION_MODE_REPLACE + cdef const char* ADBC_INGEST_OPTION_MODE_CREATE_APPEND cdef int ADBC_OBJECT_DEPTH_ALL cdef int ADBC_OBJECT_DEPTH_CATALOGS @@ -112,11 +114,22 @@ cdef extern from "adbc.h" nogil: CAdbcPartitionsRelease release CAdbcStatusCode AdbcDatabaseNew(CAdbcDatabase* database, CAdbcError* error) + CAdbcStatusCode AdbcDatabaseGetOption( + CAdbcDatabase*, const char*, char*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcDatabaseGetOptionBytes( + CAdbcDatabase*, const char*, uint8_t*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcDatabaseGetOptionDouble( + CAdbcDatabase*, const char*, double*, CAdbcError*); + CAdbcStatusCode AdbcDatabaseGetOptionInt( + CAdbcDatabase*, const char*, int64_t*, CAdbcError*); CAdbcStatusCode AdbcDatabaseSetOption( - CAdbcDatabase* database, - const char* key, - const char* value, - CAdbcError* error) + CAdbcDatabase*, const char*, const char*, CAdbcError*) + CAdbcStatusCode AdbcDatabaseSetOptionBytes( + CAdbcDatabase*, const char*, const uint8_t*, size_t, CAdbcError*) + CAdbcStatusCode AdbcDatabaseSetOptionDouble( + CAdbcDatabase*, const char*, double, CAdbcError*) + CAdbcStatusCode AdbcDatabaseSetOptionInt( + CAdbcDatabase*, const char*, int64_t, CAdbcError*) CAdbcStatusCode AdbcDatabaseInit(CAdbcDatabase* database, CAdbcError* error) CAdbcStatusCode AdbcDatabaseRelease(CAdbcDatabase* database, CAdbcError* error) @@ -126,6 +139,7 @@ cdef extern from "adbc.h" nogil: CAdbcDriverInitFunc init_func, CAdbcError* error) + CAdbcStatusCode AdbcConnectionCancel(CAdbcConnection*, CAdbcError*) CAdbcStatusCode AdbcConnectionCommit( CAdbcConnection* connection, CAdbcError* error) @@ -154,6 +168,19 @@ cdef extern from "adbc.h" nogil: const char* column_name, CArrowArrayStream* stream, CAdbcError* error) + CAdbcStatusCode AdbcConnectionGetOption( + CAdbcConnection*, const char*, char*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetOptionBytes( + CAdbcConnection*, const char*, uint8_t*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetOptionDouble( + CAdbcConnection*, const char*, double*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetOptionInt( + CAdbcConnection*, const char*, int64_t*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetStatistics( + CAdbcConnection*, const char*, const char*, const char*, + char, CArrowArrayStream*, CAdbcError*); + CAdbcStatusCode AdbcConnectionGetStatisticNames( + CAdbcConnection*, CArrowArrayStream*, CAdbcError*); CAdbcStatusCode AdbcConnectionGetTableSchema( CAdbcConnection* connection, const char* catalog, @@ -172,20 +199,24 @@ cdef extern from "adbc.h" nogil: CAdbcStatusCode AdbcConnectionNew( CAdbcConnection* connection, CAdbcError* error) - CAdbcStatusCode AdbcConnectionSetOption( - CAdbcConnection* connection, - const char* key, - const char* value, - CAdbcError* error) CAdbcStatusCode AdbcConnectionRelease( CAdbcConnection* connection, CAdbcError* error) - + CAdbcStatusCode AdbcConnectionSetOption( + CAdbcConnection*, const char*, const char*, CAdbcError*) + CAdbcStatusCode AdbcConnectionSetOptionBytes( + CAdbcConnection*, const char*, const uint8_t*, size_t, CAdbcError*) + CAdbcStatusCode AdbcConnectionSetOptionDouble( + CAdbcConnection*, const char*, double, CAdbcError*) + CAdbcStatusCode AdbcConnectionSetOptionInt( + CAdbcConnection*, const char*, int64_t, CAdbcError*) CAdbcStatusCode AdbcStatementBind( CAdbcStatement* statement, CArrowArray*, CArrowSchema*, CAdbcError* error) + + CAdbcStatusCode AdbcStatementCancel(CAdbcStatement*, CAdbcError*) CAdbcStatusCode AdbcStatementBindStream( CAdbcStatement* statement, CArrowArrayStream*, @@ -199,6 +230,16 @@ cdef extern from "adbc.h" nogil: CAdbcStatement* statement, CArrowArrayStream* out, int64_t* rows_affected, CAdbcError* error) + CAdbcStatusCode AdbcStatementExecuteSchema( + CAdbcStatement*, CArrowSchema*, CAdbcError*) + CAdbcStatusCode AdbcStatementGetOption( + CAdbcStatement*, const char*, char*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcStatementGetOptionBytes( + CAdbcStatement*, const char*, uint8_t*, size_t*, CAdbcError*); + CAdbcStatusCode AdbcStatementGetOptionDouble( + CAdbcStatement*, const char*, double*, CAdbcError*); + CAdbcStatusCode AdbcStatementGetOptionInt( + CAdbcStatement*, const char*, int64_t*, CAdbcError*); CAdbcStatusCode AdbcStatementGetParameterSchema( CAdbcStatement* statement, CArrowSchema* schema, @@ -211,10 +252,13 @@ cdef extern from "adbc.h" nogil: CAdbcStatement* statement, CAdbcError* error) CAdbcStatusCode AdbcStatementSetOption( - CAdbcStatement* statement, - const char* key, - const char* value, - CAdbcError* error) + CAdbcStatement*, const char*, const char*, CAdbcError*) + CAdbcStatusCode AdbcStatementSetOptionBytes( + CAdbcStatement*, const char*, const uint8_t*, size_t, CAdbcError*) + CAdbcStatusCode AdbcStatementSetOptionDouble( + CAdbcStatement*, const char*, double, CAdbcError*) + CAdbcStatusCode AdbcStatementSetOptionInt( + CAdbcStatement*, const char*, int64_t, CAdbcError*) CAdbcStatusCode AdbcStatementSetSqlQuery( CAdbcStatement* statement, const char* query, @@ -348,6 +392,8 @@ NotSupportedError.__module__ = "adbc_driver_manager" INGEST_OPTION_MODE = ADBC_INGEST_OPTION_MODE.decode("utf-8") INGEST_OPTION_MODE_APPEND = ADBC_INGEST_OPTION_MODE_APPEND.decode("utf-8") INGEST_OPTION_MODE_CREATE = ADBC_INGEST_OPTION_MODE_CREATE.decode("utf-8") +INGEST_OPTION_MODE_REPLACE = ADBC_INGEST_OPTION_MODE_REPLACE.decode("utf-8") +INGEST_OPTION_MODE_CREATE_APPEND = ADBC_INGEST_OPTION_MODE_CREATE_APPEND.decode("utf-8") INGEST_OPTION_TARGET_TABLE = ADBC_INGEST_OPTION_TARGET_TABLE.decode("utf-8") @@ -521,6 +567,11 @@ class GetObjectsDepth(enum.IntEnum): COLUMNS = ADBC_OBJECT_DEPTH_COLUMNS +# Assume a driver won't return more than 128 MiB of option data at +# once. +_MAX_OPTION_SIZE = 2**27 + + cdef class AdbcDatabase(_AdbcHandle): """ An instance of a database. @@ -581,15 +632,102 @@ cdef class AdbcDatabase(_AdbcHandle): status = AdbcDatabaseRelease(&self.database, &c_error) check_error(status, &c_error) + def get_option(self, key: str) -> str: + """Get the value of a string option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcDatabaseGetOption( + &self.database, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + # Remove trailing null terminator + if c_len > 0: + c_len -= 1 + return buf[:c_len].decode("utf-8") + + def get_option_bytes(self, key: str) -> bytes: + """Get the value of a binary option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcDatabaseGetOptionBytes( + &self.database, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + return bytes(buf[:c_len]) + + def get_option_float(self, key: str) -> float: + """Get the value of a floating-point option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef double c_value = 0.0 + check_error( + AdbcDatabaseGetOptionDouble( + &self.database, c_key, &c_value, &c_error), + &c_error) + return c_value + + def get_option_int(self, key: str) -> int: + """Get the value of an integer option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef int64_t c_value = 0 + check_error( + AdbcDatabaseGetOptionInt( + &self.database, c_key, &c_value, &c_error), + &c_error) + return c_value + def set_options(self, **kwargs) -> None: - """Set arbitrary key-value options. + """ + Set arbitrary key-value options. Note, not all drivers support setting options after creation. See Also -------- adbc_driver_manager.DatabaseOptions : Standard option names. - """ cdef CAdbcError c_error = empty_error() cdef char* c_key = NULL @@ -600,12 +738,28 @@ cdef class AdbcDatabase(_AdbcHandle): if value is None: c_value = NULL - else: + status = AdbcDatabaseSetOption( + &self.database, c_key, c_value, &c_error) + elif isinstance(value, str): value = value.encode("utf-8") c_value = value + status = AdbcDatabaseSetOption( + &self.database, c_key, c_value, &c_error) + elif isinstance(value, bytes): + c_value = value + status = AdbcDatabaseSetOptionBytes( + &self.database, c_key, c_value, len(value), &c_error) + elif isinstance(value, float): + status = AdbcDatabaseSetOptionDouble( + &self.database, c_key, value, &c_error) + elif isinstance(value, int): + status = AdbcDatabaseSetOptionInt( + &self.database, c_key, value, &c_error) + else: + raise ValueError( + f"Unsupported type {type(value)} for value {value!r} " + f"of option {key}") - status = AdbcDatabaseSetOption( - &self.database, c_key, c_value, &c_error) check_error(status, &c_error) @@ -659,6 +813,14 @@ cdef class AdbcConnection(_AdbcHandle): database._open_child() + def cancel(self) -> None: + """Attempt to cancel any ongoing operations on the connection.""" + cdef CAdbcError c_error = empty_error() + cdef CAdbcStatusCode status + with nogil: + status = AdbcConnectionCancel(&self.connection, &c_error) + check_error(status, &c_error) + def commit(self) -> None: """Commit the current transaction.""" cdef CAdbcError c_error = empty_error() @@ -747,6 +909,93 @@ cdef class AdbcConnection(_AdbcHandle): return stream + def get_option(self, key: str) -> str: + """Get the value of a string option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcConnectionGetOption( + &self.connection, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + # Remove trailing null terminator + if c_len > 0: + c_len -= 1 + return buf[:c_len].decode("utf-8") + + def get_option_bytes(self, key: str) -> bytes: + """Get the value of a binary option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcConnectionGetOptionBytes( + &self.connection, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + return bytes(buf[:c_len]) + + def get_option_float(self, key: str) -> float: + """Get the value of a floating-point option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef double c_value = 0.0 + check_error( + AdbcConnectionGetOptionDouble( + &self.connection, c_key, &c_value, &c_error), + &c_error) + return c_value + + def get_option_int(self, key: str) -> int: + """Get the value of an integer option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef int64_t c_value = 0 + check_error( + AdbcConnectionGetOptionInt( + &self.connection, c_key, &c_value, &c_error), + &c_error) + return c_value + def get_table_schema(self, catalog, db_schema, table_name) -> ArrowSchemaHandle: """ Get the Arrow schema of a table. @@ -854,12 +1103,28 @@ cdef class AdbcConnection(_AdbcHandle): if value is None: c_value = NULL - else: + status = AdbcConnectionSetOption( + &self.connection, c_key, c_value, &c_error) + elif isinstance(value, str): value = value.encode("utf-8") c_value = value + status = AdbcConnectionSetOption( + &self.connection, c_key, c_value, &c_error) + elif isinstance(value, bytes): + c_value = value + status = AdbcConnectionSetOptionBytes( + &self.connection, c_key, c_value, len(value), &c_error) + elif isinstance(value, float): + status = AdbcConnectionSetOptionDouble( + &self.connection, c_key, value, &c_error) + elif isinstance(value, int): + status = AdbcConnectionSetOptionInt( + &self.connection, c_key, value, &c_error) + else: + raise ValueError( + f"Unsupported type {type(value)} for value {value!r} " + f"of option {key}") - status = AdbcConnectionSetOption( - &self.connection, c_key, c_value, &c_error) check_error(status, &c_error) def close(self) -> None: @@ -970,7 +1235,16 @@ cdef class AdbcStatement(_AdbcHandle): &c_error) check_error(status, &c_error) + def cancel(self) -> None: + """Attempt to cancel any ongoing operations on the connection.""" + cdef CAdbcError c_error = empty_error() + cdef CAdbcStatusCode status + with nogil: + status = AdbcStatementCancel(&self.statement, &c_error) + check_error(status, &c_error) + def close(self) -> None: + """Release the handle to the statement.""" cdef CAdbcError c_error = empty_error() cdef CAdbcStatusCode status self.connection._close_child() @@ -1044,6 +1318,25 @@ cdef class AdbcStatement(_AdbcHandle): return (partitions, schema, rows_affected) + def execute_schema(self) -> ArrowSchemaHandle: + """ + Get the schema of the result set without executing the query. + + Returns + ------- + ArrowSchemaHandle + The schema of the result set. + """ + cdef CAdbcError c_error = empty_error() + cdef ArrowSchemaHandle schema = ArrowSchemaHandle() + with nogil: + status = AdbcStatementExecuteSchema( + &self.statement, + &schema.schema, + &c_error) + check_error(status, &c_error) + return schema + def execute_update(self) -> int: """ Execute the query without a result set. @@ -1064,6 +1357,93 @@ cdef class AdbcStatement(_AdbcHandle): check_error(status, &c_error) return rows_affected + def get_option(self, key: str) -> str: + """Get the value of a string option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcStatementGetOption( + &self.statement, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + # Remove trailing null terminator + if c_len > 0: + c_len -= 1 + return buf[:c_len].decode("utf-8") + + def get_option_bytes(self, key: str) -> bytes: + """Get the value of a binary option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef uint8_t* c_value = NULL + cdef size_t c_len = 0 + + buf = bytearray(1024) + while True: + c_value = buf + c_len = len(buf) + check_error( + AdbcStatementGetOptionBytes( + &self.statement, c_key, buf, &c_len, &c_error), + &c_error) + if c_len <= len(buf): + # Entire value read + break + else: + # Buffer too small + new_len = len(buf) * 2 + if new_len > _MAX_OPTION_SIZE: + raise RuntimeError( + f"Could not read option {key}: " + f"would need more than {len(buf)} bytes") + buf = bytearray(new_len) + + return bytes(buf[:c_len]) + + def get_option_float(self, key: str) -> float: + """Get the value of a floating-point option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef double c_value = 0.0 + check_error( + AdbcStatementGetOptionDouble( + &self.statement, c_key, &c_value, &c_error), + &c_error) + return c_value + + def get_option_int(self, key: str) -> int: + """Get the value of an integer option.""" + cdef CAdbcError c_error = empty_error() + key_bytes = key.encode("utf-8") + cdef char* c_key = key_bytes + cdef int64_t c_value = 0 + check_error( + AdbcStatementGetOptionInt( + &self.statement, c_key, &c_value, &c_error), + &c_error) + return c_value + def get_parameter_schema(self) -> ArrowSchemaHandle: """Get the Arrow schema for bound parameters. @@ -1104,7 +1484,8 @@ cdef class AdbcStatement(_AdbcHandle): check_error(status, &c_error) def set_options(self, **kwargs) -> None: - """Set arbitrary key-value options. + """ + Set arbitrary key-value options. See Also -------- @@ -1119,12 +1500,28 @@ cdef class AdbcStatement(_AdbcHandle): if value is None: c_value = NULL - else: + status = AdbcStatementSetOption( + &self.statement, c_key, c_value, &c_error) + elif isinstance(value, str): value = value.encode("utf-8") c_value = value + status = AdbcStatementSetOption( + &self.statement, c_key, c_value, &c_error) + elif isinstance(value, bytes): + c_value = value + status = AdbcStatementSetOptionBytes( + &self.statement, c_key, c_value, len(value), &c_error) + elif isinstance(value, float): + status = AdbcStatementSetOptionDouble( + &self.statement, c_key, value, &c_error) + elif isinstance(value, int): + status = AdbcStatementSetOptionInt( + &self.statement, c_key, value, &c_error) + else: + raise ValueError( + f"Unsupported type {type(value)} for value {value!r} " + f"of option {key}") - status = AdbcStatementSetOption( - &self.statement, c_key, c_value, &c_error) check_error(status, &c_error) def set_sql_query(self, str query not None) -> None: diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index 31e4392ae5..d9b1f5540a 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -43,6 +43,8 @@ except ImportError as e: raise ImportError("PyArrow is required for the DBAPI-compatible interface") from e +import adbc_driver_manager + from . import _lib if typing.TYPE_CHECKING: @@ -78,6 +80,7 @@ 100: "driver_name", 101: "driver_version", 102: "driver_arrow_version", + 103: "driver_adbc_version", } # ---------------------------------------------------------- @@ -344,6 +347,16 @@ def __del__(self) -> None: # API Extensions # ------------------------------------------------------------ + def adbc_cancel(self) -> None: + """ + Cancel any ongoing operations on this connection. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + self._conn.cancel() + def adbc_clone(self) -> "Connection": """ Create a new Connection sharing the same underlying database. @@ -479,6 +492,40 @@ def adbc_connection(self) -> _lib.AdbcConnection: """ return self._conn + @property + def adbc_current_catalog(self) -> str: + """ + The name of the current catalog. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + key = adbc_driver_manager.ConnectionOptions.CURRENT_CATALOG.value + return self._conn.get_option(key) + + @adbc_current_catalog.setter + def adbc_current_catalog(self, catalog: str) -> None: + key = adbc_driver_manager.ConnectionOptions.CURRENT_CATALOG.value + self._conn.set_options(**{key: catalog}) + + @property + def adbc_current_db_schema(self) -> str: + """ + The name of the current schema. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + key = adbc_driver_manager.ConnectionOptions.CURRENT_DB_SCHEMA.value + return self._conn.get_option(key) + + @adbc_current_db_schema.setter + def adbc_current_db_schema(self, db_schema: str) -> None: + key = adbc_driver_manager.ConnectionOptions.CURRENT_DB_SCHEMA.value + self._conn.set_options(**{key: db_schema}) + @property def adbc_database(self) -> _lib.AdbcDatabase: """ @@ -729,11 +776,21 @@ def __next__(self): # API Extensions # ------------------------------------------------------------ + def adbc_cancel(self) -> None: + """ + Cancel any ongoing operations on this statement. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + self._stmt.cancel() + def adbc_ingest( self, table_name: str, data: Union[pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader], - mode: Literal["append", "create"] = "create", + mode: Literal["append", "create", "replace", "append_create"] = "create", ) -> int: """ Ingest Arrow data into a database table. @@ -748,7 +805,12 @@ def adbc_ingest( data The Arrow data to insert. mode - Whether to append data to an existing table, or create a new table. + How to deal with existing data: + + - 'append': append to a table (error if table does not exist) + - 'create': create a table and insert (error if table exists) + - 'create_append': create a table (if not exists) and insert + - 'replace': drop existing table (if any), then same as 'create' Returns ------- @@ -764,6 +826,10 @@ def adbc_ingest( c_mode = _lib.INGEST_OPTION_MODE_APPEND elif mode == "create": c_mode = _lib.INGEST_OPTION_MODE_CREATE + elif mode == "create_append": + c_mode = _lib.INGEST_OPTION_MODE_CREATE_APPEND + elif mode == "replace": + c_mode = _lib.INGEST_OPTION_MODE_REPLACE else: raise ValueError(f"Invalid value for 'mode': {mode}") self._stmt.set_options( @@ -810,6 +876,23 @@ def adbc_execute_partitions( partitions, schema, self._rowcount = self._stmt.execute_partitions() return partitions, pyarrow.Schema._import_from_c(schema.address) + def adbc_execute_schema(self, operation, parameters=None) -> pyarrow.Schema: + """ + Get the schema of the result set of a query without executing it. + + Returns + ------- + pyarrow.Schema + The schema of the result set. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + self._prepare_execute(operation, parameters) + schema = self._stmt.execute_schema() + return pyarrow.Schema._import_from_c(schema.address) + def adbc_prepare(self, operation: Union[bytes, str]) -> Optional[pyarrow.Schema]: """ Prepare a query without executing it. diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py b/python/adbc_driver_postgresql/tests/test_dbapi.py index c50cad1e85..e3f86a4447 100644 --- a/python/adbc_driver_postgresql/tests/test_dbapi.py +++ b/python/adbc_driver_postgresql/tests/test_dbapi.py @@ -17,6 +17,7 @@ from typing import Generator +import pyarrow import pytest from adbc_driver_postgresql import StatementOptions, dbapi @@ -28,6 +29,32 @@ def postgres(postgres_uri: str) -> Generator[dbapi.Connection, None, None]: yield conn +def test_conn_current_catalog(postgres: dbapi.Connection) -> None: + assert postgres.adbc_current_catalog != "" + + +def test_conn_current_db_schema(postgres: dbapi.Connection) -> None: + assert postgres.adbc_current_db_schema == "public" + + +def test_conn_change_db_schema(postgres: dbapi.Connection) -> None: + assert postgres.adbc_current_db_schema == "public" + + with postgres.cursor() as cur: + cur.execute("CREATE SCHEMA IF NOT EXISTS dbapischema") + + assert postgres.adbc_current_db_schema == "public" + postgres.adbc_current_db_schema = "dbapischema" + assert postgres.adbc_current_db_schema == "dbapischema" + + +def test_conn_get_info(postgres: dbapi.Connection) -> None: + info = postgres.adbc_get_info() + assert info["driver_name"] == "ADBC PostgreSQL Driver" + assert info["driver_adbc_version"] == 1_001_000 + assert info["vendor_name"] == "PostgreSQL" + + def test_query_batch_size(postgres: dbapi.Connection): with postgres.cursor() as cur: cur.execute("DROP TABLE IF EXISTS test_batch_size") @@ -47,6 +74,12 @@ def test_query_batch_size(postgres: dbapi.Connection): cur.adbc_statement.set_options( **{StatementOptions.BATCH_SIZE_HINT_BYTES.value: "1"} ) + assert ( + cur.adbc_statement.get_option_int( + StatementOptions.BATCH_SIZE_HINT_BYTES.value + ) + == 1 + ) cur.execute("SELECT * FROM test_batch_size") table = cur.fetch_arrow_table() assert len(table.to_batches()) == 65536 @@ -54,17 +87,55 @@ def test_query_batch_size(postgres: dbapi.Connection): cur.adbc_statement.set_options( **{StatementOptions.BATCH_SIZE_HINT_BYTES.value: "4096"} ) + assert ( + cur.adbc_statement.get_option_int( + StatementOptions.BATCH_SIZE_HINT_BYTES.value + ) + == 4096 + ) cur.execute("SELECT * FROM test_batch_size") table = cur.fetch_arrow_table() assert 64 <= len(table.to_batches()) <= 256 +def test_query_cancel(postgres: dbapi.Connection) -> None: + with postgres.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS test_batch_size") + cur.execute("CREATE TABLE test_batch_size (ints INT)") + cur.execute( + """ + INSERT INTO test_batch_size (ints) + SELECT generated :: INT + FROM GENERATE_SERIES(1, 65536) temp(generated) + """ + ) + + cur.execute("SELECT * FROM test_batch_size") + cur.adbc_cancel() + # XXX(https://github.com/apache/arrow-adbc/issues/940): + # PyArrow swallows the errno and doesn't set it into the + # OSError, so we have no clue what happened here. (Though the + # driver does properly return ECANCELED.) + with pytest.raises(OSError, match="canceling statement"): + cur.fetchone() + + +def test_query_execute_schema(postgres: dbapi.Connection) -> None: + with postgres.cursor() as cur: + schema = cur.adbc_execute_schema("SELECT 1 AS foo") + assert schema == pyarrow.schema([("foo", "int32")]) + + def test_query_trivial(postgres: dbapi.Connection): with postgres.cursor() as cur: cur.execute("SELECT 1") assert cur.fetchone() == (1,) +def test_stmt_ingest(postgres: dbapi.Connection) -> None: + pass + + def test_ddl(postgres: dbapi.Connection): with postgres.cursor() as cur: cur.execute("DROP TABLE IF EXISTS test_ddl")