From 935cb2a25b0df4ed5d8f2f6ebca02b6389c92869 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 20 Jul 2023 11:07:04 -0400
Subject: [PATCH] feat(go/adbc/driver/flightsql): add context to gRPC errors
See #862.
---
go/adbc/driver/flightsql/flightsql_adbc.go | 48 +++++++++----------
.../flightsql/flightsql_adbc_server_test.go | 2 +-
.../driver/flightsql/flightsql_statement.go | 10 ++--
go/adbc/driver/flightsql/record_reader.go | 4 +-
go/adbc/driver/flightsql/utils.go | 7 ++-
5 files changed, 37 insertions(+), 34 deletions(-)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go b/go/adbc/driver/flightsql/flightsql_adbc.go
index e038354cc5..1ae99a6a55 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -892,10 +892,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
info, err := c.cl.GetSqlInfo(ctx, translated, c.timeouts)
if err == nil {
- for _, endpoint := range info.Endpoint {
+ for i, endpoint := range info.Endpoint {
rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, c.timeouts)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location)
}
for rdr.Next() {
@@ -922,11 +922,11 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
}
if rdr.Err() != nil {
- return nil, adbcFromFlightStatus(rdr.Err())
+ return nil, adbcFromFlightStatus(rdr.Err(), "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location)
}
}
} else if grpcstatus.Code(err) != grpccodes.Unimplemented {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)")
}
final := bldr.NewRecord()
@@ -1032,12 +1032,12 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *
// To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response.
info, err := c.cl.GetCatalogs(ctx)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
}
rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
}
defer rdr.Release()
@@ -1058,7 +1058,7 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *
}
if err = rdr.Err(); err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
}
return g.Finish()
@@ -1069,7 +1069,7 @@ func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info
// use a default queueSize for the reader
rdr, err := newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "DoGet")
}
if !rdr.Schema().Equal(expectedSchema) {
@@ -1091,12 +1091,12 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,
// Pre-populate the map of which schemas are in which catalogs
info, err := c.cl.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{DbSchemaFilterPattern: dbSchema})
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)")
}
rdr, err := c.readInfo(ctx, schema_ref.DBSchemas, info)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)")
}
defer rdr.Release()
@@ -1117,7 +1117,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,
if rdr.Err() != nil {
result = nil
- err = adbcFromFlightStatus(rdr.Err())
+ err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetDBSchemas)")
}
return
}
@@ -1137,7 +1137,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat
IncludeSchema: includeSchema,
})
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)")
}
expectedSchema := schema_ref.Tables
@@ -1146,7 +1146,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat
}
rdr, err := c.readInfo(ctx, expectedSchema, info)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)")
}
defer rdr.Release()
@@ -1195,7 +1195,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat
if rdr.Err() != nil {
result = nil
- err = adbcFromFlightStatus(rdr.Err())
+ err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetTables)")
}
return
}
@@ -1211,12 +1211,12 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
info, err := c.cl.GetTables(ctx, opts, c.timeouts)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetTableSchema(GetTables)")
}
rdr, err := doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)")
}
defer rdr.Release()
@@ -1228,7 +1228,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
Code: adbc.StatusNotFound,
}
}
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)")
}
if rec.NumRows() == 0 {
@@ -1246,7 +1246,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
schemaBytes := rec.Column(4).(*array.Binary).Value(0)
s, err := flight.DeserializeSchema(schemaBytes, c.db.alloc)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetTableSchema")
}
return s, nil
}
@@ -1262,7 +1262,7 @@ func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
info, err := c.cl.GetTableTypes(ctx, c.timeouts)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "GetTableTypes")
}
return newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5)
@@ -1289,12 +1289,12 @@ func (c *cnxn) Commit(ctx context.Context) error {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
err := c.txn.Commit(ctx, c.timeouts)
if err != nil {
- return adbcFromFlightStatus(err)
+ return adbcFromFlightStatus(err, "Commit")
}
c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts)
if err != nil {
- return adbcFromFlightStatus(err)
+ return adbcFromFlightStatus(err, "BeginTransaction")
}
return nil
}
@@ -1320,12 +1320,12 @@ func (c *cnxn) Rollback(ctx context.Context) error {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
err := c.txn.Rollback(ctx, c.timeouts)
if err != nil {
- return adbcFromFlightStatus(err)
+ return adbcFromFlightStatus(err, "Rollback")
}
c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts)
if err != nil {
- return adbcFromFlightStatus(err)
+ return adbcFromFlightStatus(err, "BeginTransaction")
}
return nil
}
@@ -1428,7 +1428,7 @@ func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (r
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
rdr, err = doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "ReadPartition(DoGet)")
}
return rdr, nil
}
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 61a46db105..dd6171c4cd 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -338,7 +338,7 @@ func (ts *TimeoutTests) TestDoActionTimeout() {
ts.ErrorAs(stmt.Prepare(context.Background()), &adbcErr)
ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
// Exact match - we don't want extra fluff in the message
- ts.Equal("context deadline exceeded", adbcErr.Msg)
+ ts.Equal("[FlightSQL] context deadline exceeded (DeadlineExceeded; Prepare)", adbcErr.Msg)
}
func (ts *TimeoutTests) TestDoGetTimeout() {
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go
index c7f074a800..3e7d20e1c4 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -239,7 +239,7 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, n
}
if err != nil {
- return nil, -1, adbcFromFlightStatus(err)
+ return nil, -1, adbcFromFlightStatus(err, "ExecuteQuery")
}
nrec = info.TotalRecords
@@ -259,7 +259,7 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) {
}
if err != nil {
- err = adbcFromFlightStatus(err)
+ err = adbcFromFlightStatus(err, "ExecuteUpdate")
}
return
@@ -271,7 +271,7 @@ func (s *statement) Prepare(ctx context.Context) error {
ctx = metadata.NewOutgoingContext(ctx, s.hdrs)
prep, err := s.query.prepare(ctx, s.cnxn, s.timeouts)
if err != nil {
- return adbcFromFlightStatus(err)
+ return adbcFromFlightStatus(err, "Prepare")
}
s.prepared = prep
return nil
@@ -394,13 +394,13 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
}
if err != nil {
- return nil, out, -1, adbcFromFlightStatus(err)
+ return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions")
}
if len(info.Schema) > 0 {
sc, err = flight.DeserializeSchema(info.Schema, s.alloc)
if err != nil {
- return nil, out, -1, adbcFromFlightStatus(err)
+ return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions: could not deserialize FlightInfo schema:")
}
}
diff --git a/go/adbc/driver/flightsql/record_reader.go b/go/adbc/driver/flightsql/record_reader.go
index 409ce58e61..c2721a7af3 100644
--- a/go/adbc/driver/flightsql/record_reader.go
+++ b/go/adbc/driver/flightsql/record_reader.go
@@ -90,7 +90,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.
} else {
rdr, err := doGet(ctx, cl, endpoints[0], clCache, opts...)
if err != nil {
- return nil, adbcFromFlightStatus(err)
+ return nil, adbcFromFlightStatus(err, "DoGet: endpoint 0: remote: %s", endpoints[0].Location)
}
schema = rdr.Schema()
group.Go(func() error {
@@ -135,7 +135,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.
rdr, err := doGet(ctx, cl, endpoint, clCache, opts...)
if err != nil {
- return err
+ return adbcFromFlightStatus(err, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location)
}
defer rdr.Release()
diff --git a/go/adbc/driver/flightsql/utils.go b/go/adbc/driver/flightsql/utils.go
index cbf9048fc0..e4cf276807 100644
--- a/go/adbc/driver/flightsql/utils.go
+++ b/go/adbc/driver/flightsql/utils.go
@@ -18,12 +18,14 @@
package flightsql
import (
+ "fmt"
+
"github.com/apache/arrow-adbc/go/adbc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
-func adbcFromFlightStatus(err error) error {
+func adbcFromFlightStatus(err error, context string, args ...any) error {
if _, ok := err.(adbc.Error); ok {
return err
}
@@ -70,8 +72,9 @@ func adbcFromFlightStatus(err error) error {
adbcCode = adbc.StatusUnknown
}
+ // People don't read error messages, so backload the context and frontload the server error
return adbc.Error{
- Msg: grpcStatus.Message(),
+ Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)),
Code: adbcCode,
}
}