From a4c8e930f1876064289fa49e85293470194b5506 Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Thu, 28 Apr 2022 21:35:58 +0800 Subject: [PATCH 1/3] feat: trace sharding log --- conn_pool.go | 5 +++++ sharding.go | 20 +++++++++++++++++++- sharding_test.go | 22 ++++++++++++++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/conn_pool.go b/conn_pool.go index 3beaa7e..cc04368 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -3,6 +3,7 @@ package sharding import ( "context" "database/sql" + "sync" "gorm.io/gorm" ) @@ -12,6 +13,7 @@ type ConnPool struct { // db, This is global db instance sharding *Sharding gorm.ConnPool + settings *sync.Map } func (pool *ConnPool) String() string { @@ -29,6 +31,7 @@ func (pool ConnPool) ExecContext(ctx context.Context, query string, args ...inte } pool.sharding.querys.Store("last_query", stQuery) + pool.settings.Store(ShardingQueryStoreKey, stQuery) if table != "" { if r, ok := pool.sharding.configs[table]; ok { @@ -49,6 +52,7 @@ func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...int } pool.sharding.querys.Store("last_query", stQuery) + pool.settings.Store(ShardingQueryStoreKey, stQuery) if table != "" { if r, ok := pool.sharding.configs[table]; ok { @@ -64,6 +68,7 @@ func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...int func (pool ConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { _, query, _, _ = pool.sharding.resolve(query, args...) pool.sharding.querys.Store("last_query", query) + pool.settings.Store(ShardingQueryStoreKey, query) return pool.ConnPool.QueryRowContext(ctx, query, args...) } diff --git a/sharding.go b/sharding.go index cbb8d1e..88864bc 100644 --- a/sharding.go +++ b/sharding.go @@ -18,6 +18,8 @@ var ( ErrInvalidID = errors.New("invalid id format") ) +var ShardingQueryStoreKey = "sharding_query" + type Sharding struct { *gorm.DB ConnPool *ConnPool @@ -252,13 +254,29 @@ func (s *Sharding) registerCallbacks(db *gorm.DB) { s.Callback().Delete().Before("*").Register("gorm:sharding", s.switchConn) s.Callback().Row().Before("*").Register("gorm:sharding", s.switchConn) s.Callback().Raw().Before("*").Register("gorm:sharding", s.switchConn) + + s.Callback().Create().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL) + s.Callback().Query().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL) + s.Callback().Update().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL) + s.Callback().Delete().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL) + s.Callback().Row().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL) + s.Callback().Raw().After("*").Register("gorm:sharding:mocksql", s.mockShardingSQL) } func (s *Sharding) switchConn(db *gorm.DB) { - s.ConnPool = &ConnPool{ConnPool: db.Statement.ConnPool, sharding: s} + s.ConnPool = &ConnPool{ConnPool: db.Statement.ConnPool, sharding: s, settings: &db.Statement.Settings} db.Statement.ConnPool = s.ConnPool } +// When all callbacks are executed, SQL is only used for log tracing, so there are no side effects. +// stmt.Settings will cloned even in save association. +func (s *Sharding) mockShardingSQL(db *gorm.DB) { + if sql, ok := db.Get(ShardingQueryStoreKey); ok { + db.Statement.SQL.Reset() + db.Statement.SQL.WriteString(sql.(string)) + } +} + // resolve split the old query to full table query and sharding table query func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, tableName string, err error) { ftQuery = query diff --git a/sharding_test.go b/sharding_test.go index ac0c8dc..c8bb2ed 100644 --- a/sharding_test.go +++ b/sharding_test.go @@ -6,6 +6,7 @@ import ( "regexp" "sort" "strings" + "sync" "testing" "github.com/bwmarrin/snowflake" @@ -377,6 +378,27 @@ func TestReadWriteSplitting(t *testing.T) { assert.Equal(t, "iPhone", order.Product) } +func TestTraceSQL(t *testing.T) { + tx := db.Session(&gorm.Session{NewDB: true}) + // we can set mode to view logs + // tx.Logger = tx.Logger.LogMode(logger.Info) + + expected := `INSERT INTO orders_0 ("user_id", "product", id) VALUES` + + wg := sync.WaitGroup{} + var mockSql string + tx.Callback().Create().After("gorm:sharding:mocksql").Register("gorm:TestTraceSQL", func(d *gorm.DB) { + mockSql = d.Statement.SQL.String() + tx.Callback().Create().Remove("gorm:TestTraceSQL") + wg.Done() + }) + + wg.Add(1) + tx.Create(&Order{UserID: 1000, Product: "TestTraceSQL"}) + wg.Wait() + assert.Equal(t, expected, mockSql[0:len(expected)]) +} + func assertQueryResult(t *testing.T, expected string, tx *gorm.DB) { t.Helper() assert.Equal(t, toDialect(expected), middleware.LastQuery()) From babfe7c7c33c263d690cfcb2f6c838fef89030fc Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Thu, 28 Apr 2022 21:47:14 +0800 Subject: [PATCH 2/3] chore: rename test --- sharding_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sharding_test.go b/sharding_test.go index c8bb2ed..4aad28a 100644 --- a/sharding_test.go +++ b/sharding_test.go @@ -378,7 +378,7 @@ func TestReadWriteSplitting(t *testing.T) { assert.Equal(t, "iPhone", order.Product) } -func TestTraceSQL(t *testing.T) { +func TestTraceSqlLog(t *testing.T) { tx := db.Session(&gorm.Session{NewDB: true}) // we can set mode to view logs // tx.Logger = tx.Logger.LogMode(logger.Info) From 7d1bb316a428415507b752b6305d7382386cb977 Mon Sep 17 00:00:00 2001 From: a631807682 <631807682@qq.com> Date: Sat, 30 Apr 2022 13:33:03 +0800 Subject: [PATCH 3/3] fix: plugin register callback order --- sharding_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sharding_test.go b/sharding_test.go index 4aad28a..02b972b 100644 --- a/sharding_test.go +++ b/sharding_test.go @@ -360,11 +360,13 @@ func TestReadWriteSplitting(t *testing.T) { }) } + // plugin register order + // https://github.com/go-gorm/gorm/pull/5304 + db.Use(middleware) db.Use(dbresolver.Register(dbresolver.Config{ Sources: []gorm.Dialector{dbWrite.Dialector}, Replicas: []gorm.Dialector{dbRead.Dialector}, })) - db.Use(middleware) var order Order db.Model(&Order{}).Where("user_id", 100).Find(&order)