diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 968ca942d3..eb9f2a2994 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -2006,3 +2006,28 @@ func (suite *SnowflakeTests) TestJwtPrivateKey() { defer os.Remove(binKey) verifyKey(binKey) } + +func (suite *SnowflakeTests) TestMetadataOnlyQuery() { + // force more than one chunk for `SHOW FUNCTIONS` which will return + // JSON data instead of arrow, even though we ask for Arrow + suite.Require().NoError(suite.stmt.SetSqlQuery(`ALTER SESSION SET CLIENT_RESULT_CHUNK_SIZE = 50`)) + _, err := suite.stmt.ExecuteUpdate(suite.ctx) + suite.Require().NoError(err) + + // since we lowered the CLIENT_RESULT_CHUNK_SIZE this will return at least + // 1 chunk in addition to the first one. Metadata queries will return JSON + // no matter what currently. + suite.Require().NoError(suite.stmt.SetSqlQuery(`SHOW FUNCTIONS`)) + rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) + suite.Require().NoError(err) + defer rdr.Release() + + recv := int64(0) + for rdr.Next() { + recv += rdr.Record().NumRows() + } + + // verify that we got the exepected number of rows if we sum up + // all the rows from each record in the stream. + suite.Equal(n, recv) +} diff --git a/go/adbc/driver/snowflake/record_reader.go b/go/adbc/driver/snowflake/record_reader.go index acf86bd0d7..1a24b91d96 100644 --- a/go/adbc/driver/snowflake/record_reader.go +++ b/go/adbc/driver/snowflake/record_reader.go @@ -18,9 +18,12 @@ package snowflake import ( + "bytes" "context" "encoding/hex" + "encoding/json" "fmt" + "io" "math" "strconv" "strings" @@ -300,7 +303,7 @@ func integerToDecimal128(ctx context.Context, a arrow.Array, dt *arrow.Decimal12 return result, err } -func rowTypesToArrowSchema(ctx context.Context, ld gosnowflake.ArrowStreamLoader, useHighPrecision bool) (*arrow.Schema, error) { +func rowTypesToArrowSchema(_ context.Context, ld gosnowflake.ArrowStreamLoader, useHighPrecision bool) (*arrow.Schema, error) { var loc *time.Location metadata := ld.RowTypes() @@ -360,8 +363,7 @@ func extractTimestamp(src *string) (sec, nsec int64, err error) { return } -func jsonDataToArrow(ctx context.Context, bldr *array.RecordBuilder, ld gosnowflake.ArrowStreamLoader) (arrow.Record, error) { - rawData := ld.JSONData() +func jsonDataToArrow(_ context.Context, bldr *array.RecordBuilder, rawData [][]*string) (arrow.Record, error) { fieldBuilders := bldr.Fields() for _, rec := range rawData { for i, col := range rec { @@ -471,7 +473,12 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake return nil, errToAdbcErr(adbc.StatusInternal, err) } - if len(batches) == 0 { + // if the first chunk was JSON, that means this was a metadata query which + // is only returning JSON data rather than Arrow + rawData := ld.JSONData() + if len(rawData) > 0 { + // construct an Arrow schema based on reading the JSON metadata description of the + // result type schema schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision) if err != nil { return nil, adbc.Error{ @@ -480,20 +487,87 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake } } + if ld.TotalRows() == 0 { + return array.NewRecordReader(schema, []arrow.Record{}) + } + bldr := array.NewRecordBuilder(alloc, schema) defer bldr.Release() - rec, err := jsonDataToArrow(ctx, bldr, ld) + rec, err := jsonDataToArrow(ctx, bldr, rawData) if err != nil { return nil, err } defer rec.Release() - if ld.TotalRows() != 0 { - return array.NewRecordReader(schema, []arrow.Record{rec}) - } else { - return array.NewRecordReader(schema, []arrow.Record{}) + results := []arrow.Record{rec} + for _, b := range batches { + rdr, err := b.GetStream(ctx) + if err != nil { + return nil, adbc.Error{ + Msg: err.Error(), + Code: adbc.StatusInternal, + } + } + defer rdr.Close() + + // the "JSON" data returned isn't valid JSON. Instead it is a list of + // comma-delimited JSON lists containing every value as a string, except + // for a JSON null to represent nulls. Thus we can't just use the existing + // JSON parsing code in Arrow. + data, err := io.ReadAll(rdr) + if err != nil { + return nil, adbc.Error{ + Msg: err.Error(), + Code: adbc.StatusInternal, + } + } + + if cap(rawData) >= int(b.NumRows()) { + rawData = rawData[:b.NumRows()] + } else { + rawData = make([][]*string, b.NumRows()) + } + bldr.Reserve(int(b.NumRows())) + + // we grab the entire JSON message and create a bytes reader + offset, buf := int64(0), bytes.NewReader(data) + for i := 0; i < int(b.NumRows()); i++ { + // we construct a decoder from the bytes.Reader to read the next JSON list + // of columns (one row) from the input + dec := json.NewDecoder(buf) + if err = dec.Decode(&rawData[i]); err != nil { + return nil, adbc.Error{ + Msg: err.Error(), + Code: adbc.StatusInternal, + } + } + + // dec.InputOffset() now represents the index of the ',' so we skip the comma + offset += dec.InputOffset() + 1 + // then seek the buffer to that spot. we have to seek based on the start + // because json.Decoder can read from the buffer more than is necessary to + // process the JSON data. + if _, err = buf.Seek(offset, 0); err != nil { + return nil, adbc.Error{ + Msg: err.Error(), + Code: adbc.StatusInternal, + } + } + } + + // now that we have our [][]*string of JSON data, we can pass it to get converted + // to an Arrow record batch and appended to our slice of batches + rec, err := jsonDataToArrow(ctx, bldr, rawData) + if err != nil { + return nil, err + } + defer rec.Release() + + results = append(results, rec) } + + return array.NewRecordReader(schema, results) } ch := make(chan arrow.Record, bufferSize)