package repository import ( "context" "fmt" "time" "management/internal/pkg/config" "github.com/drhin/logger" "gorm.io/driver/postgres" "gorm.io/gorm" gl "gorm.io/gorm/logger" ) type txCtxKey struct{} type Repository struct { db *gorm.DB logger *logger.Logger } func NewRepository(db *gorm.DB, logger *logger.Logger) *Repository { return &Repository{ db: db, logger: logger, } } type Transaction interface { Transaction(ctx context.Context, fn func(ctx context.Context) error) error } func NewTransaction(r *Repository) Transaction { return r } func (r *Repository) DB(ctx context.Context) *gorm.DB { v := ctx.Value(txCtxKey{}) if v != nil { if tx, ok := v.(*gorm.DB); ok { return tx } } return r.db.WithContext(ctx) } func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { ctx = context.WithValue(ctx, txCtxKey{}, tx) return fn(ctx) }) } func NewDB(log *logger.Logger, config *config.Config) (*gorm.DB, func(), error) { dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable TimeZone=Asia/Shanghai", config.DB.Host, config.DB.Username, config.DB.Password, config.DB.DBName, config.DB.Port, ) pgConfig := postgres.Config{ DSN: dsn, PreferSimpleProtocol: true, // disables implicit prepared statement usage } db, err := gorm.Open(postgres.New(pgConfig), &gorm.Config{ Logger: getGormLogger(config), }) if err != nil { return nil, nil, err } // db.Debug 会默认显示日志 //db = db.Debug() // Connection Pool config sqlDB, err := db.DB() if err != nil { return nil, nil, err } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() if err := sqlDB.PingContext(ctx); err != nil { return nil, nil, err } // 设置最大空闲连接数(默认 2) sqlDB.SetMaxIdleConns(config.DB.MaxIdleConns) // 设置最大打开连接数(默认 0 无限制) sqlDB.SetMaxOpenConns(config.DB.MaxOpenConns) // 设置连接最大存活时间 sqlDB.SetConnMaxLifetime(config.DB.ConnMaxLifetime) // 设置连接最大空闲时间 sqlDB.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime) cleanup := func() { if err := sqlDB.Close(); err != nil { log.Error("sql db close error", err) } } return db, cleanup, nil } func getGormLogger(config *config.Config) gl.Interface { if config.DB.LogMode { return gl.Default.LogMode(gl.Info) // 开发环境显示日志 } return gl.Default.LogMode(gl.Silent) } //var ( // once sync.Once // // 全局变量,方便其它包直接调用已初始化好的 datastore 实例. // engine *datastore //) // //type Store interface { // DB(ctx context.Context) *gorm.DB // TX(ctx context.Context, fn func(ctx context.Context) error) error //} // //// transactionKey 用于在 context.Context 中存储事务上下文的键. //type transactionKey struct{} // //// datastore 是 Store 的具体实现. //type datastore struct { // core *gorm.DB //} // //// NewStore 创建一个 Store 类型的实例. //func NewStore(db *gorm.DB) Store { // // 确保 engine 只被初始化一次 // once.Do(func() { // engine = &datastore{db} // }) // // return engine //} // //func (store *datastore) DB(ctx context.Context) *gorm.DB { // db := store.core // // 从上下文中提取事务实例 // if tx, ok := ctx.Value(transactionKey{}).(*gorm.DB); ok { // db = tx // } // // return db //} // //func (store *datastore) TX(ctx context.Context, fn func(ctx context.Context) error) error { // return store.core.WithContext(ctx).Transaction( // func(tx *gorm.DB) error { // ctx = context.WithValue(ctx, transactionKey{}, tx) // return fn(ctx) // }, // ) //}