From 4ff7bb2b0591697a76d90f4a39b9e09de23ed7db Mon Sep 17 00:00:00 2001 From: Lubo Slivka Date: Sat, 3 Feb 2024 19:44:56 +0100 Subject: [PATCH] feat(c/driver/postgresql): customize numeric conversion - introduces statement-level option `adbc.postgresql.numeric_conversion` - the option is used to tell result reader what strategy to use when converting numeric values to Arrow data; since this cannot be done 1-1, the reader has to convert to other data type - clients can use this option to specify the strategy - value can be either `to_string` or `to_double` - when not specified defaults to `to_string` - `to_string` -> numerics converted loss-less to string representation - when `to_double` -> numeric converted to double (with possible loss of precision) --- .../copy/postgres_copy_reader_test.cc | 111 ++++++- c/driver/postgresql/copy/reader.h | 277 +++++++++++++----- c/driver/postgresql/postgres_type.h | 20 +- c/driver/postgresql/statement.cc | 15 +- c/driver/postgresql/statement.h | 7 + 5 files changed, 346 insertions(+), 84 deletions(-) diff --git a/c/driver/postgresql/copy/postgres_copy_reader_test.cc b/c/driver/postgresql/copy/postgres_copy_reader_test.cc index 0d85c256ec..c1656c17a2 100644 --- a/c/driver/postgresql/copy/postgres_copy_reader_test.cc +++ b/c/driver/postgresql/copy/postgres_copy_reader_test.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include #include "postgres_copy_test_common.h" @@ -25,8 +26,11 @@ namespace adbcpq { class PostgresCopyStreamTester { public: - ArrowErrorCode Init(const PostgresType& root_type, ArrowError* error = nullptr) { - NANOARROW_RETURN_NOT_OK(reader_.Init(root_type)); + ArrowErrorCode Init( + const PostgresType& root_type, + NumericConversionStrategy numeric_conversion = NumericConversionStrategy::kToString, + ArrowError* error = nullptr) { + NANOARROW_RETURN_NOT_OK(reader_.Init(root_type, numeric_conversion)); NANOARROW_RETURN_NOT_OK(reader_.InferOutputSchema(error)); NANOARROW_RETURN_NOT_OK(reader_.InitFieldReaders(error)); return NANOARROW_OK; @@ -373,6 +377,59 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadNumeric) { EXPECT_EQ(std::string(item.data, item.size_bytes), "inf"); } +TEST(PostgresCopyUtilsTest, PostgresCopyReadNumericToDouble) { + ArrowBufferView data; + data.data.as_uint8 = kTestPgCopyNumeric; + data.size_bytes = sizeof(kTestPgCopyNumeric); + + auto col_type = PostgresType(PostgresTypeId::kNumeric); + PostgresType input_type(PostgresTypeId::kRecord); + input_type.AppendChild("col", col_type); + + PostgresCopyStreamTester tester; + ASSERT_EQ(tester.Init(input_type, NumericConversionStrategy::kToDouble), NANOARROW_OK); + ASSERT_EQ(tester.ReadAll(&data), ENODATA); + ASSERT_EQ(data.data.as_uint8 - kTestPgCopyNumeric, sizeof(kTestPgCopyNumeric)); + ASSERT_EQ(data.size_bytes, 0); + + nanoarrow::UniqueArray array; + ASSERT_EQ(tester.GetArray(array.get()), NANOARROW_OK); + ASSERT_EQ(array->length, 9); + ASSERT_EQ(array->n_children, 1); + + nanoarrow::UniqueSchema schema; + tester.GetSchema(schema.get()); + + nanoarrow::UniqueArrayView array_view; + ASSERT_EQ(ArrowArrayViewInitFromSchema(array_view.get(), schema.get(), nullptr), + NANOARROW_OK); + ASSERT_EQ(array_view->children[0]->storage_type, NANOARROW_TYPE_DOUBLE); + ASSERT_EQ(ArrowArrayViewSetArray(array_view.get(), array.get(), nullptr), NANOARROW_OK); + + auto validity = reinterpret_cast(array->children[0]->buffers[0]); + auto data_buffer = reinterpret_cast(array->children[0]->buffers[1]); + ASSERT_NE(validity, nullptr); + ASSERT_NE(data_buffer, nullptr); + ASSERT_TRUE(ArrowBitGet(validity, 0)); + ASSERT_TRUE(ArrowBitGet(validity, 1)); + ASSERT_TRUE(ArrowBitGet(validity, 2)); + ASSERT_TRUE(ArrowBitGet(validity, 3)); + ASSERT_TRUE(ArrowBitGet(validity, 4)); + ASSERT_TRUE(ArrowBitGet(validity, 5)); + ASSERT_TRUE(ArrowBitGet(validity, 6)); + ASSERT_TRUE(ArrowBitGet(validity, 7)); + ASSERT_FALSE(ArrowBitGet(validity, 8)); + + ASSERT_DOUBLE_EQ(data_buffer[0], 1000000); + ASSERT_DOUBLE_EQ(data_buffer[1], 0.00001234); + ASSERT_DOUBLE_EQ(data_buffer[2], 1.0); + ASSERT_DOUBLE_EQ(data_buffer[3], -123.456); + ASSERT_DOUBLE_EQ(data_buffer[4], 123.456); + ASSERT_TRUE(std::isnan(data_buffer[5])); + ASSERT_TRUE(data_buffer[6] == -std::numeric_limits::infinity()); + ASSERT_TRUE(data_buffer[7] == std::numeric_limits::infinity()); +} + TEST(PostgresCopyUtilsTest, PostgresCopyReadNumeric16_10) { ArrowBufferView data; data.data.as_uint8 = kTestPgCopyNumeric16_10; @@ -427,6 +484,56 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadNumeric16_10) { EXPECT_EQ(std::string(item.data, item.size_bytes), "nan"); } +TEST(PostgresCopyUtilsTest, PostgresCopyReadNumeric16_10ToDouble) { + ArrowBufferView data; + data.data.as_uint8 = kTestPgCopyNumeric16_10; + data.size_bytes = sizeof(kTestPgCopyNumeric16_10); + + auto col_type = PostgresType(PostgresTypeId::kNumeric); + PostgresType input_type(PostgresTypeId::kRecord); + input_type.AppendChild("col", col_type); + + PostgresCopyStreamTester tester; + ASSERT_EQ(tester.Init(input_type, NumericConversionStrategy::kToDouble), NANOARROW_OK); + ASSERT_EQ(tester.ReadAll(&data), ENODATA); + ASSERT_EQ(data.data.as_uint8 - kTestPgCopyNumeric16_10, + sizeof(kTestPgCopyNumeric16_10)); + ASSERT_EQ(data.size_bytes, 0); + + nanoarrow::UniqueArray array; + ASSERT_EQ(tester.GetArray(array.get()), NANOARROW_OK); + ASSERT_EQ(array->length, 7); + ASSERT_EQ(array->n_children, 1); + + nanoarrow::UniqueSchema schema; + tester.GetSchema(schema.get()); + + nanoarrow::UniqueArrayView array_view; + ASSERT_EQ(ArrowArrayViewInitFromSchema(array_view.get(), schema.get(), nullptr), + NANOARROW_OK); + ASSERT_EQ(array_view->children[0]->storage_type, NANOARROW_TYPE_DOUBLE); + ASSERT_EQ(ArrowArrayViewSetArray(array_view.get(), array.get(), nullptr), NANOARROW_OK); + + auto validity = reinterpret_cast(array->children[0]->buffers[0]); + auto data_buffer = reinterpret_cast(array->children[0]->buffers[1]); + ASSERT_NE(validity, nullptr); + ASSERT_NE(data_buffer, nullptr); + ASSERT_TRUE(ArrowBitGet(validity, 0)); + ASSERT_TRUE(ArrowBitGet(validity, 1)); + ASSERT_TRUE(ArrowBitGet(validity, 2)); + ASSERT_TRUE(ArrowBitGet(validity, 3)); + ASSERT_TRUE(ArrowBitGet(validity, 4)); + ASSERT_TRUE(ArrowBitGet(validity, 5)); + ASSERT_FALSE(ArrowBitGet(validity, 6)); + + ASSERT_DOUBLE_EQ(data_buffer[0], 0.0); + ASSERT_DOUBLE_EQ(data_buffer[1], 1.01234); + ASSERT_DOUBLE_EQ(data_buffer[2], 1.0123456789); + ASSERT_DOUBLE_EQ(data_buffer[3], -1.0123400000); + ASSERT_DOUBLE_EQ(data_buffer[4], -1.0123456789); + ASSERT_TRUE(std::isnan(data_buffer[5])); +} + TEST(PostgresCopyUtilsTest, PostgresCopyReadTimestamp) { ArrowBufferView data; data.data.as_uint8 = kTestPgCopyTimestamp; diff --git a/c/driver/postgresql/copy/reader.h b/c/driver/postgresql/copy/reader.h index c3c9acb326..1035e56d43 100644 --- a/c/driver/postgresql/copy/reader.h +++ b/c/driver/postgresql/copy/reader.h @@ -18,7 +18,9 @@ #pragma once #include +#include #include +#include #include #include #include @@ -242,10 +244,9 @@ class PostgresCopyIntervalFieldReader : public PostgresCopyFieldReader { } }; -// // Converts COPY resulting from the Postgres NUMERIC type into a string. -// Rewritten based on the Postgres implementation of NUMERIC cast to string in -// src/backend/utils/adt/numeric.c : get_str_from_var() (Note that in the initial source, -// DEC_DIGITS is always 4 and DBASE is always 10000). +// Base class for readers of Postgres NUMERIC type. Code in this class provides +// common utility methods that are useful for both conversion of NUMERIC to +// Arrow string and NUMERIC to Arrow double. // // Briefly, the Postgres representation of "numeric" is an array of int16_t ("digits") // from most significant to least significant. Each "digit" is a value between 0000 and @@ -253,15 +254,48 @@ class PostgresCopyIntervalFieldReader : public PostgresCopyFieldReader { // decimal point. Both of those values can be zero or negative. A "sign" component // encodes the positive or negativeness of the value and is also used to encode special // values (inf, -inf, and nan). +// +// The methods implemented here are responsible for reading input data and preparing the +// string representation of the value. +// +// The conversion methods are rewritten based on the Postgres implementation of +// NUMERIC cast to string in src/backend/utils/adt/numeric.c : get_str_from_var() ( +// Note that in the initial source, DEC_DIGITS is always 4 and DBASE is always 10000). class PostgresCopyNumericFieldReader : public PostgresCopyFieldReader { - public: - ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array, - ArrowError* error) override { - // -1 for NULL - if (field_size_bytes < 0) { - return ArrowArrayAppendNull(array, 1); - } + protected: + // Number of decimal digits per Postgres digit + static const int kDecDigits = 4; + // The "base" of the Postgres representation (i.e., each "digit" is 0 to 9999) + static const int kNBase = 10000; + // Valid values for the sign component + static const uint16_t kNumericPos = 0x0000; + static const uint16_t kNumericNeg = 0x4000; + static const uint16_t kNumericNAN = 0xC000; + static const uint16_t kNumericPinf = 0xD000; + static const uint16_t kNumericNinf = 0xF000; + + int16_t ndigits_; + int16_t weight_; + uint16_t sign_; + uint16_t dscale_; + std::vector digits_; + // Returns maximum number of characters required to hold + // string representation of NUMERIC value. + int64_t max_chars_required_() const { + int64_t max_chars_required = std::max(1, (weight_ + 1) * kDecDigits); + max_chars_required += dscale_ + kDecDigits + 2; + + return max_chars_required; + } + + // Reads all data for a single NUMERIC value. + // + // If the input has issues, returns non-zero error code and sets the + // Arrow error. + // + // On success, populates ndigits_, weight_, sign_, dscale_ and digits_. + ArrowErrorCode ReadInputDigit(ArrowBufferView* data, ArrowError* error) { // Read the input if (data->size_bytes < static_cast(4 * sizeof(int16_t))) { ArrowErrorSet(error, @@ -272,64 +306,44 @@ class PostgresCopyNumericFieldReader : public PostgresCopyFieldReader { return EINVAL; } - int16_t ndigits = ReadUnsafe(data); - int16_t weight = ReadUnsafe(data); - uint16_t sign = ReadUnsafe(data); - uint16_t dscale = ReadUnsafe(data); + ndigits_ = ReadUnsafe(data); + weight_ = ReadUnsafe(data); + sign_ = ReadUnsafe(data); + dscale_ = ReadUnsafe(data); - if (data->size_bytes < static_cast(ndigits * sizeof(int16_t))) { + if (data->size_bytes < static_cast(ndigits_ * sizeof(int16_t))) { ArrowErrorSet(error, "Expected at least %d bytes of field data for numeric digits copy " "data but only %d bytes of input remain", - static_cast(ndigits * sizeof(int16_t)), + static_cast(ndigits_ * sizeof(int16_t)), static_cast(data->size_bytes)); // NOLINT(runtime/int) return EINVAL; } digits_.clear(); - for (int16_t i = 0; i < ndigits; i++) { + for (int16_t i = 0; i < ndigits_; i++) { digits_.push_back(ReadUnsafe(data)); } - // Handle special values - std::string special_value; - switch (sign) { - case kNumericNAN: - special_value = std::string("nan"); - break; - case kNumericPinf: - special_value = std::string("inf"); - break; - case kNumericNinf: - special_value = std::string("-inf"); - break; - case kNumericPos: - case kNumericNeg: - special_value = std::string(""); - break; - default: - ArrowErrorSet(error, - "Unexpected value for sign read from Postgres numeric field: %d", - static_cast(sign)); - return EINVAL; - } - - if (!special_value.empty()) { - NANOARROW_RETURN_NOT_OK( - ArrowBufferAppend(data_, special_value.data(), special_value.size())); - NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(offsets_, data_->size_bytes)); - return AppendValid(array); - } + return 0; + } - // Calculate string space requirement - int64_t max_chars_required = std::max(1, (weight + 1) * kDecDigits); - max_chars_required += dscale + kDecDigits + 2; - NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(data_, max_chars_required)); - char* out0 = reinterpret_cast(data_->data + data_->size_bytes); - char* out = out0; + // Converts NUMERIC value to string. The result is written to the target and will + // not be null-terminated. + // + // This code is a rewrite of PostgreSQL function get_str_from_var() found in + // src/backend/utils/adt/numeric.c. + // + // This method has two assumptions: + // + // - the target buffer is allocated and large enough to hold the result + // - the NUMERIC value is non-special value (e.g. not +/- infinity or NaN) + int64_t DigitsToString(char** target) { + char* out = *target; + char* out0 = *target; // Build output string in-place, starting with the negative sign - if (sign == kNumericNeg) { + if (sign_ == kNumericNeg) { *out++ = '-'; } @@ -338,12 +352,12 @@ class PostgresCopyNumericFieldReader : public PostgresCopyFieldReader { int d1; int16_t dig; - if (weight < 0) { - d = weight + 1; + if (weight_ < 0) { + d = weight_ + 1; *out++ = '0'; } else { - for (d = 0; d <= weight; d++) { - if (d < ndigits) { + for (d = 0; d <= weight_; d++) { + if (d < ndigits_) { dig = digits_[d]; } else { dig = 0; @@ -370,12 +384,12 @@ class PostgresCopyNumericFieldReader : public PostgresCopyFieldReader { // keep here. int64_t actual_chars_required = out - out0; - if (dscale > 0) { + if (dscale_ > 0) { *out++ = '.'; - actual_chars_required += dscale + 1; + actual_chars_required += dscale_ + 1; - for (int i = 0; i < dscale; d++, i += kDecDigits) { - if (d >= 0 && d < ndigits) { + for (int i = 0; i < dscale_; d++, i += kDecDigits) { + if (d >= 0 && d < ndigits_) { dig = digits_[d]; } else { dig = 0; @@ -391,25 +405,126 @@ class PostgresCopyNumericFieldReader : public PostgresCopyFieldReader { } } + return actual_chars_required; + } +}; + +// Converts COPY resulting from the Postgres NUMERIC type into a string. +class PostgresCopyNumericToStrFieldReader : public PostgresCopyNumericFieldReader { + public: + ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array, + ArrowError* error) override { + // -1 for NULL + if (field_size_bytes < 0) { + return ArrowArrayAppendNull(array, 1); + } + + ArrowErrorCode digit_error = ReadInputDigit(data, error); + if (digit_error) { + return digit_error; + } + + // Handle special values + std::string special_value; + switch (sign_) { + case kNumericNAN: + special_value = std::string("nan"); + break; + case kNumericPinf: + special_value = std::string("inf"); + break; + case kNumericNinf: + special_value = std::string("-inf"); + break; + case kNumericPos: + case kNumericNeg: + special_value = std::string(""); + break; + default: + ArrowErrorSet(error, + "Unexpected value for sign read from Postgres numeric field: %d", + static_cast(sign_)); + return EINVAL; + } + + if (!special_value.empty()) { + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(data_, special_value.data(), special_value.size())); + NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(offsets_, data_->size_bytes)); + return AppendValid(array); + } + + int64_t max_chars_required = max_chars_required_(); + NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(data_, max_chars_required)); + char* out = reinterpret_cast(data_->data + data_->size_bytes); + + int64_t actual_chars_required = DigitsToString(&out); + // Update data buffer size and add offsets data_->size_bytes += actual_chars_required; NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(offsets_, data_->size_bytes)); return AppendValid(array); } +}; - private: - std::vector digits_; +// Converts COPY resulting from the Postgres NUMERIC type into a double. +// +// Similar to Postgres numericvar_to_double_no_overflow() method found in +// src/backend/utils/adt/numeric.c, this reader will first convert the NUMERIC +// to string and then use strtod() to get a double value. +class PostgresCopyNumericToDoubleFieldReader : public PostgresCopyNumericFieldReader { + public: + ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array, + ArrowError* error) override { + // -1 for NULL + if (field_size_bytes < 0) { + return ArrowArrayAppendNull(array, 1); + } - // Number of decimal digits per Postgres digit - static const int kDecDigits = 4; - // The "base" of the Postgres representation (i.e., each "digit" is 0 to 9999) - static const int kNBase = 10000; - // Valid values for the sign component - static const uint16_t kNumericPos = 0x0000; - static const uint16_t kNumericNeg = 0x4000; - static const uint16_t kNumericNAN = 0xC000; - static const uint16_t kNumericPinf = 0xD000; - static const uint16_t kNumericNinf = 0xF000; + ArrowErrorCode digit_error = ReadInputDigit(data, error); + if (digit_error) { + return digit_error; + } + + double value; + bool special_value = false; + switch (sign_) { + case kNumericNAN: + value = std::numeric_limits::quiet_NaN(); + special_value = true; + break; + case kNumericPinf: + value = std::numeric_limits::infinity(); + special_value = true; + break; + case kNumericNinf: + value = -std::numeric_limits::infinity(); + special_value = true; + break; + case kNumericPos: + case kNumericNeg: + break; + default: + ArrowErrorSet(error, + "Unexpected value for sign read from Postgres numeric field: %d", + static_cast(sign_)); + return EINVAL; + } + + if (special_value) { + NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &value, sizeof(double))); + return AppendValid(array); + } + + int64_t max_chars_required = max_chars_required_(); + char* target = new char[max_chars_required]; + int64_t actual_characters_required = DigitsToString(&target); + std::from_chars(target, target + actual_characters_required, value); + delete[] target; + + NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &value, sizeof(double))); + return AppendValid(array); + } }; // Reader for Pg->Arrow conversions whose Arrow representation is simply the @@ -761,6 +876,9 @@ static inline ArrowErrorCode MakeCopyFieldReader( case PostgresTypeId::kFloat8: *out = std::make_unique>(); return NANOARROW_OK; + case PostgresTypeId::kNumeric: + *out = std::make_unique(); + return NANOARROW_OK; default: return ErrorCantConvert(error, pg_type, schema_view); } @@ -776,7 +894,7 @@ static inline ArrowErrorCode MakeCopyFieldReader( *out = std::make_unique(); return NANOARROW_OK; case PostgresTypeId::kNumeric: - *out = std::make_unique(); + *out = std::make_unique(); return NANOARROW_OK; default: return ErrorCantConvert(error, pg_type, schema_view); @@ -885,7 +1003,8 @@ static inline ArrowErrorCode MakeCopyFieldReader( class PostgresCopyStreamReader { public: - ArrowErrorCode Init(PostgresType pg_type) { + ArrowErrorCode Init(PostgresType pg_type, NumericConversionStrategy numeric_conversion = + NumericConversionStrategy::kToString) { if (pg_type.type_id() != PostgresTypeId::kRecord) { return EINVAL; } @@ -893,6 +1012,8 @@ class PostgresCopyStreamReader { pg_type_ = std::move(pg_type); root_reader_.Init(pg_type_); array_size_approx_bytes_ = 0; + numeric_conversion_ = numeric_conversion; + return NANOARROW_OK; } @@ -924,7 +1045,8 @@ class PostgresCopyStreamReader { ArrowErrorCode InferOutputSchema(ArrowError* error) { schema_.reset(); ArrowSchemaInit(schema_.get()); - NANOARROW_RETURN_NOT_OK(root_reader_.InputType().SetSchema(schema_.get())); + NANOARROW_RETURN_NOT_OK( + root_reader_.InputType().SetSchema(schema_.get(), numeric_conversion_)); return NANOARROW_OK; } @@ -1023,6 +1145,7 @@ class PostgresCopyStreamReader { nanoarrow::UniqueSchema schema_; nanoarrow::UniqueArray array_; int64_t array_size_approx_bytes_; + NumericConversionStrategy numeric_conversion_ = NumericConversionStrategy::kToString; }; } // namespace adbcpq diff --git a/c/driver/postgresql/postgres_type.h b/c/driver/postgresql/postgres_type.h index dc5d38784e..1017721f62 100644 --- a/c/driver/postgresql/postgres_type.h +++ b/c/driver/postgresql/postgres_type.h @@ -130,6 +130,9 @@ static inline std::vector PostgresTypeIdAll(bool nested = true); class PostgresTypeResolver; +/// \brief Strategy to use when converting received NUMERIC values. +enum class NumericConversionStrategy { kToString, kToDouble }; + // An abstraction of a (potentially nested and/or parameterized) Postgres // data type. This class is where default type conversion to/from Arrow // is defined. It is intentionally copyable. @@ -191,7 +194,9 @@ class PostgresType { // do not have a corresponding Arrow type are returned as Binary with field // metadata ADBC:posgresql:typname. These types can be represented as their // binary COPY representation in the output. - ArrowErrorCode SetSchema(ArrowSchema* schema) const { + ArrowErrorCode SetSchema(ArrowSchema* schema, + NumericConversionStrategy numeric_conversion = + NumericConversionStrategy::kToString) const { switch (type_id_) { // ---- Primitive types -------------------- case PostgresTypeId::kBool: @@ -217,7 +222,12 @@ class PostgresType { // ---- Numeric/Decimal------------------- case PostgresTypeId::kNumeric: - NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_STRING)); + if (numeric_conversion == NumericConversionStrategy::kToDouble) { + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_DOUBLE)); + } else { + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_STRING)); + } + NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema)); break; @@ -271,13 +281,15 @@ class PostgresType { case PostgresTypeId::kRecord: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeStruct(schema, n_children())); for (int64_t i = 0; i < n_children(); i++) { - NANOARROW_RETURN_NOT_OK(children_[i].SetSchema(schema->children[i])); + NANOARROW_RETURN_NOT_OK( + children_[i].SetSchema(schema->children[i], numeric_conversion)); } break; case PostgresTypeId::kArray: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_LIST)); - NANOARROW_RETURN_NOT_OK(children_[0].SetSchema(schema->children[0])); + NANOARROW_RETURN_NOT_OK( + children_[0].SetSchema(schema->children[0], numeric_conversion)); break; case PostgresTypeId::kUserDefined: diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index c599832904..a96f9478cd 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -1361,6 +1361,10 @@ AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t } } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { result = std::to_string(reader_.batch_size_hint_bytes_); + } else if (std::strcmp(key, ADBC_POSTGRESQL_NUMERIC_CONVERSION) == 0) { + result = numeric_conversion_ == NumericConversionStrategy::kToDouble + ? ADBC_POSTGRESQL_NC_OPTION_TO_DOUBLE + : ADBC_POSTGRESQL_NC_OPTION_TO_STRING; } else { SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_FOUND; @@ -1480,6 +1484,15 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value, } this->reader_.batch_size_hint_bytes_ = int_value; + } else if (std::strcmp(key, ADBC_POSTGRESQL_NUMERIC_CONVERSION) == 0) { + if (std::strcmp(value, ADBC_POSTGRESQL_NC_OPTION_TO_STRING) == 0) { + numeric_conversion_ = NumericConversionStrategy::kToString; + } else if (std::strcmp(value, ADBC_POSTGRESQL_NC_OPTION_TO_DOUBLE) == 0) { + numeric_conversion_ = NumericConversionStrategy::kToDouble; + } else { + SetError(error, "[libpq] Invalid value '%s' for option '%s'", value, key); + return ADBC_STATUS_INVALID_ARGUMENT; + } } else { SetError(error, "[libpq] Unknown statement option '%s'", key); return ADBC_STATUS_NOT_IMPLEMENTED; @@ -1548,7 +1561,7 @@ AdbcStatusCode PostgresStatement::SetupReader(struct AdbcError* error) { // Initialize the copy reader and infer the output schema (i.e., error for // unsupported types before issuing the COPY query) reader_.copy_reader_.reset(new PostgresCopyStreamReader()); - reader_.copy_reader_->Init(root_type); + reader_.copy_reader_->Init(root_type, numeric_conversion_); struct ArrowError na_error; int na_res = reader_.copy_reader_->InferOutputSchema(&na_error); if (na_res != NANOARROW_OK) { diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h index d469ca112a..dbe4399d99 100644 --- a/c/driver/postgresql/statement.h +++ b/c/driver/postgresql/statement.h @@ -33,6 +33,12 @@ #define ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES \ "adbc.postgresql.batch_size_hint_bytes" +#define ADBC_POSTGRESQL_NUMERIC_CONVERSION "adbc.postgresql.numeric_conversion" + +#define ADBC_POSTGRESQL_NC_OPTION_TO_STRING "to_string" + +#define ADBC_POSTGRESQL_NC_OPTION_TO_DOUBLE "to_double" + namespace adbcpq { class PostgresConnection; class PostgresStatement; @@ -162,5 +168,6 @@ class PostgresStatement { } ingest_; TupleReader reader_; + NumericConversionStrategy numeric_conversion_; }; } // namespace adbcpq