2025-06-18 17:50:02 +08:00

57 lines
1004 B
Go

package repository
import (
"context"
"management/internal/pkg/sqldb"
"github.com/drhin/logger"
"github.com/jmoiron/sqlx"
)
var core *sqlx.DB
type txCtxKey struct{}
type Store struct {
db sqlx.ExtContext
}
func NewStore(db *sqlx.DB) *Store {
if core == nil {
core = db
}
return &Store{
db: db,
}
}
func (s *Store) DB(ctx context.Context) sqlx.ExtContext {
if tx, ok := ctx.Value(txCtxKey{}).(sqldb.CommitRollbacker); ok {
if res, err := sqldb.GetExtContext(tx); err == nil {
return res
}
}
return s.db
}
func Transaction(ctx context.Context, log *logger.Logger, fn func(c context.Context) error) error {
beginner := sqldb.NewBeginner(core)
tx, err := beginner.Begin()
if err != nil {
log.Error("begin transaction error", err)
return err
}
ctx = context.WithValue(ctx, txCtxKey{}, tx)
if err := fn(ctx); err != nil {
if err := tx.Rollback(); err != nil {
log.Error("rollback transaction error", err)
return err
}
return err
}
return tx.Commit()
}