2025-06-17 10:50:08 +08:00

165 lines
3.7 KiB
Go

package repository
import (
"context"
"fmt"
"time"
"management/internal/pkg/config"
"github.com/drhin/logger"
"gorm.io/driver/postgres"
"gorm.io/gorm"
gl "gorm.io/gorm/logger"
)
type txCtxKey struct{}
type Repository struct {
db *gorm.DB
logger *logger.Logger
}
func NewRepository(db *gorm.DB, logger *logger.Logger) *Repository {
return &Repository{
db: db,
logger: logger,
}
}
type Transaction interface {
Transaction(ctx context.Context, fn func(ctx context.Context) error) error
}
func NewTransaction(r *Repository) Transaction {
return r
}
func (r *Repository) DB(ctx context.Context) *gorm.DB {
v := ctx.Value(txCtxKey{})
if v != nil {
if tx, ok := v.(*gorm.DB); ok {
return tx
}
}
return r.db.WithContext(ctx)
}
func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
ctx = context.WithValue(ctx, txCtxKey{}, tx)
return fn(ctx)
})
}
func NewDB(config *config.Config, log *logger.Logger) (*gorm.DB, func(), error) {
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable TimeZone=Asia/Shanghai",
config.DB.Host,
config.DB.Username,
config.DB.Password,
config.DB.DBName,
config.DB.Port,
)
pgConfig := postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}
db, err := gorm.Open(postgres.New(pgConfig), &gorm.Config{
Logger: getGormLogger(config),
})
if err != nil {
return nil, nil, err
}
// db.Debug 会默认显示日志
//db = db.Debug()
// Connection Pool config
sqlDB, err := db.DB()
if err != nil {
return nil, nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := sqlDB.PingContext(ctx); err != nil {
return nil, nil, err
}
// 设置最大空闲连接数(默认 2)
sqlDB.SetMaxIdleConns(config.DB.MaxIdleConns)
// 设置最大打开连接数(默认 0 无限制)
sqlDB.SetMaxOpenConns(config.DB.MaxOpenConns)
// 设置连接最大存活时间
sqlDB.SetConnMaxLifetime(config.DB.ConnMaxLifetime)
// 设置连接最大空闲时间
sqlDB.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime)
cleanup := func() {
if err := sqlDB.Close(); err != nil {
log.Error("sql db close error", err)
}
}
return db, cleanup, nil
}
func getGormLogger(config *config.Config) gl.Interface {
if config.DB.LogMode {
return gl.Default.LogMode(gl.Info) // 开发环境显示日志
}
return gl.Default.LogMode(gl.Silent)
}
//var (
// once sync.Once
// // 全局变量,方便其它包直接调用已初始化好的 datastore 实例.
// engine *datastore
//)
//
//type Store interface {
// DB(ctx context.Context) *gorm.DB
// TX(ctx context.Context, fn func(ctx context.Context) error) error
//}
//
//// transactionKey 用于在 context.Context 中存储事务上下文的键.
//type transactionKey struct{}
//
//// datastore 是 Store 的具体实现.
//type datastore struct {
// core *gorm.DB
//}
//
//// NewStore 创建一个 Store 类型的实例.
//func NewStore(db *gorm.DB) Store {
// // 确保 engine 只被初始化一次
// once.Do(func() {
// engine = &datastore{db}
// })
//
// return engine
//}
//
//func (store *datastore) DB(ctx context.Context) *gorm.DB {
// db := store.core
// // 从上下文中提取事务实例
// if tx, ok := ctx.Value(transactionKey{}).(*gorm.DB); ok {
// db = tx
// }
//
// return db
//}
//
//func (store *datastore) TX(ctx context.Context, fn func(ctx context.Context) error) error {
// return store.core.WithContext(ctx).Transaction(
// func(tx *gorm.DB) error {
// ctx = context.WithValue(ctx, transactionKey{}, tx)
// return fn(ctx)
// },
// )
//}