From 25f2fc8a9044f30c11b8b54d5af6f3f36375ace0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Ch=C3=A1vez?= Date: Thu, 15 Nov 2018 14:40:22 +0200 Subject: [PATCH 1/3] test: adds exec and query test with sqlite3. --- driver.go | 48 ++++++++++--------- driver_test.go | 123 +++++++++++++++++++++++++++++++++++++++++++++++++ options.go | 17 +++++++ 3 files changed, 165 insertions(+), 23 deletions(-) create mode 100644 driver_test.go diff --git a/driver.go b/driver.go index a4ed276..e6ef05f 100644 --- a/driver.go +++ b/driver.go @@ -169,7 +169,7 @@ func (c zConn) Exec(query string, args []driver.Value) (res driver.Result, err e func (c zConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (res driver.Result, err error) { if execCtx, ok := c.driver.(driver.ExecerContext); ok { - if zipkin.SpanFromContext(ctx) == nil { + if zipkin.SpanFromContext(ctx) == nil && !c.options.AllowRootSpan { return execCtx.ExecContext(ctx, query, args) } @@ -187,7 +187,7 @@ func (c zConn) ExecContext(ctx context.Context, query string, args []driver.Name return nil, err } - return zResult{driver: res, ctx: ctx, options: c.options}, nil + return zResult{driver: res, tracer: c.tracer, ctx: ctx, options: c.options}, nil } return nil, driver.ErrSkip @@ -203,11 +203,11 @@ func (c zConn) Query(query string, args []driver.Value) (rows driver.Rows, err e func (c zConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { if queryerCtx, ok := c.driver.(driver.QueryerContext); ok { - if zipkin.SpanFromContext(ctx) == nil { + if zipkin.SpanFromContext(ctx) == nil && !c.options.AllowRootSpan { return queryerCtx.QueryContext(ctx, query, args) } - span, _ := c.tracer.StartSpanFromContext(ctx, "sql/exec", zipkin.Kind(zipkinmodel.Client)) + span, _ := c.tracer.StartSpanFromContext(ctx, "sql/query", zipkin.Kind(zipkinmodel.Client)) defer span.Finish() if c.options.TagQuery { @@ -254,7 +254,7 @@ func (c *zConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, } func (c *zConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - if zipkin.SpanFromContext(ctx) == nil { + if zipkin.SpanFromContext(ctx) == nil && !c.options.AllowRootSpan { if connBeginTx, ok := c.driver.(driver.ConnBeginTx); ok { return connBeginTx.BeginTx(ctx, opts) } @@ -311,7 +311,7 @@ func (r zResult) LastInsertId() (int64, error) { func (r zResult) RowsAffected() (cnt int64, err error) { zipkin.SpanFromContext(r.ctx) - if r.options.RowsAffectedSpan && zipkin.SpanFromContext(r.ctx) != nil { + if r.options.RowsAffectedSpan && (r.options.AllowRootSpan || zipkin.SpanFromContext(r.ctx) != nil) { span, _ := r.tracer.StartSpanFromContext(r.ctx, "sql/rows_affected", zipkin.Kind(zipkinmodel.Client)) setSpanDefaultTags(span, r.options.DefaultTags) defer func() { @@ -350,7 +350,7 @@ func (s zStmt) Query(args []driver.Value) (driver.Rows, error) { } func (s zStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { - if zipkin.SpanFromContext(ctx) == nil { + if zipkin.SpanFromContext(ctx) == nil && !s.options.AllowRootSpan { return s.driver.(driver.StmtExecContext).ExecContext(ctx, args) } @@ -378,12 +378,12 @@ func (s zStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res d } } - res, err = zResult{driver: res, ctx: ctx, options: s.options}, nil + res, err = zResult{driver: res, tracer: s.tracer, ctx: ctx, options: s.options}, nil return } func (s zStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { - if zipkin.SpanFromContext(ctx) == nil { + if zipkin.SpanFromContext(ctx) == nil && !s.options.AllowRootSpan { return s.driver.(driver.StmtQueryContext).QueryContext(ctx, args) } @@ -443,25 +443,27 @@ type zTx struct { } func (t zTx) Commit() (err error) { - span, _ := t.tracer.StartSpanFromContext(t.ctx, "sql/commit", zipkin.Kind(zipkinmodel.Client)) - defer func() { - setSpanDefaultTags(span, t.options.DefaultTags) - setSpanError(span, err) - span.Finish() - }() - + if zipkin.SpanFromContext(t.ctx) != nil || t.options.AllowRootSpan { + span, _ := t.tracer.StartSpanFromContext(t.ctx, "sql/commit", zipkin.Kind(zipkinmodel.Client)) + defer func() { + setSpanDefaultTags(span, t.options.DefaultTags) + setSpanError(span, err) + span.Finish() + }() + } err = t.driver.Commit() return } func (t zTx) Rollback() (err error) { - span, _ := t.tracer.StartSpanFromContext(t.ctx, "sql/rollback", zipkin.Kind(zipkinmodel.Client)) - defer func() { - setSpanDefaultTags(span, t.options.DefaultTags) - setSpanError(span, err) - span.Finish() - }() - + if zipkin.SpanFromContext(t.ctx) != nil || t.options.AllowRootSpan { + span, _ := t.tracer.StartSpanFromContext(t.ctx, "sql/rollback", zipkin.Kind(zipkinmodel.Client)) + defer func() { + setSpanDefaultTags(span, t.options.DefaultTags) + setSpanError(span, err) + span.Finish() + }() + } err = t.driver.Rollback() return } diff --git a/driver_test.go b/driver_test.go new file mode 100644 index 0000000..89e5acb --- /dev/null +++ b/driver_test.go @@ -0,0 +1,123 @@ +package zipkinsql + +import ( + "context" + "database/sql" + "testing" + + _ "github.com/mattn/go-sqlite3" + zipkin "github.com/openzipkin/zipkin-go" + zipkinreporter "github.com/openzipkin/zipkin-go/reporter/recorder" +) + +func createDB(t *testing.T, opts ...TraceOption) (*sql.DB, *zipkinreporter.ReporterRecorder) { + reporter := zipkinreporter.NewReporter() + tracer, _ := zipkin.NewTracer(reporter) + + driverName, err := Register("sqlite3", tracer, opts...) + if err != nil { + t.Fatalf("unable to register driver") + } + + db, err := sql.Open(driverName, "file:test.db?cache=shared&mode=memory") + if err != nil { + t.Fatal(err) + } + + return db, reporter +} + +type testCase struct { + opts []TraceOption + expectedSpans int +} + +func TestQuerySuccess(t *testing.T) { + ctx := context.Background() + testCases := []testCase{ + {[]TraceOption{WithAllowRootSpan(false)}, 0}, + {[]TraceOption{WithAllowRootSpan(true)}, 1}, + } + for _, c := range testCases { + db, recorder := createDB(t, c.opts...) + + rows, err := db.QueryContext(ctx, "SELECT 1") + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + defer rows.Close() + + for rows.Next() { + var n int + if err = rows.Scan(&n); err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + } + if err = rows.Err(); err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + spans := recorder.Flush() + if want, have := c.expectedSpans, len(spans); want != have { + t.Fatalf("unexpected number of spans, want: %d, have: %d", want, have) + } + + if c.expectedSpans > 0 { + if want, have := "sql/query", spans[0].Name; want != have { + t.Fatalf("unexpected span name, want: %s, have: %s", want, have) + } + } + + db.Close() + recorder.Close() + } +} + +func TestExecSuccess(t *testing.T) { + ctx := context.Background() + + testCases := []testCase{ + {[]TraceOption{WithAllowRootSpan(false)}, 0}, + {[]TraceOption{WithAllowRootSpan(true)}, 1}, + {[]TraceOption{WithAllowRootSpan(true), WithLastInsertIDSpan(true)}, 2}, + {[]TraceOption{WithAllowRootSpan(true), WithRowsAffectedSpan(true)}, 2}, + {[]TraceOption{WithAllowRootSpan(true), WithLastInsertIDSpan(true), WithRowsAffectedSpan(true)}, 3}, + } + for _, c := range testCases { + db, recorder := createDB(t, c.opts...) + + sqlStmt := ` + create table foo (id integer not null primary key, name text); + delete from foo; + ` + + res, err := db.ExecContext(ctx, sqlStmt) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + _, err = res.LastInsertId() + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + _, err = res.RowsAffected() + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + spans := recorder.Flush() + if want, have := c.expectedSpans, len(spans); want != have { + t.Fatalf("unexpected number of spans, want: %d, have: %d", want, have) + } + + if c.expectedSpans > 0 { + if want, have := "sql/exec", spans[0].Name; want != have { + t.Fatalf("unexpected span name, want: %s, have: %s", want, have) + } + } + + db.Close() + recorder.Close() + } +} diff --git a/options.go b/options.go index d48ec80..46fae00 100644 --- a/options.go +++ b/options.go @@ -8,6 +8,12 @@ type TraceOption func(o *TraceOptions) // a wrapped driver and provide the most sensible default with both performance // and security in mind. type TraceOptions struct { + // AllowRoot, if set to true, will allow zipkinsql to create root spans in + // absence of existing spans or even context. + // Default is to not trace zipkinsql calls if no existing parent span is found + // in context or when using methods not taking context. + AllowRootSpan bool + // LastInsertIDSpan, if set to true, will enable the creation of spans on // LastInsertId calls. LastInsertIDSpan bool @@ -36,6 +42,7 @@ func WithAllTraceOptions() TraceOption { // AllTraceOptions has all tracing options enabled. var AllTraceOptions = TraceOptions{ + AllowRootSpan: true, RowsAffectedSpan: true, LastInsertIDSpan: true, TagQuery: true, @@ -50,6 +57,16 @@ func WithOptions(options TraceOptions) TraceOption { } } +// WithAllowRootSpan if set to true, will allow zipkinsql to create root spans in +// absence of exisiting spans or even context. +// Default is to not trace zipkinsql calls if no existing parent span is found +// in context or when using methods not taking context. +func WithAllowRootSpan(b bool) TraceOption { + return func(o *TraceOptions) { + o.AllowRootSpan = b + } +} + // WithRowsAffectedSpan if set to true, will enable the creation of spans on // RowsAffected calls. func WithRowsAffectedSpan(b bool) TraceOption { From 79cca89d7b12709fbd7a246140179ea8bfabd0fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Ch=C3=A1vez?= Date: Thu, 15 Nov 2018 14:46:47 +0200 Subject: [PATCH 2/3] test: removes 1.8 support and execute tests on CI. --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 9fec69a..008eb06 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ sudo: false language: go go: - - 1.8.x - 1.9.x - 1.10.x - 1.11.x @@ -13,3 +12,4 @@ before_script: script: - go vet ./... - golint ./.. + - go test ./... From 24b349c675be7ff354061d94132cedfdc51bfc37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Ch=C3=A1vez?= Date: Thu, 15 Nov 2018 15:43:31 +0200 Subject: [PATCH 3/3] test: adds tests for Tx and propagation. --- driver.go | 142 +++++++++++++++++++--------------- driver_test.go | 205 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 282 insertions(+), 65 deletions(-) diff --git a/driver.go b/driver.go index e6ef05f..eb1eb09 100644 --- a/driver.go +++ b/driver.go @@ -27,7 +27,6 @@ var ( _ driver.Driver = &zDriver{} _ conn = &zConn{} _ driver.Result = &zResult{} - _ driver.Rows = &zRows{} ) var ( @@ -79,11 +78,11 @@ func wrapDriver(d driver.Driver, t *zipkin.Tracer, o TraceOptions) driver.Driver } func wrapConn(c driver.Conn, t *zipkin.Tracer, options TraceOptions) driver.Conn { - return &zConn{driver: c, tracer: t, options: options} + return &zConn{conn: c, tracer: t, options: options} } func wrapStmt(stmt driver.Stmt, query string, tracer *zipkin.Tracer, options TraceOptions) driver.Stmt { - s := zStmt{driver: stmt, query: query, options: options, tracer: tracer} + s := zStmt{stmt: stmt, query: query, options: options, tracer: tracer} _, hasExeCtx := stmt.(driver.StmtExecContext) _, hasQryCtx := stmt.(driver.StmtQueryContext) c, hasColCnv := stmt.(driver.ColumnConverter) @@ -147,20 +146,20 @@ func (d zDriver) Open(name string) (driver.Conn, error) { // zConn implements driver.Conn type zConn struct { - driver driver.Conn + conn driver.Conn tracer *zipkin.Tracer options TraceOptions } func (c zConn) Ping(ctx context.Context) (err error) { - if pinger, ok := c.driver.(driver.Pinger); ok { + if pinger, ok := c.conn.(driver.Pinger); ok { err = pinger.Ping(ctx) } return } func (c zConn) Exec(query string, args []driver.Value) (res driver.Result, err error) { - if exec, ok := c.driver.(driver.Execer); ok { + if exec, ok := c.conn.(driver.Execer); ok { return exec.Exec(query, args) } @@ -168,7 +167,7 @@ func (c zConn) Exec(query string, args []driver.Value) (res driver.Result, err e } func (c zConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (res driver.Result, err error) { - if execCtx, ok := c.driver.(driver.ExecerContext); ok { + if execCtx, ok := c.conn.(driver.ExecerContext); ok { if zipkin.SpanFromContext(ctx) == nil && !c.options.AllowRootSpan { return execCtx.ExecContext(ctx, query, args) } @@ -187,14 +186,14 @@ func (c zConn) ExecContext(ctx context.Context, query string, args []driver.Name return nil, err } - return zResult{driver: res, tracer: c.tracer, ctx: ctx, options: c.options}, nil + return zResult{result: res, tracer: c.tracer, ctx: ctx, options: c.options}, nil } return nil, driver.ErrSkip } func (c zConn) Query(query string, args []driver.Value) (rows driver.Rows, err error) { - if queryer, ok := c.driver.(driver.Queryer); ok { + if queryer, ok := c.conn.(driver.Queryer); ok { return queryer.Query(query, args) } @@ -202,7 +201,7 @@ func (c zConn) Query(query string, args []driver.Value) (rows driver.Rows, err e } func (c zConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { - if queryerCtx, ok := c.driver.(driver.QueryerContext); ok { + if queryerCtx, ok := c.conn.(driver.QueryerContext); ok { if zipkin.SpanFromContext(ctx) == nil && !c.options.AllowRootSpan { return queryerCtx.QueryContext(ctx, query, args) } @@ -221,14 +220,14 @@ func (c zConn) QueryContext(ctx context.Context, query string, args []driver.Nam return nil, err } - return zRows{driver: rows, ctx: ctx, options: c.options}, nil + return rows, nil } return nil, driver.ErrSkip } func (c zConn) Prepare(query string) (stmt driver.Stmt, err error) { - stmt, err = c.driver.Prepare(query) + stmt, err = c.conn.Prepare(query) if err != nil { return nil, err } @@ -238,7 +237,7 @@ func (c zConn) Prepare(query string) (stmt driver.Stmt, err error) { } func (c *zConn) Close() error { - return c.driver.Close() + return c.conn.Close() } func (c *zConn) Begin() (driver.Tx, error) { @@ -246,20 +245,20 @@ func (c *zConn) Begin() (driver.Tx, error) { } func (c *zConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - if prepCtx, ok := c.driver.(driver.ConnPrepareContext); ok { + if prepCtx, ok := c.conn.(driver.ConnPrepareContext); ok { return prepCtx.PrepareContext(ctx, query) } - return c.driver.Prepare(query) + return c.conn.Prepare(query) } func (c *zConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if zipkin.SpanFromContext(ctx) == nil && !c.options.AllowRootSpan { - if connBeginTx, ok := c.driver.(driver.ConnBeginTx); ok { + if connBeginTx, ok := c.conn.(driver.ConnBeginTx); ok { return connBeginTx.BeginTx(ctx, opts) } - return c.driver.Begin() + return c.conn.Begin() } span, _ := c.tracer.StartSpanFromContext(ctx, "sql/begin_transaction", zipkin.Kind(zipkinmodel.Client)) @@ -267,27 +266,27 @@ func (c *zConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, setSpanDefaultTags(span, c.options.DefaultTags) - if connBeginTx, ok := c.driver.(driver.ConnBeginTx); ok { + if connBeginTx, ok := c.conn.(driver.ConnBeginTx); ok { tx, err := connBeginTx.BeginTx(ctx, opts) setSpanError(span, err) if err != nil { return nil, err } - return zTx{driver: tx, ctx: ctx}, nil + return zTx{tx: tx, ctx: ctx, tracer: c.tracer, options: c.options}, nil } - tx, err := c.driver.Begin() + tx, err := c.conn.Begin() setSpanError(span, err) if err != nil { return nil, err } - return zTx{driver: tx, ctx: ctx, tracer: c.tracer}, nil + return zTx{tx: tx, ctx: ctx, tracer: c.tracer, options: c.options}, nil } // zResult implements driver.Result type zResult struct { - driver driver.Result + result driver.Result ctx context.Context tracer *zipkin.Tracer options TraceOptions @@ -295,7 +294,7 @@ type zResult struct { func (r zResult) LastInsertId() (int64, error) { if !r.options.LastInsertIDSpan { - return r.driver.LastInsertId() + return r.result.LastInsertId() } span, _ := r.tracer.StartSpanFromContext(r.ctx, "sql/last_insert_id", zipkin.Kind(zipkinmodel.Client)) @@ -303,7 +302,7 @@ func (r zResult) LastInsertId() (int64, error) { setSpanDefaultTags(span, r.options.DefaultTags) - id, err := r.driver.LastInsertId() + id, err := r.result.LastInsertId() setSpanError(span, err) return id, err @@ -321,37 +320,80 @@ func (r zResult) RowsAffected() (cnt int64, err error) { }() } - cnt, err = r.driver.RowsAffected() + cnt, err = r.result.RowsAffected() return } // zStmt implements driver.Stmt type zStmt struct { - driver driver.Stmt + stmt driver.Stmt query string tracer *zipkin.Tracer options TraceOptions } -func (s zStmt) Exec(args []driver.Value) (driver.Result, error) { - return s.driver.Exec(args) +func (s zStmt) Exec(args []driver.Value) (res driver.Result, err error) { + if !s.options.AllowRootSpan { + return s.stmt.Exec(args) + } + + span, ctx := s.tracer.StartSpanFromContext(context.Background(), "sql:exec", zipkin.Kind(zipkinmodel.Client)) + setSpanDefaultTags(span, s.options.DefaultTags) + + if s.options.TagQuery { + span.Tag("sql.query", s.query) + } + + defer func() { + setSpanError(span, err) + span.Finish() + }() + + res, err = s.stmt.Exec(args) + if err != nil { + return nil, err + } + + res, err = zResult{result: res, ctx: ctx, tracer: s.tracer, options: s.options}, nil + return } func (s zStmt) Close() error { - return s.driver.Close() + return s.stmt.Close() } func (s zStmt) NumInput() int { - return s.driver.NumInput() + return s.stmt.NumInput() } -func (s zStmt) Query(args []driver.Value) (driver.Rows, error) { - return s.driver.Query(args) +func (s zStmt) Query(args []driver.Value) (rows driver.Rows, err error) { + if !s.options.AllowRootSpan { + return s.stmt.Query(args) + } + + span, _ := s.tracer.StartSpanFromContext(context.Background(), "sql:query", zipkin.Kind(zipkinmodel.Client)) + setSpanDefaultTags(span, s.options.DefaultTags) + + if s.options.TagQuery { + span.Tag("sql.query", s.query) + } + + defer func() { + setSpanError(span, err) + span.Finish() + }() + + rows, err = s.stmt.Query(args) + if err != nil { + return nil, err + } + + return } func (s zStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { if zipkin.SpanFromContext(ctx) == nil && !s.options.AllowRootSpan { - return s.driver.(driver.StmtExecContext).ExecContext(ctx, args) + return s.stmt.(driver.StmtExecContext).ExecContext(ctx, args) } span, ctx := s.tracer.StartSpanFromContext(ctx, "sql/exec", zipkin.Kind(zipkinmodel.Client)) @@ -366,7 +408,7 @@ func (s zStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res d setSpanDefaultTags(span, s.options.DefaultTags) - execContext := s.driver.(driver.StmtExecContext) + execContext := s.stmt.(driver.StmtExecContext) res, err = execContext.ExecContext(ctx, args) if err != nil { return nil, err @@ -378,13 +420,13 @@ func (s zStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res d } } - res, err = zResult{driver: res, tracer: s.tracer, ctx: ctx, options: s.options}, nil + res, err = zResult{result: res, tracer: s.tracer, ctx: ctx, options: s.options}, nil return } func (s zStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { if zipkin.SpanFromContext(ctx) == nil && !s.options.AllowRootSpan { - return s.driver.(driver.StmtQueryContext).QueryContext(ctx, args) + return s.stmt.(driver.StmtQueryContext).QueryContext(ctx, args) } span, ctx := s.tracer.StartSpanFromContext(ctx, "sql/query", zipkin.Kind(zipkinmodel.Client)) @@ -405,38 +447,18 @@ func (s zStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows }() // we already tested driver to implement StmtQueryContext - queryContext := s.driver.(driver.StmtQueryContext) + queryContext := s.stmt.(driver.StmtQueryContext) rows, err = queryContext.QueryContext(ctx, args) if err != nil { return nil, err } - rows, err = zRows{driver: rows, ctx: ctx, options: s.options}, nil return } -// zRows implements driver.Rows. -type zRows struct { - driver driver.Rows - ctx context.Context - options TraceOptions -} - -func (r zRows) Columns() []string { - return r.driver.Columns() -} - -func (r zRows) Close() error { - return r.driver.Close() -} - -func (r zRows) Next(dest []driver.Value) error { - return r.driver.Next(dest) -} - // zTx implemens driver.Tx type zTx struct { - driver driver.Tx + tx driver.Tx ctx context.Context tracer *zipkin.Tracer options TraceOptions @@ -451,7 +473,7 @@ func (t zTx) Commit() (err error) { span.Finish() }() } - err = t.driver.Commit() + err = t.tx.Commit() return } @@ -464,7 +486,7 @@ func (t zTx) Rollback() (err error) { span.Finish() }() } - err = t.driver.Rollback() + err = t.tx.Rollback() return } diff --git a/driver_test.go b/driver_test.go index 89e5acb..bbf5eba 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3,6 +3,7 @@ package zipkinsql import ( "context" "database/sql" + "fmt" "testing" _ "github.com/mattn/go-sqlite3" @@ -10,7 +11,7 @@ import ( zipkinreporter "github.com/openzipkin/zipkin-go/reporter/recorder" ) -func createDB(t *testing.T, opts ...TraceOption) (*sql.DB, *zipkinreporter.ReporterRecorder) { +func createDB(t *testing.T, opts ...TraceOption) (*sql.DB, *zipkin.Tracer, *zipkinreporter.ReporterRecorder) { reporter := zipkinreporter.NewReporter() tracer, _ := zipkin.NewTracer(reporter) @@ -24,7 +25,7 @@ func createDB(t *testing.T, opts ...TraceOption) (*sql.DB, *zipkinreporter.Repor t.Fatal(err) } - return db, reporter + return db, tracer, reporter } type testCase struct { @@ -33,13 +34,52 @@ type testCase struct { } func TestQuerySuccess(t *testing.T) { + testCases := []testCase{ + {[]TraceOption{WithAllowRootSpan(false)}, 0}, + {[]TraceOption{WithAllowRootSpan(true)}, 1}, + } + for _, c := range testCases { + db, _, recorder := createDB(t, c.opts...) + + rows, err := db.Query("SELECT 1") + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + defer rows.Close() + + for rows.Next() { + var n int + if err = rows.Scan(&n); err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + } + if err = rows.Err(); err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + spans := recorder.Flush() + if want, have := c.expectedSpans, len(spans); want != have { + t.Fatalf("unexpected number of spans, want: %d, have: %d", want, have) + } + + if c.expectedSpans > 0 { + if want, have := "sql/query", spans[0].Name; want != have { + t.Fatalf("unexpected span name, want: %s, have: %s", want, have) + } + } + + db.Close() + recorder.Close() + } +} +func TestQueryContextSuccess(t *testing.T) { ctx := context.Background() testCases := []testCase{ {[]TraceOption{WithAllowRootSpan(false)}, 0}, {[]TraceOption{WithAllowRootSpan(true)}, 1}, } for _, c := range testCases { - db, recorder := createDB(t, c.opts...) + db, _, recorder := createDB(t, c.opts...) rows, err := db.QueryContext(ctx, "SELECT 1") if err != nil { @@ -73,7 +113,44 @@ func TestQuerySuccess(t *testing.T) { } } -func TestExecSuccess(t *testing.T) { +func TestQueryContextPropagationSuccess(t *testing.T) { + ctx := context.Background() + db, tracer, recorder := createDB(t, WithAllowRootSpan(false)) + + span, ctx := tracer.StartSpanFromContext(ctx, "root") + + rows, err := db.QueryContext(ctx, "SELECT 1") + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + defer rows.Close() + + for rows.Next() { + var n int + if err = rows.Scan(&n); err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + } + if err = rows.Err(); err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + span.Finish() + + spans := recorder.Flush() + if want, have := 2, len(spans); want != have { + t.Fatalf("unexpected number of spans, want: %d, have: %d", want, have) + } + + if want, have := "sql/query", spans[0].Name; want != have { + t.Fatalf("unexpected span name, want: %s, have: %s", want, have) + } + + db.Close() + recorder.Close() +} + +func TestExecContextSuccess(t *testing.T) { ctx := context.Background() testCases := []testCase{ @@ -84,7 +161,7 @@ func TestExecSuccess(t *testing.T) { {[]TraceOption{WithAllowRootSpan(true), WithLastInsertIDSpan(true), WithRowsAffectedSpan(true)}, 3}, } for _, c := range testCases { - db, recorder := createDB(t, c.opts...) + db, _, recorder := createDB(t, c.opts...) sqlStmt := ` create table foo (id integer not null primary key, name text); @@ -121,3 +198,121 @@ func TestExecSuccess(t *testing.T) { recorder.Close() } } + +func TestTxWithCommitSuccess(t *testing.T) { + ctx := context.Background() + + testCases := []testCase{ + {[]TraceOption{WithAllowRootSpan(false)}, 0}, + {[]TraceOption{WithAllowRootSpan(true)}, 3}, + } + + for _, c := range testCases { + db, _, recorder := createDB(t, c.opts...) + + sqlStmt := ` + create table foo (id integer not null primary key, name text); + delete from foo; +` + + _, err := db.ExecContext(ctx, sqlStmt) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + stmt, err := tx.Prepare("insert into foo(id, name) values(?, ?)") + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + defer stmt.Close() + for i := 0; i < 100; i++ { + _, err = stmt.Exec(i, fmt.Sprintf("こんにちわ世界%03d", i)) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + } + tx.Commit() + + spans := recorder.Flush() + if want, have := c.expectedSpans, len(spans); want != have { + t.Fatalf("unexpected number of spans, want: %d, have: %d", want, have) + } + + if c.expectedSpans > 0 { + if want, have := "sql/exec", spans[0].Name; want != have { + t.Fatalf("unexpected first span name, want: %s, have: %s", want, have) + } + if want, have := "sql/begin_transaction", spans[1].Name; want != have { + t.Fatalf("unexpected first span name, want: %s, have: %s", want, have) + } + if want, have := "sql/commit", spans[2].Name; want != have { + t.Fatalf("unexpected second span name, want: %s, have: %s", want, have) + } + } + db.Close() + recorder.Close() + } +} + +func TestTxWithRollbackSuccess(t *testing.T) { + ctx := context.Background() + + testCases := []testCase{ + {[]TraceOption{WithAllowRootSpan(false)}, 0}, + {[]TraceOption{WithAllowRootSpan(true)}, 3}, + } + + for _, c := range testCases { + db, _, recorder := createDB(t, c.opts...) + + sqlStmt := ` + create table foo (id integer not null primary key, name text); + delete from foo; +` + + _, err := db.ExecContext(ctx, sqlStmt) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + stmt, err := tx.Prepare("insert into foo(id, name) values(?, ?)") + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + defer stmt.Close() + for i := 0; i < 100; i++ { + _, err = stmt.Exec(i, fmt.Sprintf("こんにちわ世界%03d", i)) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + } + tx.Rollback() + + spans := recorder.Flush() + if want, have := c.expectedSpans, len(spans); want != have { + t.Fatalf("unexpected number of spans, want: %d, have: %d", want, have) + } + + if c.expectedSpans > 0 { + if want, have := "sql/exec", spans[0].Name; want != have { + t.Fatalf("unexpected first span name, want: %s, have: %s", want, have) + } + if want, have := "sql/begin_transaction", spans[1].Name; want != have { + t.Fatalf("unexpected first span name, want: %s, have: %s", want, have) + } + if want, have := "sql/rollback", spans[2].Name; want != have { + t.Fatalf("unexpected second span name, want: %s, have: %s", want, have) + } + } + db.Close() + recorder.Close() + } +}