From 6b50981f70b43939e5b8e94b54e34b99149ba39b Mon Sep 17 00:00:00 2001 From: CyJaySong Date: Fri, 17 Jan 2025 13:59:25 +0800 Subject: [PATCH] =?UTF-8?q?feat(database/gdb)=20Begin=E5=BC=80=E5=90=AF?= =?UTF-8?q?=E4=BA=8B=E5=8A=A1=E5=85=81=E8=AE=B8tx.GetCtx()=E7=94=A8?= =?UTF-8?q?=E4=BA=8E=E4=BA=8B=E5=8A=A1=E4=BC=A0=E9=80=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../mysql/mysql_z_unit_transaction_test.go | 41 +++++++++++++++++++ database/gdb/gdb_core_transaction.go | 2 + database/gdb/gdb_core_underlying.go | 7 +++- 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/contrib/drivers/mysql/mysql_z_unit_transaction_test.go b/contrib/drivers/mysql/mysql_z_unit_transaction_test.go index 89ed7eba7a5..e9c3315b447 100644 --- a/contrib/drivers/mysql/mysql_z_unit_transaction_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_transaction_test.go @@ -1708,3 +1708,44 @@ func Test_Transaction_Isolation(t *testing.T) { t.AssertNil(err) }) } + +func Test_Transaction_Spread(t *testing.T) { + table := createTable() + defer dropTable(table) + + db.SetDebug(true) + defer db.SetDebug(false) + + gtest.C(t, func(t *gtest.T) { + var ( + err error + ctx = context.TODO() + ) + tx, err := db.Begin(ctx) + t.AssertNil(err) + err = db.Transaction(tx.GetCtx(), func(ctx context.Context, tx gdb.TX) error { + _, err = db.Model(table).Ctx(ctx).Data(g.Map{ + "id": 1, + "passport": "USER_1", + "password": "PASS_1", + "nickname": "NAME_1", + "create_time": gtime.Now().String(), + }).Insert() + return err + }) + t.AssertNil(err) + + all, err := tx.Model(table).All() + t.AssertNil(err) + + t.Assert(len(all), 1) + t.Assert(all[0]["id"], 1) + + err = tx.Rollback() + t.AssertNil(err) + + all, err = db.Ctx(ctx).Model(table).All() + t.AssertNil(err) + t.Assert(len(all), 0) + }) +} diff --git a/database/gdb/gdb_core_transaction.go b/database/gdb/gdb_core_transaction.go index 897b179b1d2..1faa73a6205 100644 --- a/database/gdb/gdb_core_transaction.go +++ b/database/gdb/gdb_core_transaction.go @@ -257,12 +257,14 @@ func WithTX(ctx context.Context, tx TX) context.Context { } // Inject transaction object and id into context. ctx = context.WithValue(ctx, transactionKeyForContext(group), tx) + ctx = context.WithValue(ctx, transactionIdForLoggerCtx, tx.GetCtx().Value(transactionIdForLoggerCtx)) return ctx } // WithoutTX removed transaction object from context and returns a new context. func WithoutTX(ctx context.Context, group string) context.Context { ctx = context.WithValue(ctx, transactionKeyForContext(group), nil) + ctx = context.WithValue(ctx, transactionIdForLoggerCtx, nil) return ctx } diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index 25c60a4baf7..6e06ff0a50f 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -180,14 +180,17 @@ func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutp formattedSql, in.TxOptions.Isolation.String(), in.TxOptions.ReadOnly, ) if sqlTx, err = in.Db.BeginTx(ctx, &in.TxOptions); err == nil { - out.Tx = &TXCore{ + tx := &TXCore{ db: c.db, tx: sqlTx, - ctx: context.WithValue(ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1)), + ctx: ctx, master: in.Db, transactionId: guid.S(), cancelFunc: cancelFuncForTimeout, } + tx.ctx = context.WithValue(ctx, transactionKeyForContext(tx.db.GetGroup()), tx) + tx.ctx = context.WithValue(tx.ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1)) + out.Tx = tx ctx = out.Tx.GetCtx() } out.RawResult = sqlTx