From 935cb2a25b0df4ed5d8f2f6ebca02b6389c92869 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 20 Jul 2023 11:07:04 -0400 Subject: [PATCH] feat(go/adbc/driver/flightsql): add context to gRPC errors See #862. --- go/adbc/driver/flightsql/flightsql_adbc.go | 48 +++++++++---------- .../flightsql/flightsql_adbc_server_test.go | 2 +- .../driver/flightsql/flightsql_statement.go | 10 ++-- go/adbc/driver/flightsql/record_reader.go | 4 +- go/adbc/driver/flightsql/utils.go | 7 ++- 5 files changed, 37 insertions(+), 34 deletions(-) diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go b/go/adbc/driver/flightsql/flightsql_adbc.go index e038354cc5..1ae99a6a55 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc.go +++ b/go/adbc/driver/flightsql/flightsql_adbc.go @@ -892,10 +892,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re ctx = metadata.NewOutgoingContext(ctx, c.hdrs) info, err := c.cl.GetSqlInfo(ctx, translated, c.timeouts) if err == nil { - for _, endpoint := range info.Endpoint { + for i, endpoint := range info.Endpoint { rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) } for rdr.Next() { @@ -922,11 +922,11 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re } if rdr.Err() != nil { - return nil, adbcFromFlightStatus(rdr.Err()) + return nil, adbcFromFlightStatus(rdr.Err(), "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location) } } } else if grpcstatus.Code(err) != grpccodes.Unimplemented { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)") } final := bldr.NewRecord() @@ -1032,12 +1032,12 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog * // To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response. info, err := c.cl.GetCatalogs(ctx) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)") } rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)") } defer rdr.Release() @@ -1058,7 +1058,7 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog * } if err = rdr.Err(); err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)") } return g.Finish() @@ -1069,7 +1069,7 @@ func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info // use a default queueSize for the reader rdr, err := newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "DoGet") } if !rdr.Schema().Equal(expectedSchema) { @@ -1091,12 +1091,12 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, // Pre-populate the map of which schemas are in which catalogs info, err := c.cl.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{DbSchemaFilterPattern: dbSchema}) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)") } rdr, err := c.readInfo(ctx, schema_ref.DBSchemas, info) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)") } defer rdr.Release() @@ -1117,7 +1117,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, if rdr.Err() != nil { result = nil - err = adbcFromFlightStatus(rdr.Err()) + err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetDBSchemas)") } return } @@ -1137,7 +1137,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat IncludeSchema: includeSchema, }) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)") } expectedSchema := schema_ref.Tables @@ -1146,7 +1146,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat } rdr, err := c.readInfo(ctx, expectedSchema, info) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)") } defer rdr.Release() @@ -1195,7 +1195,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat if rdr.Err() != nil { result = nil - err = adbcFromFlightStatus(rdr.Err()) + err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetTables)") } return } @@ -1211,12 +1211,12 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st ctx = metadata.NewOutgoingContext(ctx, c.hdrs) info, err := c.cl.GetTables(ctx, opts, c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetTableSchema(GetTables)") } rdr, err := doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)") } defer rdr.Release() @@ -1228,7 +1228,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st Code: adbc.StatusNotFound, } } - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)") } if rec.NumRows() == 0 { @@ -1246,7 +1246,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st schemaBytes := rec.Column(4).(*array.Binary).Value(0) s, err := flight.DeserializeSchema(schemaBytes, c.db.alloc) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetTableSchema") } return s, nil } @@ -1262,7 +1262,7 @@ func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) info, err := c.cl.GetTableTypes(ctx, c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "GetTableTypes") } return newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5) @@ -1289,12 +1289,12 @@ func (c *cnxn) Commit(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) err := c.txn.Commit(ctx, c.timeouts) if err != nil { - return adbcFromFlightStatus(err) + return adbcFromFlightStatus(err, "Commit") } c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts) if err != nil { - return adbcFromFlightStatus(err) + return adbcFromFlightStatus(err, "BeginTransaction") } return nil } @@ -1320,12 +1320,12 @@ func (c *cnxn) Rollback(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, c.hdrs) err := c.txn.Rollback(ctx, c.timeouts) if err != nil { - return adbcFromFlightStatus(err) + return adbcFromFlightStatus(err, "Rollback") } c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts) if err != nil { - return adbcFromFlightStatus(err) + return adbcFromFlightStatus(err, "BeginTransaction") } return nil } @@ -1428,7 +1428,7 @@ func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (r ctx = metadata.NewOutgoingContext(ctx, c.hdrs) rdr, err = doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "ReadPartition(DoGet)") } return rdr, nil } diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index 61a46db105..dd6171c4cd 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -338,7 +338,7 @@ func (ts *TimeoutTests) TestDoActionTimeout() { ts.ErrorAs(stmt.Prepare(context.Background()), &adbcErr) ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error()) // Exact match - we don't want extra fluff in the message - ts.Equal("context deadline exceeded", adbcErr.Msg) + ts.Equal("[FlightSQL] context deadline exceeded (DeadlineExceeded; Prepare)", adbcErr.Msg) } func (ts *TimeoutTests) TestDoGetTimeout() { diff --git a/go/adbc/driver/flightsql/flightsql_statement.go b/go/adbc/driver/flightsql/flightsql_statement.go index c7f074a800..3e7d20e1c4 100644 --- a/go/adbc/driver/flightsql/flightsql_statement.go +++ b/go/adbc/driver/flightsql/flightsql_statement.go @@ -239,7 +239,7 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, n } if err != nil { - return nil, -1, adbcFromFlightStatus(err) + return nil, -1, adbcFromFlightStatus(err, "ExecuteQuery") } nrec = info.TotalRecords @@ -259,7 +259,7 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) { } if err != nil { - err = adbcFromFlightStatus(err) + err = adbcFromFlightStatus(err, "ExecuteUpdate") } return @@ -271,7 +271,7 @@ func (s *statement) Prepare(ctx context.Context) error { ctx = metadata.NewOutgoingContext(ctx, s.hdrs) prep, err := s.query.prepare(ctx, s.cnxn, s.timeouts) if err != nil { - return adbcFromFlightStatus(err) + return adbcFromFlightStatus(err, "Prepare") } s.prepared = prep return nil @@ -394,13 +394,13 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc. } if err != nil { - return nil, out, -1, adbcFromFlightStatus(err) + return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions") } if len(info.Schema) > 0 { sc, err = flight.DeserializeSchema(info.Schema, s.alloc) if err != nil { - return nil, out, -1, adbcFromFlightStatus(err) + return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions: could not deserialize FlightInfo schema:") } } diff --git a/go/adbc/driver/flightsql/record_reader.go b/go/adbc/driver/flightsql/record_reader.go index 409ce58e61..c2721a7af3 100644 --- a/go/adbc/driver/flightsql/record_reader.go +++ b/go/adbc/driver/flightsql/record_reader.go @@ -90,7 +90,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. } else { rdr, err := doGet(ctx, cl, endpoints[0], clCache, opts...) if err != nil { - return nil, adbcFromFlightStatus(err) + return nil, adbcFromFlightStatus(err, "DoGet: endpoint 0: remote: %s", endpoints[0].Location) } schema = rdr.Schema() group.Go(func() error { @@ -135,7 +135,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql. rdr, err := doGet(ctx, cl, endpoint, clCache, opts...) if err != nil { - return err + return adbcFromFlightStatus(err, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location) } defer rdr.Release() diff --git a/go/adbc/driver/flightsql/utils.go b/go/adbc/driver/flightsql/utils.go index cbf9048fc0..e4cf276807 100644 --- a/go/adbc/driver/flightsql/utils.go +++ b/go/adbc/driver/flightsql/utils.go @@ -18,12 +18,14 @@ package flightsql import ( + "fmt" + "github.com/apache/arrow-adbc/go/adbc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) -func adbcFromFlightStatus(err error) error { +func adbcFromFlightStatus(err error, context string, args ...any) error { if _, ok := err.(adbc.Error); ok { return err } @@ -70,8 +72,9 @@ func adbcFromFlightStatus(err error) error { adbcCode = adbc.StatusUnknown } + // People don't read error messages, so backload the context and frontload the server error return adbc.Error{ - Msg: grpcStatus.Message(), + Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)), Code: adbcCode, } }