diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc index 17732810fc..b5f12ca73f 100644 --- a/c/driver/postgresql/connection.cc +++ b/c/driver/postgresql/connection.cc @@ -17,6 +17,7 @@ #include "connection.h" +#include #include #include #include @@ -175,6 +176,13 @@ class PostgresGetObjectsHelper : public adbc::driver::GetObjectsHelper { all_constraints_(conn, kConstraintsQueryAll), some_constraints_(conn, ConstraintsQuery()) {} + // Allow Redshift to execute this query without constraints + // TODO(paleolimbot): Investigate to see if we can simplify the constraits query so that + // it works on both! + void SetEnableConstraints(bool enable_constraints) { + enable_constraints_ = enable_constraints; + } + Status Load(adbc::driver::GetObjectsDepth depth, std::optional catalog_filter, std::optional schema_filter, @@ -262,16 +270,23 @@ class PostgresGetObjectsHelper : public adbc::driver::GetObjectsHelper { std::optional column_filter) override { if (column_filter.has_value()) { UNWRAP_STATUS(some_columns_.Execute( - {std::string(schema), std::string(table), std::string(*column_filter)})) - UNWRAP_STATUS(some_constraints_.Execute( - {std::string(schema), std::string(table), std::string(*column_filter)})) + {std::string(schema), std::string(table), std::string(*column_filter)})); next_column_ = some_columns_.Row(-1); - next_constraint_ = some_constraints_.Row(-1); } else { - UNWRAP_STATUS(all_columns_.Execute({std::string(schema), std::string(table)})) - UNWRAP_STATUS(all_constraints_.Execute({std::string(schema), std::string(table)})) + UNWRAP_STATUS(all_columns_.Execute({std::string(schema), std::string(table)})); next_column_ = all_columns_.Row(-1); - next_constraint_ = all_constraints_.Row(-1); + } + + if (enable_constraints_) { + if (column_filter.has_value()) { + UNWRAP_STATUS(some_constraints_.Execute( + {std::string(schema), std::string(table), std::string(*column_filter)})) + next_constraint_ = some_constraints_.Row(-1); + } else { + UNWRAP_STATUS( + all_constraints_.Execute({std::string(schema), std::string(table)})); + next_constraint_ = all_constraints_.Row(-1); + } } return Status::Ok(); @@ -348,6 +363,9 @@ class PostgresGetObjectsHelper : public adbc::driver::GetObjectsHelper { PqResultHelper all_constraints_; PqResultHelper some_constraints_; + // On Redshift, the constraints query fails + bool enable_constraints_{true}; + // Iterator state for the catalogs/schema/table/column queries PqResultRow next_catalog_; PqResultRow next_schema_; @@ -478,19 +496,30 @@ AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection, for (size_t i = 0; i < info_codes_length; i++) { switch (info_codes[i]) { case ADBC_INFO_VENDOR_NAME: - infos.push_back({info_codes[i], "PostgreSQL"}); + infos.push_back({info_codes[i], std::string(VendorName())}); break; case ADBC_INFO_VENDOR_VERSION: { - const char* stmt = "SHOW server_version_num"; - auto result_helper = PqResultHelper{conn_, std::string(stmt)}; - RAISE_STATUS(error, result_helper.Execute()); - auto it = result_helper.begin(); - if (it == result_helper.end()) { - SetError(error, "[libpq] PostgreSQL returned no rows for '%s'", stmt); - return ADBC_STATUS_INTERNAL; + if (VendorName() == "Redshift") { + const std::array& version = VendorVersion(); + std::string version_string = std::to_string(version[0]) + "." + + std::to_string(version[1]) + "." + + std::to_string(version[2]); + infos.push_back({info_codes[i], std::move(version_string)}); + + } else { + // Gives a version in the form 140000 instead of 14.0.0 + const char* stmt = "SHOW server_version_num"; + auto result_helper = PqResultHelper{conn_, std::string(stmt)}; + RAISE_STATUS(error, result_helper.Execute()); + auto it = result_helper.begin(); + if (it == result_helper.end()) { + SetError(error, "[libpq] PostgreSQL returned no rows for '%s'", stmt); + return ADBC_STATUS_INTERNAL; + } + const char* server_version_num = (*it)[0].data; + infos.push_back({info_codes[i], server_version_num}); } - const char* server_version_num = (*it)[0].data; - infos.push_back({info_codes[i], server_version_num}); + break; } case ADBC_INFO_DRIVER_NAME: @@ -520,7 +549,8 @@ AdbcStatusCode PostgresConnection::GetObjects( struct AdbcConnection* connection, int c_depth, const char* catalog, const char* db_schema, const char* table_name, const char** table_type, const char* column_name, struct ArrowArrayStream* out, struct AdbcError* error) { - PostgresGetObjectsHelper new_helper(conn_); + PostgresGetObjectsHelper helper(conn_); + helper.SetEnableConstraints(VendorName() != "Redshift"); const auto catalog_filter = catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt; @@ -559,9 +589,9 @@ AdbcStatusCode PostgresConnection::GetObjects( .ToAdbc(error); } - auto status = BuildGetObjects(&new_helper, depth, catalog_filter, schema_filter, + auto status = BuildGetObjects(&helper, depth, catalog_filter, schema_filter, table_filter, column_filter, table_type_filter, out); - RAISE_STATUS(error, new_helper.Close()); + RAISE_STATUS(error, helper.Close()); RAISE_STATUS(error, status); return ADBC_STATUS_OK; @@ -573,11 +603,12 @@ AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value, if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_CATALOG) == 0) { output = PQdb(conn_); } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { - PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA"}; + PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA()"}; RAISE_STATUS(error, result_helper.Execute()); auto it = result_helper.begin(); if (it == result_helper.end()) { - SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'"); + SetError(error, + "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA()'"); return ADBC_STATUS_INTERNAL; } output = (*it)[0].data; @@ -989,7 +1020,6 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), result_helper.NumRows()), error); - ArrowError na_error; int row_counter = 0; for (auto row : result_helper) { const char* colname = row[0].data; @@ -997,14 +1027,15 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, static_cast(std::strtol(row[1].data, /*str_end=*/nullptr, /*base=*/10)); PostgresType pg_type; - if (type_resolver_->Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) { - SetError(error, "%s%d%s%s%s%" PRIu32, "Column #", row_counter + 1, " (\"", colname, - "\") has unknown type code ", pg_oid); + if (type_resolver_->FindWithDefault(pg_oid, &pg_type) != NANOARROW_OK) { + SetError(error, "%s%d%s%s%s%" PRIu32, "Error resolving type code for column #", + row_counter + 1, " (\"", colname, "\") with oid ", pg_oid); final_status = ADBC_STATUS_NOT_IMPLEMENTED; break; } CHECK_NA(INTERNAL, - pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter]), + pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter], + std::string(VendorName())), error); row_counter++; } @@ -1136,4 +1167,10 @@ AdbcStatusCode PostgresConnection::SetOptionInt(const char* key, int64_t value, return ADBC_STATUS_NOT_IMPLEMENTED; } +std::string_view PostgresConnection::VendorName() { return database_->VendorName(); } + +const std::array& PostgresConnection::VendorVersion() { + return database_->VendorVersion(); +} + } // namespace adbcpq diff --git a/c/driver/postgresql/connection.h b/c/driver/postgresql/connection.h index 787a7dcdae..7683875b5f 100644 --- a/c/driver/postgresql/connection.h +++ b/c/driver/postgresql/connection.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -73,6 +74,8 @@ class PostgresConnection { return type_resolver_; } bool autocommit() const { return autocommit_; } + std::string_view VendorName(); + const std::array& VendorVersion(); private: std::shared_ptr database_; diff --git a/c/driver/postgresql/copy/postgres_copy_reader_test.cc b/c/driver/postgresql/copy/postgres_copy_reader_test.cc index 60e0b6aaf1..7b9fe230f8 100644 --- a/c/driver/postgresql/copy/postgres_copy_reader_test.cc +++ b/c/driver/postgresql/copy/postgres_copy_reader_test.cc @@ -27,7 +27,7 @@ class PostgresCopyStreamTester { public: ArrowErrorCode Init(const PostgresType& root_type, ArrowError* error = nullptr) { NANOARROW_RETURN_NOT_OK(reader_.Init(root_type)); - NANOARROW_RETURN_NOT_OK(reader_.InferOutputSchema(error)); + NANOARROW_RETURN_NOT_OK(reader_.InferOutputSchema("PostgreSQL Tester", error)); NANOARROW_RETURN_NOT_OK(reader_.InitFieldReaders(error)); return NANOARROW_OK; } diff --git a/c/driver/postgresql/copy/reader.h b/c/driver/postgresql/copy/reader.h index 983f392264..07f91d545e 100644 --- a/c/driver/postgresql/copy/reader.h +++ b/c/driver/postgresql/copy/reader.h @@ -972,10 +972,11 @@ class PostgresCopyStreamReader { return NANOARROW_OK; } - ArrowErrorCode InferOutputSchema(ArrowError* error) { + ArrowErrorCode InferOutputSchema(const std::string& vendor_name, ArrowError* error) { schema_.reset(); ArrowSchemaInit(schema_.get()); - NANOARROW_RETURN_NOT_OK(root_reader_.InputType().SetSchema(schema_.get())); + NANOARROW_RETURN_NOT_OK( + root_reader_.InputType().SetSchema(schema_.get(), vendor_name)); return NANOARROW_OK; } diff --git a/c/driver/postgresql/database.cc b/c/driver/postgresql/database.cc index 97242ad58a..cdbad7535f 100644 --- a/c/driver/postgresql/database.cc +++ b/c/driver/postgresql/database.cc @@ -17,6 +17,8 @@ #include "database.h" +#include +#include #include #include #include @@ -28,6 +30,7 @@ #include #include "driver/common/utils.h" +#include "result_helper.h" namespace adbcpq { @@ -54,8 +57,19 @@ AdbcStatusCode PostgresDatabase::GetOptionDouble(const char* option, double* val } AdbcStatusCode PostgresDatabase::Init(struct AdbcError* error) { - // Connect to validate the parameters. - return RebuildTypeResolver(error); + // Connect to initialize the version information and build the type table + PGconn* conn = nullptr; + RAISE_ADBC(Connect(&conn, error)); + + Status status = InitVersions(conn); + if (!status.ok()) { + RAISE_ADBC(Disconnect(&conn, nullptr)); + return status.ToAdbc(error); + } + + status = RebuildTypeResolver(conn); + RAISE_ADBC(Disconnect(&conn, nullptr)); + return status.ToAdbc(error); } AdbcStatusCode PostgresDatabase::Release(struct AdbcError* error) { @@ -123,20 +137,87 @@ AdbcStatusCode PostgresDatabase::Disconnect(PGconn** conn, struct AdbcError* err return ADBC_STATUS_OK; } -// Helpers for building the type resolver from queries -static inline int32_t InsertPgAttributeResult( - PGresult* result, const std::shared_ptr& resolver); +namespace { + +// Parse an individual version in the form of "xxx.xxx.xxx". +// If the version components aren't numeric, they will be zero. +std::array ParseVersion(std::string_view version) { + std::array out{}; + size_t component = 0; + size_t component_begin = 0; + size_t component_end = 0; + + // While there are remaining version components and we haven't reached the end of the + // string + while (component_begin < version.size() && component < out.size()) { + // Find the next character that marks a version component separation or the end of the + // string + component_end = version.find_first_of(".-", component_begin); + if (component_end == version.npos) { + component_end = version.size(); + } -static inline int32_t InsertPgTypeResult( - PGresult* result, const std::shared_ptr& resolver); + // Try to parse the component as an integer (assigning zero if this fails) + int value = 0; + std::from_chars(version.data() + component_begin, version.data() + component_end, + value); + out[component] = value; -AdbcStatusCode PostgresDatabase::RebuildTypeResolver(struct AdbcError* error) { - PGconn* conn = nullptr; - AdbcStatusCode final_status = Connect(&conn, error); - if (final_status != ADBC_STATUS_OK) { - return final_status; + // Move on to the next component + component_begin = component_end + 1; + component_end = component_begin; + component++; + } + + return out; +} + +// Parse the PostgreSQL version() string that looks like: +// PostgreSQL 8.0.2 on i686-pc-linux-gnu, compiled by GCC gcc (GCC) 3.4.2 20041017 (Red +// Hat 3.4.2-6.fc3), Redshift 1.0.77467 +std::array ParsePrefixedVersion(std::string_view version_info, + std::string_view prefix) { + size_t pos = version_info.find(prefix); + if (pos == version_info.npos) { + return {0, 0, 0}; } + // Skip the prefix and any leading whitespace + pos = version_info.find_first_not_of(' ', pos + prefix.size()); + if (pos == version_info.npos) { + return {0, 0, 0}; + } + + return ParseVersion(version_info.substr(pos)); +} + +} // namespace + +Status PostgresDatabase::InitVersions(PGconn* conn) { + PqResultHelper helper(conn, "SELECT version();"); + UNWRAP_STATUS(helper.Execute()); + if (helper.NumRows() != 1 || helper.NumColumns() != 1) { + return Status::Internal("Expected 1 row and 1 column for SELECT version(); but got ", + helper.NumRows(), "/", helper.NumColumns()); + } + + std::string_view version_info = helper.Row(0)[0].value(); + postgres_server_version_ = ParsePrefixedVersion(version_info, "PostgreSQL"); + redshift_server_version_ = ParsePrefixedVersion(version_info, "Redshift"); + + return Status::Ok(); +} + +// Helpers for building the type resolver from queries +static std::string BuildPgTypeQuery(bool has_typarray); + +static Status InsertPgAttributeResult( + const PqResultHelper& result, const std::shared_ptr& resolver); + +static Status InsertPgTypeResult(const PqResultHelper& result, + const std::shared_ptr& resolver); + +Status PostgresDatabase::RebuildTypeResolver(PGconn* conn) { // We need a few queries to build the resolver. The current strategy might // fail for some recursive definitions (e.g., arrays of records of arrays). // First, one on the pg_attribute table to resolve column names/oids for @@ -156,147 +237,131 @@ ORDER BY // recursive definitions (e.g., record types with array column). This currently won't // handle range types because those rows don't have child OID information. Arrays types // are inserted after a successful insert of the element type. - const std::string kTypeQuery = R"( -SELECT - oid, - typname, - typreceive, - typbasetype, - typarray, - typrelid -FROM - pg_catalog.pg_type -WHERE - (typreceive != 0 OR typname = 'aclitem') AND typtype != 'r' AND typreceive::TEXT != 'array_recv' -ORDER BY - oid -)"; + std::string type_query = + BuildPgTypeQuery(/*has_typarray*/ redshift_server_version_[0] == 0); // Create a new type resolver (this instance's type_resolver_ member // will be updated at the end if this succeeds). auto resolver = std::make_shared(); // Insert record type definitions (this includes table schemas) - PGresult* result = PQexec(conn, kColumnsQuery.c_str()); - ExecStatusType pq_status = PQresultStatus(result); - if (pq_status == PGRES_TUPLES_OK) { - InsertPgAttributeResult(result, resolver); - } else { - SetError(error, "%s%s", - "[libpq] Failed to build type mapping table: ", PQerrorMessage(conn)); - final_status = ADBC_STATUS_IO; - } - - PQclear(result); + PqResultHelper columns(conn, kColumnsQuery.c_str()); + UNWRAP_STATUS(columns.Execute()); + UNWRAP_STATUS(InsertPgAttributeResult(columns, resolver)); // Attempt filling the resolver a few times to handle recursive definitions. int32_t max_attempts = 3; + PqResultHelper types(conn, type_query); for (int32_t i = 0; i < max_attempts; i++) { - result = PQexec(conn, kTypeQuery.c_str()); - ExecStatusType pq_status = PQresultStatus(result); - if (pq_status == PGRES_TUPLES_OK) { - InsertPgTypeResult(result, resolver); - } else { - SetError(error, "%s%s", - "[libpq] Failed to build type mapping table: ", PQerrorMessage(conn)); - final_status = ADBC_STATUS_IO; - } - - PQclear(result); - if (final_status != ADBC_STATUS_OK) { - break; - } + UNWRAP_STATUS(types.Execute()); + UNWRAP_STATUS(InsertPgTypeResult(types, resolver)); } - // Disconnect since PostgreSQL connections can be heavy. - { - AdbcStatusCode status = Disconnect(&conn, error); - if (status != ADBC_STATUS_OK) final_status = status; - } + type_resolver_ = std::move(resolver); + return Status::Ok(); +} - if (final_status == ADBC_STATUS_OK) { - type_resolver_ = std::move(resolver); +static std::string BuildPgTypeQuery(bool has_typarray) { + std::string maybe_typarray_col; + std::string maybe_array_recv_filter; + if (has_typarray) { + maybe_typarray_col = ", typarray"; + maybe_array_recv_filter = "AND typreceive::TEXT != 'array_recv'"; } - return final_status; + return std::string() + "SELECT oid, typname, typreceive, typbasetype, typrelid" + + maybe_typarray_col + " FROM pg_catalog.pg_type " + + " WHERE (typreceive != 0 OR typsend != 0) AND typtype != 'r' " + + maybe_array_recv_filter; } -static inline int32_t InsertPgAttributeResult( - PGresult* result, const std::shared_ptr& resolver) { - int num_rows = PQntuples(result); +static Status InsertPgAttributeResult( + const PqResultHelper& result, const std::shared_ptr& resolver) { + int num_rows = result.NumRows(); std::vector> columns; - uint32_t current_type_oid = 0; - int32_t n_added = 0; + int64_t current_type_oid = 0; + + if (result.NumColumns() != 3) { + return Status::Internal( + "Expected 3 columns from type resolver pg_attribute query but got ", + result.NumColumns()); + } for (int row = 0; row < num_rows; row++) { - const uint32_t type_oid = static_cast( - std::strtol(PQgetvalue(result, row, 0), /*str_end=*/nullptr, /*base=*/10)); - const char* col_name = PQgetvalue(result, row, 1); - const uint32_t col_oid = static_cast( - std::strtol(PQgetvalue(result, row, 2), /*str_end=*/nullptr, /*base=*/10)); + PqResultRow item = result.Row(row); + UNWRAP_RESULT(int64_t type_oid, item[0].ParseInteger()); + std::string_view col_name = item[1].value(); + UNWRAP_RESULT(int64_t col_oid, item[2].ParseInteger()); if (type_oid != current_type_oid && !columns.empty()) { resolver->InsertClass(current_type_oid, columns); columns.clear(); current_type_oid = type_oid; - n_added++; } - columns.push_back({col_name, col_oid}); + columns.push_back({std::string(col_name), static_cast(col_oid)}); } if (!columns.empty()) { - resolver->InsertClass(current_type_oid, columns); - n_added++; + resolver->InsertClass(static_cast(current_type_oid), columns); } - return n_added; + return Status::Ok(); } -static inline int32_t InsertPgTypeResult( - PGresult* result, const std::shared_ptr& resolver) { - int num_rows = PQntuples(result); - PostgresTypeResolver::Item item; - int32_t n_added = 0; +static Status InsertPgTypeResult(const PqResultHelper& result, + const std::shared_ptr& resolver) { + if (result.NumColumns() != 5 && result.NumColumns() != 6) { + return Status::Internal( + "Expected 5 or 6 columns from type resolver pg_type query but got ", + result.NumColumns()); + } + + int num_rows = result.NumRows(); + int num_cols = result.NumColumns(); + PostgresTypeResolver::Item type_item; for (int row = 0; row < num_rows; row++) { - const uint32_t oid = static_cast( - std::strtol(PQgetvalue(result, row, 0), /*str_end=*/nullptr, /*base=*/10)); - const char* typname = PQgetvalue(result, row, 1); - const char* typreceive = PQgetvalue(result, row, 2); - const uint32_t typbasetype = static_cast( - std::strtol(PQgetvalue(result, row, 3), /*str_end=*/nullptr, /*base=*/10)); - const uint32_t typarray = static_cast( - std::strtol(PQgetvalue(result, row, 4), /*str_end=*/nullptr, /*base=*/10)); - const uint32_t typrelid = static_cast( - std::strtol(PQgetvalue(result, row, 5), /*str_end=*/nullptr, /*base=*/10)); + PqResultRow item = result.Row(row); + UNWRAP_RESULT(int64_t oid, item[0].ParseInteger()); + const char* typname = item[1].data; + const char* typreceive = item[2].data; + UNWRAP_RESULT(int64_t typbasetype, item[3].ParseInteger()); + UNWRAP_RESULT(int64_t typrelid, item[4].ParseInteger()); + + int64_t typarray; + if (num_cols == 6) { + UNWRAP_RESULT(typarray, item[5].ParseInteger()); + } else { + typarray = 0; + } // Special case the aclitem because it shows up in a bunch of internal tables if (strcmp(typname, "aclitem") == 0) { typreceive = "aclitem_recv"; } - item.oid = oid; - item.typname = typname; - item.typreceive = typreceive; - item.class_oid = typrelid; - item.base_oid = typbasetype; + type_item.oid = static_cast(oid); + type_item.typname = typname; + type_item.typreceive = typreceive; + type_item.class_oid = static_cast(typrelid); + type_item.base_oid = static_cast(typbasetype); - int result = resolver->Insert(item, nullptr); + int result = resolver->Insert(type_item, nullptr); // If there's an array type and the insert succeeded, add that now too if (result == NANOARROW_OK && typarray != 0) { std::string array_typname = "_" + std::string(typname); - item.oid = typarray; - item.typname = array_typname.c_str(); - item.typreceive = "array_recv"; - item.child_oid = oid; + type_item.oid = typarray; + type_item.typname = array_typname.c_str(); + type_item.typreceive = "array_recv"; + type_item.child_oid = static_cast(oid); - resolver->Insert(item, nullptr); + resolver->Insert(type_item, nullptr); } } - return n_added; + return Status::Ok(); } } // namespace adbcpq diff --git a/c/driver/postgresql/database.h b/c/driver/postgresql/database.h index d246ea04a4..e0a00267e3 100644 --- a/c/driver/postgresql/database.h +++ b/c/driver/postgresql/database.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include #include @@ -24,9 +25,12 @@ #include #include +#include "driver/framework/status.h" #include "postgres_type.h" namespace adbcpq { +using adbc::driver::Status; + class PostgresDatabase { public: PostgresDatabase(); @@ -58,12 +62,29 @@ class PostgresDatabase { return type_resolver_; } - AdbcStatusCode RebuildTypeResolver(struct AdbcError* error); + Status InitVersions(PGconn* conn); + Status RebuildTypeResolver(PGconn* conn); + std::string_view VendorName() { + if (redshift_server_version_[0] != 0) { + return "Redshift"; + } else { + return "PostgreSQL"; + } + } + const std::array& VendorVersion() { + if (redshift_server_version_[0] != 0) { + return redshift_server_version_; + } else { + return postgres_server_version_; + } + } private: int32_t open_connections_; std::string uri_; std::shared_ptr type_resolver_; + std::array postgres_server_version_{}; + std::array redshift_server_version_{}; }; } // namespace adbcpq diff --git a/c/driver/postgresql/postgres_type.h b/c/driver/postgresql/postgres_type.h index b3cfc209ff..d2a5356293 100644 --- a/c/driver/postgresql/postgres_type.h +++ b/c/driver/postgresql/postgres_type.h @@ -111,7 +111,11 @@ enum class PostgresTypeId { kXid8, kXid, kXml, - kUserDefined + kUserDefined, + // This is not an actual type, but there are cases where all we have is an Oid + // that was not inserted into the type resolver. We can't use "unknown" or "opaque" + // or "void" because those names show up in actual pg_type tables. + kUnnamedArrowOpaque }; // Returns the receive function name as defined in the typrecieve column @@ -139,6 +143,11 @@ class PostgresType { PostgresType() : PostgresType(PostgresTypeId::kUninitialized) {} + static PostgresType Unnamed(uint32_t oid) { + return PostgresType(PostgresTypeId::kUnnamedArrowOpaque) + .WithPgTypeInfo(oid, "unnamed"); + } + void AppendChild(const std::string& field_name, const PostgresType& type) { PostgresType child(type); children_.push_back(child.WithFieldName(field_name)); @@ -204,7 +213,8 @@ class PostgresType { // do not have a corresponding Arrow type are returned as Binary with field // metadata ADBC:posgresql:typname. These types can be represented as their // binary COPY representation in the output. - ArrowErrorCode SetSchema(ArrowSchema* schema) const { + ArrowErrorCode SetSchema(ArrowSchema* schema, + const std::string& vendor_name = "PostgreSQL") const { switch (type_id_) { // ---- Primitive types -------------------- case PostgresTypeId::kBool: @@ -235,7 +245,7 @@ class PostgresType { // ---- Numeric/Decimal------------------- case PostgresTypeId::kNumeric: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_STRING)); - NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema)); + NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema, vendor_name)); break; @@ -290,13 +300,14 @@ class PostgresType { case PostgresTypeId::kRecord: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeStruct(schema, n_children())); for (int64_t i = 0; i < n_children(); i++) { - NANOARROW_RETURN_NOT_OK(children_[i].SetSchema(schema->children[i])); + NANOARROW_RETURN_NOT_OK( + children_[i].SetSchema(schema->children[i], vendor_name)); } break; case PostgresTypeId::kArray: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_LIST)); - NANOARROW_RETURN_NOT_OK(children_[0].SetSchema(schema->children[0])); + NANOARROW_RETURN_NOT_OK(children_[0].SetSchema(schema->children[0], vendor_name)); break; case PostgresTypeId::kUserDefined: @@ -305,7 +316,7 @@ class PostgresType { // can still return the bytes postgres gives us and attach the type name as // metadata NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_BINARY)); - NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema)); + NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema, vendor_name)); break; } @@ -329,7 +340,8 @@ class PostgresType { static constexpr const char* kOpaqueExtensionName = "arrow.opaque"; static constexpr const char* kExtensionMetadata = "ARROW:extension:metadata"; - ArrowErrorCode AddPostgresTypeMetadata(ArrowSchema* schema) const { + ArrowErrorCode AddPostgresTypeMetadata(ArrowSchema* schema, + const std::string& vendor_name) const { // the typname_ may not always be set: an instance of this class can be // created with just the type id. That's why there is this here fallback to // resolve the type name of built-in types. @@ -346,7 +358,7 @@ class PostgresType { // Add the Opaque extension type metadata std::string metadata = R"({"type_name": ")"; metadata += typname; - metadata += R"(", "vendor_name": "PostgreSQL"})"; + metadata += R"(", "vendor_name": ")" + vendor_name + R"("})"; NANOARROW_RETURN_NOT_OK( ArrowMetadataBuilderAppend(buffer.get(), ArrowCharView(kExtensionName), ArrowCharView(kOpaqueExtensionName))); @@ -395,7 +407,18 @@ class PostgresTypeResolver { return EINVAL; } - *type_out = (*result).second; + *type_out = result->second; + return NANOARROW_OK; + } + + ArrowErrorCode FindWithDefault(uint32_t oid, PostgresType* type_out) { + auto result = mapping_.find(oid); + if (result == mapping_.end()) { + *type_out = PostgresType::Unnamed(oid); + } else { + *type_out = result->second; + } + return NANOARROW_OK; } diff --git a/c/driver/postgresql/postgres_type_test.cc b/c/driver/postgresql/postgres_type_test.cc index 2e713204f4..2c76f4c1f4 100644 --- a/c/driver/postgresql/postgres_type_test.cc +++ b/c/driver/postgresql/postgres_type_test.cc @@ -337,6 +337,11 @@ TEST(PostgresTypeTest, PostgresTypeResolver) { EXPECT_EQ(resolver.Find(123, &type, &error), EINVAL); EXPECT_STREQ(ArrowErrorMessage(&error), "Postgres type with oid 123 not found"); + EXPECT_EQ(resolver.FindWithDefault(123, &type), NANOARROW_OK); + EXPECT_EQ(type.oid(), 123); + EXPECT_EQ(type.type_id(), PostgresTypeId::kUnnamedArrowOpaque); + EXPECT_EQ(type.typname(), "unnamed"); + // Check error for Array with unknown child item.oid = 123; item.typname = "some_array"; diff --git a/c/driver/postgresql/result_helper.h b/c/driver/postgresql/result_helper.h index 612573edad..7eb2c27ad5 100644 --- a/c/driver/postgresql/result_helper.h +++ b/c/driver/postgresql/result_helper.h @@ -167,7 +167,7 @@ class PqResultHelper { return PQfname(result_, column_number); } Oid FieldType(int column_number) const { return PQftype(result_, column_number); } - PqResultRow Row(int i) { return PqResultRow(result_, i); } + PqResultRow Row(int i) const { return PqResultRow(result_, i); } class iterator { const PqResultHelper& outer_; diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc index c350ab8a3e..464bad74a7 100644 --- a/c/driver/postgresql/result_reader.cc +++ b/c/driver/postgresql/result_reader.cc @@ -174,10 +174,10 @@ Status PqResultArrayReader::Initialize(int64_t* rows_affected) { for (int i = 0; i < helper_.NumColumns(); i++) { PostgresType child_type; - UNWRAP_NANOARROW(na_error_, Internal, - type_resolver_->Find(helper_.FieldType(i), &child_type, &na_error_)); + UNWRAP_ERRNO(Internal, + type_resolver_->FindWithDefault(helper_.FieldType(i), &child_type)); - UNWRAP_ERRNO(Internal, child_type.SetSchema(schema_->children[i])); + UNWRAP_ERRNO(Internal, child_type.SetSchema(schema_->children[i], vendor_name_)); UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema_->children[i], helper_.FieldName(i))); diff --git a/c/driver/postgresql/result_reader.h b/c/driver/postgresql/result_reader.h index 5c36dccb2f..90b35baf06 100644 --- a/c/driver/postgresql/result_reader.h +++ b/c/driver/postgresql/result_reader.h @@ -58,6 +58,10 @@ class PqResultArrayReader { bind_stream_->SetBind(stream); } + void SetVendorName(std::string_view vendor_name) { + vendor_name_ = std::string(vendor_name); + } + int GetSchema(struct ArrowSchema* out); int GetNext(struct ArrowArray* out); const char* GetLastError(); @@ -74,6 +78,7 @@ class PqResultArrayReader { std::vector> field_readers_; nanoarrow::UniqueSchema schema_; bool autocommit_; + std::string vendor_name_; struct AdbcError error_; struct ArrowError na_error_; diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 32558b4946..129ddebff8 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -459,6 +459,7 @@ AdbcStatusCode PostgresStatement::ExecuteBind(struct ArrowArrayStream* stream, PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); reader.SetAutocommit(connection_->autocommit()); reader.SetBind(&bind_); + reader.SetVendorName(connection_->VendorName()); RAISE_STATUS(error, reader.ToArrayStream(rows_affected, stream)); return ADBC_STATUS_OK; } @@ -485,8 +486,9 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, // If we have been requested to avoid COPY or there is no output requested, // execute using the PqResultArrayReader. - if (!stream || !use_copy_) { + if (!stream || !UseCopy()) { PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); + reader.SetVendorName(connection_->VendorName()); RAISE_STATUS(error, reader.ToArrayStream(rows_affected, stream)); return ADBC_STATUS_OK; } @@ -505,6 +507,7 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, if (root_type.n_children() == 0) { // Could/should move the helper into the reader instead of repreparing PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); + reader.SetVendorName(connection_->VendorName()); RAISE_STATUS(error, reader.ToArrayStream(rows_affected, stream)); return ADBC_STATUS_OK; } @@ -512,8 +515,10 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, struct ArrowError na_error; reader_.copy_reader_ = std::make_unique(); CHECK_NA(INTERNAL, reader_.copy_reader_->Init(root_type), error); - CHECK_NA_DETAIL(INTERNAL, reader_.copy_reader_->InferOutputSchema(&na_error), &na_error, - error); + CHECK_NA_DETAIL(INTERNAL, + reader_.copy_reader_->InferOutputSchema( + std::string(connection_->VendorName()), &na_error), + &na_error, error); CHECK_NA_DETAIL(INTERNAL, reader_.copy_reader_->InitFieldReaders(&na_error), &na_error, error); @@ -574,7 +579,9 @@ AdbcStatusCode PostgresStatement::ExecuteSchema(struct ArrowSchema* schema, nanoarrow::UniqueSchema tmp; ArrowSchemaInit(tmp.get()); - CHECK_NA(INTERNAL, output_type.SetSchema(tmp.get()), error); + CHECK_NA(INTERNAL, + output_type.SetSchema(tmp.get(), std::string(connection_->VendorName())), + error); tmp.move(schema); return ADBC_STATUS_OK; @@ -597,11 +604,12 @@ AdbcStatusCode PostgresStatement::ExecuteIngest(struct ArrowArrayStream* stream, // This is a little unfortunate; we need another DB roundtrip std::string current_schema; { - PqResultHelper result_helper{connection_->conn(), "SELECT CURRENT_SCHEMA"}; + PqResultHelper result_helper{connection_->conn(), "SELECT CURRENT_SCHEMA()"}; RAISE_STATUS(error, result_helper.Execute()); auto it = result_helper.begin(); if (it == result_helper.end()) { - SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'"); + SetError(error, + "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA()'"); return ADBC_STATUS_INTERNAL; } current_schema = (*it)[0].data; @@ -666,7 +674,7 @@ AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { result = std::to_string(reader_.batch_size_hint_bytes_); } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) { - if (use_copy_) { + if (UseCopy()) { result = "true"; } else { result = "false"; @@ -838,4 +846,12 @@ void PostgresStatement::ClearResult() { reader_.Release(); } +int PostgresStatement::UseCopy() { + if (use_copy_ == -1) { + return connection_->VendorName() != "Redshift"; + } else { + return use_copy_; + } +} + } // namespace adbcpq diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h index 9e79f41ed8..60ada992b0 100644 --- a/c/driver/postgresql/statement.h +++ b/c/driver/postgresql/statement.h @@ -97,7 +97,7 @@ class PostgresStatement { : connection_(nullptr), query_(), prepared_(false), - use_copy_(true), + use_copy_(-1), reader_(nullptr) { std::memset(&bind_, 0, sizeof(bind_)); } @@ -161,7 +161,7 @@ class PostgresStatement { }; // Options - bool use_copy_; + int use_copy_; struct { std::string db_schema; @@ -171,5 +171,7 @@ class PostgresStatement { } ingest_; TupleReader reader_; + + int UseCopy(); }; } // namespace adbcpq