Skip to content

Commit

Permalink
feat(c/driver): Date32 support (#948)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored Aug 1, 2023
1 parent 2de52f3 commit 995a02d
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 31 deletions.
6 changes: 3 additions & 3 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,9 +576,9 @@ class PostgresStatementTest : public ::testing::Test,
}

protected:
void ValidateIngestedTemporalData(struct ArrowArrayView* values,
enum ArrowTimeUnit unit,
const char* timezone) override {
void ValidateIngestedTimestampData(struct ArrowArrayView* values,
enum ArrowTimeUnit unit,
const char* timezone) override {
std::vector<std::optional<int64_t>> expected;
switch (unit) {
case (NANOARROW_TIME_UNIT_SECOND):
Expand Down
23 changes: 23 additions & 0 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ struct BindStream {
type_id = PostgresTypeId::kBytea;
param_lengths[i] = 0;
break;
case ArrowType::NANOARROW_TYPE_DATE32:
type_id = PostgresTypeId::kDate;
param_lengths[i] = 4;
break;
case ArrowType::NANOARROW_TYPE_TIMESTAMP:
type_id = PostgresTypeId::kTimestamp;
param_lengths[i] = 8;
Expand Down Expand Up @@ -389,6 +393,22 @@ struct BindStream {
param_values[col] = const_cast<char*>(view.data.as_char);
break;
}
case ArrowType::NANOARROW_TYPE_DATE32: {
// 2000-01-01
constexpr int32_t kPostgresDateEpoch = 10957;
const int32_t raw_value =
array_view->children[col]->buffer_views[1].data.as_int32[row];
if (raw_value < INT32_MIN + kPostgresDateEpoch) {
SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1,
"('", bind_schema->children[col]->name, "') Row #", row + 1,
"has value which exceeds postgres date limits");
return ADBC_STATUS_INVALID_ARGUMENT;
}

const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch);
std::memcpy(param_values[col], &value, sizeof(int32_t));
break;
}
case ArrowType::NANOARROW_TYPE_TIMESTAMP: {
int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row];

Expand Down Expand Up @@ -801,6 +821,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
case ArrowType::NANOARROW_TYPE_BINARY:
create += " BYTEA";
break;
case ArrowType::NANOARROW_TYPE_DATE32:
create += " DATE";
break;
case ArrowType::NANOARROW_TYPE_TIMESTAMP:
if (strcmp("", source_schema_fields[i].timezone)) {
create += " TIMESTAMPTZ";
Expand Down
6 changes: 3 additions & 3 deletions c/driver/snowflake/snowflake_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ class SnowflakeStatementTest : public ::testing::Test,
}

protected:
void ValidateIngestedTemporalData(struct ArrowArrayView* values,
enum ArrowTimeUnit unit,
const char* timezone) override {
void ValidateIngestedTimestampData(struct ArrowArrayView* values,
enum ArrowTimeUnit unit,
const char* timezone) override {
std::vector<std::optional<int64_t>> expected;
switch (unit) {
case NANOARROW_TIME_UNIT_SECOND:
Expand Down
7 changes: 4 additions & 3 deletions c/driver/sqlite/sqlite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
case NANOARROW_TYPE_FLOAT:
case NANOARROW_TYPE_DOUBLE:
return NANOARROW_TYPE_DOUBLE;
case NANOARROW_TYPE_DATE32:
case NANOARROW_TYPE_TIMESTAMP:
return NANOARROW_TYPE_STRING;
default:
Expand Down Expand Up @@ -200,9 +201,9 @@ class SqliteStatementTest : public ::testing::Test,
}

protected:
void ValidateIngestedTemporalData(struct ArrowArrayView* values,
enum ArrowTimeUnit unit,
const char* timezone) override {
void ValidateIngestedTimestampData(struct ArrowArrayView* values,
enum ArrowTimeUnit unit,
const char* timezone) override {
std::vector<std::optional<std::string>> expected;
switch (unit) {
case (NANOARROW_TIME_UNIT_SECOND):
Expand Down
75 changes: 75 additions & 0 deletions c/driver/sqlite/statement_reader.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,59 @@ AdbcStatusCode AdbcSqliteBinderSetArrayStream(struct AdbcSqliteBinder* binder,
return AdbcSqliteBinderSet(binder, error);
}

#define SECONDS_PER_DAY 86400

/*
Allocates to buf on success. Caller is responsible for freeing.
On failure sets error and contents of buf are undefined.
*/
static AdbcStatusCode ArrowDate32ToIsoString(int32_t value, char** buf,
struct AdbcError* error) {
int strlen = 10;

#if SIZEOF_TIME_T < 8
if ((seconds > INT32_MAX / SECONDS_PER_DAY) ||
(seconds < INT32_MIN / SECONDS_PER_DAY)) {
SetError(error, "Date %" PRId32 " exceeds platform time_t bounds", value);

return ADBC_STATUS_INVALID_ARGUMENT;
}
time_t time = (time_t)(value * SECONDS_PER_DAY);
#else
time_t time = value * SECONDS_PER_DAY;
#endif

struct tm broken_down_time;

#if defined(_WIN32)
if (gmtime_s(&broken_down_time, &time) != 0) {
SetError(error, "Could not convert date %" PRId32 " to broken down time", value);

return ADBC_STATUS_INVALID_ARGUMENT;
}
#else
if (gmtime_r(&time, &broken_down_time) != &broken_down_time) {
SetError(error, "Could not convert date %" PRId32 " to broken down time", value);

return ADBC_STATUS_INVALID_ARGUMENT;
}
#endif

char* tsstr = malloc(strlen + 1);
if (tsstr == NULL) {
return ADBC_STATUS_IO;
}

if (strftime(tsstr, strlen + 1, "%Y-%m-%d", &broken_down_time) == 0) {
SetError(error, "Call to strftime for date %" PRId32 " with failed", value);
free(tsstr);
return ADBC_STATUS_INVALID_ARGUMENT;
}

*buf = tsstr;
return ADBC_STATUS_OK;
}

/*
Allocates to buf on success. Caller is responsible for freeing.
On failure sets error and contents of buf are undefined.
Expand Down Expand Up @@ -300,6 +353,28 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3
SQLITE_STATIC);
break;
}
case NANOARROW_TYPE_DATE32: {
int64_t value =
ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row);
char* tsstr;

if ((value > INT32_MAX) || (value < INT32_MIN)) {
SetError(error,
"Column %d has value %" PRId64
" which exceeds the expected range "
"for an Arrow DATE32 type",
col, value);
return ADBC_STATUS_INVALID_DATA;
}

RAISE_ADBC(ArrowDate32ToIsoString((int32_t)value, &tsstr, error));
// SQLITE_TRANSIENT ensures the value is copied during bind
status =
sqlite3_bind_text(stmt, col + 1, tsstr, strlen(tsstr), SQLITE_TRANSIENT);

free(tsstr);
break;
}
case NANOARROW_TYPE_TIMESTAMP: {
struct ArrowSchemaView bind_schema_view;
RAISE_ADBC(ArrowSchemaViewInit(&bind_schema_view, binder->schema.children[col],
Expand Down
1 change: 1 addition & 0 deletions c/driver_manager/adbc_driver_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ class SqliteStatementTest : public ::testing::Test,

void TestSqlIngestUInt64() { GTEST_SKIP() << "Cannot ingest UINT64 (out of range)"; }
void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; }
void TestSqlIngestDate32() { GTEST_SKIP() << "Cannot ingest DATE (not implemented)"; }
void TestSqlIngestTimestamp() {
GTEST_SKIP() << "Cannot ingest TIMESTAMP (not implemented)";
}
Expand Down
45 changes: 27 additions & 18 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,10 @@ void StatementTest::TestSqlIngestNumericType(ArrowType type) {
// values. Likely a bug on our side, but for now, avoid them.
values.push_back(static_cast<CType>(-1.5));
values.push_back(static_cast<CType>(1.5));
} else if (type == ArrowType::NANOARROW_TYPE_DATE32) {
// Windows does not seem to support negative date values
values.push_back(static_cast<CType>(0));
values.push_back(static_cast<CType>(42));
} else {
values.push_back(std::numeric_limits<CType>::lowest());
values.push_back(std::numeric_limits<CType>::max());
Expand Down Expand Up @@ -1095,8 +1099,12 @@ void StatementTest::TestSqlIngestBinary() {
NANOARROW_TYPE_BINARY, {std::nullopt, "", "\x00\x01\x02\x04", "\xFE\xFF"}));
}

void StatementTest::TestSqlIngestDate32() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<int32_t>(NANOARROW_TYPE_DATE32));
}

template <enum ArrowTimeUnit TU>
void StatementTest::TestSqlIngestTemporalType(const char* timezone) {
void StatementTest::TestSqlIngestTimestampType(const char* timezone) {
if (!quirks()->supports_bulk_ingest()) {
GTEST_SKIP();
}
Expand Down Expand Up @@ -1155,7 +1163,7 @@ void StatementTest::TestSqlIngestTemporalType(const char* timezone) {
ASSERT_EQ(values.size(), reader.array->length);
ASSERT_EQ(1, reader.array->n_children);

ValidateIngestedTemporalData(reader.array_view->children[0], TU, timezone);
ValidateIngestedTimestampData(reader.array_view->children[0], TU, timezone);

ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(nullptr, reader.array->release);
Expand All @@ -1164,33 +1172,34 @@ void StatementTest::TestSqlIngestTemporalType(const char* timezone) {
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

void StatementTest::ValidateIngestedTemporalData(struct ArrowArrayView* values,
enum ArrowTimeUnit unit,
const char* timezone) {
FAIL() << "ValidateIngestedTemporalData is not implemented in the base class";
void StatementTest::ValidateIngestedTimestampData(struct ArrowArrayView* values,
enum ArrowTimeUnit unit,
const char* timezone) {
FAIL() << "ValidateIngestedTimestampData is not implemented in the base class";
}

void StatementTest::TestSqlIngestTimestamp() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>(nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>(nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>(nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>(nullptr));
ASSERT_NO_FATAL_FAILURE(
TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_SECOND>(nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MILLI>(nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MICRO>(nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_NANO>(nullptr));
}

void StatementTest::TestSqlIngestTimestampTz() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>("UTC"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>("UTC"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>("UTC"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>("UTC"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_SECOND>("UTC"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MILLI>("UTC"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MICRO>("UTC"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_NANO>("UTC"));

ASSERT_NO_FATAL_FAILURE(
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>("America/Los_Angeles"));
TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_SECOND>("America/Los_Angeles"));
ASSERT_NO_FATAL_FAILURE(
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>("America/Los_Angeles"));
TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MILLI>("America/Los_Angeles"));
ASSERT_NO_FATAL_FAILURE(
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>("America/Los_Angeles"));
TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_MICRO>("America/Los_Angeles"));
ASSERT_NO_FATAL_FAILURE(
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>("America/Los_Angeles"));
TestSqlIngestTimestampType<NANOARROW_TIME_UNIT_NANO>("America/Los_Angeles"));
}

void StatementTest::TestSqlIngestInterval() {
Expand Down
10 changes: 6 additions & 4 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class StatementTest {
void TestSqlIngestBinary();

// Temporal
void TestSqlIngestDate32();
void TestSqlIngestTimestamp();
void TestSqlIngestTimestampTz();
void TestSqlIngestInterval();
Expand Down Expand Up @@ -277,11 +278,11 @@ class StatementTest {
void TestSqlIngestNumericType(ArrowType type);

template <enum ArrowTimeUnit TU>
void TestSqlIngestTemporalType(const char* timezone);
void TestSqlIngestTimestampType(const char* timezone);

virtual void ValidateIngestedTemporalData(struct ArrowArrayView* values,
enum ArrowTimeUnit unit,
const char* timezone);
virtual void ValidateIngestedTimestampData(struct ArrowArrayView* values,
enum ArrowTimeUnit unit,
const char* timezone);
};

#define ADBCV_TEST_STATEMENT(FIXTURE) \
Expand All @@ -301,6 +302,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestFloat64) { TestSqlIngestFloat64(); } \
TEST_F(FIXTURE, SqlIngestString) { TestSqlIngestString(); } \
TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); } \
TEST_F(FIXTURE, SqlIngestDate32) { TestSqlIngestDate32(); } \
TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); } \
TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); } \
TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); } \
Expand Down

0 comments on commit 995a02d

Please sign in to comment.