165 lines
3.7 KiB
Go
165 lines
3.7 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"management/internal/pkg/config"
|
|
|
|
"github.com/drhin/logger"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
gl "gorm.io/gorm/logger"
|
|
)
|
|
|
|
type txCtxKey struct{}
|
|
|
|
type Repository struct {
|
|
db *gorm.DB
|
|
logger *logger.Logger
|
|
}
|
|
|
|
func NewRepository(db *gorm.DB, logger *logger.Logger) *Repository {
|
|
return &Repository{
|
|
db: db,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
type Transaction interface {
|
|
Transaction(ctx context.Context, fn func(ctx context.Context) error) error
|
|
}
|
|
|
|
func NewTransaction(r *Repository) Transaction {
|
|
return r
|
|
}
|
|
|
|
func (r *Repository) DB(ctx context.Context) *gorm.DB {
|
|
v := ctx.Value(txCtxKey{})
|
|
if v != nil {
|
|
if tx, ok := v.(*gorm.DB); ok {
|
|
return tx
|
|
}
|
|
}
|
|
return r.db.WithContext(ctx)
|
|
}
|
|
|
|
func (r *Repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
|
|
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
ctx = context.WithValue(ctx, txCtxKey{}, tx)
|
|
return fn(ctx)
|
|
})
|
|
}
|
|
|
|
func NewDB(config *config.Config, log *logger.Logger) (*gorm.DB, func(), error) {
|
|
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable TimeZone=Asia/Shanghai",
|
|
config.DB.Host,
|
|
config.DB.Username,
|
|
config.DB.Password,
|
|
config.DB.DBName,
|
|
config.DB.Port,
|
|
)
|
|
pgConfig := postgres.Config{
|
|
DSN: dsn,
|
|
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
|
}
|
|
db, err := gorm.Open(postgres.New(pgConfig), &gorm.Config{
|
|
Logger: getGormLogger(config),
|
|
})
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// db.Debug 会默认显示日志
|
|
//db = db.Debug()
|
|
|
|
// Connection Pool config
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
defer cancel()
|
|
|
|
if err := sqlDB.PingContext(ctx); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// 设置最大空闲连接数(默认 2)
|
|
sqlDB.SetMaxIdleConns(config.DB.MaxIdleConns)
|
|
|
|
// 设置最大打开连接数(默认 0 无限制)
|
|
sqlDB.SetMaxOpenConns(config.DB.MaxOpenConns)
|
|
|
|
// 设置连接最大存活时间
|
|
sqlDB.SetConnMaxLifetime(config.DB.ConnMaxLifetime)
|
|
|
|
// 设置连接最大空闲时间
|
|
sqlDB.SetConnMaxIdleTime(config.DB.ConnMaxIdleTime)
|
|
|
|
cleanup := func() {
|
|
if err := sqlDB.Close(); err != nil {
|
|
log.Error("sql db close error", err)
|
|
}
|
|
}
|
|
|
|
return db, cleanup, nil
|
|
}
|
|
|
|
func getGormLogger(config *config.Config) gl.Interface {
|
|
if config.DB.LogMode {
|
|
return gl.Default.LogMode(gl.Info) // 开发环境显示日志
|
|
}
|
|
return gl.Default.LogMode(gl.Silent)
|
|
}
|
|
|
|
//var (
|
|
// once sync.Once
|
|
// // 全局变量,方便其它包直接调用已初始化好的 datastore 实例.
|
|
// engine *datastore
|
|
//)
|
|
//
|
|
//type Store interface {
|
|
// DB(ctx context.Context) *gorm.DB
|
|
// TX(ctx context.Context, fn func(ctx context.Context) error) error
|
|
//}
|
|
//
|
|
//// transactionKey 用于在 context.Context 中存储事务上下文的键.
|
|
//type transactionKey struct{}
|
|
//
|
|
//// datastore 是 Store 的具体实现.
|
|
//type datastore struct {
|
|
// core *gorm.DB
|
|
//}
|
|
//
|
|
//// NewStore 创建一个 Store 类型的实例.
|
|
//func NewStore(db *gorm.DB) Store {
|
|
// // 确保 engine 只被初始化一次
|
|
// once.Do(func() {
|
|
// engine = &datastore{db}
|
|
// })
|
|
//
|
|
// return engine
|
|
//}
|
|
//
|
|
//func (store *datastore) DB(ctx context.Context) *gorm.DB {
|
|
// db := store.core
|
|
// // 从上下文中提取事务实例
|
|
// if tx, ok := ctx.Value(transactionKey{}).(*gorm.DB); ok {
|
|
// db = tx
|
|
// }
|
|
//
|
|
// return db
|
|
//}
|
|
//
|
|
//func (store *datastore) TX(ctx context.Context, fn func(ctx context.Context) error) error {
|
|
// return store.core.WithContext(ctx).Transaction(
|
|
// func(tx *gorm.DB) error {
|
|
// ctx = context.WithValue(ctx, transactionKey{}, tx)
|
|
// return fn(ctx)
|
|
// },
|
|
// )
|
|
//}
|