Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: trace sharding log #29

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sharding
import (
"context"
"database/sql"
"sync"

"gorm.io/gorm"
)
Expand All @@ -12,6 +13,7 @@ type ConnPool struct {
// db, This is global db instance
sharding *Sharding
gorm.ConnPool
settings *sync.Map
}

func (pool *ConnPool) String() string {
Expand All @@ -29,6 +31,7 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...inte
}

pool.sharding.querys.Store("last_query", stQuery)
pool.settings.Store(ShardingQueryStoreKey, stQuery)

if table != "" {
if r, ok := pool.sharding.configs[table]; ok {
Expand All @@ -49,6 +52,7 @@ func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...int
}

pool.sharding.querys.Store("last_query", stQuery)
pool.settings.Store(ShardingQueryStoreKey, stQuery)

if table != "" {
if r, ok := pool.sharding.configs[table]; ok {
Expand All @@ -64,6 +68,7 @@ func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...int
func (pool ConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
_, query, _, _ = pool.sharding.resolve(query, args...)
pool.sharding.querys.Store("last_query", query)
pool.settings.Store(ShardingQueryStoreKey, query)

return pool.ConnPool.QueryRowContext(ctx, query, args...)
}
Expand Down
19 changes: 18 additions & 1 deletion sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var (

var (
ShardingIgnoreStoreKey = "sharding_ignore"
ShardingQueryStoreKey = "sharding_query"
)

type Sharding struct {
Expand Down Expand Up @@ -256,18 +257,34 @@ func (s *Sharding) registerCallbacks(db *gorm.DB) {
s.Callback().Delete().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Row().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Raw().Before("*").Register("gorm:sharding", s.switchConn)

s.Callback().Create().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL)
s.Callback().Query().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL)
s.Callback().Update().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL)
s.Callback().Delete().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL)
s.Callback().Row().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL)
s.Callback().Raw().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL)
}

func (s *Sharding) switchConn(db *gorm.DB) {
// Support ignore sharding in some case, like:
// When DoubleWrite is enabled, we need to query database schema
// information by table name during the migration.
if _, ok := db.Get(ShardingIgnoreStoreKey); !ok {
s.ConnPool = &ConnPool{ConnPool: db.Statement.ConnPool, sharding: s}
s.ConnPool = &ConnPool{ConnPool: db.Statement.ConnPool, sharding: s, settings: &db.Statement.Settings}
db.Statement.ConnPool = s.ConnPool
}
}

// When all callbacks are executed, SQL is only used for log tracing, so there are no side effects.
// stmt.Settings will cloned even in save association.
func (s *Sharding) mockShardingSQL(db *gorm.DB) {
if sql, ok := db.Get(ShardingQueryStoreKey); ok {
db.Statement.SQL.Reset()
db.Statement.SQL.WriteString(sql.(string))
}
}

// resolve split the old query to full table query and sharding table query
func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, tableName string, err error) {
ftQuery = query
Expand Down
26 changes: 25 additions & 1 deletion sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"regexp"
"sort"
"strings"
"sync"
"testing"

"github.com/bwmarrin/snowflake"
Expand Down Expand Up @@ -363,11 +364,13 @@ func TestReadWriteSplitting(t *testing.T) {
})
}

// plugin register order
// https://github.com/go-gorm/gorm/pull/5304
db.Use(middleware)
db.Use(dbresolver.Register(dbresolver.Config{
Sources: []gorm.Dialector{dbWrite.Dialector},
Replicas: []gorm.Dialector{dbRead.Dialector},
}))
db.Use(middleware)

var order Order
db.Model(&Order{}).Where("user_id", 100).Find(&order)
Expand All @@ -381,6 +384,27 @@ func TestReadWriteSplitting(t *testing.T) {
assert.Equal(t, "iPhone", order.Product)
}

func TestTraceSqlLog(t *testing.T) {
tx := db.Session(&gorm.Session{NewDB: true})
// we can set mode to view logs
// tx.Logger = tx.Logger.LogMode(logger.Info)

expected := `INSERT INTO orders_0 ("user_id", "product", id) VALUES`

wg := sync.WaitGroup{}
var mockSql string
tx.Callback().Create().After("gorm:sharding:mocksql").Register("gorm:TestTraceSQL", func(d *gorm.DB) {
mockSql = d.Statement.SQL.String()
tx.Callback().Create().Remove("gorm:TestTraceSQL")
wg.Done()
})

wg.Add(1)
tx.Create(&Order{UserID: 1000, Product: "TestTraceSQL"})
wg.Wait()
assert.Equal(t, expected, mockSql[0:len(expected)])
}

func assertQueryResult(t *testing.T, expected string, tx *gorm.DB) {
t.Helper()
assert.Equal(t, toDialect(expected), middleware.LastQuery())
Expand Down